1package auth
2
3import (
4 "context"
5 "crypto/rand"
6 "encoding/base64"
7 "errors"
8 "fmt"
9 "log/slog"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/crypto"
13 "github.com/bluesky-social/indigo/atproto/identity"
14 "github.com/bluesky-social/indigo/atproto/syntax"
15
16 "github.com/golang-jwt/jwt/v5"
17)
18
19// TODO: check for uniqueness of JTI (random nonce) to prevent token replay
20
21type ServiceAuthValidator struct {
22 // Service DID reference for this validator: a DID with optional #-separated fragment
23 Audience string
24 Dir identity.Directory
25 TimestampLeeway time.Duration
26}
27
28type serviceAuthClaims struct {
29 jwt.RegisteredClaims
30
31 LexMethod string `json:"lxm,omitempty"`
32}
33
34func (s *ServiceAuthValidator) Validate(ctx context.Context, tokenString string, lexMethod *syntax.NSID) (syntax.DID, error) {
35
36 leeway := s.TimestampLeeway
37 if leeway == 0 {
38 leeway = 5 * time.Second
39 }
40
41 opts := []jwt.ParserOption{
42 jwt.WithValidMethods(supportedAlgs),
43 jwt.WithAudience(s.Audience),
44 jwt.WithExpirationRequired(),
45 jwt.WithIssuedAt(),
46 jwt.WithLeeway(leeway),
47 }
48
49 token, err := jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...)
50 if err != nil && errors.Is(err, jwt.ErrTokenSignatureInvalid) {
51 // if signature validation fails, purge the directory and try again
52 // TODO: probably need to cache or rate-limit this?
53
54 // do an unvalidated extraction of 'iss' from JWT
55 insecure := jwt.NewParser(jwt.WithoutClaimsValidation())
56 t, _, err := insecure.ParseUnverified(tokenString, &jwt.MapClaims{})
57 claims, ok := t.Claims.(*jwt.MapClaims)
58 if !ok {
59 return "", jwt.ErrTokenInvalidClaims
60 }
61 iss, err := claims.GetIssuer()
62 if err != nil {
63 return "", err
64 }
65 did, err := syntax.ParseDID(iss)
66 if err != nil {
67 return "", fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err)
68 }
69
70 slog.Info("purging directory and retrying service auth signature validation", "did", did)
71 err = s.Dir.Purge(ctx, did.AtIdentifier())
72 if err != nil {
73 slog.Error("purging identity directory", "did", did, "err", err)
74 }
75 token, err = jwt.ParseWithClaims(tokenString, &serviceAuthClaims{}, s.fetchIssuerKeyFunc(ctx), opts...)
76 }
77 if err != nil {
78 return "", err
79 }
80 claims, ok := token.Claims.(*serviceAuthClaims)
81 if !ok {
82 // TODO: is the error message returned descriptive enough?
83 return "", jwt.ErrTokenInvalidClaims
84 }
85
86 if lexMethod != nil && claims.LexMethod != lexMethod.String() {
87 return "", fmt.Errorf("%w: Lexicon endpoint (LXM)", jwt.ErrTokenInvalidClaims)
88 }
89
90 // NOTE: KeyFunc has already parsed issuer, so we know it is a valid DID
91 did := syntax.DID(claims.Issuer)
92 return did, nil
93}
94
95// resolves public key from identity directory
96func (s *ServiceAuthValidator) fetchIssuerKeyFunc(ctx context.Context) func(token *jwt.Token) (any, error) {
97 return func(token *jwt.Token) (any, error) {
98 claims, ok := token.Claims.(*serviceAuthClaims)
99 if !ok {
100 return nil, jwt.ErrTokenInvalidClaims
101 }
102 iss, err := claims.GetIssuer()
103 if err != nil {
104 return nil, fmt.Errorf("%w: missing 'iss' claim", jwt.ErrTokenInvalidIssuer)
105 }
106 did, err := syntax.ParseDID(iss)
107 if err != nil {
108 return nil, fmt.Errorf("%w: invalid DID: %w", jwt.ErrTokenInvalidIssuer, err)
109 }
110 // NOTE: this will do handle resolution by default
111 ident, err := s.Dir.LookupDID(ctx, did)
112 if err != nil {
113 return nil, fmt.Errorf("%w: resolving DID (%s): %w", jwt.ErrTokenInvalidIssuer, did, err)
114 }
115 return ident.PublicKey()
116 }
117}
118
119func randomNonce() string {
120 buf := make([]byte, 16)
121 rand.Read(buf)
122 return base64.RawURLEncoding.EncodeToString(buf)
123}
124
125func SignServiceAuth(iss syntax.DID, aud string, ttl time.Duration, lexMethod *syntax.NSID, priv crypto.PrivateKey) (string, error) {
126 claims := serviceAuthClaims{
127 RegisteredClaims: jwt.RegisteredClaims{
128 ExpiresAt: jwt.NewNumericDate(time.Now().Add(ttl)),
129 IssuedAt: jwt.NewNumericDate(time.Now()),
130 Issuer: iss.String(),
131 Audience: []string{aud},
132 ID: randomNonce(),
133 },
134 }
135 if lexMethod != nil {
136 claims.LexMethod = lexMethod.String()
137 }
138
139 var sm *signingMethodAtproto
140
141 // NOTE: could also have a crypto.PrivateKey.Alg() method which returns a string
142 switch priv.(type) {
143 case *crypto.PrivateKeyP256:
144 sm = signingMethodES256
145 case *crypto.PrivateKeyK256:
146 sm = signingMethodES256K
147 default:
148 return "", fmt.Errorf("unknown signing key type: %T", priv)
149 }
150
151 token := jwt.NewWithClaims(sm, claims)
152 return token.SignedString(priv)
153}