forked from hailey.at/cocoon
An atproto PDS written in Go
at main 11 kB view raw
1package client 2 3import ( 4 "context" 5 "encoding/json" 6 "errors" 7 "fmt" 8 "io" 9 "log/slog" 10 "net/http" 11 "net/url" 12 "slices" 13 "strings" 14 "time" 15 16 cache "github.com/go-pkgz/expirable-cache/v3" 17 "github.com/haileyok/cocoon/internal/helpers" 18 "github.com/lestrrat-go/jwx/v2/jwk" 19) 20 21type Manager struct { 22 cli *http.Client 23 logger *slog.Logger 24 jwksCache cache.Cache[string, jwk.Key] 25 metadataCache cache.Cache[string, *Metadata] 26} 27 28type ManagerArgs struct { 29 Cli *http.Client 30 Logger *slog.Logger 31} 32 33func NewManager(args ManagerArgs) *Manager { 34 if args.Logger == nil { 35 args.Logger = slog.Default() 36 } 37 38 if args.Cli == nil { 39 args.Cli = http.DefaultClient 40 } 41 42 jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 43 metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute) 44 45 return &Manager{ 46 cli: args.Cli, 47 logger: args.Logger, 48 jwksCache: jwksCache, 49 metadataCache: metadataCache, 50 } 51} 52 53func (cm *Manager) GetClient(ctx context.Context, clientId string) (*Client, error) { 54 metadata, err := cm.getClientMetadata(ctx, clientId) 55 if err != nil { 56 return nil, err 57 } 58 59 var jwks jwk.Key 60 if metadata.TokenEndpointAuthMethod == "private_key_jwt" { 61 if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 { 62 // TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to 63 // make sure we use the right one 64 b, err := json.Marshal(metadata.JWKS.Keys[0]) 65 if err != nil { 66 return nil, err 67 } 68 69 k, err := helpers.ParseJWKFromBytes(b) 70 if err != nil { 71 return nil, err 72 } 73 74 jwks = k 75 } else if metadata.JWKSURI != nil { 76 maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI) 77 if err != nil { 78 return nil, err 79 } 80 81 jwks = maybeJwks 82 } else { 83 return nil, fmt.Errorf("no valid jwks found in oauth client metadata") 84 } 85 } 86 87 return &Client{ 88 Metadata: metadata, 89 JWKS: jwks, 90 }, nil 91} 92 93func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) { 94 cached, ok := cm.metadataCache.Get(clientId) 95 if !ok { 96 req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil) 97 if err != nil { 98 return nil, err 99 } 100 101 resp, err := cm.cli.Do(req) 102 if err != nil { 103 return nil, err 104 } 105 defer resp.Body.Close() 106 107 if resp.StatusCode != http.StatusOK { 108 io.Copy(io.Discard, resp.Body) 109 return nil, fmt.Errorf("fetching client metadata returned response code %d", resp.StatusCode) 110 } 111 112 b, err := io.ReadAll(resp.Body) 113 if err != nil { 114 return nil, fmt.Errorf("error reading bytes from client response: %w", err) 115 } 116 117 validated, err := validateAndParseMetadata(clientId, b) 118 if err != nil { 119 return nil, err 120 } 121 122 cm.metadataCache.Set(clientId, validated, 10*time.Minute) 123 124 return validated, nil 125 } else { 126 return cached, nil 127 } 128} 129 130func (cm *Manager) getClientJwks(ctx context.Context, clientId, jwksUri string) (jwk.Key, error) { 131 jwks, ok := cm.jwksCache.Get(clientId) 132 if !ok { 133 req, err := http.NewRequestWithContext(ctx, "GET", jwksUri, nil) 134 if err != nil { 135 return nil, err 136 } 137 138 resp, err := cm.cli.Do(req) 139 if err != nil { 140 return nil, err 141 } 142 defer resp.Body.Close() 143 144 if resp.StatusCode != http.StatusOK { 145 io.Copy(io.Discard, resp.Body) 146 return nil, fmt.Errorf("fetching client jwks returned response code %d", resp.StatusCode) 147 } 148 149 type Keys struct { 150 Keys []map[string]any `json:"keys"` 151 } 152 153 var keys Keys 154 if err := json.NewDecoder(resp.Body).Decode(&keys); err != nil { 155 return nil, fmt.Errorf("error unmarshaling keys response: %w", err) 156 } 157 158 if len(keys.Keys) == 0 { 159 return nil, errors.New("no keys in jwks response") 160 } 161 162 // TODO: this is again bad, we should be figuring out which one we need to use... 163 b, err := json.Marshal(keys.Keys[0]) 164 if err != nil { 165 return nil, fmt.Errorf("could not marshal key: %w", err) 166 } 167 168 k, err := helpers.ParseJWKFromBytes(b) 169 if err != nil { 170 return nil, err 171 } 172 173 jwks = k 174 } 175 176 return jwks, nil 177} 178 179func validateAndParseMetadata(clientId string, b []byte) (*Metadata, error) { 180 var metadataMap map[string]any 181 if err := json.Unmarshal(b, &metadataMap); err != nil { 182 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 183 } 184 185 _, jwksOk := metadataMap["jwks"].(string) 186 _, jwksUriOk := metadataMap["jwks_uri"].(string) 187 if jwksOk && jwksUriOk { 188 return nil, errors.New("jwks_uri and jwks are mutually exclusive") 189 } 190 191 for _, k := range []string{ 192 "default_max_age", 193 "userinfo_signed_response_alg", 194 "id_token_signed_response_alg", 195 "userinfo_encryhpted_response_alg", 196 "authorization_encrypted_response_enc", 197 "authorization_encrypted_response_alg", 198 "tls_client_certificate_bound_access_tokens", 199 } { 200 _, kOk := metadataMap[k] 201 if kOk { 202 return nil, fmt.Errorf("unsupported `%s` parameter", k) 203 } 204 } 205 206 var metadata Metadata 207 if err := json.Unmarshal(b, &metadata); err != nil { 208 return nil, fmt.Errorf("error unmarshaling metadata: %w", err) 209 } 210 211 if metadata.ClientURI == "" { 212 u, err := url.Parse(metadata.ClientID) 213 if err != nil { 214 return nil, fmt.Errorf("unable to parse client id: %w", err) 215 } 216 u.RawPath = "" 217 u.RawQuery = "" 218 metadata.ClientURI = u.String() 219 } 220 221 u, err := url.Parse(metadata.ClientURI) 222 if err != nil { 223 return nil, fmt.Errorf("unable to parse client uri: %w", err) 224 } 225 226 if metadata.ClientName == "" { 227 metadata.ClientName = metadata.ClientURI 228 } 229 230 if isLocalHostname(u.Hostname()) { 231 return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname()) 232 } 233 234 if metadata.Scope == "" { 235 return nil, errors.New("missing `scopes` scope") 236 } 237 238 scopes := strings.Split(metadata.Scope, " ") 239 if !slices.Contains(scopes, "atproto") { 240 return nil, errors.New("missing `atproto` scope") 241 } 242 243 scopesMap := map[string]bool{} 244 for _, scope := range scopes { 245 if scopesMap[scope] { 246 return nil, fmt.Errorf("duplicate scope `%s`", scope) 247 } 248 249 // TODO: check for unsupported scopes 250 251 scopesMap[scope] = true 252 } 253 254 grantTypesMap := map[string]bool{} 255 for _, gt := range metadata.GrantTypes { 256 if grantTypesMap[gt] { 257 return nil, fmt.Errorf("duplicate grant type `%s`", gt) 258 } 259 260 switch gt { 261 case "implicit": 262 return nil, errors.New("grantg type `implicit` is not allowed") 263 case "authorization_code", "refresh_token": 264 // TODO check if this grant type is supported 265 default: 266 return nil, fmt.Errorf("grant tyhpe `%s` is not supported", gt) 267 } 268 269 grantTypesMap[gt] = true 270 } 271 272 if metadata.ClientID != clientId { 273 return nil, errors.New("`client_id` does not match") 274 } 275 276 subjectType, subjectTypeOk := metadataMap["subject_type"].(string) 277 if subjectTypeOk && subjectType != "public" { 278 return nil, errors.New("only public `subject_type` is supported") 279 } 280 281 switch metadata.TokenEndpointAuthMethod { 282 case "none": 283 if metadata.TokenEndpointAuthSigningAlg != "" { 284 return nil, errors.New("token_endpoint_auth_method `none` must not have token_endpoint_auth_signing_alg") 285 } 286 case "private_key_jwt": 287 if metadata.JWKS == nil && metadata.JWKSURI == nil { 288 return nil, errors.New("private_key_jwt auth method requires jwks or jwks_uri") 289 } 290 291 if metadata.JWKS != nil && len(metadata.JWKS.Keys) == 0 { 292 return nil, errors.New("private_key_jwt auth method requires atleast one key in jwks") 293 } 294 295 if metadata.TokenEndpointAuthSigningAlg == "" { 296 return nil, errors.New("missing token_endpoint_auth_signing_alg in client metadata") 297 } 298 default: 299 return nil, fmt.Errorf("unsupported client authentication method `%s`", metadata.TokenEndpointAuthMethod) 300 } 301 302 if !metadata.DpopBoundAccessTokens { 303 return nil, errors.New("dpop_bound_access_tokens must be true") 304 } 305 306 if !slices.Contains(metadata.ResponseTypes, "code") { 307 return nil, errors.New("response_types must inclue `code`") 308 } 309 310 if !slices.Contains(metadata.GrantTypes, "authorization_code") { 311 return nil, errors.New("the `code` response type requires that `grant_types` contains `authorization_code`") 312 } 313 314 if len(metadata.RedirectURIs) == 0 { 315 return nil, errors.New("at least one `redirect_uri` is required") 316 } 317 318 if metadata.ApplicationType == "native" && metadata.TokenEndpointAuthMethod != "none" { 319 return nil, errors.New("native clients must authenticate using `none` method") 320 } 321 322 if metadata.ApplicationType == "web" && slices.Contains(metadata.GrantTypes, "implicit") { 323 for _, ruri := range metadata.RedirectURIs { 324 u, err := url.Parse(ruri) 325 if err != nil { 326 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 327 } 328 329 if u.Scheme != "https" { 330 return nil, errors.New("web clients must use https redirect uris") 331 } 332 333 if u.Hostname() == "localhost" { 334 return nil, errors.New("web clients must not use localhost as the hostname") 335 } 336 } 337 } 338 339 for _, ruri := range metadata.RedirectURIs { 340 u, err := url.Parse(ruri) 341 if err != nil { 342 return nil, fmt.Errorf("error parsing redirect uri: %w", err) 343 } 344 345 if u.User != nil { 346 if u.User.Username() != "" { 347 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 348 } 349 350 if _, hasPass := u.User.Password(); hasPass { 351 return nil, fmt.Errorf("redirect uri %s must not contain credentials", ruri) 352 } 353 } 354 355 switch true { 356 case u.Hostname() == "localhost": 357 return nil, errors.New("loopback redirect uri is not allowed (use explicit ips instead)") 358 case u.Hostname() == "127.0.0.1", u.Hostname() == "[::1]": 359 if metadata.ApplicationType != "native" { 360 return nil, errors.New("loopback redirect uris are only allowed for native apps") 361 } 362 363 if u.Port() != "" { 364 // reference impl doesn't do anything with this? 365 } 366 367 if u.Scheme != "http" { 368 return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri) 369 } 370 case u.Scheme == "http": 371 return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme") 372 case u.Scheme == "https": 373 if isLocalHostname(u.Hostname()) { 374 return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri) 375 } 376 case strings.Contains(u.Scheme, "."): 377 if metadata.ApplicationType != "native" { 378 return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps") 379 } 380 381 revdomain := reverseDomain(u.Scheme) 382 383 if isLocalHostname(revdomain) { 384 return nil, errors.New("private use uri scheme redirect uris must not be local hostnames") 385 } 386 387 if strings.HasPrefix(u.String(), fmt.Sprintf("%s://", u.Scheme)) || u.Hostname() != "" || u.Port() != "" { 388 return nil, fmt.Errorf("private use uri scheme must be in the form ") 389 } 390 default: 391 return nil, fmt.Errorf("invalid redirect uri scheme `%s`", u.Scheme) 392 } 393 } 394 395 return &metadata, nil 396} 397 398func isLocalHostname(hostname string) bool { 399 pts := strings.Split(hostname, ".") 400 if len(pts) < 2 { 401 return true 402 } 403 404 tld := strings.ToLower(pts[len(pts)-1]) 405 return tld == "test" || tld == "local" || tld == "localhost" || tld == "invalid" || tld == "example" 406} 407 408func reverseDomain(domain string) string { 409 pts := strings.Split(domain, ".") 410 slices.Reverse(pts) 411 return strings.Join(pts, ".") 412}