forked from hailey.at/cocoon
An atproto PDS written in Go
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "errors" 7 "fmt" 8 "strings" 9 "time" 10 11 "github.com/Azure/go-autorest/autorest/to" 12 "github.com/golang-jwt/jwt/v4" 13 "github.com/haileyok/cocoon/internal/helpers" 14 "github.com/haileyok/cocoon/models" 15 "github.com/haileyok/cocoon/oauth/dpop" 16 "github.com/haileyok/cocoon/oauth/provider" 17 "github.com/labstack/echo/v4" 18 "gitlab.com/yawning/secp256k1-voi" 19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 20 "gorm.io/gorm" 21) 22 23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 24 return func(e echo.Context) error { 25 username, password, ok := e.Request().BasicAuth() 26 if !ok || username != "admin" || password != s.config.AdminPassword { 27 return helpers.InputError(e, to.StringPtr("Unauthorized")) 28 } 29 30 if err := next(e); err != nil { 31 e.Error(err) 32 } 33 34 return nil 35 } 36} 37 38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 ctx := e.Request().Context() 41 logger := s.logger.With("name", "handleLegacySessionMiddleware") 42 43 authheader := e.Request().Header.Get("authorization") 44 if authheader == "" { 45 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 46 } 47 48 pts := strings.Split(authheader, " ") 49 if len(pts) != 2 { 50 return helpers.ServerError(e, nil) 51 } 52 53 // move on to oauth session middleware if this is a dpop token 54 if pts[0] == "DPoP" { 55 return next(e) 56 } 57 58 tokenstr := pts[1] 59 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 60 claims, ok := token.Claims.(jwt.MapClaims) 61 if !ok { 62 return helpers.InvalidTokenError(e) 63 } 64 65 var did string 66 var repo *models.RepoActor 67 68 // service auth tokens 69 lxm, hasLxm := claims["lxm"] 70 if hasLxm { 71 pts := strings.Split(e.Request().URL.String(), "/") 72 if lxm != pts[len(pts)-1] { 73 logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 74 return helpers.InputError(e, nil) 75 } 76 77 maybeDid, ok := claims["iss"].(string) 78 if !ok { 79 logger.Error("no iss in service auth token", "error", err) 80 return helpers.InputError(e, nil) 81 } 82 did = maybeDid 83 84 maybeRepo, err := s.getRepoActorByDid(ctx, did) 85 if err != nil { 86 logger.Error("error fetching repo", "error", err) 87 return helpers.ServerError(e, nil) 88 } 89 repo = maybeRepo 90 } 91 92 if token.Header["alg"] != "ES256K" { 93 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 94 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 95 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 96 } 97 return s.privateKey.Public(), nil 98 }) 99 if err != nil { 100 logger.Error("error parsing jwt", "error", err) 101 return helpers.ExpiredTokenError(e) 102 } 103 104 if !token.Valid { 105 return helpers.InvalidTokenError(e) 106 } 107 } else { 108 kpts := strings.Split(tokenstr, ".") 109 signingInput := kpts[0] + "." + kpts[1] 110 hash := sha256.Sum256([]byte(signingInput)) 111 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 112 if err != nil { 113 logger.Error("error decoding signature bytes", "error", err) 114 return helpers.ServerError(e, nil) 115 } 116 117 if len(sigBytes) != 64 { 118 logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 119 return helpers.ServerError(e, nil) 120 } 121 122 rBytes := sigBytes[:32] 123 sBytes := sigBytes[32:] 124 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 125 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 126 127 if repo == nil { 128 sub, ok := claims["sub"].(string) 129 if !ok { 130 s.logger.Error("no sub claim in ES256K token and repo not set") 131 return helpers.InvalidTokenError(e) 132 } 133 maybeRepo, err := s.getRepoActorByDid(ctx, sub) 134 if err != nil { 135 s.logger.Error("error fetching repo for ES256K verification", "error", err) 136 return helpers.ServerError(e, nil) 137 } 138 repo = maybeRepo 139 did = sub 140 } 141 142 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 143 if err != nil { 144 logger.Error("can't load private key", "error", err) 145 return err 146 } 147 148 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 149 if !ok { 150 logger.Error("error getting public key from sk") 151 return helpers.ServerError(e, nil) 152 } 153 154 verified := pubKey.VerifyRaw(hash[:], rr, ss) 155 if !verified { 156 logger.Error("error verifying", "error", err) 157 return helpers.ServerError(e, nil) 158 } 159 } 160 161 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 162 scope, _ := claims["scope"].(string) 163 164 if isRefresh && scope != "com.atproto.refresh" { 165 return helpers.InvalidTokenError(e) 166 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 167 return helpers.InvalidTokenError(e) 168 } 169 170 table := "tokens" 171 if isRefresh { 172 table = "refresh_tokens" 173 } 174 175 if isRefresh { 176 type Result struct { 177 Found bool 178 } 179 var result Result 180 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 181 if err == gorm.ErrRecordNotFound { 182 return helpers.InvalidTokenError(e) 183 } 184 185 logger.Error("error getting token from db", "error", err) 186 return helpers.ServerError(e, nil) 187 } 188 189 if !result.Found { 190 return helpers.InvalidTokenError(e) 191 } 192 } 193 194 exp, ok := claims["exp"].(float64) 195 if !ok { 196 logger.Error("error getting iat from token") 197 return helpers.ServerError(e, nil) 198 } 199 200 if exp < float64(time.Now().UTC().Unix()) { 201 return helpers.ExpiredTokenError(e) 202 } 203 204 if repo == nil { 205 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 206 if err != nil { 207 logger.Error("error fetching repo", "error", err) 208 return helpers.ServerError(e, nil) 209 } 210 repo = maybeRepo 211 did = repo.Repo.Did 212 } 213 214 e.Set("repo", repo) 215 e.Set("did", did) 216 e.Set("token", tokenstr) 217 218 if err := next(e); err != nil { 219 return helpers.InvalidTokenError(e) 220 } 221 222 return nil 223 } 224} 225 226func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 227 return func(e echo.Context) error { 228 ctx := e.Request().Context() 229 logger := s.logger.With("name", "handleOauthSessionMiddleware") 230 231 authheader := e.Request().Header.Get("authorization") 232 if authheader == "" { 233 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 234 } 235 236 pts := strings.Split(authheader, " ") 237 if len(pts) != 2 { 238 return helpers.ServerError(e, nil) 239 } 240 241 if pts[0] != "DPoP" { 242 return next(e) 243 } 244 245 accessToken := pts[1] 246 247 nonce := s.oauthProvider.NextNonce() 248 if nonce != "" { 249 e.Response().Header().Set("DPoP-Nonce", nonce) 250 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 251 } 252 253 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 254 if err != nil { 255 if errors.Is(err, dpop.ErrUseDpopNonce) { 256 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`) 257 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 258 return e.JSON(401, map[string]string{ 259 "error": "use_dpop_nonce", 260 }) 261 } 262 logger.Error("invalid dpop proof", "error", err) 263 return helpers.InputError(e, nil) 264 } 265 266 var oauthToken provider.OauthToken 267 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 268 logger.Error("error finding access token in db", "error", err) 269 return helpers.InputError(e, nil) 270 } 271 272 if oauthToken.Token == "" { 273 return helpers.InvalidTokenError(e) 274 } 275 276 if *oauthToken.Parameters.DpopJkt != proof.JKT { 277 logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 278 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 279 } 280 281 if time.Now().After(oauthToken.ExpiresAt) { 282 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`) 283 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate") 284 return e.JSON(401, map[string]string{ 285 "error": "invalid_token", 286 "error_description": "Token expired", 287 }) 288 } 289 290 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 291 if err != nil { 292 logger.Error("could not find actor in db", "error", err) 293 return helpers.ServerError(e, nil) 294 } 295 296 e.Set("repo", repo) 297 e.Set("did", repo.Repo.Did) 298 e.Set("token", accessToken) 299 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 300 301 return next(e) 302 } 303}