1use crate::web_helpers::internal_error;
2use atrium_api::xrpc::http::StatusCode;
3use atrium_api::xrpc::http::request::Parts;
4use axum::extract::{FromRef, FromRequestParts};
5use bb8::{Pool, PooledConnection};
6use bb8_redis::RedisConnectionManager;
7use bb8_redis::redis::{AsyncCommands, FromRedisValue, RedisResult};
8use serde::{Deserialize, Serialize};
9use thiserror::Error;
10
11pub const ATRIUM_SESSION_STORE_PREFIX: &str = "atrium_session:";
12pub const ATRIUM_STATE_STORE_KEY: &str = "atrium_state:";
13
14pub fn create_prefixed_key(prefix: &str, key: &str) -> String {
15 format!("{}{}", prefix, key)
16}
17
18pub struct Cache<'a> {
19 redis_pool: PooledConnection<'a, RedisConnectionManager>,
20}
21
22#[derive(Debug, Error)]
23pub enum RedisFetchErrors {
24 FromDbError,
25 ParseError,
26 Other(String),
27}
28
29impl std::fmt::Display for RedisFetchErrors {
30 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
31 match self {
32 RedisFetchErrors::FromDbError => write!(f, "Error fetching from Redis database"),
33 RedisFetchErrors::ParseError => write!(f, "Error parsing Redis data"),
34 RedisFetchErrors::Other(msg) => write!(f, "Other error: {}", msg),
35 }
36 }
37}
38
39impl<'a> Cache<'a> {
40 pub fn new(redis_pool: PooledConnection<'a, RedisConnectionManager>) -> Self {
41 Self { redis_pool }
42 }
43
44 /// Writes a value to the cache at the given key
45 pub async fn write_to_cache<T: Serialize>(
46 &mut self,
47 redis_key: String,
48 data: T,
49 ) -> RedisResult<String> {
50 self.redis_pool
51 .set_ex(
52 redis_key.clone(),
53 serde_json::to_string(&data).unwrap(),
54 3600,
55 )
56 .await
57 }
58
59 /// Writes to the cache with an expiration in seconds. After x seconds it clears that key
60 pub async fn write_to_cache_with_seconds<T: Serialize>(
61 &mut self,
62 redis_key: &str,
63 data: T,
64 seconds: u64,
65 ) -> RedisResult<()> {
66 self.redis_pool
67 .set_ex(redis_key, serde_json::to_string(&data).unwrap(), seconds)
68 .await
69 }
70
71 /// Fetches a saved JSON object from the cache
72 pub async fn fetch_redis_json_object<T: for<'de> Deserialize<'de>>(
73 &mut self,
74 redis_key: &str,
75 ) -> Result<Option<T>, RedisFetchErrors> {
76 let val: RedisResult<Option<String>> = self.redis_pool.get(redis_key).await;
77
78 match val {
79 Ok(val) => match val {
80 None => Ok(None),
81 Some(val) => Ok(serde_json::from_str(&val).map_err(|err| {
82 log::error!("Error parsing redis data: {}", err);
83 RedisFetchErrors::ParseError
84 }))?,
85 },
86 Err(_) => Err(RedisFetchErrors::FromDbError),
87 }
88 }
89
90 /// Gets a value for a key
91 pub async fn fetch_redis<T: FromRedisValue>(
92 &mut self,
93 redis_key: &str,
94 ) -> Result<T, RedisFetchErrors> {
95 let val = self
96 .redis_pool
97 .get(redis_key)
98 .await
99 .map_err(|_| RedisFetchErrors::FromDbError)?;
100
101 Ok(val)
102 }
103
104 /// Gets or sets a value for a given key
105 pub async fn get_or_set<T, F, Fut>(
106 &mut self,
107 redis_key: &str,
108 seconds: u64,
109 fallback_fn: F,
110 ) -> Result<T, RedisFetchErrors>
111 where
112 T: for<'de> Deserialize<'de> + Serialize,
113 F: FnOnce() -> Fut,
114 Fut: std::future::Future<Output = Result<T, RedisFetchErrors>>,
115 {
116 // Try to get from cache first
117 match self.fetch_redis_json_object::<T>(redis_key).await {
118 Ok(Some(val)) => Ok(val),
119 Ok(None) => {
120 // If not in cache or error, execute the fallback function
121 let result = fallback_fn().await?;
122
123 // Write the result to cache
124 self.write_to_cache_with_seconds(redis_key, &result, seconds)
125 .await
126 .map_err(|err| {
127 log::error!("Error fetching from redis: {}", err);
128 RedisFetchErrors::FromDbError
129 })?;
130
131 Ok(result)
132 }
133 Err(err) => Err(err),
134 }
135 }
136}
137
138pub type ConnectionPool = Pool<RedisConnectionManager>;
139
140pub struct CacheConnection<'a>(pub Cache<'a>);
141
142impl<'a, S> FromRequestParts<S> for CacheConnection<'a>
143where
144 ConnectionPool: FromRef<S>,
145 S: Send + Sync,
146{
147 type Rejection = (StatusCode, String);
148
149 async fn from_request_parts(_parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
150 let pool = ConnectionPool::from_ref(state);
151
152 let conn = pool.get_owned().await.map_err(internal_error)?;
153 let cache = Cache { redis_pool: conn };
154 Ok(Self(cache))
155 }
156}