1/// A bunch of syntax sugar too make strongly typed sessions for Axum's sessions store
2use crate::error_response;
3use axum::extract::FromRequestParts;
4use axum::http::StatusCode;
5use axum::http::request::Parts;
6use axum::response::Response;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use tower_sessions::Session;
11
12#[derive(Debug, Deserialize, Serialize, Clone)]
13pub enum FlashMessage {
14 Success(String),
15 Error(String),
16}
17
18/// THis is the actual session store for axum sessions
19#[derive(Debug, Deserialize, Serialize)]
20struct SessionData {
21 did: Option<String>,
22
23 flash_message: HashMap<String, FlashMessage>,
24}
25
26impl Default for SessionData {
27 fn default() -> Self {
28 Self {
29 did: None,
30 flash_message: HashMap::new(),
31 }
32 }
33}
34
35pub struct AxumSessionStore {
36 session: Session,
37 data: SessionData,
38}
39
40/// How you actually interact with the session store
41impl AxumSessionStore {
42 const SESSION_DATA_KEY: &'static str = "session.data";
43
44 pub fn _logged_in(&self) -> bool {
45 self.data.did.is_some()
46 }
47
48 pub async fn set_did(&mut self, did: String) -> Result<(), tower_sessions::session::Error> {
49 self.data.did = Some(did);
50 Self::update_session(&self.session, &self.data).await
51 }
52
53 pub fn get_did(&self) -> Option<String> {
54 self.data.did.clone()
55 }
56
57 ///Gets the message as well as removes it from the session
58 pub async fn get_flash_message(
59 &mut self,
60 key: &str,
61 ) -> Result<Option<FlashMessage>, tower_sessions::session::Error> {
62 let message = self.data.flash_message.get(key).cloned();
63 if message.is_some() {
64 self.data.flash_message.remove(key);
65 Self::update_session(&self.session, &self.data).await?
66 }
67 Ok(message)
68 }
69
70 pub async fn set_flash_message(
71 &mut self,
72 key: &str,
73 message: FlashMessage,
74 ) -> Result<(), tower_sessions::session::Error> {
75 self.data.flash_message.insert(key.to_string(), message);
76 Self::update_session(&self.session, &self.data).await
77 }
78
79 /// Make sure to call this or your session won't actually be saved
80 async fn update_session(
81 session: &Session,
82 session_data: &SessionData,
83 ) -> Result<(), tower_sessions::session::Error> {
84 session.insert(Self::SESSION_DATA_KEY, session_data).await
85 }
86}
87
88impl fmt::Display for AxumSessionStore {
89 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
90 f.debug_struct("SessionStore")
91 .field("did", &self.data.did)
92 .finish()
93 }
94}
95
96impl<S> FromRequestParts<S> for AxumSessionStore
97where
98 S: Send + Sync,
99{
100 type Rejection = (StatusCode, &'static str);
101
102 async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
103 let session = Session::from_request_parts(req, state).await?;
104
105 let data: SessionData = session
106 .get(Self::SESSION_DATA_KEY)
107 .await
108 .unwrap()
109 .unwrap_or_default();
110
111 Ok(Self { session, data })
112 }
113}
114
115/// Helper wrapper for handling http responses if theres an error
116pub async fn set_flash_message(
117 session: &mut AxumSessionStore,
118 key: &str,
119 flash_message: FlashMessage,
120) -> Result<(), Response> {
121 session
122 .set_flash_message(key, flash_message)
123 .await
124 .map_err(|err| {
125 log::error!("Error setting flash message: {err}");
126 error_response(
127 StatusCode::INTERNAL_SERVER_ERROR,
128 "Error setting flash message",
129 )
130 })
131}
132
133/// Helper wrapper for handling http responses if theres an error
134pub async fn get_flash_message(
135 session: &mut AxumSessionStore,
136 key: &str,
137) -> Result<Option<FlashMessage>, Response> {
138 match session.get_flash_message(key).await {
139 Ok(message) => Ok(message),
140 Err(err) => {
141 log::error!("Error getting flash message: {err}");
142 Err(error_response(
143 StatusCode::INTERNAL_SERVER_ERROR,
144 "Error getting flash message",
145 ))
146 }
147 }
148}