Our Personal Data Server from scratch! tranquil.farm
atproto pds rust postgresql fun oauth

feat(signal): add presage client, newtypes, and slot management #89

merged opened by oyster.cafe targeting main from feat/signal-client-in-house
Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mhlxir6en722
+553
Diff #0
+541
crates/tranquil-signal/src/client.rs
··· 1 + use std::fmt; 2 + use std::panic::AssertUnwindSafe; 3 + use std::sync::Arc; 4 + use std::sync::atomic::{AtomicBool, Ordering}; 5 + use std::time::{Duration, SystemTime, UNIX_EPOCH}; 6 + 7 + use presage::libsignal_service::configuration::SignalServers; 8 + use presage::manager::Registered; 9 + use presage::proto::DataMessage; 10 + use sqlx::PgPool; 11 + use tokio::sync::{RwLock, mpsc, oneshot}; 12 + use tokio_util::sync::CancellationToken; 13 + use url::Url; 14 + 15 + use crate::store::PgSignalStore; 16 + 17 + #[derive(Debug, Clone)] 18 + pub struct SignalUsername(String); 19 + 20 + #[derive(Debug, Clone)] 21 + pub struct InvalidSignalUsername(String); 22 + 23 + impl fmt::Display for InvalidSignalUsername { 24 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 25 + write!(f, "invalid signal username: {}", self.0) 26 + } 27 + } 28 + 29 + impl std::error::Error for InvalidSignalUsername {} 30 + 31 + impl SignalUsername { 32 + pub fn parse(username: &str) -> Result<Self, InvalidSignalUsername> { 33 + let reject = || Err(InvalidSignalUsername(username.to_string())); 34 + 35 + if username.len() < 6 || username.len() > 35 { 36 + return reject(); 37 + } 38 + 39 + let Some((base, discriminator)) = username.rsplit_once('.') else { 40 + return reject(); 41 + }; 42 + 43 + if base.len() < 3 || base.len() > 32 { 44 + return reject(); 45 + } 46 + 47 + if !base.chars().next().is_some_and(|c| c.is_ascii_alphabetic()) { 48 + return reject(); 49 + } 50 + 51 + if !base.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') { 52 + return reject(); 53 + } 54 + 55 + if discriminator.len() != 2 || !discriminator.chars().all(|c| c.is_ascii_digit()) { 56 + return reject(); 57 + } 58 + 59 + Ok(Self(username.to_string())) 60 + } 61 + 62 + pub fn as_str(&self) -> &str { 63 + &self.0 64 + } 65 + } 66 + 67 + impl fmt::Display for SignalUsername { 68 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 69 + f.write_str(&self.0) 70 + } 71 + } 72 + 73 + #[derive(Debug, Clone)] 74 + pub struct DeviceName(String); 75 + 76 + #[derive(Debug, Clone)] 77 + pub struct InvalidDeviceName(String); 78 + 79 + impl fmt::Display for InvalidDeviceName { 80 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 81 + write!(f, "invalid device name: {}", self.0) 82 + } 83 + } 84 + 85 + impl std::error::Error for InvalidDeviceName {} 86 + 87 + impl DeviceName { 88 + pub fn new(name: String) -> Result<Self, InvalidDeviceName> { 89 + if name.is_empty() || name.len() > 50 || !name.is_ascii() { 90 + return Err(InvalidDeviceName(name)); 91 + } 92 + Ok(Self(name)) 93 + } 94 + 95 + fn into_inner(self) -> String { 96 + self.0 97 + } 98 + } 99 + 100 + const LINK_TIMEOUT: Duration = Duration::from_secs(120); 101 + const SEND_TIMEOUT: Duration = Duration::from_secs(60); 102 + const MAX_MESSAGE_BYTES: usize = 2000; 103 + 104 + #[derive(Debug, Clone)] 105 + pub struct MessageBody(String); 106 + 107 + #[derive(Debug, Clone)] 108 + pub struct MessageTooLong { 109 + pub len: usize, 110 + pub max: usize, 111 + } 112 + 113 + impl fmt::Display for MessageTooLong { 114 + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 115 + write!( 116 + f, 117 + "message body too long: {} bytes (max {})", 118 + self.len, self.max 119 + ) 120 + } 121 + } 122 + 123 + impl std::error::Error for MessageTooLong {} 124 + 125 + impl MessageBody { 126 + pub fn new(body: String) -> Result<Self, MessageTooLong> { 127 + let len = body.len(); 128 + if len > MAX_MESSAGE_BYTES { 129 + return Err(MessageTooLong { 130 + len, 131 + max: MAX_MESSAGE_BYTES, 132 + }); 133 + } 134 + Ok(Self(body)) 135 + } 136 + 137 + pub fn as_str(&self) -> &str { 138 + &self.0 139 + } 140 + } 141 + 142 + #[derive(Debug, Clone, Copy, PartialEq, Eq)] 143 + pub struct LinkGeneration(u64); 144 + 145 + impl LinkGeneration { 146 + fn next(self) -> Self { 147 + Self(self.0.wrapping_add(1)) 148 + } 149 + } 150 + 151 + fn log_panic(thread_name: &str, payload: Box<dyn std::any::Any + Send>) { 152 + let msg = payload 153 + .downcast_ref::<&str>() 154 + .copied() 155 + .or_else(|| payload.downcast_ref::<String>().map(|s| s.as_str())) 156 + .unwrap_or("unknown panic"); 157 + tracing::error!(thread = thread_name, panic = msg, "signal thread panicked"); 158 + } 159 + 160 + fn spawn_signal_thread( 161 + name: &'static str, 162 + f: impl FnOnce() + Send + 'static, 163 + ) -> std::io::Result<std::thread::JoinHandle<()>> { 164 + std::thread::Builder::new() 165 + .name(name.into()) 166 + .spawn(move || { 167 + if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(f)) { 168 + log_panic(name, e); 169 + } 170 + }) 171 + } 172 + 173 + fn signal_local_block_on(fut: impl std::future::Future<Output = ()>) { 174 + let rt = tokio::runtime::Builder::new_current_thread() 175 + .enable_all() 176 + .build() 177 + .expect("signal runtime"); 178 + let local = tokio::task::LocalSet::new(); 179 + local.block_on(&rt, fut); 180 + } 181 + 182 + struct LinkingGuard(Arc<AtomicBool>); 183 + 184 + impl Drop for LinkingGuard { 185 + fn drop(&mut self) { 186 + self.0.store(false, Ordering::Release); 187 + } 188 + } 189 + 190 + #[derive(Debug, thiserror::Error)] 191 + pub enum SignalError { 192 + #[error("store: {0}")] 193 + Store(#[from] crate::store::PgStoreError), 194 + #[error("presage: {0}")] 195 + Presage(String), 196 + #[error("username lookup failed: {0}")] 197 + UsernameLookup(String), 198 + #[error("username not found: {0}")] 199 + UsernameNotFound(String), 200 + #[error("linking: {0}")] 201 + Linking(String), 202 + #[error("linking timed out")] 203 + LinkingTimeout, 204 + #[error("linking cancelled")] 205 + LinkingCancelled, 206 + #[error("not linked")] 207 + NotLinked, 208 + #[error("runtime: {0}")] 209 + Runtime(String), 210 + } 211 + 212 + type Manager = presage::Manager<PgSignalStore, Registered>; 213 + 214 + struct SendRequest { 215 + recipient: SignalUsername, 216 + message: MessageBody, 217 + reply: oneshot::Sender<Result<(), SignalError>>, 218 + } 219 + 220 + pub struct LinkResult { 221 + pub url: Url, 222 + pub completion: oneshot::Receiver<Result<SignalClient, SignalError>>, 223 + } 224 + 225 + pub struct SignalSlot { 226 + state: RwLock<SlotState>, 227 + linking_in_progress: Arc<AtomicBool>, 228 + } 229 + 230 + struct SlotState { 231 + client: Option<SignalClient>, 232 + generation: LinkGeneration, 233 + link_cancel: Option<CancellationToken>, 234 + } 235 + 236 + impl Default for SignalSlot { 237 + fn default() -> Self { 238 + Self { 239 + state: RwLock::new(SlotState { 240 + client: None, 241 + generation: LinkGeneration(0), 242 + link_cancel: None, 243 + }), 244 + linking_in_progress: Arc::new(AtomicBool::new(false)), 245 + } 246 + } 247 + } 248 + 249 + impl SignalSlot { 250 + pub async fn client(&self) -> Option<SignalClient> { 251 + let client = self.state.read().await.client.clone()?; 252 + if client.is_alive() { 253 + Some(client) 254 + } else { 255 + tracing::warn!("signal worker exited unexpectedly, clearing client"); 256 + self.state.write().await.client = None; 257 + None 258 + } 259 + } 260 + 261 + pub async fn is_linked(&self) -> bool { 262 + self.state 263 + .read() 264 + .await 265 + .client 266 + .as_ref() 267 + .is_some_and(SignalClient::is_alive) 268 + } 269 + 270 + pub async fn set_client(&self, client: SignalClient) { 271 + self.state.write().await.client = Some(client); 272 + } 273 + 274 + pub fn linking_flag(&self) -> Arc<AtomicBool> { 275 + self.linking_in_progress.clone() 276 + } 277 + 278 + pub async fn begin_link(&self) -> (LinkGeneration, CancellationToken) { 279 + let mut guard = self.state.write().await; 280 + if let Some(old) = guard.link_cancel.take() { 281 + old.cancel(); 282 + } 283 + let token = CancellationToken::new(); 284 + guard.link_cancel = Some(token.clone()); 285 + (guard.generation, token) 286 + } 287 + 288 + pub async fn complete_link(&self, generation: LinkGeneration, client: SignalClient) -> bool { 289 + let mut guard = self.state.write().await; 290 + if guard.generation != generation || guard.client.is_some() { 291 + return false; 292 + } 293 + guard.client = Some(client); 294 + guard.link_cancel = None; 295 + true 296 + } 297 + 298 + pub async fn unlink(&self) { 299 + let mut guard = self.state.write().await; 300 + guard.client = None; 301 + guard.generation = guard.generation.next(); 302 + if let Some(cancel) = guard.link_cancel.take() { 303 + cancel.cancel(); 304 + } 305 + } 306 + } 307 + 308 + #[derive(Clone)] 309 + pub struct SignalClient { 310 + tx: mpsc::Sender<SendRequest>, 311 + } 312 + 313 + impl SignalClient { 314 + fn from_manager(manager: Manager, shutdown: CancellationToken) -> Result<Self, SignalError> { 315 + let (tx, rx) = mpsc::channel::<SendRequest>(64); 316 + 317 + spawn_signal_thread("signal-worker", move || { 318 + signal_local_block_on(Self::worker_loop(manager, rx, shutdown)); 319 + }) 320 + .map_err(|e| SignalError::Runtime(format!("failed to spawn signal worker: {e}")))?; 321 + 322 + Ok(Self { tx }) 323 + } 324 + 325 + pub async fn from_pool(db: &PgPool, shutdown: CancellationToken) -> Option<Self> { 326 + let store = PgSignalStore::new(db.clone()); 327 + let (init_tx, init_rx) = oneshot::channel(); 328 + 329 + spawn_signal_thread("signal-init", move || { 330 + signal_local_block_on(async { 331 + let result = presage::Manager::load_registered(store).await; 332 + init_tx 333 + .send(result.map_err(|e| SignalError::Presage(e.to_string()))) 334 + .ok(); 335 + }); 336 + }) 337 + .map_err(|e| tracing::error!(error = %e, "failed to spawn signal init thread")) 338 + .ok()?; 339 + 340 + let manager = init_rx 341 + .await 342 + .ok()? 343 + .map_err(|e| tracing::error!(error = %e, "failed to load registered signal manager")) 344 + .ok()?; 345 + 346 + Self::from_manager(manager, shutdown) 347 + .map_err(|e| tracing::error!(error = %e, "failed to start signal worker")) 348 + .ok() 349 + } 350 + 351 + async fn worker_loop( 352 + mut manager: Manager, 353 + mut rx: mpsc::Receiver<SendRequest>, 354 + shutdown: CancellationToken, 355 + ) { 356 + loop { 357 + let req = tokio::select! { 358 + biased; 359 + _ = shutdown.cancelled() => { 360 + tracing::info!("signal worker shutting down (cancellation)"); 361 + break; 362 + } 363 + msg = rx.recv() => match msg { 364 + Some(r) => r, 365 + None => { 366 + tracing::info!("signal worker shutting down (channel closed)"); 367 + break; 368 + } 369 + }, 370 + }; 371 + let result = match tokio::time::timeout( 372 + SEND_TIMEOUT, 373 + Self::handle_send(&mut manager, &req.recipient, &req.message), 374 + ) 375 + .await 376 + { 377 + Ok(r) => r, 378 + Err(_) => { 379 + tracing::error!( 380 + recipient = %req.recipient, 381 + "signal send timed out after {}s", 382 + SEND_TIMEOUT.as_secs() 383 + ); 384 + Err(SignalError::Runtime(format!( 385 + "send timed out after {}s", 386 + SEND_TIMEOUT.as_secs() 387 + ))) 388 + } 389 + }; 390 + req.reply.send(result).ok(); 391 + } 392 + } 393 + 394 + async fn handle_send( 395 + manager: &mut Manager, 396 + recipient: &SignalUsername, 397 + message: &MessageBody, 398 + ) -> Result<(), SignalError> { 399 + let aci = manager 400 + .lookup_username(recipient.as_str()) 401 + .await 402 + .map_err(|e| SignalError::UsernameLookup(e.to_string()))? 403 + .ok_or_else(|| SignalError::UsernameNotFound(recipient.to_string()))?; 404 + 405 + let timestamp = SystemTime::now() 406 + .duration_since(UNIX_EPOCH) 407 + .map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX)) 408 + .map_err(|_| SignalError::Runtime("system clock is before unix epoch".into()))?; 409 + 410 + let data_message = DataMessage { 411 + body: Some(message.as_str().to_string()), 412 + timestamp: Some(timestamp), 413 + ..Default::default() 414 + }; 415 + 416 + manager 417 + .send_message(aci, data_message, timestamp) 418 + .await 419 + .map_err(|e| SignalError::Presage(e.to_string())) 420 + } 421 + 422 + pub fn is_alive(&self) -> bool { 423 + !self.tx.is_closed() 424 + } 425 + 426 + pub async fn send( 427 + &self, 428 + recipient: &SignalUsername, 429 + message: MessageBody, 430 + ) -> Result<(), SignalError> { 431 + let (reply_tx, reply_rx) = oneshot::channel(); 432 + 433 + self.tx 434 + .send(SendRequest { 435 + recipient: recipient.clone(), 436 + message, 437 + reply: reply_tx, 438 + }) 439 + .await 440 + .map_err(|_| SignalError::Runtime("signal worker thread exited".into()))?; 441 + 442 + reply_rx 443 + .await 444 + .map_err(|_| SignalError::Runtime("signal worker dropped request".into()))? 445 + } 446 + 447 + pub async fn link_device( 448 + db: &PgPool, 449 + device_name: DeviceName, 450 + shutdown: CancellationToken, 451 + link_cancel: CancellationToken, 452 + linking_flag: Arc<AtomicBool>, 453 + ) -> Result<LinkResult, SignalError> { 454 + if linking_flag.swap(true, Ordering::AcqRel) { 455 + return Err(SignalError::Linking( 456 + "device linking already in progress".into(), 457 + )); 458 + } 459 + 460 + let store = PgSignalStore::new(db.clone()); 461 + let (url_tx, url_rx) = oneshot::channel::<Result<Url, SignalError>>(); 462 + let (done_tx, done_rx) = oneshot::channel::<Result<SignalClient, SignalError>>(); 463 + 464 + let guard_flag = linking_flag.clone(); 465 + let spawn_result = spawn_signal_thread("signal-link", move || { 466 + let _guard = LinkingGuard(guard_flag); 467 + signal_local_block_on(async { 468 + let (prov_tx, prov_rx) = futures::channel::oneshot::channel(); 469 + 470 + let link_future = presage::Manager::link_secondary_device( 471 + store, 472 + SignalServers::Production, 473 + device_name.into_inner(), 474 + prov_tx, 475 + ); 476 + 477 + let url_forward = async { 478 + match prov_rx.await { 479 + Ok(url) => { 480 + url_tx.send(Ok(url)).ok(); 481 + } 482 + Err(e) => { 483 + url_tx.send(Err(SignalError::Linking(e.to_string()))).ok(); 484 + } 485 + } 486 + }; 487 + 488 + let link_result = tokio::select! { 489 + biased; 490 + _ = link_cancel.cancelled() => { 491 + tracing::info!("signal device linking cancelled"); 492 + done_tx.send(Err(SignalError::LinkingCancelled)).ok(); 493 + return; 494 + } 495 + r = tokio::time::timeout(LINK_TIMEOUT, async { 496 + let (link_res, _) = 497 + futures::future::join(link_future, url_forward).await; 498 + link_res 499 + }) => r, 500 + }; 501 + 502 + match link_result { 503 + Ok(Ok(manager)) => { 504 + let client_result = SignalClient::from_manager(manager, shutdown); 505 + done_tx.send(client_result).ok(); 506 + } 507 + Ok(Err(e)) => { 508 + tracing::error!(error = %e, "signal device linking failed"); 509 + done_tx.send(Err(SignalError::Linking(e.to_string()))).ok(); 510 + } 511 + Err(_) => { 512 + tracing::error!( 513 + "signal device linking timed out after {}s", 514 + LINK_TIMEOUT.as_secs() 515 + ); 516 + done_tx.send(Err(SignalError::LinkingTimeout)).ok(); 517 + } 518 + } 519 + }); 520 + }); 521 + 522 + match spawn_result { 523 + Ok(_) => {} 524 + Err(e) => { 525 + linking_flag.store(false, Ordering::Release); 526 + return Err(SignalError::Runtime(format!( 527 + "failed to spawn link thread: {e}" 528 + ))); 529 + } 530 + } 531 + 532 + let url = url_rx 533 + .await 534 + .map_err(|_| SignalError::Runtime("signal link thread exited".into()))??; 535 + 536 + Ok(LinkResult { 537 + url, 538 + completion: done_rx, 539 + }) 540 + } 541 + }
+12
crates/tranquil-signal/src/lib.rs
··· 1 + mod client; 2 + pub mod store; 3 + 4 + #[cfg(test)] 5 + mod tests; 6 + 7 + pub use client::{ 8 + DeviceName, InvalidDeviceName, InvalidSignalUsername, LinkGeneration, LinkResult, MessageBody, 9 + MessageTooLong, SignalClient, SignalError, SignalSlot, SignalUsername, 10 + }; 11 + pub use presage; 12 + pub use store::PgSignalStore;

History

1 round 0 comments
sign up or login to add to the discussion
oyster.cafe submitted #0
1 commit
expand
feat(signal): add presage client, newtypes, and slot management
expand 0 comments
pull request successfully merged