That fuck shit the fascists are using
at master 137 lines 4.8 kB view raw
1/* 2 * Copyright 2023 Signal Messenger, LLC 3 * SPDX-License-Identifier: AGPL-3.0-only 4 */ 5 6package org.tm.archive.database 7 8import junit.framework.TestCase.assertEquals 9import junit.framework.TestCase.assertNotNull 10import junit.framework.TestCase.assertNull 11import org.junit.Test 12import org.signal.core.util.readToSingleObject 13import org.signal.core.util.requireLongOrNull 14import org.signal.core.util.select 15import org.signal.core.util.update 16import org.signal.libsignal.protocol.ecc.Curve 17import org.signal.libsignal.protocol.state.PreKeyRecord 18import org.whispersystems.signalservice.api.push.ServiceId 19import org.whispersystems.signalservice.api.push.ServiceId.ACI 20import org.whispersystems.signalservice.api.push.ServiceId.PNI 21import java.util.UUID 22 23class OneTimePreKeyTableTest { 24 25 private val aci: ACI = ACI.from(UUID.randomUUID()) 26 private val pni: PNI = PNI.from(UUID.randomUUID()) 27 28 @Test 29 fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() { 30 insertTestRecord(aci, id = 1) 31 insertTestRecord(aci, id = 2) 32 insertTestRecord(aci, id = 3, staleTime = 42) 33 insertTestRecord(pni, id = 4) 34 35 val now = System.currentTimeMillis() 36 SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now) 37 38 assertEquals(now, getStaleTime(aci, 1)) 39 assertEquals(now, getStaleTime(aci, 2)) 40 assertEquals(42L, getStaleTime(aci, 3)) 41 assertEquals(0L, getStaleTime(pni, 4)) 42 } 43 44 @Test 45 fun deleteAllStaleBefore_deleteOldBeforeThreshold() { 46 insertTestRecord(aci, id = 1, staleTime = 10) 47 insertTestRecord(aci, id = 2, staleTime = 10) 48 insertTestRecord(aci, id = 3, staleTime = 10) 49 insertTestRecord(aci, id = 4, staleTime = 15) 50 insertTestRecord(aci, id = 5, staleTime = 0) 51 52 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) 53 54 assertNull(getStaleTime(aci, 1)) 55 assertNull(getStaleTime(aci, 2)) 56 assertNull(getStaleTime(aci, 3)) 57 assertNotNull(getStaleTime(aci, 4)) 58 assertNotNull(getStaleTime(aci, 5)) 59 } 60 61 @Test 62 fun deleteAllStaleBefore_neverDeleteStaleOfZero() { 63 insertTestRecord(aci, id = 1, staleTime = 0) 64 insertTestRecord(aci, id = 2, staleTime = 0) 65 insertTestRecord(aci, id = 3, staleTime = 0) 66 insertTestRecord(aci, id = 4, staleTime = 0) 67 insertTestRecord(aci, id = 5, staleTime = 0) 68 69 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0) 70 71 assertNotNull(getStaleTime(aci, 1)) 72 assertNotNull(getStaleTime(aci, 2)) 73 assertNotNull(getStaleTime(aci, 3)) 74 assertNotNull(getStaleTime(aci, 4)) 75 assertNotNull(getStaleTime(aci, 5)) 76 } 77 78 @Test 79 fun deleteAllStaleBefore_respectMinCount() { 80 insertTestRecord(aci, id = 1, staleTime = 10) 81 insertTestRecord(aci, id = 2, staleTime = 10) 82 insertTestRecord(aci, id = 3, staleTime = 10) 83 insertTestRecord(aci, id = 4, staleTime = 10) 84 insertTestRecord(aci, id = 5, staleTime = 10) 85 86 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) 87 88 assertNull(getStaleTime(aci, 1)) 89 assertNull(getStaleTime(aci, 2)) 90 assertNotNull(getStaleTime(aci, 3)) 91 assertNotNull(getStaleTime(aci, 4)) 92 assertNotNull(getStaleTime(aci, 5)) 93 } 94 95 @Test 96 fun deleteAllStaleBefore_respectAccount() { 97 insertTestRecord(aci, id = 1, staleTime = 10) 98 insertTestRecord(aci, id = 2, staleTime = 10) 99 insertTestRecord(aci, id = 3, staleTime = 10) 100 101 insertTestRecord(pni, id = 4, staleTime = 10) 102 insertTestRecord(pni, id = 5, staleTime = 10) 103 104 SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2) 105 106 assertNull(getStaleTime(aci, 1)) 107 assertNotNull(getStaleTime(aci, 2)) 108 assertNotNull(getStaleTime(aci, 3)) 109 assertNotNull(getStaleTime(pni, 4)) 110 assertNotNull(getStaleTime(pni, 5)) 111 } 112 113 private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) { 114 SignalDatabase.oneTimePreKeys.insert( 115 serviceId = account, 116 keyId = id, 117 record = PreKeyRecord(id, Curve.generateKeyPair()) 118 ) 119 120 val count = SignalDatabase.rawDatabase 121 .update(OneTimePreKeyTable.TABLE_NAME) 122 .values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime) 123 .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) 124 .run() 125 126 assertEquals(1, count) 127 } 128 129 private fun getStaleTime(account: ServiceId, id: Int): Long? { 130 return SignalDatabase.rawDatabase 131 .select(OneTimePreKeyTable.STALE_TIMESTAMP) 132 .from(OneTimePreKeyTable.TABLE_NAME) 133 .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) 134 .run() 135 .readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) } 136 } 137}