forked from hailey.at/cocoon
An atproto PDS written in Go
1package server 2 3import ( 4 "errors" 5 "fmt" 6 "strings" 7 "time" 8 9 "github.com/bluesky-social/indigo/atproto/syntax" 10 "github.com/gorilla/sessions" 11 "github.com/haileyok/cocoon/internal/helpers" 12 "github.com/haileyok/cocoon/models" 13 "github.com/labstack/echo-contrib/session" 14 "github.com/labstack/echo/v4" 15 "golang.org/x/crypto/bcrypt" 16 "gorm.io/gorm" 17) 18 19type OauthSigninInput struct { 20 Username string `form:"username"` 21 Password string `form:"password"` 22 AuthFactorToken string `form:"token"` 23 QueryParams string `form:"query_params"` 24} 25 26func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 27 ctx := e.Request().Context() 28 29 sess, err := session.Get("session", e) 30 if err != nil { 31 return nil, nil, err 32 } 33 34 did, ok := sess.Values["did"].(string) 35 if !ok { 36 return nil, sess, errors.New("did was not set in session") 37 } 38 39 repo, err := s.getRepoActorByDid(ctx, did) 40 if err != nil { 41 return nil, sess, err 42 } 43 44 return repo, sess, nil 45} 46 47func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any { 48 defer sess.Save(e.Request(), e.Response()) 49 return map[string]any{ 50 "errors": sess.Flashes("error"), 51 "successes": sess.Flashes("success"), 52 "tokenrequired": sess.Flashes("tokenrequired"), 53 } 54} 55 56func (s *Server) handleAccountSigninGet(e echo.Context) error { 57 _, sess, err := s.getSessionRepoOrErr(e) 58 if err == nil { 59 return e.Redirect(303, "/account") 60 } 61 62 return e.Render(200, "signin.html", map[string]any{ 63 "flashes": getFlashesFromSession(e, sess), 64 "QueryParams": e.QueryParams().Encode(), 65 }) 66} 67 68func (s *Server) handleAccountSigninPost(e echo.Context) error { 69 ctx := e.Request().Context() 70 logger := s.logger.With("name", "handleAccountSigninPost") 71 72 var req OauthSigninInput 73 if err := e.Bind(&req); err != nil { 74 logger.Error("error binding sign in req", "error", err) 75 return helpers.ServerError(e, nil) 76 } 77 78 sess, _ := session.Get("session", e) 79 80 req.Username = strings.ToLower(req.Username) 81 var idtype string 82 if _, err := syntax.ParseDID(req.Username); err == nil { 83 idtype = "did" 84 } else if _, err := syntax.ParseHandle(req.Username); err == nil { 85 idtype = "handle" 86 } else { 87 idtype = "email" 88 } 89 90 queryParams := "" 91 if req.QueryParams != "" { 92 queryParams = fmt.Sprintf("?%s", req.QueryParams) 93 } 94 95 // TODO: we should make this a helper since we do it for the base create_session as well 96 var repo models.RepoActor 97 var err error 98 switch idtype { 99 case "did": 100 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 101 case "handle": 102 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 103 case "email": 104 err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 105 } 106 if err != nil { 107 if err == gorm.ErrRecordNotFound { 108 sess.AddFlash("Handle or password is incorrect", "error") 109 } else { 110 sess.AddFlash("Something went wrong!", "error") 111 } 112 sess.Save(e.Request(), e.Response()) 113 return e.Redirect(303, "/account/signin"+queryParams) 114 } 115 116 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil { 117 if err != bcrypt.ErrMismatchedHashAndPassword { 118 sess.AddFlash("Handle or password is incorrect", "error") 119 } else { 120 sess.AddFlash("Something went wrong!", "error") 121 } 122 sess.Save(e.Request(), e.Response()) 123 return e.Redirect(303, "/account/signin"+queryParams) 124 } 125 126 // if repo requires 2FA token and one hasn't been provided, return error prompting for one 127 if repo.TwoFactorType != models.TwoFactorTypeNone && req.AuthFactorToken == "" { 128 err = s.createAndSendTwoFactorCode(ctx, repo) 129 if err != nil { 130 sess.AddFlash("Something went wrong!", "error") 131 sess.Save(e.Request(), e.Response()) 132 return e.Redirect(303, "/account/signin"+queryParams) 133 } 134 135 sess.AddFlash("requires 2FA token", "tokenrequired") 136 sess.Save(e.Request(), e.Response()) 137 return e.Redirect(303, "/account/signin"+queryParams) 138 } 139 140 // if 2FAis required, now check that the one provided is valid 141 if repo.TwoFactorType != models.TwoFactorTypeNone { 142 if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil { 143 err = s.createAndSendTwoFactorCode(ctx, repo) 144 if err != nil { 145 sess.AddFlash("Something went wrong!", "error") 146 sess.Save(e.Request(), e.Response()) 147 return e.Redirect(303, "/account/signin"+queryParams) 148 } 149 150 sess.AddFlash("requires 2FA token", "tokenrequired") 151 sess.Save(e.Request(), e.Response()) 152 return e.Redirect(303, "/account/signin"+queryParams) 153 } 154 155 if *repo.TwoFactorCode != req.AuthFactorToken { 156 return helpers.InvalidTokenError(e) 157 } 158 159 if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) { 160 return helpers.ExpiredTokenError(e) 161 } 162 } 163 164 sess.Options = &sessions.Options{ 165 Path: "/", 166 MaxAge: int(AccountSessionMaxAge.Seconds()), 167 HttpOnly: true, 168 } 169 170 sess.Values = map[any]any{} 171 sess.Values["did"] = repo.Repo.Did 172 173 if err := sess.Save(e.Request(), e.Response()); err != nil { 174 return err 175 } 176 177 if queryParams != "" { 178 return e.Redirect(303, "/oauth/authorize"+queryParams) 179 } else { 180 return e.Redirect(303, "/account") 181 } 182}