+1
-1
src/helpers.rs
+1
-1
src/helpers.rs
···
3
3
use anyhow::anyhow;
4
4
use axum::body::{Body, to_bytes};
5
5
use axum::extract::Request;
6
-
use axum::http::header::{CONTENT_LENGTH, CONTENT_TYPE};
6
+
use axum::http::header::CONTENT_TYPE;
7
7
use axum::http::{HeaderMap, StatusCode, Uri};
8
8
use axum::response::{IntoResponse, Response};
9
9
use axum_template::TemplateEngine;
+27
-16
src/main.rs
+27
-16
src/main.rs
···
22
22
use tower_governor::governor::GovernorConfigBuilder;
23
23
use tower_http::compression::CompressionLayer;
24
24
use tower_http::cors::{Any, CorsLayer};
25
-
use tracing::error;
25
+
use tracing::log;
26
26
use tracing_subscriber::{EnvFilter, fmt, prelude::*};
27
27
28
28
pub mod helpers;
···
88
88
#[tokio::main]
89
89
async fn main() -> Result<(), Box<dyn std::error::Error>> {
90
90
setup_tracing();
91
-
//TODO may need to change where this reads from? Like an env variable for it's location?
91
+
//TODO may need to change where this reads from? Like an env variable for it's location? Or arg?
92
92
dotenvy::from_path(Path::new("./pds.env"))?;
93
93
let pds_root = env::var("PDS_DATA_DIRECTORY")?;
94
94
let account_db_url = format!("{pds_root}/account.sqlite");
95
95
96
96
let account_options = SqliteConnectOptions::new()
97
-
.journal_mode(SqliteJournalMode::Wal)
98
-
.filename(account_db_url);
97
+
.filename(account_db_url)
98
+
.busy_timeout(Duration::from_secs(5));
99
99
100
100
let account_pool = SqlitePoolOptions::new()
101
101
.max_connections(5)
···
106
106
let options = SqliteConnectOptions::new()
107
107
.journal_mode(SqliteJournalMode::Wal)
108
108
.filename(bells_db_url)
109
-
.create_if_missing(true);
109
+
.create_if_missing(true)
110
+
.busy_timeout(Duration::from_secs(5));
110
111
let pds_gatekeeper_pool = SqlitePoolOptions::new()
111
112
.max_connections(5)
112
113
.connect_with(options)
···
135
136
//TODO add an override to manually load in the hbs templates
136
137
let _ = hbs.register_embed_templates::<EmailTemplates>();
137
138
139
+
let pds_base_url =
140
+
env::var("PDS_BASE_URL").unwrap_or_else(|_| "http://localhost:3000".to_string());
141
+
138
142
let state = AppState {
139
143
account_pool,
140
144
pds_gatekeeper_pool,
141
145
reverse_proxy_client: client,
142
-
//TODO should be env prob
143
-
pds_base_url: "http://localhost:3000".to_string(),
146
+
pds_base_url,
144
147
mailer,
145
148
mailer_from: sent_from,
146
149
template_engine: Engine::from(hbs),
···
148
151
149
152
// Rate limiting
150
153
//Allows 5 within 60 seconds, and after 60 should drop one off? So hit 5, then goes to 4 after 60 seconds.
151
-
let governor_conf = GovernorConfigBuilder::default()
154
+
let create_session_governor_conf = GovernorConfigBuilder::default()
155
+
.per_second(60)
156
+
.burst_size(5)
157
+
.finish()
158
+
.expect("failed to create governor config. this should not happen and is a bug");
159
+
160
+
// Create a second config with the same settings for the other endpoint
161
+
let sign_in_governor_conf = GovernorConfigBuilder::default()
152
162
.per_second(60)
153
163
.burst_size(5)
154
164
.finish()
155
165
.expect("failed to create governor config. this should not happen and is a bug");
156
166
157
-
let governor_limiter = governor_conf.limiter().clone();
167
+
let create_session_governor_limiter = create_session_governor_conf.limiter().clone();
168
+
let sign_in_governor_limiter = sign_in_governor_conf.limiter().clone();
158
169
let interval = Duration::from_secs(60);
159
170
// a separate background task to clean up
160
171
std::thread::spawn(move || {
161
172
loop {
162
173
std::thread::sleep(interval);
163
-
tracing::info!("rate limiting storage size: {}", governor_limiter.len());
164
-
governor_limiter.retain_recent();
174
+
create_session_governor_limiter.retain_recent();
175
+
sign_in_governor_limiter.retain_recent();
165
176
}
166
177
});
167
178
···
182
193
)
183
194
.route(
184
195
"/@atproto/oauth-provider/~api/sign-in",
185
-
post(sign_in), // .layer(GovernorLayer::new(governor_conf.clone()))),
196
+
post(sign_in).layer(GovernorLayer::new(sign_in_governor_conf)),
186
197
)
187
198
.route(
188
199
"/xrpc/com.atproto.server.createSession",
189
-
post(create_session.layer(GovernorLayer::new(governor_conf))),
200
+
post(create_session.layer(GovernorLayer::new(create_session_governor_conf))),
190
201
)
191
202
.layer(CompressionLayer::new())
192
203
.layer(cors)
193
204
.with_state(state);
194
205
195
-
let host = env::var("HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
196
-
let port: u16 = env::var("PORT")
206
+
let host = env::var("GATEKEEPER_HOST").unwrap_or_else(|_| "127.0.0.1".to_string());
207
+
let port: u16 = env::var("GATEKEEPER_PORT")
197
208
.ok()
198
209
.and_then(|s| s.parse().ok())
199
210
.unwrap_or(8080);
···
210
221
.with_graceful_shutdown(shutdown_signal());
211
222
212
223
if let Err(err) = server.await {
213
-
error!(error = %err, "server error");
224
+
log::error!("server error:{err}");
214
225
}
215
226
216
227
Ok(())
+37
-45
src/oauth_provider.rs
+37
-45
src/oauth_provider.rs
···
11
11
use serde::{Deserialize, Serialize};
12
12
use tracing::log;
13
13
14
-
fn is_disallowed_header(name: &HeaderName) -> bool {
15
-
// RFC 7230 hop-by-hop headers and other problematic ones for proxying a new body
16
-
// Use lowercase comparison; HeaderName equality is case-insensitive but we compare by string for a set
17
-
match name.as_str() {
18
-
// Hop-by-hop
19
-
"connection"
20
-
| "keep-alive"
21
-
| "proxy-authenticate"
22
-
| "proxy-authorization"
23
-
| "te"
24
-
| "trailer"
25
-
| "transfer-encoding"
26
-
| "upgrade" => true,
27
-
// Payload or routing related we should not forward
28
-
"host" | "content-length" | "content-encoding" | "expect" => true,
29
-
// Compression negotiation can interfere; let upstream decide defaults
30
-
// We can drop Accept-Encoding to avoid getting compressed payloads if not needed
31
-
"accept-encoding" => true,
32
-
_ => false,
33
-
}
34
-
}
35
-
36
-
fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
37
-
for (name, value) in src.iter() {
38
-
if is_disallowed_header(name) {
39
-
continue;
40
-
}
41
-
// Only copy valid headers
42
-
if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
43
-
dst.insert(name.clone(), hv);
44
-
}
45
-
}
46
-
}
47
-
48
14
#[derive(Serialize, Deserialize, Clone)]
49
15
pub struct SignInRequest {
50
16
pub username: String,
···
64
30
let password = payload.password.clone();
65
31
let auth_factor_token = payload.email_otp.clone();
66
32
67
-
//TODO need to pass in a flag to ignore app passwords for Oauth
68
-
// Run the shared pre-auth logic to validate and check 2FA requirement
69
33
match preauth_check(&state, &identifier, &password, auth_factor_token, true).await {
70
34
Ok(result) => match result {
71
35
AuthResult::WrongIdentityOrPassword => oauth_json_error_response(
···
102
66
);
103
67
104
68
let mut req = axum::http::Request::post(uri);
105
-
// if let Some(cookie) = headers.get("Cookie") {
106
-
// req = req.header("Cookie", cookie.to_str().unwrap());
107
-
// }
108
69
if let Some(req_headers) = req.headers_mut() {
109
-
// Copy headers but remove hop-by-hop and problematic ones
70
+
// Copy headers but remove problematic ones. There was an issue with the PDS not parsing the body fully if i forwarded all headers
110
71
copy_filtered_headers(&headers, req_headers);
111
-
// Ensure JSON content type is set explicitly
72
+
//Setting the content type to application/json manually
112
73
req_headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
113
74
}
114
75
115
76
payload.email_otp = None;
116
-
// let payload_bytes =
117
-
// serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
118
-
let body = serde_json::to_string(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
77
+
let payload_bytes =
78
+
serde_json::to_vec(&payload).map_err(|_| StatusCode::BAD_REQUEST)?;
119
79
120
80
let req = req
121
-
.body(Body::from(body))
81
+
.body(Body::from(payload_bytes))
122
82
.map_err(|_| StatusCode::BAD_REQUEST)?;
123
83
124
84
let proxied = state
···
155
115
}
156
116
}
157
117
}
118
+
119
+
fn is_disallowed_header(name: &HeaderName) -> bool {
120
+
// possible problematic headers with proxying
121
+
matches!(
122
+
name.as_str(),
123
+
"connection"
124
+
| "keep-alive"
125
+
| "proxy-authenticate"
126
+
| "proxy-authorization"
127
+
| "te"
128
+
| "trailer"
129
+
| "transfer-encoding"
130
+
| "upgrade"
131
+
| "host"
132
+
| "content-length"
133
+
| "content-encoding"
134
+
| "expect"
135
+
| "accept-encoding"
136
+
)
137
+
}
138
+
139
+
fn copy_filtered_headers(src: &HeaderMap, dst: &mut HeaderMap) {
140
+
for (name, value) in src.iter() {
141
+
if is_disallowed_header(name) {
142
+
continue;
143
+
}
144
+
// Only copy valid headers
145
+
if let Ok(hv) = HeaderValue::from_bytes(value.as_bytes()) {
146
+
dst.insert(name.clone(), hv);
147
+
}
148
+
}
149
+
}
+1
-2
src/xrpc/com_atproto_server.rs
+1
-2
src/xrpc/com_atproto_server.rs
···
11
11
use serde::{Deserialize, Serialize};
12
12
use serde_json;
13
13
use tracing::log;
14
-
use tracing::log::log;
15
14
16
15
#[derive(Serialize, Deserialize, Debug, Clone)]
17
16
#[serde(rename_all = "camelCase")]
···
65
64
pub async fn create_session(
66
65
State(state): State<AppState>,
67
66
headers: HeaderMap,
68
-
Json(mut payload): extract::Json<CreateSessionRequest>,
67
+
Json(payload): extract::Json<CreateSessionRequest>,
69
68
) -> Result<Response<Body>, StatusCode> {
70
69
let identifier = payload.identifier.clone();
71
70
let password = payload.password.clone();