Built for people who think better out loud.
at main 140 lines 4.2 kB view raw
1//! OAuth policy checks used by HTTP handlers. 2use std::collections::HashSet; 3 4#[derive(Debug)] 5/// Policy errors surfaced during OAuth validation. 6pub enum OAuthPolicyError { 7 MissingAtprotoScope, 8 ScopeNotSubset, 9 DidPdsMismatch, 10 IssuerMismatch { expected: String, actual: String }, 11} 12 13/// Checks whether a scope string includes the atproto scope. 14pub fn scope_contains_atproto(scope: &str) -> bool { 15 scope.split_whitespace().any(|value| value == "atproto") 16} 17 18fn scope_set(scope: &str) -> HashSet<String> { 19 scope 20 .split_whitespace() 21 .filter(|value| !value.is_empty()) 22 .map(|value| value.to_string()) 23 .collect() 24} 25 26/// Ensures granted scopes are a subset of the requested scopes and include atproto. 27pub fn ensure_granted_scopes_valid( 28 requested_scope: &str, 29 granted_scope: &str, 30) -> Result<(), OAuthPolicyError> { 31 if !scope_contains_atproto(granted_scope) { 32 return Err(OAuthPolicyError::MissingAtprotoScope); 33 } 34 35 let requested = scope_set(requested_scope); 36 let granted = scope_set(granted_scope); 37 if !granted.is_subset(&requested) { 38 return Err(OAuthPolicyError::ScopeNotSubset); 39 } 40 41 Ok(()) 42} 43 44fn normalize_endpoint(endpoint: &str) -> String { 45 endpoint.trim_end_matches('/').to_string() 46} 47 48/// Validates that the token DID matches the PDS used for authorization. 49pub fn ensure_did_matches_authorization_server( 50 expected_issuer: &str, 51 expected_pds: &str, 52 did_pds: &str, 53 did_issuer: &str, 54) -> Result<(), OAuthPolicyError> { 55 if normalize_endpoint(expected_pds) != normalize_endpoint(did_pds) { 56 return Err(OAuthPolicyError::DidPdsMismatch); 57 } 58 59 if expected_issuer != did_issuer { 60 return Err(OAuthPolicyError::IssuerMismatch { 61 expected: expected_issuer.to_string(), 62 actual: did_issuer.to_string(), 63 }); 64 } 65 66 Ok(()) 67} 68 69#[cfg(test)] 70mod tests { 71 use super::*; 72 use proptest::prelude::*; 73 74 #[test] 75 fn ensure_granted_scopes_rejects_missing_atproto() { 76 let result = ensure_granted_scopes_valid("atproto write", "write"); 77 assert!(matches!(result, Err(OAuthPolicyError::MissingAtprotoScope))); 78 } 79 80 #[test] 81 fn ensure_granted_scopes_rejects_extra_scopes() { 82 let result = ensure_granted_scopes_valid("atproto", "atproto extra"); 83 assert!(matches!(result, Err(OAuthPolicyError::ScopeNotSubset))); 84 } 85 86 #[test] 87 fn ensure_granted_scopes_accepts_subset() { 88 let result = ensure_granted_scopes_valid("atproto write", "atproto"); 89 assert!(result.is_ok()); 90 } 91 92 #[test] 93 fn ensure_did_matches_authorization_server_detects_mismatch() { 94 let result = ensure_did_matches_authorization_server( 95 "https://issuer.example", 96 "https://pds.example", 97 "https://other-pds.example", 98 "https://issuer.example", 99 ); 100 assert!(matches!(result, Err(OAuthPolicyError::DidPdsMismatch))); 101 } 102 103 #[test] 104 fn ensure_did_matches_authorization_server_accepts_match() { 105 let result = ensure_did_matches_authorization_server( 106 "https://issuer.example", 107 "https://pds.example/", 108 "https://pds.example", 109 "https://issuer.example", 110 ); 111 assert!(result.is_ok()); 112 } 113 114 proptest! { 115 #[test] 116 fn scope_contains_atproto_detects_token( 117 prefix in "[a-z]{0,6}", 118 suffix in "[a-z]{0,6}" 119 ) { 120 let scope = format!("{prefix} atproto {suffix}"); 121 prop_assert!(scope_contains_atproto(&scope)); 122 } 123 } 124 125 proptest! { 126 #[test] 127 fn scope_contains_atproto_rejects_missing_token( 128 tokens in proptest::collection::vec("[a-z]{1,8}", 1..6) 129 ) { 130 let scope = tokens 131 .iter() 132 .filter(|token| *token != "atproto") 133 .cloned() 134 .collect::<Vec<_>>() 135 .join(" "); 136 prop_assume!(!scope.split_whitespace().any(|value| value == "atproto")); 137 prop_assert!(!scope_contains_atproto(&scope)); 138 } 139 } 140}