//! OAuth policy checks used by HTTP handlers. use std::collections::HashSet; #[derive(Debug)] /// Policy errors surfaced during OAuth validation. pub enum OAuthPolicyError { MissingAtprotoScope, ScopeNotSubset, DidPdsMismatch, IssuerMismatch { expected: String, actual: String }, } /// Checks whether a scope string includes the atproto scope. pub fn scope_contains_atproto(scope: &str) -> bool { scope.split_whitespace().any(|value| value == "atproto") } fn scope_set(scope: &str) -> HashSet { scope .split_whitespace() .filter(|value| !value.is_empty()) .map(|value| value.to_string()) .collect() } /// Ensures granted scopes are a subset of the requested scopes and include atproto. pub fn ensure_granted_scopes_valid( requested_scope: &str, granted_scope: &str, ) -> Result<(), OAuthPolicyError> { if !scope_contains_atproto(granted_scope) { return Err(OAuthPolicyError::MissingAtprotoScope); } let requested = scope_set(requested_scope); let granted = scope_set(granted_scope); if !granted.is_subset(&requested) { return Err(OAuthPolicyError::ScopeNotSubset); } Ok(()) } fn normalize_endpoint(endpoint: &str) -> String { endpoint.trim_end_matches('/').to_string() } /// Validates that the token DID matches the PDS used for authorization. pub fn ensure_did_matches_authorization_server( expected_issuer: &str, expected_pds: &str, did_pds: &str, did_issuer: &str, ) -> Result<(), OAuthPolicyError> { if normalize_endpoint(expected_pds) != normalize_endpoint(did_pds) { return Err(OAuthPolicyError::DidPdsMismatch); } if expected_issuer != did_issuer { return Err(OAuthPolicyError::IssuerMismatch { expected: expected_issuer.to_string(), actual: did_issuer.to_string(), }); } Ok(()) } #[cfg(test)] mod tests { use super::*; use proptest::prelude::*; #[test] fn ensure_granted_scopes_rejects_missing_atproto() { let result = ensure_granted_scopes_valid("atproto write", "write"); assert!(matches!(result, Err(OAuthPolicyError::MissingAtprotoScope))); } #[test] fn ensure_granted_scopes_rejects_extra_scopes() { let result = ensure_granted_scopes_valid("atproto", "atproto extra"); assert!(matches!(result, Err(OAuthPolicyError::ScopeNotSubset))); } #[test] fn ensure_granted_scopes_accepts_subset() { let result = ensure_granted_scopes_valid("atproto write", "atproto"); assert!(result.is_ok()); } #[test] fn ensure_did_matches_authorization_server_detects_mismatch() { let result = ensure_did_matches_authorization_server( "https://issuer.example", "https://pds.example", "https://other-pds.example", "https://issuer.example", ); assert!(matches!(result, Err(OAuthPolicyError::DidPdsMismatch))); } #[test] fn ensure_did_matches_authorization_server_accepts_match() { let result = ensure_did_matches_authorization_server( "https://issuer.example", "https://pds.example/", "https://pds.example", "https://issuer.example", ); assert!(result.is_ok()); } proptest! { #[test] fn scope_contains_atproto_detects_token( prefix in "[a-z]{0,6}", suffix in "[a-z]{0,6}" ) { let scope = format!("{prefix} atproto {suffix}"); prop_assert!(scope_contains_atproto(&scope)); } } proptest! { #[test] fn scope_contains_atproto_rejects_missing_token( tokens in proptest::collection::vec("[a-z]{1,8}", 1..6) ) { let scope = tokens .iter() .filter(|token| *token != "atproto") .cloned() .collect::>() .join(" "); prop_assume!(!scope.split_whitespace().any(|value| value == "atproto")); prop_assert!(!scope_contains_atproto(&scope)); } } }