An atproto PDS written in Go
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 0.5.0 169 lines 4.4 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 21 svc := e.Request().Header.Get("atproto-proxy") 22 if svc == "" && s.config.FallbackProxy != "" { 23 svc = s.config.FallbackProxy 24 } 25 26 svcPts := strings.Split(svc, "#") 27 if len(svcPts) != 2 { 28 return "", "", fmt.Errorf("invalid service header") 29 } 30 31 svcDid := svcPts[0] 32 svcId := "#" + svcPts[1] 33 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 35 if err != nil { 36 return "", "", err 37 } 38 39 var endpoint string 40 for _, s := range doc.Service { 41 if s.Id == svcId { 42 endpoint = s.ServiceEndpoint 43 } 44 } 45 46 return endpoint, svcDid, nil 47} 48 49func (s *Server) handleProxy(e echo.Context) error { 50 lgr := s.logger.With("handler", "handleProxy") 51 52 repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 54 pts := strings.Split(e.Request().URL.Path, "/") 55 if len(pts) != 3 { 56 return fmt.Errorf("incorrect number of parts") 57 } 58 59 endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 if err != nil { 61 lgr.Error("could not get atproto proxy", "error", err) 62 return helpers.ServerError(e, nil) 63 } 64 65 requrl := e.Request().URL 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 67 requrl.Scheme = "https" 68 69 body := e.Request().Body 70 if e.Request().Method == "GET" { 71 body = nil 72 } 73 74 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 75 if err != nil { 76 return err 77 } 78 79 req.Header = e.Request().Header.Clone() 80 81 if isAuthed { 82 // this is a little dumb. i should probably figure out a better way to do this, and use 83 // a single way of creating/signing jwts throughout the pds. kinda limited here because 84 // im using the atproto crypto lib for this though. will come back to it 85 86 header := map[string]string{ 87 "alg": "ES256K", 88 "crv": "secp256k1", 89 "typ": "JWT", 90 } 91 hj, err := json.Marshal(header) 92 if err != nil { 93 lgr.Error("error marshaling header", "error", err) 94 return helpers.ServerError(e, nil) 95 } 96 97 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 98 99 // When proxying app.bsky.feed.getFeed the token is actually issued for the 100 // underlying feed generator and the app view passes it on. This allows the 101 // getFeed implementation to pass in the desired lxm and aud for the token 102 // and then just delegate to the general proxying logic 103 lxm, proxyTokenLxmExists := e.Get("proxyTokenLxm").(string) 104 if !proxyTokenLxmExists || lxm == "" { 105 lxm = pts[2] 106 } 107 aud, proxyTokenAudExists := e.Get("proxyTokenAud").(string) 108 if !proxyTokenAudExists || aud == "" { 109 aud = svcDid 110 } 111 112 payload := map[string]any{ 113 "iss": repo.Repo.Did, 114 "aud": svcDid, 115 "lxm": lxm, 116 "jti": uuid.NewString(), 117 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 118 } 119 pj, err := json.Marshal(payload) 120 if err != nil { 121 lgr.Error("error marashaling payload", "error", err) 122 return helpers.ServerError(e, nil) 123 } 124 125 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 126 127 input := fmt.Sprintf("%s.%s", encheader, encpayload) 128 hash := sha256.Sum256([]byte(input)) 129 130 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 131 if err != nil { 132 lgr.Error("can't load private key", "error", err) 133 return err 134 } 135 136 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 137 if err != nil { 138 lgr.Error("error signing", "error", err) 139 } 140 141 rBytes := R.Bytes() 142 sBytes := S.Bytes() 143 144 rPadded := make([]byte, 32) 145 sPadded := make([]byte, 32) 146 copy(rPadded[32-len(rBytes):], rBytes) 147 copy(sPadded[32-len(sBytes):], sBytes) 148 149 rawsig := append(rPadded, sPadded...) 150 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 151 token := fmt.Sprintf("%s.%s", input, encsig) 152 153 req.Header.Set("authorization", "Bearer "+token) 154 } else { 155 req.Header.Del("authorization") 156 } 157 158 resp, err := http.DefaultClient.Do(req) 159 if err != nil { 160 return err 161 } 162 defer resp.Body.Close() 163 164 for k, v := range resp.Header { 165 e.Response().Header().Set(k, strings.Join(v, ",")) 166 } 167 168 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 169}