/* * Copyright 2023 Signal Messenger, LLC * SPDX-License-Identifier: AGPL-3.0-only */ package org.tm.archive.database import junit.framework.TestCase.assertEquals import junit.framework.TestCase.assertNotNull import junit.framework.TestCase.assertNull import org.junit.Test import org.signal.core.util.readToSingleObject import org.signal.core.util.requireLongOrNull import org.signal.core.util.select import org.signal.core.util.update import org.signal.libsignal.protocol.ecc.Curve import org.signal.libsignal.protocol.state.PreKeyRecord import org.whispersystems.signalservice.api.push.ServiceId import org.whispersystems.signalservice.api.push.ServiceId.ACI import org.whispersystems.signalservice.api.push.ServiceId.PNI import java.util.UUID class OneTimePreKeyTableTest { private val aci: ACI = ACI.from(UUID.randomUUID()) private val pni: PNI = PNI.from(UUID.randomUUID()) @Test fun markAllStaleIfNecessary_onlyUpdatesMatchingAccountAndZeroValues() { insertTestRecord(aci, id = 1) insertTestRecord(aci, id = 2) insertTestRecord(aci, id = 3, staleTime = 42) insertTestRecord(pni, id = 4) val now = System.currentTimeMillis() SignalDatabase.oneTimePreKeys.markAllStaleIfNecessary(aci, now) assertEquals(now, getStaleTime(aci, 1)) assertEquals(now, getStaleTime(aci, 2)) assertEquals(42L, getStaleTime(aci, 3)) assertEquals(0L, getStaleTime(pni, 4)) } @Test fun deleteAllStaleBefore_deleteOldBeforeThreshold() { insertTestRecord(aci, id = 1, staleTime = 10) insertTestRecord(aci, id = 2, staleTime = 10) insertTestRecord(aci, id = 3, staleTime = 10) insertTestRecord(aci, id = 4, staleTime = 15) insertTestRecord(aci, id = 5, staleTime = 0) SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 0) assertNull(getStaleTime(aci, 1)) assertNull(getStaleTime(aci, 2)) assertNull(getStaleTime(aci, 3)) assertNotNull(getStaleTime(aci, 4)) assertNotNull(getStaleTime(aci, 5)) } @Test fun deleteAllStaleBefore_neverDeleteStaleOfZero() { insertTestRecord(aci, id = 1, staleTime = 0) insertTestRecord(aci, id = 2, staleTime = 0) insertTestRecord(aci, id = 3, staleTime = 0) insertTestRecord(aci, id = 4, staleTime = 0) insertTestRecord(aci, id = 5, staleTime = 0) SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 10, minCount = 0) assertNotNull(getStaleTime(aci, 1)) assertNotNull(getStaleTime(aci, 2)) assertNotNull(getStaleTime(aci, 3)) assertNotNull(getStaleTime(aci, 4)) assertNotNull(getStaleTime(aci, 5)) } @Test fun deleteAllStaleBefore_respectMinCount() { insertTestRecord(aci, id = 1, staleTime = 10) insertTestRecord(aci, id = 2, staleTime = 10) insertTestRecord(aci, id = 3, staleTime = 10) insertTestRecord(aci, id = 4, staleTime = 10) insertTestRecord(aci, id = 5, staleTime = 10) SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 3) assertNull(getStaleTime(aci, 1)) assertNull(getStaleTime(aci, 2)) assertNotNull(getStaleTime(aci, 3)) assertNotNull(getStaleTime(aci, 4)) assertNotNull(getStaleTime(aci, 5)) } @Test fun deleteAllStaleBefore_respectAccount() { insertTestRecord(aci, id = 1, staleTime = 10) insertTestRecord(aci, id = 2, staleTime = 10) insertTestRecord(aci, id = 3, staleTime = 10) insertTestRecord(pni, id = 4, staleTime = 10) insertTestRecord(pni, id = 5, staleTime = 10) SignalDatabase.oneTimePreKeys.deleteAllStaleBefore(aci, threshold = 11, minCount = 2) assertNull(getStaleTime(aci, 1)) assertNotNull(getStaleTime(aci, 2)) assertNotNull(getStaleTime(aci, 3)) assertNotNull(getStaleTime(pni, 4)) assertNotNull(getStaleTime(pni, 5)) } private fun insertTestRecord(account: ServiceId, id: Int, staleTime: Long = 0) { SignalDatabase.oneTimePreKeys.insert( serviceId = account, keyId = id, record = PreKeyRecord(id, Curve.generateKeyPair()) ) val count = SignalDatabase.rawDatabase .update(OneTimePreKeyTable.TABLE_NAME) .values(OneTimePreKeyTable.STALE_TIMESTAMP to staleTime) .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) .run() assertEquals(1, count) } private fun getStaleTime(account: ServiceId, id: Int): Long? { return SignalDatabase.rawDatabase .select(OneTimePreKeyTable.STALE_TIMESTAMP) .from(OneTimePreKeyTable.TABLE_NAME) .where("${OneTimePreKeyTable.ACCOUNT_ID} = ? AND ${OneTimePreKeyTable.KEY_ID} = $id", account) .run() .readToSingleObject { it.requireLongOrNull(OneTimePreKeyTable.STALE_TIMESTAMP) } } }