An atproto PDS written in Go
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at v0.0.2 144 lines 3.5 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "net/http" 10 "strings" 11 "time" 12 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20func (s *Server) handleProxy(e echo.Context) error { 21 repo, isAuthed := e.Get("repo").(*models.RepoActor) 22 23 pts := strings.Split(e.Request().URL.Path, "/") 24 if len(pts) != 3 { 25 return fmt.Errorf("incorrect number of parts") 26 } 27 28 svc := e.Request().Header.Get("atproto-proxy") 29 if svc == "" { 30 svc = "did:web:api.bsky.app#bsky_appview" // TODO: should be a config var probably 31 } 32 33 svcPts := strings.Split(svc, "#") 34 if len(svcPts) != 2 { 35 return fmt.Errorf("invalid service header") 36 } 37 38 svcDid := svcPts[0] 39 svcId := "#" + svcPts[1] 40 41 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 42 if err != nil { 43 return err 44 } 45 46 var endpoint string 47 for _, s := range doc.Service { 48 if s.Id == svcId { 49 endpoint = s.ServiceEndpoint 50 } 51 } 52 53 requrl := e.Request().URL 54 requrl.Host = strings.TrimPrefix(endpoint, "https://") 55 requrl.Scheme = "https" 56 57 body := e.Request().Body 58 if e.Request().Method == "GET" { 59 body = nil 60 } 61 62 req, err := http.NewRequest(e.Request().Method, requrl.String(), body) 63 if err != nil { 64 return err 65 } 66 67 req.Header = e.Request().Header.Clone() 68 69 if isAuthed { 70 // this is a little dumb. i should probably figure out a better way to do this, and use 71 // a single way of creating/signing jwts throughout the pds. kinda limited here because 72 // im using the atproto crypto lib for this though. will come back to it 73 74 header := map[string]string{ 75 "alg": "ES256K", 76 "crv": "secp256k1", 77 "typ": "JWT", 78 } 79 hj, err := json.Marshal(header) 80 if err != nil { 81 s.logger.Error("error marshaling header", "error", err) 82 return helpers.ServerError(e, nil) 83 } 84 85 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 86 87 payload := map[string]any{ 88 "iss": repo.Repo.Did, 89 "aud": svcDid, 90 "lxm": pts[2], 91 "jti": uuid.NewString(), 92 "exp": time.Now().Add(1 * time.Minute).UTC().Unix(), 93 } 94 pj, err := json.Marshal(payload) 95 if err != nil { 96 s.logger.Error("error marashaling payload", "error", err) 97 return helpers.ServerError(e, nil) 98 } 99 100 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 101 102 input := fmt.Sprintf("%s.%s", encheader, encpayload) 103 hash := sha256.Sum256([]byte(input)) 104 105 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 106 if err != nil { 107 s.logger.Error("can't load private key", "error", err) 108 return err 109 } 110 111 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 112 if err != nil { 113 s.logger.Error("error signing", "error", err) 114 } 115 116 rBytes := R.Bytes() 117 sBytes := S.Bytes() 118 119 rPadded := make([]byte, 32) 120 sPadded := make([]byte, 32) 121 copy(rPadded[32-len(rBytes):], rBytes) 122 copy(sPadded[32-len(sBytes):], sBytes) 123 124 rawsig := append(rPadded, sPadded...) 125 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 126 token := fmt.Sprintf("%s.%s", input, encsig) 127 128 req.Header.Set("authorization", "Bearer "+token) 129 } else { 130 req.Header.Del("authorization") 131 } 132 133 resp, err := http.DefaultClient.Do(req) 134 if err != nil { 135 return err 136 } 137 defer resp.Body.Close() 138 139 for k, v := range resp.Header { 140 e.Response().Header().Set(k, strings.Join(v, ",")) 141 } 142 143 return e.Stream(resp.StatusCode, e.Response().Header().Get("content-type"), resp.Body) 144}