this repo has no description
at fmt 4.9 kB view raw
1use crate::web_helpers::internal_error; 2use atrium_api::xrpc::http::StatusCode; 3use atrium_api::xrpc::http::request::Parts; 4use axum::extract::{FromRef, FromRequestParts}; 5use bb8::{Pool, PooledConnection}; 6use bb8_redis::RedisConnectionManager; 7use bb8_redis::redis::{AsyncCommands, FromRedisValue, RedisResult}; 8use serde::{Deserialize, Serialize}; 9use thiserror::Error; 10 11pub const ATRIUM_SESSION_STORE_PREFIX: &str = "atrium_session:"; 12pub const ATRIUM_STATE_STORE_KEY: &str = "atrium_state:"; 13 14pub const TOWER_SESSION_KEY: &str = "tower_session:"; 15 16pub fn create_prefixed_key(prefix: &str, key: &str) -> String { 17 format!("{}{}", prefix, key) 18} 19 20pub struct Cache<'a> { 21 pub redis_pool: PooledConnection<'a, RedisConnectionManager>, 22} 23 24#[derive(Debug, Error)] 25pub enum RedisFetchErrors { 26 FromDbError, 27 ParseError, 28 Other(String), 29} 30 31impl std::fmt::Display for RedisFetchErrors { 32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 33 match self { 34 RedisFetchErrors::FromDbError => write!(f, "Error fetching from Redis database"), 35 RedisFetchErrors::ParseError => write!(f, "Error parsing Redis data"), 36 RedisFetchErrors::Other(msg) => write!(f, "Other error: {}", msg), 37 } 38 } 39} 40 41impl<'a> Cache<'a> { 42 pub fn new(redis_pool: PooledConnection<'a, RedisConnectionManager>) -> Self { 43 Self { redis_pool } 44 } 45 46 /// Writes a value to the cache at the given key 47 pub async fn write_to_cache<T: Serialize>( 48 &mut self, 49 redis_key: String, 50 data: T, 51 ) -> RedisResult<String> { 52 self.redis_pool 53 .set_ex( 54 redis_key.clone(), 55 serde_json::to_string(&data).unwrap(), 56 3600, 57 ) 58 .await 59 } 60 61 /// Writes to the cache with an expiration in seconds. After x seconds it clears that key 62 pub async fn write_to_cache_with_seconds<T: Serialize>( 63 &mut self, 64 redis_key: &str, 65 data: T, 66 seconds: u64, 67 ) -> RedisResult<()> { 68 self.redis_pool 69 .set_ex(redis_key, serde_json::to_string(&data).unwrap(), seconds) 70 .await 71 } 72 73 /// Fetches a saved JSON object from the cache 74 pub async fn fetch_redis_json_object<T: for<'de> Deserialize<'de>>( 75 &mut self, 76 redis_key: &str, 77 ) -> Result<Option<T>, RedisFetchErrors> { 78 let val: RedisResult<Option<String>> = self.redis_pool.get(redis_key).await; 79 80 match val { 81 Ok(val) => match val { 82 None => Ok(None), 83 Some(val) => Ok(serde_json::from_str(&val).map_err(|err| { 84 log::error!("Error parsing redis data: {}", err); 85 RedisFetchErrors::ParseError 86 }))?, 87 }, 88 Err(_) => Err(RedisFetchErrors::FromDbError), 89 } 90 } 91 92 /// Gets a value for a key 93 pub async fn fetch_redis<T: FromRedisValue>( 94 &mut self, 95 redis_key: &str, 96 ) -> Result<T, RedisFetchErrors> { 97 let val = self 98 .redis_pool 99 .get(redis_key) 100 .await 101 .map_err(|_| RedisFetchErrors::FromDbError)?; 102 103 Ok(val) 104 } 105 106 /// Gets or sets a value for a given key 107 pub async fn get_or_set<T, F, Fut>( 108 &mut self, 109 redis_key: &str, 110 seconds: u64, 111 fallback_fn: F, 112 ) -> Result<T, RedisFetchErrors> 113 where 114 T: for<'de> Deserialize<'de> + Serialize, 115 F: FnOnce() -> Fut, 116 Fut: std::future::Future<Output = Result<T, RedisFetchErrors>>, 117 { 118 // Try to get from cache first 119 match self.fetch_redis_json_object::<T>(redis_key).await { 120 Ok(Some(val)) => Ok(val), 121 Ok(None) => { 122 // If not in cache or error, execute the fallback function 123 let result = fallback_fn().await?; 124 125 // Write the result to cache 126 self.write_to_cache_with_seconds(redis_key, &result, seconds) 127 .await 128 .map_err(|err| { 129 log::error!("Error fetching from redis: {}", err); 130 RedisFetchErrors::FromDbError 131 })?; 132 133 Ok(result) 134 } 135 Err(err) => Err(err), 136 } 137 } 138} 139 140pub type ConnectionPool = Pool<RedisConnectionManager>; 141 142pub struct CacheConnection<'a>(pub Cache<'a>); 143 144impl<'a, S> FromRequestParts<S> for CacheConnection<'a> 145where 146 ConnectionPool: FromRef<S>, 147 S: Send + Sync, 148{ 149 type Rejection = (StatusCode, String); 150 151 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> { 152 let pool = ConnectionPool::from_ref(state); 153 154 let conn = pool.get_owned().await.map_err(internal_error)?; 155 let cache = Cache { redis_pool: conn }; 156 Ok(Self(cache)) 157 } 158}