//! TLS 1.3 handshake client (RFC 8446). //! //! Implements the full TLS 1.3 client handshake including: //! - ClientHello with required extensions //! - ServerHello processing and ECDHE key exchange //! - Encrypted handshake message processing //! - Certificate verification //! - Finished message exchange //! - TlsStream for application data use std::io::{self, Read, Write}; use we_crypto::sha2::{sha256, sha384}; use we_crypto::x25519::{x25519, x25519_base}; use we_crypto::x509::{self, Certificate, DateTime}; use super::key_schedule::{KeySchedule, TranscriptHash}; use super::record::{ CipherSuite, ContentType, RecordCryptoState, RecordLayer, TlsError, TlsRecord, }; // --------------------------------------------------------------------------- // Constants // --------------------------------------------------------------------------- /// TLS 1.2 legacy version (used in ClientHello). const LEGACY_VERSION: [u8; 2] = [0x03, 0x03]; /// TLS 1.3 version identifier for supported_versions extension. const TLS13_VERSION: [u8; 2] = [0x03, 0x04]; // Handshake message types (RFC 8446 §4) const HANDSHAKE_CLIENT_HELLO: u8 = 1; const HANDSHAKE_SERVER_HELLO: u8 = 2; const HANDSHAKE_ENCRYPTED_EXTENSIONS: u8 = 8; const HANDSHAKE_CERTIFICATE: u8 = 11; const HANDSHAKE_CERTIFICATE_VERIFY: u8 = 15; const HANDSHAKE_FINISHED: u8 = 20; // Extension types (RFC 8446 §4.2) const EXT_SERVER_NAME: u16 = 0; const EXT_SUPPORTED_GROUPS: u16 = 10; const EXT_SIGNATURE_ALGORITHMS: u16 = 13; const EXT_SUPPORTED_VERSIONS: u16 = 43; const EXT_KEY_SHARE: u16 = 51; // Named groups const GROUP_X25519: u16 = 0x001d; // Signature schemes (RFC 8446 §4.2.3) const SIG_RSA_PKCS1_SHA256: u16 = 0x0401; const SIG_RSA_PKCS1_SHA384: u16 = 0x0501; const SIG_RSA_PKCS1_SHA512: u16 = 0x0601; const SIG_ECDSA_SECP256R1_SHA256: u16 = 0x0403; const SIG_ECDSA_SECP384R1_SHA384: u16 = 0x0503; const SIG_RSA_PSS_RSAE_SHA256: u16 = 0x0804; const SIG_RSA_PSS_RSAE_SHA384: u16 = 0x0805; // Cipher suite wire values (RFC 8446 §B.4) const CS_AES_128_GCM_SHA256: [u8; 2] = [0x13, 0x01]; const CS_AES_256_GCM_SHA384: [u8; 2] = [0x13, 0x02]; const CS_CHACHA20_POLY1305_SHA256: [u8; 2] = [0x13, 0x03]; // CertificateVerify context string (RFC 8446 §4.4.3) const CV_SERVER_CONTEXT: &[u8] = b"TLS 1.3, server CertificateVerify"; // --------------------------------------------------------------------------- // Error types // --------------------------------------------------------------------------- /// Handshake-specific errors. #[derive(Debug)] pub enum HandshakeError { /// TLS record layer error. Tls(TlsError), /// Unexpected handshake message type. UnexpectedMessage(u8), /// Server selected unsupported cipher suite. UnsupportedCipherSuite, /// Server selected unsupported version. UnsupportedVersion, /// Server did not provide a key share. MissingKeyShare, /// Server key share uses unsupported group. UnsupportedGroup, /// Missing required extension. MissingExtension(&'static str), /// Certificate chain is empty. EmptyCertificateChain, /// Certificate verification failed. CertificateError(String), /// CertificateVerify signature verification failed. SignatureVerificationFailed, /// Finished verify data mismatch. FinishedMismatch, /// Message too short or malformed. Malformed(&'static str), /// I/O error. Io(io::Error), } impl std::fmt::Display for HandshakeError { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { Self::Tls(e) => write!(f, "TLS error: {e}"), Self::UnexpectedMessage(t) => write!(f, "unexpected handshake message type: {t}"), Self::UnsupportedCipherSuite => write!(f, "unsupported cipher suite"), Self::UnsupportedVersion => write!(f, "unsupported TLS version"), Self::MissingKeyShare => write!(f, "server did not provide key share"), Self::UnsupportedGroup => write!(f, "unsupported key exchange group"), Self::MissingExtension(ext) => write!(f, "missing extension: {ext}"), Self::EmptyCertificateChain => write!(f, "empty certificate chain"), Self::CertificateError(e) => write!(f, "certificate error: {e}"), Self::SignatureVerificationFailed => write!(f, "signature verification failed"), Self::FinishedMismatch => write!(f, "finished verify data mismatch"), Self::Malformed(msg) => write!(f, "malformed message: {msg}"), Self::Io(e) => write!(f, "I/O error: {e}"), } } } impl From for HandshakeError { fn from(err: TlsError) -> Self { HandshakeError::Tls(err) } } impl From for HandshakeError { fn from(err: io::Error) -> Self { HandshakeError::Io(err) } } type Result = std::result::Result; // --------------------------------------------------------------------------- // Encoding helpers // --------------------------------------------------------------------------- fn push_u8(buf: &mut Vec, val: u8) { buf.push(val); } fn push_u16(buf: &mut Vec, val: u16) { buf.extend_from_slice(&val.to_be_bytes()); } fn push_u24(buf: &mut Vec, val: u32) { buf.push((val >> 16) as u8); buf.push((val >> 8) as u8); buf.push(val as u8); } fn push_bytes(buf: &mut Vec, data: &[u8]) { buf.extend_from_slice(data); } fn read_u8(data: &[u8], offset: &mut usize) -> Result { if *offset >= data.len() { return Err(HandshakeError::Malformed("unexpected end of data")); } let val = data[*offset]; *offset += 1; Ok(val) } fn read_u16(data: &[u8], offset: &mut usize) -> Result { if *offset + 2 > data.len() { return Err(HandshakeError::Malformed("unexpected end of data")); } let val = u16::from_be_bytes([data[*offset], data[*offset + 1]]); *offset += 2; Ok(val) } fn read_u24(data: &[u8], offset: &mut usize) -> Result { if *offset + 3 > data.len() { return Err(HandshakeError::Malformed("unexpected end of data")); } let val = (data[*offset] as u32) << 16 | (data[*offset + 1] as u32) << 8 | data[*offset + 2] as u32; *offset += 3; Ok(val) } fn read_bytes<'a>(data: &'a [u8], offset: &mut usize, len: usize) -> Result<&'a [u8]> { if *offset + len > data.len() { return Err(HandshakeError::Malformed("unexpected end of data")); } let slice = &data[*offset..*offset + len]; *offset += len; Ok(slice) } // --------------------------------------------------------------------------- // Random bytes generation (using std for now) // --------------------------------------------------------------------------- fn random_bytes(buf: &mut [u8]) { // Read from /dev/urandom for random bytes. // This is available on macOS (which is our only target). let mut f = std::fs::File::open("/dev/urandom").expect("failed to open /dev/urandom"); f.read_exact(buf).expect("failed to read /dev/urandom"); } // --------------------------------------------------------------------------- // ClientHello construction // --------------------------------------------------------------------------- /// Build a ClientHello handshake message. /// /// Returns (handshake_message, x25519_private_key). fn build_client_hello(server_name: &str) -> (Vec, [u8; 32]) { // Generate X25519 ephemeral keypair let mut private_key = [0u8; 32]; random_bytes(&mut private_key); let public_key = x25519_base(&private_key); // Generate random let mut client_random = [0u8; 32]; random_bytes(&mut client_random); // Generate legacy session ID (32 random bytes) let mut session_id = [0u8; 32]; random_bytes(&mut session_id); // Build ClientHello body let mut body = Vec::with_capacity(512); // Protocol version: TLS 1.2 (legacy) push_bytes(&mut body, &LEGACY_VERSION); // Random (32 bytes) push_bytes(&mut body, &client_random); // Legacy session ID (length-prefixed) push_u8(&mut body, 32); push_bytes(&mut body, &session_id); // Cipher suites push_u16(&mut body, 6); // 3 suites * 2 bytes push_bytes(&mut body, &CS_AES_128_GCM_SHA256); push_bytes(&mut body, &CS_AES_256_GCM_SHA384); push_bytes(&mut body, &CS_CHACHA20_POLY1305_SHA256); // Compression methods: only null push_u8(&mut body, 1); // length push_u8(&mut body, 0); // null // Extensions let extensions = build_extensions(server_name, &public_key); push_u16(&mut body, extensions.len() as u16); push_bytes(&mut body, &extensions); // Wrap in handshake header let mut msg = Vec::with_capacity(4 + body.len()); push_u8(&mut msg, HANDSHAKE_CLIENT_HELLO); push_u24(&mut msg, body.len() as u32); push_bytes(&mut msg, &body); (msg, private_key) } fn build_extensions(server_name: &str, x25519_public: &[u8; 32]) -> Vec { let mut exts = Vec::with_capacity(256); // SNI extension (server_name) { let name_bytes = server_name.as_bytes(); // ServerNameList: list_length(2) + type(1) + name_length(2) + name let sni_data_len = 2 + 1 + 2 + name_bytes.len(); push_u16(&mut exts, EXT_SERVER_NAME); push_u16(&mut exts, sni_data_len as u16); // ServerNameList length push_u16(&mut exts, (1 + 2 + name_bytes.len()) as u16); push_u8(&mut exts, 0); // host_name type push_u16(&mut exts, name_bytes.len() as u16); push_bytes(&mut exts, name_bytes); } // supported_versions extension { push_u16(&mut exts, EXT_SUPPORTED_VERSIONS); push_u16(&mut exts, 3); // extension data length push_u8(&mut exts, 2); // list length push_bytes(&mut exts, &TLS13_VERSION); } // supported_groups extension { push_u16(&mut exts, EXT_SUPPORTED_GROUPS); push_u16(&mut exts, 4); // extension data length push_u16(&mut exts, 2); // list length push_u16(&mut exts, GROUP_X25519); } // key_share extension { // KeyShareEntry: group(2) + key_length(2) + key(32) let entry_len = 2 + 2 + 32; push_u16(&mut exts, EXT_KEY_SHARE); push_u16(&mut exts, (2 + entry_len) as u16); // extension data length push_u16(&mut exts, entry_len as u16); // client_shares length push_u16(&mut exts, GROUP_X25519); push_u16(&mut exts, 32); push_bytes(&mut exts, x25519_public); } // signature_algorithms extension { let sig_algs = [ SIG_ECDSA_SECP256R1_SHA256, SIG_ECDSA_SECP384R1_SHA384, SIG_RSA_PSS_RSAE_SHA256, SIG_RSA_PSS_RSAE_SHA384, SIG_RSA_PKCS1_SHA256, SIG_RSA_PKCS1_SHA384, SIG_RSA_PKCS1_SHA512, ]; let list_len = sig_algs.len() * 2; push_u16(&mut exts, EXT_SIGNATURE_ALGORITHMS); push_u16(&mut exts, (2 + list_len) as u16); push_u16(&mut exts, list_len as u16); for alg in sig_algs { push_u16(&mut exts, alg); } } exts } // --------------------------------------------------------------------------- // ServerHello parsing // --------------------------------------------------------------------------- struct ServerHelloResult { cipher_suite: CipherSuite, server_x25519_public: [u8; 32], } fn parse_server_hello(data: &[u8]) -> Result { let mut offset = 0; // legacy_version (2) let _legacy_version = read_bytes(data, &mut offset, 2)?; // random (32) let _random = read_bytes(data, &mut offset, 32)?; // legacy_session_id_echo let session_id_len = read_u8(data, &mut offset)? as usize; let _session_id = read_bytes(data, &mut offset, session_id_len)?; // cipher_suite (2) let cs_bytes = read_bytes(data, &mut offset, 2)?; let cipher_suite = match [cs_bytes[0], cs_bytes[1]] { CS_AES_128_GCM_SHA256 => CipherSuite::Aes128Gcm, CS_AES_256_GCM_SHA384 => CipherSuite::Aes256Gcm, CS_CHACHA20_POLY1305_SHA256 => CipherSuite::Chacha20Poly1305, _ => return Err(HandshakeError::UnsupportedCipherSuite), }; // legacy_compression_method (1) let _compression = read_u8(data, &mut offset)?; // Extensions let extensions_len = read_u16(data, &mut offset)? as usize; let extensions_end = offset + extensions_len; let mut server_key: Option<[u8; 32]> = None; let mut has_supported_versions = false; while offset < extensions_end { let ext_type = read_u16(data, &mut offset)?; let ext_len = read_u16(data, &mut offset)? as usize; let ext_data = read_bytes(data, &mut offset, ext_len)?; match ext_type { EXT_KEY_SHARE => { let mut eoff = 0; let group = read_u16(ext_data, &mut eoff)?; if group != GROUP_X25519 { return Err(HandshakeError::UnsupportedGroup); } let key_len = read_u16(ext_data, &mut eoff)? as usize; if key_len != 32 { return Err(HandshakeError::Malformed("invalid x25519 key length")); } let key_data = read_bytes(ext_data, &mut eoff, 32)?; let mut key = [0u8; 32]; key.copy_from_slice(key_data); server_key = Some(key); } EXT_SUPPORTED_VERSIONS => { if ext_data.len() >= 2 && ext_data[0] == 0x03 && ext_data[1] == 0x04 { has_supported_versions = true; } } _ => {} // Ignore unknown extensions } } if !has_supported_versions { return Err(HandshakeError::UnsupportedVersion); } let server_x25519_public = server_key.ok_or(HandshakeError::MissingKeyShare)?; Ok(ServerHelloResult { cipher_suite, server_x25519_public, }) } // --------------------------------------------------------------------------- // Encrypted handshake message parsing // --------------------------------------------------------------------------- fn parse_encrypted_extensions(data: &[u8]) -> Result<()> { let mut offset = 0; let _extensions_len = read_u16(data, &mut offset)?; // We don't require any specific encrypted extensions for now. // Just validate the format is parseable. Ok(()) } /// Parse a Certificate handshake message (RFC 8446 §4.4.2). /// Returns the list of DER-encoded certificates. fn parse_certificate_message(data: &[u8]) -> Result>> { let mut offset = 0; // certificate_request_context (opaque <0..255>) let ctx_len = read_u8(data, &mut offset)? as usize; let _ctx = read_bytes(data, &mut offset, ctx_len)?; // certificate_list length (3 bytes) let list_len = read_u24(data, &mut offset)? as usize; let list_end = offset + list_len; let mut certs = Vec::new(); while offset < list_end { // cert_data length (3 bytes) let cert_len = read_u24(data, &mut offset)? as usize; let cert_data = read_bytes(data, &mut offset, cert_len)?; certs.push(cert_data.to_vec()); // extensions length (2 bytes per cert entry) let ext_len = read_u16(data, &mut offset)? as usize; let _ext = read_bytes(data, &mut offset, ext_len)?; } if certs.is_empty() { return Err(HandshakeError::EmptyCertificateChain); } Ok(certs) } /// Parse and verify a CertificateVerify message (RFC 8446 §4.4.3). fn verify_certificate_verify( data: &[u8], cert: &Certificate, transcript_hash: &[u8], ) -> Result<()> { let mut offset = 0; let scheme = read_u16(data, &mut offset)?; let sig_len = read_u16(data, &mut offset)? as usize; let signature = read_bytes(data, &mut offset, sig_len)?; // Build the content to verify: // 64 spaces + context string + 0x00 + transcript hash let mut content = Vec::with_capacity(64 + CV_SERVER_CONTEXT.len() + 1 + transcript_hash.len()); content.extend_from_slice(&[0x20u8; 64]); content.extend_from_slice(CV_SERVER_CONTEXT); content.push(0x00); content.extend_from_slice(transcript_hash); match scheme { SIG_RSA_PKCS1_SHA256 => { let pubkey = we_crypto::rsa::RsaPublicKey::from_der(&cert.subject_public_key_info) .map_err(|e| HandshakeError::CertificateError(format!("RSA key: {e:?}")))?; pubkey .verify_pkcs1v15(we_crypto::rsa::HashAlgorithm::Sha256, &content, signature) .map_err(|_| HandshakeError::SignatureVerificationFailed)?; } SIG_RSA_PKCS1_SHA384 => { let pubkey = we_crypto::rsa::RsaPublicKey::from_der(&cert.subject_public_key_info) .map_err(|e| HandshakeError::CertificateError(format!("RSA key: {e:?}")))?; pubkey .verify_pkcs1v15(we_crypto::rsa::HashAlgorithm::Sha384, &content, signature) .map_err(|_| HandshakeError::SignatureVerificationFailed)?; } SIG_RSA_PKCS1_SHA512 => { let pubkey = we_crypto::rsa::RsaPublicKey::from_der(&cert.subject_public_key_info) .map_err(|e| HandshakeError::CertificateError(format!("RSA key: {e:?}")))?; pubkey .verify_pkcs1v15(we_crypto::rsa::HashAlgorithm::Sha512, &content, signature) .map_err(|_| HandshakeError::SignatureVerificationFailed)?; } SIG_ECDSA_SECP256R1_SHA256 => { let pubkey = we_crypto::ecdsa::EcdsaPublicKey::from_spki_der(&cert.subject_public_key_info) .map_err(|e| HandshakeError::CertificateError(format!("ECDSA key: {e:?}")))?; let sig = we_crypto::ecdsa::EcdsaSignature::from_der(signature) .map_err(|e| HandshakeError::CertificateError(format!("ECDSA sig: {e:?}")))?; // Hash the content with SHA-256 and verify let hash = sha256(&content); pubkey .verify_prehashed(&hash, &sig) .map_err(|_| HandshakeError::SignatureVerificationFailed)?; } SIG_ECDSA_SECP384R1_SHA384 => { let pubkey = we_crypto::ecdsa::EcdsaPublicKey::from_spki_der(&cert.subject_public_key_info) .map_err(|e| HandshakeError::CertificateError(format!("ECDSA key: {e:?}")))?; let sig = we_crypto::ecdsa::EcdsaSignature::from_der(signature) .map_err(|e| HandshakeError::CertificateError(format!("ECDSA sig: {e:?}")))?; let hash = sha384(&content); pubkey .verify_prehashed(&hash, &sig) .map_err(|_| HandshakeError::SignatureVerificationFailed)?; } SIG_RSA_PSS_RSAE_SHA256 | SIG_RSA_PSS_RSAE_SHA384 => { // RSA-PSS is common in TLS 1.3. For now, we accept the connection // if we can't verify PSS signatures, but we should implement it. // TODO: Implement RSA-PSS signature verification // For now, skip verification for PSS schemes. // This is a known limitation. return Err(HandshakeError::SignatureVerificationFailed); } _ => { return Err(HandshakeError::SignatureVerificationFailed); } } Ok(()) } // --------------------------------------------------------------------------- // Handshake message reading // --------------------------------------------------------------------------- /// Read the next handshake message from the record layer. /// Returns (handshake_type, body_bytes, full_handshake_message). fn read_handshake_message( record_layer: &mut RecordLayer, ) -> Result<(u8, Vec, Vec)> { let record = record_layer.read_record()?; if record.content_type != ContentType::Handshake { return Err(HandshakeError::Malformed("expected handshake record")); } let data = &record.data; if data.len() < 4 { return Err(HandshakeError::Malformed("handshake message too short")); } let msg_type = data[0]; let body_len = (data[1] as usize) << 16 | (data[2] as usize) << 8 | data[3] as usize; if data.len() < 4 + body_len { return Err(HandshakeError::Malformed( "handshake message body truncated", )); } let body = data[4..4 + body_len].to_vec(); let full_msg = data[..4 + body_len].to_vec(); Ok((msg_type, body, full_msg)) } // --------------------------------------------------------------------------- // TlsStream // --------------------------------------------------------------------------- /// A TLS-encrypted stream over a transport. /// /// Provides `Read` and `Write` for application data after a successful /// TLS 1.3 handshake. pub struct TlsStream { record_layer: RecordLayer, read_buffer: Vec, read_offset: usize, } impl TlsStream { /// Read decrypted application data. pub fn read(&mut self, buf: &mut [u8]) -> Result { // If we have buffered data, return from that first if self.read_offset < self.read_buffer.len() { let available = &self.read_buffer[self.read_offset..]; let to_copy = available.len().min(buf.len()); buf[..to_copy].copy_from_slice(&available[..to_copy]); self.read_offset += to_copy; if self.read_offset >= self.read_buffer.len() { self.read_buffer.clear(); self.read_offset = 0; } return Ok(to_copy); } // Read next record let record = self.record_layer.read_record()?; match record.content_type { ContentType::ApplicationData => { let to_copy = record.data.len().min(buf.len()); buf[..to_copy].copy_from_slice(&record.data[..to_copy]); if to_copy < record.data.len() { self.read_buffer = record.data; self.read_offset = to_copy; } Ok(to_copy) } ContentType::Alert => { if record.data.len() >= 2 && record.data[1] == 0 { // close_notify Ok(0) } else { Err(HandshakeError::Tls(TlsError::DecryptionFailed)) } } _ => Err(HandshakeError::Malformed("unexpected record type")), } } /// Write application data. pub fn write(&mut self, data: &[u8]) -> Result { let record = TlsRecord::new(ContentType::ApplicationData, data.to_vec()); self.record_layer.write_record(&record)?; Ok(data.len()) } /// Write all application data. pub fn write_all(&mut self, data: &[u8]) -> Result<()> { let record = TlsRecord::new(ContentType::ApplicationData, data.to_vec()); self.record_layer.write_record(&record)?; Ok(()) } /// Send a close_notify alert and shut down the TLS connection. pub fn close(&mut self) -> Result<()> { self.record_layer.send_close_notify()?; Ok(()) } /// Get a reference to the underlying stream. pub fn stream(&self) -> &S { self.record_layer.stream() } } // --------------------------------------------------------------------------- // Handshake state machine // --------------------------------------------------------------------------- /// Perform a TLS 1.3 handshake over the given stream. /// /// Returns a `TlsStream` ready for application data. pub fn connect(stream: S, server_name: &str) -> Result> { let mut record_layer = RecordLayer::new(stream); // Step 1: Build and send ClientHello let (client_hello_msg, x25519_private) = build_client_hello(server_name); let ch_record = TlsRecord::new(ContentType::Handshake, client_hello_msg.clone()); record_layer.write_record(&ch_record)?; // Step 2: Read ServerHello let (sh_type, sh_body, sh_full) = read_handshake_message(&mut record_layer)?; if sh_type != HANDSHAKE_SERVER_HELLO { return Err(HandshakeError::UnexpectedMessage(sh_type)); } let sh = parse_server_hello(&sh_body)?; // Step 3: Set up transcript hash and key schedule let mut transcript = TranscriptHash::new(sh.cipher_suite); transcript.update(&client_hello_msg); transcript.update(&sh_full); // Compute ECDHE shared secret let shared_secret = x25519(&x25519_private, &sh.server_x25519_public); // Derive handshake secrets let mut key_schedule = KeySchedule::new(sh.cipher_suite, None); key_schedule.derive_handshake_secrets(&shared_secret, &transcript.current_hash()); // Step 4: Switch to handshake encryption for reading let server_hs_keys = key_schedule.server_handshake_keys(); let server_iv: [u8; 12] = server_hs_keys.iv.try_into().expect("IV must be 12 bytes"); record_layer.set_read_crypto(RecordCryptoState::new( sh.cipher_suite, server_hs_keys.key, server_iv, )); // Step 5: Read EncryptedExtensions let (ee_type, ee_body, ee_full) = read_handshake_message(&mut record_layer)?; if ee_type != HANDSHAKE_ENCRYPTED_EXTENSIONS { return Err(HandshakeError::UnexpectedMessage(ee_type)); } parse_encrypted_extensions(&ee_body)?; transcript.update(&ee_full); // Step 6: Read Certificate let (cert_type, cert_body, cert_full) = read_handshake_message(&mut record_layer)?; if cert_type != HANDSHAKE_CERTIFICATE { return Err(HandshakeError::UnexpectedMessage(cert_type)); } let cert_ders = parse_certificate_message(&cert_body)?; transcript.update(&cert_full); // Parse certificates let mut certs = Vec::with_capacity(cert_ders.len()); for der in &cert_ders { let cert = Certificate::from_der(der) .map_err(|e| HandshakeError::CertificateError(format!("{e:?}")))?; certs.push(cert); } // Validate certificate chain let now = current_datetime(); let root_store = x509::root_ca_store() .map_err(|e| HandshakeError::CertificateError(format!("root CA store: {e:?}")))?; x509::validate_chain(&certs, &root_store, &now) .map_err(|e| HandshakeError::CertificateError(format!("{e:?}")))?; // Verify server name matches certificate verify_server_name(&certs[0], server_name)?; // Step 7: Read CertificateVerify let (cv_type, cv_body, cv_full) = read_handshake_message(&mut record_layer)?; if cv_type != HANDSHAKE_CERTIFICATE_VERIFY { return Err(HandshakeError::UnexpectedMessage(cv_type)); } // Verify CertificateVerify against transcript hash up to Certificate let cv_transcript_hash = transcript.current_hash(); verify_certificate_verify(&cv_body, &certs[0], &cv_transcript_hash)?; transcript.update(&cv_full); // Step 8: Read server Finished let (fin_type, fin_body, fin_full) = read_handshake_message(&mut record_layer)?; if fin_type != HANDSHAKE_FINISHED { return Err(HandshakeError::UnexpectedMessage(fin_type)); } // Verify server Finished let server_finished_hash = transcript.current_hash(); let expected_verify_data = key_schedule.compute_finished_verify_data( key_schedule.server_handshake_traffic_secret().unwrap(), &server_finished_hash, ); if fin_body != expected_verify_data { return Err(HandshakeError::FinishedMismatch); } transcript.update(&fin_full); // Step 9: Switch write to handshake encryption and send client Finished let client_hs_keys = key_schedule.client_handshake_keys(); let client_hs_iv: [u8; 12] = client_hs_keys.iv.try_into().expect("IV must be 12 bytes"); record_layer.set_write_crypto(RecordCryptoState::new( sh.cipher_suite, client_hs_keys.key, client_hs_iv, )); // Compute and send client Finished let client_finished_hash = transcript.current_hash(); let client_verify_data = key_schedule.compute_finished_verify_data( key_schedule.client_handshake_traffic_secret().unwrap(), &client_finished_hash, ); let mut client_finished_msg = Vec::with_capacity(4 + client_verify_data.len()); push_u8(&mut client_finished_msg, HANDSHAKE_FINISHED); push_u24(&mut client_finished_msg, client_verify_data.len() as u32); push_bytes(&mut client_finished_msg, &client_verify_data); let finished_record = TlsRecord::new(ContentType::Handshake, client_finished_msg.clone()); record_layer.write_record(&finished_record)?; transcript.update(&client_finished_msg); // Step 10: Derive application keys // RFC 8446 §7.1: application traffic secrets use the transcript hash through // server Finished (client_finished_hash), NOT including client Finished. key_schedule.derive_app_secrets(&client_finished_hash); let client_app_keys = key_schedule.client_app_keys(); let client_app_iv: [u8; 12] = client_app_keys.iv.try_into().expect("IV must be 12 bytes"); record_layer.set_write_crypto(RecordCryptoState::new( sh.cipher_suite, client_app_keys.key, client_app_iv, )); let server_app_keys = key_schedule.server_app_keys(); let server_app_iv: [u8; 12] = server_app_keys.iv.try_into().expect("IV must be 12 bytes"); record_layer.set_read_crypto(RecordCryptoState::new( sh.cipher_suite, server_app_keys.key, server_app_iv, )); Ok(TlsStream { record_layer, read_buffer: Vec::new(), read_offset: 0, }) } // --------------------------------------------------------------------------- // Server name verification // --------------------------------------------------------------------------- fn verify_server_name(cert: &Certificate, server_name: &str) -> Result<()> { // Check Subject Alternative Names first for san in &cert.extensions.subject_alt_names { if let x509::SubjectAltName::DnsName(dns) = san { if matches_hostname(dns, server_name) { return Ok(()); } } } // Fall back to Common Name if let Some(cn) = &cert.subject.common_name { if matches_hostname(cn, server_name) { return Ok(()); } } Err(HandshakeError::CertificateError(format!( "server name '{server_name}' does not match certificate" ))) } /// Match a hostname against a pattern (supports leading wildcard). fn matches_hostname(pattern: &str, hostname: &str) -> bool { let pattern = pattern.to_ascii_lowercase(); let hostname = hostname.to_ascii_lowercase(); if pattern == hostname { return true; } // Wildcard matching: *.example.com matches foo.example.com if let Some(suffix) = pattern.strip_prefix("*.") { if let Some(rest) = hostname.strip_suffix(suffix) { // The wildcard must match exactly one non-empty label if rest.ends_with('.') && rest.len() > 1 && !rest[..rest.len() - 1].contains('.') { return true; } } } false } // --------------------------------------------------------------------------- // Utility // --------------------------------------------------------------------------- fn current_datetime() -> DateTime { // We use a fixed recent time since we don't have access to system time // in a portable way without external crates. In practice this would // use std::time::SystemTime. // // For now, use SystemTime to compute the actual date. use std::time::{SystemTime, UNIX_EPOCH}; let duration = SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap_or_default(); let secs = duration.as_secs(); // Simple conversion from unix timestamp to date components let days = secs / 86400; let time_of_day = secs % 86400; let hour = (time_of_day / 3600) as u8; let minute = ((time_of_day % 3600) / 60) as u8; let second = (time_of_day % 60) as u8; // Days since 1970-01-01 let (year, month, day) = days_to_date(days); DateTime::new(year, month, day, hour, minute, second) } fn days_to_date(days_since_epoch: u64) -> (u16, u8, u8) { // Algorithm from Howard Hinnant's chrono-compatible date algorithms let z = days_since_epoch + 719468; let era = z / 146097; let doe = z - era * 146097; let yoe = (doe - doe / 1460 + doe / 36524 - doe / 146096) / 365; let y = yoe + era * 400; let doy = doe - (365 * yoe + yoe / 4 - yoe / 100); let mp = (5 * doy + 2) / 153; let d = doy - (153 * mp + 2) / 5 + 1; let m = if mp < 10 { mp + 3 } else { mp - 9 }; let y = if m <= 2 { y + 1 } else { y }; (y as u16, m as u8, d as u8) } // --------------------------------------------------------------------------- // Tests // --------------------------------------------------------------------------- #[cfg(test)] mod tests { use super::*; // -- Encoding helpers -- #[test] fn push_u8_works() { let mut buf = Vec::new(); push_u8(&mut buf, 0x42); assert_eq!(buf, vec![0x42]); } #[test] fn push_u16_works() { let mut buf = Vec::new(); push_u16(&mut buf, 0x1234); assert_eq!(buf, vec![0x12, 0x34]); } #[test] fn push_u24_works() { let mut buf = Vec::new(); push_u24(&mut buf, 0x123456); assert_eq!(buf, vec![0x12, 0x34, 0x56]); } #[test] fn read_u8_works() { let data = [0x42, 0x43]; let mut offset = 0; assert_eq!(read_u8(&data, &mut offset).unwrap(), 0x42); assert_eq!(offset, 1); assert_eq!(read_u8(&data, &mut offset).unwrap(), 0x43); assert_eq!(offset, 2); } #[test] fn read_u16_works() { let data = [0x12, 0x34]; let mut offset = 0; assert_eq!(read_u16(&data, &mut offset).unwrap(), 0x1234); assert_eq!(offset, 2); } #[test] fn read_u24_works() { let data = [0x12, 0x34, 0x56]; let mut offset = 0; assert_eq!(read_u24(&data, &mut offset).unwrap(), 0x123456); assert_eq!(offset, 3); } #[test] fn read_past_end_fails() { let data = [0x42]; let mut offset = 0; assert!(read_u16(&data, &mut offset).is_err()); } #[test] fn read_bytes_works() { let data = [1, 2, 3, 4, 5]; let mut offset = 1; let slice = read_bytes(&data, &mut offset, 3).unwrap(); assert_eq!(slice, &[2, 3, 4]); assert_eq!(offset, 4); } // -- ClientHello construction -- #[test] fn client_hello_has_correct_type() { let (msg, _) = build_client_hello("example.com"); assert_eq!(msg[0], HANDSHAKE_CLIENT_HELLO); } #[test] fn client_hello_has_valid_length() { let (msg, _) = build_client_hello("example.com"); let body_len = (msg[1] as usize) << 16 | (msg[2] as usize) << 8 | msg[3] as usize; assert_eq!(msg.len(), 4 + body_len); } #[test] fn client_hello_starts_with_legacy_version() { let (msg, _) = build_client_hello("example.com"); assert_eq!(msg[4], 0x03); assert_eq!(msg[5], 0x03); } #[test] fn client_hello_has_32_byte_random() { let (msg1, _) = build_client_hello("example.com"); let (msg2, _) = build_client_hello("example.com"); // Random bytes start at offset 6 (after type(1) + length(3) + version(2)) let random1 = &msg1[6..38]; let random2 = &msg2[6..38]; // They should almost certainly differ (random) // But in tests /dev/urandom might give same bytes... just check length assert_eq!(random1.len(), 32); assert_eq!(random2.len(), 32); } #[test] fn client_hello_has_session_id() { let (msg, _) = build_client_hello("example.com"); // session_id_len at offset 38 assert_eq!(msg[38], 32); // 32-byte session ID } #[test] fn client_hello_has_cipher_suites() { let (msg, _) = build_client_hello("example.com"); // After version(2) + random(32) + session_id_len(1) + session_id(32) let cs_offset = 4 + 2 + 32 + 1 + 32; let cs_len = u16::from_be_bytes([msg[cs_offset], msg[cs_offset + 1]]); assert_eq!(cs_len, 6); // 3 suites * 2 bytes } #[test] fn client_hello_returns_private_key() { let (_, private_key) = build_client_hello("example.com"); assert_eq!(private_key.len(), 32); } // -- ServerHello parsing -- fn build_test_server_hello(suite: [u8; 2], x25519_key: &[u8; 32]) -> Vec { let mut body = Vec::new(); // legacy version push_bytes(&mut body, &LEGACY_VERSION); // random push_bytes(&mut body, &[0u8; 32]); // session_id echo (empty) push_u8(&mut body, 0); // cipher suite push_bytes(&mut body, &suite); // compression push_u8(&mut body, 0); // Extensions let mut exts = Vec::new(); // supported_versions push_u16(&mut exts, EXT_SUPPORTED_VERSIONS); push_u16(&mut exts, 2); push_bytes(&mut exts, &TLS13_VERSION); // key_share push_u16(&mut exts, EXT_KEY_SHARE); push_u16(&mut exts, 36); // group(2) + len(2) + key(32) push_u16(&mut exts, GROUP_X25519); push_u16(&mut exts, 32); push_bytes(&mut exts, x25519_key); push_u16(&mut body, exts.len() as u16); push_bytes(&mut body, &exts); body } #[test] fn parse_server_hello_aes128() { let key = [0x42u8; 32]; let body = build_test_server_hello(CS_AES_128_GCM_SHA256, &key); let result = parse_server_hello(&body).unwrap(); assert_eq!(result.cipher_suite, CipherSuite::Aes128Gcm); assert_eq!(result.server_x25519_public, key); } #[test] fn parse_server_hello_aes256() { let key = [0x43u8; 32]; let body = build_test_server_hello(CS_AES_256_GCM_SHA384, &key); let result = parse_server_hello(&body).unwrap(); assert_eq!(result.cipher_suite, CipherSuite::Aes256Gcm); assert_eq!(result.server_x25519_public, key); } #[test] fn parse_server_hello_chacha() { let key = [0x44u8; 32]; let body = build_test_server_hello(CS_CHACHA20_POLY1305_SHA256, &key); let result = parse_server_hello(&body).unwrap(); assert_eq!(result.cipher_suite, CipherSuite::Chacha20Poly1305); } #[test] fn parse_server_hello_unsupported_suite() { let key = [0x42u8; 32]; let mut body = build_test_server_hello(CS_AES_128_GCM_SHA256, &key); // Corrupt cipher suite let cs_offset = 2 + 32 + 1; body[cs_offset] = 0xFF; body[cs_offset + 1] = 0xFF; assert!(parse_server_hello(&body).is_err()); } #[test] fn parse_server_hello_missing_version() { let key = [0x42u8; 32]; let mut body = Vec::new(); push_bytes(&mut body, &LEGACY_VERSION); push_bytes(&mut body, &[0u8; 32]); push_u8(&mut body, 0); push_bytes(&mut body, &CS_AES_128_GCM_SHA256); push_u8(&mut body, 0); // Only key_share, no supported_versions let mut exts = Vec::new(); push_u16(&mut exts, EXT_KEY_SHARE); push_u16(&mut exts, 36); push_u16(&mut exts, GROUP_X25519); push_u16(&mut exts, 32); push_bytes(&mut exts, &key); push_u16(&mut body, exts.len() as u16); push_bytes(&mut body, &exts); assert!(parse_server_hello(&body).is_err()); } #[test] fn parse_server_hello_missing_key_share() { let mut body = Vec::new(); push_bytes(&mut body, &LEGACY_VERSION); push_bytes(&mut body, &[0u8; 32]); push_u8(&mut body, 0); push_bytes(&mut body, &CS_AES_128_GCM_SHA256); push_u8(&mut body, 0); // Only supported_versions, no key_share let mut exts = Vec::new(); push_u16(&mut exts, EXT_SUPPORTED_VERSIONS); push_u16(&mut exts, 2); push_bytes(&mut exts, &TLS13_VERSION); push_u16(&mut body, exts.len() as u16); push_bytes(&mut body, &exts); assert!(parse_server_hello(&body).is_err()); } // -- Certificate message parsing -- #[test] fn parse_empty_certificate_fails() { // certificate_request_context = empty, certificate_list = empty let mut data = Vec::new(); push_u8(&mut data, 0); // empty context push_u24(&mut data, 0); // empty list assert!(parse_certificate_message(&data).is_err()); } #[test] fn parse_certificate_message_single_cert() { let fake_cert = vec![0x30, 0x82, 0x01, 0x00]; // fake DER let mut data = Vec::new(); push_u8(&mut data, 0); // empty context // certificate list let entry_len = 3 + fake_cert.len() + 2; // cert_len(3) + cert + ext_len(2) push_u24(&mut data, entry_len as u32); push_u24(&mut data, fake_cert.len() as u32); push_bytes(&mut data, &fake_cert); push_u16(&mut data, 0); // no extensions let certs = parse_certificate_message(&data).unwrap(); assert_eq!(certs.len(), 1); assert_eq!(certs[0], fake_cert); } #[test] fn parse_certificate_message_two_certs() { let cert1 = vec![0x30, 0x01]; let cert2 = vec![0x30, 0x02, 0x03]; let mut data = Vec::new(); push_u8(&mut data, 0); // empty context let entry1_len = 3 + cert1.len() + 2; let entry2_len = 3 + cert2.len() + 2; push_u24(&mut data, (entry1_len + entry2_len) as u32); push_u24(&mut data, cert1.len() as u32); push_bytes(&mut data, &cert1); push_u16(&mut data, 0); push_u24(&mut data, cert2.len() as u32); push_bytes(&mut data, &cert2); push_u16(&mut data, 0); let certs = parse_certificate_message(&data).unwrap(); assert_eq!(certs.len(), 2); assert_eq!(certs[0], cert1); assert_eq!(certs[1], cert2); } // -- EncryptedExtensions parsing -- #[test] fn parse_encrypted_extensions_empty() { let mut data = Vec::new(); push_u16(&mut data, 0); // empty extensions assert!(parse_encrypted_extensions(&data).is_ok()); } // -- Hostname matching -- #[test] fn exact_hostname_match() { assert!(matches_hostname("example.com", "example.com")); assert!(matches_hostname("Example.COM", "example.com")); } #[test] fn hostname_mismatch() { assert!(!matches_hostname("example.com", "other.com")); assert!(!matches_hostname("example.com", "sub.example.com")); } #[test] fn wildcard_hostname_match() { assert!(matches_hostname("*.example.com", "www.example.com")); assert!(matches_hostname("*.example.com", "mail.example.com")); } #[test] fn wildcard_no_deep_match() { // Wildcard should match only one label assert!(!matches_hostname("*.example.com", "a.b.example.com")); } #[test] fn wildcard_no_bare_match() { assert!(!matches_hostname("*.example.com", "example.com")); } // -- Date conversion -- #[test] fn days_to_date_epoch() { let (y, m, d) = days_to_date(0); assert_eq!((y, m, d), (1970, 1, 1)); } #[test] fn days_to_date_known() { // 2026-03-12 = 20524 days since epoch let (y, m, d) = days_to_date(20524); assert_eq!((y, m, d), (2026, 3, 12)); } #[test] fn days_to_date_2000() { // 2000-01-01 = 10957 days since epoch let (y, m, d) = days_to_date(10957); assert_eq!((y, m, d), (2000, 1, 1)); } // -- Error display -- #[test] fn handshake_error_display() { let err = HandshakeError::UnsupportedCipherSuite; assert_eq!(err.to_string(), "unsupported cipher suite"); } #[test] fn handshake_error_from_tls() { let tls_err = TlsError::RecordOverflow; let hs_err = HandshakeError::from(tls_err); assert!(matches!(hs_err, HandshakeError::Tls(_))); } #[test] fn handshake_error_from_io() { let io_err = io::Error::new(io::ErrorKind::BrokenPipe, "broken"); let hs_err = HandshakeError::from(io_err); assert!(matches!(hs_err, HandshakeError::Io(_))); } // -- Extensions building -- #[test] fn extensions_contain_sni() { let key = [0u8; 32]; let exts = build_extensions("example.com", &key); // First extension should be SNI (type 0x0000) assert_eq!(exts[0], 0x00); assert_eq!(exts[1], 0x00); } #[test] fn extensions_contain_supported_versions() { let key = [0u8; 32]; let exts = build_extensions("example.com", &key); // Search for supported_versions extension type (0x002b = 43) let mut found = false; let mut i = 0; while i + 4 <= exts.len() { let ext_type = u16::from_be_bytes([exts[i], exts[i + 1]]); let ext_len = u16::from_be_bytes([exts[i + 2], exts[i + 3]]) as usize; if ext_type == EXT_SUPPORTED_VERSIONS { found = true; break; } i += 4 + ext_len; } assert!(found, "supported_versions extension not found"); } #[test] fn extensions_contain_key_share() { let key = [0x42u8; 32]; let exts = build_extensions("example.com", &key); let mut found = false; let mut i = 0; while i + 4 <= exts.len() { let ext_type = u16::from_be_bytes([exts[i], exts[i + 1]]); let ext_len = u16::from_be_bytes([exts[i + 2], exts[i + 3]]) as usize; if ext_type == EXT_KEY_SHARE { found = true; // Verify the key is in there // ext_data: list_len(2) + group(2) + key_len(2) + key(32) let key_start = i + 4 + 2 + 2 + 2; assert_eq!(&exts[key_start..key_start + 32], &key); break; } i += 4 + ext_len; } assert!(found, "key_share extension not found"); } #[test] fn extensions_contain_sig_algorithms() { let key = [0u8; 32]; let exts = build_extensions("example.com", &key); let mut found = false; let mut i = 0; while i + 4 <= exts.len() { let ext_type = u16::from_be_bytes([exts[i], exts[i + 1]]); let ext_len = u16::from_be_bytes([exts[i + 2], exts[i + 3]]) as usize; if ext_type == EXT_SIGNATURE_ALGORITHMS { found = true; break; } i += 4 + ext_len; } assert!(found, "signature_algorithms extension not found"); } // -- Build and parse roundtrip -- #[test] fn server_hello_roundtrip_with_session_id() { let key = [0x55u8; 32]; let mut body = Vec::new(); push_bytes(&mut body, &LEGACY_VERSION); push_bytes(&mut body, &[0xAAu8; 32]); // random // 32-byte session ID echo push_u8(&mut body, 32); push_bytes(&mut body, &[0xBBu8; 32]); push_bytes(&mut body, &CS_AES_128_GCM_SHA256); push_u8(&mut body, 0); // compression // Extensions let mut exts = Vec::new(); push_u16(&mut exts, EXT_SUPPORTED_VERSIONS); push_u16(&mut exts, 2); push_bytes(&mut exts, &TLS13_VERSION); push_u16(&mut exts, EXT_KEY_SHARE); push_u16(&mut exts, 36); push_u16(&mut exts, GROUP_X25519); push_u16(&mut exts, 32); push_bytes(&mut exts, &key); push_u16(&mut body, exts.len() as u16); push_bytes(&mut body, &exts); let result = parse_server_hello(&body).unwrap(); assert_eq!(result.cipher_suite, CipherSuite::Aes128Gcm); assert_eq!(result.server_x25519_public, key); } // -- current_datetime sanity check -- #[test] fn current_datetime_reasonable() { let dt = current_datetime(); assert!(dt.year >= 2024); assert!((1..=12).contains(&dt.month)); assert!((1..=31).contains(&dt.day)); } }