at main 200 lines 6.5 kB view raw
1use futures_util::LocalWaker; 2 3use crate::wake::WakeArray; 4use std::pin::Pin; 5use std::task::Poll; 6 7/// from [futures-concurrency](https://github.com/yoshuawuyts/futures-concurrency/tree/main) 8/// Wait for the first future to complete. 9/// 10/// Awaits multiple future at once, returning as soon as one completes. The 11/// other futures are cancelled. 12pub trait Race { 13 /// The resulting output type. 14 type Output; 15 16 /// The [`ScopedFuture`] implementation returned by this method. 17 type Future: futures_core::Future<LocalWaker, Output = Self::Output>; 18 19 /// Wait for the first future to complete. 20 /// 21 /// Awaits multiple futures at once, returning as soon as one completes. The 22 /// other futures are cancelled. 23 /// 24 /// This function returns a new future which polls all futures concurrently. 25 fn race(self) -> Self::Future; 26} 27 28pub trait RaceExt<'scope> { 29 fn race_with<Fut>(self, other: Fut) -> Race2<Self, Fut> 30 where 31 Self: Sized + futures_core::Future<LocalWaker>, 32 Fut: futures_core::Future<LocalWaker>, 33 { 34 (self, other).race() 35 } 36} 37 38impl<'scope, T> RaceExt<'scope> for T where T: futures_core::Future<LocalWaker> {} 39 40macro_rules! impl_race_tuple { 41 ($namespace:ident $StructName:ident $OutputsName:ident $($F:ident)+) => { 42 mod $namespace { 43 #[repr(u8)] 44 pub(super) enum Indexes { $($F,)+ } 45 pub(super) const LEN: usize = [$(Indexes::$F,)+].len(); 46 } 47 48 pub enum $OutputsName<$($F,)+> { 49 $($F($F),)+ 50 } 51 52 impl<$($F: std::fmt::Debug,)+> std::fmt::Debug for $OutputsName<$($F,)+> { 53 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 54 match self {$( 55 Self::$F(x) => 56 f.debug_tuple(std::stringify!($F)) 57 .field(x) 58 .finish(), 59 )+} 60 } 61 } 62 63 impl<$($F: PartialEq,)+> PartialEq for $OutputsName<$($F,)+> { 64 fn eq(&self, other: &Self) -> bool { 65 match (self, other) { 66 $((Self::$F(a), Self::$F(b)) => a == b,)+ 67 _ => false 68 } 69 } 70 } 71 72 #[allow(non_snake_case)] 73 #[must_use = "futures do nothing unless you `.await` or poll them"] 74 pub struct $StructName<$($F: futures_core::Future<LocalWaker>),+> { 75 $($F: $F,)* 76 wake_array: WakeArray<{$namespace::LEN}>, 77 } 78 79 impl<'scope, $($F: futures_core::Future<LocalWaker>),+> futures_core::Future<LocalWaker> 80 for $StructName<$($F),+> 81 { 82 type Output = $OutputsName<$($F::Output,)+>; 83 84 #[allow(non_snake_case)] 85 fn poll(self: Pin<&mut Self>, waker: Pin<&LocalWaker>) -> Poll<Self::Output> { 86 let this = unsafe { self.get_unchecked_mut() }; 87 88 let wake_array = unsafe { Pin::new_unchecked(&this.wake_array) }; 89 $( 90 let mut $F = unsafe { Pin::new_unchecked(&mut this.$F) }; 91 )+ 92 93 wake_array.register_parent(waker); 94 95 $( 96 let index = $namespace::Indexes::$F as usize; 97 let waker = unsafe { wake_array.child_guard_ptr(index).unwrap_unchecked() }; 98 99 // this is safe because we know index < LEN 100 if unsafe { wake_array.take_woken(index).unwrap_unchecked() } { 101 if let Poll::Ready(res) = $F.as_mut().poll(waker) { 102 return Poll::Ready($OutputsName::$F(res)); 103 } 104 } 105 )+ 106 107 Poll::Pending 108 } 109 } 110 111 impl<'scope, $($F: futures_core::Future<LocalWaker>),+> Race for ($($F),+) { 112 type Output = $OutputsName<$($F::Output),*>; 113 type Future = $StructName<$($F),+>; 114 115 #[allow(non_snake_case)] 116 fn race(self) -> Self::Future { 117 let ($($F),+) = self; 118 119 $StructName { 120 $($F: $F,)* 121 wake_array: WakeArray::new(), 122 } 123 } 124 } 125 }; 126} 127 128impl_race_tuple!(race2 Race2 RaceOutputs2 A B); 129impl_race_tuple!(race3 Race3 RaceOutputs3 A B C); 130impl_race_tuple!(race4 Race4 RaceOutputs4 A B C D); 131impl_race_tuple!(race5 Race5 RaceOutputs5 A B C D E); 132impl_race_tuple!(race6 Race6 RaceOutputs6 A B C D E F); 133impl_race_tuple!(race7 Race7 RaceOutputs7 A B C D E F G); 134impl_race_tuple!(race8 Race8 RaceOutputs8 A B C D E F G H); 135impl_race_tuple!(race9 Race9 RaceOutputs9 A B C D E F G H I); 136impl_race_tuple!(race10 Race10 RaceOutputs10 A B C D E F G H I J); 137impl_race_tuple!(race11 Race11 RaceOutputs11 A B C D E F G H I J K); 138impl_race_tuple!(race12 Race12 RaceOutputs12 A B C D E F G H I J K L); 139 140#[cfg(test)] 141mod tests { 142 #![no_std] 143 144 use std::pin; 145 146 use futures_core::Future; 147 use futures_util::{dummy_guard, poll_fn}; 148 149 use crate::wake::local_wake; 150 151 use super::*; 152 153 #[test] 154 fn counters() { 155 let mut x1 = 0; 156 let mut x2 = 0; 157 let f1 = poll_fn(|waker| { 158 local_wake(waker); 159 x1 += 1; 160 if x1 == 4 { 161 Poll::Ready(x1) 162 } else { 163 Poll::Pending 164 } 165 }); 166 let f2 = poll_fn(|waker| { 167 local_wake(waker); 168 x2 += 1; 169 if x2 == 2 { 170 Poll::Ready(x2) 171 } else { 172 Poll::Pending 173 } 174 }); 175 let guard = pin::pin!(dummy_guard()); 176 let mut race = pin::pin!((f1, f2).race()); 177 assert_eq!(race.as_mut().poll(guard.as_ref()), Poll::Pending); 178 assert_eq!(race.poll(guard.as_ref()), Poll::Ready(RaceOutputs2::B(2))); 179 } 180 181 #[test] 182 fn never_wake() { 183 let f1 = poll_fn(|_| Poll::<i32>::Pending); 184 let f2 = poll_fn(|_| Poll::<i32>::Pending); 185 let mut race = pin::pin!((f1, f2).race()); 186 let guard = pin::pin!(dummy_guard()); 187 for _ in 0..10 { 188 assert_eq!(race.as_mut().poll(guard.as_ref()), Poll::Pending); 189 } 190 } 191 192 #[test] 193 fn basic() { 194 let f1 = poll_fn(|_| Poll::Ready(1)); 195 let f2 = poll_fn(|_| Poll::Ready(2)); 196 let race = pin::pin!(f1.race_with(f2)); 197 let guard = pin::pin!(dummy_guard()); 198 assert_eq!(race.poll(guard.as_ref()), Poll::Ready(RaceOutputs2::A(1))); 199 } 200}