+44
-28
oauth/client/manager.go
+44
-28
oauth/client/manager.go
···
22
22
cli *http.Client
23
23
logger *slog.Logger
24
24
jwksCache cache.Cache[string, jwk.Key]
25
-
metadataCache cache.Cache[string, Metadata]
25
+
metadataCache cache.Cache[string, *Metadata]
26
26
}
27
27
28
28
type ManagerArgs struct {
···
40
40
}
41
41
42
42
jwksCache := cache.NewCache[string, jwk.Key]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
43
-
metadataCache := cache.NewCache[string, Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
43
+
metadataCache := cache.NewCache[string, *Metadata]().WithLRU().WithMaxKeys(500).WithTTL(5 * time.Minute)
44
44
45
45
return &Manager{
46
46
cli: args.Cli,
···
57
57
}
58
58
59
59
var jwks jwk.Key
60
-
if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
61
-
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
62
-
// make sure we use the right one
63
-
b, err := json.Marshal(metadata.JWKS.Keys[0])
64
-
if err != nil {
65
-
return nil, err
66
-
}
60
+
if metadata.TokenEndpointAuthMethod == "private_key_jwt" {
61
+
if metadata.JWKS != nil && len(metadata.JWKS.Keys) > 0 {
62
+
// TODO: this is kinda bad but whatever for now. there could obviously be more than one jwk, and we need to
63
+
// make sure we use the right one
64
+
b, err := json.Marshal(metadata.JWKS.Keys[0])
65
+
if err != nil {
66
+
return nil, err
67
+
}
67
68
68
-
k, err := helpers.ParseJWKFromBytes(b)
69
-
if err != nil {
70
-
return nil, err
71
-
}
69
+
k, err := helpers.ParseJWKFromBytes(b)
70
+
if err != nil {
71
+
return nil, err
72
+
}
72
73
73
-
jwks = k
74
-
} else if metadata.JWKSURI != nil {
75
-
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
76
-
if err != nil {
77
-
return nil, err
78
-
}
74
+
jwks = k
75
+
} else if metadata.JWKS != nil {
76
+
} else if metadata.JWKSURI != nil {
77
+
maybeJwks, err := cm.getClientJwks(ctx, clientId, *metadata.JWKSURI)
78
+
if err != nil {
79
+
return nil, err
80
+
}
79
81
80
-
jwks = maybeJwks
81
-
} else {
82
-
return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
82
+
jwks = maybeJwks
83
+
} else {
84
+
return nil, fmt.Errorf("no valid jwks found in oauth client metadata")
85
+
}
83
86
}
84
87
85
88
return &Client{
···
89
92
}
90
93
91
94
func (cm *Manager) getClientMetadata(ctx context.Context, clientId string) (*Metadata, error) {
92
-
metadataCached, ok := cm.metadataCache.Get(clientId)
95
+
cached, ok := cm.metadataCache.Get(clientId)
93
96
if !ok {
94
97
req, err := http.NewRequestWithContext(ctx, "GET", clientId, nil)
95
98
if err != nil {
···
117
120
return nil, err
118
121
}
119
122
123
+
cm.metadataCache.Set(clientId, validated, 10*time.Minute)
124
+
120
125
return validated, nil
121
126
} else {
122
-
return &metadataCached, nil
127
+
return cached, nil
123
128
}
124
129
}
125
130
···
204
209
return nil, fmt.Errorf("error unmarshaling metadata: %w", err)
205
210
}
206
211
212
+
if metadata.ClientURI == "" {
213
+
u, err := url.Parse(metadata.ClientID)
214
+
if err != nil {
215
+
return nil, fmt.Errorf("unable to parse client id: %w", err)
216
+
}
217
+
u.RawPath = ""
218
+
u.RawQuery = ""
219
+
metadata.ClientURI = u.String()
220
+
}
221
+
207
222
u, err := url.Parse(metadata.ClientURI)
208
223
if err != nil {
209
224
return nil, fmt.Errorf("unable to parse client uri: %w", err)
210
225
}
211
226
227
+
if metadata.ClientName == "" {
228
+
metadata.ClientName = metadata.ClientURI
229
+
}
230
+
212
231
if isLocalHostname(u.Hostname()) {
213
-
return nil, errors.New("`client_uri` hostname is invalid")
232
+
return nil, fmt.Errorf("`client_uri` hostname is invalid: %s", u.Hostname())
214
233
}
215
234
216
235
if metadata.Scope == "" {
···
349
368
if u.Scheme != "http" {
350
369
return nil, fmt.Errorf("loopback redirect uri %s must use http", ruri)
351
370
}
352
-
353
-
break
354
371
case u.Scheme == "http":
355
372
return nil, errors.New("only loopbvack redirect uris are allowed to use the `http` scheme")
356
373
case u.Scheme == "https":
357
374
if isLocalHostname(u.Hostname()) {
358
375
return nil, fmt.Errorf("redirect uri %s's domain must not be a local hostname", ruri)
359
376
}
360
-
break
361
377
case strings.Contains(u.Scheme, "."):
362
378
if metadata.ApplicationType != "native" {
363
379
return nil, errors.New("private-use uri scheme redirect uris are only allowed for native apps")