use crate::web_helpers::internal_error; use atrium_api::xrpc::http::StatusCode; use atrium_api::xrpc::http::request::Parts; use axum::extract::{FromRef, FromRequestParts}; use bb8::{Pool, PooledConnection}; use bb8_redis::RedisConnectionManager; use bb8_redis::redis::{AsyncCommands, FromRedisValue, RedisResult}; use serde::{Deserialize, Serialize}; use thiserror::Error; pub const ATRIUM_SESSION_STORE_PREFIX: &str = "atrium_session:"; pub const ATRIUM_STATE_STORE_KEY: &str = "atrium_state:"; pub fn create_prefixed_key(prefix: &str, key: &str) -> String { format!("{}{}", prefix, key) } pub struct Cache<'a> { redis_pool: PooledConnection<'a, RedisConnectionManager>, } #[derive(Debug, Error)] pub enum RedisFetchErrors { FromDbError, ParseError, Other(String), } impl std::fmt::Display for RedisFetchErrors { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { RedisFetchErrors::FromDbError => write!(f, "Error fetching from Redis database"), RedisFetchErrors::ParseError => write!(f, "Error parsing Redis data"), RedisFetchErrors::Other(msg) => write!(f, "Other error: {}", msg), } } } impl<'a> Cache<'a> { pub fn new(redis_pool: PooledConnection<'a, RedisConnectionManager>) -> Self { Self { redis_pool } } /// Writes a value to the cache at the given key pub async fn write_to_cache( &mut self, redis_key: String, data: T, ) -> RedisResult { self.redis_pool .set_ex( redis_key.clone(), serde_json::to_string(&data).unwrap(), 3600, ) .await } /// Writes to the cache with an expiration in seconds. After x seconds it clears that key pub async fn write_to_cache_with_seconds( &mut self, redis_key: &str, data: T, seconds: u64, ) -> RedisResult<()> { self.redis_pool .set_ex(redis_key, serde_json::to_string(&data).unwrap(), seconds) .await } /// Fetches a saved JSON object from the cache pub async fn fetch_redis_json_object Deserialize<'de>>( &mut self, redis_key: &str, ) -> Result, RedisFetchErrors> { let val: RedisResult> = self.redis_pool.get(redis_key).await; match val { Ok(val) => match val { None => Ok(None), Some(val) => Ok(serde_json::from_str(&val).map_err(|err| { log::error!("Error parsing redis data: {}", err); RedisFetchErrors::ParseError }))?, }, Err(_) => Err(RedisFetchErrors::FromDbError), } } /// Gets a value for a key pub async fn fetch_redis( &mut self, redis_key: &str, ) -> Result { let val = self .redis_pool .get(redis_key) .await .map_err(|_| RedisFetchErrors::FromDbError)?; Ok(val) } /// Gets or sets a value for a given key pub async fn get_or_set( &mut self, redis_key: &str, seconds: u64, fallback_fn: F, ) -> Result where T: for<'de> Deserialize<'de> + Serialize, F: FnOnce() -> Fut, Fut: std::future::Future>, { // Try to get from cache first match self.fetch_redis_json_object::(redis_key).await { Ok(Some(val)) => Ok(val), Ok(None) => { // If not in cache or error, execute the fallback function let result = fallback_fn().await?; // Write the result to cache self.write_to_cache_with_seconds(redis_key, &result, seconds) .await .map_err(|err| { log::error!("Error fetching from redis: {}", err); RedisFetchErrors::FromDbError })?; Ok(result) } Err(err) => Err(err), } } } pub type ConnectionPool = Pool; pub struct CacheConnection<'a>(pub Cache<'a>); impl<'a, S> FromRequestParts for CacheConnection<'a> where ConnectionPool: FromRef, S: Send + Sync, { type Rejection = (StatusCode, String); async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result { let pool = ConnectionPool::from_ref(state); let conn = pool.get_owned().await.map_err(internal_error)?; let cache = Cache { redis_pool: conn }; Ok(Self(cache)) } }