//! TLS 1.3 record layer (RFC 8446 §5). //! //! Handles framing, encryption, and decryption of TLS records over a TCP stream. use std::io::{self, Read, Write}; // --------------------------------------------------------------------------- // Constants // --------------------------------------------------------------------------- /// TLS 1.2 legacy version used in record headers (RFC 8446 §5.1). const LEGACY_VERSION: [u8; 2] = [0x03, 0x03]; /// Maximum plaintext fragment size: 2^14 = 16384 bytes. const MAX_PLAINTEXT_LENGTH: usize = 16384; /// Maximum ciphertext overhead: 256 bytes (tag + inner content type + padding). const MAX_CIPHERTEXT_OVERHEAD: usize = 256; /// Maximum ciphertext fragment size: 2^14 + 256. const MAX_CIPHERTEXT_LENGTH: usize = MAX_PLAINTEXT_LENGTH + MAX_CIPHERTEXT_OVERHEAD; /// AEAD tag size for all TLS 1.3 cipher suites. const TAG_SIZE: usize = 16; /// Record header size: content type (1) + version (2) + length (2). const RECORD_HEADER_SIZE: usize = 5; // --------------------------------------------------------------------------- // Content types (RFC 8446 §5.1) // --------------------------------------------------------------------------- /// TLS record content types. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum ContentType { ChangeCipherSpec = 20, Alert = 21, Handshake = 22, ApplicationData = 23, } impl ContentType { fn from_u8(v: u8) -> Result { match v { 20 => Ok(ContentType::ChangeCipherSpec), 21 => Ok(ContentType::Alert), 22 => Ok(ContentType::Handshake), 23 => Ok(ContentType::ApplicationData), _ => Err(TlsError::UnknownContentType(v)), } } } // --------------------------------------------------------------------------- // Alert protocol (RFC 8446 §6) // --------------------------------------------------------------------------- /// TLS alert severity level. #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum AlertLevel { Warning = 1, Fatal = 2, } impl AlertLevel { fn from_u8(v: u8) -> Result { match v { 1 => Ok(AlertLevel::Warning), 2 => Ok(AlertLevel::Fatal), _ => Err(TlsError::MalformedAlert), } } } /// TLS alert descriptions (subset covering TLS 1.3). #[derive(Debug, Clone, Copy, PartialEq, Eq)] #[repr(u8)] pub enum AlertDescription { CloseNotify = 0, UnexpectedMessage = 10, BadRecordMac = 20, RecordOverflow = 22, HandshakeFailure = 40, BadCertificate = 42, CertificateRevoked = 44, CertificateExpired = 45, CertificateUnknown = 46, IllegalParameter = 47, UnknownCa = 48, AccessDenied = 49, DecodeError = 50, DecryptError = 51, ProtocolVersion = 70, InsufficientSecurity = 71, InternalError = 80, MissingExtension = 109, UnsupportedExtension = 110, UnrecognizedName = 112, BadCertificateStatusResponse = 113, NoApplicationProtocol = 120, } impl AlertDescription { fn from_u8(v: u8) -> Result { match v { 0 => Ok(AlertDescription::CloseNotify), 10 => Ok(AlertDescription::UnexpectedMessage), 20 => Ok(AlertDescription::BadRecordMac), 22 => Ok(AlertDescription::RecordOverflow), 40 => Ok(AlertDescription::HandshakeFailure), 42 => Ok(AlertDescription::BadCertificate), 44 => Ok(AlertDescription::CertificateRevoked), 45 => Ok(AlertDescription::CertificateExpired), 46 => Ok(AlertDescription::CertificateUnknown), 47 => Ok(AlertDescription::IllegalParameter), 48 => Ok(AlertDescription::UnknownCa), 49 => Ok(AlertDescription::AccessDenied), 50 => Ok(AlertDescription::DecodeError), 51 => Ok(AlertDescription::DecryptError), 70 => Ok(AlertDescription::ProtocolVersion), 71 => Ok(AlertDescription::InsufficientSecurity), 80 => Ok(AlertDescription::InternalError), 109 => Ok(AlertDescription::MissingExtension), 110 => Ok(AlertDescription::UnsupportedExtension), 112 => Ok(AlertDescription::UnrecognizedName), 113 => Ok(AlertDescription::BadCertificateStatusResponse), 120 => Ok(AlertDescription::NoApplicationProtocol), _ => Err(TlsError::MalformedAlert), } } } /// A parsed TLS alert message. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub struct Alert { pub level: AlertLevel, pub description: AlertDescription, } impl Alert { /// Create a new alert. pub fn new(level: AlertLevel, description: AlertDescription) -> Self { Self { level, description } } /// Create a close_notify alert. pub fn close_notify() -> Self { Self::new(AlertLevel::Warning, AlertDescription::CloseNotify) } /// Encode the alert to bytes. pub fn encode(&self) -> [u8; 2] { [self.level as u8, self.description as u8] } /// Parse an alert from bytes. pub fn parse(data: &[u8]) -> Result { if data.len() < 2 { return Err(TlsError::MalformedAlert); } Ok(Self { level: AlertLevel::from_u8(data[0])?, description: AlertDescription::from_u8(data[1])?, }) } /// Returns true if this is a fatal alert. pub fn is_fatal(&self) -> bool { self.level == AlertLevel::Fatal } } // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- /// TLS record layer errors. #[derive(Debug)] pub enum TlsError { /// Unknown content type byte. UnknownContentType(u8), /// Record exceeds maximum allowed size. RecordOverflow, /// AEAD decryption failed (bad MAC). DecryptionFailed, /// Malformed alert message. MalformedAlert, /// Received a fatal alert from the peer. AlertReceived(Alert), /// An I/O error occurred. Io(io::Error), } impl std::fmt::Display for TlsError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::UnknownContentType(v) => write!(f, "unknown TLS content type: {v}"), Self::RecordOverflow => write!(f, "TLS record overflow"), Self::DecryptionFailed => write!(f, "TLS AEAD decryption failed"), Self::MalformedAlert => write!(f, "malformed TLS alert"), Self::AlertReceived(alert) => { write!(f, "TLS alert received: {:?}", alert.description) } Self::Io(e) => write!(f, "TLS I/O error: {e}"), } } } impl From for TlsError { fn from(err: io::Error) -> Self { TlsError::Io(err) } } pub type Result = std::result::Result; // --------------------------------------------------------------------------- // Cipher suite abstraction // --------------------------------------------------------------------------- /// Supported TLS 1.3 cipher suites. #[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum CipherSuite { Aes128Gcm, Aes256Gcm, Chacha20Poly1305, } impl CipherSuite { /// Key length in bytes. pub fn key_len(&self) -> usize { match self { CipherSuite::Aes128Gcm => 16, CipherSuite::Aes256Gcm => 32, CipherSuite::Chacha20Poly1305 => 32, } } /// IV (nonce base) length in bytes. pub fn iv_len(&self) -> usize { 12 // All TLS 1.3 suites use 12-byte nonces } } // --------------------------------------------------------------------------- // TLS record // --------------------------------------------------------------------------- /// A plaintext TLS record. #[derive(Debug, Clone)] pub struct TlsRecord { pub content_type: ContentType, pub data: Vec, } impl TlsRecord { /// Create a new TLS record. pub fn new(content_type: ContentType, data: Vec) -> Self { Self { content_type, data } } } // --------------------------------------------------------------------------- // Plaintext record I/O // --------------------------------------------------------------------------- /// Read a single plaintext TLS record from a stream. pub fn read_record(stream: &mut R) -> Result { // Read 5-byte header let mut header = [0u8; RECORD_HEADER_SIZE]; stream.read_exact(&mut header)?; let content_type = ContentType::from_u8(header[0])?; // We accept any legacy version in the header (RFC 8446 §5.1) let length = u16::from_be_bytes([header[3], header[4]]) as usize; if length > MAX_CIPHERTEXT_LENGTH { return Err(TlsError::RecordOverflow); } let mut data = vec![0u8; length]; stream.read_exact(&mut data)?; Ok(TlsRecord { content_type, data }) } /// Write a single plaintext TLS record to a stream. pub fn write_record(stream: &mut W, record: &TlsRecord) -> Result<()> { if record.data.len() > MAX_PLAINTEXT_LENGTH { return Err(TlsError::RecordOverflow); } let mut header = [0u8; RECORD_HEADER_SIZE]; header[0] = record.content_type as u8; header[1..3].copy_from_slice(&LEGACY_VERSION); header[3..5].copy_from_slice(&(record.data.len() as u16).to_be_bytes()); stream.write_all(&header)?; stream.write_all(&record.data)?; stream.flush()?; Ok(()) } // --------------------------------------------------------------------------- // Nonce construction (RFC 8446 §5.3) // --------------------------------------------------------------------------- /// Construct a per-record nonce by XORing the base IV with the sequence number. /// /// The sequence number is left-padded with zeros to match the IV length, /// then XORed with the base IV. fn make_nonce(iv: &[u8; 12], seq: u64) -> [u8; 12] { let mut nonce = *iv; let seq_bytes = seq.to_be_bytes(); // 8 bytes // XOR seq into the last 8 bytes of the 12-byte nonce for i in 0..8 { nonce[4 + i] ^= seq_bytes[i]; } nonce } // --------------------------------------------------------------------------- // AEAD encrypt/decrypt wrappers // --------------------------------------------------------------------------- fn aead_encrypt( suite: CipherSuite, key: &[u8], nonce: &[u8; 12], plaintext: &[u8], aad: &[u8], ) -> (Vec, [u8; 16]) { match suite { CipherSuite::Aes128Gcm => { let key: [u8; 16] = key.try_into().expect("AES-128 key must be 16 bytes"); we_crypto::aes_gcm::aes128_gcm_encrypt(&key, nonce, plaintext, aad) } CipherSuite::Aes256Gcm => { let key: [u8; 32] = key.try_into().expect("AES-256 key must be 32 bytes"); we_crypto::aes_gcm::aes256_gcm_encrypt(&key, nonce, plaintext, aad) } CipherSuite::Chacha20Poly1305 => { let key: [u8; 32] = key.try_into().expect("ChaCha20 key must be 32 bytes"); we_crypto::chacha20_poly1305::chacha20_poly1305_encrypt(&key, nonce, plaintext, aad) } } } fn aead_decrypt( suite: CipherSuite, key: &[u8], nonce: &[u8; 12], ciphertext: &[u8], aad: &[u8], tag: &[u8; 16], ) -> Option> { match suite { CipherSuite::Aes128Gcm => { let key: [u8; 16] = key.try_into().expect("AES-128 key must be 16 bytes"); we_crypto::aes_gcm::aes128_gcm_decrypt(&key, nonce, ciphertext, aad, tag) } CipherSuite::Aes256Gcm => { let key: [u8; 32] = key.try_into().expect("AES-256 key must be 32 bytes"); we_crypto::aes_gcm::aes256_gcm_decrypt(&key, nonce, ciphertext, aad, tag) } CipherSuite::Chacha20Poly1305 => { let key: [u8; 32] = key.try_into().expect("ChaCha20 key must be 32 bytes"); we_crypto::chacha20_poly1305::chacha20_poly1305_decrypt( &key, nonce, ciphertext, aad, tag, ) } } } // --------------------------------------------------------------------------- // Encrypted record layer (RFC 8446 §5.2) // --------------------------------------------------------------------------- /// State for encrypting/decrypting TLS 1.3 records. /// /// Each direction (read/write) needs its own `RecordCryptoState` with /// independent keys, IVs, and sequence numbers. pub struct RecordCryptoState { suite: CipherSuite, key: Vec, iv: [u8; 12], seq: u64, } impl RecordCryptoState { /// Create a new record crypto state. pub fn new(suite: CipherSuite, key: Vec, iv: [u8; 12]) -> Self { assert_eq!(key.len(), suite.key_len()); Self { suite, key, iv, seq: 0, } } /// Encrypt a TLS record (RFC 8446 §5.2). /// /// Builds TLSInnerPlaintext (content ∥ content_type ∥ zeros), /// then encrypts with AEAD using the record header as AAD. /// Returns the encrypted TLS record with outer type ApplicationData. pub fn encrypt(&mut self, record: &TlsRecord) -> Result { if record.data.len() > MAX_PLAINTEXT_LENGTH { return Err(TlsError::RecordOverflow); } // Build TLSInnerPlaintext: content ∥ content_type (no padding) let mut inner = Vec::with_capacity(record.data.len() + 1); inner.extend_from_slice(&record.data); inner.push(record.content_type as u8); // Construct nonce let nonce = make_nonce(&self.iv, self.seq); // The AAD is the record header of the *outer* record: // type(23) ∥ legacy_version(0x0303) ∥ length(encrypted_len) // encrypted_len = inner.len() + tag_size let encrypted_len = inner.len() + TAG_SIZE; if encrypted_len > MAX_CIPHERTEXT_LENGTH { return Err(TlsError::RecordOverflow); } let mut aad = [0u8; RECORD_HEADER_SIZE]; aad[0] = ContentType::ApplicationData as u8; aad[1..3].copy_from_slice(&LEGACY_VERSION); aad[3..5].copy_from_slice(&(encrypted_len as u16).to_be_bytes()); let (ciphertext, tag) = aead_encrypt(self.suite, &self.key, &nonce, &inner, &aad); // Combine ciphertext + tag let mut encrypted = ciphertext; encrypted.extend_from_slice(&tag); self.seq = self.seq.wrapping_add(1); Ok(TlsRecord { content_type: ContentType::ApplicationData, data: encrypted, }) } /// Decrypt a TLS record (RFC 8446 §5.2). /// /// Expects the outer content type to be ApplicationData. /// Decrypts, strips padding and inner content type. /// Returns the decrypted record with the real content type. pub fn decrypt(&mut self, record: &TlsRecord) -> Result { if record.content_type != ContentType::ApplicationData { return Err(TlsError::UnknownContentType(record.content_type as u8)); } if record.data.len() < TAG_SIZE + 1 { // Need at least tag + 1 byte for inner content type return Err(TlsError::DecryptionFailed); } if record.data.len() > MAX_CIPHERTEXT_LENGTH { return Err(TlsError::RecordOverflow); } // Split ciphertext and tag let ct_len = record.data.len() - TAG_SIZE; let ciphertext = &record.data[..ct_len]; let tag: [u8; 16] = record.data[ct_len..].try_into().expect("tag is 16 bytes"); // Construct nonce let nonce = make_nonce(&self.iv, self.seq); // Build AAD: the record header as received let mut aad = [0u8; RECORD_HEADER_SIZE]; aad[0] = ContentType::ApplicationData as u8; aad[1..3].copy_from_slice(&LEGACY_VERSION); aad[3..5].copy_from_slice(&(record.data.len() as u16).to_be_bytes()); let inner = aead_decrypt(self.suite, &self.key, &nonce, ciphertext, &aad, &tag) .ok_or(TlsError::DecryptionFailed)?; self.seq = self.seq.wrapping_add(1); // Strip padding zeros and extract inner content type // TLSInnerPlaintext = content ∥ ContentType ∥ zeros // Find the last non-zero byte (the content type) let ct_pos = inner .iter() .rposition(|&b| b != 0) .ok_or(TlsError::DecryptionFailed)?; let content_type = ContentType::from_u8(inner[ct_pos])?; let data = inner[..ct_pos].to_vec(); if data.len() > MAX_PLAINTEXT_LENGTH { return Err(TlsError::RecordOverflow); } Ok(TlsRecord { content_type, data }) } } // --------------------------------------------------------------------------- // RecordLayer: combines plaintext and encrypted record I/O // --------------------------------------------------------------------------- /// The TLS record layer, handling both plaintext and encrypted records. pub struct RecordLayer { stream: S, write_state: Option, read_state: Option, } impl RecordLayer { /// Create a new record layer in plaintext mode. pub fn new(stream: S) -> Self { Self { stream, write_state: None, read_state: None, } } /// Enable encryption for writing. pub fn set_write_crypto(&mut self, state: RecordCryptoState) { self.write_state = Some(state); } /// Enable encryption for reading. pub fn set_read_crypto(&mut self, state: RecordCryptoState) { self.read_state = Some(state); } /// Get a reference to the underlying stream. pub fn stream(&self) -> &S { &self.stream } /// Get a mutable reference to the underlying stream. pub fn stream_mut(&mut self) -> &mut S { &mut self.stream } /// Write a TLS record, encrypting if crypto state is set. pub fn write_record(&mut self, record: &TlsRecord) -> Result<()> { let record_to_write = match &mut self.write_state { Some(state) => state.encrypt(record)?, None => record.clone(), }; write_record(&mut self.stream, &record_to_write) } /// Read a TLS record, decrypting if crypto state is set. pub fn read_record(&mut self) -> Result { let raw = read_record(&mut self.stream)?; match &mut self.read_state { Some(state) if raw.content_type == ContentType::ApplicationData => state.decrypt(&raw), _ => Ok(raw), } } /// Send an alert. pub fn send_alert(&mut self, alert: Alert) -> Result<()> { let record = TlsRecord::new(ContentType::Alert, alert.encode().to_vec()); self.write_record(&record) } /// Send a close_notify alert. pub fn send_close_notify(&mut self) -> Result<()> { self.send_alert(Alert::close_notify()) } /// Consume the record layer and return the underlying stream. pub fn into_inner(self) -> S { self.stream } } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; use std::io::Cursor; // -- ContentType tests -- #[test] fn content_type_from_u8_valid() { assert_eq!( ContentType::from_u8(20).unwrap(), ContentType::ChangeCipherSpec ); assert_eq!(ContentType::from_u8(21).unwrap(), ContentType::Alert); assert_eq!(ContentType::from_u8(22).unwrap(), ContentType::Handshake); assert_eq!( ContentType::from_u8(23).unwrap(), ContentType::ApplicationData ); } #[test] fn content_type_from_u8_invalid() { assert!(ContentType::from_u8(0).is_err()); assert!(ContentType::from_u8(19).is_err()); assert!(ContentType::from_u8(24).is_err()); assert!(ContentType::from_u8(255).is_err()); } // -- Alert tests -- #[test] fn alert_encode_decode() { let alert = Alert::new(AlertLevel::Fatal, AlertDescription::HandshakeFailure); let bytes = alert.encode(); assert_eq!(bytes, [2, 40]); let parsed = Alert::parse(&bytes).unwrap(); assert_eq!(parsed, alert); } #[test] fn alert_close_notify() { let alert = Alert::close_notify(); assert_eq!(alert.level, AlertLevel::Warning); assert_eq!(alert.description, AlertDescription::CloseNotify); assert!(!alert.is_fatal()); assert_eq!(alert.encode(), [1, 0]); } #[test] fn alert_fatal() { let alert = Alert::new(AlertLevel::Fatal, AlertDescription::InternalError); assert!(alert.is_fatal()); } #[test] fn alert_parse_too_short() { assert!(Alert::parse(&[1]).is_err()); assert!(Alert::parse(&[]).is_err()); } #[test] fn alert_parse_invalid_level() { assert!(Alert::parse(&[0, 0]).is_err()); assert!(Alert::parse(&[3, 0]).is_err()); } // -- Nonce construction tests -- #[test] fn nonce_seq_zero() { let iv = [0x01; 12]; let nonce = make_nonce(&iv, 0); assert_eq!(nonce, iv); } #[test] fn nonce_seq_one() { let iv = [0u8; 12]; let nonce = make_nonce(&iv, 1); let mut expected = [0u8; 12]; expected[11] = 1; assert_eq!(nonce, expected); } #[test] fn nonce_xor_correctness() { let iv = [0xff; 12]; let nonce = make_nonce(&iv, 0x0102030405060708); // Last 8 bytes: 0xff ^ seq_bytes assert_eq!(nonce[0..4], [0xff, 0xff, 0xff, 0xff]); // untouched first 4 assert_eq!(nonce[4], 0xff ^ 0x01); assert_eq!(nonce[5], 0xff ^ 0x02); assert_eq!(nonce[6], 0xff ^ 0x03); assert_eq!(nonce[7], 0xff ^ 0x04); assert_eq!(nonce[8], 0xff ^ 0x05); assert_eq!(nonce[9], 0xff ^ 0x06); assert_eq!(nonce[10], 0xff ^ 0x07); assert_eq!(nonce[11], 0xff ^ 0x08); } // -- Plaintext record I/O tests -- #[test] fn write_read_plaintext_record() { let record = TlsRecord::new(ContentType::Handshake, vec![0x01, 0x02, 0x03]); let mut buf = Vec::new(); write_record(&mut buf, &record).unwrap(); // Verify wire format assert_eq!(buf[0], 22); // Handshake assert_eq!(buf[1..3], LEGACY_VERSION); assert_eq!(buf[3..5], [0x00, 0x03]); // length = 3 assert_eq!(buf[5..], [0x01, 0x02, 0x03]); // Read it back let mut cursor = Cursor::new(&buf); let read_back = read_record(&mut cursor).unwrap(); assert_eq!(read_back.content_type, ContentType::Handshake); assert_eq!(read_back.data, vec![0x01, 0x02, 0x03]); } #[test] fn write_read_empty_record() { let record = TlsRecord::new(ContentType::ApplicationData, vec![]); let mut buf = Vec::new(); write_record(&mut buf, &record).unwrap(); let mut cursor = Cursor::new(&buf); let read_back = read_record(&mut cursor).unwrap(); assert_eq!(read_back.content_type, ContentType::ApplicationData); assert!(read_back.data.is_empty()); } #[test] fn write_record_overflow() { let record = TlsRecord::new( ContentType::ApplicationData, vec![0u8; MAX_PLAINTEXT_LENGTH + 1], ); let mut buf = Vec::new(); let result = write_record(&mut buf, &record); assert!(result.is_err()); } #[test] fn read_record_overflow() { // Craft a header claiming more than MAX_CIPHERTEXT_LENGTH let mut buf = vec![23u8]; // ApplicationData buf.extend_from_slice(&LEGACY_VERSION); let bad_len = (MAX_CIPHERTEXT_LENGTH + 1) as u16; buf.extend_from_slice(&bad_len.to_be_bytes()); let mut cursor = Cursor::new(&buf); let result = read_record(&mut cursor); assert!(result.is_err()); } #[test] fn read_record_truncated_header() { let buf = vec![22u8, 0x03]; // Only 2 bytes of header let mut cursor = Cursor::new(&buf); let result = read_record(&mut cursor); assert!(result.is_err()); } #[test] fn read_record_truncated_body() { // Header says 10 bytes but only 5 are available let mut buf = vec![22u8]; // Handshake buf.extend_from_slice(&LEGACY_VERSION); buf.extend_from_slice(&10u16.to_be_bytes()); buf.extend_from_slice(&[0u8; 5]); // Only 5 of 10 bytes let mut cursor = Cursor::new(&buf); let result = read_record(&mut cursor); assert!(result.is_err()); } #[test] fn write_read_multiple_records() { let records = vec![ TlsRecord::new(ContentType::Handshake, vec![1, 2, 3]), TlsRecord::new(ContentType::ApplicationData, vec![4, 5]), TlsRecord::new(ContentType::Alert, vec![1, 0]), ]; let mut buf = Vec::new(); for r in &records { write_record(&mut buf, r).unwrap(); } let mut cursor = Cursor::new(&buf); for expected in &records { let read_back = read_record(&mut cursor).unwrap(); assert_eq!(read_back.content_type, expected.content_type); assert_eq!(read_back.data, expected.data); } } // -- Encrypted record tests -- fn test_keys(suite: CipherSuite) -> (Vec, [u8; 12]) { let key = vec![0x42u8; suite.key_len()]; let iv = [0x01u8; 12]; (key, iv) } #[test] fn encrypt_decrypt_roundtrip_aes128() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let original = TlsRecord::new(ContentType::Handshake, b"hello TLS".to_vec()); let encrypted = enc.encrypt(&original).unwrap(); assert_eq!(encrypted.content_type, ContentType::ApplicationData); assert_ne!(encrypted.data, original.data); let decrypted = dec.decrypt(&encrypted).unwrap(); assert_eq!(decrypted.content_type, ContentType::Handshake); assert_eq!(decrypted.data, b"hello TLS"); } #[test] fn encrypt_decrypt_roundtrip_aes256() { let (key, iv) = test_keys(CipherSuite::Aes256Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes256Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes256Gcm, key, iv); let original = TlsRecord::new(ContentType::ApplicationData, b"data".to_vec()); let encrypted = enc.encrypt(&original).unwrap(); let decrypted = dec.decrypt(&encrypted).unwrap(); assert_eq!(decrypted.content_type, ContentType::ApplicationData); assert_eq!(decrypted.data, b"data"); } #[test] fn encrypt_decrypt_roundtrip_chacha20() { let (key, iv) = test_keys(CipherSuite::Chacha20Poly1305); let mut enc = RecordCryptoState::new(CipherSuite::Chacha20Poly1305, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Chacha20Poly1305, key, iv); let original = TlsRecord::new(ContentType::Handshake, b"chacha test".to_vec()); let encrypted = enc.encrypt(&original).unwrap(); let decrypted = dec.decrypt(&encrypted).unwrap(); assert_eq!(decrypted.content_type, ContentType::Handshake); assert_eq!(decrypted.data, b"chacha test"); } #[test] fn sequence_number_increments() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); for i in 0..5u8 { let original = TlsRecord::new(ContentType::ApplicationData, vec![i]); let encrypted = enc.encrypt(&original).unwrap(); let decrypted = dec.decrypt(&encrypted).unwrap(); assert_eq!(decrypted.data, vec![i]); } } #[test] fn out_of_order_decryption_fails() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let r1 = enc .encrypt(&TlsRecord::new(ContentType::ApplicationData, vec![1])) .unwrap(); let r2 = enc .encrypt(&TlsRecord::new(ContentType::ApplicationData, vec![2])) .unwrap(); // Decrypt r2 first (out of order) should fail assert!(dec.decrypt(&r2).is_err()); // r1 should still work let d1 = dec.decrypt(&r1).unwrap(); assert_eq!(d1.data, vec![1]); } #[test] fn tampered_ciphertext_rejected() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let original = TlsRecord::new(ContentType::ApplicationData, b"sensitive".to_vec()); let mut encrypted = enc.encrypt(&original).unwrap(); encrypted.data[0] ^= 0xff; // Tamper with ciphertext assert!(dec.decrypt(&encrypted).is_err()); } #[test] fn content_type_hidden_in_encrypted_record() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); // Encrypt a Handshake record let original = TlsRecord::new(ContentType::Handshake, vec![0x01]); let encrypted = enc.encrypt(&original).unwrap(); // Outer type must be ApplicationData (content type hiding) assert_eq!(encrypted.content_type, ContentType::ApplicationData); } #[test] fn encrypt_empty_record() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key.clone(), iv); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let original = TlsRecord::new(ContentType::ApplicationData, vec![]); let encrypted = enc.encrypt(&original).unwrap(); let decrypted = dec.decrypt(&encrypted).unwrap(); assert_eq!(decrypted.content_type, ContentType::ApplicationData); assert!(decrypted.data.is_empty()); } #[test] fn encrypt_record_overflow() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut enc = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let big = TlsRecord::new( ContentType::ApplicationData, vec![0u8; MAX_PLAINTEXT_LENGTH + 1], ); assert!(enc.encrypt(&big).is_err()); } #[test] fn decrypt_non_appdata_rejected() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); let bad = TlsRecord::new(ContentType::Handshake, vec![0, 1, 2]); assert!(dec.decrypt(&bad).is_err()); } #[test] fn decrypt_too_short_rejected() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut dec = RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv); // Less than TAG_SIZE + 1 let bad = TlsRecord::new(ContentType::ApplicationData, vec![0u8; TAG_SIZE]); assert!(dec.decrypt(&bad).is_err()); } // -- RecordLayer integration tests -- #[test] fn record_layer_plaintext_roundtrip() { let mut buf = Vec::new(); // Write { let cursor = Cursor::new(&mut buf); let mut layer = RecordLayer::new(cursor); layer .write_record(&TlsRecord::new( ContentType::Handshake, b"client hello".to_vec(), )) .unwrap(); } // Read { let cursor = Cursor::new(buf.clone()); let mut layer = RecordLayer::new(cursor); let record = layer.read_record().unwrap(); assert_eq!(record.content_type, ContentType::Handshake); assert_eq!(record.data, b"client hello"); } } #[test] fn record_layer_encrypted_roundtrip() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut buf = Vec::new(); // Write encrypted { let cursor = Cursor::new(&mut buf); let mut layer = RecordLayer::new(cursor); layer.set_write_crypto(RecordCryptoState::new( CipherSuite::Aes128Gcm, key.clone(), iv, )); layer .write_record(&TlsRecord::new( ContentType::ApplicationData, b"encrypted data".to_vec(), )) .unwrap(); } // Read encrypted { let cursor = Cursor::new(buf.clone()); let mut layer = RecordLayer::new(cursor); layer.set_read_crypto(RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv)); let record = layer.read_record().unwrap(); assert_eq!(record.content_type, ContentType::ApplicationData); assert_eq!(record.data, b"encrypted data"); } } #[test] fn record_layer_send_close_notify() { let mut buf = Vec::new(); { let cursor = Cursor::new(&mut buf); let mut layer = RecordLayer::new(cursor); layer.send_close_notify().unwrap(); } let cursor = Cursor::new(buf.clone()); let mut layer = RecordLayer::new(cursor); let record = layer.read_record().unwrap(); assert_eq!(record.content_type, ContentType::Alert); let alert = Alert::parse(&record.data).unwrap(); assert_eq!(alert.description, AlertDescription::CloseNotify); assert!(!alert.is_fatal()); } #[test] fn record_layer_plaintext_then_encrypted() { let (key, iv) = test_keys(CipherSuite::Aes128Gcm); let mut buf = Vec::new(); // Write: one plaintext, then switch to encrypted { let cursor = Cursor::new(&mut buf); let mut layer = RecordLayer::new(cursor); // Plaintext handshake layer .write_record(&TlsRecord::new(ContentType::Handshake, b"hello".to_vec())) .unwrap(); // Switch to encrypted layer.set_write_crypto(RecordCryptoState::new( CipherSuite::Aes128Gcm, key.clone(), iv, )); // Encrypted application data layer .write_record(&TlsRecord::new( ContentType::ApplicationData, b"secret".to_vec(), )) .unwrap(); } // Read: one plaintext, then switch to encrypted { let cursor = Cursor::new(buf.clone()); let mut layer = RecordLayer::new(cursor); // Read plaintext let r1 = layer.read_record().unwrap(); assert_eq!(r1.content_type, ContentType::Handshake); assert_eq!(r1.data, b"hello"); // Switch to encrypted layer.set_read_crypto(RecordCryptoState::new(CipherSuite::Aes128Gcm, key, iv)); // Read encrypted let r2 = layer.read_record().unwrap(); assert_eq!(r2.content_type, ContentType::ApplicationData); assert_eq!(r2.data, b"secret"); } } // -- CipherSuite tests -- #[test] fn cipher_suite_key_lengths() { assert_eq!(CipherSuite::Aes128Gcm.key_len(), 16); assert_eq!(CipherSuite::Aes256Gcm.key_len(), 32); assert_eq!(CipherSuite::Chacha20Poly1305.key_len(), 32); } #[test] fn cipher_suite_iv_lengths() { assert_eq!(CipherSuite::Aes128Gcm.iv_len(), 12); assert_eq!(CipherSuite::Aes256Gcm.iv_len(), 12); assert_eq!(CipherSuite::Chacha20Poly1305.iv_len(), 12); } // -- Error Display tests -- #[test] fn tls_error_display() { assert_eq!( TlsError::UnknownContentType(99).to_string(), "unknown TLS content type: 99" ); assert_eq!(TlsError::RecordOverflow.to_string(), "TLS record overflow"); assert_eq!( TlsError::DecryptionFailed.to_string(), "TLS AEAD decryption failed" ); assert_eq!(TlsError::MalformedAlert.to_string(), "malformed TLS alert"); } #[test] fn tls_error_from_io() { let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken"); let tls_err = TlsError::from(io_err); assert!(matches!(tls_err, TlsError::Io(_))); } }