use crate::runtime::{Arc, Mutex}; use crate::{yield_fn, BatchFn, WaitForWorkFn}; use std::collections::{HashMap, HashSet}; use std::fmt::Debug; use std::hash::{BuildHasher, Hash}; use std::iter::IntoIterator; pub trait Cache { type Key; type Val; fn get(&mut self, key: &Self::Key) -> Option<&Self::Val>; fn insert(&mut self, key: Self::Key, val: Self::Val); fn remove(&mut self, key: &Self::Key) -> Option; fn clear(&mut self); } impl Cache for HashMap where K: Eq + Hash, { type Key = K; type Val = V; #[inline] fn get(&mut self, key: &K) -> Option<&V> { HashMap::get(self, key) } #[inline] fn insert(&mut self, key: K, val: V) { HashMap::insert(self, key, val); } #[inline] fn remove(&mut self, key: &K) -> Option { HashMap::remove(self, key) } #[inline] fn clear(&mut self) { HashMap::clear(self) } } struct State> where C: Cache, { completed: C, pending: HashSet, } impl State where C: Cache, { fn with_cache(cache: C) -> Self { State { completed: cache, pending: HashSet::new(), } } } pub struct Loader> where K: Eq + Hash + Clone, V: Clone, F: BatchFn, C: Cache, { state: Arc>>, load_fn: Arc>, wait_for_work_fn: Arc, max_batch_size: usize, } impl Clone for Loader where K: Eq + Hash + Clone, V: Clone, F: BatchFn, C: Cache, { fn clone(&self) -> Self { Loader { state: self.state.clone(), max_batch_size: self.max_batch_size, load_fn: self.load_fn.clone(), wait_for_work_fn: self.wait_for_work_fn.clone(), } } } #[allow(clippy::implicit_hasher)] impl Loader> where K: Eq + Hash + Clone + Debug, V: Clone, F: BatchFn, { pub fn new(load_fn: F) -> Loader> { Loader::with_cache(load_fn, HashMap::new()) } } impl Loader where K: Eq + Hash + Clone + Debug, V: Clone, F: BatchFn, C: Cache, { pub fn with_cache(load_fn: F, cache: C) -> Loader { Loader { state: Arc::new(Mutex::new(State::with_cache(cache))), load_fn: Arc::new(Mutex::new(load_fn)), max_batch_size: 200, wait_for_work_fn: Arc::new(yield_fn(10)), } } pub fn with_max_batch_size(mut self, max_batch_size: usize) -> Self { self.max_batch_size = max_batch_size; self } pub fn with_yield_count(mut self, yield_count: usize) -> Self { self.wait_for_work_fn = Arc::new(yield_fn(yield_count)); self } /// Replaces the yielding for work behavior with an arbitrary future. Rather than yielding /// the runtime repeatedly this will generate and `.await` a future of your choice. /// ***This is incompatible with*** [`Self::with_yield_count()`]. pub fn with_custom_wait_for_work(mut self, wait_for_work_fn: impl WaitForWorkFn) -> Self { self.wait_for_work_fn = Arc::new(wait_for_work_fn); self } pub fn max_batch_size(&self) -> usize { self.max_batch_size } pub async fn load(&self, key: K) -> Option { let mut state = self.state.lock().await; if let Some(v) = state.completed.get(&key) { return Some((*v).clone()); } if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v); } return state.completed.get(&key).cloned(); } } drop(state); (self.wait_for_work_fn)().await; let mut state = self.state.lock().await; if let Some(v) = state.completed.get(&key) { return Some((*v).clone()); } if !state.pending.is_empty() { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v); } } state.completed.get(&key).cloned() } pub async fn load_many(&self, keys: Vec) -> HashMap { let mut state = self.state.lock().await; let mut ret = HashMap::new(); let mut rest = Vec::new(); for key in keys.into_iter() { if let Some(v) = state.completed.get(&key).cloned() { ret.insert(key, v); continue; } if !state.pending.contains(&key) { state.pending.insert(key.clone()); if state.pending.len() >= self.max_batch_size { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v); } } } rest.push(key); } drop(state); (self.wait_for_work_fn)().await; if !rest.is_empty() { let mut state = self.state.lock().await; if !state.pending.is_empty() { let keys = state.pending.drain().collect::>(); let mut load_fn = self.load_fn.lock().await; let load_ret = load_fn.load(keys.as_ref()).await; drop(load_fn); for (k, v) in load_ret.into_iter() { state.completed.insert(k, v); } } for key in rest.into_iter() { if let Some(v) = state.completed.get(&key).cloned() { ret.insert(key, v); } } } ret } pub async fn prime(&self, key: K, val: V) { let mut state = self.state.lock().await; state.completed.insert(key, val); } pub async fn prime_many(&self, values: impl IntoIterator) { let mut state = self.state.lock().await; for (k, v) in values.into_iter() { state.completed.insert(k, v); } } pub async fn clear(&self, key: K) { let mut state = self.state.lock().await; state.completed.remove(&key); } pub async fn clear_all(&self) { let mut state = self.state.lock().await; state.completed.clear() } }