//! GraphQL HTTP handler for Axum use async_graphql::dynamic::Schema; use async_graphql::http::{WebSocket as GraphQLWebSocket, WebSocketProtocols, WsMessage}; use async_graphql_axum::{GraphQLRequest, GraphQLResponse}; use axum::{ extract::{ Query, State, WebSocketUpgrade, ws::{Message, WebSocket}, }, http::{HeaderMap, StatusCode}, response::{Html, Response}, }; use futures_util::{SinkExt, StreamExt}; use serde::Deserialize; use std::sync::Arc; use tokio::sync::RwLock; use crate::AppState; use crate::errors::AppError; use crate::graphql::GraphQLContext; /// Global schema cache (one schema per slice) /// This prevents rebuilding the schema on every request type SchemaCache = Arc>>; lazy_static::lazy_static! { static ref SCHEMA_CACHE: SchemaCache = Arc::new(RwLock::new(std::collections::HashMap::new())); } #[derive(Deserialize, Default)] pub struct GraphQLParams { pub slice: Option, } /// GraphQL query handler /// Accepts slice URI from either query parameter (?slice=...) or HTTP header (X-Slice-Uri) pub async fn graphql_handler( State(state): State, Query(params): Query, headers: HeaderMap, req: GraphQLRequest, ) -> Result { // Get slice URI from query param or header let slice_uri = params .slice .or_else(|| { headers .get("x-slice-uri") .and_then(|h| h.to_str().ok()) .map(|s| s.to_string()) }) .ok_or_else(|| { ( StatusCode::BAD_REQUEST, "Missing slice parameter. Provide either ?slice=... query parameter or X-Slice-Uri header".to_string(), ) })?; let schema = match get_or_build_schema(&state, &slice_uri).await { Ok(s) => s, Err(e) => { tracing::error!("Failed to get GraphQL schema: {:?}", e); return Ok(async_graphql::Response::from_errors(vec![ async_graphql::ServerError::new(format!("Schema error: {:?}", e), None), ]) .into()); } }; // Extract optional bearer token for mutations let auth_token = headers .get("authorization") .and_then(|h| h.to_str().ok()) .and_then(|s| s.strip_prefix("Bearer ")) .map(|s| s.to_string()); // Create GraphQL context with DataLoader and auth let gql_context = GraphQLContext::with_auth( state.database.clone(), auth_token.clone(), state.config.auth_base_url.clone(), Some(state.auth_cache.clone()), ); // Verify auth token and get user DID for mutations let mut request = req.into_inner().data(gql_context).data(state.database_pool.clone()); if let Some(token) = auth_token { // Verify token and add user DID to context match crate::auth::verify_oauth_token_cached( &token, &state.config.auth_base_url, Some(state.auth_cache.clone()), ) .await { Ok(user_info) => { request = request.data(user_info.sub); } Err(_) => { // Invalid token - let the mutation handle the error } } } // Execute query with context Ok(schema.execute(request).await.into()) } /// GraphiQL UI handler /// Configures GraphiQL with the slice URI in headers pub async fn graphql_playground( Query(params): Query, ) -> Result, (StatusCode, String)> { let slice_uri = params.slice.ok_or_else(|| { ( StatusCode::BAD_REQUEST, "Missing slice parameter. Provide ?slice=... query parameter".to_string(), ) })?; // Create GraphiQL with pre-configured headers using React 19 and modern ESM let graphiql_html = format!( r#" Slices GraphiQL
Loading…
"#, slice_uri.replace("'", "\\'").replace("\"", "\\\""), slice_uri.replace("'", "\\'").replace("\"", "\\\"") ); Ok(Html(graphiql_html)) } /// GraphQL WebSocket handler for subscriptions /// Accepts slice URI from query parameter (?slice=...) pub async fn graphql_subscription_handler( State(state): State, Query(params): Query, ws: WebSocketUpgrade, ) -> Result { let slice_uri = params.slice.ok_or_else(|| { ( StatusCode::BAD_REQUEST, "Missing slice parameter. Provide ?slice=... query parameter".to_string(), ) })?; let schema = match get_or_build_schema(&state, &slice_uri).await { Ok(s) => s, Err(e) => { tracing::error!("Failed to get GraphQL schema: {:?}", e); return Err(( StatusCode::INTERNAL_SERVER_ERROR, format!("Schema error: {:?}", e), )); } }; // Create GraphQL context with DataLoader (subscriptions don't need auth typically) let gql_context = GraphQLContext::with_auth( state.database.clone(), None, state.config.auth_base_url.clone(), Some(state.auth_cache.clone()), ); // Upgrade to WebSocket and handle GraphQL subscriptions manually let db_pool = state.database_pool.clone(); Ok(ws .protocols(["graphql-transport-ws", "graphql-ws"]) .on_upgrade(move |socket| handle_graphql_ws(socket, schema, gql_context, db_pool))) } /// Handle GraphQL WebSocket connection async fn handle_graphql_ws(socket: WebSocket, schema: Schema, gql_context: GraphQLContext, state_pool: sqlx::PgPool) { let (ws_sender, ws_receiver) = socket.split(); // Convert axum WebSocket messages to strings for async-graphql let input = ws_receiver.filter_map(|msg| { futures_util::future::ready(match msg { Ok(Message::Text(text)) => Some(text.to_string()), _ => None, // Ignore other message types }) }); // Create GraphQL WebSocket handler with context and database pool let mut stream = GraphQLWebSocket::new(schema.clone(), input, WebSocketProtocols::GraphQLWS) .on_connection_init(move |_| { let gql_ctx = gql_context.clone(); let pool = state_pool.clone(); async move { let mut data = async_graphql::Data::default(); data.insert(gql_ctx); data.insert(pool); Ok(data) } }); // Send GraphQL messages back through WebSocket let mut ws_sender = ws_sender; while let Some(msg) = stream.next().await { let axum_msg = match msg { WsMessage::Text(text) => Message::Text(text.into()), WsMessage::Close(code, reason) => Message::Close(Some(axum::extract::ws::CloseFrame { code, reason: reason.into(), })), }; if ws_sender.send(axum_msg).await.is_err() { break; } } } /// Gets schema from cache or builds it if not cached async fn get_or_build_schema(state: &AppState, slice_uri: &str) -> Result { // Check cache first { let cache = SCHEMA_CACHE.read().await; if let Some(schema) = cache.get(slice_uri) { return Ok(schema.clone()); } } // Build schema let schema = crate::graphql::build_graphql_schema( state.database.clone(), slice_uri.to_string(), state.config.auth_base_url.clone(), ) .await .map_err(|e| AppError::Internal(format!("Failed to build GraphQL schema: {}", e)))?; // Cache it { let mut cache = SCHEMA_CACHE.write().await; cache.insert(slice_uri.to_string(), schema.clone()); } Ok(schema) } /// Invalidates the cached GraphQL schema for a given slice /// /// This should be called when lexicon records are created, updated, or deleted /// to ensure the schema is rebuilt with the new lexicon definitions. pub async fn invalidate_schema_cache(slice_uri: &str) { let mut cache = SCHEMA_CACHE.write().await; cache.remove(slice_uri); tracing::debug!("Invalidated GraphQL schema cache for slice: {}", slice_uri); }