forked from hailey.at/cocoon
An atproto PDS written in Go

fix service auth check (#17)

authored by hailey.at and committed by GitHub eb5580e9 5967b43b

Changed files
+111 -31
server
+111 -31
server/server.go
··· 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" ··· 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50 ) ··· 220 return helpers.ServerError(e, nil) 221 } 222 223 if pts[0] == "DPoP" { 224 return next(e) 225 } 226 227 tokenstr := pts[1] 228 229 - token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 230 - if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 231 - return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 232 } 233 234 - return s.privateKey.Public(), nil 235 - }) 236 - if err != nil { 237 - s.logger.Error("error parsing jwt", "error", err) 238 - // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 239 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 240 } 241 242 - claims, ok := token.Claims.(jwt.MapClaims) 243 - if !ok || !token.Valid { 244 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 245 } 246 247 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 248 - scope := claims["scope"].(string) 249 250 if isRefresh && scope != "com.atproto.refresh" { 251 return helpers.InputError(e, to.StringPtr("InvalidToken")) 252 - } else if !isRefresh && scope != "com.atproto.access" { 253 return helpers.InputError(e, to.StringPtr("InvalidToken")) 254 } 255 ··· 258 table = "refresh_tokens" 259 } 260 261 - type Result struct { 262 - Found bool 263 - } 264 - var result Result 265 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 266 - if err == gorm.ErrRecordNotFound { 267 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 268 } 269 270 - s.logger.Error("error getting token from db", "error", err) 271 - return helpers.ServerError(e, nil) 272 - } 273 274 - if !result.Found { 275 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 276 } 277 278 exp, ok := claims["exp"].(float64) ··· 285 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 286 } 287 288 - repo, err := s.getRepoActorByDid(claims["sub"].(string)) 289 - if err != nil { 290 - s.logger.Error("error fetching repo", "error", err) 291 - return helpers.ServerError(e, nil) 292 } 293 294 e.Set("repo", repo) 295 - e.Set("did", claims["sub"]) 296 e.Set("token", tokenstr) 297 298 if err := next(e); err != nil {
··· 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 + "crypto/sha256" 8 "embed" 9 + "encoding/base64" 10 "errors" 11 "fmt" 12 "io" ··· 47 "github.com/labstack/echo/v4" 48 "github.com/labstack/echo/v4/middleware" 49 slogecho "github.com/samber/slog-echo" 50 + "gitlab.com/yawning/secp256k1-voi" 51 + secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 52 "gorm.io/driver/sqlite" 53 "gorm.io/gorm" 54 ) ··· 224 return helpers.ServerError(e, nil) 225 } 226 227 + // move on to oauth session middleware if this is a dpop token 228 if pts[0] == "DPoP" { 229 return next(e) 230 } 231 232 tokenstr := pts[1] 233 + token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 234 + claims, ok := token.Claims.(jwt.MapClaims) 235 + if !ok { 236 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 237 + } 238 + 239 + var did string 240 + var repo *models.RepoActor 241 242 + // service auth tokens 243 + lxm, hasLxm := claims["lxm"] 244 + if hasLxm { 245 + pts := strings.Split(e.Request().URL.String(), "/") 246 + if lxm != pts[len(pts)-1] { 247 + s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 248 + return helpers.InputError(e, nil) 249 } 250 251 + maybeDid, ok := claims["iss"].(string) 252 + if !ok { 253 + s.logger.Error("no iss in service auth token", "error", err) 254 + return helpers.InputError(e, nil) 255 + } 256 + did = maybeDid 257 + 258 + maybeRepo, err := s.getRepoActorByDid(did) 259 + if err != nil { 260 + s.logger.Error("error fetching repo", "error", err) 261 + return helpers.ServerError(e, nil) 262 + } 263 + repo = maybeRepo 264 } 265 266 + if token.Header["alg"] != "ES256K" { 267 + token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 268 + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 269 + return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 270 + } 271 + return s.privateKey.Public(), nil 272 + }) 273 + if err != nil { 274 + s.logger.Error("error parsing jwt", "error", err) 275 + // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 276 + return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 277 + } 278 + 279 + if !token.Valid { 280 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 281 + } 282 + } else { 283 + kpts := strings.Split(tokenstr, ".") 284 + signingInput := kpts[0] + "." + kpts[1] 285 + hash := sha256.Sum256([]byte(signingInput)) 286 + sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 287 + if err != nil { 288 + s.logger.Error("error decoding signature bytes", "error", err) 289 + return helpers.ServerError(e, nil) 290 + } 291 + 292 + if len(sigBytes) != 64 { 293 + s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 294 + return helpers.ServerError(e, nil) 295 + } 296 + 297 + rBytes := sigBytes[:32] 298 + sBytes := sigBytes[32:] 299 + rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 300 + ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 301 + 302 + sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 303 + if err != nil { 304 + s.logger.Error("can't load private key", "error", err) 305 + return err 306 + } 307 + 308 + pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 309 + if !ok { 310 + s.logger.Error("error getting public key from sk") 311 + return helpers.ServerError(e, nil) 312 + } 313 + 314 + verified := pubKey.VerifyRaw(hash[:], rr, ss) 315 + if !verified { 316 + s.logger.Error("error verifying", "error", err) 317 + return helpers.ServerError(e, nil) 318 + } 319 } 320 321 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 322 + scope, _ := claims["scope"].(string) 323 324 if isRefresh && scope != "com.atproto.refresh" { 325 return helpers.InputError(e, to.StringPtr("InvalidToken")) 326 + } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 327 return helpers.InputError(e, to.StringPtr("InvalidToken")) 328 } 329 ··· 332 table = "refresh_tokens" 333 } 334 335 + if isRefresh { 336 + type Result struct { 337 + Found bool 338 } 339 + var result Result 340 + if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 341 + if err == gorm.ErrRecordNotFound { 342 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 343 + } 344 345 + s.logger.Error("error getting token from db", "error", err) 346 + return helpers.ServerError(e, nil) 347 + } 348 349 + if !result.Found { 350 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 351 + } 352 } 353 354 exp, ok := claims["exp"].(float64) ··· 361 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 362 } 363 364 + if repo == nil { 365 + maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 366 + if err != nil { 367 + s.logger.Error("error fetching repo", "error", err) 368 + return helpers.ServerError(e, nil) 369 + } 370 + repo = maybeRepo 371 + did = repo.Repo.Did 372 } 373 374 e.Set("repo", repo) 375 + e.Set("did", did) 376 e.Set("token", tokenstr) 377 378 if err := next(e); err != nil {