+3
-3
Cargo.lock
+3
-3
Cargo.lock
···
656
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
657
dependencies = [
658
"libc",
659
-
"windows-sys 0.52.0",
660
]
661
662
[[package]]
···
1392
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
1393
dependencies = [
1394
"cfg-if",
1395
-
"windows-targets 0.48.5",
1396
]
1397
1398
[[package]]
···
2136
"errno",
2137
"libc",
2138
"linux-raw-sys",
2139
-
"windows-sys 0.52.0",
2140
]
2141
2142
[[package]]
···
656
checksum = "778e2ac28f6c47af28e4907f13ffd1e1ddbd400980a9abd7c8df189bf578a5ad"
657
dependencies = [
658
"libc",
659
+
"windows-sys 0.59.0",
660
]
661
662
[[package]]
···
1392
checksum = "07033963ba89ebaf1584d767badaa2e8fcec21aedea6b8c0346d487d49c28667"
1393
dependencies = [
1394
"cfg-if",
1395
+
"windows-targets 0.52.6",
1396
]
1397
1398
[[package]]
···
2136
"errno",
2137
"libc",
2138
"linux-raw-sys",
2139
+
"windows-sys 0.59.0",
2140
]
2141
2142
[[package]]
+40
-9
src/main.rs
+40
-9
src/main.rs
···
19
use std::time::Duration;
20
use std::{env, net::SocketAddr};
21
use tower_governor::GovernorLayer;
22
-
use tower_governor::governor::GovernorConfigBuilder;
23
use tower_http::compression::CompressionLayer;
24
use tower_http::cors::{Any, CorsLayer};
25
use tracing::log;
···
166
.per_second(60)
167
.burst_size(5)
168
.finish()
169
-
.expect("failed to create governor config. this should not happen and is a bug");
170
171
// Create a second config with the same settings for the other endpoint
172
let sign_in_governor_conf = GovernorConfigBuilder::default()
173
.per_second(60)
174
.burst_size(5)
175
.finish()
176
-
.expect("failed to create governor config. this should not happen and is a bug");
177
178
-
// let create_account_limiter_time: Option<String> =
179
-
// env::var("GATEKEEPER_CREATE_ACCOUNT_LIMITER_WINDOW").unwrap_or_else(|_| None);
180
181
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
182
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
183
let interval = Duration::from_secs(60);
184
// a separate background task to clean up
185
std::thread::spawn(move || {
···
187
std::thread::sleep(interval);
188
create_session_governor_limiter.retain_recent();
189
sign_in_governor_limiter.retain_recent();
190
}
191
});
192
···
197
198
let app = Router::new()
199
.route("/", get(root_handler))
200
-
.route(
201
-
"/xrpc/com.atproto.server.getSession",
202
-
get(get_session).layer(ax_middleware::from_fn(middleware::extract_did)),
203
-
)
204
.route(
205
"/xrpc/com.atproto.server.updateEmail",
206
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···
213
"/xrpc/com.atproto.server.createSession",
214
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
215
)
216
.layer(CompressionLayer::new())
217
.layer(cors)
218
.with_state(state);
···
19
use std::time::Duration;
20
use std::{env, net::SocketAddr};
21
use tower_governor::GovernorLayer;
22
+
use tower_governor::governor::{GovernorConfig, GovernorConfigBuilder};
23
+
use tower_governor::key_extractor::PeerIpKeyExtractor;
24
use tower_http::compression::CompressionLayer;
25
use tower_http::cors::{Any, CorsLayer};
26
use tracing::log;
···
167
.per_second(60)
168
.burst_size(5)
169
.finish()
170
+
.expect("failed to create governor config for create session. this should not happen and is a bug");
171
172
// Create a second config with the same settings for the other endpoint
173
let sign_in_governor_conf = GovernorConfigBuilder::default()
174
.per_second(60)
175
.burst_size(5)
176
.finish()
177
+
.expect(
178
+
"failed to create governor config for sign in. this should not happen and is a bug",
179
+
);
180
+
181
+
let create_account_limiter_time: Option<String> =
182
+
env::var("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND").ok();
183
+
let create_account_limiter_burst: Option<String> =
184
+
env::var("GATEKEEPER_CREATE_ACCOUNT_BURST").ok();
185
+
let mut create_account_governor_conf = None;
186
187
+
if create_account_governor_conf.is_some() && create_account_limiter_time.is_some() {
188
+
let time = create_account_limiter_time
189
+
.expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND not set")
190
+
.parse::<u64>()
191
+
.expect("GATEKEEPER_CREATE_ACCOUNT_PER_SECOND must be a valid integer");
192
+
let burst = create_account_limiter_burst
193
+
.expect("GATEKEEPER_CREATE_ACCOUNT_BURST not set")
194
+
.parse::<u32>()
195
+
.expect("GATEKEEPER_CREATE_ACCOUNT_BURST must be a valid integer");
196
+
197
+
create_account_governor_conf = Some(
198
+
GovernorConfigBuilder::default()
199
+
.per_second(time)
200
+
.burst_size(burst)
201
+
.finish()
202
+
.expect("failed to create governor config for create account. this should not happen and is a bug"),
203
+
)
204
+
}
205
206
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
207
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
208
+
let create_account_governor_limiter = match create_account_governor_conf {
209
+
None => None,
210
+
Some(conf) => Some(conf.limiter().clone()),
211
+
};
212
+
213
let interval = Duration::from_secs(60);
214
// a separate background task to clean up
215
std::thread::spawn(move || {
···
217
std::thread::sleep(interval);
218
create_session_governor_limiter.retain_recent();
219
sign_in_governor_limiter.retain_recent();
220
+
if let Some(ref limiter) = create_account_governor_limiter {
221
+
limiter.retain_recent();
222
+
}
223
}
224
});
225
···
230
231
let app = Router::new()
232
.route("/", get(root_handler))
233
+
.route("/xrpc/com.atproto.server.getSession", get(get_session))
234
.route(
235
"/xrpc/com.atproto.server.updateEmail",
236
post(update_email).layer(ax_middleware::from_fn(middleware::extract_did)),
···
243
"/xrpc/com.atproto.server.createSession",
244
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
245
)
246
+
.route("/xrpc/com.atproto.server.createAccount")
247
.layer(CompressionLayer::new())
248
.layer(cors)
249
.with_state(state);
+45
-32
src/middleware.rs
+45
-32
src/middleware.rs
···
35
Some((scheme, token_str)) => {
36
// For Bearer, validate JWT and extract DID from `sub`.
37
// For DPoP, we currently only pass through and do not validate here; insert None DID.
38
-
// match scheme {
39
-
// AuthScheme::Bearer => {
40
-
let token = UntrustedToken::new(&token_str);
41
-
if token.is_err() {
42
-
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
43
-
.expect("Error creating an error response");
44
-
}
45
-
let parsed_token = token.expect("Already checked for error");
46
-
let claims: Result<Claims<TokenClaims>, ValidationError> =
47
-
parsed_token.deserialize_claims_unchecked();
48
-
if claims.is_err() {
49
-
return json_error_response(StatusCode::BAD_REQUEST, "TokenRequired", "")
50
-
.expect("Error creating an error response");
51
-
}
52
53
-
let key = Hs256Key::new(
54
-
env::var("PDS_JWT_SECRET").expect("PDS_JWT_SECRET not set in the pds.env"),
55
-
);
56
-
let token: Result<Token<TokenClaims>, ValidationError> =
57
-
Hs256.validator(&key).validate(&parsed_token);
58
-
if token.is_err() {
59
-
return json_error_response(StatusCode::BAD_REQUEST, "InvalidToken", "")
60
-
.expect("Error creating an error response");
61
}
62
-
let token = token.expect("Already checked for error,");
63
-
// Not going to worry about expiration since it still goes to the PDS
64
-
req.extensions_mut()
65
-
.insert(Did(Some(token.claims().custom.sub.clone())));
66
-
// }
67
-
// AuthScheme::DPoP => {
68
-
// // No DID extraction from DPoP here; leave None
69
-
// req.extensions_mut().insert(Did(None));
70
-
// }
71
-
// }
72
73
next.run(req).await
74
}
···
35
Some((scheme, token_str)) => {
36
// For Bearer, validate JWT and extract DID from `sub`.
37
// For DPoP, we currently only pass through and do not validate here; insert None DID.
38
+
match scheme {
39
+
AuthScheme::Bearer => {
40
+
let token = UntrustedToken::new(&token_str);
41
+
if token.is_err() {
42
+
return json_error_response(
43
+
StatusCode::BAD_REQUEST,
44
+
"TokenRequired",
45
+
"",
46
+
)
47
+
.expect("Error creating an error response");
48
+
}
49
+
let parsed_token = token.expect("Already checked for error");
50
+
let claims: Result<Claims<TokenClaims>, ValidationError> =
51
+
parsed_token.deserialize_claims_unchecked();
52
+
if claims.is_err() {
53
+
return json_error_response(
54
+
StatusCode::BAD_REQUEST,
55
+
"TokenRequired",
56
+
"",
57
+
)
58
+
.expect("Error creating an error response");
59
+
}
60
61
+
let key = Hs256Key::new(
62
+
env::var("PDS_JWT_SECRET")
63
+
.expect("PDS_JWT_SECRET not set in the pds.env"),
64
+
);
65
+
let token: Result<Token<TokenClaims>, ValidationError> =
66
+
Hs256.validator(&key).validate(&parsed_token);
67
+
if token.is_err() {
68
+
return json_error_response(
69
+
StatusCode::BAD_REQUEST,
70
+
"InvalidToken",
71
+
"",
72
+
)
73
+
.expect("Error creating an error response");
74
+
}
75
+
let token = token.expect("Already checked for error,");
76
+
// Not going to worry about expiration since it still goes to the PDS
77
+
req.extensions_mut()
78
+
.insert(Did(Some(token.claims().custom.sub.clone())));
79
+
}
80
+
AuthScheme::DPoP => {
81
+
//Not going to worry about oauth email update for now, just always forward to the PDS
82
+
req.extensions_mut().insert(Did(None));
83
+
}
84
}
85
86
next.run(req).await
87
}
+51
-46
src/xrpc/com_atproto_server.rs
+51
-46
src/xrpc/com_atproto_server.rs
···
147
//If email auth is set it is to either turn on or off 2fa
148
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149
150
-
// Email update asked for
151
-
if email_auth_update {
152
-
let email = payload.email.clone();
153
-
let email_confirmed = sqlx::query_as::<_, (String,)>(
154
-
"SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
155
-
)
156
-
.bind(&email)
157
-
.fetch_optional(&state.account_pool)
158
-
.await
159
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
160
161
-
//Since the email is already confirmed we can enable 2fa
162
-
return match email_confirmed {
163
-
None => Err(StatusCode::BAD_REQUEST),
164
-
Some(did_row) => {
165
-
let _ = sqlx::query(
166
-
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
167
-
)
168
-
.bind(&did_row.0)
169
-
.execute(&state.pds_gatekeeper_pool)
170
-
.await
171
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
172
-
173
-
Ok(StatusCode::OK.into_response())
174
-
}
175
-
};
176
-
}
177
-
178
-
// User wants auth turned off
179
-
if !email_auth_update && !email_auth_not_set {
180
-
//User wants auth turned off and has a token
181
-
if let Some(token) = &payload.token {
182
-
let token_found = sqlx::query_as::<_, (String,)>(
183
-
"SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
184
)
185
-
.bind(token)
186
-
.bind(&did.0)
187
.fetch_optional(&state.account_pool)
188
.await
189
.map_err(|_| StatusCode::BAD_REQUEST)?;
190
191
-
if token_found.is_some() {
192
-
let _ = sqlx::query(
193
-
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
194
)
195
-
.bind(&did.0)
196
-
.execute(&state.pds_gatekeeper_pool)
197
-
.await
198
-
.map_err(|_| StatusCode::BAD_REQUEST)?;
199
200
-
return Ok(StatusCode::OK.into_response());
201
-
} else {
202
-
return Err(StatusCode::BAD_REQUEST);
203
}
204
}
205
}
206
-
207
// Updating the actual email address by sending it on to the PDS
208
let uri = format!(
209
"{}{}",
···
147
//If email auth is set it is to either turn on or off 2fa
148
let email_auth_update = payload.email_auth_factor.unwrap_or(false);
149
150
+
//This means the middleware successfully extracted a did from the request, if not it just needs to be forward to the PDS
151
+
//This is also empty if it is an oauth request, which is not supported by gatekeeper turning on 2fa since the dpop stuff needs to be implemented
152
+
let did_is_not_empty = did.0.is_some();
153
154
+
if did_is_not_empty {
155
+
// Email update asked for
156
+
if email_auth_update {
157
+
let email = payload.email.clone();
158
+
let email_confirmed = sqlx::query_as::<_, (String,)>(
159
+
"SELECT did FROM account WHERE emailConfirmedAt IS NOT NULL AND email = ?",
160
)
161
+
.bind(&email)
162
.fetch_optional(&state.account_pool)
163
.await
164
.map_err(|_| StatusCode::BAD_REQUEST)?;
165
166
+
//Since the email is already confirmed we can enable 2fa
167
+
return match email_confirmed {
168
+
None => Err(StatusCode::BAD_REQUEST),
169
+
Some(did_row) => {
170
+
let _ = sqlx::query(
171
+
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 1) ON CONFLICT(did) DO UPDATE SET required = 1",
172
+
)
173
+
.bind(&did_row.0)
174
+
.execute(&state.pds_gatekeeper_pool)
175
+
.await
176
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
177
+
178
+
Ok(StatusCode::OK.into_response())
179
+
}
180
+
};
181
+
}
182
+
183
+
// User wants auth turned off
184
+
if !email_auth_update && !email_auth_not_set {
185
+
//User wants auth turned off and has a token
186
+
if let Some(token) = &payload.token {
187
+
let token_found = sqlx::query_as::<_, (String,)>(
188
+
"SELECT token FROM email_token WHERE token = ? AND did = ? AND purpose = 'update_email'",
189
)
190
+
.bind(token)
191
+
.bind(&did.0)
192
+
.fetch_optional(&state.account_pool)
193
+
.await
194
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
195
196
+
return if token_found.is_some() {
197
+
let _ = sqlx::query(
198
+
"INSERT INTO two_factor_accounts (did, required) VALUES (?, 0) ON CONFLICT(did) DO UPDATE SET required = 0",
199
+
)
200
+
.bind(&did.0)
201
+
.execute(&state.pds_gatekeeper_pool)
202
+
.await
203
+
.map_err(|_| StatusCode::BAD_REQUEST)?;
204
+
205
+
Ok(StatusCode::OK.into_response())
206
+
} else {
207
+
Err(StatusCode::BAD_REQUEST)
208
+
};
209
}
210
}
211
}
212
// Updating the actual email address by sending it on to the PDS
213
let uri = format!(
214
"{}{}",