+553
Diff
round #0
+541
crates/tranquil-signal/src/client.rs
+541
crates/tranquil-signal/src/client.rs
···
1
+
use std::fmt;
2
+
use std::panic::AssertUnwindSafe;
3
+
use std::sync::Arc;
4
+
use std::sync::atomic::{AtomicBool, Ordering};
5
+
use std::time::{Duration, SystemTime, UNIX_EPOCH};
6
+
7
+
use presage::libsignal_service::configuration::SignalServers;
8
+
use presage::manager::Registered;
9
+
use presage::proto::DataMessage;
10
+
use sqlx::PgPool;
11
+
use tokio::sync::{RwLock, mpsc, oneshot};
12
+
use tokio_util::sync::CancellationToken;
13
+
use url::Url;
14
+
15
+
use crate::store::PgSignalStore;
16
+
17
+
#[derive(Debug, Clone)]
18
+
pub struct SignalUsername(String);
19
+
20
+
#[derive(Debug, Clone)]
21
+
pub struct InvalidSignalUsername(String);
22
+
23
+
impl fmt::Display for InvalidSignalUsername {
24
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
25
+
write!(f, "invalid signal username: {}", self.0)
26
+
}
27
+
}
28
+
29
+
impl std::error::Error for InvalidSignalUsername {}
30
+
31
+
impl SignalUsername {
32
+
pub fn parse(username: &str) -> Result<Self, InvalidSignalUsername> {
33
+
let reject = || Err(InvalidSignalUsername(username.to_string()));
34
+
35
+
if username.len() < 6 || username.len() > 35 {
36
+
return reject();
37
+
}
38
+
39
+
let Some((base, discriminator)) = username.rsplit_once('.') else {
40
+
return reject();
41
+
};
42
+
43
+
if base.len() < 3 || base.len() > 32 {
44
+
return reject();
45
+
}
46
+
47
+
if !base.chars().next().is_some_and(|c| c.is_ascii_alphabetic()) {
48
+
return reject();
49
+
}
50
+
51
+
if !base.chars().all(|c| c.is_ascii_alphanumeric() || c == '_') {
52
+
return reject();
53
+
}
54
+
55
+
if discriminator.len() != 2 || !discriminator.chars().all(|c| c.is_ascii_digit()) {
56
+
return reject();
57
+
}
58
+
59
+
Ok(Self(username.to_string()))
60
+
}
61
+
62
+
pub fn as_str(&self) -> &str {
63
+
&self.0
64
+
}
65
+
}
66
+
67
+
impl fmt::Display for SignalUsername {
68
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
69
+
f.write_str(&self.0)
70
+
}
71
+
}
72
+
73
+
#[derive(Debug, Clone)]
74
+
pub struct DeviceName(String);
75
+
76
+
#[derive(Debug, Clone)]
77
+
pub struct InvalidDeviceName(String);
78
+
79
+
impl fmt::Display for InvalidDeviceName {
80
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
81
+
write!(f, "invalid device name: {}", self.0)
82
+
}
83
+
}
84
+
85
+
impl std::error::Error for InvalidDeviceName {}
86
+
87
+
impl DeviceName {
88
+
pub fn new(name: String) -> Result<Self, InvalidDeviceName> {
89
+
if name.is_empty() || name.len() > 50 || !name.is_ascii() {
90
+
return Err(InvalidDeviceName(name));
91
+
}
92
+
Ok(Self(name))
93
+
}
94
+
95
+
fn into_inner(self) -> String {
96
+
self.0
97
+
}
98
+
}
99
+
100
+
const LINK_TIMEOUT: Duration = Duration::from_secs(120);
101
+
const SEND_TIMEOUT: Duration = Duration::from_secs(60);
102
+
const MAX_MESSAGE_BYTES: usize = 2000;
103
+
104
+
#[derive(Debug, Clone)]
105
+
pub struct MessageBody(String);
106
+
107
+
#[derive(Debug, Clone)]
108
+
pub struct MessageTooLong {
109
+
pub len: usize,
110
+
pub max: usize,
111
+
}
112
+
113
+
impl fmt::Display for MessageTooLong {
114
+
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
115
+
write!(
116
+
f,
117
+
"message body too long: {} bytes (max {})",
118
+
self.len, self.max
119
+
)
120
+
}
121
+
}
122
+
123
+
impl std::error::Error for MessageTooLong {}
124
+
125
+
impl MessageBody {
126
+
pub fn new(body: String) -> Result<Self, MessageTooLong> {
127
+
let len = body.len();
128
+
if len > MAX_MESSAGE_BYTES {
129
+
return Err(MessageTooLong {
130
+
len,
131
+
max: MAX_MESSAGE_BYTES,
132
+
});
133
+
}
134
+
Ok(Self(body))
135
+
}
136
+
137
+
pub fn as_str(&self) -> &str {
138
+
&self.0
139
+
}
140
+
}
141
+
142
+
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
143
+
pub struct LinkGeneration(u64);
144
+
145
+
impl LinkGeneration {
146
+
fn next(self) -> Self {
147
+
Self(self.0.wrapping_add(1))
148
+
}
149
+
}
150
+
151
+
fn log_panic(thread_name: &str, payload: Box<dyn std::any::Any + Send>) {
152
+
let msg = payload
153
+
.downcast_ref::<&str>()
154
+
.copied()
155
+
.or_else(|| payload.downcast_ref::<String>().map(|s| s.as_str()))
156
+
.unwrap_or("unknown panic");
157
+
tracing::error!(thread = thread_name, panic = msg, "signal thread panicked");
158
+
}
159
+
160
+
fn spawn_signal_thread(
161
+
name: &'static str,
162
+
f: impl FnOnce() + Send + 'static,
163
+
) -> std::io::Result<std::thread::JoinHandle<()>> {
164
+
std::thread::Builder::new()
165
+
.name(name.into())
166
+
.spawn(move || {
167
+
if let Err(e) = std::panic::catch_unwind(AssertUnwindSafe(f)) {
168
+
log_panic(name, e);
169
+
}
170
+
})
171
+
}
172
+
173
+
fn signal_local_block_on(fut: impl std::future::Future<Output = ()>) {
174
+
let rt = tokio::runtime::Builder::new_current_thread()
175
+
.enable_all()
176
+
.build()
177
+
.expect("signal runtime");
178
+
let local = tokio::task::LocalSet::new();
179
+
local.block_on(&rt, fut);
180
+
}
181
+
182
+
struct LinkingGuard(Arc<AtomicBool>);
183
+
184
+
impl Drop for LinkingGuard {
185
+
fn drop(&mut self) {
186
+
self.0.store(false, Ordering::Release);
187
+
}
188
+
}
189
+
190
+
#[derive(Debug, thiserror::Error)]
191
+
pub enum SignalError {
192
+
#[error("store: {0}")]
193
+
Store(#[from] crate::store::PgStoreError),
194
+
#[error("presage: {0}")]
195
+
Presage(String),
196
+
#[error("username lookup failed: {0}")]
197
+
UsernameLookup(String),
198
+
#[error("username not found: {0}")]
199
+
UsernameNotFound(String),
200
+
#[error("linking: {0}")]
201
+
Linking(String),
202
+
#[error("linking timed out")]
203
+
LinkingTimeout,
204
+
#[error("linking cancelled")]
205
+
LinkingCancelled,
206
+
#[error("not linked")]
207
+
NotLinked,
208
+
#[error("runtime: {0}")]
209
+
Runtime(String),
210
+
}
211
+
212
+
type Manager = presage::Manager<PgSignalStore, Registered>;
213
+
214
+
struct SendRequest {
215
+
recipient: SignalUsername,
216
+
message: MessageBody,
217
+
reply: oneshot::Sender<Result<(), SignalError>>,
218
+
}
219
+
220
+
pub struct LinkResult {
221
+
pub url: Url,
222
+
pub completion: oneshot::Receiver<Result<SignalClient, SignalError>>,
223
+
}
224
+
225
+
pub struct SignalSlot {
226
+
state: RwLock<SlotState>,
227
+
linking_in_progress: Arc<AtomicBool>,
228
+
}
229
+
230
+
struct SlotState {
231
+
client: Option<SignalClient>,
232
+
generation: LinkGeneration,
233
+
link_cancel: Option<CancellationToken>,
234
+
}
235
+
236
+
impl Default for SignalSlot {
237
+
fn default() -> Self {
238
+
Self {
239
+
state: RwLock::new(SlotState {
240
+
client: None,
241
+
generation: LinkGeneration(0),
242
+
link_cancel: None,
243
+
}),
244
+
linking_in_progress: Arc::new(AtomicBool::new(false)),
245
+
}
246
+
}
247
+
}
248
+
249
+
impl SignalSlot {
250
+
pub async fn client(&self) -> Option<SignalClient> {
251
+
let client = self.state.read().await.client.clone()?;
252
+
if client.is_alive() {
253
+
Some(client)
254
+
} else {
255
+
tracing::warn!("signal worker exited unexpectedly, clearing client");
256
+
self.state.write().await.client = None;
257
+
None
258
+
}
259
+
}
260
+
261
+
pub async fn is_linked(&self) -> bool {
262
+
self.state
263
+
.read()
264
+
.await
265
+
.client
266
+
.as_ref()
267
+
.is_some_and(SignalClient::is_alive)
268
+
}
269
+
270
+
pub async fn set_client(&self, client: SignalClient) {
271
+
self.state.write().await.client = Some(client);
272
+
}
273
+
274
+
pub fn linking_flag(&self) -> Arc<AtomicBool> {
275
+
self.linking_in_progress.clone()
276
+
}
277
+
278
+
pub async fn begin_link(&self) -> (LinkGeneration, CancellationToken) {
279
+
let mut guard = self.state.write().await;
280
+
if let Some(old) = guard.link_cancel.take() {
281
+
old.cancel();
282
+
}
283
+
let token = CancellationToken::new();
284
+
guard.link_cancel = Some(token.clone());
285
+
(guard.generation, token)
286
+
}
287
+
288
+
pub async fn complete_link(&self, generation: LinkGeneration, client: SignalClient) -> bool {
289
+
let mut guard = self.state.write().await;
290
+
if guard.generation != generation || guard.client.is_some() {
291
+
return false;
292
+
}
293
+
guard.client = Some(client);
294
+
guard.link_cancel = None;
295
+
true
296
+
}
297
+
298
+
pub async fn unlink(&self) {
299
+
let mut guard = self.state.write().await;
300
+
guard.client = None;
301
+
guard.generation = guard.generation.next();
302
+
if let Some(cancel) = guard.link_cancel.take() {
303
+
cancel.cancel();
304
+
}
305
+
}
306
+
}
307
+
308
+
#[derive(Clone)]
309
+
pub struct SignalClient {
310
+
tx: mpsc::Sender<SendRequest>,
311
+
}
312
+
313
+
impl SignalClient {
314
+
fn from_manager(manager: Manager, shutdown: CancellationToken) -> Result<Self, SignalError> {
315
+
let (tx, rx) = mpsc::channel::<SendRequest>(64);
316
+
317
+
spawn_signal_thread("signal-worker", move || {
318
+
signal_local_block_on(Self::worker_loop(manager, rx, shutdown));
319
+
})
320
+
.map_err(|e| SignalError::Runtime(format!("failed to spawn signal worker: {e}")))?;
321
+
322
+
Ok(Self { tx })
323
+
}
324
+
325
+
pub async fn from_pool(db: &PgPool, shutdown: CancellationToken) -> Option<Self> {
326
+
let store = PgSignalStore::new(db.clone());
327
+
let (init_tx, init_rx) = oneshot::channel();
328
+
329
+
spawn_signal_thread("signal-init", move || {
330
+
signal_local_block_on(async {
331
+
let result = presage::Manager::load_registered(store).await;
332
+
init_tx
333
+
.send(result.map_err(|e| SignalError::Presage(e.to_string())))
334
+
.ok();
335
+
});
336
+
})
337
+
.map_err(|e| tracing::error!(error = %e, "failed to spawn signal init thread"))
338
+
.ok()?;
339
+
340
+
let manager = init_rx
341
+
.await
342
+
.ok()?
343
+
.map_err(|e| tracing::error!(error = %e, "failed to load registered signal manager"))
344
+
.ok()?;
345
+
346
+
Self::from_manager(manager, shutdown)
347
+
.map_err(|e| tracing::error!(error = %e, "failed to start signal worker"))
348
+
.ok()
349
+
}
350
+
351
+
async fn worker_loop(
352
+
mut manager: Manager,
353
+
mut rx: mpsc::Receiver<SendRequest>,
354
+
shutdown: CancellationToken,
355
+
) {
356
+
loop {
357
+
let req = tokio::select! {
358
+
biased;
359
+
_ = shutdown.cancelled() => {
360
+
tracing::info!("signal worker shutting down (cancellation)");
361
+
break;
362
+
}
363
+
msg = rx.recv() => match msg {
364
+
Some(r) => r,
365
+
None => {
366
+
tracing::info!("signal worker shutting down (channel closed)");
367
+
break;
368
+
}
369
+
},
370
+
};
371
+
let result = match tokio::time::timeout(
372
+
SEND_TIMEOUT,
373
+
Self::handle_send(&mut manager, &req.recipient, &req.message),
374
+
)
375
+
.await
376
+
{
377
+
Ok(r) => r,
378
+
Err(_) => {
379
+
tracing::error!(
380
+
recipient = %req.recipient,
381
+
"signal send timed out after {}s",
382
+
SEND_TIMEOUT.as_secs()
383
+
);
384
+
Err(SignalError::Runtime(format!(
385
+
"send timed out after {}s",
386
+
SEND_TIMEOUT.as_secs()
387
+
)))
388
+
}
389
+
};
390
+
req.reply.send(result).ok();
391
+
}
392
+
}
393
+
394
+
async fn handle_send(
395
+
manager: &mut Manager,
396
+
recipient: &SignalUsername,
397
+
message: &MessageBody,
398
+
) -> Result<(), SignalError> {
399
+
let aci = manager
400
+
.lookup_username(recipient.as_str())
401
+
.await
402
+
.map_err(|e| SignalError::UsernameLookup(e.to_string()))?
403
+
.ok_or_else(|| SignalError::UsernameNotFound(recipient.to_string()))?;
404
+
405
+
let timestamp = SystemTime::now()
406
+
.duration_since(UNIX_EPOCH)
407
+
.map(|d| u64::try_from(d.as_millis()).unwrap_or(u64::MAX))
408
+
.map_err(|_| SignalError::Runtime("system clock is before unix epoch".into()))?;
409
+
410
+
let data_message = DataMessage {
411
+
body: Some(message.as_str().to_string()),
412
+
timestamp: Some(timestamp),
413
+
..Default::default()
414
+
};
415
+
416
+
manager
417
+
.send_message(aci, data_message, timestamp)
418
+
.await
419
+
.map_err(|e| SignalError::Presage(e.to_string()))
420
+
}
421
+
422
+
pub fn is_alive(&self) -> bool {
423
+
!self.tx.is_closed()
424
+
}
425
+
426
+
pub async fn send(
427
+
&self,
428
+
recipient: &SignalUsername,
429
+
message: MessageBody,
430
+
) -> Result<(), SignalError> {
431
+
let (reply_tx, reply_rx) = oneshot::channel();
432
+
433
+
self.tx
434
+
.send(SendRequest {
435
+
recipient: recipient.clone(),
436
+
message,
437
+
reply: reply_tx,
438
+
})
439
+
.await
440
+
.map_err(|_| SignalError::Runtime("signal worker thread exited".into()))?;
441
+
442
+
reply_rx
443
+
.await
444
+
.map_err(|_| SignalError::Runtime("signal worker dropped request".into()))?
445
+
}
446
+
447
+
pub async fn link_device(
448
+
db: &PgPool,
449
+
device_name: DeviceName,
450
+
shutdown: CancellationToken,
451
+
link_cancel: CancellationToken,
452
+
linking_flag: Arc<AtomicBool>,
453
+
) -> Result<LinkResult, SignalError> {
454
+
if linking_flag.swap(true, Ordering::AcqRel) {
455
+
return Err(SignalError::Linking(
456
+
"device linking already in progress".into(),
457
+
));
458
+
}
459
+
460
+
let store = PgSignalStore::new(db.clone());
461
+
let (url_tx, url_rx) = oneshot::channel::<Result<Url, SignalError>>();
462
+
let (done_tx, done_rx) = oneshot::channel::<Result<SignalClient, SignalError>>();
463
+
464
+
let guard_flag = linking_flag.clone();
465
+
let spawn_result = spawn_signal_thread("signal-link", move || {
466
+
let _guard = LinkingGuard(guard_flag);
467
+
signal_local_block_on(async {
468
+
let (prov_tx, prov_rx) = futures::channel::oneshot::channel();
469
+
470
+
let link_future = presage::Manager::link_secondary_device(
471
+
store,
472
+
SignalServers::Production,
473
+
device_name.into_inner(),
474
+
prov_tx,
475
+
);
476
+
477
+
let url_forward = async {
478
+
match prov_rx.await {
479
+
Ok(url) => {
480
+
url_tx.send(Ok(url)).ok();
481
+
}
482
+
Err(e) => {
483
+
url_tx.send(Err(SignalError::Linking(e.to_string()))).ok();
484
+
}
485
+
}
486
+
};
487
+
488
+
let link_result = tokio::select! {
489
+
biased;
490
+
_ = link_cancel.cancelled() => {
491
+
tracing::info!("signal device linking cancelled");
492
+
done_tx.send(Err(SignalError::LinkingCancelled)).ok();
493
+
return;
494
+
}
495
+
r = tokio::time::timeout(LINK_TIMEOUT, async {
496
+
let (link_res, _) =
497
+
futures::future::join(link_future, url_forward).await;
498
+
link_res
499
+
}) => r,
500
+
};
501
+
502
+
match link_result {
503
+
Ok(Ok(manager)) => {
504
+
let client_result = SignalClient::from_manager(manager, shutdown);
505
+
done_tx.send(client_result).ok();
506
+
}
507
+
Ok(Err(e)) => {
508
+
tracing::error!(error = %e, "signal device linking failed");
509
+
done_tx.send(Err(SignalError::Linking(e.to_string()))).ok();
510
+
}
511
+
Err(_) => {
512
+
tracing::error!(
513
+
"signal device linking timed out after {}s",
514
+
LINK_TIMEOUT.as_secs()
515
+
);
516
+
done_tx.send(Err(SignalError::LinkingTimeout)).ok();
517
+
}
518
+
}
519
+
});
520
+
});
521
+
522
+
match spawn_result {
523
+
Ok(_) => {}
524
+
Err(e) => {
525
+
linking_flag.store(false, Ordering::Release);
526
+
return Err(SignalError::Runtime(format!(
527
+
"failed to spawn link thread: {e}"
528
+
)));
529
+
}
530
+
}
531
+
532
+
let url = url_rx
533
+
.await
534
+
.map_err(|_| SignalError::Runtime("signal link thread exited".into()))??;
535
+
536
+
Ok(LinkResult {
537
+
url,
538
+
completion: done_rx,
539
+
})
540
+
}
541
+
}
+12
crates/tranquil-signal/src/lib.rs
+12
crates/tranquil-signal/src/lib.rs
···
1
+
mod client;
2
+
pub mod store;
3
+
4
+
#[cfg(test)]
5
+
mod tests;
6
+
7
+
pub use client::{
8
+
DeviceName, InvalidDeviceName, InvalidSignalUsername, LinkGeneration, LinkResult, MessageBody,
9
+
MessageTooLong, SignalClient, SignalError, SignalSlot, SignalUsername,
10
+
};
11
+
pub use presage;
12
+
pub use store::PgSignalStore;
History
1 round
0 comments
oyster.cafe
submitted
#0
1 commit
expand
collapse
feat(signal): add presage client, newtypes, and slot management
expand 0 comments
pull request successfully merged