Microservice to bring 2FA to self hosted PDSes

Compare changes

Choose any two refs to compare.

+4 -4
Cargo.lock
··· 656 checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" 657 dependencies = [ 658 "libc", 659 - "windows-sys 0.52.0", 660 ] 661 662 [[package]] ··· 1392 checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" 1393 dependencies = [ 1394 "cfg-if", 1395 - "windows-targets 0.48.5", 1396 ] 1397 1398 [[package]] ··· 1690 1691 [[package]] 1692 name = "pds_gatekeeper" 1693 - version = "0.1.0" 1694 dependencies = [ 1695 "anyhow", 1696 "aws-lc-rs", ··· 2136 "errno", 2137 "libc", 2138 "linux-raw-sys", 2139 - "windows-sys 0.52.0", 2140 ] 2141 2142 [[package]]
··· 656 checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad" 657 dependencies = [ 658 "libc", 659 + "windows-sys 0.59.0", 660 ] 661 662 [[package]] ··· 1392 checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667" 1393 dependencies = [ 1394 "cfg-if", 1395 + "windows-targets 0.52.6", 1396 ] 1397 1398 [[package]] ··· 1690 1691 [[package]] 1692 name = "pds_gatekeeper" 1693 + version = "0.1.2" 1694 dependencies = [ 1695 "anyhow", 1696 "aws-lc-rs", ··· 2136 "errno", 2137 "libc", 2138 "linux-raw-sys", 2139 + "windows-sys 0.59.0", 2140 ] 2141 2142 [[package]]
+4 -5
Cargo.toml
··· 1 [package] 2 name = "pds_gatekeeper" 3 - version = "0.1.0" 4 edition = "2024" 5 license = "MIT" 6 ··· 15 tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } 16 hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] } 17 tower-http = { version = "0.6", features = ["cors", "compression-zstd"] } 18 - tower_governor = "0.8.0" 19 hex = "0.4" 20 jwt-compact = { version = "0.8.0", features = ["es256k"] } 21 scrypt = "0.11" 22 - #lettre = { version = "0.11.18", default-features = false, features = ["pool", "tokio1-rustls", "smtp-transport", "hostname", "builder"] } 23 - #lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } 24 aws-lc-rs = "1.13.0" 25 lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } 26 - rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] } 27 handlebars = { version = "6.3.2", features = ["rust-embed"] } 28 rust-embed = "8.7.2" 29 axum-template = { version = "3.0.0", features = ["handlebars"] }
··· 1 [package] 2 name = "pds_gatekeeper" 3 + version = "0.1.2" 4 edition = "2024" 5 license = "MIT" 6 ··· 15 tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] } 16 hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] } 17 tower-http = { version = "0.6", features = ["cors", "compression-zstd"] } 18 + tower_governor = { version = "0.8.0", features = ["axum", "tracing"] } 19 hex = "0.4" 20 jwt-compact = { version = "0.8.0", features = ["es256k"] } 21 scrypt = "0.11" 22 + #Leaveing these two cause I think it is needed by the email crate for ssl 23 aws-lc-rs = "1.13.0" 24 + rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] } 25 lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] } 26 handlebars = { version = "6.3.2", features = ["rust-embed"] } 27 rust-embed = "8.7.2" 28 axum-template = { version = "3.0.0", features = ["handlebars"] }
+58 -2
README.md
··· 37 ```yml 38 gatekeeper: 39 container_name: gatekeeper 40 - image: fatfingers23/pds_gatekeeper:arm-latest 41 network_mode: host 42 restart: unless-stopped 43 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory ··· 49 - pds 50 ``` 51 52 ## Caddy setup 53 54 For the reverse proxy I use caddy. This part is what overwrites the endpoints and proxies them to PDS gatekeeper to add ··· 60 path /xrpc/com.atproto.server.getSession 61 path /xrpc/com.atproto.server.updateEmail 62 path /xrpc/com.atproto.server.createSession 63 path /@atproto/oauth-provider/~api/sign-in 64 } 65 ··· 79 path /xrpc/com.atproto.server.getSession 80 path /xrpc/com.atproto.server.updateEmail 81 path /xrpc/com.atproto.server.createSession 82 path /@atproto/oauth-provider/~api/sign-in 83 } 84 85 handle @gatekeeper { 86 - reverse_proxy http://localhost:8080 87 } 88 89 reverse_proxy http://localhost:3000 ··· 113 `GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1` 114 115 `GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080`
··· 37 ```yml 38 gatekeeper: 39 container_name: gatekeeper 40 + image: fatfingers23/pds_gatekeeper:latest 41 network_mode: host 42 restart: unless-stopped 43 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory ··· 49 - pds 50 ``` 51 52 + For Coolify, if you're using Traefik as your proxy you'll need to make sure the labels for the container are set up correctly. A full example can be found at [./examples/coolify-compose.yml](./examples/coolify-compose.yml). 53 + 54 + ```yml 55 + gatekeeper: 56 + container_name: gatekeeper 57 + image: 'fatfingers23/pds_gatekeeper:latest' 58 + restart: unless-stopped 59 + volumes: 60 + - '/pds:/pds' 61 + environment: 62 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 63 + - 'PDS_BASE_URL=http://pds:3000' 64 + - GATEKEEPER_HOST=0.0.0.0 65 + depends_on: 66 + - pds 67 + healthcheck: 68 + test: 69 + - CMD 70 + - timeout 71 + - '1' 72 + - bash 73 + - '-c' 74 + - 'cat < /dev/null > /dev/tcp/0.0.0.0/8080' 75 + interval: 10s 76 + timeout: 5s 77 + retries: 3 78 + start_period: 10s 79 + labels: 80 + - traefik.enable=true 81 + - 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))' 82 + - traefik.http.routers.pds-gatekeeper.entrypoints=https 83 + - traefik.http.routers.pds-gatekeeper.tls=true 84 + - traefik.http.routers.pds-gatekeeper.priority=100 85 + - traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors 86 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080 87 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http 88 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH' 89 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*' 90 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*' 91 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100 92 + - traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true 93 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true 94 + ``` 95 + 96 ## Caddy setup 97 98 For the reverse proxy I use caddy. This part is what overwrites the endpoints and proxies them to PDS gatekeeper to add ··· 104 path /xrpc/com.atproto.server.getSession 105 path /xrpc/com.atproto.server.updateEmail 106 path /xrpc/com.atproto.server.createSession 107 + path /xrpc/com.atproto.server.createAccount 108 path /@atproto/oauth-provider/~api/sign-in 109 } 110 ··· 124 path /xrpc/com.atproto.server.getSession 125 path /xrpc/com.atproto.server.updateEmail 126 path /xrpc/com.atproto.server.createSession 127 + path /xrpc/com.atproto.server.createAccount 128 path /@atproto/oauth-provider/~api/sign-in 129 } 130 131 handle @gatekeeper { 132 + reverse_proxy http://localhost:8080 { 133 + #Makes sure the cloudflare ip is proxied and able to be picked up by pds gatekeeper 134 + header_up X-Forwarded-For {http.request.header.CF-Connecting-IP} 135 + } 136 } 137 138 reverse_proxy http://localhost:3000 ··· 162 `GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1` 163 164 `GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080` 165 + 166 + `GATEKEEPER_CREATE_ACCOUNT_PER_SECOND` - Sets how often it takes a count off the limiter. example if you hit the rate 167 + limit of 5 and set to 60, then in 60 seconds you will be able to make one more. Or in 5 minutes be able to make 5 more. 168 + 169 + `GATEKEEPER_CREATE_ACCOUNT_BURST` - Sets how many requests can be made in a burst. In the prior example this is where 170 + the 5 comes from. Example can set this to 10 to allow for 10 requests in a burst, and after 60 seconds it will drop one 171 + off.
+1
examples/Caddyfile
··· 14 path /xrpc/com.atproto.server.getSession 15 path /xrpc/com.atproto.server.updateEmail 16 path /xrpc/com.atproto.server.createSession 17 path /@atproto/oauth-provider/~api/sign-in 18 } 19
··· 14 path /xrpc/com.atproto.server.getSession 15 path /xrpc/com.atproto.server.updateEmail 16 path /xrpc/com.atproto.server.createSession 17 + path /xrpc/com.atproto.server.createAccount 18 path /@atproto/oauth-provider/~api/sign-in 19 } 20
+1 -1
examples/compose.yml
··· 39 WATCHTOWER_SCHEDULE: "@midnight" 40 gatekeeper: 41 container_name: gatekeeper 42 - image: fatfingers23/pds_gatekeeper:arm-latest 43 network_mode: host 44 restart: unless-stopped 45 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory
··· 39 WATCHTOWER_SCHEDULE: "@midnight" 40 gatekeeper: 41 container_name: gatekeeper 42 + image: fatfingers23/pds_gatekeeper:latest 43 network_mode: host 44 restart: unless-stopped 45 #This gives the container to the access to the PDS folder. Source is the location on your server of that directory
+73
examples/coolify-compose.yml
···
··· 1 + services: 2 + pds: 3 + image: 'ghcr.io/bluesky-social/pds:0.4.182' 4 + volumes: 5 + - '/pds:/pds' 6 + environment: 7 + - SERVICE_URL_PDS_3000 8 + - 'PDS_HOSTNAME=${SERVICE_FQDN_PDS_3000}' 9 + - 'PDS_JWT_SECRET=${SERVICE_HEX_32_JWTSECRET}' 10 + - 'PDS_ADMIN_PASSWORD=${SERVICE_PASSWORD_ADMIN}' 11 + - 'PDS_ADMIN_EMAIL=${PDS_ADMIN_EMAIL}' 12 + - 'PDS_PLC_ROTATION_KEY_K256_PRIVATE_KEY_HEX=${SERVICE_HEX_32_ROTATIONKEY}' 13 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 14 + - 'PDS_BLOBSTORE_DISK_LOCATION=${PDS_DATA_DIRECTORY:-/pds}/blocks' 15 + - 'PDS_BLOB_UPLOAD_LIMIT=${PDS_BLOB_UPLOAD_LIMIT:-104857600}' 16 + - 'PDS_DID_PLC_URL=${PDS_DID_PLC_URL:-https://plc.directory}' 17 + - 'PDS_EMAIL_FROM_ADDRESS=${PDS_EMAIL_FROM_ADDRESS}' 18 + - 'PDS_EMAIL_SMTP_URL=${PDS_EMAIL_SMTP_URL}' 19 + - 'PDS_BSKY_APP_VIEW_URL=${PDS_BSKY_APP_VIEW_URL:-https://api.bsky.app}' 20 + - 'PDS_BSKY_APP_VIEW_DID=${PDS_BSKY_APP_VIEW_DID:-did:web:api.bsky.app}' 21 + - 'PDS_REPORT_SERVICE_URL=${PDS_REPORT_SERVICE_URL:-https://mod.bsky.app/xrpc/com.atproto.moderation.createReport}' 22 + - 'PDS_REPORT_SERVICE_DID=${PDS_REPORT_SERVICE_DID:-did:plc:ar7c4by46qjdydhdevvrndac}' 23 + - 'PDS_CRAWLERS=${PDS_CRAWLERS:-https://bsky.network}' 24 + - 'LOG_ENABLED=${LOG_ENABLED:-true}' 25 + command: "sh -c '\n set -euo pipefail\n echo \"Installing required packages and pdsadmin...\"\n apk add --no-cache openssl curl bash jq coreutils gnupg util-linux-misc >/dev/null\n curl -o /usr/local/bin/pdsadmin.sh https://raw.githubusercontent.com/bluesky-social/pds/main/pdsadmin.sh\n chmod 700 /usr/local/bin/pdsadmin.sh\n ln -sf /usr/local/bin/pdsadmin.sh /usr/local/bin/pdsadmin\n echo \"Creating an empty pds.env file so pdsadmin works...\"\n touch ${PDS_DATA_DIRECTORY}/pds.env\n echo \"Launching PDS, enjoy!...\"\n exec node --enable-source-maps index.js\n'\n" 26 + healthcheck: 27 + test: 28 + - CMD 29 + - wget 30 + - '--spider' 31 + - 'http://127.0.0.1:3000/xrpc/_health' 32 + interval: 5s 33 + timeout: 10s 34 + retries: 10 35 + gatekeeper: 36 + container_name: gatekeeper 37 + image: 'fatfingers23/pds_gatekeeper:latest' 38 + restart: unless-stopped 39 + volumes: 40 + - '/pds:/pds' 41 + environment: 42 + - 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}' 43 + - 'PDS_BASE_URL=http://pds:3000' 44 + - GATEKEEPER_HOST=0.0.0.0 45 + depends_on: 46 + - pds 47 + healthcheck: 48 + test: 49 + - CMD 50 + - timeout 51 + - '1' 52 + - bash 53 + - '-c' 54 + - 'cat < /dev/null > /dev/tcp/0.0.0.0/8080' 55 + interval: 10s 56 + timeout: 5s 57 + retries: 3 58 + start_period: 10s 59 + labels: 60 + - traefik.enable=true 61 + - 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))' 62 + - traefik.http.routers.pds-gatekeeper.entrypoints=https 63 + - traefik.http.routers.pds-gatekeeper.tls=true 64 + - traefik.http.routers.pds-gatekeeper.priority=100 65 + - traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors 66 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080 67 + - traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http 68 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH' 69 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*' 70 + - 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*' 71 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100 72 + - traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true 73 + - traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true
+1 -1
justfile
··· 2 docker buildx build \ 3 --platform linux/arm64,linux/amd64 \ 4 --tag fatfingers23/pds_gatekeeper:latest \ 5 - --tag fatfingers23/pds_gatekeeper:0.1.0.1 \ 6 --push .
··· 2 docker buildx build \ 3 --platform linux/arm64,linux/amd64 \ 4 --tag fatfingers23/pds_gatekeeper:latest \ 5 + --tag fatfingers23/pds_gatekeeper:0.1.0.3 \ 6 --push .
+54 -9
src/main.rs
··· 1 #![warn(clippy::unwrap_used)] 2 use crate::oauth_provider::sign_in; 3 - use crate::xrpc::com_atproto_server::{create_session, get_session, update_email}; 4 use axum::body::Body; 5 use axum::handler::Handler; 6 use axum::http::{Method, header}; ··· 20 use std::{env, net::SocketAddr}; 21 use tower_governor::GovernorLayer; 22 use tower_governor::governor::GovernorConfigBuilder; 23 use tower_http::compression::CompressionLayer; 24 use tower_http::cors::{Any, CorsLayer}; 25 use tracing::log; ··· 91 let pds_env_location = 92 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 93 94 - dotenvy::from_path(Path::new(&pds_env_location))?; 95 - let pds_root = env::var("PDS_DATA_DIRECTORY")?; 96 let account_db_url = format!("{pds_root}/account.sqlite"); 97 98 let account_options = SqliteConnectOptions::new() ··· 165 let create_session_governor_conf = GovernorConfigBuilder::default() 166 .per_second(60) 167 .burst_size(5) 168 .finish() 169 - .expect("failed to create governor config. this should not happen and is a bug"); 170 171 // Create a second config with the same settings for the other endpoint 172 let sign_in_governor_conf = GovernorConfigBuilder::default() 173 .per_second(60) 174 .burst_size(5) 175 .finish() 176 - .expect("failed to create governor config. this should not happen and is a bug"); 177 178 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 179 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 180 let interval = Duration::from_secs(60); 181 // a separate background task to clean up 182 std::thread::spawn(move || { ··· 184 std::thread::sleep(interval); 185 create_session_governor_limiter.retain_recent(); 186 sign_in_governor_limiter.retain_recent(); 187 } 188 }); 189 ··· 194 195 let app = Router::new() 196 .route("/", get(root_handler)) 197 - .route( 198 - "/xrpc/com.atproto.server.getSession", 199 - get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)), 200 - ) 201 .route( 202 "/xrpc/com.atproto.server.updateEmail", 203 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), ··· 209 .route( 210 "/xrpc/com.atproto.server.createSession", 211 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 212 ) 213 .layer(CompressionLayer::new()) 214 .layer(cors)
··· 1 #![warn(clippy::unwrap_used)] 2 use crate::oauth_provider::sign_in; 3 + use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email}; 4 use axum::body::Body; 5 use axum::handler::Handler; 6 use axum::http::{Method, header}; ··· 20 use std::{env, net::SocketAddr}; 21 use tower_governor::GovernorLayer; 22 use tower_governor::governor::GovernorConfigBuilder; 23 + use tower_governor::key_extractor::SmartIpKeyExtractor; 24 use tower_http::compression::CompressionLayer; 25 use tower_http::cors::{Any, CorsLayer}; 26 use tracing::log; ··· 92 let pds_env_location = 93 env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string()); 94 95 + let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location)); 96 + if let Err(e) = result_of_finding_pds_env { 97 + log::error!( 98 + "Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}" 99 + ); 100 + } 101 + 102 + let pds_root = 103 + env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file"); 104 let account_db_url = format!("{pds_root}/account.sqlite"); 105 106 let account_options = SqliteConnectOptions::new() ··· 173 let create_session_governor_conf = GovernorConfigBuilder::default() 174 .per_second(60) 175 .burst_size(5) 176 + .key_extractor(SmartIpKeyExtractor) 177 .finish() 178 + .expect("failed to create governor config for create session. this should not happen and is a bug"); 179 180 // Create a second config with the same settings for the other endpoint 181 let sign_in_governor_conf = GovernorConfigBuilder::default() 182 .per_second(60) 183 .burst_size(5) 184 + .key_extractor(SmartIpKeyExtractor) 185 .finish() 186 + .expect( 187 + "failed to create governor config for sign in. this should not happen and is a bug", 188 + ); 189 + 190 + let create_account_limiter_time: Option<String> = 191 + env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok(); 192 + let create_account_limiter_burst: Option<String> = 193 + env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok(); 194 + 195 + //Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally 196 + let mut create_account_governor_conf = GovernorConfigBuilder::default(); 197 + if create_account_limiter_time.is_some() { 198 + let time = create_account_limiter_time 199 + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set") 200 + .parse::<u64>() 201 + .expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer"); 202 + create_account_governor_conf.per_second(time); 203 + } 204 + 205 + if create_account_limiter_burst.is_some() { 206 + let burst = create_account_limiter_burst 207 + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set") 208 + .parse::<u32>() 209 + .expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer"); 210 + create_account_governor_conf.burst_size(burst); 211 + } 212 + 213 + let create_account_governor_conf = create_account_governor_conf 214 + .key_extractor(SmartIpKeyExtractor) 215 + .finish().expect( 216 + "failed to create governor config for create account. this should not happen and is a bug", 217 + ); 218 219 let create_session_governor_limiter = create_session_governor_conf.limiter().clone(); 220 let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone(); 221 + let create_account_governor_limiter = create_account_governor_conf.limiter().clone(); 222 + 223 let interval = Duration::from_secs(60); 224 // a separate background task to clean up 225 std::thread::spawn(move || { ··· 227 std::thread::sleep(interval); 228 create_session_governor_limiter.retain_recent(); 229 sign_in_governor_limiter.retain_recent(); 230 + create_account_governor_limiter.retain_recent(); 231 } 232 }); 233 ··· 238 239 let app = Router::new() 240 .route("/", get(root_handler)) 241 + .route("/xrpc/com.atproto.server.getSession", get(get_session)) 242 .route( 243 "/xrpc/com.atproto.server.updateEmail", 244 post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)), ··· 250 .route( 251 "/xrpc/com.atproto.server.createSession", 252 post(create_session.layer(GovernorLayer::new(create_session_governor_conf))), 253 + ) 254 + .route( 255 + "/xrpc/com.atproto.server.createAccount", 256 + post(create_account).layer(GovernorLayer::new(create_account_governor_conf)), 257 ) 258 .layer(CompressionLayer::new()) 259 .layer(cors)
+73 -39
src/middleware.rs
··· 12 #[derive(Clone, Debug)] 13 pub struct Did(pub Option<String>); 14 15 #[derive(Serialize, Deserialize)] 16 pub struct TokenClaims { 17 pub sub: String, 18 } 19 20 pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { 21 - let token = extract_bearer(req.headers()); 22 23 - match token { 24 - Ok(token) => { 25 - match token { 26 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 27 .expect("Error creating an error response"), 28 - Some(token) => { 29 - let token = UntrustedToken::new(&token); 30 - if token.is_err() { 31 - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 32 - .expect("Error creating an error response"); 33 - } 34 - let parsed_token = token.expect("Already checked for error"); 35 - let claims: Result<Claims<TokenClaims>, ValidationError> = 36 - parsed_token.deserialize_claims_unchecked(); 37 - if claims.is_err() { 38 - return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 39 - .expect("Error creating an error response"); 40 - } 41 42 - let key = Hs256Key::new( 43 - env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"), 44 - ); 45 - let token: Result<Token<TokenClaims>, ValidationError> = 46 - Hs256.validator(&key).validate(&parsed_token); 47 - if token.is_err() { 48 - return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "") 49 - .expect("Error creating an error response"); 50 } 51 - let token = token.expect("Already checked for error,"); 52 - //Not going to worry about expiration since it still goes to the PDS 53 - req.extensions_mut() 54 - .insert(Did(Some(token.claims().custom.sub.clone()))); 55 next.run(req).await 56 } 57 } ··· 64 } 65 } 66 67 - fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> { 68 match headers.get(axum::http::header::AUTHORIZATION) { 69 None => Ok(None), 70 - Some(hv) => match hv.to_str() { 71 - Err(_) => Err("Authorization header is not valid".into()), 72 - Ok(s) => { 73 - // Accept forms like: "Bearer <token>" (case-sensitive for the scheme here) 74 - let mut parts = s.splitn(2, ' '); 75 - match (parts.next(), parts.next()) { 76 - (Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())), 77 - _ => Err("Authorization header must be in format 'Bearer <token>'".into()), 78 } 79 } 80 - }, 81 } 82 }
··· 12 #[derive(Clone, Debug)] 13 pub struct Did(pub Option<String>); 14 15 + #[derive(Clone, Copy, Debug, PartialEq, Eq)] 16 + pub enum AuthScheme { 17 + Bearer, 18 + DPoP, 19 + } 20 + 21 #[derive(Serialize, Deserialize)] 22 pub struct TokenClaims { 23 pub sub: String, 24 } 25 26 pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse { 27 + let auth = extract_auth(req.headers()); 28 29 + match auth { 30 + Ok(auth_opt) => { 31 + match auth_opt { 32 None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "") 33 .expect("Error creating an error response"), 34 + Some((scheme, token_str)) => { 35 + // For Bearer, validate JWT and extract DID from `sub`. 36 + // For DPoP, we currently only pass through and do not validate here; insert None DID. 37 + match scheme { 38 + AuthScheme::Bearer => { 39 + let token = UntrustedToken::new(&token_str); 40 + if token.is_err() { 41 + return json_error_response( 42 + StatusCode::BAD_REQUEST, 43 + "TokenRequired", 44 + "", 45 + ) 46 + .expect("Error creating an error response"); 47 + } 48 + let parsed_token = token.expect("Already checked for error"); 49 + let claims: Result<Claims<TokenClaims>, ValidationError> = 50 + parsed_token.deserialize_claims_unchecked(); 51 + if claims.is_err() { 52 + return json_error_response( 53 + StatusCode::BAD_REQUEST, 54 + "TokenRequired", 55 + "", 56 + ) 57 + .expect("Error creating an error response"); 58 + } 59 60 + let key = Hs256Key::new( 61 + env::var("PDS_JWT_SECRET") 62 + .expect("PDS_JWT_SECRET not set in the pds.env"), 63 + ); 64 + let token: Result<Token<TokenClaims>, ValidationError> = 65 + Hs256.validator(&key).validate(&parsed_token); 66 + if token.is_err() { 67 + return json_error_response( 68 + StatusCode::BAD_REQUEST, 69 + "InvalidToken", 70 + "", 71 + ) 72 + .expect("Error creating an error response"); 73 + } 74 + let token = token.expect("Already checked for error,"); 75 + req.extensions_mut() 76 + .insert(Did(Some(token.claims().custom.sub.clone()))); 77 + } 78 + AuthScheme::DPoP => { 79 + //Not going to worry about oauth email update for now, just always forward to the PDS 80 + req.extensions_mut().insert(Did(None)); 81 + } 82 } 83 + 84 next.run(req).await 85 } 86 } ··· 93 } 94 } 95 96 + fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> { 97 match headers.get(axum::http::header::AUTHORIZATION) { 98 None => Ok(None), 99 + Some(hv) => { 100 + match hv.to_str() { 101 + Err(_) => Err("Authorization header is not valid".into()), 102 + Ok(s) => { 103 + // Accept forms like: "Bearer <token>" or "DPoP <token>" (case-sensitive for the scheme here) 104 + let mut parts = s.splitn(2, ' '); 105 + match (parts.next(), parts.next()) { 106 + (Some("Bearer"), Some(tok)) if !tok.is_empty() => 107 + Ok(Some((AuthScheme::Bearer, tok.to_string()))), 108 + (Some("DPoP"), Some(tok)) if !tok.is_empty() => 109 + Ok(Some((AuthScheme::DPoP, tok.to_string()))), 110 + _ => Err("Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'".into()), 111 + } 112 } 113 } 114 + } 115 } 116 }
+2 -1
src/oauth_provider.rs
··· 13 pub struct SignInRequest { 14 pub username: String, 15 pub password: String, 16 - pub remember: bool, 17 pub locale: String, 18 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 19 pub email_otp: Option<String>,
··· 13 pub struct SignInRequest { 14 pub username: String, 15 pub password: String, 16 + #[serde(skip_serializing_if = "Option::is_none")] 17 + pub remember: Option<bool>, 18 pub locale: String, 19 #[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")] 20 pub email_otp: Option<String>,
+94 -47
src/xrpc/com_atproto_server.rs
··· 147 //If email auth is set it is to either turn on or off 2fa 148 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 149 150 - // Email update asked for 151 - if email_auth_update { 152 - let email = payload.email.clone(); 153 - let email_confirmed = sqlx::query_as::<_, (String,)>( 154 - "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", 155 - ) 156 - .bind(&email) 157 - .fetch_optional(&state.account_pool) 158 - .await 159 - .map_err(|_| StatusCode::BAD_REQUEST)?; 160 - 161 - //Since the email is already confirmed we can enable 2fa 162 - return match email_confirmed { 163 - None => Err(StatusCode::BAD_REQUEST), 164 - Some(did_row) => { 165 - let _ = sqlx::query( 166 - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", 167 - ) 168 - .bind(&did_row.0) 169 - .execute(&state.pds_gatekeeper_pool) 170 - .await 171 - .map_err(|_| StatusCode::BAD_REQUEST)?; 172 - 173 - Ok(StatusCode::OK.into_response()) 174 - } 175 - }; 176 - } 177 178 - // User wants auth turned off 179 - if !email_auth_update && !email_auth_not_set { 180 - //User wants auth turned off and has a token 181 - if let Some(token) = &payload.token { 182 - let token_found = sqlx::query_as::<_, (String,)>( 183 - "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", 184 ) 185 - .bind(token) 186 - .bind(&did.0) 187 .fetch_optional(&state.account_pool) 188 .await 189 - .map_err(|_| StatusCode::BAD_REQUEST)?; 190 191 - if token_found.is_some() { 192 - let _ = sqlx::query( 193 - "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", 194 ) 195 - .bind(&did.0) 196 - .execute(&state.pds_gatekeeper_pool) 197 - .await 198 - .map_err(|_| StatusCode::BAD_REQUEST)?; 199 200 - return Ok(StatusCode::OK.into_response()); 201 - } else { 202 - return Err(StatusCode::BAD_REQUEST); 203 } 204 } 205 } 206 - 207 // Updating the actual email address by sending it on to the PDS 208 let uri = format!( 209 "{}{}", ··· 259 ProxiedResult::Passthrough(resp) => Ok(resp), 260 } 261 }
··· 147 //If email auth is set it is to either turn on or off 2fa 148 let email_auth_update = payload.email_auth_factor.unwrap_or(false); 149 150 + //This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS 151 + //This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented 152 + let did_is_not_empty = did.0.is_some(); 153 154 + if did_is_not_empty { 155 + // Email update asked for 156 + if email_auth_update { 157 + let email = payload.email.clone(); 158 + let email_confirmed = match sqlx::query_as::<_, (String,)>( 159 + "SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?", 160 ) 161 + .bind(&email) 162 .fetch_optional(&state.account_pool) 163 .await 164 + { 165 + Ok(row) => row, 166 + Err(err) => { 167 + log::error!("Error checking if email is confirmed: {err}"); 168 + return Err(StatusCode::BAD_REQUEST); 169 + } 170 + }; 171 + 172 + //Since the email is already confirmed we can enable 2fa 173 + return match email_confirmed { 174 + None => Err(StatusCode::BAD_REQUEST), 175 + Some(did_row) => { 176 + let _ = sqlx::query( 177 + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1", 178 + ) 179 + .bind(&did_row.0) 180 + .execute(&state.pds_gatekeeper_pool) 181 + .await 182 + .map_err(|_| StatusCode::BAD_REQUEST)?; 183 184 + Ok(StatusCode::OK.into_response()) 185 + } 186 + }; 187 + } 188 + 189 + // User wants auth turned off 190 + if !email_auth_update && !email_auth_not_set { 191 + //User wants auth turned off and has a token 192 + if let Some(token) = &payload.token { 193 + let token_found = match sqlx::query_as::<_, (String,)>( 194 + "SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'", 195 ) 196 + .bind(token) 197 + .bind(&did.0) 198 + .fetch_optional(&state.account_pool) 199 + .await{ 200 + Ok(token) => token, 201 + Err(err) => { 202 + log::error!("Error checking if token is valid: {err}"); 203 + return Err(StatusCode::BAD_REQUEST); 204 + } 205 + }; 206 207 + return if token_found.is_some() { 208 + //TODO I think there may be a bug here and need to do some retry logic 209 + // First try was erroring, seconds was allowing 210 + match sqlx::query( 211 + "INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0", 212 + ) 213 + .bind(&did.0) 214 + .execute(&state.pds_gatekeeper_pool) 215 + .await { 216 + Ok(_) => {} 217 + Err(err) => { 218 + log::error!("Error updating email auth: {err}"); 219 + return Err(StatusCode::BAD_REQUEST); 220 + } 221 + } 222 + 223 + Ok(StatusCode::OK.into_response()) 224 + } else { 225 + Err(StatusCode::BAD_REQUEST) 226 + }; 227 } 228 } 229 } 230 // Updating the actual email address by sending it on to the PDS 231 let uri = format!( 232 "{}{}", ··· 282 ProxiedResult::Passthrough(resp) => Ok(resp), 283 } 284 } 285 + 286 + pub async fn create_account( 287 + State(state): State<AppState>, 288 + mut req: Request, 289 + ) -> Result<Response<Body>, StatusCode> { 290 + //TODO if I add the block of only accounts authenticated just take the body as json here and grab the lxm token. No middle ware is needed 291 + 292 + let uri = format!( 293 + "{}{}", 294 + state.pds_base_url, "/xrpc/com.atproto.server.createAccount" 295 + ); 296 + 297 + // Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact 298 + *req.uri_mut() = uri.parse().map_err(|_| StatusCode::BAD_REQUEST)?; 299 + 300 + let proxied = state 301 + .reverse_proxy_client 302 + .request(req) 303 + .await 304 + .map_err(|_| StatusCode::BAD_REQUEST)? 305 + .into_response(); 306 + 307 + Ok(proxied) 308 + }