An atproto PDS written in Go
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 0.3.6 268 lines 7.4 kB view raw
1package server 2 3import ( 4 "crypto/sha256" 5 "encoding/base64" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/golang-jwt/jwt/v4" 12 "github.com/haileyok/cocoon/internal/helpers" 13 "github.com/haileyok/cocoon/models" 14 "github.com/haileyok/cocoon/oauth/provider" 15 "github.com/labstack/echo/v4" 16 "gitlab.com/yawning/secp256k1-voi" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 "gorm.io/gorm" 19) 20 21func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 22 return func(e echo.Context) error { 23 username, password, ok := e.Request().BasicAuth() 24 if !ok || username != "admin" || password != s.config.AdminPassword { 25 return helpers.InputError(e, to.StringPtr("Unauthorized")) 26 } 27 28 if err := next(e); err != nil { 29 e.Error(err) 30 } 31 32 return nil 33 } 34} 35 36func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 37 return func(e echo.Context) error { 38 authheader := e.Request().Header.Get("authorization") 39 if authheader == "" { 40 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 41 } 42 43 pts := strings.Split(authheader, " ") 44 if len(pts) != 2 { 45 return helpers.ServerError(e, nil) 46 } 47 48 // move on to oauth session middleware if this is a dpop token 49 if pts[0] == "DPoP" { 50 return next(e) 51 } 52 53 tokenstr := pts[1] 54 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{}) 55 claims, ok := token.Claims.(jwt.MapClaims) 56 if !ok { 57 return helpers.InvalidTokenError(e) 58 } 59 60 var did string 61 var repo *models.RepoActor 62 63 // service auth tokens 64 lxm, hasLxm := claims["lxm"] 65 if hasLxm { 66 pts := strings.Split(e.Request().URL.String(), "/") 67 if lxm != pts[len(pts)-1] { 68 s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err) 69 return helpers.InputError(e, nil) 70 } 71 72 maybeDid, ok := claims["iss"].(string) 73 if !ok { 74 s.logger.Error("no iss in service auth token", "error", err) 75 return helpers.InputError(e, nil) 76 } 77 did = maybeDid 78 79 maybeRepo, err := s.getRepoActorByDid(did) 80 if err != nil { 81 s.logger.Error("error fetching repo", "error", err) 82 return helpers.ServerError(e, nil) 83 } 84 repo = maybeRepo 85 } 86 87 if token.Header["alg"] != "ES256K" { 88 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) { 89 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok { 90 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"]) 91 } 92 return s.privateKey.Public(), nil 93 }) 94 if err != nil { 95 s.logger.Error("error parsing jwt", "error", err) 96 return helpers.ExpiredTokenError(e) 97 } 98 99 if !token.Valid { 100 return helpers.InvalidTokenError(e) 101 } 102 } else { 103 kpts := strings.Split(tokenstr, ".") 104 signingInput := kpts[0] + "." + kpts[1] 105 hash := sha256.Sum256([]byte(signingInput)) 106 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2]) 107 if err != nil { 108 s.logger.Error("error decoding signature bytes", "error", err) 109 return helpers.ServerError(e, nil) 110 } 111 112 if len(sigBytes) != 64 { 113 s.logger.Error("incorrect sigbytes length", "length", len(sigBytes)) 114 return helpers.ServerError(e, nil) 115 } 116 117 rBytes := sigBytes[:32] 118 sBytes := sigBytes[32:] 119 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes)) 120 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes)) 121 122 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 123 if err != nil { 124 s.logger.Error("can't load private key", "error", err) 125 return err 126 } 127 128 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey) 129 if !ok { 130 s.logger.Error("error getting public key from sk") 131 return helpers.ServerError(e, nil) 132 } 133 134 verified := pubKey.VerifyRaw(hash[:], rr, ss) 135 if !verified { 136 s.logger.Error("error verifying", "error", err) 137 return helpers.ServerError(e, nil) 138 } 139 } 140 141 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession" 142 scope, _ := claims["scope"].(string) 143 144 if isRefresh && scope != "com.atproto.refresh" { 145 return helpers.InvalidTokenError(e) 146 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" { 147 return helpers.InvalidTokenError(e) 148 } 149 150 table := "tokens" 151 if isRefresh { 152 table = "refresh_tokens" 153 } 154 155 if isRefresh { 156 type Result struct { 157 Found bool 158 } 159 var result Result 160 if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 161 if err == gorm.ErrRecordNotFound { 162 return helpers.InvalidTokenError(e) 163 } 164 165 s.logger.Error("error getting token from db", "error", err) 166 return helpers.ServerError(e, nil) 167 } 168 169 if !result.Found { 170 return helpers.InvalidTokenError(e) 171 } 172 } 173 174 exp, ok := claims["exp"].(float64) 175 if !ok { 176 s.logger.Error("error getting iat from token") 177 return helpers.ServerError(e, nil) 178 } 179 180 if exp < float64(time.Now().UTC().Unix()) { 181 return helpers.ExpiredTokenError(e) 182 } 183 184 if repo == nil { 185 maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 186 if err != nil { 187 s.logger.Error("error fetching repo", "error", err) 188 return helpers.ServerError(e, nil) 189 } 190 repo = maybeRepo 191 did = repo.Repo.Did 192 } 193 194 e.Set("repo", repo) 195 e.Set("did", did) 196 e.Set("token", tokenstr) 197 198 if err := next(e); err != nil { 199 return helpers.InvalidTokenError(e) 200 } 201 202 return nil 203 } 204} 205 206func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 207 return func(e echo.Context) error { 208 authheader := e.Request().Header.Get("authorization") 209 if authheader == "" { 210 return e.JSON(401, map[string]string{"error": "Unauthorized"}) 211 } 212 213 pts := strings.Split(authheader, " ") 214 if len(pts) != 2 { 215 return helpers.ServerError(e, nil) 216 } 217 218 if pts[0] != "DPoP" { 219 return next(e) 220 } 221 222 accessToken := pts[1] 223 224 nonce := s.oauthProvider.NextNonce() 225 if nonce != "" { 226 e.Response().Header().Set("DPoP-Nonce", nonce) 227 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 228 } 229 230 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken)) 231 if err != nil { 232 s.logger.Error("invalid dpop proof", "error", err) 233 return helpers.InputError(e, to.StringPtr(err.Error())) 234 } 235 236 var oauthToken provider.OauthToken 237 if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 238 s.logger.Error("error finding access token in db", "error", err) 239 return helpers.InputError(e, nil) 240 } 241 242 if oauthToken.Token == "" { 243 return helpers.InvalidTokenError(e) 244 } 245 246 if *oauthToken.Parameters.DpopJkt != proof.JKT { 247 s.logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT) 248 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch")) 249 } 250 251 if time.Now().After(oauthToken.ExpiresAt) { 252 return helpers.ExpiredTokenError(e) 253 } 254 255 repo, err := s.getRepoActorByDid(oauthToken.Sub) 256 if err != nil { 257 s.logger.Error("could not find actor in db", "error", err) 258 return helpers.ServerError(e, nil) 259 } 260 261 e.Set("repo", repo) 262 e.Set("did", repo.Repo.Did) 263 e.Set("token", accessToken) 264 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " ")) 265 266 return next(e) 267 } 268}