forked from hailey.at/cocoon
An atproto PDS written in Go
at main 8.9 kB view raw
1package server 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/bluesky-social/indigo/api/atproto" 12 "github.com/bluesky-social/indigo/atproto/atcrypto" 13 "github.com/bluesky-social/indigo/events" 14 "github.com/bluesky-social/indigo/repo" 15 "github.com/bluesky-social/indigo/util" 16 "github.com/haileyok/cocoon/internal/helpers" 17 "github.com/haileyok/cocoon/models" 18 "github.com/labstack/echo/v4" 19 "golang.org/x/crypto/bcrypt" 20 "gorm.io/gorm" 21) 22 23type ComAtprotoServerCreateAccountRequest struct { 24 Email string `json:"email" validate:"required,email"` 25 Handle string `json:"handle" validate:"required,atproto-handle"` 26 Did *string `json:"did" validate:"atproto-did"` 27 Password string `json:"password" validate:"required"` 28 InviteCode string `json:"inviteCode" validate:"omitempty"` 29} 30 31type ComAtprotoServerCreateAccountResponse struct { 32 AccessJwt string `json:"accessJwt"` 33 RefreshJwt string `json:"refreshJwt"` 34 Handle string `json:"handle"` 35 Did string `json:"did"` 36} 37 38func (s *Server) handleCreateAccount(e echo.Context) error { 39 ctx := e.Request().Context() 40 logger := s.logger.With("name", "handleServerCreateAccount") 41 42 var request ComAtprotoServerCreateAccountRequest 43 44 if err := e.Bind(&request); err != nil { 45 logger.Error("error receiving request", "endpoint", "com.atproto.server.createAccount", "error", err) 46 return helpers.ServerError(e, nil) 47 } 48 49 request.Handle = strings.ToLower(request.Handle) 50 51 if err := e.Validate(request); err != nil { 52 logger.Error("error validating request", "endpoint", "com.atproto.server.createAccount", "error", err) 53 54 var verr ValidationError 55 if errors.As(err, &verr) { 56 if verr.Field == "Email" { 57 // TODO: what is this supposed to be? `InvalidEmail` isn't listed in doc 58 return helpers.InputError(e, to.StringPtr("InvalidEmail")) 59 } 60 61 if verr.Field == "Handle" { 62 return helpers.InputError(e, to.StringPtr("InvalidHandle")) 63 } 64 65 if verr.Field == "Password" { 66 return helpers.InputError(e, to.StringPtr("InvalidPassword")) 67 } 68 69 if verr.Field == "InviteCode" { 70 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 71 } 72 } 73 } 74 75 var signupDid string 76 if request.Did != nil { 77 signupDid = *request.Did 78 79 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 80 if token == "" { 81 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) 82 } 83 authDid, err := s.validateServiceAuth(e.Request().Context(), token, "com.atproto.server.createAccount") 84 85 if err != nil { 86 logger.Warn("error validating authorization token", "endpoint", "com.atproto.server.createAccount", "error", err) 87 return helpers.UnauthorizedError(e, to.StringPtr("invalid authorization token")) 88 } 89 90 if authDid != signupDid { 91 return helpers.ForbiddenError(e, to.StringPtr("auth did did not match signup did")) 92 } 93 } 94 95 // see if the handle is already taken 96 actor, err := s.getActorByHandle(ctx, request.Handle) 97 if err != nil && err != gorm.ErrRecordNotFound { 98 logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 99 return helpers.ServerError(e, nil) 100 } 101 if err == nil && actor.Did != signupDid { 102 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 103 } 104 105 if did, err := s.passport.ResolveHandle(e.Request().Context(), request.Handle); err == nil && did != signupDid { 106 return helpers.InputError(e, to.StringPtr("HandleNotAvailable")) 107 } 108 109 var ic models.InviteCode 110 if s.config.RequireInvite { 111 if strings.TrimSpace(request.InviteCode) == "" { 112 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 113 } 114 115 if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 116 if err == gorm.ErrRecordNotFound { 117 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 118 } 119 logger.Error("error getting invite code from db", "error", err) 120 return helpers.ServerError(e, nil) 121 } 122 123 if ic.RemainingUseCount < 1 { 124 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 125 } 126 } 127 128 // see if the email is already taken 129 existingRepo, err := s.getRepoByEmail(ctx, request.Email) 130 if err != nil && err != gorm.ErrRecordNotFound { 131 logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 132 return helpers.ServerError(e, nil) 133 } 134 if err == nil && existingRepo.Did != signupDid { 135 return helpers.InputError(e, to.StringPtr("EmailNotAvailable")) 136 } 137 138 // TODO: unsupported domains 139 140 var k *atcrypto.PrivateKeyK256 141 142 if signupDid != "" { 143 reservedKey, err := s.getReservedKey(ctx, signupDid) 144 if err != nil { 145 logger.Error("error looking up reserved key", "error", err) 146 } 147 if reservedKey != nil { 148 k, err = atcrypto.ParsePrivateBytesK256(reservedKey.PrivateKey) 149 if err != nil { 150 logger.Error("error parsing reserved key", "error", err) 151 k = nil 152 } else { 153 defer func() { 154 if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil { 155 logger.Error("error deleting reserved key", "error", delErr) 156 } 157 }() 158 } 159 } 160 } 161 162 if k == nil { 163 k, err = atcrypto.GeneratePrivateKeyK256() 164 if err != nil { 165 logger.Error("error creating signing key", "endpoint", "com.atproto.server.createAccount", "error", err) 166 return helpers.ServerError(e, nil) 167 } 168 } 169 170 if signupDid == "" { 171 did, op, err := s.plcClient.CreateDID(k, "", request.Handle) 172 if err != nil { 173 logger.Error("error creating operation", "endpoint", "com.atproto.server.createAccount", "error", err) 174 return helpers.ServerError(e, nil) 175 } 176 177 if err := s.plcClient.SendOperation(e.Request().Context(), did, op); err != nil { 178 logger.Error("error sending plc op", "endpoint", "com.atproto.server.createAccount", "error", err) 179 return helpers.ServerError(e, nil) 180 } 181 signupDid = did 182 } 183 184 hashed, err := bcrypt.GenerateFromPassword([]byte(request.Password), 10) 185 if err != nil { 186 logger.Error("error hashing password", "error", err) 187 return helpers.ServerError(e, nil) 188 } 189 190 urepo := models.Repo{ 191 Did: signupDid, 192 CreatedAt: time.Now(), 193 Email: request.Email, 194 EmailVerificationCode: to.StringPtr(fmt.Sprintf("%s-%s", helpers.RandomVarchar(6), helpers.RandomVarchar(6))), 195 Password: string(hashed), 196 SigningKey: k.Bytes(), 197 } 198 199 if actor == nil { 200 actor = &models.Actor{ 201 Did: signupDid, 202 Handle: request.Handle, 203 } 204 205 if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { 206 logger.Error("error inserting new repo", "error", err) 207 return helpers.ServerError(e, nil) 208 } 209 210 if err := s.db.Create(ctx, &actor, nil).Error; err != nil { 211 logger.Error("error inserting new actor", "error", err) 212 return helpers.ServerError(e, nil) 213 } 214 } else { 215 if err := s.db.Save(ctx, &actor, nil).Error; err != nil { 216 logger.Error("error inserting new actor", "error", err) 217 return helpers.ServerError(e, nil) 218 } 219 } 220 221 if request.Did == nil || *request.Did == "" { 222 bs := s.getBlockstore(signupDid) 223 r := repo.NewRepo(context.TODO(), signupDid, bs) 224 225 root, rev, err := r.Commit(context.TODO(), urepo.SignFor) 226 if err != nil { 227 logger.Error("error committing", "error", err) 228 return helpers.ServerError(e, nil) 229 } 230 231 if err := s.UpdateRepo(context.TODO(), urepo.Did, root, rev); err != nil { 232 logger.Error("error updating repo after commit", "error", err) 233 return helpers.ServerError(e, nil) 234 } 235 236 s.evtman.AddEvent(context.TODO(), &events.XRPCStreamEvent{ 237 RepoIdentity: &atproto.SyncSubscribeRepos_Identity{ 238 Did: urepo.Did, 239 Handle: to.StringPtr(request.Handle), 240 Seq: time.Now().UnixMicro(), // TODO: no 241 Time: time.Now().Format(util.ISO8601), 242 }, 243 }) 244 } 245 246 if s.config.RequireInvite { 247 if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 248 logger.Error("error decrementing use count", "error", err) 249 return helpers.ServerError(e, nil) 250 } 251 } 252 253 sess, err := s.createSession(ctx, &urepo) 254 if err != nil { 255 logger.Error("error creating new session", "error", err) 256 return helpers.ServerError(e, nil) 257 } 258 259 go func() { 260 if err := s.sendEmailVerification(urepo.Email, actor.Handle, *urepo.EmailVerificationCode); err != nil { 261 logger.Error("error sending email verification email", "error", err) 262 } 263 if err := s.sendWelcomeMail(urepo.Email, actor.Handle); err != nil { 264 logger.Error("error sending welcome email", "error", err) 265 } 266 }() 267 268 return e.JSON(200, ComAtprotoServerCreateAccountResponse{ 269 AccessJwt: sess.AccessToken, 270 RefreshJwt: sess.RefreshToken, 271 Handle: request.Handle, 272 Did: signupDid, 273 }) 274}