1package server
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "strings"
8 "time"
9
10 "github.com/Azure/go-autorest/autorest/to"
11 "github.com/bluesky-social/indigo/atproto/syntax"
12 "github.com/haileyok/cocoon/internal/helpers"
13 "github.com/haileyok/cocoon/models"
14 "github.com/labstack/echo/v4"
15 "golang.org/x/crypto/bcrypt"
16 "gorm.io/gorm"
17)
18
19type ComAtprotoServerCreateSessionRequest struct {
20 Identifier string `json:"identifier" validate:"required"`
21 Password string `json:"password" validate:"required"`
22 AuthFactorToken *string `json:"authFactorToken,omitempty"`
23}
24
25type ComAtprotoServerCreateSessionResponse struct {
26 AccessJwt string `json:"accessJwt"`
27 RefreshJwt string `json:"refreshJwt"`
28 Handle string `json:"handle"`
29 Did string `json:"did"`
30 Email string `json:"email"`
31 EmailConfirmed bool `json:"emailConfirmed"`
32 EmailAuthFactor bool `json:"emailAuthFactor"`
33 Active bool `json:"active"`
34 Status *string `json:"status,omitempty"`
35}
36
37func (s *Server) handleCreateSession(e echo.Context) error {
38 ctx := e.Request().Context()
39 logger := s.logger.With("name", "handleServerCreateSession")
40
41 var req ComAtprotoServerCreateSessionRequest
42 if err := e.Bind(&req); err != nil {
43 logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
44 return helpers.ServerError(e, nil)
45 }
46
47 if err := e.Validate(req); err != nil {
48 var verr ValidationError
49 if errors.As(err, &verr) {
50 if verr.Field == "Identifier" {
51 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
52 }
53
54 if verr.Field == "Password" {
55 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
56 }
57 }
58 }
59
60 req.Identifier = strings.ToLower(req.Identifier)
61 var idtype string
62 if _, err := syntax.ParseDID(req.Identifier); err == nil {
63 idtype = "did"
64 } else if _, err := syntax.ParseHandle(req.Identifier); err == nil {
65 idtype = "handle"
66 } else {
67 idtype = "email"
68 }
69
70 var repo models.RepoActor
71 var err error
72 switch idtype {
73 case "did":
74 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
75 case "handle":
76 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
77 case "email":
78 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
79 }
80
81 if err != nil {
82 if err == gorm.ErrRecordNotFound {
83 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
84 }
85
86 logger.Error("erorr looking up repo", "endpoint", "com.atproto.server.createSession", "error", err)
87 return helpers.ServerError(e, nil)
88 }
89
90 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
91 if err != bcrypt.ErrMismatchedHashAndPassword {
92 logger.Error("erorr comparing hash and password", "error", err)
93 }
94 return helpers.InputError(e, to.StringPtr("InvalidRequest"))
95 }
96
97 // if repo requires 2FA token and one hasn't been provided, return error prompting for one
98 if repo.TwoFactorType != models.TwoFactorTypeNone && (req.AuthFactorToken == nil || *req.AuthFactorToken == "") {
99 err = s.createAndSendTwoFactorCode(ctx, repo)
100 if err != nil {
101 logger.Error("sending 2FA code", "error", err)
102 return helpers.ServerError(e, nil)
103 }
104
105 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
106 }
107
108 // if 2FA is required, now check that the one provided is valid
109 if repo.TwoFactorType != models.TwoFactorTypeNone {
110 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil {
111 err = s.createAndSendTwoFactorCode(ctx, repo)
112 if err != nil {
113 logger.Error("sending 2FA code", "error", err)
114 return helpers.ServerError(e, nil)
115 }
116
117 return helpers.InputError(e, to.StringPtr("AuthFactorTokenRequired"))
118 }
119
120 if *repo.TwoFactorCode != *req.AuthFactorToken {
121 return helpers.InvalidTokenError(e)
122 }
123
124 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
125 return helpers.ExpiredTokenError(e)
126 }
127 }
128
129 sess, err := s.createSession(ctx, &repo.Repo)
130 if err != nil {
131 logger.Error("error creating session", "error", err)
132 return helpers.ServerError(e, nil)
133 }
134
135 return e.JSON(200, ComAtprotoServerCreateSessionResponse{
136 AccessJwt: sess.AccessToken,
137 RefreshJwt: sess.RefreshToken,
138 Handle: repo.Handle,
139 Did: repo.Repo.Did,
140 Email: repo.Email,
141 EmailConfirmed: repo.EmailConfirmedAt != nil,
142 EmailAuthFactor: repo.TwoFactorType != models.TwoFactorTypeNone,
143 Active: repo.Active(),
144 Status: repo.Status(),
145 })
146}
147
148func (s *Server) createAndSendTwoFactorCode(ctx context.Context, repo models.RepoActor) error {
149 // TODO: when implementing a new type of 2FA there should be some logic in here to send the
150 // right type of code
151
152 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
153 eat := time.Now().Add(10 * time.Minute).UTC()
154
155 if err := s.db.Exec(ctx, "UPDATE repos SET two_factor_code = ?, two_factor_code_expires_at = ? WHERE did = ?", nil, code, eat, repo.Repo.Did).Error; err != nil {
156 return fmt.Errorf("updating repo: %w", err)
157 }
158
159 if err := s.sendTwoFactorCode(repo.Email, repo.Handle, code); err != nil {
160 return fmt.Errorf("sending email: %w", err)
161 }
162
163 return nil
164}