The smokesignal.events web application

chore: Adding AIP auth support

Signed-off-by: Nick Gerakines <nick.gerakines@gmail.com>

Changed files
+3193 -4513
migrations
src
atproto
bin
http
storage
templates
+192
CLAUDE.md
··· 1 + # CLAUDE.md 2 + 3 + This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. 4 + 5 + ## Project Overview 6 + 7 + Smokesignal is a Rust-based event and RSVP management application built for the AT Protocol ecosystem. It provides decentralized identity management, OAuth authentication, and event coordination through the AT Protocol. The application supports both standard AT Protocol OAuth and AIP (AT Protocol Improvement Proposal) OAuth flows, with backend selection determined by runtime configuration. 8 + 9 + ## Tech Stack 10 + 11 + - **Language**: Rust (edition 2021, minimum version 1.83) 12 + - **Web Framework**: Axum with async/await 13 + - **Database**: PostgreSQL with SQLx migrations 14 + - **Caching**: Redis/Valkey for sessions and token management 15 + - **Templates**: MiniJinja (server-side rendering with optional reloading) 16 + - **Frontend**: HTMX + Bulma CSS + FontAwesome 17 + - **Authentication**: AT Protocol OAuth with JOSE/JWT (P-256 ECDSA) 18 + - **Internationalization**: Fluent for i18n support 19 + - **Static Assets**: Rust Embed for production builds 20 + - **HTTP Client**: Reqwest with middleware chain, retry logic, and compression 21 + - **Cryptography**: P-256 ECDSA with elliptic-curve and p256 crates 22 + - **Task Management**: Tokio with task tracking and graceful shutdown 23 + - **DNS Resolution**: Hickory resolver with DoH/DoT support 24 + 25 + ## Development Commands 26 + 27 + ### Building & Development 28 + ```bash 29 + # Development build with template reloading (default) 30 + cargo build --bin smokesignal 31 + 32 + # Production build with embedded templates 33 + cargo build --bin smokesignal --no-default-features -F embed 34 + 35 + # Type checking and linting 36 + cargo check 37 + cargo clippy 38 + 39 + # Run tests 40 + cargo test 41 + ``` 42 + 43 + ### Running Services 44 + ```bash 45 + # Start main application server (development mode) 46 + cargo run --bin smokesignal 47 + 48 + # Generate cryptographic keys 49 + cargo run --bin crypto -- key 50 + 51 + ``` 52 + 53 + ### Database Operations 54 + ```bash 55 + # Run database migrations 56 + sqlx migrate run 57 + 58 + # Reset database with migrations 59 + sqlx database reset 60 + ``` 61 + 62 + ## Architecture Overview 63 + 64 + ### Core Structure 65 + ``` 66 + /src/ 67 + ├── atproto/ # AT Protocol integration 68 + │ ├── auth.rs # Authentication utilities 69 + │ └── lexicon/ # AT Protocol lexicon definitions 70 + ├── bin/ # Executables 71 + │ ├── smokesignal.rs # Main application server 72 + │ └── crypto.rs # Cryptographic key utilities 73 + ├── http/ # Web layer 74 + │ ├── errors/ # HTTP error handling 75 + │ ├── handle_*.rs # Route handlers 76 + │ ├── middleware_*.rs # Request middleware 77 + │ └── server.rs # Server configuration 78 + ├── storage/ # Data access layer 79 + │ ├── cache.rs # Redis caching 80 + │ ├── identity_profile.rs # User identity management 81 + │ ├── event.rs # Event data operations 82 + │ ├── oauth.rs # OAuth session management 83 + │ └── denylist.rs # Handle blocking 84 + ├── config.rs # Environment configuration 85 + ├── key_provider.rs # JWT key management 86 + ├── i18n.rs # Internationalization 87 + └── task_refresh_tokens.rs # Background token refresh 88 + ``` 89 + 90 + ### Key Patterns 91 + - **Layered Architecture**: HTTP → Business Logic → Data Access 92 + - **Module-based Organization**: Each handler in separate module with error types 93 + - **Dependency Injection**: Services passed through application context 94 + - **Template-driven UI**: Server-side rendering with optional reloading 95 + 96 + ### Database Schema 97 + - `identity_profiles` - User identity and preferences (formerly `handles`) 98 + - `oauth_requests` - OAuth flow state with PKCE and DPoP support 99 + - `oauth_sessions` - Active sessions with token management 100 + - `events` - Calendar events with location and timezone support 101 + - `rsvps` - Event responses and attendance tracking 102 + - `denylist` - Handle blocking for moderation 103 + 104 + ## Configuration Requirements 105 + 106 + ### Required Environment Variables 107 + - `HTTP_COOKIE_KEY` - 64-character hex key for session encryption 108 + - `DATABASE_URL` - PostgreSQL connection string 109 + - `REDIS_URL` - Redis/Valkey connection string 110 + - `EXTERNAL_BASE` - Public base URL (e.g., https://smokesignal.events) 111 + - `PLC_HOSTNAME` - AT Protocol PLC server hostname 112 + - `ADMIN_DIDS` - Comma-separated admin user DIDs 113 + - `DESTINATION_KEY` - Private key for signing authentication redirect tokens 114 + 115 + ### OAuth Backend Configuration 116 + - `OAUTH_BACKEND` - OAuth backend to use: "atprotocol" (default) or "aip" 117 + 118 + When `OAUTH_BACKEND=atprotocol`, the following variable is required: 119 + - `SIGNING_KEYS` - Path to JWK key set file for OAuth signing 120 + 121 + When `OAUTH_BACKEND=aip`, the following variables are required: 122 + - `AIP_HOSTNAME` - AIP OAuth server hostname 123 + - `AIP_CLIENT_ID` - AIP OAuth client ID 124 + - `AIP_CLIENT_SECRET` - AIP OAuth client secret 125 + 126 + ### Optional Environment Variables 127 + - `RUST_LOG` - Logging configuration (default: info) 128 + - `PORT` - Server port (default: 3000) 129 + - `BIND_ADDR` - Bind address (default: 0.0.0.0) 130 + 131 + ## Development Setup 132 + 133 + ### Using DevContainer (Recommended) 134 + The project includes a complete DevContainer setup with PostgreSQL, Redis, and all Rust dependencies. Use the provided `.devcontainer/` configuration for consistent development environment. 135 + 136 + ### Key Development Features 137 + - Template reloading in development mode (default `reload` feature) 138 + - Embedded templates for production builds (`embed` feature) 139 + - SQLx compile-time query checking with migrations 140 + - Cryptographic key generation and management utilities 141 + - Comprehensive error handling with user-friendly messages 142 + - Internationalization support with Fluent 143 + - Background token refresh for OAuth sessions 144 + - LRU caching for OAuth requests and DID documents 145 + - Graceful shutdown with task tracking and cancellation tokens 146 + 147 + ## Testing 148 + 149 + - Unit tests embedded in source files using `#[cfg(test)]` 150 + - Integration tests for database operations and HTTP handlers 151 + - Run with `cargo test` 152 + - Database tests require running PostgreSQL instance with test database 153 + - OAuth flow tests use fixture data for AT Protocol interactions 154 + - Template rendering tests ensure UI consistency 155 + 156 + ## Security Notes 157 + 158 + - Uses JOSE/JWT with P-256 ECDSA keys for OAuth signing 159 + - PKCE OAuth flow with state parameter for CSRF protection 160 + - DPoP (Demonstration of Proof-of-Possession) support for token binding 161 + - Encrypted session cookies with secure flags 162 + - Forbids unsafe Rust code (`#[forbid(unsafe_code)]`) 163 + - Input validation and HTML sanitization with Ammonia 164 + - Content Security Policy headers 165 + - Rate limiting and request timeouts 166 + 167 + ## Visibility 168 + 169 + Types and methods should have the lowest visibility necessary, defaulting to `private`. If `public` visibility is necessary, attempt to make it public to the crate only. Using completely public visibility should be a last resort. 170 + 171 + ## AT Protocol Integration 172 + 173 + - Full OAuth 2.0 implementation with AT Protocol services using external `atproto-*` crates 174 + - Support for both standard AT Protocol and AIP OAuth flows 175 + - DID-based identity management with PDS discovery 176 + - Integration with PLC (Public Ledger of Credentials) for DID resolution 177 + - Handle resolution and verification with LRU caching 178 + - Decentralized identity workflows with fallback mechanisms 179 + - AT Protocol lexicon support for events and RSVPs: 180 + - `events_smokesignal_calendar_event` 181 + - `events_smokesignal_calendar_rsvp` 182 + - `community_lexicon_calendar_event` 183 + - `community_lexicon_calendar_rsvp` 184 + - `community_lexicon_location` 185 + - Background token refresh for long-lived sessions (AT Protocol backend only) 186 + - Uses external AT Protocol libraries: 187 + - `atproto-identity` - Identity resolution and verification 188 + - `atproto-oauth` - OAuth flow implementation 189 + - `atproto-oauth-axum` - Axum integration for OAuth 190 + - `atproto-oauth-aip` - AIP OAuth implementation 191 + - `atproto-client` - AT Protocol client functionality 192 + - `atproto-record` - Record management
+9 -2
Cargo.toml
··· 1 1 [package] 2 2 name = "smokesignal" 3 3 version = "1.0.2" 4 - edition = "2021" 5 - rust-version = "1.83" 4 + edition = "2024" 5 + rust-version = "1.87" 6 6 authors = ["Nick Gerakines <nick.gerakines@gmail.com>"] 7 7 description = "An event and RSVP management application." 8 8 readme = "README.md" ··· 23 23 minijinja-embed = {version = "2.7"} 24 24 25 25 [dependencies] 26 + atproto-identity = { version = "0.9.3", features = ["lru", "axum", "zeroize"] } 27 + atproto-oauth = { version = "0.9.3", features = ["lru", "axum", "zeroize"] } 28 + atproto-oauth-axum = { version = "0.9.3", features = ["zeroize"] } 29 + atproto-oauth-aip = { version = "0.9.3", features = ["zeroize"] } 30 + atproto-client = { version = "0.9.3" } 31 + atproto-record = { version = "0.9.3" } 32 + 26 33 anyhow = "1.0" 27 34 async-trait = "0.1" 28 35 axum-extra = { version = "0.10", features = ["cookie", "cookie-private", "form", "query", "cookie-key-expansion", "typed-header", "typed-routing"] }
+27 -35
Dockerfile
··· 1 1 # syntax=docker/dockerfile:1.4 2 - FROM rust:latest AS build 2 + FROM rust:1.87-slim AS builder 3 3 4 - RUN cargo install sqlx-cli@0.8.2 --no-default-features --features postgres 5 - RUN cargo install sccache --version ^0.8 6 - ENV RUSTC_WRAPPER=sccache SCCACHE_DIR=/sccache 4 + RUN apt-get update && apt-get install -y \ 5 + pkg-config \ 6 + libssl-dev \ 7 + && rm -rf /var/lib/apt/lists/* 7 8 8 - RUN USER=root cargo new --bin smokesignal 9 - RUN mkdir -p /app/ 10 - WORKDIR /app/ 9 + WORKDIR /app 10 + COPY Cargo.toml build.rs ./ 11 11 12 - RUN --mount=type=bind,source=src,target=src \ 13 - --mount=type=bind,source=migrations,target=migrations \ 14 - --mount=type=bind,source=static,target=static \ 15 - --mount=type=bind,source=i18n,target=i18n \ 16 - --mount=type=bind,source=templates,target=templates \ 17 - --mount=type=bind,source=build.rs,target=build.rs \ 18 - --mount=type=bind,source=Cargo.toml,target=Cargo.toml \ 19 - --mount=type=bind,source=Cargo.lock,target=Cargo.lock \ 20 - --mount=type=cache,id=cargo-target,target=/app/target/ \ 21 - --mount=type=cache,id=sccache,target=$SCCACHE_DIR,sharing=locked \ 22 - --mount=type=cache,id=cargo-registry,target=/usr/local/cargo/registry/ \ 23 - <<EOF 24 - set -e 25 - cargo build --locked --release --bin smokesignal --target-dir . --no-default-features -F embed 26 - EOF 12 + ARG FEATURES=embed 13 + ARG TEMPLATES=./templates 14 + ARG STATIC=./static 15 + ARG GIT_HASH=0 16 + ENV GIT_HASH=$GIT_HASH 27 17 28 - RUN groupadd -g 1500 -r smokesignal && useradd -u 1501 -r -g smokesignal -d /var/lib/smokesignal -m smokesignal 29 - RUN chown -R smokesignal:smokesignal /app/release/smokesignal 18 + COPY src ./src 19 + COPY migrations ./migrations 20 + COPY static ./static 21 + COPY i18n ./i18n 22 + COPY ${TEMPLATES} ./templates 23 + COPY ${STATIC} ./static 30 24 31 - FROM gcr.io/distroless/cc 25 + RUN cargo build --release --bin smokesignal --no-default-features --features ${FEATURES} 26 + 27 + FROM gcr.io/distroless/cc-debian12 32 28 33 29 LABEL org.opencontainers.image.title="Smoke Signal" 34 30 LABEL org.opencontainers.image.description="An event and RSVP management application." ··· 37 33 LABEL org.opencontainers.image.source="https://tangled.sh/@smokesignal.events/smokesignal" 38 34 LABEL org.opencontainers.image.version="1.0.2" 39 35 40 - WORKDIR /var/lib/smokesignal 41 - USER smokesignal:smokesignal 36 + WORKDIR /app 37 + COPY --from=builder /app/target/release/smokesignal /app/smokesignal 38 + COPY --from=builder /app/static ./static 42 39 43 - COPY --from=build /etc/passwd /etc/passwd 44 - COPY --from=build /etc/group /etc/group 45 - COPY --from=build /app/release/smokesignal /var/lib/smokesignal/ 46 - COPY static /var/lib/smokesignal/static 47 - 48 - ENV HTTP_STATIC_PATH=/var/lib/smokesignal/static 49 - 40 + ENV HTTP_STATIC_PATH=/app/static 41 + ENV HTTP_PORT=8080 50 42 ENV RUST_LOG=info 51 43 ENV RUST_BACKTRACE=full 52 44 53 - ENTRYPOINT ["/var/lib/smokesignal/smokesignal"] 45 + ENTRYPOINT ["/app/smokesignal"]
+11 -1
build.rs
··· 1 1 fn main() { 2 2 #[cfg(feature = "embed")] 3 3 { 4 - minijinja_embed::embed_templates!("templates"); 4 + use std::env; 5 + use std::path::PathBuf; 6 + let template_path = if let Ok(value) = env::var("HTTP_TEMPLATE_PATH") { 7 + value.to_string() 8 + } else { 9 + PathBuf::from(env!("CARGO_MANIFEST_DIR")) 10 + .join("templates") 11 + .display() 12 + .to_string() 13 + }; 14 + minijinja_embed::embed_templates!(&template_path); 5 15 } 6 16 }
+64
docker-compose.yml
··· 1 + version: '3.8' 2 + 3 + services: 4 + postgres: 5 + image: postgres:17-alpine 6 + container_name: smokesignal_postgres 7 + environment: 8 + POSTGRES_DB: smokesignal_dev 9 + POSTGRES_USER: smokesignal 10 + POSTGRES_PASSWORD: smokesignal_dev_password 11 + POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C" 12 + ports: 13 + - "5436:5432" # Using 5433 to avoid conflicts with system PostgreSQL 14 + volumes: 15 + - postgres_data:/var/lib/postgresql/data 16 + - ./init-db.sql:/docker-entrypoint-initdb.d/init-db.sql:ro 17 + healthcheck: 18 + test: ["CMD-SHELL", "pg_isready -U smokesignal -d smokesignal_dev"] 19 + interval: 5s 20 + timeout: 5s 21 + retries: 5 22 + restart: unless-stopped 23 + 24 + minio: 25 + image: minio/minio:latest 26 + container_name: smokesignal_minio 27 + command: server /data --console-address ":9001" 28 + environment: 29 + MINIO_ROOT_USER: smokesignal_minio 30 + MINIO_ROOT_PASSWORD: smokesignal_dev_secret 31 + ports: 32 + - "9000:9000" # MinIO API 33 + - "9001:9001" # MinIO Console 34 + volumes: 35 + - minio_data:/data 36 + healthcheck: 37 + test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"] 38 + interval: 30s 39 + timeout: 20s 40 + retries: 3 41 + restart: unless-stopped 42 + 43 + createbuckets: 44 + image: minio/mc 45 + depends_on: 46 + - minio 47 + entrypoint: > 48 + /bin/sh -c " 49 + /usr/bin/mc mc alias set local http://localhost:9000 smokesignal_minio smokesignal_dev_secret; 50 + /usr/bin/mc mb myminio/smokesignal-badges; 51 + /usr/bin/mc policy set public myminio/smokesignal-badges; 52 + exit 0; 53 + " 54 + volumes: 55 + minio_data: 56 + driver: local 57 + postgres_data: 58 + driver: local 59 + pgadmin_data: 60 + driver: local 61 + 62 + networks: 63 + default: 64 + name: smokesignal_network
+40
migrations/20250619135328_convert_dpop_jwk_to_text.sql
··· 1 + -- Convert dpop_jwk from JSON to TEXT for oauth_requests and oauth_sessions tables 2 + -- This migration preserves existing data by converting JSON values to string format 3 + 4 + -- First, convert oauth_requests.dpop_jwk from JSON to TEXT 5 + ALTER TABLE oauth_requests 6 + ADD COLUMN dpop_jwk_text TEXT; 7 + 8 + -- Copy existing JSON data as string 9 + UPDATE oauth_requests 10 + SET dpop_jwk_text = dpop_jwk::text; 11 + 12 + -- Set NOT NULL constraint after data migration 13 + ALTER TABLE oauth_requests 14 + ALTER COLUMN dpop_jwk_text SET NOT NULL; 15 + 16 + -- Drop old JSON column and rename new column 17 + ALTER TABLE oauth_requests 18 + DROP COLUMN dpop_jwk; 19 + 20 + ALTER TABLE oauth_requests 21 + RENAME COLUMN dpop_jwk_text TO dpop_jwk; 22 + 23 + -- Now do the same for oauth_sessions.dpop_jwk 24 + ALTER TABLE oauth_sessions 25 + ADD COLUMN dpop_jwk_text TEXT; 26 + 27 + -- Copy existing JSON data as string 28 + UPDATE oauth_sessions 29 + SET dpop_jwk_text = dpop_jwk::text; 30 + 31 + -- Set NOT NULL constraint after data migration 32 + ALTER TABLE oauth_sessions 33 + ALTER COLUMN dpop_jwk_text SET NOT NULL; 34 + 35 + -- Drop old JSON column and rename new column 36 + ALTER TABLE oauth_sessions 37 + DROP COLUMN dpop_jwk; 38 + 39 + ALTER TABLE oauth_sessions 40 + RENAME COLUMN dpop_jwk_text TO dpop_jwk;
+2
migrations/20250619152115_rename_handles_to_identity_profiles.sql
··· 1 + -- Rename handles table to identity_profiles 2 + ALTER TABLE handles RENAME TO identity_profiles;
+14
migrations/20250620162528_did_documents_storage.sql
··· 1 + -- DID Document storage for atproto_identity::storage::DidDocumentStorage implementation 2 + CREATE TABLE did_documents ( 3 + did varchar(512) PRIMARY KEY, 4 + document_json JSON NOT NULL, 5 + created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), 6 + updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(), 7 + expires_at TIMESTAMP WITH TIME ZONE DEFAULT NULL 8 + ); 9 + 10 + -- Index for expiration cleanup 11 + CREATE INDEX idx_did_documents_expires_at ON did_documents(expires_at) WHERE expires_at IS NOT NULL; 12 + 13 + -- Index for updated_at for LRU-style cleanup 14 + CREATE INDEX idx_did_documents_updated_at ON did_documents(updated_at);
+21
migrations/20250620162624_atproto_oauth_requests_storage.sql
··· 1 + -- AT Protocol OAuth Request storage for atproto_oauth::storage::OAuthRequestStorage implementation 2 + CREATE TABLE atproto_oauth_requests ( 3 + oauth_state varchar(512) PRIMARY KEY, 4 + issuer varchar(512) NOT NULL, 5 + did varchar(512) NOT NULL, 6 + nonce varchar(512) NOT NULL, 7 + pkce_verifier varchar(512) NOT NULL, 8 + signing_public_key varchar(512) NOT NULL, 9 + dpop_private_key TEXT NOT NULL, 10 + created_at TIMESTAMP WITH TIME ZONE NOT NULL, 11 + expires_at TIMESTAMP WITH TIME ZONE NOT NULL 12 + ); 13 + 14 + -- Index for expiration cleanup 15 + CREATE INDEX idx_atproto_oauth_requests_expires_at ON atproto_oauth_requests(expires_at); 16 + 17 + -- Index for created_at for cleanup of old requests 18 + CREATE INDEX idx_atproto_oauth_requests_created_at ON atproto_oauth_requests(created_at); 19 + 20 + -- Index for DID lookups 21 + CREATE INDEX idx_atproto_oauth_requests_did ON atproto_oauth_requests(did);
+27 -20
src/atproto/auth.rs
··· 1 - use p256::SecretKey; 1 + use anyhow::Result; 2 2 3 - pub trait OAuthSessionProvider { 4 - fn oauth_access_token(&self) -> String; 5 - fn oauth_issuer(&self) -> String; 6 - fn dpop_secret(&self) -> SecretKey; 3 + use atproto_client::client::DPoPAuth; 4 + use atproto_identity::key::identify_key; 5 + 6 + use crate::storage::oauth::model::OAuthSession; 7 + use atproto_oauth_aip::{resources::oauth_protected_resource, workflow::session_exchange}; 8 + 9 + /// Create DPoPAuth directly from OAuthSession 10 + pub fn create_dpop_auth_from_oauth_session(oauth_session: &OAuthSession) -> Result<DPoPAuth> { 11 + let dpop_private_key_data = identify_key(&oauth_session.dpop_jwk)?; 12 + 13 + Ok(DPoPAuth { 14 + dpop_private_key_data, 15 + oauth_access_token: oauth_session.access_token.clone(), 16 + }) 7 17 } 8 18 9 - pub struct SimpleOAuthSessionProvider { 10 - pub access_token: String, 11 - pub issuer: String, 12 - pub dpop_secret: SecretKey, 13 - } 19 + pub async fn create_dpop_auth_from_aip_session( 20 + http_client: &reqwest::Client, 21 + aip_server: &str, 22 + access_token: &str, 23 + ) -> Result<DPoPAuth> { 24 + let protected_resource = oauth_protected_resource(http_client, aip_server).await?; 14 25 15 - impl OAuthSessionProvider for SimpleOAuthSessionProvider { 16 - fn oauth_access_token(&self) -> String { 17 - self.access_token.clone() 18 - } 26 + let session_response = session_exchange(http_client, &protected_resource, access_token).await?; 19 27 20 - fn oauth_issuer(&self) -> String { 21 - self.issuer.clone() 22 - } 28 + let dpop_private_key_data = identify_key(&session_response.dpop_key)?; 23 29 24 - fn dpop_secret(&self) -> SecretKey { 25 - self.dpop_secret.clone() 26 - } 30 + Ok(DPoPAuth { 31 + dpop_private_key_data, 32 + oauth_access_token: session_response.access_token.clone(), 33 + }) 27 34 }
-377
src/atproto/client.rs
··· 1 - use std::time::Duration; 2 - 3 - use anyhow::Result; 4 - use reqwest_chain::ChainMiddleware; 5 - use reqwest_middleware::ClientBuilder; 6 - use serde::{de::DeserializeOwned, Deserialize, Serialize}; 7 - use tracing::Instrument; 8 - 9 - // Standard timeout for all HTTP client operations 10 - const HTTP_CLIENT_TIMEOUT_SECS: u64 = 8; 11 - 12 - use crate::atproto::auth::OAuthSessionProvider; 13 - use crate::atproto::errors::ClientError; 14 - use crate::atproto::lexicon::com::atproto::repo::StrongRef; 15 - use crate::atproto::xrpc::SimpleError; 16 - use crate::http::handle_oauth_login::pkce_challenge; 17 - use crate::http::utils::URLBuilder; 18 - use crate::jose::jwt::{Claims, Header, JoseClaims}; 19 - use crate::jose::mint_token; 20 - use crate::oauth::dpop::DpopRetry; 21 - 22 - #[derive(Debug, Serialize, Deserialize, Clone)] 23 - #[serde(bound = "T: Serialize + DeserializeOwned")] 24 - pub struct CreateRecordRequest<T: DeserializeOwned> { 25 - pub repo: String, 26 - pub collection: String, 27 - 28 - #[serde(skip_serializing_if = "Option::is_none", default, rename = "rkey")] 29 - pub record_key: Option<String>, 30 - 31 - pub validate: bool, 32 - 33 - pub record: T, 34 - 35 - #[serde( 36 - skip_serializing_if = "Option::is_none", 37 - default, 38 - rename = "swapCommit" 39 - )] 40 - pub swap_commit: Option<String>, 41 - } 42 - 43 - #[derive(Debug, Serialize, Deserialize, Clone)] 44 - #[serde(bound = "T: Serialize + DeserializeOwned")] 45 - pub struct PutRecordRequest<T: DeserializeOwned> { 46 - pub repo: String, 47 - pub collection: String, 48 - 49 - #[serde(rename = "rkey")] 50 - pub record_key: String, 51 - 52 - pub validate: bool, 53 - 54 - pub record: T, 55 - 56 - #[serde( 57 - skip_serializing_if = "Option::is_none", 58 - default, 59 - rename = "swapCommit" 60 - )] 61 - pub swap_commit: Option<String>, 62 - 63 - #[serde( 64 - skip_serializing_if = "Option::is_none", 65 - default, 66 - rename = "swapRecord" 67 - )] 68 - pub swap_record: Option<String>, 69 - } 70 - 71 - #[derive(Debug, Serialize, Deserialize, Clone)] 72 - #[serde(untagged)] 73 - pub enum CreateRecordResponse { 74 - StrongRef(StrongRef), 75 - Error(SimpleError), 76 - } 77 - 78 - #[derive(Debug, Serialize, Deserialize, Clone)] 79 - #[serde(untagged)] 80 - pub enum PutRecordResponse { 81 - StrongRef(StrongRef), 82 - Error(SimpleError), 83 - } 84 - 85 - #[derive(Debug, Serialize, Deserialize, Clone)] 86 - pub struct ListRecordsParams { 87 - pub repo: String, 88 - pub collection: String, 89 - #[serde(skip_serializing_if = "Option::is_none")] 90 - pub limit: Option<u32>, 91 - #[serde(skip_serializing_if = "Option::is_none")] 92 - pub cursor: Option<String>, 93 - #[serde(skip_serializing_if = "Option::is_none")] 94 - pub reverse: Option<bool>, 95 - } 96 - 97 - #[derive(Debug, Serialize, Deserialize, Clone)] 98 - pub struct ListRecord<T> { 99 - pub uri: String, 100 - pub cid: String, 101 - pub value: T, 102 - } 103 - 104 - #[derive(Debug, Serialize, Deserialize, Clone)] 105 - pub struct ListRecordsResponse<T> { 106 - pub cursor: Option<String>, 107 - pub records: Vec<ListRecord<T>>, 108 - } 109 - 110 - pub struct OAuthPdsClient<'a> { 111 - pub http_client: &'a reqwest::Client, 112 - pub pds: &'a str, 113 - } 114 - 115 - impl OAuthPdsClient<'_> { 116 - pub async fn create_record<T: DeserializeOwned + Serialize>( 117 - &self, 118 - oauth_session: &impl OAuthSessionProvider, 119 - record: CreateRecordRequest<T>, 120 - ) -> Result<StrongRef, anyhow::Error> { 121 - let mut url_builder = URLBuilder::new(self.pds); 122 - url_builder.path("/xrpc/com.atproto.repo.createRecord"); 123 - let url = url_builder.build(); 124 - 125 - let dpop_secret_key = oauth_session.dpop_secret(); 126 - let dpop_public_key = dpop_secret_key.public_key(); 127 - let oauth_issuer = oauth_session.oauth_issuer(); 128 - let oauth_access_token = oauth_session.oauth_access_token(); 129 - 130 - let now = chrono::Utc::now(); 131 - 132 - let dpop_proof_header = Header { 133 - type_: Some("dpop+jwt".to_string()), 134 - algorithm: Some("ES256".to_string()), 135 - json_web_key: Some(dpop_public_key.to_jwk()), 136 - ..Default::default() 137 - }; 138 - 139 - let dpop_proof_claim = Claims::new(JoseClaims { 140 - issuer: Some(oauth_issuer.clone()), 141 - issued_at: Some(now.timestamp() as u64), 142 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 143 - json_web_token_id: Some(ulid::Ulid::new().to_string()), 144 - http_method: Some("POST".to_string()), 145 - http_uri: Some(url.clone()), 146 - auth: Some(pkce_challenge(&oauth_access_token)), 147 - 148 - ..Default::default() 149 - }); 150 - let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?; 151 - 152 - let dpop_retry = DpopRetry::new( 153 - dpop_proof_header.clone(), 154 - dpop_proof_claim.clone(), 155 - dpop_secret_key.clone(), 156 - ); 157 - 158 - let dpop_retry_client = ClientBuilder::new(self.http_client.clone()) 159 - .with(ChainMiddleware::new(dpop_retry.clone())) 160 - .build(); 161 - 162 - let http_response = dpop_retry_client 163 - .post(url) 164 - .header("Authorization", &format!("DPoP {}", oauth_access_token)) 165 - .header("DPoP", dpop_proof_token.as_str()) 166 - .json(&record) 167 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 168 - .send() 169 - .instrument(tracing::info_span!("create_record")) 170 - .await?; 171 - 172 - tracing::info!( 173 - "create_record response status: {:?}", 174 - http_response.status() 175 - ); 176 - 177 - let create_record_respoonse = http_response.json::<CreateRecordResponse>().await; 178 - 179 - match create_record_respoonse { 180 - Ok(CreateRecordResponse::StrongRef(strong_ref)) => Ok(strong_ref), 181 - Ok(CreateRecordResponse::Error(err)) => { 182 - Err(ClientError::ServerError(err.error_message()).into()) 183 - } 184 - Err(err) => Err(ClientError::CreateRecordResponseFailure(err).into()), 185 - } 186 - } 187 - 188 - pub async fn put_record<T: DeserializeOwned + Serialize>( 189 - &self, 190 - oauth_session: &impl OAuthSessionProvider, 191 - record: PutRecordRequest<T>, 192 - ) -> Result<StrongRef, anyhow::Error> { 193 - let mut url_builder = URLBuilder::new(self.pds); 194 - url_builder.path("/xrpc/com.atproto.repo.putRecord"); 195 - let url = url_builder.build(); 196 - 197 - let dpop_secret_key = oauth_session.dpop_secret(); 198 - let dpop_public_key = dpop_secret_key.public_key(); 199 - let oauth_issuer = oauth_session.oauth_issuer(); 200 - let oauth_access_token = oauth_session.oauth_access_token(); 201 - 202 - let now = chrono::Utc::now(); 203 - 204 - let dpop_proof_header = Header { 205 - type_: Some("dpop+jwt".to_string()), 206 - algorithm: Some("ES256".to_string()), 207 - json_web_key: Some(dpop_public_key.to_jwk()), 208 - ..Default::default() 209 - }; 210 - 211 - let dpop_proof_claim = Claims::new(JoseClaims { 212 - issuer: Some(oauth_issuer.clone()), 213 - issued_at: Some(now.timestamp() as u64), 214 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 215 - json_web_token_id: Some(ulid::Ulid::new().to_string()), 216 - http_method: Some("POST".to_string()), 217 - http_uri: Some(url.clone()), 218 - auth: Some(pkce_challenge(&oauth_access_token)), 219 - 220 - ..Default::default() 221 - }); 222 - let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?; 223 - 224 - let dpop_retry = DpopRetry::new( 225 - dpop_proof_header.clone(), 226 - dpop_proof_claim.clone(), 227 - dpop_secret_key.clone(), 228 - ); 229 - 230 - let dpop_retry_client = ClientBuilder::new(self.http_client.clone()) 231 - .with(ChainMiddleware::new(dpop_retry.clone())) 232 - .build(); 233 - 234 - let http_response = dpop_retry_client 235 - .post(url) 236 - .header("Authorization", &format!("DPoP {}", oauth_access_token)) 237 - .header("DPoP", dpop_proof_token.as_str()) 238 - .json(&record) 239 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 240 - .send() 241 - .instrument(tracing::info_span!("put_record")) 242 - .await?; 243 - 244 - tracing::info!("put_record response status: {:?}", http_response.status()); 245 - 246 - let put_record_respoonse = http_response.json::<PutRecordResponse>().await; 247 - 248 - match put_record_respoonse { 249 - Ok(PutRecordResponse::StrongRef(strong_ref)) => Ok(strong_ref), 250 - Ok(PutRecordResponse::Error(err)) => { 251 - Err(ClientError::ServerError(err.error_message()).into()) 252 - } 253 - Err(err) => Err(ClientError::PutRecordResponseFailure(err).into()), 254 - } 255 - } 256 - 257 - pub async fn list_records<T: DeserializeOwned>( 258 - &self, 259 - oauth_session: &impl OAuthSessionProvider, 260 - params: &ListRecordsParams, 261 - ) -> Result<ListRecordsResponse<T>, anyhow::Error> { 262 - let mut url_builder = URLBuilder::new(self.pds); 263 - url_builder.path("/xrpc/com.atproto.repo.listRecords"); 264 - 265 - // Add query parameters 266 - url_builder.param("repo", &params.repo); 267 - url_builder.param("collection", &params.collection); 268 - 269 - if let Some(limit) = params.limit { 270 - url_builder.param("limit", &limit.to_string()); 271 - } 272 - 273 - if let Some(cursor) = &params.cursor { 274 - url_builder.param("cursor", cursor); 275 - } 276 - 277 - if let Some(reverse) = params.reverse { 278 - url_builder.param("reverse", &reverse.to_string()); 279 - } 280 - 281 - let url = url_builder.build(); 282 - 283 - let dpop_secret_key = oauth_session.dpop_secret(); 284 - let dpop_public_key = dpop_secret_key.public_key(); 285 - let oauth_issuer = oauth_session.oauth_issuer(); 286 - let oauth_access_token = oauth_session.oauth_access_token(); 287 - 288 - let now = chrono::Utc::now(); 289 - 290 - let dpop_proof_header = Header { 291 - type_: Some("dpop+jwt".to_string()), 292 - algorithm: Some("ES256".to_string()), 293 - json_web_key: Some(dpop_public_key.to_jwk()), 294 - ..Default::default() 295 - }; 296 - 297 - let dpop_proof_claim = Claims::new(JoseClaims { 298 - issuer: Some(oauth_issuer.clone()), 299 - issued_at: Some(now.timestamp() as u64), 300 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 301 - json_web_token_id: Some(ulid::Ulid::new().to_string()), 302 - http_method: Some("GET".to_string()), 303 - http_uri: Some(url.clone()), 304 - auth: Some(pkce_challenge(&oauth_access_token)), 305 - 306 - ..Default::default() 307 - }); 308 - let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?; 309 - 310 - let dpop_retry = DpopRetry::new( 311 - dpop_proof_header.clone(), 312 - dpop_proof_claim.clone(), 313 - dpop_secret_key.clone(), 314 - ); 315 - 316 - let dpop_retry_client = ClientBuilder::new(self.http_client.clone()) 317 - .with(ChainMiddleware::new(dpop_retry.clone())) 318 - .build(); 319 - 320 - let http_response = dpop_retry_client 321 - .get(url) 322 - .header("Authorization", &format!("DPoP {}", oauth_access_token)) 323 - .header("DPoP", dpop_proof_token.as_str()) 324 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 325 - .send() 326 - .instrument(tracing::span!(tracing::Level::INFO, "list_records")) 327 - .await?; 328 - 329 - let result = http_response.json::<ListRecordsResponse<T>>().await?; 330 - 331 - Ok(result) 332 - } 333 - } 334 - 335 - #[cfg(test)] 336 - mod tests { 337 - use std::collections::HashMap; 338 - 339 - use crate::atproto::lexicon::community::lexicon::calendar::event::Event; 340 - 341 - use super::*; 342 - use anyhow::Result; 343 - 344 - #[test] 345 - fn location_record() -> Result<()> { 346 - let test_json = r#"{"repo":"nick","collection":"stuff","validate":false,"record":{"$type":"community.lexicon.calendar.event","name":"My awesome event","description":"A really cool event.","createdAt":"2024-08-04T09:45:00.000Z"}}"#; 347 - 348 - { 349 - // Serialize bare 350 - assert_eq!( 351 - serde_json::to_string(&CreateRecordRequest { 352 - repo: "nick".to_string(), 353 - collection: "stuff".to_string(), 354 - validate: false, 355 - record_key: None, 356 - record: Event::Current { 357 - name: "My awesome event".to_string(), 358 - description: "A really cool event.".to_string(), 359 - created_at: "2024-08-04T09:45:00.000Z".parse().unwrap(), 360 - starts_at: None, 361 - ends_at: None, 362 - mode: None, 363 - status: None, 364 - locations: vec![], 365 - uris: vec![], 366 - extra: HashMap::default(), 367 - }, 368 - swap_commit: None, 369 - }) 370 - .unwrap(), 371 - test_json 372 - ); 373 - } 374 - 375 - Ok(()) 376 - } 377 - }
-53
src/atproto/datetime.rs
··· 1 - pub mod format { 2 - use chrono::{DateTime, SecondsFormat, Utc}; 3 - use serde::{self, Deserialize, Deserializer, Serializer}; 4 - 5 - pub fn serialize<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error> 6 - where 7 - S: Serializer, 8 - { 9 - let s = date.to_rfc3339_opts(SecondsFormat::Millis, true); 10 - serializer.serialize_str(&s) 11 - } 12 - 13 - pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error> 14 - where 15 - D: Deserializer<'de>, 16 - { 17 - let date_value = String::deserialize(deserializer)?; 18 - DateTime::parse_from_rfc3339(&date_value) 19 - .map(|v| v.with_timezone(&Utc)) 20 - .map_err(serde::de::Error::custom) 21 - } 22 - } 23 - 24 - pub mod optional_format { 25 - use chrono::{DateTime, SecondsFormat, Utc}; 26 - use serde::{self, Deserialize, Deserializer, Serializer}; 27 - 28 - pub fn serialize<S>(date: &Option<DateTime<Utc>>, serializer: S) -> Result<S::Ok, S::Error> 29 - where 30 - S: Serializer, 31 - { 32 - if date.is_none() { 33 - return serializer.serialize_none(); 34 - } 35 - let s = date.unwrap().to_rfc3339_opts(SecondsFormat::Millis, true); 36 - serializer.serialize_str(&s) 37 - } 38 - 39 - pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<DateTime<Utc>>, D::Error> 40 - where 41 - D: Deserializer<'de>, 42 - { 43 - let maybe_date_value: Option<String> = Option::deserialize(deserializer)?; 44 - if maybe_date_value.is_none() { 45 - return Ok(None); 46 - } 47 - let date_value = maybe_date_value.unwrap(); 48 - DateTime::parse_from_rfc3339(&date_value) 49 - .map(|v| v.with_timezone(&Utc)) 50 - .map_err(serde::de::Error::custom) 51 - .map(Some) 52 - } 53 - }
-52
src/atproto/errors.rs
··· 1 - use thiserror::Error; 2 - 3 - #[derive(Debug, Error)] 4 - pub enum ClientError { 5 - #[error("error-xrpc-client-1 Malformed PutRecord response: {0:?}")] 6 - PutRecordResponseFailure(reqwest::Error), 7 - 8 - #[error("error-xrpc-client-2 Malformed CreateRecord response: {0:?}")] 9 - CreateRecordResponseFailure(reqwest::Error), 10 - 11 - #[error("error-xrpc-client-3 XRPC error from server: {0}")] 12 - ServerError(String), 13 - 14 - #[error("error-xrpc-client-4 Invalid record format: {0}")] 15 - InvalidRecordFormat(String), 16 - } 17 - 18 - #[derive(Debug, Error)] 19 - pub enum UriError { 20 - #[error("error-uri-1 Invalid AT-URI: repository missing")] 21 - RepositoryMissing, 22 - 23 - #[error("error-uri-2 Invalid AT-URI: collection missing")] 24 - CollectionMissing, 25 - 26 - #[error("error-uri-3 Invalid AT-URI: rkey missing")] 27 - RkeyMissing, 28 - 29 - #[error("error-uri-4 Invalid AT-URI")] 30 - InvalidFormat, 31 - 32 - #[error("error-uri-5 Invalid AT-URI: repository contains invalid characters")] 33 - InvalidRepository, 34 - 35 - #[error("error-uri-6 Invalid AT-URI: collection contains invalid characters")] 36 - InvalidCollection, 37 - 38 - #[error("error-uri-7 Invalid AT-URI: rkey contains invalid characters")] 39 - InvalidRkey, 40 - 41 - #[error("error-uri-8 Invalid AT-URI: path traversal attempt detected")] 42 - PathTraversalAttempt, 43 - 44 - #[error("error-uri-9 Invalid AT-URI: repository too long (max 253 chars)")] 45 - RepositoryTooLong, 46 - 47 - #[error("error-uri-10 Invalid AT-URI: collection too long (max 128 chars)")] 48 - CollectionTooLong, 49 - 50 - #[error("error-uri-11 Invalid AT-URI: rkey too long (max 512 chars)")] 51 - RkeyTooLong, 52 - }
+2 -3
src/atproto/lexicon/community_lexicon_calendar_event.rs
··· 3 3 use std::collections::HashMap; 4 4 5 5 use crate::atproto::lexicon::community::lexicon::location::{Address, Fsq, Geo, Hthree}; 6 - use crate::atproto::{ 7 - datetime::format as datetime_format, datetime::optional_format as optional_datetime_format, 8 - }; 6 + use atproto_record::datetime::format as datetime_format; 7 + use atproto_record::datetime::optional_format as optional_datetime_format; 9 8 10 9 pub const NSID: &str = "community.lexicon.calendar.event"; 11 10
+1 -1
src/atproto/lexicon/community_lexicon_calendar_rsvp.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use serde::{Deserialize, Serialize}; 3 3 4 - use crate::atproto::datetime::format as datetime_format; 5 4 use crate::atproto::lexicon::com::atproto::repo::StrongRef; 5 + use atproto_record::datetime::format as datetime_format; 6 6 7 7 pub const NSID: &str = "community.lexicon.calendar.rsvp"; 8 8
+1 -1
src/atproto/lexicon/events_smokesignal_calendar_rsvp.rs
··· 1 1 use chrono::{DateTime, Utc}; 2 2 use serde::{Deserialize, Serialize}; 3 3 4 - use crate::atproto::datetime::optional_format as optional_datetime_format; 5 4 use crate::atproto::lexicon::com::atproto::repo::StrongRef; 5 + use atproto_record::datetime::optional_format as optional_datetime_format; 6 6 7 7 pub const NSID: &str = "events.smokesignal.calendar.rsvp"; 8 8
-5
src/atproto/mod.rs
··· 1 1 pub mod auth; 2 - pub mod client; 3 - pub mod datetime; 4 - pub mod errors; 5 2 pub mod lexicon; 6 - pub mod uri; 7 - pub mod xrpc;
-160
src/atproto/uri.rs
··· 1 - use anyhow::Result; 2 - 3 - use crate::{atproto::errors::UriError, validation::is_valid_hostname}; 4 - 5 - // Constants for maximum lengths 6 - const MAX_REPOSITORY_LENGTH: usize = 253; // DNS name length limit 7 - const MAX_COLLECTION_LENGTH: usize = 128; 8 - const MAX_RKEY_LENGTH: usize = 512; 9 - 10 - /// Validates a repository name for AT Protocol URIs 11 - /// 12 - /// Repository names should generally follow host name rules: 13 - /// - Alphanumeric characters, hyphens, and periods 14 - /// - No consecutive periods 15 - /// - Cannot start or end with period or hyphen 16 - fn is_valid_repository(repository: &str) -> bool { 17 - if repository.is_empty() || repository.len() > MAX_REPOSITORY_LENGTH { 18 - return false; 19 - } 20 - 21 - // Check for invalid characters 22 - if !repository 23 - .chars() 24 - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == ':') 25 - { 26 - return false; 27 - } 28 - 29 - // TODO: If starts with "did:plc:" then validate encoded string and length 30 - if repository.starts_with("did:plc:") { 31 - return true; 32 - } 33 - 34 - // TODO: If starts with "did:web:" then validate hostname and parts 35 - if repository.starts_with("did:web:") { 36 - return true; 37 - } 38 - 39 - is_valid_hostname(repository) 40 - } 41 - 42 - /// Validates a collection name for AT Protocol URIs 43 - /// 44 - /// Collections should follow namespace-like naming: 45 - /// - Alphanumeric characters, hyphens, underscores, and periods 46 - /// - No path traversal sequences 47 - fn is_valid_collection(collection: &str) -> bool { 48 - if collection.is_empty() || collection.len() > MAX_COLLECTION_LENGTH { 49 - return false; 50 - } 51 - 52 - // Check for invalid characters 53 - if !collection 54 - .chars() 55 - .all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.') 56 - { 57 - return false; 58 - } 59 - 60 - // Check for path traversal attempts 61 - if collection.contains("../") || collection == ".." { 62 - return false; 63 - } 64 - 65 - true 66 - } 67 - 68 - /// Validates a record key (rkey) for AT Protocol URIs 69 - /// 70 - /// Record keys have more flexible rules but shouldn't contain characters 71 - /// that could cause problems in URLs or filesystems 72 - fn is_valid_rkey(rkey: &str) -> bool { 73 - if rkey.is_empty() || rkey.len() > MAX_RKEY_LENGTH { 74 - return false; 75 - } 76 - 77 - // Block characters that are particularly problematic 78 - if rkey.chars().any(|c| { 79 - c == '<' 80 - || c == '>' 81 - || c == '"' 82 - || c == '\'' 83 - || c == '`' 84 - || c == '\\' 85 - || c == '|' 86 - || c == '*' 87 - || c == '?' 88 - || c == '#' 89 - }) { 90 - return false; 91 - } 92 - 93 - // Check for path traversal attempts 94 - if rkey.contains("../") || rkey == ".." { 95 - return false; 96 - } 97 - 98 - true 99 - } 100 - 101 - /// Parses and validates an AT Protocol URI string into its components 102 - /// 103 - /// AT Protocol URIs follow the format: at://repository/collection/rkey 104 - /// This function validates each component for proper format and security 105 - pub fn parse_aturi(uri: &str) -> Result<(String, String, String)> { 106 - // Validate URI has the correct prefix 107 - if !uri.starts_with("at://") { 108 - return Err(UriError::InvalidFormat.into()); 109 - } 110 - 111 - let value = uri.strip_prefix("at://").unwrap(); // Safe because we checked above 112 - 113 - // Split the URI into components 114 - let mut components = value.split('/'); 115 - 116 - // Extract repository 117 - let repository = components.next().ok_or(UriError::RepositoryMissing)?; 118 - 119 - // Validate repository 120 - if repository.len() > MAX_REPOSITORY_LENGTH { 121 - return Err(UriError::RepositoryTooLong.into()); 122 - } 123 - if !is_valid_repository(repository) { 124 - return Err(UriError::InvalidRepository.into()); 125 - } 126 - 127 - // Extract collection 128 - let collection = components.next().ok_or(UriError::CollectionMissing)?; 129 - 130 - // Validate collection 131 - if collection.len() > MAX_COLLECTION_LENGTH { 132 - return Err(UriError::CollectionTooLong.into()); 133 - } 134 - if !is_valid_collection(collection) { 135 - return Err(UriError::InvalidCollection.into()); 136 - } 137 - 138 - // Extract record key 139 - let rkey = components.next().ok_or(UriError::RkeyMissing)?; 140 - 141 - // Validate record key 142 - if rkey.len() > MAX_RKEY_LENGTH { 143 - return Err(UriError::RkeyTooLong.into()); 144 - } 145 - if !is_valid_rkey(rkey) { 146 - return Err(UriError::InvalidRkey.into()); 147 - } 148 - 149 - // Check for any path traversal attempts 150 - if repository.contains("..") || collection.contains("..") || rkey.contains("..") { 151 - return Err(UriError::PathTraversalAttempt.into()); 152 - } 153 - 154 - // Return validated components 155 - Ok(( 156 - repository.to_string(), 157 - collection.to_string(), 158 - rkey.to_string(), 159 - )) 160 - }
-18
src/atproto/xrpc.rs
··· 1 - use serde::{Deserialize, Serialize}; 2 - 3 - #[derive(Debug, Clone, Serialize, Deserialize)] 4 - pub struct SimpleError { 5 - pub error: Option<String>, 6 - pub error_description: Option<String>, 7 - pub message: Option<String>, 8 - } 9 - 10 - impl SimpleError { 11 - pub fn error_message(&self) -> String { 12 - [&self.error, &self.error_description, &self.message] 13 - .iter() 14 - .filter_map(|v| (*v).clone()) 15 - .collect::<Vec<String>>() 16 - .join(": ") 17 - } 18 - }
+5 -16
src/bin/crypto.rs
··· 3 3 use base64::{engine::general_purpose, Engine as _}; 4 4 use rand::RngCore; 5 5 6 - use smokesignal::jose::jwk; 7 - 8 6 fn main() { 9 7 let mut rng = rand::thread_rng(); 10 8 11 - env::args().for_each(|arg| match arg.as_str() { 12 - "key" => { 13 - let mut key: [u8; 64] = [0; 64]; 14 - rng.fill_bytes(&mut key); 15 - let encoded: String = general_purpose::STANDARD_NO_PAD.encode(key); 16 - println!("{encoded}"); 17 - } 18 - "jwk" => { 19 - let ec_jwk = jwk::generate(); 20 - let serialized_value = 21 - serde_json::to_string_pretty(&ec_jwk).expect("failed to serialize ec jwk"); 22 - println!("{serialized_value}"); 23 - } 24 - _ => {} 9 + env::args().for_each(|arg| if arg.as_str() == "key" { 10 + let mut key: [u8; 64] = [0; 64]; 11 + rng.fill_bytes(&mut key); 12 + let encoded: String = general_purpose::STANDARD_NO_PAD.encode(key); 13 + println!("{encoded}"); 25 14 }); 26 15 }
-42
src/bin/resolve.rs
··· 1 - use std::env; 2 - 3 - use anyhow::Result; 4 - use smokesignal::config::{default_env, optional_env, version, CertificateBundles, DnsNameservers}; 5 - use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _}; 6 - 7 - #[tokio::main] 8 - async fn main() -> Result<()> { 9 - tracing_subscriber::registry() 10 - .with(tracing_subscriber::EnvFilter::new( 11 - std::env::var("RUST_LOG").unwrap_or_else(|_| "trace".into()), 12 - )) 13 - .with(tracing_subscriber::fmt::layer().pretty()) 14 - .init(); 15 - 16 - let certificate_bundles: CertificateBundles = optional_env("CERTIFICATE_BUNDLES").try_into()?; 17 - let default_user_agent = format!("smokesignal ({}; +https://smokesignal.events/)", version()?); 18 - let user_agent = default_env("USER_AGENT", &default_user_agent); 19 - let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?; 20 - 21 - let mut client_builder = reqwest::Client::builder(); 22 - for ca_certificate in certificate_bundles.as_ref() { 23 - tracing::info!("Loading CA certificate: {:?}", ca_certificate); 24 - let cert = std::fs::read(ca_certificate)?; 25 - let cert = reqwest::Certificate::from_pem(&cert)?; 26 - client_builder = client_builder.add_root_certificate(cert); 27 - } 28 - 29 - client_builder = client_builder.user_agent(user_agent); 30 - let http_client = client_builder.build()?; 31 - 32 - // Initialize the DNS resolver with configuration from the app config 33 - let dns_resolver = smokesignal::resolve::create_resolver(dns_nameservers); 34 - 35 - for subject in env::args() { 36 - let resolved_did = 37 - smokesignal::resolve::resolve_subject(&http_client, &dns_resolver, &subject).await; 38 - tracing::info!(?resolved_did, ?subject, "resolved subject"); 39 - } 40 - 41 - Ok(()) 42 - }
+136 -13
src/bin/smokesignal.rs
··· 1 1 use anyhow::Result; 2 - use chrono::Duration; 2 + use atproto_identity::key::identify_key; 3 + use atproto_identity::resolve::{create_resolver, IdentityResolver, InnerIdentityResolver}; 4 + use atproto_oauth_axum::state::OAuthClientConfig; 3 5 use smokesignal::{ 4 6 http::{ 5 - context::{AppEngine, I18nContext, WebContext}, 7 + context::{AppEngine, WebContext}, 6 8 server::build_router, 7 9 }, 8 10 i18n::Locales, 9 - resolve::create_resolver, 10 - storage::cache::create_cache_pool, 11 - task_refresh_tokens::{RefreshTokensTask, RefreshTokensTaskConfig}, 11 + key_provider::SimpleKeyProvider, 12 + storage::{ 13 + atproto::{PostgresDidDocumentStorage, PostgresOAuthRequestStorage}, 14 + cache::create_cache_pool, 15 + }, 12 16 }; 17 + 18 + use chrono::Duration; 19 + use smokesignal::config::OAuthBackendConfig; 20 + use smokesignal::task_identity_refresh::{IdentityRefreshTask, IdentityRefreshTaskConfig}; 21 + use smokesignal::task_oauth_requests_cleanup::{ 22 + OAuthRequestsCleanupTask, OAuthRequestsCleanupTaskConfig, 23 + }; 24 + use smokesignal::task_refresh_tokens::{RefreshTokensTask, RefreshTokensTaskConfig}; 25 + 26 + use axum::extract::FromRef; 13 27 use sqlx::PgPool; 14 - use std::{env, str::FromStr}; 28 + use std::{collections::HashMap, env, str::FromStr, sync::Arc}; 15 29 use tokio::net::TcpListener; 16 30 use tokio::signal; 17 31 use tokio_util::{sync::CancellationToken, task::TaskTracker}; ··· 80 94 let jinja = reload_env::build_env(&config.external_base, &config.version); 81 95 82 96 // Initialize the DNS resolver with configuration from the app config 83 - let dns_resolver = create_resolver(config.dns_nameservers.clone()); 97 + let dns_resolver = create_resolver(config.dns_nameservers.as_ref()); 98 + 99 + // Initialize OAuth and identity resolution components 100 + let oauth_storage = PostgresOAuthRequestStorage::new_arc(pool.clone()); 101 + let document_storage = PostgresDidDocumentStorage::new_arc(pool.clone()); 102 + 103 + // Create a key provider populated with signing keys from OAuth backend config 104 + let key_provider_keys = 105 + if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend { 106 + signing_keys 107 + .as_ref() 108 + .iter() 109 + .filter_map(|(key_id, private_key)| match identify_key(private_key) { 110 + Ok(key_data) => { 111 + tracing::info!(?key_id, "loaded signing key for key provider"); 112 + Some((key_id.clone(), key_data)) 113 + } 114 + Err(err) => { 115 + tracing::error!(?err, ?key_id, "failed to identify key for key provider"); 116 + None 117 + } 118 + }) 119 + .collect() 120 + } else { 121 + HashMap::new() // Empty for AIP backend 122 + }; 123 + let key_provider = Arc::new(SimpleKeyProvider::new(key_provider_keys)); 124 + 125 + // Create OAuth client config (only for AT Protocol backend) 126 + let oauth_client_config = OAuthClientConfig { 127 + client_id: config.external_base.clone(), 128 + jwks_uri: None, 129 + signing_keys: if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend 130 + { 131 + signing_keys 132 + .as_ref() 133 + .iter() 134 + .filter_map(|value| { 135 + tracing::info!(?value, "signing key"); 136 + 137 + identify_key(value.1).ok() 138 + }) 139 + .collect() 140 + } else { 141 + Vec::new() // Empty for AIP backend 142 + }, 143 + redirect_uris: format!("{}/oauth/callback", config.external_base), 144 + client_name: Some("Smoke Signal".to_string()), 145 + client_uri: Some(config.external_base.clone()), 146 + logo_uri: Some(format!( 147 + "https://{}/logo-160x160x.png", 148 + config.external_base 149 + )), 150 + tos_uri: Some("https://docs.smokesignal.events/docs/about/terms/".to_string()), 151 + policy_uri: Some("https://docs.smokesignal.events/docs/about/privacy/".to_string()), 152 + scope: Some("atproto transition:generic transition:email".to_string()), 153 + }; 154 + 155 + // Create identity resolver using the provided method 156 + let identity_resolver = IdentityResolver(Arc::new(InnerIdentityResolver { 157 + dns_resolver, 158 + http_client: http_client.clone(), 159 + plc_hostname: config.plc_hostname.clone(), 160 + })); 84 161 85 162 let web_context = WebContext::new( 86 163 pool.clone(), ··· 88 165 AppEngine::from(jinja), 89 166 &http_client, 90 167 config.clone(), 91 - I18nContext::new(supported_languages, locales), 92 - dns_resolver, 168 + oauth_client_config, 169 + identity_resolver, 170 + key_provider, 171 + oauth_storage, 172 + document_storage.clone(), 173 + supported_languages, 174 + locales, 93 175 ); 94 176 95 177 let app = build_router(web_context.clone()); ··· 126 208 }); 127 209 } 128 210 129 - { 211 + // Only spawn refresh tokens task for AT Protocol OAuth backend 212 + if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend { 130 213 let task_config = RefreshTokensTaskConfig { 131 214 sleep_interval: Duration::seconds(10), 132 215 worker_id: "dev".to_string(), 133 216 external_url_base: config.external_base.clone(), 134 - signing_keys: config.signing_keys.clone(), 135 - oauth_active_keys: config.oauth_active_keys.clone(), 217 + signing_keys: signing_keys.clone(), 136 218 }; 137 219 let task = RefreshTokensTask::new( 138 220 task_config, 139 221 http_client.clone(), 140 222 pool.clone(), 141 223 cache_pool.clone(), 224 + document_storage.clone(), 142 225 token.clone(), 143 226 ); 144 227 145 228 let inner_token = token.clone(); 146 229 tracker.spawn(async move { 147 230 if let Err(err) = task.run().await { 148 - tracing::error!("Database task failed: {}", err); 231 + tracing::error!("Refresh tokens task failed: {}", err); 232 + } 233 + inner_token.cancel(); 234 + }); 235 + } 236 + 237 + // Spawn OAuth requests cleanup task if enabled 238 + if config.enable_task_oauth_requests_cleanup { 239 + let cleanup_task_config = OAuthRequestsCleanupTaskConfig { 240 + sleep_interval: Duration::hours(1), // Run once per hour 241 + }; 242 + let cleanup_task = 243 + OAuthRequestsCleanupTask::new(cleanup_task_config, pool.clone(), token.clone()); 244 + 245 + let inner_token = token.clone(); 246 + tracker.spawn(async move { 247 + if let Err(err) = cleanup_task.run().await { 248 + tracing::error!("OAuth requests cleanup task failed: {}", err); 249 + } 250 + inner_token.cancel(); 251 + }); 252 + } 253 + 254 + // Spawn identity refresh task if enabled 255 + if config.enable_task_identity_refresh { 256 + let identity_refresh_config = IdentityRefreshTaskConfig { 257 + sleep_interval: Duration::hours(1), // Run once per hour 258 + worker_id: "dev".to_string(), 259 + }; 260 + let identity_refresh_task = IdentityRefreshTask::new( 261 + identity_refresh_config, 262 + pool.clone(), 263 + document_storage.clone(), 264 + atproto_identity::resolve::IdentityResolver::from_ref(&web_context), 265 + token.clone(), 266 + ); 267 + 268 + let inner_token = token.clone(); 269 + tracker.spawn(async move { 270 + if let Err(err) = identity_refresh_task.run().await { 271 + tracing::error!("Identity refresh task failed: {}", err); 149 272 } 150 273 inner_token.cancel(); 151 274 });
+111 -121
src/config.rs
··· 1 - use anyhow::Result; 1 + use anyhow::{Context, Result}; 2 + use atproto_identity::key::{identify_key, to_public, KeyData, KeyType}; 2 3 use axum_extra::extract::cookie::Key; 3 4 use base64::{engine::general_purpose, Engine as _}; 4 5 use ordermap::OrderMap; 5 - use p256::SecretKey; 6 - use rand::seq::SliceRandom; 7 6 8 7 use crate::config_errors::ConfigError; 9 - use crate::encoding_errors::EncodingError; 10 - use crate::jose::jwk::WrappedJsonWebKeySet; 11 8 12 9 #[derive(Clone)] 13 10 pub struct HttpPort(u16); ··· 18 15 #[derive(Clone)] 19 16 pub struct CertificateBundles(Vec<String>); 20 17 21 - #[derive(Clone)] 22 - pub struct SigningKeys(OrderMap<String, SecretKey>); 23 - 24 - #[derive(Clone)] 25 - pub struct OAuthActiveKeys(Vec<String>); 18 + #[derive(Clone, Debug)] 19 + pub struct SigningKeys(OrderMap<String, String>); 26 20 27 21 #[derive(Clone)] 28 22 pub struct AdminDIDs(Vec<String>); ··· 30 24 #[derive(Clone)] 31 25 pub struct DnsNameservers(Vec<std::net::IpAddr>); 32 26 27 + #[derive(Clone, Debug)] 28 + pub enum OAuthBackendConfig { 29 + ATProtocol { 30 + signing_keys: SigningKeys, 31 + }, 32 + AIP { 33 + hostname: String, 34 + client_id: String, 35 + client_secret: String, 36 + }, 37 + } 38 + 33 39 #[derive(Clone)] 34 40 pub struct Config { 35 41 pub version: String, ··· 41 47 pub user_agent: String, 42 48 pub database_url: String, 43 49 pub plc_hostname: String, 44 - pub signing_keys: SigningKeys, 45 - pub oauth_active_keys: OAuthActiveKeys, 46 - pub destination_key: SecretKey, 47 50 pub redis_url: String, 48 51 pub admin_dids: AdminDIDs, 49 52 pub dns_nameservers: DnsNameservers, 53 + pub oauth_backend: OAuthBackendConfig, 54 + pub destination_key_data: KeyData, 55 + pub enable_task_oauth_requests_cleanup: bool, 56 + pub enable_task_identity_refresh: bool, 50 57 } 51 58 52 59 impl Config { ··· 72 79 73 80 let database_url = default_env("DATABASE_URL", "sqlite://development.db"); 74 81 75 - let signing_keys: SigningKeys = 76 - require_env("SIGNING_KEYS").and_then(|value| value.try_into())?; 77 - 78 - let oauth_active_keys: OAuthActiveKeys = 79 - require_env("OAUTH_ACTIVE_KEYS").and_then(|value| value.try_into())?; 80 - 81 - let destination_key = require_env("DESTINATION_KEY").and_then(|value| { 82 - signing_keys 83 - .0 84 - .get(&value) 85 - .cloned() 86 - .ok_or(ConfigError::InvalidDestinationKey.into()) 87 - })?; 88 - 89 82 let redis_url = default_env("REDIS_URL", "redis://valkey:6379/0"); 90 83 91 84 let admin_dids: AdminDIDs = optional_env("ADMIN_DIDS").try_into()?; 92 85 93 86 let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?; 94 87 88 + // Create OAuth backend configuration based on environment variables 89 + let oauth_backend_type = default_env("OAUTH_BACKEND", "pds"); 90 + let oauth_backend = match oauth_backend_type.to_lowercase().as_str() { 91 + "pds" => { 92 + let signing_keys: SigningKeys = 93 + require_env("SIGNING_KEYS").and_then(|value| value.try_into())?; 94 + OAuthBackendConfig::ATProtocol { signing_keys } 95 + } 96 + "aip" => { 97 + let hostname = 98 + require_env("AIP_HOSTNAME").map(|value| format!("https://{value}"))?; 99 + let client_id = require_env("AIP_CLIENT_ID")?; 100 + let client_secret = require_env("AIP_CLIENT_SECRET")?; 101 + OAuthBackendConfig::AIP { 102 + hostname, 103 + client_id, 104 + client_secret, 105 + } 106 + } 107 + _ => return Err(ConfigError::InvalidOAuthBackend(oauth_backend_type).into()), 108 + }; 109 + 110 + // Parse destination key for authentication redirects 111 + let destination_key_data = identify_key(&require_env("DESTINATION_KEY")?) 112 + .context("failed to parse DESTINATION_KEY")?; 113 + 114 + // Parse optional task enablement flags 115 + let enable_task_oauth_requests_cleanup = 116 + default_env("ENABLE_TASK_OAUTH_REQUESTS_CLEANUP", "true") 117 + .parse::<bool>() 118 + .unwrap_or(true); 119 + let enable_task_identity_refresh = default_env("ENABLE_TASK_IDENTITY_REFRESH", "true") 120 + .parse::<bool>() 121 + .unwrap_or(true); 122 + 95 123 Ok(Self { 96 124 version: version()?, 97 125 http_port, ··· 101 129 user_agent, 102 130 plc_hostname, 103 131 database_url, 104 - signing_keys, 105 - oauth_active_keys, 106 132 http_cookie_key, 107 - destination_key, 108 133 redis_url, 109 134 admin_dids, 110 135 dns_nameservers, 136 + oauth_backend, 137 + destination_key_data, 138 + enable_task_oauth_requests_cleanup, 139 + enable_task_identity_refresh, 111 140 }) 112 141 } 113 142 114 - pub fn select_oauth_signing_key(&self) -> Result<(String, SecretKey)> { 115 - let key_id = self 116 - .oauth_active_keys 117 - .as_ref() 118 - .choose(&mut rand::thread_rng()) 119 - .ok_or(ConfigError::SigningKeyNotFound)? 120 - .clone(); 121 - let signing_key = self 122 - .signing_keys 123 - .as_ref() 124 - .get(&key_id) 125 - .ok_or(ConfigError::SigningKeyNotFound)? 126 - .clone(); 127 - 128 - Ok((key_id, signing_key)) 143 + pub fn select_oauth_signing_key(&self) -> Result<(String, String)> { 144 + match &self.oauth_backend { 145 + OAuthBackendConfig::ATProtocol { signing_keys } => { 146 + let item = signing_keys.as_ref().first(); 147 + item.map(|(key, value)| (key.clone(), value.clone())) 148 + .ok_or(anyhow::anyhow!("signing keys is empty")) 149 + } 150 + OAuthBackendConfig::AIP { .. } => Err(anyhow::anyhow!( 151 + "signing keys not available for AIP OAuth backend" 152 + )), 153 + } 129 154 } 130 155 131 156 /// Check if a DID is in the admin allow list ··· 216 241 } 217 242 } 218 243 219 - impl AsRef<OrderMap<String, SecretKey>> for SigningKeys { 220 - fn as_ref(&self) -> &OrderMap<String, SecretKey> { 244 + impl AsRef<OrderMap<String, String>> for SigningKeys { 245 + fn as_ref(&self) -> &OrderMap<String, String> { 221 246 &self.0 222 247 } 223 248 } ··· 225 250 impl TryFrom<String> for SigningKeys { 226 251 type Error = anyhow::Error; 227 252 fn try_from(value: String) -> Result<Self, Self::Error> { 228 - let content = { 229 - if value.starts_with("/") { 230 - // Verify file exists before reading 231 - if !std::path::Path::new(&value).exists() { 232 - return Err(ConfigError::SigningKeysFileNotFound(value).into()); 233 - } 234 - std::fs::read(&value).map_err(ConfigError::ReadSigningKeysFailed)? 235 - } else { 236 - general_purpose::STANDARD 237 - .decode(&value) 238 - .map_err(EncodingError::Base64DecodingFailed)? 239 - } 240 - }; 241 - 242 - // Validate content is not empty 243 - if content.is_empty() { 244 - return Err(ConfigError::EmptySigningKeysFile.into()); 253 + // Allow empty signing keys for initial pass with in-memory storage 254 + if value.is_empty() { 255 + return Ok(Self(OrderMap::new())); 245 256 } 246 257 247 - // Parse JSON with proper error handling 248 - let jwks = serde_json::from_slice::<WrappedJsonWebKeySet>(&content) 249 - .map_err(ConfigError::ParseSigningKeysFailed)?; 258 + // Parse semicolon-separated DID method key strings 259 + let private_keys: Vec<&str> = value.split(';').filter(|s| !s.is_empty()).collect(); 250 260 251 - // Validate JWKS contains keys 252 - if jwks.keys.is_empty() { 253 - return Err(ConfigError::MissingKeysInJWKS.into()); 261 + if private_keys.is_empty() { 262 + return Ok(Self(OrderMap::new())); 254 263 } 255 264 256 - // Track keys that failed validation for better error reporting 257 - let mut validation_errors = Vec::new(); 265 + let mut signing_keys = OrderMap::new(); 258 266 259 - let signing_keys = jwks 260 - .keys 261 - .iter() 262 - .filter_map(|key| { 263 - // Validate key has required fields 264 - if key.kid.is_none() { 265 - validation_errors.push("Missing key ID (kid)".to_string()); 266 - return None; 267 + for private_key in private_keys { 268 + // Verify this is a private key using identify_key 269 + let key_data = match identify_key(private_key) { 270 + Ok(key_data) => key_data, 271 + Err(err) => { 272 + tracing::error!(?err, ?private_key, "failed to identify key"); 273 + continue; 267 274 } 275 + }; 268 276 269 - if let (Some(key_id), secret_key) = (key.kid.clone(), key.jwk.clone()) { 270 - // Verify the key_id format (should be a valid ULID) 271 - if ulid::Ulid::from_string(&key_id).is_err() { 272 - validation_errors.push(format!("Invalid key ID format: {}", key_id)); 273 - return None; 274 - } 277 + match key_data.0 { 278 + KeyType::P256Public | KeyType::K256Public => { 279 + tracing::error!(?private_key, "public key found"); 280 + continue; 281 + } 282 + _ => {} 283 + } 275 284 276 - // Validate the secret key 277 - match p256::SecretKey::from_jwk(&secret_key) { 278 - Ok(secret_key) => Some((key_id, secret_key)), 279 - Err(err) => { 280 - validation_errors.push(format!("Invalid key {}: {}", key_id, err)); 281 - None 282 - } 283 - } 284 - } else { 285 - None 285 + // Generate public key using to_public and use it as the identifier 286 + let public_key = match to_public(&key_data) { 287 + Ok(public_key) => public_key, 288 + Err(err) => { 289 + tracing::error!( 290 + ?err, 291 + ?private_key, 292 + "unable to derive public key from signing key" 293 + ); 294 + continue; 286 295 } 287 - }) 288 - .collect::<OrderMap<String, SecretKey>>(); 296 + }; 297 + 298 + // Store the original DID method key string as the value 299 + // This allows us to call identify_key again later when we need the KeyData 300 + signing_keys.insert(public_key.to_string(), private_key.to_string()); 301 + } 289 302 290 303 // Check if we have any valid keys 291 304 if signing_keys.is_empty() { 292 - if !validation_errors.is_empty() { 293 - return Err(ConfigError::SigningKeysValidationFailed(validation_errors).into()); 294 - } 295 305 return Err(ConfigError::EmptySigningKeys.into()); 296 306 } 297 307 298 308 Ok(Self(signing_keys)) 299 - } 300 - } 301 - 302 - impl AsRef<Vec<String>> for OAuthActiveKeys { 303 - fn as_ref(&self) -> &Vec<String> { 304 - &self.0 305 - } 306 - } 307 - 308 - impl TryFrom<String> for OAuthActiveKeys { 309 - type Error = anyhow::Error; 310 - fn try_from(value: String) -> Result<Self, Self::Error> { 311 - let values = value 312 - .split(';') 313 - .map(|s| s.to_string()) 314 - .collect::<Vec<String>>(); 315 - if values.is_empty() { 316 - return Err(ConfigError::EmptyOAuthActiveKeys.into()); 317 - } 318 - Ok(Self(values)) 319 309 } 320 310 } 321 311
+14
src/config_errors.rs
··· 125 125 /// that fail validation checks (such as having invalid format). 126 126 #[error("error-config-17 Signing keys validation failed: {0:?}")] 127 127 SigningKeysValidationFailed(Vec<String>), 128 + 129 + /// Error when AIP OAuth configuration is incomplete. 130 + /// 131 + /// This error occurs when oauth_backend is set to "aip" but 132 + /// required AIP configuration values are missing. 133 + #[error("error-config-18 When oauth_backend is 'aip', AIP_HOSTNAME, AIP_CLIENT_ID, and AIP_CLIENT_SECRET must all be set")] 134 + AipConfigurationIncomplete, 135 + 136 + /// Error when oauth_backend has an invalid value. 137 + /// 138 + /// This error occurs when the OAUTH_BACKEND environment variable 139 + /// contains a value other than "aip" or "pds". 140 + #[error("error-config-19 oauth_backend must be either 'aip' or 'pds', got: {0}")] 141 + InvalidOAuthBackend(String), 128 142 }
-228
src/did.rs
··· 1 - pub mod model { 2 - 3 - use serde::Deserialize; 4 - use serde_json::Value; 5 - use std::collections::HashMap; 6 - 7 - #[derive(Clone, Deserialize, Debug)] 8 - #[serde(rename_all = "camelCase")] 9 - pub struct Service { 10 - pub id: String, 11 - 12 - pub r#type: String, 13 - 14 - pub service_endpoint: String, 15 - } 16 - 17 - #[derive(Clone, Deserialize, Debug)] 18 - #[serde(tag = "type", rename_all = "camelCase")] 19 - pub enum VerificationMethod { 20 - Multikey { 21 - id: String, 22 - controller: String, 23 - public_key_multibase: String, 24 - }, 25 - 26 - #[serde(untagged)] 27 - Other { 28 - #[serde(flatten)] 29 - extra: HashMap<String, Value>, 30 - }, 31 - } 32 - 33 - #[derive(Clone, Deserialize, Debug)] 34 - #[serde(rename_all = "camelCase")] 35 - pub struct Document { 36 - pub id: String, 37 - pub also_known_as: Vec<String>, 38 - pub service: Vec<Service>, 39 - } 40 - 41 - impl Document { 42 - pub fn pds_endpoint(&self) -> Option<&str> { 43 - self.service 44 - .iter() 45 - .find(|service| service.r#type == "AtprotoPersonalDataServer") 46 - .map(|service| service.service_endpoint.as_str()) 47 - } 48 - 49 - pub fn primary_handle(&self) -> Option<&str> { 50 - self.also_known_as.first().map(|handle| { 51 - if let Some(trimmed) = handle.strip_prefix("at://") { 52 - trimmed 53 - } else { 54 - handle.as_str() 55 - } 56 - }) 57 - } 58 - } 59 - 60 - #[cfg(test)] 61 - mod tests { 62 - use crate::did::model::Document; 63 - 64 - #[test] 65 - fn test_deserialize() { 66 - let document = serde_json::from_str::<Document>( 67 - r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2#atproto","type":"Multikey","controller":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","publicKeyMultibase":"zQ3shXvCK2RyPrSLYQjBEw5CExZkUhJH3n1K2Mb9sC7JbvRMF"}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##, 68 - ); 69 - assert!(document.is_ok()); 70 - 71 - let document = document.unwrap(); 72 - assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2"); 73 - } 74 - 75 - #[test] 76 - fn test_deserialize_unsupported_verification_method() { 77 - let documents = vec![ 78 - r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2#atproto","type":"Ed25519VerificationKey2020","controller":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","publicKeyMultibase":"zQ3shXvCK2RyPrSLYQjBEw5CExZkUhJH3n1K2Mb9sC7JbvRMF"}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##, 79 - r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A","type": "JsonWebKey2020","controller": "did:example:123","publicKeyJwk": {"crv": "Ed25519","x": "VCpo2LMLhn6iWku8MKvSLg2ZAoC-nlOyPVQaO3FxVeQ","kty": "OKP","kid": "_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A"}}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##, 80 - ]; 81 - for document in documents { 82 - let document = serde_json::from_str::<Document>(document); 83 - assert!(document.is_ok()); 84 - 85 - let document = document.unwrap(); 86 - assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2"); 87 - } 88 - } 89 - } 90 - } 91 - 92 - pub mod plc { 93 - use anyhow::Result; 94 - use thiserror::Error; 95 - 96 - use super::model::Document; 97 - 98 - /// Error types that can occur when working with PLC DIDs 99 - #[derive(Debug, Error)] 100 - pub enum PLCDIDError { 101 - /// Occurs when the HTTP request to fetch the DID document fails 102 - #[error("error-did-plc-1 HTTP request failed: {url} {error}")] 103 - HttpRequestFailed { 104 - /// The URL that was requested 105 - url: String, 106 - /// The underlying HTTP error 107 - error: reqwest::Error, 108 - }, 109 - 110 - /// Occurs when the DID document cannot be parsed from the HTTP response 111 - #[error("error-did-plc-2 Failed to parse DID document: {url} {error}")] 112 - DocumentParseFailed { 113 - /// The URL that was requested 114 - url: String, 115 - /// The underlying parse error 116 - error: reqwest::Error, 117 - }, 118 - } 119 - 120 - pub async fn query( 121 - http_client: &reqwest::Client, 122 - plc_hostname: &str, 123 - did: &str, 124 - ) -> Result<Document> { 125 - let url = format!("https://{}/{}", plc_hostname, did); 126 - 127 - http_client 128 - .get(&url) 129 - .send() 130 - .await 131 - .map_err(|error| PLCDIDError::HttpRequestFailed { 132 - url: url.clone(), 133 - error, 134 - })? 135 - .json::<Document>() 136 - .await 137 - .map_err(|error| PLCDIDError::DocumentParseFailed { url, error }) 138 - .map_err(Into::into) 139 - } 140 - } 141 - 142 - pub mod web { 143 - use anyhow::Result; 144 - use thiserror::Error; 145 - 146 - use super::model::Document; 147 - 148 - /// Error types that can occur when working with Web DIDs 149 - #[derive(Debug, Error)] 150 - pub enum WebDIDError { 151 - /// Occurs when the DID is missing the 'did:web:' prefix 152 - #[error("error-did-web-1 Invalid DID format: missing 'did:web:' prefix")] 153 - InvalidDIDPrefix, 154 - 155 - /// Occurs when the DID is missing a hostname component 156 - #[error("error-did-web-2 Invalid DID format: missing hostname component")] 157 - MissingHostname, 158 - 159 - /// Occurs when the HTTP request to fetch the DID document fails 160 - #[error("error-did-web-3 HTTP request failed: {url} {error}")] 161 - HttpRequestFailed { 162 - /// The URL that was requested 163 - url: String, 164 - /// The underlying HTTP error 165 - error: reqwest::Error, 166 - }, 167 - 168 - /// Occurs when the DID document cannot be parsed from the HTTP response 169 - #[error("error-did-web-4 Failed to parse DID document: {url} {error}")] 170 - DocumentParseFailed { 171 - /// The URL that was requested 172 - url: String, 173 - /// The underlying parse error 174 - error: reqwest::Error, 175 - }, 176 - } 177 - 178 - pub async fn query(http_client: &reqwest::Client, did: &str) -> Result<Document> { 179 - // Parse DID and extract hostname and path components 180 - let mut parts = did 181 - .strip_prefix("did:web:") 182 - .ok_or(WebDIDError::InvalidDIDPrefix)? 183 - .split(':') 184 - .collect::<Vec<&str>>(); 185 - 186 - let hostname = parts.pop().ok_or(WebDIDError::MissingHostname)?; 187 - 188 - // Construct URL based on whether path components exist 189 - let url = if parts.is_empty() { 190 - format!("https://{}/.well-known/did.json", hostname) 191 - } else { 192 - format!("https://{}/{}/did.json", hostname, parts.join("/")) 193 - }; 194 - 195 - // Fetch and parse document 196 - http_client 197 - .get(&url) 198 - .send() 199 - .await 200 - .map_err(|error| WebDIDError::HttpRequestFailed { 201 - url: url.clone(), 202 - error, 203 - })? 204 - .json::<Document>() 205 - .await 206 - .map_err(|error| WebDIDError::DocumentParseFailed { url, error }) 207 - .map_err(Into::into) 208 - } 209 - 210 - pub async fn query_hostname(http_client: &reqwest::Client, hostname: &str) -> Result<Document> { 211 - let url = format!("https://{}/.well-known/did.json", hostname); 212 - 213 - tracing::debug!(?url, "query_hostname"); 214 - 215 - http_client 216 - .get(&url) 217 - .send() 218 - .await 219 - .map_err(|error| WebDIDError::HttpRequestFailed { 220 - url: url.clone(), 221 - error, 222 - })? 223 - .json::<Document>() 224 - .await 225 - .map_err(|error| WebDIDError::DocumentParseFailed { url, error }) 226 - .map_err(Into::into) 227 - } 228 - }
-33
src/encoding.rs
··· 1 - use anyhow::Result; 2 - use base64::{engine::general_purpose, Engine as _}; 3 - use serde::{Deserialize, Serialize}; 4 - use std::borrow::Cow; 5 - 6 - use crate::encoding_errors::EncodingError; 7 - 8 - pub trait ToBase64 { 9 - fn to_base64(&self) -> Result<Cow<str>>; 10 - } 11 - 12 - impl<T: Serialize> ToBase64 for T { 13 - fn to_base64(&self) -> Result<Cow<str>> { 14 - let json_bytes = 15 - serde_json::to_vec(&self).map_err(EncodingError::JsonSerializationFailed)?; 16 - let encoded_json_bytes = general_purpose::URL_SAFE_NO_PAD.encode(json_bytes); 17 - Ok(Cow::Owned(encoded_json_bytes)) 18 - } 19 - } 20 - 21 - pub trait FromBase64: Sized { 22 - fn from_base64<Input: ?Sized + AsRef<[u8]>>(raw: &Input) -> Result<Self>; 23 - } 24 - 25 - impl<T: for<'de> Deserialize<'de> + Sized> FromBase64 for T { 26 - fn from_base64<Input: ?Sized + AsRef<[u8]>>(raw: &Input) -> Result<Self> { 27 - let content = general_purpose::URL_SAFE_NO_PAD 28 - .decode(raw) 29 - .map_err(EncodingError::Base64DecodingFailed)?; 30 - serde_json::from_slice(&content) 31 - .map_err(|err| EncodingError::JsonDeserializationFailed(err).into()) 32 - } 33 - }
-31
src/encoding_errors.rs
··· 1 - use thiserror::Error; 2 - 3 - /// Represents errors that can occur during data encoding and decoding operations. 4 - /// 5 - /// These errors relate to serialization, deserialization, encoding and decoding 6 - /// of data across various formats used throughout the application. 7 - #[derive(Debug, Error)] 8 - pub enum EncodingError { 9 - /// Error when JSON serialization fails. 10 - /// 11 - /// This error occurs when attempting to convert a Rust structure to 12 - /// JSON format fails, typically due to data that cannot be properly 13 - /// represented in JSON. 14 - #[error("error-encoding-1 JSON serialization failed: {0:?}")] 15 - JsonSerializationFailed(serde_json::Error), 16 - 17 - /// Error when Base64 decoding fails. 18 - /// 19 - /// This error occurs when attempting to decode Base64-encoded data 20 - /// that is malformed or contains invalid characters. 21 - #[error("error-encoding-2 Base64 decoding failed: {0:?}")] 22 - Base64DecodingFailed(base64::DecodeError), 23 - 24 - /// Error when JSON deserialization fails. 25 - /// 26 - /// This error occurs when attempting to parse JSON data into a Rust 27 - /// structure fails, typically due to missing fields, type mismatches, 28 - /// or malformed JSON. 29 - #[error("error-encoding-3 JSON deserialization failed: {0:?}")] 30 - JsonDeserializationFailed(serde_json::Error), 31 - }
+1 -1
src/http/cache_countries.rs
··· 4 4 5 5 static COUNTRY_CACHE: OnceCell<Arc<BTreeMap<String, String>>> = OnceCell::new(); 6 6 7 - pub fn cached_countries<'a>() -> Result<&'a Arc<BTreeMap<String, String>>> { 7 + pub(crate) fn cached_countries<'a>() -> Result<&'a Arc<BTreeMap<String, String>>> { 8 8 if COUNTRY_CACHE.get().is_none() { 9 9 let all_countries: BTreeMap<String, String> = BTreeMap::from_iter( 10 10 [
+73 -32
src/http/context.rs
··· 1 + use atproto_identity::axum::state::DidDocumentStorageExtractor; 2 + use atproto_identity::key::KeyProvider; 3 + use atproto_identity::resolve::IdentityResolver; 4 + use atproto_identity::storage::DidDocumentStorage; 5 + use atproto_oauth::storage::OAuthRequestStorage; 6 + use atproto_oauth_axum::state::OAuthClientConfig; 1 7 use axum::extract::FromRef; 2 8 use axum::{ 3 9 extract::FromRequestParts, ··· 7 13 use axum_extra::extract::Cached; 8 14 use axum_template::engine::Engine; 9 15 use cookie::Key; 10 - use hickory_resolver::TokioAsyncResolver; 11 16 use minijinja::context as template_context; 12 17 use std::{ops::Deref, sync::Arc}; 13 18 use unic_langid::LanguageIdentifier; ··· 26 31 http::middleware_auth::Auth, 27 32 http::middleware_i18n::Language, 28 33 i18n::Locales, 29 - storage::handle::model::Handle, 34 + storage::identity_profile::model::IdentityProfile, 30 35 storage::{CachePool, StoragePool}, 31 36 }; 32 37 33 38 #[cfg(feature = "embed")] 34 39 pub type AppEngine = Engine<Environment<'static>>; 35 40 36 - pub struct I18nContext { 37 - pub supported_languages: Vec<LanguageIdentifier>, 38 - pub locales: Locales, 41 + pub(crate) struct I18nContext { 42 + pub(crate) supported_languages: Vec<LanguageIdentifier>, 43 + pub(crate) locales: Locales, 39 44 } 40 45 41 46 pub struct InnerWebContext { ··· 44 49 pub pool: StoragePool, 45 50 pub cache_pool: CachePool, 46 51 pub config: Config, 47 - pub i18n_context: I18nContext, 48 - pub dns_resolver: hickory_resolver::TokioAsyncResolver, 52 + pub(crate) i18n_context: I18nContext, 53 + pub(crate) oauth_client_config: atproto_oauth_axum::state::OAuthClientConfig, 54 + pub(crate) identity_resolver: atproto_identity::resolve::IdentityResolver, 55 + pub(crate) key_provider: Arc<dyn atproto_identity::key::KeyProvider>, 56 + pub(crate) oauth_storage: Arc<dyn atproto_oauth::storage::OAuthRequestStorage>, 57 + pub(crate) document_storage: Arc<dyn atproto_identity::storage::DidDocumentStorage>, 49 58 } 50 59 51 60 #[derive(Clone, FromRef)] ··· 60 69 } 61 70 62 71 impl WebContext { 72 + #[allow(clippy::too_many_arguments)] 63 73 pub fn new( 64 74 pool: StoragePool, 65 75 cache_pool: CachePool, 66 76 engine: AppEngine, 67 77 http_client: &reqwest::Client, 68 78 config: Config, 69 - i18n_context: I18nContext, 70 - dns_resolver: TokioAsyncResolver, 79 + oauth_client_config: OAuthClientConfig, 80 + identity_resolver: IdentityResolver, 81 + key_provider: Arc<dyn KeyProvider>, 82 + oauth_storage: Arc<dyn OAuthRequestStorage>, 83 + document_storage: Arc<dyn DidDocumentStorage>, 84 + supported_languages: Vec<LanguageIdentifier>, 85 + locales: Locales, 71 86 ) -> Self { 72 87 Self(Arc::new(InnerWebContext { 73 88 pool, ··· 75 90 engine, 76 91 http_client: http_client.clone(), 77 92 config, 78 - i18n_context, 79 - dns_resolver, 93 + i18n_context: I18nContext { 94 + supported_languages, 95 + locales, 96 + }, 97 + oauth_client_config, 98 + identity_resolver, 99 + key_provider, 100 + oauth_storage, 101 + document_storage, 80 102 })) 81 103 } 82 104 } 83 105 84 - impl I18nContext { 85 - pub fn new(supported_languages: Vec<LanguageIdentifier>, locales: Locales) -> Self { 86 - Self { 87 - supported_languages, 88 - locales, 89 - } 106 + impl FromRef<WebContext> for Key { 107 + fn from_ref(context: &WebContext) -> Self { 108 + context.0.config.http_cookie_key.as_ref().clone() 109 + } 110 + } 111 + 112 + impl FromRef<WebContext> for IdentityResolver { 113 + fn from_ref(context: &WebContext) -> Self { 114 + context.0.identity_resolver.clone() 90 115 } 91 116 } 92 117 93 - impl FromRef<WebContext> for Key { 118 + impl FromRef<WebContext> for Arc<dyn KeyProvider> { 94 119 fn from_ref(context: &WebContext) -> Self { 95 - context.0.config.http_cookie_key.as_ref().clone() 120 + context.0.key_provider.clone() 121 + } 122 + } 123 + 124 + impl FromRef<WebContext> for OAuthClientConfig { 125 + fn from_ref(context: &WebContext) -> Self { 126 + context.0.oauth_client_config.clone() 127 + } 128 + } 129 + 130 + impl FromRef<WebContext> for DidDocumentStorageExtractor { 131 + fn from_ref(context: &WebContext) -> Self { 132 + DidDocumentStorageExtractor(context.0.document_storage.clone()) 133 + } 134 + } 135 + 136 + impl FromRef<WebContext> for Arc<dyn KeyProvider + Send + Sync> { 137 + fn from_ref(context: &WebContext) -> Self { 138 + context.0.key_provider.clone() 96 139 } 97 140 } 98 141 99 142 // New structs for reducing handler function arguments 100 143 101 144 /// A context struct specifically for admin handlers 102 - pub struct AdminRequestContext { 103 - pub web_context: WebContext, 104 - pub language: Language, 105 - pub admin_handle: Handle, 106 - pub auth: Auth, 145 + pub(crate) struct AdminRequestContext { 146 + pub(crate) web_context: WebContext, 147 + pub(crate) language: Language, 148 + pub(crate) admin_handle: IdentityProfile, 107 149 } 108 150 109 151 impl<S> FromRequestParts<S> for AdminRequestContext ··· 129 171 web_context, 130 172 language, 131 173 admin_handle, 132 - auth: cached_auth.0, 133 174 }) 134 175 } 135 176 } 136 177 137 178 /// Helper function to create standard template context for admin views 138 - pub fn admin_template_context( 179 + pub(crate) fn admin_template_context( 139 180 ctx: &AdminRequestContext, 140 181 canonical_url: &str, 141 182 ) -> minijinja::value::Value { ··· 147 188 } 148 189 149 190 /// A context struct for regular authenticated user handlers 150 - pub struct UserRequestContext { 151 - pub web_context: WebContext, 152 - pub language: Language, 153 - pub current_handle: Option<Handle>, 154 - pub auth: Auth, 191 + pub(crate) struct UserRequestContext { 192 + pub(crate) web_context: WebContext, 193 + pub(crate) language: Language, 194 + pub(crate) current_handle: Option<IdentityProfile>, 195 + pub(crate) auth: Auth, 155 196 } 156 197 157 198 impl<S> FromRequestParts<S> for UserRequestContext ··· 170 211 Ok(Self { 171 212 web_context, 172 213 language, 173 - current_handle: cached_auth.0 .0.clone(), 214 + current_handle: cached_auth.0.profile().cloned(), 174 215 auth: cached_auth.0, 175 216 }) 176 217 }
+2 -2
src/http/errors/admin_errors.rs
··· 3 3 /// These errors relate to the process of importing RSVP data into the system 4 4 /// by administrators, typically during data migration or recovery. 5 5 #[derive(Debug, Error)] 6 - pub enum AdminImportRsvpError { 6 + pub(crate) enum AdminImportRsvpError { 7 7 /// Error when an RSVP cannot be inserted during import. 8 8 /// 9 9 /// This error occurs when attempting to insert an imported RSVP into ··· 16 16 /// These errors relate to the process of importing event data into the system 17 17 /// by administrators, typically during data migration or recovery. 18 18 #[derive(Debug, Error)] 19 - pub enum AdminImportEventError { 19 + pub(crate) enum AdminImportEventError { 20 20 /// Error when an event cannot be inserted during import. 21 21 /// 22 22 /// This error occurs when attempting to insert an imported event into
+1 -8
src/http/errors/common_error.rs
··· 2 2 3 3 /// Represents common errors that can occur across various HTTP handlers. 4 4 #[derive(Debug, Error)] 5 - pub enum CommonError { 5 + pub(crate) enum CommonError { 6 6 /// Error when a handle slug is invalid. 7 7 /// 8 8 /// This error occurs when a URL contains a handle slug that doesn't conform ··· 30 30 /// that is needed to complete the operation. 31 31 #[error("error-common-4 Required field not provided")] 32 32 FieldRequired, 33 - 34 - /// Error when data has an invalid format or is corrupted. 35 - /// 36 - /// This error occurs when input data doesn't match the expected format 37 - /// or appears to be corrupted or tampered with. 38 - #[error("error-common-5 Invalid format or corrupted data")] 39 - InvalidFormat, 40 33 41 34 /// Error when an AT Protocol URI has an invalid format. 42 35 ///
+1 -22
src/http/errors/create_event_errors.rs
··· 5 5 /// These errors are typically triggered during validation of user-submitted 6 6 /// event creation forms. 7 7 #[derive(Debug, Error)] 8 - pub enum CreateEventError { 8 + pub(crate) enum CreateEventError { 9 9 /// Error when the event name is not provided. 10 10 /// 11 11 /// This error occurs when a user attempts to create an event without ··· 19 19 /// specifying a description, which is a required field. 20 20 #[error("error-create-event-2 Description not set")] 21 21 DescriptionNotSet, 22 - 23 - /// Error when the event dates are invalid. 24 - /// 25 - /// This error occurs when the provided event dates are invalid, such as when 26 - /// the end date is before the start date, or dates are outside allowed ranges. 27 - #[error("error-create-event-3 Invalid event dates")] 28 - InvalidEventDates, 29 - 30 - /// Error when an invalid event mode is specified. 31 - /// 32 - /// This error occurs when the provided event mode doesn't match one of the 33 - /// supported values (e.g., "in-person", "online", "hybrid"). 34 - #[error("error-create-event-4 Invalid event mode")] 35 - InvalidEventMode, 36 - 37 - /// Error when an invalid event status is specified. 38 - /// 39 - /// This error occurs when the provided event status doesn't match one of the 40 - /// supported values (e.g., "confirmed", "tentative", "cancelled"). 41 - #[error("error-create-event-5 Invalid event status")] 42 - InvalidEventStatus, 43 22 }
+1 -15
src/http/errors/edit_event_error.rs
··· 5 5 /// These errors typically happen when users attempt to modify existing events 6 6 /// and encounter authorization, validation, or type compatibility issues. 7 7 #[derive(Debug, Error)] 8 - pub enum EditEventError { 9 - /// Error when an invalid handle slug is provided. 10 - /// 11 - /// This error occurs when attempting to edit an event with a handle slug 12 - /// that is not properly formatted or does not exist in the system. 13 - #[error("error-edit-event-1 Invalid handle slug")] 14 - InvalidHandleSlug, 15 - 8 + pub(crate) enum EditEventError { 16 9 /// Error when a user is not authorized to edit an event. 17 10 /// 18 11 /// This error occurs when a user attempts to edit an event that they ··· 44 37 /// that has a location type that is not supported for editing through the web interface. 45 38 #[error("error-edit-event-5 Cannot edit locations: Event has unsupported location type")] 46 39 UnsupportedLocationType, 47 - 48 - /// Error when attempting to edit location data on an event that has no locations. 49 - /// 50 - /// This error occurs when a user attempts to add location information to an event 51 - /// that was not created with location information. 52 - #[error("error-edit-event-6 Cannot edit locations: Event has no locations")] 53 - NoLocationsPresent, 54 40 }
+1 -1
src/http/errors/event_view_errors.rs
··· 5 5 /// These errors typically happen when retrieving and displaying event data 6 6 /// to users, including data validation and enhancement issues. 7 7 #[derive(Debug, Error)] 8 - pub enum EventViewError { 8 + pub(crate) enum EventViewError { 9 9 /// Error when an invalid collection is specified. 10 10 /// 11 11 /// This error occurs when an event view request specifies a collection
+1 -1
src/http/errors/import_error.rs
··· 5 5 /// These errors typically happen when attempting to import events and RSVPs 6 6 /// from different sources, including community and Smokesignal systems. 7 7 #[derive(Debug, Error)] 8 - pub enum ImportError { 8 + pub(crate) enum ImportError { 9 9 /// Error when listing community events fails. 10 10 /// 11 11 /// This error occurs when attempting to retrieve a list of community
+2 -6
src/http/errors/login_error.rs
··· 5 5 /// These errors typically happen during the authentication process when users 6 6 /// are logging in to the application, including OAuth flows and DID validation. 7 7 #[derive(Debug, Error)] 8 - pub enum LoginError { 9 - /// Error when a DID document does not contain a handle. 10 - /// 11 - /// This error occurs during authentication when the user's DID document 12 - /// is retrieved but does not contain a required handle identifier. 13 - #[error("error-login-1 DID document does not contain a handle")] 8 + pub(crate) enum LoginError { 9 + #[error("error-login-1 DID document does not contain a handle identifier")] 14 10 NoHandle, 15 11 16 12 /// Error when a DID document does not contain a PDS endpoint.
+4 -4
src/http/errors/middleware_errors.rs
··· 9 9 /// These errors are related to the serialization and deserialization of 10 10 /// web session data used for maintaining user authentication state. 11 11 #[derive(Debug, Error)] 12 - pub enum WebSessionError { 12 + pub(crate) enum WebSessionError { 13 13 /// Error when web session deserialization fails. 14 14 /// 15 15 /// This error occurs when attempting to deserialize a web session from JSON ··· 30 30 /// These errors typically happen in the authentication middleware layer when 31 31 /// processing requests, including cryptographic operations and session validation. 32 32 #[derive(Debug, Error)] 33 - pub enum AuthMiddlewareError { 33 + pub(crate) enum AuthMiddlewareError { 34 34 /// Error when content signing fails. 35 35 /// 36 36 /// This error occurs when the authentication middleware attempts to 37 37 /// cryptographically sign content but the operation fails. 38 38 #[error("error-authmiddleware-1 Unable to sign content: {0:?}")] 39 - SigningFailed(p256::ecdsa::Error), 39 + SigningFailed(anyhow::Error), 40 40 } 41 41 42 42 #[derive(Debug, Error)] 43 - pub enum MiddlewareAuthError { 43 + pub(crate) enum MiddlewareAuthError { 44 44 #[error("error-middleware-auth-1 Access Denied: {0}")] 45 45 AccessDenied(String), 46 46
+1 -8
src/http/errors/migrate_event_error.rs
··· 5 5 /// These errors typically happen when attempting to migrate events between 6 6 /// different formats or systems, such as from smokesignal to community events. 7 7 #[derive(Debug, Error)] 8 - pub enum MigrateEventError { 9 - /// Error when an invalid handle slug is provided. 10 - /// 11 - /// This error occurs when attempting to migrate an event with a handle slug 12 - /// that is not properly formatted or does not exist in the system. 13 - #[error("error-migrate-event-1 Invalid handle slug")] 14 - InvalidHandleSlug, 15 - 8 + pub(crate) enum MigrateEventError { 16 9 /// Error when a user is not authorized to migrate an event. 17 10 /// 18 11 /// This error occurs when a user attempts to migrate an event that they
+1 -1
src/http/errors/migrate_rsvp_error.rs
··· 5 5 /// These errors relate to the migration or conversion of RSVP data 6 6 /// between different systems, formats, or versions. 7 7 #[derive(Debug, Error)] 8 - pub enum MigrateRsvpError { 8 + pub(crate) enum MigrateRsvpError { 9 9 /// Error when an invalid RSVP status is provided during migration. 10 10 /// 11 11 /// This error occurs when attempting to migrate an RSVP with a status
+14 -14
src/http/errors/mod.rs
··· 14 14 pub mod view_event_error; 15 15 pub mod web_error; 16 16 17 - pub use admin_errors::{AdminImportEventError, AdminImportRsvpError}; 18 - pub use common_error::CommonError; 19 - pub use create_event_errors::CreateEventError; 20 - pub use edit_event_error::EditEventError; 21 - pub use event_view_errors::EventViewError; 22 - pub use import_error::ImportError; 23 - pub use login_error::LoginError; 24 - pub use middleware_errors::{AuthMiddlewareError, WebSessionError}; 25 - pub use migrate_event_error::MigrateEventError; 26 - pub use migrate_rsvp_error::MigrateRsvpError; 27 - pub use rsvp_error::RSVPError; 28 - pub use url_error::UrlError; 29 - pub use view_event_error::ViewEventError; 30 - pub use web_error::WebError; 17 + pub(crate) use admin_errors::{AdminImportEventError, AdminImportRsvpError}; 18 + pub(crate) use common_error::CommonError; 19 + pub(crate) use create_event_errors::CreateEventError; 20 + pub(crate) use edit_event_error::EditEventError; 21 + pub(crate) use event_view_errors::EventViewError; 22 + pub(crate) use import_error::ImportError; 23 + pub(crate) use login_error::LoginError; 24 + pub(crate) use middleware_errors::{AuthMiddlewareError, WebSessionError}; 25 + pub(crate) use migrate_event_error::MigrateEventError; 26 + pub(crate) use migrate_rsvp_error::MigrateRsvpError; 27 + pub(crate) use rsvp_error::RSVPError; 28 + pub(crate) use url_error::UrlError; 29 + pub(crate) use view_event_error::ViewEventError; 30 + pub(crate) use web_error::WebError;
+1 -1
src/http/errors/rsvp_error.rs
··· 5 5 /// These errors relate to the handling of event RSVPs, such as 6 6 /// when users attempt to respond to event invitations. 7 7 #[derive(Debug, Error)] 8 - pub enum RSVPError { 8 + pub(crate) enum RSVPError { 9 9 /// Error when an RSVP cannot be found. 10 10 /// 11 11 /// This error occurs when attempting to retrieve or modify an RSVP
+1 -1
src/http/errors/url_error.rs
··· 5 5 /// These errors typically happen when validating or processing URLs in the system, 6 6 /// including checking for supported collection types in URL paths. 7 7 #[derive(Debug, Error)] 8 - pub enum UrlError { 8 + pub(crate) enum UrlError { 9 9 /// Error when an unsupported collection type is specified in a URL. 10 10 /// 11 11 /// This error occurs when a URL contains a collection type that is not
+1 -8
src/http/errors/view_event_error.rs
··· 5 5 /// These errors typically happen when retrieving or displaying specific 6 6 /// event data, including lookup failures and data enhancement issues. 7 7 #[derive(Debug, Error)] 8 - pub enum ViewEventError { 8 + pub(crate) enum ViewEventError { 9 9 /// Error when a requested event cannot be found. 10 10 /// 11 11 /// This error occurs when attempting to view an event that doesn't ··· 19 19 /// and the fallback method also fails to retrieve the event. 20 20 #[error("error-view-event-2 Failed to get event from fallback: {0}")] 21 21 FallbackFailed(String), 22 - 23 - /// Error when fetching event details fails. 24 - /// 25 - /// This error occurs when the system fails to retrieve additional 26 - /// details for an event, such as RSVP counts or related data. 27 - #[error("error-view-event-3 Failed to fetch event details: {0}")] 28 - FetchEventDetailsFailed(String), 29 22 }
+2 -30
src/http/errors/web_error.rs
··· 35 35 /// and error code, while a few web-specific errors have their own error code format: 36 36 /// `error-web-<number> <message>: <details>` 37 37 #[derive(Debug, Error)] 38 - pub enum WebError { 38 + pub(crate) enum WebError { 39 39 /// Error when authentication middleware fails. 40 40 /// 41 41 /// This error occurs when there are issues with verifying a user's identity ··· 112 112 /// This error occurs during RSVP operations such as creation, updating, 113 113 /// or retrieval of RSVPs. 114 114 #[error(transparent)] 115 - RSVP(#[from] RSVPError), 115 + Rsvp(#[from] RSVPError), 116 116 117 117 /// Cache operation errors. 118 118 /// ··· 121 121 #[error(transparent)] 122 122 Cache(#[from] crate::storage::errors::CacheError), 123 123 124 - /// JSON Web Key errors. 125 - /// 126 - /// This error occurs when there are issues with cryptographic keys, 127 - /// such as missing or invalid keys. 128 - #[error(transparent)] 129 - JwkError(#[from] crate::jose_errors::JwkError), 130 - 131 - /// JSON Object Signing and Encryption errors. 132 - /// 133 - /// This error occurs when there are issues with JWT operations, 134 - /// such as signature validation or token creation. 135 - #[error(transparent)] 136 - JoseError(#[from] crate::jose_errors::JoseError), 137 - 138 124 /// Configuration errors. 139 125 /// 140 126 /// This error occurs when there are issues with application configuration, 141 127 /// such as missing environment variables or invalid settings. 142 128 #[error(transparent)] 143 129 ConfigError(#[from] crate::config_errors::ConfigError), 144 - 145 - /// Data encoding/decoding errors. 146 - /// 147 - /// This error occurs when there are issues with data encoding or decoding, 148 - /// such as invalid Base64 or JSON parsing problems. 149 - #[error(transparent)] 150 - EncodingError(#[from] crate::encoding_errors::EncodingError), 151 - 152 - /// AT Protocol URI errors. 153 - /// 154 - /// This error occurs when there are issues with AT Protocol URIs, 155 - /// such as malformed DIDs or invalid handles. 156 - #[error(transparent)] 157 - UriError(#[from] crate::atproto::errors::UriError), 158 130 159 131 /// Database storage errors. 160 132 ///
+6 -6
src/http/event_form.rs
··· 6 6 use super::cache_countries::cached_countries; 7 7 8 8 #[derive(Debug, Error)] 9 - pub enum BuildEventError { 9 + pub(crate) enum BuildEventError { 10 10 #[error("error-event-builder-1 Invalid Name")] 11 11 InvalidName, 12 12 ··· 60 60 } 61 61 62 62 #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] 63 - pub enum BuildEventContentState { 63 + pub(crate) enum BuildEventContentState { 64 64 #[default] 65 65 Reset, 66 66 Selecting, ··· 68 68 } 69 69 70 70 #[derive(Serialize, Deserialize, Debug, Clone)] 71 - pub struct BuildStartsForm { 71 + pub(crate) struct BuildStartsForm { 72 72 pub build_state: Option<BuildEventContentState>, 73 73 74 74 pub tz: Option<String>, ··· 99 99 } 100 100 101 101 #[derive(Serialize, Deserialize, Debug, Clone)] 102 - pub struct BuildLocationForm { 102 + pub(crate) struct BuildLocationForm { 103 103 pub build_state: Option<BuildEventContentState>, 104 104 105 105 pub location_country: Option<String>, ··· 122 122 } 123 123 124 124 #[derive(Serialize, Deserialize, Debug, Clone)] 125 - pub struct BuildLinkForm { 125 + pub(crate) struct BuildLinkForm { 126 126 pub build_state: Option<BuildEventContentState>, 127 127 128 128 pub link_name: Option<String>, ··· 133 133 } 134 134 135 135 #[derive(Serialize, Deserialize, Debug, Clone)] 136 - pub struct BuildEventForm { 136 + pub(crate) struct BuildEventForm { 137 137 pub build_state: Option<BuildEventContentState>, 138 138 139 139 pub name: Option<String>,
+12 -12
src/http/event_view.rs
··· 1 1 use std::collections::HashSet; 2 + use std::str::FromStr; 2 3 3 4 use ammonia::Builder; 4 5 use anyhow::Result; 6 + use atproto_record::aturi::ATURI; 5 7 use chrono_tz::Tz; 6 8 use cityhasher::HashMap; 7 9 use serde::Serialize; ··· 9 11 use crate::http::errors::EventViewError; 10 12 11 13 use crate::{ 12 - atproto::{ 13 - lexicon::{ 14 - community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID, 15 - events::smokesignal::calendar::event::NSID as SmokeSignalEventNSID, 16 - }, 17 - uri::parse_aturi, 14 + atproto::lexicon::{ 15 + community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID, 16 + events::smokesignal::calendar::event::NSID as SmokeSignalEventNSID, 18 17 }, 19 18 http::utils::truncate_text, 20 19 storage::{ ··· 23 22 count_event_rsvps, extract_event_details, get_event_rsvp_counts, 24 23 model::{Event, EventWithRole}, 25 24 }, 26 - handle::{handles_by_did, model::Handle}, 25 + identity_profile::{handles_by_did, model::IdentityProfile}, 27 26 StoragePool, 28 27 }, 29 28 }; ··· 58 57 pub links: Vec<(String, Option<String>)>, // (uri, name) 59 58 } 60 59 61 - impl TryFrom<(Option<&Handle>, Option<&Handle>, &Event)> for EventView { 60 + impl TryFrom<(Option<&IdentityProfile>, Option<&IdentityProfile>, &Event)> for EventView { 62 61 type Error = anyhow::Error; 63 62 64 63 fn try_from( 65 - (viewer, organizer, event): (Option<&Handle>, Option<&Handle>, &Event), 64 + (viewer, organizer, event): (Option<&IdentityProfile>, Option<&IdentityProfile>, &Event), 66 65 ) -> Result<Self, Self::Error> { 67 66 // Time zones are used to display date/time values from the perspective 68 67 // of the viewer. The timezone is selected with this priority: ··· 79 78 } 80 79 .unwrap_or(Tz::UTC); 81 80 82 - let (repository, collection, rkey) = parse_aturi(event.aturi.as_str())?; 81 + let aturi = ATURI::from_str(&event.aturi)?; 82 + let (repository, collection, rkey) = (aturi.authority, aturi.collection, aturi.record_key); 83 83 84 84 // We now support both community and smokesignal event formats 85 85 if collection != LexiconCommunityEventNSID && collection != SmokeSignalEventNSID { ··· 221 221 } 222 222 } 223 223 224 - pub async fn hydrate_event_organizers( 224 + pub(crate) async fn hydrate_event_organizers( 225 225 pool: &StoragePool, 226 226 events: &[EventWithRole], 227 - ) -> Result<HashMap<std::string::String, Handle>> { 227 + ) -> Result<HashMap<std::string::String, IdentityProfile>> { 228 228 if events.is_empty() { 229 229 return Ok(HashMap::default()); 230 230 }
+8 -8
src/http/handle_admin_denylist.rs
··· 21 21 }; 22 22 23 23 #[derive(Debug, Deserialize)] 24 - pub struct DenylistAddForm { 25 - pub subject: String, 26 - pub reason: String, 24 + pub(crate) struct DenylistAddForm { 25 + subject: String, 26 + reason: String, 27 27 } 28 28 29 29 #[derive(Debug, Deserialize)] 30 - pub struct DenylistRemoveForm { 31 - pub subject: String, 30 + pub(crate) struct DenylistRemoveForm { 31 + subject: String, 32 32 } 33 33 34 - pub async fn handle_admin_denylist( 34 + pub(crate) async fn handle_admin_denylist( 35 35 admin_ctx: AdminRequestContext, 36 36 pagination: Query<Pagination>, 37 37 ) -> Result<impl IntoResponse, WebError> { ··· 78 78 .into_response()) 79 79 } 80 80 81 - pub async fn handle_admin_denylist_add( 81 + pub(crate) async fn handle_admin_denylist_add( 82 82 admin_ctx: AdminRequestContext, 83 83 Form(form): Form<DenylistAddForm>, 84 84 ) -> Result<impl IntoResponse, WebError> { ··· 103 103 Ok(Redirect::to("/admin/denylist").into_response()) 104 104 } 105 105 106 - pub async fn handle_admin_denylist_remove( 106 + pub(crate) async fn handle_admin_denylist_remove( 107 107 admin_ctx: AdminRequestContext, 108 108 Form(form): Form<DenylistRemoveForm>, 109 109 ) -> Result<impl IntoResponse, WebError> {
+3 -3
src/http/handle_admin_event.rs
··· 12 12 }; 13 13 14 14 #[derive(Deserialize)] 15 - pub struct EventRecordQuery { 16 - pub aturi: String, 15 + pub(crate) struct EventRecordQuery { 16 + pub(crate) aturi: String, 17 17 } 18 18 19 - pub async fn handle_admin_event( 19 + pub(crate) async fn handle_admin_event( 20 20 admin_ctx: AdminRequestContext, 21 21 Query(query): Query<EventRecordQuery>, 22 22 ) -> Result<impl IntoResponse, WebError> {
+1 -1
src/http/handle_admin_events.rs
··· 14 14 storage::event::event_list, 15 15 }; 16 16 17 - pub async fn handle_admin_events( 17 + pub(crate) async fn handle_admin_events( 18 18 admin_ctx: AdminRequestContext, 19 19 pagination: Query<Pagination>, 20 20 ) -> Result<impl IntoResponse, WebError> {
+3 -3
src/http/handle_admin_handles.rs
··· 16 16 pagination::{Pagination, PaginationView}, 17 17 }, 18 18 select_template, 19 - storage::handle::{handle_list, handle_nuke}, 19 + storage::identity_profile::{handle_list, handle_nuke}, 20 20 }; 21 21 22 - pub async fn handle_admin_handles( 22 + pub(crate) async fn handle_admin_handles( 23 23 admin_ctx: AdminRequestContext, 24 24 pagination: Query<Pagination>, 25 25 ) -> Result<impl IntoResponse, WebError> { ··· 66 66 .into_response()) 67 67 } 68 68 69 - pub async fn handle_admin_nuke_identity( 69 + pub(crate) async fn handle_admin_nuke_identity( 70 70 admin_ctx: AdminRequestContext, 71 71 HxRequest(hx_request): HxRequest, 72 72 Path(did): Path<String>,
+62 -69
src/http/handle_admin_import_event.rs
··· 1 + use std::str::FromStr; 2 + 1 3 use anyhow::Result; 4 + use atproto_identity::resolve::IdentityResolver; 5 + use atproto_record::aturi::ATURI; 2 6 use axum::{ 3 7 extract::Form, 4 8 response::{IntoResponse, Redirect}, ··· 6 10 use serde::Deserialize; 7 11 8 12 use crate::{ 9 - atproto::{ 10 - lexicon::{ 11 - community::lexicon::calendar::event::Event as CommunityEventLexicon, 12 - events::smokesignal::calendar::event::{Event as SmokeSignalEvent, EventResponse}, 13 - }, 14 - uri::parse_aturi, 13 + atproto::lexicon::{ 14 + community::lexicon::calendar::event::Event as CommunityEventLexicon, 15 + events::smokesignal::calendar::event::{Event as SmokeSignalEvent, EventResponse}, 15 16 }, 16 17 contextual_error, 17 18 http::{ 18 19 context::{admin_template_context, AdminRequestContext}, 19 20 errors::{AdminImportEventError, CommonError, LoginError, WebError}, 20 21 }, 21 - resolve::{parse_input, resolve_subject, InputType}, 22 22 select_template, 23 - storage::{event::event_insert_with_metadata, handle::handle_warm_up}, 23 + storage::{event::event_insert_with_metadata, identity_profile::handle_warm_up}, 24 24 }; 25 25 26 26 #[derive(Deserialize)] 27 - pub struct ImportEventForm { 28 - pub aturi: String, 27 + pub(crate) struct ImportEventForm { 28 + pub(crate) aturi: String, 29 29 } 30 30 31 - pub async fn handle_admin_import_event( 31 + pub(crate) async fn handle_admin_import_event( 32 32 admin_ctx: AdminRequestContext, 33 + identity_resolver: IdentityResolver, 33 34 Form(form): Form<ImportEventForm>, 34 35 ) -> Result<impl IntoResponse, WebError> { 35 36 // Admin access is already verified by the extractor ··· 43 44 44 45 // Parse the AT-URI 45 46 let aturi = form.aturi.trim(); 46 - let (repository, collection, rkey) = match parse_aturi(aturi) { 47 - Ok(parsed) => parsed, 47 + let (repository, collection, rkey) = match ATURI::from_str(aturi) { 48 + Ok(aturi) => (aturi.authority, aturi.collection, aturi.record_key), 48 49 Err(_err) => { 49 50 return contextual_error!( 50 51 admin_ctx.web_context, ··· 71 72 } 72 73 }; 73 74 74 - // Resolve the DID for the repository 75 - let input_type = match parse_input(&repository) { 76 - Ok(input) => input, 77 - Err(_err) => { 75 + let document = match identity_resolver.resolve(&repository).await { 76 + Ok(value) => value, 77 + Err(err) => { 78 78 return contextual_error!( 79 79 admin_ctx.web_context, 80 80 admin_ctx.language, 81 81 error_template, 82 82 default_context, 83 - CommonError::FailedToParse 83 + err 84 84 ); 85 85 } 86 86 }; 87 87 88 - let did = match input_type { 89 - InputType::Handle(handle) => { 90 - match resolve_subject( 91 - &admin_ctx.web_context.http_client, 92 - &admin_ctx.web_context.dns_resolver, 93 - &handle, 94 - ) 95 - .await 96 - { 97 - Ok(did) => did, 98 - Err(_err) => { 99 - return contextual_error!( 100 - admin_ctx.web_context, 101 - admin_ctx.language, 102 - error_template, 103 - default_context, 104 - CommonError::FailedToParse 105 - ); 106 - } 107 - } 108 - } 109 - InputType::Plc(did) | InputType::Web(did) => did, 110 - }; 111 - 112 - // Get the DID document to find the PDS endpoint 113 - let did_doc = match crate::did::plc::query( 114 - &admin_ctx.web_context.http_client, 115 - &admin_ctx.web_context.config.plc_hostname, 116 - &did, 117 - ) 118 - .await 88 + let handle = match document 89 + .handles() 90 + .ok_or(WebError::Login(LoginError::NoHandle)) 119 91 { 120 - Ok(doc) => doc, 121 - Err(_err) => { 92 + Ok(value) => value, 93 + Err(err) => { 122 94 return contextual_error!( 123 95 admin_ctx.web_context, 124 96 admin_ctx.language, 125 97 error_template, 126 98 default_context, 127 - CommonError::FailedToParse 99 + err 128 100 ); 129 101 } 130 102 }; 131 103 132 - // Insert the handle if it doesn't exist 133 - if let Some(handle) = did_doc.primary_handle() { 134 - if let Some(pds) = did_doc.pds_endpoint() { 135 - if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &did, handle, pds).await { 136 - tracing::warn!("Failed to insert handle: {}", err); 137 - } 138 - } 139 - } 140 - 141 - // Get the PDS endpoint 142 - let pds_endpoint = match did_doc.pds_endpoint() { 143 - Some(endpoint) => endpoint, 144 - None => { 104 + let pds = match document 105 + .pds_endpoints() 106 + .first() 107 + .cloned() 108 + .ok_or(WebError::Login(LoginError::NoPDS)) 109 + { 110 + Ok(value) => value, 111 + Err(err) => { 145 112 return contextual_error!( 146 113 admin_ctx.web_context, 147 114 admin_ctx.language, 148 115 error_template, 149 116 default_context, 150 - WebError::Login(LoginError::NoPDS) 117 + err 151 118 ); 152 119 } 153 120 }; 154 121 122 + if let Err(err) = admin_ctx 123 + .web_context 124 + .document_storage 125 + .store_document(document.clone()) 126 + .await 127 + { 128 + return contextual_error!( 129 + admin_ctx.web_context, 130 + admin_ctx.language, 131 + error_template, 132 + default_context, 133 + err 134 + ); 135 + } 136 + 137 + // Insert the handle if it doesn't exist 138 + if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &document.id, handle, pds).await { 139 + return contextual_error!( 140 + admin_ctx.web_context, 141 + admin_ctx.language, 142 + error_template, 143 + default_context, 144 + err 145 + ); 146 + } 147 + 155 148 // Construct the XRPC request to get the record 156 149 let url = format!( 157 150 "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}", 158 - pds_endpoint, did, collection, rkey 151 + pds, document.id, collection, rkey 159 152 ); 160 153 161 154 let response = match admin_ctx.web_context.http_client.get(&url).send().await { ··· 197 190 &admin_ctx.web_context.pool, 198 191 aturi, 199 192 &record.cid, 200 - &did, 193 + &document.id, 201 194 "events.smokesignal.calendar.event", 202 195 &record.value, 203 196 &name, ··· 246 239 &admin_ctx.web_context.pool, 247 240 aturi, 248 241 &record.cid, 249 - &did, 242 + &document.id, 250 243 "community.lexicon.calendar.event", 251 244 &record.value, 252 245 &name,
+64 -71
src/http/handle_admin_import_rsvp.rs
··· 1 + use std::str::FromStr; 2 + 1 3 use anyhow::Result; 4 + use atproto_identity::resolve::IdentityResolver; 5 + use atproto_record::aturi::ATURI; 2 6 use axum::{ 3 7 extract::Form, 4 8 response::{IntoResponse, Redirect}, ··· 7 11 use urlencoding; 8 12 9 13 use crate::{ 10 - atproto::{ 11 - lexicon::{ 12 - community::lexicon::calendar::rsvp::{ 13 - Rsvp as CommunityRsvpLexicon, RsvpStatus as CommunityRsvpStatusLexicon, 14 - NSID as COMMUNITY_RSVP_NSID, 15 - }, 16 - events::smokesignal::calendar::rsvp::{ 17 - Rsvp as SmokesignalRsvpLexicon, NSID as SMOKESIGNAL_RSVP_NSID, 18 - }, 14 + atproto::lexicon::{ 15 + community::lexicon::calendar::rsvp::{ 16 + Rsvp as CommunityRsvpLexicon, RsvpStatus as CommunityRsvpStatusLexicon, 17 + NSID as COMMUNITY_RSVP_NSID, 19 18 }, 20 - uri::parse_aturi, 19 + events::smokesignal::calendar::rsvp::{ 20 + Rsvp as SmokesignalRsvpLexicon, NSID as SMOKESIGNAL_RSVP_NSID, 21 + }, 21 22 }, 22 23 contextual_error, 23 24 http::{ 24 25 context::{admin_template_context, AdminRequestContext}, 25 26 errors::{AdminImportRsvpError, CommonError, LoginError, WebError}, 26 27 }, 27 - resolve::{parse_input, resolve_subject, InputType}, 28 28 select_template, 29 - storage::{event::rsvp_insert_with_metadata, handle::handle_warm_up}, 29 + storage::{event::rsvp_insert_with_metadata, identity_profile::handle_warm_up}, 30 30 }; 31 31 32 32 #[derive(Deserialize)] ··· 34 34 pub aturi: String, 35 35 } 36 36 37 - pub async fn handle_admin_import_rsvp( 37 + pub(crate) async fn handle_admin_import_rsvp( 38 38 admin_ctx: AdminRequestContext, 39 + identity_resolver: IdentityResolver, 39 40 Form(form): Form<ImportRsvpForm>, 40 41 ) -> Result<impl IntoResponse, WebError> { 41 42 // Admin access is already verified by the extractor ··· 49 50 50 51 // Parse the AT-URI 51 52 let aturi = form.aturi.trim(); 52 - let (repository, collection, rkey) = match parse_aturi(aturi) { 53 - Ok(parsed) => parsed, 53 + let (repository, collection, rkey) = match ATURI::from_str(aturi) { 54 + Ok(aturi) => (aturi.authority, aturi.collection, aturi.record_key), 54 55 Err(_err) => { 55 56 return contextual_error!( 56 57 admin_ctx.web_context, ··· 77 78 } 78 79 }; 79 80 80 - // Resolve the DID for the repository 81 - let input_type = match parse_input(&repository) { 82 - Ok(input) => input, 83 - Err(_err) => { 81 + let document = match identity_resolver.resolve(&repository).await { 82 + Ok(value) => value, 83 + Err(err) => { 84 84 return contextual_error!( 85 85 admin_ctx.web_context, 86 86 admin_ctx.language, 87 87 error_template, 88 88 default_context, 89 - CommonError::FailedToParse 89 + err 90 90 ); 91 91 } 92 92 }; 93 93 94 - let did = match input_type { 95 - InputType::Handle(handle) => { 96 - match resolve_subject( 97 - &admin_ctx.web_context.http_client, 98 - &admin_ctx.web_context.dns_resolver, 99 - &handle, 100 - ) 101 - .await 102 - { 103 - Ok(did) => did, 104 - Err(_err) => { 105 - return contextual_error!( 106 - admin_ctx.web_context, 107 - admin_ctx.language, 108 - error_template, 109 - default_context, 110 - CommonError::FailedToParse 111 - ); 112 - } 113 - } 114 - } 115 - InputType::Plc(did) | InputType::Web(did) => did, 116 - }; 117 - 118 - // Get the DID document to find the PDS endpoint 119 - let did_doc = match crate::did::plc::query( 120 - &admin_ctx.web_context.http_client, 121 - &admin_ctx.web_context.config.plc_hostname, 122 - &did, 123 - ) 124 - .await 94 + let handle = match document 95 + .handles() 96 + .ok_or(WebError::Login(LoginError::NoHandle)) 125 97 { 126 - Ok(doc) => doc, 127 - Err(_err) => { 98 + Ok(value) => value, 99 + Err(err) => { 128 100 return contextual_error!( 129 101 admin_ctx.web_context, 130 102 admin_ctx.language, 131 103 error_template, 132 104 default_context, 133 - CommonError::FailedToParse 105 + err 134 106 ); 135 107 } 136 108 }; 137 109 138 - // Insert the handle if it doesn't exist 139 - if let Some(handle) = did_doc.primary_handle() { 140 - if let Some(pds) = did_doc.pds_endpoint() { 141 - if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &did, handle, pds).await { 142 - tracing::warn!("Failed to insert handle: {}", err); 143 - } 144 - } 145 - } 146 - 147 - // Get the PDS endpoint 148 - let pds_endpoint = match did_doc.pds_endpoint() { 149 - Some(endpoint) => endpoint, 150 - None => { 110 + let pds = match document 111 + .pds_endpoints() 112 + .first() 113 + .cloned() 114 + .ok_or(WebError::Login(LoginError::NoPDS)) 115 + { 116 + Ok(value) => value, 117 + Err(err) => { 151 118 return contextual_error!( 152 119 admin_ctx.web_context, 153 120 admin_ctx.language, 154 121 error_template, 155 122 default_context, 156 - WebError::Login(LoginError::NoPDS) 123 + err 157 124 ); 158 125 } 159 126 }; 160 127 128 + if let Err(err) = admin_ctx 129 + .web_context 130 + .document_storage 131 + .store_document(document.clone()) 132 + .await 133 + { 134 + return contextual_error!( 135 + admin_ctx.web_context, 136 + admin_ctx.language, 137 + error_template, 138 + default_context, 139 + err 140 + ); 141 + } 142 + 143 + // Insert the handle if it doesn't exist 144 + if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &document.id, handle, pds).await { 145 + return contextual_error!( 146 + admin_ctx.web_context, 147 + admin_ctx.language, 148 + error_template, 149 + default_context, 150 + err 151 + ); 152 + } 153 + 161 154 // Construct the XRPC request to get the record 162 155 let url = format!( 163 156 "{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}", 164 - pds_endpoint, did, collection, rkey 157 + pds, document.id, collection, rkey 165 158 ); 166 159 167 160 let response = match admin_ctx.web_context.http_client.get(&url).send().await { ··· 237 230 crate::storage::event::RsvpInsertParams { 238 231 aturi, 239 232 cid: &cid, 240 - did: &did, 233 + did: &document.id, 241 234 lexicon: COMMUNITY_RSVP_NSID, 242 235 record: &rsvp_value, 243 236 event_aturi: &event_aturi, ··· 284 277 crate::storage::event::RsvpInsertParams { 285 278 aturi, 286 279 cid: &cid, 287 - did: &did, 280 + did: &document.id, 288 281 lexicon: SMOKESIGNAL_RSVP_NSID, 289 282 record: &rsvp_value, 290 283 event_aturi: &event_aturi,
+1 -1
src/http/handle_admin_index.rs
··· 7 7 8 8 use super::errors::WebError; 9 9 10 - pub async fn handle_admin_index( 10 + pub(crate) async fn handle_admin_index( 11 11 admin_ctx: AdminRequestContext, 12 12 ) -> Result<impl IntoResponse, WebError> { 13 13 // User is already verified as admin by the extractor
+3 -3
src/http/handle_admin_rsvp.rs
··· 18 18 }; 19 19 20 20 #[derive(Deserialize)] 21 - pub struct RsvpRecordQuery { 22 - pub aturi: String, 21 + pub(crate) struct RsvpRecordQuery { 22 + pub(crate) aturi: String, 23 23 } 24 24 25 - pub async fn handle_admin_rsvp( 25 + pub(crate) async fn handle_admin_rsvp( 26 26 State(web_context): State<WebContext>, 27 27 Language(language): Language, 28 28 Cached(auth): Cached<Auth>,
+2 -2
src/http/handle_admin_rsvps.rs
··· 16 16 }; 17 17 18 18 #[derive(Deserialize, Default)] 19 - pub struct AdminRsvpsParams { 19 + pub(crate) struct AdminRsvpsParams { 20 20 #[serde(flatten)] 21 21 pagination: Pagination, 22 22 ··· 25 25 imported_aturi: Option<String>, 26 26 } 27 27 28 - pub async fn handle_admin_rsvps( 28 + pub(crate) async fn handle_admin_rsvps( 29 29 admin_ctx: AdminRequestContext, 30 30 Query(params): Query<AdminRsvpsParams>, 31 31 ) -> Result<impl IntoResponse, WebError> {
+68 -39
src/http/handle_create_event.rs
··· 16 16 use minijinja::context as template_context; 17 17 use serde::Deserialize; 18 18 19 - use crate::atproto::auth::SimpleOAuthSessionProvider; 20 - use crate::atproto::client::CreateRecordRequest; 21 - use crate::atproto::client::OAuthPdsClient; 19 + use crate::atproto::auth::{ 20 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 21 + }; 22 22 use crate::atproto::lexicon::community::lexicon::calendar::event::Event; 23 23 use crate::atproto::lexicon::community::lexicon::calendar::event::EventLink; 24 24 use crate::atproto::lexicon::community::lexicon::calendar::event::EventLocation; ··· 26 26 use crate::atproto::lexicon::community::lexicon::calendar::event::Status; 27 27 use crate::atproto::lexicon::community::lexicon::calendar::event::NSID; 28 28 use crate::atproto::lexicon::community::lexicon::location::Address; 29 + use crate::config::OAuthBackendConfig; 29 30 use crate::contextual_error; 30 31 use crate::http::context::WebContext; 31 32 use crate::http::errors::CommonError; ··· 41 42 use crate::http::utils::url_from_aturi; 42 43 use crate::select_template; 43 44 use crate::storage::event::event_insert; 45 + use atproto_client::com::atproto::repo::{ 46 + create_record, CreateRecordRequest, CreateRecordResponse, 47 + }; 44 48 45 49 use super::cache_countries::cached_countries; 46 50 use super::event_form::BuildLocationForm; 47 51 48 - pub async fn handle_create_event( 52 + pub(crate) async fn handle_create_event( 49 53 method: Method, 50 54 State(web_context): State<WebContext>, 51 55 Language(language): Language, ··· 54 58 HxBoosted(hx_boosted): HxBoosted, 55 59 Form(mut build_event_form): Form<BuildEventForm>, 56 60 ) -> Result<impl IntoResponse, WebError> { 57 - let current_handle = auth.require(&web_context.config.destination_key, "/event")?; 61 + let current_handle = auth.require(&web_context.config, "/event")?; 58 62 59 63 let is_development = cfg!(debug_assertions); 60 64 ··· 72 76 73 77 let error_template = select_template!(hx_boosted, hx_request, language); 74 78 75 - let (default_tz, timezones) = supported_timezones(auth.0.as_ref()); 79 + let (default_tz, timezones) = supported_timezones(auth.profile()); 76 80 77 81 if build_event_form.build_state.is_none() { 78 82 build_event_form.build_state = Some(BuildEventContentState::default()); ··· 217 221 _ => None, 218 222 }); 219 223 220 - // Ensure we have auth data for the API call 221 - let auth_data = auth.1.ok_or(CommonError::NotAuthorized)?; 222 - let client_auth: SimpleOAuthSessionProvider = 223 - SimpleOAuthSessionProvider::try_from(auth_data)?; 224 - 225 - let client = OAuthPdsClient { 226 - http_client: &web_context.http_client, 227 - pds: &current_handle.pds, 224 + // Create DPoP auth based on OAuth backend type 225 + let dpop_auth = match (&auth, &web_context.config.oauth_backend) { 226 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 227 + create_dpop_auth_from_oauth_session(session)? 228 + } 229 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 230 + create_dpop_auth_from_aip_session( 231 + &web_context.http_client, 232 + hostname, 233 + access_token, 234 + ) 235 + .await? 236 + } 237 + _ => return Err(CommonError::NotAuthorized.into()), 228 238 }; 229 239 230 240 let locations = match &build_event_form.location_country { ··· 276 286 swap_commit: None, 277 287 }; 278 288 279 - let create_record_result = client.create_record(&client_auth, event_record).await; 289 + let create_record_result = create_record( 290 + &web_context.http_client, 291 + &dpop_auth, 292 + &current_handle.pds, 293 + event_record, 294 + ) 295 + .await; 280 296 281 - if let Err(err) = create_record_result { 282 - return contextual_error!( 283 - web_context, 284 - language, 285 - error_template, 286 - default_context, 287 - err 288 - ); 289 - } 290 - 291 - // create_record_result is guaranteed to be Ok since we checked for Err above 292 - let create_record_result = create_record_result?; 297 + let create_record_response = match create_record_result { 298 + Ok(CreateRecordResponse::StrongRef { uri, cid, .. }) => { 299 + crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid } 300 + } 301 + Ok(CreateRecordResponse::Error(err)) => { 302 + return contextual_error!( 303 + web_context, 304 + language, 305 + error_template, 306 + default_context, 307 + anyhow::anyhow!("Server error: {}", err.error_message()) 308 + ); 309 + } 310 + Err(err) => { 311 + return contextual_error!( 312 + web_context, 313 + language, 314 + error_template, 315 + default_context, 316 + err 317 + ); 318 + } 319 + }; 293 320 294 321 let event_insert_result = event_insert( 295 322 &web_context.pool, 296 - &create_record_result.uri, 297 - &create_record_result.cid, 323 + &create_record_response.uri, 324 + &create_record_response.cid, 298 325 &current_handle.did, 299 326 NSID, 300 327 &the_record, ··· 311 338 ); 312 339 } 313 340 314 - let event_url = 315 - url_from_aturi(&web_context.config.external_base, &create_record_result.uri)?; 341 + let event_url = url_from_aturi( 342 + &web_context.config.external_base, 343 + &create_record_response.uri, 344 + )?; 316 345 317 346 return Ok(RenderHtml( 318 347 &render_template, ··· 346 375 .into_response()) 347 376 } 348 377 349 - pub async fn handle_starts_at_builder( 378 + pub(crate) async fn handle_starts_at_builder( 350 379 method: Method, 351 380 State(web_context): State<WebContext>, 352 381 Language(language): Language, ··· 362 391 return Ok(StatusCode::BAD_REQUEST.into_response()); 363 392 } 364 393 365 - let (default_tz, timezones) = supported_timezones(auth.0.as_ref()); 394 + let (default_tz, timezones) = supported_timezones(auth.profile()); 366 395 367 396 let is_development = cfg!(debug_assertions); 368 397 ··· 436 465 .into_response()) 437 466 } 438 467 439 - pub async fn handle_location_at_builder( 468 + pub(crate) async fn handle_location_at_builder( 440 469 method: Method, 441 470 State(web_context): State<WebContext>, 442 471 Language(language): Language, ··· 510 539 .into_response()) 511 540 } 512 541 513 - pub async fn handle_link_at_builder( 542 + pub(crate) async fn handle_link_at_builder( 514 543 method: Method, 515 544 State(web_context): State<WebContext>, 516 545 Language(language): Language, ··· 585 614 } 586 615 587 616 #[derive(Deserialize, Debug, Clone)] 588 - pub struct LocationDataListHint { 589 - pub location_country: Option<String>, 617 + pub(crate) struct LocationDataListHint { 618 + pub(crate) location_country: Option<String>, 590 619 } 591 620 592 - pub async fn handle_location_datalist( 621 + pub(crate) async fn handle_location_datalist( 593 622 State(web_context): State<WebContext>, 594 623 HxRequest(hx_request): HxRequest, 595 624 Query(location_country_hint): Query<LocationDataListHint>, ··· 653 682 None 654 683 } 655 684 656 - pub fn prefixed(mut set: BTreeMap<String, String>, prefix: &str) -> BTreeMap<String, String> { 685 + fn prefixed(mut set: BTreeMap<String, String>, prefix: &str) -> BTreeMap<String, String> { 657 686 let mut set = set.split_off(prefix); 658 687 659 688 if let Some(not_in_prefix) = upper_bound_from_prefix(prefix) {
+55 -28
src/http/handle_create_rsvp.rs
··· 9 9 use minijinja::context as template_context; 10 10 use std::hash::Hasher; 11 11 12 + use crate::atproto::auth::{ 13 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 14 + }; 15 + use crate::config::OAuthBackendConfig; 12 16 use crate::{ 13 - atproto::{ 14 - auth::SimpleOAuthSessionProvider, 15 - client::{OAuthPdsClient, PutRecordRequest}, 16 - lexicon::{ 17 - com::atproto::repo::StrongRef, 18 - community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID}, 19 - }, 17 + atproto::lexicon::{ 18 + com::atproto::repo::StrongRef, 19 + community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID}, 20 20 }, 21 21 contextual_error, 22 22 http::{ 23 23 context::WebContext, 24 - errors::WebError, 24 + errors::{CommonError, WebError}, 25 25 middleware_auth::Auth, 26 26 middleware_i18n::Language, 27 27 rsvp_form::{BuildRSVPForm, BuildRsvpContentState}, ··· 30 30 select_template, 31 31 storage::event::rsvp_insert, 32 32 }; 33 + use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse}; 33 34 34 - pub async fn handle_create_rsvp( 35 + pub(crate) async fn handle_create_rsvp( 35 36 method: Method, 36 37 State(web_context): State<WebContext>, 37 38 Language(language): Language, ··· 40 41 HxBoosted(hx_boosted): HxBoosted, 41 42 Form(mut build_rsvp_form): Form<BuildRSVPForm>, 42 43 ) -> Result<impl IntoResponse, WebError> { 43 - let current_handle = auth.require(&web_context.config.destination_key, "/rsvp")?; 44 + let current_handle = auth.require(&web_context.config, "/rsvp")?; 44 45 45 46 let default_context = template_context! { 46 47 current_handle, ··· 115 116 if !found_errors { 116 117 let now = Utc::now(); 117 118 118 - let client_auth: SimpleOAuthSessionProvider = 119 - SimpleOAuthSessionProvider::try_from(auth.1.unwrap())?; 120 - 121 - let client = OAuthPdsClient { 122 - http_client: &web_context.http_client, 123 - pds: &current_handle.pds, 119 + // Create DPoP auth based on OAuth backend type 120 + let dpop_auth = match (&auth, &web_context.config.oauth_backend) { 121 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 122 + create_dpop_auth_from_oauth_session(session)? 123 + } 124 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 125 + create_dpop_auth_from_aip_session( 126 + &web_context.http_client, 127 + hostname, 128 + access_token, 129 + ) 130 + .await? 131 + } 132 + _ => return Err(CommonError::NotAuthorized.into()), 124 133 }; 125 134 126 135 let subject = StrongRef { ··· 156 165 swap_record: None, 157 166 }; 158 167 159 - let put_record_result = client.put_record(&client_auth, rsvp_record).await; 168 + let put_record_result = put_record( 169 + &web_context.http_client, 170 + &dpop_auth, 171 + &current_handle.pds, 172 + rsvp_record, 173 + ) 174 + .await; 160 175 161 - if let Err(err) = put_record_result { 162 - return contextual_error!( 163 - web_context, 164 - language, 165 - error_template, 166 - default_context, 167 - err 168 - ); 169 - } 170 - 171 - let create_record_result = put_record_result.unwrap(); 176 + let create_record_result = match put_record_result { 177 + Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => { 178 + crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid } 179 + } 180 + Ok(PutRecordResponse::Error(err)) => { 181 + return contextual_error!( 182 + web_context, 183 + language, 184 + error_template, 185 + default_context, 186 + anyhow::anyhow!("Server error: {}", err.error_message()) 187 + ); 188 + } 189 + Err(err) => { 190 + return contextual_error!( 191 + web_context, 192 + language, 193 + error_template, 194 + default_context, 195 + err 196 + ); 197 + } 198 + }; 172 199 173 200 let rsvp_insert_result = rsvp_insert( 174 201 &web_context.pool,
+214
src/http/handle_delete_event.rs
··· 1 + use anyhow::{Result, anyhow}; 2 + use axum::{extract::Path, response::IntoResponse}; 3 + use axum_extra::extract::Form; 4 + use axum_template::RenderHtml; 5 + use http::StatusCode; 6 + use minijinja::context as template_context; 7 + use serde::{Deserialize, Serialize}; 8 + 9 + use atproto_client::com::atproto::repo::{delete_record, DeleteRecordRequest}; 10 + 11 + use crate::{ 12 + atproto::{ 13 + auth::{create_dpop_auth_from_oauth_session, create_dpop_auth_from_aip_session}, 14 + lexicon::community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID, 15 + }, 16 + contextual_error, 17 + http::{context::UserRequestContext, errors::WebError, middleware_auth::Auth}, 18 + select_template, 19 + storage::{event::{event_delete, event_exists}}, 20 + config::OAuthBackendConfig, 21 + }; 22 + 23 + #[derive(Debug, Deserialize, Serialize)] 24 + pub struct DeleteEventForm { 25 + pub confirm: Option<String>, 26 + } 27 + 28 + pub(crate) async fn handle_delete_event( 29 + ctx: UserRequestContext, 30 + Path((handle_slug, event_rkey)): Path<(String, String)>, 31 + Form(form): Form<DeleteEventForm>, 32 + ) -> Result<impl IntoResponse, WebError> { 33 + let current_handle = match &ctx.auth { 34 + Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => profile.clone(), 35 + Auth::Unauthenticated => { 36 + return Ok(StatusCode::FORBIDDEN.into_response()); 37 + } 38 + }; 39 + 40 + let default_context = template_context! { 41 + current_handle, 42 + language => ctx.language.to_string(), 43 + canonical_url => format!("https://{}/{}/{}/delete", ctx.web_context.config.external_base, handle_slug, event_rkey), 44 + event_url => format!("/{}/{}", handle_slug, event_rkey), 45 + }; 46 + 47 + let error_template = select_template!(false, false, ctx.language); 48 + let render_template = select_template!("delete_event", false, false, ctx.language); 49 + 50 + let lookup_aturi = format!( 51 + "at://{}/{}/{}", 52 + current_handle.did, LexiconCommunityEventNSID, event_rkey 53 + ); 54 + 55 + // Get the event 56 + let event_exists = event_exists(&ctx.web_context.pool, &lookup_aturi).await; 57 + if let Err(err) = event_exists { 58 + return contextual_error!( 59 + ctx.web_context, 60 + ctx.language, 61 + error_template, 62 + default_context, 63 + err 64 + ); 65 + } 66 + 67 + if !event_exists.unwrap() { 68 + return contextual_error!( 69 + ctx.web_context, 70 + ctx.language, 71 + error_template, 72 + default_context, 73 + anyhow!( 74 + "error-delete-event-1 identity does not have event with that AT-URI: {lookup_aturi}" 75 + ) 76 + ); 77 + } 78 + 79 + // If form is submitted with confirmation or it's a non-HTMX POST, proceed with deletion 80 + if form.confirm.as_deref() == Some("true") { 81 + // Create DPoP authentication based on auth type 82 + let dpop_auth = match &ctx.auth { 83 + Auth::Pds { session, .. } => { 84 + match create_dpop_auth_from_oauth_session(session) { 85 + Ok(auth) => auth, 86 + Err(err) => { 87 + tracing::error!("Failed to create DPoP auth from OAuth session: {}", err); 88 + return contextual_error!( 89 + ctx.web_context, 90 + ctx.language, 91 + error_template, 92 + default_context, 93 + err, 94 + StatusCode::INTERNAL_SERVER_ERROR 95 + ); 96 + } 97 + } 98 + } 99 + Auth::Aip { .. } => { 100 + match &ctx.web_context.config.oauth_backend { 101 + OAuthBackendConfig::AIP { hostname, .. } => { 102 + let access_token = match &ctx.auth { 103 + Auth::Aip { access_token, .. } => access_token, 104 + _ => unreachable!("We already matched on Auth::Aip"), 105 + }; 106 + 107 + match create_dpop_auth_from_aip_session( 108 + &ctx.web_context.http_client, 109 + hostname, 110 + access_token, 111 + ).await { 112 + Ok(auth) => auth, 113 + Err(err) => { 114 + tracing::error!("Failed to create DPoP auth from AIP session: {}", err); 115 + return contextual_error!( 116 + ctx.web_context, 117 + ctx.language, 118 + error_template, 119 + default_context, 120 + err, 121 + StatusCode::INTERNAL_SERVER_ERROR 122 + ); 123 + } 124 + } 125 + } 126 + _ => { 127 + tracing::error!("AIP auth found but OAuth backend is not AIP"); 128 + return contextual_error!( 129 + ctx.web_context, 130 + ctx.language, 131 + error_template, 132 + default_context, 133 + anyhow!("Authentication configuration mismatch"), 134 + StatusCode::INTERNAL_SERVER_ERROR 135 + ); 136 + } 137 + } 138 + } 139 + Auth::Unauthenticated => { 140 + // This should not happen due to the check above 141 + return Ok(StatusCode::FORBIDDEN.into_response()); 142 + } 143 + }; 144 + 145 + // Delete from PDS first 146 + let delete_record_request = DeleteRecordRequest { 147 + repo: current_handle.did.clone(), 148 + collection: LexiconCommunityEventNSID.to_string(), 149 + record_key: event_rkey.clone(), 150 + swap_commit: None, 151 + swap_record: None, 152 + }; 153 + 154 + match delete_record( 155 + &ctx.web_context.http_client, 156 + &dpop_auth, 157 + &current_handle.pds, 158 + delete_record_request, 159 + ).await { 160 + Ok(_) => { 161 + tracing::info!("Successfully deleted event from PDS: {}", lookup_aturi); 162 + } 163 + Err(err) => { 164 + tracing::error!("Failed to delete event from PDS: {}", err); 165 + return contextual_error!( 166 + ctx.web_context, 167 + ctx.language, 168 + error_template, 169 + default_context, 170 + err, 171 + StatusCode::INTERNAL_SERVER_ERROR 172 + ); 173 + } 174 + } 175 + 176 + // Delete from local storage 177 + if let Err(err) = event_delete(&ctx.web_context.pool, &lookup_aturi).await { 178 + tracing::error!("Failed to delete event from local storage: {}", err); 179 + return contextual_error!( 180 + ctx.web_context, 181 + ctx.language, 182 + error_template, 183 + default_context, 184 + err, 185 + StatusCode::INTERNAL_SERVER_ERROR 186 + ); 187 + } 188 + 189 + let render_template = select_template!("alert", false, false, ctx.language); 190 + 191 + // TODO: Localize these strings. 192 + 193 + return Ok(RenderHtml( 194 + &render_template, 195 + ctx.web_context.engine.clone(), 196 + template_context! { ..default_context, ..template_context! { 197 + message_type => "info", 198 + message_title => "Event Deleted Successfully", 199 + message => "The event has been deleted from your Personal Data Server (PDS) and removed this Smoke Signal instance.", 200 + }}, 201 + ) 202 + .into_response()); 203 + } 204 + 205 + Ok(RenderHtml( 206 + &render_template, 207 + ctx.web_context.engine.clone(), 208 + template_context! { ..default_context, ..template_context! { 209 + show_confirm => true, 210 + event_url => format!("/{}/{}", handle_slug, event_rkey), 211 + }}, 212 + ) 213 + .into_response()) 214 + }
+69 -38
src/http/handle_edit_event.rs
··· 7 7 use http::{Method, StatusCode}; 8 8 use minijinja::context as template_context; 9 9 10 + use crate::atproto::auth::{ 11 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 12 + }; 13 + use crate::config::OAuthBackendConfig; 14 + use crate::http::middleware_auth::Auth; 10 15 use crate::{ 11 16 atproto::{ 12 - auth::SimpleOAuthSessionProvider, 13 - client::{OAuthPdsClient, PutRecordRequest}, 14 17 lexicon::community::lexicon::calendar::event::{ 15 18 Event as LexiconCommunityEvent, EventLink, EventLocation, Mode, NamedUri, Status, 16 19 NSID as LexiconCommunityEventNSID, ··· 26 29 http::location_edit_status::{check_location_edit_status, LocationEditStatus}, 27 30 http::timezones::supported_timezones, 28 31 http::utils::url_from_aturi, 29 - resolve::{parse_input, InputType}, 30 32 select_template, 31 33 storage::{ 32 34 event::{event_get, event_update_with_metadata}, 33 - handle::{handle_for_did, handle_for_handle}, 35 + identity_profile::{handle_for_did, handle_for_handle}, 34 36 }, 35 37 }; 38 + use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse}; 36 39 37 - pub async fn handle_edit_event( 40 + pub(crate) async fn handle_edit_event( 38 41 ctx: UserRequestContext, 39 42 method: Method, 40 43 HxBoosted(hx_boosted): HxBoosted, ··· 42 45 Path((handle_slug, event_rkey)): Path<(String, String)>, 43 46 Form(mut build_event_form): Form<BuildEventForm>, 44 47 ) -> Result<impl IntoResponse, WebError> { 45 - let current_handle = ctx 46 - .auth 47 - .require(&ctx.web_context.config.destination_key, "/")?; 48 + let current_handle = ctx.auth.require(&ctx.web_context.config, "/")?; 48 49 49 50 let default_context = template_context! { 50 51 current_handle, ··· 53 54 create_event => false, 54 55 submit_url => format!("/{}/{}/edit", handle_slug, event_rkey), 55 56 cancel_url => format!("/{}/{}", handle_slug, event_rkey), 57 + delete_event_url => format!("https://{}/{}/{}/delete", ctx.web_context.config.external_base, handle_slug, event_rkey), 56 58 }; 57 59 58 60 let render_template = select_template!("edit_event", hx_boosted, hx_request, ctx.language); 59 61 let error_template = select_template!(hx_boosted, hx_request, ctx.language); 60 62 61 63 // Lookup the event 62 - let profile = match parse_input(&handle_slug) { 63 - Ok(InputType::Handle(handle)) => handle_for_handle(&ctx.web_context.pool, &handle) 64 + let profile = if handle_slug.starts_with("did:") { 65 + handle_for_did(&ctx.web_context.pool, &handle_slug) 66 + .await 67 + .map_err(WebError::from) 68 + } else { 69 + let handle = if let Some(handle) = handle_slug.strip_prefix('@') { 70 + handle 71 + } else { 72 + &handle_slug 73 + }; 74 + handle_for_handle(&ctx.web_context.pool, handle) 64 75 .await 65 - .map_err(WebError::from), 66 - Ok(InputType::Plc(did) | InputType::Web(did)) => { 67 - handle_for_did(&ctx.web_context.pool, &did) 68 - .await 69 - .map_err(WebError::from) 70 - } 71 - _ => Err(WebError::from(EditEventError::InvalidHandleSlug)), 76 + .map_err(WebError::from) 72 77 }?; 73 78 74 79 let lookup_aturi = format!( ··· 484 489 _ => None, 485 490 }); 486 491 487 - let client_auth: SimpleOAuthSessionProvider = 488 - SimpleOAuthSessionProvider::try_from(ctx.auth.1.unwrap())?; 489 - 490 - let client = OAuthPdsClient { 491 - http_client: &ctx.web_context.http_client, 492 - pds: &current_handle.pds, 492 + // Create DPoP auth based on OAuth backend type 493 + let dpop_auth = match (&ctx.auth, &ctx.web_context.config.oauth_backend) { 494 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 495 + create_dpop_auth_from_oauth_session(session)? 496 + } 497 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 498 + create_dpop_auth_from_aip_session( 499 + &ctx.web_context.http_client, 500 + hostname, 501 + access_token, 502 + ) 503 + .await? 504 + } 505 + _ => return Err(CommonError::NotAuthorized.into()), 493 506 }; 494 507 495 508 // Extract existing locations and URIs from the original record ··· 601 614 swap_record: Some(event.cid.clone()), 602 615 }; 603 616 604 - let update_record_result = 605 - client.put_record(&client_auth, update_record_request).await; 606 - 607 - if let Err(err) = update_record_result { 608 - return contextual_error!( 609 - ctx.web_context, 610 - ctx.language, 611 - error_template, 612 - default_context, 613 - err, 614 - StatusCode::OK 615 - ); 616 - } 617 + let update_record_result = put_record( 618 + &ctx.web_context.http_client, 619 + &dpop_auth, 620 + &current_handle.pds, 621 + update_record_request, 622 + ) 623 + .await; 617 624 618 - let update_record_result = update_record_result.unwrap(); 625 + let update_record_response = match update_record_result { 626 + Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => { 627 + crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid } 628 + } 629 + Ok(PutRecordResponse::Error(err)) => { 630 + return contextual_error!( 631 + ctx.web_context, 632 + ctx.language, 633 + error_template, 634 + default_context, 635 + anyhow::anyhow!("Server error: {}", err.error_message()), 636 + StatusCode::OK 637 + ); 638 + } 639 + Err(err) => { 640 + return contextual_error!( 641 + ctx.web_context, 642 + ctx.language, 643 + error_template, 644 + default_context, 645 + err, 646 + StatusCode::OK 647 + ); 648 + } 649 + }; 619 650 620 651 let name = match &updated_record { 621 652 LexiconCommunityEvent::Current { name, .. } => name, ··· 625 656 let event_update_result = event_update_with_metadata( 626 657 &ctx.web_context.pool, 627 658 &lookup_aturi, 628 - &update_record_result.cid, 659 + &update_record_response.cid, 629 660 &updated_record, 630 661 name, 631 662 )
+86 -42
src/http/handle_import.rs
··· 9 9 use minijinja::context as template_context; 10 10 use serde::Deserialize; 11 11 12 + use crate::atproto::auth::{ 13 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 14 + }; 15 + use crate::config::OAuthBackendConfig; 12 16 use crate::{ 13 - atproto::{ 14 - auth::SimpleOAuthSessionProvider, 15 - client::{ListRecordsParams, OAuthPdsClient}, 16 - lexicon::{ 17 - community::lexicon::calendar::{ 18 - event::{Event as LexiconCommunityEvent, NSID as LEXICON_COMMUNITY_EVENT_NSID}, 19 - rsvp::{ 20 - Rsvp as LexiconCommunityRsvp, RsvpStatus as LexiconCommunityRsvpStatus, 21 - NSID as LEXICON_COMMUNITY_RSVP_NSID, 22 - }, 17 + atproto::lexicon::{ 18 + community::lexicon::calendar::{ 19 + event::{Event as LexiconCommunityEvent, NSID as LEXICON_COMMUNITY_EVENT_NSID}, 20 + rsvp::{ 21 + Rsvp as LexiconCommunityRsvp, RsvpStatus as LexiconCommunityRsvpStatus, 22 + NSID as LEXICON_COMMUNITY_RSVP_NSID, 23 23 }, 24 - events::smokesignal::calendar::{ 25 - event::{Event as SmokeSignalEvent, NSID as SMOKESIGNAL_EVENT_NSID}, 26 - rsvp::{ 27 - Rsvp as SmokeSignalRsvp, RsvpStatus as SmokeSignalRsvpStatus, 28 - NSID as SMOKESIGNAL_RSVP_NSID, 29 - }, 24 + }, 25 + events::smokesignal::calendar::{ 26 + event::{Event as SmokeSignalEvent, NSID as SMOKESIGNAL_EVENT_NSID}, 27 + rsvp::{ 28 + Rsvp as SmokeSignalRsvp, RsvpStatus as SmokeSignalRsvpStatus, 29 + NSID as SMOKESIGNAL_RSVP_NSID, 30 30 }, 31 31 }, 32 32 }, 33 33 contextual_error, 34 34 http::{ 35 35 context::WebContext, 36 - errors::{ImportError, WebError}, 36 + errors::{CommonError, ImportError, WebError}, 37 37 middleware_auth::Auth, 38 38 middleware_i18n::Language, 39 39 }, 40 40 select_template, 41 41 storage::event::{event_insert_with_metadata, rsvp_insert_with_metadata}, 42 42 }; 43 + use atproto_client::com::atproto::repo::{list_records, ListRecordsParams}; 43 44 44 - pub async fn handle_import( 45 + pub(crate) async fn handle_import( 45 46 State(web_context): State<WebContext>, 46 47 Language(language): Language, 47 48 Cached(auth): Cached<Auth>, 48 49 HxRequest(hx_request): HxRequest, 49 50 HxBoosted(hx_boosted): HxBoosted, 50 51 ) -> Result<impl IntoResponse, WebError> { 51 - let current_handle = auth.require(&web_context.config.destination_key, "/import")?; 52 + let current_handle = auth.require(&web_context.config, "/import")?; 52 53 53 54 let default_context = template_context! { 54 55 current_handle, ··· 67 68 } 68 69 69 70 #[derive(Debug, Deserialize)] 70 - pub struct ImportForm { 71 - pub collection: Option<String>, 72 - pub cursor: Option<String>, 71 + pub(crate) struct ImportForm { 72 + pub(crate) collection: Option<String>, 73 + pub(crate) cursor: Option<String>, 73 74 } 74 75 75 - pub async fn handle_import_submit( 76 + pub(crate) async fn handle_import_submit( 76 77 State(web_context): State<WebContext>, 77 78 Language(language): Language, 78 79 Cached(auth): Cached<Auth>, ··· 98 99 let collection = import_form.collection.unwrap_or(collections[0].to_string()); 99 100 let cursor = import_form.cursor; 100 101 101 - let client_auth: SimpleOAuthSessionProvider = 102 - SimpleOAuthSessionProvider::try_from(auth.1.unwrap())?; 103 - let client = OAuthPdsClient { 104 - http_client: &web_context.http_client, 105 - pds: &current_handle.pds, 102 + // Create DPoP auth based on OAuth backend type 103 + let dpop_auth = match (&auth, &web_context.config.oauth_backend) { 104 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 105 + create_dpop_auth_from_oauth_session(session)? 106 + } 107 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 108 + create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token) 109 + .await? 110 + } 111 + _ => return Err(CommonError::NotAuthorized.into()), 106 112 }; 107 113 108 114 const LIMIT: u32 = 20; 109 115 110 116 // Set up list records parameters to fetch records 111 117 let list_params = ListRecordsParams { 112 - repo: current_handle.did.clone(), 113 - collection: collection.clone(), 114 118 limit: Some(LIMIT), 115 119 cursor, 116 120 reverse: None, ··· 118 122 119 123 let render_context = match collection.as_str() { 120 124 LEXICON_COMMUNITY_EVENT_NSID => { 121 - let results = client 122 - .list_records::<LexiconCommunityEvent>(&client_auth, &list_params) 123 - .await; 125 + let results = list_records::<LexiconCommunityEvent>( 126 + &web_context.http_client, 127 + &dpop_auth, 128 + &current_handle.pds, 129 + current_handle.did.clone(), 130 + collection.clone(), 131 + ListRecordsParams { 132 + limit: Some(LIMIT), 133 + cursor: list_params.cursor.clone(), 134 + reverse: None, 135 + }, 136 + ) 137 + .await; 124 138 match results { 125 139 Ok(list_records) => { 126 140 let mut items = vec![]; ··· 176 190 } 177 191 } 178 192 LEXICON_COMMUNITY_RSVP_NSID => { 179 - let results = client 180 - .list_records::<LexiconCommunityRsvp>(&client_auth, &list_params) 181 - .await; 193 + let results = list_records::<LexiconCommunityRsvp>( 194 + &web_context.http_client, 195 + &dpop_auth, 196 + &current_handle.pds, 197 + current_handle.did.clone(), 198 + collection.clone(), 199 + ListRecordsParams { 200 + limit: Some(LIMIT), 201 + cursor: list_params.cursor.clone(), 202 + reverse: None, 203 + }, 204 + ) 205 + .await; 182 206 match results { 183 207 Ok(list_records) => { 184 208 let mut items = vec![]; ··· 248 272 } 249 273 } 250 274 SMOKESIGNAL_EVENT_NSID => { 251 - let results = client 252 - .list_records::<SmokeSignalEvent>(&client_auth, &list_params) 253 - .await; 275 + let results = list_records::<SmokeSignalEvent>( 276 + &web_context.http_client, 277 + &dpop_auth, 278 + &current_handle.pds, 279 + current_handle.did.clone(), 280 + collection.clone(), 281 + ListRecordsParams { 282 + limit: Some(LIMIT), 283 + cursor: list_params.cursor.clone(), 284 + reverse: None, 285 + }, 286 + ) 287 + .await; 254 288 match results { 255 289 Ok(list_records) => { 256 290 let mut items = vec![]; ··· 306 340 } 307 341 } 308 342 SMOKESIGNAL_RSVP_NSID => { 309 - let results = client 310 - .list_records::<SmokeSignalRsvp>(&client_auth, &list_params) 311 - .await; 343 + let results = list_records::<SmokeSignalRsvp>( 344 + &web_context.http_client, 345 + &dpop_auth, 346 + &current_handle.pds, 347 + current_handle.did.clone(), 348 + collection.clone(), 349 + ListRecordsParams { 350 + limit: Some(LIMIT), 351 + cursor: list_params.cursor.clone(), 352 + reverse: None, 353 + }, 354 + ) 355 + .await; 312 356 match results { 313 357 Ok(list_records) => { 314 358 let mut items = vec![];
+3 -3
src/http/handle_index.rs
··· 46 46 } 47 47 } 48 48 49 - pub async fn handle_index( 49 + pub(crate) async fn handle_index( 50 50 State(web_context): State<WebContext>, 51 51 HxBoosted(hx_boosted): HxBoosted, 52 52 Language(language): Language, ··· 88 88 .filter_map(|event_view| { 89 89 let organizer_maybe = organizer_handlers.get(&event_view.event.did); 90 90 let event_view = 91 - EventView::try_from((auth.0.as_ref(), organizer_maybe, &event_view.event)); 91 + EventView::try_from((auth.profile(), organizer_maybe, &event_view.event)); 92 92 93 93 match event_view { 94 94 Ok(event_view) => Some(event_view), ··· 118 118 &render_template, 119 119 web_context.engine.clone(), 120 120 template_context! { 121 - current_handle => auth.0, 121 + current_handle => auth.profile(), 122 122 language => language.to_string(), 123 123 canonical_url => format!("https://{}/", web_context.config.external_base), 124 124 tab => tab.to_string(),
+73 -45
src/http/handle_migrate_event.rs
··· 10 10 use minijinja::context as template_context; 11 11 use std::collections::HashMap; 12 12 13 + use crate::atproto::auth::{ 14 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 15 + }; 16 + use crate::config::OAuthBackendConfig; 13 17 use crate::{ 14 - atproto::{ 15 - auth::SimpleOAuthSessionProvider, 16 - client::{OAuthPdsClient, PutRecordRequest}, 17 - lexicon::{ 18 - community::lexicon::calendar::event::{ 19 - Event as CommunityEvent, EventLink, EventLocation as CommunityLocation, Mode, 20 - Status, NSID as COMMUNITY_NSID, 21 - }, 22 - community::lexicon::location, 23 - events::smokesignal::calendar::event::{ 24 - Event as SmokeSignalEvent, Location as SmokeSignalLocation, PlaceLocation, 25 - NSID as SMOKESIGNAL_NSID, 26 - }, 18 + atproto::lexicon::{ 19 + community::lexicon::calendar::event::{ 20 + Event as CommunityEvent, EventLink, EventLocation as CommunityLocation, Mode, Status, 21 + NSID as COMMUNITY_NSID, 22 + }, 23 + community::lexicon::location, 24 + events::smokesignal::calendar::event::{ 25 + Event as SmokeSignalEvent, Location as SmokeSignalLocation, PlaceLocation, 26 + NSID as SMOKESIGNAL_NSID, 27 27 }, 28 28 }, 29 29 contextual_error, ··· 31 31 context::WebContext, errors::MigrateEventError, errors::WebError, middleware_auth::Auth, 32 32 middleware_i18n::Language, utils::url_from_aturi, 33 33 }, 34 - resolve::{parse_input, InputType}, 35 34 select_template, 36 35 storage::{ 37 36 event::{event_get, event_insert_with_metadata}, 38 - handle::{handle_for_did, handle_for_handle, model::Handle}, 37 + identity_profile::{handle_for_did, handle_for_handle, model::IdentityProfile}, 39 38 }, 40 39 }; 40 + use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse}; 41 41 42 - pub async fn handle_migrate_event( 42 + pub(crate) async fn handle_migrate_event( 43 43 State(web_context): State<WebContext>, 44 44 HxBoosted(hx_boosted): HxBoosted, 45 45 Language(language): Language, ··· 47 47 HxRequest(hx_request): HxRequest, 48 48 Path((handle_slug, event_rkey)): Path<(String, String)>, 49 49 ) -> Result<impl IntoResponse, WebError> { 50 - let current_handle = auth.require(&web_context.config.destination_key, "/")?; 50 + let current_handle = auth.require(&web_context.config, "/")?; 51 51 52 52 // Configure templates 53 53 let default_context = template_context! { ··· 59 59 let error_template = select_template!(hx_boosted, hx_request, language); 60 60 61 61 // Lookup the user handle/profile 62 - let profile: Result<Handle> = match parse_input(&handle_slug) { 63 - Ok(InputType::Handle(handle)) => handle_for_handle(&web_context.pool, &handle) 62 + let profile: Result<IdentityProfile> = if handle_slug.starts_with("did:") { 63 + handle_for_did(&web_context.pool, &handle_slug) 64 64 .await 65 - .map_err(|err| err.into()), 66 - Ok(InputType::Plc(did) | InputType::Web(did)) => handle_for_did(&web_context.pool, &did) 65 + .map_err(|err| err.into()) 66 + } else { 67 + let handle = if let Some(handle) = handle_slug.strip_prefix('@') { 68 + handle 69 + } else { 70 + &handle_slug 71 + }; 72 + handle_for_handle(&web_context.pool, handle) 67 73 .await 68 - .map_err(|err| err.into()), 69 - Err(err) => Err(err.into()), 74 + .map_err(|err| err.into()) 70 75 }; 71 76 72 77 if let Err(err) = profile { ··· 261 266 ); 262 267 } 263 268 264 - // Set up XRPC client 265 - // Error if we don't have auth data 266 - let auth_data = auth.1.ok_or(MigrateEventError::NotAuthorized)?; 267 - let client_auth: SimpleOAuthSessionProvider = SimpleOAuthSessionProvider::try_from(auth_data)?; 268 - 269 - let client = OAuthPdsClient { 270 - http_client: &web_context.http_client, 271 - pds: &current_handle.pds, 269 + // Set up authentication 270 + // Create DPoP auth based on OAuth backend type 271 + let dpop_auth = match (&auth, &web_context.config.oauth_backend) { 272 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 273 + create_dpop_auth_from_oauth_session(session)? 274 + } 275 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 276 + create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token) 277 + .await? 278 + } 279 + _ => return Err(MigrateEventError::NotAuthorized.into()), 272 280 }; 273 281 274 282 // Create the community event record in the user's PDS using putRecord to retain the same rkey ··· 283 291 }; 284 292 285 293 // Write to the PDS 286 - let update_record_result = client.put_record(&client_auth, update_record_request).await; 287 - if let Err(err) = update_record_result { 288 - return contextual_error!( 289 - web_context, 290 - language, 291 - error_template, 292 - default_context, 293 - err, 294 - StatusCode::OK 295 - ); 296 - } 297 - // update_record_result is guaranteed to be Ok at this point since we checked for Err above 298 - let update_record_result = update_record_result?; 294 + let update_record_result = put_record( 295 + &web_context.http_client, 296 + &dpop_auth, 297 + &current_handle.pds, 298 + update_record_request, 299 + ) 300 + .await; 301 + 302 + let update_record_response = match update_record_result { 303 + Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => { 304 + crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid } 305 + } 306 + Ok(PutRecordResponse::Error(err)) => { 307 + return contextual_error!( 308 + web_context, 309 + language, 310 + error_template, 311 + default_context, 312 + anyhow::anyhow!("Server error: {}", err.error_message()), 313 + StatusCode::OK 314 + ); 315 + } 316 + Err(err) => { 317 + return contextual_error!( 318 + web_context, 319 + language, 320 + error_template, 321 + default_context, 322 + err, 323 + StatusCode::OK 324 + ); 325 + } 326 + }; 299 327 300 328 // We already have the migrated AT-URI defined above 301 329 ··· 303 331 let migrated_event_insert_result = event_insert_with_metadata( 304 332 &web_context.pool, 305 333 &migrated_aturi, 306 - &update_record_result.cid, 334 + &update_record_response.cid, 307 335 &current_handle.did, 308 336 COMMUNITY_NSID, 309 337 &new_event,
+57 -32
src/http/handle_migrate_rsvp.rs
··· 10 10 use minijinja::context as template_context; 11 11 use std::hash::Hasher; 12 12 13 + use crate::atproto::auth::{ 14 + create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session, 15 + }; 16 + use crate::config::OAuthBackendConfig; 13 17 use crate::{ 14 - atproto::{ 15 - auth::SimpleOAuthSessionProvider, 16 - client::{OAuthPdsClient, PutRecordRequest}, 17 - lexicon::{ 18 - com::atproto::repo::StrongRef, 19 - community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID as RSVP_COLLECTION}, 20 - events::smokesignal::calendar::event::NSID as EVENT_COLLECTION, 21 - }, 18 + atproto::lexicon::{ 19 + com::atproto::repo::StrongRef, 20 + community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID as RSVP_COLLECTION}, 21 + events::smokesignal::calendar::event::NSID as EVENT_COLLECTION, 22 22 }, 23 23 contextual_error, 24 24 http::{ ··· 27 27 middleware_auth::Auth, 28 28 middleware_i18n::Language, 29 29 }, 30 - resolve::{parse_input, InputType}, 31 30 select_template, 32 31 storage::{ 33 32 event::{event_get, get_user_rsvp, rsvp_insert}, 34 - handle::{handle_for_did, handle_for_handle, model::Handle}, 33 + identity_profile::{handle_for_did, handle_for_handle, model::IdentityProfile}, 35 34 }, 36 35 }; 36 + use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse}; 37 37 38 38 /// Migrates a user's RSVP from a legacy event to a standard event format. 39 39 /// ··· 48 48 /// 3. Ensuring the user doesn't already have an RSVP for the standard event 49 49 /// 4. Creating a new RSVP record in the standard format 50 50 /// 5. Storing the RSVP both on the PDS and in the local database 51 - pub async fn handle_migrate_rsvp( 51 + pub(crate) async fn handle_migrate_rsvp( 52 52 State(web_context): State<WebContext>, 53 53 Language(language): Language, 54 54 Cached(auth): Cached<Auth>, ··· 56 56 ) -> Result<impl IntoResponse, WebError> { 57 57 // Require user to be logged in 58 58 let current_handle = auth.require( 59 - &web_context.config.destination_key, 59 + &web_context.config, 60 60 "/{handle_slug}/{event_rkey}/migrate-rsvp", 61 61 )?; 62 62 ··· 69 69 let error_template = select_template!(false, false, language); 70 70 71 71 // Get handle information from the path parameter 72 - let profile: Result<Handle> = match parse_input(&handle_slug) { 73 - Ok(InputType::Handle(handle)) => handle_for_handle(&web_context.pool, &handle) 72 + let profile: Result<IdentityProfile> = if handle_slug.starts_with("did:") { 73 + handle_for_did(&web_context.pool, &handle_slug) 74 74 .await 75 - .map_err(|err| err.into()), 76 - Ok(InputType::Plc(did) | InputType::Web(did)) => handle_for_did(&web_context.pool, &did) 75 + .map_err(|err| err.into()) 76 + } else { 77 + let handle = if let Some(handle) = handle_slug.strip_prefix('@') { 78 + handle 79 + } else { 80 + &handle_slug 81 + }; 82 + handle_for_handle(&web_context.pool, handle) 77 83 .await 78 - .map_err(|err| err.into()), 79 - Err(err) => Err(err.into()), 84 + .map_err(|err| err.into()) 80 85 }; 81 86 82 87 let profile = match profile { ··· 181 186 } 182 187 183 188 // Create a new RSVP for the standard event 184 - // Error if we don't have auth data 185 - let auth_data = auth.1.ok_or(MigrateRsvpError::NotAuthorized)?; 186 - let client_auth: SimpleOAuthSessionProvider = SimpleOAuthSessionProvider::try_from(auth_data)?; 187 - 188 - let client = OAuthPdsClient { 189 - http_client: &web_context.http_client, 190 - pds: &current_handle.pds, 189 + // Create DPoP auth based on OAuth backend type 190 + let dpop_auth = match (&auth, &web_context.config.oauth_backend) { 191 + (Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => { 192 + create_dpop_auth_from_oauth_session(session)? 193 + } 194 + (Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => { 195 + create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token) 196 + .await? 197 + } 198 + _ => return Err(MigrateRsvpError::NotAuthorized.into()), 191 199 }; 192 200 193 201 // Create a reference to the standard event that will be the subject of the RSVP ··· 242 250 swap_record: None, 243 251 }; 244 252 245 - let put_record_result = client.put_record(&client_auth, rsvp_record).await; 253 + let put_record_result = put_record( 254 + &web_context.http_client, 255 + &dpop_auth, 256 + &current_handle.pds, 257 + rsvp_record, 258 + ) 259 + .await; 246 260 247 - if let Err(err) = put_record_result { 248 - return contextual_error!(web_context, language, error_template, default_context, err); 249 - } 250 - 251 - // put_record_result is guaranteed to be Ok here since we checked for Err above 252 - let create_record_result = put_record_result?; 261 + let create_record_result = match put_record_result { 262 + Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => { 263 + crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid } 264 + } 265 + Ok(PutRecordResponse::Error(err)) => { 266 + return contextual_error!( 267 + web_context, 268 + language, 269 + error_template, 270 + default_context, 271 + anyhow::anyhow!("Server error: {}", err.error_message()) 272 + ); 273 + } 274 + Err(err) => { 275 + return contextual_error!(web_context, language, error_template, default_context, err); 276 + } 277 + }; 253 278 254 279 // Store the new RSVP in the database 255 280 let rsvp_insert_result = rsvp_insert(
+149
src/http/handle_oauth_aip_callback.rs
··· 1 + use crate::{config::OAuthBackendConfig, contextual_error, select_template}; 2 + use anyhow::Result; 3 + use axum::{ 4 + extract::State, 5 + response::{IntoResponse, Redirect}, 6 + }; 7 + use axum_extra::extract::{ 8 + cookie::{Cookie, SameSite}, 9 + Form, PrivateCookieJar, 10 + }; 11 + use minijinja::context as template_context; 12 + use serde::{Deserialize, Serialize}; 13 + 14 + use super::{ 15 + context::WebContext, 16 + errors::{LoginError, WebError}, 17 + middleware_auth::{WebSession, AUTH_COOKIE_NAME}, 18 + middleware_i18n::Language, 19 + }; 20 + 21 + #[derive(Deserialize, Serialize)] 22 + pub(crate) struct OAuthCallbackForm { 23 + pub(crate) state: Option<String>, 24 + pub(crate) code: Option<String>, 25 + } 26 + 27 + pub(crate) async fn handle_oauth_callback( 28 + State(web_context): State<WebContext>, 29 + Language(language): Language, 30 + jar: PrivateCookieJar, 31 + Form(callback_form): Form<OAuthCallbackForm>, 32 + ) -> Result<impl IntoResponse, WebError> { 33 + let default_context = template_context! { 34 + language => language.to_string(), 35 + canonical_url => format!("https://{}/oauth/callback", web_context.config.external_base), 36 + }; 37 + 38 + let error_template = select_template!(false, false, language); 39 + 40 + // Get AIP server configuration - config validation ensures these are set when oauth_backend is AIP 41 + let (hostname, client_id, client_secret) = if let OAuthBackendConfig::AIP { 42 + hostname, 43 + client_id, 44 + client_secret, 45 + } = &web_context.config.oauth_backend 46 + { 47 + (hostname, client_id, client_secret) 48 + } else { 49 + unreachable!("AIP OAuth backend should have AIP configuration") 50 + }; 51 + 52 + let (callback_code, callback_state) = match (callback_form.code, callback_form.state) { 53 + (Some(x), Some(z)) => (x, z), 54 + _ => { 55 + return contextual_error!( 56 + web_context, 57 + language, 58 + error_template, 59 + default_context, 60 + LoginError::OAuthCallbackIncomplete 61 + ); 62 + } 63 + }; 64 + 65 + let oauth_request = match web_context 66 + .oauth_storage 67 + .get_oauth_request_by_state(&callback_state) 68 + .await 69 + { 70 + Err(err) => { 71 + return contextual_error!(web_context, language, error_template, default_context, err); 72 + } 73 + Ok(None) => { 74 + return contextual_error!( 75 + web_context, 76 + language, 77 + error_template, 78 + default_context, 79 + anyhow::anyhow!("oauth request not found in storage") 80 + ); 81 + } 82 + Ok(Some(value)) => value, 83 + }; 84 + 85 + let authorization_server = match atproto_oauth_aip::resources::oauth_authorization_server( 86 + &web_context.http_client, 87 + hostname, 88 + ) 89 + .await 90 + { 91 + Ok(value) => value, 92 + Err(err) => { 93 + return contextual_error!(web_context, language, error_template, default_context, err); 94 + } 95 + }; 96 + 97 + let oauth_client = atproto_oauth_aip::workflow::OAuthClient { 98 + redirect_uri: format!( 99 + "https://{}/oauth/callback", 100 + &web_context.config.external_base 101 + ), 102 + client_id: client_id.clone(), 103 + client_secret: client_secret.clone(), 104 + }; 105 + 106 + let token_response = atproto_oauth_aip::workflow::oauth_complete( 107 + &web_context.http_client, 108 + &oauth_client, 109 + &authorization_server, 110 + &callback_code, 111 + &oauth_request, 112 + ) 113 + .await; 114 + 115 + if let Err(err) = token_response { 116 + tracing::error!( 117 + ?err, 118 + "atproto_oauth_aip::workflow::oautoauth_completeh_init" 119 + ); 120 + return contextual_error!(web_context, language, error_template, default_context, err); 121 + } 122 + 123 + let token_response = token_response.unwrap(); 124 + 125 + let cookie_value: String = WebSession::Aip { 126 + did: oauth_request.did.clone(), 127 + access_token: token_response.access_token.clone(), 128 + } 129 + .try_into()?; 130 + 131 + let mut cookie = Cookie::new(AUTH_COOKIE_NAME, cookie_value); 132 + cookie.set_domain(web_context.config.external_base.clone()); 133 + cookie.set_path("/"); 134 + cookie.set_http_only(true); 135 + cookie.set_secure(true); 136 + cookie.set_max_age(Some(cookie::time::Duration::seconds( 137 + (token_response.expires_in as i64) - 60, 138 + ))); 139 + cookie.set_same_site(Some(SameSite::Lax)); 140 + 141 + let updated_jar = jar.add(cookie); 142 + 143 + // let destination = match oauth_request.destination { 144 + // Some(destination) => destination, 145 + // None => "/".to_string(), 146 + // }; 147 + 148 + Ok((updated_jar, Redirect::to("/")).into_response()) 149 + }
+281
src/http/handle_oauth_aip_login.rs
··· 1 + use anyhow::{anyhow, Result}; 2 + use atproto_identity::resolve::IdentityResolver; 3 + use atproto_oauth::pkce::generate; 4 + use atproto_oauth::workflow::OAuthRequestState as AipOAuthRequestState; 5 + use atproto_oauth_aip::{ 6 + resources::oauth_authorization_server, 7 + workflow::{oauth_init, OAuthClient}, 8 + }; 9 + use axum::response::Redirect; 10 + use axum::{extract::State, response::IntoResponse}; 11 + use axum_extra::extract::{Cached, Form, Query}; 12 + use axum_htmx::{HxBoosted, HxRedirect, HxRequest}; 13 + use axum_template::RenderHtml; 14 + use http::StatusCode; 15 + use minijinja::context as template_context; 16 + use rand::{distributions::Alphanumeric, Rng}; 17 + use serde::Deserialize; 18 + 19 + use crate::{ 20 + config::OAuthBackendConfig, 21 + contextual_error, 22 + http::{ 23 + context::WebContext, 24 + errors::{LoginError, WebError}, 25 + middleware_auth::Auth, 26 + middleware_i18n::Language, 27 + utils::stringify, 28 + }, 29 + select_template, 30 + storage::{denylist::denylist_exists, identity_profile::handle_warm_up}, 31 + }; 32 + 33 + #[derive(Deserialize)] 34 + pub(crate) struct OAuthLoginForm { 35 + handle: Option<String>, 36 + } 37 + 38 + #[derive(Deserialize)] 39 + pub(crate) struct Destination { 40 + destination: Option<String>, 41 + } 42 + 43 + #[allow(clippy::too_many_arguments)] 44 + pub(crate) async fn handle_oauth_aip_login( 45 + State(web_context): State<WebContext>, 46 + Language(language): Language, 47 + Cached(auth): Cached<Auth>, 48 + identity_resolver: IdentityResolver, 49 + HxRequest(hx_request): HxRequest, 50 + HxBoosted(hx_boosted): HxBoosted, 51 + Query(destination): Query<Destination>, 52 + Form(login_form): Form<OAuthLoginForm>, 53 + ) -> Result<impl IntoResponse, WebError> { 54 + let default_context = template_context! { 55 + current_handle => auth.profile(), 56 + language => language.to_string(), 57 + canonical_url => format!("https://{}/oauth/login", web_context.config.external_base), 58 + destination => destination.destination, 59 + }; 60 + 61 + let render_template = select_template!("login", hx_boosted, hx_request, language); 62 + let error_template = select_template!(hx_boosted, hx_request, language); 63 + 64 + if let Some(subject) = login_form.handle { 65 + let handle_denied = denylist_exists(&web_context.pool, &[subject.as_str()]) 66 + .await 67 + .unwrap_or(true); 68 + if handle_denied { 69 + return contextual_error!( 70 + web_context, 71 + language, 72 + error_template, 73 + default_context, 74 + anyhow!("access-denied") 75 + ); 76 + } 77 + 78 + let document = match identity_resolver.resolve(&subject).await { 79 + Ok(value) => value, 80 + Err(err) => { 81 + return contextual_error!( 82 + web_context, 83 + language, 84 + error_template, 85 + default_context, 86 + err 87 + ); 88 + } 89 + }; 90 + 91 + let handle = match document 92 + .handles() 93 + .ok_or(WebError::Login(LoginError::NoHandle)) 94 + { 95 + Ok(value) => value, 96 + Err(err) => { 97 + tracing::error!(?err, "handles"); 98 + return contextual_error!( 99 + web_context, 100 + language, 101 + error_template, 102 + default_context, 103 + err 104 + ); 105 + } 106 + }; 107 + 108 + let pds = match document 109 + .pds_endpoints() 110 + .first() 111 + .cloned() 112 + .ok_or(WebError::Login(LoginError::NoPDS)) 113 + { 114 + Ok(value) => value, 115 + Err(err) => { 116 + tracing::error!(?err, "pds_endpoints first"); 117 + return contextual_error!( 118 + web_context, 119 + language, 120 + error_template, 121 + default_context, 122 + err 123 + ); 124 + } 125 + }; 126 + 127 + if let Err(err) = web_context 128 + .document_storage 129 + .store_document(document.clone()) 130 + .await 131 + { 132 + tracing::error!(?err, "store_document"); 133 + return contextual_error!(web_context, language, error_template, default_context, err); 134 + } 135 + 136 + // Insert the handle if it doesn't exist 137 + if let Err(err) = handle_warm_up(&web_context.pool, &document.id, handle, pds).await { 138 + tracing::error!(?err, "handle_warm_up"); 139 + return contextual_error!(web_context, language, error_template, default_context, err); 140 + } 141 + 142 + // Generate OAuth parameters 143 + let state: String = rand::thread_rng() 144 + .sample_iter(&Alphanumeric) 145 + .take(30) 146 + .map(char::from) 147 + .collect(); 148 + let nonce: String = rand::thread_rng() 149 + .sample_iter(&Alphanumeric) 150 + .take(30) 151 + .map(char::from) 152 + .collect(); 153 + 154 + let (pkce_verifier, code_challenge) = generate(); 155 + 156 + // Create AIP-specific OAuth request state with scope 157 + let aip_oauth_request_state = AipOAuthRequestState { 158 + state: state.clone(), 159 + nonce: nonce.clone(), 160 + code_challenge, 161 + scope: "atproto:atproto atproto:transition:generic".to_string(), 162 + }; 163 + 164 + // Get AIP server configuration - config validation ensures these are set when oauth_backend is AIP 165 + let (hostname, client_id, client_secret) = if let OAuthBackendConfig::AIP { 166 + hostname, 167 + client_id, 168 + client_secret, 169 + } = &web_context.config.oauth_backend 170 + { 171 + (hostname, client_id, client_secret) 172 + } else { 173 + unreachable!("AIP OAuth backend should have AIP configuration") 174 + }; 175 + 176 + // Get AIP authorization server metadata 177 + let authorization_server = 178 + match oauth_authorization_server(&web_context.http_client, hostname).await { 179 + Ok(value) => value, 180 + Err(err) => { 181 + tracing::error!(?err, "oauth_authorization_server"); 182 + return contextual_error!( 183 + web_context, 184 + language, 185 + error_template, 186 + default_context, 187 + err 188 + ); 189 + } 190 + }; 191 + 192 + // Create AIP OAuth client 193 + let oauth_client = OAuthClient { 194 + redirect_uri: format!( 195 + "https://{}/oauth/callback", 196 + web_context.config.external_base 197 + ), 198 + client_id: client_id.clone(), 199 + client_secret: client_secret.clone(), 200 + }; 201 + tracing::info!(oauth_client.redirect_uri, "oauth_client"); 202 + 203 + // Initialize AIP OAuth flow 204 + let par_response = oauth_init( 205 + &web_context.http_client, 206 + &oauth_client, 207 + Some(&subject), 208 + &authorization_server, 209 + &aip_oauth_request_state, 210 + ) 211 + .await; 212 + 213 + if let Err(err) = par_response { 214 + tracing::error!(?err, "oauth_init"); 215 + return contextual_error!(web_context, language, error_template, default_context, err); 216 + } 217 + 218 + let par_response = par_response.unwrap(); 219 + 220 + let created_at = chrono::Utc::now(); 221 + let expires_at = created_at + chrono::Duration::seconds(par_response.expires_in as i64); 222 + 223 + let oauth_request = atproto_oauth::workflow::OAuthRequest { 224 + oauth_state: state.clone(), 225 + nonce: nonce.clone(), 226 + pkce_verifier: pkce_verifier.clone(), 227 + 228 + issuer: authorization_server.issuer.clone(), 229 + did: document.id, 230 + 231 + signing_public_key: "".to_string(), 232 + dpop_private_key: "".to_string(), 233 + created_at, 234 + expires_at, 235 + }; 236 + 237 + if let Err(err) = web_context 238 + .oauth_storage 239 + .insert_oauth_request(oauth_request) 240 + .await 241 + { 242 + tracing::error!(?err, "insert_oauth_request"); 243 + return contextual_error!(web_context, language, error_template, default_context, err); 244 + } 245 + 246 + let oauth_args = [ 247 + ( 248 + "request_uri".to_string(), 249 + urlencoding::encode(&par_response.request_uri).to_string(), 250 + ), 251 + ( 252 + "client_id".to_string(), 253 + urlencoding::encode(client_id).to_string(), 254 + ), 255 + ]; 256 + let oauth_args = oauth_args.iter().map(|(k, v)| (&**k, &**v)).collect(); 257 + 258 + let destination = format!( 259 + "{}?{}", 260 + authorization_server.authorization_endpoint, 261 + stringify(oauth_args) 262 + ); 263 + 264 + if hx_request { 265 + if let Ok(hx_redirect) = HxRedirect::try_from(destination.as_str()) { 266 + return Ok((StatusCode::OK, hx_redirect, "").into_response()); 267 + } 268 + } 269 + 270 + return Ok(Redirect::temporary(destination.as_str()).into_response()); 271 + } 272 + 273 + Ok(RenderHtml( 274 + &render_template, 275 + web_context.engine.clone(), 276 + template_context! { ..default_context, ..template_context! { 277 + destination => destination.destination, 278 + }}, 279 + ) 280 + .into_response()) 281 + }
+98 -100
src/http/handle_oauth_callback.rs
··· 1 - use anyhow::Result; 1 + use anyhow::{anyhow, Result}; 2 + use atproto_identity::{axum::state::KeyProviderExtractor, key::identify_key}; 3 + use atproto_oauth::workflow::{oauth_complete, OAuthClient}; 2 4 use axum::{ 3 5 extract::State, 4 6 response::{IntoResponse, Redirect}, ··· 7 9 cookie::{Cookie, SameSite}, 8 10 Form, PrivateCookieJar, 9 11 }; 10 - use deadpool_redis::redis::AsyncCommands as _; 11 12 use minijinja::context as template_context; 12 - use p256::SecretKey; 13 13 use serde::{Deserialize, Serialize}; 14 - use std::borrow::Cow; 15 14 16 - use crate::jose_errors::JwkError; 17 - use crate::storage::errors::CacheError; 18 - 19 - use crate::{ 20 - contextual_error, 21 - oauth::oauth_complete, 22 - select_template, 23 - storage::{ 24 - cache::OAUTH_REFRESH_QUEUE, 25 - handle::handle_for_did, 26 - oauth::{oauth_request_get, oauth_request_remove, oauth_session_insert}, 27 - }, 28 - }; 15 + use crate::{contextual_error, select_template}; 29 16 30 17 use super::{ 31 18 context::WebContext, ··· 35 22 }; 36 23 37 24 #[derive(Deserialize, Serialize)] 38 - pub struct OAuthCallbackForm { 39 - pub state: Option<String>, 40 - pub iss: Option<String>, 41 - pub code: Option<String>, 25 + pub(crate) struct OAuthCallbackForm { 26 + state: Option<String>, 27 + iss: Option<String>, 28 + code: Option<String>, 42 29 } 43 30 44 - pub async fn handle_oauth_callback( 31 + #[axum::debug_handler] 32 + pub(crate) async fn handle_oauth_callback( 45 33 State(web_context): State<WebContext>, 34 + key_provider: KeyProviderExtractor, 46 35 Language(language): Language, 47 36 jar: PrivateCookieJar, 48 37 Form(callback_form): Form<OAuthCallbackForm>, ··· 68 57 } 69 58 }; 70 59 71 - let oauth_request = oauth_request_get(&web_context.pool, &callback_state).await; 72 - if let Err(err) = oauth_request { 73 - return contextual_error!(web_context, language, error_template, default_context, err); 74 - } 60 + let oauth_request = web_context 61 + .oauth_storage 62 + .get_oauth_request_by_state(&callback_state) 63 + .await; 75 64 76 - let oauth_request = oauth_request.unwrap(); 65 + let oauth_request = match oauth_request { 66 + Err(err) => { 67 + return contextual_error!(web_context, language, error_template, default_context, err); 68 + } 69 + Ok(None) => { 70 + return contextual_error!( 71 + web_context, 72 + language, 73 + error_template, 74 + default_context, 75 + anyhow!("oauth request not found in storage") 76 + ); 77 + } 78 + Ok(Some(value)) => value, 79 + }; 77 80 78 81 if oauth_request.issuer != callback_iss { 79 82 return contextual_error!( ··· 85 88 ); 86 89 } 87 90 88 - let handle = handle_for_did(&web_context.pool, &oauth_request.did).await; 89 - if let Err(err) = handle { 90 - return contextual_error!(web_context, language, error_template, default_context, err); 91 - } 92 - 93 - let handle = handle.unwrap(); 91 + let document = match web_context 92 + .document_storage 93 + .get_document_by_did(&oauth_request.did) 94 + .await 95 + { 96 + Err(err) => { 97 + return contextual_error!(web_context, language, error_template, default_context, err); 98 + } 99 + Ok(None) => { 100 + return contextual_error!( 101 + web_context, 102 + language, 103 + error_template, 104 + default_context, 105 + anyhow!("identity did document not found in storage") 106 + ); 107 + } 108 + Ok(Some(value)) => value, 109 + }; 94 110 95 - let secret_signing_key = web_context 96 - .config 97 - .signing_keys 98 - .as_ref() 99 - .get(&oauth_request.secret_jwk_id) 100 - .cloned() 101 - .ok_or(JwkError::SecretKeyNotFound); 111 + let secret_signing_key = key_provider 112 + .0 113 + .get_private_key_by_id(&oauth_request.signing_public_key) 114 + .await; 102 115 103 - if let Err(err) = secret_signing_key { 104 - return contextual_error!(web_context, language, error_template, default_context, err); 105 - } 106 - let secret_signing_key = secret_signing_key.unwrap(); 116 + let secret_key_data = match secret_signing_key { 117 + Ok(Some(value)) => value, 118 + Ok(None) => { 119 + return contextual_error!( 120 + web_context, 121 + language, 122 + error_template, 123 + default_context, 124 + LoginError::OAuthCallbackIncomplete 125 + ); 126 + } 127 + Err(err) => { 128 + return contextual_error!(web_context, language, error_template, default_context, err); 129 + } 130 + }; 107 131 108 - let dpop_secret_key = SecretKey::from_jwk(&oauth_request.dpop_jwk.jwk); 132 + let dpop_key_data = match identify_key(&oauth_request.dpop_private_key) { 133 + Ok(value) => value, 134 + Err(err) => { 135 + return contextual_error!(web_context, language, error_template, default_context, err); 136 + } 137 + }; 109 138 110 - if let Err(err) = dpop_secret_key { 111 - return contextual_error!(web_context, language, error_template, default_context, err); 112 - } 113 - let dpop_secret_key = dpop_secret_key.unwrap(); 139 + let oauth_client = OAuthClient { 140 + redirect_uri: format!( 141 + "https://{}/oauth/callback", 142 + &web_context.config.external_base 143 + ), 144 + client_id: format!( 145 + "https://{}/oauth/client-metadata.json", 146 + &web_context.config.external_base 147 + ), 148 + private_signing_key_data: secret_key_data, 149 + }; 114 150 115 151 let token_response = oauth_complete( 116 152 &web_context.http_client, 117 - &web_context.config.external_base, 118 - (&oauth_request.secret_jwk_id, secret_signing_key), 153 + &oauth_client, 154 + &dpop_key_data, 119 155 &callback_code, 120 156 &oauth_request, 121 - &handle, 122 - &dpop_secret_key, 157 + &document, 123 158 ) 124 159 .await; 125 160 if let Err(err) = token_response { ··· 128 163 129 164 let token_response = token_response.unwrap(); 130 165 131 - if let Err(err) = oauth_request_remove(&web_context.pool, &oauth_request.oauth_state).await { 166 + if let Err(err) = web_context 167 + .oauth_storage 168 + .delete_oauth_request_by_state(&callback_state) 169 + .await 170 + { 132 171 tracing::error!(error = ?err, "Unable to remove oauth_request"); 133 172 } 134 173 135 - let session_group = ulid::Ulid::new().to_string(); 136 - let now = chrono::Utc::now(); 137 - 138 - if let Err(err) = oauth_session_insert( 139 - &web_context.pool, 140 - crate::storage::oauth::OAuthSessionParams { 141 - session_group: Cow::Owned(session_group.clone()), 142 - access_token: Cow::Owned(token_response.access_token.clone()), 143 - did: Cow::Owned(token_response.sub.clone()), 144 - issuer: Cow::Owned(oauth_request.issuer.clone()), 145 - refresh_token: Cow::Owned(token_response.refresh_token.clone()), 146 - secret_jwk_id: Cow::Owned(oauth_request.secret_jwk_id.clone()), 147 - dpop_jwk: oauth_request.dpop_jwk.0.clone(), 148 - created_at: now, 149 - access_token_expires_at: now 150 - + chrono::Duration::seconds(token_response.expires_in as i64), 151 - }, 152 - ) 153 - .await 154 - { 155 - return contextual_error!(web_context, language, error_template, default_context, err); 156 - } 157 - 158 - { 159 - let mut conn = web_context 160 - .cache_pool 161 - .get() 162 - .await 163 - .map_err(CacheError::FailedToGetConnection)?; 164 - 165 - let modified_expires_at = ((token_response.expires_in as f64) * 0.8).round() as i64; 166 - let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis(); 167 - 168 - let _: () = conn 169 - .zadd(OAUTH_REFRESH_QUEUE, &session_group, refresh_at) 170 - .await 171 - .map_err(CacheError::FailedToPlaceInRefreshQueue)?; 172 - } 173 - 174 - let cookie_value: String = WebSession { 175 - did: token_response.sub.clone(), 176 - session_group: session_group.clone(), 174 + // For standard OAuth, create a PDS session 175 + let cookie_value: String = WebSession::Pds { 176 + did: token_response.sub.clone().unwrap(), 177 + session_group: "".to_string(), // Simplified for initial pass 177 178 } 178 179 .try_into()?; 179 180 ··· 187 188 188 189 let updated_jar = jar.add(cookie); 189 190 190 - let destination = match oauth_request.destination { 191 - Some(destination) => destination, 192 - None => "/".to_string(), 193 - }; 191 + let destination = "/".to_string(); // Simplified for initial pass 194 192 195 193 Ok((updated_jar, Redirect::to(&destination)).into_response()) 196 194 }
-94
src/http/handle_oauth_jwks.rs
··· 1 - use anyhow::Result; 2 - use axum::{ 3 - extract::State, 4 - http::{header, HeaderValue}, 5 - response::{IntoResponse, Response}, 6 - }; 7 - use std::sync::Arc; 8 - 9 - use crate::http::{context::WebContext, errors::WebError}; 10 - use crate::jose::jwk::{WrappedJsonWebKey, WrappedJsonWebKeySet}; 11 - 12 - // Function to compute JWKS data and serialize to JSON string 13 - fn compute_jwks_json(web_context: &WebContext) -> Result<String, serde_json::Error> { 14 - let mut keys = vec![]; 15 - let signing_keys = web_context.config.signing_keys.as_ref(); 16 - 17 - for available_signing_key in web_context.config.oauth_active_keys.as_ref() { 18 - let available_signing_key = available_signing_key.clone(); 19 - 20 - let signing_key = match signing_keys.get(&available_signing_key) { 21 - Some(key) => key.clone(), 22 - None => continue, 23 - }; 24 - let public_key = signing_key.public_key(); 25 - 26 - let wrapped_json_web_key = WrappedJsonWebKey { 27 - jwk: public_key.to_jwk(), 28 - kid: Some(available_signing_key.clone()), 29 - alg: Some("ES256".to_string()), 30 - }; 31 - 32 - keys.push(wrapped_json_web_key); 33 - } 34 - 35 - let jwks = WrappedJsonWebKeySet { keys }; 36 - serde_json::to_string(&jwks) 37 - } 38 - 39 - // Global cache for the pre-serialized JSON string 40 - static JWKS_JSON_CACHE: once_cell::sync::OnceCell<Arc<String>> = once_cell::sync::OnceCell::new(); 41 - 42 - #[tracing::instrument(skip_all, err)] 43 - pub async fn handle_oauth_jwks( 44 - State(web_context): State<WebContext>, 45 - ) -> Result<impl IntoResponse, WebError> { 46 - tracing::debug!("handle_oauth_jwks"); 47 - 48 - // Initialize the cache if needed 49 - if JWKS_JSON_CACHE.get().is_none() { 50 - // Compute and serialize the JWKS data 51 - let jwks_json = compute_jwks_json(&web_context) 52 - .map_err(|e| anyhow::anyhow!("error-oauth-jwks-1 Failed to serialize JWKS: {}", e))?; 53 - 54 - // Store in cache - don't worry if another thread beat us to it 55 - let _ = JWKS_JSON_CACHE.set(Arc::new(jwks_json)); 56 - } 57 - 58 - // By this point, the cache should be initialized - either by us or another thread 59 - // In the extremely unlikely event it's still not initialized, we'll create it one more time 60 - let jwks_json = if let Some(json) = JWKS_JSON_CACHE.get() { 61 - json 62 - } else { 63 - // Final attempt to compute and cache 64 - let jwks_json = compute_jwks_json(&web_context) 65 - .map_err(|e| anyhow::anyhow!("error-oauth-jwks-1 Failed to serialize JWKS: {}", e))?; 66 - 67 - // Create a new Arc and set it in the cache 68 - let json_arc = Arc::new(jwks_json); 69 - 70 - // This will either succeed in setting it, or another thread beat us to it 71 - if JWKS_JSON_CACHE.set(json_arc.clone()).is_err() { 72 - // Another thread set it first, so use that value 73 - JWKS_JSON_CACHE.get().ok_or_else(|| { 74 - anyhow::anyhow!("error-oauth-jwks-2 Failed to initialize JWKS cache") 75 - })? 76 - } else { 77 - // We set it, so we can use our local copy 78 - JWKS_JSON_CACHE.get().ok_or_else(|| { 79 - anyhow::anyhow!("error-oauth-jwks-2 Failed to initialize JWKS cache") 80 - })? 81 - } 82 - }; 83 - 84 - // Create response with proper content type 85 - let mut response = Response::new((**jwks_json).clone()); 86 - 87 - // Set content type to application/json 88 - response.headers_mut().insert( 89 - header::CONTENT_TYPE, 90 - HeaderValue::from_static("application/json"), 91 - ); 92 - 93 - Ok(response) 94 - }
+116 -119
src/http/handle_oauth_login.rs
··· 1 1 use anyhow::Result; 2 - use axum::response::Redirect; 2 + use atproto_identity::{ 3 + key::{generate_key, identify_key, KeyType}, 4 + resolve::IdentityResolver, 5 + }; 6 + use atproto_oauth::{ 7 + pkce::generate, 8 + resources::pds_resources, 9 + workflow::{oauth_init, OAuthClient, OAuthRequest, OAuthRequestState}, 10 + }; 3 11 use axum::{extract::State, response::IntoResponse}; 4 12 use axum_extra::extract::{Cached, Form, Query}; 5 13 use axum_htmx::{HxBoosted, HxRedirect, HxRequest}; 6 14 use axum_template::RenderHtml; 7 - use base64::{engine::general_purpose, Engine as _}; 8 15 use http::StatusCode; 9 16 use minijinja::context as template_context; 10 - use p256::SecretKey; 11 17 use rand::{distributions::Alphanumeric, Rng}; 12 18 use serde::Deserialize; 13 - use sha2::{Digest, Sha256}; 14 - use std::borrow::Cow; 15 19 16 20 use crate::{ 17 21 contextual_error, 18 - did::{plc::query as plc_query, web::query as web_query}, 19 22 http::{ 20 23 context::WebContext, errors::LoginError, errors::WebError, middleware_auth::Auth, 21 24 middleware_i18n::Language, utils::stringify, 22 25 }, 23 - jose, 24 - oauth::{oauth_init, pds_resources}, 25 - resolve::{parse_input, resolve_subject, InputType}, 26 26 select_template, 27 - storage::{ 28 - denylist::denylist_exists, 29 - handle::handle_warm_up, 30 - oauth::{model::OAuthRequestState, oauth_request_insert}, 31 - }, 27 + storage::{denylist::denylist_exists, identity_profile::handle_warm_up}, 32 28 }; 33 29 34 30 #[derive(Deserialize)] 35 - pub struct OAuthLoginForm { 36 - pub handle: Option<String>, 37 - pub destination: Option<String>, 31 + pub(crate) struct OAuthLoginForm { 32 + handle: Option<String>, 38 33 } 39 34 40 35 #[derive(Deserialize)] 41 - pub struct Destination { 42 - pub destination: Option<String>, 36 + pub(crate) struct Destination { 37 + destination: Option<String>, 43 38 } 44 39 45 - pub async fn handle_oauth_login( 40 + #[allow(clippy::too_many_arguments)] 41 + pub(crate) async fn handle_oauth_login( 46 42 State(web_context): State<WebContext>, 43 + identity_resolver: IdentityResolver, 47 44 Language(language): Language, 48 45 Cached(auth): Cached<Auth>, 49 46 HxRequest(hx_request): HxRequest, ··· 52 49 Form(login_form): Form<OAuthLoginForm>, 53 50 ) -> Result<impl IntoResponse, WebError> { 54 51 let default_context = template_context! { 55 - current_handle => auth.0, 52 + current_handle => auth.profile(), 56 53 language => language.to_string(), 57 54 canonical_url => format!("https://{}/oauth/login", web_context.config.external_base), 58 55 destination => destination.destination, ··· 62 59 let error_template = select_template!(hx_boosted, hx_request, language); 63 60 64 61 if let Some(subject) = login_form.handle { 65 - let resolved_did = resolve_subject( 66 - &web_context.http_client, 67 - &web_context.dns_resolver, 68 - &subject, 69 - ) 70 - .await; 71 - 72 - if let Err(err) = resolved_did { 73 - return contextual_error!( 74 - web_context, 75 - language, 76 - render_template, 77 - template_context! { ..default_context, ..template_context! { 78 - handle_error => true, 79 - handle_input => subject, 80 - }}, 81 - err 82 - ); 83 - } 84 - 85 - let resolved_did = resolved_did.unwrap(); 86 - 87 - let query_results = match parse_input(&resolved_did) { 88 - Ok(InputType::Plc(did)) => { 89 - plc_query( 90 - &web_context.http_client, 91 - &web_context.config.plc_hostname, 92 - &did, 93 - ) 94 - .await 95 - } 96 - Ok(InputType::Web(did)) => web_query(&web_context.http_client, &did).await, 97 - _ => Err(LoginError::NoHandle.into()), 98 - }; 99 - 100 - let did_document = match query_results { 62 + let did_document = match identity_resolver.resolve(&subject).await { 101 63 Ok(value) => value, 102 64 Err(err) => { 103 65 return contextual_error!( ··· 113 75 } 114 76 }; 115 77 116 - let mut lookup_values: Vec<&str> = vec![&resolved_did, &did_document.id]; 117 - if let Some(pds) = did_document.pds_endpoint() { 78 + let mut lookup_values: Vec<&str> = vec![&did_document.id, &subject]; 79 + if let Some(pds) = did_document.pds_endpoints().first() { 118 80 lookup_values.push(pds); 119 81 } 120 82 ··· 144 106 ); 145 107 } 146 108 147 - let pds = match did_document.pds_endpoint() { 109 + let pds = match did_document.pds_endpoints().first().cloned() { 148 110 Some(value) => value, 149 111 None => { 150 112 return contextual_error!( ··· 160 122 } 161 123 }; 162 124 163 - let primary_handle = match did_document.primary_handle() { 125 + let primary_handle = match did_document.handles() { 164 126 Some(value) => value, 165 127 None => { 166 128 return contextual_error!( ··· 192 154 .take(30) 193 155 .map(char::from) 194 156 .collect(); 195 - let (pkce_verifier, code_challenge) = gen_pkce(); 157 + let (pkce_verifier, code_challenge) = generate(); 196 158 197 159 let oauth_request_state = OAuthRequestState { 198 160 state, 199 161 nonce, 200 162 code_challenge, 163 + scope: "atproto transition:generic transition:email".to_string(), 201 164 }; 202 165 203 - let pds_auth_resources = pds_resources(&web_context.http_client, pds).await; 166 + let authorization_server = match pds_resources(&web_context.http_client, pds).await { 167 + Ok(value) => value.1, 168 + Err(err) => { 169 + return contextual_error!( 170 + web_context, 171 + language, 172 + error_template, 173 + default_context, 174 + err 175 + ); 176 + } 177 + }; 204 178 205 - if let Err(err) = pds_auth_resources { 179 + let signing_key_selection = web_context.config.select_oauth_signing_key(); 180 + if let Err(err) = signing_key_selection { 206 181 return contextual_error!(web_context, language, error_template, default_context, err); 207 182 } 208 183 209 - let (_, authorization_server) = pds_auth_resources.unwrap(); 210 - tracing::info!(authorization_server = ?authorization_server, "resolved authorization server"); 184 + let (public_signing_key, private_signing_key) = signing_key_selection.unwrap(); 211 185 212 - let signing_key = web_context.config.select_oauth_signing_key(); 213 - if let Err(err) = signing_key { 214 - return contextual_error!(web_context, language, error_template, default_context, err); 215 - } 186 + // Convert the DID method key string to KeyData using identify_key 187 + let private_signing_key_data = match identify_key(&private_signing_key) { 188 + Ok(key_data) => key_data, 189 + Err(err) => { 190 + return contextual_error!( 191 + web_context, 192 + language, 193 + error_template, 194 + default_context, 195 + err 196 + ); 197 + } 198 + }; 216 199 217 - let (key_id, signing_key) = signing_key.unwrap(); 200 + let private_dpop_key_data = match generate_key(KeyType::P256Private) { 201 + Ok(value) => value, 202 + Err(err) => { 203 + return contextual_error!( 204 + web_context, 205 + language, 206 + error_template, 207 + default_context, 208 + err 209 + ); 210 + } 211 + }; 218 212 219 - let dpop_jwk = jose::jwk::generate(); 220 - let dpop_secret_key = SecretKey::from_jwk(&dpop_jwk.jwk); 213 + let private_dpop_key = private_dpop_key_data.to_string(); 221 214 222 - if let Err(err) = dpop_secret_key { 223 - return contextual_error!(web_context, language, error_template, default_context, err); 224 - } 225 - 226 - let dpop_secret_key = dpop_secret_key.unwrap(); 215 + let oauth_client = OAuthClient { 216 + redirect_uri: format!( 217 + "https://{}/oauth/callback", 218 + &web_context.config.external_base 219 + ), 220 + client_id: format!( 221 + "https://{}/oauth/client-metadata.json", 222 + &web_context.config.external_base 223 + ), 224 + private_signing_key_data, 225 + }; 227 226 228 227 let par_response = oauth_init( 229 228 &web_context.http_client, 230 - &web_context.config.external_base, 231 - (&key_id, signing_key), 232 - &dpop_secret_key, 229 + &oauth_client, 230 + &private_dpop_key_data, 233 231 primary_handle, 234 232 &authorization_server, 235 233 &oauth_request_state, ··· 242 240 243 241 let par_response = par_response.unwrap(); 244 242 243 + // Store the DID document 244 + if let Err(err) = web_context 245 + .document_storage 246 + .store_document(did_document.clone()) 247 + .await 248 + { 249 + return contextual_error!(web_context, language, error_template, default_context, err); 250 + } 251 + 245 252 let created_at = chrono::Utc::now(); 246 253 let expires_at = created_at + chrono::Duration::seconds(par_response.expires_in as i64); 247 254 248 - if let Err(err) = oauth_request_insert( 249 - &web_context.pool, 250 - crate::storage::oauth::OAuthRequestParams { 251 - oauth_state: Cow::Owned(oauth_request_state.state.clone()), 252 - issuer: Cow::Owned(authorization_server.issuer.clone()), 253 - did: Cow::Owned(did_document.id.clone()), 254 - nonce: Cow::Owned(oauth_request_state.nonce.clone()), 255 - pkce_verifier: Cow::Owned(pkce_verifier.clone()), 256 - secret_jwk_id: Cow::Owned(key_id.clone()), 257 - dpop_jwk: Some(dpop_jwk.clone()), 258 - destination: login_form.destination.clone().map(Cow::Owned), 259 - created_at, 260 - expires_at, 261 - }, 262 - ) 263 - .await 255 + let oauth_request = OAuthRequest { 256 + oauth_state: oauth_request_state.state.clone(), 257 + issuer: authorization_server.issuer.clone(), 258 + did: did_document.id.clone(), 259 + nonce: oauth_request_state.nonce.clone(), 260 + pkce_verifier: pkce_verifier.clone(), 261 + signing_public_key: public_signing_key, 262 + dpop_private_key: private_dpop_key, 263 + created_at, 264 + expires_at, 265 + }; 266 + 267 + if let Err(err) = web_context 268 + .oauth_storage 269 + .insert_oauth_request(oauth_request) 270 + .await 264 271 { 265 272 return contextual_error!(web_context, language, error_template, default_context, err); 266 273 } ··· 287 294 stringify(oauth_args) 288 295 ); 289 296 290 - if hx_request { 291 - if let Ok(hx_redirect) = HxRedirect::try_from(destination.as_str()) { 292 - return Ok((StatusCode::OK, hx_redirect, "").into_response()); 297 + let hx_redirect = match HxRedirect::try_from(destination.as_str()) { 298 + Ok(value) => value, 299 + Err(err) => { 300 + return contextual_error!( 301 + web_context, 302 + language, 303 + error_template, 304 + default_context, 305 + err 306 + ); 293 307 } 294 - } 308 + }; 295 309 296 - return Ok(Redirect::temporary(destination.as_str()).into_response()); 310 + return Ok((StatusCode::OK, hx_redirect, "").into_response()); 297 311 } 298 312 299 313 Ok(RenderHtml( ··· 305 319 ) 306 320 .into_response()) 307 321 } 308 - 309 - pub fn gen_pkce() -> (String, String) { 310 - let token: String = rand::thread_rng() 311 - .sample_iter(&Alphanumeric) 312 - .take(100) 313 - .map(char::from) 314 - .collect(); 315 - (token.clone(), pkce_challenge(&token)) 316 - } 317 - 318 - pub fn pkce_challenge(token: &str) -> String { 319 - let mut hasher = Sha256::new(); 320 - hasher.update(token.as_bytes()); 321 - let result = hasher.finalize(); 322 - 323 - general_purpose::URL_SAFE_NO_PAD.encode(result) 324 - }
+16 -33
src/http/handle_oauth_logout.rs
··· 1 1 use anyhow::Result; 2 - use axum::{ 3 - extract::State, 4 - response::{IntoResponse, Redirect}, 5 - }; 6 - use axum_extra::extract::{cookie::Cookie, PrivateCookieJar}; 7 - use axum_htmx::{HxRedirect, HxRequest}; 8 - use axum_template::RenderHtml; 9 - use http::StatusCode; 10 - use minijinja::context as template_context; 2 + use axum::extract::State; 3 + use axum::response::{IntoResponse, Redirect}; 4 + use axum_extra::extract::PrivateCookieJar; 5 + use cookie::Cookie; 6 + 7 + use crate::http::{errors::WebError, middleware_auth::AUTH_COOKIE_NAME}; 11 8 12 - use crate::http::{ 13 - context::WebContext, errors::WebError, middleware_auth::AUTH_COOKIE_NAME, 14 - middleware_i18n::Language, 15 - }; 9 + use crate::http::context::WebContext; 16 10 17 - pub async fn handle_logout( 11 + pub(crate) async fn handle_logout( 18 12 State(web_context): State<WebContext>, 19 - Language(language): Language, 20 - HxRequest(hx_request): HxRequest, 21 13 jar: PrivateCookieJar, 22 14 ) -> Result<impl IntoResponse, WebError> { 23 - let updated_jar = jar.remove(Cookie::from(AUTH_COOKIE_NAME)); 15 + let mut removal_cookie = Cookie::from(AUTH_COOKIE_NAME); 16 + removal_cookie.set_domain(web_context.config.external_base.clone()); 17 + removal_cookie.set_path("/"); 24 18 25 - if hx_request { 26 - let hx_redirect = HxRedirect::try_from("/"); 27 - if let Err(err) = hx_redirect { 28 - tracing::error!("Failed to create HxLocation: {}", err); 29 - return Ok(RenderHtml( 30 - format!("alert.{}.partial.html", language.to_string().to_lowercase()), 31 - web_context.engine.clone(), 32 - template_context! { message => "Internal Server Error" }, 33 - ) 34 - .into_response()); 35 - } 36 - let hx_redirect = hx_redirect.unwrap(); 37 - Ok((StatusCode::OK, hx_redirect, "").into_response()) 38 - } else { 39 - Ok((updated_jar, Redirect::to("/")).into_response()) 40 - } 19 + let updated_jar = jar.remove(removal_cookie); 20 + 21 + tracing::info!(?updated_jar, "updated cookie jar"); 22 + 23 + Ok((updated_jar, Redirect::to("/")).into_response()) 41 24 }
+1 -1
src/http/handle_oauth_metadata.rs
··· 57 57 policy_uri: "https://docs.smokesignal.events/docs/about/privacy/", 58 58 redirect_uris, 59 59 response_types: vec!["code"], 60 - scope: "atproto transition:generic", 60 + scope: "atproto transition:generic transition:email", 61 61 token_endpoint_auth_method: "private_key_jwt", 62 62 token_endpoint_auth_signing_alg: "ES256", 63 63 subject_type: "public",
+8 -8
src/http/handle_policy.rs
··· 13 13 select_template, 14 14 }; 15 15 16 - pub async fn handle_privacy_policy( 16 + pub(crate) async fn handle_privacy_policy( 17 17 State(web_context): State<WebContext>, 18 18 HxBoosted(hx_boosted): HxBoosted, 19 19 Language(language): Language, ··· 26 26 &render_template, 27 27 web_context.engine.clone(), 28 28 template_context! { 29 - current_handle => auth.0, 29 + current_handle => auth.profile(), 30 30 language => language.to_string(), 31 31 canonical_url => format!("https://{}/privacy-policy", web_context.config.external_base), 32 32 }, ··· 35 35 .into_response()) 36 36 } 37 37 38 - pub async fn handle_terms_of_service( 38 + pub(crate) async fn handle_terms_of_service( 39 39 State(web_context): State<WebContext>, 40 40 HxBoosted(hx_boosted): HxBoosted, 41 41 Language(language): Language, ··· 48 48 &render_template, 49 49 web_context.engine.clone(), 50 50 template_context! { 51 - current_handle => auth.0, 51 + current_handle => auth.profile(), 52 52 language => language.to_string(), 53 53 canonical_url => format!("https://{}/terms-of-service", web_context.config.external_base), 54 54 }, ··· 57 57 .into_response()) 58 58 } 59 59 60 - pub async fn handle_cookie_policy( 60 + pub(crate) async fn handle_cookie_policy( 61 61 State(web_context): State<WebContext>, 62 62 HxBoosted(hx_boosted): HxBoosted, 63 63 Language(language): Language, ··· 70 70 &render_template, 71 71 web_context.engine.clone(), 72 72 template_context! { 73 - current_handle => auth.0, 73 + current_handle => auth.profile(), 74 74 language => language.to_string(), 75 75 canonical_url => format!("https://{}/cookie-policy", web_context.config.external_base), 76 76 }, ··· 79 79 .into_response()) 80 80 } 81 81 82 - pub async fn handle_acknowledgement( 82 + pub(crate) async fn handle_acknowledgement( 83 83 State(web_context): State<WebContext>, 84 84 HxBoosted(hx_boosted): HxBoosted, 85 85 Language(language): Language, ··· 92 92 &render_template, 93 93 web_context.engine.clone(), 94 94 template_context! { 95 - current_handle => auth.0, 95 + current_handle => auth.profile(), 96 96 language => language.to_string(), 97 97 canonical_url => format!("https://{}/acknowledgement", web_context.config.external_base), 98 98 },
+2 -2
src/http/handle_profile.rs
··· 24 24 storage::{ 25 25 errors::StorageError, 26 26 event::{event_list_did_recently_updated, model::EventWithRole}, 27 - handle::{handle_for_did, handle_for_handle}, 27 + identity_profile::{handle_for_did, handle_for_handle}, 28 28 }, 29 29 }; 30 30 ··· 49 49 } 50 50 } 51 51 52 - pub async fn handle_profile_view( 52 + pub(crate) async fn handle_profile_view( 53 53 ctx: UserRequestContext, 54 54 HxRequest(hx_request): HxRequest, 55 55 HxBoosted(hx_boosted): HxBoosted,
+5 -6
src/http/handle_set_language.rs
··· 12 12 use std::{borrow::Cow, str::FromStr}; 13 13 use unic_langid::LanguageIdentifier; 14 14 15 - use crate::storage::handle::{handle_update_field, HandleField}; 15 + use crate::storage::identity_profile::{handle_update_field, HandleField}; 16 16 17 17 use super::{ 18 18 context::WebContext, errors::WebError, middleware_auth::Auth, middleware_i18n::COOKIE_LANG, ··· 20 20 }; 21 21 22 22 #[derive(Deserialize, Clone)] 23 - pub struct LanguageForm { 23 + pub(crate) struct LanguageForm { 24 24 language: String, 25 25 } 26 26 27 - #[tracing::instrument(skip_all, err)] 28 - pub async fn handle_set_language( 27 + pub(crate) async fn handle_set_language( 29 28 State(web_context): State<WebContext>, 30 29 Cached(auth): Cached<Auth>, 31 30 jar: CookieJar, 32 31 Form(language_form): Form<LanguageForm>, 33 32 ) -> Result<impl IntoResponse, WebError> { 34 33 let default_context = template_context! { 35 - current_handle => auth.0, 34 + current_handle => auth.profile(), 36 35 canonical_url => format!("https://{}/language", web_context.config.external_base), 37 36 }; 38 37 ··· 65 64 } 66 65 let found = found.unwrap(); 67 66 68 - if let Some(handle) = auth.0 { 67 + if let Some(handle) = auth.profile() { 69 68 if let Err(err) = handle_update_field( 70 69 &web_context.pool, 71 70 &handle.did,
+7 -7
src/http/handle_settings.rs
··· 16 16 timezones::supported_timezones, 17 17 }, 18 18 select_template, 19 - storage::handle::{handle_for_did, handle_update_field, HandleField}, 19 + storage::identity_profile::{handle_for_did, handle_update_field, HandleField}, 20 20 }; 21 21 22 22 #[derive(Deserialize, Clone, Debug)] 23 - pub struct TimezoneForm { 23 + pub(crate) struct TimezoneForm { 24 24 timezone: String, 25 25 } 26 26 27 27 #[derive(Deserialize, Clone, Debug)] 28 - pub struct LanguageForm { 28 + pub(crate) struct LanguageForm { 29 29 language: String, 30 30 } 31 31 32 - pub async fn handle_settings( 32 + pub(crate) async fn handle_settings( 33 33 State(web_context): State<WebContext>, 34 34 Language(language): Language, 35 35 Cached(auth): Cached<Auth>, 36 36 HxBoosted(hx_boosted): HxBoosted, 37 37 ) -> Result<impl IntoResponse, WebError> { 38 38 // Require authentication 39 - let current_handle = auth.require(&web_context.config.destination_key, "/settings")?; 39 + let current_handle = auth.require(&web_context.config, "/settings")?; 40 40 41 41 let default_context = template_context! { 42 42 current_handle => current_handle.clone(), ··· 74 74 } 75 75 76 76 #[tracing::instrument(skip_all, err)] 77 - pub async fn handle_timezone_update( 77 + pub(crate) async fn handle_timezone_update( 78 78 State(web_context): State<WebContext>, 79 79 Language(language): Language, 80 80 Cached(auth): Cached<Auth>, ··· 139 139 } 140 140 141 141 #[tracing::instrument(skip_all, err)] 142 - pub async fn handle_language_update( 142 + pub(crate) async fn handle_language_update( 143 143 State(web_context): State<WebContext>, 144 144 Language(language): Language, 145 145 Cached(auth): Cached<Auth>,
+17 -17
src/http/handle_view_event.rs
··· 15 15 use crate::atproto::lexicon::events::smokesignal::calendar::event::NSID as SMOKESIGNAL_EVENT_NSID; 16 16 use crate::contextual_error; 17 17 use crate::http::context::UserRequestContext; 18 - use crate::http::errors::CommonError; 19 18 use crate::http::errors::ViewEventError; 20 19 use crate::http::errors::WebError; 21 20 use crate::http::event_view::hydrate_event_rsvp_counts; ··· 23 22 use crate::http::pagination::Pagination; 24 23 use crate::http::tab_selector::TabSelector; 25 24 use crate::http::utils::url_from_aturi; 26 - use crate::resolve::parse_input; 27 - use crate::resolve::InputType; 28 25 use crate::select_template; 29 26 use crate::storage::event::count_event_rsvps; 30 27 use crate::storage::event::event_exists; 31 28 use crate::storage::event::event_get; 32 29 use crate::storage::event::get_event_rsvps; 33 30 use crate::storage::event::get_user_rsvp; 34 - use crate::storage::handle::handle_for_did; 35 - use crate::storage::handle::handle_for_handle; 36 - use crate::storage::handle::model::Handle; 31 + use crate::storage::identity_profile::handle_for_did; 32 + use crate::storage::identity_profile::handle_for_handle; 33 + use crate::storage::identity_profile::model::IdentityProfile; 37 34 use crate::storage::StoragePool; 38 35 39 36 #[derive(Debug, Deserialize, Serialize, PartialEq)] ··· 75 72 76 73 /// Helper function to fetch the organizer's handle (which contains their time zone) 77 74 /// This is used to implement the time zone selection logic. 78 - async fn fetch_organizer_handle(pool: &StoragePool, did: &str) -> Option<Handle> { 75 + async fn fetch_organizer_handle(pool: &StoragePool, did: &str) -> Option<IdentityProfile> { 79 76 match handle_for_did(pool, did).await { 80 77 Ok(handle) => Some(handle), 81 78 Err(err) => { ··· 85 82 } 86 83 } 87 84 88 - pub async fn handle_view_event( 85 + pub(crate) async fn handle_view_event( 89 86 ctx: UserRequestContext, 90 87 HxBoosted(hx_boosted): HxBoosted, 91 88 Path((handle_slug, event_rkey)): Path<(String, String)>, ··· 101 98 let render_template = select_template!("view_event", hx_boosted, false, ctx.language); 102 99 let error_template = select_template!(hx_boosted, false, ctx.language); 103 100 104 - let profile: Result<Handle, WebError> = match parse_input(&handle_slug) { 105 - Ok(InputType::Handle(handle)) => handle_for_handle(&ctx.web_context.pool, &handle) 101 + let profile: Result<IdentityProfile, WebError> = if handle_slug.starts_with("did:") { 102 + handle_for_did(&ctx.web_context.pool, &handle_slug) 106 103 .await 107 - .map_err(|err| err.into()), 108 - Ok(InputType::Plc(did) | InputType::Web(did)) => { 109 - handle_for_did(&ctx.web_context.pool, &did) 110 - .await 111 - .map_err(|err| err.into()) 112 - } 113 - _ => Err(CommonError::InvalidHandleSlug.into()), 104 + .map_err(|err| err.into()) 105 + } else { 106 + let handle = if let Some(handle) = handle_slug.strip_prefix('@') { 107 + handle 108 + } else { 109 + &handle_slug 110 + }; 111 + handle_for_handle(&ctx.web_context.pool, handle) 112 + .await 113 + .map_err(|err| err.into()) 114 114 }; 115 115 116 116 if let Err(err) = profile {
+2 -2
src/http/handle_view_feed.rs
··· 9 9 context::WebContext, errors::WebError, middleware_auth::Auth, middleware_i18n::Language, 10 10 }; 11 11 12 - pub async fn handle_view_feed( 12 + pub(crate) async fn handle_view_feed( 13 13 State(web_context): State<WebContext>, 14 14 HxBoosted(hx_boosted): HxBoosted, 15 15 Language(language): Language, ··· 25 25 &render_template, 26 26 web_context.engine.clone(), 27 27 template_context! { 28 - current_handle => auth.0, 28 + current_handle => auth.profile(), 29 29 language => language.to_string(), 30 30 canonical_url => format!("https://{}/", web_context.config.external_base), 31 31 },
+4 -4
src/http/handle_view_rsvp.rs
··· 22 22 }; 23 23 24 24 #[derive(Deserialize)] 25 - pub struct RsvpQuery { 26 - pub aturi: Option<String>, 25 + pub(crate) struct RsvpQuery { 26 + aturi: Option<String>, 27 27 } 28 28 29 - pub async fn handle_view_rsvp( 29 + pub(crate) async fn handle_view_rsvp( 30 30 State(web_context): State<WebContext>, 31 31 HxBoosted(hx_boosted): HxBoosted, 32 32 HxRequest(hx_request): HxRequest, ··· 34 34 Cached(auth): Cached<Auth>, 35 35 query: Query<RsvpQuery>, 36 36 ) -> Result<impl IntoResponse, WebError> { 37 - let current_handle = auth.0.clone(); 37 + let current_handle = auth.profile().cloned(); 38 38 39 39 let default_context = template_context! { 40 40 current_handle,
+106 -84
src/http/middleware_auth.rs
··· 1 1 use anyhow::Result; 2 + use atproto_oauth::jwt::{mint, Claims, Header, JoseClaims}; 2 3 use axum::{ 3 4 extract::{FromRef, FromRequestParts}, 4 5 http::request::Parts, 5 6 response::Response, 6 7 }; 7 8 use axum_extra::extract::PrivateCookieJar; 8 - use base64::{engine::general_purpose, Engine as _}; 9 - use p256::{ 10 - ecdsa::{signature::Signer, Signature, SigningKey}, 11 - SecretKey, 12 - }; 13 9 use serde::{Deserialize, Serialize}; 14 10 use tracing::{debug, instrument, trace}; 15 11 16 12 use crate::{ 17 13 config::Config, 18 - encoding::ToBase64, 19 14 http::context::WebContext, 20 15 http::errors::{AuthMiddlewareError, WebSessionError}, 21 - storage::handle::model::Handle, 22 - storage::oauth::model::OAuthSession, 23 - storage::oauth::web_session_lookup, 16 + storage::identity_profile::model::IdentityProfile, 24 17 }; 25 18 19 + use crate::{storage::oauth::model::OAuthSession, storage::oauth::web_session_lookup}; 20 + 26 21 use super::errors::middleware_errors::MiddlewareAuthError; 27 22 28 - pub const AUTH_COOKIE_NAME: &str = "session1"; 23 + pub(crate) const AUTH_COOKIE_NAME: &str = "session"; 29 24 30 25 #[derive(Clone, PartialEq, Serialize, Deserialize)] 31 - pub struct WebSession { 32 - pub did: String, 33 - pub session_group: String, 26 + pub(crate) enum WebSession { 27 + Pds { did: String, session_group: String }, 28 + Aip { did: String, access_token: String }, 34 29 } 35 30 36 31 impl TryFrom<String> for WebSession { ··· 53 48 } 54 49 } 55 50 56 - #[derive(Clone, Serialize, Deserialize)] 57 - pub struct DestinationClaims { 58 - #[serde(rename = "d")] 59 - pub destination: String, 60 - 61 - #[serde(rename = "n")] 62 - pub nonce: String, 63 - } 64 - 65 51 #[derive(Clone)] 66 - pub struct Auth(pub Option<Handle>, pub Option<OAuthSession>); 52 + pub(crate) enum Auth { 53 + Unauthenticated, 54 + Pds { 55 + profile: IdentityProfile, 56 + session: OAuthSession, 57 + }, 58 + Aip { 59 + profile: IdentityProfile, 60 + access_token: String, 61 + }, 62 + } 67 63 68 64 impl Auth { 65 + /// Get the profile if authenticated, None otherwise 66 + pub(crate) fn profile(&self) -> Option<&IdentityProfile> { 67 + match self { 68 + Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => Some(profile), 69 + Auth::Unauthenticated => None, 70 + } 71 + } 69 72 /// Requires authentication and redirects to login with a signed token containing the original destination 70 73 /// 71 74 /// This creates a redirect URL with a signed token containing the destination, 72 75 /// which the login handler can verify and redirect back to after successful authentication. 73 - #[instrument(level = "debug", skip(self, secret_key), err)] 74 - pub fn require( 76 + #[instrument(level = "debug", skip(self, config), err)] 77 + pub(crate) fn require( 75 78 &self, 76 - secret_key: &SecretKey, 79 + config: &crate::config::Config, 77 80 location: &str, 78 - ) -> Result<Handle, MiddlewareAuthError> { 79 - if let Some(handle) = &self.0 { 80 - trace!(did = %handle.did, "User authenticated"); 81 - return Ok(handle.clone()); 81 + ) -> Result<IdentityProfile, MiddlewareAuthError> { 82 + match self { 83 + Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => { 84 + trace!(did = %profile.did, "User authenticated"); 85 + return Ok(profile.clone()); 86 + } 87 + Auth::Unauthenticated => {} 82 88 } 83 89 84 90 debug!( ··· 86 92 "Authentication required, creating signed redirect" 87 93 ); 88 94 89 - // Create claims with destination and random nonce 90 - let claims = DestinationClaims { 91 - destination: location.to_string(), 92 - nonce: ulid::Ulid::new().to_string(), 93 - }; 94 - 95 - // Encode claims to base64 96 - let claims = claims.to_base64()?; 97 - let claim_content = claims.to_string(); 98 - let encoded_json_bytes = general_purpose::URL_SAFE_NO_PAD.encode(claims.as_bytes()); 99 - 100 - // Sign the encoded claims 101 - let signing_key = SigningKey::from(secret_key); 102 - let signature: Signature = signing_key 103 - .try_sign(encoded_json_bytes.as_bytes()) 95 + let header: Header = config 96 + .destination_key_data 97 + .clone() 98 + .try_into() 104 99 .map_err(AuthMiddlewareError::SigningFailed)?; 100 + let claims = Claims::new(JoseClaims { 101 + http_uri: Some(location.to_string()), 102 + nonce: Some(ulid::Ulid::new().to_string()), 103 + ..Default::default() 104 + }); 105 105 106 - // Format the final destination with claims and signature 107 - let destination = format!( 108 - "{}.{}", 109 - claim_content, 110 - general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes()) 111 - ); 106 + let destination_token = mint(&config.destination_key_data, &header, &claims) 107 + .map_err(AuthMiddlewareError::SigningFailed)?; 112 108 113 - trace!( 114 - destination_length = destination.len(), 115 - "Created signed destination token" 116 - ); 117 - Err(MiddlewareAuthError::AccessDenied(destination)) 109 + Err(MiddlewareAuthError::AccessDenied(destination_token)) 118 110 } 119 111 120 112 /// Simpler authentication check that just redirects to root path 121 113 /// 122 114 /// Use this when you don't need to return to the original page after login 123 115 #[instrument(level = "debug", skip(self), err)] 124 - pub fn require_flat(&self) -> Result<Handle, MiddlewareAuthError> { 125 - if let Some(handle) = &self.0 { 126 - trace!(did = %handle.did, "User authenticated"); 127 - return Ok(handle.clone()); 116 + pub(crate) fn require_flat(&self) -> Result<IdentityProfile, MiddlewareAuthError> { 117 + match self { 118 + Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => { 119 + trace!(did = %profile.did, "User authenticated"); 120 + return Ok(profile.clone()); 121 + } 122 + Auth::Unauthenticated => {} 128 123 } 129 124 130 125 debug!("Authentication required, redirecting to root"); ··· 135 130 /// 136 131 /// Returns NotFound error instead of redirecting to login for security reasons 137 132 #[instrument(level = "debug", skip(self, config), err)] 138 - pub fn require_admin(&self, config: &Config) -> Result<Handle, MiddlewareAuthError> { 139 - if let Some(handle) = &self.0 { 140 - if config.is_admin(&handle.did) { 141 - debug!(did = %handle.did, "Admin authenticated"); 142 - return Ok(handle.clone()); 133 + pub(crate) fn require_admin( 134 + &self, 135 + config: &Config, 136 + ) -> Result<IdentityProfile, MiddlewareAuthError> { 137 + match self { 138 + Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => { 139 + if config.is_admin(&profile.did) { 140 + debug!(did = %profile.did, "Admin authenticated"); 141 + return Ok(profile.clone()); 142 + } 143 + debug!(did = %profile.did, "User not an admin"); 143 144 } 144 - debug!(did = %handle.did, "User not an admin"); 145 - } else { 146 - debug!("No authentication found for admin check"); 145 + Auth::Unauthenticated => { 146 + debug!("No authentication found for admin check"); 147 + } 147 148 } 148 149 149 150 // Return NotFound instead of redirect for security reasons ··· 173 174 .and_then(|inner_value| WebSession::try_from(inner_value).ok()); 174 175 175 176 if let Some(web_session) = session { 176 - trace!(?web_session.session_group, "Found session cookie"); 177 + match web_session { 178 + WebSession::Pds { did, session_group } => { 179 + trace!(?session_group, "Found PDS session cookie"); 177 180 178 - match web_session_lookup( 179 - &web_context.pool, 180 - &web_session.session_group, 181 - Some(&web_session.did), 182 - ) 183 - .await 184 - { 185 - Ok(record) => { 186 - debug!(?web_session.session_group, "Session validated"); 187 - return Ok(Self(Some(record.0), Some(record.1))); 181 + match web_session_lookup(&web_context.pool, &session_group, Some(&did)).await { 182 + Ok(record) => { 183 + debug!(?session_group, "Session validated"); 184 + return Ok(Auth::Pds { 185 + profile: record.0, 186 + session: record.1, 187 + }); 188 + } 189 + Err(err) => { 190 + debug!(?session_group, ?err, "Invalid session"); 191 + return Ok(Auth::Unauthenticated); 192 + } 193 + }; 188 194 } 189 - Err(err) => { 190 - debug!(?web_session.session_group, ?err, "Invalid session"); 191 - return Ok(Self(None, None)); 195 + WebSession::Aip { did, access_token } => { 196 + trace!(?access_token, "Found AIP session cookie with access token"); 197 + // For AIP OAuth, try to fetch the handle from the database 198 + let profile = sqlx::query_as::<_, IdentityProfile>( 199 + "SELECT * FROM identity_profiles WHERE did = $1", 200 + ) 201 + .bind(&did) 202 + .fetch_one(&web_context.pool) 203 + .await 204 + .ok(); 205 + 206 + if let Some(profile) = profile { 207 + return Ok(Auth::Aip { 208 + profile, 209 + access_token, 210 + }); 211 + } else { 212 + return Ok(Auth::Unauthenticated); 213 + } 192 214 } 193 - }; 215 + } 194 216 } 195 217 196 218 trace!("No session cookie found"); 197 - Ok(Self(None, None)) 219 + Ok(Auth::Unauthenticated) 198 220 } 199 221 }
+2 -2
src/http/middleware_i18n.rs
··· 85 85 86 86 /// Wrapper around LanguageIdentifier for the current request's language 87 87 #[derive(Clone, Debug)] 88 - pub struct Language(pub LanguageIdentifier); 88 + pub(crate) struct Language(pub(crate) LanguageIdentifier); 89 89 90 90 impl std::fmt::Display for Language { 91 91 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { ··· 120 120 let auth: Auth = Cached::<Auth>::from_request_parts(parts, context).await?.0; 121 121 122 122 // 1. Try to get language from user's profile settings 123 - if let Some(handle) = &auth.0 { 123 + if let Some(handle) = auth.profile() { 124 124 if let Ok(auth_lang) = handle.language.parse::<LanguageIdentifier>() { 125 125 debug!(language = %auth_lang, "Using language from user profile"); 126 126 return Ok(Self(auth_lang));
+5 -2
src/http/mod.rs
··· 14 14 pub mod handle_admin_rsvps; 15 15 pub mod handle_create_event; 16 16 pub mod handle_create_rsvp; 17 + pub mod handle_delete_event; 17 18 pub mod handle_edit_event; 18 19 pub mod handle_import; 19 20 pub mod handle_index; 20 21 pub mod handle_migrate_event; 21 22 pub mod handle_migrate_rsvp; 23 + 24 + pub mod handle_oauth_aip_callback; 25 + pub mod handle_oauth_aip_login; 22 26 pub mod handle_oauth_callback; 23 - pub mod handle_oauth_jwks; 24 27 pub mod handle_oauth_login; 28 + 25 29 pub mod handle_oauth_logout; 26 - pub mod handle_oauth_metadata; 27 30 pub mod handle_policy; 28 31 pub mod handle_profile; 29 32 pub mod handle_set_language;
+17 -24
src/http/pagination.rs
··· 1 1 use crate::http::utils::stringify; 2 2 use serde::{Deserialize, Serialize}; 3 3 4 - pub const PAGE_DEFAULT: i64 = 1; 5 - pub const PAGE_MIN: i64 = 1; 6 - pub const PAGE_MAX: i64 = 100; 7 - pub const PAGE_SIZE_DEFAULT: i64 = 10; 8 - pub const PAGE_SIZE_MIN: i64 = 5; 9 - pub const PAGE_SIZE_MAX: i64 = 100; 10 - 11 - pub const LIMITED_PAGE_DEFAULT: i64 = 1; 12 - pub const LIMITED_PAGE_MIN: i64 = 1; 13 - pub const LIMITED_PAGE_MAX: i64 = 5; 14 - pub const LIMITED_PAGE_SIZE_DEFAULT: i64 = 5; 15 - pub const LIMITED_PAGE_SIZE_MIN: i64 = 5; 16 - pub const LIMITED_PAGE_SIZE_MAX: i64 = 5; 4 + pub(crate) const PAGE_DEFAULT: i64 = 1; 5 + pub(crate) const PAGE_MIN: i64 = 1; 6 + pub(crate) const PAGE_MAX: i64 = 100; 7 + pub(crate) const PAGE_SIZE_DEFAULT: i64 = 10; 8 + pub(crate) const PAGE_SIZE_MIN: i64 = 5; 9 + pub(crate) const PAGE_SIZE_MAX: i64 = 100; 17 10 18 11 #[derive(Deserialize, Default)] 19 - pub struct Pagination { 20 - pub page: Option<i64>, 21 - pub page_size: Option<i64>, 12 + pub(crate) struct Pagination { 13 + pub(crate) page: Option<i64>, 14 + pub(crate) page_size: Option<i64>, 22 15 } 23 16 24 17 #[derive(Serialize, Debug)] 25 - pub struct PaginationView { 26 - pub previous: Option<i64>, 27 - pub previous_url: Option<String>, 28 - pub next: Option<i64>, 29 - pub next_url: Option<String>, 18 + pub(crate) struct PaginationView { 19 + pub(crate) previous: Option<i64>, 20 + pub(crate) previous_url: Option<String>, 21 + pub(crate) next: Option<i64>, 22 + pub(crate) next_url: Option<String>, 30 23 } 31 24 32 25 impl Pagination { 33 - pub fn admin_clamped(&self) -> (i64, i64) { 26 + pub(crate) fn admin_clamped(&self) -> (i64, i64) { 34 27 let page = self.page.unwrap_or(1).clamp(1, 25000); 35 28 let page_size = self.page_size.unwrap_or(1).clamp(20, 100); 36 29 (page, page_size) 37 30 } 38 31 39 - pub fn clamped(&self) -> (i64, i64) { 32 + pub(crate) fn clamped(&self) -> (i64, i64) { 40 33 let page = self.page.unwrap_or(PAGE_DEFAULT).clamp(PAGE_MIN, PAGE_MAX); 41 34 let page_size = self 42 35 .page_size ··· 47 40 } 48 41 49 42 impl PaginationView { 50 - pub fn new(page_size: i64, total: i64, page: i64, params: Vec<(&str, &str)>) -> Self { 43 + pub(crate) fn new(page_size: i64, total: i64, page: i64, params: Vec<(&str, &str)>) -> Self { 51 44 let (previous, previous_url) = { 52 45 if page > 1 { 53 46 let page_value = (page - 1).to_string();
+11 -21
src/http/rsvp_form.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 - use thiserror::Error; 3 2 4 3 use crate::{ 5 4 errors::expand_error, ··· 7 6 storage::{event::event_get_cid, StoragePool}, 8 7 }; 9 8 10 - #[derive(Debug, Error)] 11 - pub enum BuildRSVPError { 12 - #[error("error-rsvp-builder-1 Invalid Subject")] 13 - InvalidSubject, 14 - 15 - #[error("error-rsvp-builder-2 Invalid Status")] 16 - InvalidStatus, 17 - } 18 - 19 9 #[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)] 20 - pub enum BuildRsvpContentState { 10 + pub(crate) enum BuildRsvpContentState { 21 11 #[default] 22 12 Reset, 23 13 Selecting, ··· 26 16 } 27 17 28 18 #[derive(Serialize, Deserialize, Debug, Clone)] 29 - pub struct BuildRSVPForm { 30 - pub build_state: Option<BuildRsvpContentState>, 19 + pub(crate) struct BuildRSVPForm { 20 + pub(crate) build_state: Option<BuildRsvpContentState>, 31 21 32 - pub subject_aturi: Option<String>, 33 - pub subject_aturi_error: Option<String>, 22 + pub(crate) subject_aturi: Option<String>, 23 + pub(crate) subject_aturi_error: Option<String>, 34 24 35 - pub subject_cid: Option<String>, 36 - pub subject_cid_error: Option<String>, 25 + pub(crate) subject_cid: Option<String>, 26 + pub(crate) subject_cid_error: Option<String>, 37 27 38 - pub status: Option<String>, 39 - pub status_error: Option<String>, 28 + pub(crate) status: Option<String>, 29 + pub(crate) status_error: Option<String>, 40 30 } 41 31 42 32 impl BuildRSVPForm { 43 - pub async fn hydrate( 33 + pub(crate) async fn hydrate( 44 34 &mut self, 45 35 database_pool: &StoragePool, 46 36 locales: &Locales, ··· 70 60 } 71 61 } 72 62 73 - pub fn validate( 63 + pub(crate) fn validate( 74 64 &mut self, 75 65 _locales: &Locales, 76 66 _language: &unic_langid::LanguageIdentifier,
+33 -11
src/http/server.rs
··· 15 15 use tower_http::{cors::CorsLayer, services::ServeDir}; 16 16 use tracing::Span; 17 17 18 + use crate::config::OAuthBackendConfig; 18 19 use crate::http::{ 19 20 context::WebContext, 20 21 handle_admin_denylist::{ ··· 33 34 handle_location_datalist, handle_starts_at_builder, 34 35 }, 35 36 handle_create_rsvp::handle_create_rsvp, 37 + handle_delete_event::handle_delete_event, 36 38 handle_edit_event::handle_edit_event, 37 39 handle_import::{handle_import, handle_import_submit}, 38 40 handle_index::handle_index, 39 41 handle_migrate_event::handle_migrate_event, 40 42 handle_migrate_rsvp::handle_migrate_rsvp, 41 - handle_oauth_callback::handle_oauth_callback, 42 - handle_oauth_jwks::handle_oauth_jwks, 43 - handle_oauth_login::handle_oauth_login, 44 43 handle_oauth_logout::handle_logout, 45 - handle_oauth_metadata::handle_oauth_metadata, 46 44 handle_policy::{ 47 45 handle_acknowledgement, handle_cookie_policy, handle_privacy_policy, 48 46 handle_terms_of_service, ··· 55 53 handle_view_rsvp::handle_view_rsvp, 56 54 }; 57 55 56 + use crate::http::handle_oauth_aip_callback::handle_oauth_callback as handle_oauth_aip_callback; 57 + use crate::http::handle_oauth_aip_login::handle_oauth_aip_login; 58 + use crate::http::handle_oauth_callback::handle_oauth_callback as handle_oauth_pds_callback; 59 + use crate::http::handle_oauth_login::handle_oauth_login as handle_oauth_pds_login; 60 + use atproto_oauth_axum::handler_metadata::handle_oauth_metadata; 61 + 58 62 pub fn build_router(web_context: WebContext) -> Router { 59 63 let serve_dir = ServeDir::new(web_context.config.http_static_path.clone()); 60 64 61 - Router::new() 65 + let mut router = Router::new() 62 66 .route("/", get(handle_index)) 63 67 .route("/privacy-policy", get(handle_privacy_policy)) 64 68 .route("/terms-of-service", get(handle_terms_of_service)) 65 69 .route("/cookie-policy", get(handle_cookie_policy)) 66 - .route("/acknowledgement", get(handle_acknowledgement)) 70 + .route("/acknowledgement", get(handle_acknowledgement)); 71 + 72 + // Add OAuth metadata route only for AT Protocol backend 73 + if matches!( 74 + web_context.config.oauth_backend, 75 + OAuthBackendConfig::ATProtocol { .. } 76 + ) { 77 + router = router.route("/oauth/client-metadata.json", get(handle_oauth_metadata)); 78 + } 79 + 80 + // Add OAuth routes based on backend configuration 81 + router = match web_context.config.oauth_backend { 82 + OAuthBackendConfig::AIP { .. } => router 83 + .route("/oauth/login", get(handle_oauth_aip_login)) 84 + .route("/oauth/login", post(handle_oauth_aip_login)) 85 + .route("/oauth/callback", get(handle_oauth_aip_callback)), 86 + OAuthBackendConfig::ATProtocol { .. } => router 87 + .route("/oauth/login", get(handle_oauth_pds_login)) 88 + .route("/oauth/login", post(handle_oauth_pds_login)) 89 + .route("/oauth/callback", get(handle_oauth_pds_callback)), 90 + }; 91 + 92 + router 67 93 .route("/admin", get(handle_admin_index)) 68 94 .route("/admin/handles", get(handle_admin_handles)) 69 95 .route( ··· 79 105 .route("/admin/rsvps", get(handle_admin_rsvps)) 80 106 .route("/admin/rsvp", get(handle_admin_rsvp)) 81 107 .route("/admin/rsvps/import", post(handle_admin_import_rsvp)) 82 - .route("/oauth/client-metadata.json", get(handle_oauth_metadata)) 83 - .route("/.well-known/jwks.json", get(handle_oauth_jwks)) 84 - .route("/oauth/login", get(handle_oauth_login)) 85 - .route("/oauth/login", post(handle_oauth_login)) 86 - .route("/oauth/callback", get(handle_oauth_callback)) 87 108 .route("/logout", get(handle_logout)) 88 109 .route("/language", post(handle_set_language)) 89 110 .route("/settings", get(handle_settings)) ··· 105 126 .route("/event/links", post(handle_link_at_builder)) 106 127 .route("/{handle_slug}/{event_rkey}/edit", get(handle_edit_event)) 107 128 .route("/{handle_slug}/{event_rkey}/edit", post(handle_edit_event)) 129 + .route("/{handle_slug}/{event_rkey}/delete", post(handle_delete_event)) 108 130 .route( 109 131 "/{handle_slug}/{event_rkey}/migrate", 110 132 get(handle_migrate_event),
+7 -7
src/http/tab_selector.rs
··· 1 1 use serde::{Deserialize, Serialize}; 2 2 3 3 #[derive(Deserialize, Default)] 4 - pub struct TabSelector { 5 - pub tab: Option<String>, 4 + pub(crate) struct TabSelector { 5 + pub(crate) tab: Option<String>, 6 6 } 7 7 8 8 #[derive(Deserialize, Serialize)] 9 - pub struct TabLink { 10 - pub name: String, 11 - pub label: String, 12 - pub url: String, 13 - pub active: bool, 9 + pub(crate) struct TabLink { 10 + pub(crate) name: String, 11 + pub(crate) label: String, 12 + pub(crate) url: String, 13 + pub(crate) active: bool, 14 14 }
+1 -1
src/http/templates.rs
··· 2 2 use axum_template::{RenderHtml, TemplateEngine}; 3 3 use minijinja::context as template_context; 4 4 5 - pub fn render_alert<E: TemplateEngine, S: Into<String>>( 5 + pub(crate) fn render_alert<E: TemplateEngine, S: Into<String>>( 6 6 engine: E, 7 7 language: &str, 8 8 message: S,
+2 -2
src/http/timezones.rs
··· 2 2 use chrono::{DateTime, NaiveDateTime, Utc}; 3 3 use itertools::Itertools; 4 4 5 - use crate::storage::handle::model::Handle; 5 + use crate::storage::identity_profile::model::IdentityProfile; 6 6 7 - pub fn supported_timezones(handle: Option<&Handle>) -> (&str, Vec<&str>) { 7 + pub(crate) fn supported_timezones(handle: Option<&IdentityProfile>) -> (&str, Vec<&str>) { 8 8 let handle_tz = handle 9 9 .and_then(|handle| handle.tz.parse().ok()) 10 10 .unwrap_or(chrono_tz::UTC);
+11 -11
src/http/utils.rs
··· 6 6 http::errors::UrlError, 7 7 }; 8 8 9 - pub type QueryParam<'a> = (&'a str, &'a str); 10 - pub type QueryParams<'a> = Vec<QueryParam<'a>>; 9 + pub(crate) type QueryParam<'a> = (&'a str, &'a str); 10 + pub(crate) type QueryParams<'a> = Vec<QueryParam<'a>>; 11 11 12 - pub fn stringify(query: QueryParams) -> String { 12 + pub(crate) fn stringify(query: QueryParams) -> String { 13 13 query.iter().fold(String::new(), |acc, &tuple| { 14 14 acc + tuple.0 + "=" + tuple.1 + "&" 15 15 }) 16 16 } 17 17 18 - pub struct URLBuilder { 18 + struct URLBuilder { 19 19 host: String, 20 20 path: String, 21 21 params: Vec<(String, String)>, 22 22 } 23 23 24 - pub fn build_url(host: &str, path: &str, params: Vec<Option<(&str, &str)>>) -> String { 24 + pub(crate) fn build_url(host: &str, path: &str, params: Vec<Option<(&str, &str)>>) -> String { 25 25 let mut url_builder = URLBuilder::new(host); 26 26 url_builder.path(path); 27 27 ··· 33 33 } 34 34 35 35 impl URLBuilder { 36 - pub fn new(host: &str) -> URLBuilder { 36 + fn new(host: &str) -> URLBuilder { 37 37 let host = if host.starts_with("https://") { 38 38 host.to_string() 39 39 } else { ··· 53 53 } 54 54 } 55 55 56 - pub fn param(&mut self, key: &str, value: &str) -> &mut Self { 56 + fn param(&mut self, key: &str, value: &str) -> &mut Self { 57 57 self.params 58 58 .push((key.to_owned(), urlencoding::encode(value).to_string())); 59 59 self 60 60 } 61 61 62 - pub fn path(&mut self, path: &str) -> &mut Self { 62 + fn path(&mut self, path: &str) -> &mut Self { 63 63 path.clone_into(&mut self.path); 64 64 self 65 65 } 66 66 67 - pub fn build(self) -> String { 67 + fn build(self) -> String { 68 68 let mut url_params = String::new(); 69 69 70 70 if !self.params.is_empty() { ··· 78 78 } 79 79 } 80 80 81 - pub fn url_from_aturi(external_base: &str, aturi: &str) -> Result<String, UrlError> { 81 + pub(crate) fn url_from_aturi(external_base: &str, aturi: &str) -> Result<String, UrlError> { 82 82 let aturi = aturi.strip_prefix("at://").unwrap_or(aturi); 83 83 let parts = aturi.split("/").collect::<Vec<_>>(); 84 84 if parts.len() == 3 && parts[1] == NSID { ··· 105 105 clen 106 106 } 107 107 108 - pub fn truncate_text(text: &str, tlen: usize, suffix: Option<String>) -> String { 108 + pub(crate) fn truncate_text(text: &str, tlen: usize, suffix: Option<String>) -> String { 109 109 if text.len() <= tlen { 110 110 return text.to_string(); 111 111 }
-243
src/jose.rs
··· 1 - use base64::{engine::general_purpose, Engine as _}; 2 - use jwt::{Claims, Header}; 3 - use p256::{ 4 - ecdsa::{ 5 - signature::{Signer, Verifier}, 6 - Signature, SigningKey, VerifyingKey, 7 - }, 8 - PublicKey, SecretKey, 9 - }; 10 - use std::time::{SystemTime, UNIX_EPOCH}; 11 - 12 - use crate::encoding::ToBase64; 13 - use crate::jose_errors::JoseError; 14 - 15 - /// Signs a JWT token with the provided secret key, header, and claims 16 - /// 17 - /// Creates a JSON Web Token (JWT) by: 18 - /// 1. Base64URL encoding the header and claims 19 - /// 2. Signing the encoded header and claims with the secret key 20 - /// 3. Returning the complete JWT (header.claims.signature) 21 - pub fn mint_token( 22 - secret_key: &SecretKey, 23 - header: &Header, 24 - claims: &Claims, 25 - ) -> Result<String, JoseError> { 26 - // Encode header and claims to base64url 27 - let header = header 28 - .to_base64() 29 - .map_err(|_| JoseError::SigningKeyNotFound)?; 30 - let claims = claims 31 - .to_base64() 32 - .map_err(|_| JoseError::SigningKeyNotFound)?; 33 - let content = format!("{}.{}", header, claims); 34 - 35 - // Create signature 36 - let signing_key = SigningKey::from(secret_key.clone()); 37 - let signature: Signature = signing_key 38 - .try_sign(content.as_bytes()) 39 - .map_err(JoseError::SigningFailed)?; 40 - 41 - // Return complete JWT 42 - Ok(format!( 43 - "{}.{}", 44 - content, 45 - general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes()) 46 - )) 47 - } 48 - 49 - /// Verifies a JWT token's signature and validates its claims 50 - /// 51 - /// Performs the following validations: 52 - /// 1. Checks token format is valid (three parts separated by periods) 53 - /// 2. Decodes header and claims from base64url format 54 - /// 3. Verifies the token signature using the provided public key 55 - /// 4. Validates token expiration (if provided in claims) 56 - /// 5. Validates token not-before time (if provided in claims) 57 - /// 6. Returns the decoded claims if all validation passes 58 - pub fn verify_token(token: &str, public_key: &PublicKey) -> Result<Claims, JoseError> { 59 - // Split token into its parts 60 - let parts: Vec<&str> = token.split('.').collect(); 61 - if parts.len() != 3 { 62 - return Err(JoseError::InvalidTokenFormat); 63 - } 64 - 65 - let encoded_header = parts[0]; 66 - let encoded_claims = parts[1]; 67 - let encoded_signature = parts[2]; 68 - 69 - // Decode header 70 - let header_bytes = general_purpose::URL_SAFE_NO_PAD 71 - .decode(encoded_header) 72 - .map_err(|_| JoseError::InvalidHeader)?; 73 - 74 - let header: Header = 75 - serde_json::from_slice(&header_bytes).map_err(|_| JoseError::InvalidHeader)?; 76 - 77 - // Verify algorithm matches what we expect 78 - // We only support ES256 for now 79 - if header.algorithm.as_deref() != Some("ES256") { 80 - return Err(JoseError::UnsupportedAlgorithm); 81 - } 82 - 83 - // Decode claims 84 - let claims_bytes = general_purpose::URL_SAFE_NO_PAD 85 - .decode(encoded_claims) 86 - .map_err(|_| JoseError::InvalidClaims)?; 87 - 88 - let claims: Claims = 89 - serde_json::from_slice(&claims_bytes).map_err(|_| JoseError::InvalidClaims)?; 90 - 91 - // Decode signature 92 - let signature_bytes = general_purpose::URL_SAFE_NO_PAD 93 - .decode(encoded_signature) 94 - .map_err(|_| JoseError::InvalidSignature)?; 95 - 96 - let signature = 97 - Signature::try_from(signature_bytes.as_slice()).map_err(|_| JoseError::InvalidSignature)?; 98 - 99 - // Verify signature 100 - let verifying_key = VerifyingKey::from(public_key); 101 - let content = format!("{}.{}", encoded_header, encoded_claims); 102 - 103 - verifying_key 104 - .verify(content.as_bytes(), &signature) 105 - .map_err(|_| JoseError::SignatureVerificationFailed)?; 106 - 107 - // Get current timestamp for validation 108 - let now = SystemTime::now() 109 - .duration_since(UNIX_EPOCH) 110 - .map_err(|_| JoseError::SystemTimeError)? 111 - .as_secs(); 112 - 113 - // Validate expiration time if present 114 - if let Some(exp) = claims.jose.expiration { 115 - if now >= exp { 116 - return Err(JoseError::TokenExpired); 117 - } 118 - } 119 - 120 - // Validate not-before time if present 121 - if let Some(nbf) = claims.jose.not_before { 122 - if now < nbf { 123 - return Err(JoseError::TokenNotYetValid); 124 - } 125 - } 126 - 127 - // Return validated claims 128 - Ok(claims) 129 - } 130 - 131 - pub mod jwk { 132 - use elliptic_curve::JwkEcKey; 133 - use p256::SecretKey; 134 - use rand::rngs::OsRng; 135 - use serde::{Deserialize, Serialize}; 136 - 137 - #[derive(Serialize, Deserialize, Clone, PartialEq, Debug)] 138 - pub struct WrappedJsonWebKey { 139 - #[serde(skip_serializing_if = "Option::is_none", default)] 140 - pub kid: Option<String>, 141 - 142 - #[serde(skip_serializing_if = "Option::is_none", default)] 143 - pub alg: Option<String>, 144 - 145 - #[serde(flatten)] 146 - pub jwk: JwkEcKey, 147 - } 148 - 149 - #[derive(Serialize, Deserialize, Clone)] 150 - pub struct WrappedJsonWebKeySet { 151 - pub keys: Vec<WrappedJsonWebKey>, 152 - } 153 - 154 - pub fn generate() -> WrappedJsonWebKey { 155 - let secret_key = SecretKey::random(&mut OsRng); 156 - 157 - let kid = ulid::Ulid::new().to_string(); 158 - 159 - WrappedJsonWebKey { 160 - kid: Some(kid), 161 - alg: Some("ES256".to_string()), 162 - jwk: secret_key.to_jwk(), 163 - } 164 - } 165 - } 166 - 167 - pub mod jwt { 168 - 169 - use std::collections::BTreeMap; 170 - 171 - use elliptic_curve::JwkEcKey; 172 - use serde::{Deserialize, Serialize}; 173 - 174 - #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] 175 - pub struct Header { 176 - #[serde(rename = "alg", skip_serializing_if = "Option::is_none")] 177 - pub algorithm: Option<String>, 178 - 179 - #[serde(rename = "kid", skip_serializing_if = "Option::is_none")] 180 - pub key_id: Option<String>, 181 - 182 - #[serde(rename = "typ", skip_serializing_if = "Option::is_none")] 183 - pub type_: Option<String>, 184 - 185 - #[serde(rename = "jwk", skip_serializing_if = "Option::is_none")] 186 - pub json_web_key: Option<JwkEcKey>, 187 - } 188 - 189 - #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] 190 - pub struct Claims { 191 - #[serde(flatten)] 192 - pub jose: JoseClaims, 193 - #[serde(flatten)] 194 - pub private: BTreeMap<String, serde_json::Value>, 195 - } 196 - 197 - impl Claims { 198 - pub fn new(jose: JoseClaims) -> Self { 199 - Claims { 200 - jose, 201 - private: BTreeMap::new(), 202 - } 203 - } 204 - } 205 - 206 - pub type SecondsSinceEpoch = u64; 207 - 208 - #[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)] 209 - pub struct JoseClaims { 210 - #[serde(rename = "iss", skip_serializing_if = "Option::is_none")] 211 - pub issuer: Option<String>, 212 - 213 - #[serde(rename = "sub", skip_serializing_if = "Option::is_none")] 214 - pub subject: Option<String>, 215 - 216 - #[serde(rename = "aud", skip_serializing_if = "Option::is_none")] 217 - pub audience: Option<String>, 218 - 219 - #[serde(rename = "exp", skip_serializing_if = "Option::is_none")] 220 - pub expiration: Option<SecondsSinceEpoch>, 221 - 222 - #[serde(rename = "nbf", skip_serializing_if = "Option::is_none")] 223 - pub not_before: Option<SecondsSinceEpoch>, 224 - 225 - #[serde(rename = "iat", skip_serializing_if = "Option::is_none")] 226 - pub issued_at: Option<SecondsSinceEpoch>, 227 - 228 - #[serde(rename = "jti", skip_serializing_if = "Option::is_none")] 229 - pub json_web_token_id: Option<String>, 230 - 231 - #[serde(rename = "htm", skip_serializing_if = "Option::is_none")] 232 - pub http_method: Option<String>, 233 - 234 - #[serde(rename = "htu", skip_serializing_if = "Option::is_none")] 235 - pub http_uri: Option<String>, 236 - 237 - #[serde(rename = "nonce", skip_serializing_if = "Option::is_none")] 238 - pub nonce: Option<String>, 239 - 240 - #[serde(rename = "ath", skip_serializing_if = "Option::is_none")] 241 - pub auth: Option<String>, 242 - } 243 - }
-140
src/jose_errors.rs
··· 1 - use thiserror::Error; 2 - 3 - /// Represents errors that can occur during JOSE (JSON Object Signing and Encryption) operations. 4 - /// 5 - /// These errors are related to JSON Web Token (JWT) signing and verification, 6 - /// JSON Web Key (JWK) operations, and DPoP (Demonstrating Proof-of-Possession) functionality. 7 - #[derive(Debug, Error)] 8 - pub enum JoseError { 9 - /// Error when token signing fails. 10 - /// 11 - /// This error occurs when the application tries to sign a JWT token 12 - /// using an ECDSA signing key but the signing operation fails. 13 - #[error("error-jose-1 Failed to sign token: {0:?}")] 14 - SigningFailed(p256::ecdsa::Error), 15 - 16 - /// Error when a required signing key is not found. 17 - /// 18 - /// This error occurs when the application tries to use a signing key 19 - /// that is not available in the loaded configuration. 20 - #[error("error-jose-2 Signing key not found")] 21 - SigningKeyNotFound, 22 - 23 - /// Error when a simple error cannot be parsed. 24 - /// 25 - /// This error occurs when the application fails to parse an error 26 - /// response from an OAuth server. 27 - #[error("error-jose-3 Unable to parse simple error")] 28 - UnableToParseSimpleError, 29 - 30 - /// Error when a required DPoP header is missing. 31 - /// 32 - /// This error occurs when making a request to a protected resource 33 - /// that requires a DPoP header, but the header is not present. 34 - #[error("error-jose-4 Missing DPoP header")] 35 - MissingDpopHeader, 36 - 37 - /// Error when a DPoP header cannot be parsed. 38 - /// 39 - /// This error occurs when the application receives a DPoP header 40 - /// that is malformed or contains invalid data. 41 - #[error("error-jose-5 Unable to parse DPoP header: {0}")] 42 - UnableToParseDpopHeader(String), 43 - 44 - /// Error when a DPoP proof token cannot be created. 45 - /// 46 - /// This error occurs when the application fails to create a valid 47 - /// DPoP proof token required for accessing protected resources. 48 - #[error("error-jose-6 Unable to mint DPoP proof token: {0}")] 49 - UnableToMintDpopProofToken(String), 50 - 51 - /// Error when an unexpected error occurs during JOSE operations. 52 - /// 53 - /// This is a catch-all error for unexpected issues that occur 54 - /// during JOSE-related operations. 55 - #[error("error-jose-7 Unexpected error: {0}")] 56 - UnexpectedError(String), 57 - 58 - /// Error when a JWT token has an invalid format. 59 - /// 60 - /// This error occurs when a JWT token doesn't have three parts 61 - /// separated by periods (header.payload.signature). 62 - #[error("error-jose-8 Invalid token format")] 63 - InvalidTokenFormat, 64 - 65 - /// Error when a JWT header cannot be decoded or parsed. 66 - /// 67 - /// This error occurs when the header part of a JWT token contains 68 - /// invalid base64url-encoded data or invalid JSON. 69 - #[error("error-jose-9 Invalid token header")] 70 - InvalidHeader, 71 - 72 - /// Error when a JWT claims part cannot be decoded or parsed. 73 - /// 74 - /// This error occurs when the claims part of a JWT token contains 75 - /// invalid base64url-encoded data or invalid JSON. 76 - #[error("error-jose-10 Invalid token claims")] 77 - InvalidClaims, 78 - 79 - /// Error when a JWT signature cannot be decoded. 80 - /// 81 - /// This error occurs when the signature part of a JWT token contains 82 - /// invalid base64url-encoded data. 83 - #[error("error-jose-11 Invalid token signature")] 84 - InvalidSignature, 85 - 86 - /// Error when JWT signature verification fails. 87 - /// 88 - /// This error occurs when the signature of a JWT token doesn't match 89 - /// the expected signature computed from the header and claims. 90 - #[error("error-jose-12 Signature verification failed")] 91 - SignatureVerificationFailed, 92 - 93 - /// Error when a JWT token has expired. 94 - /// 95 - /// This error occurs when the current time is past the expiration 96 - /// time (exp) specified in the JWT claims. 97 - #[error("error-jose-13 Token has expired")] 98 - TokenExpired, 99 - 100 - /// Error when a JWT token is not yet valid. 101 - /// 102 - /// This error occurs when the current time is before the not-before 103 - /// time (nbf) specified in the JWT claims. 104 - #[error("error-jose-14 Token is not yet valid")] 105 - TokenNotYetValid, 106 - 107 - /// Error when the system time cannot be determined. 108 - /// 109 - /// This rare error occurs when the system time cannot be retrieved 110 - /// or is invalid. 111 - #[error("error-jose-15 System time error")] 112 - SystemTimeError, 113 - 114 - /// Error when a JWT token uses an unsupported algorithm. 115 - /// 116 - /// This error occurs when the JWT token uses an algorithm (alg) 117 - /// that the application doesn't support or allow. 118 - #[error("error-jose-16 Unsupported algorithm")] 119 - UnsupportedAlgorithm, 120 - 121 - /// Error when a JWT token has invalid key parameters. 122 - /// 123 - /// This error occurs when the JWT token uses key parameters that 124 - /// are invalid or not supported. 125 - #[error("error-jose-17 Invalid key parameters: {0}")] 126 - InvalidKeyParameters(String), 127 - } 128 - 129 - /// Represents errors that can occur during JSON Web Key (JWK) operations. 130 - /// 131 - /// These errors relate to operations with cryptographic keys in JWK format. 132 - #[derive(Debug, Error)] 133 - pub enum JwkError { 134 - /// Error when a secret JWK key is not found. 135 - /// 136 - /// This error occurs when the application tries to use a secret JWK key 137 - /// that is not available in the loaded configuration. 138 - #[error("error-jwk-1 Secret JWK key not found")] 139 - SecretKeyNotFound, 140 - }
+21
src/key_provider.rs
··· 1 + use async_trait::async_trait; 2 + use atproto_identity::key::{KeyData, KeyProvider}; 3 + use std::collections::HashMap; 4 + 5 + #[derive(Clone)] 6 + pub struct SimpleKeyProvider { 7 + keys: HashMap<String, KeyData>, 8 + } 9 + 10 + impl SimpleKeyProvider { 11 + pub fn new(keys: HashMap<String, KeyData>) -> Self { 12 + Self { keys } 13 + } 14 + } 15 + 16 + #[async_trait] 17 + impl KeyProvider for SimpleKeyProvider { 18 + async fn get_private_key_by_id(&self, key_id: &str) -> anyhow::Result<Option<KeyData>> { 19 + Ok(self.keys.get(key_id).cloned()) 20 + } 21 + }
+4 -11
src/lib.rs
··· 1 1 pub mod atproto; 2 2 pub mod config; 3 3 pub mod config_errors; 4 - pub mod did; 5 - pub mod encoding; 6 - pub mod encoding_errors; 7 4 pub mod errors; 8 5 pub mod http; 9 6 pub mod i18n; 10 - pub mod jose; 11 - pub mod jose_errors; 12 - pub mod oauth; 13 - pub mod oauth_client_errors; 14 - pub mod oauth_errors; 7 + pub mod key_provider; 15 8 pub mod refresh_tokens_errors; 16 - pub mod resolve; 17 9 pub mod storage; 18 - // Removing storage_oauth_errors, consolidated with storage/oauth_model_errors 10 + 11 + pub mod task_identity_refresh; 12 + pub mod task_oauth_requests_cleanup; 19 13 pub mod task_refresh_tokens; 20 - pub mod validation;
-664
src/oauth.rs
··· 1 - use dpop::DpopRetry; 2 - use p256::SecretKey; 3 - use rand::distributions::{Alphanumeric, DistString}; 4 - use reqwest_chain::ChainMiddleware; 5 - use reqwest_middleware::ClientBuilder; 6 - use std::time::Duration; 7 - 8 - use crate::oauth_client_errors::OAuthClientError; 9 - use crate::oauth_errors::{AuthServerValidationError, ResourceValidationError}; 10 - use model::{AuthorizationServer, OAuthProtectedResource, ParResponse, TokenResponse}; 11 - 12 - use crate::{ 13 - jose::{ 14 - jwt::{Claims, Header, JoseClaims}, 15 - mint_token, 16 - }, 17 - storage::{ 18 - handle::model::Handle, 19 - oauth::model::{OAuthRequest, OAuthRequestState}, 20 - }, 21 - }; 22 - 23 - const HTTP_CLIENT_TIMEOUT_SECS: u64 = 8; 24 - 25 - pub async fn pds_resources( 26 - http_client: &reqwest::Client, 27 - pds: &str, 28 - ) -> Result<(OAuthProtectedResource, AuthorizationServer), OAuthClientError> { 29 - let protected_resource = oauth_protected_resource(http_client, pds).await?; 30 - 31 - let first_authorization_server = protected_resource 32 - .authorization_servers 33 - .first() 34 - .ok_or(OAuthClientError::InvalidOAuthProtectedResource)?; 35 - 36 - let authorization_server = 37 - oauth_authorization_server(http_client, first_authorization_server).await?; 38 - Ok((protected_resource, authorization_server)) 39 - } 40 - 41 - pub async fn oauth_protected_resource( 42 - http_client: &reqwest::Client, 43 - pds: &str, 44 - ) -> Result<OAuthProtectedResource, OAuthClientError> { 45 - let destination = format!("{}/.well-known/oauth-protected-resource", pds); 46 - 47 - let resource: OAuthProtectedResource = http_client 48 - .get(destination) 49 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 50 - .send() 51 - .await 52 - .map_err(OAuthClientError::OAuthProtectedResourceRequestFailed)? 53 - .json() 54 - .await 55 - .map_err(OAuthClientError::MalformedOAuthProtectedResourceResponse)?; 56 - 57 - if resource.resource != pds { 58 - return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse( 59 - ResourceValidationError::ResourceMustMatchPds.into(), 60 - )); 61 - } 62 - 63 - if resource.authorization_servers.is_empty() { 64 - return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse( 65 - ResourceValidationError::AuthorizationServersMustNotBeEmpty.into(), 66 - )); 67 - } 68 - 69 - Ok(resource) 70 - } 71 - 72 - #[tracing::instrument(skip(http_client), err)] 73 - pub async fn oauth_authorization_server( 74 - http_client: &reqwest::Client, 75 - pds: &str, 76 - ) -> Result<AuthorizationServer, OAuthClientError> { 77 - let destination = format!("{}/.well-known/oauth-authorization-server", pds); 78 - 79 - let resource: AuthorizationServer = http_client 80 - .get(destination) 81 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 82 - .send() 83 - .await 84 - .map_err(OAuthClientError::AuthorizationServerRequestFailed)? 85 - .json() 86 - .await 87 - .map_err(OAuthClientError::MalformedAuthorizationServerResponse)?; 88 - 89 - // All of this is going to change. 90 - 91 - if resource.issuer != pds { 92 - return Err(OAuthClientError::InvalidAuthorizationServerResponse( 93 - AuthServerValidationError::IssuerMustMatchPds.into(), 94 - )); 95 - } 96 - 97 - resource 98 - .response_types_supported 99 - .iter() 100 - .find(|&x| x == "code") 101 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 102 - AuthServerValidationError::ResponseTypesSupportMustIncludeCode.into(), 103 - ))?; 104 - 105 - resource 106 - .grant_types_supported 107 - .iter() 108 - .find(|&x| x == "authorization_code") 109 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 110 - AuthServerValidationError::GrantTypesSupportMustIncludeAuthorizationCode.into(), 111 - ))?; 112 - resource 113 - .grant_types_supported 114 - .iter() 115 - .find(|&x| x == "refresh_token") 116 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 117 - AuthServerValidationError::GrantTypesSupportMustIncludeRefreshToken.into(), 118 - ))?; 119 - resource 120 - .code_challenge_methods_supported 121 - .iter() 122 - .find(|&x| x == "S256") 123 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 124 - AuthServerValidationError::CodeChallengeMethodsSupportedMustIncludeS256.into(), 125 - ))?; 126 - resource 127 - .token_endpoint_auth_methods_supported 128 - .iter() 129 - .find(|&x| x == "none") 130 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 131 - AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludeNone.into(), 132 - ))?; 133 - resource 134 - .token_endpoint_auth_methods_supported 135 - .iter() 136 - .find(|&x| x == "private_key_jwt") 137 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 138 - AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludePrivateKeyJwt 139 - .into(), 140 - ))?; 141 - resource 142 - .token_endpoint_auth_signing_alg_values_supported 143 - .iter() 144 - .find(|&x| x == "ES256") 145 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 146 - AuthServerValidationError::TokenEndpointAuthSigningAlgValuesMustIncludeES256.into(), 147 - ))?; 148 - resource 149 - .scopes_supported 150 - .iter() 151 - .find(|&x| x == "atproto") 152 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 153 - AuthServerValidationError::ScopesSupportedMustIncludeAtProto.into(), 154 - ))?; 155 - resource 156 - .scopes_supported 157 - .iter() 158 - .find(|&x| x == "transition:generic") 159 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 160 - AuthServerValidationError::ScopesSupportedMustIncludeTransitionGeneric.into(), 161 - ))?; 162 - resource 163 - .dpop_signing_alg_values_supported 164 - .iter() 165 - .find(|&x| x == "ES256") 166 - .ok_or(OAuthClientError::InvalidAuthorizationServerResponse( 167 - AuthServerValidationError::DpopSigningAlgValuesSupportedMustIncludeES256.into(), 168 - ))?; 169 - 170 - if !(resource.authorization_response_iss_parameter_supported 171 - && resource.require_pushed_authorization_requests 172 - && resource.client_id_metadata_document_supported) 173 - { 174 - return Err(OAuthClientError::InvalidAuthorizationServerResponse( 175 - AuthServerValidationError::RequiredServerFeaturesMustBeSupported.into(), 176 - )); 177 - } 178 - 179 - Ok(resource) 180 - } 181 - 182 - pub async fn oauth_init( 183 - http_client: &reqwest::Client, 184 - external_url_base: &str, 185 - (secret_key_id, secret_key): (&str, SecretKey), 186 - dpop_secret_key: &SecretKey, 187 - handle: &str, 188 - authorization_server: &AuthorizationServer, 189 - oauth_request_state: &OAuthRequestState, 190 - ) -> Result<ParResponse, OAuthClientError> { 191 - let par_url = authorization_server 192 - .pushed_authorization_request_endpoint 193 - .clone(); 194 - 195 - let redirect_uri = format!("https://{}/oauth/callback", external_url_base); 196 - let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base); 197 - 198 - let scope = "atproto transition:generic".to_string(); 199 - 200 - let client_assertion_header = Header { 201 - algorithm: Some("ES256".to_string()), 202 - key_id: Some(secret_key_id.to_string()), 203 - ..Default::default() 204 - }; 205 - 206 - let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 207 - let client_assertion_claims = Claims::new(JoseClaims { 208 - issuer: Some(client_id.clone()), 209 - subject: Some(client_id.clone()), 210 - audience: Some(authorization_server.issuer.clone()), 211 - json_web_token_id: Some(client_assertion_jti), 212 - issued_at: Some(chrono::Utc::now().timestamp() as u64), 213 - ..Default::default() 214 - }); 215 - tracing::info!("client_assertion_claims: {:?}", client_assertion_claims); 216 - 217 - let client_assertion_token = mint_token( 218 - &secret_key, 219 - &client_assertion_header, 220 - &client_assertion_claims, 221 - ) 222 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 223 - 224 - let now = chrono::Utc::now(); 225 - let public_key = dpop_secret_key.public_key(); 226 - 227 - let dpop_proof_header = Header { 228 - type_: Some("dpop+jwt".to_string()), 229 - algorithm: Some("ES256".to_string()), 230 - json_web_key: Some(public_key.to_jwk()), 231 - ..Default::default() 232 - }; 233 - let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 234 - 235 - let dpop_proof_claim = Claims::new(JoseClaims { 236 - json_web_token_id: Some(dpop_proof_jti), 237 - http_method: Some("POST".to_string()), 238 - http_uri: Some(par_url.clone()), 239 - issued_at: Some(now.timestamp() as u64), 240 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 241 - ..Default::default() 242 - }); 243 - let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim) 244 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 245 - 246 - let dpop_retry = DpopRetry::new( 247 - dpop_proof_header.clone(), 248 - dpop_proof_claim.clone(), 249 - dpop_secret_key.clone(), 250 - ); 251 - 252 - let dpop_retry_client = ClientBuilder::new(http_client.clone()) 253 - .with(ChainMiddleware::new(dpop_retry.clone())) 254 - .build(); 255 - 256 - let params = [ 257 - ("response_type", "code"), 258 - ("code_challenge", &oauth_request_state.code_challenge), 259 - ("code_challenge_method", "S256"), 260 - ("client_id", client_id.as_str()), 261 - ("state", oauth_request_state.state.as_str()), 262 - ("redirect_uri", redirect_uri.as_str()), 263 - ("scope", scope.as_str()), 264 - ("login_hint", handle), 265 - ( 266 - "client_assertion_type", 267 - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", 268 - ), 269 - ("client_assertion", client_assertion_token.as_str()), 270 - ]; 271 - 272 - tracing::warn!("params: {:?}", params); 273 - 274 - dpop_retry_client 275 - .post(par_url) 276 - .header("DPoP", dpop_proof_token.as_str()) 277 - .form(&params) 278 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 279 - .send() 280 - .await 281 - .map_err(OAuthClientError::PARMiddlewareRequestFailed)? 282 - .json() 283 - .await 284 - .map_err(OAuthClientError::MalformedPARResponse) 285 - } 286 - 287 - pub async fn oauth_complete( 288 - http_client: &reqwest::Client, 289 - external_url_base: &str, 290 - (secret_key_id, secret_key): (&str, SecretKey), 291 - callback_code: &str, 292 - oauth_request: &OAuthRequest, 293 - handle: &Handle, 294 - dpop_secret_key: &SecretKey, 295 - ) -> Result<TokenResponse, OAuthClientError> { 296 - let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?; 297 - 298 - let client_assertion_header = Header { 299 - algorithm: Some("ES256".to_string()), 300 - key_id: Some(secret_key_id.to_string()), 301 - ..Default::default() 302 - }; 303 - 304 - let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base); 305 - let redirect_uri = format!("https://{}/oauth/callback", external_url_base); 306 - 307 - let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 308 - let client_assertion_claims = Claims::new(JoseClaims { 309 - issuer: Some(client_id.clone()), 310 - subject: Some(client_id.clone()), 311 - audience: Some(authorization_server.issuer.clone()), 312 - json_web_token_id: Some(client_assertion_jti), 313 - issued_at: Some(chrono::Utc::now().timestamp() as u64), 314 - ..Default::default() 315 - }); 316 - 317 - let client_assertion_token = mint_token( 318 - &secret_key, 319 - &client_assertion_header, 320 - &client_assertion_claims, 321 - ) 322 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 323 - 324 - let params = [ 325 - ("client_id", client_id.as_str()), 326 - ("redirect_uri", redirect_uri.as_str()), 327 - ("grant_type", "authorization_code"), 328 - ("code", callback_code), 329 - ("code_verifier", &oauth_request.pkce_verifier), 330 - ( 331 - "client_assertion_type", 332 - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", 333 - ), 334 - ("client_assertion", client_assertion_token.as_str()), 335 - ]; 336 - 337 - let public_key = dpop_secret_key.public_key(); 338 - 339 - let token_endpoint = authorization_server.token_endpoint.clone(); 340 - 341 - let now = chrono::Utc::now(); 342 - 343 - let dpop_proof_header = Header { 344 - type_: Some("dpop+jwt".to_string()), 345 - algorithm: Some("ES256".to_string()), 346 - json_web_key: Some(public_key.to_jwk()), 347 - ..Default::default() 348 - }; 349 - let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 350 - let dpop_proof_claim = Claims::new(JoseClaims { 351 - json_web_token_id: Some(dpop_proof_jti), 352 - http_method: Some("POST".to_string()), 353 - http_uri: Some(authorization_server.token_endpoint.clone()), 354 - issued_at: Some(now.timestamp() as u64), 355 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 356 - ..Default::default() 357 - }); 358 - let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim) 359 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 360 - 361 - let dpop_retry = DpopRetry::new( 362 - dpop_proof_header.clone(), 363 - dpop_proof_claim.clone(), 364 - dpop_secret_key.clone(), 365 - ); 366 - 367 - let dpop_retry_client = ClientBuilder::new(http_client.clone()) 368 - .with(ChainMiddleware::new(dpop_retry.clone())) 369 - .build(); 370 - 371 - dpop_retry_client 372 - .post(token_endpoint) 373 - .header("DPoP", dpop_proof_token.as_str()) 374 - .form(&params) 375 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 376 - .send() 377 - .await 378 - .map_err(OAuthClientError::TokenMiddlewareRequestFailed)? 379 - .json() 380 - .await 381 - .map_err(OAuthClientError::MalformedTokenResponse) 382 - } 383 - 384 - pub async fn client_oauth_refresh( 385 - http_client: &reqwest::Client, 386 - external_url_base: &str, 387 - (secret_key_id, secret_key): (&str, SecretKey), 388 - refresh_token: &str, 389 - handle: &Handle, 390 - dpop_secret_key: &SecretKey, 391 - ) -> Result<TokenResponse, OAuthClientError> { 392 - let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?; 393 - 394 - let client_assertion_header = Header { 395 - algorithm: Some("ES256".to_string()), 396 - key_id: Some(secret_key_id.to_string()), 397 - ..Default::default() 398 - }; 399 - 400 - let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base); 401 - let redirect_uri = format!("https://{}/oauth/callback", external_url_base); 402 - 403 - let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 404 - let client_assertion_claims = Claims::new(JoseClaims { 405 - issuer: Some(client_id.clone()), 406 - subject: Some(client_id.clone()), 407 - audience: Some(authorization_server.issuer.clone()), 408 - json_web_token_id: Some(client_assertion_jti), 409 - issued_at: Some(chrono::Utc::now().timestamp() as u64), 410 - ..Default::default() 411 - }); 412 - 413 - let client_assertion_token = mint_token( 414 - &secret_key, 415 - &client_assertion_header, 416 - &client_assertion_claims, 417 - ) 418 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 419 - 420 - let params = [ 421 - ("client_id", client_id.as_str()), 422 - ("redirect_uri", redirect_uri.as_str()), 423 - ("grant_type", "refresh_token"), 424 - ("refresh_token", refresh_token), 425 - ( 426 - "client_assertion_type", 427 - "urn:ietf:params:oauth:client-assertion-type:jwt-bearer", 428 - ), 429 - ("client_assertion", client_assertion_token.as_str()), 430 - ]; 431 - 432 - tracing::info!("params: {:?}", params); 433 - 434 - let public_key = dpop_secret_key.public_key(); 435 - 436 - let token_endpoint = authorization_server.token_endpoint.clone(); 437 - 438 - let now = chrono::Utc::now(); 439 - 440 - let dpop_proof_header = Header { 441 - type_: Some("dpop+jwt".to_string()), 442 - algorithm: Some("ES256".to_string()), 443 - json_web_key: Some(public_key.to_jwk()), 444 - ..Default::default() 445 - }; 446 - let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30); 447 - let dpop_proof_claim = Claims::new(JoseClaims { 448 - json_web_token_id: Some(dpop_proof_jti), 449 - http_method: Some("POST".to_string()), 450 - http_uri: Some(authorization_server.token_endpoint.clone()), 451 - issued_at: Some(now.timestamp() as u64), 452 - expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64), 453 - ..Default::default() 454 - }); 455 - let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim) 456 - .map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?; 457 - 458 - let dpop_retry = DpopRetry::new( 459 - dpop_proof_header.clone(), 460 - dpop_proof_claim.clone(), 461 - dpop_secret_key.clone(), 462 - ); 463 - 464 - let dpop_retry_client = ClientBuilder::new(http_client.clone()) 465 - .with(ChainMiddleware::new(dpop_retry.clone())) 466 - .build(); 467 - 468 - dpop_retry_client 469 - .post(token_endpoint) 470 - .header("DPoP", dpop_proof_token.as_str()) 471 - .form(&params) 472 - .timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS)) 473 - .send() 474 - .await 475 - .map_err(OAuthClientError::TokenMiddlewareRequestFailed)? 476 - .json() 477 - .await 478 - .map_err(OAuthClientError::MalformedTokenResponse) 479 - } 480 - 481 - pub mod dpop { 482 - use p256::SecretKey; 483 - use reqwest::header::HeaderValue; 484 - use reqwest_chain::Chainer; 485 - use serde::Deserialize; 486 - 487 - use crate::{ 488 - jose::{ 489 - jwt::{Claims, Header}, 490 - mint_token, 491 - }, 492 - jose_errors::JoseError, 493 - }; 494 - 495 - #[derive(Clone, Debug, Deserialize)] 496 - pub struct SimpleError { 497 - pub error: Option<String>, 498 - pub error_description: Option<String>, 499 - pub message: Option<String>, 500 - } 501 - 502 - impl std::fmt::Display for SimpleError { 503 - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { 504 - if let Some(value) = &self.error { 505 - write!(f, "{}", value) 506 - } else if let Some(value) = &self.message { 507 - write!(f, "{}", value) 508 - } else if let Some(value) = &self.error_description { 509 - write!(f, "{}", value) 510 - } else { 511 - write!(f, "unknown") 512 - } 513 - } 514 - } 515 - 516 - #[derive(Clone)] 517 - pub struct DpopRetry { 518 - pub header: Header, 519 - pub claims: Claims, 520 - pub secret: SecretKey, 521 - } 522 - 523 - impl DpopRetry { 524 - pub fn new(header: Header, claims: Claims, secret: SecretKey) -> Self { 525 - DpopRetry { 526 - header, 527 - claims, 528 - secret, 529 - } 530 - } 531 - } 532 - 533 - #[async_trait::async_trait] 534 - impl Chainer for DpopRetry { 535 - type State = (); 536 - 537 - async fn chain( 538 - &self, 539 - result: Result<reqwest::Response, reqwest_middleware::Error>, 540 - _state: &mut Self::State, 541 - request: &mut reqwest::Request, 542 - ) -> Result<Option<reqwest::Response>, reqwest_middleware::Error> { 543 - let response = result?; 544 - 545 - let status_code = response.status(); 546 - 547 - if status_code != 400 && status_code != 401 { 548 - return Ok(Some(response)); 549 - }; 550 - 551 - let headers = response.headers().clone(); 552 - 553 - let simple_error = response.json::<SimpleError>().await; 554 - if simple_error.is_err() { 555 - return Err(reqwest_middleware::Error::Middleware( 556 - JoseError::UnableToParseSimpleError.into(), 557 - )); 558 - } 559 - 560 - let simple_error = simple_error.unwrap(); 561 - 562 - tracing::error!("dpop error: {:?}", simple_error); 563 - 564 - let is_use_dpop_nonce_error = simple_error 565 - .clone() 566 - .error 567 - .is_some_and(|error_value| error_value == "use_dpop_nonce"); 568 - 569 - if !is_use_dpop_nonce_error { 570 - return Err(reqwest_middleware::Error::Middleware( 571 - JoseError::UnexpectedError(simple_error.to_string()).into(), 572 - )); 573 - } 574 - 575 - let dpop_header = headers.get("DPoP-Nonce"); 576 - 577 - if dpop_header.is_none() { 578 - return Err(reqwest_middleware::Error::Middleware( 579 - JoseError::MissingDpopHeader.into(), 580 - )); 581 - } 582 - 583 - let new_dpop_header = dpop_header.unwrap().to_str().map_err(|dpop_header_err| { 584 - reqwest_middleware::Error::Middleware( 585 - JoseError::UnableToParseDpopHeader(dpop_header_err.to_string()).into(), 586 - ) 587 - })?; 588 - 589 - let dpop_proof_header = self.header.clone(); 590 - let mut dpop_proof_claim = self.claims.clone(); 591 - dpop_proof_claim 592 - .private 593 - .insert("nonce".to_string(), new_dpop_header.to_string().into()); 594 - 595 - let dpop_proof_token = mint_token(&self.secret, &dpop_proof_header, &dpop_proof_claim) 596 - .map_err(|dpop_proof_token_err| { 597 - reqwest_middleware::Error::Middleware( 598 - JoseError::UnableToMintDpopProofToken(dpop_proof_token_err.to_string()) 599 - .into(), 600 - ) 601 - })?; 602 - 603 - request.headers_mut().insert( 604 - "DPoP", 605 - HeaderValue::from_str(&dpop_proof_token).expect("invalid header value"), 606 - ); 607 - Ok(None) 608 - } 609 - } 610 - } 611 - 612 - pub mod model { 613 - use serde::Deserialize; 614 - 615 - #[derive(Clone, Deserialize)] 616 - pub struct OAuthProtectedResource { 617 - pub resource: String, 618 - pub authorization_servers: Vec<String>, 619 - pub scopes_supported: Vec<String>, 620 - pub bearer_methods_supported: Vec<String>, 621 - } 622 - 623 - #[derive(Clone, Deserialize, Default, Debug)] 624 - pub struct AuthorizationServer { 625 - pub introspection_endpoint: String, 626 - pub authorization_endpoint: String, 627 - pub authorization_response_iss_parameter_supported: bool, 628 - pub client_id_metadata_document_supported: bool, 629 - pub code_challenge_methods_supported: Vec<String>, 630 - pub dpop_signing_alg_values_supported: Vec<String>, 631 - pub grant_types_supported: Vec<String>, 632 - pub issuer: String, 633 - pub pushed_authorization_request_endpoint: String, 634 - pub request_parameter_supported: bool, 635 - pub require_pushed_authorization_requests: bool, 636 - pub response_types_supported: Vec<String>, 637 - pub scopes_supported: Vec<String>, 638 - pub token_endpoint_auth_methods_supported: Vec<String>, 639 - pub token_endpoint_auth_signing_alg_values_supported: Vec<String>, 640 - pub token_endpoint: String, 641 - } 642 - 643 - #[derive(Clone, Deserialize)] 644 - pub struct ParResponse { 645 - pub request_uri: String, 646 - pub expires_in: u64, 647 - } 648 - 649 - #[derive(Clone, Deserialize)] 650 - pub struct TokenResponse { 651 - pub access_token: String, 652 - pub token_type: String, 653 - pub refresh_token: String, 654 - pub scope: String, 655 - pub expires_in: u32, 656 - pub sub: String, 657 - } 658 - } 659 - 660 - // This errors module is now deprecated. 661 - // Use crate::oauth_client_errors::OAuthClientError instead. 662 - pub mod errors { 663 - pub use crate::oauth_client_errors::OAuthClientError; 664 - }
-99
src/oauth_client_errors.rs
··· 1 - use thiserror::Error; 2 - 3 - /// Represents errors that can occur during OAuth client operations. 4 - /// 5 - /// These errors are related to the OAuth client functionality, including 6 - /// interacting with authorization servers, protected resources, and token management. 7 - #[derive(Debug, Error)] 8 - pub enum OAuthClientError { 9 - /// Error when a request to the authorization server fails. 10 - /// 11 - /// This error occurs when the OAuth client fails to establish a connection 12 - /// or complete a request to the authorization server. 13 - #[error("error-oauth-client-1 Authorization Server Request Failed: {0:?}")] 14 - AuthorizationServerRequestFailed(reqwest::Error), 15 - 16 - /// Error when the authorization server response is malformed. 17 - /// 18 - /// This error occurs when the response from the authorization server 19 - /// cannot be properly parsed or processed. 20 - #[error("error-oauth-client-2 Malformed Authorization Server Response: {0:?}")] 21 - MalformedAuthorizationServerResponse(reqwest::Error), 22 - 23 - /// Error when the authorization server response is invalid. 24 - /// 25 - /// This error occurs when the response from the authorization server 26 - /// is well-formed but contains invalid or unexpected data. 27 - #[error("error-oauth-client-3 Invalid Authorization Server Response: {0:?}")] 28 - InvalidAuthorizationServerResponse(anyhow::Error), 29 - 30 - /// Error when an OAuth protected resource is invalid. 31 - /// 32 - /// This error occurs when trying to access a protected resource that 33 - /// is not properly configured for OAuth access. 34 - #[error("error-oauth-client-4 Invalid OAuth Protected Resource")] 35 - InvalidOAuthProtectedResource, 36 - 37 - /// Error when a request to an OAuth protected resource fails. 38 - /// 39 - /// This error occurs when the OAuth client fails to establish a connection 40 - /// or complete a request to a protected resource. 41 - #[error("error-oauth-client-5 OAuth Protected Resource Request Failed: {0:?}")] 42 - OAuthProtectedResourceRequestFailed(reqwest::Error), 43 - 44 - /// Error when a protected resource response is malformed. 45 - /// 46 - /// This error occurs when the response from a protected resource 47 - /// cannot be properly parsed or processed. 48 - #[error("error-oauth-client-6 Malformed OAuth Protected Resource Response: {0:?}")] 49 - MalformedOAuthProtectedResourceResponse(reqwest::Error), 50 - 51 - /// Error when a protected resource response is invalid. 52 - /// 53 - /// This error occurs when the response from a protected resource 54 - /// is well-formed but contains invalid or unexpected data. 55 - #[error("error-oauth-client-7 Invalid OAuth Protected Resource Response: {0:?}")] 56 - InvalidOAuthProtectedResourceResponse(anyhow::Error), 57 - 58 - /// Error when a PAR middleware request fails. 59 - /// 60 - /// This error occurs when a Pushed Authorization Request (PAR) middleware 61 - /// request fails to complete successfully. 62 - #[error("error-oauth-client-8 PAR Middleware Request Failed: {0:?}")] 63 - PARMiddlewareRequestFailed(reqwest_middleware::Error), 64 - 65 - /// Error when a PAR request fails. 66 - /// 67 - /// This error occurs when a Pushed Authorization Request (PAR) 68 - /// fails to be properly processed by the authorization server. 69 - #[error("error-oauth-client-9 PAR Request Failed: {0:?}")] 70 - PARRequestFailed(reqwest::Error), 71 - 72 - /// Error when a PAR response is malformed. 73 - /// 74 - /// This error occurs when the response from a Pushed Authorization 75 - /// Request (PAR) cannot be properly parsed or processed. 76 - #[error("error-oauth-client-10 Malformed PAR Response: {0:?}")] 77 - MalformedPARResponse(reqwest::Error), 78 - 79 - /// Error when token minting fails. 80 - /// 81 - /// This error occurs when the system fails to mint (create) a new 82 - /// OAuth token, typically due to cryptographic or validation issues. 83 - #[error("error-oauth-client-11 Token minting failed: {0:?}")] 84 - MintTokenFailed(anyhow::Error), 85 - 86 - /// Error when a token response is malformed. 87 - /// 88 - /// This error occurs when the response containing a token cannot 89 - /// be properly parsed or processed. 90 - #[error("error-oauth-client-12 Malformed Token Response: {0:?}")] 91 - MalformedTokenResponse(reqwest::Error), 92 - 93 - /// Error when a token middleware request fails. 94 - /// 95 - /// This error occurs when a token-related middleware request 96 - /// fails to complete successfully. 97 - #[error("error-oauth-client-13 Token Request Failed: {0:?}")] 98 - TokenMiddlewareRequestFailed(reqwest_middleware::Error), 99 - }
-116
src/oauth_errors.rs
··· 1 - use thiserror::Error; 2 - 3 - /// Represents errors that can occur during OAuth resource validation. 4 - /// 5 - /// These errors occur when validating the configuration of an OAuth resource server 6 - /// against the requirements of the AT Protocol. 7 - #[derive(Debug, Error)] 8 - pub enum ResourceValidationError { 9 - /// Error when the resource server URI doesn't match the PDS URI. 10 - /// 11 - /// This error occurs when the resource server URI in the OAuth configuration 12 - /// does not match the expected Personal Data Server (PDS) URI, which is required 13 - /// for proper AT Protocol OAuth integration. 14 - #[error("error-oauth-resource-1 Resource must match PDS")] 15 - ResourceMustMatchPds, 16 - 17 - /// Error when the authorization servers list is empty. 18 - /// 19 - /// This error occurs when the OAuth resource configuration doesn't specify 20 - /// any authorization servers, which is required for AT Protocol OAuth flows. 21 - #[error("error-oauth-resource-2 Authorization servers must not be empty")] 22 - AuthorizationServersMustNotBeEmpty, 23 - } 24 - 25 - /// Represents errors that can occur during OAuth authorization server validation. 26 - /// 27 - /// These errors occur when validating the configuration of an OAuth authorization server 28 - /// against the requirements specified by the AT Protocol. 29 - #[derive(Debug, Error)] 30 - pub enum AuthServerValidationError { 31 - /// Error when the authorization server issuer doesn't match the PDS. 32 - /// 33 - /// This error occurs when the issuer URI in the OAuth authorization server metadata 34 - /// does not match the expected Personal Data Server (PDS) URI. 35 - #[error("error-oauth-auth-server-1 Issuer must match PDS")] 36 - IssuerMustMatchPds, 37 - 38 - /// Error when the 'code' response type is not supported. 39 - /// 40 - /// This error occurs when the authorization server doesn't support the 'code' response type, 41 - /// which is required for the authorization code grant flow in AT Protocol. 42 - #[error("error-oauth-auth-server-2 Response types supported must include 'code'")] 43 - ResponseTypesSupportMustIncludeCode, 44 - 45 - /// Error when the 'authorization_code' grant type is not supported. 46 - /// 47 - /// This error occurs when the authorization server doesn't support the 'authorization_code' 48 - /// grant type, which is required for the AT Protocol OAuth flow. 49 - #[error("error-oauth-auth-server-3 Grant types supported must include 'authorization_code'")] 50 - GrantTypesSupportMustIncludeAuthorizationCode, 51 - 52 - /// Error when the 'refresh_token' grant type is not supported. 53 - /// 54 - /// This error occurs when the authorization server doesn't support the 'refresh_token' 55 - /// grant type, which is required for maintaining long-term access in AT Protocol. 56 - #[error("error-oauth-auth-server-4 Grant types supported must include 'refresh_token'")] 57 - GrantTypesSupportMustIncludeRefreshToken, 58 - 59 - /// Error when the 'S256' code challenge method is not supported. 60 - /// 61 - /// This error occurs when the authorization server doesn't support the 'S256' code 62 - /// challenge method for PKCE, which is required for secure authorization code flow. 63 - #[error("error-oauth-auth-server-5 Code challenge methods supported must include 'S256'")] 64 - CodeChallengeMethodsSupportedMustIncludeS256, 65 - 66 - /// Error when the 'none' token endpoint auth method is not supported. 67 - /// 68 - /// This error occurs when the authorization server doesn't support the 'none' 69 - /// token endpoint authentication method, which is used for public clients. 70 - #[error("error-oauth-auth-server-6 Token endpoint auth methods supported must include 'none'")] 71 - TokenEndpointAuthMethodsSupportedMustIncludeNone, 72 - 73 - /// Error when the 'private_key_jwt' token endpoint auth method is not supported. 74 - /// 75 - /// This error occurs when the authorization server doesn't support the 'private_key_jwt' 76 - /// token endpoint authentication method, which is required for AT Protocol clients. 77 - #[error("error-oauth-auth-server-7 Token endpoint auth methods supported must include 'private_key_jwt'")] 78 - TokenEndpointAuthMethodsSupportedMustIncludePrivateKeyJwt, 79 - 80 - /// Error when the 'ES256' signing algorithm is not supported for token endpoint auth. 81 - /// 82 - /// This error occurs when the authorization server doesn't support the 'ES256' signing 83 - /// algorithm for token endpoint authentication, which is required for AT Protocol. 84 - #[error("error-oauth-auth-server-8 Token endpoint auth signing algorithm values must include 'ES256'")] 85 - TokenEndpointAuthSigningAlgValuesMustIncludeES256, 86 - 87 - /// Error when the 'atproto' scope is not supported. 88 - /// 89 - /// This error occurs when the authorization server doesn't support the 'atproto' 90 - /// scope, which is required for accessing AT Protocol resources. 91 - #[error("error-oauth-auth-server-9 Scopes supported must include 'atproto'")] 92 - ScopesSupportedMustIncludeAtProto, 93 - 94 - /// Error when the 'transition:generic' scope is not supported. 95 - /// 96 - /// This error occurs when the authorization server doesn't support the 'transition:generic' 97 - /// scope, which is required for transitional functionality in AT Protocol. 98 - #[error("error-oauth-auth-server-10 Scopes supported must include 'transition:generic'")] 99 - ScopesSupportedMustIncludeTransitionGeneric, 100 - 101 - /// Error when the 'ES256' DPoP signing algorithm is not supported. 102 - /// 103 - /// This error occurs when the authorization server doesn't support the 'ES256' 104 - /// signing algorithm for DPoP proofs, which is required for AT Protocol security. 105 - #[error( 106 - "error-oauth-auth-server-11 DPoP signing algorithm values supported must include 'ES256'" 107 - )] 108 - DpopSigningAlgValuesSupportedMustIncludeES256, 109 - 110 - /// Error when required server features are not supported. 111 - /// 112 - /// This error occurs when the authorization server doesn't support required features 113 - /// such as pushed authorization requests, client ID metadata, or authorization response parameters. 114 - #[error("error-oauth-auth-server-12 Authorization response parameters, pushed requests, client ID metadata must be supported")] 115 - RequiredServerFeaturesMustBeSupported, 116 - }
+7
src/refresh_tokens_errors.rs
··· 26 26 /// used to manage session refresh operations. 27 27 #[error("error-refresh-3 Failed to place session group into refresh queue: {0:?}")] 28 28 PlaceInRefreshQueueFailed(deadpool_redis::redis::RedisError), 29 + 30 + /// Error when the identity document cannot be found. 31 + /// 32 + /// This error occurs when attempting to refresh a token but the necessary 33 + /// identity document for the user is not available in storage. 34 + #[error("error-refresh-4 Identity document not found")] 35 + IdentityDocumentNotFound, 29 36 }
-205
src/resolve.rs
··· 1 - use anyhow::Result; 2 - use errors::ResolveError; 3 - use futures_util::future::join3; 4 - use hickory_resolver::{ 5 - config::{NameServerConfigGroup, ResolverConfig, ResolverOpts}, 6 - TokioAsyncResolver, 7 - }; 8 - use std::collections::HashSet; 9 - use std::time::Duration; 10 - 11 - use crate::config::DnsNameservers; 12 - use crate::did::web::query_hostname; 13 - 14 - pub enum InputType { 15 - Handle(String), 16 - Plc(String), 17 - Web(String), 18 - } 19 - 20 - pub async fn resolve_handle_dns( 21 - dns_resolver: &TokioAsyncResolver, 22 - lookup_dns: &str, 23 - ) -> Result<String, ResolveError> { 24 - let lookup = dns_resolver 25 - .txt_lookup(&format!("_atproto.{}", lookup_dns)) 26 - .await 27 - .map_err(ResolveError::DNSResolutionFailed)?; 28 - 29 - let dids = lookup 30 - .iter() 31 - .filter_map(|record| { 32 - record 33 - .to_string() 34 - .strip_prefix("did=") 35 - .map(|did| did.to_string()) 36 - }) 37 - .collect::<HashSet<String>>(); 38 - 39 - if dids.len() > 1 { 40 - return Err(ResolveError::MultipleDIDsFound); 41 - } 42 - 43 - dids.iter().next().cloned().ok_or(ResolveError::NoDIDsFound) 44 - } 45 - 46 - pub async fn resolve_handle_http( 47 - http_client: &reqwest::Client, 48 - handle: &str, 49 - ) -> Result<String, ResolveError> { 50 - let lookup_url = format!("https://{}/.well-known/atproto-did", handle); 51 - 52 - http_client 53 - .get(lookup_url.clone()) 54 - .timeout(Duration::from_secs(10)) 55 - .send() 56 - .await 57 - .map_err(ResolveError::HTTPResolutionFailed)? 58 - .text() 59 - .await 60 - .map_err(ResolveError::HTTPResolutionFailed) 61 - .and_then(|body| { 62 - if body.starts_with("did:") { 63 - Ok(body.trim().to_string()) 64 - } else { 65 - Err(ResolveError::InvalidHTTPResolutionResponse) 66 - } 67 - }) 68 - } 69 - 70 - pub fn parse_input(input: &str) -> Result<InputType, ResolveError> { 71 - let trimmed = { 72 - if let Some(value) = input.trim().strip_prefix("at://") { 73 - value.trim() 74 - } else if let Some(value) = input.trim().strip_prefix('@') { 75 - value.trim() 76 - } else { 77 - input.trim() 78 - } 79 - }; 80 - if trimmed.is_empty() { 81 - return Err(ResolveError::InvalidInput); 82 - } 83 - if trimmed.starts_with("did:web:") { 84 - Ok(InputType::Web(trimmed.to_string())) 85 - } else if trimmed.starts_with("did:plc:") { 86 - Ok(InputType::Plc(trimmed.to_string())) 87 - } else { 88 - Ok(InputType::Handle(trimmed.to_string())) 89 - } 90 - } 91 - 92 - pub async fn resolve_handle( 93 - http_client: &reqwest::Client, 94 - dns_resolver: &TokioAsyncResolver, 95 - handle: &str, 96 - ) -> Result<String, ResolveError> { 97 - let trimmed = { 98 - if let Some(value) = handle.trim().strip_prefix("at://") { 99 - value 100 - } else if let Some(value) = handle.trim().strip_prefix('@') { 101 - value 102 - } else { 103 - handle.trim() 104 - } 105 - }; 106 - 107 - let (dns_lookup, http_lookup, did_web_lookup) = join3( 108 - resolve_handle_dns(dns_resolver, trimmed), 109 - resolve_handle_http(http_client, trimmed), 110 - query_hostname(http_client, trimmed), 111 - ) 112 - .await; 113 - 114 - tracing::debug!( 115 - ?handle, 116 - ?dns_lookup, 117 - ?http_lookup, 118 - ?did_web_lookup, 119 - "raw query results" 120 - ); 121 - 122 - let did_web_lookup_did = did_web_lookup 123 - .map(|document| document.id) 124 - .map_err(ResolveError::DIDWebResolutionFailed); 125 - 126 - let results = vec![dns_lookup, http_lookup, did_web_lookup_did] 127 - .into_iter() 128 - .filter_map(|result| result.ok()) 129 - .collect::<Vec<String>>(); 130 - if results.is_empty() { 131 - return Err(ResolveError::NoDIDsFound); 132 - } 133 - 134 - tracing::debug!(?handle, ?results, "query results"); 135 - 136 - let first = results[0].clone(); 137 - if results.iter().all(|result| result == &first) { 138 - return Ok(first); 139 - } 140 - Err(ResolveError::ConflictingDIDsFound) 141 - } 142 - 143 - pub async fn resolve_subject( 144 - http_client: &reqwest::Client, 145 - dns_resolver: &TokioAsyncResolver, 146 - subject: &str, 147 - ) -> Result<String, ResolveError> { 148 - match parse_input(subject)? { 149 - InputType::Handle(handle) => resolve_handle(http_client, dns_resolver, &handle).await, 150 - InputType::Plc(did) | InputType::Web(did) => Ok(did), 151 - } 152 - } 153 - 154 - /// Creates a new DNS resolver with configuration based on app config. 155 - /// 156 - /// If custom nameservers are configured in app config, they will be used. 157 - /// Otherwise, the system default resolver configuration will be used. 158 - pub fn create_resolver(nameservers: DnsNameservers) -> TokioAsyncResolver { 159 - // Initialize the DNS resolver with custom nameservers if configured 160 - let nameservers = nameservers.as_ref(); 161 - let resolver_config = if !nameservers.is_empty() { 162 - // Use custom nameservers 163 - tracing::info!("Using custom DNS nameservers: {:?}", nameservers); 164 - let nameserver_group = NameServerConfigGroup::from_ips_clear(nameservers, 53, true); 165 - ResolverConfig::from_parts(None, vec![], nameserver_group) 166 - } else { 167 - // Use system default 168 - tracing::info!("Using system default DNS nameservers"); 169 - ResolverConfig::default() 170 - }; 171 - 172 - // TokioAsyncResolver::tokio returns an AsyncResolver directly, not a Result 173 - TokioAsyncResolver::tokio(resolver_config, ResolverOpts::default()) 174 - } 175 - 176 - pub mod errors { 177 - use thiserror::Error; 178 - 179 - #[derive(Debug, Error)] 180 - pub enum ResolveError { 181 - #[error("error-resolve-1 Multiple DIDs resolved for method")] 182 - MultipleDIDsFound, 183 - 184 - #[error("error-resolve-2 No DIDs resolved for method")] 185 - NoDIDsFound, 186 - 187 - #[error("error-resolve-3 No DIDs resolved for method")] 188 - ConflictingDIDsFound, 189 - 190 - #[error("error-resolve-4 DNS resolution failed: {0:?}")] 191 - DNSResolutionFailed(hickory_resolver::error::ResolveError), 192 - 193 - #[error("error-resolve-5 HTTP resolution failed: {0:?}")] 194 - HTTPResolutionFailed(reqwest::Error), 195 - 196 - #[error("error-resolve-6 HTTP resolution failed")] 197 - InvalidHTTPResolutionResponse, 198 - 199 - #[error("error-resolve-7 HTTP resolution failed: {0:?}")] 200 - DIDWebResolutionFailed(anyhow::Error), 201 - 202 - #[error("error-resolve-8 Invalid input")] 203 - InvalidInput, 204 - } 205 - }
+246
src/storage/atproto.rs
··· 1 + use async_trait::async_trait; 2 + use atproto_identity::{model::Document, storage::DidDocumentStorage}; 3 + use atproto_oauth::{storage::OAuthRequestStorage, workflow::OAuthRequest}; 4 + use chrono::{DateTime, Utc}; 5 + use serde_json::Value as JsonValue; 6 + use sqlx::FromRow; 7 + use std::sync::Arc; 8 + 9 + use crate::storage::{errors::StorageError, StoragePool}; 10 + 11 + /// Database row representation of OAuthRequest 12 + #[derive(FromRow)] 13 + struct OAuthRequestRow { 14 + pub oauth_state: String, 15 + pub issuer: String, 16 + pub did: String, 17 + pub nonce: String, 18 + pub pkce_verifier: String, 19 + pub signing_public_key: String, 20 + pub dpop_private_key: String, 21 + pub created_at: DateTime<Utc>, 22 + pub expires_at: DateTime<Utc>, 23 + } 24 + 25 + /// Postgres implementation of DidDocumentStorage trait 26 + pub struct PostgresDidDocumentStorage { 27 + pool: StoragePool, 28 + } 29 + 30 + impl PostgresDidDocumentStorage { 31 + pub fn new(pool: StoragePool) -> Self { 32 + Self { pool } 33 + } 34 + 35 + pub fn new_arc(pool: StoragePool) -> Arc<dyn DidDocumentStorage> { 36 + Arc::new(Self::new(pool)) 37 + } 38 + } 39 + 40 + #[async_trait] 41 + impl DidDocumentStorage for PostgresDidDocumentStorage { 42 + async fn get_document_by_did(&self, did: &str) -> Result<Option<Document>, anyhow::Error> { 43 + if did.trim().is_empty() { 44 + return Err(anyhow::anyhow!("DID cannot be empty")); 45 + } 46 + 47 + let mut tx = self.pool.begin().await?; 48 + 49 + let result = sqlx::query_scalar::<_, JsonValue>( 50 + "SELECT document_json FROM did_documents WHERE did = $1 AND (expires_at IS NULL OR expires_at > NOW())" 51 + ) 52 + .bind(did) 53 + .fetch_optional(tx.as_mut()) 54 + .await?; 55 + 56 + if let Some(json_value) = result { 57 + // Update the updated_at timestamp for LRU behavior 58 + sqlx::query("UPDATE did_documents SET updated_at = NOW() WHERE did = $1") 59 + .bind(did) 60 + .execute(tx.as_mut()) 61 + .await?; 62 + 63 + tx.commit().await?; 64 + 65 + // Convert JSON back to Document 66 + let document: Document = serde_json::from_value(json_value)?; 67 + Ok(Some(document)) 68 + } else { 69 + tx.commit().await?; 70 + Ok(None) 71 + } 72 + } 73 + 74 + async fn store_document(&self, document: Document) -> Result<(), anyhow::Error> { 75 + let did = document.id.clone(); 76 + let document_json = serde_json::to_value(&document)?; 77 + 78 + let mut tx = self.pool.begin().await?; 79 + 80 + sqlx::query( 81 + "INSERT INTO did_documents (did, document_json) VALUES ($1, $2) 82 + ON CONFLICT (did) DO UPDATE SET 83 + document_json = EXCLUDED.document_json, 84 + updated_at = NOW()", 85 + ) 86 + .bind(&did) 87 + .bind(&document_json) 88 + .execute(tx.as_mut()) 89 + .await?; 90 + 91 + tx.commit().await?; 92 + Ok(()) 93 + } 94 + 95 + async fn delete_document_by_did(&self, did: &str) -> Result<(), anyhow::Error> { 96 + if did.trim().is_empty() { 97 + return Err(anyhow::anyhow!("DID cannot be empty")); 98 + } 99 + 100 + let mut tx = self.pool.begin().await?; 101 + 102 + sqlx::query("DELETE FROM did_documents WHERE did = $1") 103 + .bind(did) 104 + .execute(tx.as_mut()) 105 + .await?; 106 + 107 + tx.commit().await?; 108 + Ok(()) 109 + } 110 + } 111 + 112 + /// Postgres implementation of OAuthRequestStorage trait 113 + pub struct PostgresOAuthRequestStorage { 114 + pool: StoragePool, 115 + } 116 + 117 + impl PostgresOAuthRequestStorage { 118 + pub fn new(pool: StoragePool) -> Self { 119 + Self { pool } 120 + } 121 + 122 + pub fn new_arc(pool: StoragePool) -> Arc<dyn OAuthRequestStorage> { 123 + Arc::new(Self::new(pool)) 124 + } 125 + } 126 + 127 + #[async_trait] 128 + impl OAuthRequestStorage for PostgresOAuthRequestStorage { 129 + async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<(), anyhow::Error> { 130 + let mut tx = self.pool.begin().await?; 131 + 132 + sqlx::query( 133 + "INSERT INTO atproto_oauth_requests ( 134 + oauth_state, issuer, did, nonce, pkce_verifier, 135 + signing_public_key, dpop_private_key, created_at, expires_at 136 + ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)", 137 + ) 138 + .bind(&request.oauth_state) 139 + .bind(&request.issuer) 140 + .bind(&request.did) 141 + .bind(&request.nonce) 142 + .bind(&request.pkce_verifier) 143 + .bind(&request.signing_public_key) 144 + .bind(&request.dpop_private_key) 145 + .bind(request.created_at) 146 + .bind(request.expires_at) 147 + .execute(tx.as_mut()) 148 + .await?; 149 + 150 + tx.commit().await?; 151 + Ok(()) 152 + } 153 + 154 + async fn get_oauth_request_by_state( 155 + &self, 156 + state: &str, 157 + ) -> Result<Option<OAuthRequest>, anyhow::Error> { 158 + if state.trim().is_empty() { 159 + return Err(anyhow::anyhow!("OAuth state cannot be empty")); 160 + } 161 + 162 + let mut tx = self.pool.begin().await?; 163 + 164 + let result = sqlx::query_as::<_, OAuthRequestRow>( 165 + "SELECT oauth_state, issuer, did, nonce, pkce_verifier, 166 + signing_public_key, dpop_private_key, created_at, expires_at 167 + FROM atproto_oauth_requests 168 + WHERE oauth_state = $1 AND expires_at > NOW()", 169 + ) 170 + .bind(state) 171 + .fetch_optional(tx.as_mut()) 172 + .await?; 173 + 174 + tx.commit().await?; 175 + 176 + if let Some(row) = result { 177 + let oauth_request = OAuthRequest { 178 + oauth_state: row.oauth_state, 179 + issuer: row.issuer, 180 + did: row.did, 181 + nonce: row.nonce, 182 + pkce_verifier: row.pkce_verifier, 183 + signing_public_key: row.signing_public_key, 184 + dpop_private_key: row.dpop_private_key, 185 + created_at: row.created_at, 186 + expires_at: row.expires_at, 187 + }; 188 + Ok(Some(oauth_request)) 189 + } else { 190 + Ok(None) 191 + } 192 + } 193 + 194 + async fn delete_oauth_request_by_state(&self, state: &str) -> Result<(), anyhow::Error> { 195 + if state.trim().is_empty() { 196 + return Err(anyhow::anyhow!("OAuth state cannot be empty")); 197 + } 198 + 199 + let mut tx = self.pool.begin().await?; 200 + 201 + sqlx::query("DELETE FROM atproto_oauth_requests WHERE oauth_state = $1") 202 + .bind(state) 203 + .execute(tx.as_mut()) 204 + .await?; 205 + 206 + tx.commit().await?; 207 + Ok(()) 208 + } 209 + 210 + async fn clear_expired_oauth_requests(&self) -> Result<u64, anyhow::Error> { 211 + let mut tx = self.pool.begin().await?; 212 + 213 + let result = sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at <= NOW()") 214 + .execute(tx.as_mut()) 215 + .await?; 216 + 217 + tx.commit().await?; 218 + Ok(result.rows_affected()) 219 + } 220 + } 221 + 222 + /// Cleanup expired DID documents and OAuth requests 223 + pub async fn cleanup_expired_records(pool: &StoragePool) -> Result<(), StorageError> { 224 + let mut tx = pool 225 + .begin() 226 + .await 227 + .map_err(StorageError::CannotBeginDatabaseTransaction)?; 228 + 229 + // Clean up expired DID documents 230 + sqlx::query("DELETE FROM did_documents WHERE expires_at IS NOT NULL AND expires_at <= NOW()") 231 + .execute(tx.as_mut()) 232 + .await 233 + .map_err(StorageError::UnableToExecuteQuery)?; 234 + 235 + // Clean up expired OAuth requests 236 + sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at <= NOW()") 237 + .execute(tx.as_mut()) 238 + .await 239 + .map_err(StorageError::UnableToExecuteQuery)?; 240 + 241 + tx.commit() 242 + .await 243 + .map_err(StorageError::CannotCommitDatabaseTransaction)?; 244 + 245 + Ok(()) 246 + }
+8 -36
src/storage/denylist.rs
··· 8 8 9 9 use crate::storage::{errors::StorageError, StoragePool}; 10 10 11 - pub mod model { 11 + pub(crate) mod model { 12 12 use chrono::{DateTime, Utc}; 13 13 use serde::{Deserialize, Serialize}; 14 14 use sqlx::FromRow; ··· 22 22 } 23 23 24 24 // Add a new entry to the denylist or update an existing one 25 - pub async fn denylist_add_or_update( 25 + pub(crate) async fn denylist_add_or_update( 26 26 pool: &StoragePool, 27 27 subject: Cow<'_, str>, 28 28 reason: Cow<'_, str>, ··· 68 68 } 69 69 70 70 // Remove an entry from the denylist 71 - pub async fn denylist_remove(pool: &StoragePool, subject: &str) -> Result<(), StorageError> { 71 + pub(crate) async fn denylist_remove(pool: &StoragePool, subject: &str) -> Result<(), StorageError> { 72 72 // Validate subject before proceeding 73 73 if subject.trim().is_empty() { 74 74 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( ··· 98 98 Ok(()) 99 99 } 100 100 101 - // Check if a subject is in the denylist 102 - pub async fn denylist_check(pool: &StoragePool, subject: &str) -> Result<bool, StorageError> { 103 - // Validate subject before proceeding 104 - if subject.trim().is_empty() { 105 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 106 - "Subject cannot be empty".into(), 107 - ))); 108 - } 109 - 110 - let mut tx = pool 111 - .begin() 112 - .await 113 - .map_err(StorageError::CannotBeginDatabaseTransaction)?; 114 - 115 - let mut h = MetroHash64::default(); 116 - h.write(subject.as_bytes()); 117 - let subject = crockford::encode(h.finish()); 118 - 119 - let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM denylist WHERE subject = $1") 120 - .bind(subject) 121 - .fetch_one(tx.as_mut()) 122 - .await 123 - .map_err(StorageError::UnableToExecuteQuery)?; 124 - 125 - tx.commit() 126 - .await 127 - .map_err(StorageError::CannotCommitDatabaseTransaction)?; 128 - 129 - Ok(count > 0) 130 - } 131 - 132 101 // Get a list of denylist entries with pagination 133 - pub async fn denylist_list( 102 + pub(crate) async fn denylist_list( 134 103 pool: &StoragePool, 135 104 page: i64, 136 105 page_size: i64, ··· 163 132 Ok((count, entries)) 164 133 } 165 134 166 - pub async fn denylist_exists(pool: &StoragePool, subjects: &[&str]) -> Result<bool, StorageError> { 135 + pub(crate) async fn denylist_exists( 136 + pool: &StoragePool, 137 + subjects: &[&str], 138 + ) -> Result<bool, StorageError> { 167 139 // Validate input - empty array should return false, not error 168 140 if subjects.is_empty() { 169 141 return Ok(false);
+7
src/storage/errors.rs
··· 22 22 #[error("error-oauth-model-2 Invalid OAuth flow state")] 23 23 InvalidOAuthFlowState(), 24 24 25 + /// Error when deserializing DPoP JWK from string fails. 26 + /// 27 + /// This error occurs when attempting to deserialize a string-encoded 28 + /// JSON Web Key (JWK) for DPoP operations, typically due to invalid JSON format. 29 + #[error("error-oauth-model-5 Failed to deserialize DPoP JWK: {0:?}")] 30 + DpopJwkDeserializationFailed(serde_json::Error), 31 + 25 32 /// Error when required OAuth session data is missing. 26 33 /// 27 34 /// This error occurs when attempting to use an OAuth session
+38 -9
src/storage/event.rs
··· 21 21 use sqlx::FromRow; 22 22 23 23 #[derive(Clone, FromRow, Deserialize, Serialize, Debug)] 24 - pub struct Event { 24 + pub(crate) struct Event { 25 25 pub aturi: String, 26 26 pub cid: String, 27 27 ··· 36 36 } 37 37 38 38 #[derive(Clone, FromRow, Debug, Serialize)] 39 - pub struct EventWithRole { 39 + pub(crate) struct EventWithRole { 40 40 #[sqlx(flatten)] 41 - pub event: Event, 41 + pub(crate) event: Event, 42 42 43 - pub role: String, 43 + pub(crate) role: String, 44 44 // pub event_handle: String, 45 45 } 46 46 ··· 250 250 } 251 251 } 252 252 253 - pub fn extract_event_details(event: &Event) -> EventDetails { 253 + pub(crate) fn extract_event_details(event: &Event) -> EventDetails { 254 254 use crate::atproto::lexicon::{ 255 255 community::lexicon::calendar::event::{Event as CommunityEvent, Mode, Status}, 256 256 events::smokesignal::calendar::event::Event as SmokeSignalEvent, ··· 469 469 pub uris: Vec<crate::atproto::lexicon::community::lexicon::calendar::event::EventLink>, 470 470 } 471 471 472 - pub async fn event_get(pool: &StoragePool, aturi: &str) -> Result<Event, StorageError> { 472 + pub(crate) async fn event_get(pool: &StoragePool, aturi: &str) -> Result<Event, StorageError> { 473 473 // Validate aturi is not empty 474 474 if aturi.trim().is_empty() { 475 475 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( ··· 553 553 Ok(record) 554 554 } 555 555 556 - pub async fn event_list_did_recently_updated( 556 + pub(crate) async fn event_list_did_recently_updated( 557 557 pool: &StoragePool, 558 558 did: &str, 559 559 page: i64, ··· 611 611 Ok(event_roles) 612 612 } 613 613 614 - pub async fn event_list_recently_updated( 614 + pub(crate) async fn event_list_recently_updated( 615 615 pool: &StoragePool, 616 616 page: i64, 617 617 page_size: i64, ··· 954 954 ))) 955 955 } 956 956 957 - pub async fn event_list( 957 + pub(crate) async fn event_list( 958 958 pool: &StoragePool, 959 959 page: i64, 960 960 page_size: i64, ··· 993 993 994 994 Ok((total_count, events)) 995 995 } 996 + 997 + pub async fn event_delete(pool: &StoragePool, aturi: &str) -> Result<(), StorageError> { 998 + let mut tx = pool 999 + .begin() 1000 + .await 1001 + .map_err(StorageError::CannotBeginDatabaseTransaction)?; 1002 + 1003 + // Delete only the event record (RSVPs are preserved) 1004 + let result = sqlx::query("DELETE FROM events WHERE aturi = $1") 1005 + .bind(aturi) 1006 + .execute(tx.as_mut()) 1007 + .await 1008 + .map_err(StorageError::UnableToExecuteQuery)?; 1009 + 1010 + if result.rows_affected() == 0 { 1011 + // Rollback the transaction - we don't need to map the error 1012 + let _ = tx.rollback().await; 1013 + return Err(StorageError::RowNotFound( 1014 + "Event not found".to_string(), 1015 + sqlx::Error::RowNotFound, 1016 + )); 1017 + } 1018 + 1019 + tx.commit() 1020 + .await 1021 + .map_err(StorageError::CannotCommitDatabaseTransaction)?; 1022 + 1023 + Ok(()) 1024 + }
+63 -80
src/storage/handle.rs src/storage/identity_profile.rs
··· 7 7 use crate::storage::denylist::denylist_add_or_update; 8 8 use crate::storage::errors::StorageError; 9 9 use crate::storage::StoragePool; 10 - use model::Handle; 10 + use model::IdentityProfile; 11 11 12 12 pub mod model { 13 13 use chrono::{DateTime, Utc}; ··· 15 15 use sqlx::FromRow; 16 16 17 17 #[derive(Clone, FromRow, Deserialize, Serialize, Debug)] 18 - pub struct Handle { 18 + pub struct IdentityProfile { 19 19 pub did: String, 20 20 pub handle: String, 21 21 pub pds: String, ··· 60 60 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 61 61 62 62 let now = Utc::now(); 63 - let insert_result = sqlx::query("INSERT INTO handles (did, handle, pds, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING") 63 + let insert_result = sqlx::query("INSERT INTO identity_profiles (did, handle, pds, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING") 64 64 .bind(did) 65 65 .bind(handle) 66 66 .bind(pds) ··· 71 71 .map_err(StorageError::UnableToExecuteQuery)?; 72 72 73 73 if insert_result.rows_affected() == 0 { 74 - sqlx::query("UPDATE handles SET updated_at = $1, handle = $2, pds = $3 WHERE did = $4") 75 - .bind(now) 76 - .bind(handle) 77 - .bind(pds) 78 - .bind(did) 79 - .execute(tx.as_mut()) 80 - .await 81 - .map_err(StorageError::UnableToExecuteQuery)?; 74 + sqlx::query( 75 + "UPDATE identity_profiles SET updated_at = $1, handle = $2, pds = $3 WHERE did = $4", 76 + ) 77 + .bind(now) 78 + .bind(handle) 79 + .bind(pds) 80 + .bind(did) 81 + .execute(tx.as_mut()) 82 + .await 83 + .map_err(StorageError::UnableToExecuteQuery)?; 82 84 } 83 85 84 86 tx.commit() ··· 106 108 107 109 let query = match &field { 108 110 HandleField::Language(_) => { 109 - "UPDATE handles SET language = $1, updated_at = $2 WHERE did = $3" 111 + "UPDATE identity_profiles SET language = $1, updated_at = $2 WHERE did = $3" 110 112 } 111 - HandleField::Timezone(_) => "UPDATE handles SET tz = $1, updated_at = $2 WHERE did = $3", 113 + HandleField::Timezone(_) => { 114 + "UPDATE identity_profiles SET tz = $1, updated_at = $2 WHERE did = $3" 115 + } 112 116 HandleField::ActiveNow => { 113 - "UPDATE handles SET active_at = $1, updated_at = $2 WHERE did = $3" 117 + "UPDATE identity_profiles SET active_at = $1, updated_at = $2 WHERE did = $3" 114 118 } 115 119 }; 116 120 ··· 140 144 .map_err(StorageError::CannotCommitDatabaseTransaction) 141 145 } 142 146 143 - pub async fn handle_for_did(pool: &StoragePool, did: &str) -> Result<Handle, StorageError> { 147 + pub async fn handle_for_did( 148 + pool: &StoragePool, 149 + did: &str, 150 + ) -> Result<IdentityProfile, StorageError> { 144 151 // Validate DID is not empty 145 152 if did.trim().is_empty() { 146 153 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( ··· 153 160 .await 154 161 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 155 162 156 - let entity = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1") 157 - .bind(did) 158 - .fetch_one(tx.as_mut()) 159 - .await 160 - .map_err(|err| match err { 161 - sqlx::Error::RowNotFound => StorageError::HandleNotFound, 162 - other => StorageError::UnableToExecuteQuery(other), 163 - })?; 163 + let entity = 164 + sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1") 165 + .bind(did) 166 + .fetch_one(tx.as_mut()) 167 + .await 168 + .map_err(|err| match err { 169 + sqlx::Error::RowNotFound => StorageError::HandleNotFound, 170 + other => StorageError::UnableToExecuteQuery(other), 171 + })?; 164 172 165 173 tx.commit() 166 174 .await ··· 169 177 Ok(entity) 170 178 } 171 179 172 - pub async fn handle_for_handle(pool: &StoragePool, handle: &str) -> Result<Handle, StorageError> { 180 + pub async fn handle_for_handle( 181 + pool: &StoragePool, 182 + handle: &str, 183 + ) -> Result<IdentityProfile, StorageError> { 173 184 // Validate handle is not empty 174 185 if handle.trim().is_empty() { 175 186 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( ··· 182 193 .await 183 194 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 184 195 185 - let entity = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE handle = $1") 186 - .bind(handle) 187 - .fetch_one(tx.as_mut()) 188 - .await 189 - .map_err(|err| match err { 190 - sqlx::Error::RowNotFound => StorageError::HandleNotFound, 191 - other => StorageError::UnableToExecuteQuery(other), 192 - })?; 196 + let entity = 197 + sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE handle = $1") 198 + .bind(handle) 199 + .fetch_one(tx.as_mut()) 200 + .await 201 + .map_err(|err| match err { 202 + sqlx::Error::RowNotFound => StorageError::HandleNotFound, 203 + other => StorageError::UnableToExecuteQuery(other), 204 + })?; 193 205 194 206 tx.commit() 195 207 .await ··· 202 214 pool: &StoragePool, 203 215 page: i64, 204 216 page_size: i64, 205 - ) -> Result<(i64, Vec<Handle>), StorageError> { 217 + ) -> Result<(i64, Vec<IdentityProfile>), StorageError> { 206 218 let mut tx = pool 207 219 .begin() 208 220 .await 209 221 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 210 222 211 - let total_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM handles") 223 + let total_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM identity_profiles") 212 224 .fetch_one(tx.as_mut()) 213 225 .await 214 226 .map_err(StorageError::UnableToExecuteQuery)?; 215 227 216 228 let offset = (page - 1) * page_size; 217 229 218 - let handles = sqlx::query_as::<_, Handle>( 219 - "SELECT * FROM handles ORDER BY updated_at DESC LIMIT $1 OFFSET $2", 230 + let handles = sqlx::query_as::<_, IdentityProfile>( 231 + "SELECT * FROM identity_profiles ORDER BY updated_at DESC LIMIT $1 OFFSET $2", 220 232 ) 221 233 .bind(page_size + 1) // Fetch one more to know if there are more entries 222 234 .bind(offset) ··· 256 268 .map_err(StorageError::CannotBeginDatabaseTransaction)?; 257 269 258 270 // Get handle information first 259 - let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1") 260 - .bind(did) 261 - .fetch_one(tx.as_mut()) 262 - .await 263 - .map_err(|err| match err { 264 - sqlx::Error::RowNotFound => StorageError::HandleNotFound, 265 - other => StorageError::UnableToExecuteQuery(other), 266 - })?; 271 + let handle = 272 + sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1") 273 + .bind(did) 274 + .fetch_one(tx.as_mut()) 275 + .await 276 + .map_err(|err| match err { 277 + sqlx::Error::RowNotFound => StorageError::HandleNotFound, 278 + other => StorageError::UnableToExecuteQuery(other), 279 + })?; 267 280 268 281 // Delete RSVPs created by this identity 269 282 sqlx::query("DELETE FROM rsvps WHERE did = $1") ··· 280 293 .map_err(StorageError::UnableToExecuteQuery)?; 281 294 282 295 // Delete the handle entry 283 - sqlx::query("DELETE FROM handles WHERE did = $1") 296 + sqlx::query("DELETE FROM identity_profiles WHERE did = $1") 284 297 .bind(did) 285 298 .execute(tx.as_mut()) 286 299 .await ··· 322 335 pub async fn handles_by_did( 323 336 pool: &StoragePool, 324 337 dids: Vec<String>, 325 - ) -> Result<HashMap<std::string::String, Handle>, StorageError> { 338 + ) -> Result<HashMap<std::string::String, IdentityProfile>, StorageError> { 326 339 if dids.is_empty() { 327 340 return Ok(HashMap::default()); 328 341 } ··· 343 356 344 357 // Build the query with placeholders 345 358 let mut query_builder: QueryBuilder<Postgres> = 346 - QueryBuilder::new("SELECT * FROM handles WHERE did IN ("); 359 + QueryBuilder::new("SELECT * FROM identity_profiles WHERE did IN ("); 347 360 let mut separated = query_builder.separated(", "); 348 361 for did in &dids { 349 362 separated.push_bind(did); ··· 351 364 separated.push_unseparated(") "); 352 365 353 366 // The query_builder.build() already includes the bindings, so we don't need to bind again 354 - let query = query_builder.build_query_as::<Handle>(); 367 + let query = query_builder.build_query_as::<IdentityProfile>(); 355 368 let values = query 356 369 .fetch_all(tx.as_mut()) 357 370 .await ··· 372 385 pub mod test { 373 386 use sqlx::PgPool; 374 387 375 - use crate::storage::handle::handle_for_did; 376 - use crate::storage::handle::handle_for_handle; 377 - use crate::storage::handle::handle_warm_up; 388 + use crate::storage::identity_profile::handle_for_did; 389 + use crate::storage::identity_profile::handle_for_handle; 378 390 379 391 #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 380 392 async fn test_handle_for_did(pool: PgPool) -> sqlx::Result<()> { ··· 394 406 assert!(!handle.is_err()); 395 407 let handle = handle.unwrap(); 396 408 assert_eq!(handle.did, "did:plc:d5c1ed6d01421a67b96f68fa"); 397 - 398 - Ok(()) 399 - } 400 - 401 - #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 402 - async fn test_handle_warm_up(pool: PgPool) -> sqlx::Result<()> { 403 - let did = "did:plc:f263c822655b579fc8a79635"; 404 - let handle = "inspiring-bobwhite.examplepds.com"; 405 - let updated_handle = "charming-needlefish.examplepds.com"; 406 - let pds = "https://pds.examplepds.com"; 407 - 408 - let warmup_result = handle_warm_up(&pool, did, handle, pds).await; 409 - assert!(!warmup_result.is_err()); 410 - 411 - { 412 - let handle = handle_for_handle(&pool, handle).await; 413 - assert!(!handle.is_err()); 414 - let handle = handle.unwrap(); 415 - assert_eq!(handle.did, did); 416 - } 417 - 418 - { 419 - let warmup_result = handle_warm_up(&pool, did, updated_handle, pds).await; 420 - assert!(!warmup_result.is_err()); 421 - } 422 - { 423 - let handle = handle_for_handle(&pool, handle).await; 424 - assert!(handle.is_err()); 425 - } 426 409 427 410 Ok(()) 428 411 }
+2 -1
src/storage/mod.rs
··· 1 + pub mod atproto; 1 2 pub mod cache; 2 3 pub mod denylist; 3 4 pub mod errors; 4 5 pub mod event; 5 - pub mod handle; 6 + pub mod identity_profile; 6 7 pub mod oauth; 7 8 pub mod types; 8 9
+14 -358
src/storage/oauth.rs
··· 1 1 use std::borrow::Cow; 2 2 3 3 use chrono::{DateTime, Utc}; 4 - use serde_json::json; 5 - 6 - use crate::{ 7 - jose::jwk::WrappedJsonWebKey, 8 - storage::{errors::StorageError, handle::model::Handle, StoragePool}, 9 - }; 10 - use model::{OAuthRequest, OAuthSession}; 11 - 12 - pub struct OAuthRequestParams { 13 - pub oauth_state: Cow<'static, str>, 14 - pub issuer: Cow<'static, str>, 15 - pub did: Cow<'static, str>, 16 - pub nonce: Cow<'static, str>, 17 - pub pkce_verifier: Cow<'static, str>, 18 - pub secret_jwk_id: Cow<'static, str>, 19 - pub dpop_jwk: Option<WrappedJsonWebKey>, 20 - pub destination: Option<Cow<'static, str>>, 21 - pub created_at: DateTime<Utc>, 22 - pub expires_at: DateTime<Utc>, 23 - } 24 - 25 - pub async fn oauth_request_insert( 26 - pool: &StoragePool, 27 - params: OAuthRequestParams, 28 - ) -> Result<(), StorageError> { 29 - // Validate required input parameters 30 - if params.oauth_state.trim().is_empty() { 31 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 32 - "OAuth state cannot be empty".into(), 33 - ))); 34 - } 35 - 36 - if params.issuer.trim().is_empty() { 37 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 38 - "Issuer cannot be empty".into(), 39 - ))); 40 - } 41 - 42 - if params.did.trim().is_empty() { 43 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 44 - "DID cannot be empty".into(), 45 - ))); 46 - } 47 - 48 - if params.nonce.trim().is_empty() { 49 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 50 - "Nonce cannot be empty".into(), 51 - ))); 52 - } 53 - 54 - if params.pkce_verifier.trim().is_empty() { 55 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 56 - "PKCE verifier cannot be empty".into(), 57 - ))); 58 - } 59 - 60 - if params.secret_jwk_id.trim().is_empty() { 61 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 62 - "Secret JWK ID cannot be empty".into(), 63 - ))); 64 - } 65 - 66 - let mut tx = pool 67 - .begin() 68 - .await 69 - .map_err(StorageError::CannotBeginDatabaseTransaction)?; 70 - 71 - let dpop_jwk_value = params 72 - .dpop_jwk 73 - .map(|jwk| json!(jwk)) 74 - .unwrap_or_else(|| json!({})); 75 - 76 - sqlx::query("INSERT INTO oauth_requests (oauth_state, issuer, did, nonce, pkce_verifier, secret_jwk_id, dpop_jwk, destination, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)") 77 - .bind(&params.oauth_state) 78 - .bind(&params.issuer) 79 - .bind(&params.did) 80 - .bind(&params.nonce) 81 - .bind(&params.pkce_verifier) 82 - .bind(&params.secret_jwk_id) 83 - .bind(dpop_jwk_value) 84 - .bind(params.destination) 85 - .bind(params.created_at) 86 - .bind(params.expires_at) 87 - .execute(tx.as_mut()) 88 - .await 89 - .map_err(StorageError::UnableToExecuteQuery)?; 90 - 91 - tx.commit() 92 - .await 93 - .map_err(StorageError::CannotCommitDatabaseTransaction) 94 - } 95 - 96 - pub async fn oauth_request_get( 97 - pool: &StoragePool, 98 - oauth_state: &str, 99 - ) -> Result<OAuthRequest, StorageError> { 100 - // Validate oauth_state is not empty 101 - if oauth_state.trim().is_empty() { 102 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 103 - "OAuth state cannot be empty".into(), 104 - ))); 105 - } 106 - 107 - let mut tx = pool 108 - .begin() 109 - .await 110 - .map_err(StorageError::CannotBeginDatabaseTransaction)?; 111 - 112 - let record = 113 - sqlx::query_as::<_, OAuthRequest>("SELECT * FROM oauth_requests WHERE oauth_state = $1") 114 - .bind(oauth_state) 115 - .fetch_one(tx.as_mut()) 116 - .await 117 - .map_err(|err| match err { 118 - sqlx::Error::RowNotFound => StorageError::OAuthRequestNotFound, 119 - other => StorageError::UnableToExecuteQuery(other), 120 - })?; 121 - 122 - tx.commit() 123 - .await 124 - .map_err(StorageError::CannotCommitDatabaseTransaction)?; 125 - 126 - Ok(record) 127 - } 128 - 129 - pub async fn oauth_request_remove( 130 - pool: &StoragePool, 131 - oauth_state: &str, 132 - ) -> Result<(), StorageError> { 133 - // Validate oauth_state is not empty 134 - if oauth_state.trim().is_empty() { 135 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 136 - "OAuth state cannot be empty".into(), 137 - ))); 138 - } 139 - 140 - let mut tx = pool 141 - .begin() 142 - .await 143 - .map_err(StorageError::CannotBeginDatabaseTransaction)?; 144 - 145 - sqlx::query("DELETE FROM oauth_requests WHERE oauth_state = $1") 146 - .bind(oauth_state) 147 - .execute(tx.as_mut()) 148 - .await 149 - .map_err(StorageError::UnableToExecuteQuery)?; 150 - 151 - tx.commit() 152 - .await 153 - .map_err(StorageError::CannotCommitDatabaseTransaction) 154 - } 155 - 156 - pub struct OAuthSessionParams { 157 - pub session_group: Cow<'static, str>, 158 - pub access_token: Cow<'static, str>, 159 - pub did: Cow<'static, str>, 160 - pub issuer: Cow<'static, str>, 161 - pub refresh_token: Cow<'static, str>, 162 - pub secret_jwk_id: Cow<'static, str>, 163 - pub dpop_jwk: WrappedJsonWebKey, 164 - pub created_at: DateTime<Utc>, 165 - pub access_token_expires_at: DateTime<Utc>, 166 - } 167 - 168 - pub async fn oauth_session_insert( 169 - pool: &StoragePool, 170 - params: OAuthSessionParams, 171 - ) -> Result<(), StorageError> { 172 - // Validate required input parameters 173 - if params.session_group.trim().is_empty() { 174 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 175 - "Session group cannot be empty".into(), 176 - ))); 177 - } 178 - 179 - if params.access_token.trim().is_empty() { 180 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 181 - "Access token cannot be empty".into(), 182 - ))); 183 - } 184 - 185 - if params.did.trim().is_empty() { 186 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 187 - "DID cannot be empty".into(), 188 - ))); 189 - } 190 4 191 - if params.issuer.trim().is_empty() { 192 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 193 - "Issuer cannot be empty".into(), 194 - ))); 195 - } 196 - 197 - if params.refresh_token.trim().is_empty() { 198 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 199 - "Refresh token cannot be empty".into(), 200 - ))); 201 - } 202 - 203 - if params.secret_jwk_id.trim().is_empty() { 204 - return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( 205 - "Secret JWK ID cannot be empty".into(), 206 - ))); 207 - } 208 - 209 - let mut tx = pool 210 - .begin() 211 - .await 212 - .map_err(StorageError::CannotBeginDatabaseTransaction)?; 213 - 214 - sqlx::query("INSERT INTO oauth_sessions (session_group, access_token, did, issuer, refresh_token, secret_jwk_id, dpop_jwk, created_at, access_token_expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)") 215 - .bind(&params.session_group) 216 - .bind(&params.access_token) 217 - .bind(&params.did) 218 - .bind(&params.issuer) 219 - .bind(&params.refresh_token) 220 - .bind(&params.secret_jwk_id) 221 - .bind(json!(params.dpop_jwk)) 222 - .bind(params.created_at) 223 - .bind(params.access_token_expires_at) 224 - .execute(tx.as_mut()) 225 - .await 226 - .map_err(StorageError::UnableToExecuteQuery)?; 227 - 228 - tx.commit() 229 - .await 230 - .map_err(StorageError::CannotCommitDatabaseTransaction) 231 - } 5 + use crate::storage::{errors::StorageError, identity_profile::model::IdentityProfile, StoragePool}; 6 + use model::OAuthSession; 232 7 233 8 pub async fn oauth_session_update( 234 9 pool: &StoragePool, ··· 308 83 pool: &StoragePool, 309 84 session_group: &str, 310 85 did: Option<&str>, 311 - ) -> Result<(Handle, OAuthSession), StorageError> { 86 + ) -> Result<(IdentityProfile, OAuthSession), StorageError> { 312 87 // Validate session_group is not empty 313 88 if session_group.trim().is_empty() { 314 89 return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol( ··· 356 131 357 132 let did_for_handle = did.unwrap_or(&oauth_session.did); 358 133 359 - let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1") 360 - .bind(did_for_handle) 361 - .fetch_one(tx.as_mut()) 362 - .await 363 - .map_err(|err| match err { 364 - sqlx::Error::RowNotFound => StorageError::HandleNotFound, 365 - other => StorageError::UnableToExecuteQuery(other), 366 - })?; 134 + let handle = 135 + sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1") 136 + .bind(did_for_handle) 137 + .fetch_one(tx.as_mut()) 138 + .await 139 + .map_err(|err| match err { 140 + sqlx::Error::RowNotFound => StorageError::HandleNotFound, 141 + other => StorageError::UnableToExecuteQuery(other), 142 + })?; 367 143 368 144 tx.commit() 369 145 .await ··· 373 149 } 374 150 375 151 pub mod model { 376 - use anyhow::Error; 377 152 use chrono::{DateTime, Utc}; 378 - use p256::SecretKey; 379 153 use serde::Deserialize; 380 154 use sqlx::FromRow; 381 155 382 - use crate::{ 383 - atproto::auth::SimpleOAuthSessionProvider, jose::jwk::WrappedJsonWebKey, 384 - storage::errors::OAuthModelError, 385 - }; 386 - 387 156 #[derive(Clone, FromRow, Deserialize)] 388 157 pub struct OAuthRequest { 389 158 pub oauth_state: String, ··· 393 162 pub pkce_verifier: String, 394 163 pub secret_jwk_id: String, 395 164 pub destination: Option<String>, 396 - pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>, 165 + pub dpop_jwk: String, 397 166 pub created_at: DateTime<Utc>, 398 167 pub expires_at: DateTime<Utc>, 399 168 } 400 169 401 - pub struct OAuthRequestState { 402 - pub state: String, 403 - pub nonce: String, 404 - pub code_challenge: String, 405 - } 406 - 407 170 #[derive(Clone, FromRow, Deserialize)] 408 171 pub struct OAuthSession { 409 172 pub session_group: String, ··· 412 175 pub issuer: String, 413 176 pub refresh_token: String, 414 177 pub secret_jwk_id: String, 415 - pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>, 178 + pub dpop_jwk: String, 416 179 pub created_at: DateTime<Utc>, 417 180 pub access_token_expires_at: DateTime<Utc>, 418 - } 419 - 420 - impl TryFrom<OAuthSession> for SimpleOAuthSessionProvider { 421 - type Error = Error; 422 - 423 - fn try_from(value: OAuthSession) -> Result<Self, Self::Error> { 424 - let dpop_secret = SecretKey::from_jwk(&value.dpop_jwk.jwk) 425 - .map_err(OAuthModelError::DpopSecretFromJwkFailed)?; 426 - 427 - Ok(SimpleOAuthSessionProvider { 428 - access_token: value.access_token, 429 - issuer: value.issuer, 430 - dpop_secret, 431 - }) 432 - } 433 - } 434 - } 435 - 436 - #[cfg(test)] 437 - pub mod test { 438 - use sqlx::PgPool; 439 - 440 - use crate::{ 441 - jose, 442 - storage::oauth::{ 443 - oauth_request_get, oauth_request_insert, oauth_request_remove, oauth_session_insert, 444 - web_session_lookup, OAuthRequestParams, OAuthSessionParams, 445 - }, 446 - }; 447 - 448 - #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 449 - async fn test_oauth_request(pool: PgPool) -> anyhow::Result<()> { 450 - let dpop_jwk = jose::jwk::generate(); 451 - let created_at = chrono::Utc::now(); 452 - let expires_at = created_at + chrono::Duration::seconds(60 as i64); 453 - 454 - let res = oauth_request_insert( 455 - &pool, 456 - OAuthRequestParams { 457 - oauth_state: "oauth_state".to_string().into(), 458 - issuer: "pds.examplepds.com".to_string().into(), 459 - did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(), 460 - nonce: "nonce".to_string().into(), 461 - pkce_verifier: "pkce_verifier".to_string().into(), 462 - secret_jwk_id: "secret_jwk_id".to_string().into(), 463 - dpop_jwk: Some(dpop_jwk.clone()), 464 - destination: None, 465 - created_at, 466 - expires_at, 467 - }, 468 - ) 469 - .await; 470 - 471 - assert!(!res.is_err()); 472 - 473 - let oauth_request = oauth_request_get(&pool, "oauth_state").await; 474 - assert!(!oauth_request.is_err()); 475 - let oauth_request = oauth_request.unwrap(); 476 - 477 - assert_eq!(oauth_request.did, "did:plc:d5c1ed6d01421a67b96f68fa"); 478 - assert_eq!(oauth_request.dpop_jwk.as_ref(), &dpop_jwk); 479 - 480 - let res = oauth_request_remove(&pool, "oauth_state").await; 481 - assert!(!res.is_err()); 482 - 483 - { 484 - let oauth_request = oauth_request_get(&pool, "oauth_state").await; 485 - assert!(oauth_request.is_err()); 486 - } 487 - 488 - Ok(()) 489 - } 490 - 491 - #[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))] 492 - async fn test_oauth_session(pool: PgPool) -> anyhow::Result<()> { 493 - let dpop_jwk = jose::jwk::generate(); 494 - 495 - let session_group = ulid::Ulid::new().to_string(); 496 - let now = chrono::Utc::now(); 497 - 498 - let insert_session_res = oauth_session_insert( 499 - &pool, 500 - OAuthSessionParams { 501 - session_group: session_group.clone().into(), 502 - access_token: "access_token".to_string().into(), 503 - did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(), 504 - issuer: "pds.examplepds.com".to_string().into(), 505 - refresh_token: "refresh_token".to_string().into(), 506 - secret_jwk_id: "secret_jwk_id".to_string().into(), 507 - dpop_jwk: dpop_jwk.clone(), 508 - created_at: now, 509 - access_token_expires_at: now + chrono::Duration::seconds(60 as i64), 510 - }, 511 - ) 512 - .await; 513 - 514 - assert!(!insert_session_res.is_err()); 515 - 516 - let web_session = web_session_lookup( 517 - &pool, 518 - &session_group, 519 - Some("did:plc:d5c1ed6d01421a67b96f68fa"), 520 - ) 521 - .await; 522 - assert!(!web_session.is_err()); 523 - 524 - Ok(()) 525 181 } 526 182 }
+142
src/task_identity_refresh.rs
··· 1 + use anyhow::Result; 2 + use atproto_identity::{resolve::IdentityResolver, storage::DidDocumentStorage}; 3 + use chrono::Duration; 4 + use sqlx::FromRow; 5 + use tokio::time::{sleep, Instant}; 6 + use tokio_util::sync::CancellationToken; 7 + 8 + use crate::storage::StoragePool; 9 + 10 + pub struct IdentityRefreshTaskConfig { 11 + pub sleep_interval: Duration, 12 + pub worker_id: String, 13 + } 14 + 15 + pub struct IdentityRefreshTask { 16 + pub config: IdentityRefreshTaskConfig, 17 + pub storage_pool: StoragePool, 18 + pub document_storage: std::sync::Arc<dyn DidDocumentStorage>, 19 + pub identity_resolver: IdentityResolver, 20 + pub cancellation_token: CancellationToken, 21 + } 22 + 23 + #[derive(FromRow)] 24 + struct ExpiredDidDocument { 25 + did: String, 26 + } 27 + 28 + impl IdentityRefreshTask { 29 + #[must_use] 30 + pub fn new( 31 + config: IdentityRefreshTaskConfig, 32 + storage_pool: StoragePool, 33 + document_storage: std::sync::Arc<dyn DidDocumentStorage>, 34 + identity_resolver: IdentityResolver, 35 + cancellation_token: CancellationToken, 36 + ) -> Self { 37 + Self { 38 + config, 39 + storage_pool, 40 + document_storage, 41 + identity_resolver, 42 + cancellation_token, 43 + } 44 + } 45 + 46 + /// Runs the identity refresh task as a long-running process 47 + /// 48 + /// # Errors 49 + /// Returns an error if the sleep interval cannot be converted, or if there's a problem 50 + /// processing the expired DID documents 51 + pub async fn run(&self) -> Result<()> { 52 + tracing::debug!("IdentityRefreshTask started"); 53 + 54 + let interval = self.config.sleep_interval.to_std()?; 55 + 56 + let sleeper = sleep(interval); 57 + tokio::pin!(sleeper); 58 + 59 + loop { 60 + tokio::select! { 61 + () = self.cancellation_token.cancelled() => { 62 + break; 63 + }, 64 + () = &mut sleeper => { 65 + if let Err(err) = self.process_expired_documents().await { 66 + tracing::error!("IdentityRefreshTask failed: {}", err); 67 + } 68 + sleeper.as_mut().reset(Instant::now() + interval); 69 + } 70 + } 71 + } 72 + 73 + tracing::info!("IdentityRefreshTask stopped"); 74 + 75 + Ok(()) 76 + } 77 + 78 + async fn process_expired_documents(&self) -> Result<i32> { 79 + // Find DID documents that have expired in a separate transaction 80 + let expired_docs = { 81 + let mut tx = self.storage_pool.begin().await?; 82 + let docs = sqlx::query_as::<_, ExpiredDidDocument>( 83 + "SELECT did FROM did_documents WHERE expires_at IS NOT NULL AND expires_at <= NOW() LIMIT 50" 84 + ) 85 + .fetch_all(tx.as_mut()) 86 + .await?; 87 + tx.commit().await?; 88 + docs 89 + }; 90 + 91 + let count = expired_docs.len() as i32; 92 + 93 + if count == 0 { 94 + return Ok(0); 95 + } 96 + 97 + tracing::info!(count = count, "processing expired DID documents"); 98 + 99 + for expired_doc in expired_docs { 100 + tracing::debug!(did = expired_doc.did, "refreshing expired DID document"); 101 + 102 + match self.refresh_did_document(&expired_doc.did).await { 103 + Ok(()) => { 104 + tracing::debug!(did = expired_doc.did, "successfully refreshed DID document"); 105 + } 106 + Err(err) => { 107 + tracing::warn!( 108 + did = expired_doc.did, 109 + error = ?err, 110 + "failed to refresh DID document, deleting from storage" 111 + ); 112 + 113 + // If we can't resolve the DID, delete it from storage 114 + if let Err(delete_err) = self 115 + .document_storage 116 + .delete_document_by_did(&expired_doc.did) 117 + .await 118 + { 119 + tracing::error!( 120 + did = expired_doc.did, 121 + error = ?delete_err, 122 + "failed to delete expired DID document" 123 + ); 124 + } 125 + } 126 + } 127 + } 128 + 129 + Ok(count) 130 + } 131 + 132 + async fn refresh_did_document(&self, did: &str) -> Result<()> { 133 + // Use the identity resolver to get the updated DID document 134 + let document = self.identity_resolver.resolve(did).await?; 135 + 136 + // Store the updated document using the DidDocumentStorage trait 137 + // This will reset the expires_at column based on the storage implementation 138 + self.document_storage.store_document(document).await?; 139 + 140 + Ok(()) 141 + } 142 + }
+87
src/task_oauth_requests_cleanup.rs
··· 1 + use anyhow::Result; 2 + use chrono::Duration; 3 + use tokio::time::{sleep, Instant}; 4 + use tokio_util::sync::CancellationToken; 5 + 6 + use crate::storage::StoragePool; 7 + 8 + pub struct OAuthRequestsCleanupTaskConfig { 9 + pub sleep_interval: Duration, 10 + } 11 + 12 + pub struct OAuthRequestsCleanupTask { 13 + pub config: OAuthRequestsCleanupTaskConfig, 14 + pub storage_pool: StoragePool, 15 + pub cancellation_token: CancellationToken, 16 + } 17 + 18 + impl OAuthRequestsCleanupTask { 19 + #[must_use] 20 + pub fn new( 21 + config: OAuthRequestsCleanupTaskConfig, 22 + storage_pool: StoragePool, 23 + cancellation_token: CancellationToken, 24 + ) -> Self { 25 + Self { 26 + config, 27 + storage_pool, 28 + cancellation_token, 29 + } 30 + } 31 + 32 + /// Runs the OAuth requests cleanup task as a long-running process 33 + /// 34 + /// # Errors 35 + /// Returns an error if the sleep interval cannot be converted, or if there's a problem 36 + /// cleaning up expired requests 37 + pub async fn run(&self) -> Result<()> { 38 + tracing::debug!("OAuthRequestsCleanupTask started"); 39 + 40 + let interval = self.config.sleep_interval.to_std()?; 41 + 42 + let sleeper = sleep(interval); 43 + tokio::pin!(sleeper); 44 + 45 + loop { 46 + tokio::select! { 47 + () = self.cancellation_token.cancelled() => { 48 + break; 49 + }, 50 + () = &mut sleeper => { 51 + if let Err(err) = self.cleanup_expired_requests().await { 52 + tracing::error!("OAuthRequestsCleanupTask failed: {}", err); 53 + } 54 + sleeper.as_mut().reset(Instant::now() + interval); 55 + } 56 + } 57 + } 58 + 59 + tracing::info!("OAuthRequestsCleanupTask stopped"); 60 + 61 + Ok(()) 62 + } 63 + 64 + async fn cleanup_expired_requests(&self) -> Result<()> { 65 + let now = chrono::Utc::now(); 66 + 67 + tracing::debug!("Starting cleanup of expired OAuth requests"); 68 + 69 + let result = sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at < $1") 70 + .bind(now) 71 + .execute(&self.storage_pool) 72 + .await?; 73 + 74 + let deleted_count = result.rows_affected(); 75 + 76 + if deleted_count > 0 { 77 + tracing::info!( 78 + deleted_count = deleted_count, 79 + "Cleaned up expired OAuth requests" 80 + ); 81 + } else { 82 + tracing::debug!("No expired OAuth requests to clean up"); 83 + } 84 + 85 + Ok(()) 86 + } 87 + }
+34 -17
src/task_refresh_tokens.rs
··· 1 1 use anyhow::Result; 2 + use atproto_identity::key::identify_key; 3 + use atproto_oauth::workflow::{oauth_refresh, OAuthClient}; 2 4 use chrono::{Duration, Utc}; 3 5 use deadpool_redis::redis::{pipe, AsyncCommands}; 4 - use p256::SecretKey; 5 6 use std::borrow::Cow; 6 7 use tokio::time::{sleep, Instant}; 7 8 use tokio_util::sync::CancellationToken; 8 9 9 10 use crate::{ 10 - config::{OAuthActiveKeys, SigningKeys}, 11 - oauth::client_oauth_refresh, 11 + config::SigningKeys, 12 12 refresh_tokens_errors::RefreshError, 13 13 storage::{ 14 14 cache::{build_worker_queue, OAUTH_REFRESH_HEARTBEATS, OAUTH_REFRESH_QUEUE}, ··· 22 22 pub worker_id: String, 23 23 pub external_url_base: String, 24 24 pub signing_keys: SigningKeys, 25 - pub oauth_active_keys: OAuthActiveKeys, 26 25 } 27 26 28 27 pub struct RefreshTokensTask { ··· 30 29 pub http_client: reqwest::Client, 31 30 pub storage_pool: StoragePool, 32 31 pub cache_pool: CachePool, 32 + pub document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>, 33 33 pub cancellation_token: CancellationToken, 34 34 } 35 35 ··· 40 40 http_client: reqwest::Client, 41 41 storage_pool: StoragePool, 42 42 cache_pool: CachePool, 43 + document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>, 43 44 cancellation_token: CancellationToken, 44 45 ) -> Self { 45 46 Self { ··· 47 48 http_client, 48 49 storage_pool, 49 50 cache_pool, 51 + document_storage, 50 52 cancellation_token, 51 53 } 52 54 } ··· 180 182 let (handle, oauth_session) = 181 183 web_session_lookup(&self.storage_pool, session_group, None).await?; 182 184 183 - let secret_signing_key = self 185 + let secret_signing_key_string = self 184 186 .config 185 187 .signing_keys 186 188 .as_ref() 187 189 .get(&oauth_session.secret_jwk_id) 188 - .cloned(); 190 + .cloned() 191 + .ok_or_else(|| anyhow::Error::from(RefreshError::SecretSigningKeyNotFound))?; 192 + 193 + let private_signing_key_data = identify_key(&secret_signing_key_string)?; 194 + 195 + let private_dpop_key_data = identify_key(&oauth_session.dpop_jwk)?; 189 196 190 - if secret_signing_key.is_none() { 191 - return Err(RefreshError::SecretSigningKeyNotFound.into()); 192 - } 197 + let document = match self 198 + .document_storage 199 + .get_document_by_did(&handle.did) 200 + .await? 201 + { 202 + Some(doc) => doc, 203 + None => return Err(RefreshError::IdentityDocumentNotFound.into()), 204 + }; 193 205 194 - let dpop_secret_key = SecretKey::from_jwk(&oauth_session.dpop_jwk.jwk) 195 - .map_err(RefreshError::DpopProofCreationFailed)?; 206 + let oauth_client = OAuthClient { 207 + redirect_uri: format!("https://{}/oauth/callback", self.config.external_url_base), 208 + client_id: format!( 209 + "https://{}/oauth/client-metadata.json", 210 + self.config.external_url_base 211 + ), 212 + private_signing_key_data, 213 + }; 196 214 197 - let token_response = client_oauth_refresh( 215 + let token_response = oauth_refresh( 198 216 &self.http_client, 199 - &self.config.external_url_base, 200 - (&oauth_session.secret_jwk_id, secret_signing_key.unwrap()), 201 - oauth_session.refresh_token.as_str(), 202 - &handle, 203 - &dpop_secret_key, 217 + &oauth_client, 218 + &private_dpop_key_data, 219 + &oauth_session.refresh_token, 220 + &document, 204 221 ) 205 222 .await?; 206 223
-199
src/validation.rs
··· 1 - //! Validation module that provides utilities for validating hostnames and AT Protocol handles. 2 - //! 3 - //! This module implements RFC-compliant hostname validation and AT Protocol handle formatting rules. 4 - 5 - /// Maximum length for a valid hostname as defined in RFC 1035 6 - const MAX_HOSTNAME_LENGTH: usize = 253; 7 - 8 - /// Maximum length for a DNS label (component between dots) as defined in RFC 1035 9 - const MAX_LABEL_LENGTH: usize = 63; 10 - 11 - /// List of reserved top-level domains that are not valid for AT Protocol handles 12 - const RESERVED_TLDS: [&str; 4] = [".localhost", ".internal", ".arpa", ".local"]; 13 - 14 - /// Validates if a string is a valid hostname according to RFC standards. 15 - /// 16 - /// A valid hostname must: 17 - /// - Only contain alphanumeric characters, hyphens, and periods 18 - /// - Not start or end labels with hyphens 19 - /// - Have labels (parts between dots) with length between 1-63 characters 20 - /// - Have total length not exceeding 253 characters 21 - /// - Not use reserved top-level domains 22 - /// 23 - /// # Arguments 24 - /// * `hostname` - The hostname string to validate 25 - /// 26 - /// # Returns 27 - /// * `true` if the hostname is valid, `false` otherwise 28 - #[must_use] 29 - pub fn is_valid_hostname(hostname: &str) -> bool { 30 - // Empty hostnames are invalid 31 - if hostname.is_empty() || hostname.len() > MAX_HOSTNAME_LENGTH { 32 - return false; 33 - } 34 - 35 - // Check if hostname uses any reserved TLDs 36 - if RESERVED_TLDS.iter().any(|tld| hostname.ends_with(tld)) { 37 - return false; 38 - } 39 - 40 - // Ensure all characters are valid hostname characters 41 - if hostname.bytes().any(|byte| !is_valid_hostname_char(byte)) { 42 - return false; 43 - } 44 - 45 - // Validate each DNS label in the hostname 46 - if hostname.split('.').any(|label| !is_valid_dns_label(label)) { 47 - return false; 48 - } 49 - 50 - true 51 - } 52 - 53 - /// Checks if a byte is a valid character in a hostname. 54 - /// 55 - /// Valid characters are: a-z, A-Z, 0-9, hyphen (-), and period (.) 56 - /// 57 - /// # Arguments 58 - /// * `byte` - The byte to check 59 - /// 60 - /// # Returns 61 - /// * `true` if the byte is a valid hostname character, `false` otherwise 62 - fn is_valid_hostname_char(byte: u8) -> bool { 63 - byte.is_ascii_lowercase() 64 - || byte.is_ascii_uppercase() 65 - || byte.is_ascii_digit() 66 - || byte == b'-' 67 - || byte == b'.' 68 - } 69 - 70 - /// Validates if a DNS label is valid according to RFC standards. 71 - /// 72 - /// A valid DNS label must: 73 - /// - Not be empty 74 - /// - Not exceed 63 characters 75 - /// - Not start or end with a hyphen 76 - /// 77 - /// # Arguments 78 - /// * `label` - The DNS label to validate 79 - /// 80 - /// # Returns 81 - /// * `true` if the label is valid, `false` otherwise 82 - fn is_valid_dns_label(label: &str) -> bool { 83 - !(label.is_empty() 84 - || label.len() > MAX_LABEL_LENGTH 85 - || label.starts_with('-') 86 - || label.ends_with('-')) 87 - } 88 - 89 - /// Validates and normalizes an AT Protocol handle. 90 - /// 91 - /// A valid AT Protocol handle must: 92 - /// - Be a valid hostname (after stripping any prefixes) 93 - /// - Contain at least one period (.) 94 - /// - Can optionally have "at://" or "@" prefix, which will be removed 95 - /// 96 - /// # Arguments 97 - /// * `handle` - The handle string to validate 98 - /// 99 - /// # Returns 100 - /// * `Some(String)` containing the normalized handle if valid 101 - /// * `None` if the handle is invalid 102 - #[must_use] 103 - pub fn is_valid_handle(handle: &str) -> Option<String> { 104 - // Strip optional prefixes to get the core handle 105 - let trimmed = strip_handle_prefixes(handle); 106 - 107 - // A valid handle must be a valid hostname with at least one period 108 - if is_valid_hostname(trimmed) && trimmed.contains('.') { 109 - Some(trimmed.to_string()) 110 - } else { 111 - None 112 - } 113 - } 114 - 115 - /// Strips common AT Protocol handle prefixes. 116 - /// 117 - /// Removes "at://" or "@" prefix if present. 118 - /// 119 - /// # Arguments 120 - /// * `handle` - The handle to strip prefixes from 121 - /// 122 - /// # Returns 123 - /// * The handle with prefixes removed 124 - fn strip_handle_prefixes(handle: &str) -> &str { 125 - if let Some(value) = handle.strip_prefix("at://") { 126 - value 127 - } else if let Some(value) = handle.strip_prefix('@') { 128 - value 129 - } else { 130 - handle 131 - } 132 - } 133 - 134 - #[cfg(test)] 135 - mod tests { 136 - use super::*; 137 - 138 - #[test] 139 - fn test_valid_hostnames() { 140 - // Valid hostnames 141 - assert!(is_valid_hostname("example.com")); 142 - assert!(is_valid_hostname("subdomain.example.com")); 143 - assert!(is_valid_hostname("with-hyphen.example.com")); 144 - assert!(is_valid_hostname("123numeric.example.com")); 145 - assert!(is_valid_hostname("xn--bcher-kva.example.com")); // IDN 146 - } 147 - 148 - #[test] 149 - fn test_invalid_hostnames() { 150 - // Invalid hostnames 151 - assert!(!is_valid_hostname("")); // Empty 152 - assert!(!is_valid_hostname("a".repeat(254).as_str())); // Too long 153 - assert!(!is_valid_hostname("example.localhost")); // Reserved TLD 154 - assert!(!is_valid_hostname("example.internal")); // Reserved TLD 155 - assert!(!is_valid_hostname("example.arpa")); // Reserved TLD 156 - assert!(!is_valid_hostname("example.local")); // Reserved TLD 157 - assert!(!is_valid_hostname("invalid_char.example.com")); // Invalid underscore 158 - assert!(!is_valid_hostname("-starts-with-hyphen.example.com")); // Label starts with hyphen 159 - assert!(!is_valid_hostname("ends-with-hyphen-.example.com")); // Label ends with hyphen 160 - assert!(!is_valid_hostname(&("a".repeat(64) + ".example.com"))); // Label too long 161 - assert!(!is_valid_hostname(".starts.with.dot")); // Empty label 162 - assert!(!is_valid_hostname("ends.with.dot.")); // Empty label 163 - assert!(!is_valid_hostname("double..dot")); // Empty label 164 - } 165 - 166 - #[test] 167 - fn test_valid_handles() { 168 - // Valid handles 169 - assert_eq!( 170 - is_valid_handle("user.example.com"), 171 - Some("user.example.com".to_string()) 172 - ); 173 - assert_eq!( 174 - is_valid_handle("at://user.example.com"), 175 - Some("user.example.com".to_string()) 176 - ); 177 - assert_eq!( 178 - is_valid_handle("@user.example.com"), 179 - Some("user.example.com".to_string()) 180 - ); 181 - } 182 - 183 - #[test] 184 - fn test_invalid_handles() { 185 - // Invalid handles 186 - assert_eq!(is_valid_handle("nodots"), None); // No dots 187 - assert_eq!(is_valid_handle("at://invalid_char.example.com"), None); // Invalid character 188 - assert_eq!(is_valid_handle("@example.localhost"), None); // Reserved TLD 189 - } 190 - 191 - #[test] 192 - fn test_strip_handle_prefixes() { 193 - assert_eq!(strip_handle_prefixes("example.com"), "example.com"); 194 - assert_eq!(strip_handle_prefixes("at://example.com"), "example.com"); 195 - assert_eq!(strip_handle_prefixes("@example.com"), "example.com"); 196 - // Nested prefixes should only strip the outermost one 197 - assert_eq!(strip_handle_prefixes("at://@example.com"), "@example.com"); 198 - } 199 - }
+13
templates/delete_event.en-us.html
··· 1 + {% extends "base.en-us.html" %} 2 + {% block title %}Smoke Signal - Delete Event{% endblock %} 3 + {% block head %}{% endblock %} 4 + {% block content %} 5 + 6 + {% from "form_include.html" import text_input %} 7 + <section class="section is-fullheight"> 8 + <div class="container "> 9 + {% include 'delete_event.en-us.partial.html' %} 10 + </div> 11 + </section> 12 + 13 + {% endblock %}
+23
templates/delete_event.en-us.partial.html
··· 1 + <div class="box content" style="background-color: #ffe0e6; border-color: #ff1744;"> 2 + <h2>Danger Zone</h2> 3 + <p>This will delete the event from your Personal Data Server (PDS) as well as this Smoke Signal instance.</p> 4 + <ol> 5 + <li>Deleting records cannot be undone.</li> 6 + <li>If you are making changes to the event, please consider just changing the event status, mode, and time.</li> 7 + <li>Existing RSVP records will display "Unknown Event", possibly causing confusion to those who RSVP'd to the event. 8 + </li> 9 + </ol> 10 + <form action="{{ delete_event_url }}" method="post"> 11 + {% if show_confirm %} 12 + <div class="field"> 13 + <div class="control"> 14 + <label class="checkbox"> 15 + <input type="checkbox" name="confirm" value="true" required> 16 + <strong>I understand that deleting this event cannot be undone</strong> 17 + </label> 18 + </div> 19 + </div> 20 + {% endif %} 21 + <button type="submit" class="button is-danger">Delete Event</button> 22 + </form> 23 + </div>
+4
templates/edit_event.en-us.common.html
··· 8 8 {% include 'create_event.en-us.partial.html' %} 9 9 </div> 10 10 11 + {% if delete_event_url %} 12 + {% include 'delete_event.en-us.partial.html' %} 13 + {% endif %} 14 + 11 15 </div> 12 16 </section> 13 17 <script>
+3 -3
templates/footer.en-us.html
··· 1 1 <footer class="footer"> 2 2 <div class="container content has-text-centered"> 3 3 <p> 4 - <strong>Smoke Signal Events</strong> made by <a href="https://ngerakines.me/">Nick Gerakines</a> 5 - (<a href="https://github.com/ngerakines">Source Code</a>) 4 + <strong>Smoke Signal Events</strong> made by <a href="https://bsky.app/profile/ngerakines.me">@ngerakines.me</a> 5 + (<a href="https://github.com/ngerakines">source</a>) 6 6 </p> 7 7 <nav class="level"> 8 8 <div class="level-item has-text-centered"> 9 - <a href="https://docs.smokesignal.events/">Support</a> 9 + <a href="https://discourse.smokesignal.events/">Support</a> 10 10 </div> 11 11 <div class="level-item has-text-centered"> 12 12 <a href="/privacy-policy" hx-boost="true">Privacy Policy</a>
+1 -1
templates/nav.en-us.html
··· 47 47 <span>Your Profile</span> 48 48 </a> 49 49 <a class="button is-danger is-light" 50 - href="/logout">Log out</a> 50 + href="/logout" hx-boost="false">Log out</a> 51 51 {% else %} 52 52 <a class="button is-primary" href="/oauth/login" hx-boost="true">Log in</a> 53 53 {% endif %}