#![no_std] use core::ops::BitXor; use chacha20poly1305::{AeadInOut, ChaCha20Poly1305, KeyInit, aead}; use dhkem::{ Encapsulate, Expander, Kem, X25519DecapsulationKey, X25519EncapsulationKey, X25519Kem, kem::{Ciphertext, Decapsulate, Key, KeyExport, SharedKey, TryKeyInit}, }; use elliptic_curve::subtle::ConstantTimeEq; /// Error type. /// /// This type is deliberately opaque as to avoid potential side-channel /// leakage (e.g. padding oracle). #[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct ProtoError; impl core::fmt::Display for ProtoError { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { f.write_str("ProtoError") } } impl core::error::Error for ProtoError {} impl From for ProtoError { fn from(_value: chacha20poly1305::Error) -> Self { Self } } pub struct ClientHandshake(X25519DecapsulationKey); pub struct EncapsulatedPublicKey(X25519EncapsulationKey); /// The role of the participant, whether sending/receiving during handshake, /// and then whether sending/receiving during communication. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub enum Role { /// Participant SENDS data Sender, /// Participant RECEIVES data Receiver, } impl From for u8 { fn from(value: Role) -> Self { match value { Role::Sender => 0, Role::Receiver => 1, } } } impl BitXor for Role { type Output = u8; fn bitxor(self, rhs: Self) -> u8 { u8::from(self) ^ u8::from(rhs) } } impl EncapsulatedPublicKey { pub fn serialize(&self) -> Key { self.0.to_bytes() } pub fn deserialize(buf: &[u8]) -> Result { Ok(Self( X25519EncapsulationKey::new_from_slice(buf).map_err(|_| ProtoError)?, )) } pub fn encapsulate(&self) -> (Ciphertext, SharedKey) { self.0.encapsulate() } } impl ClientHandshake { pub fn send() -> (EncapsulatedPublicKey, Self) { let (decap, encap) = X25519Kem::generate_keypair(); (EncapsulatedPublicKey(encap), Self(decap)) } pub fn finish(self, ciphertext: &[u8], psk: &[u8; 32]) -> Result { let shared = self .0 .decapsulate_slice(ciphertext) .map_err(|_| ProtoError)?; TransportState::init(psk, shared, Role::Sender) } } pub struct ServerHandshake(SharedKey); impl ServerHandshake { pub fn receive(buf: &[u8]) -> Result<(Ciphertext, Self), ProtoError> { let encap = EncapsulatedPublicKey::deserialize(buf)?; let (ciphertext, sk) = encap.encapsulate(); Ok((ciphertext, Self(sk))) } pub fn finish(self, psk: &[u8; 32]) -> Result { TransportState::init(psk, self.0, Role::Receiver) } } pub struct SendingState<'a> { transport: &'a TransportState, counter: u64, } impl SendingState<'_> { pub fn encrypt( &mut self, msg: &mut dyn aead::Buffer, associated_data: &[u8], ) -> Result<(), ProtoError> { if self.counter.ct_eq(&u64::MAX).into() { return Err(ProtoError); } self.transport.aead.encrypt_in_place( &self .transport .mix_nonce(&self.counter.to_be_bytes(), Role::Sender), associated_data, msg, )?; self.counter = self.counter.wrapping_add(1); Ok(()) } } pub struct ReceivingState<'a> { transport: &'a TransportState, counter: u64, } impl ReceivingState<'_> { pub fn decrypt( &mut self, msg: &mut dyn aead::Buffer, associated_data: &[u8], ) -> Result<(), ProtoError> { if self.counter.ct_eq(&u64::MAX).into() { return Err(ProtoError); } self.transport.aead.decrypt_in_place( &self .transport .mix_nonce(&self.counter.to_be_bytes(), Role::Receiver), associated_data, msg, )?; self.counter = self.counter.wrapping_add(1); Ok(()) } } #[repr(align(4))] pub struct TransportState { aead: ChaCha20Poly1305, client: aead::Nonce, server: aead::Nonce, role: Role, } impl TransportState { pub fn init( psk: &[u8; 32], shared: SharedKey, role: Role, ) -> Result { let kdf = Expander::::new_labeled_hpke(psk, b"Sachy-Crypto", &shared) .map_err(|_| ProtoError)?; let mut key = [0u8; 32]; let mut client = aead::Nonce::::default(); let mut server = aead::Nonce::::default(); kdf.expand(b"SecretKey012", &mut key) .map_err(|_| ProtoError)?; kdf.expand(b"NonceClient*", &mut client) .map_err(|_| ProtoError)?; kdf.expand(b"NonceServer#", &mut server) .map_err(|_| ProtoError)?; Ok(Self { aead: ChaCha20Poly1305::new(&key.into()), client, server, role, }) } pub fn split(&self) -> (SendingState<'_>, ReceivingState<'_>) { ( SendingState { transport: self, counter: 0, }, ReceivingState { transport: self, counter: 0, }, ) } /// Selects which nonce to use for encrypting/decrypting, which matters for /// ensuring the same nonce is used only for one direction of communication. fn select_nonce_context(&self, send: Role) -> &aead::Nonce { let context_select = self.role ^ send; // Handshake ROLE XOR Transport ROLE selects either one or other nonce context, // (0) for first context, (1) for second context // Sending: Client ^ Sender = 0 (select first/client context) // Receiving: Server ^ Receiver = 0 (select first/client context) // Sending: Server ^ Sender = 1 (select second/server context) // Receiving: Client ^ Receiver = 1 (select second/server context) if context_select.ct_eq(&0).into() { &self.client } else { &self.server } } fn mix_nonce(&self, position: &[u8; 8], send: Role) -> aead::Nonce { let mut trump = aead::Nonce::::default(); let epstein = self.select_nonce_context(send); let index = trump.len() - position.len(); // Copy position bytes into BE format onto derived nonce trump[index..].copy_from_slice(position); // XOR the base nonce onto the derived nonce bytes trump .iter_mut() .zip(epstein) .for_each(|(trump, epstein)| *trump ^= *epstein); trump } } #[derive(Debug)] pub struct BufferSlice<'a> { slice: &'a mut [u8], end: usize, } impl<'a> BufferSlice<'a> { pub fn new(slice: &'a mut [u8]) -> Self { Self { end: slice.len(), slice, } } pub fn reset(&mut self) { self.end = self.slice.len(); } } impl AsRef<[u8]> for BufferSlice<'_> { fn as_ref(&self) -> &[u8] { &self.slice[..self.end] } } impl AsMut<[u8]> for BufferSlice<'_> { fn as_mut(&mut self) -> &mut [u8] { &mut self.slice[..self.end] } } impl aead::Buffer for BufferSlice<'_> { fn extend_from_slice(&mut self, other: &[u8]) -> aead::Result<()> { let index = self.end + other.len(); if index > self.slice.len() { return Err(aead::Error); } self.slice[self.end..index].copy_from_slice(other); self.end = index; Ok(()) } fn truncate(&mut self, len: usize) { self.end = len; } } #[cfg(test)] mod tests { use alloc::vec; use chacha20poly1305::aead::Buffer; use dhkem::Generate; use elliptic_curve::array::Array; extern crate alloc; use super::*; #[test] fn buffer_slice_works() { let mut buf = vec![0u8; 128]; let mut buf_slice = BufferSlice::new(&mut buf); assert_eq!(buf_slice.len(), 128); assert_eq!(buf_slice.extend_from_slice(&[0, 0, 0]), Err(aead::Error)); buf_slice.truncate(64); assert_eq!(buf_slice.extend_from_slice(&[0, 0, 0, 0, 0, 0]), Ok(())); assert_eq!(buf_slice.len(), 70); buf_slice.reset(); assert_eq!(buf_slice.len(), 128); } #[test] fn handshake_protocol_works() -> Result<(), ProtoError> { let psk: [u8; 32] = [ 31, 48, 29, 177, 88, 236, 186, 84, 65, 51, 214, 243, 174, 24, 45, 101, 229, 129, 62, 132, 45, 174, 183, 65, 89, 73, 107, 177, 77, 90, 164, 251, ]; let (ek, client) = ClientHandshake::send(); // Pretend to send ek across the webz: client -> server let (ciphertext, server) = ServerHandshake::receive(&ek.serialize())?; // Pretend to send ciphertext across the webz: server -> client let alice = client.finish(&ciphertext, &psk)?; let bob = server.finish(&psk)?; let nonce = aead::Nonce::::generate(); let mut buffer1 = vec![0u8; 64]; let mut buffer2 = vec![0u8; 64]; // Using the same nonce to check that the internal AEAD states match. Normally, client/server // would work with unique derived nonces, because nonce reuse is BAD alice.aead.encrypt_in_place(&nonce, &[], &mut buffer1)?; bob.aead.encrypt_in_place(&nonce, &[], &mut buffer2)?; // If the nonces match, then we can assume the rest of the internal state is the same too // so the outputs should match each other assert_eq!(&buffer1, &buffer2); // Both Transports have derived base nonces for each context. // Client context nonces will not match Server context nonces. assert_eq!(alice.client, bob.client); assert_eq!(alice.server, bob.server); assert_ne!(alice.client, alice.server); assert_ne!(bob.client, bob.server); Ok(()) } #[test] fn two_way_transport_sync_works() -> Result<(), ProtoError> { let shared_secret = [ 0x80, 0x81, 0x82, 0x83, 0x84, 0x85, 0x86, 0x87, 0x88, 0x89, 0x8a, 0x8b, 0x8c, 0x8d, 0x8e, 0x8f, 0x90, 0x91, 0x92, 0x93, 0x94, 0x95, 0x96, 0x97, 0x98, 0x99, 0x9a, 0x9b, 0x9c, 0x9d, 0x9e, 0x9f, ]; let psk: [u8; 32] = [ 31, 48, 29, 177, 88, 236, 186, 84, 65, 51, 214, 243, 174, 24, 45, 101, 229, 129, 62, 132, 45, 174, 183, 65, 89, 73, 107, 177, 77, 90, 164, 251, ]; let alice = TransportState::init(&psk, Array(shared_secret), Role::Sender)?; let bob = TransportState::init(&psk, Array(shared_secret), Role::Receiver)?; let (mut alice_send, mut alice_recv) = alice.split(); let (mut bob_send, mut bob_recv) = bob.split(); let orig = b"Test Message, Please ignore."; let ad = b"random"; let mut msg = orig.to_vec(); // a -> b alice_send.encrypt(&mut msg, ad)?; assert_ne!(orig.as_slice(), msg.as_slice()); let ct1 = msg.clone(); bob_recv.decrypt(&mut msg, ad)?; // a -> b alice_send.encrypt(&mut msg, b"")?; assert_ne!(msg.as_slice(), ct1.as_slice()); let ct2 = msg.clone(); bob_recv.decrypt(&mut msg, b"")?; // b -> a bob_send.encrypt(&mut msg, ad)?; // None of the ciphertexts should match each other assert_ne!(msg.as_slice(), ct1.as_slice()); assert_ne!(msg.as_slice(), ct2.as_slice()); assert_ne!(ct1.as_slice(), ct2.as_slice()); alice_recv.decrypt(&mut msg, ad)?; assert_eq!(orig.as_slice(), msg.as_slice()); // Counters are tracked from sender to receiver assert_eq!(alice_send.counter, bob_recv.counter); assert_eq!(bob_send.counter, alice_recv.counter); // Counters are not linked on the same side assert_ne!(alice_send.counter, alice_recv.counter); assert_ne!(bob_send.counter, bob_recv.counter); Ok(()) } }