+4
-4
Cargo.lock
+4
-4
Cargo.lock
···
656
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
657
dependencies = [
658
"libc",
659
-
"windows-sys 0.52.0",
660
]
661
662
[[package]]
···
1392
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
1393
dependencies = [
1394
"cfg-if",
1395
-
"windows-targets 0.48.5",
1396
]
1397
1398
[[package]]
···
1690
1691
[[package]]
1692
name = "pds_gatekeeper"
1693
-
version = "0.1.0"
1694
dependencies = [
1695
"anyhow",
1696
"aws-lc-rs",
···
2136
"errno",
2137
"libc",
2138
"linux-raw-sys",
2139
-
"windows-sys 0.52.0",
2140
]
2141
2142
[[package]]
···
656
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
657
dependencies = [
658
"libc",
659
+
"windows-sys 0.59.0",
660
]
661
662
[[package]]
···
1392
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
1393
dependencies = [
1394
"cfg-if",
1395
+
"windows-targets 0.52.6",
1396
]
1397
1398
[[package]]
···
1690
1691
[[package]]
1692
name = "pds_gatekeeper"
1693
+
version = "0.1.2"
1694
dependencies = [
1695
"anyhow",
1696
"aws-lc-rs",
···
2136
"errno",
2137
"libc",
2138
"linux-raw-sys",
2139
+
"windows-sys 0.59.0",
2140
]
2141
2142
[[package]]
+4
-5
Cargo.toml
+4
-5
Cargo.toml
···
1
[package]
2
name = "pds_gatekeeper"
3
-
version = "0.1.0"
4
edition = "2024"
5
license = "MIT"
6
···
15
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
16
hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] }
17
tower-http = { version = "0.6", features = ["cors", "compression-zstd"] }
18
-
tower_governor = "0.8.0"
19
hex = "0.4"
20
jwt-compact = { version = "0.8.0", features = ["es256k"] }
21
scrypt = "0.11"
22
-
#lettre = { version = "0.11.18", default-features = false, features = ["pool", "tokio1-rustls", "smtp-transport", "hostname", "builder"] }
23
-
#lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] }
24
aws-lc-rs = "1.13.0"
25
lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] }
26
-
rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] }
27
handlebars = { version = "6.3.2", features = ["rust-embed"] }
28
rust-embed = "8.7.2"
29
axum-template = { version = "3.0.0", features = ["handlebars"] }
···
1
[package]
2
name = "pds_gatekeeper"
3
+
version = "0.1.2"
4
edition = "2024"
5
license = "MIT"
6
···
15
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
16
hyper-util = { version = "0.1.16", features = ["client", "client-legacy"] }
17
tower-http = { version = "0.6", features = ["cors", "compression-zstd"] }
18
+
tower_governor = { version = "0.8.0", features = ["axum", "tracing"] }
19
hex = "0.4"
20
jwt-compact = { version = "0.8.0", features = ["es256k"] }
21
scrypt = "0.11"
22
+
#Leaveing these two cause I think it is needed by the email crate for ssl
23
aws-lc-rs = "1.13.0"
24
+
rustls = { version = "0.23", default-features = false, features = ["tls12", "std", "logging", "aws_lc_rs"] }
25
lettre = { version = "0.11", default-features = false, features = ["builder", "webpki-roots", "rustls", "aws-lc-rs", "smtp-transport", "tokio1", "tokio1-rustls"] }
26
handlebars = { version = "6.3.2", features = ["rust-embed"] }
27
rust-embed = "8.7.2"
28
axum-template = { version = "3.0.0", features = ["handlebars"] }
+58
-2
README.md
+58
-2
README.md
···
37
```yml
38
gatekeeper:
39
container_name: gatekeeper
40
-
image: fatfingers23/pds_gatekeeper:arm-latest
41
network_mode: host
42
restart: unless-stopped
43
#This gives the container to the access to the PDS folder. Source is the location on your server of that directory
···
49
- pds
50
```
51
52
## Caddy setup
53
54
For the reverse proxy I use caddy. This part is what overwrites the endpoints and proxies them to PDS gatekeeper to add
···
60
path /xrpc/com.atproto.server.getSession
61
path /xrpc/com.atproto.server.updateEmail
62
path /xrpc/com.atproto.server.createSession
63
path /@atproto/oauth-provider/~api/sign-in
64
}
65
···
79
path /xrpc/com.atproto.server.getSession
80
path /xrpc/com.atproto.server.updateEmail
81
path /xrpc/com.atproto.server.createSession
82
path /@atproto/oauth-provider/~api/sign-in
83
}
84
85
handle @gatekeeper {
86
-
reverse_proxy http://localhost:8080
87
}
88
89
reverse_proxy http://localhost:3000
···
113
`GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1`
114
115
`GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080`
···
37
```yml
38
gatekeeper:
39
container_name: gatekeeper
40
+
image: fatfingers23/pds_gatekeeper:latest
41
network_mode: host
42
restart: unless-stopped
43
#This gives the container to the access to the PDS folder. Source is the location on your server of that directory
···
49
- pds
50
```
51
52
+
For Coolify, if you're using Traefik as your proxy you'll need to make sure the labels for the container are set up correctly. A full example can be found at [./examples/coolify-compose.yml](./examples/coolify-compose.yml).
53
+
54
+
```yml
55
+
gatekeeper:
56
+
container_name: gatekeeper
57
+
image: 'fatfingers23/pds_gatekeeper:latest'
58
+
restart: unless-stopped
59
+
volumes:
60
+
- '/pds:/pds'
61
+
environment:
62
+
- 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}'
63
+
- 'PDS_BASE_URL=http://pds:3000'
64
+
- GATEKEEPER_HOST=0.0.0.0
65
+
depends_on:
66
+
- pds
67
+
healthcheck:
68
+
test:
69
+
- CMD
70
+
- timeout
71
+
- '1'
72
+
- bash
73
+
- '-c'
74
+
- 'cat < /dev/null > /dev/tcp/0.0.0.0/8080'
75
+
interval: 10s
76
+
timeout: 5s
77
+
retries: 3
78
+
start_period: 10s
79
+
labels:
80
+
- traefik.enable=true
81
+
- 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))'
82
+
- traefik.http.routers.pds-gatekeeper.entrypoints=https
83
+
- traefik.http.routers.pds-gatekeeper.tls=true
84
+
- traefik.http.routers.pds-gatekeeper.priority=100
85
+
- traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors
86
+
- traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080
87
+
- traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http
88
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH'
89
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*'
90
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*'
91
+
- traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100
92
+
- traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true
93
+
- traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true
94
+
```
95
+
96
## Caddy setup
97
98
For the reverse proxy I use caddy. This part is what overwrites the endpoints and proxies them to PDS gatekeeper to add
···
104
path /xrpc/com.atproto.server.getSession
105
path /xrpc/com.atproto.server.updateEmail
106
path /xrpc/com.atproto.server.createSession
107
+
path /xrpc/com.atproto.server.createAccount
108
path /@atproto/oauth-provider/~api/sign-in
109
}
110
···
124
path /xrpc/com.atproto.server.getSession
125
path /xrpc/com.atproto.server.updateEmail
126
path /xrpc/com.atproto.server.createSession
127
+
path /xrpc/com.atproto.server.createAccount
128
path /@atproto/oauth-provider/~api/sign-in
129
}
130
131
handle @gatekeeper {
132
+
reverse_proxy http://localhost:8080 {
133
+
#Makes sure the cloudflare ip is proxied and able to be picked up by pds gatekeeper
134
+
header_up X-Forwarded-For {http.request.header.CF-Connecting-IP}
135
+
}
136
}
137
138
reverse_proxy http://localhost:3000
···
162
`GATEKEEPER_HOST` - Host for pds gatekeeper. Defaults to `127.0.0.1`
163
164
`GATEKEEPER_PORT` - Port for pds gatekeeper. Defaults to `8080`
165
+
166
+
`GATEKEEPER_CREATE_ACCOUNT_PER_SECOND` - Sets how often it takes a count off the limiter. example if you hit the rate
167
+
limit of 5 and set to 60, then in 60 seconds you will be able to make one more. Or in 5 minutes be able to make 5 more.
168
+
169
+
`GATEKEEPER_CREATE_ACCOUNT_BURST` - Sets how many requests can be made in a burst. In the prior example this is where
170
+
the 5 comes from. Example can set this to 10 to allow for 10 requests in a burst, and after 60 seconds it will drop one
171
+
off.
+1
examples/Caddyfile
+1
examples/Caddyfile
+1
-1
examples/compose.yml
+1
-1
examples/compose.yml
···
39
WATCHTOWER_SCHEDULE: "@midnight"
40
gatekeeper:
41
container_name: gatekeeper
42
-
image: fatfingers23/pds_gatekeeper:arm-latest
43
network_mode: host
44
restart: unless-stopped
45
#This gives the container to the access to the PDS folder. Source is the location on your server of that directory
+73
examples/coolify-compose.yml
+73
examples/coolify-compose.yml
···
···
1
+
services:
2
+
pds:
3
+
image: 'ghcr.io/bluesky-social/pds:0.4.182'
4
+
volumes:
5
+
- '/pds:/pds'
6
+
environment:
7
+
- SERVICE_URL_PDS_3000
8
+
- 'PDS_HOSTNAME=${SERVICE_FQDN_PDS_3000}'
9
+
- 'PDS_JWT_SECRET=${SERVICE_HEX_32_JWTSECRET}'
10
+
- 'PDS_ADMIN_PASSWORD=${SERVICE_PASSWORD_ADMIN}'
11
+
- 'PDS_ADMIN_EMAIL=${PDS_ADMIN_EMAIL}'
12
+
- 'PDS_PLC_ROTATION_KEY_K256_PRIVATE_KEY_HEX=${SERVICE_HEX_32_ROTATIONKEY}'
13
+
- 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}'
14
+
- 'PDS_BLOBSTORE_DISK_LOCATION=${PDS_DATA_DIRECTORY:-/pds}/blocks'
15
+
- 'PDS_BLOB_UPLOAD_LIMIT=${PDS_BLOB_UPLOAD_LIMIT:-104857600}'
16
+
- 'PDS_DID_PLC_URL=${PDS_DID_PLC_URL:-https://plc.directory}'
17
+
- 'PDS_EMAIL_FROM_ADDRESS=${PDS_EMAIL_FROM_ADDRESS}'
18
+
- 'PDS_EMAIL_SMTP_URL=${PDS_EMAIL_SMTP_URL}'
19
+
- 'PDS_BSKY_APP_VIEW_URL=${PDS_BSKY_APP_VIEW_URL:-https://api.bsky.app}'
20
+
- 'PDS_BSKY_APP_VIEW_DID=${PDS_BSKY_APP_VIEW_DID:-did:web:api.bsky.app}'
21
+
- 'PDS_REPORT_SERVICE_URL=${PDS_REPORT_SERVICE_URL:-https://mod.bsky.app/xrpc/com.atproto.moderation.createReport}'
22
+
- 'PDS_REPORT_SERVICE_DID=${PDS_REPORT_SERVICE_DID:-did:plc:ar7c4by46qjdydhdevvrndac}'
23
+
- 'PDS_CRAWLERS=${PDS_CRAWLERS:-https://bsky.network}'
24
+
- 'LOG_ENABLED=${LOG_ENABLED:-true}'
25
+
command: "sh -c '\n set -euo pipefail\n echo \"Installing required packages and pdsadmin...\"\n apk add --no-cache openssl curl bash jq coreutils gnupg util-linux-misc >/dev/null\n curl -o /usr/local/bin/pdsadmin.sh https://raw.githubusercontent.com/bluesky-social/pds/main/pdsadmin.sh\n chmod 700 /usr/local/bin/pdsadmin.sh\n ln -sf /usr/local/bin/pdsadmin.sh /usr/local/bin/pdsadmin\n echo \"Creating an empty pds.env file so pdsadmin works...\"\n touch ${PDS_DATA_DIRECTORY}/pds.env\n echo \"Launching PDS, enjoy!...\"\n exec node --enable-source-maps index.js\n'\n"
26
+
healthcheck:
27
+
test:
28
+
- CMD
29
+
- wget
30
+
- '--spider'
31
+
- 'http://127.0.0.1:3000/xrpc/_health'
32
+
interval: 5s
33
+
timeout: 10s
34
+
retries: 10
35
+
gatekeeper:
36
+
container_name: gatekeeper
37
+
image: 'fatfingers23/pds_gatekeeper:latest'
38
+
restart: unless-stopped
39
+
volumes:
40
+
- '/pds:/pds'
41
+
environment:
42
+
- 'PDS_DATA_DIRECTORY=${PDS_DATA_DIRECTORY:-/pds}'
43
+
- 'PDS_BASE_URL=http://pds:3000'
44
+
- GATEKEEPER_HOST=0.0.0.0
45
+
depends_on:
46
+
- pds
47
+
healthcheck:
48
+
test:
49
+
- CMD
50
+
- timeout
51
+
- '1'
52
+
- bash
53
+
- '-c'
54
+
- 'cat < /dev/null > /dev/tcp/0.0.0.0/8080'
55
+
interval: 10s
56
+
timeout: 5s
57
+
retries: 3
58
+
start_period: 10s
59
+
labels:
60
+
- traefik.enable=true
61
+
- 'traefik.http.routers.pds-gatekeeper.rule=Host(`yourpds.com`) && (Path(`/xrpc/com.atproto.server.getSession`) || Path(`/xrpc/com.atproto.server.updateEmail`) || Path(`/xrpc/com.atproto.server.createSession`) || Path(`/xrpc/com.atproto.server.createAccount`) || Path(`/@atproto/oauth-provider/~api/sign-in`))'
62
+
- traefik.http.routers.pds-gatekeeper.entrypoints=https
63
+
- traefik.http.routers.pds-gatekeeper.tls=true
64
+
- traefik.http.routers.pds-gatekeeper.priority=100
65
+
- traefik.http.routers.pds-gatekeeper.middlewares=gatekeeper-cors
66
+
- traefik.http.services.pds-gatekeeper.loadbalancer.server.port=8080
67
+
- traefik.http.services.pds-gatekeeper.loadbalancer.server.scheme=http
68
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowmethods=GET,POST,PUT,DELETE,OPTIONS,PATCH'
69
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowheaders=*'
70
+
- 'traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolalloworiginlist=*'
71
+
- traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolmaxage=100
72
+
- traefik.http.middlewares.gatekeeper-cors.headers.addvaryheader=true
73
+
- traefik.http.middlewares.gatekeeper-cors.headers.accesscontrolallowcredentials=true
+1
-1
justfile
+1
-1
justfile
+54
-9
src/main.rs
+54
-9
src/main.rs
···
1
#![warn(clippy::unwrap_used)]
2
use crate::oauth_provider::sign_in;
3
-
use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
4
use axum::body::Body;
5
use axum::handler::Handler;
6
use axum::http::{Method, header};
···
20
use std::{env, net::SocketAddr};
21
use tower_governor::GovernorLayer;
22
use tower_governor::governor::GovernorConfigBuilder;
23
use tower_http::compression::CompressionLayer;
24
use tower_http::cors::{Any, CorsLayer};
25
use tracing::log;
···
91
let pds_env_location =
92
env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
93
94
-
dotenvy::from_path(Path::new(&pds_env_location))?;
95
-
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
96
let account_db_url = format!("{pds_root}/account.sqlite");
97
98
let account_options = SqliteConnectOptions::new()
···
165
let create_session_governor_conf = GovernorConfigBuilder::default()
166
.per_second(60)
167
.burst_size(5)
168
.finish()
169
-
.expect("failed to create governor config. this should not happen and is a bug");
170
171
// Create a second config with the same settings for the other endpoint
172
let sign_in_governor_conf = GovernorConfigBuilder::default()
173
.per_second(60)
174
.burst_size(5)
175
.finish()
176
-
.expect("failed to create governor config. this should not happen and is a bug");
177
178
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
179
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
180
let interval = Duration::from_secs(60);
181
// a separate background task to clean up
182
std::thread::spawn(move || {
···
184
std::thread::sleep(interval);
185
create_session_governor_limiter.retain_recent();
186
sign_in_governor_limiter.retain_recent();
187
}
188
});
189
···
194
195
let app = Router::new()
196
.route("/", get(root_handler))
197
-
.route(
198
-
"/xrpc/com.atproto.server.getSession",
199
-
get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
200
-
)
201
.route(
202
"/xrpc/com.atproto.server.updateEmail",
203
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···
209
.route(
210
"/xrpc/com.atproto.server.createSession",
211
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
212
)
213
.layer(CompressionLayer::new())
214
.layer(cors)
···
1
#![warn(clippy::unwrap_used)]
2
use crate::oauth_provider::sign_in;
3
+
use crate::xrpc::com_atproto_server::{create_account, create_session, get_session, update_email};
4
use axum::body::Body;
5
use axum::handler::Handler;
6
use axum::http::{Method, header};
···
20
use std::{env, net::SocketAddr};
21
use tower_governor::GovernorLayer;
22
use tower_governor::governor::GovernorConfigBuilder;
23
+
use tower_governor::key_extractor::SmartIpKeyExtractor;
24
use tower_http::compression::CompressionLayer;
25
use tower_http::cors::{Any, CorsLayer};
26
use tracing::log;
···
92
let pds_env_location =
93
env::var("PDS_ENV_LOCATION").unwrap_or_else(|_| "/pds/pds.env".to_string());
94
95
+
let result_of_finding_pds_env = dotenvy::from_path(Path::new(&pds_env_location));
96
+
if let Err(e) = result_of_finding_pds_env {
97
+
log::error!(
98
+
"Error loading pds.env file (ignore if you loaded your variables in the environment somehow else): {e}"
99
+
);
100
+
}
101
+
102
+
let pds_root =
103
+
env::var("PDS_DATA_DIRECTORY").expect("PDS_DATA_DIRECTORY is not set in your pds.env file");
104
let account_db_url = format!("{pds_root}/account.sqlite");
105
106
let account_options = SqliteConnectOptions::new()
···
173
let create_session_governor_conf = GovernorConfigBuilder::default()
174
.per_second(60)
175
.burst_size(5)
176
+
.key_extractor(SmartIpKeyExtractor)
177
.finish()
178
+
.expect("failed to create governor config for create session. this should not happen and is a bug");
179
180
// Create a second config with the same settings for the other endpoint
181
let sign_in_governor_conf = GovernorConfigBuilder::default()
182
.per_second(60)
183
.burst_size(5)
184
+
.key_extractor(SmartIpKeyExtractor)
185
.finish()
186
+
.expect(
187
+
"failed to create governor config for sign in. this should not happen and is a bug",
188
+
);
189
+
190
+
let create_account_limiter_time: Option<String> =
191
+
env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
192
+
let create_account_limiter_burst: Option<String> =
193
+
env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
194
+
195
+
//Default should be 608 requests per 5 minutes, PDS is 300 per 500 so will never hit it ideally
196
+
let mut create_account_governor_conf = GovernorConfigBuilder::default();
197
+
if create_account_limiter_time.is_some() {
198
+
let time = create_account_limiter_time
199
+
.expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
200
+
.parse::<u64>()
201
+
.expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
202
+
create_account_governor_conf.per_second(time);
203
+
}
204
+
205
+
if create_account_limiter_burst.is_some() {
206
+
let burst = create_account_limiter_burst
207
+
.expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
208
+
.parse::<u32>()
209
+
.expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
210
+
create_account_governor_conf.burst_size(burst);
211
+
}
212
+
213
+
let create_account_governor_conf = create_account_governor_conf
214
+
.key_extractor(SmartIpKeyExtractor)
215
+
.finish().expect(
216
+
"failed to create governor config for create account. this should not happen and is a bug",
217
+
);
218
219
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
220
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
221
+
let create_account_governor_limiter = create_account_governor_conf.limiter().clone();
222
+
223
let interval = Duration::from_secs(60);
224
// a separate background task to clean up
225
std::thread::spawn(move || {
···
227
std::thread::sleep(interval);
228
create_session_governor_limiter.retain_recent();
229
sign_in_governor_limiter.retain_recent();
230
+
create_account_governor_limiter.retain_recent();
231
}
232
});
233
···
238
239
let app = Router::new()
240
.route("/", get(root_handler))
241
+
.route("/xrpc/com.atproto.server.getSession", get(get_session))
242
.route(
243
"/xrpc/com.atproto.server.updateEmail",
244
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···
250
.route(
251
"/xrpc/com.atproto.server.createSession",
252
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
253
+
)
254
+
.route(
255
+
"/xrpc/com.atproto.server.createAccount",
256
+
post(create_account).layer(GovernorLayer::new(create_account_governor_conf)),
257
)
258
.layer(CompressionLayer::new())
259
.layer(cors)
+73
-39
src/middleware.rs
+73
-39
src/middleware.rs
···
12
#[derive(Clone, Debug)]
13
pub struct Did(pub Option<String>);
14
15
#[derive(Serialize, Deserialize)]
16
pub struct TokenClaims {
17
pub sub: String,
18
}
19
20
pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
21
-
let token = extract_bearer(req.headers());
22
23
-
match token {
24
-
Ok(token) => {
25
-
match token {
26
None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
27
.expect("Error creating an error response"),
28
-
Some(token) => {
29
-
let token = UntrustedToken::new(&token);
30
-
if token.is_err() {
31
-
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
32
-
.expect("Error creating an error response");
33
-
}
34
-
let parsed_token = token.expect("Already checked for error");
35
-
let claims: Result<Claims<TokenClaims>, ValidationError> =
36
-
parsed_token.deserialize_claims_unchecked();
37
-
if claims.is_err() {
38
-
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
39
-
.expect("Error creating an error response");
40
-
}
41
42
-
let key = Hs256Key::new(
43
-
env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
44
-
);
45
-
let token: Result<Token<TokenClaims>, ValidationError> =
46
-
Hs256.validator(&key).validate(&parsed_token);
47
-
if token.is_err() {
48
-
return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
49
-
.expect("Error creating an error response");
50
}
51
-
let token = token.expect("Already checked for error,");
52
-
//Not going to worry about expiration since it still goes to the PDS
53
-
req.extensions_mut()
54
-
.insert(Did(Some(token.claims().custom.sub.clone())));
55
next.run(req).await
56
}
57
}
···
64
}
65
}
66
67
-
fn extract_bearer(headers: &HeaderMap) -> Result<Option<String>, String> {
68
match headers.get(axum::http::header::AUTHORIZATION) {
69
None => Ok(None),
70
-
Some(hv) => match hv.to_str() {
71
-
Err(_) => Err("Authorization header is not valid".into()),
72
-
Ok(s) => {
73
-
// Accept forms like: "Bearer <token>" (case-sensitive for the scheme here)
74
-
let mut parts = s.splitn(2, ' ');
75
-
match (parts.next(), parts.next()) {
76
-
(Some("Bearer"), Some(tok)) if !tok.is_empty() => Ok(Some(tok.to_string())),
77
-
_ => Err("Authorization header must be in format 'Bearer <token>'".into()),
78
}
79
}
80
-
},
81
}
82
}
···
12
#[derive(Clone, Debug)]
13
pub struct Did(pub Option<String>);
14
15
+
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
16
+
pub enum AuthScheme {
17
+
Bearer,
18
+
DPoP,
19
+
}
20
+
21
#[derive(Serialize, Deserialize)]
22
pub struct TokenClaims {
23
pub sub: String,
24
}
25
26
pub async fn extract_did(mut req: Request, next: Next) -> impl IntoResponse {
27
+
let auth = extract_auth(req.headers());
28
29
+
match auth {
30
+
Ok(auth_opt) => {
31
+
match auth_opt {
32
None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
33
.expect("Error creating an error response"),
34
+
Some((scheme, token_str)) => {
35
+
// For Bearer, validate JWT and extract DID from `sub`.
36
+
// For DPoP, we currently only pass through and do not validate here; insert None DID.
37
+
match scheme {
38
+
AuthScheme::Bearer => {
39
+
let token = UntrustedToken::new(&token_str);
40
+
if token.is_err() {
41
+
return json_error_response(
42
+
StatusCode::BAD_REQUEST,
43
+
"TokenRequired",
44
+
"",
45
+
)
46
+
.expect("Error creating an error response");
47
+
}
48
+
let parsed_token = token.expect("Already checked for error");
49
+
let claims: Result<Claims<TokenClaims>, ValidationError> =
50
+
parsed_token.deserialize_claims_unchecked();
51
+
if claims.is_err() {
52
+
return json_error_response(
53
+
StatusCode::BAD_REQUEST,
54
+
"TokenRequired",
55
+
"",
56
+
)
57
+
.expect("Error creating an error response");
58
+
}
59
60
+
let key = Hs256Key::new(
61
+
env::var("PDS_JWT_SECRET")
62
+
.expect("PDS_JWT_SECRET not set in the pds.env"),
63
+
);
64
+
let token: Result<Token<TokenClaims>, ValidationError> =
65
+
Hs256.validator(&key).validate(&parsed_token);
66
+
if token.is_err() {
67
+
return json_error_response(
68
+
StatusCode::BAD_REQUEST,
69
+
"InvalidToken",
70
+
"",
71
+
)
72
+
.expect("Error creating an error response");
73
+
}
74
+
let token = token.expect("Already checked for error,");
75
+
req.extensions_mut()
76
+
.insert(Did(Some(token.claims().custom.sub.clone())));
77
+
}
78
+
AuthScheme::DPoP => {
79
+
//Not going to worry about oauth email update for now, just always forward to the PDS
80
+
req.extensions_mut().insert(Did(None));
81
+
}
82
}
83
+
84
next.run(req).await
85
}
86
}
···
93
}
94
}
95
96
+
fn extract_auth(headers: &HeaderMap) -> Result<Option<(AuthScheme, String)>, String> {
97
match headers.get(axum::http::header::AUTHORIZATION) {
98
None => Ok(None),
99
+
Some(hv) => {
100
+
match hv.to_str() {
101
+
Err(_) => Err("Authorization header is not valid".into()),
102
+
Ok(s) => {
103
+
// Accept forms like: "Bearer <token>" or "DPoP <token>" (case-sensitive for the scheme here)
104
+
let mut parts = s.splitn(2, ' ');
105
+
match (parts.next(), parts.next()) {
106
+
(Some("Bearer"), Some(tok)) if !tok.is_empty() =>
107
+
Ok(Some((AuthScheme::Bearer, tok.to_string()))),
108
+
(Some("DPoP"), Some(tok)) if !tok.is_empty() =>
109
+
Ok(Some((AuthScheme::DPoP, tok.to_string()))),
110
+
_ => Err("Authorization header must be in format 'Bearer <token>' or 'DPoP <token>'".into()),
111
+
}
112
}
113
}
114
+
}
115
}
116
}
+2
-1
src/oauth_provider.rs
+2
-1
src/oauth_provider.rs
···
13
pub struct SignInRequest {
14
pub username: String,
15
pub password: String,
16
+
#[serde(skip_serializing_if = "Option::is_none")]
17
+
pub remember: Option<bool>,
18
pub locale: String,
19
#[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
20
pub email_otp: Option<String>,
+94
-47
src/xrpc/com_atproto_server.rs
+94
-47
src/xrpc/com_atproto_server.rs
···
147
//If email auth is set it is to either turn on or off 2fa
148
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149
150
-
// Email update asked for
151
-
if email_auth_update {
152
-
let email = payload.email.clone();
153
-
let email_confirmed = sqlx::query_as::<_, (String,)>(
154
-
"SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
155
-
)
156
-
.bind(&email)
157
-
.fetch_optional(&state.account_pool)
158
-
.await
159
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
160
-
161
-
//Since the email is already confirmed we can enable 2fa
162
-
return match email_confirmed {
163
-
None => Err(StatusCode::BAD_REQUEST),
164
-
Some(did_row) => {
165
-
let _ = sqlx::query(
166
-
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
167
-
)
168
-
.bind(&did_row.0)
169
-
.execute(&state.pds_gatekeeper_pool)
170
-
.await
171
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
172
-
173
-
Ok(StatusCode::OK.into_response())
174
-
}
175
-
};
176
-
}
177
178
-
// User wants auth turned off
179
-
if !email_auth_update && !email_auth_not_set {
180
-
//User wants auth turned off and has a token
181
-
if let Some(token) = &payload.token {
182
-
let token_found = sqlx::query_as::<_, (String,)>(
183
-
"SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
184
)
185
-
.bind(token)
186
-
.bind(&did.0)
187
.fetch_optional(&state.account_pool)
188
.await
189
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
190
191
-
if token_found.is_some() {
192
-
let _ = sqlx::query(
193
-
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
194
)
195
-
.bind(&did.0)
196
-
.execute(&state.pds_gatekeeper_pool)
197
-
.await
198
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
199
200
-
return Ok(StatusCode::OK.into_response());
201
-
} else {
202
-
return Err(StatusCode::BAD_REQUEST);
203
}
204
}
205
}
206
-
207
// Updating the actual email address by sending it on to the PDS
208
let uri = format!(
209
"{}{}",
···
259
ProxiedResult::Passthrough(resp) => Ok(resp),
260
}
261
}
···
147
//If email auth is set it is to either turn on or off 2fa
148
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149
150
+
//This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
151
+
//This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
152
+
let did_is_not_empty = did.0.is_some();
153
154
+
if did_is_not_empty {
155
+
// Email update asked for
156
+
if email_auth_update {
157
+
let email = payload.email.clone();
158
+
let email_confirmed = match sqlx::query_as::<_, (String,)>(
159
+
"SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
160
)
161
+
.bind(&email)
162
.fetch_optional(&state.account_pool)
163
.await
164
+
{
165
+
Ok(row) => row,
166
+
Err(err) => {
167
+
log::error!("Error checking if email is confirmed: {err}");
168
+
return Err(StatusCode::BAD_REQUEST);
169
+
}
170
+
};
171
+
172
+
//Since the email is already confirmed we can enable 2fa
173
+
return match email_confirmed {
174
+
None => Err(StatusCode::BAD_REQUEST),
175
+
Some(did_row) => {
176
+
let _ = sqlx::query(
177
+
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
178
+
)
179
+
.bind(&did_row.0)
180
+
.execute(&state.pds_gatekeeper_pool)
181
+
.await
182
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
183
184
+
Ok(StatusCode::OK.into_response())
185
+
}
186
+
};
187
+
}
188
+
189
+
// User wants auth turned off
190
+
if !email_auth_update && !email_auth_not_set {
191
+
//User wants auth turned off and has a token
192
+
if let Some(token) = &payload.token {
193
+
let token_found = match sqlx::query_as::<_, (String,)>(
194
+
"SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
195
)
196
+
.bind(token)
197
+
.bind(&did.0)
198
+
.fetch_optional(&state.account_pool)
199
+
.await{
200
+
Ok(token) => token,
201
+
Err(err) => {
202
+
log::error!("Error checking if token is valid: {err}");
203
+
return Err(StatusCode::BAD_REQUEST);
204
+
}
205
+
};
206
207
+
return if token_found.is_some() {
208
+
//TODO I think there may be a bug here and need to do some retry logic
209
+
// First try was erroring, seconds was allowing
210
+
match sqlx::query(
211
+
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
212
+
)
213
+
.bind(&did.0)
214
+
.execute(&state.pds_gatekeeper_pool)
215
+
.await {
216
+
Ok(_) => {}
217
+
Err(err) => {
218
+
log::error!("Error updating email auth: {err}");
219
+
return Err(StatusCode::BAD_REQUEST);
220
+
}
221
+
}
222
+
223
+
Ok(StatusCode::OK.into_response())
224
+
} else {
225
+
Err(StatusCode::BAD_REQUEST)
226
+
};
227
}
228
}
229
}
230
// Updating the actual email address by sending it on to the PDS
231
let uri = format!(
232
"{}{}",
···
282
ProxiedResult::Passthrough(resp) => Ok(resp),
283
}
284
}
285
+
286
+
pub async fn create_account(
287
+
State(state): State<AppState>,
288
+
mut req: Request,
289
+
) -> Result<Response<Body>, StatusCode> {
290
+
//TODO if I add the block of only accounts authenticated just take the body as json here and grab the lxm token. No middle ware is needed
291
+
292
+
let uri = format!(
293
+
"{}{}",
294
+
state.pds_base_url, "/xrpc/com.atproto.server.createAccount"
295
+
);
296
+
297
+
// Rewrite the URI to point at the upstream PDS; keep headers, method, and body intact
298
+
*req.uri_mut() = uri.parse().map_err(|_| StatusCode::BAD_REQUEST)?;
299
+
300
+
let proxied = state
301
+
.reverse_proxy_client
302
+
.request(req)
303
+
.await
304
+
.map_err(|_| StatusCode::BAD_REQUEST)?
305
+
.into_response();
306
+
307
+
Ok(proxied)
308
+
}