this repo has no description
at main 4.8 kB view raw
1use crate::web_helpers::internal_error; 2use atrium_api::xrpc::http::StatusCode; 3use atrium_api::xrpc::http::request::Parts; 4use axum::extract::{FromRef, FromRequestParts}; 5use bb8::{Pool, PooledConnection}; 6use bb8_redis::RedisConnectionManager; 7use bb8_redis::redis::{AsyncCommands, FromRedisValue, RedisResult}; 8use serde::{Deserialize, Serialize}; 9use thiserror::Error; 10 11pub const ATRIUM_SESSION_STORE_PREFIX: &str = "atrium_session:"; 12pub const ATRIUM_STATE_STORE_KEY: &str = "atrium_state:"; 13 14pub fn create_prefixed_key(prefix: &str, key: &str) -> String { 15 format!("{}{}", prefix, key) 16} 17 18pub struct Cache<'a> { 19 redis_pool: PooledConnection<'a, RedisConnectionManager>, 20} 21 22#[derive(Debug, Error)] 23pub enum RedisFetchErrors { 24 FromDbError, 25 ParseError, 26 Other(String), 27} 28 29impl std::fmt::Display for RedisFetchErrors { 30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 31 match self { 32 RedisFetchErrors::FromDbError => write!(f, "Error fetching from Redis database"), 33 RedisFetchErrors::ParseError => write!(f, "Error parsing Redis data"), 34 RedisFetchErrors::Other(msg) => write!(f, "Other error: {}", msg), 35 } 36 } 37} 38 39impl<'a> Cache<'a> { 40 pub fn new(redis_pool: PooledConnection<'a, RedisConnectionManager>) -> Self { 41 Self { redis_pool } 42 } 43 44 /// Writes a value to the cache at the given key 45 pub async fn write_to_cache<T: Serialize>( 46 &mut self, 47 redis_key: String, 48 data: T, 49 ) -> RedisResult<String> { 50 self.redis_pool 51 .set_ex( 52 redis_key.clone(), 53 serde_json::to_string(&data).unwrap(), 54 3600, 55 ) 56 .await 57 } 58 59 /// Writes to the cache with an expiration in seconds. After x seconds it clears that key 60 pub async fn write_to_cache_with_seconds<T: Serialize>( 61 &mut self, 62 redis_key: &str, 63 data: T, 64 seconds: u64, 65 ) -> RedisResult<()> { 66 self.redis_pool 67 .set_ex(redis_key, serde_json::to_string(&data).unwrap(), seconds) 68 .await 69 } 70 71 /// Fetches a saved JSON object from the cache 72 pub async fn fetch_redis_json_object<T: for<'de> Deserialize<'de>>( 73 &mut self, 74 redis_key: &str, 75 ) -> Result<Option<T>, RedisFetchErrors> { 76 let val: RedisResult<Option<String>> = self.redis_pool.get(redis_key).await; 77 78 match val { 79 Ok(val) => match val { 80 None => Ok(None), 81 Some(val) => Ok(serde_json::from_str(&val).map_err(|err| { 82 log::error!("Error parsing redis data: {}", err); 83 RedisFetchErrors::ParseError 84 }))?, 85 }, 86 Err(_) => Err(RedisFetchErrors::FromDbError), 87 } 88 } 89 90 /// Gets a value for a key 91 pub async fn fetch_redis<T: FromRedisValue>( 92 &mut self, 93 redis_key: &str, 94 ) -> Result<T, RedisFetchErrors> { 95 let val = self 96 .redis_pool 97 .get(redis_key) 98 .await 99 .map_err(|_| RedisFetchErrors::FromDbError)?; 100 101 Ok(val) 102 } 103 104 /// Gets or sets a value for a given key 105 pub async fn get_or_set<T, F, Fut>( 106 &mut self, 107 redis_key: &str, 108 seconds: u64, 109 fallback_fn: F, 110 ) -> Result<T, RedisFetchErrors> 111 where 112 T: for<'de> Deserialize<'de> + Serialize, 113 F: FnOnce() -> Fut, 114 Fut: std::future::Future<Output = Result<T, RedisFetchErrors>>, 115 { 116 // Try to get from cache first 117 match self.fetch_redis_json_object::<T>(redis_key).await { 118 Ok(Some(val)) => Ok(val), 119 Ok(None) => { 120 // If not in cache or error, execute the fallback function 121 let result = fallback_fn().await?; 122 123 // Write the result to cache 124 self.write_to_cache_with_seconds(redis_key, &result, seconds) 125 .await 126 .map_err(|err| { 127 log::error!("Error fetching from redis: {}", err); 128 RedisFetchErrors::FromDbError 129 })?; 130 131 Ok(result) 132 } 133 Err(err) => Err(err), 134 } 135 } 136} 137 138pub type ConnectionPool = Pool<RedisConnectionManager>; 139 140pub struct CacheConnection<'a>(pub Cache<'a>); 141 142impl<'a, S> FromRequestParts<S> for CacheConnection<'a> 143where 144 ConnectionPool: FromRef<S>, 145 S: Send + Sync, 146{ 147 type Rejection = (StatusCode, String); 148 149 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 150 let pool = ConnectionPool::from_ref(state); 151 152 let conn = pool.get_owned().await.map_err(internal_error)?; 153 let cache = Cache { redis_pool: conn }; 154 Ok(Self(cache)) 155 } 156}