1use crate::web_helpers::internal_error;
2use atrium_api::xrpc::http::StatusCode;
3use atrium_api::xrpc::http::request::Parts;
4use axum::extract::{FromRef, FromRequestParts};
5use bb8::{Pool, PooledConnection};
6use bb8_redis::RedisConnectionManager;
7use bb8_redis::redis::{AsyncCommands, FromRedisValue, RedisResult};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11pub const ATRIUM_SESSION_STORE_PREFIX: &str = "atrium_session:";
12pub const ATRIUM_STATE_STORE_KEY: &str = "atrium_state:";
13
14pub const TOWER_SESSION_KEY: &str = "tower_session:";
15
16pub fn create_prefixed_key(prefix: &str, key: &str) -> String {
17 format!("{}{}", prefix, key)
18}
19
20pub struct Cache<'a> {
21 pub redis_pool: PooledConnection<'a, RedisConnectionManager>,
22}
23
24#[derive(Debug, Error)]
25pub enum RedisFetchErrors {
26 FromDbError,
27 ParseError,
28 Other(String),
29}
30
31impl std::fmt::Display for RedisFetchErrors {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 match self {
34 RedisFetchErrors::FromDbError => write!(f, "Error fetching from Redis database"),
35 RedisFetchErrors::ParseError => write!(f, "Error parsing Redis data"),
36 RedisFetchErrors::Other(msg) => write!(f, "Other error: {}", msg),
37 }
38 }
39}
40
41impl<'a> Cache<'a> {
42 pub fn new(redis_pool: PooledConnection<'a, RedisConnectionManager>) -> Self {
43 Self { redis_pool }
44 }
45
46 /// Writes a value to the cache at the given key
47 pub async fn write_to_cache<T: Serialize>(
48 &mut self,
49 redis_key: String,
50 data: T,
51 ) -> RedisResult<String> {
52 self.redis_pool
53 .set_ex(
54 redis_key.clone(),
55 serde_json::to_string(&data).unwrap(),
56 3600,
57 )
58 .await
59 }
60
61 /// Writes to the cache with an expiration in seconds. After x seconds it clears that key
62 pub async fn write_to_cache_with_seconds<T: Serialize>(
63 &mut self,
64 redis_key: &str,
65 data: T,
66 seconds: u64,
67 ) -> RedisResult<()> {
68 self.redis_pool
69 .set_ex(redis_key, serde_json::to_string(&data).unwrap(), seconds)
70 .await
71 }
72
73 /// Fetches a saved JSON object from the cache
74 pub async fn fetch_redis_json_object<T: for<'de> Deserialize<'de>>(
75 &mut self,
76 redis_key: &str,
77 ) -> Result<Option<T>, RedisFetchErrors> {
78 let val: RedisResult<Option<String>> = self.redis_pool.get(redis_key).await;
79
80 match val {
81 Ok(val) => match val {
82 None => Ok(None),
83 Some(val) => Ok(serde_json::from_str(&val).map_err(|err| {
84 log::error!("Error parsing redis data: {}", err);
85 RedisFetchErrors::ParseError
86 }))?,
87 },
88 Err(_) => Err(RedisFetchErrors::FromDbError),
89 }
90 }
91
92 /// Gets a value for a key
93 pub async fn fetch_redis<T: FromRedisValue>(
94 &mut self,
95 redis_key: &str,
96 ) -> Result<T, RedisFetchErrors> {
97 let val = self
98 .redis_pool
99 .get(redis_key)
100 .await
101 .map_err(|_| RedisFetchErrors::FromDbError)?;
102
103 Ok(val)
104 }
105
106 /// Gets or sets a value for a given key
107 pub async fn get_or_set<T, F, Fut>(
108 &mut self,
109 redis_key: &str,
110 seconds: u64,
111 fallback_fn: F,
112 ) -> Result<T, RedisFetchErrors>
113 where
114 T: for<'de> Deserialize<'de> + Serialize,
115 F: FnOnce() -> Fut,
116 Fut: std::future::Future<Output = Result<T, RedisFetchErrors>>,
117 {
118 // Try to get from cache first
119 match self.fetch_redis_json_object::<T>(redis_key).await {
120 Ok(Some(val)) => Ok(val),
121 Ok(None) => {
122 // If not in cache or error, execute the fallback function
123 let result = fallback_fn().await?;
124
125 // Write the result to cache
126 self.write_to_cache_with_seconds(redis_key, &result, seconds)
127 .await
128 .map_err(|err| {
129 log::error!("Error fetching from redis: {}", err);
130 RedisFetchErrors::FromDbError
131 })?;
132
133 Ok(result)
134 }
135 Err(err) => Err(err),
136 }
137 }
138}
139
140pub type ConnectionPool = Pool<RedisConnectionManager>;
141
142pub struct CacheConnection<'a>(pub Cache<'a>);
143
144impl<'a, S> FromRequestParts<S> for CacheConnection<'a>
145where
146 ConnectionPool: FromRef<S>,
147 S: Send + Sync,
148{
149 type Rejection = (StatusCode, String);
150
151 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
152 let pool = ConnectionPool::from_ref(state);
153
154 let conn = pool.get_owned().await.map_err(internal_error)?;
155 let cache = Cache { redis_pool: conn };
156 Ok(Self(cache))
157 }
158}