A collection of Custom Bluesky Feeds, including Fresh Feeds, all under one roof
1package auth
2
3import (
4 "context"
5 "fmt"
6 "net/http"
7 "strings"
8 "time"
9
10 "github.com/bluesky-social/indigo/atproto/identity"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 es256k "github.com/ericvolp12/jwt-go-secp256k1"
13 "github.com/gin-gonic/gin"
14 "github.com/golang-jwt/jwt"
15 lru "github.com/hashicorp/golang-lru/arc/v2"
16 "github.com/prometheus/client_golang/prometheus"
17 "github.com/prometheus/client_golang/prometheus/promauto"
18 "gitlab.com/yawning/secp256k1-voi/secec"
19 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
20 "go.opentelemetry.io/otel"
21 "go.opentelemetry.io/otel/attribute"
22 "golang.org/x/time/rate"
23)
24
25type KeyCacheEntry struct {
26 UserDID string
27 Key any
28 ExpiresAt time.Time
29}
30
31// Initialize Prometheus Metrics for cache hits and misses
32var cacheHits = promauto.NewCounterVec(prometheus.CounterOpts{
33 Name: "feedgen_auth_cache_hits_total",
34 Help: "The total number of cache hits",
35}, []string{"cache_type"})
36
37var cacheMisses = promauto.NewCounterVec(prometheus.CounterOpts{
38 Name: "feedgen_auth_cache_misses_total",
39 Help: "The total number of cache misses",
40}, []string{"cache_type"})
41
42var cacheSize = promauto.NewGaugeVec(prometheus.GaugeOpts{
43 Name: "feedgen_auth_cache_size_bytes",
44 Help: "The size of the cache in bytes",
45}, []string{"cache_type"})
46
47type Auth struct {
48 KeyCache *lru.ARCCache[string, KeyCacheEntry]
49 KeyCacheTTL time.Duration
50 ServiceDID string
51 Dir *identity.CacheDirectory
52}
53
54// NewAuth creates a new Auth instance with the given key cache size and TTL
55// The PLC Directory URL is also required, as well as the DID of the service
56// for JWT audience validation
57// The key cache is used to cache the public keys of users for a given TTL
58// The PLC Directory URL is used to fetch the public keys of users
59// The service DID is used to validate the audience of JWTs
60// The HTTP client is used to make requests to the PLC Directory
61// A rate limiter is used to limit the number of requests to the PLC Directory
62func NewAuth(
63 keyCacheSize int,
64 keyCacheTTL time.Duration,
65 requestsPerSecond int,
66 serviceDID string,
67) (*Auth, error) {
68 keyCache, err := lru.NewARC[string, KeyCacheEntry](keyCacheSize)
69 if err != nil {
70 return nil, fmt.Errorf("Failed to create key cache: %v", err)
71 }
72
73 // Initialize the HTTP client with OpenTelemetry instrumentation
74 client := http.Client{
75 Transport: otelhttp.NewTransport(http.DefaultTransport),
76 }
77
78 baseDir := identity.BaseDirectory{
79 PLCURL: identity.DefaultPLCURL,
80 PLCLimiter: rate.NewLimiter(rate.Limit(float64(requestsPerSecond)), 1),
81 HTTPClient: client,
82 TryAuthoritativeDNS: true,
83 // primary Bluesky PDS instance only supports HTTP resolution method
84 SkipDNSDomainSuffixes: []string{".bsky.social"},
85 }
86 dir := identity.NewCacheDirectory(&baseDir, keyCacheSize, keyCacheTTL, time.Minute*2, keyCacheTTL)
87
88 return &Auth{
89 KeyCache: keyCache,
90 KeyCacheTTL: keyCacheTTL,
91 ServiceDID: serviceDID,
92 Dir: &dir,
93 }, nil
94}
95
96func (auth *Auth) GetClaimsFromAuthHeader(ctx context.Context, authHeader string, claims jwt.Claims) error {
97 tracer := otel.Tracer("auth")
98 ctx, span := tracer.Start(ctx, "Auth:GetClaimsFromAuthHeader")
99 defer span.End()
100
101 if authHeader == "" {
102 span.End()
103 return fmt.Errorf("No Authorization header provided")
104 }
105
106 authHeaderParts := strings.Split(authHeader, " ")
107 if len(authHeaderParts) != 2 {
108 return fmt.Errorf("Invalid Authorization header")
109 }
110
111 if authHeaderParts[0] != "Bearer" {
112 return fmt.Errorf("Invalid Authorization header (expected Bearer)")
113 }
114
115 accessToken := authHeaderParts[1]
116
117 parser := jwt.Parser{
118 ValidMethods: []string{es256k.SigningMethodES256K.Alg()},
119 }
120
121 token, err := parser.ParseWithClaims(accessToken, claims, func(token *jwt.Token) (interface{}, error) {
122 if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
123 // Get the user's key from PLC Directory
124 userDID := claims.Issuer
125 entry, ok := auth.KeyCache.Get(userDID)
126 if ok && entry.ExpiresAt.After(time.Now()) {
127 cacheHits.WithLabelValues("key").Inc()
128 span.SetAttributes(attribute.Bool("caches.keys.hit", true))
129 return entry.Key, nil
130 }
131
132 cacheMisses.WithLabelValues("key").Inc()
133 span.SetAttributes(attribute.Bool("caches.keys.hit", false))
134
135 did, err := syntax.ParseDID(userDID)
136 if err != nil {
137 return nil, fmt.Errorf("Failed to parse user DID: %v", err)
138 }
139
140 // Get the user's key from PLC Directory
141 id, err := auth.Dir.LookupDID(ctx, did)
142 if err != nil {
143 return nil, fmt.Errorf("Failed to lookup user DID: %v", err)
144 }
145
146 key, err := id.GetPublicKey("atproto")
147 if err != nil {
148 return nil, fmt.Errorf("Failed to get user public key: %v", err)
149 }
150
151 parsedPubkey, err := secec.NewPublicKey(key.UncompressedBytes())
152 if err != nil {
153 return nil, fmt.Errorf("Failed to parse user public key: %v", err)
154 }
155
156 // Add the ECDSA key to the cache
157 auth.KeyCache.Add(userDID, KeyCacheEntry{
158 Key: parsedPubkey,
159 ExpiresAt: time.Now().Add(auth.KeyCacheTTL),
160 })
161
162 return parsedPubkey, nil
163 }
164
165 return nil, fmt.Errorf("Invalid authorization token (failed to parse claims)")
166 })
167
168 if err != nil {
169 return fmt.Errorf("Failed to parse authorization token: %v", err)
170 }
171
172 if !token.Valid {
173 return fmt.Errorf("Invalid authorization token")
174 }
175
176 return nil
177}
178
179func (auth *Auth) AuthenticateGinRequestViaJWT(c *gin.Context) {
180 tracer := otel.Tracer("auth")
181 ctx, span := tracer.Start(c.Request.Context(), "Auth:AuthenticateGinRequestViaJWT")
182
183 authHeader := c.GetHeader("Authorization")
184 if authHeader == "" {
185 span.End()
186 c.Next()
187 return
188 }
189
190 claims := jwt.StandardClaims{}
191
192 err := auth.GetClaimsFromAuthHeader(ctx, authHeader, &claims)
193 if err != nil {
194 c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Errorf("Failed to get claims from auth header: %v", err).Error()})
195 span.End()
196 c.Abort()
197 return
198 }
199
200 if claims.Audience != auth.ServiceDID {
201 c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("Invalid audience (expected %s)", auth.ServiceDID)})
202 c.Abort()
203 return
204 }
205
206 // Set claims Issuer to context as user DID
207 c.Set("user_did", claims.Issuer)
208 span.SetAttributes(attribute.String("user.did", claims.Issuer))
209 span.End()
210 c.Next()
211}