+1
Cargo.lock
+1
Cargo.lock
+1
Cargo.toml
+1
Cargo.toml
+9
-30
src/api/status.rs
+9
-30
src/api/status.rs
···
1129
1129
let did_for_event = did_string.clone();
1130
1130
let uri = current_status.uri.clone();
1131
1131
tokio::spawn(async move {
1132
-
let event = StatusEvent {
1133
-
event: "status.deleted",
1134
-
did: &did_for_event,
1135
-
handle: None,
1136
-
status: None,
1137
-
text: None,
1138
-
uri: Some(&uri),
1139
-
since: None,
1140
-
expires: None,
1141
-
};
1142
-
send_status_event(pool, &did_for_event, event)
1143
-
.await;
1132
+
crate::webhooks::emit_deleted(
1133
+
pool,
1134
+
&did_for_event,
1135
+
&uri,
1136
+
)
1137
+
.await;
1144
1138
});
1145
1139
}
1146
1140
···
1445
1439
1446
1440
let _ = status.save(db_pool.clone()).await;
1447
1441
1448
-
// Fire webhooks asynchronously
1442
+
// Fire webhooks asynchronously (helper keeps this file lean)
1449
1443
{
1450
1444
let pool = db_pool.get_ref().clone();
1451
-
let did_for_event = status.author_did.clone();
1452
-
let emoji = status.status.clone();
1453
-
let text = status.text.clone();
1454
-
let uri = status.uri.clone();
1455
-
let since = status.started_at.to_rfc3339();
1456
-
let expires = status.expires_at.map(|e| e.to_rfc3339());
1445
+
let s = status.clone();
1457
1446
tokio::spawn(async move {
1458
-
let event = StatusEvent {
1459
-
event: "status.created",
1460
-
did: &did_for_event,
1461
-
handle: None,
1462
-
status: Some(&emoji),
1463
-
text: text.as_deref(),
1464
-
uri: Some(&uri),
1465
-
since: Some(&since),
1466
-
expires: expires.as_deref(),
1467
-
};
1468
-
send_status_event(pool, &did_for_event, event).await;
1447
+
crate::webhooks::emit_created(pool, &s).await;
1469
1448
});
1470
1449
}
1471
1450
Ok(Redirect::to("/")
+77
-6
src/api/webhooks.rs
+77
-6
src/api/webhooks.rs
···
1
-
use crate::{db, error_handler::AppError};
1
+
use crate::{config::Config, db, error_handler::AppError};
2
2
use actix_session::Session;
3
3
use actix_web::{HttpResponse, Responder, Result, delete, get, post, put, web};
4
4
use async_sqlite::Pool;
5
5
use atrium_api::types::string::Did;
6
6
use serde::Deserialize;
7
7
use std::sync::Arc;
8
+
use url::Url;
8
9
9
10
#[derive(Deserialize)]
10
11
pub struct CreateWebhookRequest {
···
56
57
pub async fn create_webhook(
57
58
session: Session,
58
59
db_pool: web::Data<Arc<Pool>>,
60
+
app_config: web::Data<Config>,
59
61
payload: web::Json<CreateWebhookRequest>,
60
62
) -> Result<impl Responder> {
61
63
let did = session.get::<Did>("did")?;
62
64
if let Some(did) = did {
63
-
// Basic URL validation
64
-
if !(payload.url.starts_with("https://") || payload.url.starts_with("http://")) {
65
-
return Ok(web::Json(serde_json::json!({
66
-
"error": "URL must start with http:// or https://"
67
-
})));
65
+
// Robust URL + SSRF validation
66
+
if let Err(msg) = validate_url(&payload.url, &app_config) {
67
+
return Ok(web::Json(serde_json::json!({ "error": msg })));
68
+
}
69
+
// Events validation
70
+
if let Some(events_str) = &payload.events {
71
+
if let Err(msg) = validate_events(events_str) {
72
+
return Ok(web::Json(serde_json::json!({ "error": msg })));
73
+
}
68
74
}
69
75
let (id, secret) = db::create_webhook(
70
76
&db_pool,
···
97
103
match session.get::<Did>("did").unwrap_or(None) {
98
104
Some(did) => {
99
105
let id = path.into_inner();
106
+
if let Some(url) = &payload.url {
107
+
if Url::parse(url).is_err() {
108
+
return HttpResponse::BadRequest()
109
+
.json(serde_json::json!({ "error": "Invalid URL" }));
110
+
}
111
+
}
112
+
if let Some(events_str) = &payload.events {
113
+
if let Err(msg) = validate_events(events_str) {
114
+
return HttpResponse::BadRequest().json(serde_json::json!({ "error": msg }));
115
+
}
116
+
}
100
117
let res = db::update_webhook(
101
118
&db_pool,
102
119
did.as_str(),
···
116
133
HttpResponse::Unauthorized().json(serde_json::json!({ "error": "Not authenticated" }))
117
134
}
118
135
}
136
+
}
137
+
138
+
fn validate_events(s: &str) -> Result<(), &'static str> {
139
+
if s.trim().is_empty() {
140
+
return Ok(());
141
+
}
142
+
const ALLOWED: &[&str] = &["status.created", "status.deleted"];
143
+
for ev in s.split(',').map(|e| e.trim()) {
144
+
if !ALLOWED.contains(&ev) {
145
+
return Err("Unsupported event type");
146
+
}
147
+
}
148
+
Ok(())
149
+
}
150
+
151
+
fn validate_url(raw: &str, cfg: &Config) -> Result<(), &'static str> {
152
+
let url = Url::parse(raw).map_err(|_| "Invalid URL")?;
153
+
let scheme = url.scheme();
154
+
let host = url.host_str().ok_or("Missing host")?.to_ascii_lowercase();
155
+
156
+
// Treat localhost explicitly
157
+
let host_is_localname = host == "localhost";
158
+
159
+
// If host is an IP literal, apply standard library checks
160
+
let ip_check_blocks = if let Ok(ip) = host.parse::<std::net::IpAddr>() {
161
+
match ip {
162
+
std::net::IpAddr::V4(v4) => {
163
+
v4.is_private()
164
+
|| v4.is_loopback()
165
+
|| v4.is_link_local()
166
+
|| v4.is_multicast()
167
+
|| v4.is_unspecified()
168
+
}
169
+
std::net::IpAddr::V6(v6) => {
170
+
v6.is_unique_local() || v6.is_loopback() || v6.is_multicast() || v6.is_unspecified()
171
+
}
172
+
}
173
+
} else {
174
+
false
175
+
};
176
+
177
+
// Enforce HTTPS in production
178
+
let is_production = !cfg.oauth_redirect_base.starts_with("http://localhost")
179
+
&& !cfg.oauth_redirect_base.starts_with("http://127.0.0.1");
180
+
if is_production && scheme != "https" {
181
+
return Err("HTTPS required in production");
182
+
}
183
+
184
+
// Basic SSRF protection in production
185
+
if (host_is_localname || ip_check_blocks) && is_production {
186
+
return Err("Private/local hosts not allowed");
187
+
}
188
+
189
+
Ok(())
119
190
}
120
191
121
192
#[post("/api/webhooks/{id}/rotate")]
+122
-28
src/webhooks.rs
+122
-28
src/webhooks.rs
···
4
4
use serde::Serialize;
5
5
use sha2::Sha256;
6
6
7
-
use crate::db::{Webhook, get_user_webhooks};
7
+
use crate::db::{StatusFromDb, Webhook, get_user_webhooks};
8
+
use futures_util::future;
8
9
9
10
#[derive(Serialize)]
10
11
pub struct StatusEvent<'a> {
···
32
33
.any(|e| e.eq_ignore_ascii_case(event))
33
34
}
34
35
36
+
fn hmac_sig_hex(secret: &str, ts: &str, payload: &[u8]) -> String {
37
+
let mut mac =
38
+
Hmac::<Sha256>::new_from_slice(secret.as_bytes()).expect("HMAC can take key of any size");
39
+
mac.update(ts.as_bytes());
40
+
mac.update(b".");
41
+
mac.update(payload);
42
+
hex::encode(mac.finalize().into_bytes())
43
+
}
44
+
35
45
pub async fn send_status_event(pool: std::sync::Arc<Pool>, did: &str, event: StatusEvent<'_>) {
36
46
let client = Client::new();
37
47
let hooks = match get_user_webhooks(&pool, did).await {
···
50
60
};
51
61
let ts = chrono::Utc::now().timestamp().to_string();
52
62
53
-
for h in hooks.into_iter().filter(|h| should_send(h, event.event)) {
54
-
let mut mac = Hmac::<Sha256>::new_from_slice(h.secret.as_bytes())
55
-
.expect("HMAC can take key of any size");
56
-
mac.update(ts.as_bytes());
57
-
mac.update(b".");
58
-
mac.update(&payload);
59
-
let sig = hex::encode(mac.finalize().into_bytes());
63
+
let futures = hooks
64
+
.into_iter()
65
+
.filter(|h| should_send(h, event.event))
66
+
.map(|h| {
67
+
let payload = payload.clone();
68
+
let ts = ts.clone();
69
+
let client = client.clone();
70
+
async move {
71
+
let sig = hmac_sig_hex(&h.secret, &ts, &payload);
72
+
let res = client
73
+
.post(&h.url)
74
+
.header("User-Agent", "status-webhooks/1.0")
75
+
.header("Content-Type", "application/json")
76
+
.header("X-Status-Webhook-Timestamp", &ts)
77
+
.header("X-Status-Webhook-Signature", format!("sha256={}", sig))
78
+
.timeout(std::time::Duration::from_secs(5))
79
+
.body(payload)
80
+
.send()
81
+
.await;
60
82
61
-
let res = client
62
-
.post(&h.url)
63
-
.header("User-Agent", "status-webhooks/1.0")
64
-
.header("Content-Type", "application/json")
65
-
.header("X-Status-Webhook-Timestamp", &ts)
66
-
.header("X-Status-Webhook-Signature", format!("sha256={}", sig))
67
-
.body(payload.clone())
68
-
.send()
69
-
.await;
70
-
71
-
match res {
72
-
Ok(resp) => {
73
-
if !resp.status().is_success() {
74
-
log::warn!(
75
-
"webhook delivery failed: {} -> status {}",
76
-
&h.url,
77
-
resp.status()
78
-
);
83
+
match res {
84
+
Ok(resp) => {
85
+
if !resp.status().is_success() {
86
+
log::warn!(
87
+
"webhook delivery failed: {} -> status {}",
88
+
&h.url,
89
+
resp.status()
90
+
);
91
+
}
92
+
}
93
+
Err(e) => log::warn!("webhook delivery error to {}: {}", &h.url, e),
79
94
}
80
95
}
81
-
Err(e) => log::warn!("webhook delivery error to {}: {}", &h.url, e),
82
-
}
96
+
});
97
+
98
+
future::join_all(futures).await;
99
+
}
100
+
101
+
pub async fn emit_created(pool: std::sync::Arc<Pool>, s: &StatusFromDb) {
102
+
let did = s.author_did.clone();
103
+
let emoji = s.status.clone();
104
+
let text = s.text.clone();
105
+
let uri = s.uri.clone();
106
+
let since = s.started_at.to_rfc3339();
107
+
let expires = s.expires_at.map(|e| e.to_rfc3339());
108
+
let event = StatusEvent {
109
+
event: "status.created",
110
+
did: &did,
111
+
handle: None,
112
+
status: Some(&emoji),
113
+
text: text.as_deref(),
114
+
uri: Some(&uri),
115
+
since: Some(&since),
116
+
expires: expires.as_deref(),
117
+
};
118
+
send_status_event(pool, &did, event).await;
119
+
}
120
+
121
+
pub async fn emit_deleted(pool: std::sync::Arc<Pool>, did: &str, uri: &str) {
122
+
let did_owned = did.to_string();
123
+
let uri_owned = uri.to_string();
124
+
let event = StatusEvent {
125
+
event: "status.deleted",
126
+
did: &did_owned,
127
+
handle: None,
128
+
status: None,
129
+
text: None,
130
+
uri: Some(&uri_owned),
131
+
since: None,
132
+
expires: None,
133
+
};
134
+
send_status_event(pool, &did_owned, event).await;
135
+
}
136
+
137
+
#[cfg(test)]
138
+
mod tests {
139
+
use super::*;
140
+
141
+
#[test]
142
+
fn test_should_send_wildcard() {
143
+
let h = Webhook {
144
+
id: 1,
145
+
did: "d".into(),
146
+
url: "u".into(),
147
+
secret: "s".into(),
148
+
events: "*".into(),
149
+
active: true,
150
+
created_at: 0,
151
+
updated_at: 0,
152
+
};
153
+
assert!(should_send(&h, "status.created"));
154
+
}
155
+
156
+
#[test]
157
+
fn test_should_send_specific() {
158
+
let h = Webhook {
159
+
id: 1,
160
+
did: "d".into(),
161
+
url: "u".into(),
162
+
secret: "s".into(),
163
+
events: "status.deleted".into(),
164
+
active: true,
165
+
created_at: 0,
166
+
updated_at: 0,
167
+
};
168
+
assert!(should_send(&h, "status.deleted"));
169
+
assert!(!should_send(&h, "status.created"));
170
+
}
171
+
172
+
#[test]
173
+
fn test_hmac_sig_hex() {
174
+
let sig = hmac_sig_hex("secret", "1234567890", b"{\"a\":1}");
175
+
// Deterministic expected if inputs fixed
176
+
assert_eq!(sig.len(), 64);
83
177
}
84
178
}