Learn how to use Rust to build ATProto powered applications
1/// Storage impls to persis OAuth sessions if you are not using the memory stores
2/// https://github.com/bluesky-social/statusphere-example-app/blob/main/src/auth/storage.ts
3use crate::db::{AuthSession, AuthState};
4use async_sqlite::Pool;
5use atrium_api::types::string::Did;
6use atrium_common::store::Store;
7use atrium_oauth::store::session::SessionStore;
8use atrium_oauth::store::state::StateStore;
9use serde::Serialize;
10use serde::de::DeserializeOwned;
11use std::fmt::Debug;
12use std::hash::Hash;
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum SqliteStoreError {
17 #[error("Invalid session")]
18 InvalidSession,
19 #[error("No session found")]
20 NoSessionFound,
21 #[error("Database error: {0}")]
22 DatabaseError(async_sqlite::Error),
23}
24
25///Persistent session store in sqlite
26impl SessionStore for SqliteSessionStore {}
27
28pub struct SqliteSessionStore {
29 db_pool: Pool,
30}
31
32impl SqliteSessionStore {
33 pub fn new(db: Pool) -> Self {
34 Self { db_pool: db }
35 }
36}
37
38impl<K, V> Store<K, V> for SqliteSessionStore
39where
40 K: Debug + Eq + Hash + Send + Sync + 'static + From<Did> + AsRef<str>,
41 V: Debug + Clone + Send + Sync + 'static + Serialize + DeserializeOwned,
42{
43 type Error = SqliteStoreError;
44 async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
45 let did = key.as_ref().to_string();
46 match AuthSession::get_by_did(&self.db_pool, did).await {
47 Ok(Some(auth_session)) => {
48 let deserialized_session: V = serde_json::from_str(&auth_session.session)
49 .map_err(|_| SqliteStoreError::InvalidSession)?;
50 Ok(Some(deserialized_session))
51 }
52 Ok(None) => Err(SqliteStoreError::NoSessionFound),
53 Err(db_error) => {
54 log::error!("Database error: {db_error}");
55 Err(SqliteStoreError::DatabaseError(db_error))
56 }
57 }
58 }
59
60 async fn set(&self, key: K, value: V) -> Result<(), Self::Error> {
61 let did = key.as_ref().to_string();
62 let auth_session = AuthSession::new(did, value);
63 auth_session
64 .save_or_update(&self.db_pool)
65 .await
66 .map_err(SqliteStoreError::DatabaseError)?;
67 Ok(())
68 }
69
70 async fn del(&self, _key: &K) -> Result<(), Self::Error> {
71 let did = _key.as_ref().to_string();
72 AuthSession::delete_by_did(&self.db_pool, did)
73 .await
74 .map_err(SqliteStoreError::DatabaseError)?;
75 Ok(())
76 }
77
78 async fn clear(&self) -> Result<(), Self::Error> {
79 AuthSession::delete_all(&self.db_pool)
80 .await
81 .map_err(SqliteStoreError::DatabaseError)?;
82 Ok(())
83 }
84}
85
86///Persistent session state in sqlite
87impl StateStore for SqliteStateStore {}
88
89pub struct SqliteStateStore {
90 db_pool: Pool,
91}
92
93impl SqliteStateStore {
94 pub fn new(db: Pool) -> Self {
95 Self { db_pool: db }
96 }
97}
98
99impl<K, V> Store<K, V> for SqliteStateStore
100where
101 K: Debug + Eq + Hash + Send + Sync + 'static + From<Did> + AsRef<str>,
102 V: Debug + Clone + Send + Sync + 'static + Serialize + DeserializeOwned,
103{
104 type Error = SqliteStoreError;
105 async fn get(&self, key: &K) -> Result<Option<V>, Self::Error> {
106 let key = key.as_ref().to_string();
107 match AuthState::get_by_key(&self.db_pool, key).await {
108 Ok(Some(auth_state)) => {
109 let deserialized_state: V = serde_json::from_str(&auth_state.state)
110 .map_err(|_| SqliteStoreError::InvalidSession)?;
111 Ok(Some(deserialized_state))
112 }
113 Ok(None) => Err(SqliteStoreError::NoSessionFound),
114 Err(db_error) => {
115 log::error!("Database error: {db_error}");
116 Err(SqliteStoreError::DatabaseError(db_error))
117 }
118 }
119 }
120
121 async fn set(&self, key: K, value: V) -> Result<(), Self::Error> {
122 let did = key.as_ref().to_string();
123 let auth_state = AuthState::new(did, value);
124 auth_state
125 .save_or_update(&self.db_pool)
126 .await
127 .map_err(SqliteStoreError::DatabaseError)?;
128 Ok(())
129 }
130
131 async fn del(&self, _key: &K) -> Result<(), Self::Error> {
132 let key = _key.as_ref().to_string();
133 AuthState::delete_by_key(&self.db_pool, key)
134 .await
135 .map_err(SqliteStoreError::DatabaseError)?;
136 Ok(())
137 }
138
139 async fn clear(&self) -> Result<(), Self::Error> {
140 AuthState::delete_all(&self.db_pool)
141 .await
142 .map_err(SqliteStoreError::DatabaseError)?;
143 Ok(())
144 }
145}