+147
atproto/auth/jwt.go
+147
atproto/auth/jwt.go
···
1
+
package auth
2
+
3
+
import (
4
+
"context"
5
+
"encoding/base64"
6
+
"errors"
7
+
"fmt"
8
+
"log/slog"
9
+
"math/rand"
10
+
"time"
11
+
12
+
"github.com/bluesky-social/indigo/atproto/crypto"
13
+
"github.com/bluesky-social/indigo/atproto/identity"
14
+
"github.com/bluesky-social/indigo/atproto/syntax"
15
+
16
+
"github.com/golang-jwt/jwt/v5"
17
+
)
18
+
19
+
// TODO: check for uniqueness of JTI (random nonce) to prevent token replay
20
+
21
+
type ServiceAuthValidator struct {
22
+
// Service DID reference for this validator: a DID with optional #-separated fragment
23
+
Audience string
24
+
Dir identity.Directory
25
+
}
26
+
27
+
type serviceAuthClaims struct {
28
+
jwt.RegisteredClaims
29
+
30
+
LexMethod string `json:"lxm,omitempty"`
31
+
}
32
+
33
+
func (s *ServiceAuthValidator) Validate(ctx context.Context, tokenString string, lexMethod *syntax.NSID) (syntax.DID, error) {
34
+
35
+
opts := []jwt.ParserOption{
36
+
jwt.WithValidMethods(supportedAlgs),
37
+
jwt.WithAudience(s.Audience),
38
+
jwt.WithExpirationRequired(),
39
+
jwt.WithIssuedAt(),
40
+
jwt.WithLeeway(5 * time.Second), // TODO: configurable? better default?
41
+
}
42
+
43
+
token, err := jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...)
44
+
if err != nil && errors.Is(err, jwt.ErrTokenSignatureInvalid) {
45
+
// if signature validation fails, purge the directory and try again
46
+
// TODO: probably need to cache or rate-limit this?
47
+
48
+
// do an unvalidated extraction of 'iss' from JWT
49
+
insecure := jwt.NewParser(jwt.WithoutClaimsValidation())
50
+
t, _, err := insecure.ParseUnverified(tokenString, &jwt.MapClaims{})
51
+
claims, ok := t.Claims.(*jwt.MapClaims)
52
+
if !ok {
53
+
return "", jwt.ErrTokenInvalidClaims
54
+
}
55
+
iss, err := claims.GetIssuer()
56
+
if err != nil {
57
+
return "", err
58
+
}
59
+
did, err := syntax.ParseDID(iss)
60
+
if err != nil {
61
+
return "", fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err)
62
+
}
63
+
64
+
slog.Info("purging directory and retrying service auth signature validation", "did", did)
65
+
err = s.Dir.Purge(ctx, did.AtIdentifier())
66
+
if err != nil {
67
+
slog.Error("purging identity directory", "did", did, "err", err)
68
+
}
69
+
token, err = jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...)
70
+
}
71
+
if err != nil {
72
+
return "", err
73
+
}
74
+
claims, ok := token.Claims.(*serviceAuthClaims)
75
+
if !ok {
76
+
// TODO: is this the best error here?
77
+
return "", jwt.ErrTokenInvalidClaims
78
+
}
79
+
80
+
if lexMethod != nil && claims.LexMethod != lexMethod.String() {
81
+
return "", fmt.Errorf("%w: Lexicon endpoint (LXM)", jwt.ErrTokenInvalidClaims)
82
+
}
83
+
84
+
// NOTE: KeyFunc has already parsed issuer, so we know it is a valid DID
85
+
did := syntax.DID(claims.Issuer)
86
+
return did, nil
87
+
}
88
+
89
+
// resolves public key from identity directory
90
+
func (s *ServiceAuthValidator) fetchIssuerKeyFunc(ctx context.Context) func(token *jwt.Token) (any, error) {
91
+
return func(token *jwt.Token) (any, error) {
92
+
claims, ok := token.Claims.(*serviceAuthClaims)
93
+
if !ok {
94
+
return nil, fmt.Errorf("%w: missing 'iss'", jwt.ErrTokenInvalidClaims)
95
+
}
96
+
iss, err := claims.GetIssuer()
97
+
if err != nil {
98
+
return nil, fmt.Errorf("%w: missing 'iss'", jwt.ErrTokenInvalidClaims)
99
+
}
100
+
did, err := syntax.ParseDID(iss)
101
+
if err != nil {
102
+
return nil, fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err)
103
+
}
104
+
// NOTE: this will do handle resolution by default
105
+
ident, err := s.Dir.LookupDID(ctx, did)
106
+
if err != nil {
107
+
return nil, fmt.Errorf("%w: resolving DID (%s): %w", jwt.ErrTokenInvalidIssuer, did, err)
108
+
}
109
+
return ident.PublicKey()
110
+
}
111
+
}
112
+
113
+
func randomNonce() string {
114
+
buf := make([]byte, 16)
115
+
rand.Read(buf)
116
+
return base64.RawURLEncoding.EncodeToString(buf)
117
+
}
118
+
119
+
func SignServiceAuth(iss syntax.DID, aud string, ttl time.Duration, lexMethod *syntax.NSID, priv crypto.PrivateKey) (string, error) {
120
+
claims := serviceAuthClaims{
121
+
RegisteredClaims: jwt.RegisteredClaims{
122
+
ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
123
+
IssuedAt: jwt.NewNumericDate(time.Now()),
124
+
Issuer: iss.String(),
125
+
Audience: []string{aud},
126
+
ID: randomNonce(),
127
+
},
128
+
}
129
+
if lexMethod != nil {
130
+
claims.LexMethod = lexMethod.String()
131
+
}
132
+
133
+
var sm *signingMethodAtproto
134
+
135
+
// NOTE: could also have a crypto.PrivateKey.Alg() method which returns a string
136
+
switch priv.(type) {
137
+
case *crypto.PrivateKeyP256:
138
+
sm = signingMethodES256
139
+
case *crypto.PrivateKeyK256:
140
+
sm = signingMethodES256K
141
+
default:
142
+
return "", fmt.Errorf("unknown signing key type")
143
+
}
144
+
145
+
token := jwt.NewWithClaims(sm, claims)
146
+
return token.SignedString(priv)
147
+
}
+88
atproto/auth/jwt_signing.go
+88
atproto/auth/jwt_signing.go
···
1
+
package auth
2
+
3
+
import (
4
+
"crypto"
5
+
6
+
atcrypto "github.com/bluesky-social/indigo/atproto/crypto"
7
+
"github.com/golang-jwt/jwt/v5"
8
+
)
9
+
10
+
var (
11
+
signingMethodES256K *signingMethodAtproto
12
+
signingMethodES256 *signingMethodAtproto
13
+
supportedAlgs []string
14
+
)
15
+
16
+
// Implementation of jwt.SigningMethod for the `atproto/crypto` types.
17
+
type signingMethodAtproto struct {
18
+
alg string
19
+
hash crypto.Hash
20
+
toOutSig toOutSig
21
+
sigLen int
22
+
}
23
+
24
+
type toOutSig func(sig []byte) []byte
25
+
26
+
func init() {
27
+
// tells JWT library to serialize 'aud' as regular string, not array of strings (when signing)
28
+
jwt.MarshalSingleStringAsArray = false
29
+
30
+
signingMethodES256K = &signingMethodAtproto{
31
+
alg: "ES256K",
32
+
hash: crypto.SHA256,
33
+
toOutSig: toES256K,
34
+
sigLen: 64,
35
+
}
36
+
jwt.RegisterSigningMethod(signingMethodES256K.Alg(), func() jwt.SigningMethod {
37
+
return signingMethodES256K
38
+
})
39
+
signingMethodES256 = &signingMethodAtproto{
40
+
alg: "ES256",
41
+
hash: crypto.SHA256,
42
+
toOutSig: toES256,
43
+
sigLen: 64,
44
+
}
45
+
jwt.RegisterSigningMethod(signingMethodES256.Alg(), func() jwt.SigningMethod {
46
+
return signingMethodES256
47
+
})
48
+
supportedAlgs = []string{signingMethodES256K.Alg(), signingMethodES256.Alg()}
49
+
}
50
+
51
+
func (sm *signingMethodAtproto) Verify(signingString string, sig []byte, key interface{}) error {
52
+
pub, ok := key.(atcrypto.PublicKey)
53
+
if !ok {
54
+
return jwt.ErrInvalidKeyType
55
+
}
56
+
57
+
if !sm.hash.Available() {
58
+
return jwt.ErrHashUnavailable
59
+
}
60
+
61
+
if len(sig) != sm.sigLen {
62
+
return jwt.ErrTokenSignatureInvalid
63
+
}
64
+
65
+
// NOTE: important to use using "lenient" variant here
66
+
return pub.HashAndVerifyLenient([]byte(signingString), sig)
67
+
}
68
+
69
+
func (sm *signingMethodAtproto) Sign(signingString string, key interface{}) ([]byte, error) {
70
+
priv, ok := key.(atcrypto.PrivateKey)
71
+
if !ok {
72
+
return nil, jwt.ErrInvalidKeyType
73
+
}
74
+
75
+
return priv.HashAndSign([]byte(signingString))
76
+
}
77
+
78
+
func (sm *signingMethodAtproto) Alg() string {
79
+
return sm.alg
80
+
}
81
+
82
+
func toES256K(sig []byte) []byte {
83
+
return sig[:64]
84
+
}
85
+
86
+
func toES256(sig []byte) []byte {
87
+
return sig[:64]
88
+
}
+156
atproto/auth/jwt_test.go
+156
atproto/auth/jwt_test.go
···
1
+
package auth
2
+
3
+
import (
4
+
"context"
5
+
"fmt"
6
+
"testing"
7
+
"time"
8
+
9
+
"github.com/bluesky-social/indigo/atproto/crypto"
10
+
"github.com/bluesky-social/indigo/atproto/identity"
11
+
"github.com/bluesky-social/indigo/atproto/syntax"
12
+
13
+
"github.com/golang-jwt/jwt/v5"
14
+
"github.com/stretchr/testify/assert"
15
+
)
16
+
17
+
// Returns an early-2024 timestamp as a point in time for validating known JWTs (which contain expires-at)
18
+
func testTime() time.Time {
19
+
return time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC)
20
+
}
21
+
22
+
func validateMinimal(token string, iss, aud string, pub crypto.PublicKey) error {
23
+
24
+
p := jwt.NewParser(
25
+
jwt.WithValidMethods(supportedAlgs),
26
+
jwt.WithTimeFunc(testTime),
27
+
jwt.WithIssuer(iss),
28
+
jwt.WithAudience(aud),
29
+
)
30
+
_, err := p.Parse(token, func(tok *jwt.Token) (any, error) {
31
+
return pub, nil
32
+
})
33
+
if err != nil {
34
+
return fmt.Errorf("failed to parse auth header JWT: %w", err)
35
+
}
36
+
return nil
37
+
}
38
+
39
+
func TestSignatureMethods(t *testing.T) {
40
+
assert := assert.New(t)
41
+
42
+
jwtTestFixtures := []struct {
43
+
name string
44
+
pubkey string
45
+
iss string
46
+
aud string
47
+
jwt string
48
+
}{
49
+
{
50
+
name: "secp256k1 (K-256)",
51
+
pubkey: "did:key:zQ3shscXNYZQZSPwegiv7uQZZV5kzATLBRtgJhs7uRY7pfSk4",
52
+
iss: "did:example:iss",
53
+
aud: "did:example:aud",
54
+
jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJleHAiOjE3MTM1NzEwMTJ9.J_In_PQCMjygeeoIKyjybORD89ZnEy1bZTd--sdq_78qv3KCO9181ZAh-2Pl0qlXZjfUlxgIa6wiak2NtsT98g",
55
+
},
56
+
{
57
+
name: "secp256k1 (K-256)",
58
+
pubkey: "did:key:zQ3shqKrpHzQ5HDfhgcYMWaFcpBK3SS39wZLdTjA5GeakX8G5",
59
+
iss: "did:example:iss",
60
+
aud: "did:example:aud",
61
+
jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJleHAiOjE3MTM1NzExMzJ9.itNeYcF5oFMZIGxtnbJhE4McSniv_aR-Yk1Wj8uWk1K8YjlS2fzuJMo0-fILV3payETxn6r45f0FfpTaqY0EZQ",
62
+
},
63
+
{
64
+
name: "P-256",
65
+
pubkey: "did:key:zDnaeXRDKRCEUoYxi8ZJS2pDsgfxUh3pZiu3SES9nbY4DoART",
66
+
iss: "did:example:iss",
67
+
aud: "did:example:aud",
68
+
jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJleHAiOjE3MTM1NzE1NTR9.FFRLm7SGbDUp6cL0WoCs0L5oqNkjCXB963TqbgI-KxIjbiqMQATVCalcMJx17JGTjMmfVHJP6Op_V4Z0TTjqog",
69
+
},
70
+
}
71
+
72
+
for _, fix := range jwtTestFixtures {
73
+
74
+
pubk, err := crypto.ParsePublicDIDKey(fix.pubkey)
75
+
if err != nil {
76
+
t.Fatal(err)
77
+
}
78
+
79
+
assert.NoError(validateMinimal(fix.jwt, fix.iss, fix.aud, pubk))
80
+
}
81
+
}
82
+
83
+
func testSigningValidation(t *testing.T, priv crypto.PrivateKey) {
84
+
assert := assert.New(t)
85
+
ctx := context.Background()
86
+
87
+
iss := syntax.DID("did:example:iss")
88
+
aud := "did:example:aud#svc"
89
+
lxm := syntax.NSID("com.example.api")
90
+
91
+
priv, err := crypto.GeneratePrivateKeyP256()
92
+
if err != nil {
93
+
t.Fatal(err)
94
+
}
95
+
pub, err := priv.PublicKey()
96
+
if err != nil {
97
+
t.Fatal(err)
98
+
}
99
+
100
+
dir := identity.NewMockDirectory()
101
+
dir.Insert(identity.Identity{
102
+
DID: iss,
103
+
Keys: map[string]identity.Key{
104
+
"atproto": identity.Key{
105
+
Type: "Multikey",
106
+
PublicKeyMultibase: pub.Multibase(),
107
+
},
108
+
},
109
+
})
110
+
111
+
v := ServiceAuthValidator{
112
+
Audience: aud,
113
+
Dir: &dir,
114
+
}
115
+
116
+
t1, err := SignServiceAuth(iss, aud, time.Minute, nil, priv)
117
+
if err != nil {
118
+
t.Fatal(err)
119
+
}
120
+
d1, err := v.Validate(ctx, t1, nil)
121
+
assert.NoError(err)
122
+
assert.Equal(d1, iss)
123
+
_, err = v.Validate(ctx, t1, &lxm)
124
+
assert.Error(err)
125
+
126
+
t2, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv)
127
+
if err != nil {
128
+
t.Fatal(err)
129
+
}
130
+
d2, err := v.Validate(ctx, t2, nil)
131
+
assert.NoError(err)
132
+
assert.Equal(d2, iss)
133
+
_, err = v.Validate(ctx, t2, &lxm)
134
+
assert.NoError(err)
135
+
136
+
_, err = v.Validate(ctx, t2, nil)
137
+
assert.NoError(err)
138
+
_, err = v.Validate(ctx, t2, &lxm)
139
+
assert.NoError(err)
140
+
}
141
+
142
+
func TestP256SigningValidation(t *testing.T) {
143
+
priv, err := crypto.GeneratePrivateKeyP256()
144
+
if err != nil {
145
+
t.Fatal(err)
146
+
}
147
+
testSigningValidation(t, priv)
148
+
}
149
+
150
+
func TestK256SigningValidation(t *testing.T) {
151
+
priv, err := crypto.GeneratePrivateKeyK256()
152
+
if err != nil {
153
+
t.Fatal(err)
154
+
}
155
+
testSigningValidation(t, priv)
156
+
}