wip
1use futures_util::LocalWaker;
2
3use crate::wake::WakeArray;
4use std::pin::Pin;
5use std::task::Poll;
6
7/// from [futures-concurrency](https://github.com/yoshuawuyts/futures-concurrency/tree/main)
8/// Wait for the first future to complete.
9///
10/// Awaits multiple future at once, returning as soon as one completes. The
11/// other futures are cancelled.
12pub trait Race {
13 /// The resulting output type.
14 type Output;
15
16 /// The [`ScopedFuture`] implementation returned by this method.
17 type Future: futures_core::Future<LocalWaker, Output = Self::Output>;
18
19 /// Wait for the first future to complete.
20 ///
21 /// Awaits multiple futures at once, returning as soon as one completes. The
22 /// other futures are cancelled.
23 ///
24 /// This function returns a new future which polls all futures concurrently.
25 fn race(self) -> Self::Future;
26}
27
28pub trait RaceExt<'scope> {
29 fn race_with<Fut>(self, other: Fut) -> Race2<Self, Fut>
30 where
31 Self: Sized + futures_core::Future<LocalWaker>,
32 Fut: futures_core::Future<LocalWaker>,
33 {
34 (self, other).race()
35 }
36}
37
38impl<'scope, T> RaceExt<'scope> for T where T: futures_core::Future<LocalWaker> {}
39
40macro_rules! impl_race_tuple {
41 ($namespace:ident $StructName:ident $OutputsName:ident $($F:ident)+) => {
42 mod $namespace {
43 #[repr(u8)]
44 pub(super) enum Indexes { $($F,)+ }
45 pub(super) const LEN: usize = [$(Indexes::$F,)+].len();
46 }
47
48 pub enum $OutputsName<$($F,)+> {
49 $($F($F),)+
50 }
51
52 impl<$($F: std::fmt::Debug,)+> std::fmt::Debug for $OutputsName<$($F,)+> {
53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
54 match self {$(
55 Self::$F(x) =>
56 f.debug_tuple(std::stringify!($F))
57 .field(x)
58 .finish(),
59 )+}
60 }
61 }
62
63 impl<$($F: PartialEq,)+> PartialEq for $OutputsName<$($F,)+> {
64 fn eq(&self, other: &Self) -> bool {
65 match (self, other) {
66 $((Self::$F(a), Self::$F(b)) => a == b,)+
67 _ => false
68 }
69 }
70 }
71
72 #[allow(non_snake_case)]
73 #[must_use = "futures do nothing unless you `.await` or poll them"]
74 pub struct $StructName<$($F: futures_core::Future<LocalWaker>),+> {
75 $($F: $F,)*
76 wake_array: WakeArray<{$namespace::LEN}>,
77 }
78
79 impl<'scope, $($F: futures_core::Future<LocalWaker>),+> futures_core::Future<LocalWaker>
80 for $StructName<$($F),+>
81 {
82 type Output = $OutputsName<$($F::Output,)+>;
83
84 #[allow(non_snake_case)]
85 fn poll(self: Pin<&mut Self>, waker: Pin<&LocalWaker>) -> Poll<Self::Output> {
86 let this = unsafe { self.get_unchecked_mut() };
87
88 let wake_array = unsafe { Pin::new_unchecked(&this.wake_array) };
89 $(
90 let mut $F = unsafe { Pin::new_unchecked(&mut this.$F) };
91 )+
92
93 wake_array.register_parent(waker);
94
95 $(
96 let index = $namespace::Indexes::$F as usize;
97 let waker = unsafe { wake_array.child_guard_ptr(index).unwrap_unchecked() };
98
99 // this is safe because we know index < LEN
100 if unsafe { wake_array.take_woken(index).unwrap_unchecked() } {
101 if let Poll::Ready(res) = $F.as_mut().poll(waker) {
102 return Poll::Ready($OutputsName::$F(res));
103 }
104 }
105 )+
106
107 Poll::Pending
108 }
109 }
110
111 impl<'scope, $($F: futures_core::Future<LocalWaker>),+> Race for ($($F),+) {
112 type Output = $OutputsName<$($F::Output),*>;
113 type Future = $StructName<$($F),+>;
114
115 #[allow(non_snake_case)]
116 fn race(self) -> Self::Future {
117 let ($($F),+) = self;
118
119 $StructName {
120 $($F: $F,)*
121 wake_array: WakeArray::new(),
122 }
123 }
124 }
125 };
126}
127
128impl_race_tuple!(race2 Race2 RaceOutputs2 A B);
129impl_race_tuple!(race3 Race3 RaceOutputs3 A B C);
130impl_race_tuple!(race4 Race4 RaceOutputs4 A B C D);
131impl_race_tuple!(race5 Race5 RaceOutputs5 A B C D E);
132impl_race_tuple!(race6 Race6 RaceOutputs6 A B C D E F);
133impl_race_tuple!(race7 Race7 RaceOutputs7 A B C D E F G);
134impl_race_tuple!(race8 Race8 RaceOutputs8 A B C D E F G H);
135impl_race_tuple!(race9 Race9 RaceOutputs9 A B C D E F G H I);
136impl_race_tuple!(race10 Race10 RaceOutputs10 A B C D E F G H I J);
137impl_race_tuple!(race11 Race11 RaceOutputs11 A B C D E F G H I J K);
138impl_race_tuple!(race12 Race12 RaceOutputs12 A B C D E F G H I J K L);
139
140#[cfg(test)]
141mod tests {
142 #![no_std]
143
144 use std::pin;
145
146 use futures_core::Future;
147 use futures_util::{dummy_guard, poll_fn};
148
149 use crate::wake::local_wake;
150
151 use super::*;
152
153 #[test]
154 fn counters() {
155 let mut x1 = 0;
156 let mut x2 = 0;
157 let f1 = poll_fn(|waker| {
158 local_wake(waker);
159 x1 += 1;
160 if x1 == 4 {
161 Poll::Ready(x1)
162 } else {
163 Poll::Pending
164 }
165 });
166 let f2 = poll_fn(|waker| {
167 local_wake(waker);
168 x2 += 1;
169 if x2 == 2 {
170 Poll::Ready(x2)
171 } else {
172 Poll::Pending
173 }
174 });
175 let guard = pin::pin!(dummy_guard());
176 let mut race = pin::pin!((f1, f2).race());
177 assert_eq!(race.as_mut().poll(guard.as_ref()), Poll::Pending);
178 assert_eq!(race.poll(guard.as_ref()), Poll::Ready(RaceOutputs2::B(2)));
179 }
180
181 #[test]
182 fn never_wake() {
183 let f1 = poll_fn(|_| Poll::<i32>::Pending);
184 let f2 = poll_fn(|_| Poll::<i32>::Pending);
185 let mut race = pin::pin!((f1, f2).race());
186 let guard = pin::pin!(dummy_guard());
187 for _ in 0..10 {
188 assert_eq!(race.as_mut().poll(guard.as_ref()), Poll::Pending);
189 }
190 }
191
192 #[test]
193 fn basic() {
194 let f1 = poll_fn(|_| Poll::Ready(1));
195 let f2 = poll_fn(|_| Poll::Ready(2));
196 let race = pin::pin!(f1.race_with(f2));
197 let guard = pin::pin!(dummy_guard());
198 assert_eq!(race.poll(guard.as_ref()), Poll::Ready(RaceOutputs2::A(1)));
199 }
200}