+466
Diff
round #0
+466
crates/tranquil-signal/src/tests.rs
+466
crates/tranquil-signal/src/tests.rs
···
1
+
use presage::libsignal_service::{
2
+
pre_keys::{KyberPreKeyStoreExt, PreKeysStore},
3
+
prelude::{ProfileKey, SessionStoreExt},
4
+
protocol::{
5
+
DeviceId, Direction, GenericSignedPreKey, IdentityKeyPair, IdentityKeyStore, KeyPair,
6
+
KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore,
7
+
ProtocolAddress, SenderKeyStore, ServiceId, SessionRecord, SessionStore, SignedPreKeyId,
8
+
SignedPreKeyRecord, SignedPreKeyStore, Timestamp,
9
+
},
10
+
};
11
+
use presage::store::{ContentsStore, StateStore, Store};
12
+
use sqlx::postgres::PgPoolOptions;
13
+
use uuid::Uuid;
14
+
15
+
use crate::store::{IdentityType, PgProtocolStore, PgSignalStore};
16
+
17
+
async fn test_store() -> PgSignalStore {
18
+
let url = std::env::var("DATABASE_URL")
19
+
.unwrap_or_else(|_| "postgres://postgres:postgres@127.0.0.1:5432/postgres".into());
20
+
21
+
let pool = PgPoolOptions::new()
22
+
.max_connections(5)
23
+
.connect(&url)
24
+
.await
25
+
.unwrap();
26
+
27
+
sqlx::query("DELETE FROM signal_kv")
28
+
.execute(&pool)
29
+
.await
30
+
.ok();
31
+
sqlx::query("DELETE FROM signal_base_keys_seen")
32
+
.execute(&pool)
33
+
.await
34
+
.ok();
35
+
sqlx::query("DELETE FROM signal_sender_keys")
36
+
.execute(&pool)
37
+
.await
38
+
.ok();
39
+
sqlx::query("DELETE FROM signal_sessions")
40
+
.execute(&pool)
41
+
.await
42
+
.ok();
43
+
sqlx::query("DELETE FROM signal_identities")
44
+
.execute(&pool)
45
+
.await
46
+
.ok();
47
+
sqlx::query("DELETE FROM signal_kyber_pre_keys")
48
+
.execute(&pool)
49
+
.await
50
+
.ok();
51
+
sqlx::query("DELETE FROM signal_signed_pre_keys")
52
+
.execute(&pool)
53
+
.await
54
+
.ok();
55
+
sqlx::query("DELETE FROM signal_pre_keys")
56
+
.execute(&pool)
57
+
.await
58
+
.ok();
59
+
sqlx::query("DELETE FROM signal_profile_keys")
60
+
.execute(&pool)
61
+
.await
62
+
.ok();
63
+
64
+
PgSignalStore::new(pool)
65
+
}
66
+
67
+
fn protocol_store(store: &PgSignalStore, identity: IdentityType) -> PgProtocolStore {
68
+
PgProtocolStore::new(store.clone(), identity)
69
+
}
70
+
71
+
#[tokio::test]
72
+
async fn state_store_registration_empty() {
73
+
let store = test_store().await;
74
+
75
+
assert!(store.load_registration_data().await.unwrap().is_none());
76
+
assert!(!store.is_registered().await);
77
+
}
78
+
79
+
#[tokio::test]
80
+
async fn state_store_kv_roundtrip() {
81
+
let store = test_store().await;
82
+
83
+
let value = b"test-data".to_vec();
84
+
sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('test_key', $1)")
85
+
.bind(&value)
86
+
.execute(&store.db)
87
+
.await
88
+
.unwrap();
89
+
90
+
let loaded: Vec<u8> = sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'test_key'")
91
+
.fetch_one(&store.db)
92
+
.await
93
+
.unwrap();
94
+
assert_eq!(loaded, value);
95
+
}
96
+
97
+
#[tokio::test]
98
+
async fn state_store_identity_keypairs() {
99
+
let store = test_store().await;
100
+
101
+
let aci_pair = IdentityKeyPair::generate(&mut rand::rng());
102
+
let pni_pair = IdentityKeyPair::generate(&mut rand::rng());
103
+
104
+
store.set_aci_identity_key_pair(aci_pair).await.unwrap();
105
+
store.set_pni_identity_key_pair(pni_pair).await.unwrap();
106
+
107
+
let aci_store = protocol_store(&store, IdentityType::Aci);
108
+
let pni_store = protocol_store(&store, IdentityType::Pni);
109
+
110
+
let loaded_aci = aci_store.get_identity_key_pair().await.unwrap();
111
+
let loaded_pni = pni_store.get_identity_key_pair().await.unwrap();
112
+
113
+
assert_eq!(loaded_aci.serialize(), aci_pair.serialize());
114
+
assert_eq!(loaded_pni.serialize(), pni_pair.serialize());
115
+
}
116
+
117
+
#[tokio::test]
118
+
async fn state_store_sender_certificate_roundtrip() {
119
+
let store = test_store().await;
120
+
assert!(store.sender_certificate().await.unwrap().is_none());
121
+
}
122
+
123
+
#[tokio::test]
124
+
async fn state_store_clear_registration() {
125
+
let mut store = test_store().await;
126
+
127
+
sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)")
128
+
.bind(b"dummy-data".as_slice())
129
+
.execute(&store.db)
130
+
.await
131
+
.unwrap();
132
+
133
+
let mut ps = protocol_store(&store, IdentityType::Aci);
134
+
let keypair = KeyPair::generate(&mut rand::rng());
135
+
let record = PreKeyRecord::new(PreKeyId::from(1u32), &keypair);
136
+
ps.save_pre_key(PreKeyId::from(1u32), &record)
137
+
.await
138
+
.unwrap();
139
+
140
+
store.clear_registration().await.unwrap();
141
+
142
+
let remaining: Option<Vec<u8>> =
143
+
sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'registration'")
144
+
.fetch_optional(&store.db)
145
+
.await
146
+
.unwrap();
147
+
assert!(remaining.is_none());
148
+
149
+
assert!(ps.get_pre_key(PreKeyId::from(1u32)).await.is_err());
150
+
}
151
+
152
+
#[tokio::test]
153
+
async fn session_store_crud() {
154
+
let store = test_store().await;
155
+
let mut ps = protocol_store(&store, IdentityType::Aci);
156
+
157
+
let addr = ProtocolAddress::new("test-uuid".into(), DeviceId::new(1).unwrap());
158
+
assert!(ps.load_session(&addr).await.unwrap().is_none());
159
+
160
+
let record = SessionRecord::new_fresh();
161
+
ps.store_session(&addr, &record).await.unwrap();
162
+
163
+
let loaded = ps.load_session(&addr).await.unwrap();
164
+
assert!(loaded.is_some());
165
+
166
+
ps.store_session(&addr, &record).await.unwrap();
167
+
let loaded2 = ps.load_session(&addr).await.unwrap();
168
+
assert!(loaded2.is_some());
169
+
}
170
+
171
+
#[tokio::test]
172
+
async fn session_store_sub_devices() {
173
+
let store = test_store().await;
174
+
let mut ps = protocol_store(&store, IdentityType::Aci);
175
+
176
+
let uuid = Uuid::new_v4();
177
+
let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into();
178
+
let addr1 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(1).unwrap());
179
+
let addr2 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(2).unwrap());
180
+
let addr3 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(3).unwrap());
181
+
182
+
let record = SessionRecord::new_fresh();
183
+
ps.store_session(&addr1, &record).await.unwrap();
184
+
ps.store_session(&addr2, &record).await.unwrap();
185
+
ps.store_session(&addr3, &record).await.unwrap();
186
+
187
+
let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap();
188
+
assert_eq!(sub_devices.len(), 2);
189
+
190
+
let deleted = ps.delete_all_sessions(&service_id).await.unwrap();
191
+
assert_eq!(deleted, 3);
192
+
193
+
let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap();
194
+
assert!(sub_devices.is_empty());
195
+
}
196
+
197
+
#[tokio::test]
198
+
async fn pre_key_store_crud() {
199
+
let store = test_store().await;
200
+
let mut ps = protocol_store(&store, IdentityType::Aci);
201
+
202
+
let keypair = KeyPair::generate(&mut rand::rng());
203
+
let id = PreKeyId::from(42u32);
204
+
let record = PreKeyRecord::new(id, &keypair);
205
+
206
+
ps.save_pre_key(id, &record).await.unwrap();
207
+
let loaded = ps.get_pre_key(id).await.unwrap();
208
+
assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap());
209
+
210
+
ps.remove_pre_key(id).await.unwrap();
211
+
assert!(ps.get_pre_key(id).await.is_err());
212
+
}
213
+
214
+
#[tokio::test]
215
+
async fn pre_key_store_next_ids() {
216
+
let store = test_store().await;
217
+
let mut ps = protocol_store(&store, IdentityType::Aci);
218
+
219
+
assert_eq!(ps.next_pre_key_id().await.unwrap(), 1);
220
+
221
+
let keypair = KeyPair::generate(&mut rand::rng());
222
+
let record = PreKeyRecord::new(PreKeyId::from(5u32), &keypair);
223
+
ps.save_pre_key(PreKeyId::from(5u32), &record)
224
+
.await
225
+
.unwrap();
226
+
227
+
assert_eq!(ps.next_pre_key_id().await.unwrap(), 6);
228
+
}
229
+
230
+
#[tokio::test]
231
+
async fn signed_pre_key_store_crud() {
232
+
let store = test_store().await;
233
+
let mut ps = protocol_store(&store, IdentityType::Aci);
234
+
235
+
let keypair = KeyPair::generate(&mut rand::rng());
236
+
let id = SignedPreKeyId::from(1u32);
237
+
let signature = keypair
238
+
.private_key
239
+
.calculate_signature(&keypair.public_key.serialize(), &mut rand::rng())
240
+
.unwrap();
241
+
let record =
242
+
SignedPreKeyRecord::new(id, Timestamp::from_epoch_millis(1000), &keypair, &signature);
243
+
244
+
ps.save_signed_pre_key(id, &record).await.unwrap();
245
+
let loaded = ps.get_signed_pre_key(id).await.unwrap();
246
+
assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap());
247
+
248
+
assert_eq!(ps.signed_pre_keys_count().await.unwrap(), 1);
249
+
assert_eq!(ps.next_signed_pre_key_id().await.unwrap(), 2);
250
+
}
251
+
252
+
#[tokio::test]
253
+
async fn kyber_pre_key_one_time_mark_used_deletes() {
254
+
let store = test_store().await;
255
+
let mut ps = protocol_store(&store, IdentityType::Aci);
256
+
257
+
let keypair = KeyPair::generate(&mut rand::rng());
258
+
let id = KyberPreKeyId::from(1u32);
259
+
let record = KyberPreKeyRecord::generate(
260
+
presage::libsignal_service::protocol::kem::KeyType::Kyber1024,
261
+
id,
262
+
&keypair.private_key,
263
+
)
264
+
.unwrap();
265
+
266
+
ps.save_kyber_pre_key(id, &record).await.unwrap();
267
+
assert!(ps.get_kyber_pre_key(id).await.is_ok());
268
+
269
+
let ec_prekey_id = SignedPreKeyId::from(1u32);
270
+
ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key)
271
+
.await
272
+
.unwrap();
273
+
274
+
assert!(ps.get_kyber_pre_key(id).await.is_err());
275
+
}
276
+
277
+
#[tokio::test]
278
+
async fn kyber_pre_key_last_resort_survives_mark_used() {
279
+
let store = test_store().await;
280
+
let mut ps = protocol_store(&store, IdentityType::Aci);
281
+
282
+
let keypair = KeyPair::generate(&mut rand::rng());
283
+
let id = KyberPreKeyId::from(1u32);
284
+
let record = KyberPreKeyRecord::generate(
285
+
presage::libsignal_service::protocol::kem::KeyType::Kyber1024,
286
+
id,
287
+
&keypair.private_key,
288
+
)
289
+
.unwrap();
290
+
291
+
ps.store_last_resort_kyber_pre_key(id, &record)
292
+
.await
293
+
.unwrap();
294
+
assert!(ps.get_kyber_pre_key(id).await.is_ok());
295
+
296
+
let ec_prekey_id = SignedPreKeyId::from(1u32);
297
+
ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key)
298
+
.await
299
+
.unwrap();
300
+
301
+
assert!(ps.get_kyber_pre_key(id).await.is_ok());
302
+
}
303
+
304
+
#[tokio::test]
305
+
async fn kyber_pre_key_last_resort_rejects_replayed_base_key() {
306
+
let store = test_store().await;
307
+
let mut ps = protocol_store(&store, IdentityType::Aci);
308
+
309
+
let keypair = KeyPair::generate(&mut rand::rng());
310
+
let id = KyberPreKeyId::from(1u32);
311
+
let record = KyberPreKeyRecord::generate(
312
+
presage::libsignal_service::protocol::kem::KeyType::Kyber1024,
313
+
id,
314
+
&keypair.private_key,
315
+
)
316
+
.unwrap();
317
+
318
+
ps.store_last_resort_kyber_pre_key(id, &record)
319
+
.await
320
+
.unwrap();
321
+
322
+
let ec_prekey_id = SignedPreKeyId::from(1u32);
323
+
ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key)
324
+
.await
325
+
.unwrap();
326
+
327
+
let replay_result = ps
328
+
.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key)
329
+
.await;
330
+
assert!(replay_result.is_err());
331
+
}
332
+
333
+
#[tokio::test]
334
+
async fn kyber_pre_key_last_resort_list() {
335
+
let store = test_store().await;
336
+
let mut ps = protocol_store(&store, IdentityType::Aci);
337
+
338
+
let keypair = KeyPair::generate(&mut rand::rng());
339
+
let id = KyberPreKeyId::from(1u32);
340
+
let record = KyberPreKeyRecord::generate(
341
+
presage::libsignal_service::protocol::kem::KeyType::Kyber1024,
342
+
id,
343
+
&keypair.private_key,
344
+
)
345
+
.unwrap();
346
+
347
+
assert!(
348
+
ps.load_last_resort_kyber_pre_keys()
349
+
.await
350
+
.unwrap()
351
+
.is_empty()
352
+
);
353
+
354
+
ps.store_last_resort_kyber_pre_key(id, &record)
355
+
.await
356
+
.unwrap();
357
+
358
+
let last_resorts = ps.load_last_resort_kyber_pre_keys().await.unwrap();
359
+
assert_eq!(last_resorts.len(), 1);
360
+
}
361
+
362
+
#[tokio::test]
363
+
async fn identity_store_crud() {
364
+
let store = test_store().await;
365
+
let mut ps = protocol_store(&store, IdentityType::Aci);
366
+
367
+
let addr = ProtocolAddress::new("test-addr".into(), DeviceId::new(1).unwrap());
368
+
let keypair = IdentityKeyPair::generate(&mut rand::rng());
369
+
let identity_key = keypair.identity_key();
370
+
371
+
assert!(ps.get_identity(&addr).await.unwrap().is_none());
372
+
373
+
ps.save_identity(&addr, identity_key).await.unwrap();
374
+
let loaded = ps.get_identity(&addr).await.unwrap().unwrap();
375
+
assert_eq!(loaded.serialize(), identity_key.serialize());
376
+
377
+
assert!(
378
+
ps.is_trusted_identity(&addr, identity_key, Direction::Receiving)
379
+
.await
380
+
.unwrap()
381
+
);
382
+
}
383
+
384
+
#[tokio::test]
385
+
async fn identity_store_aci_pni_isolation() {
386
+
let store = test_store().await;
387
+
let mut aci_store = protocol_store(&store, IdentityType::Aci);
388
+
let pni_store = protocol_store(&store, IdentityType::Pni);
389
+
390
+
let addr = ProtocolAddress::new("same-addr".into(), DeviceId::new(1).unwrap());
391
+
let keypair = IdentityKeyPair::generate(&mut rand::rng());
392
+
393
+
aci_store
394
+
.save_identity(&addr, keypair.identity_key())
395
+
.await
396
+
.unwrap();
397
+
398
+
assert!(aci_store.get_identity(&addr).await.unwrap().is_some());
399
+
assert!(pni_store.get_identity(&addr).await.unwrap().is_none());
400
+
}
401
+
402
+
#[tokio::test]
403
+
async fn sender_key_store_load_missing() {
404
+
let store = test_store().await;
405
+
let mut ps = protocol_store(&store, IdentityType::Aci);
406
+
407
+
let sender = ProtocolAddress::new("sender-uuid".into(), DeviceId::new(1).unwrap());
408
+
let dist_id = Uuid::new_v4();
409
+
410
+
assert!(
411
+
ps.load_sender_key(&sender, dist_id)
412
+
.await
413
+
.unwrap()
414
+
.is_none()
415
+
);
416
+
}
417
+
418
+
#[tokio::test]
419
+
async fn profile_key_store_roundtrip() {
420
+
let mut store = test_store().await;
421
+
422
+
let uuid = Uuid::new_v4();
423
+
let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into();
424
+
let key = ProfileKey { bytes: [42u8; 32] };
425
+
426
+
assert!(store.profile_key(&service_id).await.unwrap().is_none());
427
+
428
+
store.upsert_profile_key(&uuid, key).await.unwrap();
429
+
430
+
let loaded = store.profile_key(&service_id).await.unwrap().unwrap();
431
+
assert_eq!(loaded.bytes, key.bytes);
432
+
}
433
+
434
+
#[tokio::test]
435
+
async fn client_from_pool_returns_none_without_registration() {
436
+
let store = test_store().await;
437
+
let pool = store.db.clone();
438
+
439
+
let client =
440
+
crate::SignalClient::from_pool(&pool, tokio_util::sync::CancellationToken::new()).await;
441
+
assert!(client.is_none());
442
+
}
443
+
444
+
#[tokio::test]
445
+
async fn store_clear_removes_kv() {
446
+
let mut store = test_store().await;
447
+
448
+
store
449
+
.set_aci_identity_key_pair(IdentityKeyPair::generate(&mut rand::rng()))
450
+
.await
451
+
.unwrap();
452
+
453
+
sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)")
454
+
.bind(b"dummy".as_slice())
455
+
.execute(&store.db)
456
+
.await
457
+
.unwrap();
458
+
459
+
store.clear().await.unwrap();
460
+
461
+
let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM signal_kv")
462
+
.fetch_one(&store.db)
463
+
.await
464
+
.unwrap();
465
+
assert_eq!(count, 0);
466
+
}
History
1 round
0 comments
oyster.cafe
submitted
#0
1 commit
expand
collapse
test(signal): add protocol store integration tests
expand 0 comments
pull request successfully merged