forked from
smokesignal.events/smokesignal
i18n+filtering fork - fluent-templates v2
1use anyhow::Result;
2use chrono::{Duration, Utc};
3use deadpool_redis::redis::{pipe, AsyncCommands};
4use p256::SecretKey;
5use std::borrow::Cow;
6use tokio::time::{sleep, Instant};
7use tokio_util::sync::CancellationToken;
8
9use crate::{
10 config::{OAuthActiveKeys, SigningKeys},
11 oauth::client_oauth_refresh,
12 refresh_tokens_errors::RefreshError,
13 storage::{
14 cache::{build_worker_queue, OAUTH_REFRESH_HEARTBEATS, OAUTH_REFRESH_QUEUE},
15 oauth::{oauth_session_delete, oauth_session_update, web_session_lookup},
16 CachePool, StoragePool,
17 },
18};
19
20pub struct RefreshTokensTaskConfig {
21 pub sleep_interval: Duration,
22 pub worker_id: String,
23 pub external_url_base: String,
24 pub signing_keys: SigningKeys,
25 pub oauth_active_keys: OAuthActiveKeys,
26}
27
28pub struct RefreshTokensTask {
29 pub config: RefreshTokensTaskConfig,
30 pub http_client: reqwest::Client,
31 pub storage_pool: StoragePool,
32 pub cache_pool: CachePool,
33 pub cancellation_token: CancellationToken,
34}
35
36impl RefreshTokensTask {
37 #[must_use]
38 pub fn new(
39 config: RefreshTokensTaskConfig,
40 http_client: reqwest::Client,
41 storage_pool: StoragePool,
42 cache_pool: CachePool,
43 cancellation_token: CancellationToken,
44 ) -> Self {
45 Self {
46 config,
47 http_client,
48 storage_pool,
49 cache_pool,
50 cancellation_token,
51 }
52 }
53
54 /// Runs the refresh tokens task as a long-running process
55 ///
56 /// # Errors
57 /// Returns an error if the sleep interval cannot be converted, or if there's a problem
58 /// processing the work items
59 pub async fn run(&self) -> Result<()> {
60 tracing::debug!("RefreshTokensTask started");
61
62 let interval = self.config.sleep_interval.to_std()?;
63
64 let sleeper = sleep(interval);
65 tokio::pin!(sleeper);
66
67 loop {
68 tokio::select! {
69 () = self.cancellation_token.cancelled() => {
70 break;
71 },
72 () = &mut sleeper => {
73 if let Err(err) = self.process_work().await {
74 tracing::error!("RefreshTokensTask failed: {}", err);
75 }
76 sleeper.as_mut().reset(Instant::now() + interval);
77 }
78 }
79 }
80
81 tracing::info!("RefreshTokensTask stopped");
82
83 Ok(())
84 }
85
86 async fn process_work(&self) -> Result<i32> {
87 let worker_queue = build_worker_queue(&self.config.worker_id);
88
89 let mut conn = self.cache_pool.get().await?;
90
91 let now = chrono::Utc::now();
92 let epoch_millis = now.timestamp_millis();
93
94 let _: () = conn
95 .hset(
96 OAUTH_REFRESH_HEARTBEATS,
97 &self.config.worker_id,
98 now.to_string(),
99 )
100 .await?;
101
102 let global_queue_count: i32 = conn
103 .zcount(OAUTH_REFRESH_QUEUE, 0, epoch_millis + 1)
104 .await?;
105 let worker_queue_count: i32 = conn.zcount(&worker_queue, 0, epoch_millis + 1).await?;
106
107 tracing::trace!(
108 global_queue_count = global_queue_count,
109 worker_queue_count = worker_queue_count,
110 "queue counts"
111 );
112
113 let mut process_work = worker_queue_count > 0;
114
115 if global_queue_count > 0 && worker_queue_count == 0 {
116 let (moved, new_count): (i64, i64) = pipe()
117 .atomic()
118 // Take some work from the global queue and put it in the worker queue
119 // ZRANGESTORE dst src min max [BYSCORE | BYLEX] [REV] [LIMIT offset count]
120 .cmd("ZRANGESTORE")
121 .arg(&worker_queue)
122 .arg(OAUTH_REFRESH_QUEUE)
123 .arg(0)
124 .arg(epoch_millis)
125 .arg("BYSCORE")
126 .arg("LIMIT")
127 .arg(0)
128 .arg(5)
129 // Update the global queue to remove the items that were moved
130 .cmd("ZDIFFSTORE")
131 .arg(OAUTH_REFRESH_QUEUE)
132 .arg(2)
133 .arg(OAUTH_REFRESH_QUEUE)
134 .arg(&worker_queue)
135 .query_async(&mut conn)
136 .await?;
137 process_work = true;
138
139 tracing::debug!(
140 moved = moved,
141 new_count = new_count,
142 "moved work from global queue to worker queue"
143 );
144 }
145
146 if !process_work {
147 return Ok(0);
148 }
149
150 let count = 0;
151 let results: Vec<(String, i64)> = conn
152 .zrangebyscore_limit_withscores(&worker_queue, 0, epoch_millis, 0, 5)
153 .await?;
154
155 for (session_group, deadline) in results {
156 tracing::info!(session_group, deadline, "processing work");
157 let _: () = conn.zrem(&worker_queue, &session_group).await?;
158
159 if let Err(err) = self
160 .refresh_oauth_session(&mut conn, &session_group, deadline)
161 .await
162 {
163 tracing::error!(session_group, deadline, err = ?err, "failed to refresh oauth session: {}", err);
164
165 if let Err(err) = oauth_session_delete(&self.storage_pool, &session_group).await {
166 tracing::error!(session_group, err = ?err, "failed to delete oauth session: {}", err);
167 }
168 }
169 }
170
171 Ok(count)
172 }
173
174 async fn refresh_oauth_session(
175 &self,
176 conn: &mut deadpool_redis::Connection,
177 session_group: &str,
178 _deadline: i64,
179 ) -> Result<()> {
180 let (handle, oauth_session) =
181 web_session_lookup(&self.storage_pool, session_group, None).await?;
182
183 let secret_signing_key = self
184 .config
185 .signing_keys
186 .as_ref()
187 .get(&oauth_session.secret_jwk_id)
188 .cloned();
189
190 if secret_signing_key.is_none() {
191 return Err(RefreshError::SecretSigningKeyNotFound.into());
192 }
193
194 let dpop_secret_key = SecretKey::from_jwk(&oauth_session.dpop_jwk.jwk)
195 .map_err(RefreshError::DpopProofCreationFailed)?;
196
197 let token_response = client_oauth_refresh(
198 &self.http_client,
199 &self.config.external_url_base,
200 (&oauth_session.secret_jwk_id, secret_signing_key.unwrap()),
201 oauth_session.refresh_token.as_str(),
202 &handle,
203 &dpop_secret_key,
204 )
205 .await?;
206
207 let now = Utc::now();
208
209 oauth_session_update(
210 &self.storage_pool,
211 Cow::Borrowed(session_group),
212 Cow::Borrowed(&token_response.access_token),
213 Cow::Borrowed(&token_response.refresh_token),
214 now + chrono::Duration::seconds(i64::from(token_response.expires_in)),
215 )
216 .await?;
217
218 let modified_expires_at = ((f64::from(token_response.expires_in)) * 0.8).round() as i64;
219 let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis();
220
221 let _: () = conn
222 .zadd(OAUTH_REFRESH_QUEUE, session_group, refresh_at)
223 .await
224 .map_err(RefreshError::PlaceInRefreshQueueFailed)?;
225
226 Ok(())
227 }
228}