1package server
2
3import (
4 "crypto/sha256"
5 "encoding/base64"
6 "errors"
7 "fmt"
8 "strings"
9 "time"
10
11 "github.com/Azure/go-autorest/autorest/to"
12 "github.com/golang-jwt/jwt/v4"
13 "github.com/haileyok/cocoon/internal/helpers"
14 "github.com/haileyok/cocoon/models"
15 "github.com/haileyok/cocoon/oauth/dpop"
16 "github.com/haileyok/cocoon/oauth/provider"
17 "github.com/labstack/echo/v4"
18 "gitlab.com/yawning/secp256k1-voi"
19 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
20 "gorm.io/gorm"
21)
22
23func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
24 return func(e echo.Context) error {
25 username, password, ok := e.Request().BasicAuth()
26 if !ok || username != "admin" || password != s.config.AdminPassword {
27 return helpers.InputError(e, to.StringPtr("Unauthorized"))
28 }
29
30 if err := next(e); err != nil {
31 e.Error(err)
32 }
33
34 return nil
35 }
36}
37
38func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
39 return func(e echo.Context) error {
40 ctx := e.Request().Context()
41 logger := s.logger.With("name", "handleLegacySessionMiddleware")
42
43 authheader := e.Request().Header.Get("authorization")
44 if authheader == "" {
45 return e.JSON(401, map[string]string{"error": "Unauthorized"})
46 }
47
48 pts := strings.Split(authheader, " ")
49 if len(pts) != 2 {
50 return helpers.ServerError(e, nil)
51 }
52
53 // move on to oauth session middleware if this is a dpop token
54 if pts[0] == "DPoP" {
55 return next(e)
56 }
57
58 tokenstr := pts[1]
59 token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
60 claims, ok := token.Claims.(jwt.MapClaims)
61 if !ok {
62 return helpers.InvalidTokenError(e)
63 }
64
65 var did string
66 var repo *models.RepoActor
67
68 // service auth tokens
69 lxm, hasLxm := claims["lxm"]
70 if hasLxm {
71 pts := strings.Split(e.Request().URL.String(), "/")
72 if lxm != pts[len(pts)-1] {
73 logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
74 return helpers.InputError(e, nil)
75 }
76
77 maybeDid, ok := claims["iss"].(string)
78 if !ok {
79 logger.Error("no iss in service auth token", "error", err)
80 return helpers.InputError(e, nil)
81 }
82 did = maybeDid
83
84 maybeRepo, err := s.getRepoActorByDid(ctx, did)
85 if err != nil {
86 logger.Error("error fetching repo", "error", err)
87 return helpers.ServerError(e, nil)
88 }
89 repo = maybeRepo
90 }
91
92 if token.Header["alg"] != "ES256K" {
93 token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
94 if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
95 return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
96 }
97 return s.privateKey.Public(), nil
98 })
99 if err != nil {
100 logger.Error("error parsing jwt", "error", err)
101 return helpers.ExpiredTokenError(e)
102 }
103
104 if !token.Valid {
105 return helpers.InvalidTokenError(e)
106 }
107 } else {
108 kpts := strings.Split(tokenstr, ".")
109 signingInput := kpts[0] + "." + kpts[1]
110 hash := sha256.Sum256([]byte(signingInput))
111 sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
112 if err != nil {
113 logger.Error("error decoding signature bytes", "error", err)
114 return helpers.ServerError(e, nil)
115 }
116
117 if len(sigBytes) != 64 {
118 logger.Error("incorrect sigbytes length", "length", len(sigBytes))
119 return helpers.ServerError(e, nil)
120 }
121
122 rBytes := sigBytes[:32]
123 sBytes := sigBytes[32:]
124 rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
125 ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
126
127 if repo == nil {
128 sub, ok := claims["sub"].(string)
129 if !ok {
130 s.logger.Error("no sub claim in ES256K token and repo not set")
131 return helpers.InvalidTokenError(e)
132 }
133 maybeRepo, err := s.getRepoActorByDid(ctx, sub)
134 if err != nil {
135 s.logger.Error("error fetching repo for ES256K verification", "error", err)
136 return helpers.ServerError(e, nil)
137 }
138 repo = maybeRepo
139 did = sub
140 }
141
142 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
143 if err != nil {
144 logger.Error("can't load private key", "error", err)
145 return err
146 }
147
148 pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
149 if !ok {
150 logger.Error("error getting public key from sk")
151 return helpers.ServerError(e, nil)
152 }
153
154 verified := pubKey.VerifyRaw(hash[:], rr, ss)
155 if !verified {
156 logger.Error("error verifying", "error", err)
157 return helpers.ServerError(e, nil)
158 }
159 }
160
161 isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
162 scope, _ := claims["scope"].(string)
163
164 if isRefresh && scope != "com.atproto.refresh" {
165 return helpers.InvalidTokenError(e)
166 } else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
167 return helpers.InvalidTokenError(e)
168 }
169
170 table := "tokens"
171 if isRefresh {
172 table = "refresh_tokens"
173 }
174
175 if isRefresh {
176 type Result struct {
177 Found bool
178 }
179 var result Result
180 if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
181 if err == gorm.ErrRecordNotFound {
182 return helpers.InvalidTokenError(e)
183 }
184
185 logger.Error("error getting token from db", "error", err)
186 return helpers.ServerError(e, nil)
187 }
188
189 if !result.Found {
190 return helpers.InvalidTokenError(e)
191 }
192 }
193
194 exp, ok := claims["exp"].(float64)
195 if !ok {
196 logger.Error("error getting iat from token")
197 return helpers.ServerError(e, nil)
198 }
199
200 if exp < float64(time.Now().UTC().Unix()) {
201 return helpers.ExpiredTokenError(e)
202 }
203
204 if repo == nil {
205 maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string))
206 if err != nil {
207 logger.Error("error fetching repo", "error", err)
208 return helpers.ServerError(e, nil)
209 }
210 repo = maybeRepo
211 did = repo.Repo.Did
212 }
213
214 e.Set("repo", repo)
215 e.Set("did", did)
216 e.Set("token", tokenstr)
217
218 if err := next(e); err != nil {
219 return helpers.InvalidTokenError(e)
220 }
221
222 return nil
223 }
224}
225
226func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
227 return func(e echo.Context) error {
228 ctx := e.Request().Context()
229 logger := s.logger.With("name", "handleOauthSessionMiddleware")
230
231 authheader := e.Request().Header.Get("authorization")
232 if authheader == "" {
233 return e.JSON(401, map[string]string{"error": "Unauthorized"})
234 }
235
236 pts := strings.Split(authheader, " ")
237 if len(pts) != 2 {
238 return helpers.ServerError(e, nil)
239 }
240
241 if pts[0] != "DPoP" {
242 return next(e)
243 }
244
245 accessToken := pts[1]
246
247 nonce := s.oauthProvider.NextNonce()
248 if nonce != "" {
249 e.Response().Header().Set("DPoP-Nonce", nonce)
250 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce")
251 }
252
253 proof, err := s.oauthProvider.DpopManager.CheckProof(e.Request().Method, "https://"+s.config.Hostname+e.Request().URL.String(), e.Request().Header, to.StringPtr(accessToken))
254 if err != nil {
255 if errors.Is(err, dpop.ErrUseDpopNonce) {
256 e.Response().Header().Set("WWW-Authenticate", `DPoP error="use_dpop_nonce"`)
257 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
258 return e.JSON(401, map[string]string{
259 "error": "use_dpop_nonce",
260 })
261 }
262 logger.Error("invalid dpop proof", "error", err)
263 return helpers.InputError(e, nil)
264 }
265
266 var oauthToken provider.OauthToken
267 if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
268 logger.Error("error finding access token in db", "error", err)
269 return helpers.InputError(e, nil)
270 }
271
272 if oauthToken.Token == "" {
273 return helpers.InvalidTokenError(e)
274 }
275
276 if *oauthToken.Parameters.DpopJkt != proof.JKT {
277 logger.Error("jkt mismatch", "token", oauthToken.Parameters.DpopJkt, "proof", proof.JKT)
278 return helpers.InputError(e, to.StringPtr("dpop jkt mismatch"))
279 }
280
281 if time.Now().After(oauthToken.ExpiresAt) {
282 e.Response().Header().Set("WWW-Authenticate", `DPoP error="invalid_token", error_description="Token expired"`)
283 e.Response().Header().Add("access-control-expose-headers", "WWW-Authenticate")
284 return e.JSON(401, map[string]string{
285 "error": "invalid_token",
286 "error_description": "Token expired",
287 })
288 }
289
290 repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub)
291 if err != nil {
292 logger.Error("could not find actor in db", "error", err)
293 return helpers.ServerError(e, nil)
294 }
295
296 e.Set("repo", repo)
297 e.Set("did", repo.Repo.Did)
298 e.Set("token", accessToken)
299 e.Set("scopes", strings.Split(oauthToken.Parameters.Scope, " "))
300
301 return next(e)
302 }
303}