at main 305 lines 8.9 kB view raw
1use std::collections::HashMap; 2 3use anyhow::anyhow; 4use axum::{ 5 extract::{self, Request}, 6 middleware::Next, 7 response::{self, IntoResponse, Response}, 8}; 9use http::StatusCode; 10use oauth2::{ 11 AuthUrl, EndpointNotSet, EndpointSet, RedirectUrl, TokenResponse as _, TokenUrl, basic::*, *, 12}; 13use serde::{Deserialize, Serialize}; 14use tracing::{debug, error, info, warn}; 15 16use crate::{ 17 AppState, AthleteId, 18 api::{self, UserSession}, 19 config::{self}, 20 strava, 21}; 22 23#[derive(Deserialize, Debug, Serialize, Clone)] 24pub struct ExtraTokenFields { 25 pub athlete: Option<strava::Athlete>, 26} 27 28impl oauth2::ExtraTokenFields for ExtraTokenFields {} 29 30type TokenResponse = oauth2::StandardTokenResponse<ExtraTokenFields, oauth2::basic::BasicTokenType>; 31 32pub type OAuthClient = Client< 33 BasicErrorResponse, 34 TokenResponse, 35 BasicTokenIntrospectionResponse, 36 StandardRevocableToken, 37 BasicRevocationErrorResponse, 38 EndpointSet, 39 EndpointNotSet, 40 EndpointNotSet, 41 EndpointNotSet, 42 EndpointSet, 43>; 44 45pub fn oauth_client(config: &config::Config) -> Result<OAuthClient, anyhow::Error> { 46 let config::StravaConfig { 47 client_id, 48 client_secret, 49 auth_url, 50 token_url, 51 redirect_url, 52 } = &config.strava; 53 54 let client = oauth2::Client::new(client_id.clone()) 55 .set_client_secret(client_secret.clone()) 56 .set_auth_uri(AuthUrl::new(auth_url.to_string())?) 57 .set_token_uri(TokenUrl::new(token_url.to_string())?) 58 .set_auth_type(oauth2::AuthType::RequestBody) 59 // Set the URL the user will be redirected to after the authorization process. 60 .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?); 61 62 Ok(client) 63} 64 65#[derive(Deserialize)] 66pub struct Auth { 67 code: oauth2::AuthorizationCode, 68 state: oauth2::CsrfToken, 69} 70 71// #[axum::debug_handler] 72#[tracing::instrument(level = "debug", skip_all)] 73pub async fn callback( 74 state: extract::State<AppState<'_>>, 75 auth: extract::Query<Auth>, 76 session: tower_sessions::Session, 77) -> Result<impl IntoResponse, impl IntoResponse> { 78 let session_csrf_token: Option<oauth2::CsrfToken> = session.get("csrf_token").await.unwrap(); 79 80 match session_csrf_token { 81 None => { 82 error!("Missing csrf_token in Session"); 83 return Err((StatusCode::BAD_REQUEST, "Missing csrf_token in Session").into_response()); 84 } 85 Some(session_code) if session_code.secret() == auth.state.secret() => { 86 debug!(code = ?session_code, "Valid csrf_token"); 87 } 88 Some(session_code) => { 89 error!( 90 expected = auth.state.secret(), 91 actual = session_code.secret(), 92 "csrf_token mismatch" 93 ); 94 return Err((StatusCode::BAD_REQUEST, "Invalid csrf_token in Session").into_response()); 95 } 96 }; 97 98 debug!("exchanging token"); 99 100 let http_client = reqwest::ClientBuilder::new() 101 // Following redirects opens the client up to SSRF vulnerabilities. 102 .redirect(reqwest::redirect::Policy::none()) 103 .connection_verbose(true) 104 .build() 105 .expect("Client should build"); 106 107 // Exchange the code with a token. 108 let token_response = state 109 .oauth_client 110 .exchange_code(auth.0.code) 111 .request_async(&http_client) 112 .await 113 .expect("Failed to exchange token"); 114 115 debug!( 116 ?token_response, 117 "authentication successful, fetching user-info" 118 ); 119 120 let Some(athlete) = &token_response.extra_fields().athlete else { 121 error!("Didn't get TokenResponse.athlete"); 122 return Err(( 123 StatusCode::INTERNAL_SERVER_ERROR, 124 "Missing 'athlete' in TokenResponse", 125 ) 126 .into_response()); 127 }; 128 129 debug!(?athlete); 130 131 let db = state.acquire_db().await; 132 let athlete_entity = db 133 .find(athlete.id) 134 .next() 135 .unwrap_or_else(|| db.new_entity().attach(athlete.id)); 136 137 athlete_entity.attach(athlete.to_owned()); 138 139 store_tokens(athlete_entity, &token_response); 140 141 let user_session = api::UserSession { 142 athlete_id: athlete.id, 143 }; 144 145 session 146 .insert(api::USER_SESSION_KEY, user_session.clone()) 147 .await 148 .unwrap(); 149 150 info!( 151 %user_session.athlete_id, 152 "User authenticated" 153 ); 154 155 Ok(axum::response::Redirect::to("/")) 156} 157 158#[tracing::instrument(level = "debug")] 159pub fn store_tokens(athlete: ecsdb::Entity, token_response: &TokenResponse) -> strava::AccessToken { 160 let access_token = strava::AccessToken { 161 token: token_response.access_token().to_owned(), 162 expires_at: chrono::Utc::now() 163 + token_response 164 .expires_in() 165 .expect("TokenResponse.expires_in"), 166 }; 167 let refresh_token = strava::RefreshToken( 168 token_response 169 .refresh_token() 170 .expect("refreshToken") 171 .to_owned(), 172 ); 173 if let Some(ref token_response_athlete) = token_response.extra_fields().athlete { 174 assert_eq!( 175 athlete.component::<AthleteId>().expect("AthleteId"), 176 token_response_athlete.id 177 ); 178 } 179 180 athlete.attach((refresh_token, access_token.clone())); 181 182 debug!("tokens stored"); 183 184 access_token 185} 186 187#[tracing::instrument(level = "debug", err, skip(client, state))] 188pub async fn refresh_access_token( 189 client: &reqwest::Client, 190 state: &AppState<'_>, 191 athlete: ecsdb::Entity<'_>, 192) -> Result<strava::AccessToken, anyhow::Error> { 193 debug!("Refreshing AccessToken via RefreshToken"); 194 195 let strava::RefreshToken(refresh_token) = athlete 196 .component() 197 .ok_or(anyhow!("No RefreshToken on {}", athlete.id()))?; 198 199 let token_response = state 200 .oauth_client 201 .exchange_refresh_token(&refresh_token) 202 .request_async(client) 203 .await?; 204 205 Ok(store_tokens(athlete, &token_response)) 206} 207 208#[tracing::instrument(level = "debug", skip_all)] 209pub async fn redirect( 210 state: extract::State<AppState<'_>>, 211 session: tower_sessions::Session, 212) -> impl IntoResponse { 213 use oauth2::*; 214 215 let scopes = [ 216 // Scope::new("read".to_string()), 217 Scope::new("activity:read".to_string()), 218 ]; 219 220 let (authorize_url, csrf_state) = state 221 .oauth_client 222 .authorize_url(CsrfToken::new_random) 223 .add_scopes(scopes) 224 .url(); 225 226 session.insert("csrf_token", csrf_state).await.unwrap(); 227 228 axum::response::Redirect::to(authorize_url.as_str()).into_response() 229} 230 231// #[axum::debug_middleware] 232// #[tracing::instrument(level = "debug", skip_all)] 233pub async fn user_session_middleware( 234 app_state: extract::State<crate::AppState<'_>>, 235 http_session: tower_sessions::Session, 236 _query_params: extract::Query<HashMap<String, String>>, 237 mut request: Request, 238 next: Next, 239) -> Response { 240 let user_session = http_session 241 .get::<api::UserSession>(api::USER_SESSION_KEY) 242 .await; 243 244 match user_session { 245 Ok(Some(session)) => { 246 debug!(?session, "Found UserSession"); 247 248 // Check if the OAuth2 Session has an `AccessToken` component 249 let db = app_state.acquire_db().await; 250 if db 251 .find(session.athlete_id) 252 .next() 253 .is_some_and(|e| e.has::<strava::AccessToken>()) 254 { 255 request.extensions_mut().insert(session); 256 } else { 257 warn!( 258 ?session, 259 reason = "Missing AccessToken", 260 "Invalidating Session" 261 ); 262 263 http_session.clear().await; 264 } 265 } 266 Ok(None) => debug!("No session found"), 267 Err(e) => { 268 error!(?e, "Failed to extract UserSession from Session"); 269 http_session.clear().await; 270 } 271 }; 272 273 next.run(request).await 274} 275 276// #[tracing::instrument(level = "debug", skip_all)] 277pub async fn enforce_user_session( 278 user_session: Option<UserSession>, 279 request: Request, 280 next: Next, 281) -> Response { 282 match user_session { 283 Some(_) => next.run(request).await, 284 None => { 285 warn!("Unauthorized"); 286 287 let is_htmx_request = request 288 .headers() 289 .get("HX-Request") 290 .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); 291 292 if is_htmx_request { 293 Response::builder() 294 .status(StatusCode::UNAUTHORIZED) 295 .header("hx-refresh", "true") 296 .body(axum::body::Body::empty()) 297 .unwrap() 298 .into_response() 299 } else { 300 info!("Redirecting to /login"); 301 response::Redirect::to("/login").into_response() 302 } 303 } 304 } 305}