Microservice to bring 2FA to self hosted PDSes
1use dashmap::DashMap;
2use std::sync::Arc;
3use std::time::{Duration, Instant};
4
5#[derive(Clone, Debug)]
6struct CachedHandle {
7 handle: String,
8 cached_at: Instant,
9}
10
11/// A thread-safe cache for DID-to-handle resolutions with TTL expiration.
12#[derive(Clone)]
13pub struct HandleCache {
14 cache: Arc<DashMap<String, CachedHandle>>,
15 ttl: Duration,
16}
17
18impl Default for HandleCache {
19 fn default() -> Self {
20 Self::new()
21 }
22}
23
24impl HandleCache {
25 /// Creates a new HandleCache with a default TTL of 1 hour.
26 pub fn new() -> Self {
27 Self::with_ttl(Duration::from_secs(3600))
28 }
29
30 /// Creates a new HandleCache with a custom TTL.
31 pub fn with_ttl(ttl: Duration) -> Self {
32 Self {
33 cache: Arc::new(DashMap::new()),
34 ttl,
35 }
36 }
37
38 /// Gets a cached handle for the given DID, if it exists and hasn't expired.
39 pub fn get(&self, did: &str) -> Option<String> {
40 let entry = self.cache.get(did)?;
41 if entry.cached_at.elapsed() > self.ttl {
42 drop(entry);
43 self.cache.remove(did);
44 return None;
45 }
46 Some(entry.handle.clone())
47 }
48
49 /// Inserts a DID-to-handle mapping into the cache.
50 pub fn insert(&self, did: String, handle: String) {
51 self.cache.insert(
52 did,
53 CachedHandle {
54 handle,
55 cached_at: Instant::now(),
56 },
57 );
58 }
59
60 /// Removes expired entries from the cache.
61 /// Call this periodically to prevent memory growth.
62 pub fn cleanup(&self) {
63 self.cache.retain(|_, v| v.cached_at.elapsed() <= self.ttl);
64 }
65
66 /// Returns the number of entries in the cache.
67 pub fn len(&self) -> usize {
68 self.cache.len()
69 }
70
71 /// Returns true if the cache is empty.
72 pub fn is_empty(&self) -> bool {
73 self.cache.is_empty()
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use super::*;
80
81 #[test]
82 fn test_cache_insert_and_get() {
83 let cache = HandleCache::new();
84 cache.insert("did:plc:test".into(), "test.handle.com".into());
85 assert_eq!(cache.get("did:plc:test"), Some("test.handle.com".into()));
86 }
87
88 #[test]
89 fn test_cache_miss() {
90 let cache = HandleCache::new();
91 assert_eq!(cache.get("did:plc:nonexistent"), None);
92 }
93
94 #[test]
95 fn test_cache_expiration() {
96 let cache = HandleCache::with_ttl(Duration::from_millis(1));
97 cache.insert("did:plc:test".into(), "test.handle.com".into());
98 std::thread::sleep(Duration::from_millis(10));
99 assert_eq!(cache.get("did:plc:test"), None);
100 }
101}