forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/atcrypto"
13 atp "github.com/bluesky-social/indigo/atproto/repo"
14 "github.com/bluesky-social/indigo/atproto/repo/mst"
15 "github.com/bluesky-social/indigo/atproto/syntax"
16 "github.com/bluesky-social/indigo/events"
17 "github.com/bluesky-social/indigo/util"
18 "github.com/haileyok/cocoon/internal/helpers"
19 "github.com/haileyok/cocoon/models"
20 "github.com/labstack/echo/v4"
21 "golang.org/x/crypto/bcrypt"
22 "gorm.io/gorm"
23)
24
25type ComAtprotoServerCreateAccountRequest struct {
26 Email string `json:"email" validate:"required,email"`
27 Handle string `json:"handle" validate:"required,atproto-handle"`
28 Did *string `json:"did" validate:"atproto-did"`
29 Password string `json:"password" validate:"required"`
30 InviteCode string `json:"inviteCode" validate:"omitempty"`
31}
32
33type ComAtprotoServerCreateAccountResponse struct {
34 AccessJwt string `json:"accessJwt"`
35 RefreshJwt string `json:"refreshJwt"`
36 Handle string `json:"handle"`
37 Did string `json:"did"`
38}
39
40func (s *Server) handleCreateAccount(e echo.Context) error {
41 ctx := e.Request().Context()
42 logger := s.logger.With("name", "handleServerCreateAccount")
43
44 var request ComAtprotoServerCreateAccountRequest
45
46 if err := e.Bind(&request); err != nil {
47 logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err)
48 return helpers.ServerError(e, nil)
49 }
50
51 request.Handle = strings.ToLower(request.Handle)
52
53 if err := e.Validate(request); err != nil {
54 logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err)
55
56 var verr ValidationError
57 if errors.As(err, &verr) {
58 if verr.Field == "Email" {
59 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc
60 return helpers.InputError(e, to.StringPtr("InvalidEmail"))
61 }
62
63 if verr.Field == "Handle" {
64 return helpers.InputError(e, to.StringPtr("InvalidHandle"))
65 }
66
67 if verr.Field == "Password" {
68 return helpers.InputError(e, to.StringPtr("InvalidPassword"))
69 }
70
71 if verr.Field == "InviteCode" {
72 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
73 }
74 }
75 }
76
77 var signupDid string
78 if request.Did != nil {
79 signupDid = *request.Did
80
81 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
82 if token == "" {
83 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
84 }
85 authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount")
86
87 if err != nil {
88 logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err)
89 return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token"))
90 }
91
92 if authDid != signupDid {
93 return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did"))
94 }
95 }
96
97 // see if the handle is already taken
98 actor, err := s.getActorByHandle(ctx, request.Handle)
99 if err != nil && err != gorm.ErrRecordNotFound {
100 logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
101 return helpers.ServerError(e, nil)
102 }
103 if err == nil && actor.Did != signupDid {
104 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
105 }
106
107 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid {
108 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
109 }
110
111 var ic models.InviteCode
112 if s.config.RequireInvite {
113 if strings.TrimSpace(request.InviteCode) == "" {
114 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
115 }
116
117 if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
118 if err == gorm.ErrRecordNotFound {
119 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
120 }
121 logger.Error("error getting invite code from db", "error", err)
122 return helpers.ServerError(e, nil)
123 }
124
125 if ic.RemainingUseCount < 1 {
126 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
127 }
128 }
129
130 // see if the email is already taken
131 existingRepo, err := s.getRepoByEmail(ctx, request.Email)
132 if err != nil && err != gorm.ErrRecordNotFound {
133 logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
134 return helpers.ServerError(e, nil)
135 }
136 if err == nil && existingRepo.Did != signupDid {
137 return helpers.InputError(e, to.StringPtr("EmailNotAvailable"))
138 }
139
140 // TODO: unsupported domains
141
142 var k *atcrypto.PrivateKeyK256
143
144 if signupDid != "" {
145 reservedKey, err := s.getReservedKey(ctx, signupDid)
146 if err != nil {
147 logger.Error("error looking up reserved key", "error", err)
148 }
149 if reservedKey != nil {
150 k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey)
151 if err != nil {
152 logger.Error("error parsing reserved key", "error", err)
153 k = nil
154 } else {
155 defer func() {
156 if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil {
157 logger.Error("error deleting reserved key", "error", delErr)
158 }
159 }()
160 }
161 }
162 }
163
164 if k == nil {
165 k, err = atcrypto.GeneratePrivateKeyK256()
166 if err != nil {
167 logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
168 return helpers.ServerError(e, nil)
169 }
170 }
171
172 if signupDid == "" {
173 did, op, err := s.plcClient.CreateDID(k, "", request.Handle)
174 if err != nil {
175 logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err)
176 return helpers.ServerError(e, nil)
177 }
178
179 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil {
180 logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err)
181 return helpers.ServerError(e, nil)
182 }
183 signupDid = did
184 }
185
186 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10)
187 if err != nil {
188 logger.Error("error hashing password", "error", err)
189 return helpers.ServerError(e, nil)
190 }
191
192 urepo := models.Repo{
193 Did: signupDid,
194 CreatedAt: time.Now(),
195 Email: request.Email,
196 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))),
197 Password: string(hashed),
198 SigningKey: k.Bytes(),
199 }
200
201 if actor == nil {
202 actor = &models.Actor{
203 Did: signupDid,
204 Handle: request.Handle,
205 }
206
207 if err := s.db.Create(ctx, &urepo, nil).Error; err != nil {
208 logger.Error("error inserting new repo", "error", err)
209 return helpers.ServerError(e, nil)
210 }
211
212 if err := s.db.Create(ctx, &actor, nil).Error; err != nil {
213 logger.Error("error inserting new actor", "error", err)
214 return helpers.ServerError(e, nil)
215 }
216 } else {
217 if err := s.db.Save(ctx, &actor, nil).Error; err != nil {
218 logger.Error("error inserting new actor", "error", err)
219 return helpers.ServerError(e, nil)
220 }
221 }
222
223 if request.Did == nil || *request.Did == "" {
224 bs := s.getBlockstore(signupDid)
225
226 clk := syntax.NewTIDClock(0)
227 r := &atp.Repo{
228 DID: syntax.DID(signupDid),
229 Clock: clk,
230 MST: mst.NewEmptyTree(),
231 RecordStore: bs,
232 }
233
234 root, rev, err := commitRepo(context.TODO(), bs, r, urepo.SigningKey)
235 if err != nil {
236 logger.Error("error committing", "error", err)
237 return helpers.ServerError(e, nil)
238 }
239
240 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
241 logger.Error("error updating repo after commit", "error", err)
242 return helpers.ServerError(e, nil)
243 }
244
245 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
246 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
247 Did: urepo.Did,
248 Handle: to.StringPtr(request.Handle),
249 Seq: time.Now().UnixMicro(), // TODO: no
250 Time: time.Now().Format(util.ISO8601),
251 },
252 })
253 }
254
255 if s.config.RequireInvite {
256 if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
257 logger.Error("error decrementing use count", "error", err)
258 return helpers.ServerError(e, nil)
259 }
260 }
261
262 sess, err := s.createSession(ctx, &urepo)
263 if err != nil {
264 logger.Error("error creating new session", "error", err)
265 return helpers.ServerError(e, nil)
266 }
267
268 go func() {
269 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil {
270 logger.Error("error sending email verification email", "error", err)
271 }
272 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil {
273 logger.Error("error sending welcome email", "error", err)
274 }
275 }()
276
277 return e.JSON(200, ComAtprotoServerCreateAccountResponse{
278 AccessJwt: sess.AccessToken,
279 RefreshJwt: sess.RefreshToken,
280 Handle: request.Handle,
281 Did: signupDid,
282 })
283}