Our Personal Data Server from scratch! tranquil.farm
atproto pds rust postgresql fun oauth

test(signal): add protocol store integration tests #90

merged opened by oyster.cafe targeting main from feat/signal-client-in-house
Labels

None yet.

assignee

None yet.

Participants 1
AT URI
at://did:plc:3fwecdnvtcscjnrx2p4n7alz/sh.tangled.repo.pull/3mhlxir6epv22
+466
Diff #0
+466
crates/tranquil-signal/src/tests.rs
··· 1 + use presage::libsignal_service::{ 2 + pre_keys::{KyberPreKeyStoreExt, PreKeysStore}, 3 + prelude::{ProfileKey, SessionStoreExt}, 4 + protocol::{ 5 + DeviceId, Direction, GenericSignedPreKey, IdentityKeyPair, IdentityKeyStore, KeyPair, 6 + KyberPreKeyId, KyberPreKeyRecord, KyberPreKeyStore, PreKeyId, PreKeyRecord, PreKeyStore, 7 + ProtocolAddress, SenderKeyStore, ServiceId, SessionRecord, SessionStore, SignedPreKeyId, 8 + SignedPreKeyRecord, SignedPreKeyStore, Timestamp, 9 + }, 10 + }; 11 + use presage::store::{ContentsStore, StateStore, Store}; 12 + use sqlx::postgres::PgPoolOptions; 13 + use uuid::Uuid; 14 + 15 + use crate::store::{IdentityType, PgProtocolStore, PgSignalStore}; 16 + 17 + async fn test_store() -> PgSignalStore { 18 + let url = std::env::var("DATABASE_URL") 19 + .unwrap_or_else(|_| "postgres://postgres:postgres@127.0.0.1:5432/postgres".into()); 20 + 21 + let pool = PgPoolOptions::new() 22 + .max_connections(5) 23 + .connect(&url) 24 + .await 25 + .unwrap(); 26 + 27 + sqlx::query("DELETE FROM signal_kv") 28 + .execute(&pool) 29 + .await 30 + .ok(); 31 + sqlx::query("DELETE FROM signal_base_keys_seen") 32 + .execute(&pool) 33 + .await 34 + .ok(); 35 + sqlx::query("DELETE FROM signal_sender_keys") 36 + .execute(&pool) 37 + .await 38 + .ok(); 39 + sqlx::query("DELETE FROM signal_sessions") 40 + .execute(&pool) 41 + .await 42 + .ok(); 43 + sqlx::query("DELETE FROM signal_identities") 44 + .execute(&pool) 45 + .await 46 + .ok(); 47 + sqlx::query("DELETE FROM signal_kyber_pre_keys") 48 + .execute(&pool) 49 + .await 50 + .ok(); 51 + sqlx::query("DELETE FROM signal_signed_pre_keys") 52 + .execute(&pool) 53 + .await 54 + .ok(); 55 + sqlx::query("DELETE FROM signal_pre_keys") 56 + .execute(&pool) 57 + .await 58 + .ok(); 59 + sqlx::query("DELETE FROM signal_profile_keys") 60 + .execute(&pool) 61 + .await 62 + .ok(); 63 + 64 + PgSignalStore::new(pool) 65 + } 66 + 67 + fn protocol_store(store: &PgSignalStore, identity: IdentityType) -> PgProtocolStore { 68 + PgProtocolStore::new(store.clone(), identity) 69 + } 70 + 71 + #[tokio::test] 72 + async fn state_store_registration_empty() { 73 + let store = test_store().await; 74 + 75 + assert!(store.load_registration_data().await.unwrap().is_none()); 76 + assert!(!store.is_registered().await); 77 + } 78 + 79 + #[tokio::test] 80 + async fn state_store_kv_roundtrip() { 81 + let store = test_store().await; 82 + 83 + let value = b"test-data".to_vec(); 84 + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('test_key', $1)") 85 + .bind(&value) 86 + .execute(&store.db) 87 + .await 88 + .unwrap(); 89 + 90 + let loaded: Vec<u8> = sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'test_key'") 91 + .fetch_one(&store.db) 92 + .await 93 + .unwrap(); 94 + assert_eq!(loaded, value); 95 + } 96 + 97 + #[tokio::test] 98 + async fn state_store_identity_keypairs() { 99 + let store = test_store().await; 100 + 101 + let aci_pair = IdentityKeyPair::generate(&mut rand::rng()); 102 + let pni_pair = IdentityKeyPair::generate(&mut rand::rng()); 103 + 104 + store.set_aci_identity_key_pair(aci_pair).await.unwrap(); 105 + store.set_pni_identity_key_pair(pni_pair).await.unwrap(); 106 + 107 + let aci_store = protocol_store(&store, IdentityType::Aci); 108 + let pni_store = protocol_store(&store, IdentityType::Pni); 109 + 110 + let loaded_aci = aci_store.get_identity_key_pair().await.unwrap(); 111 + let loaded_pni = pni_store.get_identity_key_pair().await.unwrap(); 112 + 113 + assert_eq!(loaded_aci.serialize(), aci_pair.serialize()); 114 + assert_eq!(loaded_pni.serialize(), pni_pair.serialize()); 115 + } 116 + 117 + #[tokio::test] 118 + async fn state_store_sender_certificate_roundtrip() { 119 + let store = test_store().await; 120 + assert!(store.sender_certificate().await.unwrap().is_none()); 121 + } 122 + 123 + #[tokio::test] 124 + async fn state_store_clear_registration() { 125 + let mut store = test_store().await; 126 + 127 + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)") 128 + .bind(b"dummy-data".as_slice()) 129 + .execute(&store.db) 130 + .await 131 + .unwrap(); 132 + 133 + let mut ps = protocol_store(&store, IdentityType::Aci); 134 + let keypair = KeyPair::generate(&mut rand::rng()); 135 + let record = PreKeyRecord::new(PreKeyId::from(1u32), &keypair); 136 + ps.save_pre_key(PreKeyId::from(1u32), &record) 137 + .await 138 + .unwrap(); 139 + 140 + store.clear_registration().await.unwrap(); 141 + 142 + let remaining: Option<Vec<u8>> = 143 + sqlx::query_scalar("SELECT value FROM signal_kv WHERE key = 'registration'") 144 + .fetch_optional(&store.db) 145 + .await 146 + .unwrap(); 147 + assert!(remaining.is_none()); 148 + 149 + assert!(ps.get_pre_key(PreKeyId::from(1u32)).await.is_err()); 150 + } 151 + 152 + #[tokio::test] 153 + async fn session_store_crud() { 154 + let store = test_store().await; 155 + let mut ps = protocol_store(&store, IdentityType::Aci); 156 + 157 + let addr = ProtocolAddress::new("test-uuid".into(), DeviceId::new(1).unwrap()); 158 + assert!(ps.load_session(&addr).await.unwrap().is_none()); 159 + 160 + let record = SessionRecord::new_fresh(); 161 + ps.store_session(&addr, &record).await.unwrap(); 162 + 163 + let loaded = ps.load_session(&addr).await.unwrap(); 164 + assert!(loaded.is_some()); 165 + 166 + ps.store_session(&addr, &record).await.unwrap(); 167 + let loaded2 = ps.load_session(&addr).await.unwrap(); 168 + assert!(loaded2.is_some()); 169 + } 170 + 171 + #[tokio::test] 172 + async fn session_store_sub_devices() { 173 + let store = test_store().await; 174 + let mut ps = protocol_store(&store, IdentityType::Aci); 175 + 176 + let uuid = Uuid::new_v4(); 177 + let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into(); 178 + let addr1 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(1).unwrap()); 179 + let addr2 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(2).unwrap()); 180 + let addr3 = ProtocolAddress::new(uuid.to_string(), DeviceId::new(3).unwrap()); 181 + 182 + let record = SessionRecord::new_fresh(); 183 + ps.store_session(&addr1, &record).await.unwrap(); 184 + ps.store_session(&addr2, &record).await.unwrap(); 185 + ps.store_session(&addr3, &record).await.unwrap(); 186 + 187 + let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap(); 188 + assert_eq!(sub_devices.len(), 2); 189 + 190 + let deleted = ps.delete_all_sessions(&service_id).await.unwrap(); 191 + assert_eq!(deleted, 3); 192 + 193 + let sub_devices = ps.get_sub_device_sessions(&service_id).await.unwrap(); 194 + assert!(sub_devices.is_empty()); 195 + } 196 + 197 + #[tokio::test] 198 + async fn pre_key_store_crud() { 199 + let store = test_store().await; 200 + let mut ps = protocol_store(&store, IdentityType::Aci); 201 + 202 + let keypair = KeyPair::generate(&mut rand::rng()); 203 + let id = PreKeyId::from(42u32); 204 + let record = PreKeyRecord::new(id, &keypair); 205 + 206 + ps.save_pre_key(id, &record).await.unwrap(); 207 + let loaded = ps.get_pre_key(id).await.unwrap(); 208 + assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap()); 209 + 210 + ps.remove_pre_key(id).await.unwrap(); 211 + assert!(ps.get_pre_key(id).await.is_err()); 212 + } 213 + 214 + #[tokio::test] 215 + async fn pre_key_store_next_ids() { 216 + let store = test_store().await; 217 + let mut ps = protocol_store(&store, IdentityType::Aci); 218 + 219 + assert_eq!(ps.next_pre_key_id().await.unwrap(), 1); 220 + 221 + let keypair = KeyPair::generate(&mut rand::rng()); 222 + let record = PreKeyRecord::new(PreKeyId::from(5u32), &keypair); 223 + ps.save_pre_key(PreKeyId::from(5u32), &record) 224 + .await 225 + .unwrap(); 226 + 227 + assert_eq!(ps.next_pre_key_id().await.unwrap(), 6); 228 + } 229 + 230 + #[tokio::test] 231 + async fn signed_pre_key_store_crud() { 232 + let store = test_store().await; 233 + let mut ps = protocol_store(&store, IdentityType::Aci); 234 + 235 + let keypair = KeyPair::generate(&mut rand::rng()); 236 + let id = SignedPreKeyId::from(1u32); 237 + let signature = keypair 238 + .private_key 239 + .calculate_signature(&keypair.public_key.serialize(), &mut rand::rng()) 240 + .unwrap(); 241 + let record = 242 + SignedPreKeyRecord::new(id, Timestamp::from_epoch_millis(1000), &keypair, &signature); 243 + 244 + ps.save_signed_pre_key(id, &record).await.unwrap(); 245 + let loaded = ps.get_signed_pre_key(id).await.unwrap(); 246 + assert_eq!(loaded.serialize().unwrap(), record.serialize().unwrap()); 247 + 248 + assert_eq!(ps.signed_pre_keys_count().await.unwrap(), 1); 249 + assert_eq!(ps.next_signed_pre_key_id().await.unwrap(), 2); 250 + } 251 + 252 + #[tokio::test] 253 + async fn kyber_pre_key_one_time_mark_used_deletes() { 254 + let store = test_store().await; 255 + let mut ps = protocol_store(&store, IdentityType::Aci); 256 + 257 + let keypair = KeyPair::generate(&mut rand::rng()); 258 + let id = KyberPreKeyId::from(1u32); 259 + let record = KyberPreKeyRecord::generate( 260 + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, 261 + id, 262 + &keypair.private_key, 263 + ) 264 + .unwrap(); 265 + 266 + ps.save_kyber_pre_key(id, &record).await.unwrap(); 267 + assert!(ps.get_kyber_pre_key(id).await.is_ok()); 268 + 269 + let ec_prekey_id = SignedPreKeyId::from(1u32); 270 + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) 271 + .await 272 + .unwrap(); 273 + 274 + assert!(ps.get_kyber_pre_key(id).await.is_err()); 275 + } 276 + 277 + #[tokio::test] 278 + async fn kyber_pre_key_last_resort_survives_mark_used() { 279 + let store = test_store().await; 280 + let mut ps = protocol_store(&store, IdentityType::Aci); 281 + 282 + let keypair = KeyPair::generate(&mut rand::rng()); 283 + let id = KyberPreKeyId::from(1u32); 284 + let record = KyberPreKeyRecord::generate( 285 + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, 286 + id, 287 + &keypair.private_key, 288 + ) 289 + .unwrap(); 290 + 291 + ps.store_last_resort_kyber_pre_key(id, &record) 292 + .await 293 + .unwrap(); 294 + assert!(ps.get_kyber_pre_key(id).await.is_ok()); 295 + 296 + let ec_prekey_id = SignedPreKeyId::from(1u32); 297 + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) 298 + .await 299 + .unwrap(); 300 + 301 + assert!(ps.get_kyber_pre_key(id).await.is_ok()); 302 + } 303 + 304 + #[tokio::test] 305 + async fn kyber_pre_key_last_resort_rejects_replayed_base_key() { 306 + let store = test_store().await; 307 + let mut ps = protocol_store(&store, IdentityType::Aci); 308 + 309 + let keypair = KeyPair::generate(&mut rand::rng()); 310 + let id = KyberPreKeyId::from(1u32); 311 + let record = KyberPreKeyRecord::generate( 312 + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, 313 + id, 314 + &keypair.private_key, 315 + ) 316 + .unwrap(); 317 + 318 + ps.store_last_resort_kyber_pre_key(id, &record) 319 + .await 320 + .unwrap(); 321 + 322 + let ec_prekey_id = SignedPreKeyId::from(1u32); 323 + ps.mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) 324 + .await 325 + .unwrap(); 326 + 327 + let replay_result = ps 328 + .mark_kyber_pre_key_used(id, ec_prekey_id, &keypair.public_key) 329 + .await; 330 + assert!(replay_result.is_err()); 331 + } 332 + 333 + #[tokio::test] 334 + async fn kyber_pre_key_last_resort_list() { 335 + let store = test_store().await; 336 + let mut ps = protocol_store(&store, IdentityType::Aci); 337 + 338 + let keypair = KeyPair::generate(&mut rand::rng()); 339 + let id = KyberPreKeyId::from(1u32); 340 + let record = KyberPreKeyRecord::generate( 341 + presage::libsignal_service::protocol::kem::KeyType::Kyber1024, 342 + id, 343 + &keypair.private_key, 344 + ) 345 + .unwrap(); 346 + 347 + assert!( 348 + ps.load_last_resort_kyber_pre_keys() 349 + .await 350 + .unwrap() 351 + .is_empty() 352 + ); 353 + 354 + ps.store_last_resort_kyber_pre_key(id, &record) 355 + .await 356 + .unwrap(); 357 + 358 + let last_resorts = ps.load_last_resort_kyber_pre_keys().await.unwrap(); 359 + assert_eq!(last_resorts.len(), 1); 360 + } 361 + 362 + #[tokio::test] 363 + async fn identity_store_crud() { 364 + let store = test_store().await; 365 + let mut ps = protocol_store(&store, IdentityType::Aci); 366 + 367 + let addr = ProtocolAddress::new("test-addr".into(), DeviceId::new(1).unwrap()); 368 + let keypair = IdentityKeyPair::generate(&mut rand::rng()); 369 + let identity_key = keypair.identity_key(); 370 + 371 + assert!(ps.get_identity(&addr).await.unwrap().is_none()); 372 + 373 + ps.save_identity(&addr, identity_key).await.unwrap(); 374 + let loaded = ps.get_identity(&addr).await.unwrap().unwrap(); 375 + assert_eq!(loaded.serialize(), identity_key.serialize()); 376 + 377 + assert!( 378 + ps.is_trusted_identity(&addr, identity_key, Direction::Receiving) 379 + .await 380 + .unwrap() 381 + ); 382 + } 383 + 384 + #[tokio::test] 385 + async fn identity_store_aci_pni_isolation() { 386 + let store = test_store().await; 387 + let mut aci_store = protocol_store(&store, IdentityType::Aci); 388 + let pni_store = protocol_store(&store, IdentityType::Pni); 389 + 390 + let addr = ProtocolAddress::new("same-addr".into(), DeviceId::new(1).unwrap()); 391 + let keypair = IdentityKeyPair::generate(&mut rand::rng()); 392 + 393 + aci_store 394 + .save_identity(&addr, keypair.identity_key()) 395 + .await 396 + .unwrap(); 397 + 398 + assert!(aci_store.get_identity(&addr).await.unwrap().is_some()); 399 + assert!(pni_store.get_identity(&addr).await.unwrap().is_none()); 400 + } 401 + 402 + #[tokio::test] 403 + async fn sender_key_store_load_missing() { 404 + let store = test_store().await; 405 + let mut ps = protocol_store(&store, IdentityType::Aci); 406 + 407 + let sender = ProtocolAddress::new("sender-uuid".into(), DeviceId::new(1).unwrap()); 408 + let dist_id = Uuid::new_v4(); 409 + 410 + assert!( 411 + ps.load_sender_key(&sender, dist_id) 412 + .await 413 + .unwrap() 414 + .is_none() 415 + ); 416 + } 417 + 418 + #[tokio::test] 419 + async fn profile_key_store_roundtrip() { 420 + let mut store = test_store().await; 421 + 422 + let uuid = Uuid::new_v4(); 423 + let service_id: ServiceId = presage::libsignal_service::protocol::Aci::from(uuid).into(); 424 + let key = ProfileKey { bytes: [42u8; 32] }; 425 + 426 + assert!(store.profile_key(&service_id).await.unwrap().is_none()); 427 + 428 + store.upsert_profile_key(&uuid, key).await.unwrap(); 429 + 430 + let loaded = store.profile_key(&service_id).await.unwrap().unwrap(); 431 + assert_eq!(loaded.bytes, key.bytes); 432 + } 433 + 434 + #[tokio::test] 435 + async fn client_from_pool_returns_none_without_registration() { 436 + let store = test_store().await; 437 + let pool = store.db.clone(); 438 + 439 + let client = 440 + crate::SignalClient::from_pool(&pool, tokio_util::sync::CancellationToken::new()).await; 441 + assert!(client.is_none()); 442 + } 443 + 444 + #[tokio::test] 445 + async fn store_clear_removes_kv() { 446 + let mut store = test_store().await; 447 + 448 + store 449 + .set_aci_identity_key_pair(IdentityKeyPair::generate(&mut rand::rng())) 450 + .await 451 + .unwrap(); 452 + 453 + sqlx::query("INSERT INTO signal_kv (key, value) VALUES ('registration', $1)") 454 + .bind(b"dummy".as_slice()) 455 + .execute(&store.db) 456 + .await 457 + .unwrap(); 458 + 459 + store.clear().await.unwrap(); 460 + 461 + let count: i64 = sqlx::query_scalar("SELECT COUNT(*) FROM signal_kv") 462 + .fetch_one(&store.db) 463 + .await 464 + .unwrap(); 465 + assert_eq!(count, 0); 466 + }

History

1 round 0 comments
sign up or login to add to the discussion
oyster.cafe submitted #0
1 commit
expand
test(signal): add protocol store integration tests
expand 0 comments
pull request successfully merged