forked from hailey.at/cocoon
An atproto PDS written in Go

cleanup proxying a bit

Changed files
+26 -14
server
+26 -14
server/handle_proxy.go
··· 17 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18 18 ) 19 19 20 - func (s *Server) handleProxy(e echo.Context) error { 21 - repo, isAuthed := e.Get("repo").(*models.RepoActor) 22 - 23 - pts := strings.Split(e.Request().URL.Path, "/") 24 - if len(pts) != 3 { 25 - return fmt.Errorf("incorrect number of parts") 26 - } 27 - 20 + func (s *Server) getAtprotoProxyEndpointFromRequest(e echo.Context) (string, string, error) { 28 21 svc := e.Request().Header.Get("atproto-proxy") 29 22 if svc == "" { 30 23 svc = s.config.DefaultAtprotoProxy ··· 32 25 33 26 svcPts := strings.Split(svc, "#") 34 27 if len(svcPts) != 2 { 35 - return fmt.Errorf("invalid service header") 28 + return "", "", fmt.Errorf("invalid service header") 36 29 } 37 30 38 31 svcDid := svcPts[0] ··· 40 33 41 34 doc, err := s.passport.FetchDoc(e.Request().Context(), svcDid) 42 35 if err != nil { 43 - return err 36 + return "", "", err 44 37 } 45 38 46 39 var endpoint string ··· 50 43 } 51 44 } 52 45 46 + return endpoint, "", nil 47 + } 48 + 49 + func (s *Server) handleProxy(e echo.Context) error { 50 + lgr := s.logger.With("handler", "handleProxy") 51 + 52 + repo, isAuthed := e.Get("repo").(*models.RepoActor) 53 + 54 + pts := strings.Split(e.Request().URL.Path, "/") 55 + if len(pts) != 3 { 56 + return fmt.Errorf("incorrect number of parts") 57 + } 58 + 59 + endpoint, svcDid, err := s.getAtprotoProxyEndpointFromRequest(e) 60 + if err != nil { 61 + lgr.Error("could not get atproto proxy", "error", err) 62 + return helpers.ServerError(e, nil) 63 + } 64 + 53 65 requrl := e.Request().URL 54 66 requrl.Host = strings.TrimPrefix(endpoint, "https://") 55 67 requrl.Scheme = "https" ··· 78 90 } 79 91 hj, err := json.Marshal(header) 80 92 if err != nil { 81 - s.logger.Error("error marshaling header", "error", err) 93 + lgr.Error("error marshaling header", "error", err) 82 94 return helpers.ServerError(e, nil) 83 95 } 84 96 ··· 93 105 } 94 106 pj, err := json.Marshal(payload) 95 107 if err != nil { 96 - s.logger.Error("error marashaling payload", "error", err) 108 + lgr.Error("error marashaling payload", "error", err) 97 109 return helpers.ServerError(e, nil) 98 110 } 99 111 ··· 104 116 105 117 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 106 118 if err != nil { 107 - s.logger.Error("can't load private key", "error", err) 119 + lgr.Error("can't load private key", "error", err) 108 120 return err 109 121 } 110 122 111 123 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 112 124 if err != nil { 113 - s.logger.Error("error signing", "error", err) 125 + lgr.Error("error signing", "error", err) 114 126 } 115 127 116 128 rBytes := R.Bytes()