Microservice to bring 2FA to self hosted PDSes

Compare changes

Choose any two refs to compare.

+4 -4
Cargo.lock
··· 656 656 checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" 657 657 dependencies = [ 658 658 "libc", 659 - "windows-sys 0.52.0", 659 + "windows-sys 0.59.0", 660 660 ] 661 661 662 662 [[package]] ··· 1392 1392 checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" 1393 1393 dependencies = [ 1394 1394 "cfg-if", 1395 - "windows-targets 0.48.5", 1395 + "windows-targets 0.52.6", 1396 1396 ] 1397 1397 1398 1398 [[package]] ··· 1690 1690 1691 1691 [[package]] 1692 1692 name = "pds_gatekeeper" 1693 - version = "0.1.0" 1693 + version = "0.1.2" 1694 1694 dependencies = [ 1695 1695 "anyhow", 1696 1696 "aws-lc-rs", ··· 2136 2136 "errno", 2137 2137 "libc", 2138 2138 "linux-raw-sys", 2139 - "windows-sys 0.52.0", 2139 + "windows-sys 0.59.0", 2140 2140 ] 2141 2141 2142 2142 [[package]]
+5 -5
Cargo.toml
··· 1 1 [package] 2 2 name = "pds_gatekeeper" 3 - version = "0.1.0" 3 + version = "0.1.2" 4 4 edition = "2024" 5 + license = "MIT" 5 6 6 7 [dependencies] 7 8 axum = { version = "0.8.4", features = ["macros", "json"] } ··· 14 15 tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } 15 16 hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] } 16 17 tower-http = { version = "0.6", features = ["cors", "compression-zstd"] } 17 - tower_governor = "0.8.0" 18 + tower_governor = { version = "0.8.0", features = ["axum", "tracing"] } 18 19 hex = "0.4" 19 20 jwt-compact = { version = "0.8.0", features = ["es256k"] } 20 21 scrypt = "0.11" 21 - #lettre = { version = "0.11.18", default-features = false, features = ["pool", "tokio1-rustls", "smtp-transport", "hostname", "builder"] } 22 - #lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } 22 + #Leaveing these two cause I think it is needed by the email crate for ssl 23 23 aws-lc-rs = "1.13.0" 24 + rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] } 24 25 lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } 25 - rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] } 26 26 handlebars = { version = "6.3.2", features = ["rust-embed"] } 27 27 rust-embed = "8.7.2" 28 28 axum-template = { version = "3.0.0", features = ["handlebars"] }
+21
LICENSE.md
··· 1 + MIT License 2 + 3 + Copyright (c) 2025 Bailey Townsend 4 + 5 + Permission is hereby granted, free of charge, to any person obtaining a copy 6 + of this software and associated documentation files (the "Software"), to deal 7 + in the Software without restriction, including without limitation the rights 8 + to use, copy, modify, merge, publish, distribute, sublicense, and/or sell 9 + copies of the Software, and to permit persons to whom the Software is 10 + furnished to do so, subject to the following conditions: 11 + 12 + The above copyright notice and this permission notice shall be included in all 13 + copies or substantial portions of the Software. 14 + 15 + THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 16 + IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 17 + FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 18 + AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 19 + LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 20 + OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 21 + SOFTWARE.
+61 -2
README.md
··· 37 37 ```yml 38 38 gatekeeper: 39 39 container_name: gatekeeper 40 - image: fatfingers23/pds_gatekeeper:arm-latest 40 + image: fatfingers23/pds_gatekeeper:latest 41 41 network_mode: host 42 42 restart: unless-stopped 43 43 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory ··· 49 49 - pds 50 50 ``` 51 51 52 + For Coolify, if you're using Traefik as your proxy you'll need to make sure the labels for the container are set up correctly. A full example can be found at [./examples/coolify-compose.yml](./examples/coolify-compose.yml). 53 + 54 + ```yml 55 + gatekeeper: 56 + container_name: gatekeeper 57 + image: 'fatfingers23/pds_gatekeeper:latest' 58 + restart: unless-stopped 59 + volumes: 60 + - '/pds:/pds' 61 + environment: 62 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 63 + - 'PDS_BASE_URL=http://pds:3000' 64 + - GATEKEEPER_HOST=0.0.0.0 65 + depends_on: 66 + - pds 67 + healthcheck: 68 + test: 69 + - CMD 70 + - timeout 71 + - '1' 72 + - bash 73 + - '-c' 74 + - 'cat < /dev/null > /dev/tcp/0.0.0.0/8080' 75 + interval: 10s 76 + timeout: 5s 77 + retries: 3 78 + start_period: 10s 79 + labels: 80 + - traefik.enable=true 81 + - 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))' 82 + - traefik.http.routers.pds-gatekeeper.entrypoints=https 83 + - traefik.http.routers.pds-gatekeeper.tls=true 84 + - traefik.http.routers.pds-gatekeeper.priority=100 85 + - traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors 86 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080 87 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http 88 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH' 89 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*' 90 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*' 91 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100 92 + - traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true 93 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true 94 + ``` 95 + 52 96 ## Caddy setup 53 97 54 98 For the reverse proxy I use caddy. This part is what overwrites the endpoints and proxies them to PDS gatekeeper to add ··· 60 104 path /xrpc/com.atproto.server.getSession 61 105 path /xrpc/com.atproto.server.updateEmail 62 106 path /xrpc/com.atproto.server.createSession 107 + path /xrpc/com.atproto.server.createAccount 63 108 path /@atproto/oauth-provider/~api/sign-in 64 109 } 65 110 ··· 79 124 path /xrpc/com.atproto.server.getSession 80 125 path /xrpc/com.atproto.server.updateEmail 81 126 path /xrpc/com.atproto.server.createSession 127 + path /xrpc/com.atproto.server.createAccount 82 128 path /@atproto/oauth-provider/~api/sign-in 83 129 } 84 130 85 131 handle @gatekeeper { 86 - reverse_proxy http://localhost:8080 132 + reverse_proxy http://localhost:8080 { 133 + #Makes sure the cloudflare ip is proxied and able to be picked up by pds gatekeeper 134 + header_up X-Forwarded-For {http.request.header.CF-Connecting-IP} 135 + } 87 136 } 88 137 89 138 reverse_proxy http://localhost:3000 ··· 105 154 in the pds gateekeper container and it will use them in place of the default ones. Just make sure ot keep the names the 106 155 same. 107 156 157 + `GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT` - Subject of the email sent to the user when they turn on 2FA. Defaults to 158 + `Sign in to Bluesky` 159 + 108 160 `PDS_BASE_URL` - Base url of the PDS. You most likely want `https://localhost:3000` which is also the default 109 161 110 162 `GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1` 111 163 112 164 `GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080` 165 + 166 + `GATEKEEPER_CREATE_ACCOUNT_PER_SECOND` - Sets how often it takes a count off the limiter. example if you hit the rate 167 + limit of 5 and set to 60, then in 60 seconds you will be able to make one more. Or in 5 minutes be able to make 5 more. 168 + 169 + `GATEKEEPER_CREATE_ACCOUNT_BURST` - Sets how many requests can be made in a burst. In the prior example this is where 170 + the 5 comes from. Example can set this to 10 to allow for 10 requests in a burst, and after 60 seconds it will drop one 171 + off.
+1
examples/Caddyfile
··· 14 14 path /xrpc/com.atproto.server.getSession 15 15 path /xrpc/com.atproto.server.updateEmail 16 16 path /xrpc/com.atproto.server.createSession 17 + path /xrpc/com.atproto.server.createAccount 17 18 path /@atproto/oauth-provider/~api/sign-in 18 19 } 19 20
+1 -1
examples/compose.yml
··· 39 39 WATCHTOWER_SCHEDULE: "@midnight" 40 40 gatekeeper: 41 41 container_name: gatekeeper 42 - image: fatfingers23/pds_gatekeeper:arm-latest 42 + image: fatfingers23/pds_gatekeeper:latest 43 43 network_mode: host 44 44 restart: unless-stopped 45 45 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory
+73
examples/coolify-compose.yml
··· 1 + services: 2 + pds: 3 + image: 'ghcr.io/bluesky-social/pds:0.4.182' 4 + volumes: 5 + - '/pds:/pds' 6 + environment: 7 + - SERVICE_URL_PDS_3000 8 + - 'PDS_HOSTNAME=${SERVICE_FQDN_PDS_3000}' 9 + - 'PDS_JWT_SECRET=${SERVICE_HEX_32_JWTSECRET}' 10 + - 'PDS_ADMIN_PASSWORD=${SERVICE_PASSWORD_ADMIN}' 11 + - 'PDS_ADMIN_EMAIL=${PDS_ADMIN_EMAIL}' 12 + - 'PDS_PLC_ROTATION_KEY_K256_PRIVATE_KEY_HEX=${SERVICE_HEX_32_ROTATIONKEY}' 13 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 14 + - 'PDS_BLOBSTORE_DISK_LOCATION=${PDS_DATA_DIRECTORY:-/pds}/blocks' 15 + - 'PDS_BLOB_UPLOAD_LIMIT=${PDS_BLOB_UPLOAD_LIMIT:-104857600}' 16 + - 'PDS_DID_PLC_URL=${PDS_DID_PLC_URL:-https://plc.directory}' 17 + - 'PDS_EMAIL_FROM_ADDRESS=${PDS_EMAIL_FROM_ADDRESS}' 18 + - 'PDS_EMAIL_SMTP_URL=${PDS_EMAIL_SMTP_URL}' 19 + - 'PDS_BSKY_APP_VIEW_URL=${PDS_BSKY_APP_VIEW_URL:-https://api.bsky.app}' 20 + - 'PDS_BSKY_APP_VIEW_DID=${PDS_BSKY_APP_VIEW_DID:-did:web:api.bsky.app}' 21 + - 'PDS_REPORT_SERVICE_URL=${PDS_REPORT_SERVICE_URL:-https://mod.bsky.app/xrpc/com.atproto.moderation.createReport}' 22 + - 'PDS_REPORT_SERVICE_DID=${PDS_REPORT_SERVICE_DID:-did:plc:ar7c4by46qjdydhdevvrndac}' 23 + - 'PDS_CRAWLERS=${PDS_CRAWLERS:-https://bsky.network}' 24 + - 'LOG_ENABLED=${LOG_ENABLED:-true}' 25 + command: "sh -c '\n set -euo pipefail\n echo \"Installing required packages and pdsadmin...\"\n apk add --no-cache openssl curl bash jq coreutils gnupg util-linux-misc >/dev/null\n curl -o /usr/local/bin/pdsadmin.sh https://raw.githubusercontent.com/bluesky-social/pds/main/pdsadmin.sh\n chmod 700 /usr/local/bin/pdsadmin.sh\n ln -sf /usr/local/bin/pdsadmin.sh /usr/local/bin/pdsadmin\n echo \"Creating an empty pds.env file so pdsadmin works...\"\n touch ${PDS_DATA_DIRECTORY}/pds.env\n echo \"Launching PDS, enjoy!...\"\n exec node --enable-source-maps index.js\n'\n" 26 + healthcheck: 27 + test: 28 + - CMD 29 + - wget 30 + - '--spider' 31 + - 'http://127.0.0.1:3000/xrpc/_health' 32 + interval: 5s 33 + timeout: 10s 34 + retries: 10 35 + gatekeeper: 36 + container_name: gatekeeper 37 + image: 'fatfingers23/pds_gatekeeper:latest' 38 + restart: unless-stopped 39 + volumes: 40 + - '/pds:/pds' 41 + environment: 42 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 43 + - 'PDS_BASE_URL=http://pds:3000' 44 + - GATEKEEPER_HOST=0.0.0.0 45 + depends_on: 46 + - pds 47 + healthcheck: 48 + test: 49 + - CMD 50 + - timeout 51 + - '1' 52 + - bash 53 + - '-c' 54 + - 'cat < /dev/null > /dev/tcp/0.0.0.0/8080' 55 + interval: 10s 56 + timeout: 5s 57 + retries: 3 58 + start_period: 10s 59 + labels: 60 + - traefik.enable=true 61 + - 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))' 62 + - traefik.http.routers.pds-gatekeeper.entrypoints=https 63 + - traefik.http.routers.pds-gatekeeper.tls=true 64 + - traefik.http.routers.pds-gatekeeper.priority=100 65 + - traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors 66 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080 67 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http 68 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH' 69 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*' 70 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*' 71 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100 72 + - traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true 73 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true
+1 -1
justfile
··· 2 2 docker buildx build \ 3 3 --platform linux/arm64,linux/amd64 \ 4 4 --tag fatfingers23/pds_gatekeeper:latest \ 5 - --tag fatfingers23/pds_gatekeeper:0.1 \ 5 + --tag fatfingers23/pds_gatekeeper:0.1.0.3 \ 6 6 --push .
+6 -5
src/helpers.rs
··· 15 15 use serde_json::{Map, Value}; 16 16 use sha2::{Digest, Sha256}; 17 17 use sqlx::SqlitePool; 18 + use std::env; 18 19 use tracing::{error, log}; 19 20 20 21 ///Used to generate the email 2fa code ··· 134 135 full_code.push(UPPERCASE_BASE32_CHARS[idx] as char); 135 136 } 136 137 137 - //The PDS implementation creates in lowercase, then converts to uppercase. 138 - //Just going a head and doing uppercase here. 139 - let slice_one = &full_code[0..5].to_ascii_uppercase(); 140 - let slice_two = &full_code[5..10].to_ascii_uppercase(); 138 + let slice_one = &full_code[0..5]; 139 + let slice_two = &full_code[5..10]; 141 140 format!("{slice_one}-{slice_two}") 142 141 } 143 142 ··· 334 333 let email_body = state 335 334 .template_engine 336 335 .render("two_factor_code.hbs", email_data)?; 336 + let email_subject = env::var("GATEKEEPER_TWO_FACTOR_EMAIL_SUBJECT") 337 + .unwrap_or("Sign in to Bluesky".to_string()); 337 338 338 339 let email_message = Message::builder() 339 340 //TODO prob get the proper type in the state 340 341 .from(state.mailer_from.parse()?) 341 342 .to(email.parse()?) 342 - .subject("Sign in to Bluesky") 343 + .subject(email_subject) 343 344 .multipart( 344 345 MultiPart::alternative() // This is composed of two parts. 345 346 .singlepart(
+54 -9
src/main.rs
··· 1 1 #![warn(clippy::unwrap_used)] 2 2 use crate::oauth_provider::sign_in; 3 - use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 3 + use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; 4 4 use axum::body::Body; 5 5 use axum::handler::Handler; 6 6 use axum::http::{Method, header}; ··· 20 20 use std::{env, net::SocketAddr}; 21 21 use tower_governor::GovernorLayer; 22 22 use tower_governor::governor::GovernorConfigBuilder; 23 + use tower_governor::key_extractor::SmartIpKeyExtractor; 23 24 use tower_http::compression::CompressionLayer; 24 25 use tower_http::cors::{Any, CorsLayer}; 25 26 use tracing::log; ··· 91 92 let pds_env_location = 92 93 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 93 94 94 - dotenvy::from_path(Path::new(&pds_env_location))?; 95 - let pds_root = env::var("PDS_DATA_DIRECTORY")?; 95 + let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 96 + if let Err(e) = result_of_finding_pds_env { 97 + log::error!( 98 + "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 99 + ); 100 + } 101 + 102 + let pds_root = 103 + env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 96 104 let account_db_url = format!("{pds_root}/account.sqlite"); 97 105 98 106 let account_options = SqliteConnectOptions::new() ··· 165 173 let create_session_governor_conf = GovernorConfigBuilder::default() 166 174 .per_second(60) 167 175 .burst_size(5) 176 + .key_extractor(SmartIpKeyExtractor) 168 177 .finish() 169 - .expect("failed to create governor config. this should not happen and is a bug"); 178 + .expect("failed to create governor config for create session. this should not happen and is a bug"); 170 179 171 180 // Create a second config with the same settings for the other endpoint 172 181 let sign_in_governor_conf = GovernorConfigBuilder::default() 173 182 .per_second(60) 174 183 .burst_size(5) 184 + .key_extractor(SmartIpKeyExtractor) 175 185 .finish() 176 - .expect("failed to create governor config. this should not happen and is a bug"); 186 + .expect( 187 + "failed to create governor config for sign in. this should not happen and is a bug", 188 + ); 189 + 190 + let create_account_limiter_time: Option<String> = 191 + env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 192 + let create_account_limiter_burst: Option<String> = 193 + env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 194 + 195 + //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 196 + let mut create_account_governor_conf = GovernorConfigBuilder::default(); 197 + if create_account_limiter_time.is_some() { 198 + let time = create_account_limiter_time 199 + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 200 + .parse::<u64>() 201 + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 202 + create_account_governor_conf.per_second(time); 203 + } 204 + 205 + if create_account_limiter_burst.is_some() { 206 + let burst = create_account_limiter_burst 207 + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 208 + .parse::<u32>() 209 + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 210 + create_account_governor_conf.burst_size(burst); 211 + } 212 + 213 + let create_account_governor_conf = create_account_governor_conf 214 + .key_extractor(SmartIpKeyExtractor) 215 + .finish().expect( 216 + "failed to create governor config for create account. this should not happen and is a bug", 217 + ); 177 218 178 219 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 179 220 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 221 + let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 222 + 180 223 let interval = Duration::from_secs(60); 181 224 // a separate background task to clean up 182 225 std::thread::spawn(move || { ··· 184 227 std::thread::sleep(interval); 185 228 create_session_governor_limiter.retain_recent(); 186 229 sign_in_governor_limiter.retain_recent(); 230 + create_account_governor_limiter.retain_recent(); 187 231 } 188 232 }); 189 233 ··· 194 238 195 239 let app = Router::new() 196 240 .route("/", get(root_handler)) 197 - .route( 198 - "/xrpc/com.atproto.server.getSession", 199 - get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), 200 - ) 241 + .route("/xrpc/com.atproto.server.getSession", get(get_session)) 201 242 .route( 202 243 "/xrpc/com.atproto.server.updateEmail", 203 244 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), ··· 209 250 .route( 210 251 "/xrpc/com.atproto.server.createSession", 211 252 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 253 + ) 254 + .route( 255 + "/xrpc/com.atproto.server.createAccount", 256 + post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 212 257 ) 213 258 .layer(CompressionLayer::new()) 214 259 .layer(cors)
+73 -39
src/middleware.rs
··· 12 12 #[derive(Clone, Debug)] 13 13 pub struct Did(pub Option<String>); 14 14 15 + #[derive(Clone, Copy, Debug, PartialEq, Eq)] 16 + pub enum AuthScheme { 17 + Bearer, 18 + DPoP, 19 + } 20 + 15 21 #[derive(Serialize, Deserialize)] 16 22 pub struct TokenClaims { 17 23 pub sub: String, 18 24 } 19 25 20 26 pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { 21 - let token = extract_bearer(req.headers()); 27 + let auth = extract_auth(req.headers()); 22 28 23 - match token { 24 - Ok(token) => { 25 - match token { 29 + match auth { 30 + Ok(auth_opt) => { 31 + match auth_opt { 26 32 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 27 33 .expect("Error creating an error response"), 28 - Some(token) => { 29 - let token = UntrustedToken::new(&token); 30 - if token.is_err() { 31 - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 32 - .expect("Error creating an error response"); 33 - } 34 - let parsed_token = token.expect("Already checked for error"); 35 - let claims: Result<Claims<TokenClaims>, ValidationError> = 36 - parsed_token.deserialize_claims_unchecked(); 37 - if claims.is_err() { 38 - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 39 - .expect("Error creating an error response"); 40 - } 34 + Some((scheme, token_str)) => { 35 + // For Bearer, validate JWT and extract DID from `sub`. 36 + // For DPoP, we currently only pass through and do not validate here; insert None DID. 37 + match scheme { 38 + AuthScheme::Bearer => { 39 + let token = UntrustedToken::new(&token_str); 40 + if token.is_err() { 41 + return json_error_response( 42 + StatusCode::BAD_REQUEST, 43 + "TokenRequired", 44 + "", 45 + ) 46 + .expect("Error creating an error response"); 47 + } 48 + let parsed_token = token.expect("Already checked for error"); 49 + let claims: Result<Claims<TokenClaims>, ValidationError> = 50 + parsed_token.deserialize_claims_unchecked(); 51 + if claims.is_err() { 52 + return json_error_response( 53 + StatusCode::BAD_REQUEST, 54 + "TokenRequired", 55 + "", 56 + ) 57 + .expect("Error creating an error response"); 58 + } 41 59 42 - let key = Hs256Key::new( 43 - env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), 44 - ); 45 - let token: Result<Token<TokenClaims>, ValidationError> = 46 - Hs256.validator(&key).validate(&parsed_token); 47 - if token.is_err() { 48 - return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 49 - .expect("Error creating an error response"); 60 + let key = Hs256Key::new( 61 + env::var("PDS_JWT_SECRET") 62 + .expect("PDS_JWT_SECRET not set in the pds.env"), 63 + ); 64 + let token: Result<Token<TokenClaims>, ValidationError> = 65 + Hs256.validator(&key).validate(&parsed_token); 66 + if token.is_err() { 67 + return json_error_response( 68 + StatusCode::BAD_REQUEST, 69 + "InvalidToken", 70 + "", 71 + ) 72 + .expect("Error creating an error response"); 73 + } 74 + let token = token.expect("Already checked for error,"); 75 + req.extensions_mut() 76 + .insert(Did(Some(token.claims().custom.sub.clone()))); 77 + } 78 + AuthScheme::DPoP => { 79 + //Not going to worry about oauth email update for now, just always forward to the PDS 80 + req.extensions_mut().insert(Did(None)); 81 + } 50 82 } 51 - let token = token.expect("Already checked for error,"); 52 - //Not going to worry about expiration since it still goes to the PDS 53 - req.extensions_mut() 54 - .insert(Did(Some(token.claims().custom.sub.clone()))); 83 + 55 84 next.run(req).await 56 85 } 57 86 } ··· 64 93 } 65 94 } 66 95 67 - fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> { 96 + fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> { 68 97 match headers.get(axum::http::header::AUTHORIZATION) { 69 98 None => Ok(None), 70 - Some(hv) => match hv.to_str() { 71 - Err(_) => Err("Authorization header is not valid".into()), 72 - Ok(s) => { 73 - // Accept forms like: "Bearer <token>" (case-sensitive for the scheme here) 74 - let mut parts = s.splitn(2, ' '); 75 - match (parts.next(), parts.next()) { 76 - (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())), 77 - _ => Err("Authorization header must be in format 'Bearer <token>'".into()), 99 + Some(hv) => { 100 + match hv.to_str() { 101 + Err(_) => Err("Authorization header is not valid".into()), 102 + Ok(s) => { 103 + // Accept forms like: "Bearer <token>" or "DPoP <token>" (case-sensitive for the scheme here) 104 + let mut parts = s.splitn(2, ' '); 105 + match (parts.next(), parts.next()) { 106 + (Some("Bearer"), Some(tok)) if !tok.is_empty() => 107 + Ok(Some((AuthScheme::Bearer, tok.to_string()))), 108 + (Some("DPoP"), Some(tok)) if !tok.is_empty() => 109 + Ok(Some((AuthScheme::DPoP, tok.to_string()))), 110 + _ => Err("Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'".into()), 111 + } 78 112 } 79 113 } 80 - }, 114 + } 81 115 } 82 116 }
+2 -1
src/oauth_provider.rs
··· 13 13 pub struct SignInRequest { 14 14 pub username: String, 15 15 pub password: String, 16 - pub remember: bool, 16 + #[serde(skip_serializing_if = "Option::is_none")] 17 + pub remember: Option<bool>, 17 18 pub locale: String, 18 19 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 19 20 pub email_otp: Option<String>,
+94 -48
src/xrpc/com_atproto_server.rs
··· 87 87 ) 88 88 } 89 89 AuthResult::ProxyThrough => { 90 - log::info!("Proxying through"); 91 90 //No 2FA or already passed 92 91 let uri = format!( 93 92 "{}{}", ··· 148 147 //If email auth is set it is to either turn on or off 2fa 149 148 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 150 149 151 - // Email update asked for 152 - if email_auth_update { 153 - let email = payload.email.clone(); 154 - let email_confirmed = sqlx::query_as::<_, (String,)>( 155 - "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", 156 - ) 157 - .bind(&email) 158 - .fetch_optional(&state.account_pool) 159 - .await 160 - .map_err(|_| StatusCode::BAD_REQUEST)?; 150 + //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS 151 + //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented 152 + let did_is_not_empty = did.0.is_some(); 161 153 162 - //Since the email is already confirmed we can enable 2fa 163 - return match email_confirmed { 164 - None => Err(StatusCode::BAD_REQUEST), 165 - Some(did_row) => { 166 - let _ = sqlx::query( 167 - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", 168 - ) 169 - .bind(&did_row.0) 170 - .execute(&state.pds_gatekeeper_pool) 171 - .await 172 - .map_err(|_| StatusCode::BAD_REQUEST)?; 173 - 174 - Ok(StatusCode::OK.into_response()) 175 - } 176 - }; 177 - } 178 - 179 - // User wants auth turned off 180 - if !email_auth_update && !email_auth_not_set { 181 - //User wants auth turned off and has a token 182 - if let Some(token) = &payload.token { 183 - let token_found = sqlx::query_as::<_, (String,)>( 184 - "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", 154 + if did_is_not_empty { 155 + // Email update asked for 156 + if email_auth_update { 157 + let email = payload.email.clone(); 158 + let email_confirmed = match sqlx::query_as::<_, (String,)>( 159 + "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", 185 160 ) 186 - .bind(token) 187 - .bind(&did.0) 161 + .bind(&email) 188 162 .fetch_optional(&state.account_pool) 189 163 .await 190 - .map_err(|_| StatusCode::BAD_REQUEST)?; 164 + { 165 + Ok(row) => row, 166 + Err(err) => { 167 + log::error!("Error checking if email is confirmed: {err}"); 168 + return Err(StatusCode::BAD_REQUEST); 169 + } 170 + }; 171 + 172 + //Since the email is already confirmed we can enable 2fa 173 + return match email_confirmed { 174 + None => Err(StatusCode::BAD_REQUEST), 175 + Some(did_row) => { 176 + let _ = sqlx::query( 177 + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", 178 + ) 179 + .bind(&did_row.0) 180 + .execute(&state.pds_gatekeeper_pool) 181 + .await 182 + .map_err(|_| StatusCode::BAD_REQUEST)?; 191 183 192 - if token_found.is_some() { 193 - let _ = sqlx::query( 194 - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", 184 + Ok(StatusCode::OK.into_response()) 185 + } 186 + }; 187 + } 188 + 189 + // User wants auth turned off 190 + if !email_auth_update && !email_auth_not_set { 191 + //User wants auth turned off and has a token 192 + if let Some(token) = &payload.token { 193 + let token_found = match sqlx::query_as::<_, (String,)>( 194 + "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", 195 195 ) 196 - .bind(&did.0) 197 - .execute(&state.pds_gatekeeper_pool) 198 - .await 199 - .map_err(|_| StatusCode::BAD_REQUEST)?; 196 + .bind(token) 197 + .bind(&did.0) 198 + .fetch_optional(&state.account_pool) 199 + .await{ 200 + Ok(token) => token, 201 + Err(err) => { 202 + log::error!("Error checking if token is valid: {err}"); 203 + return Err(StatusCode::BAD_REQUEST); 204 + } 205 + }; 206 + 207 + return if token_found.is_some() { 208 + //TODO I think there may be a bug here and need to do some retry logic 209 + // First try was erroring, seconds was allowing 210 + match sqlx::query( 211 + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", 212 + ) 213 + .bind(&did.0) 214 + .execute(&state.pds_gatekeeper_pool) 215 + .await { 216 + Ok(_) => {} 217 + Err(err) => { 218 + log::error!("Error updating email auth: {err}"); 219 + return Err(StatusCode::BAD_REQUEST); 220 + } 221 + } 200 222 201 - return Ok(StatusCode::OK.into_response()); 202 - } else { 203 - return Err(StatusCode::BAD_REQUEST); 223 + Ok(StatusCode::OK.into_response()) 224 + } else { 225 + Err(StatusCode::BAD_REQUEST) 226 + }; 204 227 } 205 228 } 206 229 } 207 - 208 230 // Updating the actual email address by sending it on to the PDS 209 231 let uri = format!( 210 232 "{}{}", ··· 260 282 ProxiedResult::Passthrough(resp) => Ok(resp), 261 283 } 262 284 } 285 + 286 + pub async fn create_account( 287 + State(state): State<AppState>, 288 + mut req: Request, 289 + ) -> Result<Response<Body>, StatusCode> { 290 + //TODO if I add the block of only accounts authenticated just take the body as json here and grab the lxm token. No middle ware is needed 291 + 292 + let uri = format!( 293 + "{}{}", 294 + state.pds_base_url, "/xrpc/com.atproto.server.createAccount" 295 + ); 296 + 297 + // Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact 298 + *req.uri_mut() = uri.parse().map_err(|_| StatusCode::BAD_REQUEST)?; 299 + 300 + let proxied = state 301 + .reverse_proxy_client 302 + .request(req) 303 + .await 304 + .map_err(|_| StatusCode::BAD_REQUEST)? 305 + .into_response(); 306 + 307 + Ok(proxied) 308 + }