+192
CLAUDE.md
+192
CLAUDE.md
···
1
+
# CLAUDE.md
2
+
3
+
This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository.
4
+
5
+
## Project Overview
6
+
7
+
Smokesignal is a Rust-based event and RSVP management application built for the AT Protocol ecosystem. It provides decentralized identity management, OAuth authentication, and event coordination through the AT Protocol. The application supports both standard AT Protocol OAuth and AIP (AT Protocol Improvement Proposal) OAuth flows, with backend selection determined by runtime configuration.
8
+
9
+
## Tech Stack
10
+
11
+
- **Language**: Rust (edition 2021, minimum version 1.83)
12
+
- **Web Framework**: Axum with async/await
13
+
- **Database**: PostgreSQL with SQLx migrations
14
+
- **Caching**: Redis/Valkey for sessions and token management
15
+
- **Templates**: MiniJinja (server-side rendering with optional reloading)
16
+
- **Frontend**: HTMX + Bulma CSS + FontAwesome
17
+
- **Authentication**: AT Protocol OAuth with JOSE/JWT (P-256 ECDSA)
18
+
- **Internationalization**: Fluent for i18n support
19
+
- **Static Assets**: Rust Embed for production builds
20
+
- **HTTP Client**: Reqwest with middleware chain, retry logic, and compression
21
+
- **Cryptography**: P-256 ECDSA with elliptic-curve and p256 crates
22
+
- **Task Management**: Tokio with task tracking and graceful shutdown
23
+
- **DNS Resolution**: Hickory resolver with DoH/DoT support
24
+
25
+
## Development Commands
26
+
27
+
### Building & Development
28
+
```bash
29
+
# Development build with template reloading (default)
30
+
cargo build --bin smokesignal
31
+
32
+
# Production build with embedded templates
33
+
cargo build --bin smokesignal --no-default-features -F embed
34
+
35
+
# Type checking and linting
36
+
cargo check
37
+
cargo clippy
38
+
39
+
# Run tests
40
+
cargo test
41
+
```
42
+
43
+
### Running Services
44
+
```bash
45
+
# Start main application server (development mode)
46
+
cargo run --bin smokesignal
47
+
48
+
# Generate cryptographic keys
49
+
cargo run --bin crypto -- key
50
+
51
+
```
52
+
53
+
### Database Operations
54
+
```bash
55
+
# Run database migrations
56
+
sqlx migrate run
57
+
58
+
# Reset database with migrations
59
+
sqlx database reset
60
+
```
61
+
62
+
## Architecture Overview
63
+
64
+
### Core Structure
65
+
```
66
+
/src/
67
+
├── atproto/ # AT Protocol integration
68
+
│ ├── auth.rs # Authentication utilities
69
+
│ └── lexicon/ # AT Protocol lexicon definitions
70
+
├── bin/ # Executables
71
+
│ ├── smokesignal.rs # Main application server
72
+
│ └── crypto.rs # Cryptographic key utilities
73
+
├── http/ # Web layer
74
+
│ ├── errors/ # HTTP error handling
75
+
│ ├── handle_*.rs # Route handlers
76
+
│ ├── middleware_*.rs # Request middleware
77
+
│ └── server.rs # Server configuration
78
+
├── storage/ # Data access layer
79
+
│ ├── cache.rs # Redis caching
80
+
│ ├── identity_profile.rs # User identity management
81
+
│ ├── event.rs # Event data operations
82
+
│ ├── oauth.rs # OAuth session management
83
+
│ └── denylist.rs # Handle blocking
84
+
├── config.rs # Environment configuration
85
+
├── key_provider.rs # JWT key management
86
+
├── i18n.rs # Internationalization
87
+
└── task_refresh_tokens.rs # Background token refresh
88
+
```
89
+
90
+
### Key Patterns
91
+
- **Layered Architecture**: HTTP → Business Logic → Data Access
92
+
- **Module-based Organization**: Each handler in separate module with error types
93
+
- **Dependency Injection**: Services passed through application context
94
+
- **Template-driven UI**: Server-side rendering with optional reloading
95
+
96
+
### Database Schema
97
+
- `identity_profiles` - User identity and preferences (formerly `handles`)
98
+
- `oauth_requests` - OAuth flow state with PKCE and DPoP support
99
+
- `oauth_sessions` - Active sessions with token management
100
+
- `events` - Calendar events with location and timezone support
101
+
- `rsvps` - Event responses and attendance tracking
102
+
- `denylist` - Handle blocking for moderation
103
+
104
+
## Configuration Requirements
105
+
106
+
### Required Environment Variables
107
+
- `HTTP_COOKIE_KEY` - 64-character hex key for session encryption
108
+
- `DATABASE_URL` - PostgreSQL connection string
109
+
- `REDIS_URL` - Redis/Valkey connection string
110
+
- `EXTERNAL_BASE` - Public base URL (e.g., https://smokesignal.events)
111
+
- `PLC_HOSTNAME` - AT Protocol PLC server hostname
112
+
- `ADMIN_DIDS` - Comma-separated admin user DIDs
113
+
- `DESTINATION_KEY` - Private key for signing authentication redirect tokens
114
+
115
+
### OAuth Backend Configuration
116
+
- `OAUTH_BACKEND` - OAuth backend to use: "atprotocol" (default) or "aip"
117
+
118
+
When `OAUTH_BACKEND=atprotocol`, the following variable is required:
119
+
- `SIGNING_KEYS` - Path to JWK key set file for OAuth signing
120
+
121
+
When `OAUTH_BACKEND=aip`, the following variables are required:
122
+
- `AIP_HOSTNAME` - AIP OAuth server hostname
123
+
- `AIP_CLIENT_ID` - AIP OAuth client ID
124
+
- `AIP_CLIENT_SECRET` - AIP OAuth client secret
125
+
126
+
### Optional Environment Variables
127
+
- `RUST_LOG` - Logging configuration (default: info)
128
+
- `PORT` - Server port (default: 3000)
129
+
- `BIND_ADDR` - Bind address (default: 0.0.0.0)
130
+
131
+
## Development Setup
132
+
133
+
### Using DevContainer (Recommended)
134
+
The project includes a complete DevContainer setup with PostgreSQL, Redis, and all Rust dependencies. Use the provided `.devcontainer/` configuration for consistent development environment.
135
+
136
+
### Key Development Features
137
+
- Template reloading in development mode (default `reload` feature)
138
+
- Embedded templates for production builds (`embed` feature)
139
+
- SQLx compile-time query checking with migrations
140
+
- Cryptographic key generation and management utilities
141
+
- Comprehensive error handling with user-friendly messages
142
+
- Internationalization support with Fluent
143
+
- Background token refresh for OAuth sessions
144
+
- LRU caching for OAuth requests and DID documents
145
+
- Graceful shutdown with task tracking and cancellation tokens
146
+
147
+
## Testing
148
+
149
+
- Unit tests embedded in source files using `#[cfg(test)]`
150
+
- Integration tests for database operations and HTTP handlers
151
+
- Run with `cargo test`
152
+
- Database tests require running PostgreSQL instance with test database
153
+
- OAuth flow tests use fixture data for AT Protocol interactions
154
+
- Template rendering tests ensure UI consistency
155
+
156
+
## Security Notes
157
+
158
+
- Uses JOSE/JWT with P-256 ECDSA keys for OAuth signing
159
+
- PKCE OAuth flow with state parameter for CSRF protection
160
+
- DPoP (Demonstration of Proof-of-Possession) support for token binding
161
+
- Encrypted session cookies with secure flags
162
+
- Forbids unsafe Rust code (`#[forbid(unsafe_code)]`)
163
+
- Input validation and HTML sanitization with Ammonia
164
+
- Content Security Policy headers
165
+
- Rate limiting and request timeouts
166
+
167
+
## Visibility
168
+
169
+
Types and methods should have the lowest visibility necessary, defaulting to `private`. If `public` visibility is necessary, attempt to make it public to the crate only. Using completely public visibility should be a last resort.
170
+
171
+
## AT Protocol Integration
172
+
173
+
- Full OAuth 2.0 implementation with AT Protocol services using external `atproto-*` crates
174
+
- Support for both standard AT Protocol and AIP OAuth flows
175
+
- DID-based identity management with PDS discovery
176
+
- Integration with PLC (Public Ledger of Credentials) for DID resolution
177
+
- Handle resolution and verification with LRU caching
178
+
- Decentralized identity workflows with fallback mechanisms
179
+
- AT Protocol lexicon support for events and RSVPs:
180
+
- `events_smokesignal_calendar_event`
181
+
- `events_smokesignal_calendar_rsvp`
182
+
- `community_lexicon_calendar_event`
183
+
- `community_lexicon_calendar_rsvp`
184
+
- `community_lexicon_location`
185
+
- Background token refresh for long-lived sessions (AT Protocol backend only)
186
+
- Uses external AT Protocol libraries:
187
+
- `atproto-identity` - Identity resolution and verification
188
+
- `atproto-oauth` - OAuth flow implementation
189
+
- `atproto-oauth-axum` - Axum integration for OAuth
190
+
- `atproto-oauth-aip` - AIP OAuth implementation
191
+
- `atproto-client` - AT Protocol client functionality
192
+
- `atproto-record` - Record management
+9
-2
Cargo.toml
+9
-2
Cargo.toml
···
1
1
[package]
2
2
name = "smokesignal"
3
3
version = "1.0.2"
4
-
edition = "2021"
5
-
rust-version = "1.83"
4
+
edition = "2024"
5
+
rust-version = "1.87"
6
6
authors = ["Nick Gerakines <nick.gerakines@gmail.com>"]
7
7
description = "An event and RSVP management application."
8
8
readme = "README.md"
···
23
23
minijinja-embed = {version = "2.7"}
24
24
25
25
[dependencies]
26
+
atproto-identity = { version = "0.9.3", features = ["lru", "axum", "zeroize"] }
27
+
atproto-oauth = { version = "0.9.3", features = ["lru", "axum", "zeroize"] }
28
+
atproto-oauth-axum = { version = "0.9.3", features = ["zeroize"] }
29
+
atproto-oauth-aip = { version = "0.9.3", features = ["zeroize"] }
30
+
atproto-client = { version = "0.9.3" }
31
+
atproto-record = { version = "0.9.3" }
32
+
26
33
anyhow = "1.0"
27
34
async-trait = "0.1"
28
35
axum-extra = { version = "0.10", features = ["cookie", "cookie-private", "form", "query", "cookie-key-expansion", "typed-header", "typed-routing"] }
+27
-35
Dockerfile
+27
-35
Dockerfile
···
1
1
# syntax=docker/dockerfile:1.4
2
-
FROM rust:latest AS build
2
+
FROM rust:1.87-slim AS builder
3
3
4
-
RUN cargo install sqlx-cli@0.8.2 --no-default-features --features postgres
5
-
RUN cargo install sccache --version ^0.8
6
-
ENV RUSTC_WRAPPER=sccache SCCACHE_DIR=/sccache
4
+
RUN apt-get update && apt-get install -y \
5
+
pkg-config \
6
+
libssl-dev \
7
+
&& rm -rf /var/lib/apt/lists/*
7
8
8
-
RUN USER=root cargo new --bin smokesignal
9
-
RUN mkdir -p /app/
10
-
WORKDIR /app/
9
+
WORKDIR /app
10
+
COPY Cargo.toml build.rs ./
11
11
12
-
RUN --mount=type=bind,source=src,target=src \
13
-
--mount=type=bind,source=migrations,target=migrations \
14
-
--mount=type=bind,source=static,target=static \
15
-
--mount=type=bind,source=i18n,target=i18n \
16
-
--mount=type=bind,source=templates,target=templates \
17
-
--mount=type=bind,source=build.rs,target=build.rs \
18
-
--mount=type=bind,source=Cargo.toml,target=Cargo.toml \
19
-
--mount=type=bind,source=Cargo.lock,target=Cargo.lock \
20
-
--mount=type=cache,id=cargo-target,target=/app/target/ \
21
-
--mount=type=cache,id=sccache,target=$SCCACHE_DIR,sharing=locked \
22
-
--mount=type=cache,id=cargo-registry,target=/usr/local/cargo/registry/ \
23
-
<<EOF
24
-
set -e
25
-
cargo build --locked --release --bin smokesignal --target-dir . --no-default-features -F embed
26
-
EOF
12
+
ARG FEATURES=embed
13
+
ARG TEMPLATES=./templates
14
+
ARG STATIC=./static
15
+
ARG GIT_HASH=0
16
+
ENV GIT_HASH=$GIT_HASH
27
17
28
-
RUN groupadd -g 1500 -r smokesignal && useradd -u 1501 -r -g smokesignal -d /var/lib/smokesignal -m smokesignal
29
-
RUN chown -R smokesignal:smokesignal /app/release/smokesignal
18
+
COPY src ./src
19
+
COPY migrations ./migrations
20
+
COPY static ./static
21
+
COPY i18n ./i18n
22
+
COPY ${TEMPLATES} ./templates
23
+
COPY ${STATIC} ./static
30
24
31
-
FROM gcr.io/distroless/cc
25
+
RUN cargo build --release --bin smokesignal --no-default-features --features ${FEATURES}
26
+
27
+
FROM gcr.io/distroless/cc-debian12
32
28
33
29
LABEL org.opencontainers.image.title="Smoke Signal"
34
30
LABEL org.opencontainers.image.description="An event and RSVP management application."
···
37
33
LABEL org.opencontainers.image.source="https://tangled.sh/@smokesignal.events/smokesignal"
38
34
LABEL org.opencontainers.image.version="1.0.2"
39
35
40
-
WORKDIR /var/lib/smokesignal
41
-
USER smokesignal:smokesignal
36
+
WORKDIR /app
37
+
COPY --from=builder /app/target/release/smokesignal /app/smokesignal
38
+
COPY --from=builder /app/static ./static
42
39
43
-
COPY --from=build /etc/passwd /etc/passwd
44
-
COPY --from=build /etc/group /etc/group
45
-
COPY --from=build /app/release/smokesignal /var/lib/smokesignal/
46
-
COPY static /var/lib/smokesignal/static
47
-
48
-
ENV HTTP_STATIC_PATH=/var/lib/smokesignal/static
49
-
40
+
ENV HTTP_STATIC_PATH=/app/static
41
+
ENV HTTP_PORT=8080
50
42
ENV RUST_LOG=info
51
43
ENV RUST_BACKTRACE=full
52
44
53
-
ENTRYPOINT ["/var/lib/smokesignal/smokesignal"]
45
+
ENTRYPOINT ["/app/smokesignal"]
+11
-1
build.rs
+11
-1
build.rs
···
1
1
fn main() {
2
2
#[cfg(feature = "embed")]
3
3
{
4
-
minijinja_embed::embed_templates!("templates");
4
+
use std::env;
5
+
use std::path::PathBuf;
6
+
let template_path = if let Ok(value) = env::var("HTTP_TEMPLATE_PATH") {
7
+
value.to_string()
8
+
} else {
9
+
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
10
+
.join("templates")
11
+
.display()
12
+
.to_string()
13
+
};
14
+
minijinja_embed::embed_templates!(&template_path);
5
15
}
6
16
}
+64
docker-compose.yml
+64
docker-compose.yml
···
1
+
version: '3.8'
2
+
3
+
services:
4
+
postgres:
5
+
image: postgres:17-alpine
6
+
container_name: smokesignal_postgres
7
+
environment:
8
+
POSTGRES_DB: smokesignal_dev
9
+
POSTGRES_USER: smokesignal
10
+
POSTGRES_PASSWORD: smokesignal_dev_password
11
+
POSTGRES_INITDB_ARGS: "--encoding=UTF8 --locale=C"
12
+
ports:
13
+
- "5436:5432" # Using 5433 to avoid conflicts with system PostgreSQL
14
+
volumes:
15
+
- postgres_data:/var/lib/postgresql/data
16
+
- ./init-db.sql:/docker-entrypoint-initdb.d/init-db.sql:ro
17
+
healthcheck:
18
+
test: ["CMD-SHELL", "pg_isready -U smokesignal -d smokesignal_dev"]
19
+
interval: 5s
20
+
timeout: 5s
21
+
retries: 5
22
+
restart: unless-stopped
23
+
24
+
minio:
25
+
image: minio/minio:latest
26
+
container_name: smokesignal_minio
27
+
command: server /data --console-address ":9001"
28
+
environment:
29
+
MINIO_ROOT_USER: smokesignal_minio
30
+
MINIO_ROOT_PASSWORD: smokesignal_dev_secret
31
+
ports:
32
+
- "9000:9000" # MinIO API
33
+
- "9001:9001" # MinIO Console
34
+
volumes:
35
+
- minio_data:/data
36
+
healthcheck:
37
+
test: ["CMD", "curl", "-f", "http://localhost:9000/minio/health/live"]
38
+
interval: 30s
39
+
timeout: 20s
40
+
retries: 3
41
+
restart: unless-stopped
42
+
43
+
createbuckets:
44
+
image: minio/mc
45
+
depends_on:
46
+
- minio
47
+
entrypoint: >
48
+
/bin/sh -c "
49
+
/usr/bin/mc mc alias set local http://localhost:9000 smokesignal_minio smokesignal_dev_secret;
50
+
/usr/bin/mc mb myminio/smokesignal-badges;
51
+
/usr/bin/mc policy set public myminio/smokesignal-badges;
52
+
exit 0;
53
+
"
54
+
volumes:
55
+
minio_data:
56
+
driver: local
57
+
postgres_data:
58
+
driver: local
59
+
pgadmin_data:
60
+
driver: local
61
+
62
+
networks:
63
+
default:
64
+
name: smokesignal_network
+40
migrations/20250619135328_convert_dpop_jwk_to_text.sql
+40
migrations/20250619135328_convert_dpop_jwk_to_text.sql
···
1
+
-- Convert dpop_jwk from JSON to TEXT for oauth_requests and oauth_sessions tables
2
+
-- This migration preserves existing data by converting JSON values to string format
3
+
4
+
-- First, convert oauth_requests.dpop_jwk from JSON to TEXT
5
+
ALTER TABLE oauth_requests
6
+
ADD COLUMN dpop_jwk_text TEXT;
7
+
8
+
-- Copy existing JSON data as string
9
+
UPDATE oauth_requests
10
+
SET dpop_jwk_text = dpop_jwk::text;
11
+
12
+
-- Set NOT NULL constraint after data migration
13
+
ALTER TABLE oauth_requests
14
+
ALTER COLUMN dpop_jwk_text SET NOT NULL;
15
+
16
+
-- Drop old JSON column and rename new column
17
+
ALTER TABLE oauth_requests
18
+
DROP COLUMN dpop_jwk;
19
+
20
+
ALTER TABLE oauth_requests
21
+
RENAME COLUMN dpop_jwk_text TO dpop_jwk;
22
+
23
+
-- Now do the same for oauth_sessions.dpop_jwk
24
+
ALTER TABLE oauth_sessions
25
+
ADD COLUMN dpop_jwk_text TEXT;
26
+
27
+
-- Copy existing JSON data as string
28
+
UPDATE oauth_sessions
29
+
SET dpop_jwk_text = dpop_jwk::text;
30
+
31
+
-- Set NOT NULL constraint after data migration
32
+
ALTER TABLE oauth_sessions
33
+
ALTER COLUMN dpop_jwk_text SET NOT NULL;
34
+
35
+
-- Drop old JSON column and rename new column
36
+
ALTER TABLE oauth_sessions
37
+
DROP COLUMN dpop_jwk;
38
+
39
+
ALTER TABLE oauth_sessions
40
+
RENAME COLUMN dpop_jwk_text TO dpop_jwk;
+2
migrations/20250619152115_rename_handles_to_identity_profiles.sql
+2
migrations/20250619152115_rename_handles_to_identity_profiles.sql
+14
migrations/20250620162528_did_documents_storage.sql
+14
migrations/20250620162528_did_documents_storage.sql
···
1
+
-- DID Document storage for atproto_identity::storage::DidDocumentStorage implementation
2
+
CREATE TABLE did_documents (
3
+
did varchar(512) PRIMARY KEY,
4
+
document_json JSON NOT NULL,
5
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
6
+
updated_at TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT NOW(),
7
+
expires_at TIMESTAMP WITH TIME ZONE DEFAULT NULL
8
+
);
9
+
10
+
-- Index for expiration cleanup
11
+
CREATE INDEX idx_did_documents_expires_at ON did_documents(expires_at) WHERE expires_at IS NOT NULL;
12
+
13
+
-- Index for updated_at for LRU-style cleanup
14
+
CREATE INDEX idx_did_documents_updated_at ON did_documents(updated_at);
+21
migrations/20250620162624_atproto_oauth_requests_storage.sql
+21
migrations/20250620162624_atproto_oauth_requests_storage.sql
···
1
+
-- AT Protocol OAuth Request storage for atproto_oauth::storage::OAuthRequestStorage implementation
2
+
CREATE TABLE atproto_oauth_requests (
3
+
oauth_state varchar(512) PRIMARY KEY,
4
+
issuer varchar(512) NOT NULL,
5
+
did varchar(512) NOT NULL,
6
+
nonce varchar(512) NOT NULL,
7
+
pkce_verifier varchar(512) NOT NULL,
8
+
signing_public_key varchar(512) NOT NULL,
9
+
dpop_private_key TEXT NOT NULL,
10
+
created_at TIMESTAMP WITH TIME ZONE NOT NULL,
11
+
expires_at TIMESTAMP WITH TIME ZONE NOT NULL
12
+
);
13
+
14
+
-- Index for expiration cleanup
15
+
CREATE INDEX idx_atproto_oauth_requests_expires_at ON atproto_oauth_requests(expires_at);
16
+
17
+
-- Index for created_at for cleanup of old requests
18
+
CREATE INDEX idx_atproto_oauth_requests_created_at ON atproto_oauth_requests(created_at);
19
+
20
+
-- Index for DID lookups
21
+
CREATE INDEX idx_atproto_oauth_requests_did ON atproto_oauth_requests(did);
+27
-20
src/atproto/auth.rs
+27
-20
src/atproto/auth.rs
···
1
-
use p256::SecretKey;
1
+
use anyhow::Result;
2
2
3
-
pub trait OAuthSessionProvider {
4
-
fn oauth_access_token(&self) -> String;
5
-
fn oauth_issuer(&self) -> String;
6
-
fn dpop_secret(&self) -> SecretKey;
3
+
use atproto_client::client::DPoPAuth;
4
+
use atproto_identity::key::identify_key;
5
+
6
+
use crate::storage::oauth::model::OAuthSession;
7
+
use atproto_oauth_aip::{resources::oauth_protected_resource, workflow::session_exchange};
8
+
9
+
/// Create DPoPAuth directly from OAuthSession
10
+
pub fn create_dpop_auth_from_oauth_session(oauth_session: &OAuthSession) -> Result<DPoPAuth> {
11
+
let dpop_private_key_data = identify_key(&oauth_session.dpop_jwk)?;
12
+
13
+
Ok(DPoPAuth {
14
+
dpop_private_key_data,
15
+
oauth_access_token: oauth_session.access_token.clone(),
16
+
})
7
17
}
8
18
9
-
pub struct SimpleOAuthSessionProvider {
10
-
pub access_token: String,
11
-
pub issuer: String,
12
-
pub dpop_secret: SecretKey,
13
-
}
19
+
pub async fn create_dpop_auth_from_aip_session(
20
+
http_client: &reqwest::Client,
21
+
aip_server: &str,
22
+
access_token: &str,
23
+
) -> Result<DPoPAuth> {
24
+
let protected_resource = oauth_protected_resource(http_client, aip_server).await?;
14
25
15
-
impl OAuthSessionProvider for SimpleOAuthSessionProvider {
16
-
fn oauth_access_token(&self) -> String {
17
-
self.access_token.clone()
18
-
}
26
+
let session_response = session_exchange(http_client, &protected_resource, access_token).await?;
19
27
20
-
fn oauth_issuer(&self) -> String {
21
-
self.issuer.clone()
22
-
}
28
+
let dpop_private_key_data = identify_key(&session_response.dpop_key)?;
23
29
24
-
fn dpop_secret(&self) -> SecretKey {
25
-
self.dpop_secret.clone()
26
-
}
30
+
Ok(DPoPAuth {
31
+
dpop_private_key_data,
32
+
oauth_access_token: session_response.access_token.clone(),
33
+
})
27
34
}
-377
src/atproto/client.rs
-377
src/atproto/client.rs
···
1
-
use std::time::Duration;
2
-
3
-
use anyhow::Result;
4
-
use reqwest_chain::ChainMiddleware;
5
-
use reqwest_middleware::ClientBuilder;
6
-
use serde::{de::DeserializeOwned, Deserialize, Serialize};
7
-
use tracing::Instrument;
8
-
9
-
// Standard timeout for all HTTP client operations
10
-
const HTTP_CLIENT_TIMEOUT_SECS: u64 = 8;
11
-
12
-
use crate::atproto::auth::OAuthSessionProvider;
13
-
use crate::atproto::errors::ClientError;
14
-
use crate::atproto::lexicon::com::atproto::repo::StrongRef;
15
-
use crate::atproto::xrpc::SimpleError;
16
-
use crate::http::handle_oauth_login::pkce_challenge;
17
-
use crate::http::utils::URLBuilder;
18
-
use crate::jose::jwt::{Claims, Header, JoseClaims};
19
-
use crate::jose::mint_token;
20
-
use crate::oauth::dpop::DpopRetry;
21
-
22
-
#[derive(Debug, Serialize, Deserialize, Clone)]
23
-
#[serde(bound = "T: Serialize + DeserializeOwned")]
24
-
pub struct CreateRecordRequest<T: DeserializeOwned> {
25
-
pub repo: String,
26
-
pub collection: String,
27
-
28
-
#[serde(skip_serializing_if = "Option::is_none", default, rename = "rkey")]
29
-
pub record_key: Option<String>,
30
-
31
-
pub validate: bool,
32
-
33
-
pub record: T,
34
-
35
-
#[serde(
36
-
skip_serializing_if = "Option::is_none",
37
-
default,
38
-
rename = "swapCommit"
39
-
)]
40
-
pub swap_commit: Option<String>,
41
-
}
42
-
43
-
#[derive(Debug, Serialize, Deserialize, Clone)]
44
-
#[serde(bound = "T: Serialize + DeserializeOwned")]
45
-
pub struct PutRecordRequest<T: DeserializeOwned> {
46
-
pub repo: String,
47
-
pub collection: String,
48
-
49
-
#[serde(rename = "rkey")]
50
-
pub record_key: String,
51
-
52
-
pub validate: bool,
53
-
54
-
pub record: T,
55
-
56
-
#[serde(
57
-
skip_serializing_if = "Option::is_none",
58
-
default,
59
-
rename = "swapCommit"
60
-
)]
61
-
pub swap_commit: Option<String>,
62
-
63
-
#[serde(
64
-
skip_serializing_if = "Option::is_none",
65
-
default,
66
-
rename = "swapRecord"
67
-
)]
68
-
pub swap_record: Option<String>,
69
-
}
70
-
71
-
#[derive(Debug, Serialize, Deserialize, Clone)]
72
-
#[serde(untagged)]
73
-
pub enum CreateRecordResponse {
74
-
StrongRef(StrongRef),
75
-
Error(SimpleError),
76
-
}
77
-
78
-
#[derive(Debug, Serialize, Deserialize, Clone)]
79
-
#[serde(untagged)]
80
-
pub enum PutRecordResponse {
81
-
StrongRef(StrongRef),
82
-
Error(SimpleError),
83
-
}
84
-
85
-
#[derive(Debug, Serialize, Deserialize, Clone)]
86
-
pub struct ListRecordsParams {
87
-
pub repo: String,
88
-
pub collection: String,
89
-
#[serde(skip_serializing_if = "Option::is_none")]
90
-
pub limit: Option<u32>,
91
-
#[serde(skip_serializing_if = "Option::is_none")]
92
-
pub cursor: Option<String>,
93
-
#[serde(skip_serializing_if = "Option::is_none")]
94
-
pub reverse: Option<bool>,
95
-
}
96
-
97
-
#[derive(Debug, Serialize, Deserialize, Clone)]
98
-
pub struct ListRecord<T> {
99
-
pub uri: String,
100
-
pub cid: String,
101
-
pub value: T,
102
-
}
103
-
104
-
#[derive(Debug, Serialize, Deserialize, Clone)]
105
-
pub struct ListRecordsResponse<T> {
106
-
pub cursor: Option<String>,
107
-
pub records: Vec<ListRecord<T>>,
108
-
}
109
-
110
-
pub struct OAuthPdsClient<'a> {
111
-
pub http_client: &'a reqwest::Client,
112
-
pub pds: &'a str,
113
-
}
114
-
115
-
impl OAuthPdsClient<'_> {
116
-
pub async fn create_record<T: DeserializeOwned + Serialize>(
117
-
&self,
118
-
oauth_session: &impl OAuthSessionProvider,
119
-
record: CreateRecordRequest<T>,
120
-
) -> Result<StrongRef, anyhow::Error> {
121
-
let mut url_builder = URLBuilder::new(self.pds);
122
-
url_builder.path("/xrpc/com.atproto.repo.createRecord");
123
-
let url = url_builder.build();
124
-
125
-
let dpop_secret_key = oauth_session.dpop_secret();
126
-
let dpop_public_key = dpop_secret_key.public_key();
127
-
let oauth_issuer = oauth_session.oauth_issuer();
128
-
let oauth_access_token = oauth_session.oauth_access_token();
129
-
130
-
let now = chrono::Utc::now();
131
-
132
-
let dpop_proof_header = Header {
133
-
type_: Some("dpop+jwt".to_string()),
134
-
algorithm: Some("ES256".to_string()),
135
-
json_web_key: Some(dpop_public_key.to_jwk()),
136
-
..Default::default()
137
-
};
138
-
139
-
let dpop_proof_claim = Claims::new(JoseClaims {
140
-
issuer: Some(oauth_issuer.clone()),
141
-
issued_at: Some(now.timestamp() as u64),
142
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
143
-
json_web_token_id: Some(ulid::Ulid::new().to_string()),
144
-
http_method: Some("POST".to_string()),
145
-
http_uri: Some(url.clone()),
146
-
auth: Some(pkce_challenge(&oauth_access_token)),
147
-
148
-
..Default::default()
149
-
});
150
-
let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?;
151
-
152
-
let dpop_retry = DpopRetry::new(
153
-
dpop_proof_header.clone(),
154
-
dpop_proof_claim.clone(),
155
-
dpop_secret_key.clone(),
156
-
);
157
-
158
-
let dpop_retry_client = ClientBuilder::new(self.http_client.clone())
159
-
.with(ChainMiddleware::new(dpop_retry.clone()))
160
-
.build();
161
-
162
-
let http_response = dpop_retry_client
163
-
.post(url)
164
-
.header("Authorization", &format!("DPoP {}", oauth_access_token))
165
-
.header("DPoP", dpop_proof_token.as_str())
166
-
.json(&record)
167
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
168
-
.send()
169
-
.instrument(tracing::info_span!("create_record"))
170
-
.await?;
171
-
172
-
tracing::info!(
173
-
"create_record response status: {:?}",
174
-
http_response.status()
175
-
);
176
-
177
-
let create_record_respoonse = http_response.json::<CreateRecordResponse>().await;
178
-
179
-
match create_record_respoonse {
180
-
Ok(CreateRecordResponse::StrongRef(strong_ref)) => Ok(strong_ref),
181
-
Ok(CreateRecordResponse::Error(err)) => {
182
-
Err(ClientError::ServerError(err.error_message()).into())
183
-
}
184
-
Err(err) => Err(ClientError::CreateRecordResponseFailure(err).into()),
185
-
}
186
-
}
187
-
188
-
pub async fn put_record<T: DeserializeOwned + Serialize>(
189
-
&self,
190
-
oauth_session: &impl OAuthSessionProvider,
191
-
record: PutRecordRequest<T>,
192
-
) -> Result<StrongRef, anyhow::Error> {
193
-
let mut url_builder = URLBuilder::new(self.pds);
194
-
url_builder.path("/xrpc/com.atproto.repo.putRecord");
195
-
let url = url_builder.build();
196
-
197
-
let dpop_secret_key = oauth_session.dpop_secret();
198
-
let dpop_public_key = dpop_secret_key.public_key();
199
-
let oauth_issuer = oauth_session.oauth_issuer();
200
-
let oauth_access_token = oauth_session.oauth_access_token();
201
-
202
-
let now = chrono::Utc::now();
203
-
204
-
let dpop_proof_header = Header {
205
-
type_: Some("dpop+jwt".to_string()),
206
-
algorithm: Some("ES256".to_string()),
207
-
json_web_key: Some(dpop_public_key.to_jwk()),
208
-
..Default::default()
209
-
};
210
-
211
-
let dpop_proof_claim = Claims::new(JoseClaims {
212
-
issuer: Some(oauth_issuer.clone()),
213
-
issued_at: Some(now.timestamp() as u64),
214
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
215
-
json_web_token_id: Some(ulid::Ulid::new().to_string()),
216
-
http_method: Some("POST".to_string()),
217
-
http_uri: Some(url.clone()),
218
-
auth: Some(pkce_challenge(&oauth_access_token)),
219
-
220
-
..Default::default()
221
-
});
222
-
let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?;
223
-
224
-
let dpop_retry = DpopRetry::new(
225
-
dpop_proof_header.clone(),
226
-
dpop_proof_claim.clone(),
227
-
dpop_secret_key.clone(),
228
-
);
229
-
230
-
let dpop_retry_client = ClientBuilder::new(self.http_client.clone())
231
-
.with(ChainMiddleware::new(dpop_retry.clone()))
232
-
.build();
233
-
234
-
let http_response = dpop_retry_client
235
-
.post(url)
236
-
.header("Authorization", &format!("DPoP {}", oauth_access_token))
237
-
.header("DPoP", dpop_proof_token.as_str())
238
-
.json(&record)
239
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
240
-
.send()
241
-
.instrument(tracing::info_span!("put_record"))
242
-
.await?;
243
-
244
-
tracing::info!("put_record response status: {:?}", http_response.status());
245
-
246
-
let put_record_respoonse = http_response.json::<PutRecordResponse>().await;
247
-
248
-
match put_record_respoonse {
249
-
Ok(PutRecordResponse::StrongRef(strong_ref)) => Ok(strong_ref),
250
-
Ok(PutRecordResponse::Error(err)) => {
251
-
Err(ClientError::ServerError(err.error_message()).into())
252
-
}
253
-
Err(err) => Err(ClientError::PutRecordResponseFailure(err).into()),
254
-
}
255
-
}
256
-
257
-
pub async fn list_records<T: DeserializeOwned>(
258
-
&self,
259
-
oauth_session: &impl OAuthSessionProvider,
260
-
params: &ListRecordsParams,
261
-
) -> Result<ListRecordsResponse<T>, anyhow::Error> {
262
-
let mut url_builder = URLBuilder::new(self.pds);
263
-
url_builder.path("/xrpc/com.atproto.repo.listRecords");
264
-
265
-
// Add query parameters
266
-
url_builder.param("repo", ¶ms.repo);
267
-
url_builder.param("collection", ¶ms.collection);
268
-
269
-
if let Some(limit) = params.limit {
270
-
url_builder.param("limit", &limit.to_string());
271
-
}
272
-
273
-
if let Some(cursor) = ¶ms.cursor {
274
-
url_builder.param("cursor", cursor);
275
-
}
276
-
277
-
if let Some(reverse) = params.reverse {
278
-
url_builder.param("reverse", &reverse.to_string());
279
-
}
280
-
281
-
let url = url_builder.build();
282
-
283
-
let dpop_secret_key = oauth_session.dpop_secret();
284
-
let dpop_public_key = dpop_secret_key.public_key();
285
-
let oauth_issuer = oauth_session.oauth_issuer();
286
-
let oauth_access_token = oauth_session.oauth_access_token();
287
-
288
-
let now = chrono::Utc::now();
289
-
290
-
let dpop_proof_header = Header {
291
-
type_: Some("dpop+jwt".to_string()),
292
-
algorithm: Some("ES256".to_string()),
293
-
json_web_key: Some(dpop_public_key.to_jwk()),
294
-
..Default::default()
295
-
};
296
-
297
-
let dpop_proof_claim = Claims::new(JoseClaims {
298
-
issuer: Some(oauth_issuer.clone()),
299
-
issued_at: Some(now.timestamp() as u64),
300
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
301
-
json_web_token_id: Some(ulid::Ulid::new().to_string()),
302
-
http_method: Some("GET".to_string()),
303
-
http_uri: Some(url.clone()),
304
-
auth: Some(pkce_challenge(&oauth_access_token)),
305
-
306
-
..Default::default()
307
-
});
308
-
let dpop_proof_token = mint_token(&dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)?;
309
-
310
-
let dpop_retry = DpopRetry::new(
311
-
dpop_proof_header.clone(),
312
-
dpop_proof_claim.clone(),
313
-
dpop_secret_key.clone(),
314
-
);
315
-
316
-
let dpop_retry_client = ClientBuilder::new(self.http_client.clone())
317
-
.with(ChainMiddleware::new(dpop_retry.clone()))
318
-
.build();
319
-
320
-
let http_response = dpop_retry_client
321
-
.get(url)
322
-
.header("Authorization", &format!("DPoP {}", oauth_access_token))
323
-
.header("DPoP", dpop_proof_token.as_str())
324
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
325
-
.send()
326
-
.instrument(tracing::span!(tracing::Level::INFO, "list_records"))
327
-
.await?;
328
-
329
-
let result = http_response.json::<ListRecordsResponse<T>>().await?;
330
-
331
-
Ok(result)
332
-
}
333
-
}
334
-
335
-
#[cfg(test)]
336
-
mod tests {
337
-
use std::collections::HashMap;
338
-
339
-
use crate::atproto::lexicon::community::lexicon::calendar::event::Event;
340
-
341
-
use super::*;
342
-
use anyhow::Result;
343
-
344
-
#[test]
345
-
fn location_record() -> Result<()> {
346
-
let test_json = r#"{"repo":"nick","collection":"stuff","validate":false,"record":{"$type":"community.lexicon.calendar.event","name":"My awesome event","description":"A really cool event.","createdAt":"2024-08-04T09:45:00.000Z"}}"#;
347
-
348
-
{
349
-
// Serialize bare
350
-
assert_eq!(
351
-
serde_json::to_string(&CreateRecordRequest {
352
-
repo: "nick".to_string(),
353
-
collection: "stuff".to_string(),
354
-
validate: false,
355
-
record_key: None,
356
-
record: Event::Current {
357
-
name: "My awesome event".to_string(),
358
-
description: "A really cool event.".to_string(),
359
-
created_at: "2024-08-04T09:45:00.000Z".parse().unwrap(),
360
-
starts_at: None,
361
-
ends_at: None,
362
-
mode: None,
363
-
status: None,
364
-
locations: vec![],
365
-
uris: vec![],
366
-
extra: HashMap::default(),
367
-
},
368
-
swap_commit: None,
369
-
})
370
-
.unwrap(),
371
-
test_json
372
-
);
373
-
}
374
-
375
-
Ok(())
376
-
}
377
-
}
-53
src/atproto/datetime.rs
-53
src/atproto/datetime.rs
···
1
-
pub mod format {
2
-
use chrono::{DateTime, SecondsFormat, Utc};
3
-
use serde::{self, Deserialize, Deserializer, Serializer};
4
-
5
-
pub fn serialize<S>(date: &DateTime<Utc>, serializer: S) -> Result<S::Ok, S::Error>
6
-
where
7
-
S: Serializer,
8
-
{
9
-
let s = date.to_rfc3339_opts(SecondsFormat::Millis, true);
10
-
serializer.serialize_str(&s)
11
-
}
12
-
13
-
pub fn deserialize<'de, D>(deserializer: D) -> Result<DateTime<Utc>, D::Error>
14
-
where
15
-
D: Deserializer<'de>,
16
-
{
17
-
let date_value = String::deserialize(deserializer)?;
18
-
DateTime::parse_from_rfc3339(&date_value)
19
-
.map(|v| v.with_timezone(&Utc))
20
-
.map_err(serde::de::Error::custom)
21
-
}
22
-
}
23
-
24
-
pub mod optional_format {
25
-
use chrono::{DateTime, SecondsFormat, Utc};
26
-
use serde::{self, Deserialize, Deserializer, Serializer};
27
-
28
-
pub fn serialize<S>(date: &Option<DateTime<Utc>>, serializer: S) -> Result<S::Ok, S::Error>
29
-
where
30
-
S: Serializer,
31
-
{
32
-
if date.is_none() {
33
-
return serializer.serialize_none();
34
-
}
35
-
let s = date.unwrap().to_rfc3339_opts(SecondsFormat::Millis, true);
36
-
serializer.serialize_str(&s)
37
-
}
38
-
39
-
pub fn deserialize<'de, D>(deserializer: D) -> Result<Option<DateTime<Utc>>, D::Error>
40
-
where
41
-
D: Deserializer<'de>,
42
-
{
43
-
let maybe_date_value: Option<String> = Option::deserialize(deserializer)?;
44
-
if maybe_date_value.is_none() {
45
-
return Ok(None);
46
-
}
47
-
let date_value = maybe_date_value.unwrap();
48
-
DateTime::parse_from_rfc3339(&date_value)
49
-
.map(|v| v.with_timezone(&Utc))
50
-
.map_err(serde::de::Error::custom)
51
-
.map(Some)
52
-
}
53
-
}
-52
src/atproto/errors.rs
-52
src/atproto/errors.rs
···
1
-
use thiserror::Error;
2
-
3
-
#[derive(Debug, Error)]
4
-
pub enum ClientError {
5
-
#[error("error-xrpc-client-1 Malformed PutRecord response: {0:?}")]
6
-
PutRecordResponseFailure(reqwest::Error),
7
-
8
-
#[error("error-xrpc-client-2 Malformed CreateRecord response: {0:?}")]
9
-
CreateRecordResponseFailure(reqwest::Error),
10
-
11
-
#[error("error-xrpc-client-3 XRPC error from server: {0}")]
12
-
ServerError(String),
13
-
14
-
#[error("error-xrpc-client-4 Invalid record format: {0}")]
15
-
InvalidRecordFormat(String),
16
-
}
17
-
18
-
#[derive(Debug, Error)]
19
-
pub enum UriError {
20
-
#[error("error-uri-1 Invalid AT-URI: repository missing")]
21
-
RepositoryMissing,
22
-
23
-
#[error("error-uri-2 Invalid AT-URI: collection missing")]
24
-
CollectionMissing,
25
-
26
-
#[error("error-uri-3 Invalid AT-URI: rkey missing")]
27
-
RkeyMissing,
28
-
29
-
#[error("error-uri-4 Invalid AT-URI")]
30
-
InvalidFormat,
31
-
32
-
#[error("error-uri-5 Invalid AT-URI: repository contains invalid characters")]
33
-
InvalidRepository,
34
-
35
-
#[error("error-uri-6 Invalid AT-URI: collection contains invalid characters")]
36
-
InvalidCollection,
37
-
38
-
#[error("error-uri-7 Invalid AT-URI: rkey contains invalid characters")]
39
-
InvalidRkey,
40
-
41
-
#[error("error-uri-8 Invalid AT-URI: path traversal attempt detected")]
42
-
PathTraversalAttempt,
43
-
44
-
#[error("error-uri-9 Invalid AT-URI: repository too long (max 253 chars)")]
45
-
RepositoryTooLong,
46
-
47
-
#[error("error-uri-10 Invalid AT-URI: collection too long (max 128 chars)")]
48
-
CollectionTooLong,
49
-
50
-
#[error("error-uri-11 Invalid AT-URI: rkey too long (max 512 chars)")]
51
-
RkeyTooLong,
52
-
}
+2
-3
src/atproto/lexicon/community_lexicon_calendar_event.rs
+2
-3
src/atproto/lexicon/community_lexicon_calendar_event.rs
···
3
3
use std::collections::HashMap;
4
4
5
5
use crate::atproto::lexicon::community::lexicon::location::{Address, Fsq, Geo, Hthree};
6
-
use crate::atproto::{
7
-
datetime::format as datetime_format, datetime::optional_format as optional_datetime_format,
8
-
};
6
+
use atproto_record::datetime::format as datetime_format;
7
+
use atproto_record::datetime::optional_format as optional_datetime_format;
9
8
10
9
pub const NSID: &str = "community.lexicon.calendar.event";
11
10
+1
-1
src/atproto/lexicon/community_lexicon_calendar_rsvp.rs
+1
-1
src/atproto/lexicon/community_lexicon_calendar_rsvp.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use serde::{Deserialize, Serialize};
3
3
4
-
use crate::atproto::datetime::format as datetime_format;
5
4
use crate::atproto::lexicon::com::atproto::repo::StrongRef;
5
+
use atproto_record::datetime::format as datetime_format;
6
6
7
7
pub const NSID: &str = "community.lexicon.calendar.rsvp";
8
8
+1
-1
src/atproto/lexicon/events_smokesignal_calendar_rsvp.rs
+1
-1
src/atproto/lexicon/events_smokesignal_calendar_rsvp.rs
···
1
1
use chrono::{DateTime, Utc};
2
2
use serde::{Deserialize, Serialize};
3
3
4
-
use crate::atproto::datetime::optional_format as optional_datetime_format;
5
4
use crate::atproto::lexicon::com::atproto::repo::StrongRef;
5
+
use atproto_record::datetime::optional_format as optional_datetime_format;
6
6
7
7
pub const NSID: &str = "events.smokesignal.calendar.rsvp";
8
8
-5
src/atproto/mod.rs
-5
src/atproto/mod.rs
-160
src/atproto/uri.rs
-160
src/atproto/uri.rs
···
1
-
use anyhow::Result;
2
-
3
-
use crate::{atproto::errors::UriError, validation::is_valid_hostname};
4
-
5
-
// Constants for maximum lengths
6
-
const MAX_REPOSITORY_LENGTH: usize = 253; // DNS name length limit
7
-
const MAX_COLLECTION_LENGTH: usize = 128;
8
-
const MAX_RKEY_LENGTH: usize = 512;
9
-
10
-
/// Validates a repository name for AT Protocol URIs
11
-
///
12
-
/// Repository names should generally follow host name rules:
13
-
/// - Alphanumeric characters, hyphens, and periods
14
-
/// - No consecutive periods
15
-
/// - Cannot start or end with period or hyphen
16
-
fn is_valid_repository(repository: &str) -> bool {
17
-
if repository.is_empty() || repository.len() > MAX_REPOSITORY_LENGTH {
18
-
return false;
19
-
}
20
-
21
-
// Check for invalid characters
22
-
if !repository
23
-
.chars()
24
-
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '.' || c == ':')
25
-
{
26
-
return false;
27
-
}
28
-
29
-
// TODO: If starts with "did:plc:" then validate encoded string and length
30
-
if repository.starts_with("did:plc:") {
31
-
return true;
32
-
}
33
-
34
-
// TODO: If starts with "did:web:" then validate hostname and parts
35
-
if repository.starts_with("did:web:") {
36
-
return true;
37
-
}
38
-
39
-
is_valid_hostname(repository)
40
-
}
41
-
42
-
/// Validates a collection name for AT Protocol URIs
43
-
///
44
-
/// Collections should follow namespace-like naming:
45
-
/// - Alphanumeric characters, hyphens, underscores, and periods
46
-
/// - No path traversal sequences
47
-
fn is_valid_collection(collection: &str) -> bool {
48
-
if collection.is_empty() || collection.len() > MAX_COLLECTION_LENGTH {
49
-
return false;
50
-
}
51
-
52
-
// Check for invalid characters
53
-
if !collection
54
-
.chars()
55
-
.all(|c| c.is_ascii_alphanumeric() || c == '-' || c == '_' || c == '.')
56
-
{
57
-
return false;
58
-
}
59
-
60
-
// Check for path traversal attempts
61
-
if collection.contains("../") || collection == ".." {
62
-
return false;
63
-
}
64
-
65
-
true
66
-
}
67
-
68
-
/// Validates a record key (rkey) for AT Protocol URIs
69
-
///
70
-
/// Record keys have more flexible rules but shouldn't contain characters
71
-
/// that could cause problems in URLs or filesystems
72
-
fn is_valid_rkey(rkey: &str) -> bool {
73
-
if rkey.is_empty() || rkey.len() > MAX_RKEY_LENGTH {
74
-
return false;
75
-
}
76
-
77
-
// Block characters that are particularly problematic
78
-
if rkey.chars().any(|c| {
79
-
c == '<'
80
-
|| c == '>'
81
-
|| c == '"'
82
-
|| c == '\''
83
-
|| c == '`'
84
-
|| c == '\\'
85
-
|| c == '|'
86
-
|| c == '*'
87
-
|| c == '?'
88
-
|| c == '#'
89
-
}) {
90
-
return false;
91
-
}
92
-
93
-
// Check for path traversal attempts
94
-
if rkey.contains("../") || rkey == ".." {
95
-
return false;
96
-
}
97
-
98
-
true
99
-
}
100
-
101
-
/// Parses and validates an AT Protocol URI string into its components
102
-
///
103
-
/// AT Protocol URIs follow the format: at://repository/collection/rkey
104
-
/// This function validates each component for proper format and security
105
-
pub fn parse_aturi(uri: &str) -> Result<(String, String, String)> {
106
-
// Validate URI has the correct prefix
107
-
if !uri.starts_with("at://") {
108
-
return Err(UriError::InvalidFormat.into());
109
-
}
110
-
111
-
let value = uri.strip_prefix("at://").unwrap(); // Safe because we checked above
112
-
113
-
// Split the URI into components
114
-
let mut components = value.split('/');
115
-
116
-
// Extract repository
117
-
let repository = components.next().ok_or(UriError::RepositoryMissing)?;
118
-
119
-
// Validate repository
120
-
if repository.len() > MAX_REPOSITORY_LENGTH {
121
-
return Err(UriError::RepositoryTooLong.into());
122
-
}
123
-
if !is_valid_repository(repository) {
124
-
return Err(UriError::InvalidRepository.into());
125
-
}
126
-
127
-
// Extract collection
128
-
let collection = components.next().ok_or(UriError::CollectionMissing)?;
129
-
130
-
// Validate collection
131
-
if collection.len() > MAX_COLLECTION_LENGTH {
132
-
return Err(UriError::CollectionTooLong.into());
133
-
}
134
-
if !is_valid_collection(collection) {
135
-
return Err(UriError::InvalidCollection.into());
136
-
}
137
-
138
-
// Extract record key
139
-
let rkey = components.next().ok_or(UriError::RkeyMissing)?;
140
-
141
-
// Validate record key
142
-
if rkey.len() > MAX_RKEY_LENGTH {
143
-
return Err(UriError::RkeyTooLong.into());
144
-
}
145
-
if !is_valid_rkey(rkey) {
146
-
return Err(UriError::InvalidRkey.into());
147
-
}
148
-
149
-
// Check for any path traversal attempts
150
-
if repository.contains("..") || collection.contains("..") || rkey.contains("..") {
151
-
return Err(UriError::PathTraversalAttempt.into());
152
-
}
153
-
154
-
// Return validated components
155
-
Ok((
156
-
repository.to_string(),
157
-
collection.to_string(),
158
-
rkey.to_string(),
159
-
))
160
-
}
-18
src/atproto/xrpc.rs
-18
src/atproto/xrpc.rs
···
1
-
use serde::{Deserialize, Serialize};
2
-
3
-
#[derive(Debug, Clone, Serialize, Deserialize)]
4
-
pub struct SimpleError {
5
-
pub error: Option<String>,
6
-
pub error_description: Option<String>,
7
-
pub message: Option<String>,
8
-
}
9
-
10
-
impl SimpleError {
11
-
pub fn error_message(&self) -> String {
12
-
[&self.error, &self.error_description, &self.message]
13
-
.iter()
14
-
.filter_map(|v| (*v).clone())
15
-
.collect::<Vec<String>>()
16
-
.join(": ")
17
-
}
18
-
}
+5
-16
src/bin/crypto.rs
+5
-16
src/bin/crypto.rs
···
3
3
use base64::{engine::general_purpose, Engine as _};
4
4
use rand::RngCore;
5
5
6
-
use smokesignal::jose::jwk;
7
-
8
6
fn main() {
9
7
let mut rng = rand::thread_rng();
10
8
11
-
env::args().for_each(|arg| match arg.as_str() {
12
-
"key" => {
13
-
let mut key: [u8; 64] = [0; 64];
14
-
rng.fill_bytes(&mut key);
15
-
let encoded: String = general_purpose::STANDARD_NO_PAD.encode(key);
16
-
println!("{encoded}");
17
-
}
18
-
"jwk" => {
19
-
let ec_jwk = jwk::generate();
20
-
let serialized_value =
21
-
serde_json::to_string_pretty(&ec_jwk).expect("failed to serialize ec jwk");
22
-
println!("{serialized_value}");
23
-
}
24
-
_ => {}
9
+
env::args().for_each(|arg| if arg.as_str() == "key" {
10
+
let mut key: [u8; 64] = [0; 64];
11
+
rng.fill_bytes(&mut key);
12
+
let encoded: String = general_purpose::STANDARD_NO_PAD.encode(key);
13
+
println!("{encoded}");
25
14
});
26
15
}
-42
src/bin/resolve.rs
-42
src/bin/resolve.rs
···
1
-
use std::env;
2
-
3
-
use anyhow::Result;
4
-
use smokesignal::config::{default_env, optional_env, version, CertificateBundles, DnsNameservers};
5
-
use tracing_subscriber::{layer::SubscriberExt as _, util::SubscriberInitExt as _};
6
-
7
-
#[tokio::main]
8
-
async fn main() -> Result<()> {
9
-
tracing_subscriber::registry()
10
-
.with(tracing_subscriber::EnvFilter::new(
11
-
std::env::var("RUST_LOG").unwrap_or_else(|_| "trace".into()),
12
-
))
13
-
.with(tracing_subscriber::fmt::layer().pretty())
14
-
.init();
15
-
16
-
let certificate_bundles: CertificateBundles = optional_env("CERTIFICATE_BUNDLES").try_into()?;
17
-
let default_user_agent = format!("smokesignal ({}; +https://smokesignal.events/)", version()?);
18
-
let user_agent = default_env("USER_AGENT", &default_user_agent);
19
-
let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?;
20
-
21
-
let mut client_builder = reqwest::Client::builder();
22
-
for ca_certificate in certificate_bundles.as_ref() {
23
-
tracing::info!("Loading CA certificate: {:?}", ca_certificate);
24
-
let cert = std::fs::read(ca_certificate)?;
25
-
let cert = reqwest::Certificate::from_pem(&cert)?;
26
-
client_builder = client_builder.add_root_certificate(cert);
27
-
}
28
-
29
-
client_builder = client_builder.user_agent(user_agent);
30
-
let http_client = client_builder.build()?;
31
-
32
-
// Initialize the DNS resolver with configuration from the app config
33
-
let dns_resolver = smokesignal::resolve::create_resolver(dns_nameservers);
34
-
35
-
for subject in env::args() {
36
-
let resolved_did =
37
-
smokesignal::resolve::resolve_subject(&http_client, &dns_resolver, &subject).await;
38
-
tracing::info!(?resolved_did, ?subject, "resolved subject");
39
-
}
40
-
41
-
Ok(())
42
-
}
+136
-13
src/bin/smokesignal.rs
+136
-13
src/bin/smokesignal.rs
···
1
1
use anyhow::Result;
2
-
use chrono::Duration;
2
+
use atproto_identity::key::identify_key;
3
+
use atproto_identity::resolve::{create_resolver, IdentityResolver, InnerIdentityResolver};
4
+
use atproto_oauth_axum::state::OAuthClientConfig;
3
5
use smokesignal::{
4
6
http::{
5
-
context::{AppEngine, I18nContext, WebContext},
7
+
context::{AppEngine, WebContext},
6
8
server::build_router,
7
9
},
8
10
i18n::Locales,
9
-
resolve::create_resolver,
10
-
storage::cache::create_cache_pool,
11
-
task_refresh_tokens::{RefreshTokensTask, RefreshTokensTaskConfig},
11
+
key_provider::SimpleKeyProvider,
12
+
storage::{
13
+
atproto::{PostgresDidDocumentStorage, PostgresOAuthRequestStorage},
14
+
cache::create_cache_pool,
15
+
},
12
16
};
17
+
18
+
use chrono::Duration;
19
+
use smokesignal::config::OAuthBackendConfig;
20
+
use smokesignal::task_identity_refresh::{IdentityRefreshTask, IdentityRefreshTaskConfig};
21
+
use smokesignal::task_oauth_requests_cleanup::{
22
+
OAuthRequestsCleanupTask, OAuthRequestsCleanupTaskConfig,
23
+
};
24
+
use smokesignal::task_refresh_tokens::{RefreshTokensTask, RefreshTokensTaskConfig};
25
+
26
+
use axum::extract::FromRef;
13
27
use sqlx::PgPool;
14
-
use std::{env, str::FromStr};
28
+
use std::{collections::HashMap, env, str::FromStr, sync::Arc};
15
29
use tokio::net::TcpListener;
16
30
use tokio::signal;
17
31
use tokio_util::{sync::CancellationToken, task::TaskTracker};
···
80
94
let jinja = reload_env::build_env(&config.external_base, &config.version);
81
95
82
96
// Initialize the DNS resolver with configuration from the app config
83
-
let dns_resolver = create_resolver(config.dns_nameservers.clone());
97
+
let dns_resolver = create_resolver(config.dns_nameservers.as_ref());
98
+
99
+
// Initialize OAuth and identity resolution components
100
+
let oauth_storage = PostgresOAuthRequestStorage::new_arc(pool.clone());
101
+
let document_storage = PostgresDidDocumentStorage::new_arc(pool.clone());
102
+
103
+
// Create a key provider populated with signing keys from OAuth backend config
104
+
let key_provider_keys =
105
+
if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend {
106
+
signing_keys
107
+
.as_ref()
108
+
.iter()
109
+
.filter_map(|(key_id, private_key)| match identify_key(private_key) {
110
+
Ok(key_data) => {
111
+
tracing::info!(?key_id, "loaded signing key for key provider");
112
+
Some((key_id.clone(), key_data))
113
+
}
114
+
Err(err) => {
115
+
tracing::error!(?err, ?key_id, "failed to identify key for key provider");
116
+
None
117
+
}
118
+
})
119
+
.collect()
120
+
} else {
121
+
HashMap::new() // Empty for AIP backend
122
+
};
123
+
let key_provider = Arc::new(SimpleKeyProvider::new(key_provider_keys));
124
+
125
+
// Create OAuth client config (only for AT Protocol backend)
126
+
let oauth_client_config = OAuthClientConfig {
127
+
client_id: config.external_base.clone(),
128
+
jwks_uri: None,
129
+
signing_keys: if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend
130
+
{
131
+
signing_keys
132
+
.as_ref()
133
+
.iter()
134
+
.filter_map(|value| {
135
+
tracing::info!(?value, "signing key");
136
+
137
+
identify_key(value.1).ok()
138
+
})
139
+
.collect()
140
+
} else {
141
+
Vec::new() // Empty for AIP backend
142
+
},
143
+
redirect_uris: format!("{}/oauth/callback", config.external_base),
144
+
client_name: Some("Smoke Signal".to_string()),
145
+
client_uri: Some(config.external_base.clone()),
146
+
logo_uri: Some(format!(
147
+
"https://{}/logo-160x160x.png",
148
+
config.external_base
149
+
)),
150
+
tos_uri: Some("https://docs.smokesignal.events/docs/about/terms/".to_string()),
151
+
policy_uri: Some("https://docs.smokesignal.events/docs/about/privacy/".to_string()),
152
+
scope: Some("atproto transition:generic transition:email".to_string()),
153
+
};
154
+
155
+
// Create identity resolver using the provided method
156
+
let identity_resolver = IdentityResolver(Arc::new(InnerIdentityResolver {
157
+
dns_resolver,
158
+
http_client: http_client.clone(),
159
+
plc_hostname: config.plc_hostname.clone(),
160
+
}));
84
161
85
162
let web_context = WebContext::new(
86
163
pool.clone(),
···
88
165
AppEngine::from(jinja),
89
166
&http_client,
90
167
config.clone(),
91
-
I18nContext::new(supported_languages, locales),
92
-
dns_resolver,
168
+
oauth_client_config,
169
+
identity_resolver,
170
+
key_provider,
171
+
oauth_storage,
172
+
document_storage.clone(),
173
+
supported_languages,
174
+
locales,
93
175
);
94
176
95
177
let app = build_router(web_context.clone());
···
126
208
});
127
209
}
128
210
129
-
{
211
+
// Only spawn refresh tokens task for AT Protocol OAuth backend
212
+
if let OAuthBackendConfig::ATProtocol { signing_keys } = &config.oauth_backend {
130
213
let task_config = RefreshTokensTaskConfig {
131
214
sleep_interval: Duration::seconds(10),
132
215
worker_id: "dev".to_string(),
133
216
external_url_base: config.external_base.clone(),
134
-
signing_keys: config.signing_keys.clone(),
135
-
oauth_active_keys: config.oauth_active_keys.clone(),
217
+
signing_keys: signing_keys.clone(),
136
218
};
137
219
let task = RefreshTokensTask::new(
138
220
task_config,
139
221
http_client.clone(),
140
222
pool.clone(),
141
223
cache_pool.clone(),
224
+
document_storage.clone(),
142
225
token.clone(),
143
226
);
144
227
145
228
let inner_token = token.clone();
146
229
tracker.spawn(async move {
147
230
if let Err(err) = task.run().await {
148
-
tracing::error!("Database task failed: {}", err);
231
+
tracing::error!("Refresh tokens task failed: {}", err);
232
+
}
233
+
inner_token.cancel();
234
+
});
235
+
}
236
+
237
+
// Spawn OAuth requests cleanup task if enabled
238
+
if config.enable_task_oauth_requests_cleanup {
239
+
let cleanup_task_config = OAuthRequestsCleanupTaskConfig {
240
+
sleep_interval: Duration::hours(1), // Run once per hour
241
+
};
242
+
let cleanup_task =
243
+
OAuthRequestsCleanupTask::new(cleanup_task_config, pool.clone(), token.clone());
244
+
245
+
let inner_token = token.clone();
246
+
tracker.spawn(async move {
247
+
if let Err(err) = cleanup_task.run().await {
248
+
tracing::error!("OAuth requests cleanup task failed: {}", err);
249
+
}
250
+
inner_token.cancel();
251
+
});
252
+
}
253
+
254
+
// Spawn identity refresh task if enabled
255
+
if config.enable_task_identity_refresh {
256
+
let identity_refresh_config = IdentityRefreshTaskConfig {
257
+
sleep_interval: Duration::hours(1), // Run once per hour
258
+
worker_id: "dev".to_string(),
259
+
};
260
+
let identity_refresh_task = IdentityRefreshTask::new(
261
+
identity_refresh_config,
262
+
pool.clone(),
263
+
document_storage.clone(),
264
+
atproto_identity::resolve::IdentityResolver::from_ref(&web_context),
265
+
token.clone(),
266
+
);
267
+
268
+
let inner_token = token.clone();
269
+
tracker.spawn(async move {
270
+
if let Err(err) = identity_refresh_task.run().await {
271
+
tracing::error!("Identity refresh task failed: {}", err);
149
272
}
150
273
inner_token.cancel();
151
274
});
+111
-121
src/config.rs
+111
-121
src/config.rs
···
1
-
use anyhow::Result;
1
+
use anyhow::{Context, Result};
2
+
use atproto_identity::key::{identify_key, to_public, KeyData, KeyType};
2
3
use axum_extra::extract::cookie::Key;
3
4
use base64::{engine::general_purpose, Engine as _};
4
5
use ordermap::OrderMap;
5
-
use p256::SecretKey;
6
-
use rand::seq::SliceRandom;
7
6
8
7
use crate::config_errors::ConfigError;
9
-
use crate::encoding_errors::EncodingError;
10
-
use crate::jose::jwk::WrappedJsonWebKeySet;
11
8
12
9
#[derive(Clone)]
13
10
pub struct HttpPort(u16);
···
18
15
#[derive(Clone)]
19
16
pub struct CertificateBundles(Vec<String>);
20
17
21
-
#[derive(Clone)]
22
-
pub struct SigningKeys(OrderMap<String, SecretKey>);
23
-
24
-
#[derive(Clone)]
25
-
pub struct OAuthActiveKeys(Vec<String>);
18
+
#[derive(Clone, Debug)]
19
+
pub struct SigningKeys(OrderMap<String, String>);
26
20
27
21
#[derive(Clone)]
28
22
pub struct AdminDIDs(Vec<String>);
···
30
24
#[derive(Clone)]
31
25
pub struct DnsNameservers(Vec<std::net::IpAddr>);
32
26
27
+
#[derive(Clone, Debug)]
28
+
pub enum OAuthBackendConfig {
29
+
ATProtocol {
30
+
signing_keys: SigningKeys,
31
+
},
32
+
AIP {
33
+
hostname: String,
34
+
client_id: String,
35
+
client_secret: String,
36
+
},
37
+
}
38
+
33
39
#[derive(Clone)]
34
40
pub struct Config {
35
41
pub version: String,
···
41
47
pub user_agent: String,
42
48
pub database_url: String,
43
49
pub plc_hostname: String,
44
-
pub signing_keys: SigningKeys,
45
-
pub oauth_active_keys: OAuthActiveKeys,
46
-
pub destination_key: SecretKey,
47
50
pub redis_url: String,
48
51
pub admin_dids: AdminDIDs,
49
52
pub dns_nameservers: DnsNameservers,
53
+
pub oauth_backend: OAuthBackendConfig,
54
+
pub destination_key_data: KeyData,
55
+
pub enable_task_oauth_requests_cleanup: bool,
56
+
pub enable_task_identity_refresh: bool,
50
57
}
51
58
52
59
impl Config {
···
72
79
73
80
let database_url = default_env("DATABASE_URL", "sqlite://development.db");
74
81
75
-
let signing_keys: SigningKeys =
76
-
require_env("SIGNING_KEYS").and_then(|value| value.try_into())?;
77
-
78
-
let oauth_active_keys: OAuthActiveKeys =
79
-
require_env("OAUTH_ACTIVE_KEYS").and_then(|value| value.try_into())?;
80
-
81
-
let destination_key = require_env("DESTINATION_KEY").and_then(|value| {
82
-
signing_keys
83
-
.0
84
-
.get(&value)
85
-
.cloned()
86
-
.ok_or(ConfigError::InvalidDestinationKey.into())
87
-
})?;
88
-
89
82
let redis_url = default_env("REDIS_URL", "redis://valkey:6379/0");
90
83
91
84
let admin_dids: AdminDIDs = optional_env("ADMIN_DIDS").try_into()?;
92
85
93
86
let dns_nameservers: DnsNameservers = optional_env("DNS_NAMESERVERS").try_into()?;
94
87
88
+
// Create OAuth backend configuration based on environment variables
89
+
let oauth_backend_type = default_env("OAUTH_BACKEND", "pds");
90
+
let oauth_backend = match oauth_backend_type.to_lowercase().as_str() {
91
+
"pds" => {
92
+
let signing_keys: SigningKeys =
93
+
require_env("SIGNING_KEYS").and_then(|value| value.try_into())?;
94
+
OAuthBackendConfig::ATProtocol { signing_keys }
95
+
}
96
+
"aip" => {
97
+
let hostname =
98
+
require_env("AIP_HOSTNAME").map(|value| format!("https://{value}"))?;
99
+
let client_id = require_env("AIP_CLIENT_ID")?;
100
+
let client_secret = require_env("AIP_CLIENT_SECRET")?;
101
+
OAuthBackendConfig::AIP {
102
+
hostname,
103
+
client_id,
104
+
client_secret,
105
+
}
106
+
}
107
+
_ => return Err(ConfigError::InvalidOAuthBackend(oauth_backend_type).into()),
108
+
};
109
+
110
+
// Parse destination key for authentication redirects
111
+
let destination_key_data = identify_key(&require_env("DESTINATION_KEY")?)
112
+
.context("failed to parse DESTINATION_KEY")?;
113
+
114
+
// Parse optional task enablement flags
115
+
let enable_task_oauth_requests_cleanup =
116
+
default_env("ENABLE_TASK_OAUTH_REQUESTS_CLEANUP", "true")
117
+
.parse::<bool>()
118
+
.unwrap_or(true);
119
+
let enable_task_identity_refresh = default_env("ENABLE_TASK_IDENTITY_REFRESH", "true")
120
+
.parse::<bool>()
121
+
.unwrap_or(true);
122
+
95
123
Ok(Self {
96
124
version: version()?,
97
125
http_port,
···
101
129
user_agent,
102
130
plc_hostname,
103
131
database_url,
104
-
signing_keys,
105
-
oauth_active_keys,
106
132
http_cookie_key,
107
-
destination_key,
108
133
redis_url,
109
134
admin_dids,
110
135
dns_nameservers,
136
+
oauth_backend,
137
+
destination_key_data,
138
+
enable_task_oauth_requests_cleanup,
139
+
enable_task_identity_refresh,
111
140
})
112
141
}
113
142
114
-
pub fn select_oauth_signing_key(&self) -> Result<(String, SecretKey)> {
115
-
let key_id = self
116
-
.oauth_active_keys
117
-
.as_ref()
118
-
.choose(&mut rand::thread_rng())
119
-
.ok_or(ConfigError::SigningKeyNotFound)?
120
-
.clone();
121
-
let signing_key = self
122
-
.signing_keys
123
-
.as_ref()
124
-
.get(&key_id)
125
-
.ok_or(ConfigError::SigningKeyNotFound)?
126
-
.clone();
127
-
128
-
Ok((key_id, signing_key))
143
+
pub fn select_oauth_signing_key(&self) -> Result<(String, String)> {
144
+
match &self.oauth_backend {
145
+
OAuthBackendConfig::ATProtocol { signing_keys } => {
146
+
let item = signing_keys.as_ref().first();
147
+
item.map(|(key, value)| (key.clone(), value.clone()))
148
+
.ok_or(anyhow::anyhow!("signing keys is empty"))
149
+
}
150
+
OAuthBackendConfig::AIP { .. } => Err(anyhow::anyhow!(
151
+
"signing keys not available for AIP OAuth backend"
152
+
)),
153
+
}
129
154
}
130
155
131
156
/// Check if a DID is in the admin allow list
···
216
241
}
217
242
}
218
243
219
-
impl AsRef<OrderMap<String, SecretKey>> for SigningKeys {
220
-
fn as_ref(&self) -> &OrderMap<String, SecretKey> {
244
+
impl AsRef<OrderMap<String, String>> for SigningKeys {
245
+
fn as_ref(&self) -> &OrderMap<String, String> {
221
246
&self.0
222
247
}
223
248
}
···
225
250
impl TryFrom<String> for SigningKeys {
226
251
type Error = anyhow::Error;
227
252
fn try_from(value: String) -> Result<Self, Self::Error> {
228
-
let content = {
229
-
if value.starts_with("/") {
230
-
// Verify file exists before reading
231
-
if !std::path::Path::new(&value).exists() {
232
-
return Err(ConfigError::SigningKeysFileNotFound(value).into());
233
-
}
234
-
std::fs::read(&value).map_err(ConfigError::ReadSigningKeysFailed)?
235
-
} else {
236
-
general_purpose::STANDARD
237
-
.decode(&value)
238
-
.map_err(EncodingError::Base64DecodingFailed)?
239
-
}
240
-
};
241
-
242
-
// Validate content is not empty
243
-
if content.is_empty() {
244
-
return Err(ConfigError::EmptySigningKeysFile.into());
253
+
// Allow empty signing keys for initial pass with in-memory storage
254
+
if value.is_empty() {
255
+
return Ok(Self(OrderMap::new()));
245
256
}
246
257
247
-
// Parse JSON with proper error handling
248
-
let jwks = serde_json::from_slice::<WrappedJsonWebKeySet>(&content)
249
-
.map_err(ConfigError::ParseSigningKeysFailed)?;
258
+
// Parse semicolon-separated DID method key strings
259
+
let private_keys: Vec<&str> = value.split(';').filter(|s| !s.is_empty()).collect();
250
260
251
-
// Validate JWKS contains keys
252
-
if jwks.keys.is_empty() {
253
-
return Err(ConfigError::MissingKeysInJWKS.into());
261
+
if private_keys.is_empty() {
262
+
return Ok(Self(OrderMap::new()));
254
263
}
255
264
256
-
// Track keys that failed validation for better error reporting
257
-
let mut validation_errors = Vec::new();
265
+
let mut signing_keys = OrderMap::new();
258
266
259
-
let signing_keys = jwks
260
-
.keys
261
-
.iter()
262
-
.filter_map(|key| {
263
-
// Validate key has required fields
264
-
if key.kid.is_none() {
265
-
validation_errors.push("Missing key ID (kid)".to_string());
266
-
return None;
267
+
for private_key in private_keys {
268
+
// Verify this is a private key using identify_key
269
+
let key_data = match identify_key(private_key) {
270
+
Ok(key_data) => key_data,
271
+
Err(err) => {
272
+
tracing::error!(?err, ?private_key, "failed to identify key");
273
+
continue;
267
274
}
275
+
};
268
276
269
-
if let (Some(key_id), secret_key) = (key.kid.clone(), key.jwk.clone()) {
270
-
// Verify the key_id format (should be a valid ULID)
271
-
if ulid::Ulid::from_string(&key_id).is_err() {
272
-
validation_errors.push(format!("Invalid key ID format: {}", key_id));
273
-
return None;
274
-
}
277
+
match key_data.0 {
278
+
KeyType::P256Public | KeyType::K256Public => {
279
+
tracing::error!(?private_key, "public key found");
280
+
continue;
281
+
}
282
+
_ => {}
283
+
}
275
284
276
-
// Validate the secret key
277
-
match p256::SecretKey::from_jwk(&secret_key) {
278
-
Ok(secret_key) => Some((key_id, secret_key)),
279
-
Err(err) => {
280
-
validation_errors.push(format!("Invalid key {}: {}", key_id, err));
281
-
None
282
-
}
283
-
}
284
-
} else {
285
-
None
285
+
// Generate public key using to_public and use it as the identifier
286
+
let public_key = match to_public(&key_data) {
287
+
Ok(public_key) => public_key,
288
+
Err(err) => {
289
+
tracing::error!(
290
+
?err,
291
+
?private_key,
292
+
"unable to derive public key from signing key"
293
+
);
294
+
continue;
286
295
}
287
-
})
288
-
.collect::<OrderMap<String, SecretKey>>();
296
+
};
297
+
298
+
// Store the original DID method key string as the value
299
+
// This allows us to call identify_key again later when we need the KeyData
300
+
signing_keys.insert(public_key.to_string(), private_key.to_string());
301
+
}
289
302
290
303
// Check if we have any valid keys
291
304
if signing_keys.is_empty() {
292
-
if !validation_errors.is_empty() {
293
-
return Err(ConfigError::SigningKeysValidationFailed(validation_errors).into());
294
-
}
295
305
return Err(ConfigError::EmptySigningKeys.into());
296
306
}
297
307
298
308
Ok(Self(signing_keys))
299
-
}
300
-
}
301
-
302
-
impl AsRef<Vec<String>> for OAuthActiveKeys {
303
-
fn as_ref(&self) -> &Vec<String> {
304
-
&self.0
305
-
}
306
-
}
307
-
308
-
impl TryFrom<String> for OAuthActiveKeys {
309
-
type Error = anyhow::Error;
310
-
fn try_from(value: String) -> Result<Self, Self::Error> {
311
-
let values = value
312
-
.split(';')
313
-
.map(|s| s.to_string())
314
-
.collect::<Vec<String>>();
315
-
if values.is_empty() {
316
-
return Err(ConfigError::EmptyOAuthActiveKeys.into());
317
-
}
318
-
Ok(Self(values))
319
309
}
320
310
}
321
311
+14
src/config_errors.rs
+14
src/config_errors.rs
···
125
125
/// that fail validation checks (such as having invalid format).
126
126
#[error("error-config-17 Signing keys validation failed: {0:?}")]
127
127
SigningKeysValidationFailed(Vec<String>),
128
+
129
+
/// Error when AIP OAuth configuration is incomplete.
130
+
///
131
+
/// This error occurs when oauth_backend is set to "aip" but
132
+
/// required AIP configuration values are missing.
133
+
#[error("error-config-18 When oauth_backend is 'aip', AIP_HOSTNAME, AIP_CLIENT_ID, and AIP_CLIENT_SECRET must all be set")]
134
+
AipConfigurationIncomplete,
135
+
136
+
/// Error when oauth_backend has an invalid value.
137
+
///
138
+
/// This error occurs when the OAUTH_BACKEND environment variable
139
+
/// contains a value other than "aip" or "pds".
140
+
#[error("error-config-19 oauth_backend must be either 'aip' or 'pds', got: {0}")]
141
+
InvalidOAuthBackend(String),
128
142
}
-228
src/did.rs
-228
src/did.rs
···
1
-
pub mod model {
2
-
3
-
use serde::Deserialize;
4
-
use serde_json::Value;
5
-
use std::collections::HashMap;
6
-
7
-
#[derive(Clone, Deserialize, Debug)]
8
-
#[serde(rename_all = "camelCase")]
9
-
pub struct Service {
10
-
pub id: String,
11
-
12
-
pub r#type: String,
13
-
14
-
pub service_endpoint: String,
15
-
}
16
-
17
-
#[derive(Clone, Deserialize, Debug)]
18
-
#[serde(tag = "type", rename_all = "camelCase")]
19
-
pub enum VerificationMethod {
20
-
Multikey {
21
-
id: String,
22
-
controller: String,
23
-
public_key_multibase: String,
24
-
},
25
-
26
-
#[serde(untagged)]
27
-
Other {
28
-
#[serde(flatten)]
29
-
extra: HashMap<String, Value>,
30
-
},
31
-
}
32
-
33
-
#[derive(Clone, Deserialize, Debug)]
34
-
#[serde(rename_all = "camelCase")]
35
-
pub struct Document {
36
-
pub id: String,
37
-
pub also_known_as: Vec<String>,
38
-
pub service: Vec<Service>,
39
-
}
40
-
41
-
impl Document {
42
-
pub fn pds_endpoint(&self) -> Option<&str> {
43
-
self.service
44
-
.iter()
45
-
.find(|service| service.r#type == "AtprotoPersonalDataServer")
46
-
.map(|service| service.service_endpoint.as_str())
47
-
}
48
-
49
-
pub fn primary_handle(&self) -> Option<&str> {
50
-
self.also_known_as.first().map(|handle| {
51
-
if let Some(trimmed) = handle.strip_prefix("at://") {
52
-
trimmed
53
-
} else {
54
-
handle.as_str()
55
-
}
56
-
})
57
-
}
58
-
}
59
-
60
-
#[cfg(test)]
61
-
mod tests {
62
-
use crate::did::model::Document;
63
-
64
-
#[test]
65
-
fn test_deserialize() {
66
-
let document = serde_json::from_str::<Document>(
67
-
r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2#atproto","type":"Multikey","controller":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","publicKeyMultibase":"zQ3shXvCK2RyPrSLYQjBEw5CExZkUhJH3n1K2Mb9sC7JbvRMF"}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##,
68
-
);
69
-
assert!(document.is_ok());
70
-
71
-
let document = document.unwrap();
72
-
assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2");
73
-
}
74
-
75
-
#[test]
76
-
fn test_deserialize_unsupported_verification_method() {
77
-
let documents = vec![
78
-
r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2#atproto","type":"Ed25519VerificationKey2020","controller":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","publicKeyMultibase":"zQ3shXvCK2RyPrSLYQjBEw5CExZkUhJH3n1K2Mb9sC7JbvRMF"}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##,
79
-
r##"{"@context":["https://www.w3.org/ns/did/v1","https://w3id.org/security/multikey/v1","https://w3id.org/security/suites/secp256k1-2019/v1"],"id":"did:plc:cbkjy5n7bk3ax2wplmtjofq2","alsoKnownAs":["at://ngerakines.me","at://nick.gerakines.net","at://nick.thegem.city","https://github.com/ngerakines","https://ngerakines.me/","dns:ngerakines.me"],"verificationMethod":[{"id": "did:example:123#_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A","type": "JsonWebKey2020","controller": "did:example:123","publicKeyJwk": {"crv": "Ed25519","x": "VCpo2LMLhn6iWku8MKvSLg2ZAoC-nlOyPVQaO3FxVeQ","kty": "OKP","kid": "_Qq0UL2Fq651Q0Fjd6TvnYE-faHiOpRlPVQcY_-tA4A"}}],"service":[{"id":"#atproto_pds","type":"AtprotoPersonalDataServer","serviceEndpoint":"https://pds.cauda.cloud"}]}"##,
80
-
];
81
-
for document in documents {
82
-
let document = serde_json::from_str::<Document>(document);
83
-
assert!(document.is_ok());
84
-
85
-
let document = document.unwrap();
86
-
assert_eq!(document.id, "did:plc:cbkjy5n7bk3ax2wplmtjofq2");
87
-
}
88
-
}
89
-
}
90
-
}
91
-
92
-
pub mod plc {
93
-
use anyhow::Result;
94
-
use thiserror::Error;
95
-
96
-
use super::model::Document;
97
-
98
-
/// Error types that can occur when working with PLC DIDs
99
-
#[derive(Debug, Error)]
100
-
pub enum PLCDIDError {
101
-
/// Occurs when the HTTP request to fetch the DID document fails
102
-
#[error("error-did-plc-1 HTTP request failed: {url} {error}")]
103
-
HttpRequestFailed {
104
-
/// The URL that was requested
105
-
url: String,
106
-
/// The underlying HTTP error
107
-
error: reqwest::Error,
108
-
},
109
-
110
-
/// Occurs when the DID document cannot be parsed from the HTTP response
111
-
#[error("error-did-plc-2 Failed to parse DID document: {url} {error}")]
112
-
DocumentParseFailed {
113
-
/// The URL that was requested
114
-
url: String,
115
-
/// The underlying parse error
116
-
error: reqwest::Error,
117
-
},
118
-
}
119
-
120
-
pub async fn query(
121
-
http_client: &reqwest::Client,
122
-
plc_hostname: &str,
123
-
did: &str,
124
-
) -> Result<Document> {
125
-
let url = format!("https://{}/{}", plc_hostname, did);
126
-
127
-
http_client
128
-
.get(&url)
129
-
.send()
130
-
.await
131
-
.map_err(|error| PLCDIDError::HttpRequestFailed {
132
-
url: url.clone(),
133
-
error,
134
-
})?
135
-
.json::<Document>()
136
-
.await
137
-
.map_err(|error| PLCDIDError::DocumentParseFailed { url, error })
138
-
.map_err(Into::into)
139
-
}
140
-
}
141
-
142
-
pub mod web {
143
-
use anyhow::Result;
144
-
use thiserror::Error;
145
-
146
-
use super::model::Document;
147
-
148
-
/// Error types that can occur when working with Web DIDs
149
-
#[derive(Debug, Error)]
150
-
pub enum WebDIDError {
151
-
/// Occurs when the DID is missing the 'did:web:' prefix
152
-
#[error("error-did-web-1 Invalid DID format: missing 'did:web:' prefix")]
153
-
InvalidDIDPrefix,
154
-
155
-
/// Occurs when the DID is missing a hostname component
156
-
#[error("error-did-web-2 Invalid DID format: missing hostname component")]
157
-
MissingHostname,
158
-
159
-
/// Occurs when the HTTP request to fetch the DID document fails
160
-
#[error("error-did-web-3 HTTP request failed: {url} {error}")]
161
-
HttpRequestFailed {
162
-
/// The URL that was requested
163
-
url: String,
164
-
/// The underlying HTTP error
165
-
error: reqwest::Error,
166
-
},
167
-
168
-
/// Occurs when the DID document cannot be parsed from the HTTP response
169
-
#[error("error-did-web-4 Failed to parse DID document: {url} {error}")]
170
-
DocumentParseFailed {
171
-
/// The URL that was requested
172
-
url: String,
173
-
/// The underlying parse error
174
-
error: reqwest::Error,
175
-
},
176
-
}
177
-
178
-
pub async fn query(http_client: &reqwest::Client, did: &str) -> Result<Document> {
179
-
// Parse DID and extract hostname and path components
180
-
let mut parts = did
181
-
.strip_prefix("did:web:")
182
-
.ok_or(WebDIDError::InvalidDIDPrefix)?
183
-
.split(':')
184
-
.collect::<Vec<&str>>();
185
-
186
-
let hostname = parts.pop().ok_or(WebDIDError::MissingHostname)?;
187
-
188
-
// Construct URL based on whether path components exist
189
-
let url = if parts.is_empty() {
190
-
format!("https://{}/.well-known/did.json", hostname)
191
-
} else {
192
-
format!("https://{}/{}/did.json", hostname, parts.join("/"))
193
-
};
194
-
195
-
// Fetch and parse document
196
-
http_client
197
-
.get(&url)
198
-
.send()
199
-
.await
200
-
.map_err(|error| WebDIDError::HttpRequestFailed {
201
-
url: url.clone(),
202
-
error,
203
-
})?
204
-
.json::<Document>()
205
-
.await
206
-
.map_err(|error| WebDIDError::DocumentParseFailed { url, error })
207
-
.map_err(Into::into)
208
-
}
209
-
210
-
pub async fn query_hostname(http_client: &reqwest::Client, hostname: &str) -> Result<Document> {
211
-
let url = format!("https://{}/.well-known/did.json", hostname);
212
-
213
-
tracing::debug!(?url, "query_hostname");
214
-
215
-
http_client
216
-
.get(&url)
217
-
.send()
218
-
.await
219
-
.map_err(|error| WebDIDError::HttpRequestFailed {
220
-
url: url.clone(),
221
-
error,
222
-
})?
223
-
.json::<Document>()
224
-
.await
225
-
.map_err(|error| WebDIDError::DocumentParseFailed { url, error })
226
-
.map_err(Into::into)
227
-
}
228
-
}
-33
src/encoding.rs
-33
src/encoding.rs
···
1
-
use anyhow::Result;
2
-
use base64::{engine::general_purpose, Engine as _};
3
-
use serde::{Deserialize, Serialize};
4
-
use std::borrow::Cow;
5
-
6
-
use crate::encoding_errors::EncodingError;
7
-
8
-
pub trait ToBase64 {
9
-
fn to_base64(&self) -> Result<Cow<str>>;
10
-
}
11
-
12
-
impl<T: Serialize> ToBase64 for T {
13
-
fn to_base64(&self) -> Result<Cow<str>> {
14
-
let json_bytes =
15
-
serde_json::to_vec(&self).map_err(EncodingError::JsonSerializationFailed)?;
16
-
let encoded_json_bytes = general_purpose::URL_SAFE_NO_PAD.encode(json_bytes);
17
-
Ok(Cow::Owned(encoded_json_bytes))
18
-
}
19
-
}
20
-
21
-
pub trait FromBase64: Sized {
22
-
fn from_base64<Input: ?Sized + AsRef<[u8]>>(raw: &Input) -> Result<Self>;
23
-
}
24
-
25
-
impl<T: for<'de> Deserialize<'de> + Sized> FromBase64 for T {
26
-
fn from_base64<Input: ?Sized + AsRef<[u8]>>(raw: &Input) -> Result<Self> {
27
-
let content = general_purpose::URL_SAFE_NO_PAD
28
-
.decode(raw)
29
-
.map_err(EncodingError::Base64DecodingFailed)?;
30
-
serde_json::from_slice(&content)
31
-
.map_err(|err| EncodingError::JsonDeserializationFailed(err).into())
32
-
}
33
-
}
-31
src/encoding_errors.rs
-31
src/encoding_errors.rs
···
1
-
use thiserror::Error;
2
-
3
-
/// Represents errors that can occur during data encoding and decoding operations.
4
-
///
5
-
/// These errors relate to serialization, deserialization, encoding and decoding
6
-
/// of data across various formats used throughout the application.
7
-
#[derive(Debug, Error)]
8
-
pub enum EncodingError {
9
-
/// Error when JSON serialization fails.
10
-
///
11
-
/// This error occurs when attempting to convert a Rust structure to
12
-
/// JSON format fails, typically due to data that cannot be properly
13
-
/// represented in JSON.
14
-
#[error("error-encoding-1 JSON serialization failed: {0:?}")]
15
-
JsonSerializationFailed(serde_json::Error),
16
-
17
-
/// Error when Base64 decoding fails.
18
-
///
19
-
/// This error occurs when attempting to decode Base64-encoded data
20
-
/// that is malformed or contains invalid characters.
21
-
#[error("error-encoding-2 Base64 decoding failed: {0:?}")]
22
-
Base64DecodingFailed(base64::DecodeError),
23
-
24
-
/// Error when JSON deserialization fails.
25
-
///
26
-
/// This error occurs when attempting to parse JSON data into a Rust
27
-
/// structure fails, typically due to missing fields, type mismatches,
28
-
/// or malformed JSON.
29
-
#[error("error-encoding-3 JSON deserialization failed: {0:?}")]
30
-
JsonDeserializationFailed(serde_json::Error),
31
-
}
+1
-1
src/http/cache_countries.rs
+1
-1
src/http/cache_countries.rs
···
4
4
5
5
static COUNTRY_CACHE: OnceCell<Arc<BTreeMap<String, String>>> = OnceCell::new();
6
6
7
-
pub fn cached_countries<'a>() -> Result<&'a Arc<BTreeMap<String, String>>> {
7
+
pub(crate) fn cached_countries<'a>() -> Result<&'a Arc<BTreeMap<String, String>>> {
8
8
if COUNTRY_CACHE.get().is_none() {
9
9
let all_countries: BTreeMap<String, String> = BTreeMap::from_iter(
10
10
[
+73
-32
src/http/context.rs
+73
-32
src/http/context.rs
···
1
+
use atproto_identity::axum::state::DidDocumentStorageExtractor;
2
+
use atproto_identity::key::KeyProvider;
3
+
use atproto_identity::resolve::IdentityResolver;
4
+
use atproto_identity::storage::DidDocumentStorage;
5
+
use atproto_oauth::storage::OAuthRequestStorage;
6
+
use atproto_oauth_axum::state::OAuthClientConfig;
1
7
use axum::extract::FromRef;
2
8
use axum::{
3
9
extract::FromRequestParts,
···
7
13
use axum_extra::extract::Cached;
8
14
use axum_template::engine::Engine;
9
15
use cookie::Key;
10
-
use hickory_resolver::TokioAsyncResolver;
11
16
use minijinja::context as template_context;
12
17
use std::{ops::Deref, sync::Arc};
13
18
use unic_langid::LanguageIdentifier;
···
26
31
http::middleware_auth::Auth,
27
32
http::middleware_i18n::Language,
28
33
i18n::Locales,
29
-
storage::handle::model::Handle,
34
+
storage::identity_profile::model::IdentityProfile,
30
35
storage::{CachePool, StoragePool},
31
36
};
32
37
33
38
#[cfg(feature = "embed")]
34
39
pub type AppEngine = Engine<Environment<'static>>;
35
40
36
-
pub struct I18nContext {
37
-
pub supported_languages: Vec<LanguageIdentifier>,
38
-
pub locales: Locales,
41
+
pub(crate) struct I18nContext {
42
+
pub(crate) supported_languages: Vec<LanguageIdentifier>,
43
+
pub(crate) locales: Locales,
39
44
}
40
45
41
46
pub struct InnerWebContext {
···
44
49
pub pool: StoragePool,
45
50
pub cache_pool: CachePool,
46
51
pub config: Config,
47
-
pub i18n_context: I18nContext,
48
-
pub dns_resolver: hickory_resolver::TokioAsyncResolver,
52
+
pub(crate) i18n_context: I18nContext,
53
+
pub(crate) oauth_client_config: atproto_oauth_axum::state::OAuthClientConfig,
54
+
pub(crate) identity_resolver: atproto_identity::resolve::IdentityResolver,
55
+
pub(crate) key_provider: Arc<dyn atproto_identity::key::KeyProvider>,
56
+
pub(crate) oauth_storage: Arc<dyn atproto_oauth::storage::OAuthRequestStorage>,
57
+
pub(crate) document_storage: Arc<dyn atproto_identity::storage::DidDocumentStorage>,
49
58
}
50
59
51
60
#[derive(Clone, FromRef)]
···
60
69
}
61
70
62
71
impl WebContext {
72
+
#[allow(clippy::too_many_arguments)]
63
73
pub fn new(
64
74
pool: StoragePool,
65
75
cache_pool: CachePool,
66
76
engine: AppEngine,
67
77
http_client: &reqwest::Client,
68
78
config: Config,
69
-
i18n_context: I18nContext,
70
-
dns_resolver: TokioAsyncResolver,
79
+
oauth_client_config: OAuthClientConfig,
80
+
identity_resolver: IdentityResolver,
81
+
key_provider: Arc<dyn KeyProvider>,
82
+
oauth_storage: Arc<dyn OAuthRequestStorage>,
83
+
document_storage: Arc<dyn DidDocumentStorage>,
84
+
supported_languages: Vec<LanguageIdentifier>,
85
+
locales: Locales,
71
86
) -> Self {
72
87
Self(Arc::new(InnerWebContext {
73
88
pool,
···
75
90
engine,
76
91
http_client: http_client.clone(),
77
92
config,
78
-
i18n_context,
79
-
dns_resolver,
93
+
i18n_context: I18nContext {
94
+
supported_languages,
95
+
locales,
96
+
},
97
+
oauth_client_config,
98
+
identity_resolver,
99
+
key_provider,
100
+
oauth_storage,
101
+
document_storage,
80
102
}))
81
103
}
82
104
}
83
105
84
-
impl I18nContext {
85
-
pub fn new(supported_languages: Vec<LanguageIdentifier>, locales: Locales) -> Self {
86
-
Self {
87
-
supported_languages,
88
-
locales,
89
-
}
106
+
impl FromRef<WebContext> for Key {
107
+
fn from_ref(context: &WebContext) -> Self {
108
+
context.0.config.http_cookie_key.as_ref().clone()
109
+
}
110
+
}
111
+
112
+
impl FromRef<WebContext> for IdentityResolver {
113
+
fn from_ref(context: &WebContext) -> Self {
114
+
context.0.identity_resolver.clone()
90
115
}
91
116
}
92
117
93
-
impl FromRef<WebContext> for Key {
118
+
impl FromRef<WebContext> for Arc<dyn KeyProvider> {
94
119
fn from_ref(context: &WebContext) -> Self {
95
-
context.0.config.http_cookie_key.as_ref().clone()
120
+
context.0.key_provider.clone()
121
+
}
122
+
}
123
+
124
+
impl FromRef<WebContext> for OAuthClientConfig {
125
+
fn from_ref(context: &WebContext) -> Self {
126
+
context.0.oauth_client_config.clone()
127
+
}
128
+
}
129
+
130
+
impl FromRef<WebContext> for DidDocumentStorageExtractor {
131
+
fn from_ref(context: &WebContext) -> Self {
132
+
DidDocumentStorageExtractor(context.0.document_storage.clone())
133
+
}
134
+
}
135
+
136
+
impl FromRef<WebContext> for Arc<dyn KeyProvider + Send + Sync> {
137
+
fn from_ref(context: &WebContext) -> Self {
138
+
context.0.key_provider.clone()
96
139
}
97
140
}
98
141
99
142
// New structs for reducing handler function arguments
100
143
101
144
/// A context struct specifically for admin handlers
102
-
pub struct AdminRequestContext {
103
-
pub web_context: WebContext,
104
-
pub language: Language,
105
-
pub admin_handle: Handle,
106
-
pub auth: Auth,
145
+
pub(crate) struct AdminRequestContext {
146
+
pub(crate) web_context: WebContext,
147
+
pub(crate) language: Language,
148
+
pub(crate) admin_handle: IdentityProfile,
107
149
}
108
150
109
151
impl<S> FromRequestParts<S> for AdminRequestContext
···
129
171
web_context,
130
172
language,
131
173
admin_handle,
132
-
auth: cached_auth.0,
133
174
})
134
175
}
135
176
}
136
177
137
178
/// Helper function to create standard template context for admin views
138
-
pub fn admin_template_context(
179
+
pub(crate) fn admin_template_context(
139
180
ctx: &AdminRequestContext,
140
181
canonical_url: &str,
141
182
) -> minijinja::value::Value {
···
147
188
}
148
189
149
190
/// A context struct for regular authenticated user handlers
150
-
pub struct UserRequestContext {
151
-
pub web_context: WebContext,
152
-
pub language: Language,
153
-
pub current_handle: Option<Handle>,
154
-
pub auth: Auth,
191
+
pub(crate) struct UserRequestContext {
192
+
pub(crate) web_context: WebContext,
193
+
pub(crate) language: Language,
194
+
pub(crate) current_handle: Option<IdentityProfile>,
195
+
pub(crate) auth: Auth,
155
196
}
156
197
157
198
impl<S> FromRequestParts<S> for UserRequestContext
···
170
211
Ok(Self {
171
212
web_context,
172
213
language,
173
-
current_handle: cached_auth.0 .0.clone(),
214
+
current_handle: cached_auth.0.profile().cloned(),
174
215
auth: cached_auth.0,
175
216
})
176
217
}
+2
-2
src/http/errors/admin_errors.rs
+2
-2
src/http/errors/admin_errors.rs
···
3
3
/// These errors relate to the process of importing RSVP data into the system
4
4
/// by administrators, typically during data migration or recovery.
5
5
#[derive(Debug, Error)]
6
-
pub enum AdminImportRsvpError {
6
+
pub(crate) enum AdminImportRsvpError {
7
7
/// Error when an RSVP cannot be inserted during import.
8
8
///
9
9
/// This error occurs when attempting to insert an imported RSVP into
···
16
16
/// These errors relate to the process of importing event data into the system
17
17
/// by administrators, typically during data migration or recovery.
18
18
#[derive(Debug, Error)]
19
-
pub enum AdminImportEventError {
19
+
pub(crate) enum AdminImportEventError {
20
20
/// Error when an event cannot be inserted during import.
21
21
///
22
22
/// This error occurs when attempting to insert an imported event into
+1
-8
src/http/errors/common_error.rs
+1
-8
src/http/errors/common_error.rs
···
2
2
3
3
/// Represents common errors that can occur across various HTTP handlers.
4
4
#[derive(Debug, Error)]
5
-
pub enum CommonError {
5
+
pub(crate) enum CommonError {
6
6
/// Error when a handle slug is invalid.
7
7
///
8
8
/// This error occurs when a URL contains a handle slug that doesn't conform
···
30
30
/// that is needed to complete the operation.
31
31
#[error("error-common-4 Required field not provided")]
32
32
FieldRequired,
33
-
34
-
/// Error when data has an invalid format or is corrupted.
35
-
///
36
-
/// This error occurs when input data doesn't match the expected format
37
-
/// or appears to be corrupted or tampered with.
38
-
#[error("error-common-5 Invalid format or corrupted data")]
39
-
InvalidFormat,
40
33
41
34
/// Error when an AT Protocol URI has an invalid format.
42
35
///
+1
-22
src/http/errors/create_event_errors.rs
+1
-22
src/http/errors/create_event_errors.rs
···
5
5
/// These errors are typically triggered during validation of user-submitted
6
6
/// event creation forms.
7
7
#[derive(Debug, Error)]
8
-
pub enum CreateEventError {
8
+
pub(crate) enum CreateEventError {
9
9
/// Error when the event name is not provided.
10
10
///
11
11
/// This error occurs when a user attempts to create an event without
···
19
19
/// specifying a description, which is a required field.
20
20
#[error("error-create-event-2 Description not set")]
21
21
DescriptionNotSet,
22
-
23
-
/// Error when the event dates are invalid.
24
-
///
25
-
/// This error occurs when the provided event dates are invalid, such as when
26
-
/// the end date is before the start date, or dates are outside allowed ranges.
27
-
#[error("error-create-event-3 Invalid event dates")]
28
-
InvalidEventDates,
29
-
30
-
/// Error when an invalid event mode is specified.
31
-
///
32
-
/// This error occurs when the provided event mode doesn't match one of the
33
-
/// supported values (e.g., "in-person", "online", "hybrid").
34
-
#[error("error-create-event-4 Invalid event mode")]
35
-
InvalidEventMode,
36
-
37
-
/// Error when an invalid event status is specified.
38
-
///
39
-
/// This error occurs when the provided event status doesn't match one of the
40
-
/// supported values (e.g., "confirmed", "tentative", "cancelled").
41
-
#[error("error-create-event-5 Invalid event status")]
42
-
InvalidEventStatus,
43
22
}
+1
-15
src/http/errors/edit_event_error.rs
+1
-15
src/http/errors/edit_event_error.rs
···
5
5
/// These errors typically happen when users attempt to modify existing events
6
6
/// and encounter authorization, validation, or type compatibility issues.
7
7
#[derive(Debug, Error)]
8
-
pub enum EditEventError {
9
-
/// Error when an invalid handle slug is provided.
10
-
///
11
-
/// This error occurs when attempting to edit an event with a handle slug
12
-
/// that is not properly formatted or does not exist in the system.
13
-
#[error("error-edit-event-1 Invalid handle slug")]
14
-
InvalidHandleSlug,
15
-
8
+
pub(crate) enum EditEventError {
16
9
/// Error when a user is not authorized to edit an event.
17
10
///
18
11
/// This error occurs when a user attempts to edit an event that they
···
44
37
/// that has a location type that is not supported for editing through the web interface.
45
38
#[error("error-edit-event-5 Cannot edit locations: Event has unsupported location type")]
46
39
UnsupportedLocationType,
47
-
48
-
/// Error when attempting to edit location data on an event that has no locations.
49
-
///
50
-
/// This error occurs when a user attempts to add location information to an event
51
-
/// that was not created with location information.
52
-
#[error("error-edit-event-6 Cannot edit locations: Event has no locations")]
53
-
NoLocationsPresent,
54
40
}
+1
-1
src/http/errors/event_view_errors.rs
+1
-1
src/http/errors/event_view_errors.rs
···
5
5
/// These errors typically happen when retrieving and displaying event data
6
6
/// to users, including data validation and enhancement issues.
7
7
#[derive(Debug, Error)]
8
-
pub enum EventViewError {
8
+
pub(crate) enum EventViewError {
9
9
/// Error when an invalid collection is specified.
10
10
///
11
11
/// This error occurs when an event view request specifies a collection
+1
-1
src/http/errors/import_error.rs
+1
-1
src/http/errors/import_error.rs
···
5
5
/// These errors typically happen when attempting to import events and RSVPs
6
6
/// from different sources, including community and Smokesignal systems.
7
7
#[derive(Debug, Error)]
8
-
pub enum ImportError {
8
+
pub(crate) enum ImportError {
9
9
/// Error when listing community events fails.
10
10
///
11
11
/// This error occurs when attempting to retrieve a list of community
+2
-6
src/http/errors/login_error.rs
+2
-6
src/http/errors/login_error.rs
···
5
5
/// These errors typically happen during the authentication process when users
6
6
/// are logging in to the application, including OAuth flows and DID validation.
7
7
#[derive(Debug, Error)]
8
-
pub enum LoginError {
9
-
/// Error when a DID document does not contain a handle.
10
-
///
11
-
/// This error occurs during authentication when the user's DID document
12
-
/// is retrieved but does not contain a required handle identifier.
13
-
#[error("error-login-1 DID document does not contain a handle")]
8
+
pub(crate) enum LoginError {
9
+
#[error("error-login-1 DID document does not contain a handle identifier")]
14
10
NoHandle,
15
11
16
12
/// Error when a DID document does not contain a PDS endpoint.
+4
-4
src/http/errors/middleware_errors.rs
+4
-4
src/http/errors/middleware_errors.rs
···
9
9
/// These errors are related to the serialization and deserialization of
10
10
/// web session data used for maintaining user authentication state.
11
11
#[derive(Debug, Error)]
12
-
pub enum WebSessionError {
12
+
pub(crate) enum WebSessionError {
13
13
/// Error when web session deserialization fails.
14
14
///
15
15
/// This error occurs when attempting to deserialize a web session from JSON
···
30
30
/// These errors typically happen in the authentication middleware layer when
31
31
/// processing requests, including cryptographic operations and session validation.
32
32
#[derive(Debug, Error)]
33
-
pub enum AuthMiddlewareError {
33
+
pub(crate) enum AuthMiddlewareError {
34
34
/// Error when content signing fails.
35
35
///
36
36
/// This error occurs when the authentication middleware attempts to
37
37
/// cryptographically sign content but the operation fails.
38
38
#[error("error-authmiddleware-1 Unable to sign content: {0:?}")]
39
-
SigningFailed(p256::ecdsa::Error),
39
+
SigningFailed(anyhow::Error),
40
40
}
41
41
42
42
#[derive(Debug, Error)]
43
-
pub enum MiddlewareAuthError {
43
+
pub(crate) enum MiddlewareAuthError {
44
44
#[error("error-middleware-auth-1 Access Denied: {0}")]
45
45
AccessDenied(String),
46
46
+1
-8
src/http/errors/migrate_event_error.rs
+1
-8
src/http/errors/migrate_event_error.rs
···
5
5
/// These errors typically happen when attempting to migrate events between
6
6
/// different formats or systems, such as from smokesignal to community events.
7
7
#[derive(Debug, Error)]
8
-
pub enum MigrateEventError {
9
-
/// Error when an invalid handle slug is provided.
10
-
///
11
-
/// This error occurs when attempting to migrate an event with a handle slug
12
-
/// that is not properly formatted or does not exist in the system.
13
-
#[error("error-migrate-event-1 Invalid handle slug")]
14
-
InvalidHandleSlug,
15
-
8
+
pub(crate) enum MigrateEventError {
16
9
/// Error when a user is not authorized to migrate an event.
17
10
///
18
11
/// This error occurs when a user attempts to migrate an event that they
+1
-1
src/http/errors/migrate_rsvp_error.rs
+1
-1
src/http/errors/migrate_rsvp_error.rs
···
5
5
/// These errors relate to the migration or conversion of RSVP data
6
6
/// between different systems, formats, or versions.
7
7
#[derive(Debug, Error)]
8
-
pub enum MigrateRsvpError {
8
+
pub(crate) enum MigrateRsvpError {
9
9
/// Error when an invalid RSVP status is provided during migration.
10
10
///
11
11
/// This error occurs when attempting to migrate an RSVP with a status
+14
-14
src/http/errors/mod.rs
+14
-14
src/http/errors/mod.rs
···
14
14
pub mod view_event_error;
15
15
pub mod web_error;
16
16
17
-
pub use admin_errors::{AdminImportEventError, AdminImportRsvpError};
18
-
pub use common_error::CommonError;
19
-
pub use create_event_errors::CreateEventError;
20
-
pub use edit_event_error::EditEventError;
21
-
pub use event_view_errors::EventViewError;
22
-
pub use import_error::ImportError;
23
-
pub use login_error::LoginError;
24
-
pub use middleware_errors::{AuthMiddlewareError, WebSessionError};
25
-
pub use migrate_event_error::MigrateEventError;
26
-
pub use migrate_rsvp_error::MigrateRsvpError;
27
-
pub use rsvp_error::RSVPError;
28
-
pub use url_error::UrlError;
29
-
pub use view_event_error::ViewEventError;
30
-
pub use web_error::WebError;
17
+
pub(crate) use admin_errors::{AdminImportEventError, AdminImportRsvpError};
18
+
pub(crate) use common_error::CommonError;
19
+
pub(crate) use create_event_errors::CreateEventError;
20
+
pub(crate) use edit_event_error::EditEventError;
21
+
pub(crate) use event_view_errors::EventViewError;
22
+
pub(crate) use import_error::ImportError;
23
+
pub(crate) use login_error::LoginError;
24
+
pub(crate) use middleware_errors::{AuthMiddlewareError, WebSessionError};
25
+
pub(crate) use migrate_event_error::MigrateEventError;
26
+
pub(crate) use migrate_rsvp_error::MigrateRsvpError;
27
+
pub(crate) use rsvp_error::RSVPError;
28
+
pub(crate) use url_error::UrlError;
29
+
pub(crate) use view_event_error::ViewEventError;
30
+
pub(crate) use web_error::WebError;
+1
-1
src/http/errors/rsvp_error.rs
+1
-1
src/http/errors/rsvp_error.rs
···
5
5
/// These errors relate to the handling of event RSVPs, such as
6
6
/// when users attempt to respond to event invitations.
7
7
#[derive(Debug, Error)]
8
-
pub enum RSVPError {
8
+
pub(crate) enum RSVPError {
9
9
/// Error when an RSVP cannot be found.
10
10
///
11
11
/// This error occurs when attempting to retrieve or modify an RSVP
+1
-1
src/http/errors/url_error.rs
+1
-1
src/http/errors/url_error.rs
···
5
5
/// These errors typically happen when validating or processing URLs in the system,
6
6
/// including checking for supported collection types in URL paths.
7
7
#[derive(Debug, Error)]
8
-
pub enum UrlError {
8
+
pub(crate) enum UrlError {
9
9
/// Error when an unsupported collection type is specified in a URL.
10
10
///
11
11
/// This error occurs when a URL contains a collection type that is not
+1
-8
src/http/errors/view_event_error.rs
+1
-8
src/http/errors/view_event_error.rs
···
5
5
/// These errors typically happen when retrieving or displaying specific
6
6
/// event data, including lookup failures and data enhancement issues.
7
7
#[derive(Debug, Error)]
8
-
pub enum ViewEventError {
8
+
pub(crate) enum ViewEventError {
9
9
/// Error when a requested event cannot be found.
10
10
///
11
11
/// This error occurs when attempting to view an event that doesn't
···
19
19
/// and the fallback method also fails to retrieve the event.
20
20
#[error("error-view-event-2 Failed to get event from fallback: {0}")]
21
21
FallbackFailed(String),
22
-
23
-
/// Error when fetching event details fails.
24
-
///
25
-
/// This error occurs when the system fails to retrieve additional
26
-
/// details for an event, such as RSVP counts or related data.
27
-
#[error("error-view-event-3 Failed to fetch event details: {0}")]
28
-
FetchEventDetailsFailed(String),
29
22
}
+2
-30
src/http/errors/web_error.rs
+2
-30
src/http/errors/web_error.rs
···
35
35
/// and error code, while a few web-specific errors have their own error code format:
36
36
/// `error-web-<number> <message>: <details>`
37
37
#[derive(Debug, Error)]
38
-
pub enum WebError {
38
+
pub(crate) enum WebError {
39
39
/// Error when authentication middleware fails.
40
40
///
41
41
/// This error occurs when there are issues with verifying a user's identity
···
112
112
/// This error occurs during RSVP operations such as creation, updating,
113
113
/// or retrieval of RSVPs.
114
114
#[error(transparent)]
115
-
RSVP(#[from] RSVPError),
115
+
Rsvp(#[from] RSVPError),
116
116
117
117
/// Cache operation errors.
118
118
///
···
121
121
#[error(transparent)]
122
122
Cache(#[from] crate::storage::errors::CacheError),
123
123
124
-
/// JSON Web Key errors.
125
-
///
126
-
/// This error occurs when there are issues with cryptographic keys,
127
-
/// such as missing or invalid keys.
128
-
#[error(transparent)]
129
-
JwkError(#[from] crate::jose_errors::JwkError),
130
-
131
-
/// JSON Object Signing and Encryption errors.
132
-
///
133
-
/// This error occurs when there are issues with JWT operations,
134
-
/// such as signature validation or token creation.
135
-
#[error(transparent)]
136
-
JoseError(#[from] crate::jose_errors::JoseError),
137
-
138
124
/// Configuration errors.
139
125
///
140
126
/// This error occurs when there are issues with application configuration,
141
127
/// such as missing environment variables or invalid settings.
142
128
#[error(transparent)]
143
129
ConfigError(#[from] crate::config_errors::ConfigError),
144
-
145
-
/// Data encoding/decoding errors.
146
-
///
147
-
/// This error occurs when there are issues with data encoding or decoding,
148
-
/// such as invalid Base64 or JSON parsing problems.
149
-
#[error(transparent)]
150
-
EncodingError(#[from] crate::encoding_errors::EncodingError),
151
-
152
-
/// AT Protocol URI errors.
153
-
///
154
-
/// This error occurs when there are issues with AT Protocol URIs,
155
-
/// such as malformed DIDs or invalid handles.
156
-
#[error(transparent)]
157
-
UriError(#[from] crate::atproto::errors::UriError),
158
130
159
131
/// Database storage errors.
160
132
///
+6
-6
src/http/event_form.rs
+6
-6
src/http/event_form.rs
···
6
6
use super::cache_countries::cached_countries;
7
7
8
8
#[derive(Debug, Error)]
9
-
pub enum BuildEventError {
9
+
pub(crate) enum BuildEventError {
10
10
#[error("error-event-builder-1 Invalid Name")]
11
11
InvalidName,
12
12
···
60
60
}
61
61
62
62
#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)]
63
-
pub enum BuildEventContentState {
63
+
pub(crate) enum BuildEventContentState {
64
64
#[default]
65
65
Reset,
66
66
Selecting,
···
68
68
}
69
69
70
70
#[derive(Serialize, Deserialize, Debug, Clone)]
71
-
pub struct BuildStartsForm {
71
+
pub(crate) struct BuildStartsForm {
72
72
pub build_state: Option<BuildEventContentState>,
73
73
74
74
pub tz: Option<String>,
···
99
99
}
100
100
101
101
#[derive(Serialize, Deserialize, Debug, Clone)]
102
-
pub struct BuildLocationForm {
102
+
pub(crate) struct BuildLocationForm {
103
103
pub build_state: Option<BuildEventContentState>,
104
104
105
105
pub location_country: Option<String>,
···
122
122
}
123
123
124
124
#[derive(Serialize, Deserialize, Debug, Clone)]
125
-
pub struct BuildLinkForm {
125
+
pub(crate) struct BuildLinkForm {
126
126
pub build_state: Option<BuildEventContentState>,
127
127
128
128
pub link_name: Option<String>,
···
133
133
}
134
134
135
135
#[derive(Serialize, Deserialize, Debug, Clone)]
136
-
pub struct BuildEventForm {
136
+
pub(crate) struct BuildEventForm {
137
137
pub build_state: Option<BuildEventContentState>,
138
138
139
139
pub name: Option<String>,
+12
-12
src/http/event_view.rs
+12
-12
src/http/event_view.rs
···
1
1
use std::collections::HashSet;
2
+
use std::str::FromStr;
2
3
3
4
use ammonia::Builder;
4
5
use anyhow::Result;
6
+
use atproto_record::aturi::ATURI;
5
7
use chrono_tz::Tz;
6
8
use cityhasher::HashMap;
7
9
use serde::Serialize;
···
9
11
use crate::http::errors::EventViewError;
10
12
11
13
use crate::{
12
-
atproto::{
13
-
lexicon::{
14
-
community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID,
15
-
events::smokesignal::calendar::event::NSID as SmokeSignalEventNSID,
16
-
},
17
-
uri::parse_aturi,
14
+
atproto::lexicon::{
15
+
community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID,
16
+
events::smokesignal::calendar::event::NSID as SmokeSignalEventNSID,
18
17
},
19
18
http::utils::truncate_text,
20
19
storage::{
···
23
22
count_event_rsvps, extract_event_details, get_event_rsvp_counts,
24
23
model::{Event, EventWithRole},
25
24
},
26
-
handle::{handles_by_did, model::Handle},
25
+
identity_profile::{handles_by_did, model::IdentityProfile},
27
26
StoragePool,
28
27
},
29
28
};
···
58
57
pub links: Vec<(String, Option<String>)>, // (uri, name)
59
58
}
60
59
61
-
impl TryFrom<(Option<&Handle>, Option<&Handle>, &Event)> for EventView {
60
+
impl TryFrom<(Option<&IdentityProfile>, Option<&IdentityProfile>, &Event)> for EventView {
62
61
type Error = anyhow::Error;
63
62
64
63
fn try_from(
65
-
(viewer, organizer, event): (Option<&Handle>, Option<&Handle>, &Event),
64
+
(viewer, organizer, event): (Option<&IdentityProfile>, Option<&IdentityProfile>, &Event),
66
65
) -> Result<Self, Self::Error> {
67
66
// Time zones are used to display date/time values from the perspective
68
67
// of the viewer. The timezone is selected with this priority:
···
79
78
}
80
79
.unwrap_or(Tz::UTC);
81
80
82
-
let (repository, collection, rkey) = parse_aturi(event.aturi.as_str())?;
81
+
let aturi = ATURI::from_str(&event.aturi)?;
82
+
let (repository, collection, rkey) = (aturi.authority, aturi.collection, aturi.record_key);
83
83
84
84
// We now support both community and smokesignal event formats
85
85
if collection != LexiconCommunityEventNSID && collection != SmokeSignalEventNSID {
···
221
221
}
222
222
}
223
223
224
-
pub async fn hydrate_event_organizers(
224
+
pub(crate) async fn hydrate_event_organizers(
225
225
pool: &StoragePool,
226
226
events: &[EventWithRole],
227
-
) -> Result<HashMap<std::string::String, Handle>> {
227
+
) -> Result<HashMap<std::string::String, IdentityProfile>> {
228
228
if events.is_empty() {
229
229
return Ok(HashMap::default());
230
230
}
+8
-8
src/http/handle_admin_denylist.rs
+8
-8
src/http/handle_admin_denylist.rs
···
21
21
};
22
22
23
23
#[derive(Debug, Deserialize)]
24
-
pub struct DenylistAddForm {
25
-
pub subject: String,
26
-
pub reason: String,
24
+
pub(crate) struct DenylistAddForm {
25
+
subject: String,
26
+
reason: String,
27
27
}
28
28
29
29
#[derive(Debug, Deserialize)]
30
-
pub struct DenylistRemoveForm {
31
-
pub subject: String,
30
+
pub(crate) struct DenylistRemoveForm {
31
+
subject: String,
32
32
}
33
33
34
-
pub async fn handle_admin_denylist(
34
+
pub(crate) async fn handle_admin_denylist(
35
35
admin_ctx: AdminRequestContext,
36
36
pagination: Query<Pagination>,
37
37
) -> Result<impl IntoResponse, WebError> {
···
78
78
.into_response())
79
79
}
80
80
81
-
pub async fn handle_admin_denylist_add(
81
+
pub(crate) async fn handle_admin_denylist_add(
82
82
admin_ctx: AdminRequestContext,
83
83
Form(form): Form<DenylistAddForm>,
84
84
) -> Result<impl IntoResponse, WebError> {
···
103
103
Ok(Redirect::to("/admin/denylist").into_response())
104
104
}
105
105
106
-
pub async fn handle_admin_denylist_remove(
106
+
pub(crate) async fn handle_admin_denylist_remove(
107
107
admin_ctx: AdminRequestContext,
108
108
Form(form): Form<DenylistRemoveForm>,
109
109
) -> Result<impl IntoResponse, WebError> {
+3
-3
src/http/handle_admin_event.rs
+3
-3
src/http/handle_admin_event.rs
···
12
12
};
13
13
14
14
#[derive(Deserialize)]
15
-
pub struct EventRecordQuery {
16
-
pub aturi: String,
15
+
pub(crate) struct EventRecordQuery {
16
+
pub(crate) aturi: String,
17
17
}
18
18
19
-
pub async fn handle_admin_event(
19
+
pub(crate) async fn handle_admin_event(
20
20
admin_ctx: AdminRequestContext,
21
21
Query(query): Query<EventRecordQuery>,
22
22
) -> Result<impl IntoResponse, WebError> {
+1
-1
src/http/handle_admin_events.rs
+1
-1
src/http/handle_admin_events.rs
+3
-3
src/http/handle_admin_handles.rs
+3
-3
src/http/handle_admin_handles.rs
···
16
16
pagination::{Pagination, PaginationView},
17
17
},
18
18
select_template,
19
-
storage::handle::{handle_list, handle_nuke},
19
+
storage::identity_profile::{handle_list, handle_nuke},
20
20
};
21
21
22
-
pub async fn handle_admin_handles(
22
+
pub(crate) async fn handle_admin_handles(
23
23
admin_ctx: AdminRequestContext,
24
24
pagination: Query<Pagination>,
25
25
) -> Result<impl IntoResponse, WebError> {
···
66
66
.into_response())
67
67
}
68
68
69
-
pub async fn handle_admin_nuke_identity(
69
+
pub(crate) async fn handle_admin_nuke_identity(
70
70
admin_ctx: AdminRequestContext,
71
71
HxRequest(hx_request): HxRequest,
72
72
Path(did): Path<String>,
+62
-69
src/http/handle_admin_import_event.rs
+62
-69
src/http/handle_admin_import_event.rs
···
1
+
use std::str::FromStr;
2
+
1
3
use anyhow::Result;
4
+
use atproto_identity::resolve::IdentityResolver;
5
+
use atproto_record::aturi::ATURI;
2
6
use axum::{
3
7
extract::Form,
4
8
response::{IntoResponse, Redirect},
···
6
10
use serde::Deserialize;
7
11
8
12
use crate::{
9
-
atproto::{
10
-
lexicon::{
11
-
community::lexicon::calendar::event::Event as CommunityEventLexicon,
12
-
events::smokesignal::calendar::event::{Event as SmokeSignalEvent, EventResponse},
13
-
},
14
-
uri::parse_aturi,
13
+
atproto::lexicon::{
14
+
community::lexicon::calendar::event::Event as CommunityEventLexicon,
15
+
events::smokesignal::calendar::event::{Event as SmokeSignalEvent, EventResponse},
15
16
},
16
17
contextual_error,
17
18
http::{
18
19
context::{admin_template_context, AdminRequestContext},
19
20
errors::{AdminImportEventError, CommonError, LoginError, WebError},
20
21
},
21
-
resolve::{parse_input, resolve_subject, InputType},
22
22
select_template,
23
-
storage::{event::event_insert_with_metadata, handle::handle_warm_up},
23
+
storage::{event::event_insert_with_metadata, identity_profile::handle_warm_up},
24
24
};
25
25
26
26
#[derive(Deserialize)]
27
-
pub struct ImportEventForm {
28
-
pub aturi: String,
27
+
pub(crate) struct ImportEventForm {
28
+
pub(crate) aturi: String,
29
29
}
30
30
31
-
pub async fn handle_admin_import_event(
31
+
pub(crate) async fn handle_admin_import_event(
32
32
admin_ctx: AdminRequestContext,
33
+
identity_resolver: IdentityResolver,
33
34
Form(form): Form<ImportEventForm>,
34
35
) -> Result<impl IntoResponse, WebError> {
35
36
// Admin access is already verified by the extractor
···
43
44
44
45
// Parse the AT-URI
45
46
let aturi = form.aturi.trim();
46
-
let (repository, collection, rkey) = match parse_aturi(aturi) {
47
-
Ok(parsed) => parsed,
47
+
let (repository, collection, rkey) = match ATURI::from_str(aturi) {
48
+
Ok(aturi) => (aturi.authority, aturi.collection, aturi.record_key),
48
49
Err(_err) => {
49
50
return contextual_error!(
50
51
admin_ctx.web_context,
···
71
72
}
72
73
};
73
74
74
-
// Resolve the DID for the repository
75
-
let input_type = match parse_input(&repository) {
76
-
Ok(input) => input,
77
-
Err(_err) => {
75
+
let document = match identity_resolver.resolve(&repository).await {
76
+
Ok(value) => value,
77
+
Err(err) => {
78
78
return contextual_error!(
79
79
admin_ctx.web_context,
80
80
admin_ctx.language,
81
81
error_template,
82
82
default_context,
83
-
CommonError::FailedToParse
83
+
err
84
84
);
85
85
}
86
86
};
87
87
88
-
let did = match input_type {
89
-
InputType::Handle(handle) => {
90
-
match resolve_subject(
91
-
&admin_ctx.web_context.http_client,
92
-
&admin_ctx.web_context.dns_resolver,
93
-
&handle,
94
-
)
95
-
.await
96
-
{
97
-
Ok(did) => did,
98
-
Err(_err) => {
99
-
return contextual_error!(
100
-
admin_ctx.web_context,
101
-
admin_ctx.language,
102
-
error_template,
103
-
default_context,
104
-
CommonError::FailedToParse
105
-
);
106
-
}
107
-
}
108
-
}
109
-
InputType::Plc(did) | InputType::Web(did) => did,
110
-
};
111
-
112
-
// Get the DID document to find the PDS endpoint
113
-
let did_doc = match crate::did::plc::query(
114
-
&admin_ctx.web_context.http_client,
115
-
&admin_ctx.web_context.config.plc_hostname,
116
-
&did,
117
-
)
118
-
.await
88
+
let handle = match document
89
+
.handles()
90
+
.ok_or(WebError::Login(LoginError::NoHandle))
119
91
{
120
-
Ok(doc) => doc,
121
-
Err(_err) => {
92
+
Ok(value) => value,
93
+
Err(err) => {
122
94
return contextual_error!(
123
95
admin_ctx.web_context,
124
96
admin_ctx.language,
125
97
error_template,
126
98
default_context,
127
-
CommonError::FailedToParse
99
+
err
128
100
);
129
101
}
130
102
};
131
103
132
-
// Insert the handle if it doesn't exist
133
-
if let Some(handle) = did_doc.primary_handle() {
134
-
if let Some(pds) = did_doc.pds_endpoint() {
135
-
if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &did, handle, pds).await {
136
-
tracing::warn!("Failed to insert handle: {}", err);
137
-
}
138
-
}
139
-
}
140
-
141
-
// Get the PDS endpoint
142
-
let pds_endpoint = match did_doc.pds_endpoint() {
143
-
Some(endpoint) => endpoint,
144
-
None => {
104
+
let pds = match document
105
+
.pds_endpoints()
106
+
.first()
107
+
.cloned()
108
+
.ok_or(WebError::Login(LoginError::NoPDS))
109
+
{
110
+
Ok(value) => value,
111
+
Err(err) => {
145
112
return contextual_error!(
146
113
admin_ctx.web_context,
147
114
admin_ctx.language,
148
115
error_template,
149
116
default_context,
150
-
WebError::Login(LoginError::NoPDS)
117
+
err
151
118
);
152
119
}
153
120
};
154
121
122
+
if let Err(err) = admin_ctx
123
+
.web_context
124
+
.document_storage
125
+
.store_document(document.clone())
126
+
.await
127
+
{
128
+
return contextual_error!(
129
+
admin_ctx.web_context,
130
+
admin_ctx.language,
131
+
error_template,
132
+
default_context,
133
+
err
134
+
);
135
+
}
136
+
137
+
// Insert the handle if it doesn't exist
138
+
if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &document.id, handle, pds).await {
139
+
return contextual_error!(
140
+
admin_ctx.web_context,
141
+
admin_ctx.language,
142
+
error_template,
143
+
default_context,
144
+
err
145
+
);
146
+
}
147
+
155
148
// Construct the XRPC request to get the record
156
149
let url = format!(
157
150
"{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}",
158
-
pds_endpoint, did, collection, rkey
151
+
pds, document.id, collection, rkey
159
152
);
160
153
161
154
let response = match admin_ctx.web_context.http_client.get(&url).send().await {
···
197
190
&admin_ctx.web_context.pool,
198
191
aturi,
199
192
&record.cid,
200
-
&did,
193
+
&document.id,
201
194
"events.smokesignal.calendar.event",
202
195
&record.value,
203
196
&name,
···
246
239
&admin_ctx.web_context.pool,
247
240
aturi,
248
241
&record.cid,
249
-
&did,
242
+
&document.id,
250
243
"community.lexicon.calendar.event",
251
244
&record.value,
252
245
&name,
+64
-71
src/http/handle_admin_import_rsvp.rs
+64
-71
src/http/handle_admin_import_rsvp.rs
···
1
+
use std::str::FromStr;
2
+
1
3
use anyhow::Result;
4
+
use atproto_identity::resolve::IdentityResolver;
5
+
use atproto_record::aturi::ATURI;
2
6
use axum::{
3
7
extract::Form,
4
8
response::{IntoResponse, Redirect},
···
7
11
use urlencoding;
8
12
9
13
use crate::{
10
-
atproto::{
11
-
lexicon::{
12
-
community::lexicon::calendar::rsvp::{
13
-
Rsvp as CommunityRsvpLexicon, RsvpStatus as CommunityRsvpStatusLexicon,
14
-
NSID as COMMUNITY_RSVP_NSID,
15
-
},
16
-
events::smokesignal::calendar::rsvp::{
17
-
Rsvp as SmokesignalRsvpLexicon, NSID as SMOKESIGNAL_RSVP_NSID,
18
-
},
14
+
atproto::lexicon::{
15
+
community::lexicon::calendar::rsvp::{
16
+
Rsvp as CommunityRsvpLexicon, RsvpStatus as CommunityRsvpStatusLexicon,
17
+
NSID as COMMUNITY_RSVP_NSID,
19
18
},
20
-
uri::parse_aturi,
19
+
events::smokesignal::calendar::rsvp::{
20
+
Rsvp as SmokesignalRsvpLexicon, NSID as SMOKESIGNAL_RSVP_NSID,
21
+
},
21
22
},
22
23
contextual_error,
23
24
http::{
24
25
context::{admin_template_context, AdminRequestContext},
25
26
errors::{AdminImportRsvpError, CommonError, LoginError, WebError},
26
27
},
27
-
resolve::{parse_input, resolve_subject, InputType},
28
28
select_template,
29
-
storage::{event::rsvp_insert_with_metadata, handle::handle_warm_up},
29
+
storage::{event::rsvp_insert_with_metadata, identity_profile::handle_warm_up},
30
30
};
31
31
32
32
#[derive(Deserialize)]
···
34
34
pub aturi: String,
35
35
}
36
36
37
-
pub async fn handle_admin_import_rsvp(
37
+
pub(crate) async fn handle_admin_import_rsvp(
38
38
admin_ctx: AdminRequestContext,
39
+
identity_resolver: IdentityResolver,
39
40
Form(form): Form<ImportRsvpForm>,
40
41
) -> Result<impl IntoResponse, WebError> {
41
42
// Admin access is already verified by the extractor
···
49
50
50
51
// Parse the AT-URI
51
52
let aturi = form.aturi.trim();
52
-
let (repository, collection, rkey) = match parse_aturi(aturi) {
53
-
Ok(parsed) => parsed,
53
+
let (repository, collection, rkey) = match ATURI::from_str(aturi) {
54
+
Ok(aturi) => (aturi.authority, aturi.collection, aturi.record_key),
54
55
Err(_err) => {
55
56
return contextual_error!(
56
57
admin_ctx.web_context,
···
77
78
}
78
79
};
79
80
80
-
// Resolve the DID for the repository
81
-
let input_type = match parse_input(&repository) {
82
-
Ok(input) => input,
83
-
Err(_err) => {
81
+
let document = match identity_resolver.resolve(&repository).await {
82
+
Ok(value) => value,
83
+
Err(err) => {
84
84
return contextual_error!(
85
85
admin_ctx.web_context,
86
86
admin_ctx.language,
87
87
error_template,
88
88
default_context,
89
-
CommonError::FailedToParse
89
+
err
90
90
);
91
91
}
92
92
};
93
93
94
-
let did = match input_type {
95
-
InputType::Handle(handle) => {
96
-
match resolve_subject(
97
-
&admin_ctx.web_context.http_client,
98
-
&admin_ctx.web_context.dns_resolver,
99
-
&handle,
100
-
)
101
-
.await
102
-
{
103
-
Ok(did) => did,
104
-
Err(_err) => {
105
-
return contextual_error!(
106
-
admin_ctx.web_context,
107
-
admin_ctx.language,
108
-
error_template,
109
-
default_context,
110
-
CommonError::FailedToParse
111
-
);
112
-
}
113
-
}
114
-
}
115
-
InputType::Plc(did) | InputType::Web(did) => did,
116
-
};
117
-
118
-
// Get the DID document to find the PDS endpoint
119
-
let did_doc = match crate::did::plc::query(
120
-
&admin_ctx.web_context.http_client,
121
-
&admin_ctx.web_context.config.plc_hostname,
122
-
&did,
123
-
)
124
-
.await
94
+
let handle = match document
95
+
.handles()
96
+
.ok_or(WebError::Login(LoginError::NoHandle))
125
97
{
126
-
Ok(doc) => doc,
127
-
Err(_err) => {
98
+
Ok(value) => value,
99
+
Err(err) => {
128
100
return contextual_error!(
129
101
admin_ctx.web_context,
130
102
admin_ctx.language,
131
103
error_template,
132
104
default_context,
133
-
CommonError::FailedToParse
105
+
err
134
106
);
135
107
}
136
108
};
137
109
138
-
// Insert the handle if it doesn't exist
139
-
if let Some(handle) = did_doc.primary_handle() {
140
-
if let Some(pds) = did_doc.pds_endpoint() {
141
-
if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &did, handle, pds).await {
142
-
tracing::warn!("Failed to insert handle: {}", err);
143
-
}
144
-
}
145
-
}
146
-
147
-
// Get the PDS endpoint
148
-
let pds_endpoint = match did_doc.pds_endpoint() {
149
-
Some(endpoint) => endpoint,
150
-
None => {
110
+
let pds = match document
111
+
.pds_endpoints()
112
+
.first()
113
+
.cloned()
114
+
.ok_or(WebError::Login(LoginError::NoPDS))
115
+
{
116
+
Ok(value) => value,
117
+
Err(err) => {
151
118
return contextual_error!(
152
119
admin_ctx.web_context,
153
120
admin_ctx.language,
154
121
error_template,
155
122
default_context,
156
-
WebError::Login(LoginError::NoPDS)
123
+
err
157
124
);
158
125
}
159
126
};
160
127
128
+
if let Err(err) = admin_ctx
129
+
.web_context
130
+
.document_storage
131
+
.store_document(document.clone())
132
+
.await
133
+
{
134
+
return contextual_error!(
135
+
admin_ctx.web_context,
136
+
admin_ctx.language,
137
+
error_template,
138
+
default_context,
139
+
err
140
+
);
141
+
}
142
+
143
+
// Insert the handle if it doesn't exist
144
+
if let Err(err) = handle_warm_up(&admin_ctx.web_context.pool, &document.id, handle, pds).await {
145
+
return contextual_error!(
146
+
admin_ctx.web_context,
147
+
admin_ctx.language,
148
+
error_template,
149
+
default_context,
150
+
err
151
+
);
152
+
}
153
+
161
154
// Construct the XRPC request to get the record
162
155
let url = format!(
163
156
"{}/xrpc/com.atproto.repo.getRecord?repo={}&collection={}&rkey={}",
164
-
pds_endpoint, did, collection, rkey
157
+
pds, document.id, collection, rkey
165
158
);
166
159
167
160
let response = match admin_ctx.web_context.http_client.get(&url).send().await {
···
237
230
crate::storage::event::RsvpInsertParams {
238
231
aturi,
239
232
cid: &cid,
240
-
did: &did,
233
+
did: &document.id,
241
234
lexicon: COMMUNITY_RSVP_NSID,
242
235
record: &rsvp_value,
243
236
event_aturi: &event_aturi,
···
284
277
crate::storage::event::RsvpInsertParams {
285
278
aturi,
286
279
cid: &cid,
287
-
did: &did,
280
+
did: &document.id,
288
281
lexicon: SMOKESIGNAL_RSVP_NSID,
289
282
record: &rsvp_value,
290
283
event_aturi: &event_aturi,
+1
-1
src/http/handle_admin_index.rs
+1
-1
src/http/handle_admin_index.rs
+3
-3
src/http/handle_admin_rsvp.rs
+3
-3
src/http/handle_admin_rsvp.rs
···
18
18
};
19
19
20
20
#[derive(Deserialize)]
21
-
pub struct RsvpRecordQuery {
22
-
pub aturi: String,
21
+
pub(crate) struct RsvpRecordQuery {
22
+
pub(crate) aturi: String,
23
23
}
24
24
25
-
pub async fn handle_admin_rsvp(
25
+
pub(crate) async fn handle_admin_rsvp(
26
26
State(web_context): State<WebContext>,
27
27
Language(language): Language,
28
28
Cached(auth): Cached<Auth>,
+2
-2
src/http/handle_admin_rsvps.rs
+2
-2
src/http/handle_admin_rsvps.rs
···
16
16
};
17
17
18
18
#[derive(Deserialize, Default)]
19
-
pub struct AdminRsvpsParams {
19
+
pub(crate) struct AdminRsvpsParams {
20
20
#[serde(flatten)]
21
21
pagination: Pagination,
22
22
···
25
25
imported_aturi: Option<String>,
26
26
}
27
27
28
-
pub async fn handle_admin_rsvps(
28
+
pub(crate) async fn handle_admin_rsvps(
29
29
admin_ctx: AdminRequestContext,
30
30
Query(params): Query<AdminRsvpsParams>,
31
31
) -> Result<impl IntoResponse, WebError> {
+68
-39
src/http/handle_create_event.rs
+68
-39
src/http/handle_create_event.rs
···
16
16
use minijinja::context as template_context;
17
17
use serde::Deserialize;
18
18
19
-
use crate::atproto::auth::SimpleOAuthSessionProvider;
20
-
use crate::atproto::client::CreateRecordRequest;
21
-
use crate::atproto::client::OAuthPdsClient;
19
+
use crate::atproto::auth::{
20
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
21
+
};
22
22
use crate::atproto::lexicon::community::lexicon::calendar::event::Event;
23
23
use crate::atproto::lexicon::community::lexicon::calendar::event::EventLink;
24
24
use crate::atproto::lexicon::community::lexicon::calendar::event::EventLocation;
···
26
26
use crate::atproto::lexicon::community::lexicon::calendar::event::Status;
27
27
use crate::atproto::lexicon::community::lexicon::calendar::event::NSID;
28
28
use crate::atproto::lexicon::community::lexicon::location::Address;
29
+
use crate::config::OAuthBackendConfig;
29
30
use crate::contextual_error;
30
31
use crate::http::context::WebContext;
31
32
use crate::http::errors::CommonError;
···
41
42
use crate::http::utils::url_from_aturi;
42
43
use crate::select_template;
43
44
use crate::storage::event::event_insert;
45
+
use atproto_client::com::atproto::repo::{
46
+
create_record, CreateRecordRequest, CreateRecordResponse,
47
+
};
44
48
45
49
use super::cache_countries::cached_countries;
46
50
use super::event_form::BuildLocationForm;
47
51
48
-
pub async fn handle_create_event(
52
+
pub(crate) async fn handle_create_event(
49
53
method: Method,
50
54
State(web_context): State<WebContext>,
51
55
Language(language): Language,
···
54
58
HxBoosted(hx_boosted): HxBoosted,
55
59
Form(mut build_event_form): Form<BuildEventForm>,
56
60
) -> Result<impl IntoResponse, WebError> {
57
-
let current_handle = auth.require(&web_context.config.destination_key, "/event")?;
61
+
let current_handle = auth.require(&web_context.config, "/event")?;
58
62
59
63
let is_development = cfg!(debug_assertions);
60
64
···
72
76
73
77
let error_template = select_template!(hx_boosted, hx_request, language);
74
78
75
-
let (default_tz, timezones) = supported_timezones(auth.0.as_ref());
79
+
let (default_tz, timezones) = supported_timezones(auth.profile());
76
80
77
81
if build_event_form.build_state.is_none() {
78
82
build_event_form.build_state = Some(BuildEventContentState::default());
···
217
221
_ => None,
218
222
});
219
223
220
-
// Ensure we have auth data for the API call
221
-
let auth_data = auth.1.ok_or(CommonError::NotAuthorized)?;
222
-
let client_auth: SimpleOAuthSessionProvider =
223
-
SimpleOAuthSessionProvider::try_from(auth_data)?;
224
-
225
-
let client = OAuthPdsClient {
226
-
http_client: &web_context.http_client,
227
-
pds: ¤t_handle.pds,
224
+
// Create DPoP auth based on OAuth backend type
225
+
let dpop_auth = match (&auth, &web_context.config.oauth_backend) {
226
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
227
+
create_dpop_auth_from_oauth_session(session)?
228
+
}
229
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
230
+
create_dpop_auth_from_aip_session(
231
+
&web_context.http_client,
232
+
hostname,
233
+
access_token,
234
+
)
235
+
.await?
236
+
}
237
+
_ => return Err(CommonError::NotAuthorized.into()),
228
238
};
229
239
230
240
let locations = match &build_event_form.location_country {
···
276
286
swap_commit: None,
277
287
};
278
288
279
-
let create_record_result = client.create_record(&client_auth, event_record).await;
289
+
let create_record_result = create_record(
290
+
&web_context.http_client,
291
+
&dpop_auth,
292
+
¤t_handle.pds,
293
+
event_record,
294
+
)
295
+
.await;
280
296
281
-
if let Err(err) = create_record_result {
282
-
return contextual_error!(
283
-
web_context,
284
-
language,
285
-
error_template,
286
-
default_context,
287
-
err
288
-
);
289
-
}
290
-
291
-
// create_record_result is guaranteed to be Ok since we checked for Err above
292
-
let create_record_result = create_record_result?;
297
+
let create_record_response = match create_record_result {
298
+
Ok(CreateRecordResponse::StrongRef { uri, cid, .. }) => {
299
+
crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid }
300
+
}
301
+
Ok(CreateRecordResponse::Error(err)) => {
302
+
return contextual_error!(
303
+
web_context,
304
+
language,
305
+
error_template,
306
+
default_context,
307
+
anyhow::anyhow!("Server error: {}", err.error_message())
308
+
);
309
+
}
310
+
Err(err) => {
311
+
return contextual_error!(
312
+
web_context,
313
+
language,
314
+
error_template,
315
+
default_context,
316
+
err
317
+
);
318
+
}
319
+
};
293
320
294
321
let event_insert_result = event_insert(
295
322
&web_context.pool,
296
-
&create_record_result.uri,
297
-
&create_record_result.cid,
323
+
&create_record_response.uri,
324
+
&create_record_response.cid,
298
325
¤t_handle.did,
299
326
NSID,
300
327
&the_record,
···
311
338
);
312
339
}
313
340
314
-
let event_url =
315
-
url_from_aturi(&web_context.config.external_base, &create_record_result.uri)?;
341
+
let event_url = url_from_aturi(
342
+
&web_context.config.external_base,
343
+
&create_record_response.uri,
344
+
)?;
316
345
317
346
return Ok(RenderHtml(
318
347
&render_template,
···
346
375
.into_response())
347
376
}
348
377
349
-
pub async fn handle_starts_at_builder(
378
+
pub(crate) async fn handle_starts_at_builder(
350
379
method: Method,
351
380
State(web_context): State<WebContext>,
352
381
Language(language): Language,
···
362
391
return Ok(StatusCode::BAD_REQUEST.into_response());
363
392
}
364
393
365
-
let (default_tz, timezones) = supported_timezones(auth.0.as_ref());
394
+
let (default_tz, timezones) = supported_timezones(auth.profile());
366
395
367
396
let is_development = cfg!(debug_assertions);
368
397
···
436
465
.into_response())
437
466
}
438
467
439
-
pub async fn handle_location_at_builder(
468
+
pub(crate) async fn handle_location_at_builder(
440
469
method: Method,
441
470
State(web_context): State<WebContext>,
442
471
Language(language): Language,
···
510
539
.into_response())
511
540
}
512
541
513
-
pub async fn handle_link_at_builder(
542
+
pub(crate) async fn handle_link_at_builder(
514
543
method: Method,
515
544
State(web_context): State<WebContext>,
516
545
Language(language): Language,
···
585
614
}
586
615
587
616
#[derive(Deserialize, Debug, Clone)]
588
-
pub struct LocationDataListHint {
589
-
pub location_country: Option<String>,
617
+
pub(crate) struct LocationDataListHint {
618
+
pub(crate) location_country: Option<String>,
590
619
}
591
620
592
-
pub async fn handle_location_datalist(
621
+
pub(crate) async fn handle_location_datalist(
593
622
State(web_context): State<WebContext>,
594
623
HxRequest(hx_request): HxRequest,
595
624
Query(location_country_hint): Query<LocationDataListHint>,
···
653
682
None
654
683
}
655
684
656
-
pub fn prefixed(mut set: BTreeMap<String, String>, prefix: &str) -> BTreeMap<String, String> {
685
+
fn prefixed(mut set: BTreeMap<String, String>, prefix: &str) -> BTreeMap<String, String> {
657
686
let mut set = set.split_off(prefix);
658
687
659
688
if let Some(not_in_prefix) = upper_bound_from_prefix(prefix) {
+55
-28
src/http/handle_create_rsvp.rs
+55
-28
src/http/handle_create_rsvp.rs
···
9
9
use minijinja::context as template_context;
10
10
use std::hash::Hasher;
11
11
12
+
use crate::atproto::auth::{
13
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
14
+
};
15
+
use crate::config::OAuthBackendConfig;
12
16
use crate::{
13
-
atproto::{
14
-
auth::SimpleOAuthSessionProvider,
15
-
client::{OAuthPdsClient, PutRecordRequest},
16
-
lexicon::{
17
-
com::atproto::repo::StrongRef,
18
-
community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID},
19
-
},
17
+
atproto::lexicon::{
18
+
com::atproto::repo::StrongRef,
19
+
community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID},
20
20
},
21
21
contextual_error,
22
22
http::{
23
23
context::WebContext,
24
-
errors::WebError,
24
+
errors::{CommonError, WebError},
25
25
middleware_auth::Auth,
26
26
middleware_i18n::Language,
27
27
rsvp_form::{BuildRSVPForm, BuildRsvpContentState},
···
30
30
select_template,
31
31
storage::event::rsvp_insert,
32
32
};
33
+
use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse};
33
34
34
-
pub async fn handle_create_rsvp(
35
+
pub(crate) async fn handle_create_rsvp(
35
36
method: Method,
36
37
State(web_context): State<WebContext>,
37
38
Language(language): Language,
···
40
41
HxBoosted(hx_boosted): HxBoosted,
41
42
Form(mut build_rsvp_form): Form<BuildRSVPForm>,
42
43
) -> Result<impl IntoResponse, WebError> {
43
-
let current_handle = auth.require(&web_context.config.destination_key, "/rsvp")?;
44
+
let current_handle = auth.require(&web_context.config, "/rsvp")?;
44
45
45
46
let default_context = template_context! {
46
47
current_handle,
···
115
116
if !found_errors {
116
117
let now = Utc::now();
117
118
118
-
let client_auth: SimpleOAuthSessionProvider =
119
-
SimpleOAuthSessionProvider::try_from(auth.1.unwrap())?;
120
-
121
-
let client = OAuthPdsClient {
122
-
http_client: &web_context.http_client,
123
-
pds: ¤t_handle.pds,
119
+
// Create DPoP auth based on OAuth backend type
120
+
let dpop_auth = match (&auth, &web_context.config.oauth_backend) {
121
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
122
+
create_dpop_auth_from_oauth_session(session)?
123
+
}
124
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
125
+
create_dpop_auth_from_aip_session(
126
+
&web_context.http_client,
127
+
hostname,
128
+
access_token,
129
+
)
130
+
.await?
131
+
}
132
+
_ => return Err(CommonError::NotAuthorized.into()),
124
133
};
125
134
126
135
let subject = StrongRef {
···
156
165
swap_record: None,
157
166
};
158
167
159
-
let put_record_result = client.put_record(&client_auth, rsvp_record).await;
168
+
let put_record_result = put_record(
169
+
&web_context.http_client,
170
+
&dpop_auth,
171
+
¤t_handle.pds,
172
+
rsvp_record,
173
+
)
174
+
.await;
160
175
161
-
if let Err(err) = put_record_result {
162
-
return contextual_error!(
163
-
web_context,
164
-
language,
165
-
error_template,
166
-
default_context,
167
-
err
168
-
);
169
-
}
170
-
171
-
let create_record_result = put_record_result.unwrap();
176
+
let create_record_result = match put_record_result {
177
+
Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => {
178
+
crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid }
179
+
}
180
+
Ok(PutRecordResponse::Error(err)) => {
181
+
return contextual_error!(
182
+
web_context,
183
+
language,
184
+
error_template,
185
+
default_context,
186
+
anyhow::anyhow!("Server error: {}", err.error_message())
187
+
);
188
+
}
189
+
Err(err) => {
190
+
return contextual_error!(
191
+
web_context,
192
+
language,
193
+
error_template,
194
+
default_context,
195
+
err
196
+
);
197
+
}
198
+
};
172
199
173
200
let rsvp_insert_result = rsvp_insert(
174
201
&web_context.pool,
+214
src/http/handle_delete_event.rs
+214
src/http/handle_delete_event.rs
···
1
+
use anyhow::{Result, anyhow};
2
+
use axum::{extract::Path, response::IntoResponse};
3
+
use axum_extra::extract::Form;
4
+
use axum_template::RenderHtml;
5
+
use http::StatusCode;
6
+
use minijinja::context as template_context;
7
+
use serde::{Deserialize, Serialize};
8
+
9
+
use atproto_client::com::atproto::repo::{delete_record, DeleteRecordRequest};
10
+
11
+
use crate::{
12
+
atproto::{
13
+
auth::{create_dpop_auth_from_oauth_session, create_dpop_auth_from_aip_session},
14
+
lexicon::community::lexicon::calendar::event::NSID as LexiconCommunityEventNSID,
15
+
},
16
+
contextual_error,
17
+
http::{context::UserRequestContext, errors::WebError, middleware_auth::Auth},
18
+
select_template,
19
+
storage::{event::{event_delete, event_exists}},
20
+
config::OAuthBackendConfig,
21
+
};
22
+
23
+
#[derive(Debug, Deserialize, Serialize)]
24
+
pub struct DeleteEventForm {
25
+
pub confirm: Option<String>,
26
+
}
27
+
28
+
pub(crate) async fn handle_delete_event(
29
+
ctx: UserRequestContext,
30
+
Path((handle_slug, event_rkey)): Path<(String, String)>,
31
+
Form(form): Form<DeleteEventForm>,
32
+
) -> Result<impl IntoResponse, WebError> {
33
+
let current_handle = match &ctx.auth {
34
+
Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => profile.clone(),
35
+
Auth::Unauthenticated => {
36
+
return Ok(StatusCode::FORBIDDEN.into_response());
37
+
}
38
+
};
39
+
40
+
let default_context = template_context! {
41
+
current_handle,
42
+
language => ctx.language.to_string(),
43
+
canonical_url => format!("https://{}/{}/{}/delete", ctx.web_context.config.external_base, handle_slug, event_rkey),
44
+
event_url => format!("/{}/{}", handle_slug, event_rkey),
45
+
};
46
+
47
+
let error_template = select_template!(false, false, ctx.language);
48
+
let render_template = select_template!("delete_event", false, false, ctx.language);
49
+
50
+
let lookup_aturi = format!(
51
+
"at://{}/{}/{}",
52
+
current_handle.did, LexiconCommunityEventNSID, event_rkey
53
+
);
54
+
55
+
// Get the event
56
+
let event_exists = event_exists(&ctx.web_context.pool, &lookup_aturi).await;
57
+
if let Err(err) = event_exists {
58
+
return contextual_error!(
59
+
ctx.web_context,
60
+
ctx.language,
61
+
error_template,
62
+
default_context,
63
+
err
64
+
);
65
+
}
66
+
67
+
if !event_exists.unwrap() {
68
+
return contextual_error!(
69
+
ctx.web_context,
70
+
ctx.language,
71
+
error_template,
72
+
default_context,
73
+
anyhow!(
74
+
"error-delete-event-1 identity does not have event with that AT-URI: {lookup_aturi}"
75
+
)
76
+
);
77
+
}
78
+
79
+
// If form is submitted with confirmation or it's a non-HTMX POST, proceed with deletion
80
+
if form.confirm.as_deref() == Some("true") {
81
+
// Create DPoP authentication based on auth type
82
+
let dpop_auth = match &ctx.auth {
83
+
Auth::Pds { session, .. } => {
84
+
match create_dpop_auth_from_oauth_session(session) {
85
+
Ok(auth) => auth,
86
+
Err(err) => {
87
+
tracing::error!("Failed to create DPoP auth from OAuth session: {}", err);
88
+
return contextual_error!(
89
+
ctx.web_context,
90
+
ctx.language,
91
+
error_template,
92
+
default_context,
93
+
err,
94
+
StatusCode::INTERNAL_SERVER_ERROR
95
+
);
96
+
}
97
+
}
98
+
}
99
+
Auth::Aip { .. } => {
100
+
match &ctx.web_context.config.oauth_backend {
101
+
OAuthBackendConfig::AIP { hostname, .. } => {
102
+
let access_token = match &ctx.auth {
103
+
Auth::Aip { access_token, .. } => access_token,
104
+
_ => unreachable!("We already matched on Auth::Aip"),
105
+
};
106
+
107
+
match create_dpop_auth_from_aip_session(
108
+
&ctx.web_context.http_client,
109
+
hostname,
110
+
access_token,
111
+
).await {
112
+
Ok(auth) => auth,
113
+
Err(err) => {
114
+
tracing::error!("Failed to create DPoP auth from AIP session: {}", err);
115
+
return contextual_error!(
116
+
ctx.web_context,
117
+
ctx.language,
118
+
error_template,
119
+
default_context,
120
+
err,
121
+
StatusCode::INTERNAL_SERVER_ERROR
122
+
);
123
+
}
124
+
}
125
+
}
126
+
_ => {
127
+
tracing::error!("AIP auth found but OAuth backend is not AIP");
128
+
return contextual_error!(
129
+
ctx.web_context,
130
+
ctx.language,
131
+
error_template,
132
+
default_context,
133
+
anyhow!("Authentication configuration mismatch"),
134
+
StatusCode::INTERNAL_SERVER_ERROR
135
+
);
136
+
}
137
+
}
138
+
}
139
+
Auth::Unauthenticated => {
140
+
// This should not happen due to the check above
141
+
return Ok(StatusCode::FORBIDDEN.into_response());
142
+
}
143
+
};
144
+
145
+
// Delete from PDS first
146
+
let delete_record_request = DeleteRecordRequest {
147
+
repo: current_handle.did.clone(),
148
+
collection: LexiconCommunityEventNSID.to_string(),
149
+
record_key: event_rkey.clone(),
150
+
swap_commit: None,
151
+
swap_record: None,
152
+
};
153
+
154
+
match delete_record(
155
+
&ctx.web_context.http_client,
156
+
&dpop_auth,
157
+
¤t_handle.pds,
158
+
delete_record_request,
159
+
).await {
160
+
Ok(_) => {
161
+
tracing::info!("Successfully deleted event from PDS: {}", lookup_aturi);
162
+
}
163
+
Err(err) => {
164
+
tracing::error!("Failed to delete event from PDS: {}", err);
165
+
return contextual_error!(
166
+
ctx.web_context,
167
+
ctx.language,
168
+
error_template,
169
+
default_context,
170
+
err,
171
+
StatusCode::INTERNAL_SERVER_ERROR
172
+
);
173
+
}
174
+
}
175
+
176
+
// Delete from local storage
177
+
if let Err(err) = event_delete(&ctx.web_context.pool, &lookup_aturi).await {
178
+
tracing::error!("Failed to delete event from local storage: {}", err);
179
+
return contextual_error!(
180
+
ctx.web_context,
181
+
ctx.language,
182
+
error_template,
183
+
default_context,
184
+
err,
185
+
StatusCode::INTERNAL_SERVER_ERROR
186
+
);
187
+
}
188
+
189
+
let render_template = select_template!("alert", false, false, ctx.language);
190
+
191
+
// TODO: Localize these strings.
192
+
193
+
return Ok(RenderHtml(
194
+
&render_template,
195
+
ctx.web_context.engine.clone(),
196
+
template_context! { ..default_context, ..template_context! {
197
+
message_type => "info",
198
+
message_title => "Event Deleted Successfully",
199
+
message => "The event has been deleted from your Personal Data Server (PDS) and removed this Smoke Signal instance.",
200
+
}},
201
+
)
202
+
.into_response());
203
+
}
204
+
205
+
Ok(RenderHtml(
206
+
&render_template,
207
+
ctx.web_context.engine.clone(),
208
+
template_context! { ..default_context, ..template_context! {
209
+
show_confirm => true,
210
+
event_url => format!("/{}/{}", handle_slug, event_rkey),
211
+
}},
212
+
)
213
+
.into_response())
214
+
}
+69
-38
src/http/handle_edit_event.rs
+69
-38
src/http/handle_edit_event.rs
···
7
7
use http::{Method, StatusCode};
8
8
use minijinja::context as template_context;
9
9
10
+
use crate::atproto::auth::{
11
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
12
+
};
13
+
use crate::config::OAuthBackendConfig;
14
+
use crate::http::middleware_auth::Auth;
10
15
use crate::{
11
16
atproto::{
12
-
auth::SimpleOAuthSessionProvider,
13
-
client::{OAuthPdsClient, PutRecordRequest},
14
17
lexicon::community::lexicon::calendar::event::{
15
18
Event as LexiconCommunityEvent, EventLink, EventLocation, Mode, NamedUri, Status,
16
19
NSID as LexiconCommunityEventNSID,
···
26
29
http::location_edit_status::{check_location_edit_status, LocationEditStatus},
27
30
http::timezones::supported_timezones,
28
31
http::utils::url_from_aturi,
29
-
resolve::{parse_input, InputType},
30
32
select_template,
31
33
storage::{
32
34
event::{event_get, event_update_with_metadata},
33
-
handle::{handle_for_did, handle_for_handle},
35
+
identity_profile::{handle_for_did, handle_for_handle},
34
36
},
35
37
};
38
+
use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse};
36
39
37
-
pub async fn handle_edit_event(
40
+
pub(crate) async fn handle_edit_event(
38
41
ctx: UserRequestContext,
39
42
method: Method,
40
43
HxBoosted(hx_boosted): HxBoosted,
···
42
45
Path((handle_slug, event_rkey)): Path<(String, String)>,
43
46
Form(mut build_event_form): Form<BuildEventForm>,
44
47
) -> Result<impl IntoResponse, WebError> {
45
-
let current_handle = ctx
46
-
.auth
47
-
.require(&ctx.web_context.config.destination_key, "/")?;
48
+
let current_handle = ctx.auth.require(&ctx.web_context.config, "/")?;
48
49
49
50
let default_context = template_context! {
50
51
current_handle,
···
53
54
create_event => false,
54
55
submit_url => format!("/{}/{}/edit", handle_slug, event_rkey),
55
56
cancel_url => format!("/{}/{}", handle_slug, event_rkey),
57
+
delete_event_url => format!("https://{}/{}/{}/delete", ctx.web_context.config.external_base, handle_slug, event_rkey),
56
58
};
57
59
58
60
let render_template = select_template!("edit_event", hx_boosted, hx_request, ctx.language);
59
61
let error_template = select_template!(hx_boosted, hx_request, ctx.language);
60
62
61
63
// Lookup the event
62
-
let profile = match parse_input(&handle_slug) {
63
-
Ok(InputType::Handle(handle)) => handle_for_handle(&ctx.web_context.pool, &handle)
64
+
let profile = if handle_slug.starts_with("did:") {
65
+
handle_for_did(&ctx.web_context.pool, &handle_slug)
66
+
.await
67
+
.map_err(WebError::from)
68
+
} else {
69
+
let handle = if let Some(handle) = handle_slug.strip_prefix('@') {
70
+
handle
71
+
} else {
72
+
&handle_slug
73
+
};
74
+
handle_for_handle(&ctx.web_context.pool, handle)
64
75
.await
65
-
.map_err(WebError::from),
66
-
Ok(InputType::Plc(did) | InputType::Web(did)) => {
67
-
handle_for_did(&ctx.web_context.pool, &did)
68
-
.await
69
-
.map_err(WebError::from)
70
-
}
71
-
_ => Err(WebError::from(EditEventError::InvalidHandleSlug)),
76
+
.map_err(WebError::from)
72
77
}?;
73
78
74
79
let lookup_aturi = format!(
···
484
489
_ => None,
485
490
});
486
491
487
-
let client_auth: SimpleOAuthSessionProvider =
488
-
SimpleOAuthSessionProvider::try_from(ctx.auth.1.unwrap())?;
489
-
490
-
let client = OAuthPdsClient {
491
-
http_client: &ctx.web_context.http_client,
492
-
pds: ¤t_handle.pds,
492
+
// Create DPoP auth based on OAuth backend type
493
+
let dpop_auth = match (&ctx.auth, &ctx.web_context.config.oauth_backend) {
494
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
495
+
create_dpop_auth_from_oauth_session(session)?
496
+
}
497
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
498
+
create_dpop_auth_from_aip_session(
499
+
&ctx.web_context.http_client,
500
+
hostname,
501
+
access_token,
502
+
)
503
+
.await?
504
+
}
505
+
_ => return Err(CommonError::NotAuthorized.into()),
493
506
};
494
507
495
508
// Extract existing locations and URIs from the original record
···
601
614
swap_record: Some(event.cid.clone()),
602
615
};
603
616
604
-
let update_record_result =
605
-
client.put_record(&client_auth, update_record_request).await;
606
-
607
-
if let Err(err) = update_record_result {
608
-
return contextual_error!(
609
-
ctx.web_context,
610
-
ctx.language,
611
-
error_template,
612
-
default_context,
613
-
err,
614
-
StatusCode::OK
615
-
);
616
-
}
617
+
let update_record_result = put_record(
618
+
&ctx.web_context.http_client,
619
+
&dpop_auth,
620
+
¤t_handle.pds,
621
+
update_record_request,
622
+
)
623
+
.await;
617
624
618
-
let update_record_result = update_record_result.unwrap();
625
+
let update_record_response = match update_record_result {
626
+
Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => {
627
+
crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid }
628
+
}
629
+
Ok(PutRecordResponse::Error(err)) => {
630
+
return contextual_error!(
631
+
ctx.web_context,
632
+
ctx.language,
633
+
error_template,
634
+
default_context,
635
+
anyhow::anyhow!("Server error: {}", err.error_message()),
636
+
StatusCode::OK
637
+
);
638
+
}
639
+
Err(err) => {
640
+
return contextual_error!(
641
+
ctx.web_context,
642
+
ctx.language,
643
+
error_template,
644
+
default_context,
645
+
err,
646
+
StatusCode::OK
647
+
);
648
+
}
649
+
};
619
650
620
651
let name = match &updated_record {
621
652
LexiconCommunityEvent::Current { name, .. } => name,
···
625
656
let event_update_result = event_update_with_metadata(
626
657
&ctx.web_context.pool,
627
658
&lookup_aturi,
628
-
&update_record_result.cid,
659
+
&update_record_response.cid,
629
660
&updated_record,
630
661
name,
631
662
)
+86
-42
src/http/handle_import.rs
+86
-42
src/http/handle_import.rs
···
9
9
use minijinja::context as template_context;
10
10
use serde::Deserialize;
11
11
12
+
use crate::atproto::auth::{
13
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
14
+
};
15
+
use crate::config::OAuthBackendConfig;
12
16
use crate::{
13
-
atproto::{
14
-
auth::SimpleOAuthSessionProvider,
15
-
client::{ListRecordsParams, OAuthPdsClient},
16
-
lexicon::{
17
-
community::lexicon::calendar::{
18
-
event::{Event as LexiconCommunityEvent, NSID as LEXICON_COMMUNITY_EVENT_NSID},
19
-
rsvp::{
20
-
Rsvp as LexiconCommunityRsvp, RsvpStatus as LexiconCommunityRsvpStatus,
21
-
NSID as LEXICON_COMMUNITY_RSVP_NSID,
22
-
},
17
+
atproto::lexicon::{
18
+
community::lexicon::calendar::{
19
+
event::{Event as LexiconCommunityEvent, NSID as LEXICON_COMMUNITY_EVENT_NSID},
20
+
rsvp::{
21
+
Rsvp as LexiconCommunityRsvp, RsvpStatus as LexiconCommunityRsvpStatus,
22
+
NSID as LEXICON_COMMUNITY_RSVP_NSID,
23
23
},
24
-
events::smokesignal::calendar::{
25
-
event::{Event as SmokeSignalEvent, NSID as SMOKESIGNAL_EVENT_NSID},
26
-
rsvp::{
27
-
Rsvp as SmokeSignalRsvp, RsvpStatus as SmokeSignalRsvpStatus,
28
-
NSID as SMOKESIGNAL_RSVP_NSID,
29
-
},
24
+
},
25
+
events::smokesignal::calendar::{
26
+
event::{Event as SmokeSignalEvent, NSID as SMOKESIGNAL_EVENT_NSID},
27
+
rsvp::{
28
+
Rsvp as SmokeSignalRsvp, RsvpStatus as SmokeSignalRsvpStatus,
29
+
NSID as SMOKESIGNAL_RSVP_NSID,
30
30
},
31
31
},
32
32
},
33
33
contextual_error,
34
34
http::{
35
35
context::WebContext,
36
-
errors::{ImportError, WebError},
36
+
errors::{CommonError, ImportError, WebError},
37
37
middleware_auth::Auth,
38
38
middleware_i18n::Language,
39
39
},
40
40
select_template,
41
41
storage::event::{event_insert_with_metadata, rsvp_insert_with_metadata},
42
42
};
43
+
use atproto_client::com::atproto::repo::{list_records, ListRecordsParams};
43
44
44
-
pub async fn handle_import(
45
+
pub(crate) async fn handle_import(
45
46
State(web_context): State<WebContext>,
46
47
Language(language): Language,
47
48
Cached(auth): Cached<Auth>,
48
49
HxRequest(hx_request): HxRequest,
49
50
HxBoosted(hx_boosted): HxBoosted,
50
51
) -> Result<impl IntoResponse, WebError> {
51
-
let current_handle = auth.require(&web_context.config.destination_key, "/import")?;
52
+
let current_handle = auth.require(&web_context.config, "/import")?;
52
53
53
54
let default_context = template_context! {
54
55
current_handle,
···
67
68
}
68
69
69
70
#[derive(Debug, Deserialize)]
70
-
pub struct ImportForm {
71
-
pub collection: Option<String>,
72
-
pub cursor: Option<String>,
71
+
pub(crate) struct ImportForm {
72
+
pub(crate) collection: Option<String>,
73
+
pub(crate) cursor: Option<String>,
73
74
}
74
75
75
-
pub async fn handle_import_submit(
76
+
pub(crate) async fn handle_import_submit(
76
77
State(web_context): State<WebContext>,
77
78
Language(language): Language,
78
79
Cached(auth): Cached<Auth>,
···
98
99
let collection = import_form.collection.unwrap_or(collections[0].to_string());
99
100
let cursor = import_form.cursor;
100
101
101
-
let client_auth: SimpleOAuthSessionProvider =
102
-
SimpleOAuthSessionProvider::try_from(auth.1.unwrap())?;
103
-
let client = OAuthPdsClient {
104
-
http_client: &web_context.http_client,
105
-
pds: ¤t_handle.pds,
102
+
// Create DPoP auth based on OAuth backend type
103
+
let dpop_auth = match (&auth, &web_context.config.oauth_backend) {
104
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
105
+
create_dpop_auth_from_oauth_session(session)?
106
+
}
107
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
108
+
create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token)
109
+
.await?
110
+
}
111
+
_ => return Err(CommonError::NotAuthorized.into()),
106
112
};
107
113
108
114
const LIMIT: u32 = 20;
109
115
110
116
// Set up list records parameters to fetch records
111
117
let list_params = ListRecordsParams {
112
-
repo: current_handle.did.clone(),
113
-
collection: collection.clone(),
114
118
limit: Some(LIMIT),
115
119
cursor,
116
120
reverse: None,
···
118
122
119
123
let render_context = match collection.as_str() {
120
124
LEXICON_COMMUNITY_EVENT_NSID => {
121
-
let results = client
122
-
.list_records::<LexiconCommunityEvent>(&client_auth, &list_params)
123
-
.await;
125
+
let results = list_records::<LexiconCommunityEvent>(
126
+
&web_context.http_client,
127
+
&dpop_auth,
128
+
¤t_handle.pds,
129
+
current_handle.did.clone(),
130
+
collection.clone(),
131
+
ListRecordsParams {
132
+
limit: Some(LIMIT),
133
+
cursor: list_params.cursor.clone(),
134
+
reverse: None,
135
+
},
136
+
)
137
+
.await;
124
138
match results {
125
139
Ok(list_records) => {
126
140
let mut items = vec![];
···
176
190
}
177
191
}
178
192
LEXICON_COMMUNITY_RSVP_NSID => {
179
-
let results = client
180
-
.list_records::<LexiconCommunityRsvp>(&client_auth, &list_params)
181
-
.await;
193
+
let results = list_records::<LexiconCommunityRsvp>(
194
+
&web_context.http_client,
195
+
&dpop_auth,
196
+
¤t_handle.pds,
197
+
current_handle.did.clone(),
198
+
collection.clone(),
199
+
ListRecordsParams {
200
+
limit: Some(LIMIT),
201
+
cursor: list_params.cursor.clone(),
202
+
reverse: None,
203
+
},
204
+
)
205
+
.await;
182
206
match results {
183
207
Ok(list_records) => {
184
208
let mut items = vec![];
···
248
272
}
249
273
}
250
274
SMOKESIGNAL_EVENT_NSID => {
251
-
let results = client
252
-
.list_records::<SmokeSignalEvent>(&client_auth, &list_params)
253
-
.await;
275
+
let results = list_records::<SmokeSignalEvent>(
276
+
&web_context.http_client,
277
+
&dpop_auth,
278
+
¤t_handle.pds,
279
+
current_handle.did.clone(),
280
+
collection.clone(),
281
+
ListRecordsParams {
282
+
limit: Some(LIMIT),
283
+
cursor: list_params.cursor.clone(),
284
+
reverse: None,
285
+
},
286
+
)
287
+
.await;
254
288
match results {
255
289
Ok(list_records) => {
256
290
let mut items = vec![];
···
306
340
}
307
341
}
308
342
SMOKESIGNAL_RSVP_NSID => {
309
-
let results = client
310
-
.list_records::<SmokeSignalRsvp>(&client_auth, &list_params)
311
-
.await;
343
+
let results = list_records::<SmokeSignalRsvp>(
344
+
&web_context.http_client,
345
+
&dpop_auth,
346
+
¤t_handle.pds,
347
+
current_handle.did.clone(),
348
+
collection.clone(),
349
+
ListRecordsParams {
350
+
limit: Some(LIMIT),
351
+
cursor: list_params.cursor.clone(),
352
+
reverse: None,
353
+
},
354
+
)
355
+
.await;
312
356
match results {
313
357
Ok(list_records) => {
314
358
let mut items = vec![];
+3
-3
src/http/handle_index.rs
+3
-3
src/http/handle_index.rs
···
46
46
}
47
47
}
48
48
49
-
pub async fn handle_index(
49
+
pub(crate) async fn handle_index(
50
50
State(web_context): State<WebContext>,
51
51
HxBoosted(hx_boosted): HxBoosted,
52
52
Language(language): Language,
···
88
88
.filter_map(|event_view| {
89
89
let organizer_maybe = organizer_handlers.get(&event_view.event.did);
90
90
let event_view =
91
-
EventView::try_from((auth.0.as_ref(), organizer_maybe, &event_view.event));
91
+
EventView::try_from((auth.profile(), organizer_maybe, &event_view.event));
92
92
93
93
match event_view {
94
94
Ok(event_view) => Some(event_view),
···
118
118
&render_template,
119
119
web_context.engine.clone(),
120
120
template_context! {
121
-
current_handle => auth.0,
121
+
current_handle => auth.profile(),
122
122
language => language.to_string(),
123
123
canonical_url => format!("https://{}/", web_context.config.external_base),
124
124
tab => tab.to_string(),
+73
-45
src/http/handle_migrate_event.rs
+73
-45
src/http/handle_migrate_event.rs
···
10
10
use minijinja::context as template_context;
11
11
use std::collections::HashMap;
12
12
13
+
use crate::atproto::auth::{
14
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
15
+
};
16
+
use crate::config::OAuthBackendConfig;
13
17
use crate::{
14
-
atproto::{
15
-
auth::SimpleOAuthSessionProvider,
16
-
client::{OAuthPdsClient, PutRecordRequest},
17
-
lexicon::{
18
-
community::lexicon::calendar::event::{
19
-
Event as CommunityEvent, EventLink, EventLocation as CommunityLocation, Mode,
20
-
Status, NSID as COMMUNITY_NSID,
21
-
},
22
-
community::lexicon::location,
23
-
events::smokesignal::calendar::event::{
24
-
Event as SmokeSignalEvent, Location as SmokeSignalLocation, PlaceLocation,
25
-
NSID as SMOKESIGNAL_NSID,
26
-
},
18
+
atproto::lexicon::{
19
+
community::lexicon::calendar::event::{
20
+
Event as CommunityEvent, EventLink, EventLocation as CommunityLocation, Mode, Status,
21
+
NSID as COMMUNITY_NSID,
22
+
},
23
+
community::lexicon::location,
24
+
events::smokesignal::calendar::event::{
25
+
Event as SmokeSignalEvent, Location as SmokeSignalLocation, PlaceLocation,
26
+
NSID as SMOKESIGNAL_NSID,
27
27
},
28
28
},
29
29
contextual_error,
···
31
31
context::WebContext, errors::MigrateEventError, errors::WebError, middleware_auth::Auth,
32
32
middleware_i18n::Language, utils::url_from_aturi,
33
33
},
34
-
resolve::{parse_input, InputType},
35
34
select_template,
36
35
storage::{
37
36
event::{event_get, event_insert_with_metadata},
38
-
handle::{handle_for_did, handle_for_handle, model::Handle},
37
+
identity_profile::{handle_for_did, handle_for_handle, model::IdentityProfile},
39
38
},
40
39
};
40
+
use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse};
41
41
42
-
pub async fn handle_migrate_event(
42
+
pub(crate) async fn handle_migrate_event(
43
43
State(web_context): State<WebContext>,
44
44
HxBoosted(hx_boosted): HxBoosted,
45
45
Language(language): Language,
···
47
47
HxRequest(hx_request): HxRequest,
48
48
Path((handle_slug, event_rkey)): Path<(String, String)>,
49
49
) -> Result<impl IntoResponse, WebError> {
50
-
let current_handle = auth.require(&web_context.config.destination_key, "/")?;
50
+
let current_handle = auth.require(&web_context.config, "/")?;
51
51
52
52
// Configure templates
53
53
let default_context = template_context! {
···
59
59
let error_template = select_template!(hx_boosted, hx_request, language);
60
60
61
61
// Lookup the user handle/profile
62
-
let profile: Result<Handle> = match parse_input(&handle_slug) {
63
-
Ok(InputType::Handle(handle)) => handle_for_handle(&web_context.pool, &handle)
62
+
let profile: Result<IdentityProfile> = if handle_slug.starts_with("did:") {
63
+
handle_for_did(&web_context.pool, &handle_slug)
64
64
.await
65
-
.map_err(|err| err.into()),
66
-
Ok(InputType::Plc(did) | InputType::Web(did)) => handle_for_did(&web_context.pool, &did)
65
+
.map_err(|err| err.into())
66
+
} else {
67
+
let handle = if let Some(handle) = handle_slug.strip_prefix('@') {
68
+
handle
69
+
} else {
70
+
&handle_slug
71
+
};
72
+
handle_for_handle(&web_context.pool, handle)
67
73
.await
68
-
.map_err(|err| err.into()),
69
-
Err(err) => Err(err.into()),
74
+
.map_err(|err| err.into())
70
75
};
71
76
72
77
if let Err(err) = profile {
···
261
266
);
262
267
}
263
268
264
-
// Set up XRPC client
265
-
// Error if we don't have auth data
266
-
let auth_data = auth.1.ok_or(MigrateEventError::NotAuthorized)?;
267
-
let client_auth: SimpleOAuthSessionProvider = SimpleOAuthSessionProvider::try_from(auth_data)?;
268
-
269
-
let client = OAuthPdsClient {
270
-
http_client: &web_context.http_client,
271
-
pds: ¤t_handle.pds,
269
+
// Set up authentication
270
+
// Create DPoP auth based on OAuth backend type
271
+
let dpop_auth = match (&auth, &web_context.config.oauth_backend) {
272
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
273
+
create_dpop_auth_from_oauth_session(session)?
274
+
}
275
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
276
+
create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token)
277
+
.await?
278
+
}
279
+
_ => return Err(MigrateEventError::NotAuthorized.into()),
272
280
};
273
281
274
282
// Create the community event record in the user's PDS using putRecord to retain the same rkey
···
283
291
};
284
292
285
293
// Write to the PDS
286
-
let update_record_result = client.put_record(&client_auth, update_record_request).await;
287
-
if let Err(err) = update_record_result {
288
-
return contextual_error!(
289
-
web_context,
290
-
language,
291
-
error_template,
292
-
default_context,
293
-
err,
294
-
StatusCode::OK
295
-
);
296
-
}
297
-
// update_record_result is guaranteed to be Ok at this point since we checked for Err above
298
-
let update_record_result = update_record_result?;
294
+
let update_record_result = put_record(
295
+
&web_context.http_client,
296
+
&dpop_auth,
297
+
¤t_handle.pds,
298
+
update_record_request,
299
+
)
300
+
.await;
301
+
302
+
let update_record_response = match update_record_result {
303
+
Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => {
304
+
crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid }
305
+
}
306
+
Ok(PutRecordResponse::Error(err)) => {
307
+
return contextual_error!(
308
+
web_context,
309
+
language,
310
+
error_template,
311
+
default_context,
312
+
anyhow::anyhow!("Server error: {}", err.error_message()),
313
+
StatusCode::OK
314
+
);
315
+
}
316
+
Err(err) => {
317
+
return contextual_error!(
318
+
web_context,
319
+
language,
320
+
error_template,
321
+
default_context,
322
+
err,
323
+
StatusCode::OK
324
+
);
325
+
}
326
+
};
299
327
300
328
// We already have the migrated AT-URI defined above
301
329
···
303
331
let migrated_event_insert_result = event_insert_with_metadata(
304
332
&web_context.pool,
305
333
&migrated_aturi,
306
-
&update_record_result.cid,
334
+
&update_record_response.cid,
307
335
¤t_handle.did,
308
336
COMMUNITY_NSID,
309
337
&new_event,
+57
-32
src/http/handle_migrate_rsvp.rs
+57
-32
src/http/handle_migrate_rsvp.rs
···
10
10
use minijinja::context as template_context;
11
11
use std::hash::Hasher;
12
12
13
+
use crate::atproto::auth::{
14
+
create_dpop_auth_from_aip_session, create_dpop_auth_from_oauth_session,
15
+
};
16
+
use crate::config::OAuthBackendConfig;
13
17
use crate::{
14
-
atproto::{
15
-
auth::SimpleOAuthSessionProvider,
16
-
client::{OAuthPdsClient, PutRecordRequest},
17
-
lexicon::{
18
-
com::atproto::repo::StrongRef,
19
-
community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID as RSVP_COLLECTION},
20
-
events::smokesignal::calendar::event::NSID as EVENT_COLLECTION,
21
-
},
18
+
atproto::lexicon::{
19
+
com::atproto::repo::StrongRef,
20
+
community::lexicon::calendar::rsvp::{Rsvp, RsvpStatus, NSID as RSVP_COLLECTION},
21
+
events::smokesignal::calendar::event::NSID as EVENT_COLLECTION,
22
22
},
23
23
contextual_error,
24
24
http::{
···
27
27
middleware_auth::Auth,
28
28
middleware_i18n::Language,
29
29
},
30
-
resolve::{parse_input, InputType},
31
30
select_template,
32
31
storage::{
33
32
event::{event_get, get_user_rsvp, rsvp_insert},
34
-
handle::{handle_for_did, handle_for_handle, model::Handle},
33
+
identity_profile::{handle_for_did, handle_for_handle, model::IdentityProfile},
35
34
},
36
35
};
36
+
use atproto_client::com::atproto::repo::{put_record, PutRecordRequest, PutRecordResponse};
37
37
38
38
/// Migrates a user's RSVP from a legacy event to a standard event format.
39
39
///
···
48
48
/// 3. Ensuring the user doesn't already have an RSVP for the standard event
49
49
/// 4. Creating a new RSVP record in the standard format
50
50
/// 5. Storing the RSVP both on the PDS and in the local database
51
-
pub async fn handle_migrate_rsvp(
51
+
pub(crate) async fn handle_migrate_rsvp(
52
52
State(web_context): State<WebContext>,
53
53
Language(language): Language,
54
54
Cached(auth): Cached<Auth>,
···
56
56
) -> Result<impl IntoResponse, WebError> {
57
57
// Require user to be logged in
58
58
let current_handle = auth.require(
59
-
&web_context.config.destination_key,
59
+
&web_context.config,
60
60
"/{handle_slug}/{event_rkey}/migrate-rsvp",
61
61
)?;
62
62
···
69
69
let error_template = select_template!(false, false, language);
70
70
71
71
// Get handle information from the path parameter
72
-
let profile: Result<Handle> = match parse_input(&handle_slug) {
73
-
Ok(InputType::Handle(handle)) => handle_for_handle(&web_context.pool, &handle)
72
+
let profile: Result<IdentityProfile> = if handle_slug.starts_with("did:") {
73
+
handle_for_did(&web_context.pool, &handle_slug)
74
74
.await
75
-
.map_err(|err| err.into()),
76
-
Ok(InputType::Plc(did) | InputType::Web(did)) => handle_for_did(&web_context.pool, &did)
75
+
.map_err(|err| err.into())
76
+
} else {
77
+
let handle = if let Some(handle) = handle_slug.strip_prefix('@') {
78
+
handle
79
+
} else {
80
+
&handle_slug
81
+
};
82
+
handle_for_handle(&web_context.pool, handle)
77
83
.await
78
-
.map_err(|err| err.into()),
79
-
Err(err) => Err(err.into()),
84
+
.map_err(|err| err.into())
80
85
};
81
86
82
87
let profile = match profile {
···
181
186
}
182
187
183
188
// Create a new RSVP for the standard event
184
-
// Error if we don't have auth data
185
-
let auth_data = auth.1.ok_or(MigrateRsvpError::NotAuthorized)?;
186
-
let client_auth: SimpleOAuthSessionProvider = SimpleOAuthSessionProvider::try_from(auth_data)?;
187
-
188
-
let client = OAuthPdsClient {
189
-
http_client: &web_context.http_client,
190
-
pds: ¤t_handle.pds,
189
+
// Create DPoP auth based on OAuth backend type
190
+
let dpop_auth = match (&auth, &web_context.config.oauth_backend) {
191
+
(Auth::Pds { session, .. }, OAuthBackendConfig::ATProtocol { .. }) => {
192
+
create_dpop_auth_from_oauth_session(session)?
193
+
}
194
+
(Auth::Aip { access_token, .. }, OAuthBackendConfig::AIP { hostname, .. }) => {
195
+
create_dpop_auth_from_aip_session(&web_context.http_client, hostname, access_token)
196
+
.await?
197
+
}
198
+
_ => return Err(MigrateRsvpError::NotAuthorized.into()),
191
199
};
192
200
193
201
// Create a reference to the standard event that will be the subject of the RSVP
···
242
250
swap_record: None,
243
251
};
244
252
245
-
let put_record_result = client.put_record(&client_auth, rsvp_record).await;
253
+
let put_record_result = put_record(
254
+
&web_context.http_client,
255
+
&dpop_auth,
256
+
¤t_handle.pds,
257
+
rsvp_record,
258
+
)
259
+
.await;
246
260
247
-
if let Err(err) = put_record_result {
248
-
return contextual_error!(web_context, language, error_template, default_context, err);
249
-
}
250
-
251
-
// put_record_result is guaranteed to be Ok here since we checked for Err above
252
-
let create_record_result = put_record_result?;
261
+
let create_record_result = match put_record_result {
262
+
Ok(PutRecordResponse::StrongRef { uri, cid, .. }) => {
263
+
crate::atproto::lexicon::com::atproto::repo::StrongRef { uri, cid }
264
+
}
265
+
Ok(PutRecordResponse::Error(err)) => {
266
+
return contextual_error!(
267
+
web_context,
268
+
language,
269
+
error_template,
270
+
default_context,
271
+
anyhow::anyhow!("Server error: {}", err.error_message())
272
+
);
273
+
}
274
+
Err(err) => {
275
+
return contextual_error!(web_context, language, error_template, default_context, err);
276
+
}
277
+
};
253
278
254
279
// Store the new RSVP in the database
255
280
let rsvp_insert_result = rsvp_insert(
+149
src/http/handle_oauth_aip_callback.rs
+149
src/http/handle_oauth_aip_callback.rs
···
1
+
use crate::{config::OAuthBackendConfig, contextual_error, select_template};
2
+
use anyhow::Result;
3
+
use axum::{
4
+
extract::State,
5
+
response::{IntoResponse, Redirect},
6
+
};
7
+
use axum_extra::extract::{
8
+
cookie::{Cookie, SameSite},
9
+
Form, PrivateCookieJar,
10
+
};
11
+
use minijinja::context as template_context;
12
+
use serde::{Deserialize, Serialize};
13
+
14
+
use super::{
15
+
context::WebContext,
16
+
errors::{LoginError, WebError},
17
+
middleware_auth::{WebSession, AUTH_COOKIE_NAME},
18
+
middleware_i18n::Language,
19
+
};
20
+
21
+
#[derive(Deserialize, Serialize)]
22
+
pub(crate) struct OAuthCallbackForm {
23
+
pub(crate) state: Option<String>,
24
+
pub(crate) code: Option<String>,
25
+
}
26
+
27
+
pub(crate) async fn handle_oauth_callback(
28
+
State(web_context): State<WebContext>,
29
+
Language(language): Language,
30
+
jar: PrivateCookieJar,
31
+
Form(callback_form): Form<OAuthCallbackForm>,
32
+
) -> Result<impl IntoResponse, WebError> {
33
+
let default_context = template_context! {
34
+
language => language.to_string(),
35
+
canonical_url => format!("https://{}/oauth/callback", web_context.config.external_base),
36
+
};
37
+
38
+
let error_template = select_template!(false, false, language);
39
+
40
+
// Get AIP server configuration - config validation ensures these are set when oauth_backend is AIP
41
+
let (hostname, client_id, client_secret) = if let OAuthBackendConfig::AIP {
42
+
hostname,
43
+
client_id,
44
+
client_secret,
45
+
} = &web_context.config.oauth_backend
46
+
{
47
+
(hostname, client_id, client_secret)
48
+
} else {
49
+
unreachable!("AIP OAuth backend should have AIP configuration")
50
+
};
51
+
52
+
let (callback_code, callback_state) = match (callback_form.code, callback_form.state) {
53
+
(Some(x), Some(z)) => (x, z),
54
+
_ => {
55
+
return contextual_error!(
56
+
web_context,
57
+
language,
58
+
error_template,
59
+
default_context,
60
+
LoginError::OAuthCallbackIncomplete
61
+
);
62
+
}
63
+
};
64
+
65
+
let oauth_request = match web_context
66
+
.oauth_storage
67
+
.get_oauth_request_by_state(&callback_state)
68
+
.await
69
+
{
70
+
Err(err) => {
71
+
return contextual_error!(web_context, language, error_template, default_context, err);
72
+
}
73
+
Ok(None) => {
74
+
return contextual_error!(
75
+
web_context,
76
+
language,
77
+
error_template,
78
+
default_context,
79
+
anyhow::anyhow!("oauth request not found in storage")
80
+
);
81
+
}
82
+
Ok(Some(value)) => value,
83
+
};
84
+
85
+
let authorization_server = match atproto_oauth_aip::resources::oauth_authorization_server(
86
+
&web_context.http_client,
87
+
hostname,
88
+
)
89
+
.await
90
+
{
91
+
Ok(value) => value,
92
+
Err(err) => {
93
+
return contextual_error!(web_context, language, error_template, default_context, err);
94
+
}
95
+
};
96
+
97
+
let oauth_client = atproto_oauth_aip::workflow::OAuthClient {
98
+
redirect_uri: format!(
99
+
"https://{}/oauth/callback",
100
+
&web_context.config.external_base
101
+
),
102
+
client_id: client_id.clone(),
103
+
client_secret: client_secret.clone(),
104
+
};
105
+
106
+
let token_response = atproto_oauth_aip::workflow::oauth_complete(
107
+
&web_context.http_client,
108
+
&oauth_client,
109
+
&authorization_server,
110
+
&callback_code,
111
+
&oauth_request,
112
+
)
113
+
.await;
114
+
115
+
if let Err(err) = token_response {
116
+
tracing::error!(
117
+
?err,
118
+
"atproto_oauth_aip::workflow::oautoauth_completeh_init"
119
+
);
120
+
return contextual_error!(web_context, language, error_template, default_context, err);
121
+
}
122
+
123
+
let token_response = token_response.unwrap();
124
+
125
+
let cookie_value: String = WebSession::Aip {
126
+
did: oauth_request.did.clone(),
127
+
access_token: token_response.access_token.clone(),
128
+
}
129
+
.try_into()?;
130
+
131
+
let mut cookie = Cookie::new(AUTH_COOKIE_NAME, cookie_value);
132
+
cookie.set_domain(web_context.config.external_base.clone());
133
+
cookie.set_path("/");
134
+
cookie.set_http_only(true);
135
+
cookie.set_secure(true);
136
+
cookie.set_max_age(Some(cookie::time::Duration::seconds(
137
+
(token_response.expires_in as i64) - 60,
138
+
)));
139
+
cookie.set_same_site(Some(SameSite::Lax));
140
+
141
+
let updated_jar = jar.add(cookie);
142
+
143
+
// let destination = match oauth_request.destination {
144
+
// Some(destination) => destination,
145
+
// None => "/".to_string(),
146
+
// };
147
+
148
+
Ok((updated_jar, Redirect::to("/")).into_response())
149
+
}
+281
src/http/handle_oauth_aip_login.rs
+281
src/http/handle_oauth_aip_login.rs
···
1
+
use anyhow::{anyhow, Result};
2
+
use atproto_identity::resolve::IdentityResolver;
3
+
use atproto_oauth::pkce::generate;
4
+
use atproto_oauth::workflow::OAuthRequestState as AipOAuthRequestState;
5
+
use atproto_oauth_aip::{
6
+
resources::oauth_authorization_server,
7
+
workflow::{oauth_init, OAuthClient},
8
+
};
9
+
use axum::response::Redirect;
10
+
use axum::{extract::State, response::IntoResponse};
11
+
use axum_extra::extract::{Cached, Form, Query};
12
+
use axum_htmx::{HxBoosted, HxRedirect, HxRequest};
13
+
use axum_template::RenderHtml;
14
+
use http::StatusCode;
15
+
use minijinja::context as template_context;
16
+
use rand::{distributions::Alphanumeric, Rng};
17
+
use serde::Deserialize;
18
+
19
+
use crate::{
20
+
config::OAuthBackendConfig,
21
+
contextual_error,
22
+
http::{
23
+
context::WebContext,
24
+
errors::{LoginError, WebError},
25
+
middleware_auth::Auth,
26
+
middleware_i18n::Language,
27
+
utils::stringify,
28
+
},
29
+
select_template,
30
+
storage::{denylist::denylist_exists, identity_profile::handle_warm_up},
31
+
};
32
+
33
+
#[derive(Deserialize)]
34
+
pub(crate) struct OAuthLoginForm {
35
+
handle: Option<String>,
36
+
}
37
+
38
+
#[derive(Deserialize)]
39
+
pub(crate) struct Destination {
40
+
destination: Option<String>,
41
+
}
42
+
43
+
#[allow(clippy::too_many_arguments)]
44
+
pub(crate) async fn handle_oauth_aip_login(
45
+
State(web_context): State<WebContext>,
46
+
Language(language): Language,
47
+
Cached(auth): Cached<Auth>,
48
+
identity_resolver: IdentityResolver,
49
+
HxRequest(hx_request): HxRequest,
50
+
HxBoosted(hx_boosted): HxBoosted,
51
+
Query(destination): Query<Destination>,
52
+
Form(login_form): Form<OAuthLoginForm>,
53
+
) -> Result<impl IntoResponse, WebError> {
54
+
let default_context = template_context! {
55
+
current_handle => auth.profile(),
56
+
language => language.to_string(),
57
+
canonical_url => format!("https://{}/oauth/login", web_context.config.external_base),
58
+
destination => destination.destination,
59
+
};
60
+
61
+
let render_template = select_template!("login", hx_boosted, hx_request, language);
62
+
let error_template = select_template!(hx_boosted, hx_request, language);
63
+
64
+
if let Some(subject) = login_form.handle {
65
+
let handle_denied = denylist_exists(&web_context.pool, &[subject.as_str()])
66
+
.await
67
+
.unwrap_or(true);
68
+
if handle_denied {
69
+
return contextual_error!(
70
+
web_context,
71
+
language,
72
+
error_template,
73
+
default_context,
74
+
anyhow!("access-denied")
75
+
);
76
+
}
77
+
78
+
let document = match identity_resolver.resolve(&subject).await {
79
+
Ok(value) => value,
80
+
Err(err) => {
81
+
return contextual_error!(
82
+
web_context,
83
+
language,
84
+
error_template,
85
+
default_context,
86
+
err
87
+
);
88
+
}
89
+
};
90
+
91
+
let handle = match document
92
+
.handles()
93
+
.ok_or(WebError::Login(LoginError::NoHandle))
94
+
{
95
+
Ok(value) => value,
96
+
Err(err) => {
97
+
tracing::error!(?err, "handles");
98
+
return contextual_error!(
99
+
web_context,
100
+
language,
101
+
error_template,
102
+
default_context,
103
+
err
104
+
);
105
+
}
106
+
};
107
+
108
+
let pds = match document
109
+
.pds_endpoints()
110
+
.first()
111
+
.cloned()
112
+
.ok_or(WebError::Login(LoginError::NoPDS))
113
+
{
114
+
Ok(value) => value,
115
+
Err(err) => {
116
+
tracing::error!(?err, "pds_endpoints first");
117
+
return contextual_error!(
118
+
web_context,
119
+
language,
120
+
error_template,
121
+
default_context,
122
+
err
123
+
);
124
+
}
125
+
};
126
+
127
+
if let Err(err) = web_context
128
+
.document_storage
129
+
.store_document(document.clone())
130
+
.await
131
+
{
132
+
tracing::error!(?err, "store_document");
133
+
return contextual_error!(web_context, language, error_template, default_context, err);
134
+
}
135
+
136
+
// Insert the handle if it doesn't exist
137
+
if let Err(err) = handle_warm_up(&web_context.pool, &document.id, handle, pds).await {
138
+
tracing::error!(?err, "handle_warm_up");
139
+
return contextual_error!(web_context, language, error_template, default_context, err);
140
+
}
141
+
142
+
// Generate OAuth parameters
143
+
let state: String = rand::thread_rng()
144
+
.sample_iter(&Alphanumeric)
145
+
.take(30)
146
+
.map(char::from)
147
+
.collect();
148
+
let nonce: String = rand::thread_rng()
149
+
.sample_iter(&Alphanumeric)
150
+
.take(30)
151
+
.map(char::from)
152
+
.collect();
153
+
154
+
let (pkce_verifier, code_challenge) = generate();
155
+
156
+
// Create AIP-specific OAuth request state with scope
157
+
let aip_oauth_request_state = AipOAuthRequestState {
158
+
state: state.clone(),
159
+
nonce: nonce.clone(),
160
+
code_challenge,
161
+
scope: "atproto:atproto atproto:transition:generic".to_string(),
162
+
};
163
+
164
+
// Get AIP server configuration - config validation ensures these are set when oauth_backend is AIP
165
+
let (hostname, client_id, client_secret) = if let OAuthBackendConfig::AIP {
166
+
hostname,
167
+
client_id,
168
+
client_secret,
169
+
} = &web_context.config.oauth_backend
170
+
{
171
+
(hostname, client_id, client_secret)
172
+
} else {
173
+
unreachable!("AIP OAuth backend should have AIP configuration")
174
+
};
175
+
176
+
// Get AIP authorization server metadata
177
+
let authorization_server =
178
+
match oauth_authorization_server(&web_context.http_client, hostname).await {
179
+
Ok(value) => value,
180
+
Err(err) => {
181
+
tracing::error!(?err, "oauth_authorization_server");
182
+
return contextual_error!(
183
+
web_context,
184
+
language,
185
+
error_template,
186
+
default_context,
187
+
err
188
+
);
189
+
}
190
+
};
191
+
192
+
// Create AIP OAuth client
193
+
let oauth_client = OAuthClient {
194
+
redirect_uri: format!(
195
+
"https://{}/oauth/callback",
196
+
web_context.config.external_base
197
+
),
198
+
client_id: client_id.clone(),
199
+
client_secret: client_secret.clone(),
200
+
};
201
+
tracing::info!(oauth_client.redirect_uri, "oauth_client");
202
+
203
+
// Initialize AIP OAuth flow
204
+
let par_response = oauth_init(
205
+
&web_context.http_client,
206
+
&oauth_client,
207
+
Some(&subject),
208
+
&authorization_server,
209
+
&aip_oauth_request_state,
210
+
)
211
+
.await;
212
+
213
+
if let Err(err) = par_response {
214
+
tracing::error!(?err, "oauth_init");
215
+
return contextual_error!(web_context, language, error_template, default_context, err);
216
+
}
217
+
218
+
let par_response = par_response.unwrap();
219
+
220
+
let created_at = chrono::Utc::now();
221
+
let expires_at = created_at + chrono::Duration::seconds(par_response.expires_in as i64);
222
+
223
+
let oauth_request = atproto_oauth::workflow::OAuthRequest {
224
+
oauth_state: state.clone(),
225
+
nonce: nonce.clone(),
226
+
pkce_verifier: pkce_verifier.clone(),
227
+
228
+
issuer: authorization_server.issuer.clone(),
229
+
did: document.id,
230
+
231
+
signing_public_key: "".to_string(),
232
+
dpop_private_key: "".to_string(),
233
+
created_at,
234
+
expires_at,
235
+
};
236
+
237
+
if let Err(err) = web_context
238
+
.oauth_storage
239
+
.insert_oauth_request(oauth_request)
240
+
.await
241
+
{
242
+
tracing::error!(?err, "insert_oauth_request");
243
+
return contextual_error!(web_context, language, error_template, default_context, err);
244
+
}
245
+
246
+
let oauth_args = [
247
+
(
248
+
"request_uri".to_string(),
249
+
urlencoding::encode(&par_response.request_uri).to_string(),
250
+
),
251
+
(
252
+
"client_id".to_string(),
253
+
urlencoding::encode(client_id).to_string(),
254
+
),
255
+
];
256
+
let oauth_args = oauth_args.iter().map(|(k, v)| (&**k, &**v)).collect();
257
+
258
+
let destination = format!(
259
+
"{}?{}",
260
+
authorization_server.authorization_endpoint,
261
+
stringify(oauth_args)
262
+
);
263
+
264
+
if hx_request {
265
+
if let Ok(hx_redirect) = HxRedirect::try_from(destination.as_str()) {
266
+
return Ok((StatusCode::OK, hx_redirect, "").into_response());
267
+
}
268
+
}
269
+
270
+
return Ok(Redirect::temporary(destination.as_str()).into_response());
271
+
}
272
+
273
+
Ok(RenderHtml(
274
+
&render_template,
275
+
web_context.engine.clone(),
276
+
template_context! { ..default_context, ..template_context! {
277
+
destination => destination.destination,
278
+
}},
279
+
)
280
+
.into_response())
281
+
}
+98
-100
src/http/handle_oauth_callback.rs
+98
-100
src/http/handle_oauth_callback.rs
···
1
-
use anyhow::Result;
1
+
use anyhow::{anyhow, Result};
2
+
use atproto_identity::{axum::state::KeyProviderExtractor, key::identify_key};
3
+
use atproto_oauth::workflow::{oauth_complete, OAuthClient};
2
4
use axum::{
3
5
extract::State,
4
6
response::{IntoResponse, Redirect},
···
7
9
cookie::{Cookie, SameSite},
8
10
Form, PrivateCookieJar,
9
11
};
10
-
use deadpool_redis::redis::AsyncCommands as _;
11
12
use minijinja::context as template_context;
12
-
use p256::SecretKey;
13
13
use serde::{Deserialize, Serialize};
14
-
use std::borrow::Cow;
15
14
16
-
use crate::jose_errors::JwkError;
17
-
use crate::storage::errors::CacheError;
18
-
19
-
use crate::{
20
-
contextual_error,
21
-
oauth::oauth_complete,
22
-
select_template,
23
-
storage::{
24
-
cache::OAUTH_REFRESH_QUEUE,
25
-
handle::handle_for_did,
26
-
oauth::{oauth_request_get, oauth_request_remove, oauth_session_insert},
27
-
},
28
-
};
15
+
use crate::{contextual_error, select_template};
29
16
30
17
use super::{
31
18
context::WebContext,
···
35
22
};
36
23
37
24
#[derive(Deserialize, Serialize)]
38
-
pub struct OAuthCallbackForm {
39
-
pub state: Option<String>,
40
-
pub iss: Option<String>,
41
-
pub code: Option<String>,
25
+
pub(crate) struct OAuthCallbackForm {
26
+
state: Option<String>,
27
+
iss: Option<String>,
28
+
code: Option<String>,
42
29
}
43
30
44
-
pub async fn handle_oauth_callback(
31
+
#[axum::debug_handler]
32
+
pub(crate) async fn handle_oauth_callback(
45
33
State(web_context): State<WebContext>,
34
+
key_provider: KeyProviderExtractor,
46
35
Language(language): Language,
47
36
jar: PrivateCookieJar,
48
37
Form(callback_form): Form<OAuthCallbackForm>,
···
68
57
}
69
58
};
70
59
71
-
let oauth_request = oauth_request_get(&web_context.pool, &callback_state).await;
72
-
if let Err(err) = oauth_request {
73
-
return contextual_error!(web_context, language, error_template, default_context, err);
74
-
}
60
+
let oauth_request = web_context
61
+
.oauth_storage
62
+
.get_oauth_request_by_state(&callback_state)
63
+
.await;
75
64
76
-
let oauth_request = oauth_request.unwrap();
65
+
let oauth_request = match oauth_request {
66
+
Err(err) => {
67
+
return contextual_error!(web_context, language, error_template, default_context, err);
68
+
}
69
+
Ok(None) => {
70
+
return contextual_error!(
71
+
web_context,
72
+
language,
73
+
error_template,
74
+
default_context,
75
+
anyhow!("oauth request not found in storage")
76
+
);
77
+
}
78
+
Ok(Some(value)) => value,
79
+
};
77
80
78
81
if oauth_request.issuer != callback_iss {
79
82
return contextual_error!(
···
85
88
);
86
89
}
87
90
88
-
let handle = handle_for_did(&web_context.pool, &oauth_request.did).await;
89
-
if let Err(err) = handle {
90
-
return contextual_error!(web_context, language, error_template, default_context, err);
91
-
}
92
-
93
-
let handle = handle.unwrap();
91
+
let document = match web_context
92
+
.document_storage
93
+
.get_document_by_did(&oauth_request.did)
94
+
.await
95
+
{
96
+
Err(err) => {
97
+
return contextual_error!(web_context, language, error_template, default_context, err);
98
+
}
99
+
Ok(None) => {
100
+
return contextual_error!(
101
+
web_context,
102
+
language,
103
+
error_template,
104
+
default_context,
105
+
anyhow!("identity did document not found in storage")
106
+
);
107
+
}
108
+
Ok(Some(value)) => value,
109
+
};
94
110
95
-
let secret_signing_key = web_context
96
-
.config
97
-
.signing_keys
98
-
.as_ref()
99
-
.get(&oauth_request.secret_jwk_id)
100
-
.cloned()
101
-
.ok_or(JwkError::SecretKeyNotFound);
111
+
let secret_signing_key = key_provider
112
+
.0
113
+
.get_private_key_by_id(&oauth_request.signing_public_key)
114
+
.await;
102
115
103
-
if let Err(err) = secret_signing_key {
104
-
return contextual_error!(web_context, language, error_template, default_context, err);
105
-
}
106
-
let secret_signing_key = secret_signing_key.unwrap();
116
+
let secret_key_data = match secret_signing_key {
117
+
Ok(Some(value)) => value,
118
+
Ok(None) => {
119
+
return contextual_error!(
120
+
web_context,
121
+
language,
122
+
error_template,
123
+
default_context,
124
+
LoginError::OAuthCallbackIncomplete
125
+
);
126
+
}
127
+
Err(err) => {
128
+
return contextual_error!(web_context, language, error_template, default_context, err);
129
+
}
130
+
};
107
131
108
-
let dpop_secret_key = SecretKey::from_jwk(&oauth_request.dpop_jwk.jwk);
132
+
let dpop_key_data = match identify_key(&oauth_request.dpop_private_key) {
133
+
Ok(value) => value,
134
+
Err(err) => {
135
+
return contextual_error!(web_context, language, error_template, default_context, err);
136
+
}
137
+
};
109
138
110
-
if let Err(err) = dpop_secret_key {
111
-
return contextual_error!(web_context, language, error_template, default_context, err);
112
-
}
113
-
let dpop_secret_key = dpop_secret_key.unwrap();
139
+
let oauth_client = OAuthClient {
140
+
redirect_uri: format!(
141
+
"https://{}/oauth/callback",
142
+
&web_context.config.external_base
143
+
),
144
+
client_id: format!(
145
+
"https://{}/oauth/client-metadata.json",
146
+
&web_context.config.external_base
147
+
),
148
+
private_signing_key_data: secret_key_data,
149
+
};
114
150
115
151
let token_response = oauth_complete(
116
152
&web_context.http_client,
117
-
&web_context.config.external_base,
118
-
(&oauth_request.secret_jwk_id, secret_signing_key),
153
+
&oauth_client,
154
+
&dpop_key_data,
119
155
&callback_code,
120
156
&oauth_request,
121
-
&handle,
122
-
&dpop_secret_key,
157
+
&document,
123
158
)
124
159
.await;
125
160
if let Err(err) = token_response {
···
128
163
129
164
let token_response = token_response.unwrap();
130
165
131
-
if let Err(err) = oauth_request_remove(&web_context.pool, &oauth_request.oauth_state).await {
166
+
if let Err(err) = web_context
167
+
.oauth_storage
168
+
.delete_oauth_request_by_state(&callback_state)
169
+
.await
170
+
{
132
171
tracing::error!(error = ?err, "Unable to remove oauth_request");
133
172
}
134
173
135
-
let session_group = ulid::Ulid::new().to_string();
136
-
let now = chrono::Utc::now();
137
-
138
-
if let Err(err) = oauth_session_insert(
139
-
&web_context.pool,
140
-
crate::storage::oauth::OAuthSessionParams {
141
-
session_group: Cow::Owned(session_group.clone()),
142
-
access_token: Cow::Owned(token_response.access_token.clone()),
143
-
did: Cow::Owned(token_response.sub.clone()),
144
-
issuer: Cow::Owned(oauth_request.issuer.clone()),
145
-
refresh_token: Cow::Owned(token_response.refresh_token.clone()),
146
-
secret_jwk_id: Cow::Owned(oauth_request.secret_jwk_id.clone()),
147
-
dpop_jwk: oauth_request.dpop_jwk.0.clone(),
148
-
created_at: now,
149
-
access_token_expires_at: now
150
-
+ chrono::Duration::seconds(token_response.expires_in as i64),
151
-
},
152
-
)
153
-
.await
154
-
{
155
-
return contextual_error!(web_context, language, error_template, default_context, err);
156
-
}
157
-
158
-
{
159
-
let mut conn = web_context
160
-
.cache_pool
161
-
.get()
162
-
.await
163
-
.map_err(CacheError::FailedToGetConnection)?;
164
-
165
-
let modified_expires_at = ((token_response.expires_in as f64) * 0.8).round() as i64;
166
-
let refresh_at = (now + chrono::Duration::seconds(modified_expires_at)).timestamp_millis();
167
-
168
-
let _: () = conn
169
-
.zadd(OAUTH_REFRESH_QUEUE, &session_group, refresh_at)
170
-
.await
171
-
.map_err(CacheError::FailedToPlaceInRefreshQueue)?;
172
-
}
173
-
174
-
let cookie_value: String = WebSession {
175
-
did: token_response.sub.clone(),
176
-
session_group: session_group.clone(),
174
+
// For standard OAuth, create a PDS session
175
+
let cookie_value: String = WebSession::Pds {
176
+
did: token_response.sub.clone().unwrap(),
177
+
session_group: "".to_string(), // Simplified for initial pass
177
178
}
178
179
.try_into()?;
179
180
···
187
188
188
189
let updated_jar = jar.add(cookie);
189
190
190
-
let destination = match oauth_request.destination {
191
-
Some(destination) => destination,
192
-
None => "/".to_string(),
193
-
};
191
+
let destination = "/".to_string(); // Simplified for initial pass
194
192
195
193
Ok((updated_jar, Redirect::to(&destination)).into_response())
196
194
}
-94
src/http/handle_oauth_jwks.rs
-94
src/http/handle_oauth_jwks.rs
···
1
-
use anyhow::Result;
2
-
use axum::{
3
-
extract::State,
4
-
http::{header, HeaderValue},
5
-
response::{IntoResponse, Response},
6
-
};
7
-
use std::sync::Arc;
8
-
9
-
use crate::http::{context::WebContext, errors::WebError};
10
-
use crate::jose::jwk::{WrappedJsonWebKey, WrappedJsonWebKeySet};
11
-
12
-
// Function to compute JWKS data and serialize to JSON string
13
-
fn compute_jwks_json(web_context: &WebContext) -> Result<String, serde_json::Error> {
14
-
let mut keys = vec![];
15
-
let signing_keys = web_context.config.signing_keys.as_ref();
16
-
17
-
for available_signing_key in web_context.config.oauth_active_keys.as_ref() {
18
-
let available_signing_key = available_signing_key.clone();
19
-
20
-
let signing_key = match signing_keys.get(&available_signing_key) {
21
-
Some(key) => key.clone(),
22
-
None => continue,
23
-
};
24
-
let public_key = signing_key.public_key();
25
-
26
-
let wrapped_json_web_key = WrappedJsonWebKey {
27
-
jwk: public_key.to_jwk(),
28
-
kid: Some(available_signing_key.clone()),
29
-
alg: Some("ES256".to_string()),
30
-
};
31
-
32
-
keys.push(wrapped_json_web_key);
33
-
}
34
-
35
-
let jwks = WrappedJsonWebKeySet { keys };
36
-
serde_json::to_string(&jwks)
37
-
}
38
-
39
-
// Global cache for the pre-serialized JSON string
40
-
static JWKS_JSON_CACHE: once_cell::sync::OnceCell<Arc<String>> = once_cell::sync::OnceCell::new();
41
-
42
-
#[tracing::instrument(skip_all, err)]
43
-
pub async fn handle_oauth_jwks(
44
-
State(web_context): State<WebContext>,
45
-
) -> Result<impl IntoResponse, WebError> {
46
-
tracing::debug!("handle_oauth_jwks");
47
-
48
-
// Initialize the cache if needed
49
-
if JWKS_JSON_CACHE.get().is_none() {
50
-
// Compute and serialize the JWKS data
51
-
let jwks_json = compute_jwks_json(&web_context)
52
-
.map_err(|e| anyhow::anyhow!("error-oauth-jwks-1 Failed to serialize JWKS: {}", e))?;
53
-
54
-
// Store in cache - don't worry if another thread beat us to it
55
-
let _ = JWKS_JSON_CACHE.set(Arc::new(jwks_json));
56
-
}
57
-
58
-
// By this point, the cache should be initialized - either by us or another thread
59
-
// In the extremely unlikely event it's still not initialized, we'll create it one more time
60
-
let jwks_json = if let Some(json) = JWKS_JSON_CACHE.get() {
61
-
json
62
-
} else {
63
-
// Final attempt to compute and cache
64
-
let jwks_json = compute_jwks_json(&web_context)
65
-
.map_err(|e| anyhow::anyhow!("error-oauth-jwks-1 Failed to serialize JWKS: {}", e))?;
66
-
67
-
// Create a new Arc and set it in the cache
68
-
let json_arc = Arc::new(jwks_json);
69
-
70
-
// This will either succeed in setting it, or another thread beat us to it
71
-
if JWKS_JSON_CACHE.set(json_arc.clone()).is_err() {
72
-
// Another thread set it first, so use that value
73
-
JWKS_JSON_CACHE.get().ok_or_else(|| {
74
-
anyhow::anyhow!("error-oauth-jwks-2 Failed to initialize JWKS cache")
75
-
})?
76
-
} else {
77
-
// We set it, so we can use our local copy
78
-
JWKS_JSON_CACHE.get().ok_or_else(|| {
79
-
anyhow::anyhow!("error-oauth-jwks-2 Failed to initialize JWKS cache")
80
-
})?
81
-
}
82
-
};
83
-
84
-
// Create response with proper content type
85
-
let mut response = Response::new((**jwks_json).clone());
86
-
87
-
// Set content type to application/json
88
-
response.headers_mut().insert(
89
-
header::CONTENT_TYPE,
90
-
HeaderValue::from_static("application/json"),
91
-
);
92
-
93
-
Ok(response)
94
-
}
+116
-119
src/http/handle_oauth_login.rs
+116
-119
src/http/handle_oauth_login.rs
···
1
1
use anyhow::Result;
2
-
use axum::response::Redirect;
2
+
use atproto_identity::{
3
+
key::{generate_key, identify_key, KeyType},
4
+
resolve::IdentityResolver,
5
+
};
6
+
use atproto_oauth::{
7
+
pkce::generate,
8
+
resources::pds_resources,
9
+
workflow::{oauth_init, OAuthClient, OAuthRequest, OAuthRequestState},
10
+
};
3
11
use axum::{extract::State, response::IntoResponse};
4
12
use axum_extra::extract::{Cached, Form, Query};
5
13
use axum_htmx::{HxBoosted, HxRedirect, HxRequest};
6
14
use axum_template::RenderHtml;
7
-
use base64::{engine::general_purpose, Engine as _};
8
15
use http::StatusCode;
9
16
use minijinja::context as template_context;
10
-
use p256::SecretKey;
11
17
use rand::{distributions::Alphanumeric, Rng};
12
18
use serde::Deserialize;
13
-
use sha2::{Digest, Sha256};
14
-
use std::borrow::Cow;
15
19
16
20
use crate::{
17
21
contextual_error,
18
-
did::{plc::query as plc_query, web::query as web_query},
19
22
http::{
20
23
context::WebContext, errors::LoginError, errors::WebError, middleware_auth::Auth,
21
24
middleware_i18n::Language, utils::stringify,
22
25
},
23
-
jose,
24
-
oauth::{oauth_init, pds_resources},
25
-
resolve::{parse_input, resolve_subject, InputType},
26
26
select_template,
27
-
storage::{
28
-
denylist::denylist_exists,
29
-
handle::handle_warm_up,
30
-
oauth::{model::OAuthRequestState, oauth_request_insert},
31
-
},
27
+
storage::{denylist::denylist_exists, identity_profile::handle_warm_up},
32
28
};
33
29
34
30
#[derive(Deserialize)]
35
-
pub struct OAuthLoginForm {
36
-
pub handle: Option<String>,
37
-
pub destination: Option<String>,
31
+
pub(crate) struct OAuthLoginForm {
32
+
handle: Option<String>,
38
33
}
39
34
40
35
#[derive(Deserialize)]
41
-
pub struct Destination {
42
-
pub destination: Option<String>,
36
+
pub(crate) struct Destination {
37
+
destination: Option<String>,
43
38
}
44
39
45
-
pub async fn handle_oauth_login(
40
+
#[allow(clippy::too_many_arguments)]
41
+
pub(crate) async fn handle_oauth_login(
46
42
State(web_context): State<WebContext>,
43
+
identity_resolver: IdentityResolver,
47
44
Language(language): Language,
48
45
Cached(auth): Cached<Auth>,
49
46
HxRequest(hx_request): HxRequest,
···
52
49
Form(login_form): Form<OAuthLoginForm>,
53
50
) -> Result<impl IntoResponse, WebError> {
54
51
let default_context = template_context! {
55
-
current_handle => auth.0,
52
+
current_handle => auth.profile(),
56
53
language => language.to_string(),
57
54
canonical_url => format!("https://{}/oauth/login", web_context.config.external_base),
58
55
destination => destination.destination,
···
62
59
let error_template = select_template!(hx_boosted, hx_request, language);
63
60
64
61
if let Some(subject) = login_form.handle {
65
-
let resolved_did = resolve_subject(
66
-
&web_context.http_client,
67
-
&web_context.dns_resolver,
68
-
&subject,
69
-
)
70
-
.await;
71
-
72
-
if let Err(err) = resolved_did {
73
-
return contextual_error!(
74
-
web_context,
75
-
language,
76
-
render_template,
77
-
template_context! { ..default_context, ..template_context! {
78
-
handle_error => true,
79
-
handle_input => subject,
80
-
}},
81
-
err
82
-
);
83
-
}
84
-
85
-
let resolved_did = resolved_did.unwrap();
86
-
87
-
let query_results = match parse_input(&resolved_did) {
88
-
Ok(InputType::Plc(did)) => {
89
-
plc_query(
90
-
&web_context.http_client,
91
-
&web_context.config.plc_hostname,
92
-
&did,
93
-
)
94
-
.await
95
-
}
96
-
Ok(InputType::Web(did)) => web_query(&web_context.http_client, &did).await,
97
-
_ => Err(LoginError::NoHandle.into()),
98
-
};
99
-
100
-
let did_document = match query_results {
62
+
let did_document = match identity_resolver.resolve(&subject).await {
101
63
Ok(value) => value,
102
64
Err(err) => {
103
65
return contextual_error!(
···
113
75
}
114
76
};
115
77
116
-
let mut lookup_values: Vec<&str> = vec![&resolved_did, &did_document.id];
117
-
if let Some(pds) = did_document.pds_endpoint() {
78
+
let mut lookup_values: Vec<&str> = vec![&did_document.id, &subject];
79
+
if let Some(pds) = did_document.pds_endpoints().first() {
118
80
lookup_values.push(pds);
119
81
}
120
82
···
144
106
);
145
107
}
146
108
147
-
let pds = match did_document.pds_endpoint() {
109
+
let pds = match did_document.pds_endpoints().first().cloned() {
148
110
Some(value) => value,
149
111
None => {
150
112
return contextual_error!(
···
160
122
}
161
123
};
162
124
163
-
let primary_handle = match did_document.primary_handle() {
125
+
let primary_handle = match did_document.handles() {
164
126
Some(value) => value,
165
127
None => {
166
128
return contextual_error!(
···
192
154
.take(30)
193
155
.map(char::from)
194
156
.collect();
195
-
let (pkce_verifier, code_challenge) = gen_pkce();
157
+
let (pkce_verifier, code_challenge) = generate();
196
158
197
159
let oauth_request_state = OAuthRequestState {
198
160
state,
199
161
nonce,
200
162
code_challenge,
163
+
scope: "atproto transition:generic transition:email".to_string(),
201
164
};
202
165
203
-
let pds_auth_resources = pds_resources(&web_context.http_client, pds).await;
166
+
let authorization_server = match pds_resources(&web_context.http_client, pds).await {
167
+
Ok(value) => value.1,
168
+
Err(err) => {
169
+
return contextual_error!(
170
+
web_context,
171
+
language,
172
+
error_template,
173
+
default_context,
174
+
err
175
+
);
176
+
}
177
+
};
204
178
205
-
if let Err(err) = pds_auth_resources {
179
+
let signing_key_selection = web_context.config.select_oauth_signing_key();
180
+
if let Err(err) = signing_key_selection {
206
181
return contextual_error!(web_context, language, error_template, default_context, err);
207
182
}
208
183
209
-
let (_, authorization_server) = pds_auth_resources.unwrap();
210
-
tracing::info!(authorization_server = ?authorization_server, "resolved authorization server");
184
+
let (public_signing_key, private_signing_key) = signing_key_selection.unwrap();
211
185
212
-
let signing_key = web_context.config.select_oauth_signing_key();
213
-
if let Err(err) = signing_key {
214
-
return contextual_error!(web_context, language, error_template, default_context, err);
215
-
}
186
+
// Convert the DID method key string to KeyData using identify_key
187
+
let private_signing_key_data = match identify_key(&private_signing_key) {
188
+
Ok(key_data) => key_data,
189
+
Err(err) => {
190
+
return contextual_error!(
191
+
web_context,
192
+
language,
193
+
error_template,
194
+
default_context,
195
+
err
196
+
);
197
+
}
198
+
};
216
199
217
-
let (key_id, signing_key) = signing_key.unwrap();
200
+
let private_dpop_key_data = match generate_key(KeyType::P256Private) {
201
+
Ok(value) => value,
202
+
Err(err) => {
203
+
return contextual_error!(
204
+
web_context,
205
+
language,
206
+
error_template,
207
+
default_context,
208
+
err
209
+
);
210
+
}
211
+
};
218
212
219
-
let dpop_jwk = jose::jwk::generate();
220
-
let dpop_secret_key = SecretKey::from_jwk(&dpop_jwk.jwk);
213
+
let private_dpop_key = private_dpop_key_data.to_string();
221
214
222
-
if let Err(err) = dpop_secret_key {
223
-
return contextual_error!(web_context, language, error_template, default_context, err);
224
-
}
225
-
226
-
let dpop_secret_key = dpop_secret_key.unwrap();
215
+
let oauth_client = OAuthClient {
216
+
redirect_uri: format!(
217
+
"https://{}/oauth/callback",
218
+
&web_context.config.external_base
219
+
),
220
+
client_id: format!(
221
+
"https://{}/oauth/client-metadata.json",
222
+
&web_context.config.external_base
223
+
),
224
+
private_signing_key_data,
225
+
};
227
226
228
227
let par_response = oauth_init(
229
228
&web_context.http_client,
230
-
&web_context.config.external_base,
231
-
(&key_id, signing_key),
232
-
&dpop_secret_key,
229
+
&oauth_client,
230
+
&private_dpop_key_data,
233
231
primary_handle,
234
232
&authorization_server,
235
233
&oauth_request_state,
···
242
240
243
241
let par_response = par_response.unwrap();
244
242
243
+
// Store the DID document
244
+
if let Err(err) = web_context
245
+
.document_storage
246
+
.store_document(did_document.clone())
247
+
.await
248
+
{
249
+
return contextual_error!(web_context, language, error_template, default_context, err);
250
+
}
251
+
245
252
let created_at = chrono::Utc::now();
246
253
let expires_at = created_at + chrono::Duration::seconds(par_response.expires_in as i64);
247
254
248
-
if let Err(err) = oauth_request_insert(
249
-
&web_context.pool,
250
-
crate::storage::oauth::OAuthRequestParams {
251
-
oauth_state: Cow::Owned(oauth_request_state.state.clone()),
252
-
issuer: Cow::Owned(authorization_server.issuer.clone()),
253
-
did: Cow::Owned(did_document.id.clone()),
254
-
nonce: Cow::Owned(oauth_request_state.nonce.clone()),
255
-
pkce_verifier: Cow::Owned(pkce_verifier.clone()),
256
-
secret_jwk_id: Cow::Owned(key_id.clone()),
257
-
dpop_jwk: Some(dpop_jwk.clone()),
258
-
destination: login_form.destination.clone().map(Cow::Owned),
259
-
created_at,
260
-
expires_at,
261
-
},
262
-
)
263
-
.await
255
+
let oauth_request = OAuthRequest {
256
+
oauth_state: oauth_request_state.state.clone(),
257
+
issuer: authorization_server.issuer.clone(),
258
+
did: did_document.id.clone(),
259
+
nonce: oauth_request_state.nonce.clone(),
260
+
pkce_verifier: pkce_verifier.clone(),
261
+
signing_public_key: public_signing_key,
262
+
dpop_private_key: private_dpop_key,
263
+
created_at,
264
+
expires_at,
265
+
};
266
+
267
+
if let Err(err) = web_context
268
+
.oauth_storage
269
+
.insert_oauth_request(oauth_request)
270
+
.await
264
271
{
265
272
return contextual_error!(web_context, language, error_template, default_context, err);
266
273
}
···
287
294
stringify(oauth_args)
288
295
);
289
296
290
-
if hx_request {
291
-
if let Ok(hx_redirect) = HxRedirect::try_from(destination.as_str()) {
292
-
return Ok((StatusCode::OK, hx_redirect, "").into_response());
297
+
let hx_redirect = match HxRedirect::try_from(destination.as_str()) {
298
+
Ok(value) => value,
299
+
Err(err) => {
300
+
return contextual_error!(
301
+
web_context,
302
+
language,
303
+
error_template,
304
+
default_context,
305
+
err
306
+
);
293
307
}
294
-
}
308
+
};
295
309
296
-
return Ok(Redirect::temporary(destination.as_str()).into_response());
310
+
return Ok((StatusCode::OK, hx_redirect, "").into_response());
297
311
}
298
312
299
313
Ok(RenderHtml(
···
305
319
)
306
320
.into_response())
307
321
}
308
-
309
-
pub fn gen_pkce() -> (String, String) {
310
-
let token: String = rand::thread_rng()
311
-
.sample_iter(&Alphanumeric)
312
-
.take(100)
313
-
.map(char::from)
314
-
.collect();
315
-
(token.clone(), pkce_challenge(&token))
316
-
}
317
-
318
-
pub fn pkce_challenge(token: &str) -> String {
319
-
let mut hasher = Sha256::new();
320
-
hasher.update(token.as_bytes());
321
-
let result = hasher.finalize();
322
-
323
-
general_purpose::URL_SAFE_NO_PAD.encode(result)
324
-
}
+16
-33
src/http/handle_oauth_logout.rs
+16
-33
src/http/handle_oauth_logout.rs
···
1
1
use anyhow::Result;
2
-
use axum::{
3
-
extract::State,
4
-
response::{IntoResponse, Redirect},
5
-
};
6
-
use axum_extra::extract::{cookie::Cookie, PrivateCookieJar};
7
-
use axum_htmx::{HxRedirect, HxRequest};
8
-
use axum_template::RenderHtml;
9
-
use http::StatusCode;
10
-
use minijinja::context as template_context;
2
+
use axum::extract::State;
3
+
use axum::response::{IntoResponse, Redirect};
4
+
use axum_extra::extract::PrivateCookieJar;
5
+
use cookie::Cookie;
6
+
7
+
use crate::http::{errors::WebError, middleware_auth::AUTH_COOKIE_NAME};
11
8
12
-
use crate::http::{
13
-
context::WebContext, errors::WebError, middleware_auth::AUTH_COOKIE_NAME,
14
-
middleware_i18n::Language,
15
-
};
9
+
use crate::http::context::WebContext;
16
10
17
-
pub async fn handle_logout(
11
+
pub(crate) async fn handle_logout(
18
12
State(web_context): State<WebContext>,
19
-
Language(language): Language,
20
-
HxRequest(hx_request): HxRequest,
21
13
jar: PrivateCookieJar,
22
14
) -> Result<impl IntoResponse, WebError> {
23
-
let updated_jar = jar.remove(Cookie::from(AUTH_COOKIE_NAME));
15
+
let mut removal_cookie = Cookie::from(AUTH_COOKIE_NAME);
16
+
removal_cookie.set_domain(web_context.config.external_base.clone());
17
+
removal_cookie.set_path("/");
24
18
25
-
if hx_request {
26
-
let hx_redirect = HxRedirect::try_from("/");
27
-
if let Err(err) = hx_redirect {
28
-
tracing::error!("Failed to create HxLocation: {}", err);
29
-
return Ok(RenderHtml(
30
-
format!("alert.{}.partial.html", language.to_string().to_lowercase()),
31
-
web_context.engine.clone(),
32
-
template_context! { message => "Internal Server Error" },
33
-
)
34
-
.into_response());
35
-
}
36
-
let hx_redirect = hx_redirect.unwrap();
37
-
Ok((StatusCode::OK, hx_redirect, "").into_response())
38
-
} else {
39
-
Ok((updated_jar, Redirect::to("/")).into_response())
40
-
}
19
+
let updated_jar = jar.remove(removal_cookie);
20
+
21
+
tracing::info!(?updated_jar, "updated cookie jar");
22
+
23
+
Ok((updated_jar, Redirect::to("/")).into_response())
41
24
}
+1
-1
src/http/handle_oauth_metadata.rs
+1
-1
src/http/handle_oauth_metadata.rs
···
57
57
policy_uri: "https://docs.smokesignal.events/docs/about/privacy/",
58
58
redirect_uris,
59
59
response_types: vec!["code"],
60
-
scope: "atproto transition:generic",
60
+
scope: "atproto transition:generic transition:email",
61
61
token_endpoint_auth_method: "private_key_jwt",
62
62
token_endpoint_auth_signing_alg: "ES256",
63
63
subject_type: "public",
+8
-8
src/http/handle_policy.rs
+8
-8
src/http/handle_policy.rs
···
13
13
select_template,
14
14
};
15
15
16
-
pub async fn handle_privacy_policy(
16
+
pub(crate) async fn handle_privacy_policy(
17
17
State(web_context): State<WebContext>,
18
18
HxBoosted(hx_boosted): HxBoosted,
19
19
Language(language): Language,
···
26
26
&render_template,
27
27
web_context.engine.clone(),
28
28
template_context! {
29
-
current_handle => auth.0,
29
+
current_handle => auth.profile(),
30
30
language => language.to_string(),
31
31
canonical_url => format!("https://{}/privacy-policy", web_context.config.external_base),
32
32
},
···
35
35
.into_response())
36
36
}
37
37
38
-
pub async fn handle_terms_of_service(
38
+
pub(crate) async fn handle_terms_of_service(
39
39
State(web_context): State<WebContext>,
40
40
HxBoosted(hx_boosted): HxBoosted,
41
41
Language(language): Language,
···
48
48
&render_template,
49
49
web_context.engine.clone(),
50
50
template_context! {
51
-
current_handle => auth.0,
51
+
current_handle => auth.profile(),
52
52
language => language.to_string(),
53
53
canonical_url => format!("https://{}/terms-of-service", web_context.config.external_base),
54
54
},
···
57
57
.into_response())
58
58
}
59
59
60
-
pub async fn handle_cookie_policy(
60
+
pub(crate) async fn handle_cookie_policy(
61
61
State(web_context): State<WebContext>,
62
62
HxBoosted(hx_boosted): HxBoosted,
63
63
Language(language): Language,
···
70
70
&render_template,
71
71
web_context.engine.clone(),
72
72
template_context! {
73
-
current_handle => auth.0,
73
+
current_handle => auth.profile(),
74
74
language => language.to_string(),
75
75
canonical_url => format!("https://{}/cookie-policy", web_context.config.external_base),
76
76
},
···
79
79
.into_response())
80
80
}
81
81
82
-
pub async fn handle_acknowledgement(
82
+
pub(crate) async fn handle_acknowledgement(
83
83
State(web_context): State<WebContext>,
84
84
HxBoosted(hx_boosted): HxBoosted,
85
85
Language(language): Language,
···
92
92
&render_template,
93
93
web_context.engine.clone(),
94
94
template_context! {
95
-
current_handle => auth.0,
95
+
current_handle => auth.profile(),
96
96
language => language.to_string(),
97
97
canonical_url => format!("https://{}/acknowledgement", web_context.config.external_base),
98
98
},
+2
-2
src/http/handle_profile.rs
+2
-2
src/http/handle_profile.rs
···
24
24
storage::{
25
25
errors::StorageError,
26
26
event::{event_list_did_recently_updated, model::EventWithRole},
27
-
handle::{handle_for_did, handle_for_handle},
27
+
identity_profile::{handle_for_did, handle_for_handle},
28
28
},
29
29
};
30
30
···
49
49
}
50
50
}
51
51
52
-
pub async fn handle_profile_view(
52
+
pub(crate) async fn handle_profile_view(
53
53
ctx: UserRequestContext,
54
54
HxRequest(hx_request): HxRequest,
55
55
HxBoosted(hx_boosted): HxBoosted,
+5
-6
src/http/handle_set_language.rs
+5
-6
src/http/handle_set_language.rs
···
12
12
use std::{borrow::Cow, str::FromStr};
13
13
use unic_langid::LanguageIdentifier;
14
14
15
-
use crate::storage::handle::{handle_update_field, HandleField};
15
+
use crate::storage::identity_profile::{handle_update_field, HandleField};
16
16
17
17
use super::{
18
18
context::WebContext, errors::WebError, middleware_auth::Auth, middleware_i18n::COOKIE_LANG,
···
20
20
};
21
21
22
22
#[derive(Deserialize, Clone)]
23
-
pub struct LanguageForm {
23
+
pub(crate) struct LanguageForm {
24
24
language: String,
25
25
}
26
26
27
-
#[tracing::instrument(skip_all, err)]
28
-
pub async fn handle_set_language(
27
+
pub(crate) async fn handle_set_language(
29
28
State(web_context): State<WebContext>,
30
29
Cached(auth): Cached<Auth>,
31
30
jar: CookieJar,
32
31
Form(language_form): Form<LanguageForm>,
33
32
) -> Result<impl IntoResponse, WebError> {
34
33
let default_context = template_context! {
35
-
current_handle => auth.0,
34
+
current_handle => auth.profile(),
36
35
canonical_url => format!("https://{}/language", web_context.config.external_base),
37
36
};
38
37
···
65
64
}
66
65
let found = found.unwrap();
67
66
68
-
if let Some(handle) = auth.0 {
67
+
if let Some(handle) = auth.profile() {
69
68
if let Err(err) = handle_update_field(
70
69
&web_context.pool,
71
70
&handle.did,
+7
-7
src/http/handle_settings.rs
+7
-7
src/http/handle_settings.rs
···
16
16
timezones::supported_timezones,
17
17
},
18
18
select_template,
19
-
storage::handle::{handle_for_did, handle_update_field, HandleField},
19
+
storage::identity_profile::{handle_for_did, handle_update_field, HandleField},
20
20
};
21
21
22
22
#[derive(Deserialize, Clone, Debug)]
23
-
pub struct TimezoneForm {
23
+
pub(crate) struct TimezoneForm {
24
24
timezone: String,
25
25
}
26
26
27
27
#[derive(Deserialize, Clone, Debug)]
28
-
pub struct LanguageForm {
28
+
pub(crate) struct LanguageForm {
29
29
language: String,
30
30
}
31
31
32
-
pub async fn handle_settings(
32
+
pub(crate) async fn handle_settings(
33
33
State(web_context): State<WebContext>,
34
34
Language(language): Language,
35
35
Cached(auth): Cached<Auth>,
36
36
HxBoosted(hx_boosted): HxBoosted,
37
37
) -> Result<impl IntoResponse, WebError> {
38
38
// Require authentication
39
-
let current_handle = auth.require(&web_context.config.destination_key, "/settings")?;
39
+
let current_handle = auth.require(&web_context.config, "/settings")?;
40
40
41
41
let default_context = template_context! {
42
42
current_handle => current_handle.clone(),
···
74
74
}
75
75
76
76
#[tracing::instrument(skip_all, err)]
77
-
pub async fn handle_timezone_update(
77
+
pub(crate) async fn handle_timezone_update(
78
78
State(web_context): State<WebContext>,
79
79
Language(language): Language,
80
80
Cached(auth): Cached<Auth>,
···
139
139
}
140
140
141
141
#[tracing::instrument(skip_all, err)]
142
-
pub async fn handle_language_update(
142
+
pub(crate) async fn handle_language_update(
143
143
State(web_context): State<WebContext>,
144
144
Language(language): Language,
145
145
Cached(auth): Cached<Auth>,
+17
-17
src/http/handle_view_event.rs
+17
-17
src/http/handle_view_event.rs
···
15
15
use crate::atproto::lexicon::events::smokesignal::calendar::event::NSID as SMOKESIGNAL_EVENT_NSID;
16
16
use crate::contextual_error;
17
17
use crate::http::context::UserRequestContext;
18
-
use crate::http::errors::CommonError;
19
18
use crate::http::errors::ViewEventError;
20
19
use crate::http::errors::WebError;
21
20
use crate::http::event_view::hydrate_event_rsvp_counts;
···
23
22
use crate::http::pagination::Pagination;
24
23
use crate::http::tab_selector::TabSelector;
25
24
use crate::http::utils::url_from_aturi;
26
-
use crate::resolve::parse_input;
27
-
use crate::resolve::InputType;
28
25
use crate::select_template;
29
26
use crate::storage::event::count_event_rsvps;
30
27
use crate::storage::event::event_exists;
31
28
use crate::storage::event::event_get;
32
29
use crate::storage::event::get_event_rsvps;
33
30
use crate::storage::event::get_user_rsvp;
34
-
use crate::storage::handle::handle_for_did;
35
-
use crate::storage::handle::handle_for_handle;
36
-
use crate::storage::handle::model::Handle;
31
+
use crate::storage::identity_profile::handle_for_did;
32
+
use crate::storage::identity_profile::handle_for_handle;
33
+
use crate::storage::identity_profile::model::IdentityProfile;
37
34
use crate::storage::StoragePool;
38
35
39
36
#[derive(Debug, Deserialize, Serialize, PartialEq)]
···
75
72
76
73
/// Helper function to fetch the organizer's handle (which contains their time zone)
77
74
/// This is used to implement the time zone selection logic.
78
-
async fn fetch_organizer_handle(pool: &StoragePool, did: &str) -> Option<Handle> {
75
+
async fn fetch_organizer_handle(pool: &StoragePool, did: &str) -> Option<IdentityProfile> {
79
76
match handle_for_did(pool, did).await {
80
77
Ok(handle) => Some(handle),
81
78
Err(err) => {
···
85
82
}
86
83
}
87
84
88
-
pub async fn handle_view_event(
85
+
pub(crate) async fn handle_view_event(
89
86
ctx: UserRequestContext,
90
87
HxBoosted(hx_boosted): HxBoosted,
91
88
Path((handle_slug, event_rkey)): Path<(String, String)>,
···
101
98
let render_template = select_template!("view_event", hx_boosted, false, ctx.language);
102
99
let error_template = select_template!(hx_boosted, false, ctx.language);
103
100
104
-
let profile: Result<Handle, WebError> = match parse_input(&handle_slug) {
105
-
Ok(InputType::Handle(handle)) => handle_for_handle(&ctx.web_context.pool, &handle)
101
+
let profile: Result<IdentityProfile, WebError> = if handle_slug.starts_with("did:") {
102
+
handle_for_did(&ctx.web_context.pool, &handle_slug)
106
103
.await
107
-
.map_err(|err| err.into()),
108
-
Ok(InputType::Plc(did) | InputType::Web(did)) => {
109
-
handle_for_did(&ctx.web_context.pool, &did)
110
-
.await
111
-
.map_err(|err| err.into())
112
-
}
113
-
_ => Err(CommonError::InvalidHandleSlug.into()),
104
+
.map_err(|err| err.into())
105
+
} else {
106
+
let handle = if let Some(handle) = handle_slug.strip_prefix('@') {
107
+
handle
108
+
} else {
109
+
&handle_slug
110
+
};
111
+
handle_for_handle(&ctx.web_context.pool, handle)
112
+
.await
113
+
.map_err(|err| err.into())
114
114
};
115
115
116
116
if let Err(err) = profile {
+2
-2
src/http/handle_view_feed.rs
+2
-2
src/http/handle_view_feed.rs
···
9
9
context::WebContext, errors::WebError, middleware_auth::Auth, middleware_i18n::Language,
10
10
};
11
11
12
-
pub async fn handle_view_feed(
12
+
pub(crate) async fn handle_view_feed(
13
13
State(web_context): State<WebContext>,
14
14
HxBoosted(hx_boosted): HxBoosted,
15
15
Language(language): Language,
···
25
25
&render_template,
26
26
web_context.engine.clone(),
27
27
template_context! {
28
-
current_handle => auth.0,
28
+
current_handle => auth.profile(),
29
29
language => language.to_string(),
30
30
canonical_url => format!("https://{}/", web_context.config.external_base),
31
31
},
+4
-4
src/http/handle_view_rsvp.rs
+4
-4
src/http/handle_view_rsvp.rs
···
22
22
};
23
23
24
24
#[derive(Deserialize)]
25
-
pub struct RsvpQuery {
26
-
pub aturi: Option<String>,
25
+
pub(crate) struct RsvpQuery {
26
+
aturi: Option<String>,
27
27
}
28
28
29
-
pub async fn handle_view_rsvp(
29
+
pub(crate) async fn handle_view_rsvp(
30
30
State(web_context): State<WebContext>,
31
31
HxBoosted(hx_boosted): HxBoosted,
32
32
HxRequest(hx_request): HxRequest,
···
34
34
Cached(auth): Cached<Auth>,
35
35
query: Query<RsvpQuery>,
36
36
) -> Result<impl IntoResponse, WebError> {
37
-
let current_handle = auth.0.clone();
37
+
let current_handle = auth.profile().cloned();
38
38
39
39
let default_context = template_context! {
40
40
current_handle,
+106
-84
src/http/middleware_auth.rs
+106
-84
src/http/middleware_auth.rs
···
1
1
use anyhow::Result;
2
+
use atproto_oauth::jwt::{mint, Claims, Header, JoseClaims};
2
3
use axum::{
3
4
extract::{FromRef, FromRequestParts},
4
5
http::request::Parts,
5
6
response::Response,
6
7
};
7
8
use axum_extra::extract::PrivateCookieJar;
8
-
use base64::{engine::general_purpose, Engine as _};
9
-
use p256::{
10
-
ecdsa::{signature::Signer, Signature, SigningKey},
11
-
SecretKey,
12
-
};
13
9
use serde::{Deserialize, Serialize};
14
10
use tracing::{debug, instrument, trace};
15
11
16
12
use crate::{
17
13
config::Config,
18
-
encoding::ToBase64,
19
14
http::context::WebContext,
20
15
http::errors::{AuthMiddlewareError, WebSessionError},
21
-
storage::handle::model::Handle,
22
-
storage::oauth::model::OAuthSession,
23
-
storage::oauth::web_session_lookup,
16
+
storage::identity_profile::model::IdentityProfile,
24
17
};
25
18
19
+
use crate::{storage::oauth::model::OAuthSession, storage::oauth::web_session_lookup};
20
+
26
21
use super::errors::middleware_errors::MiddlewareAuthError;
27
22
28
-
pub const AUTH_COOKIE_NAME: &str = "session1";
23
+
pub(crate) const AUTH_COOKIE_NAME: &str = "session";
29
24
30
25
#[derive(Clone, PartialEq, Serialize, Deserialize)]
31
-
pub struct WebSession {
32
-
pub did: String,
33
-
pub session_group: String,
26
+
pub(crate) enum WebSession {
27
+
Pds { did: String, session_group: String },
28
+
Aip { did: String, access_token: String },
34
29
}
35
30
36
31
impl TryFrom<String> for WebSession {
···
53
48
}
54
49
}
55
50
56
-
#[derive(Clone, Serialize, Deserialize)]
57
-
pub struct DestinationClaims {
58
-
#[serde(rename = "d")]
59
-
pub destination: String,
60
-
61
-
#[serde(rename = "n")]
62
-
pub nonce: String,
63
-
}
64
-
65
51
#[derive(Clone)]
66
-
pub struct Auth(pub Option<Handle>, pub Option<OAuthSession>);
52
+
pub(crate) enum Auth {
53
+
Unauthenticated,
54
+
Pds {
55
+
profile: IdentityProfile,
56
+
session: OAuthSession,
57
+
},
58
+
Aip {
59
+
profile: IdentityProfile,
60
+
access_token: String,
61
+
},
62
+
}
67
63
68
64
impl Auth {
65
+
/// Get the profile if authenticated, None otherwise
66
+
pub(crate) fn profile(&self) -> Option<&IdentityProfile> {
67
+
match self {
68
+
Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => Some(profile),
69
+
Auth::Unauthenticated => None,
70
+
}
71
+
}
69
72
/// Requires authentication and redirects to login with a signed token containing the original destination
70
73
///
71
74
/// This creates a redirect URL with a signed token containing the destination,
72
75
/// which the login handler can verify and redirect back to after successful authentication.
73
-
#[instrument(level = "debug", skip(self, secret_key), err)]
74
-
pub fn require(
76
+
#[instrument(level = "debug", skip(self, config), err)]
77
+
pub(crate) fn require(
75
78
&self,
76
-
secret_key: &SecretKey,
79
+
config: &crate::config::Config,
77
80
location: &str,
78
-
) -> Result<Handle, MiddlewareAuthError> {
79
-
if let Some(handle) = &self.0 {
80
-
trace!(did = %handle.did, "User authenticated");
81
-
return Ok(handle.clone());
81
+
) -> Result<IdentityProfile, MiddlewareAuthError> {
82
+
match self {
83
+
Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => {
84
+
trace!(did = %profile.did, "User authenticated");
85
+
return Ok(profile.clone());
86
+
}
87
+
Auth::Unauthenticated => {}
82
88
}
83
89
84
90
debug!(
···
86
92
"Authentication required, creating signed redirect"
87
93
);
88
94
89
-
// Create claims with destination and random nonce
90
-
let claims = DestinationClaims {
91
-
destination: location.to_string(),
92
-
nonce: ulid::Ulid::new().to_string(),
93
-
};
94
-
95
-
// Encode claims to base64
96
-
let claims = claims.to_base64()?;
97
-
let claim_content = claims.to_string();
98
-
let encoded_json_bytes = general_purpose::URL_SAFE_NO_PAD.encode(claims.as_bytes());
99
-
100
-
// Sign the encoded claims
101
-
let signing_key = SigningKey::from(secret_key);
102
-
let signature: Signature = signing_key
103
-
.try_sign(encoded_json_bytes.as_bytes())
95
+
let header: Header = config
96
+
.destination_key_data
97
+
.clone()
98
+
.try_into()
104
99
.map_err(AuthMiddlewareError::SigningFailed)?;
100
+
let claims = Claims::new(JoseClaims {
101
+
http_uri: Some(location.to_string()),
102
+
nonce: Some(ulid::Ulid::new().to_string()),
103
+
..Default::default()
104
+
});
105
105
106
-
// Format the final destination with claims and signature
107
-
let destination = format!(
108
-
"{}.{}",
109
-
claim_content,
110
-
general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes())
111
-
);
106
+
let destination_token = mint(&config.destination_key_data, &header, &claims)
107
+
.map_err(AuthMiddlewareError::SigningFailed)?;
112
108
113
-
trace!(
114
-
destination_length = destination.len(),
115
-
"Created signed destination token"
116
-
);
117
-
Err(MiddlewareAuthError::AccessDenied(destination))
109
+
Err(MiddlewareAuthError::AccessDenied(destination_token))
118
110
}
119
111
120
112
/// Simpler authentication check that just redirects to root path
121
113
///
122
114
/// Use this when you don't need to return to the original page after login
123
115
#[instrument(level = "debug", skip(self), err)]
124
-
pub fn require_flat(&self) -> Result<Handle, MiddlewareAuthError> {
125
-
if let Some(handle) = &self.0 {
126
-
trace!(did = %handle.did, "User authenticated");
127
-
return Ok(handle.clone());
116
+
pub(crate) fn require_flat(&self) -> Result<IdentityProfile, MiddlewareAuthError> {
117
+
match self {
118
+
Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => {
119
+
trace!(did = %profile.did, "User authenticated");
120
+
return Ok(profile.clone());
121
+
}
122
+
Auth::Unauthenticated => {}
128
123
}
129
124
130
125
debug!("Authentication required, redirecting to root");
···
135
130
///
136
131
/// Returns NotFound error instead of redirecting to login for security reasons
137
132
#[instrument(level = "debug", skip(self, config), err)]
138
-
pub fn require_admin(&self, config: &Config) -> Result<Handle, MiddlewareAuthError> {
139
-
if let Some(handle) = &self.0 {
140
-
if config.is_admin(&handle.did) {
141
-
debug!(did = %handle.did, "Admin authenticated");
142
-
return Ok(handle.clone());
133
+
pub(crate) fn require_admin(
134
+
&self,
135
+
config: &Config,
136
+
) -> Result<IdentityProfile, MiddlewareAuthError> {
137
+
match self {
138
+
Auth::Pds { profile, .. } | Auth::Aip { profile, .. } => {
139
+
if config.is_admin(&profile.did) {
140
+
debug!(did = %profile.did, "Admin authenticated");
141
+
return Ok(profile.clone());
142
+
}
143
+
debug!(did = %profile.did, "User not an admin");
143
144
}
144
-
debug!(did = %handle.did, "User not an admin");
145
-
} else {
146
-
debug!("No authentication found for admin check");
145
+
Auth::Unauthenticated => {
146
+
debug!("No authentication found for admin check");
147
+
}
147
148
}
148
149
149
150
// Return NotFound instead of redirect for security reasons
···
173
174
.and_then(|inner_value| WebSession::try_from(inner_value).ok());
174
175
175
176
if let Some(web_session) = session {
176
-
trace!(?web_session.session_group, "Found session cookie");
177
+
match web_session {
178
+
WebSession::Pds { did, session_group } => {
179
+
trace!(?session_group, "Found PDS session cookie");
177
180
178
-
match web_session_lookup(
179
-
&web_context.pool,
180
-
&web_session.session_group,
181
-
Some(&web_session.did),
182
-
)
183
-
.await
184
-
{
185
-
Ok(record) => {
186
-
debug!(?web_session.session_group, "Session validated");
187
-
return Ok(Self(Some(record.0), Some(record.1)));
181
+
match web_session_lookup(&web_context.pool, &session_group, Some(&did)).await {
182
+
Ok(record) => {
183
+
debug!(?session_group, "Session validated");
184
+
return Ok(Auth::Pds {
185
+
profile: record.0,
186
+
session: record.1,
187
+
});
188
+
}
189
+
Err(err) => {
190
+
debug!(?session_group, ?err, "Invalid session");
191
+
return Ok(Auth::Unauthenticated);
192
+
}
193
+
};
188
194
}
189
-
Err(err) => {
190
-
debug!(?web_session.session_group, ?err, "Invalid session");
191
-
return Ok(Self(None, None));
195
+
WebSession::Aip { did, access_token } => {
196
+
trace!(?access_token, "Found AIP session cookie with access token");
197
+
// For AIP OAuth, try to fetch the handle from the database
198
+
let profile = sqlx::query_as::<_, IdentityProfile>(
199
+
"SELECT * FROM identity_profiles WHERE did = $1",
200
+
)
201
+
.bind(&did)
202
+
.fetch_one(&web_context.pool)
203
+
.await
204
+
.ok();
205
+
206
+
if let Some(profile) = profile {
207
+
return Ok(Auth::Aip {
208
+
profile,
209
+
access_token,
210
+
});
211
+
} else {
212
+
return Ok(Auth::Unauthenticated);
213
+
}
192
214
}
193
-
};
215
+
}
194
216
}
195
217
196
218
trace!("No session cookie found");
197
-
Ok(Self(None, None))
219
+
Ok(Auth::Unauthenticated)
198
220
}
199
221
}
+2
-2
src/http/middleware_i18n.rs
+2
-2
src/http/middleware_i18n.rs
···
85
85
86
86
/// Wrapper around LanguageIdentifier for the current request's language
87
87
#[derive(Clone, Debug)]
88
-
pub struct Language(pub LanguageIdentifier);
88
+
pub(crate) struct Language(pub(crate) LanguageIdentifier);
89
89
90
90
impl std::fmt::Display for Language {
91
91
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
···
120
120
let auth: Auth = Cached::<Auth>::from_request_parts(parts, context).await?.0;
121
121
122
122
// 1. Try to get language from user's profile settings
123
-
if let Some(handle) = &auth.0 {
123
+
if let Some(handle) = auth.profile() {
124
124
if let Ok(auth_lang) = handle.language.parse::<LanguageIdentifier>() {
125
125
debug!(language = %auth_lang, "Using language from user profile");
126
126
return Ok(Self(auth_lang));
+5
-2
src/http/mod.rs
+5
-2
src/http/mod.rs
···
14
14
pub mod handle_admin_rsvps;
15
15
pub mod handle_create_event;
16
16
pub mod handle_create_rsvp;
17
+
pub mod handle_delete_event;
17
18
pub mod handle_edit_event;
18
19
pub mod handle_import;
19
20
pub mod handle_index;
20
21
pub mod handle_migrate_event;
21
22
pub mod handle_migrate_rsvp;
23
+
24
+
pub mod handle_oauth_aip_callback;
25
+
pub mod handle_oauth_aip_login;
22
26
pub mod handle_oauth_callback;
23
-
pub mod handle_oauth_jwks;
24
27
pub mod handle_oauth_login;
28
+
25
29
pub mod handle_oauth_logout;
26
-
pub mod handle_oauth_metadata;
27
30
pub mod handle_policy;
28
31
pub mod handle_profile;
29
32
pub mod handle_set_language;
+17
-24
src/http/pagination.rs
+17
-24
src/http/pagination.rs
···
1
1
use crate::http::utils::stringify;
2
2
use serde::{Deserialize, Serialize};
3
3
4
-
pub const PAGE_DEFAULT: i64 = 1;
5
-
pub const PAGE_MIN: i64 = 1;
6
-
pub const PAGE_MAX: i64 = 100;
7
-
pub const PAGE_SIZE_DEFAULT: i64 = 10;
8
-
pub const PAGE_SIZE_MIN: i64 = 5;
9
-
pub const PAGE_SIZE_MAX: i64 = 100;
10
-
11
-
pub const LIMITED_PAGE_DEFAULT: i64 = 1;
12
-
pub const LIMITED_PAGE_MIN: i64 = 1;
13
-
pub const LIMITED_PAGE_MAX: i64 = 5;
14
-
pub const LIMITED_PAGE_SIZE_DEFAULT: i64 = 5;
15
-
pub const LIMITED_PAGE_SIZE_MIN: i64 = 5;
16
-
pub const LIMITED_PAGE_SIZE_MAX: i64 = 5;
4
+
pub(crate) const PAGE_DEFAULT: i64 = 1;
5
+
pub(crate) const PAGE_MIN: i64 = 1;
6
+
pub(crate) const PAGE_MAX: i64 = 100;
7
+
pub(crate) const PAGE_SIZE_DEFAULT: i64 = 10;
8
+
pub(crate) const PAGE_SIZE_MIN: i64 = 5;
9
+
pub(crate) const PAGE_SIZE_MAX: i64 = 100;
17
10
18
11
#[derive(Deserialize, Default)]
19
-
pub struct Pagination {
20
-
pub page: Option<i64>,
21
-
pub page_size: Option<i64>,
12
+
pub(crate) struct Pagination {
13
+
pub(crate) page: Option<i64>,
14
+
pub(crate) page_size: Option<i64>,
22
15
}
23
16
24
17
#[derive(Serialize, Debug)]
25
-
pub struct PaginationView {
26
-
pub previous: Option<i64>,
27
-
pub previous_url: Option<String>,
28
-
pub next: Option<i64>,
29
-
pub next_url: Option<String>,
18
+
pub(crate) struct PaginationView {
19
+
pub(crate) previous: Option<i64>,
20
+
pub(crate) previous_url: Option<String>,
21
+
pub(crate) next: Option<i64>,
22
+
pub(crate) next_url: Option<String>,
30
23
}
31
24
32
25
impl Pagination {
33
-
pub fn admin_clamped(&self) -> (i64, i64) {
26
+
pub(crate) fn admin_clamped(&self) -> (i64, i64) {
34
27
let page = self.page.unwrap_or(1).clamp(1, 25000);
35
28
let page_size = self.page_size.unwrap_or(1).clamp(20, 100);
36
29
(page, page_size)
37
30
}
38
31
39
-
pub fn clamped(&self) -> (i64, i64) {
32
+
pub(crate) fn clamped(&self) -> (i64, i64) {
40
33
let page = self.page.unwrap_or(PAGE_DEFAULT).clamp(PAGE_MIN, PAGE_MAX);
41
34
let page_size = self
42
35
.page_size
···
47
40
}
48
41
49
42
impl PaginationView {
50
-
pub fn new(page_size: i64, total: i64, page: i64, params: Vec<(&str, &str)>) -> Self {
43
+
pub(crate) fn new(page_size: i64, total: i64, page: i64, params: Vec<(&str, &str)>) -> Self {
51
44
let (previous, previous_url) = {
52
45
if page > 1 {
53
46
let page_value = (page - 1).to_string();
+11
-21
src/http/rsvp_form.rs
+11
-21
src/http/rsvp_form.rs
···
1
1
use serde::{Deserialize, Serialize};
2
-
use thiserror::Error;
3
2
4
3
use crate::{
5
4
errors::expand_error,
···
7
6
storage::{event::event_get_cid, StoragePool},
8
7
};
9
8
10
-
#[derive(Debug, Error)]
11
-
pub enum BuildRSVPError {
12
-
#[error("error-rsvp-builder-1 Invalid Subject")]
13
-
InvalidSubject,
14
-
15
-
#[error("error-rsvp-builder-2 Invalid Status")]
16
-
InvalidStatus,
17
-
}
18
-
19
9
#[derive(Serialize, Deserialize, Debug, Default, PartialEq, Clone)]
20
-
pub enum BuildRsvpContentState {
10
+
pub(crate) enum BuildRsvpContentState {
21
11
#[default]
22
12
Reset,
23
13
Selecting,
···
26
16
}
27
17
28
18
#[derive(Serialize, Deserialize, Debug, Clone)]
29
-
pub struct BuildRSVPForm {
30
-
pub build_state: Option<BuildRsvpContentState>,
19
+
pub(crate) struct BuildRSVPForm {
20
+
pub(crate) build_state: Option<BuildRsvpContentState>,
31
21
32
-
pub subject_aturi: Option<String>,
33
-
pub subject_aturi_error: Option<String>,
22
+
pub(crate) subject_aturi: Option<String>,
23
+
pub(crate) subject_aturi_error: Option<String>,
34
24
35
-
pub subject_cid: Option<String>,
36
-
pub subject_cid_error: Option<String>,
25
+
pub(crate) subject_cid: Option<String>,
26
+
pub(crate) subject_cid_error: Option<String>,
37
27
38
-
pub status: Option<String>,
39
-
pub status_error: Option<String>,
28
+
pub(crate) status: Option<String>,
29
+
pub(crate) status_error: Option<String>,
40
30
}
41
31
42
32
impl BuildRSVPForm {
43
-
pub async fn hydrate(
33
+
pub(crate) async fn hydrate(
44
34
&mut self,
45
35
database_pool: &StoragePool,
46
36
locales: &Locales,
···
70
60
}
71
61
}
72
62
73
-
pub fn validate(
63
+
pub(crate) fn validate(
74
64
&mut self,
75
65
_locales: &Locales,
76
66
_language: &unic_langid::LanguageIdentifier,
+33
-11
src/http/server.rs
+33
-11
src/http/server.rs
···
15
15
use tower_http::{cors::CorsLayer, services::ServeDir};
16
16
use tracing::Span;
17
17
18
+
use crate::config::OAuthBackendConfig;
18
19
use crate::http::{
19
20
context::WebContext,
20
21
handle_admin_denylist::{
···
33
34
handle_location_datalist, handle_starts_at_builder,
34
35
},
35
36
handle_create_rsvp::handle_create_rsvp,
37
+
handle_delete_event::handle_delete_event,
36
38
handle_edit_event::handle_edit_event,
37
39
handle_import::{handle_import, handle_import_submit},
38
40
handle_index::handle_index,
39
41
handle_migrate_event::handle_migrate_event,
40
42
handle_migrate_rsvp::handle_migrate_rsvp,
41
-
handle_oauth_callback::handle_oauth_callback,
42
-
handle_oauth_jwks::handle_oauth_jwks,
43
-
handle_oauth_login::handle_oauth_login,
44
43
handle_oauth_logout::handle_logout,
45
-
handle_oauth_metadata::handle_oauth_metadata,
46
44
handle_policy::{
47
45
handle_acknowledgement, handle_cookie_policy, handle_privacy_policy,
48
46
handle_terms_of_service,
···
55
53
handle_view_rsvp::handle_view_rsvp,
56
54
};
57
55
56
+
use crate::http::handle_oauth_aip_callback::handle_oauth_callback as handle_oauth_aip_callback;
57
+
use crate::http::handle_oauth_aip_login::handle_oauth_aip_login;
58
+
use crate::http::handle_oauth_callback::handle_oauth_callback as handle_oauth_pds_callback;
59
+
use crate::http::handle_oauth_login::handle_oauth_login as handle_oauth_pds_login;
60
+
use atproto_oauth_axum::handler_metadata::handle_oauth_metadata;
61
+
58
62
pub fn build_router(web_context: WebContext) -> Router {
59
63
let serve_dir = ServeDir::new(web_context.config.http_static_path.clone());
60
64
61
-
Router::new()
65
+
let mut router = Router::new()
62
66
.route("/", get(handle_index))
63
67
.route("/privacy-policy", get(handle_privacy_policy))
64
68
.route("/terms-of-service", get(handle_terms_of_service))
65
69
.route("/cookie-policy", get(handle_cookie_policy))
66
-
.route("/acknowledgement", get(handle_acknowledgement))
70
+
.route("/acknowledgement", get(handle_acknowledgement));
71
+
72
+
// Add OAuth metadata route only for AT Protocol backend
73
+
if matches!(
74
+
web_context.config.oauth_backend,
75
+
OAuthBackendConfig::ATProtocol { .. }
76
+
) {
77
+
router = router.route("/oauth/client-metadata.json", get(handle_oauth_metadata));
78
+
}
79
+
80
+
// Add OAuth routes based on backend configuration
81
+
router = match web_context.config.oauth_backend {
82
+
OAuthBackendConfig::AIP { .. } => router
83
+
.route("/oauth/login", get(handle_oauth_aip_login))
84
+
.route("/oauth/login", post(handle_oauth_aip_login))
85
+
.route("/oauth/callback", get(handle_oauth_aip_callback)),
86
+
OAuthBackendConfig::ATProtocol { .. } => router
87
+
.route("/oauth/login", get(handle_oauth_pds_login))
88
+
.route("/oauth/login", post(handle_oauth_pds_login))
89
+
.route("/oauth/callback", get(handle_oauth_pds_callback)),
90
+
};
91
+
92
+
router
67
93
.route("/admin", get(handle_admin_index))
68
94
.route("/admin/handles", get(handle_admin_handles))
69
95
.route(
···
79
105
.route("/admin/rsvps", get(handle_admin_rsvps))
80
106
.route("/admin/rsvp", get(handle_admin_rsvp))
81
107
.route("/admin/rsvps/import", post(handle_admin_import_rsvp))
82
-
.route("/oauth/client-metadata.json", get(handle_oauth_metadata))
83
-
.route("/.well-known/jwks.json", get(handle_oauth_jwks))
84
-
.route("/oauth/login", get(handle_oauth_login))
85
-
.route("/oauth/login", post(handle_oauth_login))
86
-
.route("/oauth/callback", get(handle_oauth_callback))
87
108
.route("/logout", get(handle_logout))
88
109
.route("/language", post(handle_set_language))
89
110
.route("/settings", get(handle_settings))
···
105
126
.route("/event/links", post(handle_link_at_builder))
106
127
.route("/{handle_slug}/{event_rkey}/edit", get(handle_edit_event))
107
128
.route("/{handle_slug}/{event_rkey}/edit", post(handle_edit_event))
129
+
.route("/{handle_slug}/{event_rkey}/delete", post(handle_delete_event))
108
130
.route(
109
131
"/{handle_slug}/{event_rkey}/migrate",
110
132
get(handle_migrate_event),
+7
-7
src/http/tab_selector.rs
+7
-7
src/http/tab_selector.rs
···
1
1
use serde::{Deserialize, Serialize};
2
2
3
3
#[derive(Deserialize, Default)]
4
-
pub struct TabSelector {
5
-
pub tab: Option<String>,
4
+
pub(crate) struct TabSelector {
5
+
pub(crate) tab: Option<String>,
6
6
}
7
7
8
8
#[derive(Deserialize, Serialize)]
9
-
pub struct TabLink {
10
-
pub name: String,
11
-
pub label: String,
12
-
pub url: String,
13
-
pub active: bool,
9
+
pub(crate) struct TabLink {
10
+
pub(crate) name: String,
11
+
pub(crate) label: String,
12
+
pub(crate) url: String,
13
+
pub(crate) active: bool,
14
14
}
+1
-1
src/http/templates.rs
+1
-1
src/http/templates.rs
+2
-2
src/http/timezones.rs
+2
-2
src/http/timezones.rs
···
2
2
use chrono::{DateTime, NaiveDateTime, Utc};
3
3
use itertools::Itertools;
4
4
5
-
use crate::storage::handle::model::Handle;
5
+
use crate::storage::identity_profile::model::IdentityProfile;
6
6
7
-
pub fn supported_timezones(handle: Option<&Handle>) -> (&str, Vec<&str>) {
7
+
pub(crate) fn supported_timezones(handle: Option<&IdentityProfile>) -> (&str, Vec<&str>) {
8
8
let handle_tz = handle
9
9
.and_then(|handle| handle.tz.parse().ok())
10
10
.unwrap_or(chrono_tz::UTC);
+11
-11
src/http/utils.rs
+11
-11
src/http/utils.rs
···
6
6
http::errors::UrlError,
7
7
};
8
8
9
-
pub type QueryParam<'a> = (&'a str, &'a str);
10
-
pub type QueryParams<'a> = Vec<QueryParam<'a>>;
9
+
pub(crate) type QueryParam<'a> = (&'a str, &'a str);
10
+
pub(crate) type QueryParams<'a> = Vec<QueryParam<'a>>;
11
11
12
-
pub fn stringify(query: QueryParams) -> String {
12
+
pub(crate) fn stringify(query: QueryParams) -> String {
13
13
query.iter().fold(String::new(), |acc, &tuple| {
14
14
acc + tuple.0 + "=" + tuple.1 + "&"
15
15
})
16
16
}
17
17
18
-
pub struct URLBuilder {
18
+
struct URLBuilder {
19
19
host: String,
20
20
path: String,
21
21
params: Vec<(String, String)>,
22
22
}
23
23
24
-
pub fn build_url(host: &str, path: &str, params: Vec<Option<(&str, &str)>>) -> String {
24
+
pub(crate) fn build_url(host: &str, path: &str, params: Vec<Option<(&str, &str)>>) -> String {
25
25
let mut url_builder = URLBuilder::new(host);
26
26
url_builder.path(path);
27
27
···
33
33
}
34
34
35
35
impl URLBuilder {
36
-
pub fn new(host: &str) -> URLBuilder {
36
+
fn new(host: &str) -> URLBuilder {
37
37
let host = if host.starts_with("https://") {
38
38
host.to_string()
39
39
} else {
···
53
53
}
54
54
}
55
55
56
-
pub fn param(&mut self, key: &str, value: &str) -> &mut Self {
56
+
fn param(&mut self, key: &str, value: &str) -> &mut Self {
57
57
self.params
58
58
.push((key.to_owned(), urlencoding::encode(value).to_string()));
59
59
self
60
60
}
61
61
62
-
pub fn path(&mut self, path: &str) -> &mut Self {
62
+
fn path(&mut self, path: &str) -> &mut Self {
63
63
path.clone_into(&mut self.path);
64
64
self
65
65
}
66
66
67
-
pub fn build(self) -> String {
67
+
fn build(self) -> String {
68
68
let mut url_params = String::new();
69
69
70
70
if !self.params.is_empty() {
···
78
78
}
79
79
}
80
80
81
-
pub fn url_from_aturi(external_base: &str, aturi: &str) -> Result<String, UrlError> {
81
+
pub(crate) fn url_from_aturi(external_base: &str, aturi: &str) -> Result<String, UrlError> {
82
82
let aturi = aturi.strip_prefix("at://").unwrap_or(aturi);
83
83
let parts = aturi.split("/").collect::<Vec<_>>();
84
84
if parts.len() == 3 && parts[1] == NSID {
···
105
105
clen
106
106
}
107
107
108
-
pub fn truncate_text(text: &str, tlen: usize, suffix: Option<String>) -> String {
108
+
pub(crate) fn truncate_text(text: &str, tlen: usize, suffix: Option<String>) -> String {
109
109
if text.len() <= tlen {
110
110
return text.to_string();
111
111
}
-243
src/jose.rs
-243
src/jose.rs
···
1
-
use base64::{engine::general_purpose, Engine as _};
2
-
use jwt::{Claims, Header};
3
-
use p256::{
4
-
ecdsa::{
5
-
signature::{Signer, Verifier},
6
-
Signature, SigningKey, VerifyingKey,
7
-
},
8
-
PublicKey, SecretKey,
9
-
};
10
-
use std::time::{SystemTime, UNIX_EPOCH};
11
-
12
-
use crate::encoding::ToBase64;
13
-
use crate::jose_errors::JoseError;
14
-
15
-
/// Signs a JWT token with the provided secret key, header, and claims
16
-
///
17
-
/// Creates a JSON Web Token (JWT) by:
18
-
/// 1. Base64URL encoding the header and claims
19
-
/// 2. Signing the encoded header and claims with the secret key
20
-
/// 3. Returning the complete JWT (header.claims.signature)
21
-
pub fn mint_token(
22
-
secret_key: &SecretKey,
23
-
header: &Header,
24
-
claims: &Claims,
25
-
) -> Result<String, JoseError> {
26
-
// Encode header and claims to base64url
27
-
let header = header
28
-
.to_base64()
29
-
.map_err(|_| JoseError::SigningKeyNotFound)?;
30
-
let claims = claims
31
-
.to_base64()
32
-
.map_err(|_| JoseError::SigningKeyNotFound)?;
33
-
let content = format!("{}.{}", header, claims);
34
-
35
-
// Create signature
36
-
let signing_key = SigningKey::from(secret_key.clone());
37
-
let signature: Signature = signing_key
38
-
.try_sign(content.as_bytes())
39
-
.map_err(JoseError::SigningFailed)?;
40
-
41
-
// Return complete JWT
42
-
Ok(format!(
43
-
"{}.{}",
44
-
content,
45
-
general_purpose::URL_SAFE_NO_PAD.encode(signature.to_bytes())
46
-
))
47
-
}
48
-
49
-
/// Verifies a JWT token's signature and validates its claims
50
-
///
51
-
/// Performs the following validations:
52
-
/// 1. Checks token format is valid (three parts separated by periods)
53
-
/// 2. Decodes header and claims from base64url format
54
-
/// 3. Verifies the token signature using the provided public key
55
-
/// 4. Validates token expiration (if provided in claims)
56
-
/// 5. Validates token not-before time (if provided in claims)
57
-
/// 6. Returns the decoded claims if all validation passes
58
-
pub fn verify_token(token: &str, public_key: &PublicKey) -> Result<Claims, JoseError> {
59
-
// Split token into its parts
60
-
let parts: Vec<&str> = token.split('.').collect();
61
-
if parts.len() != 3 {
62
-
return Err(JoseError::InvalidTokenFormat);
63
-
}
64
-
65
-
let encoded_header = parts[0];
66
-
let encoded_claims = parts[1];
67
-
let encoded_signature = parts[2];
68
-
69
-
// Decode header
70
-
let header_bytes = general_purpose::URL_SAFE_NO_PAD
71
-
.decode(encoded_header)
72
-
.map_err(|_| JoseError::InvalidHeader)?;
73
-
74
-
let header: Header =
75
-
serde_json::from_slice(&header_bytes).map_err(|_| JoseError::InvalidHeader)?;
76
-
77
-
// Verify algorithm matches what we expect
78
-
// We only support ES256 for now
79
-
if header.algorithm.as_deref() != Some("ES256") {
80
-
return Err(JoseError::UnsupportedAlgorithm);
81
-
}
82
-
83
-
// Decode claims
84
-
let claims_bytes = general_purpose::URL_SAFE_NO_PAD
85
-
.decode(encoded_claims)
86
-
.map_err(|_| JoseError::InvalidClaims)?;
87
-
88
-
let claims: Claims =
89
-
serde_json::from_slice(&claims_bytes).map_err(|_| JoseError::InvalidClaims)?;
90
-
91
-
// Decode signature
92
-
let signature_bytes = general_purpose::URL_SAFE_NO_PAD
93
-
.decode(encoded_signature)
94
-
.map_err(|_| JoseError::InvalidSignature)?;
95
-
96
-
let signature =
97
-
Signature::try_from(signature_bytes.as_slice()).map_err(|_| JoseError::InvalidSignature)?;
98
-
99
-
// Verify signature
100
-
let verifying_key = VerifyingKey::from(public_key);
101
-
let content = format!("{}.{}", encoded_header, encoded_claims);
102
-
103
-
verifying_key
104
-
.verify(content.as_bytes(), &signature)
105
-
.map_err(|_| JoseError::SignatureVerificationFailed)?;
106
-
107
-
// Get current timestamp for validation
108
-
let now = SystemTime::now()
109
-
.duration_since(UNIX_EPOCH)
110
-
.map_err(|_| JoseError::SystemTimeError)?
111
-
.as_secs();
112
-
113
-
// Validate expiration time if present
114
-
if let Some(exp) = claims.jose.expiration {
115
-
if now >= exp {
116
-
return Err(JoseError::TokenExpired);
117
-
}
118
-
}
119
-
120
-
// Validate not-before time if present
121
-
if let Some(nbf) = claims.jose.not_before {
122
-
if now < nbf {
123
-
return Err(JoseError::TokenNotYetValid);
124
-
}
125
-
}
126
-
127
-
// Return validated claims
128
-
Ok(claims)
129
-
}
130
-
131
-
pub mod jwk {
132
-
use elliptic_curve::JwkEcKey;
133
-
use p256::SecretKey;
134
-
use rand::rngs::OsRng;
135
-
use serde::{Deserialize, Serialize};
136
-
137
-
#[derive(Serialize, Deserialize, Clone, PartialEq, Debug)]
138
-
pub struct WrappedJsonWebKey {
139
-
#[serde(skip_serializing_if = "Option::is_none", default)]
140
-
pub kid: Option<String>,
141
-
142
-
#[serde(skip_serializing_if = "Option::is_none", default)]
143
-
pub alg: Option<String>,
144
-
145
-
#[serde(flatten)]
146
-
pub jwk: JwkEcKey,
147
-
}
148
-
149
-
#[derive(Serialize, Deserialize, Clone)]
150
-
pub struct WrappedJsonWebKeySet {
151
-
pub keys: Vec<WrappedJsonWebKey>,
152
-
}
153
-
154
-
pub fn generate() -> WrappedJsonWebKey {
155
-
let secret_key = SecretKey::random(&mut OsRng);
156
-
157
-
let kid = ulid::Ulid::new().to_string();
158
-
159
-
WrappedJsonWebKey {
160
-
kid: Some(kid),
161
-
alg: Some("ES256".to_string()),
162
-
jwk: secret_key.to_jwk(),
163
-
}
164
-
}
165
-
}
166
-
167
-
pub mod jwt {
168
-
169
-
use std::collections::BTreeMap;
170
-
171
-
use elliptic_curve::JwkEcKey;
172
-
use serde::{Deserialize, Serialize};
173
-
174
-
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
175
-
pub struct Header {
176
-
#[serde(rename = "alg", skip_serializing_if = "Option::is_none")]
177
-
pub algorithm: Option<String>,
178
-
179
-
#[serde(rename = "kid", skip_serializing_if = "Option::is_none")]
180
-
pub key_id: Option<String>,
181
-
182
-
#[serde(rename = "typ", skip_serializing_if = "Option::is_none")]
183
-
pub type_: Option<String>,
184
-
185
-
#[serde(rename = "jwk", skip_serializing_if = "Option::is_none")]
186
-
pub json_web_key: Option<JwkEcKey>,
187
-
}
188
-
189
-
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
190
-
pub struct Claims {
191
-
#[serde(flatten)]
192
-
pub jose: JoseClaims,
193
-
#[serde(flatten)]
194
-
pub private: BTreeMap<String, serde_json::Value>,
195
-
}
196
-
197
-
impl Claims {
198
-
pub fn new(jose: JoseClaims) -> Self {
199
-
Claims {
200
-
jose,
201
-
private: BTreeMap::new(),
202
-
}
203
-
}
204
-
}
205
-
206
-
pub type SecondsSinceEpoch = u64;
207
-
208
-
#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize)]
209
-
pub struct JoseClaims {
210
-
#[serde(rename = "iss", skip_serializing_if = "Option::is_none")]
211
-
pub issuer: Option<String>,
212
-
213
-
#[serde(rename = "sub", skip_serializing_if = "Option::is_none")]
214
-
pub subject: Option<String>,
215
-
216
-
#[serde(rename = "aud", skip_serializing_if = "Option::is_none")]
217
-
pub audience: Option<String>,
218
-
219
-
#[serde(rename = "exp", skip_serializing_if = "Option::is_none")]
220
-
pub expiration: Option<SecondsSinceEpoch>,
221
-
222
-
#[serde(rename = "nbf", skip_serializing_if = "Option::is_none")]
223
-
pub not_before: Option<SecondsSinceEpoch>,
224
-
225
-
#[serde(rename = "iat", skip_serializing_if = "Option::is_none")]
226
-
pub issued_at: Option<SecondsSinceEpoch>,
227
-
228
-
#[serde(rename = "jti", skip_serializing_if = "Option::is_none")]
229
-
pub json_web_token_id: Option<String>,
230
-
231
-
#[serde(rename = "htm", skip_serializing_if = "Option::is_none")]
232
-
pub http_method: Option<String>,
233
-
234
-
#[serde(rename = "htu", skip_serializing_if = "Option::is_none")]
235
-
pub http_uri: Option<String>,
236
-
237
-
#[serde(rename = "nonce", skip_serializing_if = "Option::is_none")]
238
-
pub nonce: Option<String>,
239
-
240
-
#[serde(rename = "ath", skip_serializing_if = "Option::is_none")]
241
-
pub auth: Option<String>,
242
-
}
243
-
}
-140
src/jose_errors.rs
-140
src/jose_errors.rs
···
1
-
use thiserror::Error;
2
-
3
-
/// Represents errors that can occur during JOSE (JSON Object Signing and Encryption) operations.
4
-
///
5
-
/// These errors are related to JSON Web Token (JWT) signing and verification,
6
-
/// JSON Web Key (JWK) operations, and DPoP (Demonstrating Proof-of-Possession) functionality.
7
-
#[derive(Debug, Error)]
8
-
pub enum JoseError {
9
-
/// Error when token signing fails.
10
-
///
11
-
/// This error occurs when the application tries to sign a JWT token
12
-
/// using an ECDSA signing key but the signing operation fails.
13
-
#[error("error-jose-1 Failed to sign token: {0:?}")]
14
-
SigningFailed(p256::ecdsa::Error),
15
-
16
-
/// Error when a required signing key is not found.
17
-
///
18
-
/// This error occurs when the application tries to use a signing key
19
-
/// that is not available in the loaded configuration.
20
-
#[error("error-jose-2 Signing key not found")]
21
-
SigningKeyNotFound,
22
-
23
-
/// Error when a simple error cannot be parsed.
24
-
///
25
-
/// This error occurs when the application fails to parse an error
26
-
/// response from an OAuth server.
27
-
#[error("error-jose-3 Unable to parse simple error")]
28
-
UnableToParseSimpleError,
29
-
30
-
/// Error when a required DPoP header is missing.
31
-
///
32
-
/// This error occurs when making a request to a protected resource
33
-
/// that requires a DPoP header, but the header is not present.
34
-
#[error("error-jose-4 Missing DPoP header")]
35
-
MissingDpopHeader,
36
-
37
-
/// Error when a DPoP header cannot be parsed.
38
-
///
39
-
/// This error occurs when the application receives a DPoP header
40
-
/// that is malformed or contains invalid data.
41
-
#[error("error-jose-5 Unable to parse DPoP header: {0}")]
42
-
UnableToParseDpopHeader(String),
43
-
44
-
/// Error when a DPoP proof token cannot be created.
45
-
///
46
-
/// This error occurs when the application fails to create a valid
47
-
/// DPoP proof token required for accessing protected resources.
48
-
#[error("error-jose-6 Unable to mint DPoP proof token: {0}")]
49
-
UnableToMintDpopProofToken(String),
50
-
51
-
/// Error when an unexpected error occurs during JOSE operations.
52
-
///
53
-
/// This is a catch-all error for unexpected issues that occur
54
-
/// during JOSE-related operations.
55
-
#[error("error-jose-7 Unexpected error: {0}")]
56
-
UnexpectedError(String),
57
-
58
-
/// Error when a JWT token has an invalid format.
59
-
///
60
-
/// This error occurs when a JWT token doesn't have three parts
61
-
/// separated by periods (header.payload.signature).
62
-
#[error("error-jose-8 Invalid token format")]
63
-
InvalidTokenFormat,
64
-
65
-
/// Error when a JWT header cannot be decoded or parsed.
66
-
///
67
-
/// This error occurs when the header part of a JWT token contains
68
-
/// invalid base64url-encoded data or invalid JSON.
69
-
#[error("error-jose-9 Invalid token header")]
70
-
InvalidHeader,
71
-
72
-
/// Error when a JWT claims part cannot be decoded or parsed.
73
-
///
74
-
/// This error occurs when the claims part of a JWT token contains
75
-
/// invalid base64url-encoded data or invalid JSON.
76
-
#[error("error-jose-10 Invalid token claims")]
77
-
InvalidClaims,
78
-
79
-
/// Error when a JWT signature cannot be decoded.
80
-
///
81
-
/// This error occurs when the signature part of a JWT token contains
82
-
/// invalid base64url-encoded data.
83
-
#[error("error-jose-11 Invalid token signature")]
84
-
InvalidSignature,
85
-
86
-
/// Error when JWT signature verification fails.
87
-
///
88
-
/// This error occurs when the signature of a JWT token doesn't match
89
-
/// the expected signature computed from the header and claims.
90
-
#[error("error-jose-12 Signature verification failed")]
91
-
SignatureVerificationFailed,
92
-
93
-
/// Error when a JWT token has expired.
94
-
///
95
-
/// This error occurs when the current time is past the expiration
96
-
/// time (exp) specified in the JWT claims.
97
-
#[error("error-jose-13 Token has expired")]
98
-
TokenExpired,
99
-
100
-
/// Error when a JWT token is not yet valid.
101
-
///
102
-
/// This error occurs when the current time is before the not-before
103
-
/// time (nbf) specified in the JWT claims.
104
-
#[error("error-jose-14 Token is not yet valid")]
105
-
TokenNotYetValid,
106
-
107
-
/// Error when the system time cannot be determined.
108
-
///
109
-
/// This rare error occurs when the system time cannot be retrieved
110
-
/// or is invalid.
111
-
#[error("error-jose-15 System time error")]
112
-
SystemTimeError,
113
-
114
-
/// Error when a JWT token uses an unsupported algorithm.
115
-
///
116
-
/// This error occurs when the JWT token uses an algorithm (alg)
117
-
/// that the application doesn't support or allow.
118
-
#[error("error-jose-16 Unsupported algorithm")]
119
-
UnsupportedAlgorithm,
120
-
121
-
/// Error when a JWT token has invalid key parameters.
122
-
///
123
-
/// This error occurs when the JWT token uses key parameters that
124
-
/// are invalid or not supported.
125
-
#[error("error-jose-17 Invalid key parameters: {0}")]
126
-
InvalidKeyParameters(String),
127
-
}
128
-
129
-
/// Represents errors that can occur during JSON Web Key (JWK) operations.
130
-
///
131
-
/// These errors relate to operations with cryptographic keys in JWK format.
132
-
#[derive(Debug, Error)]
133
-
pub enum JwkError {
134
-
/// Error when a secret JWK key is not found.
135
-
///
136
-
/// This error occurs when the application tries to use a secret JWK key
137
-
/// that is not available in the loaded configuration.
138
-
#[error("error-jwk-1 Secret JWK key not found")]
139
-
SecretKeyNotFound,
140
-
}
+21
src/key_provider.rs
+21
src/key_provider.rs
···
1
+
use async_trait::async_trait;
2
+
use atproto_identity::key::{KeyData, KeyProvider};
3
+
use std::collections::HashMap;
4
+
5
+
#[derive(Clone)]
6
+
pub struct SimpleKeyProvider {
7
+
keys: HashMap<String, KeyData>,
8
+
}
9
+
10
+
impl SimpleKeyProvider {
11
+
pub fn new(keys: HashMap<String, KeyData>) -> Self {
12
+
Self { keys }
13
+
}
14
+
}
15
+
16
+
#[async_trait]
17
+
impl KeyProvider for SimpleKeyProvider {
18
+
async fn get_private_key_by_id(&self, key_id: &str) -> anyhow::Result<Option<KeyData>> {
19
+
Ok(self.keys.get(key_id).cloned())
20
+
}
21
+
}
+4
-11
src/lib.rs
+4
-11
src/lib.rs
···
1
1
pub mod atproto;
2
2
pub mod config;
3
3
pub mod config_errors;
4
-
pub mod did;
5
-
pub mod encoding;
6
-
pub mod encoding_errors;
7
4
pub mod errors;
8
5
pub mod http;
9
6
pub mod i18n;
10
-
pub mod jose;
11
-
pub mod jose_errors;
12
-
pub mod oauth;
13
-
pub mod oauth_client_errors;
14
-
pub mod oauth_errors;
7
+
pub mod key_provider;
15
8
pub mod refresh_tokens_errors;
16
-
pub mod resolve;
17
9
pub mod storage;
18
-
// Removing storage_oauth_errors, consolidated with storage/oauth_model_errors
10
+
11
+
pub mod task_identity_refresh;
12
+
pub mod task_oauth_requests_cleanup;
19
13
pub mod task_refresh_tokens;
20
-
pub mod validation;
-664
src/oauth.rs
-664
src/oauth.rs
···
1
-
use dpop::DpopRetry;
2
-
use p256::SecretKey;
3
-
use rand::distributions::{Alphanumeric, DistString};
4
-
use reqwest_chain::ChainMiddleware;
5
-
use reqwest_middleware::ClientBuilder;
6
-
use std::time::Duration;
7
-
8
-
use crate::oauth_client_errors::OAuthClientError;
9
-
use crate::oauth_errors::{AuthServerValidationError, ResourceValidationError};
10
-
use model::{AuthorizationServer, OAuthProtectedResource, ParResponse, TokenResponse};
11
-
12
-
use crate::{
13
-
jose::{
14
-
jwt::{Claims, Header, JoseClaims},
15
-
mint_token,
16
-
},
17
-
storage::{
18
-
handle::model::Handle,
19
-
oauth::model::{OAuthRequest, OAuthRequestState},
20
-
},
21
-
};
22
-
23
-
const HTTP_CLIENT_TIMEOUT_SECS: u64 = 8;
24
-
25
-
pub async fn pds_resources(
26
-
http_client: &reqwest::Client,
27
-
pds: &str,
28
-
) -> Result<(OAuthProtectedResource, AuthorizationServer), OAuthClientError> {
29
-
let protected_resource = oauth_protected_resource(http_client, pds).await?;
30
-
31
-
let first_authorization_server = protected_resource
32
-
.authorization_servers
33
-
.first()
34
-
.ok_or(OAuthClientError::InvalidOAuthProtectedResource)?;
35
-
36
-
let authorization_server =
37
-
oauth_authorization_server(http_client, first_authorization_server).await?;
38
-
Ok((protected_resource, authorization_server))
39
-
}
40
-
41
-
pub async fn oauth_protected_resource(
42
-
http_client: &reqwest::Client,
43
-
pds: &str,
44
-
) -> Result<OAuthProtectedResource, OAuthClientError> {
45
-
let destination = format!("{}/.well-known/oauth-protected-resource", pds);
46
-
47
-
let resource: OAuthProtectedResource = http_client
48
-
.get(destination)
49
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
50
-
.send()
51
-
.await
52
-
.map_err(OAuthClientError::OAuthProtectedResourceRequestFailed)?
53
-
.json()
54
-
.await
55
-
.map_err(OAuthClientError::MalformedOAuthProtectedResourceResponse)?;
56
-
57
-
if resource.resource != pds {
58
-
return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse(
59
-
ResourceValidationError::ResourceMustMatchPds.into(),
60
-
));
61
-
}
62
-
63
-
if resource.authorization_servers.is_empty() {
64
-
return Err(OAuthClientError::InvalidOAuthProtectedResourceResponse(
65
-
ResourceValidationError::AuthorizationServersMustNotBeEmpty.into(),
66
-
));
67
-
}
68
-
69
-
Ok(resource)
70
-
}
71
-
72
-
#[tracing::instrument(skip(http_client), err)]
73
-
pub async fn oauth_authorization_server(
74
-
http_client: &reqwest::Client,
75
-
pds: &str,
76
-
) -> Result<AuthorizationServer, OAuthClientError> {
77
-
let destination = format!("{}/.well-known/oauth-authorization-server", pds);
78
-
79
-
let resource: AuthorizationServer = http_client
80
-
.get(destination)
81
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
82
-
.send()
83
-
.await
84
-
.map_err(OAuthClientError::AuthorizationServerRequestFailed)?
85
-
.json()
86
-
.await
87
-
.map_err(OAuthClientError::MalformedAuthorizationServerResponse)?;
88
-
89
-
// All of this is going to change.
90
-
91
-
if resource.issuer != pds {
92
-
return Err(OAuthClientError::InvalidAuthorizationServerResponse(
93
-
AuthServerValidationError::IssuerMustMatchPds.into(),
94
-
));
95
-
}
96
-
97
-
resource
98
-
.response_types_supported
99
-
.iter()
100
-
.find(|&x| x == "code")
101
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
102
-
AuthServerValidationError::ResponseTypesSupportMustIncludeCode.into(),
103
-
))?;
104
-
105
-
resource
106
-
.grant_types_supported
107
-
.iter()
108
-
.find(|&x| x == "authorization_code")
109
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
110
-
AuthServerValidationError::GrantTypesSupportMustIncludeAuthorizationCode.into(),
111
-
))?;
112
-
resource
113
-
.grant_types_supported
114
-
.iter()
115
-
.find(|&x| x == "refresh_token")
116
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
117
-
AuthServerValidationError::GrantTypesSupportMustIncludeRefreshToken.into(),
118
-
))?;
119
-
resource
120
-
.code_challenge_methods_supported
121
-
.iter()
122
-
.find(|&x| x == "S256")
123
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
124
-
AuthServerValidationError::CodeChallengeMethodsSupportedMustIncludeS256.into(),
125
-
))?;
126
-
resource
127
-
.token_endpoint_auth_methods_supported
128
-
.iter()
129
-
.find(|&x| x == "none")
130
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
131
-
AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludeNone.into(),
132
-
))?;
133
-
resource
134
-
.token_endpoint_auth_methods_supported
135
-
.iter()
136
-
.find(|&x| x == "private_key_jwt")
137
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
138
-
AuthServerValidationError::TokenEndpointAuthMethodsSupportedMustIncludePrivateKeyJwt
139
-
.into(),
140
-
))?;
141
-
resource
142
-
.token_endpoint_auth_signing_alg_values_supported
143
-
.iter()
144
-
.find(|&x| x == "ES256")
145
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
146
-
AuthServerValidationError::TokenEndpointAuthSigningAlgValuesMustIncludeES256.into(),
147
-
))?;
148
-
resource
149
-
.scopes_supported
150
-
.iter()
151
-
.find(|&x| x == "atproto")
152
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
153
-
AuthServerValidationError::ScopesSupportedMustIncludeAtProto.into(),
154
-
))?;
155
-
resource
156
-
.scopes_supported
157
-
.iter()
158
-
.find(|&x| x == "transition:generic")
159
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
160
-
AuthServerValidationError::ScopesSupportedMustIncludeTransitionGeneric.into(),
161
-
))?;
162
-
resource
163
-
.dpop_signing_alg_values_supported
164
-
.iter()
165
-
.find(|&x| x == "ES256")
166
-
.ok_or(OAuthClientError::InvalidAuthorizationServerResponse(
167
-
AuthServerValidationError::DpopSigningAlgValuesSupportedMustIncludeES256.into(),
168
-
))?;
169
-
170
-
if !(resource.authorization_response_iss_parameter_supported
171
-
&& resource.require_pushed_authorization_requests
172
-
&& resource.client_id_metadata_document_supported)
173
-
{
174
-
return Err(OAuthClientError::InvalidAuthorizationServerResponse(
175
-
AuthServerValidationError::RequiredServerFeaturesMustBeSupported.into(),
176
-
));
177
-
}
178
-
179
-
Ok(resource)
180
-
}
181
-
182
-
pub async fn oauth_init(
183
-
http_client: &reqwest::Client,
184
-
external_url_base: &str,
185
-
(secret_key_id, secret_key): (&str, SecretKey),
186
-
dpop_secret_key: &SecretKey,
187
-
handle: &str,
188
-
authorization_server: &AuthorizationServer,
189
-
oauth_request_state: &OAuthRequestState,
190
-
) -> Result<ParResponse, OAuthClientError> {
191
-
let par_url = authorization_server
192
-
.pushed_authorization_request_endpoint
193
-
.clone();
194
-
195
-
let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
196
-
let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
197
-
198
-
let scope = "atproto transition:generic".to_string();
199
-
200
-
let client_assertion_header = Header {
201
-
algorithm: Some("ES256".to_string()),
202
-
key_id: Some(secret_key_id.to_string()),
203
-
..Default::default()
204
-
};
205
-
206
-
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
207
-
let client_assertion_claims = Claims::new(JoseClaims {
208
-
issuer: Some(client_id.clone()),
209
-
subject: Some(client_id.clone()),
210
-
audience: Some(authorization_server.issuer.clone()),
211
-
json_web_token_id: Some(client_assertion_jti),
212
-
issued_at: Some(chrono::Utc::now().timestamp() as u64),
213
-
..Default::default()
214
-
});
215
-
tracing::info!("client_assertion_claims: {:?}", client_assertion_claims);
216
-
217
-
let client_assertion_token = mint_token(
218
-
&secret_key,
219
-
&client_assertion_header,
220
-
&client_assertion_claims,
221
-
)
222
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
223
-
224
-
let now = chrono::Utc::now();
225
-
let public_key = dpop_secret_key.public_key();
226
-
227
-
let dpop_proof_header = Header {
228
-
type_: Some("dpop+jwt".to_string()),
229
-
algorithm: Some("ES256".to_string()),
230
-
json_web_key: Some(public_key.to_jwk()),
231
-
..Default::default()
232
-
};
233
-
let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
234
-
235
-
let dpop_proof_claim = Claims::new(JoseClaims {
236
-
json_web_token_id: Some(dpop_proof_jti),
237
-
http_method: Some("POST".to_string()),
238
-
http_uri: Some(par_url.clone()),
239
-
issued_at: Some(now.timestamp() as u64),
240
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
241
-
..Default::default()
242
-
});
243
-
let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
244
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
245
-
246
-
let dpop_retry = DpopRetry::new(
247
-
dpop_proof_header.clone(),
248
-
dpop_proof_claim.clone(),
249
-
dpop_secret_key.clone(),
250
-
);
251
-
252
-
let dpop_retry_client = ClientBuilder::new(http_client.clone())
253
-
.with(ChainMiddleware::new(dpop_retry.clone()))
254
-
.build();
255
-
256
-
let params = [
257
-
("response_type", "code"),
258
-
("code_challenge", &oauth_request_state.code_challenge),
259
-
("code_challenge_method", "S256"),
260
-
("client_id", client_id.as_str()),
261
-
("state", oauth_request_state.state.as_str()),
262
-
("redirect_uri", redirect_uri.as_str()),
263
-
("scope", scope.as_str()),
264
-
("login_hint", handle),
265
-
(
266
-
"client_assertion_type",
267
-
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
268
-
),
269
-
("client_assertion", client_assertion_token.as_str()),
270
-
];
271
-
272
-
tracing::warn!("params: {:?}", params);
273
-
274
-
dpop_retry_client
275
-
.post(par_url)
276
-
.header("DPoP", dpop_proof_token.as_str())
277
-
.form(¶ms)
278
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
279
-
.send()
280
-
.await
281
-
.map_err(OAuthClientError::PARMiddlewareRequestFailed)?
282
-
.json()
283
-
.await
284
-
.map_err(OAuthClientError::MalformedPARResponse)
285
-
}
286
-
287
-
pub async fn oauth_complete(
288
-
http_client: &reqwest::Client,
289
-
external_url_base: &str,
290
-
(secret_key_id, secret_key): (&str, SecretKey),
291
-
callback_code: &str,
292
-
oauth_request: &OAuthRequest,
293
-
handle: &Handle,
294
-
dpop_secret_key: &SecretKey,
295
-
) -> Result<TokenResponse, OAuthClientError> {
296
-
let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?;
297
-
298
-
let client_assertion_header = Header {
299
-
algorithm: Some("ES256".to_string()),
300
-
key_id: Some(secret_key_id.to_string()),
301
-
..Default::default()
302
-
};
303
-
304
-
let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
305
-
let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
306
-
307
-
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
308
-
let client_assertion_claims = Claims::new(JoseClaims {
309
-
issuer: Some(client_id.clone()),
310
-
subject: Some(client_id.clone()),
311
-
audience: Some(authorization_server.issuer.clone()),
312
-
json_web_token_id: Some(client_assertion_jti),
313
-
issued_at: Some(chrono::Utc::now().timestamp() as u64),
314
-
..Default::default()
315
-
});
316
-
317
-
let client_assertion_token = mint_token(
318
-
&secret_key,
319
-
&client_assertion_header,
320
-
&client_assertion_claims,
321
-
)
322
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
323
-
324
-
let params = [
325
-
("client_id", client_id.as_str()),
326
-
("redirect_uri", redirect_uri.as_str()),
327
-
("grant_type", "authorization_code"),
328
-
("code", callback_code),
329
-
("code_verifier", &oauth_request.pkce_verifier),
330
-
(
331
-
"client_assertion_type",
332
-
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
333
-
),
334
-
("client_assertion", client_assertion_token.as_str()),
335
-
];
336
-
337
-
let public_key = dpop_secret_key.public_key();
338
-
339
-
let token_endpoint = authorization_server.token_endpoint.clone();
340
-
341
-
let now = chrono::Utc::now();
342
-
343
-
let dpop_proof_header = Header {
344
-
type_: Some("dpop+jwt".to_string()),
345
-
algorithm: Some("ES256".to_string()),
346
-
json_web_key: Some(public_key.to_jwk()),
347
-
..Default::default()
348
-
};
349
-
let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
350
-
let dpop_proof_claim = Claims::new(JoseClaims {
351
-
json_web_token_id: Some(dpop_proof_jti),
352
-
http_method: Some("POST".to_string()),
353
-
http_uri: Some(authorization_server.token_endpoint.clone()),
354
-
issued_at: Some(now.timestamp() as u64),
355
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
356
-
..Default::default()
357
-
});
358
-
let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
359
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
360
-
361
-
let dpop_retry = DpopRetry::new(
362
-
dpop_proof_header.clone(),
363
-
dpop_proof_claim.clone(),
364
-
dpop_secret_key.clone(),
365
-
);
366
-
367
-
let dpop_retry_client = ClientBuilder::new(http_client.clone())
368
-
.with(ChainMiddleware::new(dpop_retry.clone()))
369
-
.build();
370
-
371
-
dpop_retry_client
372
-
.post(token_endpoint)
373
-
.header("DPoP", dpop_proof_token.as_str())
374
-
.form(¶ms)
375
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
376
-
.send()
377
-
.await
378
-
.map_err(OAuthClientError::TokenMiddlewareRequestFailed)?
379
-
.json()
380
-
.await
381
-
.map_err(OAuthClientError::MalformedTokenResponse)
382
-
}
383
-
384
-
pub async fn client_oauth_refresh(
385
-
http_client: &reqwest::Client,
386
-
external_url_base: &str,
387
-
(secret_key_id, secret_key): (&str, SecretKey),
388
-
refresh_token: &str,
389
-
handle: &Handle,
390
-
dpop_secret_key: &SecretKey,
391
-
) -> Result<TokenResponse, OAuthClientError> {
392
-
let (_, authorization_server) = pds_resources(http_client, &handle.pds).await?;
393
-
394
-
let client_assertion_header = Header {
395
-
algorithm: Some("ES256".to_string()),
396
-
key_id: Some(secret_key_id.to_string()),
397
-
..Default::default()
398
-
};
399
-
400
-
let client_id = format!("https://{}/oauth/client-metadata.json", external_url_base);
401
-
let redirect_uri = format!("https://{}/oauth/callback", external_url_base);
402
-
403
-
let client_assertion_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
404
-
let client_assertion_claims = Claims::new(JoseClaims {
405
-
issuer: Some(client_id.clone()),
406
-
subject: Some(client_id.clone()),
407
-
audience: Some(authorization_server.issuer.clone()),
408
-
json_web_token_id: Some(client_assertion_jti),
409
-
issued_at: Some(chrono::Utc::now().timestamp() as u64),
410
-
..Default::default()
411
-
});
412
-
413
-
let client_assertion_token = mint_token(
414
-
&secret_key,
415
-
&client_assertion_header,
416
-
&client_assertion_claims,
417
-
)
418
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
419
-
420
-
let params = [
421
-
("client_id", client_id.as_str()),
422
-
("redirect_uri", redirect_uri.as_str()),
423
-
("grant_type", "refresh_token"),
424
-
("refresh_token", refresh_token),
425
-
(
426
-
"client_assertion_type",
427
-
"urn:ietf:params:oauth:client-assertion-type:jwt-bearer",
428
-
),
429
-
("client_assertion", client_assertion_token.as_str()),
430
-
];
431
-
432
-
tracing::info!("params: {:?}", params);
433
-
434
-
let public_key = dpop_secret_key.public_key();
435
-
436
-
let token_endpoint = authorization_server.token_endpoint.clone();
437
-
438
-
let now = chrono::Utc::now();
439
-
440
-
let dpop_proof_header = Header {
441
-
type_: Some("dpop+jwt".to_string()),
442
-
algorithm: Some("ES256".to_string()),
443
-
json_web_key: Some(public_key.to_jwk()),
444
-
..Default::default()
445
-
};
446
-
let dpop_proof_jti = Alphanumeric.sample_string(&mut rand::thread_rng(), 30);
447
-
let dpop_proof_claim = Claims::new(JoseClaims {
448
-
json_web_token_id: Some(dpop_proof_jti),
449
-
http_method: Some("POST".to_string()),
450
-
http_uri: Some(authorization_server.token_endpoint.clone()),
451
-
issued_at: Some(now.timestamp() as u64),
452
-
expiration: Some((now + chrono::Duration::seconds(30)).timestamp() as u64),
453
-
..Default::default()
454
-
});
455
-
let dpop_proof_token = mint_token(dpop_secret_key, &dpop_proof_header, &dpop_proof_claim)
456
-
.map_err(|jose_err| OAuthClientError::MintTokenFailed(jose_err.into()))?;
457
-
458
-
let dpop_retry = DpopRetry::new(
459
-
dpop_proof_header.clone(),
460
-
dpop_proof_claim.clone(),
461
-
dpop_secret_key.clone(),
462
-
);
463
-
464
-
let dpop_retry_client = ClientBuilder::new(http_client.clone())
465
-
.with(ChainMiddleware::new(dpop_retry.clone()))
466
-
.build();
467
-
468
-
dpop_retry_client
469
-
.post(token_endpoint)
470
-
.header("DPoP", dpop_proof_token.as_str())
471
-
.form(¶ms)
472
-
.timeout(Duration::from_secs(HTTP_CLIENT_TIMEOUT_SECS))
473
-
.send()
474
-
.await
475
-
.map_err(OAuthClientError::TokenMiddlewareRequestFailed)?
476
-
.json()
477
-
.await
478
-
.map_err(OAuthClientError::MalformedTokenResponse)
479
-
}
480
-
481
-
pub mod dpop {
482
-
use p256::SecretKey;
483
-
use reqwest::header::HeaderValue;
484
-
use reqwest_chain::Chainer;
485
-
use serde::Deserialize;
486
-
487
-
use crate::{
488
-
jose::{
489
-
jwt::{Claims, Header},
490
-
mint_token,
491
-
},
492
-
jose_errors::JoseError,
493
-
};
494
-
495
-
#[derive(Clone, Debug, Deserialize)]
496
-
pub struct SimpleError {
497
-
pub error: Option<String>,
498
-
pub error_description: Option<String>,
499
-
pub message: Option<String>,
500
-
}
501
-
502
-
impl std::fmt::Display for SimpleError {
503
-
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
504
-
if let Some(value) = &self.error {
505
-
write!(f, "{}", value)
506
-
} else if let Some(value) = &self.message {
507
-
write!(f, "{}", value)
508
-
} else if let Some(value) = &self.error_description {
509
-
write!(f, "{}", value)
510
-
} else {
511
-
write!(f, "unknown")
512
-
}
513
-
}
514
-
}
515
-
516
-
#[derive(Clone)]
517
-
pub struct DpopRetry {
518
-
pub header: Header,
519
-
pub claims: Claims,
520
-
pub secret: SecretKey,
521
-
}
522
-
523
-
impl DpopRetry {
524
-
pub fn new(header: Header, claims: Claims, secret: SecretKey) -> Self {
525
-
DpopRetry {
526
-
header,
527
-
claims,
528
-
secret,
529
-
}
530
-
}
531
-
}
532
-
533
-
#[async_trait::async_trait]
534
-
impl Chainer for DpopRetry {
535
-
type State = ();
536
-
537
-
async fn chain(
538
-
&self,
539
-
result: Result<reqwest::Response, reqwest_middleware::Error>,
540
-
_state: &mut Self::State,
541
-
request: &mut reqwest::Request,
542
-
) -> Result<Option<reqwest::Response>, reqwest_middleware::Error> {
543
-
let response = result?;
544
-
545
-
let status_code = response.status();
546
-
547
-
if status_code != 400 && status_code != 401 {
548
-
return Ok(Some(response));
549
-
};
550
-
551
-
let headers = response.headers().clone();
552
-
553
-
let simple_error = response.json::<SimpleError>().await;
554
-
if simple_error.is_err() {
555
-
return Err(reqwest_middleware::Error::Middleware(
556
-
JoseError::UnableToParseSimpleError.into(),
557
-
));
558
-
}
559
-
560
-
let simple_error = simple_error.unwrap();
561
-
562
-
tracing::error!("dpop error: {:?}", simple_error);
563
-
564
-
let is_use_dpop_nonce_error = simple_error
565
-
.clone()
566
-
.error
567
-
.is_some_and(|error_value| error_value == "use_dpop_nonce");
568
-
569
-
if !is_use_dpop_nonce_error {
570
-
return Err(reqwest_middleware::Error::Middleware(
571
-
JoseError::UnexpectedError(simple_error.to_string()).into(),
572
-
));
573
-
}
574
-
575
-
let dpop_header = headers.get("DPoP-Nonce");
576
-
577
-
if dpop_header.is_none() {
578
-
return Err(reqwest_middleware::Error::Middleware(
579
-
JoseError::MissingDpopHeader.into(),
580
-
));
581
-
}
582
-
583
-
let new_dpop_header = dpop_header.unwrap().to_str().map_err(|dpop_header_err| {
584
-
reqwest_middleware::Error::Middleware(
585
-
JoseError::UnableToParseDpopHeader(dpop_header_err.to_string()).into(),
586
-
)
587
-
})?;
588
-
589
-
let dpop_proof_header = self.header.clone();
590
-
let mut dpop_proof_claim = self.claims.clone();
591
-
dpop_proof_claim
592
-
.private
593
-
.insert("nonce".to_string(), new_dpop_header.to_string().into());
594
-
595
-
let dpop_proof_token = mint_token(&self.secret, &dpop_proof_header, &dpop_proof_claim)
596
-
.map_err(|dpop_proof_token_err| {
597
-
reqwest_middleware::Error::Middleware(
598
-
JoseError::UnableToMintDpopProofToken(dpop_proof_token_err.to_string())
599
-
.into(),
600
-
)
601
-
})?;
602
-
603
-
request.headers_mut().insert(
604
-
"DPoP",
605
-
HeaderValue::from_str(&dpop_proof_token).expect("invalid header value"),
606
-
);
607
-
Ok(None)
608
-
}
609
-
}
610
-
}
611
-
612
-
pub mod model {
613
-
use serde::Deserialize;
614
-
615
-
#[derive(Clone, Deserialize)]
616
-
pub struct OAuthProtectedResource {
617
-
pub resource: String,
618
-
pub authorization_servers: Vec<String>,
619
-
pub scopes_supported: Vec<String>,
620
-
pub bearer_methods_supported: Vec<String>,
621
-
}
622
-
623
-
#[derive(Clone, Deserialize, Default, Debug)]
624
-
pub struct AuthorizationServer {
625
-
pub introspection_endpoint: String,
626
-
pub authorization_endpoint: String,
627
-
pub authorization_response_iss_parameter_supported: bool,
628
-
pub client_id_metadata_document_supported: bool,
629
-
pub code_challenge_methods_supported: Vec<String>,
630
-
pub dpop_signing_alg_values_supported: Vec<String>,
631
-
pub grant_types_supported: Vec<String>,
632
-
pub issuer: String,
633
-
pub pushed_authorization_request_endpoint: String,
634
-
pub request_parameter_supported: bool,
635
-
pub require_pushed_authorization_requests: bool,
636
-
pub response_types_supported: Vec<String>,
637
-
pub scopes_supported: Vec<String>,
638
-
pub token_endpoint_auth_methods_supported: Vec<String>,
639
-
pub token_endpoint_auth_signing_alg_values_supported: Vec<String>,
640
-
pub token_endpoint: String,
641
-
}
642
-
643
-
#[derive(Clone, Deserialize)]
644
-
pub struct ParResponse {
645
-
pub request_uri: String,
646
-
pub expires_in: u64,
647
-
}
648
-
649
-
#[derive(Clone, Deserialize)]
650
-
pub struct TokenResponse {
651
-
pub access_token: String,
652
-
pub token_type: String,
653
-
pub refresh_token: String,
654
-
pub scope: String,
655
-
pub expires_in: u32,
656
-
pub sub: String,
657
-
}
658
-
}
659
-
660
-
// This errors module is now deprecated.
661
-
// Use crate::oauth_client_errors::OAuthClientError instead.
662
-
pub mod errors {
663
-
pub use crate::oauth_client_errors::OAuthClientError;
664
-
}
-99
src/oauth_client_errors.rs
-99
src/oauth_client_errors.rs
···
1
-
use thiserror::Error;
2
-
3
-
/// Represents errors that can occur during OAuth client operations.
4
-
///
5
-
/// These errors are related to the OAuth client functionality, including
6
-
/// interacting with authorization servers, protected resources, and token management.
7
-
#[derive(Debug, Error)]
8
-
pub enum OAuthClientError {
9
-
/// Error when a request to the authorization server fails.
10
-
///
11
-
/// This error occurs when the OAuth client fails to establish a connection
12
-
/// or complete a request to the authorization server.
13
-
#[error("error-oauth-client-1 Authorization Server Request Failed: {0:?}")]
14
-
AuthorizationServerRequestFailed(reqwest::Error),
15
-
16
-
/// Error when the authorization server response is malformed.
17
-
///
18
-
/// This error occurs when the response from the authorization server
19
-
/// cannot be properly parsed or processed.
20
-
#[error("error-oauth-client-2 Malformed Authorization Server Response: {0:?}")]
21
-
MalformedAuthorizationServerResponse(reqwest::Error),
22
-
23
-
/// Error when the authorization server response is invalid.
24
-
///
25
-
/// This error occurs when the response from the authorization server
26
-
/// is well-formed but contains invalid or unexpected data.
27
-
#[error("error-oauth-client-3 Invalid Authorization Server Response: {0:?}")]
28
-
InvalidAuthorizationServerResponse(anyhow::Error),
29
-
30
-
/// Error when an OAuth protected resource is invalid.
31
-
///
32
-
/// This error occurs when trying to access a protected resource that
33
-
/// is not properly configured for OAuth access.
34
-
#[error("error-oauth-client-4 Invalid OAuth Protected Resource")]
35
-
InvalidOAuthProtectedResource,
36
-
37
-
/// Error when a request to an OAuth protected resource fails.
38
-
///
39
-
/// This error occurs when the OAuth client fails to establish a connection
40
-
/// or complete a request to a protected resource.
41
-
#[error("error-oauth-client-5 OAuth Protected Resource Request Failed: {0:?}")]
42
-
OAuthProtectedResourceRequestFailed(reqwest::Error),
43
-
44
-
/// Error when a protected resource response is malformed.
45
-
///
46
-
/// This error occurs when the response from a protected resource
47
-
/// cannot be properly parsed or processed.
48
-
#[error("error-oauth-client-6 Malformed OAuth Protected Resource Response: {0:?}")]
49
-
MalformedOAuthProtectedResourceResponse(reqwest::Error),
50
-
51
-
/// Error when a protected resource response is invalid.
52
-
///
53
-
/// This error occurs when the response from a protected resource
54
-
/// is well-formed but contains invalid or unexpected data.
55
-
#[error("error-oauth-client-7 Invalid OAuth Protected Resource Response: {0:?}")]
56
-
InvalidOAuthProtectedResourceResponse(anyhow::Error),
57
-
58
-
/// Error when a PAR middleware request fails.
59
-
///
60
-
/// This error occurs when a Pushed Authorization Request (PAR) middleware
61
-
/// request fails to complete successfully.
62
-
#[error("error-oauth-client-8 PAR Middleware Request Failed: {0:?}")]
63
-
PARMiddlewareRequestFailed(reqwest_middleware::Error),
64
-
65
-
/// Error when a PAR request fails.
66
-
///
67
-
/// This error occurs when a Pushed Authorization Request (PAR)
68
-
/// fails to be properly processed by the authorization server.
69
-
#[error("error-oauth-client-9 PAR Request Failed: {0:?}")]
70
-
PARRequestFailed(reqwest::Error),
71
-
72
-
/// Error when a PAR response is malformed.
73
-
///
74
-
/// This error occurs when the response from a Pushed Authorization
75
-
/// Request (PAR) cannot be properly parsed or processed.
76
-
#[error("error-oauth-client-10 Malformed PAR Response: {0:?}")]
77
-
MalformedPARResponse(reqwest::Error),
78
-
79
-
/// Error when token minting fails.
80
-
///
81
-
/// This error occurs when the system fails to mint (create) a new
82
-
/// OAuth token, typically due to cryptographic or validation issues.
83
-
#[error("error-oauth-client-11 Token minting failed: {0:?}")]
84
-
MintTokenFailed(anyhow::Error),
85
-
86
-
/// Error when a token response is malformed.
87
-
///
88
-
/// This error occurs when the response containing a token cannot
89
-
/// be properly parsed or processed.
90
-
#[error("error-oauth-client-12 Malformed Token Response: {0:?}")]
91
-
MalformedTokenResponse(reqwest::Error),
92
-
93
-
/// Error when a token middleware request fails.
94
-
///
95
-
/// This error occurs when a token-related middleware request
96
-
/// fails to complete successfully.
97
-
#[error("error-oauth-client-13 Token Request Failed: {0:?}")]
98
-
TokenMiddlewareRequestFailed(reqwest_middleware::Error),
99
-
}
-116
src/oauth_errors.rs
-116
src/oauth_errors.rs
···
1
-
use thiserror::Error;
2
-
3
-
/// Represents errors that can occur during OAuth resource validation.
4
-
///
5
-
/// These errors occur when validating the configuration of an OAuth resource server
6
-
/// against the requirements of the AT Protocol.
7
-
#[derive(Debug, Error)]
8
-
pub enum ResourceValidationError {
9
-
/// Error when the resource server URI doesn't match the PDS URI.
10
-
///
11
-
/// This error occurs when the resource server URI in the OAuth configuration
12
-
/// does not match the expected Personal Data Server (PDS) URI, which is required
13
-
/// for proper AT Protocol OAuth integration.
14
-
#[error("error-oauth-resource-1 Resource must match PDS")]
15
-
ResourceMustMatchPds,
16
-
17
-
/// Error when the authorization servers list is empty.
18
-
///
19
-
/// This error occurs when the OAuth resource configuration doesn't specify
20
-
/// any authorization servers, which is required for AT Protocol OAuth flows.
21
-
#[error("error-oauth-resource-2 Authorization servers must not be empty")]
22
-
AuthorizationServersMustNotBeEmpty,
23
-
}
24
-
25
-
/// Represents errors that can occur during OAuth authorization server validation.
26
-
///
27
-
/// These errors occur when validating the configuration of an OAuth authorization server
28
-
/// against the requirements specified by the AT Protocol.
29
-
#[derive(Debug, Error)]
30
-
pub enum AuthServerValidationError {
31
-
/// Error when the authorization server issuer doesn't match the PDS.
32
-
///
33
-
/// This error occurs when the issuer URI in the OAuth authorization server metadata
34
-
/// does not match the expected Personal Data Server (PDS) URI.
35
-
#[error("error-oauth-auth-server-1 Issuer must match PDS")]
36
-
IssuerMustMatchPds,
37
-
38
-
/// Error when the 'code' response type is not supported.
39
-
///
40
-
/// This error occurs when the authorization server doesn't support the 'code' response type,
41
-
/// which is required for the authorization code grant flow in AT Protocol.
42
-
#[error("error-oauth-auth-server-2 Response types supported must include 'code'")]
43
-
ResponseTypesSupportMustIncludeCode,
44
-
45
-
/// Error when the 'authorization_code' grant type is not supported.
46
-
///
47
-
/// This error occurs when the authorization server doesn't support the 'authorization_code'
48
-
/// grant type, which is required for the AT Protocol OAuth flow.
49
-
#[error("error-oauth-auth-server-3 Grant types supported must include 'authorization_code'")]
50
-
GrantTypesSupportMustIncludeAuthorizationCode,
51
-
52
-
/// Error when the 'refresh_token' grant type is not supported.
53
-
///
54
-
/// This error occurs when the authorization server doesn't support the 'refresh_token'
55
-
/// grant type, which is required for maintaining long-term access in AT Protocol.
56
-
#[error("error-oauth-auth-server-4 Grant types supported must include 'refresh_token'")]
57
-
GrantTypesSupportMustIncludeRefreshToken,
58
-
59
-
/// Error when the 'S256' code challenge method is not supported.
60
-
///
61
-
/// This error occurs when the authorization server doesn't support the 'S256' code
62
-
/// challenge method for PKCE, which is required for secure authorization code flow.
63
-
#[error("error-oauth-auth-server-5 Code challenge methods supported must include 'S256'")]
64
-
CodeChallengeMethodsSupportedMustIncludeS256,
65
-
66
-
/// Error when the 'none' token endpoint auth method is not supported.
67
-
///
68
-
/// This error occurs when the authorization server doesn't support the 'none'
69
-
/// token endpoint authentication method, which is used for public clients.
70
-
#[error("error-oauth-auth-server-6 Token endpoint auth methods supported must include 'none'")]
71
-
TokenEndpointAuthMethodsSupportedMustIncludeNone,
72
-
73
-
/// Error when the 'private_key_jwt' token endpoint auth method is not supported.
74
-
///
75
-
/// This error occurs when the authorization server doesn't support the 'private_key_jwt'
76
-
/// token endpoint authentication method, which is required for AT Protocol clients.
77
-
#[error("error-oauth-auth-server-7 Token endpoint auth methods supported must include 'private_key_jwt'")]
78
-
TokenEndpointAuthMethodsSupportedMustIncludePrivateKeyJwt,
79
-
80
-
/// Error when the 'ES256' signing algorithm is not supported for token endpoint auth.
81
-
///
82
-
/// This error occurs when the authorization server doesn't support the 'ES256' signing
83
-
/// algorithm for token endpoint authentication, which is required for AT Protocol.
84
-
#[error("error-oauth-auth-server-8 Token endpoint auth signing algorithm values must include 'ES256'")]
85
-
TokenEndpointAuthSigningAlgValuesMustIncludeES256,
86
-
87
-
/// Error when the 'atproto' scope is not supported.
88
-
///
89
-
/// This error occurs when the authorization server doesn't support the 'atproto'
90
-
/// scope, which is required for accessing AT Protocol resources.
91
-
#[error("error-oauth-auth-server-9 Scopes supported must include 'atproto'")]
92
-
ScopesSupportedMustIncludeAtProto,
93
-
94
-
/// Error when the 'transition:generic' scope is not supported.
95
-
///
96
-
/// This error occurs when the authorization server doesn't support the 'transition:generic'
97
-
/// scope, which is required for transitional functionality in AT Protocol.
98
-
#[error("error-oauth-auth-server-10 Scopes supported must include 'transition:generic'")]
99
-
ScopesSupportedMustIncludeTransitionGeneric,
100
-
101
-
/// Error when the 'ES256' DPoP signing algorithm is not supported.
102
-
///
103
-
/// This error occurs when the authorization server doesn't support the 'ES256'
104
-
/// signing algorithm for DPoP proofs, which is required for AT Protocol security.
105
-
#[error(
106
-
"error-oauth-auth-server-11 DPoP signing algorithm values supported must include 'ES256'"
107
-
)]
108
-
DpopSigningAlgValuesSupportedMustIncludeES256,
109
-
110
-
/// Error when required server features are not supported.
111
-
///
112
-
/// This error occurs when the authorization server doesn't support required features
113
-
/// such as pushed authorization requests, client ID metadata, or authorization response parameters.
114
-
#[error("error-oauth-auth-server-12 Authorization response parameters, pushed requests, client ID metadata must be supported")]
115
-
RequiredServerFeaturesMustBeSupported,
116
-
}
+7
src/refresh_tokens_errors.rs
+7
src/refresh_tokens_errors.rs
···
26
26
/// used to manage session refresh operations.
27
27
#[error("error-refresh-3 Failed to place session group into refresh queue: {0:?}")]
28
28
PlaceInRefreshQueueFailed(deadpool_redis::redis::RedisError),
29
+
30
+
/// Error when the identity document cannot be found.
31
+
///
32
+
/// This error occurs when attempting to refresh a token but the necessary
33
+
/// identity document for the user is not available in storage.
34
+
#[error("error-refresh-4 Identity document not found")]
35
+
IdentityDocumentNotFound,
29
36
}
-205
src/resolve.rs
-205
src/resolve.rs
···
1
-
use anyhow::Result;
2
-
use errors::ResolveError;
3
-
use futures_util::future::join3;
4
-
use hickory_resolver::{
5
-
config::{NameServerConfigGroup, ResolverConfig, ResolverOpts},
6
-
TokioAsyncResolver,
7
-
};
8
-
use std::collections::HashSet;
9
-
use std::time::Duration;
10
-
11
-
use crate::config::DnsNameservers;
12
-
use crate::did::web::query_hostname;
13
-
14
-
pub enum InputType {
15
-
Handle(String),
16
-
Plc(String),
17
-
Web(String),
18
-
}
19
-
20
-
pub async fn resolve_handle_dns(
21
-
dns_resolver: &TokioAsyncResolver,
22
-
lookup_dns: &str,
23
-
) -> Result<String, ResolveError> {
24
-
let lookup = dns_resolver
25
-
.txt_lookup(&format!("_atproto.{}", lookup_dns))
26
-
.await
27
-
.map_err(ResolveError::DNSResolutionFailed)?;
28
-
29
-
let dids = lookup
30
-
.iter()
31
-
.filter_map(|record| {
32
-
record
33
-
.to_string()
34
-
.strip_prefix("did=")
35
-
.map(|did| did.to_string())
36
-
})
37
-
.collect::<HashSet<String>>();
38
-
39
-
if dids.len() > 1 {
40
-
return Err(ResolveError::MultipleDIDsFound);
41
-
}
42
-
43
-
dids.iter().next().cloned().ok_or(ResolveError::NoDIDsFound)
44
-
}
45
-
46
-
pub async fn resolve_handle_http(
47
-
http_client: &reqwest::Client,
48
-
handle: &str,
49
-
) -> Result<String, ResolveError> {
50
-
let lookup_url = format!("https://{}/.well-known/atproto-did", handle);
51
-
52
-
http_client
53
-
.get(lookup_url.clone())
54
-
.timeout(Duration::from_secs(10))
55
-
.send()
56
-
.await
57
-
.map_err(ResolveError::HTTPResolutionFailed)?
58
-
.text()
59
-
.await
60
-
.map_err(ResolveError::HTTPResolutionFailed)
61
-
.and_then(|body| {
62
-
if body.starts_with("did:") {
63
-
Ok(body.trim().to_string())
64
-
} else {
65
-
Err(ResolveError::InvalidHTTPResolutionResponse)
66
-
}
67
-
})
68
-
}
69
-
70
-
pub fn parse_input(input: &str) -> Result<InputType, ResolveError> {
71
-
let trimmed = {
72
-
if let Some(value) = input.trim().strip_prefix("at://") {
73
-
value.trim()
74
-
} else if let Some(value) = input.trim().strip_prefix('@') {
75
-
value.trim()
76
-
} else {
77
-
input.trim()
78
-
}
79
-
};
80
-
if trimmed.is_empty() {
81
-
return Err(ResolveError::InvalidInput);
82
-
}
83
-
if trimmed.starts_with("did:web:") {
84
-
Ok(InputType::Web(trimmed.to_string()))
85
-
} else if trimmed.starts_with("did:plc:") {
86
-
Ok(InputType::Plc(trimmed.to_string()))
87
-
} else {
88
-
Ok(InputType::Handle(trimmed.to_string()))
89
-
}
90
-
}
91
-
92
-
pub async fn resolve_handle(
93
-
http_client: &reqwest::Client,
94
-
dns_resolver: &TokioAsyncResolver,
95
-
handle: &str,
96
-
) -> Result<String, ResolveError> {
97
-
let trimmed = {
98
-
if let Some(value) = handle.trim().strip_prefix("at://") {
99
-
value
100
-
} else if let Some(value) = handle.trim().strip_prefix('@') {
101
-
value
102
-
} else {
103
-
handle.trim()
104
-
}
105
-
};
106
-
107
-
let (dns_lookup, http_lookup, did_web_lookup) = join3(
108
-
resolve_handle_dns(dns_resolver, trimmed),
109
-
resolve_handle_http(http_client, trimmed),
110
-
query_hostname(http_client, trimmed),
111
-
)
112
-
.await;
113
-
114
-
tracing::debug!(
115
-
?handle,
116
-
?dns_lookup,
117
-
?http_lookup,
118
-
?did_web_lookup,
119
-
"raw query results"
120
-
);
121
-
122
-
let did_web_lookup_did = did_web_lookup
123
-
.map(|document| document.id)
124
-
.map_err(ResolveError::DIDWebResolutionFailed);
125
-
126
-
let results = vec![dns_lookup, http_lookup, did_web_lookup_did]
127
-
.into_iter()
128
-
.filter_map(|result| result.ok())
129
-
.collect::<Vec<String>>();
130
-
if results.is_empty() {
131
-
return Err(ResolveError::NoDIDsFound);
132
-
}
133
-
134
-
tracing::debug!(?handle, ?results, "query results");
135
-
136
-
let first = results[0].clone();
137
-
if results.iter().all(|result| result == &first) {
138
-
return Ok(first);
139
-
}
140
-
Err(ResolveError::ConflictingDIDsFound)
141
-
}
142
-
143
-
pub async fn resolve_subject(
144
-
http_client: &reqwest::Client,
145
-
dns_resolver: &TokioAsyncResolver,
146
-
subject: &str,
147
-
) -> Result<String, ResolveError> {
148
-
match parse_input(subject)? {
149
-
InputType::Handle(handle) => resolve_handle(http_client, dns_resolver, &handle).await,
150
-
InputType::Plc(did) | InputType::Web(did) => Ok(did),
151
-
}
152
-
}
153
-
154
-
/// Creates a new DNS resolver with configuration based on app config.
155
-
///
156
-
/// If custom nameservers are configured in app config, they will be used.
157
-
/// Otherwise, the system default resolver configuration will be used.
158
-
pub fn create_resolver(nameservers: DnsNameservers) -> TokioAsyncResolver {
159
-
// Initialize the DNS resolver with custom nameservers if configured
160
-
let nameservers = nameservers.as_ref();
161
-
let resolver_config = if !nameservers.is_empty() {
162
-
// Use custom nameservers
163
-
tracing::info!("Using custom DNS nameservers: {:?}", nameservers);
164
-
let nameserver_group = NameServerConfigGroup::from_ips_clear(nameservers, 53, true);
165
-
ResolverConfig::from_parts(None, vec![], nameserver_group)
166
-
} else {
167
-
// Use system default
168
-
tracing::info!("Using system default DNS nameservers");
169
-
ResolverConfig::default()
170
-
};
171
-
172
-
// TokioAsyncResolver::tokio returns an AsyncResolver directly, not a Result
173
-
TokioAsyncResolver::tokio(resolver_config, ResolverOpts::default())
174
-
}
175
-
176
-
pub mod errors {
177
-
use thiserror::Error;
178
-
179
-
#[derive(Debug, Error)]
180
-
pub enum ResolveError {
181
-
#[error("error-resolve-1 Multiple DIDs resolved for method")]
182
-
MultipleDIDsFound,
183
-
184
-
#[error("error-resolve-2 No DIDs resolved for method")]
185
-
NoDIDsFound,
186
-
187
-
#[error("error-resolve-3 No DIDs resolved for method")]
188
-
ConflictingDIDsFound,
189
-
190
-
#[error("error-resolve-4 DNS resolution failed: {0:?}")]
191
-
DNSResolutionFailed(hickory_resolver::error::ResolveError),
192
-
193
-
#[error("error-resolve-5 HTTP resolution failed: {0:?}")]
194
-
HTTPResolutionFailed(reqwest::Error),
195
-
196
-
#[error("error-resolve-6 HTTP resolution failed")]
197
-
InvalidHTTPResolutionResponse,
198
-
199
-
#[error("error-resolve-7 HTTP resolution failed: {0:?}")]
200
-
DIDWebResolutionFailed(anyhow::Error),
201
-
202
-
#[error("error-resolve-8 Invalid input")]
203
-
InvalidInput,
204
-
}
205
-
}
+246
src/storage/atproto.rs
+246
src/storage/atproto.rs
···
1
+
use async_trait::async_trait;
2
+
use atproto_identity::{model::Document, storage::DidDocumentStorage};
3
+
use atproto_oauth::{storage::OAuthRequestStorage, workflow::OAuthRequest};
4
+
use chrono::{DateTime, Utc};
5
+
use serde_json::Value as JsonValue;
6
+
use sqlx::FromRow;
7
+
use std::sync::Arc;
8
+
9
+
use crate::storage::{errors::StorageError, StoragePool};
10
+
11
+
/// Database row representation of OAuthRequest
12
+
#[derive(FromRow)]
13
+
struct OAuthRequestRow {
14
+
pub oauth_state: String,
15
+
pub issuer: String,
16
+
pub did: String,
17
+
pub nonce: String,
18
+
pub pkce_verifier: String,
19
+
pub signing_public_key: String,
20
+
pub dpop_private_key: String,
21
+
pub created_at: DateTime<Utc>,
22
+
pub expires_at: DateTime<Utc>,
23
+
}
24
+
25
+
/// Postgres implementation of DidDocumentStorage trait
26
+
pub struct PostgresDidDocumentStorage {
27
+
pool: StoragePool,
28
+
}
29
+
30
+
impl PostgresDidDocumentStorage {
31
+
pub fn new(pool: StoragePool) -> Self {
32
+
Self { pool }
33
+
}
34
+
35
+
pub fn new_arc(pool: StoragePool) -> Arc<dyn DidDocumentStorage> {
36
+
Arc::new(Self::new(pool))
37
+
}
38
+
}
39
+
40
+
#[async_trait]
41
+
impl DidDocumentStorage for PostgresDidDocumentStorage {
42
+
async fn get_document_by_did(&self, did: &str) -> Result<Option<Document>, anyhow::Error> {
43
+
if did.trim().is_empty() {
44
+
return Err(anyhow::anyhow!("DID cannot be empty"));
45
+
}
46
+
47
+
let mut tx = self.pool.begin().await?;
48
+
49
+
let result = sqlx::query_scalar::<_, JsonValue>(
50
+
"SELECT document_json FROM did_documents WHERE did = $1 AND (expires_at IS NULL OR expires_at > NOW())"
51
+
)
52
+
.bind(did)
53
+
.fetch_optional(tx.as_mut())
54
+
.await?;
55
+
56
+
if let Some(json_value) = result {
57
+
// Update the updated_at timestamp for LRU behavior
58
+
sqlx::query("UPDATE did_documents SET updated_at = NOW() WHERE did = $1")
59
+
.bind(did)
60
+
.execute(tx.as_mut())
61
+
.await?;
62
+
63
+
tx.commit().await?;
64
+
65
+
// Convert JSON back to Document
66
+
let document: Document = serde_json::from_value(json_value)?;
67
+
Ok(Some(document))
68
+
} else {
69
+
tx.commit().await?;
70
+
Ok(None)
71
+
}
72
+
}
73
+
74
+
async fn store_document(&self, document: Document) -> Result<(), anyhow::Error> {
75
+
let did = document.id.clone();
76
+
let document_json = serde_json::to_value(&document)?;
77
+
78
+
let mut tx = self.pool.begin().await?;
79
+
80
+
sqlx::query(
81
+
"INSERT INTO did_documents (did, document_json) VALUES ($1, $2)
82
+
ON CONFLICT (did) DO UPDATE SET
83
+
document_json = EXCLUDED.document_json,
84
+
updated_at = NOW()",
85
+
)
86
+
.bind(&did)
87
+
.bind(&document_json)
88
+
.execute(tx.as_mut())
89
+
.await?;
90
+
91
+
tx.commit().await?;
92
+
Ok(())
93
+
}
94
+
95
+
async fn delete_document_by_did(&self, did: &str) -> Result<(), anyhow::Error> {
96
+
if did.trim().is_empty() {
97
+
return Err(anyhow::anyhow!("DID cannot be empty"));
98
+
}
99
+
100
+
let mut tx = self.pool.begin().await?;
101
+
102
+
sqlx::query("DELETE FROM did_documents WHERE did = $1")
103
+
.bind(did)
104
+
.execute(tx.as_mut())
105
+
.await?;
106
+
107
+
tx.commit().await?;
108
+
Ok(())
109
+
}
110
+
}
111
+
112
+
/// Postgres implementation of OAuthRequestStorage trait
113
+
pub struct PostgresOAuthRequestStorage {
114
+
pool: StoragePool,
115
+
}
116
+
117
+
impl PostgresOAuthRequestStorage {
118
+
pub fn new(pool: StoragePool) -> Self {
119
+
Self { pool }
120
+
}
121
+
122
+
pub fn new_arc(pool: StoragePool) -> Arc<dyn OAuthRequestStorage> {
123
+
Arc::new(Self::new(pool))
124
+
}
125
+
}
126
+
127
+
#[async_trait]
128
+
impl OAuthRequestStorage for PostgresOAuthRequestStorage {
129
+
async fn insert_oauth_request(&self, request: OAuthRequest) -> Result<(), anyhow::Error> {
130
+
let mut tx = self.pool.begin().await?;
131
+
132
+
sqlx::query(
133
+
"INSERT INTO atproto_oauth_requests (
134
+
oauth_state, issuer, did, nonce, pkce_verifier,
135
+
signing_public_key, dpop_private_key, created_at, expires_at
136
+
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)",
137
+
)
138
+
.bind(&request.oauth_state)
139
+
.bind(&request.issuer)
140
+
.bind(&request.did)
141
+
.bind(&request.nonce)
142
+
.bind(&request.pkce_verifier)
143
+
.bind(&request.signing_public_key)
144
+
.bind(&request.dpop_private_key)
145
+
.bind(request.created_at)
146
+
.bind(request.expires_at)
147
+
.execute(tx.as_mut())
148
+
.await?;
149
+
150
+
tx.commit().await?;
151
+
Ok(())
152
+
}
153
+
154
+
async fn get_oauth_request_by_state(
155
+
&self,
156
+
state: &str,
157
+
) -> Result<Option<OAuthRequest>, anyhow::Error> {
158
+
if state.trim().is_empty() {
159
+
return Err(anyhow::anyhow!("OAuth state cannot be empty"));
160
+
}
161
+
162
+
let mut tx = self.pool.begin().await?;
163
+
164
+
let result = sqlx::query_as::<_, OAuthRequestRow>(
165
+
"SELECT oauth_state, issuer, did, nonce, pkce_verifier,
166
+
signing_public_key, dpop_private_key, created_at, expires_at
167
+
FROM atproto_oauth_requests
168
+
WHERE oauth_state = $1 AND expires_at > NOW()",
169
+
)
170
+
.bind(state)
171
+
.fetch_optional(tx.as_mut())
172
+
.await?;
173
+
174
+
tx.commit().await?;
175
+
176
+
if let Some(row) = result {
177
+
let oauth_request = OAuthRequest {
178
+
oauth_state: row.oauth_state,
179
+
issuer: row.issuer,
180
+
did: row.did,
181
+
nonce: row.nonce,
182
+
pkce_verifier: row.pkce_verifier,
183
+
signing_public_key: row.signing_public_key,
184
+
dpop_private_key: row.dpop_private_key,
185
+
created_at: row.created_at,
186
+
expires_at: row.expires_at,
187
+
};
188
+
Ok(Some(oauth_request))
189
+
} else {
190
+
Ok(None)
191
+
}
192
+
}
193
+
194
+
async fn delete_oauth_request_by_state(&self, state: &str) -> Result<(), anyhow::Error> {
195
+
if state.trim().is_empty() {
196
+
return Err(anyhow::anyhow!("OAuth state cannot be empty"));
197
+
}
198
+
199
+
let mut tx = self.pool.begin().await?;
200
+
201
+
sqlx::query("DELETE FROM atproto_oauth_requests WHERE oauth_state = $1")
202
+
.bind(state)
203
+
.execute(tx.as_mut())
204
+
.await?;
205
+
206
+
tx.commit().await?;
207
+
Ok(())
208
+
}
209
+
210
+
async fn clear_expired_oauth_requests(&self) -> Result<u64, anyhow::Error> {
211
+
let mut tx = self.pool.begin().await?;
212
+
213
+
let result = sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at <= NOW()")
214
+
.execute(tx.as_mut())
215
+
.await?;
216
+
217
+
tx.commit().await?;
218
+
Ok(result.rows_affected())
219
+
}
220
+
}
221
+
222
+
/// Cleanup expired DID documents and OAuth requests
223
+
pub async fn cleanup_expired_records(pool: &StoragePool) -> Result<(), StorageError> {
224
+
let mut tx = pool
225
+
.begin()
226
+
.await
227
+
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
228
+
229
+
// Clean up expired DID documents
230
+
sqlx::query("DELETE FROM did_documents WHERE expires_at IS NOT NULL AND expires_at <= NOW()")
231
+
.execute(tx.as_mut())
232
+
.await
233
+
.map_err(StorageError::UnableToExecuteQuery)?;
234
+
235
+
// Clean up expired OAuth requests
236
+
sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at <= NOW()")
237
+
.execute(tx.as_mut())
238
+
.await
239
+
.map_err(StorageError::UnableToExecuteQuery)?;
240
+
241
+
tx.commit()
242
+
.await
243
+
.map_err(StorageError::CannotCommitDatabaseTransaction)?;
244
+
245
+
Ok(())
246
+
}
+8
-36
src/storage/denylist.rs
+8
-36
src/storage/denylist.rs
···
8
8
9
9
use crate::storage::{errors::StorageError, StoragePool};
10
10
11
-
pub mod model {
11
+
pub(crate) mod model {
12
12
use chrono::{DateTime, Utc};
13
13
use serde::{Deserialize, Serialize};
14
14
use sqlx::FromRow;
···
22
22
}
23
23
24
24
// Add a new entry to the denylist or update an existing one
25
-
pub async fn denylist_add_or_update(
25
+
pub(crate) async fn denylist_add_or_update(
26
26
pool: &StoragePool,
27
27
subject: Cow<'_, str>,
28
28
reason: Cow<'_, str>,
···
68
68
}
69
69
70
70
// Remove an entry from the denylist
71
-
pub async fn denylist_remove(pool: &StoragePool, subject: &str) -> Result<(), StorageError> {
71
+
pub(crate) async fn denylist_remove(pool: &StoragePool, subject: &str) -> Result<(), StorageError> {
72
72
// Validate subject before proceeding
73
73
if subject.trim().is_empty() {
74
74
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
···
98
98
Ok(())
99
99
}
100
100
101
-
// Check if a subject is in the denylist
102
-
pub async fn denylist_check(pool: &StoragePool, subject: &str) -> Result<bool, StorageError> {
103
-
// Validate subject before proceeding
104
-
if subject.trim().is_empty() {
105
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
106
-
"Subject cannot be empty".into(),
107
-
)));
108
-
}
109
-
110
-
let mut tx = pool
111
-
.begin()
112
-
.await
113
-
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
114
-
115
-
let mut h = MetroHash64::default();
116
-
h.write(subject.as_bytes());
117
-
let subject = crockford::encode(h.finish());
118
-
119
-
let count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM denylist WHERE subject = $1")
120
-
.bind(subject)
121
-
.fetch_one(tx.as_mut())
122
-
.await
123
-
.map_err(StorageError::UnableToExecuteQuery)?;
124
-
125
-
tx.commit()
126
-
.await
127
-
.map_err(StorageError::CannotCommitDatabaseTransaction)?;
128
-
129
-
Ok(count > 0)
130
-
}
131
-
132
101
// Get a list of denylist entries with pagination
133
-
pub async fn denylist_list(
102
+
pub(crate) async fn denylist_list(
134
103
pool: &StoragePool,
135
104
page: i64,
136
105
page_size: i64,
···
163
132
Ok((count, entries))
164
133
}
165
134
166
-
pub async fn denylist_exists(pool: &StoragePool, subjects: &[&str]) -> Result<bool, StorageError> {
135
+
pub(crate) async fn denylist_exists(
136
+
pool: &StoragePool,
137
+
subjects: &[&str],
138
+
) -> Result<bool, StorageError> {
167
139
// Validate input - empty array should return false, not error
168
140
if subjects.is_empty() {
169
141
return Ok(false);
+7
src/storage/errors.rs
+7
src/storage/errors.rs
···
22
22
#[error("error-oauth-model-2 Invalid OAuth flow state")]
23
23
InvalidOAuthFlowState(),
24
24
25
+
/// Error when deserializing DPoP JWK from string fails.
26
+
///
27
+
/// This error occurs when attempting to deserialize a string-encoded
28
+
/// JSON Web Key (JWK) for DPoP operations, typically due to invalid JSON format.
29
+
#[error("error-oauth-model-5 Failed to deserialize DPoP JWK: {0:?}")]
30
+
DpopJwkDeserializationFailed(serde_json::Error),
31
+
25
32
/// Error when required OAuth session data is missing.
26
33
///
27
34
/// This error occurs when attempting to use an OAuth session
+38
-9
src/storage/event.rs
+38
-9
src/storage/event.rs
···
21
21
use sqlx::FromRow;
22
22
23
23
#[derive(Clone, FromRow, Deserialize, Serialize, Debug)]
24
-
pub struct Event {
24
+
pub(crate) struct Event {
25
25
pub aturi: String,
26
26
pub cid: String,
27
27
···
36
36
}
37
37
38
38
#[derive(Clone, FromRow, Debug, Serialize)]
39
-
pub struct EventWithRole {
39
+
pub(crate) struct EventWithRole {
40
40
#[sqlx(flatten)]
41
-
pub event: Event,
41
+
pub(crate) event: Event,
42
42
43
-
pub role: String,
43
+
pub(crate) role: String,
44
44
// pub event_handle: String,
45
45
}
46
46
···
250
250
}
251
251
}
252
252
253
-
pub fn extract_event_details(event: &Event) -> EventDetails {
253
+
pub(crate) fn extract_event_details(event: &Event) -> EventDetails {
254
254
use crate::atproto::lexicon::{
255
255
community::lexicon::calendar::event::{Event as CommunityEvent, Mode, Status},
256
256
events::smokesignal::calendar::event::Event as SmokeSignalEvent,
···
469
469
pub uris: Vec<crate::atproto::lexicon::community::lexicon::calendar::event::EventLink>,
470
470
}
471
471
472
-
pub async fn event_get(pool: &StoragePool, aturi: &str) -> Result<Event, StorageError> {
472
+
pub(crate) async fn event_get(pool: &StoragePool, aturi: &str) -> Result<Event, StorageError> {
473
473
// Validate aturi is not empty
474
474
if aturi.trim().is_empty() {
475
475
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
···
553
553
Ok(record)
554
554
}
555
555
556
-
pub async fn event_list_did_recently_updated(
556
+
pub(crate) async fn event_list_did_recently_updated(
557
557
pool: &StoragePool,
558
558
did: &str,
559
559
page: i64,
···
611
611
Ok(event_roles)
612
612
}
613
613
614
-
pub async fn event_list_recently_updated(
614
+
pub(crate) async fn event_list_recently_updated(
615
615
pool: &StoragePool,
616
616
page: i64,
617
617
page_size: i64,
···
954
954
)))
955
955
}
956
956
957
-
pub async fn event_list(
957
+
pub(crate) async fn event_list(
958
958
pool: &StoragePool,
959
959
page: i64,
960
960
page_size: i64,
···
993
993
994
994
Ok((total_count, events))
995
995
}
996
+
997
+
pub async fn event_delete(pool: &StoragePool, aturi: &str) -> Result<(), StorageError> {
998
+
let mut tx = pool
999
+
.begin()
1000
+
.await
1001
+
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
1002
+
1003
+
// Delete only the event record (RSVPs are preserved)
1004
+
let result = sqlx::query("DELETE FROM events WHERE aturi = $1")
1005
+
.bind(aturi)
1006
+
.execute(tx.as_mut())
1007
+
.await
1008
+
.map_err(StorageError::UnableToExecuteQuery)?;
1009
+
1010
+
if result.rows_affected() == 0 {
1011
+
// Rollback the transaction - we don't need to map the error
1012
+
let _ = tx.rollback().await;
1013
+
return Err(StorageError::RowNotFound(
1014
+
"Event not found".to_string(),
1015
+
sqlx::Error::RowNotFound,
1016
+
));
1017
+
}
1018
+
1019
+
tx.commit()
1020
+
.await
1021
+
.map_err(StorageError::CannotCommitDatabaseTransaction)?;
1022
+
1023
+
Ok(())
1024
+
}
+63
-80
src/storage/handle.rs
src/storage/identity_profile.rs
+63
-80
src/storage/handle.rs
src/storage/identity_profile.rs
···
7
7
use crate::storage::denylist::denylist_add_or_update;
8
8
use crate::storage::errors::StorageError;
9
9
use crate::storage::StoragePool;
10
-
use model::Handle;
10
+
use model::IdentityProfile;
11
11
12
12
pub mod model {
13
13
use chrono::{DateTime, Utc};
···
15
15
use sqlx::FromRow;
16
16
17
17
#[derive(Clone, FromRow, Deserialize, Serialize, Debug)]
18
-
pub struct Handle {
18
+
pub struct IdentityProfile {
19
19
pub did: String,
20
20
pub handle: String,
21
21
pub pds: String,
···
60
60
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
61
61
62
62
let now = Utc::now();
63
-
let insert_result = sqlx::query("INSERT INTO handles (did, handle, pds, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING")
63
+
let insert_result = sqlx::query("INSERT INTO identity_profiles (did, handle, pds, created_at, updated_at) VALUES ($1, $2, $3, $4, $5) ON CONFLICT DO NOTHING")
64
64
.bind(did)
65
65
.bind(handle)
66
66
.bind(pds)
···
71
71
.map_err(StorageError::UnableToExecuteQuery)?;
72
72
73
73
if insert_result.rows_affected() == 0 {
74
-
sqlx::query("UPDATE handles SET updated_at = $1, handle = $2, pds = $3 WHERE did = $4")
75
-
.bind(now)
76
-
.bind(handle)
77
-
.bind(pds)
78
-
.bind(did)
79
-
.execute(tx.as_mut())
80
-
.await
81
-
.map_err(StorageError::UnableToExecuteQuery)?;
74
+
sqlx::query(
75
+
"UPDATE identity_profiles SET updated_at = $1, handle = $2, pds = $3 WHERE did = $4",
76
+
)
77
+
.bind(now)
78
+
.bind(handle)
79
+
.bind(pds)
80
+
.bind(did)
81
+
.execute(tx.as_mut())
82
+
.await
83
+
.map_err(StorageError::UnableToExecuteQuery)?;
82
84
}
83
85
84
86
tx.commit()
···
106
108
107
109
let query = match &field {
108
110
HandleField::Language(_) => {
109
-
"UPDATE handles SET language = $1, updated_at = $2 WHERE did = $3"
111
+
"UPDATE identity_profiles SET language = $1, updated_at = $2 WHERE did = $3"
110
112
}
111
-
HandleField::Timezone(_) => "UPDATE handles SET tz = $1, updated_at = $2 WHERE did = $3",
113
+
HandleField::Timezone(_) => {
114
+
"UPDATE identity_profiles SET tz = $1, updated_at = $2 WHERE did = $3"
115
+
}
112
116
HandleField::ActiveNow => {
113
-
"UPDATE handles SET active_at = $1, updated_at = $2 WHERE did = $3"
117
+
"UPDATE identity_profiles SET active_at = $1, updated_at = $2 WHERE did = $3"
114
118
}
115
119
};
116
120
···
140
144
.map_err(StorageError::CannotCommitDatabaseTransaction)
141
145
}
142
146
143
-
pub async fn handle_for_did(pool: &StoragePool, did: &str) -> Result<Handle, StorageError> {
147
+
pub async fn handle_for_did(
148
+
pool: &StoragePool,
149
+
did: &str,
150
+
) -> Result<IdentityProfile, StorageError> {
144
151
// Validate DID is not empty
145
152
if did.trim().is_empty() {
146
153
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
···
153
160
.await
154
161
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
155
162
156
-
let entity = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1")
157
-
.bind(did)
158
-
.fetch_one(tx.as_mut())
159
-
.await
160
-
.map_err(|err| match err {
161
-
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
162
-
other => StorageError::UnableToExecuteQuery(other),
163
-
})?;
163
+
let entity =
164
+
sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1")
165
+
.bind(did)
166
+
.fetch_one(tx.as_mut())
167
+
.await
168
+
.map_err(|err| match err {
169
+
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
170
+
other => StorageError::UnableToExecuteQuery(other),
171
+
})?;
164
172
165
173
tx.commit()
166
174
.await
···
169
177
Ok(entity)
170
178
}
171
179
172
-
pub async fn handle_for_handle(pool: &StoragePool, handle: &str) -> Result<Handle, StorageError> {
180
+
pub async fn handle_for_handle(
181
+
pool: &StoragePool,
182
+
handle: &str,
183
+
) -> Result<IdentityProfile, StorageError> {
173
184
// Validate handle is not empty
174
185
if handle.trim().is_empty() {
175
186
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
···
182
193
.await
183
194
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
184
195
185
-
let entity = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE handle = $1")
186
-
.bind(handle)
187
-
.fetch_one(tx.as_mut())
188
-
.await
189
-
.map_err(|err| match err {
190
-
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
191
-
other => StorageError::UnableToExecuteQuery(other),
192
-
})?;
196
+
let entity =
197
+
sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE handle = $1")
198
+
.bind(handle)
199
+
.fetch_one(tx.as_mut())
200
+
.await
201
+
.map_err(|err| match err {
202
+
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
203
+
other => StorageError::UnableToExecuteQuery(other),
204
+
})?;
193
205
194
206
tx.commit()
195
207
.await
···
202
214
pool: &StoragePool,
203
215
page: i64,
204
216
page_size: i64,
205
-
) -> Result<(i64, Vec<Handle>), StorageError> {
217
+
) -> Result<(i64, Vec<IdentityProfile>), StorageError> {
206
218
let mut tx = pool
207
219
.begin()
208
220
.await
209
221
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
210
222
211
-
let total_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM handles")
223
+
let total_count = sqlx::query_scalar::<_, i64>("SELECT COUNT(*) FROM identity_profiles")
212
224
.fetch_one(tx.as_mut())
213
225
.await
214
226
.map_err(StorageError::UnableToExecuteQuery)?;
215
227
216
228
let offset = (page - 1) * page_size;
217
229
218
-
let handles = sqlx::query_as::<_, Handle>(
219
-
"SELECT * FROM handles ORDER BY updated_at DESC LIMIT $1 OFFSET $2",
230
+
let handles = sqlx::query_as::<_, IdentityProfile>(
231
+
"SELECT * FROM identity_profiles ORDER BY updated_at DESC LIMIT $1 OFFSET $2",
220
232
)
221
233
.bind(page_size + 1) // Fetch one more to know if there are more entries
222
234
.bind(offset)
···
256
268
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
257
269
258
270
// Get handle information first
259
-
let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1")
260
-
.bind(did)
261
-
.fetch_one(tx.as_mut())
262
-
.await
263
-
.map_err(|err| match err {
264
-
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
265
-
other => StorageError::UnableToExecuteQuery(other),
266
-
})?;
271
+
let handle =
272
+
sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1")
273
+
.bind(did)
274
+
.fetch_one(tx.as_mut())
275
+
.await
276
+
.map_err(|err| match err {
277
+
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
278
+
other => StorageError::UnableToExecuteQuery(other),
279
+
})?;
267
280
268
281
// Delete RSVPs created by this identity
269
282
sqlx::query("DELETE FROM rsvps WHERE did = $1")
···
280
293
.map_err(StorageError::UnableToExecuteQuery)?;
281
294
282
295
// Delete the handle entry
283
-
sqlx::query("DELETE FROM handles WHERE did = $1")
296
+
sqlx::query("DELETE FROM identity_profiles WHERE did = $1")
284
297
.bind(did)
285
298
.execute(tx.as_mut())
286
299
.await
···
322
335
pub async fn handles_by_did(
323
336
pool: &StoragePool,
324
337
dids: Vec<String>,
325
-
) -> Result<HashMap<std::string::String, Handle>, StorageError> {
338
+
) -> Result<HashMap<std::string::String, IdentityProfile>, StorageError> {
326
339
if dids.is_empty() {
327
340
return Ok(HashMap::default());
328
341
}
···
343
356
344
357
// Build the query with placeholders
345
358
let mut query_builder: QueryBuilder<Postgres> =
346
-
QueryBuilder::new("SELECT * FROM handles WHERE did IN (");
359
+
QueryBuilder::new("SELECT * FROM identity_profiles WHERE did IN (");
347
360
let mut separated = query_builder.separated(", ");
348
361
for did in &dids {
349
362
separated.push_bind(did);
···
351
364
separated.push_unseparated(") ");
352
365
353
366
// The query_builder.build() already includes the bindings, so we don't need to bind again
354
-
let query = query_builder.build_query_as::<Handle>();
367
+
let query = query_builder.build_query_as::<IdentityProfile>();
355
368
let values = query
356
369
.fetch_all(tx.as_mut())
357
370
.await
···
372
385
pub mod test {
373
386
use sqlx::PgPool;
374
387
375
-
use crate::storage::handle::handle_for_did;
376
-
use crate::storage::handle::handle_for_handle;
377
-
use crate::storage::handle::handle_warm_up;
388
+
use crate::storage::identity_profile::handle_for_did;
389
+
use crate::storage::identity_profile::handle_for_handle;
378
390
379
391
#[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
380
392
async fn test_handle_for_did(pool: PgPool) -> sqlx::Result<()> {
···
394
406
assert!(!handle.is_err());
395
407
let handle = handle.unwrap();
396
408
assert_eq!(handle.did, "did:plc:d5c1ed6d01421a67b96f68fa");
397
-
398
-
Ok(())
399
-
}
400
-
401
-
#[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
402
-
async fn test_handle_warm_up(pool: PgPool) -> sqlx::Result<()> {
403
-
let did = "did:plc:f263c822655b579fc8a79635";
404
-
let handle = "inspiring-bobwhite.examplepds.com";
405
-
let updated_handle = "charming-needlefish.examplepds.com";
406
-
let pds = "https://pds.examplepds.com";
407
-
408
-
let warmup_result = handle_warm_up(&pool, did, handle, pds).await;
409
-
assert!(!warmup_result.is_err());
410
-
411
-
{
412
-
let handle = handle_for_handle(&pool, handle).await;
413
-
assert!(!handle.is_err());
414
-
let handle = handle.unwrap();
415
-
assert_eq!(handle.did, did);
416
-
}
417
-
418
-
{
419
-
let warmup_result = handle_warm_up(&pool, did, updated_handle, pds).await;
420
-
assert!(!warmup_result.is_err());
421
-
}
422
-
{
423
-
let handle = handle_for_handle(&pool, handle).await;
424
-
assert!(handle.is_err());
425
-
}
426
409
427
410
Ok(())
428
411
}
+2
-1
src/storage/mod.rs
+2
-1
src/storage/mod.rs
+14
-358
src/storage/oauth.rs
+14
-358
src/storage/oauth.rs
···
1
1
use std::borrow::Cow;
2
2
3
3
use chrono::{DateTime, Utc};
4
-
use serde_json::json;
5
-
6
-
use crate::{
7
-
jose::jwk::WrappedJsonWebKey,
8
-
storage::{errors::StorageError, handle::model::Handle, StoragePool},
9
-
};
10
-
use model::{OAuthRequest, OAuthSession};
11
-
12
-
pub struct OAuthRequestParams {
13
-
pub oauth_state: Cow<'static, str>,
14
-
pub issuer: Cow<'static, str>,
15
-
pub did: Cow<'static, str>,
16
-
pub nonce: Cow<'static, str>,
17
-
pub pkce_verifier: Cow<'static, str>,
18
-
pub secret_jwk_id: Cow<'static, str>,
19
-
pub dpop_jwk: Option<WrappedJsonWebKey>,
20
-
pub destination: Option<Cow<'static, str>>,
21
-
pub created_at: DateTime<Utc>,
22
-
pub expires_at: DateTime<Utc>,
23
-
}
24
-
25
-
pub async fn oauth_request_insert(
26
-
pool: &StoragePool,
27
-
params: OAuthRequestParams,
28
-
) -> Result<(), StorageError> {
29
-
// Validate required input parameters
30
-
if params.oauth_state.trim().is_empty() {
31
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
32
-
"OAuth state cannot be empty".into(),
33
-
)));
34
-
}
35
-
36
-
if params.issuer.trim().is_empty() {
37
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
38
-
"Issuer cannot be empty".into(),
39
-
)));
40
-
}
41
-
42
-
if params.did.trim().is_empty() {
43
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
44
-
"DID cannot be empty".into(),
45
-
)));
46
-
}
47
-
48
-
if params.nonce.trim().is_empty() {
49
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
50
-
"Nonce cannot be empty".into(),
51
-
)));
52
-
}
53
-
54
-
if params.pkce_verifier.trim().is_empty() {
55
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
56
-
"PKCE verifier cannot be empty".into(),
57
-
)));
58
-
}
59
-
60
-
if params.secret_jwk_id.trim().is_empty() {
61
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
62
-
"Secret JWK ID cannot be empty".into(),
63
-
)));
64
-
}
65
-
66
-
let mut tx = pool
67
-
.begin()
68
-
.await
69
-
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
70
-
71
-
let dpop_jwk_value = params
72
-
.dpop_jwk
73
-
.map(|jwk| json!(jwk))
74
-
.unwrap_or_else(|| json!({}));
75
-
76
-
sqlx::query("INSERT INTO oauth_requests (oauth_state, issuer, did, nonce, pkce_verifier, secret_jwk_id, dpop_jwk, destination, created_at, expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)")
77
-
.bind(¶ms.oauth_state)
78
-
.bind(¶ms.issuer)
79
-
.bind(¶ms.did)
80
-
.bind(¶ms.nonce)
81
-
.bind(¶ms.pkce_verifier)
82
-
.bind(¶ms.secret_jwk_id)
83
-
.bind(dpop_jwk_value)
84
-
.bind(params.destination)
85
-
.bind(params.created_at)
86
-
.bind(params.expires_at)
87
-
.execute(tx.as_mut())
88
-
.await
89
-
.map_err(StorageError::UnableToExecuteQuery)?;
90
-
91
-
tx.commit()
92
-
.await
93
-
.map_err(StorageError::CannotCommitDatabaseTransaction)
94
-
}
95
-
96
-
pub async fn oauth_request_get(
97
-
pool: &StoragePool,
98
-
oauth_state: &str,
99
-
) -> Result<OAuthRequest, StorageError> {
100
-
// Validate oauth_state is not empty
101
-
if oauth_state.trim().is_empty() {
102
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
103
-
"OAuth state cannot be empty".into(),
104
-
)));
105
-
}
106
-
107
-
let mut tx = pool
108
-
.begin()
109
-
.await
110
-
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
111
-
112
-
let record =
113
-
sqlx::query_as::<_, OAuthRequest>("SELECT * FROM oauth_requests WHERE oauth_state = $1")
114
-
.bind(oauth_state)
115
-
.fetch_one(tx.as_mut())
116
-
.await
117
-
.map_err(|err| match err {
118
-
sqlx::Error::RowNotFound => StorageError::OAuthRequestNotFound,
119
-
other => StorageError::UnableToExecuteQuery(other),
120
-
})?;
121
-
122
-
tx.commit()
123
-
.await
124
-
.map_err(StorageError::CannotCommitDatabaseTransaction)?;
125
-
126
-
Ok(record)
127
-
}
128
-
129
-
pub async fn oauth_request_remove(
130
-
pool: &StoragePool,
131
-
oauth_state: &str,
132
-
) -> Result<(), StorageError> {
133
-
// Validate oauth_state is not empty
134
-
if oauth_state.trim().is_empty() {
135
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
136
-
"OAuth state cannot be empty".into(),
137
-
)));
138
-
}
139
-
140
-
let mut tx = pool
141
-
.begin()
142
-
.await
143
-
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
144
-
145
-
sqlx::query("DELETE FROM oauth_requests WHERE oauth_state = $1")
146
-
.bind(oauth_state)
147
-
.execute(tx.as_mut())
148
-
.await
149
-
.map_err(StorageError::UnableToExecuteQuery)?;
150
-
151
-
tx.commit()
152
-
.await
153
-
.map_err(StorageError::CannotCommitDatabaseTransaction)
154
-
}
155
-
156
-
pub struct OAuthSessionParams {
157
-
pub session_group: Cow<'static, str>,
158
-
pub access_token: Cow<'static, str>,
159
-
pub did: Cow<'static, str>,
160
-
pub issuer: Cow<'static, str>,
161
-
pub refresh_token: Cow<'static, str>,
162
-
pub secret_jwk_id: Cow<'static, str>,
163
-
pub dpop_jwk: WrappedJsonWebKey,
164
-
pub created_at: DateTime<Utc>,
165
-
pub access_token_expires_at: DateTime<Utc>,
166
-
}
167
-
168
-
pub async fn oauth_session_insert(
169
-
pool: &StoragePool,
170
-
params: OAuthSessionParams,
171
-
) -> Result<(), StorageError> {
172
-
// Validate required input parameters
173
-
if params.session_group.trim().is_empty() {
174
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
175
-
"Session group cannot be empty".into(),
176
-
)));
177
-
}
178
-
179
-
if params.access_token.trim().is_empty() {
180
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
181
-
"Access token cannot be empty".into(),
182
-
)));
183
-
}
184
-
185
-
if params.did.trim().is_empty() {
186
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
187
-
"DID cannot be empty".into(),
188
-
)));
189
-
}
190
4
191
-
if params.issuer.trim().is_empty() {
192
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
193
-
"Issuer cannot be empty".into(),
194
-
)));
195
-
}
196
-
197
-
if params.refresh_token.trim().is_empty() {
198
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
199
-
"Refresh token cannot be empty".into(),
200
-
)));
201
-
}
202
-
203
-
if params.secret_jwk_id.trim().is_empty() {
204
-
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
205
-
"Secret JWK ID cannot be empty".into(),
206
-
)));
207
-
}
208
-
209
-
let mut tx = pool
210
-
.begin()
211
-
.await
212
-
.map_err(StorageError::CannotBeginDatabaseTransaction)?;
213
-
214
-
sqlx::query("INSERT INTO oauth_sessions (session_group, access_token, did, issuer, refresh_token, secret_jwk_id, dpop_jwk, created_at, access_token_expires_at) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9)")
215
-
.bind(¶ms.session_group)
216
-
.bind(¶ms.access_token)
217
-
.bind(¶ms.did)
218
-
.bind(¶ms.issuer)
219
-
.bind(¶ms.refresh_token)
220
-
.bind(¶ms.secret_jwk_id)
221
-
.bind(json!(params.dpop_jwk))
222
-
.bind(params.created_at)
223
-
.bind(params.access_token_expires_at)
224
-
.execute(tx.as_mut())
225
-
.await
226
-
.map_err(StorageError::UnableToExecuteQuery)?;
227
-
228
-
tx.commit()
229
-
.await
230
-
.map_err(StorageError::CannotCommitDatabaseTransaction)
231
-
}
5
+
use crate::storage::{errors::StorageError, identity_profile::model::IdentityProfile, StoragePool};
6
+
use model::OAuthSession;
232
7
233
8
pub async fn oauth_session_update(
234
9
pool: &StoragePool,
···
308
83
pool: &StoragePool,
309
84
session_group: &str,
310
85
did: Option<&str>,
311
-
) -> Result<(Handle, OAuthSession), StorageError> {
86
+
) -> Result<(IdentityProfile, OAuthSession), StorageError> {
312
87
// Validate session_group is not empty
313
88
if session_group.trim().is_empty() {
314
89
return Err(StorageError::UnableToExecuteQuery(sqlx::Error::Protocol(
···
356
131
357
132
let did_for_handle = did.unwrap_or(&oauth_session.did);
358
133
359
-
let handle = sqlx::query_as::<_, Handle>("SELECT * FROM handles WHERE did = $1")
360
-
.bind(did_for_handle)
361
-
.fetch_one(tx.as_mut())
362
-
.await
363
-
.map_err(|err| match err {
364
-
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
365
-
other => StorageError::UnableToExecuteQuery(other),
366
-
})?;
134
+
let handle =
135
+
sqlx::query_as::<_, IdentityProfile>("SELECT * FROM identity_profiles WHERE did = $1")
136
+
.bind(did_for_handle)
137
+
.fetch_one(tx.as_mut())
138
+
.await
139
+
.map_err(|err| match err {
140
+
sqlx::Error::RowNotFound => StorageError::HandleNotFound,
141
+
other => StorageError::UnableToExecuteQuery(other),
142
+
})?;
367
143
368
144
tx.commit()
369
145
.await
···
373
149
}
374
150
375
151
pub mod model {
376
-
use anyhow::Error;
377
152
use chrono::{DateTime, Utc};
378
-
use p256::SecretKey;
379
153
use serde::Deserialize;
380
154
use sqlx::FromRow;
381
155
382
-
use crate::{
383
-
atproto::auth::SimpleOAuthSessionProvider, jose::jwk::WrappedJsonWebKey,
384
-
storage::errors::OAuthModelError,
385
-
};
386
-
387
156
#[derive(Clone, FromRow, Deserialize)]
388
157
pub struct OAuthRequest {
389
158
pub oauth_state: String,
···
393
162
pub pkce_verifier: String,
394
163
pub secret_jwk_id: String,
395
164
pub destination: Option<String>,
396
-
pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>,
165
+
pub dpop_jwk: String,
397
166
pub created_at: DateTime<Utc>,
398
167
pub expires_at: DateTime<Utc>,
399
168
}
400
169
401
-
pub struct OAuthRequestState {
402
-
pub state: String,
403
-
pub nonce: String,
404
-
pub code_challenge: String,
405
-
}
406
-
407
170
#[derive(Clone, FromRow, Deserialize)]
408
171
pub struct OAuthSession {
409
172
pub session_group: String,
···
412
175
pub issuer: String,
413
176
pub refresh_token: String,
414
177
pub secret_jwk_id: String,
415
-
pub dpop_jwk: sqlx::types::Json<WrappedJsonWebKey>,
178
+
pub dpop_jwk: String,
416
179
pub created_at: DateTime<Utc>,
417
180
pub access_token_expires_at: DateTime<Utc>,
418
-
}
419
-
420
-
impl TryFrom<OAuthSession> for SimpleOAuthSessionProvider {
421
-
type Error = Error;
422
-
423
-
fn try_from(value: OAuthSession) -> Result<Self, Self::Error> {
424
-
let dpop_secret = SecretKey::from_jwk(&value.dpop_jwk.jwk)
425
-
.map_err(OAuthModelError::DpopSecretFromJwkFailed)?;
426
-
427
-
Ok(SimpleOAuthSessionProvider {
428
-
access_token: value.access_token,
429
-
issuer: value.issuer,
430
-
dpop_secret,
431
-
})
432
-
}
433
-
}
434
-
}
435
-
436
-
#[cfg(test)]
437
-
pub mod test {
438
-
use sqlx::PgPool;
439
-
440
-
use crate::{
441
-
jose,
442
-
storage::oauth::{
443
-
oauth_request_get, oauth_request_insert, oauth_request_remove, oauth_session_insert,
444
-
web_session_lookup, OAuthRequestParams, OAuthSessionParams,
445
-
},
446
-
};
447
-
448
-
#[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
449
-
async fn test_oauth_request(pool: PgPool) -> anyhow::Result<()> {
450
-
let dpop_jwk = jose::jwk::generate();
451
-
let created_at = chrono::Utc::now();
452
-
let expires_at = created_at + chrono::Duration::seconds(60 as i64);
453
-
454
-
let res = oauth_request_insert(
455
-
&pool,
456
-
OAuthRequestParams {
457
-
oauth_state: "oauth_state".to_string().into(),
458
-
issuer: "pds.examplepds.com".to_string().into(),
459
-
did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(),
460
-
nonce: "nonce".to_string().into(),
461
-
pkce_verifier: "pkce_verifier".to_string().into(),
462
-
secret_jwk_id: "secret_jwk_id".to_string().into(),
463
-
dpop_jwk: Some(dpop_jwk.clone()),
464
-
destination: None,
465
-
created_at,
466
-
expires_at,
467
-
},
468
-
)
469
-
.await;
470
-
471
-
assert!(!res.is_err());
472
-
473
-
let oauth_request = oauth_request_get(&pool, "oauth_state").await;
474
-
assert!(!oauth_request.is_err());
475
-
let oauth_request = oauth_request.unwrap();
476
-
477
-
assert_eq!(oauth_request.did, "did:plc:d5c1ed6d01421a67b96f68fa");
478
-
assert_eq!(oauth_request.dpop_jwk.as_ref(), &dpop_jwk);
479
-
480
-
let res = oauth_request_remove(&pool, "oauth_state").await;
481
-
assert!(!res.is_err());
482
-
483
-
{
484
-
let oauth_request = oauth_request_get(&pool, "oauth_state").await;
485
-
assert!(oauth_request.is_err());
486
-
}
487
-
488
-
Ok(())
489
-
}
490
-
491
-
#[sqlx::test(fixtures(path = "../../fixtures/storage", scripts("handles")))]
492
-
async fn test_oauth_session(pool: PgPool) -> anyhow::Result<()> {
493
-
let dpop_jwk = jose::jwk::generate();
494
-
495
-
let session_group = ulid::Ulid::new().to_string();
496
-
let now = chrono::Utc::now();
497
-
498
-
let insert_session_res = oauth_session_insert(
499
-
&pool,
500
-
OAuthSessionParams {
501
-
session_group: session_group.clone().into(),
502
-
access_token: "access_token".to_string().into(),
503
-
did: "did:plc:d5c1ed6d01421a67b96f68fa".to_string().into(),
504
-
issuer: "pds.examplepds.com".to_string().into(),
505
-
refresh_token: "refresh_token".to_string().into(),
506
-
secret_jwk_id: "secret_jwk_id".to_string().into(),
507
-
dpop_jwk: dpop_jwk.clone(),
508
-
created_at: now,
509
-
access_token_expires_at: now + chrono::Duration::seconds(60 as i64),
510
-
},
511
-
)
512
-
.await;
513
-
514
-
assert!(!insert_session_res.is_err());
515
-
516
-
let web_session = web_session_lookup(
517
-
&pool,
518
-
&session_group,
519
-
Some("did:plc:d5c1ed6d01421a67b96f68fa"),
520
-
)
521
-
.await;
522
-
assert!(!web_session.is_err());
523
-
524
-
Ok(())
525
181
}
526
182
}
+142
src/task_identity_refresh.rs
+142
src/task_identity_refresh.rs
···
1
+
use anyhow::Result;
2
+
use atproto_identity::{resolve::IdentityResolver, storage::DidDocumentStorage};
3
+
use chrono::Duration;
4
+
use sqlx::FromRow;
5
+
use tokio::time::{sleep, Instant};
6
+
use tokio_util::sync::CancellationToken;
7
+
8
+
use crate::storage::StoragePool;
9
+
10
+
pub struct IdentityRefreshTaskConfig {
11
+
pub sleep_interval: Duration,
12
+
pub worker_id: String,
13
+
}
14
+
15
+
pub struct IdentityRefreshTask {
16
+
pub config: IdentityRefreshTaskConfig,
17
+
pub storage_pool: StoragePool,
18
+
pub document_storage: std::sync::Arc<dyn DidDocumentStorage>,
19
+
pub identity_resolver: IdentityResolver,
20
+
pub cancellation_token: CancellationToken,
21
+
}
22
+
23
+
#[derive(FromRow)]
24
+
struct ExpiredDidDocument {
25
+
did: String,
26
+
}
27
+
28
+
impl IdentityRefreshTask {
29
+
#[must_use]
30
+
pub fn new(
31
+
config: IdentityRefreshTaskConfig,
32
+
storage_pool: StoragePool,
33
+
document_storage: std::sync::Arc<dyn DidDocumentStorage>,
34
+
identity_resolver: IdentityResolver,
35
+
cancellation_token: CancellationToken,
36
+
) -> Self {
37
+
Self {
38
+
config,
39
+
storage_pool,
40
+
document_storage,
41
+
identity_resolver,
42
+
cancellation_token,
43
+
}
44
+
}
45
+
46
+
/// Runs the identity refresh task as a long-running process
47
+
///
48
+
/// # Errors
49
+
/// Returns an error if the sleep interval cannot be converted, or if there's a problem
50
+
/// processing the expired DID documents
51
+
pub async fn run(&self) -> Result<()> {
52
+
tracing::debug!("IdentityRefreshTask started");
53
+
54
+
let interval = self.config.sleep_interval.to_std()?;
55
+
56
+
let sleeper = sleep(interval);
57
+
tokio::pin!(sleeper);
58
+
59
+
loop {
60
+
tokio::select! {
61
+
() = self.cancellation_token.cancelled() => {
62
+
break;
63
+
},
64
+
() = &mut sleeper => {
65
+
if let Err(err) = self.process_expired_documents().await {
66
+
tracing::error!("IdentityRefreshTask failed: {}", err);
67
+
}
68
+
sleeper.as_mut().reset(Instant::now() + interval);
69
+
}
70
+
}
71
+
}
72
+
73
+
tracing::info!("IdentityRefreshTask stopped");
74
+
75
+
Ok(())
76
+
}
77
+
78
+
async fn process_expired_documents(&self) -> Result<i32> {
79
+
// Find DID documents that have expired in a separate transaction
80
+
let expired_docs = {
81
+
let mut tx = self.storage_pool.begin().await?;
82
+
let docs = sqlx::query_as::<_, ExpiredDidDocument>(
83
+
"SELECT did FROM did_documents WHERE expires_at IS NOT NULL AND expires_at <= NOW() LIMIT 50"
84
+
)
85
+
.fetch_all(tx.as_mut())
86
+
.await?;
87
+
tx.commit().await?;
88
+
docs
89
+
};
90
+
91
+
let count = expired_docs.len() as i32;
92
+
93
+
if count == 0 {
94
+
return Ok(0);
95
+
}
96
+
97
+
tracing::info!(count = count, "processing expired DID documents");
98
+
99
+
for expired_doc in expired_docs {
100
+
tracing::debug!(did = expired_doc.did, "refreshing expired DID document");
101
+
102
+
match self.refresh_did_document(&expired_doc.did).await {
103
+
Ok(()) => {
104
+
tracing::debug!(did = expired_doc.did, "successfully refreshed DID document");
105
+
}
106
+
Err(err) => {
107
+
tracing::warn!(
108
+
did = expired_doc.did,
109
+
error = ?err,
110
+
"failed to refresh DID document, deleting from storage"
111
+
);
112
+
113
+
// If we can't resolve the DID, delete it from storage
114
+
if let Err(delete_err) = self
115
+
.document_storage
116
+
.delete_document_by_did(&expired_doc.did)
117
+
.await
118
+
{
119
+
tracing::error!(
120
+
did = expired_doc.did,
121
+
error = ?delete_err,
122
+
"failed to delete expired DID document"
123
+
);
124
+
}
125
+
}
126
+
}
127
+
}
128
+
129
+
Ok(count)
130
+
}
131
+
132
+
async fn refresh_did_document(&self, did: &str) -> Result<()> {
133
+
// Use the identity resolver to get the updated DID document
134
+
let document = self.identity_resolver.resolve(did).await?;
135
+
136
+
// Store the updated document using the DidDocumentStorage trait
137
+
// This will reset the expires_at column based on the storage implementation
138
+
self.document_storage.store_document(document).await?;
139
+
140
+
Ok(())
141
+
}
142
+
}
+87
src/task_oauth_requests_cleanup.rs
+87
src/task_oauth_requests_cleanup.rs
···
1
+
use anyhow::Result;
2
+
use chrono::Duration;
3
+
use tokio::time::{sleep, Instant};
4
+
use tokio_util::sync::CancellationToken;
5
+
6
+
use crate::storage::StoragePool;
7
+
8
+
pub struct OAuthRequestsCleanupTaskConfig {
9
+
pub sleep_interval: Duration,
10
+
}
11
+
12
+
pub struct OAuthRequestsCleanupTask {
13
+
pub config: OAuthRequestsCleanupTaskConfig,
14
+
pub storage_pool: StoragePool,
15
+
pub cancellation_token: CancellationToken,
16
+
}
17
+
18
+
impl OAuthRequestsCleanupTask {
19
+
#[must_use]
20
+
pub fn new(
21
+
config: OAuthRequestsCleanupTaskConfig,
22
+
storage_pool: StoragePool,
23
+
cancellation_token: CancellationToken,
24
+
) -> Self {
25
+
Self {
26
+
config,
27
+
storage_pool,
28
+
cancellation_token,
29
+
}
30
+
}
31
+
32
+
/// Runs the OAuth requests cleanup task as a long-running process
33
+
///
34
+
/// # Errors
35
+
/// Returns an error if the sleep interval cannot be converted, or if there's a problem
36
+
/// cleaning up expired requests
37
+
pub async fn run(&self) -> Result<()> {
38
+
tracing::debug!("OAuthRequestsCleanupTask started");
39
+
40
+
let interval = self.config.sleep_interval.to_std()?;
41
+
42
+
let sleeper = sleep(interval);
43
+
tokio::pin!(sleeper);
44
+
45
+
loop {
46
+
tokio::select! {
47
+
() = self.cancellation_token.cancelled() => {
48
+
break;
49
+
},
50
+
() = &mut sleeper => {
51
+
if let Err(err) = self.cleanup_expired_requests().await {
52
+
tracing::error!("OAuthRequestsCleanupTask failed: {}", err);
53
+
}
54
+
sleeper.as_mut().reset(Instant::now() + interval);
55
+
}
56
+
}
57
+
}
58
+
59
+
tracing::info!("OAuthRequestsCleanupTask stopped");
60
+
61
+
Ok(())
62
+
}
63
+
64
+
async fn cleanup_expired_requests(&self) -> Result<()> {
65
+
let now = chrono::Utc::now();
66
+
67
+
tracing::debug!("Starting cleanup of expired OAuth requests");
68
+
69
+
let result = sqlx::query("DELETE FROM atproto_oauth_requests WHERE expires_at < $1")
70
+
.bind(now)
71
+
.execute(&self.storage_pool)
72
+
.await?;
73
+
74
+
let deleted_count = result.rows_affected();
75
+
76
+
if deleted_count > 0 {
77
+
tracing::info!(
78
+
deleted_count = deleted_count,
79
+
"Cleaned up expired OAuth requests"
80
+
);
81
+
} else {
82
+
tracing::debug!("No expired OAuth requests to clean up");
83
+
}
84
+
85
+
Ok(())
86
+
}
87
+
}
+34
-17
src/task_refresh_tokens.rs
+34
-17
src/task_refresh_tokens.rs
···
1
1
use anyhow::Result;
2
+
use atproto_identity::key::identify_key;
3
+
use atproto_oauth::workflow::{oauth_refresh, OAuthClient};
2
4
use chrono::{Duration, Utc};
3
5
use deadpool_redis::redis::{pipe, AsyncCommands};
4
-
use p256::SecretKey;
5
6
use std::borrow::Cow;
6
7
use tokio::time::{sleep, Instant};
7
8
use tokio_util::sync::CancellationToken;
8
9
9
10
use crate::{
10
-
config::{OAuthActiveKeys, SigningKeys},
11
-
oauth::client_oauth_refresh,
11
+
config::SigningKeys,
12
12
refresh_tokens_errors::RefreshError,
13
13
storage::{
14
14
cache::{build_worker_queue, OAUTH_REFRESH_HEARTBEATS, OAUTH_REFRESH_QUEUE},
···
22
22
pub worker_id: String,
23
23
pub external_url_base: String,
24
24
pub signing_keys: SigningKeys,
25
-
pub oauth_active_keys: OAuthActiveKeys,
26
25
}
27
26
28
27
pub struct RefreshTokensTask {
···
30
29
pub http_client: reqwest::Client,
31
30
pub storage_pool: StoragePool,
32
31
pub cache_pool: CachePool,
32
+
pub document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>,
33
33
pub cancellation_token: CancellationToken,
34
34
}
35
35
···
40
40
http_client: reqwest::Client,
41
41
storage_pool: StoragePool,
42
42
cache_pool: CachePool,
43
+
document_storage: std::sync::Arc<dyn atproto_identity::storage::DidDocumentStorage>,
43
44
cancellation_token: CancellationToken,
44
45
) -> Self {
45
46
Self {
···
47
48
http_client,
48
49
storage_pool,
49
50
cache_pool,
51
+
document_storage,
50
52
cancellation_token,
51
53
}
52
54
}
···
180
182
let (handle, oauth_session) =
181
183
web_session_lookup(&self.storage_pool, session_group, None).await?;
182
184
183
-
let secret_signing_key = self
185
+
let secret_signing_key_string = self
184
186
.config
185
187
.signing_keys
186
188
.as_ref()
187
189
.get(&oauth_session.secret_jwk_id)
188
-
.cloned();
190
+
.cloned()
191
+
.ok_or_else(|| anyhow::Error::from(RefreshError::SecretSigningKeyNotFound))?;
192
+
193
+
let private_signing_key_data = identify_key(&secret_signing_key_string)?;
194
+
195
+
let private_dpop_key_data = identify_key(&oauth_session.dpop_jwk)?;
189
196
190
-
if secret_signing_key.is_none() {
191
-
return Err(RefreshError::SecretSigningKeyNotFound.into());
192
-
}
197
+
let document = match self
198
+
.document_storage
199
+
.get_document_by_did(&handle.did)
200
+
.await?
201
+
{
202
+
Some(doc) => doc,
203
+
None => return Err(RefreshError::IdentityDocumentNotFound.into()),
204
+
};
193
205
194
-
let dpop_secret_key = SecretKey::from_jwk(&oauth_session.dpop_jwk.jwk)
195
-
.map_err(RefreshError::DpopProofCreationFailed)?;
206
+
let oauth_client = OAuthClient {
207
+
redirect_uri: format!("https://{}/oauth/callback", self.config.external_url_base),
208
+
client_id: format!(
209
+
"https://{}/oauth/client-metadata.json",
210
+
self.config.external_url_base
211
+
),
212
+
private_signing_key_data,
213
+
};
196
214
197
-
let token_response = client_oauth_refresh(
215
+
let token_response = oauth_refresh(
198
216
&self.http_client,
199
-
&self.config.external_url_base,
200
-
(&oauth_session.secret_jwk_id, secret_signing_key.unwrap()),
201
-
oauth_session.refresh_token.as_str(),
202
-
&handle,
203
-
&dpop_secret_key,
217
+
&oauth_client,
218
+
&private_dpop_key_data,
219
+
&oauth_session.refresh_token,
220
+
&document,
204
221
)
205
222
.await?;
206
223
-199
src/validation.rs
-199
src/validation.rs
···
1
-
//! Validation module that provides utilities for validating hostnames and AT Protocol handles.
2
-
//!
3
-
//! This module implements RFC-compliant hostname validation and AT Protocol handle formatting rules.
4
-
5
-
/// Maximum length for a valid hostname as defined in RFC 1035
6
-
const MAX_HOSTNAME_LENGTH: usize = 253;
7
-
8
-
/// Maximum length for a DNS label (component between dots) as defined in RFC 1035
9
-
const MAX_LABEL_LENGTH: usize = 63;
10
-
11
-
/// List of reserved top-level domains that are not valid for AT Protocol handles
12
-
const RESERVED_TLDS: [&str; 4] = [".localhost", ".internal", ".arpa", ".local"];
13
-
14
-
/// Validates if a string is a valid hostname according to RFC standards.
15
-
///
16
-
/// A valid hostname must:
17
-
/// - Only contain alphanumeric characters, hyphens, and periods
18
-
/// - Not start or end labels with hyphens
19
-
/// - Have labels (parts between dots) with length between 1-63 characters
20
-
/// - Have total length not exceeding 253 characters
21
-
/// - Not use reserved top-level domains
22
-
///
23
-
/// # Arguments
24
-
/// * `hostname` - The hostname string to validate
25
-
///
26
-
/// # Returns
27
-
/// * `true` if the hostname is valid, `false` otherwise
28
-
#[must_use]
29
-
pub fn is_valid_hostname(hostname: &str) -> bool {
30
-
// Empty hostnames are invalid
31
-
if hostname.is_empty() || hostname.len() > MAX_HOSTNAME_LENGTH {
32
-
return false;
33
-
}
34
-
35
-
// Check if hostname uses any reserved TLDs
36
-
if RESERVED_TLDS.iter().any(|tld| hostname.ends_with(tld)) {
37
-
return false;
38
-
}
39
-
40
-
// Ensure all characters are valid hostname characters
41
-
if hostname.bytes().any(|byte| !is_valid_hostname_char(byte)) {
42
-
return false;
43
-
}
44
-
45
-
// Validate each DNS label in the hostname
46
-
if hostname.split('.').any(|label| !is_valid_dns_label(label)) {
47
-
return false;
48
-
}
49
-
50
-
true
51
-
}
52
-
53
-
/// Checks if a byte is a valid character in a hostname.
54
-
///
55
-
/// Valid characters are: a-z, A-Z, 0-9, hyphen (-), and period (.)
56
-
///
57
-
/// # Arguments
58
-
/// * `byte` - The byte to check
59
-
///
60
-
/// # Returns
61
-
/// * `true` if the byte is a valid hostname character, `false` otherwise
62
-
fn is_valid_hostname_char(byte: u8) -> bool {
63
-
byte.is_ascii_lowercase()
64
-
|| byte.is_ascii_uppercase()
65
-
|| byte.is_ascii_digit()
66
-
|| byte == b'-'
67
-
|| byte == b'.'
68
-
}
69
-
70
-
/// Validates if a DNS label is valid according to RFC standards.
71
-
///
72
-
/// A valid DNS label must:
73
-
/// - Not be empty
74
-
/// - Not exceed 63 characters
75
-
/// - Not start or end with a hyphen
76
-
///
77
-
/// # Arguments
78
-
/// * `label` - The DNS label to validate
79
-
///
80
-
/// # Returns
81
-
/// * `true` if the label is valid, `false` otherwise
82
-
fn is_valid_dns_label(label: &str) -> bool {
83
-
!(label.is_empty()
84
-
|| label.len() > MAX_LABEL_LENGTH
85
-
|| label.starts_with('-')
86
-
|| label.ends_with('-'))
87
-
}
88
-
89
-
/// Validates and normalizes an AT Protocol handle.
90
-
///
91
-
/// A valid AT Protocol handle must:
92
-
/// - Be a valid hostname (after stripping any prefixes)
93
-
/// - Contain at least one period (.)
94
-
/// - Can optionally have "at://" or "@" prefix, which will be removed
95
-
///
96
-
/// # Arguments
97
-
/// * `handle` - The handle string to validate
98
-
///
99
-
/// # Returns
100
-
/// * `Some(String)` containing the normalized handle if valid
101
-
/// * `None` if the handle is invalid
102
-
#[must_use]
103
-
pub fn is_valid_handle(handle: &str) -> Option<String> {
104
-
// Strip optional prefixes to get the core handle
105
-
let trimmed = strip_handle_prefixes(handle);
106
-
107
-
// A valid handle must be a valid hostname with at least one period
108
-
if is_valid_hostname(trimmed) && trimmed.contains('.') {
109
-
Some(trimmed.to_string())
110
-
} else {
111
-
None
112
-
}
113
-
}
114
-
115
-
/// Strips common AT Protocol handle prefixes.
116
-
///
117
-
/// Removes "at://" or "@" prefix if present.
118
-
///
119
-
/// # Arguments
120
-
/// * `handle` - The handle to strip prefixes from
121
-
///
122
-
/// # Returns
123
-
/// * The handle with prefixes removed
124
-
fn strip_handle_prefixes(handle: &str) -> &str {
125
-
if let Some(value) = handle.strip_prefix("at://") {
126
-
value
127
-
} else if let Some(value) = handle.strip_prefix('@') {
128
-
value
129
-
} else {
130
-
handle
131
-
}
132
-
}
133
-
134
-
#[cfg(test)]
135
-
mod tests {
136
-
use super::*;
137
-
138
-
#[test]
139
-
fn test_valid_hostnames() {
140
-
// Valid hostnames
141
-
assert!(is_valid_hostname("example.com"));
142
-
assert!(is_valid_hostname("subdomain.example.com"));
143
-
assert!(is_valid_hostname("with-hyphen.example.com"));
144
-
assert!(is_valid_hostname("123numeric.example.com"));
145
-
assert!(is_valid_hostname("xn--bcher-kva.example.com")); // IDN
146
-
}
147
-
148
-
#[test]
149
-
fn test_invalid_hostnames() {
150
-
// Invalid hostnames
151
-
assert!(!is_valid_hostname("")); // Empty
152
-
assert!(!is_valid_hostname("a".repeat(254).as_str())); // Too long
153
-
assert!(!is_valid_hostname("example.localhost")); // Reserved TLD
154
-
assert!(!is_valid_hostname("example.internal")); // Reserved TLD
155
-
assert!(!is_valid_hostname("example.arpa")); // Reserved TLD
156
-
assert!(!is_valid_hostname("example.local")); // Reserved TLD
157
-
assert!(!is_valid_hostname("invalid_char.example.com")); // Invalid underscore
158
-
assert!(!is_valid_hostname("-starts-with-hyphen.example.com")); // Label starts with hyphen
159
-
assert!(!is_valid_hostname("ends-with-hyphen-.example.com")); // Label ends with hyphen
160
-
assert!(!is_valid_hostname(&("a".repeat(64) + ".example.com"))); // Label too long
161
-
assert!(!is_valid_hostname(".starts.with.dot")); // Empty label
162
-
assert!(!is_valid_hostname("ends.with.dot.")); // Empty label
163
-
assert!(!is_valid_hostname("double..dot")); // Empty label
164
-
}
165
-
166
-
#[test]
167
-
fn test_valid_handles() {
168
-
// Valid handles
169
-
assert_eq!(
170
-
is_valid_handle("user.example.com"),
171
-
Some("user.example.com".to_string())
172
-
);
173
-
assert_eq!(
174
-
is_valid_handle("at://user.example.com"),
175
-
Some("user.example.com".to_string())
176
-
);
177
-
assert_eq!(
178
-
is_valid_handle("@user.example.com"),
179
-
Some("user.example.com".to_string())
180
-
);
181
-
}
182
-
183
-
#[test]
184
-
fn test_invalid_handles() {
185
-
// Invalid handles
186
-
assert_eq!(is_valid_handle("nodots"), None); // No dots
187
-
assert_eq!(is_valid_handle("at://invalid_char.example.com"), None); // Invalid character
188
-
assert_eq!(is_valid_handle("@example.localhost"), None); // Reserved TLD
189
-
}
190
-
191
-
#[test]
192
-
fn test_strip_handle_prefixes() {
193
-
assert_eq!(strip_handle_prefixes("example.com"), "example.com");
194
-
assert_eq!(strip_handle_prefixes("at://example.com"), "example.com");
195
-
assert_eq!(strip_handle_prefixes("@example.com"), "example.com");
196
-
// Nested prefixes should only strip the outermost one
197
-
assert_eq!(strip_handle_prefixes("at://@example.com"), "@example.com");
198
-
}
199
-
}
+13
templates/delete_event.en-us.html
+13
templates/delete_event.en-us.html
···
1
+
{% extends "base.en-us.html" %}
2
+
{% block title %}Smoke Signal - Delete Event{% endblock %}
3
+
{% block head %}{% endblock %}
4
+
{% block content %}
5
+
6
+
{% from "form_include.html" import text_input %}
7
+
<section class="section is-fullheight">
8
+
<div class="container ">
9
+
{% include 'delete_event.en-us.partial.html' %}
10
+
</div>
11
+
</section>
12
+
13
+
{% endblock %}
+23
templates/delete_event.en-us.partial.html
+23
templates/delete_event.en-us.partial.html
···
1
+
<div class="box content" style="background-color: #ffe0e6; border-color: #ff1744;">
2
+
<h2>Danger Zone</h2>
3
+
<p>This will delete the event from your Personal Data Server (PDS) as well as this Smoke Signal instance.</p>
4
+
<ol>
5
+
<li>Deleting records cannot be undone.</li>
6
+
<li>If you are making changes to the event, please consider just changing the event status, mode, and time.</li>
7
+
<li>Existing RSVP records will display "Unknown Event", possibly causing confusion to those who RSVP'd to the event.
8
+
</li>
9
+
</ol>
10
+
<form action="{{ delete_event_url }}" method="post">
11
+
{% if show_confirm %}
12
+
<div class="field">
13
+
<div class="control">
14
+
<label class="checkbox">
15
+
<input type="checkbox" name="confirm" value="true" required>
16
+
<strong>I understand that deleting this event cannot be undone</strong>
17
+
</label>
18
+
</div>
19
+
</div>
20
+
{% endif %}
21
+
<button type="submit" class="button is-danger">Delete Event</button>
22
+
</form>
23
+
</div>