1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/api/atproto"
12 "github.com/bluesky-social/indigo/atproto/atcrypto"
13 "github.com/bluesky-social/indigo/events"
14 "github.com/bluesky-social/indigo/repo"
15 "github.com/bluesky-social/indigo/util"
16 "github.com/haileyok/cocoon/internal/helpers"
17 "github.com/haileyok/cocoon/models"
18 "github.com/labstack/echo/v4"
19 "golang.org/x/crypto/bcrypt"
20 "gorm.io/gorm"
21)
22
23type ComAtprotoServerCreateAccountRequest struct {
24 Email string `json:"email" validate:"required,email"`
25 Handle string `json:"handle" validate:"required,atproto-handle"`
26 Did *string `json:"did" validate:"atproto-did"`
27 Password string `json:"password" validate:"required"`
28 InviteCode string `json:"inviteCode" validate:"omitempty"`
29}
30
31type ComAtprotoServerCreateAccountResponse struct {
32 AccessJwt string `json:"accessJwt"`
33 RefreshJwt string `json:"refreshJwt"`
34 Handle string `json:"handle"`
35 Did string `json:"did"`
36}
37
38func (s *Server) handleCreateAccount(e echo.Context) error {
39 ctx := e.Request().Context()
40 logger := s.logger.With("name", "handleServerCreateAccount")
41
42 var request ComAtprotoServerCreateAccountRequest
43
44 if err := e.Bind(&request); err != nil {
45 logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err)
46 return helpers.ServerError(e, nil)
47 }
48
49 request.Handle = strings.ToLower(request.Handle)
50
51 if err := e.Validate(request); err != nil {
52 logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err)
53
54 var verr ValidationError
55 if errors.As(err, &verr) {
56 if verr.Field == "Email" {
57 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc
58 return helpers.InputError(e, to.StringPtr("InvalidEmail"))
59 }
60
61 if verr.Field == "Handle" {
62 return helpers.InputError(e, to.StringPtr("InvalidHandle"))
63 }
64
65 if verr.Field == "Password" {
66 return helpers.InputError(e, to.StringPtr("InvalidPassword"))
67 }
68
69 if verr.Field == "InviteCode" {
70 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
71 }
72 }
73 }
74
75 var signupDid string
76 if request.Did != nil {
77 signupDid = *request.Did
78
79 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
80 if token == "" {
81 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
82 }
83 authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount")
84
85 if err != nil {
86 logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err)
87 return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token"))
88 }
89
90 if authDid != signupDid {
91 return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did"))
92 }
93 }
94
95 // see if the handle is already taken
96 actor, err := s.getActorByHandle(ctx, request.Handle)
97 if err != nil && err != gorm.ErrRecordNotFound {
98 logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
99 return helpers.ServerError(e, nil)
100 }
101 if err == nil && actor.Did != signupDid {
102 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
103 }
104
105 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid {
106 return helpers.InputError(e, to.StringPtr("HandleNotAvailable"))
107 }
108
109 var ic models.InviteCode
110 if s.config.RequireInvite {
111 if strings.TrimSpace(request.InviteCode) == "" {
112 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
113 }
114
115 if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
116 if err == gorm.ErrRecordNotFound {
117 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
118 }
119 logger.Error("error getting invite code from db", "error", err)
120 return helpers.ServerError(e, nil)
121 }
122
123 if ic.RemainingUseCount < 1 {
124 return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
125 }
126 }
127
128 // see if the email is already taken
129 existingRepo, err := s.getRepoByEmail(ctx, request.Email)
130 if err != nil && err != gorm.ErrRecordNotFound {
131 logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
132 return helpers.ServerError(e, nil)
133 }
134 if err == nil && existingRepo.Did != signupDid {
135 return helpers.InputError(e, to.StringPtr("EmailNotAvailable"))
136 }
137
138 // TODO: unsupported domains
139
140 var k *atcrypto.PrivateKeyK256
141
142 if signupDid != "" {
143 reservedKey, err := s.getReservedKey(ctx, signupDid)
144 if err != nil {
145 logger.Error("error looking up reserved key", "error", err)
146 }
147 if reservedKey != nil {
148 k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey)
149 if err != nil {
150 logger.Error("error parsing reserved key", "error", err)
151 k = nil
152 } else {
153 defer func() {
154 if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil {
155 logger.Error("error deleting reserved key", "error", delErr)
156 }
157 }()
158 }
159 }
160 }
161
162 if k == nil {
163 k, err = atcrypto.GeneratePrivateKeyK256()
164 if err != nil {
165 logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err)
166 return helpers.ServerError(e, nil)
167 }
168 }
169
170 if signupDid == "" {
171 did, op, err := s.plcClient.CreateDID(k, "", request.Handle)
172 if err != nil {
173 logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err)
174 return helpers.ServerError(e, nil)
175 }
176
177 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil {
178 logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err)
179 return helpers.ServerError(e, nil)
180 }
181 signupDid = did
182 }
183
184 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10)
185 if err != nil {
186 logger.Error("error hashing password", "error", err)
187 return helpers.ServerError(e, nil)
188 }
189
190 urepo := models.Repo{
191 Did: signupDid,
192 CreatedAt: time.Now(),
193 Email: request.Email,
194 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))),
195 Password: string(hashed),
196 SigningKey: k.Bytes(),
197 }
198
199 if actor == nil {
200 actor = &models.Actor{
201 Did: signupDid,
202 Handle: request.Handle,
203 }
204
205 if err := s.db.Create(ctx, &urepo, nil).Error; err != nil {
206 logger.Error("error inserting new repo", "error", err)
207 return helpers.ServerError(e, nil)
208 }
209
210 if err := s.db.Create(ctx, &actor, nil).Error; err != nil {
211 logger.Error("error inserting new actor", "error", err)
212 return helpers.ServerError(e, nil)
213 }
214 } else {
215 if err := s.db.Save(ctx, &actor, nil).Error; err != nil {
216 logger.Error("error inserting new actor", "error", err)
217 return helpers.ServerError(e, nil)
218 }
219 }
220
221 if request.Did == nil || *request.Did == "" {
222 bs := s.getBlockstore(signupDid)
223 r := repo.NewRepo(context.TODO(), signupDid, bs)
224
225 root, rev, err := r.Commit(context.TODO(), urepo.SignFor)
226 if err != nil {
227 logger.Error("error committing", "error", err)
228 return helpers.ServerError(e, nil)
229 }
230
231 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil {
232 logger.Error("error updating repo after commit", "error", err)
233 return helpers.ServerError(e, nil)
234 }
235
236 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{
237 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{
238 Did: urepo.Did,
239 Handle: to.StringPtr(request.Handle),
240 Seq: time.Now().UnixMicro(), // TODO: no
241 Time: time.Now().Format(util.ISO8601),
242 },
243 })
244 }
245
246 if s.config.RequireInvite {
247 if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
248 logger.Error("error decrementing use count", "error", err)
249 return helpers.ServerError(e, nil)
250 }
251 }
252
253 sess, err := s.createSession(ctx, &urepo)
254 if err != nil {
255 logger.Error("error creating new session", "error", err)
256 return helpers.ServerError(e, nil)
257 }
258
259 go func() {
260 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil {
261 logger.Error("error sending email verification email", "error", err)
262 }
263 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil {
264 logger.Error("error sending welcome email", "error", err)
265 }
266 }()
267
268 return e.JSON(200, ComAtprotoServerCreateAccountResponse{
269 AccessJwt: sess.AccessToken,
270 RefreshJwt: sess.RefreshToken,
271 Handle: request.Handle,
272 Did: signupDid,
273 })
274}