forked from
oppi.li/at-advent
this repo has no description
1/// A bunch of syntax sugar too make strongly typed sessions for Axum's sessions store
2use crate::error_response;
3use axum::extract::FromRequestParts;
4use axum::http::StatusCode;
5use axum::http::request::Parts;
6use axum::response::Response;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fmt;
10use tower_sessions::Session;
11
12#[derive(Debug, Deserialize, Serialize, Clone)]
13pub enum FlashMessage {
14 Success(String),
15 Error(String),
16}
17
18/// THis is the actual session store for axum sessions
19#[derive(Debug, Deserialize, Serialize)]
20struct SessionData {
21 did: Option<String>,
22
23 flash_message: HashMap<String, FlashMessage>,
24}
25
26impl Default for SessionData {
27 fn default() -> Self {
28 Self {
29 did: None,
30 flash_message: HashMap::new(),
31 }
32 }
33}
34
35pub struct AxumSessionStore {
36 session: Session,
37 data: SessionData,
38}
39
40/// How you actually interact with the session store
41impl AxumSessionStore {
42 const SESSION_DATA_KEY: &'static str = "session.data";
43
44 pub fn logged_in(&self) -> bool {
45 self.data.did.is_some()
46 }
47
48 pub async fn set_did(&mut self, did: String) -> Result<(), tower_sessions::session::Error> {
49 self.data.did = Some(did);
50 Self::update_session(&self.session, &self.data).await
51 }
52
53 pub fn get_did(&self) -> Option<String> {
54 self.data.did.clone()
55 }
56
57 /// Clears the session data (logs out the user)
58 pub async fn clear_session(&mut self) -> Result<(), tower_sessions::session::Error> {
59 self.data = SessionData::default();
60 Self::update_session(&self.session, &self.data).await
61 }
62
63 ///Gets the message as well as removes it from the session
64 pub async fn get_flash_message(
65 &mut self,
66 key: &str,
67 ) -> Result<Option<FlashMessage>, tower_sessions::session::Error> {
68 let message = self.data.flash_message.get(key).cloned();
69 if message.is_some() {
70 self.data.flash_message.remove(key);
71 Self::update_session(&self.session, &self.data).await?
72 }
73 Ok(message)
74 }
75
76 pub async fn set_flash_message(
77 &mut self,
78 key: &str,
79 message: FlashMessage,
80 ) -> Result<(), tower_sessions::session::Error> {
81 self.data.flash_message.insert(key.to_string(), message);
82 Self::update_session(&self.session, &self.data).await
83 }
84
85 /// Make sure to call this or your session won't actually be saved
86 async fn update_session(
87 session: &Session,
88 session_data: &SessionData,
89 ) -> Result<(), tower_sessions::session::Error> {
90 session.insert(Self::SESSION_DATA_KEY, session_data).await
91 }
92}
93
94impl fmt::Display for AxumSessionStore {
95 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
96 f.debug_struct("SessionStore")
97 .field("did", &self.data.did)
98 .finish()
99 }
100}
101
102impl<S> FromRequestParts<S> for AxumSessionStore
103where
104 S: Send + Sync,
105{
106 type Rejection = (StatusCode, &'static str);
107
108 async fn from_request_parts(req: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
109 let session = Session::from_request_parts(req, state).await?;
110
111 let data: SessionData = session
112 .get(Self::SESSION_DATA_KEY)
113 .await
114 .unwrap()
115 .unwrap_or_default();
116
117 Ok(Self { session, data })
118 }
119}
120
121/// Helper wrapper for handling http responses if theres an error
122pub async fn set_flash_message(
123 session: &mut AxumSessionStore,
124 key: &str,
125 flash_message: FlashMessage,
126) -> Result<(), Response> {
127 session
128 .set_flash_message(key, flash_message)
129 .await
130 .map_err(|err| {
131 log::error!("Error setting flash message: {err}");
132 error_response(
133 StatusCode::INTERNAL_SERVER_ERROR,
134 "Error setting flash message",
135 )
136 })
137}
138
139/// Helper wrapper for handling http responses if theres an error
140pub async fn get_flash_message(
141 session: &mut AxumSessionStore,
142 key: &str,
143) -> Result<Option<FlashMessage>, Response> {
144 match session.get_flash_message(key).await {
145 Ok(message) => Ok(message),
146 Err(err) => {
147 log::error!("Error getting flash message: {err}");
148 Err(error_response(
149 StatusCode::INTERNAL_SERVER_ERROR,
150 "Error getting flash message",
151 ))
152 }
153 }
154}