Built for people who think better out loud.
1//! OAuth policy checks used by HTTP handlers.
2use std::collections::HashSet;
3
4#[derive(Debug)]
5/// Policy errors surfaced during OAuth validation.
6pub enum OAuthPolicyError {
7 MissingAtprotoScope,
8 ScopeNotSubset,
9 DidPdsMismatch,
10 IssuerMismatch { expected: String, actual: String },
11}
12
13/// Checks whether a scope string includes the atproto scope.
14pub fn scope_contains_atproto(scope: &str) -> bool {
15 scope.split_whitespace().any(|value| value == "atproto")
16}
17
18fn scope_set(scope: &str) -> HashSet<String> {
19 scope
20 .split_whitespace()
21 .filter(|value| !value.is_empty())
22 .map(|value| value.to_string())
23 .collect()
24}
25
26/// Ensures granted scopes are a subset of the requested scopes and include atproto.
27pub fn ensure_granted_scopes_valid(
28 requested_scope: &str,
29 granted_scope: &str,
30) -> Result<(), OAuthPolicyError> {
31 if !scope_contains_atproto(granted_scope) {
32 return Err(OAuthPolicyError::MissingAtprotoScope);
33 }
34
35 let requested = scope_set(requested_scope);
36 let granted = scope_set(granted_scope);
37 if !granted.is_subset(&requested) {
38 return Err(OAuthPolicyError::ScopeNotSubset);
39 }
40
41 Ok(())
42}
43
44fn normalize_endpoint(endpoint: &str) -> String {
45 endpoint.trim_end_matches('/').to_string()
46}
47
48/// Validates that the token DID matches the PDS used for authorization.
49pub fn ensure_did_matches_authorization_server(
50 expected_issuer: &str,
51 expected_pds: &str,
52 did_pds: &str,
53 did_issuer: &str,
54) -> Result<(), OAuthPolicyError> {
55 if normalize_endpoint(expected_pds) != normalize_endpoint(did_pds) {
56 return Err(OAuthPolicyError::DidPdsMismatch);
57 }
58
59 if expected_issuer != did_issuer {
60 return Err(OAuthPolicyError::IssuerMismatch {
61 expected: expected_issuer.to_string(),
62 actual: did_issuer.to_string(),
63 });
64 }
65
66 Ok(())
67}
68
69#[cfg(test)]
70mod tests {
71 use super::*;
72 use proptest::prelude::*;
73
74 #[test]
75 fn ensure_granted_scopes_rejects_missing_atproto() {
76 let result = ensure_granted_scopes_valid("atproto write", "write");
77 assert!(matches!(result, Err(OAuthPolicyError::MissingAtprotoScope)));
78 }
79
80 #[test]
81 fn ensure_granted_scopes_rejects_extra_scopes() {
82 let result = ensure_granted_scopes_valid("atproto", "atproto extra");
83 assert!(matches!(result, Err(OAuthPolicyError::ScopeNotSubset)));
84 }
85
86 #[test]
87 fn ensure_granted_scopes_accepts_subset() {
88 let result = ensure_granted_scopes_valid("atproto write", "atproto");
89 assert!(result.is_ok());
90 }
91
92 #[test]
93 fn ensure_did_matches_authorization_server_detects_mismatch() {
94 let result = ensure_did_matches_authorization_server(
95 "https://issuer.example",
96 "https://pds.example",
97 "https://other-pds.example",
98 "https://issuer.example",
99 );
100 assert!(matches!(result, Err(OAuthPolicyError::DidPdsMismatch)));
101 }
102
103 #[test]
104 fn ensure_did_matches_authorization_server_accepts_match() {
105 let result = ensure_did_matches_authorization_server(
106 "https://issuer.example",
107 "https://pds.example/",
108 "https://pds.example",
109 "https://issuer.example",
110 );
111 assert!(result.is_ok());
112 }
113
114 proptest! {
115 #[test]
116 fn scope_contains_atproto_detects_token(
117 prefix in "[a-z]{0,6}",
118 suffix in "[a-z]{0,6}"
119 ) {
120 let scope = format!("{prefix} atproto {suffix}");
121 prop_assert!(scope_contains_atproto(&scope));
122 }
123 }
124
125 proptest! {
126 #[test]
127 fn scope_contains_atproto_rejects_missing_token(
128 tokens in proptest::collection::vec("[a-z]{1,8}", 1..6)
129 ) {
130 let scope = tokens
131 .iter()
132 .filter(|token| *token != "atproto")
133 .cloned()
134 .collect::<Vec<_>>()
135 .join(" ");
136 prop_assume!(!scope.split_whitespace().any(|value| value == "atproto"));
137 prop_assert!(!scope_contains_atproto(&scope));
138 }
139 }
140}