this repo has no description
1use std::collections::HashMap;
2
3use anyhow::anyhow;
4use axum::{
5 extract::{self, Request},
6 middleware::Next,
7 response::{self, IntoResponse, Response},
8};
9use http::StatusCode;
10use oauth2::{
11 AuthUrl, EndpointNotSet, EndpointSet, RedirectUrl, TokenResponse as _, TokenUrl, basic::*, *,
12};
13use serde::{Deserialize, Serialize};
14use tracing::{debug, error, info, warn};
15
16use crate::{
17 AppState, AthleteId,
18 api::{self, UserSession},
19 config::{self},
20 strava,
21};
22
23#[derive(Deserialize, Debug, Serialize, Clone)]
24pub struct ExtraTokenFields {
25 pub athlete: Option<strava::Athlete>,
26}
27
28impl oauth2::ExtraTokenFields for ExtraTokenFields {}
29
30type TokenResponse = oauth2::StandardTokenResponse<ExtraTokenFields, oauth2::basic::BasicTokenType>;
31
32pub type OAuthClient = Client<
33 BasicErrorResponse,
34 TokenResponse,
35 BasicTokenIntrospectionResponse,
36 StandardRevocableToken,
37 BasicRevocationErrorResponse,
38 EndpointSet,
39 EndpointNotSet,
40 EndpointNotSet,
41 EndpointNotSet,
42 EndpointSet,
43>;
44
45pub fn oauth_client(config: &config::Config) -> Result<OAuthClient, anyhow::Error> {
46 let config::StravaConfig {
47 client_id,
48 client_secret,
49 auth_url,
50 token_url,
51 redirect_url,
52 } = &config.strava;
53
54 let client = oauth2::Client::new(client_id.clone())
55 .set_client_secret(client_secret.clone())
56 .set_auth_uri(AuthUrl::new(auth_url.to_string())?)
57 .set_token_uri(TokenUrl::new(token_url.to_string())?)
58 .set_auth_type(oauth2::AuthType::RequestBody)
59 // Set the URL the user will be redirected to after the authorization process.
60 .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?);
61
62 Ok(client)
63}
64
65#[derive(Deserialize)]
66pub struct Auth {
67 code: oauth2::AuthorizationCode,
68 state: oauth2::CsrfToken,
69}
70
71// #[axum::debug_handler]
72#[tracing::instrument(level = "debug", skip_all)]
73pub async fn callback(
74 state: extract::State<AppState<'_>>,
75 auth: extract::Query<Auth>,
76 session: tower_sessions::Session,
77) -> Result<impl IntoResponse, impl IntoResponse> {
78 let session_csrf_token: Option<oauth2::CsrfToken> = session.get("csrf_token").await.unwrap();
79
80 match session_csrf_token {
81 None => {
82 error!("Missing csrf_token in Session");
83 return Err((StatusCode::BAD_REQUEST, "Missing csrf_token in Session").into_response());
84 }
85 Some(session_code) if session_code.secret() == auth.state.secret() => {
86 debug!(code = ?session_code, "Valid csrf_token");
87 }
88 Some(session_code) => {
89 error!(
90 expected = auth.state.secret(),
91 actual = session_code.secret(),
92 "csrf_token mismatch"
93 );
94 return Err((StatusCode::BAD_REQUEST, "Invalid csrf_token in Session").into_response());
95 }
96 };
97
98 debug!("exchanging token");
99
100 let http_client = reqwest::ClientBuilder::new()
101 // Following redirects opens the client up to SSRF vulnerabilities.
102 .redirect(reqwest::redirect::Policy::none())
103 .connection_verbose(true)
104 .build()
105 .expect("Client should build");
106
107 // Exchange the code with a token.
108 let token_response = state
109 .oauth_client
110 .exchange_code(auth.0.code)
111 .request_async(&http_client)
112 .await
113 .expect("Failed to exchange token");
114
115 debug!(
116 ?token_response,
117 "authentication successful, fetching user-info"
118 );
119
120 let Some(athlete) = &token_response.extra_fields().athlete else {
121 error!("Didn't get TokenResponse.athlete");
122 return Err((
123 StatusCode::INTERNAL_SERVER_ERROR,
124 "Missing 'athlete' in TokenResponse",
125 )
126 .into_response());
127 };
128
129 debug!(?athlete);
130
131 let db = state.acquire_db().await;
132 let athlete_entity = db
133 .find(athlete.id)
134 .next()
135 .unwrap_or_else(|| db.new_entity().attach(athlete.id));
136
137 athlete_entity.attach(athlete.to_owned());
138
139 store_tokens(athlete_entity, &token_response);
140
141 let user_session = api::UserSession {
142 athlete_id: athlete.id,
143 };
144
145 session
146 .insert(api::USER_SESSION_KEY, user_session.clone())
147 .await
148 .unwrap();
149
150 info!(
151 %user_session.athlete_id,
152 "User authenticated"
153 );
154
155 Ok(axum::response::Redirect::to("/"))
156}
157
158#[tracing::instrument(level = "debug")]
159pub fn store_tokens(athlete: ecsdb::Entity, token_response: &TokenResponse) -> strava::AccessToken {
160 let access_token = strava::AccessToken {
161 token: token_response.access_token().to_owned(),
162 expires_at: chrono::Utc::now()
163 + token_response
164 .expires_in()
165 .expect("TokenResponse.expires_in"),
166 };
167 let refresh_token = strava::RefreshToken(
168 token_response
169 .refresh_token()
170 .expect("refreshToken")
171 .to_owned(),
172 );
173 if let Some(ref token_response_athlete) = token_response.extra_fields().athlete {
174 assert_eq!(
175 athlete.component::<AthleteId>().expect("AthleteId"),
176 token_response_athlete.id
177 );
178 }
179
180 athlete.attach((refresh_token, access_token.clone()));
181
182 debug!("tokens stored");
183
184 access_token
185}
186
187#[tracing::instrument(level = "debug", err, skip(client, state))]
188pub async fn refresh_access_token(
189 client: &reqwest::Client,
190 state: &AppState<'_>,
191 athlete: ecsdb::Entity<'_>,
192) -> Result<strava::AccessToken, anyhow::Error> {
193 debug!("Refreshing AccessToken via RefreshToken");
194
195 let strava::RefreshToken(refresh_token) = athlete
196 .component()
197 .ok_or(anyhow!("No RefreshToken on {}", athlete.id()))?;
198
199 let token_response = state
200 .oauth_client
201 .exchange_refresh_token(&refresh_token)
202 .request_async(client)
203 .await?;
204
205 Ok(store_tokens(athlete, &token_response))
206}
207
208#[tracing::instrument(level = "debug", skip_all)]
209pub async fn redirect(
210 state: extract::State<AppState<'_>>,
211 session: tower_sessions::Session,
212) -> impl IntoResponse {
213 use oauth2::*;
214
215 let scopes = [
216 // Scope::new("read".to_string()),
217 Scope::new("activity:read".to_string()),
218 ];
219
220 let (authorize_url, csrf_state) = state
221 .oauth_client
222 .authorize_url(CsrfToken::new_random)
223 .add_scopes(scopes)
224 .url();
225
226 session.insert("csrf_token", csrf_state).await.unwrap();
227
228 axum::response::Redirect::to(authorize_url.as_str()).into_response()
229}
230
231// #[axum::debug_middleware]
232// #[tracing::instrument(level = "debug", skip_all)]
233pub async fn user_session_middleware(
234 app_state: extract::State<crate::AppState<'_>>,
235 http_session: tower_sessions::Session,
236 _query_params: extract::Query<HashMap<String, String>>,
237 mut request: Request,
238 next: Next,
239) -> Response {
240 let user_session = http_session
241 .get::<api::UserSession>(api::USER_SESSION_KEY)
242 .await;
243
244 match user_session {
245 Ok(Some(session)) => {
246 debug!(?session, "Found UserSession");
247
248 // Check if the OAuth2 Session has an `AccessToken` component
249 let db = app_state.acquire_db().await;
250 if db
251 .find(session.athlete_id)
252 .next()
253 .is_some_and(|e| e.has::<strava::AccessToken>())
254 {
255 request.extensions_mut().insert(session);
256 } else {
257 warn!(
258 ?session,
259 reason = "Missing AccessToken",
260 "Invalidating Session"
261 );
262
263 http_session.clear().await;
264 }
265 }
266 Ok(None) => debug!("No session found"),
267 Err(e) => {
268 error!(?e, "Failed to extract UserSession from Session");
269 http_session.clear().await;
270 }
271 };
272
273 next.run(request).await
274}
275
276// #[tracing::instrument(level = "debug", skip_all)]
277pub async fn enforce_user_session(
278 user_session: Option<UserSession>,
279 request: Request,
280 next: Next,
281) -> Response {
282 match user_session {
283 Some(_) => next.run(request).await,
284 None => {
285 warn!("Unauthorized");
286
287 let is_htmx_request = request
288 .headers()
289 .get("HX-Request")
290 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true"));
291
292 if is_htmx_request {
293 Response::builder()
294 .status(StatusCode::UNAUTHORIZED)
295 .header("hx-refresh", "true")
296 .body(axum::body::Body::empty())
297 .unwrap()
298 .into_response()
299 } else {
300 info!("Redirecting to /login");
301 response::Redirect::to("/login").into_response()
302 }
303 }
304 }
305}