1package server
2
3import (
4 "errors"
5 "fmt"
6 "strings"
7 "time"
8
9 "github.com/bluesky-social/indigo/atproto/syntax"
10 "github.com/gorilla/sessions"
11 "github.com/haileyok/cocoon/internal/helpers"
12 "github.com/haileyok/cocoon/models"
13 "github.com/labstack/echo-contrib/session"
14 "github.com/labstack/echo/v4"
15 "golang.org/x/crypto/bcrypt"
16 "gorm.io/gorm"
17)
18
19type OauthSigninInput struct {
20 Username string `form:"username"`
21 Password string `form:"password"`
22 AuthFactorToken string `form:"token"`
23 QueryParams string `form:"query_params"`
24}
25
26func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
27 ctx := e.Request().Context()
28
29 sess, err := session.Get("session", e)
30 if err != nil {
31 return nil, nil, err
32 }
33
34 did, ok := sess.Values["did"].(string)
35 if !ok {
36 return nil, sess, errors.New("did was not set in session")
37 }
38
39 repo, err := s.getRepoActorByDid(ctx, did)
40 if err != nil {
41 return nil, sess, err
42 }
43
44 return repo, sess, nil
45}
46
47func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
48 defer sess.Save(e.Request(), e.Response())
49 return map[string]any{
50 "errors": sess.Flashes("error"),
51 "successes": sess.Flashes("success"),
52 "tokenrequired": sess.Flashes("tokenrequired"),
53 }
54}
55
56func (s *Server) handleAccountSigninGet(e echo.Context) error {
57 _, sess, err := s.getSessionRepoOrErr(e)
58 if err == nil {
59 return e.Redirect(303, "/account")
60 }
61
62 return e.Render(200, "signin.html", map[string]any{
63 "flashes": getFlashesFromSession(e, sess),
64 "QueryParams": e.QueryParams().Encode(),
65 })
66}
67
68func (s *Server) handleAccountSigninPost(e echo.Context) error {
69 ctx := e.Request().Context()
70 logger := s.logger.With("name", "handleAccountSigninPost")
71
72 var req OauthSigninInput
73 if err := e.Bind(&req); err != nil {
74 logger.Error("error binding sign in req", "error", err)
75 return helpers.ServerError(e, nil)
76 }
77
78 sess, _ := session.Get("session", e)
79
80 req.Username = strings.ToLower(req.Username)
81 var idtype string
82 if _, err := syntax.ParseDID(req.Username); err == nil {
83 idtype = "did"
84 } else if _, err := syntax.ParseHandle(req.Username); err == nil {
85 idtype = "handle"
86 } else {
87 idtype = "email"
88 }
89
90 queryParams := ""
91 if req.QueryParams != "" {
92 queryParams = fmt.Sprintf("?%s", req.QueryParams)
93 }
94
95 // TODO: we should make this a helper since we do it for the base create_session as well
96 var repo models.RepoActor
97 var err error
98 switch idtype {
99 case "did":
100 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
101 case "handle":
102 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
103 case "email":
104 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
105 }
106 if err != nil {
107 if err == gorm.ErrRecordNotFound {
108 sess.AddFlash("Handle or password is incorrect", "error")
109 } else {
110 sess.AddFlash("Something went wrong!", "error")
111 }
112 sess.Save(e.Request(), e.Response())
113 return e.Redirect(303, "/account/signin"+queryParams)
114 }
115
116 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
117 if err != bcrypt.ErrMismatchedHashAndPassword {
118 sess.AddFlash("Handle or password is incorrect", "error")
119 } else {
120 sess.AddFlash("Something went wrong!", "error")
121 }
122 sess.Save(e.Request(), e.Response())
123 return e.Redirect(303, "/account/signin"+queryParams)
124 }
125
126 // if repo requires 2FA token and one hasn't been provided, return error prompting for one
127 if repo.TwoFactorType != models.TwoFactorTypeNone && req.AuthFactorToken == "" {
128 err = s.createAndSendTwoFactorCode(ctx, repo)
129 if err != nil {
130 sess.AddFlash("Something went wrong!", "error")
131 sess.Save(e.Request(), e.Response())
132 return e.Redirect(303, "/account/signin"+queryParams)
133 }
134
135 sess.AddFlash("requires 2FA token", "tokenrequired")
136 sess.Save(e.Request(), e.Response())
137 return e.Redirect(303, "/account/signin"+queryParams)
138 }
139
140 // if 2FAis required, now check that the one provided is valid
141 if repo.TwoFactorType != models.TwoFactorTypeNone {
142 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil {
143 err = s.createAndSendTwoFactorCode(ctx, repo)
144 if err != nil {
145 sess.AddFlash("Something went wrong!", "error")
146 sess.Save(e.Request(), e.Response())
147 return e.Redirect(303, "/account/signin"+queryParams)
148 }
149
150 sess.AddFlash("requires 2FA token", "tokenrequired")
151 sess.Save(e.Request(), e.Response())
152 return e.Redirect(303, "/account/signin"+queryParams)
153 }
154
155 if *repo.TwoFactorCode != req.AuthFactorToken {
156 return helpers.InvalidTokenError(e)
157 }
158
159 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
160 return helpers.ExpiredTokenError(e)
161 }
162 }
163
164 sess.Options = &sessions.Options{
165 Path: "/",
166 MaxAge: int(AccountSessionMaxAge.Seconds()),
167 HttpOnly: true,
168 }
169
170 sess.Values = map[any]any{}
171 sess.Values["did"] = repo.Repo.Did
172
173 if err := sess.Save(e.Request(), e.Response()); err != nil {
174 return err
175 }
176
177 if queryParams != "" {
178 return e.Redirect(303, "/oauth/authorize"+queryParams)
179 } else {
180 return e.Redirect(303, "/account")
181 }
182}