use atrium_api::did_doc::DidDocument; use atrium_crypto::did::parse_multikey; use atrium_crypto::verify::Verifier; use atrium_crypto::Algorithm; use base64::engine::general_purpose::URL_SAFE_NO_PAD; use base64::Engine; use serde::Deserialize; use std::fmt; #[derive(Debug)] pub enum ServiceAuthError { InvalidFormat, InvalidBase64(String), InvalidJson(String), UnsupportedAlgorithm(String), Expired, InvalidSignature(String), MissingSigningKey, ClaimMismatch(String), } impl fmt::Display for ServiceAuthError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidFormat => write!(f, "JWT must have three dot-separated parts"), Self::InvalidBase64(msg) => write!(f, "Invalid base64url encoding: {msg}"), Self::InvalidJson(msg) => write!(f, "Invalid JSON in JWT: {msg}"), Self::UnsupportedAlgorithm(alg) => { write!(f, "Unsupported algorithm '{alg}', expected ES256 or ES256K") } Self::Expired => write!(f, "Token has expired"), Self::InvalidSignature(msg) => write!(f, "Invalid signature: {msg}"), Self::MissingSigningKey => write!(f, "No signing key found in DID document"), Self::ClaimMismatch(msg) => write!(f, "{msg}"), } } } #[derive(Debug, Deserialize)] struct JwtHeader { alg: String, #[allow(dead_code)] typ: Option, } #[derive(Debug, Deserialize, Clone)] pub struct ServiceAuthClaims { pub iss: String, pub aud: String, #[serde(default)] pub exp: u64, #[serde(default)] pub iat: Option, #[serde(default)] pub lxm: Option, #[serde(default)] pub jti: Option, } /// Decode JWT claims without verifying the signature. /// Useful for extracting the `iss` field before resolving the DID document. pub fn decode_jwt_claims(jwt: &str) -> Result { let parts: Vec<&str> = jwt.trim().split('.').collect(); if parts.len() != 3 { return Err(ServiceAuthError::InvalidFormat); } let payload_bytes = URL_SAFE_NO_PAD .decode(parts[1]) .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; serde_json::from_slice(&payload_bytes) .map_err(|e| ServiceAuthError::InvalidJson(e.to_string())) } /// Fully decode and verify a serviceAuth JWT. /// /// `public_key_bytes` should be the decompressed SEC1 public key bytes /// from the issuer's DID document (as returned by `extract_signing_key_bytes`). /// `key_algorithm` is the algorithm associated with the key from the DID document. pub fn decode_and_verify_service_auth( jwt: &str, public_key_bytes: &[u8], key_algorithm: Algorithm, ) -> Result { let jwt = jwt.trim(); let parts: Vec<&str> = jwt.split('.').collect(); if parts.len() != 3 { return Err(ServiceAuthError::InvalidFormat); } let header_bytes = URL_SAFE_NO_PAD .decode(parts[0]) .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; let payload_bytes = URL_SAFE_NO_PAD .decode(parts[1]) .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; let signature_bytes = URL_SAFE_NO_PAD .decode(parts[2]) .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; let header: JwtHeader = serde_json::from_slice(&header_bytes) .map_err(|e| ServiceAuthError::InvalidJson(e.to_string()))?; let claims: ServiceAuthClaims = serde_json::from_slice(&payload_bytes) .map_err(|e| ServiceAuthError::InvalidJson(e.to_string()))?; // Validate algorithm let jwt_algorithm = match header.alg.as_str() { "ES256" => Algorithm::P256, "ES256K" => Algorithm::Secp256k1, other => return Err(ServiceAuthError::UnsupportedAlgorithm(other.to_string())), }; // The JWT algorithm must match the key's algorithm from the DID document if jwt_algorithm != key_algorithm { return Err(ServiceAuthError::UnsupportedAlgorithm(format!( "JWT uses {} but DID document key is {}", header.alg, match key_algorithm { Algorithm::P256 => "ES256 (P-256)", Algorithm::Secp256k1 => "ES256K (secp256k1)", } ))); } // Check expiry let now = chrono::Utc::now().timestamp() as u64; if claims.exp > 0 && claims.exp < now { return Err(ServiceAuthError::Expired); } // Verify signature: signing input is the raw "{header}.{payload}" string let signing_input = format!("{}.{}", parts[0], parts[1]); let verifier = Verifier::new(true); verifier .verify( jwt_algorithm, public_key_bytes, signing_input.as_bytes(), &signature_bytes, ) .map_err(|e| ServiceAuthError::InvalidSignature(e.to_string()))?; Ok(claims) } /// Extract the signing key bytes and algorithm from a DID document. /// Returns the decompressed SEC1 public key bytes and the associated algorithm. pub fn extract_signing_key_bytes( did_doc: &DidDocument, ) -> Result<(Algorithm, Vec), ServiceAuthError> { let method = did_doc .get_signing_key() .ok_or(ServiceAuthError::MissingSigningKey)?; let multibase = method .public_key_multibase .as_ref() .ok_or(ServiceAuthError::MissingSigningKey)?; let (alg, key_bytes) = parse_multikey(multibase).map_err(|e| ServiceAuthError::InvalidSignature(e.to_string()))?; Ok((alg, key_bytes)) }