A collection of Custom Bluesky Feeds, including Fresh Feeds, all under one roof
at main 6.2 kB view raw
1package auth 2 3import ( 4 "context" 5 "fmt" 6 "net/http" 7 "strings" 8 "time" 9 10 "github.com/bluesky-social/indigo/atproto/identity" 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 es256k "github.com/ericvolp12/jwt-go-secp256k1" 13 "github.com/gin-gonic/gin" 14 "github.com/golang-jwt/jwt" 15 lru "github.com/hashicorp/golang-lru/arc/v2" 16 "github.com/prometheus/client_golang/prometheus" 17 "github.com/prometheus/client_golang/prometheus/promauto" 18 "gitlab.com/yawning/secp256k1-voi/secec" 19 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" 20 "go.opentelemetry.io/otel" 21 "go.opentelemetry.io/otel/attribute" 22 "golang.org/x/time/rate" 23) 24 25type KeyCacheEntry struct { 26 UserDID string 27 Key any 28 ExpiresAt time.Time 29} 30 31// Initialize Prometheus Metrics for cache hits and misses 32var cacheHits = promauto.NewCounterVec(prometheus.CounterOpts{ 33 Name: "feedgen_auth_cache_hits_total", 34 Help: "The total number of cache hits", 35}, []string{"cache_type"}) 36 37var cacheMisses = promauto.NewCounterVec(prometheus.CounterOpts{ 38 Name: "feedgen_auth_cache_misses_total", 39 Help: "The total number of cache misses", 40}, []string{"cache_type"}) 41 42var cacheSize = promauto.NewGaugeVec(prometheus.GaugeOpts{ 43 Name: "feedgen_auth_cache_size_bytes", 44 Help: "The size of the cache in bytes", 45}, []string{"cache_type"}) 46 47type Auth struct { 48 KeyCache *lru.ARCCache[string, KeyCacheEntry] 49 KeyCacheTTL time.Duration 50 ServiceDID string 51 Dir *identity.CacheDirectory 52} 53 54// NewAuth creates a new Auth instance with the given key cache size and TTL 55// The PLC Directory URL is also required, as well as the DID of the service 56// for JWT audience validation 57// The key cache is used to cache the public keys of users for a given TTL 58// The PLC Directory URL is used to fetch the public keys of users 59// The service DID is used to validate the audience of JWTs 60// The HTTP client is used to make requests to the PLC Directory 61// A rate limiter is used to limit the number of requests to the PLC Directory 62func NewAuth( 63 keyCacheSize int, 64 keyCacheTTL time.Duration, 65 requestsPerSecond int, 66 serviceDID string, 67) (*Auth, error) { 68 keyCache, err := lru.NewARC[string, KeyCacheEntry](keyCacheSize) 69 if err != nil { 70 return nil, fmt.Errorf("Failed to create key cache: %v", err) 71 } 72 73 // Initialize the HTTP client with OpenTelemetry instrumentation 74 client := http.Client{ 75 Transport: otelhttp.NewTransport(http.DefaultTransport), 76 } 77 78 baseDir := identity.BaseDirectory{ 79 PLCURL: identity.DefaultPLCURL, 80 PLCLimiter: rate.NewLimiter(rate.Limit(float64(requestsPerSecond)), 1), 81 HTTPClient: client, 82 TryAuthoritativeDNS: true, 83 // primary Bluesky PDS instance only supports HTTP resolution method 84 SkipDNSDomainSuffixes: []string{".bsky.social"}, 85 } 86 dir := identity.NewCacheDirectory(&baseDir, keyCacheSize, keyCacheTTL, time.Minute*2, keyCacheTTL) 87 88 return &Auth{ 89 KeyCache: keyCache, 90 KeyCacheTTL: keyCacheTTL, 91 ServiceDID: serviceDID, 92 Dir: &dir, 93 }, nil 94} 95 96func (auth *Auth) GetClaimsFromAuthHeader(ctx context.Context, authHeader string, claims jwt.Claims) error { 97 tracer := otel.Tracer("auth") 98 ctx, span := tracer.Start(ctx, "Auth:GetClaimsFromAuthHeader") 99 defer span.End() 100 101 if authHeader == "" { 102 span.End() 103 return fmt.Errorf("No Authorization header provided") 104 } 105 106 authHeaderParts := strings.Split(authHeader, " ") 107 if len(authHeaderParts) != 2 { 108 return fmt.Errorf("Invalid Authorization header") 109 } 110 111 if authHeaderParts[0] != "Bearer" { 112 return fmt.Errorf("Invalid Authorization header (expected Bearer)") 113 } 114 115 accessToken := authHeaderParts[1] 116 117 parser := jwt.Parser{ 118 ValidMethods: []string{es256k.SigningMethodES256K.Alg()}, 119 } 120 121 token, err := parser.ParseWithClaims(accessToken, claims, func(token *jwt.Token) (interface{}, error) { 122 if claims, ok := token.Claims.(*jwt.StandardClaims); ok { 123 // Get the user's key from PLC Directory 124 userDID := claims.Issuer 125 entry, ok := auth.KeyCache.Get(userDID) 126 if ok && entry.ExpiresAt.After(time.Now()) { 127 cacheHits.WithLabelValues("key").Inc() 128 span.SetAttributes(attribute.Bool("caches.keys.hit", true)) 129 return entry.Key, nil 130 } 131 132 cacheMisses.WithLabelValues("key").Inc() 133 span.SetAttributes(attribute.Bool("caches.keys.hit", false)) 134 135 did, err := syntax.ParseDID(userDID) 136 if err != nil { 137 return nil, fmt.Errorf("Failed to parse user DID: %v", err) 138 } 139 140 // Get the user's key from PLC Directory 141 id, err := auth.Dir.LookupDID(ctx, did) 142 if err != nil { 143 return nil, fmt.Errorf("Failed to lookup user DID: %v", err) 144 } 145 146 key, err := id.GetPublicKey("atproto") 147 if err != nil { 148 return nil, fmt.Errorf("Failed to get user public key: %v", err) 149 } 150 151 parsedPubkey, err := secec.NewPublicKey(key.UncompressedBytes()) 152 if err != nil { 153 return nil, fmt.Errorf("Failed to parse user public key: %v", err) 154 } 155 156 // Add the ECDSA key to the cache 157 auth.KeyCache.Add(userDID, KeyCacheEntry{ 158 Key: parsedPubkey, 159 ExpiresAt: time.Now().Add(auth.KeyCacheTTL), 160 }) 161 162 return parsedPubkey, nil 163 } 164 165 return nil, fmt.Errorf("Invalid authorization token (failed to parse claims)") 166 }) 167 168 if err != nil { 169 return fmt.Errorf("Failed to parse authorization token: %v", err) 170 } 171 172 if !token.Valid { 173 return fmt.Errorf("Invalid authorization token") 174 } 175 176 return nil 177} 178 179func (auth *Auth) AuthenticateGinRequestViaJWT(c *gin.Context) { 180 tracer := otel.Tracer("auth") 181 ctx, span := tracer.Start(c.Request.Context(), "Auth:AuthenticateGinRequestViaJWT") 182 183 authHeader := c.GetHeader("Authorization") 184 if authHeader == "" { 185 span.End() 186 c.Next() 187 return 188 } 189 190 claims := jwt.StandardClaims{} 191 192 err := auth.GetClaimsFromAuthHeader(ctx, authHeader, &claims) 193 if err != nil { 194 c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Errorf("Failed to get claims from auth header: %v", err).Error()}) 195 span.End() 196 c.Abort() 197 return 198 } 199 200 if claims.Audience != auth.ServiceDID { 201 c.JSON(http.StatusUnauthorized, gin.H{"error": fmt.Sprintf("Invalid audience (expected %s)", auth.ServiceDID)}) 202 c.Abort() 203 return 204 } 205 206 // Set claims Issuer to context as user DID 207 c.Set("user_did", claims.Issuer) 208 span.SetAttributes(attribute.String("user.did", claims.Issuer)) 209 span.End() 210 c.Next() 211}