An atproto PDS written in Go
103
fork

Configure Feed

Select the types of activity you want to include in your feed.

at 3bff74e02ff12b3451194961757a3bc6c8a3afbb 114 lines 2.9 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strings" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20type ServerGetServiceAuthRequest struct { 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 // exp should be a float, as some clients will send a non-integer expiration 23 Exp float64 `query:"exp"` 24 Lxm string `query:"lxm" validate:"required,atproto-nsid"` 25} 26 27func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 28 var req ServerGetServiceAuthRequest 29 if err := e.Bind(&req); err != nil { 30 s.logger.Error("could not bind service auth request", "error", err) 31 return helpers.ServerError(e, nil) 32 } 33 34 if err := e.Validate(req); err != nil { 35 return helpers.InputError(e, nil) 36 } 37 38 exp := int64(req.Exp) 39 now := time.Now().Unix() 40 if exp == 0 { 41 exp = now + 60 // default 42 } 43 44 if req.Lxm == "com.atproto.server.getServiceAuth" { 45 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 46 } 47 48 maxExp := now + (60 * 30) 49 if exp > maxExp { 50 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 51 } 52 53 repo := e.Get("repo").(*models.RepoActor) 54 55 header := map[string]string{ 56 "alg": "ES256K", 57 "crv": "secp256k1", 58 "typ": "JWT", 59 } 60 hj, err := json.Marshal(header) 61 if err != nil { 62 s.logger.Error("error marshaling header", "error", err) 63 return helpers.ServerError(e, nil) 64 } 65 66 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 67 68 payload := map[string]any{ 69 "iss": repo.Repo.Did, 70 "aud": req.Aud, 71 "lxm": req.Lxm, 72 "jti": uuid.NewString(), 73 "exp": exp, 74 "iat": now, 75 } 76 pj, err := json.Marshal(payload) 77 if err != nil { 78 s.logger.Error("error marashaling payload", "error", err) 79 return helpers.ServerError(e, nil) 80 } 81 82 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 83 84 input := fmt.Sprintf("%s.%s", encheader, encpayload) 85 hash := sha256.Sum256([]byte(input)) 86 87 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 88 if err != nil { 89 s.logger.Error("can't load private key", "error", err) 90 return err 91 } 92 93 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 94 if err != nil { 95 s.logger.Error("error signing", "error", err) 96 return helpers.ServerError(e, nil) 97 } 98 99 rBytes := R.Bytes() 100 sBytes := S.Bytes() 101 102 rPadded := make([]byte, 32) 103 sPadded := make([]byte, 32) 104 copy(rPadded[32-len(rBytes):], rBytes) 105 copy(sPadded[32-len(sBytes):], sBytes) 106 107 rawsig := append(rPadded, sPadded...) 108 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 109 token := fmt.Sprintf("%s.%s", input, encsig) 110 111 return e.JSON(200, map[string]string{ 112 "token": token, 113 }) 114}