That fuck shit the fascists are using
1/*
2 * Copyright 2023 Signal Messenger, LLC
3 * SPDX-License-Identifier: AGPL-3.0-only
4 */
5
6package org.tm.archive.database
7
8import junit.framework.TestCase.assertEquals
9import junit.framework.TestCase.assertNotNull
10import junit.framework.TestCase.assertNull
11import org.junit.Test
12import org.signal.core.util.readToSingleObject
13import org.signal.core.util.requireLongOrNull
14import org.signal.core.util.select
15import org.signal.core.util.update
16import org.signal.libsignal.protocol.ecc.Curve
17import org.signal.libsignal.protocol.state.PreKeyRecord
18import org.whispersystems.signalservice.api.push.ServiceId
19import org.whispersystems.signalservice.api.push.ServiceId.ACI
20import org.whispersystems.signalservice.api.push.ServiceId.PNI
21import java.util.UUID
22
23class OneTimePreKeyTableTest {
24
25 private val aci: ACI = ACI.from(UUID.randomUUID())
26 private val pni: PNI = PNI.from(UUID.randomUUID())
27
28 @Test
29 fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() {
30 insertTestRecord(aci, id = 1)
31 insertTestRecord(aci, id = 2)
32 insertTestRecord(aci, id = 3, staleTime = 42)
33 insertTestRecord(pni, id = 4)
34
35 val now = System.currentTimeMillis()
36 SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now)
37
38 assertEquals(now, getStaleTime(aci, 1))
39 assertEquals(now, getStaleTime(aci, 2))
40 assertEquals(42L, getStaleTime(aci, 3))
41 assertEquals(0L, getStaleTime(pni, 4))
42 }
43
44 @Test
45 fun deleteAllStaleBefore_deleteOldBeforeThreshold() {
46 insertTestRecord(aci, id = 1, staleTime = 10)
47 insertTestRecord(aci, id = 2, staleTime = 10)
48 insertTestRecord(aci, id = 3, staleTime = 10)
49 insertTestRecord(aci, id = 4, staleTime = 15)
50 insertTestRecord(aci, id = 5, staleTime = 0)
51
52 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0)
53
54 assertNull(getStaleTime(aci, 1))
55 assertNull(getStaleTime(aci, 2))
56 assertNull(getStaleTime(aci, 3))
57 assertNotNull(getStaleTime(aci, 4))
58 assertNotNull(getStaleTime(aci, 5))
59 }
60
61 @Test
62 fun deleteAllStaleBefore_neverDeleteStaleOfZero() {
63 insertTestRecord(aci, id = 1, staleTime = 0)
64 insertTestRecord(aci, id = 2, staleTime = 0)
65 insertTestRecord(aci, id = 3, staleTime = 0)
66 insertTestRecord(aci, id = 4, staleTime = 0)
67 insertTestRecord(aci, id = 5, staleTime = 0)
68
69 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0)
70
71 assertNotNull(getStaleTime(aci, 1))
72 assertNotNull(getStaleTime(aci, 2))
73 assertNotNull(getStaleTime(aci, 3))
74 assertNotNull(getStaleTime(aci, 4))
75 assertNotNull(getStaleTime(aci, 5))
76 }
77
78 @Test
79 fun deleteAllStaleBefore_respectMinCount() {
80 insertTestRecord(aci, id = 1, staleTime = 10)
81 insertTestRecord(aci, id = 2, staleTime = 10)
82 insertTestRecord(aci, id = 3, staleTime = 10)
83 insertTestRecord(aci, id = 4, staleTime = 10)
84 insertTestRecord(aci, id = 5, staleTime = 10)
85
86 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3)
87
88 assertNull(getStaleTime(aci, 1))
89 assertNull(getStaleTime(aci, 2))
90 assertNotNull(getStaleTime(aci, 3))
91 assertNotNull(getStaleTime(aci, 4))
92 assertNotNull(getStaleTime(aci, 5))
93 }
94
95 @Test
96 fun deleteAllStaleBefore_respectAccount() {
97 insertTestRecord(aci, id = 1, staleTime = 10)
98 insertTestRecord(aci, id = 2, staleTime = 10)
99 insertTestRecord(aci, id = 3, staleTime = 10)
100
101 insertTestRecord(pni, id = 4, staleTime = 10)
102 insertTestRecord(pni, id = 5, staleTime = 10)
103
104 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2)
105
106 assertNull(getStaleTime(aci, 1))
107 assertNotNull(getStaleTime(aci, 2))
108 assertNotNull(getStaleTime(aci, 3))
109 assertNotNull(getStaleTime(pni, 4))
110 assertNotNull(getStaleTime(pni, 5))
111 }
112
113 private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) {
114 SignalDatabase.oneTimePreKeys.insert(
115 serviceId = account,
116 keyId = id,
117 record = PreKeyRecord(id, Curve.generateKeyPair())
118 )
119
120 val count = SignalDatabase.rawDatabase
121 .update(OneTimePreKeyTable.TABLE_NAME)
122 .values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime)
123 .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
124 .run()
125
126 assertEquals(1, count)
127 }
128
129 private fun getStaleTime(account: ServiceId, id: Int): Long? {
130 return SignalDatabase.rawDatabase
131 .select(OneTimePreKeyTable.STALE_TIMESTAMP)
132 .from(OneTimePreKeyTable.TABLE_NAME)
133 .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account)
134 .run()
135 .readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) }
136 }
137}