fork of indigo with slightly nicer lexgen

MVP atproto inter-service auth token support

Changed files
+391
atproto
+147
atproto/auth/jwt.go
··· 1 + package auth 2 + 3 + import ( 4 + "context" 5 + "encoding/base64" 6 + "errors" 7 + "fmt" 8 + "log/slog" 9 + "math/rand" 10 + "time" 11 + 12 + "github.com/bluesky-social/indigo/atproto/crypto" 13 + "github.com/bluesky-social/indigo/atproto/identity" 14 + "github.com/bluesky-social/indigo/atproto/syntax" 15 + 16 + "github.com/golang-jwt/jwt/v5" 17 + ) 18 + 19 + // TODO: check for uniqueness of JTI (random nonce) to prevent token replay 20 + 21 + type ServiceAuthValidator struct { 22 + // Service DID reference for this validator: a DID with optional #-separated fragment 23 + Audience string 24 + Dir identity.Directory 25 + } 26 + 27 + type serviceAuthClaims struct { 28 + jwt.RegisteredClaims 29 + 30 + LexMethod string `json:"lxm,omitempty"` 31 + } 32 + 33 + func (s *ServiceAuthValidator) Validate(ctx context.Context, tokenString string, lexMethod *syntax.NSID) (syntax.DID, error) { 34 + 35 + opts := []jwt.ParserOption{ 36 + jwt.WithValidMethods(supportedAlgs), 37 + jwt.WithAudience(s.Audience), 38 + jwt.WithExpirationRequired(), 39 + jwt.WithIssuedAt(), 40 + jwt.WithLeeway(5 * time.Second), // TODO: configurable? better default? 41 + } 42 + 43 + token, err := jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...) 44 + if err != nil && errors.Is(err, jwt.ErrTokenSignatureInvalid) { 45 + // if signature validation fails, purge the directory and try again 46 + // TODO: probably need to cache or rate-limit this? 47 + 48 + // do an unvalidated extraction of 'iss' from JWT 49 + insecure := jwt.NewParser(jwt.WithoutClaimsValidation()) 50 + t, _, err := insecure.ParseUnverified(tokenString, &jwt.MapClaims{}) 51 + claims, ok := t.Claims.(*jwt.MapClaims) 52 + if !ok { 53 + return "", jwt.ErrTokenInvalidClaims 54 + } 55 + iss, err := claims.GetIssuer() 56 + if err != nil { 57 + return "", err 58 + } 59 + did, err := syntax.ParseDID(iss) 60 + if err != nil { 61 + return "", fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err) 62 + } 63 + 64 + slog.Info("purging directory and retrying service auth signature validation", "did", did) 65 + err = s.Dir.Purge(ctx, did.AtIdentifier()) 66 + if err != nil { 67 + slog.Error("purging identity directory", "did", did, "err", err) 68 + } 69 + token, err = jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...) 70 + } 71 + if err != nil { 72 + return "", err 73 + } 74 + claims, ok := token.Claims.(*serviceAuthClaims) 75 + if !ok { 76 + // TODO: is this the best error here? 77 + return "", jwt.ErrTokenInvalidClaims 78 + } 79 + 80 + if lexMethod != nil && claims.LexMethod != lexMethod.String() { 81 + return "", fmt.Errorf("%w: Lexicon endpoint (LXM)", jwt.ErrTokenInvalidClaims) 82 + } 83 + 84 + // NOTE: KeyFunc has already parsed issuer, so we know it is a valid DID 85 + did := syntax.DID(claims.Issuer) 86 + return did, nil 87 + } 88 + 89 + // resolves public key from identity directory 90 + func (s *ServiceAuthValidator) fetchIssuerKeyFunc(ctx context.Context) func(token *jwt.Token) (any, error) { 91 + return func(token *jwt.Token) (any, error) { 92 + claims, ok := token.Claims.(*serviceAuthClaims) 93 + if !ok { 94 + return nil, fmt.Errorf("%w: missing 'iss'", jwt.ErrTokenInvalidClaims) 95 + } 96 + iss, err := claims.GetIssuer() 97 + if err != nil { 98 + return nil, fmt.Errorf("%w: missing 'iss'", jwt.ErrTokenInvalidClaims) 99 + } 100 + did, err := syntax.ParseDID(iss) 101 + if err != nil { 102 + return nil, fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err) 103 + } 104 + // NOTE: this will do handle resolution by default 105 + ident, err := s.Dir.LookupDID(ctx, did) 106 + if err != nil { 107 + return nil, fmt.Errorf("%w: resolving DID (%s): %w", jwt.ErrTokenInvalidIssuer, did, err) 108 + } 109 + return ident.PublicKey() 110 + } 111 + } 112 + 113 + func randomNonce() string { 114 + buf := make([]byte, 16) 115 + rand.Read(buf) 116 + return base64.RawURLEncoding.EncodeToString(buf) 117 + } 118 + 119 + func SignServiceAuth(iss syntax.DID, aud string, ttl time.Duration, lexMethod *syntax.NSID, priv crypto.PrivateKey) (string, error) { 120 + claims := serviceAuthClaims{ 121 + RegisteredClaims: jwt.RegisteredClaims{ 122 + ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), 123 + IssuedAt: jwt.NewNumericDate(time.Now()), 124 + Issuer: iss.String(), 125 + Audience: []string{aud}, 126 + ID: randomNonce(), 127 + }, 128 + } 129 + if lexMethod != nil { 130 + claims.LexMethod = lexMethod.String() 131 + } 132 + 133 + var sm *signingMethodAtproto 134 + 135 + // NOTE: could also have a crypto.PrivateKey.Alg() method which returns a string 136 + switch priv.(type) { 137 + case *crypto.PrivateKeyP256: 138 + sm = signingMethodES256 139 + case *crypto.PrivateKeyK256: 140 + sm = signingMethodES256K 141 + default: 142 + return "", fmt.Errorf("unknown signing key type") 143 + } 144 + 145 + token := jwt.NewWithClaims(sm, claims) 146 + return token.SignedString(priv) 147 + }
+88
atproto/auth/jwt_signing.go
··· 1 + package auth 2 + 3 + import ( 4 + "crypto" 5 + 6 + atcrypto "github.com/bluesky-social/indigo/atproto/crypto" 7 + "github.com/golang-jwt/jwt/v5" 8 + ) 9 + 10 + var ( 11 + signingMethodES256K *signingMethodAtproto 12 + signingMethodES256 *signingMethodAtproto 13 + supportedAlgs []string 14 + ) 15 + 16 + // Implementation of jwt.SigningMethod for the `atproto/crypto` types. 17 + type signingMethodAtproto struct { 18 + alg string 19 + hash crypto.Hash 20 + toOutSig toOutSig 21 + sigLen int 22 + } 23 + 24 + type toOutSig func(sig []byte) []byte 25 + 26 + func init() { 27 + // tells JWT library to serialize 'aud' as regular string, not array of strings (when signing) 28 + jwt.MarshalSingleStringAsArray = false 29 + 30 + signingMethodES256K = &signingMethodAtproto{ 31 + alg: "ES256K", 32 + hash: crypto.SHA256, 33 + toOutSig: toES256K, 34 + sigLen: 64, 35 + } 36 + jwt.RegisterSigningMethod(signingMethodES256K.Alg(), func() jwt.SigningMethod { 37 + return signingMethodES256K 38 + }) 39 + signingMethodES256 = &signingMethodAtproto{ 40 + alg: "ES256", 41 + hash: crypto.SHA256, 42 + toOutSig: toES256, 43 + sigLen: 64, 44 + } 45 + jwt.RegisterSigningMethod(signingMethodES256.Alg(), func() jwt.SigningMethod { 46 + return signingMethodES256 47 + }) 48 + supportedAlgs = []string{signingMethodES256K.Alg(), signingMethodES256.Alg()} 49 + } 50 + 51 + func (sm *signingMethodAtproto) Verify(signingString string, sig []byte, key interface{}) error { 52 + pub, ok := key.(atcrypto.PublicKey) 53 + if !ok { 54 + return jwt.ErrInvalidKeyType 55 + } 56 + 57 + if !sm.hash.Available() { 58 + return jwt.ErrHashUnavailable 59 + } 60 + 61 + if len(sig) != sm.sigLen { 62 + return jwt.ErrTokenSignatureInvalid 63 + } 64 + 65 + // NOTE: important to use using "lenient" variant here 66 + return pub.HashAndVerifyLenient([]byte(signingString), sig) 67 + } 68 + 69 + func (sm *signingMethodAtproto) Sign(signingString string, key interface{}) ([]byte, error) { 70 + priv, ok := key.(atcrypto.PrivateKey) 71 + if !ok { 72 + return nil, jwt.ErrInvalidKeyType 73 + } 74 + 75 + return priv.HashAndSign([]byte(signingString)) 76 + } 77 + 78 + func (sm *signingMethodAtproto) Alg() string { 79 + return sm.alg 80 + } 81 + 82 + func toES256K(sig []byte) []byte { 83 + return sig[:64] 84 + } 85 + 86 + func toES256(sig []byte) []byte { 87 + return sig[:64] 88 + }
+156
atproto/auth/jwt_test.go
··· 1 + package auth 2 + 3 + import ( 4 + "context" 5 + "fmt" 6 + "testing" 7 + "time" 8 + 9 + "github.com/bluesky-social/indigo/atproto/crypto" 10 + "github.com/bluesky-social/indigo/atproto/identity" 11 + "github.com/bluesky-social/indigo/atproto/syntax" 12 + 13 + "github.com/golang-jwt/jwt/v5" 14 + "github.com/stretchr/testify/assert" 15 + ) 16 + 17 + // Returns an early-2024 timestamp as a point in time for validating known JWTs (which contain expires-at) 18 + func testTime() time.Time { 19 + return time.Date(2024, 1, 1, 0, 0, 0, 0, time.UTC) 20 + } 21 + 22 + func validateMinimal(token string, iss, aud string, pub crypto.PublicKey) error { 23 + 24 + p := jwt.NewParser( 25 + jwt.WithValidMethods(supportedAlgs), 26 + jwt.WithTimeFunc(testTime), 27 + jwt.WithIssuer(iss), 28 + jwt.WithAudience(aud), 29 + ) 30 + _, err := p.Parse(token, func(tok *jwt.Token) (any, error) { 31 + return pub, nil 32 + }) 33 + if err != nil { 34 + return fmt.Errorf("failed to parse auth header JWT: %w", err) 35 + } 36 + return nil 37 + } 38 + 39 + func TestSignatureMethods(t *testing.T) { 40 + assert := assert.New(t) 41 + 42 + jwtTestFixtures := []struct { 43 + name string 44 + pubkey string 45 + iss string 46 + aud string 47 + jwt string 48 + }{ 49 + { 50 + name: "secp256k1 (K-256)", 51 + pubkey: "did:key:zQ3shscXNYZQZSPwegiv7uQZZV5kzATLBRtgJhs7uRY7pfSk4", 52 + iss: "did:example:iss", 53 + aud: "did:example:aud", 54 + jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJleHAiOjE3MTM1NzEwMTJ9.J_In_PQCMjygeeoIKyjybORD89ZnEy1bZTd--sdq_78qv3KCO9181ZAh-2Pl0qlXZjfUlxgIa6wiak2NtsT98g", 55 + }, 56 + { 57 + name: "secp256k1 (K-256)", 58 + pubkey: "did:key:zQ3shqKrpHzQ5HDfhgcYMWaFcpBK3SS39wZLdTjA5GeakX8G5", 59 + iss: "did:example:iss", 60 + aud: "did:example:aud", 61 + jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NksifQ.eyJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJleHAiOjE3MTM1NzExMzJ9.itNeYcF5oFMZIGxtnbJhE4McSniv_aR-Yk1Wj8uWk1K8YjlS2fzuJMo0-fILV3payETxn6r45f0FfpTaqY0EZQ", 62 + }, 63 + { 64 + name: "P-256", 65 + pubkey: "did:key:zDnaeXRDKRCEUoYxi8ZJS2pDsgfxUh3pZiu3SES9nbY4DoART", 66 + iss: "did:example:iss", 67 + aud: "did:example:aud", 68 + jwt: "eyJ0eXAiOiJKV1QiLCJhbGciOiJFUzI1NiJ9.eyJpc3MiOiJkaWQ6ZXhhbXBsZTppc3MiLCJhdWQiOiJkaWQ6ZXhhbXBsZTphdWQiLCJleHAiOjE3MTM1NzE1NTR9.FFRLm7SGbDUp6cL0WoCs0L5oqNkjCXB963TqbgI-KxIjbiqMQATVCalcMJx17JGTjMmfVHJP6Op_V4Z0TTjqog", 69 + }, 70 + } 71 + 72 + for _, fix := range jwtTestFixtures { 73 + 74 + pubk, err := crypto.ParsePublicDIDKey(fix.pubkey) 75 + if err != nil { 76 + t.Fatal(err) 77 + } 78 + 79 + assert.NoError(validateMinimal(fix.jwt, fix.iss, fix.aud, pubk)) 80 + } 81 + } 82 + 83 + func testSigningValidation(t *testing.T, priv crypto.PrivateKey) { 84 + assert := assert.New(t) 85 + ctx := context.Background() 86 + 87 + iss := syntax.DID("did:example:iss") 88 + aud := "did:example:aud#svc" 89 + lxm := syntax.NSID("com.example.api") 90 + 91 + priv, err := crypto.GeneratePrivateKeyP256() 92 + if err != nil { 93 + t.Fatal(err) 94 + } 95 + pub, err := priv.PublicKey() 96 + if err != nil { 97 + t.Fatal(err) 98 + } 99 + 100 + dir := identity.NewMockDirectory() 101 + dir.Insert(identity.Identity{ 102 + DID: iss, 103 + Keys: map[string]identity.Key{ 104 + "atproto": identity.Key{ 105 + Type: "Multikey", 106 + PublicKeyMultibase: pub.Multibase(), 107 + }, 108 + }, 109 + }) 110 + 111 + v := ServiceAuthValidator{ 112 + Audience: aud, 113 + Dir: &dir, 114 + } 115 + 116 + t1, err := SignServiceAuth(iss, aud, time.Minute, nil, priv) 117 + if err != nil { 118 + t.Fatal(err) 119 + } 120 + d1, err := v.Validate(ctx, t1, nil) 121 + assert.NoError(err) 122 + assert.Equal(d1, iss) 123 + _, err = v.Validate(ctx, t1, &lxm) 124 + assert.Error(err) 125 + 126 + t2, err := SignServiceAuth(iss, aud, time.Minute, &lxm, priv) 127 + if err != nil { 128 + t.Fatal(err) 129 + } 130 + d2, err := v.Validate(ctx, t2, nil) 131 + assert.NoError(err) 132 + assert.Equal(d2, iss) 133 + _, err = v.Validate(ctx, t2, &lxm) 134 + assert.NoError(err) 135 + 136 + _, err = v.Validate(ctx, t2, nil) 137 + assert.NoError(err) 138 + _, err = v.Validate(ctx, t2, &lxm) 139 + assert.NoError(err) 140 + } 141 + 142 + func TestP256SigningValidation(t *testing.T) { 143 + priv, err := crypto.GeneratePrivateKeyP256() 144 + if err != nil { 145 + t.Fatal(err) 146 + } 147 + testSigningValidation(t, priv) 148 + } 149 + 150 + func TestK256SigningValidation(t *testing.T) { 151 + priv, err := crypto.GeneratePrivateKeyK256() 152 + if err != nil { 153 + t.Fatal(err) 154 + } 155 + testSigningValidation(t, priv) 156 + }