+111
-31
server/server.go
+111
-31
server/server.go
···
4
4
"bytes"
5
5
"context"
6
6
"crypto/ecdsa"
7
+
"crypto/sha256"
7
8
"embed"
9
+
"encoding/base64"
8
10
"errors"
9
11
"fmt"
10
12
"io"
···
45
47
"github.com/labstack/echo/v4"
46
48
"github.com/labstack/echo/v4/middleware"
47
49
slogecho "github.com/samber/slog-echo"
50
+
"gitlab.com/yawning/secp256k1-voi"
51
+
secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
48
52
"gorm.io/driver/sqlite"
49
53
"gorm.io/gorm"
50
54
)
···
220
224
return helpers.ServerError(e, nil)
221
225
}
222
226
227
+
// move on to oauth session middleware if this is a dpop token
223
228
if pts[0] == "DPoP" {
224
229
return next(e)
225
230
}
226
231
227
232
tokenstr := pts[1]
233
+
token, _, err := new(jwt.Parser).ParseUnverified(tokenstr, jwt.MapClaims{})
234
+
claims, ok := token.Claims.(jwt.MapClaims)
235
+
if !ok {
236
+
return helpers.InputError(e, to.StringPtr("InvalidToken"))
237
+
}
238
+
239
+
var did string
240
+
var repo *models.RepoActor
228
241
229
-
token, err := new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
230
-
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
231
-
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
242
+
// service auth tokens
243
+
lxm, hasLxm := claims["lxm"]
244
+
if hasLxm {
245
+
pts := strings.Split(e.Request().URL.String(), "/")
246
+
if lxm != pts[len(pts)-1] {
247
+
s.logger.Error("service auth lxm incorrect", "lxm", lxm, "expected", pts[len(pts)-1], "error", err)
248
+
return helpers.InputError(e, nil)
232
249
}
233
250
234
-
return s.privateKey.Public(), nil
235
-
})
236
-
if err != nil {
237
-
s.logger.Error("error parsing jwt", "error", err)
238
-
// NOTE: https://github.com/bluesky-social/atproto/discussions/3319
239
-
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
251
+
maybeDid, ok := claims["iss"].(string)
252
+
if !ok {
253
+
s.logger.Error("no iss in service auth token", "error", err)
254
+
return helpers.InputError(e, nil)
255
+
}
256
+
did = maybeDid
257
+
258
+
maybeRepo, err := s.getRepoActorByDid(did)
259
+
if err != nil {
260
+
s.logger.Error("error fetching repo", "error", err)
261
+
return helpers.ServerError(e, nil)
262
+
}
263
+
repo = maybeRepo
240
264
}
241
265
242
-
claims, ok := token.Claims.(jwt.MapClaims)
243
-
if !ok || !token.Valid {
244
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
266
+
if token.Header["alg"] != "ES256K" {
267
+
token, err = new(jwt.Parser).Parse(tokenstr, func(t *jwt.Token) (any, error) {
268
+
if _, ok := t.Method.(*jwt.SigningMethodECDSA); !ok {
269
+
return nil, fmt.Errorf("unsupported signing method: %v", t.Header["alg"])
270
+
}
271
+
return s.privateKey.Public(), nil
272
+
})
273
+
if err != nil {
274
+
s.logger.Error("error parsing jwt", "error", err)
275
+
// NOTE: https://github.com/bluesky-social/atproto/discussions/3319
276
+
return e.JSON(400, map[string]string{"error": "ExpiredToken", "message": "token has expired"})
277
+
}
278
+
279
+
if !token.Valid {
280
+
return helpers.InputError(e, to.StringPtr("InvalidToken"))
281
+
}
282
+
} else {
283
+
kpts := strings.Split(tokenstr, ".")
284
+
signingInput := kpts[0] + "." + kpts[1]
285
+
hash := sha256.Sum256([]byte(signingInput))
286
+
sigBytes, err := base64.RawURLEncoding.DecodeString(kpts[2])
287
+
if err != nil {
288
+
s.logger.Error("error decoding signature bytes", "error", err)
289
+
return helpers.ServerError(e, nil)
290
+
}
291
+
292
+
if len(sigBytes) != 64 {
293
+
s.logger.Error("incorrect sigbytes length", "length", len(sigBytes))
294
+
return helpers.ServerError(e, nil)
295
+
}
296
+
297
+
rBytes := sigBytes[:32]
298
+
sBytes := sigBytes[32:]
299
+
rr, _ := secp256k1.NewScalarFromBytes((*[32]byte)(rBytes))
300
+
ss, _ := secp256k1.NewScalarFromBytes((*[32]byte)(sBytes))
301
+
302
+
sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
303
+
if err != nil {
304
+
s.logger.Error("can't load private key", "error", err)
305
+
return err
306
+
}
307
+
308
+
pubKey, ok := sk.Public().(*secp256k1secec.PublicKey)
309
+
if !ok {
310
+
s.logger.Error("error getting public key from sk")
311
+
return helpers.ServerError(e, nil)
312
+
}
313
+
314
+
verified := pubKey.VerifyRaw(hash[:], rr, ss)
315
+
if !verified {
316
+
s.logger.Error("error verifying", "error", err)
317
+
return helpers.ServerError(e, nil)
318
+
}
245
319
}
246
320
247
321
isRefresh := e.Request().URL.Path == "/xrpc/com.atproto.server.refreshSession"
248
-
scope := claims["scope"].(string)
322
+
scope, _ := claims["scope"].(string)
249
323
250
324
if isRefresh && scope != "com.atproto.refresh" {
251
325
return helpers.InputError(e, to.StringPtr("InvalidToken"))
252
-
} else if !isRefresh && scope != "com.atproto.access" {
326
+
} else if !hasLxm && !isRefresh && scope != "com.atproto.access" {
253
327
return helpers.InputError(e, to.StringPtr("InvalidToken"))
254
328
}
255
329
···
258
332
table = "refresh_tokens"
259
333
}
260
334
261
-
type Result struct {
262
-
Found bool
263
-
}
264
-
var result Result
265
-
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
266
-
if err == gorm.ErrRecordNotFound {
267
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
335
+
if isRefresh {
336
+
type Result struct {
337
+
Found bool
268
338
}
339
+
var result Result
340
+
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
341
+
if err == gorm.ErrRecordNotFound {
342
+
return helpers.InputError(e, to.StringPtr("InvalidToken"))
343
+
}
269
344
270
-
s.logger.Error("error getting token from db", "error", err)
271
-
return helpers.ServerError(e, nil)
272
-
}
345
+
s.logger.Error("error getting token from db", "error", err)
346
+
return helpers.ServerError(e, nil)
347
+
}
273
348
274
-
if !result.Found {
275
-
return helpers.InputError(e, to.StringPtr("InvalidToken"))
349
+
if !result.Found {
350
+
return helpers.InputError(e, to.StringPtr("InvalidToken"))
351
+
}
276
352
}
277
353
278
354
exp, ok := claims["exp"].(float64)
···
285
361
return helpers.InputError(e, to.StringPtr("ExpiredToken"))
286
362
}
287
363
288
-
repo, err := s.getRepoActorByDid(claims["sub"].(string))
289
-
if err != nil {
290
-
s.logger.Error("error fetching repo", "error", err)
291
-
return helpers.ServerError(e, nil)
364
+
if repo == nil {
365
+
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
366
+
if err != nil {
367
+
s.logger.Error("error fetching repo", "error", err)
368
+
return helpers.ServerError(e, nil)
369
+
}
370
+
repo = maybeRepo
371
+
did = repo.Repo.Did
292
372
}
293
373
294
374
e.Set("repo", repo)
295
-
e.Set("did", claims["sub"])
375
+
e.Set("did", did)
296
376
e.Set("token", tokenstr)
297
377
298
378
if err := next(e); err != nil {