···11package server
2233import (
44+ "context"
55+46 "github.com/haileyok/cocoon/models"
57)
6877-func (s *Server) getActorByHandle(handle string) (*models.Actor, error) {
99+func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) {
810 var actor models.Actor
99- if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil {
1111+ if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil {
1012 return nil, err
1113 }
1214 return &actor, nil
1315}
14161515-func (s *Server) getRepoByEmail(email string) (*models.Repo, error) {
1717+func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) {
1618 var repo models.Repo
1717- if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil {
1919+ if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil {
1820 return nil, err
1921 }
2022 return &repo, nil
2123}
22242323-func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) {
2525+func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) {
2426 var repo models.RepoActor
2525- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
2727+ if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
2628 return nil, err
2729 }
2830 return &repo, nil
2931}
30323131-func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) {
3333+func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) {
3234 var repo models.RepoActor
3333- if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
3535+ if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
3436 return nil, err
3537 }
3638 return &repo, nil
+4-2
server/handle_account.go
···12121313func (s *Server) handleAccount(e echo.Context) error {
1414 ctx := e.Request().Context()
1515+ logger := s.logger.With("name", "handleAuth")
1616+1517 repo, sess, err := s.getSessionRepoOrErr(e)
1618 if err != nil {
1719 return e.Redirect(303, "/account/signin")
···2022 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
21232224 var tokens []provider.OauthToken
2323- if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
2424- s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
2525+ if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
2626+ logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
2527 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
2628 sess.Save(e.Request(), e.Response())
2729 return e.Render(200, "account.html", map[string]any{
+8-5
server/handle_account_revoke.go
···55 "github.com/labstack/echo/v4"
66)
7788-type AccountRevokeRequest struct {
88+type AccountRevokeInput struct {
99 Token string `form:"token"`
1010}
11111212func (s *Server) handleAccountRevoke(e echo.Context) error {
1313- var req AccountRevokeRequest
1313+ ctx := e.Request().Context()
1414+ logger := s.logger.With("name", "handleAcocuntRevoke")
1515+1616+ var req AccountRevokeInput
1417 if err := e.Bind(&req); err != nil {
1515- s.logger.Error("could not bind account revoke request", "error", err)
1818+ logger.Error("could not bind account revoke request", "error", err)
1619 return helpers.ServerError(e, nil)
1720 }
1821···2124 return e.Redirect(303, "/account/signin")
2225 }
23262424- if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
2525- s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err)
2727+ if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
2828+ logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err)
2629 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error")
2730 sess.Save(e.Request(), e.Response())
2831 return e.Redirect(303, "/account")
+68-16
server/handle_account_signin.go
···2233import (
44 "errors"
55+ "fmt"
56 "strings"
77+ "time"
6879 "github.com/bluesky-social/indigo/atproto/syntax"
810 "github.com/gorilla/sessions"
···1416 "gorm.io/gorm"
1517)
16181717-type OauthSigninRequest struct {
1818- Username string `form:"username"`
1919- Password string `form:"password"`
2020- QueryParams string `form:"query_params"`
1919+type OauthSigninInput struct {
2020+ Username string `form:"username"`
2121+ Password string `form:"password"`
2222+ AuthFactorToken string `form:"token"`
2323+ QueryParams string `form:"query_params"`
2124}
22252326func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
2727+ ctx := e.Request().Context()
2828+2429 sess, err := session.Get("session", e)
2530 if err != nil {
2631 return nil, nil, err
···3136 return nil, sess, errors.New("did was not set in session")
3237 }
33383434- repo, err := s.getRepoActorByDid(did)
3939+ repo, err := s.getRepoActorByDid(ctx, did)
3540 if err != nil {
3641 return nil, sess, err
3742 }
···4247func getFlashesFromSession(e echo.Context, sess *sessions.Session) map[string]any {
4348 defer sess.Save(e.Request(), e.Response())
4449 return map[string]any{
4545- "errors": sess.Flashes("error"),
4646- "successes": sess.Flashes("success"),
5050+ "errors": sess.Flashes("error"),
5151+ "successes": sess.Flashes("success"),
5252+ "tokenrequired": sess.Flashes("tokenrequired"),
4753 }
4854}
4955···6066}
61676268func (s *Server) handleAccountSigninPost(e echo.Context) error {
6363- var req OauthSigninRequest
6969+ ctx := e.Request().Context()
7070+ logger := s.logger.With("name", "handleAccountSigninPost")
7171+7272+ var req OauthSigninInput
6473 if err := e.Bind(&req); err != nil {
6565- s.logger.Error("error binding sign in req", "error", err)
7474+ logger.Error("error binding sign in req", "error", err)
6675 return helpers.ServerError(e, nil)
6776 }
6877···7685 idtype = "handle"
7786 } else {
7887 idtype = "email"
8888+ }
8989+9090+ queryParams := ""
9191+ if req.QueryParams != "" {
9292+ queryParams = fmt.Sprintf("?%s", req.QueryParams)
7993 }
80948195 // TODO: we should make this a helper since we do it for the base create_session as well
···8397 var err error
8498 switch idtype {
8599 case "did":
8686- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
100100+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
87101 case "handle":
8888- err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
102102+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
89103 case "email":
9090- err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
104104+ err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
91105 }
92106 if err != nil {
93107 if err == gorm.ErrRecordNotFound {
···96110 sess.AddFlash("Something went wrong!", "error")
97111 }
98112 sess.Save(e.Request(), e.Response())
9999- return e.Redirect(303, "/account/signin")
113113+ return e.Redirect(303, "/account/signin"+queryParams)
100114 }
101115102116 if err := bcrypt.CompareHashAndPassword([]byte(repo.Password), []byte(req.Password)); err != nil {
···106120 sess.AddFlash("Something went wrong!", "error")
107121 }
108122 sess.Save(e.Request(), e.Response())
109109- return e.Redirect(303, "/account/signin")
123123+ return e.Redirect(303, "/account/signin"+queryParams)
124124+ }
125125+126126+ // if repo requires 2FA token and one hasn't been provided, return error prompting for one
127127+ if repo.TwoFactorType != models.TwoFactorTypeNone && req.AuthFactorToken == "" {
128128+ err = s.createAndSendTwoFactorCode(ctx, repo)
129129+ if err != nil {
130130+ sess.AddFlash("Something went wrong!", "error")
131131+ sess.Save(e.Request(), e.Response())
132132+ return e.Redirect(303, "/account/signin"+queryParams)
133133+ }
134134+135135+ sess.AddFlash("requires 2FA token", "tokenrequired")
136136+ sess.Save(e.Request(), e.Response())
137137+ return e.Redirect(303, "/account/signin"+queryParams)
138138+ }
139139+140140+ // if 2FAis required, now check that the one provided is valid
141141+ if repo.TwoFactorType != models.TwoFactorTypeNone {
142142+ if repo.TwoFactorCode == nil || repo.TwoFactorCodeExpiresAt == nil {
143143+ err = s.createAndSendTwoFactorCode(ctx, repo)
144144+ if err != nil {
145145+ sess.AddFlash("Something went wrong!", "error")
146146+ sess.Save(e.Request(), e.Response())
147147+ return e.Redirect(303, "/account/signin"+queryParams)
148148+ }
149149+150150+ sess.AddFlash("requires 2FA token", "tokenrequired")
151151+ sess.Save(e.Request(), e.Response())
152152+ return e.Redirect(303, "/account/signin"+queryParams)
153153+ }
154154+155155+ if *repo.TwoFactorCode != req.AuthFactorToken {
156156+ return helpers.InvalidTokenError(e)
157157+ }
158158+159159+ if time.Now().UTC().After(*repo.TwoFactorCodeExpiresAt) {
160160+ return helpers.ExpiredTokenError(e)
161161+ }
110162 }
111163112164 sess.Options = &sessions.Options{
···122174 return err
123175 }
124176125125- if req.QueryParams != "" {
126126- return e.Redirect(303, "/oauth/authorize?"+req.QueryParams)
177177+ if queryParams != "" {
178178+ return e.Redirect(303, "/oauth/authorize"+queryParams)
127179 } else {
128180 return e.Redirect(303, "/account")
129181 }
+3-1
server/handle_actor_put_preferences.go
···1010// This is kinda lame. Not great to implement app.bsky in the pds, but alas
11111212func (s *Server) handleActorPutPreferences(e echo.Context) error {
1313+ ctx := e.Request().Context()
1414+1315 repo := e.Get("repo").(*models.RepoActor)
14161517 var prefs map[string]any
···2224 return err
2325 }
24262525- if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
2727+ if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
2628 return err
2729 }
2830
···2020}
21212222func (s *Server) handleDescribeRepo(e echo.Context) error {
2323+ ctx := e.Request().Context()
2424+ logger := s.logger.With("name", "handleDescribeRepo")
2525+2326 did := e.QueryParam("repo")
2424- repo, err := s.getRepoActorByDid(did)
2727+ repo, err := s.getRepoActorByDid(ctx, did)
2528 if err != nil {
2629 if err == gorm.ErrRecordNotFound {
2730 return helpers.InputError(e, to.StringPtr("RepoNotFound"))
2831 }
29323030- s.logger.Error("error looking up repo", "error", err)
3333+ logger.Error("error looking up repo", "error", err)
3134 return helpers.ServerError(e, nil)
3235 }
3336···35383639 diddoc, err := s.passport.FetchDoc(e.Request().Context(), repo.Repo.Did)
3740 if err != nil {
3838- s.logger.Error("error fetching diddoc", "error", err)
4141+ logger.Error("error fetching diddoc", "error", err)
3942 return helpers.ServerError(e, nil)
4043 }
4144···6467 }
65686669 var records []models.Record
6767- if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
6868- s.logger.Error("error getting collections", "error", err)
7070+ if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
7171+ logger.Error("error getting collections", "error", err)
6972 return helpers.ServerError(e, nil)
7073 }
7174
+3-1
server/handle_repo_get_record.go
···1414}
15151616func (s *Server) handleRepoGetRecord(e echo.Context) error {
1717+ ctx := e.Request().Context()
1818+1719 repo := e.QueryParam("repo")
1820 collection := e.QueryParam("collection")
1921 rkey := e.QueryParam("rkey")
···3234 }
33353436 var record models.Record
3535- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
3737+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
3638 // TODO: handle error nicely
3739 return err
3840 }
+6-3
server/handle_repo_list_missing_blobs.go
···2222}
23232424func (s *Server) handleListMissingBlobs(e echo.Context) error {
2525+ ctx := e.Request().Context()
2626+ logger := s.logger.With("name", "handleListMissingBlos")
2727+2528 urepo := e.Get("repo").(*models.RepoActor)
26292730 limitStr := e.QueryParam("limit")
···3538 }
36393740 var records []models.Record
3838- if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
3939- s.logger.Error("failed to get records for listMissingBlobs", "error", err)
4141+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
4242+ logger.Error("failed to get records for listMissingBlobs", "error", err)
4043 return helpers.ServerError(e, nil)
4144 }
4245···6972 }
70737174 var count int64
7272- if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
7575+ if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
7376 continue
7477 }
7578
+7-4
server/handle_repo_list_records.go
···4646}
47474848func (s *Server) handleListRecords(e echo.Context) error {
4949+ ctx := e.Request().Context()
5050+ logger := s.logger.With("name", "handleListRecords")
5151+4952 var req ComAtprotoRepoListRecordsRequest
5053 if err := e.Bind(&req); err != nil {
5151- s.logger.Error("could not bind list records request", "error", err)
5454+ logger.Error("could not bind list records request", "error", err)
5255 return helpers.ServerError(e, nil)
5356 }
5457···78817982 did := req.Repo
8083 if _, err := syntax.ParseDID(did); err != nil {
8181- actor, err := s.getActorByHandle(req.Repo)
8484+ actor, err := s.getActorByHandle(ctx, req.Repo)
8285 if err != nil {
8386 return helpers.InputError(e, to.StringPtr("RepoNotFound"))
8487 }
···9396 params = append(params, limit)
94979598 var records []models.Record
9696- if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
9797- s.logger.Error("error getting records", "error", err)
9999+ if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
100100+ logger.Error("error getting records", "error", err)
98101 return helpers.ServerError(e, nil)
99102 }
100103
+3-1
server/handle_repo_list_repos.go
···21212222// TODO: paginate this bitch
2323func (s *Server) handleListRepos(e echo.Context) error {
2424+ ctx := e.Request().Context()
2525+2426 var repos []models.Repo
2525- if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
2727+ if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
2628 return err
2729 }
2830