/// A bunch of syntax sugar too make strongly typed sessions for Axum's sessions store use crate::error_response; use axum::extract::FromRequestParts; use axum::http::StatusCode; use axum::http::request::Parts; use axum::response::Response; use serde::{Deserialize, Serialize}; use std::collections::HashMap; use std::fmt; use tower_sessions::Session; #[derive(Debug, Deserialize, Serialize, Clone)] pub enum FlashMessage { Success(String), Error(String), } /// THis is the actual session store for axum sessions #[derive(Debug, Deserialize, Serialize)] struct SessionData { did: Option, flash_message: HashMap, } impl Default for SessionData { fn default() -> Self { Self { did: None, flash_message: HashMap::new(), } } } pub struct AxumSessionStore { session: Session, data: SessionData, } /// How you actually interact with the session store impl AxumSessionStore { const SESSION_DATA_KEY: &'static str = "session.data"; pub fn _logged_in(&self) -> bool { self.data.did.is_some() } pub async fn set_did(&mut self, did: String) -> Result<(), tower_sessions::session::Error> { self.data.did = Some(did); Self::update_session(&self.session, &self.data).await } pub fn get_did(&self) -> Option { self.data.did.clone() } ///Gets the message as well as removes it from the session pub async fn get_flash_message( &mut self, key: &str, ) -> Result, tower_sessions::session::Error> { let message = self.data.flash_message.get(key).cloned(); if message.is_some() { self.data.flash_message.remove(key); Self::update_session(&self.session, &self.data).await? } Ok(message) } pub async fn set_flash_message( &mut self, key: &str, message: FlashMessage, ) -> Result<(), tower_sessions::session::Error> { self.data.flash_message.insert(key.to_string(), message); Self::update_session(&self.session, &self.data).await } /// Make sure to call this or your session won't actually be saved async fn update_session( session: &Session, session_data: &SessionData, ) -> Result<(), tower_sessions::session::Error> { session.insert(Self::SESSION_DATA_KEY, session_data).await } } impl fmt::Display for AxumSessionStore { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { f.debug_struct("SessionStore") .field("did", &self.data.did) .finish() } } impl FromRequestParts for AxumSessionStore where S: Send + Sync, { type Rejection = (StatusCode, &'static str); async fn from_request_parts(req: &mut Parts, state: &S) -> Result { let session = Session::from_request_parts(req, state).await?; let data: SessionData = session .get(Self::SESSION_DATA_KEY) .await .unwrap() .unwrap_or_default(); Ok(Self { session, data }) } } /// Helper wrapper for handling http responses if theres an error pub async fn set_flash_message( session: &mut AxumSessionStore, key: &str, flash_message: FlashMessage, ) -> Result<(), Response> { session .set_flash_message(key, flash_message) .await .map_err(|err| { log::error!("Error setting flash message: {err}"); error_response( StatusCode::INTERNAL_SERVER_ERROR, "Error setting flash message", ) }) } /// Helper wrapper for handling http responses if theres an error pub async fn get_flash_message( session: &mut AxumSessionStore, key: &str, ) -> Result, Response> { match session.get_flash_message(key).await { Ok(message) => Ok(message), Err(err) => { log::error!("Error getting flash message: {err}"); Err(error_response( StatusCode::INTERNAL_SERVER_ERROR, "Error getting flash message", )) } } }