forked from hailey.at/cocoon
An atproto PDS written in Go

fix service auth check (#17)

authored by hailey.at and committed by GitHub eb5580e9 5967b43b

Changed files
+111 -31
server
+111 -31
server/server.go
··· 4 4 "bytes" 5 5 "context" 6 6 "crypto/ecdsa" 7 + "crypto/sha256" 7 8 "embed" 9 + "encoding/base64" 8 10 "errors" 9 11 "fmt" 10 12 "io" ··· 45 47 "github.com/labstack/echo/v4" 46 48 "github.com/labstack/echo/v4/middleware" 47 49 slogecho "github.com/samber/slog-echo" 50 + "gitlab.com/yawning/secp256k1-voi" 51 + secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 48 52 "gorm.io/driver/sqlite" 49 53 "gorm.io/gorm" 50 54 ) ··· 220 224 return helpers.ServerError(e, nil) 221 225 } 222 226 227 + // move on to oauth session middleware if this is a dpop token 223 228 if pts[0] == "DPoP" { 224 229 return next(e) 225 230 } 226 231 227 232 tokenstr := pts[1] 233 + token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 234 + claims, ok := token.Claims.(jwt.MapClaims) 235 + if !ok { 236 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 237 + } 238 + 239 + var did string 240 + var repo *models.RepoActor 228 241 229 - token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 230 - if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 231 - return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 242 + // service auth tokens 243 + lxm, hasLxm := claims["lxm"] 244 + if hasLxm { 245 + pts := strings.Split(e.Request().URL.String(), "/") 246 + if lxm != pts[len(pts)-1] { 247 + s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 248 + return helpers.InputError(e, nil) 232 249 } 233 250 234 - return s.privateKey.Public(), nil 235 - }) 236 - if err != nil { 237 - s.logger.Error("error parsing jwt", "error", err) 238 - // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 239 - return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 251 + maybeDid, ok := claims["iss"].(string) 252 + if !ok { 253 + s.logger.Error("no iss in service auth token", "error", err) 254 + return helpers.InputError(e, nil) 255 + } 256 + did = maybeDid 257 + 258 + maybeRepo, err := s.getRepoActorByDid(did) 259 + if err != nil { 260 + s.logger.Error("error fetching repo", "error", err) 261 + return helpers.ServerError(e, nil) 262 + } 263 + repo = maybeRepo 240 264 } 241 265 242 - claims, ok := token.Claims.(jwt.MapClaims) 243 - if !ok || !token.Valid { 244 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 266 + if token.Header["alg"] != "ES256K" { 267 + token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 268 + if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 269 + return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 270 + } 271 + return s.privateKey.Public(), nil 272 + }) 273 + if err != nil { 274 + s.logger.Error("error parsing jwt", "error", err) 275 + // NOTE: https://github.com/bluesky-social/atproto/discussions/3319 276 + return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"}) 277 + } 278 + 279 + if !token.Valid { 280 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 281 + } 282 + } else { 283 + kpts := strings.Split(tokenstr, ".") 284 + signingInput := kpts[0] + "." + kpts[1] 285 + hash := sha256.Sum256([]byte(signingInput)) 286 + sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 287 + if err != nil { 288 + s.logger.Error("error decoding signature bytes", "error", err) 289 + return helpers.ServerError(e, nil) 290 + } 291 + 292 + if len(sigBytes) != 64 { 293 + s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 294 + return helpers.ServerError(e, nil) 295 + } 296 + 297 + rBytes := sigBytes[:32] 298 + sBytes := sigBytes[32:] 299 + rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 300 + ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 301 + 302 + sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 303 + if err != nil { 304 + s.logger.Error("can't load private key", "error", err) 305 + return err 306 + } 307 + 308 + pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 309 + if !ok { 310 + s.logger.Error("error getting public key from sk") 311 + return helpers.ServerError(e, nil) 312 + } 313 + 314 + verified := pubKey.VerifyRaw(hash[:], rr, ss) 315 + if !verified { 316 + s.logger.Error("error verifying", "error", err) 317 + return helpers.ServerError(e, nil) 318 + } 245 319 } 246 320 247 321 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 248 - scope := claims["scope"].(string) 322 + scope, _ := claims["scope"].(string) 249 323 250 324 if isRefresh && scope != "com.atproto.refresh" { 251 325 return helpers.InputError(e, to.StringPtr("InvalidToken")) 252 - } else if !isRefresh && scope != "com.atproto.access" { 326 + } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 253 327 return helpers.InputError(e, to.StringPtr("InvalidToken")) 254 328 } 255 329 ··· 258 332 table = "refresh_tokens" 259 333 } 260 334 261 - type Result struct { 262 - Found bool 263 - } 264 - var result Result 265 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 266 - if err == gorm.ErrRecordNotFound { 267 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 335 + if isRefresh { 336 + type Result struct { 337 + Found bool 268 338 } 339 + var result Result 340 + if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 341 + if err == gorm.ErrRecordNotFound { 342 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 343 + } 269 344 270 - s.logger.Error("error getting token from db", "error", err) 271 - return helpers.ServerError(e, nil) 272 - } 345 + s.logger.Error("error getting token from db", "error", err) 346 + return helpers.ServerError(e, nil) 347 + } 273 348 274 - if !result.Found { 275 - return helpers.InputError(e, to.StringPtr("InvalidToken")) 349 + if !result.Found { 350 + return helpers.InputError(e, to.StringPtr("InvalidToken")) 351 + } 276 352 } 277 353 278 354 exp, ok := claims["exp"].(float64) ··· 285 361 return helpers.InputError(e, to.StringPtr("ExpiredToken")) 286 362 } 287 363 288 - repo, err := s.getRepoActorByDid(claims["sub"].(string)) 289 - if err != nil { 290 - s.logger.Error("error fetching repo", "error", err) 291 - return helpers.ServerError(e, nil) 364 + if repo == nil { 365 + maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 366 + if err != nil { 367 + s.logger.Error("error fetching repo", "error", err) 368 + return helpers.ServerError(e, nil) 369 + } 370 + repo = maybeRepo 371 + did = repo.Repo.Did 292 372 } 293 373 294 374 e.Set("repo", repo) 295 - e.Set("did", claims["sub"]) 375 + e.Set("did", did) 296 376 e.Set("token", tokenstr) 297 377 298 378 if err := next(e); err != nil {