auth dns over atproto
at main 298 lines 8.4 kB view raw
1use std::sync::Arc; 2use std::time::Instant; 3 4use axum::{ 5 Json, Router, 6 extract::{Query, Request, State}, 7 http::StatusCode, 8 middleware::{self, Next}, 9 response::IntoResponse, 10 routing::{get, put}, 11}; 12use serde::{Deserialize, Serialize}; 13 14use crate::materializer::{AppState, domain_ancestors}; 15 16pub fn router(state: Arc<AppState>) -> Router { 17 let metrics_router = onis_common::metrics::router(state.metrics_handle.clone()); 18 19 Router::new() 20 .route("/v1/resolve", get(resolve)) 21 .route("/v1/health", get(health)) 22 .route("/v1/zones", get(zones)) 23 .route("/v1/zones/stale", get(stale_zones)) 24 .route("/v1/verification", put(verification)) 25 .layer(middleware::from_fn(track_request_duration)) 26 .with_state(state) 27 .merge(metrics_router) 28} 29 30async fn track_request_duration(request: Request, next: Next) -> impl IntoResponse { 31 let path = request.uri().path().to_owned(); 32 let start = Instant::now(); 33 let response = next.run(request).await; 34 metrics::histogram!("appview_api_request_duration_seconds", "path" => path) 35 .record(start.elapsed().as_secs_f64()); 36 response 37} 38 39// --------------------------------------------------------------------------- 40// GET /v1/resolve?name={name}&type={type} 41// --------------------------------------------------------------------------- 42 43#[derive(Deserialize)] 44struct ResolveParams { 45 name: String, 46 #[serde(rename = "type")] 47 record_type: Option<String>, 48} 49 50#[derive(Serialize)] 51struct ResolveResponse { 52 zone: Option<String>, 53 verified: bool, 54 records: Vec<serde_json::Value>, 55 name_exists: bool, 56} 57 58/// Walk up the domain tree to find the matching zone in zone_index. 59async fn find_zone( 60 index: &sqlx::SqlitePool, 61 name: &str, 62) -> Result<Option<(String, String, bool)>, sqlx::Error> { 63 for candidate in domain_ancestors(name) { 64 let row: Option<(String, i64)> = 65 sqlx::query_as( 66 "SELECT did, verified FROM zone_index WHERE zone = ? ORDER BY verified DESC, first_seen ASC LIMIT 1" 67 ) 68 .bind(&candidate) 69 .fetch_optional(index) 70 .await?; 71 if let Some((did, verified)) = row { 72 return Ok(Some((candidate, did, verified != 0))); 73 } 74 } 75 Ok(None) 76} 77 78async fn resolve( 79 State(state): State<Arc<AppState>>, 80 Query(params): Query<ResolveParams>, 81) -> Result<Json<ResolveResponse>, StatusCode> { 82 let name = params.name.to_lowercase(); 83 84 let zone_match = find_zone(&state.index, &name) 85 .await 86 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 87 88 let (zone, did, verified) = match zone_match { 89 Some(z) => z, 90 None => { 91 return Ok(Json(ResolveResponse { 92 zone: None, 93 verified: false, 94 records: vec![], 95 name_exists: false, 96 })); 97 } 98 }; 99 100 if !verified { 101 return Ok(Json(ResolveResponse { 102 zone: Some(zone), 103 verified: false, 104 records: vec![], 105 name_exists: false, 106 })); 107 } 108 109 let user_db = state 110 .get_user_db(&did) 111 .await 112 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 113 114 let records: Vec<(String,)> = if let Some(ref rt) = params.record_type { 115 sqlx::query_as( 116 "SELECT data FROM records WHERE domain = ? AND record_type = ?", 117 ) 118 .bind(&name) 119 .bind(rt) 120 .fetch_all(&user_db) 121 .await 122 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? 123 } else { 124 sqlx::query_as("SELECT data FROM records WHERE domain = ?") 125 .bind(&name) 126 .fetch_all(&user_db) 127 .await 128 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)? 129 }; 130 131 let name_exists: Option<(i64,)> = 132 sqlx::query_as("SELECT 1 FROM records WHERE domain = ? LIMIT 1") 133 .bind(&name) 134 .fetch_optional(&user_db) 135 .await 136 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 137 138 let records: Vec<serde_json::Value> = records 139 .into_iter() 140 .filter_map(|(data,)| serde_json::from_str(&data).ok()) 141 .collect(); 142 143 Ok(Json(ResolveResponse { 144 zone: Some(zone), 145 verified: true, 146 records, 147 name_exists: name_exists.is_some(), 148 })) 149} 150 151// --------------------------------------------------------------------------- 152// GET /v1/health 153// --------------------------------------------------------------------------- 154 155#[derive(Serialize)] 156struct HealthResponse { 157 status: String, 158 zones: i64, 159 users: i64, 160} 161 162async fn health( 163 State(state): State<Arc<AppState>>, 164) -> Result<Json<HealthResponse>, StatusCode> { 165 let (zones,): (i64,) = 166 sqlx::query_as("SELECT COUNT(*) FROM zone_index") 167 .fetch_one(&state.index) 168 .await 169 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 170 171 let (users,): (i64,) = 172 sqlx::query_as("SELECT COUNT(DISTINCT did) FROM zone_index") 173 .fetch_one(&state.index) 174 .await 175 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 176 177 Ok(Json(HealthResponse { 178 status: "ok".to_string(), 179 zones, 180 users, 181 })) 182} 183 184// --------------------------------------------------------------------------- 185// GET /v1/zones?did={did} 186// --------------------------------------------------------------------------- 187 188#[derive(Deserialize)] 189struct ZonesParams { 190 did: String, 191} 192 193#[derive(Serialize, sqlx::FromRow)] 194struct ZoneEntry { 195 zone: String, 196 verified: bool, 197} 198 199#[derive(Serialize)] 200struct ZonesResponse { 201 zones: Vec<ZoneEntry>, 202} 203 204async fn zones( 205 State(state): State<Arc<AppState>>, 206 Query(params): Query<ZonesParams>, 207) -> Result<Json<ZonesResponse>, StatusCode> { 208 let zones: Vec<ZoneEntry> = 209 sqlx::query_as("SELECT zone, verified FROM zone_index WHERE did = ?") 210 .bind(&params.did) 211 .fetch_all(&state.index) 212 .await 213 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 214 215 Ok(Json(ZonesResponse { zones })) 216} 217 218// --------------------------------------------------------------------------- 219// GET /v1/zones/stale?checked_before={unix_timestamp} 220// --------------------------------------------------------------------------- 221 222#[derive(Deserialize)] 223struct StaleZonesParams { 224 checked_before: i64, 225} 226 227#[derive(Serialize, sqlx::FromRow)] 228struct StaleZoneEntry { 229 zone: String, 230 did: String, 231 verified: bool, 232 first_seen: i64, 233 last_verified: Option<i64>, 234} 235 236#[derive(Serialize)] 237struct StaleZonesResponse { 238 zones: Vec<StaleZoneEntry>, 239} 240 241async fn stale_zones( 242 State(state): State<Arc<AppState>>, 243 Query(params): Query<StaleZonesParams>, 244) -> Result<Json<StaleZonesResponse>, StatusCode> { 245 let zones: Vec<StaleZoneEntry> = sqlx::query_as( 246 "SELECT zone, did, verified, first_seen, last_verified FROM zone_index \ 247 WHERE last_verified < ? OR last_verified IS NULL", 248 ) 249 .bind(params.checked_before) 250 .fetch_all(&state.index) 251 .await 252 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 253 254 Ok(Json(StaleZonesResponse { zones })) 255} 256 257// --------------------------------------------------------------------------- 258// PUT /v1/verification 259// --------------------------------------------------------------------------- 260 261#[derive(Deserialize)] 262struct VerificationBody { 263 zone: String, 264 did: String, 265 verified: bool, 266} 267 268async fn verification( 269 State(state): State<Arc<AppState>>, 270 Json(body): Json<VerificationBody>, 271) -> Result<StatusCode, StatusCode> { 272 let now = chrono::Utc::now().timestamp(); 273 let verified_int: i64 = if body.verified { 1 } else { 0 }; 274 275 let result = sqlx::query( 276 "UPDATE zone_index SET verified = ?, last_verified = ? WHERE zone = ? AND did = ?", 277 ) 278 .bind(verified_int) 279 .bind(now) 280 .bind(&body.zone) 281 .bind(&body.did) 282 .execute(&state.index) 283 .await 284 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; 285 286 if result.rows_affected() == 0 { 287 return Err(StatusCode::NOT_FOUND); 288 } 289 290 tracing::info!( 291 zone = %body.zone, 292 did = %body.did, 293 verified = body.verified, 294 "verification status updated" 295 ); 296 297 Ok(StatusCode::NO_CONTENT) 298}