forked from hailey.at/cocoon
An atproto PDS written in Go
at main 3.0 kB view raw
1package server 2 3import ( 4 "crypto/rand" 5 "crypto/sha256" 6 "encoding/base64" 7 "encoding/json" 8 "fmt" 9 "strings" 10 "time" 11 12 "github.com/Azure/go-autorest/autorest/to" 13 "github.com/google/uuid" 14 "github.com/haileyok/cocoon/internal/helpers" 15 "github.com/haileyok/cocoon/models" 16 "github.com/labstack/echo/v4" 17 secp256k1secec "gitlab.com/yawning/secp256k1-voi/secec" 18) 19 20type ServerGetServiceAuthRequest struct { 21 Aud string `query:"aud" validate:"required,atproto-did"` 22 // exp should be a float, as some clients will send a non-integer expiration 23 Exp float64 `query:"exp"` 24 Lxm string `query:"lxm"` 25} 26 27func (s *Server) handleServerGetServiceAuth(e echo.Context) error { 28 logger := s.logger.With("name", "handleServerGetServiceAuth") 29 30 var req ServerGetServiceAuthRequest 31 if err := e.Bind(&req); err != nil { 32 logger.Error("could not bind service auth request", "error", err) 33 return helpers.ServerError(e, nil) 34 } 35 36 if err := e.Validate(req); err != nil { 37 return helpers.InputError(e, nil) 38 } 39 40 exp := int64(req.Exp) 41 now := time.Now().Unix() 42 if exp == 0 { 43 exp = now + 60 // default 44 } 45 46 if req.Lxm == "com.atproto.server.getServiceAuth" { 47 return helpers.InputError(e, to.StringPtr("may not generate auth tokens recursively")) 48 } 49 50 var maxExp int64 51 if req.Lxm != "" { 52 maxExp = now + (60 * 60) 53 } else { 54 maxExp = now + 60 55 } 56 if exp > maxExp { 57 return helpers.InputError(e, to.StringPtr("expiration too big. smoller please")) 58 } 59 60 repo := e.Get("repo").(*models.RepoActor) 61 62 header := map[string]string{ 63 "alg": "ES256K", 64 "crv": "secp256k1", 65 "typ": "JWT", 66 } 67 hj, err := json.Marshal(header) 68 if err != nil { 69 logger.Error("error marshaling header", "error", err) 70 return helpers.ServerError(e, nil) 71 } 72 73 encheader := strings.TrimRight(base64.RawURLEncoding.EncodeToString(hj), "=") 74 75 payload := map[string]any{ 76 "iss": repo.Repo.Did, 77 "aud": req.Aud, 78 "jti": uuid.NewString(), 79 "exp": exp, 80 "iat": now, 81 } 82 if req.Lxm != "" { 83 payload["lxm"] = req.Lxm 84 } 85 pj, err := json.Marshal(payload) 86 if err != nil { 87 logger.Error("error marashaling payload", "error", err) 88 return helpers.ServerError(e, nil) 89 } 90 91 encpayload := strings.TrimRight(base64.RawURLEncoding.EncodeToString(pj), "=") 92 93 input := fmt.Sprintf("%s.%s", encheader, encpayload) 94 hash := sha256.Sum256([]byte(input)) 95 96 sk, err := secp256k1secec.NewPrivateKey(repo.SigningKey) 97 if err != nil { 98 logger.Error("can't load private key", "error", err) 99 return err 100 } 101 102 R, S, _, err := sk.SignRaw(rand.Reader, hash[:]) 103 if err != nil { 104 logger.Error("error signing", "error", err) 105 return helpers.ServerError(e, nil) 106 } 107 108 rBytes := R.Bytes() 109 sBytes := S.Bytes() 110 111 rPadded := make([]byte, 32) 112 sPadded := make([]byte, 32) 113 copy(rPadded[32-len(rBytes):], rBytes) 114 copy(sPadded[32-len(sBytes):], sBytes) 115 116 rawsig := append(rPadded, sPadded...) 117 encsig := strings.TrimRight(base64.RawURLEncoding.EncodeToString(rawsig), "=") 118 token := fmt.Sprintf("%s.%s", input, encsig) 119 120 return e.JSON(200, map[string]string{ 121 "token": token, 122 }) 123}