That fuck shit the fascists are using
at master 176 lines 6.2 kB view raw
1/* 2 * Copyright 2023 Signal Messenger, LLC 3 * SPDX-License-Identifier: AGPL-3.0-only 4 */ 5 6package org.tm.archive.database 7 8import junit.framework.TestCase.assertEquals 9import junit.framework.TestCase.assertNotNull 10import junit.framework.TestCase.assertNull 11import org.junit.Test 12import org.signal.core.util.readToSingleObject 13import org.signal.core.util.requireLongOrNull 14import org.signal.core.util.select 15import org.signal.core.util.update 16import org.signal.libsignal.protocol.ecc.Curve 17import org.signal.libsignal.protocol.kem.KEMKeyPair 18import org.signal.libsignal.protocol.kem.KEMKeyType 19import org.signal.libsignal.protocol.state.KyberPreKeyRecord 20import org.whispersystems.signalservice.api.push.ServiceId 21import org.whispersystems.signalservice.api.push.ServiceId.ACI 22import org.whispersystems.signalservice.api.push.ServiceId.PNI 23import java.util.UUID 24 25class KyberPreKeyTableTest { 26 27 private val aci: ACI = ACI.from(UUID.randomUUID()) 28 private val pni: PNI = PNI.from(UUID.randomUUID()) 29 30 @Test 31 fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() { 32 insertTestRecord(aci, id = 1) 33 insertTestRecord(aci, id = 2) 34 insertTestRecord(aci, id = 3, staleTime = 42) 35 insertTestRecord(pni, id = 4) 36 37 val now = System.currentTimeMillis() 38 SignalDatabase.kyberPreKeys.markAllStaleIfNecessary(aci, now) 39 40 assertEquals(now, getStaleTime(aci, 1)) 41 assertEquals(now, getStaleTime(aci, 2)) 42 assertEquals(42L, getStaleTime(aci, 3)) 43 assertEquals(0L, getStaleTime(pni, 4)) 44 } 45 46 @Test 47 fun deleteAllStaleBefore_deleteOldBeforeThreshold() { 48 insertTestRecord(aci, id = 1, staleTime = 10) 49 insertTestRecord(aci, id = 2, staleTime = 10) 50 insertTestRecord(aci, id = 3, staleTime = 10) 51 insertTestRecord(aci, id = 4, staleTime = 15) 52 insertTestRecord(aci, id = 5, staleTime = 0) 53 54 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) 55 56 assertNull(getStaleTime(aci, 1)) 57 assertNull(getStaleTime(aci, 2)) 58 assertNull(getStaleTime(aci, 3)) 59 assertNotNull(getStaleTime(aci, 4)) 60 assertNotNull(getStaleTime(aci, 5)) 61 } 62 63 @Test 64 fun deleteAllStaleBefore_neverDeleteStaleOfZero() { 65 insertTestRecord(aci, id = 1, staleTime = 0) 66 insertTestRecord(aci, id = 2, staleTime = 0) 67 insertTestRecord(aci, id = 3, staleTime = 0) 68 insertTestRecord(aci, id = 4, staleTime = 0) 69 insertTestRecord(aci, id = 5, staleTime = 0) 70 71 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 1) 72 73 assertNotNull(getStaleTime(aci, 1)) 74 assertNotNull(getStaleTime(aci, 2)) 75 assertNotNull(getStaleTime(aci, 3)) 76 assertNotNull(getStaleTime(aci, 4)) 77 assertNotNull(getStaleTime(aci, 5)) 78 } 79 80 @Test 81 fun deleteAllStaleBefore_respectMinCount() { 82 insertTestRecord(aci, id = 1, staleTime = 10) 83 insertTestRecord(aci, id = 2, staleTime = 10) 84 insertTestRecord(aci, id = 3, staleTime = 10) 85 insertTestRecord(aci, id = 4, staleTime = 10) 86 insertTestRecord(aci, id = 5, staleTime = 10) 87 88 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) 89 90 assertNull(getStaleTime(aci, 1)) 91 assertNull(getStaleTime(aci, 2)) 92 assertNotNull(getStaleTime(aci, 3)) 93 assertNotNull(getStaleTime(aci, 4)) 94 assertNotNull(getStaleTime(aci, 5)) 95 } 96 97 @Test 98 fun deleteAllStaleBefore_respectAccount() { 99 insertTestRecord(aci, id = 1, staleTime = 10) 100 insertTestRecord(aci, id = 2, staleTime = 10) 101 insertTestRecord(aci, id = 3, staleTime = 10) 102 103 insertTestRecord(pni, id = 4, staleTime = 10) 104 insertTestRecord(pni, id = 5, staleTime = 10) 105 106 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2) 107 108 assertNull(getStaleTime(aci, 1)) 109 assertNotNull(getStaleTime(aci, 2)) 110 assertNotNull(getStaleTime(aci, 3)) 111 assertNotNull(getStaleTime(pni, 4)) 112 assertNotNull(getStaleTime(pni, 5)) 113 } 114 115 @Test 116 fun deleteAllStaleBefore_ignoreLastResortForMinCount() { 117 insertTestRecord(aci, id = 1, staleTime = 10) 118 insertTestRecord(aci, id = 2, staleTime = 10) 119 insertTestRecord(aci, id = 3, staleTime = 10) 120 insertTestRecord(aci, id = 4, staleTime = 10) 121 insertTestRecord(aci, id = 5, staleTime = 10, lastResort = true) 122 123 SignalDatabase.kyberPreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) 124 125 assertNull(getStaleTime(aci, 1)) 126 assertNotNull(getStaleTime(aci, 2)) 127 assertNotNull(getStaleTime(aci, 3)) 128 assertNotNull(getStaleTime(aci, 4)) 129 assertNotNull(getStaleTime(aci, 5)) 130 } 131 132 @Test 133 fun deleteAllStaleBefore_neverDeleteLastResort() { 134 insertTestRecord(aci, id = 1, staleTime = 10, lastResort = true) 135 insertTestRecord(aci, id = 2, staleTime = 10, lastResort = true) 136 insertTestRecord(aci, id = 3, staleTime = 10, lastResort = true) 137 138 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) 139 140 assertNotNull(getStaleTime(aci, 1)) 141 assertNotNull(getStaleTime(aci, 2)) 142 assertNotNull(getStaleTime(aci, 3)) 143 } 144 145 private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0, lastResort: Boolean = false) { 146 val kemKeyPair = KEMKeyPair.generate(KEMKeyType.KYBER_1024) 147 SignalDatabase.kyberPreKeys.insert( 148 serviceId = account, 149 keyId = id, 150 record = KyberPreKeyRecord( 151 id, 152 System.currentTimeMillis(), 153 kemKeyPair, 154 Curve.generateKeyPair().privateKey.calculateSignature(kemKeyPair.publicKey.serialize()) 155 ), 156 lastResort = lastResort 157 ) 158 159 val count = SignalDatabase.rawDatabase 160 .update(KyberPreKeyTable.TABLE_NAME) 161 .values(KyberPreKeyTable.STALE_TIMESTAMP to staleTime) 162 .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account) 163 .run() 164 165 assertEquals(1, count) 166 } 167 168 private fun getStaleTime(account: ServiceId, id: Int): Long? { 169 return SignalDatabase.rawDatabase 170 .select(KyberPreKeyTable.STALE_TIMESTAMP) 171 .from(KyberPreKeyTable.TABLE_NAME) 172 .where("${KyberPreKeyTable.ACCOUNT_ID} = ? AND ${KyberPreKeyTable.KEY_ID} = $id", account) 173 .run() 174 .readToSingleObject { it.requireLongOrNull(KyberPreKeyTable.STALE_TIMESTAMP) } 175 } 176}