use std::collections::HashMap; use anyhow::anyhow; use axum::{ extract::{self, Request}, middleware::Next, response::{self, IntoResponse, Response}, }; use http::StatusCode; use oauth2::{ AuthUrl, EndpointNotSet, EndpointSet, RedirectUrl, TokenResponse as _, TokenUrl, basic::*, *, }; use serde::{Deserialize, Serialize}; use tracing::{debug, error, info, warn}; use crate::{ AppState, AthleteId, api::{self, UserSession}, config::{self}, strava, }; #[derive(Deserialize, Debug, Serialize, Clone)] pub struct ExtraTokenFields { pub athlete: Option, } impl oauth2::ExtraTokenFields for ExtraTokenFields {} type TokenResponse = oauth2::StandardTokenResponse; pub type OAuthClient = Client< BasicErrorResponse, TokenResponse, BasicTokenIntrospectionResponse, StandardRevocableToken, BasicRevocationErrorResponse, EndpointSet, EndpointNotSet, EndpointNotSet, EndpointNotSet, EndpointSet, >; pub fn oauth_client(config: &config::Config) -> Result { let config::StravaConfig { client_id, client_secret, auth_url, token_url, redirect_url, } = &config.strava; let client = oauth2::Client::new(client_id.clone()) .set_client_secret(client_secret.clone()) .set_auth_uri(AuthUrl::new(auth_url.to_string())?) .set_token_uri(TokenUrl::new(token_url.to_string())?) .set_auth_type(oauth2::AuthType::RequestBody) // Set the URL the user will be redirected to after the authorization process. .set_redirect_uri(RedirectUrl::new(redirect_url.to_string())?); Ok(client) } #[derive(Deserialize)] pub struct Auth { code: oauth2::AuthorizationCode, state: oauth2::CsrfToken, } // #[axum::debug_handler] #[tracing::instrument(level = "debug", skip_all)] pub async fn callback( state: extract::State>, auth: extract::Query, session: tower_sessions::Session, ) -> Result { let session_csrf_token: Option = session.get("csrf_token").await.unwrap(); match session_csrf_token { None => { error!("Missing csrf_token in Session"); return Err((StatusCode::BAD_REQUEST, "Missing csrf_token in Session").into_response()); } Some(session_code) if session_code.secret() == auth.state.secret() => { debug!(code = ?session_code, "Valid csrf_token"); } Some(session_code) => { error!( expected = auth.state.secret(), actual = session_code.secret(), "csrf_token mismatch" ); return Err((StatusCode::BAD_REQUEST, "Invalid csrf_token in Session").into_response()); } }; debug!("exchanging token"); let http_client = reqwest::ClientBuilder::new() // Following redirects opens the client up to SSRF vulnerabilities. .redirect(reqwest::redirect::Policy::none()) .connection_verbose(true) .build() .expect("Client should build"); // Exchange the code with a token. let token_response = state .oauth_client .exchange_code(auth.0.code) .request_async(&http_client) .await .expect("Failed to exchange token"); debug!( ?token_response, "authentication successful, fetching user-info" ); let Some(athlete) = &token_response.extra_fields().athlete else { error!("Didn't get TokenResponse.athlete"); return Err(( StatusCode::INTERNAL_SERVER_ERROR, "Missing 'athlete' in TokenResponse", ) .into_response()); }; debug!(?athlete); let db = state.acquire_db().await; let athlete_entity = db .find(athlete.id) .next() .unwrap_or_else(|| db.new_entity().attach(athlete.id)); athlete_entity.attach(athlete.to_owned()); store_tokens(athlete_entity, &token_response); let user_session = api::UserSession { athlete_id: athlete.id, }; session .insert(api::USER_SESSION_KEY, user_session.clone()) .await .unwrap(); info!( %user_session.athlete_id, "User authenticated" ); Ok(axum::response::Redirect::to("/")) } #[tracing::instrument(level = "debug")] pub fn store_tokens(athlete: ecsdb::Entity, token_response: &TokenResponse) -> strava::AccessToken { let access_token = strava::AccessToken { token: token_response.access_token().to_owned(), expires_at: chrono::Utc::now() + token_response .expires_in() .expect("TokenResponse.expires_in"), }; let refresh_token = strava::RefreshToken( token_response .refresh_token() .expect("refreshToken") .to_owned(), ); if let Some(ref token_response_athlete) = token_response.extra_fields().athlete { assert_eq!( athlete.component::().expect("AthleteId"), token_response_athlete.id ); } athlete.attach((refresh_token, access_token.clone())); debug!("tokens stored"); access_token } #[tracing::instrument(level = "debug", err, skip(client, state))] pub async fn refresh_access_token( client: &reqwest::Client, state: &AppState<'_>, athlete: ecsdb::Entity<'_>, ) -> Result { debug!("Refreshing AccessToken via RefreshToken"); let strava::RefreshToken(refresh_token) = athlete .component() .ok_or(anyhow!("No RefreshToken on {}", athlete.id()))?; let token_response = state .oauth_client .exchange_refresh_token(&refresh_token) .request_async(client) .await?; Ok(store_tokens(athlete, &token_response)) } #[tracing::instrument(level = "debug", skip_all)] pub async fn redirect( state: extract::State>, session: tower_sessions::Session, ) -> impl IntoResponse { use oauth2::*; let scopes = [ // Scope::new("read".to_string()), Scope::new("activity:read".to_string()), ]; let (authorize_url, csrf_state) = state .oauth_client .authorize_url(CsrfToken::new_random) .add_scopes(scopes) .url(); session.insert("csrf_token", csrf_state).await.unwrap(); axum::response::Redirect::to(authorize_url.as_str()).into_response() } // #[axum::debug_middleware] // #[tracing::instrument(level = "debug", skip_all)] pub async fn user_session_middleware( app_state: extract::State>, http_session: tower_sessions::Session, _query_params: extract::Query>, mut request: Request, next: Next, ) -> Response { let user_session = http_session .get::(api::USER_SESSION_KEY) .await; match user_session { Ok(Some(session)) => { debug!(?session, "Found UserSession"); // Check if the OAuth2 Session has an `AccessToken` component let db = app_state.acquire_db().await; if db .find(session.athlete_id) .next() .is_some_and(|e| e.has::()) { request.extensions_mut().insert(session); } else { warn!( ?session, reason = "Missing AccessToken", "Invalidating Session" ); http_session.clear().await; } } Ok(None) => debug!("No session found"), Err(e) => { error!(?e, "Failed to extract UserSession from Session"); http_session.clear().await; } }; next.run(request).await } // #[tracing::instrument(level = "debug", skip_all)] pub async fn enforce_user_session( user_session: Option, request: Request, next: Next, ) -> Response { match user_session { Some(_) => next.run(request).await, None => { warn!("Unauthorized"); let is_htmx_request = request .headers() .get("HX-Request") .is_some_and(|h| h.to_str().is_ok_and(|v| v == "true")); if is_htmx_request { Response::builder() .status(StatusCode::UNAUTHORIZED) .header("hx-refresh", "true") .body(axum::body::Body::empty()) .unwrap() .into_response() } else { info!("Redirecting to /login"); response::Redirect::to("/login").into_response() } } } }