+10
Cargo.lock
+10
Cargo.lock
···
287
287
dependencies = [
288
288
"android-tzdata",
289
289
"iana-time-zone",
290
+
"js-sys",
290
291
"num-traits",
292
+
"wasm-bindgen",
291
293
"windows-link",
292
294
]
293
295
···
1652
1654
name = "pds_gatekeeper"
1653
1655
version = "0.1.0"
1654
1656
dependencies = [
1657
+
"anyhow",
1655
1658
"axum",
1656
1659
"axum-template",
1660
+
"chrono",
1657
1661
"dotenvy",
1658
1662
"handlebars",
1659
1663
"hex",
1660
1664
"hyper-util",
1661
1665
"jwt-compact",
1662
1666
"lettre",
1667
+
"rand 0.9.2",
1663
1668
"rust-embed",
1664
1669
"scrypt",
1665
1670
"serde",
1666
1671
"serde_json",
1672
+
"sha2",
1667
1673
"sqlx",
1668
1674
"tokio",
1669
1675
"tower-http",
···
2393
2399
dependencies = [
2394
2400
"base64",
2395
2401
"bytes",
2402
+
"chrono",
2396
2403
"crc",
2397
2404
"crossbeam-queue",
2398
2405
"either",
···
2470
2477
"bitflags",
2471
2478
"byteorder",
2472
2479
"bytes",
2480
+
"chrono",
2473
2481
"crc",
2474
2482
"digest",
2475
2483
"dotenvy",
···
2511
2519
"base64",
2512
2520
"bitflags",
2513
2521
"byteorder",
2522
+
"chrono",
2514
2523
"crc",
2515
2524
"dotenvy",
2516
2525
"etcetera",
···
2545
2554
checksum = "c2d12fe70b2c1b4401038055f90f151b78208de1f9f89a7dbfd41587a10c3eea"
2546
2555
dependencies = [
2547
2556
"atoi",
2557
+
"chrono",
2548
2558
"flume",
2549
2559
"futures-channel",
2550
2560
"futures-core",
+5
-1
Cargo.toml
+5
-1
Cargo.toml
···
6
6
[dependencies]
7
7
axum = { version = "0.8.4", features = ["macros", "json"] }
8
8
tokio = { version = "1.47.1", features = ["rt-multi-thread", "macros", "signal"] }
9
-
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate"] }
9
+
sqlx = { version = "0.8.6", features = ["runtime-tokio-rustls", "sqlite", "migrate", "chrono"] }
10
10
dotenvy = "0.15.7"
11
11
serde = { version = "1.0", features = ["derive"] }
12
12
serde_json = "1.0"
···
22
22
handlebars = { version = "6.3.2", features = ["rust-embed"] }
23
23
rust-embed = "8.7.2"
24
24
axum-template = { version = "3.0.0", features = ["handlebars"] }
25
+
rand = "0.9.2"
26
+
anyhow = "1.0.99"
27
+
chrono = "0.4.41"
28
+
sha2 = "0.10"
+5
-6
README.md
+5
-6
README.md
···
12
12
13
13
## 2FA
14
14
15
-
- [x] Ability to turn on/off 2FA
16
-
- [x] getSession overwrite to set the `emailAuthFactor` flag if the user has 2FA turned on
17
-
- [x] send an email using the `PDS_EMAIL_SMTP_URL` with a handlebar email template like Bluesky's 2FA sign in email.
18
-
- [ ] generate a 2FA code
19
-
- [ ] createSession gatekeeping (It does stop logins, just eh, doesn't actually send a real code or check it yet)
20
-
- [ ] oauth endpoint gatekeeping
15
+
- Overrides The login endpoint to add 2FA for both Bluesky client logged in and OAuth logins
16
+
- Overrides the settings endpoints as well. As long as you have a confirmed email you can turn on 2FA
21
17
22
18
## Captcha on Create Account
23
19
···
25
21
26
22
# Setup
27
23
24
+
We are getting close! Testing now
25
+
28
26
Nothing here yet! If you are brave enough to try before full release, let me know and I'll help you set it up.
29
27
But I want to run it locally on my own PDS first to test run it a bit.
30
28
···
37
35
path /xrpc/com.atproto.server.getSession
38
36
path /xrpc/com.atproto.server.updateEmail
39
37
path /xrpc/com.atproto.server.createSession
38
+
path /@atproto/oauth-provider/~api/sign-in
40
39
}
41
40
42
41
handle @gatekeeper {
-3
migrations_bells_and_whistles/.keep
-3
migrations_bells_and_whistles/.keep
+524
src/helpers.rs
+524
src/helpers.rs
···
1
+
use crate::AppState;
2
+
use crate::helpers::TokenCheckError::InvalidToken;
3
+
use anyhow::anyhow;
4
+
use axum::body::{Body, to_bytes};
5
+
use axum::extract::Request;
6
+
use axum::http::header::CONTENT_TYPE;
7
+
use axum::http::{HeaderMap, StatusCode, Uri};
8
+
use axum::response::{IntoResponse, Response};
9
+
use axum_template::TemplateEngine;
10
+
use chrono::Utc;
11
+
use lettre::message::{MultiPart, SinglePart, header};
12
+
use lettre::{AsyncTransport, Message};
13
+
use rand::Rng;
14
+
use serde::de::DeserializeOwned;
15
+
use serde_json::{Map, Value};
16
+
use sha2::{Digest, Sha256};
17
+
use sqlx::SqlitePool;
18
+
use tracing::{error, log};
19
+
20
+
///Used to generate the email 2fa code
21
+
const UPPERCASE_BASE32_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZ234567";
22
+
23
+
/// The result of a proxied call that attempts to parse JSON.
24
+
pub enum ProxiedResult<T> {
25
+
/// Successfully parsed JSON body along with original response headers.
26
+
Parsed { value: T, _headers: HeaderMap },
27
+
/// Could not or should not parse: return the original (or rebuilt) response as-is.
28
+
Passthrough(Response<Body>),
29
+
}
30
+
31
+
/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
32
+
/// the successful response body as JSON into `T`.
33
+
///
34
+
pub async fn proxy_get_json<T>(
35
+
state: &AppState,
36
+
mut req: Request,
37
+
path: &str,
38
+
) -> Result<ProxiedResult<T>, StatusCode>
39
+
where
40
+
T: DeserializeOwned,
41
+
{
42
+
let uri = format!("{}{}", state.pds_base_url, path);
43
+
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
44
+
45
+
let result = state
46
+
.reverse_proxy_client
47
+
.request(req)
48
+
.await
49
+
.map_err(|_| StatusCode::BAD_REQUEST)?
50
+
.into_response();
51
+
52
+
if result.status() != StatusCode::OK {
53
+
return Ok(ProxiedResult::Passthrough(result));
54
+
}
55
+
56
+
let response_headers = result.headers().clone();
57
+
let body = result.into_body();
58
+
let body_bytes = to_bytes(body, usize::MAX)
59
+
.await
60
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
61
+
62
+
match serde_json::from_slice::<T>(&body_bytes) {
63
+
Ok(value) => Ok(ProxiedResult::Parsed {
64
+
value,
65
+
_headers: response_headers,
66
+
}),
67
+
Err(err) => {
68
+
error!(%err, "failed to parse proxied JSON response; returning original body");
69
+
let mut builder = Response::builder().status(StatusCode::OK);
70
+
if let Some(headers) = builder.headers_mut() {
71
+
*headers = response_headers;
72
+
}
73
+
let resp = builder
74
+
.body(Body::from(body_bytes))
75
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
76
+
Ok(ProxiedResult::Passthrough(resp))
77
+
}
78
+
}
79
+
}
80
+
81
+
/// Build a JSON error response with the required Content-Type header
82
+
/// Content-Type: application/json;charset=utf-8
83
+
/// Body shape: { "error": string, "message": string }
84
+
pub fn json_error_response(
85
+
status: StatusCode,
86
+
error: impl Into<String>,
87
+
message: impl Into<String>,
88
+
) -> Result<Response<Body>, StatusCode> {
89
+
let body_str = match serde_json::to_string(&serde_json::json!({
90
+
"error": error.into(),
91
+
"message": message.into(),
92
+
})) {
93
+
Ok(s) => s,
94
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
95
+
};
96
+
97
+
Response::builder()
98
+
.status(status)
99
+
.header(CONTENT_TYPE, "application/json;charset=utf-8")
100
+
.body(Body::from(body_str))
101
+
.map_err(|_| StatusCode::BAD_REQUEST)
102
+
}
103
+
104
+
/// Build a JSON error response with the required Content-Type header
105
+
/// Content-Type: application/json (oauth endpoint does not like utf ending)
106
+
/// Body shape: { "error": string, "error_description": string }
107
+
pub fn oauth_json_error_response(
108
+
status: StatusCode,
109
+
error: impl Into<String>,
110
+
message: impl Into<String>,
111
+
) -> Result<Response<Body>, StatusCode> {
112
+
let body_str = match serde_json::to_string(&serde_json::json!({
113
+
"error": error.into(),
114
+
"error_description": message.into(),
115
+
})) {
116
+
Ok(s) => s,
117
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
118
+
};
119
+
120
+
Response::builder()
121
+
.status(status)
122
+
.header(CONTENT_TYPE, "application/json")
123
+
.body(Body::from(body_str))
124
+
.map_err(|_| StatusCode::BAD_REQUEST)
125
+
}
126
+
127
+
/// Creates a random token of 10 characters for email 2FA
128
+
pub fn get_random_token() -> String {
129
+
let mut rng = rand::rng();
130
+
131
+
let mut full_code = String::with_capacity(10);
132
+
for _ in 0..10 {
133
+
let idx = rng.random_range(0..UPPERCASE_BASE32_CHARS.len());
134
+
full_code.push(UPPERCASE_BASE32_CHARS[idx] as char);
135
+
}
136
+
137
+
//The PDS implementation creates in lowercase, then converts to uppercase.
138
+
//Just going a head and doing uppercase here.
139
+
let slice_one = &full_code[0..5].to_ascii_uppercase();
140
+
let slice_two = &full_code[5..10].to_ascii_uppercase();
141
+
format!("{slice_one}-{slice_two}")
142
+
}
143
+
144
+
pub enum TokenCheckError {
145
+
InvalidToken,
146
+
ExpiredToken,
147
+
}
148
+
149
+
pub enum AuthResult {
150
+
WrongIdentityOrPassword,
151
+
/// The string here is the email address to create a hint for oauth
152
+
TwoFactorRequired(String),
153
+
/// User does not have 2FA enabled, or using an app password, or passes it
154
+
ProxyThrough,
155
+
TokenCheckFailed(TokenCheckError),
156
+
}
157
+
158
+
pub enum IdentifierType {
159
+
Email,
160
+
Did,
161
+
Handle,
162
+
}
163
+
164
+
impl IdentifierType {
165
+
fn what_is_it(identifier: String) -> Self {
166
+
if identifier.contains("@") {
167
+
IdentifierType::Email
168
+
} else if identifier.contains("did:") {
169
+
IdentifierType::Did
170
+
} else {
171
+
IdentifierType::Handle
172
+
}
173
+
}
174
+
}
175
+
176
+
/// Creates a hex string from the password and salt to find app passwords
177
+
fn scrypt_hex(password: &str, salt: &str) -> anyhow::Result<String> {
178
+
let params = scrypt::Params::new(14, 8, 1, 64)?;
179
+
let mut derived = [0u8; 64];
180
+
scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived)?;
181
+
Ok(hex::encode(derived))
182
+
}
183
+
184
+
/// Hashes the app password. did is used as the salt.
185
+
pub fn hash_app_password(did: &str, password: &str) -> anyhow::Result<String> {
186
+
let mut hasher = Sha256::new();
187
+
hasher.update(did.as_bytes());
188
+
let sha = hasher.finalize();
189
+
let salt = hex::encode(&sha[..16]);
190
+
let hash_hex = scrypt_hex(password, &salt)?;
191
+
Ok(format!("{salt}:{hash_hex}"))
192
+
}
193
+
194
+
async fn verify_password(password: &str, password_scrypt: &str) -> anyhow::Result<bool> {
195
+
// Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
196
+
let mut parts = password_scrypt.splitn(2, ':');
197
+
let salt = match parts.next() {
198
+
Some(s) if !s.is_empty() => s,
199
+
_ => return Ok(false),
200
+
};
201
+
let stored_hash_hex = match parts.next() {
202
+
Some(h) if !h.is_empty() => h,
203
+
_ => return Ok(false),
204
+
};
205
+
206
+
// Derive using the shared helper and compare
207
+
let derived_hex = match scrypt_hex(password, salt) {
208
+
Ok(h) => h,
209
+
Err(_) => return Ok(false),
210
+
};
211
+
212
+
Ok(derived_hex.as_str() == stored_hash_hex)
213
+
}
214
+
215
+
/// Handles the auth checks along with sending a 2fa email
216
+
pub async fn preauth_check(
217
+
state: &AppState,
218
+
identifier: &str,
219
+
password: &str,
220
+
two_factor_code: Option<String>,
221
+
oauth: bool,
222
+
) -> anyhow::Result<AuthResult> {
223
+
// Determine identifier type
224
+
let id_type = IdentifierType::what_is_it(identifier.to_string());
225
+
226
+
// Query account DB for did and passwordScrypt based on identifier type
227
+
let account_row: Option<(String, String, String, String)> = match id_type {
228
+
IdentifierType::Email => {
229
+
sqlx::query_as::<_, (String, String, String, String)>(
230
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
231
+
FROM actor
232
+
LEFT JOIN account ON actor.did = account.did
233
+
where account.email = ? LIMIT 1",
234
+
)
235
+
.bind(identifier)
236
+
.fetch_optional(&state.account_pool)
237
+
.await?
238
+
}
239
+
IdentifierType::Handle => {
240
+
sqlx::query_as::<_, (String, String, String, String)>(
241
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
242
+
FROM actor
243
+
LEFT JOIN account ON actor.did = account.did
244
+
where actor.handle = ? LIMIT 1",
245
+
)
246
+
.bind(identifier)
247
+
.fetch_optional(&state.account_pool)
248
+
.await?
249
+
}
250
+
IdentifierType::Did => {
251
+
sqlx::query_as::<_, (String, String, String, String)>(
252
+
"SELECT account.did, account.passwordScrypt, account.email, actor.handle
253
+
FROM actor
254
+
LEFT JOIN account ON actor.did = account.did
255
+
where account.did = ? LIMIT 1",
256
+
)
257
+
.bind(identifier)
258
+
.fetch_optional(&state.account_pool)
259
+
.await?
260
+
}
261
+
};
262
+
263
+
if let Some((did, password_scrypt, email, handle)) = account_row {
264
+
// Verify password before proceeding to 2FA email step
265
+
let verified = verify_password(password, &password_scrypt).await?;
266
+
if !verified {
267
+
if oauth {
268
+
//OAuth does not allow app password logins so just go ahead and send it along it's way
269
+
return Ok(AuthResult::WrongIdentityOrPassword);
270
+
}
271
+
//Theres a chance it could be an app password so check that as well
272
+
return match verify_app_password(&state.account_pool, &did, password).await {
273
+
Ok(valid) => {
274
+
if valid {
275
+
//Was a valid app password up to the PDS now
276
+
Ok(AuthResult::ProxyThrough)
277
+
} else {
278
+
Ok(AuthResult::WrongIdentityOrPassword)
279
+
}
280
+
}
281
+
Err(err) => {
282
+
log::error!("Error checking the app password: {err}");
283
+
Err(err)
284
+
}
285
+
};
286
+
}
287
+
288
+
// Check two-factor requirement for this DID in the gatekeeper DB
289
+
let required_opt = sqlx::query_as::<_, (u8,)>(
290
+
"SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
291
+
)
292
+
.bind(did.clone())
293
+
.fetch_optional(&state.pds_gatekeeper_pool)
294
+
.await?;
295
+
296
+
let two_factor_required = match required_opt {
297
+
Some(row) => row.0 != 0,
298
+
None => false,
299
+
};
300
+
301
+
if two_factor_required {
302
+
//Two factor is required and a taken was provided
303
+
if let Some(two_factor_code) = two_factor_code {
304
+
//if the two_factor_code is set need to see if we have a valid token
305
+
if !two_factor_code.is_empty() {
306
+
return match assert_valid_token(
307
+
&state.account_pool,
308
+
did.clone(),
309
+
two_factor_code,
310
+
)
311
+
.await
312
+
{
313
+
Ok(_) => {
314
+
let result_of_cleanup =
315
+
delete_all_email_tokens(&state.account_pool, did.clone()).await;
316
+
if result_of_cleanup.is_err() {
317
+
log::error!(
318
+
"There was an error deleting the email tokens after login: {:?}",
319
+
result_of_cleanup.err()
320
+
)
321
+
}
322
+
Ok(AuthResult::ProxyThrough)
323
+
}
324
+
Err(err) => Ok(AuthResult::TokenCheckFailed(err)),
325
+
};
326
+
}
327
+
}
328
+
329
+
return match create_two_factor_token(&state.account_pool, did).await {
330
+
Ok(code) => {
331
+
let mut email_data = Map::new();
332
+
email_data.insert("token".to_string(), Value::from(code.clone()));
333
+
email_data.insert("handle".to_string(), Value::from(handle.clone()));
334
+
let email_body = state
335
+
.template_engine
336
+
.render("two_factor_code.hbs", email_data)?;
337
+
338
+
let email_message = Message::builder()
339
+
//TODO prob get the proper type in the state
340
+
.from(state.mailer_from.parse()?)
341
+
.to(email.parse()?)
342
+
.subject("Sign in to Bluesky")
343
+
.multipart(
344
+
MultiPart::alternative() // This is composed of two parts.
345
+
.singlepart(
346
+
SinglePart::builder()
347
+
.header(header::ContentType::TEXT_PLAIN)
348
+
.body(format!("We received a sign-in request for the account @{handle}. Use the code: {code} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.")), // Every message should have a plain text fallback.
349
+
)
350
+
.singlepart(
351
+
SinglePart::builder()
352
+
.header(header::ContentType::TEXT_HTML)
353
+
.body(email_body),
354
+
),
355
+
)?;
356
+
match state.mailer.send(email_message).await {
357
+
Ok(_) => Ok(AuthResult::TwoFactorRequired(mask_email(email))),
358
+
Err(err) => {
359
+
log::error!("Error sending the 2FA email: {err}");
360
+
Err(anyhow!(err))
361
+
}
362
+
}
363
+
}
364
+
Err(err) => {
365
+
log::error!("error on creating a 2fa token: {err}");
366
+
Err(anyhow!(err))
367
+
}
368
+
};
369
+
}
370
+
}
371
+
372
+
// No local 2FA requirement (or account not found)
373
+
Ok(AuthResult::ProxyThrough)
374
+
}
375
+
376
+
pub async fn create_two_factor_token(
377
+
account_db: &SqlitePool,
378
+
did: String,
379
+
) -> anyhow::Result<String> {
380
+
let purpose = "2fa_code";
381
+
382
+
let token = get_random_token();
383
+
let right_now = Utc::now();
384
+
385
+
let res = sqlx::query(
386
+
"INSERT INTO email_token (purpose, did, token, requestedAt)
387
+
VALUES (?, ?, ?, ?)
388
+
ON CONFLICT(purpose, did) DO UPDATE SET
389
+
token=excluded.token,
390
+
requestedAt=excluded.requestedAt
391
+
WHERE did=excluded.did",
392
+
)
393
+
.bind(purpose)
394
+
.bind(&did)
395
+
.bind(&token)
396
+
.bind(right_now)
397
+
.execute(account_db)
398
+
.await;
399
+
400
+
match res {
401
+
Ok(_) => Ok(token),
402
+
Err(err) => {
403
+
log::error!("Error creating a two factor token: {err}");
404
+
Err(anyhow::anyhow!(err))
405
+
}
406
+
}
407
+
}
408
+
409
+
pub async fn delete_all_email_tokens(account_db: &SqlitePool, did: String) -> anyhow::Result<()> {
410
+
sqlx::query("DELETE FROM email_token WHERE did = ?")
411
+
.bind(did)
412
+
.execute(account_db)
413
+
.await?;
414
+
Ok(())
415
+
}
416
+
417
+
pub async fn assert_valid_token(
418
+
account_db: &SqlitePool,
419
+
did: String,
420
+
token: String,
421
+
) -> Result<(), TokenCheckError> {
422
+
let token_upper = token.to_ascii_uppercase();
423
+
let purpose = "2fa_code";
424
+
425
+
let row: Option<(String,)> = sqlx::query_as(
426
+
"SELECT requestedAt FROM email_token WHERE purpose = ? AND did = ? AND token = ? LIMIT 1",
427
+
)
428
+
.bind(purpose)
429
+
.bind(did)
430
+
.bind(token_upper)
431
+
.fetch_optional(account_db)
432
+
.await
433
+
.map_err(|err| {
434
+
log::error!("Error getting the 2fa token: {err}");
435
+
InvalidToken
436
+
})?;
437
+
438
+
match row {
439
+
None => Err(InvalidToken),
440
+
Some(row) => {
441
+
// Token lives for 15 minutes
442
+
let expiration_ms = 15 * 60_000;
443
+
444
+
let requested_at_utc = match chrono::DateTime::parse_from_rfc3339(&row.0) {
445
+
Ok(dt) => dt.with_timezone(&Utc),
446
+
Err(_) => {
447
+
return Err(TokenCheckError::InvalidToken);
448
+
}
449
+
};
450
+
451
+
let now = Utc::now();
452
+
let age_ms = (now - requested_at_utc).num_milliseconds();
453
+
let expired = age_ms > expiration_ms;
454
+
if expired {
455
+
return Err(TokenCheckError::ExpiredToken);
456
+
}
457
+
458
+
Ok(())
459
+
}
460
+
}
461
+
}
462
+
463
+
/// We just need to confirm if it's there or not. Will let the PDS do the actual figuring of permissions
464
+
pub async fn verify_app_password(
465
+
account_db: &SqlitePool,
466
+
did: &str,
467
+
password: &str,
468
+
) -> anyhow::Result<bool> {
469
+
let password_scrypt = hash_app_password(did, password)?;
470
+
471
+
let row: Option<(i64,)> = sqlx::query_as(
472
+
"SELECT Count(*) FROM app_password WHERE did = ? AND passwordScrypt = ? LIMIT 1",
473
+
)
474
+
.bind(did)
475
+
.bind(password_scrypt)
476
+
.fetch_optional(account_db)
477
+
.await?;
478
+
479
+
Ok(match row {
480
+
None => false,
481
+
Some((count,)) => count > 0,
482
+
})
483
+
}
484
+
485
+
/// Mask an email address into a hint like "2***0@p***m".
486
+
pub fn mask_email(email: String) -> String {
487
+
// Basic split on first '@'
488
+
let mut parts = email.splitn(2, '@');
489
+
let local = match parts.next() {
490
+
Some(l) => l,
491
+
None => return email.to_string(),
492
+
};
493
+
let domain_rest = match parts.next() {
494
+
Some(d) if !d.is_empty() => d,
495
+
_ => return email.to_string(),
496
+
};
497
+
498
+
// Helper to mask a single label (keep first and last, middle becomes ***).
499
+
fn mask_label(s: &str) -> String {
500
+
let chars: Vec<char> = s.chars().collect();
501
+
match chars.len() {
502
+
0 => String::new(),
503
+
1 => format!("{}***", chars[0]),
504
+
2 => format!("{}***{}", chars[0], chars[1]),
505
+
_ => format!("{}***{}", chars[0], chars[chars.len() - 1]),
506
+
}
507
+
}
508
+
509
+
// Mask local
510
+
let masked_local = mask_label(local);
511
+
512
+
// Mask first domain label only, keep the rest of the domain intact
513
+
let mut dom_parts = domain_rest.splitn(2, '.');
514
+
let first_label = dom_parts.next().unwrap_or("");
515
+
let rest = dom_parts.next();
516
+
let masked_first = mask_label(first_label);
517
+
let masked_domain = if let Some(rest) = rest {
518
+
format!("{}.{rest}", masked_first)
519
+
} else {
520
+
masked_first
521
+
};
522
+
523
+
format!("{masked_local}@{masked_domain}")
524
+
}
+53
-26
src/main.rs
+53
-26
src/main.rs
···
1
+
#![warn(clippy::unwrap_used)]
2
+
use crate::oauth_provider::sign_in;
1
3
use crate::xrpc::com_atproto_server::{create_session, get_session, update_email};
2
-
use axum::middleware as ax_middleware;
3
-
mod middleware;
4
4
use axum::body::Body;
5
5
use axum::handler::Handler;
6
6
use axum::http::{Method, header};
7
+
use axum::middleware as ax_middleware;
7
8
use axum::routing::post;
8
9
use axum::{Router, routing::get};
9
10
use axum_template::engine::Engine;
···
21
22
use tower_governor::governor::GovernorConfigBuilder;
22
23
use tower_http::compression::CompressionLayer;
23
24
use tower_http::cors::{Any, CorsLayer};
24
-
use tracing::{error, log};
25
+
use tracing::log;
25
26
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
26
27
28
+
pub mod helpers;
29
+
mod middleware;
30
+
mod oauth_provider;
27
31
mod xrpc;
28
32
29
33
type HyperUtilClient = hyper_util::client::legacy::Client<HttpConnector, Body>;
···
34
38
struct EmailTemplates;
35
39
36
40
#[derive(Clone)]
37
-
struct AppState {
41
+
pub struct AppState {
38
42
account_pool: SqlitePool,
39
43
pds_gatekeeper_pool: SqlitePool,
40
44
reverse_proxy_client: HyperUtilClient,
···
73
77
74
78
let intro = "\n\nThis is a PDS gatekeeper\n\nCode: https://tangled.sh/@baileytownsend.dev/pds-gatekeeper\n";
75
79
76
-
let banner = format!(" {}\n{}", body, intro);
80
+
let banner = format!(" {body}\n{intro}");
77
81
78
82
(
79
83
[(header::CONTENT_TYPE, "text/plain; charset=utf-8")],
···
84
88
#[tokio::main]
85
89
async fn main() -> Result<(), Box<dyn std::error::Error>> {
86
90
setup_tracing();
87
-
//TODO prod
91
+
//TODO may need to change where this reads from? Like an env variable for it's location? Or arg?
88
92
dotenvy::from_path(Path::new("./pds.env"))?;
89
93
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
90
-
// let pds_root = "/home/baileytownsend/Documents/code/docker_compose/pds/pds_data";
91
-
let account_db_url = format!("{}/account.sqlite", pds_root);
92
-
log::info!("accounts_db_url: {}", account_db_url);
94
+
let account_db_url = format!("{pds_root}/account.sqlite");
93
95
94
96
let account_options = SqliteConnectOptions::new()
95
-
.journal_mode(SqliteJournalMode::Wal)
96
-
.filename(account_db_url);
97
+
.filename(account_db_url)
98
+
.busy_timeout(Duration::from_secs(5));
97
99
98
100
let account_pool = SqlitePoolOptions::new()
99
101
.max_connections(5)
100
102
.connect_with(account_options)
101
103
.await?;
102
104
103
-
let bells_db_url = format!("{}/pds_gatekeeper.sqlite", pds_root);
105
+
let bells_db_url = format!("{pds_root}/pds_gatekeeper.sqlite");
104
106
let options = SqliteConnectOptions::new()
105
107
.journal_mode(SqliteJournalMode::Wal)
106
108
.filename(bells_db_url)
107
-
.create_if_missing(true);
109
+
.create_if_missing(true)
110
+
.busy_timeout(Duration::from_secs(5));
108
111
let pds_gatekeeper_pool = SqlitePoolOptions::new()
109
112
.max_connections(5)
110
113
.connect_with(options)
111
114
.await?;
112
115
113
-
// Run migrations for the bells_and_whistles database
116
+
// Run migrations for the extra database
114
117
// Note: the migrations are embedded at compile time from the given directory
115
118
// sqlx
116
119
sqlx::migrate!("./migrations")
···
130
133
AsyncSmtpTransport::<Tokio1Executor>::from_url(smtp_url.as_str())?.build();
131
134
//Email templates setup
132
135
let mut hbs = Handlebars::new();
133
-
let _ = hbs.register_embed_templates::<EmailTemplates>();
136
+
137
+
let users_email_directory = env::var("GATEKEEPER_EMAIL_TEMPLATES_DIRECTORY");
138
+
if let Ok(users_email_directory) = users_email_directory {
139
+
hbs.register_template_file(
140
+
"two_factor_code.hbs",
141
+
format!("{users_email_directory}/two_factor_code.hbs"),
142
+
)?;
143
+
} else {
144
+
let _ = hbs.register_embed_templates::<EmailTemplates>();
145
+
}
146
+
147
+
let pds_base_url =
148
+
env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
134
149
135
150
let state = AppState {
136
151
account_pool,
137
152
pds_gatekeeper_pool,
138
153
reverse_proxy_client: client,
139
-
//TODO should be env prob
140
-
pds_base_url: "http://localhost:3000".to_string(),
154
+
pds_base_url,
141
155
mailer,
142
156
mailer_from: sent_from,
143
157
template_engine: Engine::from(hbs),
···
145
159
146
160
// Rate limiting
147
161
//Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
148
-
let governor_conf = GovernorConfigBuilder::default()
162
+
let create_session_governor_conf = GovernorConfigBuilder::default()
149
163
.per_second(60)
150
164
.burst_size(5)
151
165
.finish()
152
-
.unwrap();
153
-
let governor_limiter = governor_conf.limiter().clone();
166
+
.expect("failed to create governor config. this should not happen and is a bug");
167
+
168
+
// Create a second config with the same settings for the other endpoint
169
+
let sign_in_governor_conf = GovernorConfigBuilder::default()
170
+
.per_second(60)
171
+
.burst_size(5)
172
+
.finish()
173
+
.expect("failed to create governor config. this should not happen and is a bug");
174
+
175
+
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
176
+
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
154
177
let interval = Duration::from_secs(60);
155
178
// a separate background task to clean up
156
179
std::thread::spawn(move || {
157
180
loop {
158
181
std::thread::sleep(interval);
159
-
tracing::info!("rate limiting storage size: {}", governor_limiter.len());
160
-
governor_limiter.retain_recent();
182
+
create_session_governor_limiter.retain_recent();
183
+
sign_in_governor_limiter.retain_recent();
161
184
}
162
185
});
163
186
···
176
199
"/xrpc/com.atproto.server.updateEmail",
177
200
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
178
201
)
202
+
.route(
203
+
"/@atproto/oauth-provider/~api/sign-in",
204
+
post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
205
+
)
179
206
.route(
180
207
"/xrpc/com.atproto.server.createSession",
181
-
post(create_session.layer(GovernorLayer::new(governor_conf))),
208
+
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
182
209
)
183
210
.layer(CompressionLayer::new())
184
211
.layer(cors)
185
212
.with_state(state);
186
213
187
-
let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
188
-
let port: u16 = env::var("PORT")
214
+
let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
215
+
let port: u16 = env::var("GATEKEEPER_PORT")
189
216
.ok()
190
217
.and_then(|s| s.parse().ok())
191
218
.unwrap_or(8080);
···
202
229
.with_graceful_shutdown(shutdown_signal());
203
230
204
231
if let Err(err) = server.await {
205
-
error!(error = %err, "server error");
232
+
log::error!("server error:{err}");
206
233
}
207
234
208
235
Ok(())
+19
-34
src/middleware.rs
+19
-34
src/middleware.rs
···
1
-
use crate::xrpc::helpers::json_error_response;
1
+
use crate::helpers::json_error_response;
2
2
use axum::extract::Request;
3
3
use axum::http::{HeaderMap, StatusCode};
4
4
use axum::middleware::Next;
···
7
7
use jwt_compact::{AlgorithmExt, Claims, Token, UntrustedToken, ValidationError};
8
8
use serde::{Deserialize, Serialize};
9
9
use std::env;
10
+
use tracing::log;
10
11
11
12
#[derive(Clone, Debug)]
12
13
pub struct Did(pub Option<String>);
···
22
23
match token {
23
24
Ok(token) => {
24
25
match token {
25
-
None => {
26
-
return json_error_response(
27
-
StatusCode::BAD_REQUEST,
28
-
"TokenRequired",
29
-
"",
30
-
).unwrap();
31
-
}
26
+
None => json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
27
+
.expect("Error creating an error response"),
32
28
Some(token) => {
33
29
let token = UntrustedToken::new(&token);
34
-
//Doing weird unwraps cause I can't do Result for middleware?
35
30
if token.is_err() {
36
-
return json_error_response(
37
-
StatusCode::BAD_REQUEST,
38
-
"TokenRequired",
39
-
"",
40
-
).unwrap();
31
+
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
32
+
.expect("Error creating an error response");
41
33
}
42
-
let parsed_token = token.unwrap();
34
+
let parsed_token = token.expect("Already checked for error");
43
35
let claims: Result<Claims<TokenClaims>, ValidationError> =
44
36
parsed_token.deserialize_claims_unchecked();
45
37
if claims.is_err() {
46
-
return json_error_response(
47
-
StatusCode::BAD_REQUEST,
48
-
"TokenRequired",
49
-
"",
50
-
).unwrap();
38
+
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
39
+
.expect("Error creating an error response");
51
40
}
52
41
53
-
let key = Hs256Key::new(env::var("PDS_JWT_SECRET").unwrap());
42
+
let key = Hs256Key::new(
43
+
env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
44
+
);
54
45
let token: Result<Token<TokenClaims>, ValidationError> =
55
46
Hs256.validator(&key).validate(&parsed_token);
56
47
if token.is_err() {
57
-
return json_error_response(
58
-
StatusCode::BAD_REQUEST,
59
-
"InvalidToken",
60
-
"",
61
-
).unwrap();
48
+
return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
49
+
.expect("Error creating an error response");
62
50
}
63
-
let token = token.unwrap();
51
+
let token = token.expect("Already checked for error,");
64
52
//Not going to worry about expiration since it still goes to the PDS
65
-
66
53
req.extensions_mut()
67
54
.insert(Did(Some(token.claims().custom.sub.clone())));
68
55
next.run(req).await
69
56
}
70
57
}
71
58
}
72
-
Err(_) => {
73
-
return json_error_response(
74
-
StatusCode::BAD_REQUEST,
75
-
"InvalidToken",
76
-
"",
77
-
).unwrap();
59
+
Err(err) => {
60
+
log::error!("Error extracting token: {err}");
61
+
json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
62
+
.expect("Error creating an error response")
78
63
}
79
64
}
80
65
}
+141
src/oauth_provider.rs
+141
src/oauth_provider.rs
···
1
+
use crate::AppState;
2
+
use crate::helpers::{AuthResult, oauth_json_error_response, preauth_check};
3
+
use axum::body::Body;
4
+
use axum::extract::State;
5
+
use axum::http::header::CONTENT_TYPE;
6
+
use axum::http::{HeaderMap, HeaderName, HeaderValue, StatusCode};
7
+
use axum::response::{IntoResponse, Response};
8
+
use axum::{Json, extract};
9
+
use serde::{Deserialize, Serialize};
10
+
use tracing::log;
11
+
12
+
#[derive(Serialize, Deserialize, Clone)]
13
+
pub struct SignInRequest {
14
+
pub username: String,
15
+
pub password: String,
16
+
pub remember: bool,
17
+
pub locale: String,
18
+
#[serde(skip_serializing_if = "Option::is_none", rename = "emailOtp")]
19
+
pub email_otp: Option<String>,
20
+
}
21
+
22
+
pub async fn sign_in(
23
+
State(state): State<AppState>,
24
+
headers: HeaderMap,
25
+
Json(mut payload): extract::Json<SignInRequest>,
26
+
) -> Result<Response<Body>, StatusCode> {
27
+
let identifier = payload.username.clone();
28
+
let password = payload.password.clone();
29
+
let auth_factor_token = payload.email_otp.clone();
30
+
31
+
match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
32
+
Ok(result) => match result {
33
+
AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
34
+
StatusCode::BAD_REQUEST,
35
+
"invalid_request",
36
+
"Invalid identifier or password",
37
+
),
38
+
AuthResult::TwoFactorRequired(masked_email) => {
39
+
// Email sending step can be handled here if needed in the future.
40
+
41
+
// {"error":"second_authentication_factor_required","error_description":"emailOtp authentication factor required (hint: 2***0@p***m)","type":"emailOtp","hint":"2***0@p***m"}
42
+
let body_str = match serde_json::to_string(&serde_json::json!({
43
+
"error": "second_authentication_factor_required",
44
+
"error_description": format!("emailOtp authentication factor required (hint: {})", masked_email),
45
+
"type": "emailOtp",
46
+
"hint": masked_email,
47
+
})) {
48
+
Ok(s) => s,
49
+
Err(_) => return Err(StatusCode::BAD_REQUEST),
50
+
};
51
+
52
+
Response::builder()
53
+
.status(StatusCode::BAD_REQUEST)
54
+
.header(CONTENT_TYPE, "application/json")
55
+
.body(Body::from(body_str))
56
+
.map_err(|_| StatusCode::BAD_REQUEST)
57
+
}
58
+
AuthResult::ProxyThrough => {
59
+
//No 2FA or already passed
60
+
let uri = format!(
61
+
"{}{}",
62
+
state.pds_base_url, "/@atproto/oauth-provider/~api/sign-in"
63
+
);
64
+
65
+
let mut req = axum::http::Request::post(uri);
66
+
if let Some(req_headers) = req.headers_mut() {
67
+
// Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
68
+
copy_filtered_headers(&headers, req_headers);
69
+
//Setting the content type to application/json manually
70
+
req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
71
+
}
72
+
73
+
//Clears the email_otp because the pds will reject a request with it.
74
+
payload.email_otp = None;
75
+
let payload_bytes =
76
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
77
+
78
+
let req = req
79
+
.body(Body::from(payload_bytes))
80
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
81
+
82
+
let proxied = state
83
+
.reverse_proxy_client
84
+
.request(req)
85
+
.await
86
+
.map_err(|_| StatusCode::BAD_REQUEST)?
87
+
.into_response();
88
+
89
+
Ok(proxied)
90
+
}
91
+
//Ignoring the type of token check failure. Looks like oauth on the entry treads them the same.
92
+
AuthResult::TokenCheckFailed(_) => oauth_json_error_response(
93
+
StatusCode::BAD_REQUEST,
94
+
"invalid_request",
95
+
"Unable to sign-in due to an unexpected server error",
96
+
),
97
+
},
98
+
Err(err) => {
99
+
log::error!(
100
+
"Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
101
+
);
102
+
oauth_json_error_response(
103
+
StatusCode::BAD_REQUEST,
104
+
"pds_gatekeeper_error",
105
+
"This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
106
+
)
107
+
}
108
+
}
109
+
}
110
+
111
+
fn is_disallowed_header(name: &HeaderName) -> bool {
112
+
// possible problematic headers with proxying
113
+
matches!(
114
+
name.as_str(),
115
+
"connection"
116
+
| "keep-alive"
117
+
| "proxy-authenticate"
118
+
| "proxy-authorization"
119
+
| "te"
120
+
| "trailer"
121
+
| "transfer-encoding"
122
+
| "upgrade"
123
+
| "host"
124
+
| "content-length"
125
+
| "content-encoding"
126
+
| "expect"
127
+
| "accept-encoding"
128
+
)
129
+
}
130
+
131
+
fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
132
+
for (name, value) in src.iter() {
133
+
if is_disallowed_header(name) {
134
+
continue;
135
+
}
136
+
// Only copy valid headers
137
+
if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
138
+
dst.insert(name.clone(), hv);
139
+
}
140
+
}
141
+
}
+66
-211
src/xrpc/com_atproto_server.rs
+66
-211
src/xrpc/com_atproto_server.rs
···
1
1
use crate::AppState;
2
+
use crate::helpers::{
3
+
AuthResult, ProxiedResult, TokenCheckError, json_error_response, preauth_check, proxy_get_json,
4
+
};
2
5
use crate::middleware::Did;
3
-
use crate::xrpc::helpers::{ProxiedResult, json_error_response, proxy_get_json};
4
6
use axum::body::Body;
5
7
use axum::extract::State;
6
8
use axum::http::{HeaderMap, StatusCode};
7
9
use axum::response::{IntoResponse, Response};
8
10
use axum::{Extension, Json, debug_handler, extract, extract::Request};
9
-
use axum_template::TemplateEngine;
10
-
use lettre::message::{MultiPart, SinglePart, header};
11
-
use lettre::{AsyncTransport, Message};
12
11
use serde::{Deserialize, Serialize};
13
12
use serde_json;
14
-
use serde_json::Value;
15
-
use serde_json::value::Map;
16
13
use tracing::log;
17
14
18
15
#[derive(Serialize, Deserialize, Debug, Clone)]
···
58
55
pub struct CreateSessionRequest {
59
56
identifier: String,
60
57
password: String,
61
-
auth_factor_token: String,
62
-
allow_takendown: bool,
63
-
}
64
-
65
-
pub enum AuthResult {
66
-
WrongIdentityOrPassword,
67
-
TwoFactorRequired,
68
-
TwoFactorFailed,
69
-
/// User does not have 2FA enabled, or passes it
70
-
ProxyThrough,
71
-
}
72
-
73
-
pub enum IdentifierType {
74
-
Email,
75
-
DID,
76
-
Handle,
77
-
}
78
-
79
-
impl IdentifierType {
80
-
fn what_is_it(identifier: String) -> Self {
81
-
if identifier.contains("@") {
82
-
IdentifierType::Email
83
-
} else if identifier.contains("did:") {
84
-
IdentifierType::DID
85
-
} else {
86
-
IdentifierType::Handle
87
-
}
88
-
}
89
-
}
90
-
91
-
async fn verify_password(password: &str, password_scrypt: &str) -> Result<bool, StatusCode> {
92
-
// Expected format: "salt:hash" where hash is hex of scrypt(password, salt, 64 bytes)
93
-
let mut parts = password_scrypt.splitn(2, ':');
94
-
let salt = match parts.next() {
95
-
Some(s) if !s.is_empty() => s,
96
-
_ => return Ok(false),
97
-
};
98
-
let stored_hash_hex = match parts.next() {
99
-
Some(h) if !h.is_empty() => h,
100
-
_ => return Ok(false),
101
-
};
102
-
103
-
//Sets up scrypt to mimic node's scrypt
104
-
let params = match scrypt::Params::new(14, 8, 1, 64) {
105
-
Ok(p) => p,
106
-
Err(_) => return Ok(false),
107
-
};
108
-
let mut derived = [0u8; 64];
109
-
if scrypt::scrypt(password.as_bytes(), salt.as_bytes(), ¶ms, &mut derived).is_err() {
110
-
return Ok(false);
111
-
}
112
-
113
-
let stored_bytes = match hex::decode(stored_hash_hex) {
114
-
Ok(b) => b,
115
-
Err(e) => {
116
-
log::error!("Error decoding stored hash: {}", e);
117
-
return Ok(false);
118
-
}
119
-
};
120
-
121
-
Ok(derived.as_slice() == stored_bytes.as_slice())
122
-
}
123
-
124
-
async fn preauth_check(
125
-
state: &AppState,
126
-
identifier: &str,
127
-
password: &str,
128
-
) -> Result<AuthResult, StatusCode> {
129
-
// Determine identifier type
130
-
let id_type = IdentifierType::what_is_it(identifier.to_string());
131
-
132
-
// Query account DB for did and passwordScrypt based on identifier type
133
-
let account_row: Option<(String, String, String)> = match id_type {
134
-
IdentifierType::Email => sqlx::query_as::<_, (String, String, String)>(
135
-
"SELECT did, passwordScrypt, account.email FROM account WHERE email = ? LIMIT 1",
136
-
)
137
-
.bind(identifier)
138
-
.fetch_optional(&state.account_pool)
139
-
.await
140
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
141
-
IdentifierType::Handle => sqlx::query_as::<_, (String, String, String)>(
142
-
"SELECT account.did, account.passwordScrypt, account.email
143
-
FROM actor
144
-
LEFT JOIN account ON actor.did = account.did
145
-
where actor.handle =? LIMIT 1",
146
-
)
147
-
.bind(identifier)
148
-
.fetch_optional(&state.account_pool)
149
-
.await
150
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
151
-
IdentifierType::DID => sqlx::query_as::<_, (String, String, String)>(
152
-
"SELECT did, passwordScrypt, account.email FROM account WHERE did = ? LIMIT 1",
153
-
)
154
-
.bind(identifier)
155
-
.fetch_optional(&state.account_pool)
156
-
.await
157
-
.map_err(|_| StatusCode::BAD_REQUEST)?,
158
-
};
159
-
160
-
if let Some((did, password_scrypt, email)) = account_row {
161
-
// Check two-factor requirement for this DID in the gatekeeper DB
162
-
let required_opt = sqlx::query_as::<_, (u8,)>(
163
-
"SELECT required FROM two_factor_accounts WHERE did = ? LIMIT 1",
164
-
)
165
-
.bind(&did)
166
-
.fetch_optional(&state.pds_gatekeeper_pool)
167
-
.await
168
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
169
-
170
-
let two_factor_required = match required_opt {
171
-
Some(row) => row.0 != 0,
172
-
None => false,
173
-
};
174
-
175
-
if two_factor_required {
176
-
// Verify password before proceeding to 2FA email step
177
-
let verified = verify_password(password, &password_scrypt).await?;
178
-
if !verified {
179
-
return Ok(AuthResult::WrongIdentityOrPassword);
180
-
}
181
-
let mut email_data = Map::new();
182
-
//TODO these need real values
183
-
let token = "test".to_string();
184
-
let handle = "baileytownsend.dev".to_string();
185
-
email_data.insert("token".to_string(), Value::from(token.clone()));
186
-
email_data.insert("handle".to_string(), Value::from(handle.clone()));
187
-
//TODO bad unwrap
188
-
let email_body = state
189
-
.template_engine
190
-
.render("two_factor_code.hbs", email_data)
191
-
.unwrap();
192
-
193
-
let email = Message::builder()
194
-
//TODO prob get the proper type in the state
195
-
.from(state.mailer_from.parse().unwrap())
196
-
.to(email.parse().unwrap())
197
-
.subject("Sign in to Bluesky")
198
-
.multipart(
199
-
MultiPart::alternative() // This is composed of two parts.
200
-
.singlepart(
201
-
SinglePart::builder()
202
-
.header(header::ContentType::TEXT_PLAIN)
203
-
.body(format!("We received a sign-in request for the account @{}. Use the code: {} to sign in. If this wasn't you, we recommend taking steps to protect your account by changing your password at https://bsky.app/settings.", handle, token)), // Every message should have a plain text fallback.
204
-
)
205
-
.singlepart(
206
-
SinglePart::builder()
207
-
.header(header::ContentType::TEXT_HTML)
208
-
.body(email_body),
209
-
),
210
-
)
211
-
//TODO bad
212
-
.unwrap();
213
-
return match state.mailer.send(email).await {
214
-
Ok(_) => Ok(AuthResult::TwoFactorRequired),
215
-
Err(err) => {
216
-
log::error!("Error sending the 2FA email: {}", err);
217
-
Err(StatusCode::BAD_REQUEST)
218
-
}
219
-
};
220
-
}
221
-
}
222
-
223
-
// No local 2FA requirement (or account not found)
224
-
Ok(AuthResult::ProxyThrough)
58
+
#[serde(skip_serializing_if = "Option::is_none")]
59
+
auth_factor_token: Option<String>,
60
+
#[serde(skip_serializing_if = "Option::is_none")]
61
+
allow_takendown: Option<bool>,
225
62
}
226
63
227
64
pub async fn create_session(
···
231
68
) -> Result<Response<Body>, StatusCode> {
232
69
let identifier = payload.identifier.clone();
233
70
let password = payload.password.clone();
71
+
let auth_factor_token = payload.auth_factor_token.clone();
234
72
235
73
// Run the shared pre-auth logic to validate and check 2FA requirement
236
-
match preauth_check(&state, &identifier, &password).await? {
237
-
AuthResult::WrongIdentityOrPassword => json_error_response(
238
-
StatusCode::UNAUTHORIZED,
239
-
"AuthenticationRequired",
240
-
"Invalid identifier or password",
241
-
),
242
-
AuthResult::TwoFactorRequired => {
243
-
// Email sending step can be handled here if needed in the future.
244
-
json_error_response(
74
+
match preauth_check(&state, &identifier, &password, auth_factor_token, false).await {
75
+
Ok(result) => match result {
76
+
AuthResult::WrongIdentityOrPassword => json_error_response(
245
77
StatusCode::UNAUTHORIZED,
246
-
"AuthFactorTokenRequired",
247
-
"A sign in code has been sent to your email address",
248
-
)
249
-
}
250
-
AuthResult::TwoFactorFailed => {
251
-
//Not sure what the errors are for this response is yet
252
-
json_error_response(StatusCode::UNAUTHORIZED, "PLACEHOLDER", "PLACEHOLDER")
253
-
}
254
-
AuthResult::ProxyThrough => {
255
-
//No 2FA or already passed
256
-
let uri = format!(
257
-
"{}{}",
258
-
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
259
-
);
260
-
261
-
let mut req = axum::http::Request::post(uri);
262
-
if let Some(req_headers) = req.headers_mut() {
263
-
req_headers.extend(headers.clone());
78
+
"AuthenticationRequired",
79
+
"Invalid identifier or password",
80
+
),
81
+
AuthResult::TwoFactorRequired(_) => {
82
+
// Email sending step can be handled here if needed in the future.
83
+
json_error_response(
84
+
StatusCode::UNAUTHORIZED,
85
+
"AuthFactorTokenRequired",
86
+
"A sign in code has been sent to your email address",
87
+
)
264
88
}
89
+
AuthResult::ProxyThrough => {
90
+
log::info!("Proxying through");
91
+
//No 2FA or already passed
92
+
let uri = format!(
93
+
"{}{}",
94
+
state.pds_base_url, "/xrpc/com.atproto.server.createSession"
95
+
);
96
+
97
+
let mut req = axum::http::Request::post(uri);
98
+
if let Some(req_headers) = req.headers_mut() {
99
+
req_headers.extend(headers.clone());
100
+
}
265
101
266
-
let payload_bytes =
267
-
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
268
-
let req = req
269
-
.body(Body::from(payload_bytes))
270
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
102
+
let payload_bytes =
103
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
104
+
let req = req
105
+
.body(Body::from(payload_bytes))
106
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
271
107
272
-
let proxied = state
273
-
.reverse_proxy_client
274
-
.request(req)
275
-
.await
276
-
.map_err(|_| StatusCode::BAD_REQUEST)?
277
-
.into_response();
108
+
let proxied = state
109
+
.reverse_proxy_client
110
+
.request(req)
111
+
.await
112
+
.map_err(|_| StatusCode::BAD_REQUEST)?
113
+
.into_response();
278
114
279
-
Ok(proxied)
115
+
Ok(proxied)
116
+
}
117
+
AuthResult::TokenCheckFailed(err) => match err {
118
+
TokenCheckError::InvalidToken => {
119
+
json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "Token is invalid")
120
+
}
121
+
TokenCheckError::ExpiredToken => {
122
+
json_error_response(StatusCode::BAD_REQUEST, "ExpiredToken", "Token is expired")
123
+
}
124
+
},
125
+
},
126
+
Err(err) => {
127
+
log::error!(
128
+
"Error during pre-auth check. This happens on the create_session endpoint when trying to decide if the user has access:\n {err}"
129
+
);
130
+
json_error_response(
131
+
StatusCode::INTERNAL_SERVER_ERROR,
132
+
"InternalServerError",
133
+
"This error was not generated by the PDS, but PDS Gatekeeper. Please contact your PDS administrator for help and for them to review the server logs.",
134
+
)
280
135
}
281
136
}
282
137
}
···
290
145
) -> Result<Response<Body>, StatusCode> {
291
146
//If email auth is not set at all it is a update email address
292
147
let email_auth_not_set = payload.email_auth_factor.is_none();
293
-
//If email aurth is set it is to either turn on or off 2fa
148
+
//If email auth is set it is to either turn on or off 2fa
294
149
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
295
150
296
151
// Email update asked for
···
350
205
}
351
206
}
352
207
353
-
// Updating the acutal email address
208
+
// Updating the actual email address by sending it on to the PDS
354
209
let uri = format!(
355
210
"{}{}",
356
211
state.pds_base_url, "/xrpc/com.atproto.server.updateEmail"
-150
src/xrpc/helpers.rs
-150
src/xrpc/helpers.rs
···
1
-
use axum::body::{Body, to_bytes};
2
-
use axum::extract::Request;
3
-
use axum::http::{HeaderMap, Method, StatusCode, Uri};
4
-
use axum::http::header::CONTENT_TYPE;
5
-
use axum::response::{IntoResponse, Response};
6
-
use serde::de::DeserializeOwned;
7
-
use tracing::error;
8
-
9
-
use crate::AppState;
10
-
11
-
/// The result of a proxied call that attempts to parse JSON.
12
-
pub enum ProxiedResult<T> {
13
-
/// Successfully parsed JSON body along with original response headers.
14
-
Parsed { value: T, _headers: HeaderMap },
15
-
/// Could not or should not parse: return the original (or rebuilt) response as-is.
16
-
Passthrough(Response<Body>),
17
-
}
18
-
19
-
/// Proxy the incoming request to the PDS base URL plus the provided path and attempt to parse
20
-
/// the successful response body as JSON into `T`.
21
-
///
22
-
/// Behavior:
23
-
/// - If the proxied response is non-200, returns Passthrough with the original response.
24
-
/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
25
-
/// - If parsing succeeds, returns Parsed { value, headers }.
26
-
pub async fn proxy_get_json<T>(
27
-
state: &AppState,
28
-
mut req: Request,
29
-
path: &str,
30
-
) -> Result<ProxiedResult<T>, StatusCode>
31
-
where
32
-
T: DeserializeOwned,
33
-
{
34
-
let uri = format!("{}{}", state.pds_base_url, path);
35
-
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
36
-
37
-
let result = state
38
-
.reverse_proxy_client
39
-
.request(req)
40
-
.await
41
-
.map_err(|_| StatusCode::BAD_REQUEST)?
42
-
.into_response();
43
-
44
-
if result.status() != StatusCode::OK {
45
-
return Ok(ProxiedResult::Passthrough(result));
46
-
}
47
-
48
-
let response_headers = result.headers().clone();
49
-
let body = result.into_body();
50
-
let body_bytes = to_bytes(body, usize::MAX)
51
-
.await
52
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
53
-
54
-
match serde_json::from_slice::<T>(&body_bytes) {
55
-
Ok(value) => Ok(ProxiedResult::Parsed {
56
-
value,
57
-
_headers: response_headers,
58
-
}),
59
-
Err(err) => {
60
-
error!(%err, "failed to parse proxied JSON response; returning original body");
61
-
let mut builder = Response::builder().status(StatusCode::OK);
62
-
if let Some(headers) = builder.headers_mut() {
63
-
*headers = response_headers;
64
-
}
65
-
let resp = builder
66
-
.body(Body::from(body_bytes))
67
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
68
-
Ok(ProxiedResult::Passthrough(resp))
69
-
}
70
-
}
71
-
}
72
-
73
-
/// Proxy the incoming request as a POST to the PDS base URL plus the provided path and attempt to parse
74
-
/// the successful response body as JSON into `T`.
75
-
///
76
-
/// Behavior mirrors `proxy_get_json`:
77
-
/// - If the proxied response is non-200, returns Passthrough with the original response.
78
-
/// - If the response is 200 but JSON parsing fails, returns Passthrough with the original body and headers.
79
-
/// - If parsing succeeds, returns Parsed { value, headers }.
80
-
pub async fn _proxy_post_json<T>(
81
-
state: &AppState,
82
-
mut req: Request,
83
-
path: &str,
84
-
) -> Result<ProxiedResult<T>, StatusCode>
85
-
where
86
-
T: DeserializeOwned,
87
-
{
88
-
let uri = format!("{}{}", state.pds_base_url, path);
89
-
*req.uri_mut() = Uri::try_from(uri).map_err(|_| StatusCode::BAD_REQUEST)?;
90
-
*req.method_mut() = Method::POST;
91
-
92
-
let result = state
93
-
.reverse_proxy_client
94
-
.request(req)
95
-
.await
96
-
.map_err(|_| StatusCode::BAD_REQUEST)?
97
-
.into_response();
98
-
99
-
if result.status() != StatusCode::OK {
100
-
return Ok(ProxiedResult::Passthrough(result));
101
-
}
102
-
103
-
let response_headers = result.headers().clone();
104
-
let body = result.into_body();
105
-
let body_bytes = to_bytes(body, usize::MAX)
106
-
.await
107
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
108
-
109
-
match serde_json::from_slice::<T>(&body_bytes) {
110
-
Ok(value) => Ok(ProxiedResult::Parsed {
111
-
value,
112
-
_headers: response_headers,
113
-
}),
114
-
Err(err) => {
115
-
error!(%err, "failed to parse proxied JSON response (POST); returning original body");
116
-
let mut builder = Response::builder().status(StatusCode::OK);
117
-
if let Some(headers) = builder.headers_mut() {
118
-
*headers = response_headers;
119
-
}
120
-
let resp = builder
121
-
.body(Body::from(body_bytes))
122
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
123
-
Ok(ProxiedResult::Passthrough(resp))
124
-
}
125
-
}
126
-
}
127
-
128
-
129
-
/// Build a JSON error response with the required Content-Type header
130
-
/// Content-Type: application/json;charset=utf-8
131
-
/// Body shape: { "error": string, "message": string }
132
-
pub fn json_error_response(
133
-
status: StatusCode,
134
-
error: impl Into<String>,
135
-
message: impl Into<String>,
136
-
) -> Result<Response<Body>, StatusCode> {
137
-
let body_str = match serde_json::to_string(&serde_json::json!({
138
-
"error": error.into(),
139
-
"message": message.into(),
140
-
})) {
141
-
Ok(s) => s,
142
-
Err(_) => return Err(StatusCode::BAD_REQUEST),
143
-
};
144
-
145
-
Response::builder()
146
-
.status(status)
147
-
.header(CONTENT_TYPE, "application/json;charset=utf-8")
148
-
.body(Body::from(body_str))
149
-
.map_err(|_| StatusCode::BAD_REQUEST)
150
-
}