That fuck shit the fascists are using
1/*
2 * Copyright 2023 Signal Messenger, LLC
3 * SPDX-License-Identifier: AGPL-3.0-only
4 */
5
6package org.tm.archive.database
7
8import junit.framework.TestCase.assertEquals
9import junit.framework.TestCase.assertNotNull
10import junit.framework.TestCase.assertNull
11import org.junit.Test
12import org.signal.core.util.readToSingleObject
13import org.signal.core.util.requireLongOrNull
14import org.signal.core.util.select
15import org.signal.core.util.update
16import org.signal.libsignal.protocol.ecc.Curve
17import org.signal.libsignal.protocol.kem.KEMKeyPair
18import org.signal.libsignal.protocol.kem.KEMKeyType
19import org.signal.libsignal.protocol.state.KyberPreKeyRecord
20import org.whispersystems.signalservice.api.push.ServiceId
21import org.whispersystems.signalservice.api.push.ServiceId.ACI
22import org.whispersystems.signalservice.api.push.ServiceId.PNI
23import java.util.UUID
24
25class KyberPreKeyTableTest {
26
27 private val aci: ACI = ACI.from(UUID.randomUUID())
28 private val pni: PNI = PNI.from(UUID.randomUUID())
29
30 @Test
31 fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
32 insertTestRecord(aci, id = 1)
33 insertTestRecord(aci, id = 2)
34 insertTestRecord(aci, id = 3, staleTime = 42)
35 insertTestRecord(pni, id = 4)
36
37 val now = System.currentTimeMillis()
38 SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(aci, now)
39
40 assertEquals(now, getStaleTime(aci, 1))
41 assertEquals(now, getStaleTime(aci, 2))
42 assertEquals(42L, getStaleTime(aci, 3))
43 assertEquals(0L, getStaleTime(pni, 4))
44 }
45
46 @Test
47 fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
48 insertTestRecord(aci, id = 1, staleTime = 10)
49 insertTestRecord(aci, id = 2, staleTime = 10)
50 insertTestRecord(aci, id = 3, staleTime = 10)
51 insertTestRecord(aci, id = 4, staleTime = 15)
52 insertTestRecord(aci, id = 5, staleTime = 0)
53
54 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
55
56 assertNull(getStaleTime(aci, 1))
57 assertNull(getStaleTime(aci, 2))
58 assertNull(getStaleTime(aci, 3))
59 assertNotNull(getStaleTime(aci, 4))
60 assertNotNull(getStaleTime(aci, 5))
61 }
62
63 @Test
64 fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
65 insertTestRecord(aci, id = 1, staleTime = 0)
66 insertTestRecord(aci, id = 2, staleTime = 0)
67 insertTestRecord(aci, id = 3, staleTime = 0)
68 insertTestRecord(aci, id = 4, staleTime = 0)
69 insertTestRecord(aci, id = 5, staleTime = 0)
70
71 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 1)
72
73 assertNotNull(getStaleTime(aci, 1))
74 assertNotNull(getStaleTime(aci, 2))
75 assertNotNull(getStaleTime(aci, 3))
76 assertNotNull(getStaleTime(aci, 4))
77 assertNotNull(getStaleTime(aci, 5))
78 }
79
80 @Test
81 fun deleteAllStaleBefore_respectMinCount() {
82 insertTestRecord(aci, id = 1, staleTime = 10)
83 insertTestRecord(aci, id = 2, staleTime = 10)
84 insertTestRecord(aci, id = 3, staleTime = 10)
85 insertTestRecord(aci, id = 4, staleTime = 10)
86 insertTestRecord(aci, id = 5, staleTime = 10)
87
88 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
89
90 assertNull(getStaleTime(aci, 1))
91 assertNull(getStaleTime(aci, 2))
92 assertNotNull(getStaleTime(aci, 3))
93 assertNotNull(getStaleTime(aci, 4))
94 assertNotNull(getStaleTime(aci, 5))
95 }
96
97 @Test
98 fun deleteAllStaleBefore_respectAccount() {
99 insertTestRecord(aci, id = 1, staleTime = 10)
100 insertTestRecord(aci, id = 2, staleTime = 10)
101 insertTestRecord(aci, id = 3, staleTime = 10)
102
103 insertTestRecord(pni, id = 4, staleTime = 10)
104 insertTestRecord(pni, id = 5, staleTime = 10)
105
106 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
107
108 assertNull(getStaleTime(aci, 1))
109 assertNotNull(getStaleTime(aci, 2))
110 assertNotNull(getStaleTime(aci, 3))
111 assertNotNull(getStaleTime(pni, 4))
112 assertNotNull(getStaleTime(pni, 5))
113 }
114
115 @Test
116 fun deleteAllStaleBefore_ignoreLastResortForMinCount() {
117 insertTestRecord(aci, id = 1, staleTime = 10)
118 insertTestRecord(aci, id = 2, staleTime = 10)
119 insertTestRecord(aci, id = 3, staleTime = 10)
120 insertTestRecord(aci, id = 4, staleTime = 10)
121 insertTestRecord(aci, id = 5, staleTime = 10, lastResort = true)
122
123 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
124
125 assertNull(getStaleTime(aci, 1))
126 assertNotNull(getStaleTime(aci, 2))
127 assertNotNull(getStaleTime(aci, 3))
128 assertNotNull(getStaleTime(aci, 4))
129 assertNotNull(getStaleTime(aci, 5))
130 }
131
132 @Test
133 fun deleteAllStaleBefore_neverDeleteLastResort() {
134 insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true)
135 insertTestRecord(aci, id = 2, staleTime = 10, lastResort = true)
136 insertTestRecord(aci, id = 3, staleTime = 10, lastResort = true)
137
138 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
139
140 assertNotNull(getStaleTime(aci, 1))
141 assertNotNull(getStaleTime(aci, 2))
142 assertNotNull(getStaleTime(aci, 3))
143 }
144
145 private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) {
146 val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024)
147 SignalDatabase.kyberPreKeys.insert(
148 serviceId = account,
149 keyId = id,
150 record = KyberPreKeyRecord(
151 id,
152 System.currentTimeMillis(),
153 kemKeyPair,
154 Curve.generateKeyPair().privateKey.calculateSignature(kemKeyPair.publicKey.serialize())
155 ),
156 lastResort = lastResort
157 )
158
159 val count = SignalDatabase.rawDatabase
160 .update(KyberPreKeyTable.TABLE_NAME)
161 .values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime)
162 .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
163 .run()
164
165 assertEquals(1, count)
166 }
167
168 private fun getStaleTime(account: ServiceId, id: Int): Long? {
169 return SignalDatabase.rawDatabase
170 .select(KyberPreKeyTable.STALE_TIMESTAMP)
171 .from(KyberPreKeyTable.TABLE_NAME)
172 .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account)
173 .run()
174 .readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) }
175 }
176}