forked from hailey.at/cocoon
An atproto PDS written in Go
1package server 2 3import ( 4 "context" 5 "errors" 6 "fmt" 7 "strings" 8 "time" 9 10 "github.com/Azure/go-autorest/autorest/to" 11 "github.com/bluesky-social/indigo/atproto/syntax" 12 "github.com/haileyok/cocoon/internal/helpers" 13 "github.com/haileyok/cocoon/models" 14 "github.com/labstack/echo/v4" 15 "golang.org/x/crypto/bcrypt" 16 "gorm.io/gorm" 17) 18 19type ComAtprotoServerCreateSessionRequest struct { 20 Identifier string `json:"identifier" validate:"required"` 21 Password string `json:"password" validate:"required"` 22 AuthFactorToken *string `json:"authFactorToken,omitempty"` 23} 24 25type ComAtprotoServerCreateSessionResponse struct { 26 AccessJwt string `json:"accessJwt"` 27 RefreshJwt string `json:"refreshJwt"` 28 Handle string `json:"handle"` 29 Did string `json:"did"` 30 Email string `json:"email"` 31 EmailConfirmed bool `json:"emailConfirmed"` 32 EmailAuthFactor bool `json:"emailAuthFactor"` 33 Active bool `json:"active"` 34 Status *string `json:"status,omitempty"` 35} 36 37func (s *Server) handleCreateSession(e echo.Context) error { 38 ctx := e.Request().Context() 39 logger := s.logger.With("name", "handleServerCreateSession") 40 41 var req ComAtprotoServerCreateSessionRequest 42 if err := e.Bind(&req); err != nil { 43 logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) 44 return helpers.ServerError(e, nil) 45 } 46 47 if err := e.Validate(req); err != nil { 48 var verr ValidationError 49 if errors.As(err, &verr) { 50 if verr.Field == "Identifier" { 51 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 52 } 53 54 if verr.Field == "Password" { 55 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 56 } 57 } 58 } 59 60 req.Identifier = strings.ToLower(req.Identifier) 61 var idtype string 62 if _, err := syntax.ParseDID(req.Identifier); err == nil { 63 idtype = "did" 64 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil { 65 idtype = "handle" 66 } else { 67 idtype = "email" 68 } 69 70 var repo models.RepoActor 71 var err error 72 switch idtype { 73 case "did": 74 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 75 case "handle": 76 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 77 case "email": 78 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 79 } 80 81 if err != nil { 82 if err == gorm.ErrRecordNotFound { 83 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 84 } 85 86 logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err) 87 return helpers.ServerError(e, nil) 88 } 89 90 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { 91 if err != bcrypt.ErrMismatchedHashAndPassword { 92 logger.Error("erorr comparing hash and password", "error", err) 93 } 94 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 95 } 96 97 // if repo requires 2FA token and one hasn't been provided, return error prompting for one 98 if repo.TwoFactorType != models.TwoFactorTypeNone && (req.AuthFactorToken == nil || *req.AuthFactorToken == "") { 99 err = s.createAndSendTwoFactorCode(ctx, repo) 100 if err != nil { 101 logger.Error("sending 2FA code", "error", err) 102 return helpers.ServerError(e, nil) 103 } 104 105 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 106 } 107 108 // if 2FA is required, now check that the one provided is valid 109 if repo.TwoFactorType != models.TwoFactorTypeNone { 110 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil { 111 err = s.createAndSendTwoFactorCode(ctx, repo) 112 if err != nil { 113 logger.Error("sending 2FA code", "error", err) 114 return helpers.ServerError(e, nil) 115 } 116 117 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired")) 118 } 119 120 if *repo.TwoFactorCode != *req.AuthFactorToken { 121 return helpers.InvalidTokenError(e) 122 } 123 124 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 125 return helpers.ExpiredTokenError(e) 126 } 127 } 128 129 sess, err := s.createSession(ctx, &repo.Repo) 130 if err != nil { 131 logger.Error("error creating session", "error", err) 132 return helpers.ServerError(e, nil) 133 } 134 135 return e.JSON(200, ComAtprotoServerCreateSessionResponse{ 136 AccessJwt: sess.AccessToken, 137 RefreshJwt: sess.RefreshToken, 138 Handle: repo.Handle, 139 Did: repo.Repo.Did, 140 Email: repo.Email, 141 EmailConfirmed: repo.EmailConfirmedAt != nil, 142 EmailAuthFactor: repo.TwoFactorType != models.TwoFactorTypeNone, 143 Active: repo.Active(), 144 Status: repo.Status(), 145 }) 146} 147 148func (s *Server) createAndSendTwoFactorCode(ctx context.Context, repo models.RepoActor) error { 149 // TODO: when implementing a new type of 2FA there should be some logic in here to send the 150 // right type of code 151 152 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 153 eat := time.Now().Add(10 * time.Minute).UTC() 154 155 if err := s.db.Exec(ctx, "UPDATE repos SET two_factor_code = ?, two_factor_code_expires_at = ? WHERE did = ?", nil, code, eat, repo.Repo.Did).Error; err != nil { 156 return fmt.Errorf("updating repo: %w", err) 157 } 158 159 if err := s.sendTwoFactorCode(repo.Email, repo.Handle, code); err != nil { 160 return fmt.Errorf("sending email: %w", err) 161 } 162 163 return nil 164}