wip
1use crate::wake::WakeArray;
2use futures_compat::LocalWaker;
3use futures_core::FusedFuture;
4use futures_util::maybe_done::MaybeDone;
5use futures_util::maybe_done::maybe_done;
6use std::pin::Pin;
7use std::task::Poll;
8
9/// from [futures-concurrency](https://github.com/yoshuawuyts/futures-concurrency/tree/main)
10/// Wait for all futures to complete.
11///
12/// Awaits multiple futures simultaneously, returning the output of the futures
13/// in the same container type they were created once all complete.
14pub trait Join {
15 /// The resulting output type.
16 type Output;
17
18 /// The [`ScopedFuture`] implementation returned by this method.
19 type Future: futures_core::Future<LocalWaker, Output = Self::Output>;
20
21 /// Waits for multiple futures to complete.
22 ///
23 /// Awaits multiple futures simultaneously, returning the output of the futures
24 /// in the same container type they we're created once all complete.
25 ///
26 /// This function returns a new future which polls all futures concurrently.
27 fn join(self) -> Self::Future;
28}
29
30pub trait JoinExt {
31 fn along_with<Fut>(self, other: Fut) -> Join2<Self, Fut>
32 where
33 Self: Sized + futures_core::Future<LocalWaker>,
34 Fut: futures_core::Future<LocalWaker>,
35 {
36 (self, other).join()
37 }
38}
39
40impl<T> JoinExt for T where T: futures_core::Future<LocalWaker> {}
41
42macro_rules! impl_join_tuple {
43 ($namespace:ident $StructName:ident $($F:ident)+) => {
44 mod $namespace {
45 #[repr(u8)]
46 pub(super) enum Indexes { $($F,)+ }
47 pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
48 }
49
50 #[allow(non_snake_case)]
51 #[must_use = "futures do nothing unless you `.await` or poll them"]
52 pub struct $StructName<$($F: futures_core::Future<LocalWaker>),+> {
53 $($F: MaybeDone<$F>,)*
54 wake_array: WakeArray<{$namespace::LEN}>,
55 }
56
57 impl<$($F: futures_core::Future<LocalWaker>),+> futures_core::Future<LocalWaker> for $StructName<$($F),+>
58 {
59 type Output = ($($F::Output),+);
60
61 #[allow(non_snake_case)]
62 fn poll(self: Pin<&mut Self>, waker: Pin<&LocalWaker>) -> Poll<Self::Output> {
63 let this = unsafe { self.get_unchecked_mut() };
64
65 let wake_array = unsafe { Pin::new_unchecked(&this.wake_array) };
66 $(
67 // TODO debug_assert_matches is nightly https://github.com/rust-lang/rust/issues/82775
68 debug_assert!(!matches!(this.$F, MaybeDone::Gone), "do not poll futures after they return Poll::Ready");
69 let mut $F = unsafe { Pin::new_unchecked(&mut this.$F) };
70 )+
71
72 wake_array.register_parent(waker);
73
74 let mut ready = true;
75
76 $(
77 let index = $namespace::Indexes::$F as usize;
78 let waker = unsafe { wake_array.child_guard_ptr(index).unwrap_unchecked() };
79
80 // ready if MaybeDone is Done or just completed (converted to Done)
81 // unsafe / against Future api contract to poll after Gone/Future is finished
82 ready &= if unsafe { dbg!(wake_array.take_woken(index).unwrap_unchecked()) } {
83 $F.as_mut().poll(waker).is_ready()
84 } else {
85 $F.is_terminated()
86 };
87 )+
88
89 if ready {
90 Poll::Ready((
91 $(
92 // SAFETY:
93 // `ready == true` when all futures are complete.
94 // Once a future is not `MaybeDoneState::Future`, it transitions to `Done`,
95 // so we know the result of `take_output` must be `Some`.
96 unsafe {
97 $F.take_output().unwrap_unchecked()
98 },
99 )*
100 ))
101 } else {
102 Poll::Pending
103 }
104 }
105 }
106
107 impl<$($F: futures_core::Future<LocalWaker>),+> Join for ($($F),+) {
108 type Output = ($($F::Output),*);
109 type Future = $StructName<$($F),+>;
110
111 #[allow(non_snake_case)]
112 fn join(self) -> Self::Future {
113 let ($($F),+) = self;
114
115 $StructName {
116 $($F: maybe_done($F),)*
117 wake_array: WakeArray::new(),
118 }
119 }
120 }
121 };
122}
123
124impl_join_tuple!(join2 Join2 A B);
125impl_join_tuple!(join3 Join3 A B C);
126impl_join_tuple!(join4 Join4 A B C D);
127impl_join_tuple!(join5 Join5 A B C D E);
128impl_join_tuple!(join6 Join6 A B C D E F);
129impl_join_tuple!(join7 Join7 A B C D E F G);
130impl_join_tuple!(join8 Join8 A B C D E F G H);
131impl_join_tuple!(join9 Join9 A B C D E F G H I);
132impl_join_tuple!(join10 Join10 A B C D E F G H I J);
133impl_join_tuple!(join11 Join11 A B C D E F G H I J K);
134impl_join_tuple!(join12 Join12 A B C D E F G H I J K L);
135
136#[cfg(test)]
137mod tests {
138 #![no_std]
139
140 use futures_core::Future;
141 use futures_util::{dummy_guard, poll_fn};
142
143 use crate::wake::local_wake;
144
145 use super::*;
146
147 use std::pin;
148
149 #[test]
150 fn counters() {
151 let mut x1 = 0;
152 let mut x2 = 0;
153 let f1 = poll_fn(|waker| {
154 local_wake(waker);
155 x1 += 1;
156 if x1 == 4 {
157 Poll::Ready(x1)
158 } else {
159 Poll::Pending
160 }
161 });
162 let f2 = poll_fn(|waker| {
163 local_wake(waker);
164 x2 += 1;
165 if x2 == 5 {
166 Poll::Ready(x2)
167 } else {
168 Poll::Pending
169 }
170 });
171 let guard = pin::pin!(dummy_guard());
172 let mut join = pin::pin!((f1, f2).join());
173 for _ in 0..4 {
174 assert_eq!(join.as_mut().poll(guard.as_ref()), Poll::Pending);
175 }
176 assert_eq!(join.poll(guard.as_ref()), Poll::Ready((4, 5)));
177 }
178
179 #[test]
180 fn never_wake() {
181 let f1 = poll_fn(|_| Poll::<i32>::Ready(0));
182 let f2 = poll_fn(|_| Poll::<i32>::Pending);
183 let guard = pin::pin!(dummy_guard());
184 let mut join = pin::pin!((f1, f2).join());
185 for _ in 0..10 {
186 assert_eq!(join.as_mut().poll(guard.as_ref()), Poll::Pending);
187 }
188 }
189
190 #[test]
191 fn immediate() {
192 let f1 = poll_fn(|_| Poll::Ready(1));
193 let f2 = poll_fn(|_| Poll::Ready(2));
194 let join = pin::pin!(f1.along_with(f2));
195 let guard = pin::pin!(dummy_guard());
196 assert_eq!(join.poll(guard.as_ref()), Poll::Ready((1, 2)));
197 }
198}