1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "strings"
10 "time"
11
12 "github.com/Azure/go-autorest/autorest/to"
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20type ServerGetServiceAuthRequest struct {
21 Aud string `query:"aud" validate:"required,atproto-did"`
22 // exp should be a float, as some clients will send a non-integer expiration
23 Exp float64 `query:"exp"`
24 Lxm string `query:"lxm"`
25}
26
27func (s *Server) handleServerGetServiceAuth(e echo.Context) error {
28 logger := s.logger.With("name", "handleServerGetServiceAuth")
29
30 var req ServerGetServiceAuthRequest
31 if err := e.Bind(&req); err != nil {
32 logger.Error("could not bind service auth request", "error", err)
33 return helpers.ServerError(e, nil)
34 }
35
36 if err := e.Validate(req); err != nil {
37 return helpers.InputError(e, nil)
38 }
39
40 exp := int64(req.Exp)
41 now := time.Now().Unix()
42 if exp == 0 {
43 exp = now + 60 // default
44 }
45
46 if req.Lxm == "com.atproto.server.getServiceAuth" {
47 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively"))
48 }
49
50 var maxExp int64
51 if req.Lxm != "" {
52 maxExp = now + (60 * 60)
53 } else {
54 maxExp = now + 60
55 }
56 if exp > maxExp {
57 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please"))
58 }
59
60 repo := e.Get("repo").(*models.RepoActor)
61
62 header := map[string]string{
63 "alg": "ES256K",
64 "crv": "secp256k1",
65 "typ": "JWT",
66 }
67 hj, err := json.Marshal(header)
68 if err != nil {
69 logger.Error("error marshaling header", "error", err)
70 return helpers.ServerError(e, nil)
71 }
72
73 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
74
75 payload := map[string]any{
76 "iss": repo.Repo.Did,
77 "aud": req.Aud,
78 "jti": uuid.NewString(),
79 "exp": exp,
80 "iat": now,
81 }
82 if req.Lxm != "" {
83 payload["lxm"] = req.Lxm
84 }
85 pj, err := json.Marshal(payload)
86 if err != nil {
87 logger.Error("error marashaling payload", "error", err)
88 return helpers.ServerError(e, nil)
89 }
90
91 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
92
93 input := fmt.Sprintf("%s.%s", encheader, encpayload)
94 hash := sha256.Sum256([]byte(input))
95
96 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
97 if err != nil {
98 logger.Error("can't load private key", "error", err)
99 return err
100 }
101
102 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
103 if err != nil {
104 logger.Error("error signing", "error", err)
105 return helpers.ServerError(e, nil)
106 }
107
108 rBytes := R.Bytes()
109 sBytes := S.Bytes()
110
111 rPadded := make([]byte, 32)
112 sPadded := make([]byte, 32)
113 copy(rPadded[32-len(rBytes):], rBytes)
114 copy(sPadded[32-len(sBytes):], sBytes)
115
116 rawsig := append(rPadded, sPadded...)
117 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
118 token := fmt.Sprintf("%s.%s", input, encsig)
119
120 return e.JSON(200, map[string]string{
121 "token": token,
122 })
123}