this repo has no description
at main 162 lines 5.6 kB view raw
1use atrium_api::did_doc::DidDocument; 2use atrium_crypto::did::parse_multikey; 3use atrium_crypto::verify::Verifier; 4use atrium_crypto::Algorithm; 5use base64::engine::general_purpose::URL_SAFE_NO_PAD; 6use base64::Engine; 7use serde::Deserialize; 8use std::fmt; 9 10#[derive(Debug)] 11pub enum ServiceAuthError { 12 InvalidFormat, 13 InvalidBase64(String), 14 InvalidJson(String), 15 UnsupportedAlgorithm(String), 16 Expired, 17 InvalidSignature(String), 18 MissingSigningKey, 19 ClaimMismatch(String), 20} 21 22impl fmt::Display for ServiceAuthError { 23 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 24 match self { 25 Self::InvalidFormat => write!(f, "JWT must have three dot-separated parts"), 26 Self::InvalidBase64(msg) => write!(f, "Invalid base64url encoding: {msg}"), 27 Self::InvalidJson(msg) => write!(f, "Invalid JSON in JWT: {msg}"), 28 Self::UnsupportedAlgorithm(alg) => { 29 write!(f, "Unsupported algorithm '{alg}', expected ES256 or ES256K") 30 } 31 Self::Expired => write!(f, "Token has expired"), 32 Self::InvalidSignature(msg) => write!(f, "Invalid signature: {msg}"), 33 Self::MissingSigningKey => write!(f, "No signing key found in DID document"), 34 Self::ClaimMismatch(msg) => write!(f, "{msg}"), 35 } 36 } 37} 38 39#[derive(Debug, Deserialize)] 40struct JwtHeader { 41 alg: String, 42 #[allow(dead_code)] 43 typ: Option<String>, 44} 45 46#[derive(Debug, Deserialize, Clone)] 47pub struct ServiceAuthClaims { 48 pub iss: String, 49 pub aud: String, 50 #[serde(default)] 51 pub exp: u64, 52 #[serde(default)] 53 pub iat: Option<u64>, 54 #[serde(default)] 55 pub lxm: Option<String>, 56 #[serde(default)] 57 pub jti: Option<String>, 58} 59 60/// Decode JWT claims without verifying the signature. 61/// Useful for extracting the `iss` field before resolving the DID document. 62pub fn decode_jwt_claims(jwt: &str) -> Result<ServiceAuthClaims, ServiceAuthError> { 63 let parts: Vec<&str> = jwt.trim().split('.').collect(); 64 if parts.len() != 3 { 65 return Err(ServiceAuthError::InvalidFormat); 66 } 67 68 let payload_bytes = URL_SAFE_NO_PAD 69 .decode(parts[1]) 70 .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; 71 72 serde_json::from_slice(&payload_bytes) 73 .map_err(|e| ServiceAuthError::InvalidJson(e.to_string())) 74} 75 76/// Fully decode and verify a serviceAuth JWT. 77/// 78/// `public_key_bytes` should be the decompressed SEC1 public key bytes 79/// from the issuer's DID document (as returned by `extract_signing_key_bytes`). 80/// `key_algorithm` is the algorithm associated with the key from the DID document. 81pub fn decode_and_verify_service_auth( 82 jwt: &str, 83 public_key_bytes: &[u8], 84 key_algorithm: Algorithm, 85) -> Result<ServiceAuthClaims, ServiceAuthError> { 86 let jwt = jwt.trim(); 87 let parts: Vec<&str> = jwt.split('.').collect(); 88 if parts.len() != 3 { 89 return Err(ServiceAuthError::InvalidFormat); 90 } 91 92 let header_bytes = URL_SAFE_NO_PAD 93 .decode(parts[0]) 94 .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; 95 let payload_bytes = URL_SAFE_NO_PAD 96 .decode(parts[1]) 97 .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; 98 let signature_bytes = URL_SAFE_NO_PAD 99 .decode(parts[2]) 100 .map_err(|e| ServiceAuthError::InvalidBase64(e.to_string()))?; 101 102 let header: JwtHeader = serde_json::from_slice(&header_bytes) 103 .map_err(|e| ServiceAuthError::InvalidJson(e.to_string()))?; 104 let claims: ServiceAuthClaims = serde_json::from_slice(&payload_bytes) 105 .map_err(|e| ServiceAuthError::InvalidJson(e.to_string()))?; 106 107 // Validate algorithm 108 let jwt_algorithm = match header.alg.as_str() { 109 "ES256" => Algorithm::P256, 110 "ES256K" => Algorithm::Secp256k1, 111 other => return Err(ServiceAuthError::UnsupportedAlgorithm(other.to_string())), 112 }; 113 114 // The JWT algorithm must match the key's algorithm from the DID document 115 if jwt_algorithm != key_algorithm { 116 return Err(ServiceAuthError::UnsupportedAlgorithm(format!( 117 "JWT uses {} but DID document key is {}", 118 header.alg, 119 match key_algorithm { 120 Algorithm::P256 => "ES256 (P-256)", 121 Algorithm::Secp256k1 => "ES256K (secp256k1)", 122 } 123 ))); 124 } 125 126 // Check expiry 127 let now = chrono::Utc::now().timestamp() as u64; 128 if claims.exp > 0 && claims.exp < now { 129 return Err(ServiceAuthError::Expired); 130 } 131 132 // Verify signature: signing input is the raw "{header}.{payload}" string 133 let signing_input = format!("{}.{}", parts[0], parts[1]); 134 let verifier = Verifier::new(true); 135 verifier 136 .verify( 137 jwt_algorithm, 138 public_key_bytes, 139 signing_input.as_bytes(), 140 &signature_bytes, 141 ) 142 .map_err(|e| ServiceAuthError::InvalidSignature(e.to_string()))?; 143 144 Ok(claims) 145} 146 147/// Extract the signing key bytes and algorithm from a DID document. 148/// Returns the decompressed SEC1 public key bytes and the associated algorithm. 149pub fn extract_signing_key_bytes( 150 did_doc: &DidDocument, 151) -> Result<(Algorithm, Vec<u8>), ServiceAuthError> { 152 let method = did_doc 153 .get_signing_key() 154 .ok_or(ServiceAuthError::MissingSigningKey)?; 155 let multibase = method 156 .public_key_multibase 157 .as_ref() 158 .ok_or(ServiceAuthError::MissingSigningKey)?; 159 let (alg, key_bytes) = 160 parse_multikey(multibase).map_err(|e| ServiceAuthError::InvalidSignature(e.to_string()))?; 161 Ok((alg, key_bytes)) 162}