use async_trait::async_trait; use bb8::Pool; use bb8_redis::{RedisConnectionManager, redis::cmd}; use shared::cache::{Cache, TOWER_SESSION_KEY, create_prefixed_key}; use std::fmt::Display; use std::fmt::{Debug, Formatter}; use tower_sessions::SessionStore; use tower_sessions::session::{Id, Record}; use tower_sessions::session_store::{Error, Result as StoreResult}; #[derive(Clone)] pub struct RedisSessionStore { cache_pool: Pool, } impl RedisSessionStore { pub fn new(cache_pool: Pool) -> Self { Self { cache_pool } } } impl Debug for RedisSessionStore { fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { f.debug_struct("RedisSessionStore").finish() } } // Small helper to convert any error into tower-sessions Backend error consistently fn backend_map(context: &'static str) -> impl FnOnce(E) -> Error { move |err| { log::error!("{}: {}", context, err); Error::Backend(err.to_string()) } } #[async_trait] impl SessionStore for RedisSessionStore { async fn create(&self, session_record: &mut Record) -> StoreResult<()> { //TODO i don't think there is an issue with overwriting the session here since it's redis and should be no collision //The default create throws a warning about this so, added this to get rid of it and adding a note in case it does cause a problem self.save(session_record).await } async fn save(&self, session_record: &Record) -> StoreResult<()> { let id_as_str: String = session_record.id.0.to_string(); let key = create_prefixed_key(TOWER_SESSION_KEY, id_as_str.as_str()); // Get a redis connection let conn = self .cache_pool .get() .await .map_err(backend_map("There was an error connecting to the cache"))?; // Set value with TTL based on expiry_date let expiry = session_record.expiry_date; let now = std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) .unwrap_or_default() .as_secs() as i64; let ttl_secs = expiry.unix_timestamp().saturating_sub(now).max(0) as usize; //Helper for some cache functions let mut cache = Cache { redis_pool: conn }; cache .write_to_cache_with_seconds(&key, &session_record, ttl_secs as u64) .await .map_err(backend_map("There was an error saving the session"))?; Ok(()) } async fn load(&self, session_id: &Id) -> StoreResult> { let id_as_str: String = session_id.0.to_string(); let key = create_prefixed_key(TOWER_SESSION_KEY, id_as_str.as_str()); let conn = self .cache_pool .get() .await .map_err(backend_map("There was an error connecting to the cache"))?; let mut cache = Cache { redis_pool: conn }; let val = match cache.fetch_redis_json_object::>(&key).await { Ok(Some(record)) => Ok(record), Ok(None) => Ok(None), Err(err) => Err(err), } .map_err(backend_map("There was an error loading the session"))?; Ok(val) } async fn delete(&self, session_id: &Id) -> StoreResult<()> { let id_as_str: String = session_id.0.to_string(); let key = create_prefixed_key(TOWER_SESSION_KEY, id_as_str.as_str()); let mut conn = self .cache_pool .get() .await .map_err(backend_map("There was an error connecting to the cache"))?; let _: usize = cmd("DEL") .arg(&key) .query_async::(&mut *conn) .await .map_err(backend_map("There was an error deleting the session"))?; Ok(()) } }