at main 6.7 kB view raw
1use crate::wake::WakeArray; 2use futures_compat::LocalWaker; 3use futures_core::FusedFuture; 4use futures_util::maybe_done::MaybeDone; 5use futures_util::maybe_done::maybe_done; 6use std::pin::Pin; 7use std::task::Poll; 8 9/// from [futures-concurrency](https://github.com/yoshuawuyts/futures-concurrency/tree/main) 10/// Wait for all futures to complete. 11/// 12/// Awaits multiple futures simultaneously, returning the output of the futures 13/// in the same container type they were created once all complete. 14pub trait Join { 15 /// The resulting output type. 16 type Output; 17 18 /// The [`ScopedFuture`] implementation returned by this method. 19 type Future: futures_core::Future<LocalWaker, Output = Self::Output>; 20 21 /// Waits for multiple futures to complete. 22 /// 23 /// Awaits multiple futures simultaneously, returning the output of the futures 24 /// in the same container type they we're created once all complete. 25 /// 26 /// This function returns a new future which polls all futures concurrently. 27 fn join(self) -> Self::Future; 28} 29 30pub trait JoinExt { 31 fn along_with<Fut>(self, other: Fut) -> Join2<Self, Fut> 32 where 33 Self: Sized + futures_core::Future<LocalWaker>, 34 Fut: futures_core::Future<LocalWaker>, 35 { 36 (self, other).join() 37 } 38} 39 40impl<T> JoinExt for T where T: futures_core::Future<LocalWaker> {} 41 42macro_rules! impl_join_tuple { 43 ($namespace:ident $StructName:ident $($F:ident)+) => { 44 mod $namespace { 45 #[repr(u8)] 46 pub(super) enum Indexes { $($F,)+ } 47 pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); 48 } 49 50 #[allow(non_snake_case)] 51 #[must_use = "futures do nothing unless you `.await` or poll them"] 52 pub struct $StructName<$($F: futures_core::Future<LocalWaker>),+> { 53 $($F: MaybeDone<$F>,)* 54 wake_array: WakeArray<{$namespace::LEN}>, 55 } 56 57 impl<$($F: futures_core::Future<LocalWaker>),+> futures_core::Future<LocalWaker> for $StructName<$($F),+> 58 { 59 type Output = ($($F::Output),+); 60 61 #[allow(non_snake_case)] 62 fn poll(self: Pin<&mut Self>, waker: Pin<&LocalWaker>) -> Poll<Self::Output> { 63 let this = unsafe { self.get_unchecked_mut() }; 64 65 let wake_array = unsafe { Pin::new_unchecked(&this.wake_array) }; 66 $( 67 // TODO debug_assert_matches is nightly https://github.com/rust-lang/rust/issues/82775 68 debug_assert!(!matches!(this.$F, MaybeDone::Gone), "do not poll futures after they return Poll::Ready"); 69 let mut $F = unsafe { Pin::new_unchecked(&mut this.$F) }; 70 )+ 71 72 wake_array.register_parent(waker); 73 74 let mut ready = true; 75 76 $( 77 let index = $namespace::Indexes::$F as usize; 78 let waker = unsafe { wake_array.child_guard_ptr(index).unwrap_unchecked() }; 79 80 // ready if MaybeDone is Done or just completed (converted to Done) 81 // unsafe / against Future api contract to poll after Gone/Future is finished 82 ready &= if unsafe { dbg!(wake_array.take_woken(index).unwrap_unchecked()) } { 83 $F.as_mut().poll(waker).is_ready() 84 } else { 85 $F.is_terminated() 86 }; 87 )+ 88 89 if ready { 90 Poll::Ready(( 91 $( 92 // SAFETY: 93 // `ready == true` when all futures are complete. 94 // Once a future is not `MaybeDoneState::Future`, it transitions to `Done`, 95 // so we know the result of `take_output` must be `Some`. 96 unsafe { 97 $F.take_output().unwrap_unchecked() 98 }, 99 )* 100 )) 101 } else { 102 Poll::Pending 103 } 104 } 105 } 106 107 impl<$($F: futures_core::Future<LocalWaker>),+> Join for ($($F),+) { 108 type Output = ($($F::Output),*); 109 type Future = $StructName<$($F),+>; 110 111 #[allow(non_snake_case)] 112 fn join(self) -> Self::Future { 113 let ($($F),+) = self; 114 115 $StructName { 116 $($F: maybe_done($F),)* 117 wake_array: WakeArray::new(), 118 } 119 } 120 } 121 }; 122} 123 124impl_join_tuple!(join2 Join2 A B); 125impl_join_tuple!(join3 Join3 A B C); 126impl_join_tuple!(join4 Join4 A B C D); 127impl_join_tuple!(join5 Join5 A B C D E); 128impl_join_tuple!(join6 Join6 A B C D E F); 129impl_join_tuple!(join7 Join7 A B C D E F G); 130impl_join_tuple!(join8 Join8 A B C D E F G H); 131impl_join_tuple!(join9 Join9 A B C D E F G H I); 132impl_join_tuple!(join10 Join10 A B C D E F G H I J); 133impl_join_tuple!(join11 Join11 A B C D E F G H I J K); 134impl_join_tuple!(join12 Join12 A B C D E F G H I J K L); 135 136#[cfg(test)] 137mod tests { 138 #![no_std] 139 140 use futures_core::Future; 141 use futures_util::{dummy_guard, poll_fn}; 142 143 use crate::wake::local_wake; 144 145 use super::*; 146 147 use std::pin; 148 149 #[test] 150 fn counters() { 151 let mut x1 = 0; 152 let mut x2 = 0; 153 let f1 = poll_fn(|waker| { 154 local_wake(waker); 155 x1 += 1; 156 if x1 == 4 { 157 Poll::Ready(x1) 158 } else { 159 Poll::Pending 160 } 161 }); 162 let f2 = poll_fn(|waker| { 163 local_wake(waker); 164 x2 += 1; 165 if x2 == 5 { 166 Poll::Ready(x2) 167 } else { 168 Poll::Pending 169 } 170 }); 171 let guard = pin::pin!(dummy_guard()); 172 let mut join = pin::pin!((f1, f2).join()); 173 for _ in 0..4 { 174 assert_eq!(join.as_mut().poll(guard.as_ref()), Poll::Pending); 175 } 176 assert_eq!(join.poll(guard.as_ref()), Poll::Ready((4, 5))); 177 } 178 179 #[test] 180 fn never_wake() { 181 let f1 = poll_fn(|_| Poll::<i32>::Ready(0)); 182 let f2 = poll_fn(|_| Poll::<i32>::Pending); 183 let guard = pin::pin!(dummy_guard()); 184 let mut join = pin::pin!((f1, f2).join()); 185 for _ in 0..10 { 186 assert_eq!(join.as_mut().poll(guard.as_ref()), Poll::Pending); 187 } 188 } 189 190 #[test] 191 fn immediate() { 192 let f1 = poll_fn(|_| Poll::Ready(1)); 193 let f2 = poll_fn(|_| Poll::Ready(2)); 194 let join = pin::pin!(f1.along_with(f2)); 195 let guard = pin::pin!(dummy_guard()); 196 assert_eq!(join.poll(guard.as_ref()), Poll::Ready((1, 2))); 197 } 198}