use crate::runtime::{Arc, Mutex}; use crate::{yield_fn, BatchFn, WaitForWorkFn}; use std::collections::{HashMap, HashSet}; use std::hash::Hash; use std::iter::IntoIterator; pub trait AsyncCache { type Key; type Val; async fn get(&mut self, key: &Self::Key) -> Option; async fn insert(&mut self, key: Self::Key, val: Self::Val); async fn remove(&mut self, key: &Self::Key) -> Option; async fn clear(&mut self); } struct State where C: AsyncCache, { completed: C, pending: HashSet, } impl State where C: AsyncCache, { fn with_cache(cache: C) -> Self { State { completed: cache, pending: HashSet::new(), } } } #[derive(Clone)] pub struct Loader where K: Eq + Hash + Clone, V: Clone, F: BatchFn, C: AsyncCache, { state: Arc>>, load_fn: Arc>, wait_for_work_fn: Arc, max_batch_size: usize, } impl Loader where K: Eq + Hash + Clone, V: Clone, F: BatchFn, C: AsyncCache, { pub fn new(load_fn: F, cache: C) -> Self { Loader { state: Arc::new(Mutex::new(State::with_cache(cache))), load_fn: Arc::new(Mutex::new(load_fn)), max_batch_size: 200, wait_for_work_fn: Arc::new(yield_fn(10)), } } pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self { self.max_batch_size = max_batch_size; self } pub fn with_yield_count(mut self, yield_count: usize) -> Self { self.wait_for_work_fn = Arc::new(yield_fn(yield_count)); self } /// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding /// the runtime repeatedly this will generate and `.await` a future of your choice. /// ***This is incompatible with*** [`Self::with_yield_count()`]. pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) -> Self { self.wait_for_work_fn = Arc::new(wait_for_work_fn); self } pub fn max_batch_size(&self) -> usize { self.max_batch_size } pub async fn load(&self, key: K) -> Option { let mut state = self.state.lock().await; if let Some(v) = state.completed.get(&key).await { return Some(v.clone()); } if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v).await; } return state.completed.get(&key).await.clone(); } } drop(state); (self.wait_for_work_fn)().await; let mut state = self.state.lock().await; if let Some(v) = state.completed.get(&key).await { return Some(v.clone()); } if !state.pending.is_empty() { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v).await; } } state.completed.get(&key).await.clone() } pub async fn load_many(&self, keys: Vec) -> HashMap { let mut state = self.state.lock().await; let mut ret = HashMap::new(); let mut rest = Vec::new(); for key in keys.into_iter() { if let Some(v) = state.completed.get(&key).await.clone() { ret.insert(key, v); continue; } if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v).await; } } } rest.push(key); } drop(state); (self.wait_for_work_fn)().await; if !rest.is_empty() { let mut state = self.state.lock().await; if !state.pending.is_empty() { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v).await; } } for key in rest.into_iter() { if let Some(v) = state.completed.get(&key).await.clone() { ret.insert(key, v); } } } ret } pub async fn prime(&self, key: K, val: V) { let mut state = self.state.lock().await; state.completed.insert(key, val).await; } pub async fn prime_many(&self, values: impl IntoIterator) { let mut state = self.state.lock().await; for (k, v) in values.into_iter() { state.completed.insert(k, v).await; } } pub async fn clear(&self, key: K) { let mut state = self.state.lock().await; state.completed.remove(&key).await; } pub async fn clear_all(&self) { let mut state = self.state.lock().await; state.completed.clear().await } }