fork of indigo with slightly nicer lexgen
at main 4.5 kB view raw
1package auth 2 3import ( 4 "context" 5 "crypto/rand" 6 "encoding/base64" 7 "errors" 8 "fmt" 9 "log/slog" 10 "time" 11 12 "github.com/bluesky-social/indigo/atproto/crypto" 13 "github.com/bluesky-social/indigo/atproto/identity" 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 16 "github.com/golang-jwt/jwt/v5" 17) 18 19// TODO: check for uniqueness of JTI (random nonce) to prevent token replay 20 21type ServiceAuthValidator struct { 22 // Service DID reference for this validator: a DID with optional #-separated fragment 23 Audience string 24 Dir identity.Directory 25 TimestampLeeway time.Duration 26} 27 28type serviceAuthClaims struct { 29 jwt.RegisteredClaims 30 31 LexMethod string `json:"lxm,omitempty"` 32} 33 34func (s *ServiceAuthValidator) Validate(ctx context.Context, tokenString string, lexMethod *syntax.NSID) (syntax.DID, error) { 35 36 leeway := s.TimestampLeeway 37 if leeway == 0 { 38 leeway = 5 * time.Second 39 } 40 41 opts := []jwt.ParserOption{ 42 jwt.WithValidMethods(supportedAlgs), 43 jwt.WithAudience(s.Audience), 44 jwt.WithExpirationRequired(), 45 jwt.WithIssuedAt(), 46 jwt.WithLeeway(leeway), 47 } 48 49 token, err := jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...) 50 if err != nil && errors.Is(err, jwt.ErrTokenSignatureInvalid) { 51 // if signature validation fails, purge the directory and try again 52 // TODO: probably need to cache or rate-limit this? 53 54 // do an unvalidated extraction of 'iss' from JWT 55 insecure := jwt.NewParser(jwt.WithoutClaimsValidation()) 56 t, _, err := insecure.ParseUnverified(tokenString, &jwt.MapClaims{}) 57 claims, ok := t.Claims.(*jwt.MapClaims) 58 if !ok { 59 return "", jwt.ErrTokenInvalidClaims 60 } 61 iss, err := claims.GetIssuer() 62 if err != nil { 63 return "", err 64 } 65 did, err := syntax.ParseDID(iss) 66 if err != nil { 67 return "", fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err) 68 } 69 70 slog.Info("purging directory and retrying service auth signature validation", "did", did) 71 err = s.Dir.Purge(ctx, did.AtIdentifier()) 72 if err != nil { 73 slog.Error("purging identity directory", "did", did, "err", err) 74 } 75 token, err = jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...) 76 } 77 if err != nil { 78 return "", err 79 } 80 claims, ok := token.Claims.(*serviceAuthClaims) 81 if !ok { 82 // TODO: is the error message returned descriptive enough? 83 return "", jwt.ErrTokenInvalidClaims 84 } 85 86 if lexMethod != nil && claims.LexMethod != lexMethod.String() { 87 return "", fmt.Errorf("%w: Lexicon endpoint (LXM)", jwt.ErrTokenInvalidClaims) 88 } 89 90 // NOTE: KeyFunc has already parsed issuer, so we know it is a valid DID 91 did := syntax.DID(claims.Issuer) 92 return did, nil 93} 94 95// resolves public key from identity directory 96func (s *ServiceAuthValidator) fetchIssuerKeyFunc(ctx context.Context) func(token *jwt.Token) (any, error) { 97 return func(token *jwt.Token) (any, error) { 98 claims, ok := token.Claims.(*serviceAuthClaims) 99 if !ok { 100 return nil, jwt.ErrTokenInvalidClaims 101 } 102 iss, err := claims.GetIssuer() 103 if err != nil { 104 return nil, fmt.Errorf("%w: missing 'iss' claim", jwt.ErrTokenInvalidIssuer) 105 } 106 did, err := syntax.ParseDID(iss) 107 if err != nil { 108 return nil, fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err) 109 } 110 // NOTE: this will do handle resolution by default 111 ident, err := s.Dir.LookupDID(ctx, did) 112 if err != nil { 113 return nil, fmt.Errorf("%w: resolving DID (%s): %w", jwt.ErrTokenInvalidIssuer, did, err) 114 } 115 return ident.PublicKey() 116 } 117} 118 119func randomNonce() string { 120 buf := make([]byte, 16) 121 rand.Read(buf) 122 return base64.RawURLEncoding.EncodeToString(buf) 123} 124 125func SignServiceAuth(iss syntax.DID, aud string, ttl time.Duration, lexMethod *syntax.NSID, priv crypto.PrivateKey) (string, error) { 126 claims := serviceAuthClaims{ 127 RegisteredClaims: jwt.RegisteredClaims{ 128 ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)), 129 IssuedAt: jwt.NewNumericDate(time.Now()), 130 Issuer: iss.String(), 131 Audience: []string{aud}, 132 ID: randomNonce(), 133 }, 134 } 135 if lexMethod != nil { 136 claims.LexMethod = lexMethod.String() 137 } 138 139 var sm *signingMethodAtproto 140 141 // NOTE: could also have a crypto.PrivateKey.Alg() method which returns a string 142 switch priv.(type) { 143 case *crypto.PrivateKeyP256: 144 sm = signingMethodES256 145 case *crypto.PrivateKeyK256: 146 sm = signingMethodES256K 147 default: 148 return "", fmt.Errorf("unknown signing key type: %T", priv) 149 } 150 151 token := jwt.NewWithClaims(sm, claims) 152 return token.SignedString(priv) 153}