forked from
hailey.at/cocoon
fork
Configure Feed
Select the types of activity you want to include in your feed.
An atproto PDS written in Go
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package server
2
3import (
4 "crypto/rand"
5 "crypto/sha256"
6 "encoding/base64"
7 "encoding/json"
8 "fmt"
9 "net/http"
10 "strings"
11 "time"
12
13 "github.com/google/uuid"
14 "github.com/haileyok/cocoon/internal/helpers"
15 "github.com/haileyok/cocoon/models"
16 "github.com/labstack/echo/v4"
17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec"
18)
19
20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) {
21 svc := e.Request().Header.Get("atproto-proxy")
22 if svc == "" {
23 svc = s.config.DefaultAtprotoProxy
24 }
25
26 svcPts := strings.Split(svc, "#")
27 if len(svcPts) != 2 {
28 return "", "", fmt.Errorf("invalid service header")
29 }
30
31 svcDid := svcPts[0]
32 svcId := "#" + svcPts[1]
33
34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid)
35 if err != nil {
36 return "", "", err
37 }
38
39 var endpoint string
40 for _, s := range doc.Service {
41 if s.Id == svcId {
42 endpoint = s.ServiceEndpoint
43 }
44 }
45
46 return endpoint, svcDid, nil
47}
48
49func (s *Server) handleProxy(e echo.Context) error {
50 lgr := s.logger.With("handler", "handleProxy")
51
52 repo, isAuthed := e.Get("repo").(*models.RepoActor)
53
54 pts := strings.Split(e.Request().URL.Path, "/")
55 if len(pts) != 3 {
56 return fmt.Errorf("incorrect number of parts")
57 }
58
59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e)
60 if err != nil {
61 lgr.Error("could not get atproto proxy", "error", err)
62 return helpers.ServerError(e, nil)
63 }
64
65 requrl := e.Request().URL
66 requrl.Host = strings.TrimPrefix(endpoint, "https://")
67 requrl.Scheme = "https"
68
69 body := e.Request().Body
70 if e.Request().Method == "GET" {
71 body = nil
72 }
73
74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body)
75 if err != nil {
76 return err
77 }
78
79 req.Header = e.Request().Header.Clone()
80
81 if isAuthed {
82 // this is a little dumb. i should probably figure out a better way to do this, and use
83 // a single way of creating/signing jwts throughout the pds. kinda limited here because
84 // im using the atproto crypto lib for this though. will come back to it
85
86 header := map[string]string{
87 "alg": "ES256K",
88 "crv": "secp256k1",
89 "typ": "JWT",
90 }
91 hj, err := json.Marshal(header)
92 if err != nil {
93 lgr.Error("error marshaling header", "error", err)
94 return helpers.ServerError(e, nil)
95 }
96
97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=")
98
99 payload := map[string]any{
100 "iss": repo.Repo.Did,
101 "aud": svcDid,
102 "lxm": pts[2],
103 "jti": uuid.NewString(),
104 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(),
105 }
106 pj, err := json.Marshal(payload)
107 if err != nil {
108 lgr.Error("error marashaling payload", "error", err)
109 return helpers.ServerError(e, nil)
110 }
111
112 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=")
113
114 input := fmt.Sprintf("%s.%s", encheader, encpayload)
115 hash := sha256.Sum256([]byte(input))
116
117 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey)
118 if err != nil {
119 lgr.Error("can't load private key", "error", err)
120 return err
121 }
122
123 R, S, _, err := sk.SignRaw(rand.Reader, hash[:])
124 if err != nil {
125 lgr.Error("error signing", "error", err)
126 }
127
128 rBytes := R.Bytes()
129 sBytes := S.Bytes()
130
131 rPadded := make([]byte, 32)
132 sPadded := make([]byte, 32)
133 copy(rPadded[32-len(rBytes):], rBytes)
134 copy(sPadded[32-len(sBytes):], sBytes)
135
136 rawsig := append(rPadded, sPadded...)
137 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=")
138 token := fmt.Sprintf("%s.%s", input, encsig)
139
140 req.Header.Set("authorization", "Bearer "+token)
141 } else {
142 req.Header.Del("authorization")
143 }
144
145 resp, err := http.DefaultClient.Do(req)
146 if err != nil {
147 return err
148 }
149 defer resp.Body.Close()
150
151 for k, v := range resp.Header {
152 e.Response().Header().Set(k, strings.Join(v, ","))
153 }
154
155 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body)
156}