From 0d9fcc5a3825db9b681594efc1af38ea6032f03b Mon Sep 17 00:00:00 2001 From: Hailey Date: Tue, 30 Dec 2025 02:53:39 -0800 Subject: [PATCH] add withcontxt to all db calls --- internal/db/db.go | 29 +++++++++--------- server/common.go | 18 ++++++----- server/handle_account.go | 3 +- server/handle_account_revoke.go | 8 +++-- server/handle_account_signin.go | 16 ++++++---- server/handle_actor_put_preferences.go | 4 ++- .../handle_identity_request_plc_operation.go | 4 ++- server/handle_identity_sign_plc_operation.go | 2 +- server/handle_identity_update_handle.go | 2 +- server/handle_import_repo.go | 4 ++- server/handle_oauth_authorize.go | 10 +++++-- server/handle_oauth_par.go | 4 ++- server/handle_oauth_token.go | 12 ++++---- server/handle_repo_describe_repo.go | 6 ++-- server/handle_repo_get_record.go | 4 ++- server/handle_repo_list_missing_blobs.go | 6 ++-- server/handle_repo_list_records.go | 6 ++-- server/handle_repo_list_repos.go | 4 ++- server/handle_repo_upload_blob.go | 8 +++-- server/handle_server_activate_account.go | 4 ++- server/handle_server_check_account_status.go | 8 +++-- server/handle_server_confirm_email.go | 4 ++- server/handle_server_create_account.go | 30 ++++++++++--------- server/handle_server_create_invite_code.go | 4 ++- server/handle_server_create_invite_codes.go | 4 ++- server/handle_server_create_session.go | 10 ++++--- server/handle_server_deactivate_account.go | 4 ++- server/handle_server_delete_account.go | 6 ++-- server/handle_server_delete_session.go | 6 ++-- server/handle_server_refresh_session.go | 8 +++-- .../handle_server_request_account_delete.go | 4 ++- ...andle_server_request_email_confirmation.go | 4 ++- server/handle_server_request_email_update.go | 4 ++- .../handle_server_request_password_reset.go | 6 ++-- server/handle_server_reserve_signing_key.go | 19 +++++++----- server/handle_server_reset_password.go | 4 ++- server/handle_server_update_email.go | 4 ++- server/handle_sync_get_blob.go | 8 +++-- server/handle_sync_get_blocks.go | 2 +- server/handle_sync_get_latest_commit.go | 4 ++- server/handle_sync_get_record.go | 2 +- server/handle_sync_get_repo.go | 6 ++-- server/handle_sync_get_repo_status.go | 4 ++- server/handle_sync_list_blobs.go | 6 ++-- server/handle_well_known.go | 4 ++- server/middleware.go | 14 +++++---- server/repo.go | 22 +++++++------- server/server.go | 2 +- server/session.go | 7 +++-- sqlite_blockstore/sqlite_blockstore.go | 6 ++-- 50 files changed, 230 insertions(+), 140 deletions(-) diff --git a/internal/db/db.go b/internal/db/db.go index 73c3577..a431332 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,7 @@ package db import ( + "context" "sync" "gorm.io/gorm" @@ -19,47 +20,47 @@ func NewDB(cli *gorm.DB) *DB { } } -func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB { +func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { db.mu.Lock() defer db.mu.Unlock() - return db.cli.Clauses(clauses...).Create(value) + return db.cli.WithContext(ctx).Clauses(clauses...).Create(value) } -func (db *DB) Save(value any, clauses []clause.Expression) *gorm.DB { +func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { db.mu.Lock() defer db.mu.Unlock() - return db.cli.Clauses(clauses...).Save(value) + return db.cli.WithContext(ctx).Clauses(clauses...).Save(value) } -func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB { +func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { db.mu.Lock() defer db.mu.Unlock() - return db.cli.Clauses(clauses...).Exec(sql, values...) + return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...) } -func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB { - return db.cli.Clauses(clauses...).Raw(sql, values...) +func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { + return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...) } func (db *DB) AutoMigrate(models ...any) error { return db.cli.AutoMigrate(models...) } -func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB { +func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { db.mu.Lock() defer db.mu.Unlock() - return db.cli.Clauses(clauses...).Delete(value) + return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value) } -func (db *DB) First(dest any, conds ...any) *gorm.DB { - return db.cli.First(dest, conds...) +func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB { + return db.cli.WithContext(ctx).First(dest, conds...) } // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. // e.g. when we do apply writes we should also be using a transcation but we don't right now -func (db *DB) BeginDangerously() *gorm.DB { - return db.cli.Begin() +func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB { + return db.cli.WithContext(ctx).Begin() } func (db *DB) Lock() { diff --git a/server/common.go b/server/common.go index 3905767..60d07e3 100644 --- a/server/common.go +++ b/server/common.go @@ -1,36 +1,38 @@ package server import ( + "context" + "github.com/haileyok/cocoon/models" ) -func (s *Server) getActorByHandle(handle string) (*models.Actor, error) { +func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) { var actor models.Actor - if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil { + if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil { return nil, err } return &actor, nil } -func (s *Server) getRepoByEmail(email string) (*models.Repo, error) { +func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) { var repo models.Repo - if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil { + if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil { return nil, err } return &repo, nil } -func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) { +func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) { var repo models.RepoActor - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { return nil, err } return &repo, nil } -func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) { +func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) { var repo models.RepoActor - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { return nil, err } return &repo, nil diff --git a/server/handle_account.go b/server/handle_account.go index a3a6f97..87cfbc0 100644 --- a/server/handle_account.go +++ b/server/handle_account.go @@ -12,6 +12,7 @@ import ( func (s *Server) handleAccount(e echo.Context) error { ctx := e.Request().Context() + repo, sess, err := s.getSessionRepoOrErr(e) if err != nil { return e.Redirect(303, "/account/signin") @@ -20,7 +21,7 @@ func (s *Server) handleAccount(e echo.Context) error { oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) var tokens []provider.OauthToken - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") sess.Save(e.Request(), e.Response()) diff --git a/server/handle_account_revoke.go b/server/handle_account_revoke.go index 5ab00e7..febb66a 100644 --- a/server/handle_account_revoke.go +++ b/server/handle_account_revoke.go @@ -5,12 +5,14 @@ import ( "github.com/labstack/echo/v4" ) -type AccountRevokeRequest struct { +type AccountRevokeInput struct { Token string `form:"token"` } func (s *Server) handleAccountRevoke(e echo.Context) error { - var req AccountRevokeRequest + ctx := e.Request().Context() + + var req AccountRevokeInput if err := e.Bind(&req); err != nil { s.logger.Error("could not bind account revoke request", "error", err) return helpers.ServerError(e, nil) @@ -21,7 +23,7 @@ func (s *Server) handleAccountRevoke(e echo.Context) error { return e.Redirect(303, "/account/signin") } - if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { + if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") sess.Save(e.Request(), e.Response()) diff --git a/server/handle_account_signin.go b/server/handle_account_signin.go index a27feda..57b082a 100644 --- a/server/handle_account_signin.go +++ b/server/handle_account_signin.go @@ -14,13 +14,15 @@ import ( "gorm.io/gorm" ) -type OauthSigninRequest struct { +type OauthSigninInput struct { Username string `form:"username"` Password string `form:"password"` QueryParams string `form:"query_params"` } func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { + ctx := e.Request().Context() + sess, err := session.Get("session", e) if err != nil { return nil, nil, err @@ -31,7 +33,7 @@ func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessio return nil, sess, errors.New("did was not set in session") } - repo, err := s.getRepoActorByDid(did) + repo, err := s.getRepoActorByDid(ctx, did) if err != nil { return nil, sess, err } @@ -60,7 +62,9 @@ func (s *Server) handleAccountSigninGet(e echo.Context) error { } func (s *Server) handleAccountSigninPost(e echo.Context) error { - var req OauthSigninRequest + ctx := e.Request().Context() + + var req OauthSigninInput if err := e.Bind(&req); err != nil { s.logger.Error("error binding sign in req", "error", err) return helpers.ServerError(e, nil) @@ -83,11 +87,11 @@ func (s *Server) handleAccountSigninPost(e echo.Context) error { var err error switch idtype { case "did": - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error case "handle": - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error case "email": - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error } if err != nil { if err == gorm.ErrRecordNotFound { diff --git a/server/handle_actor_put_preferences.go b/server/handle_actor_put_preferences.go index 2cb1423..2d2ed0e 100644 --- a/server/handle_actor_put_preferences.go +++ b/server/handle_actor_put_preferences.go @@ -10,6 +10,8 @@ import ( // This is kinda lame. Not great to implement app.bsky in the pds, but alas func (s *Server) handleActorPutPreferences(e echo.Context) error { + ctx := e.Request().Context() + repo := e.Get("repo").(*models.RepoActor) var prefs map[string]any @@ -22,7 +24,7 @@ func (s *Server) handleActorPutPreferences(e echo.Context) error { return err } - if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { return err } diff --git a/server/handle_identity_request_plc_operation.go b/server/handle_identity_request_plc_operation.go index 70eb6d9..028e805 100644 --- a/server/handle_identity_request_plc_operation.go +++ b/server/handle_identity_request_plc_operation.go @@ -10,12 +10,14 @@ import ( ) func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) eat := time.Now().Add(10 * time.Minute).UTC() - if err := s.db.Exec("UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating user", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_identity_sign_plc_operation.go b/server/handle_identity_sign_plc_operation.go index 859e714..feb2bbd 100644 --- a/server/handle_identity_sign_plc_operation.go +++ b/server/handle_identity_sign_plc_operation.go @@ -92,7 +92,7 @@ func (s *Server) handleSignPlcOperation(e echo.Context) error { return helpers.ServerError(e, nil) } - if err := s.db.Exec("UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { s.logger.Error("error updating repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_identity_update_handle.go b/server/handle_identity_update_handle.go index fa87be6..93a41c1 100644 --- a/server/handle_identity_update_handle.go +++ b/server/handle_identity_update_handle.go @@ -94,7 +94,7 @@ func (s *Server) handleIdentityUpdateHandle(e echo.Context) error { }, }) - if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { s.logger.Error("error updating handle in db", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_import_repo.go b/server/handle_import_repo.go index e762846..ea1b0c5 100644 --- a/server/handle_import_repo.go +++ b/server/handle_import_repo.go @@ -18,6 +18,8 @@ import ( ) func (s *Server) handleRepoImportRepo(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) b, err := io.ReadAll(e.Request().Body) @@ -63,7 +65,7 @@ func (s *Server) handleRepoImportRepo(e echo.Context) error { return helpers.ServerError(e, nil) } - tx := s.db.BeginDangerously() + tx := s.db.BeginDangerously(ctx) clock := syntax.NewTIDClock(0) diff --git a/server/handle_oauth_authorize.go b/server/handle_oauth_authorize.go index 05495db..19f5dee 100644 --- a/server/handle_oauth_authorize.go +++ b/server/handle_oauth_authorize.go @@ -13,6 +13,8 @@ import ( ) func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { + ctx := e.Request().Context() + reqUri := e.QueryParam("request_uri") if reqUri == "" { // render page for logged out dev @@ -38,7 +40,7 @@ func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { } var req provider.OauthAuthorizationRequest - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { return helpers.ServerError(e, to.StringPtr(err.Error())) } @@ -72,6 +74,8 @@ type OauthAuthorizePostRequest struct { } func (s *Server) handleOauthAuthorizePost(e echo.Context) error { + ctx := e.Request().Context() + repo, _, err := s.getSessionRepoOrErr(e) if err != nil { return e.Redirect(303, "/account/signin") @@ -89,7 +93,7 @@ func (s *Server) handleOauthAuthorizePost(e echo.Context) error { } var authReq provider.OauthAuthorizationRequest - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { return helpers.ServerError(e, to.StringPtr(err.Error())) } @@ -113,7 +117,7 @@ func (s *Server) handleOauthAuthorizePost(e echo.Context) error { code := oauth.GenerateCode() - if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { s.logger.Error("error updating authorization request", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_oauth_par.go b/server/handle_oauth_par.go index e3d01c6..37682cd 100644 --- a/server/handle_oauth_par.go +++ b/server/handle_oauth_par.go @@ -19,6 +19,8 @@ type OauthParResponse struct { } func (s *Server) handleOauthPar(e echo.Context) error { + ctx := e.Request().Context() + var parRequest provider.ParRequest if err := e.Bind(&parRequest); err != nil { s.logger.Error("error binding for par request", "error", err) @@ -86,7 +88,7 @@ func (s *Server) handleOauthPar(e echo.Context) error { ExpiresAt: eat, } - if err := s.db.Create(authRequest, nil).Error; err != nil { + if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { s.logger.Error("error creating auth request in db", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_oauth_token.go b/server/handle_oauth_token.go index acb9b57..905454b 100644 --- a/server/handle_oauth_token.go +++ b/server/handle_oauth_token.go @@ -38,6 +38,8 @@ type OauthTokenResponse struct { } func (s *Server) handleOauthToken(e echo.Context) error { + ctx := e.Request().Context() + var req OauthTokenRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding token request", "error", err) @@ -84,7 +86,7 @@ func (s *Server) handleOauthToken(e echo.Context) error { var authReq provider.OauthAuthorizationRequest // get the lil guy and delete him - if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { + if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { s.logger.Error("error finding authorization request", "error", err) return helpers.ServerError(e, nil) } @@ -128,7 +130,7 @@ func (s *Server) handleOauthToken(e echo.Context) error { return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) } - repo, err := s.getRepoActorByDid(*authReq.Sub) + repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) if err != nil { helpers.InputError(e, to.StringPtr("unable to find actor")) } @@ -159,7 +161,7 @@ func (s *Server) handleOauthToken(e echo.Context) error { return err } - if err := s.db.Create(&provider.OauthToken{ + if err := s.db.Create(ctx, &provider.OauthToken{ ClientId: authReq.ClientId, ClientAuth: *clientAuth, Parameters: authReq.Parameters, @@ -199,7 +201,7 @@ func (s *Server) handleOauthToken(e echo.Context) error { } var oauthToken provider.OauthToken - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) return helpers.ServerError(e, nil) } @@ -257,7 +259,7 @@ func (s *Server) handleOauthToken(e echo.Context) error { return err } - if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { s.logger.Error("error updating token", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_repo_describe_repo.go b/server/handle_repo_describe_repo.go index ffa530f..73b0f98 100644 --- a/server/handle_repo_describe_repo.go +++ b/server/handle_repo_describe_repo.go @@ -20,8 +20,10 @@ type ComAtprotoRepoDescribeRepoResponse struct { } func (s *Server) handleDescribeRepo(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("repo") - repo, err := s.getRepoActorByDid(did) + repo, err := s.getRepoActorByDid(ctx, did) if err != nil { if err == gorm.ErrRecordNotFound { return helpers.InputError(e, to.StringPtr("RepoNotFound")) @@ -64,7 +66,7 @@ func (s *Server) handleDescribeRepo(e echo.Context) error { } var records []models.Record - if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { s.logger.Error("error getting collections", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_repo_get_record.go b/server/handle_repo_get_record.go index 6b75857..5aae9f5 100644 --- a/server/handle_repo_get_record.go +++ b/server/handle_repo_get_record.go @@ -14,6 +14,8 @@ type ComAtprotoRepoGetRecordResponse struct { } func (s *Server) handleRepoGetRecord(e echo.Context) error { + ctx := e.Request().Context() + repo := e.QueryParam("repo") collection := e.QueryParam("collection") rkey := e.QueryParam("rkey") @@ -32,7 +34,7 @@ func (s *Server) handleRepoGetRecord(e echo.Context) error { } var record models.Record - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { // TODO: handle error nicely return err } diff --git a/server/handle_repo_list_missing_blobs.go b/server/handle_repo_list_missing_blobs.go index 2ab9098..ff468e7 100644 --- a/server/handle_repo_list_missing_blobs.go +++ b/server/handle_repo_list_missing_blobs.go @@ -22,6 +22,8 @@ type ComAtprotoRepoListMissingBlobsRecordBlob struct { } func (s *Server) handleListMissingBlobs(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) limitStr := e.QueryParam("limit") @@ -35,7 +37,7 @@ func (s *Server) handleListMissingBlobs(e echo.Context) error { } var records []models.Record - if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { s.logger.Error("failed to get records for listMissingBlobs", "error", err) return helpers.ServerError(e, nil) } @@ -69,7 +71,7 @@ func (s *Server) handleListMissingBlobs(e echo.Context) error { } var count int64 - if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { continue } diff --git a/server/handle_repo_list_records.go b/server/handle_repo_list_records.go index aa1063c..a2d7a61 100644 --- a/server/handle_repo_list_records.go +++ b/server/handle_repo_list_records.go @@ -46,6 +46,8 @@ func getLimitFromContext(e echo.Context, def int) (int, error) { } func (s *Server) handleListRecords(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoRepoListRecordsRequest if err := e.Bind(&req); err != nil { s.logger.Error("could not bind list records request", "error", err) @@ -78,7 +80,7 @@ func (s *Server) handleListRecords(e echo.Context) error { did := req.Repo if _, err := syntax.ParseDID(did); err != nil { - actor, err := s.getActorByHandle(req.Repo) + actor, err := s.getActorByHandle(ctx, req.Repo) if err != nil { return helpers.InputError(e, to.StringPtr("RepoNotFound")) } @@ -93,7 +95,7 @@ func (s *Server) handleListRecords(e echo.Context) error { params = append(params, limit) var records []models.Record - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { s.logger.Error("error getting records", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_repo_list_repos.go b/server/handle_repo_list_repos.go index f74dbc3..0a33022 100644 --- a/server/handle_repo_list_repos.go +++ b/server/handle_repo_list_repos.go @@ -21,8 +21,10 @@ type ComAtprotoSyncListReposRepoItem struct { // TODO: paginate this bitch func (s *Server) handleListRepos(e echo.Context) error { + ctx := e.Request().Context() + var repos []models.Repo - if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { return err } diff --git a/server/handle_repo_upload_blob.go b/server/handle_repo_upload_blob.go index 28d56b4..3d863a7 100644 --- a/server/handle_repo_upload_blob.go +++ b/server/handle_repo_upload_blob.go @@ -32,6 +32,8 @@ type ComAtprotoRepoUploadBlobResponse struct { } func (s *Server) handleRepoUploadBlob(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) mime := e.Request().Header.Get("content-type") @@ -51,7 +53,7 @@ func (s *Server) handleRepoUploadBlob(e echo.Context) error { Storage: storage, } - if err := s.db.Create(&blob, nil).Error; err != nil { + if err := s.db.Create(ctx, &blob, nil).Error; err != nil { s.logger.Error("error creating new blob in db", "error", err) return helpers.ServerError(e, nil) } @@ -84,7 +86,7 @@ func (s *Server) handleRepoUploadBlob(e echo.Context) error { Data: data, } - if err := s.db.Create(&blobPart, nil).Error; err != nil { + if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil { s.logger.Error("error adding blob part to db", "error", err) return helpers.ServerError(e, nil) } @@ -131,7 +133,7 @@ func (s *Server) handleRepoUploadBlob(e echo.Context) error { } } - if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { // there should probably be somme handling here if this fails... s.logger.Error("error updating blob", "error", err) return helpers.ServerError(e, nil) diff --git a/server/handle_server_activate_account.go b/server/handle_server_activate_account.go index 8207a53..c5d2106 100644 --- a/server/handle_server_activate_account.go +++ b/server/handle_server_activate_account.go @@ -18,6 +18,8 @@ type ComAtprotoServerActivateAccountRequest struct { } func (s *Server) handleServerActivateAccount(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerDeactivateAccountRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding", "error", err) @@ -26,7 +28,7 @@ func (s *Server) handleServerActivateAccount(e echo.Context) error { urepo := e.Get("repo").(*models.RepoActor) - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating account status to deactivated", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_check_account_status.go b/server/handle_server_check_account_status.go index 6b815bf..e9a76f2 100644 --- a/server/handle_server_check_account_status.go +++ b/server/handle_server_check_account_status.go @@ -20,6 +20,8 @@ type ComAtprotoServerCheckAccountStatusResponse struct { } func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) resp := ComAtprotoServerCheckAccountStatusResponse{ @@ -41,21 +43,21 @@ func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { } var blockCtResp CountResp - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { s.logger.Error("error getting block count", "error", err) return helpers.ServerError(e, nil) } resp.RepoBlocks = blockCtResp.Ct var recCtResp CountResp - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { s.logger.Error("error getting record count", "error", err) return helpers.ServerError(e, nil) } resp.IndexedRecords = recCtResp.Ct var blobCtResp CountResp - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { s.logger.Error("error getting record count", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_confirm_email.go b/server/handle_server_confirm_email.go index fe34cdc..5ab241a 100644 --- a/server/handle_server_confirm_email.go +++ b/server/handle_server_confirm_email.go @@ -15,6 +15,8 @@ type ComAtprotoServerConfirmEmailRequest struct { } func (s *Server) handleServerConfirmEmail(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) var req ComAtprotoServerConfirmEmailRequest @@ -41,7 +43,7 @@ func (s *Server) handleServerConfirmEmail(e echo.Context) error { now := time.Now().UTC() - if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating user", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_create_account.go b/server/handle_server_create_account.go index b534acc..c0916df 100644 --- a/server/handle_server_create_account.go +++ b/server/handle_server_create_account.go @@ -36,6 +36,8 @@ type ComAtprotoServerCreateAccountResponse struct { } func (s *Server) handleCreateAccount(e echo.Context) error { + ctx := e.Request().Context() + var request ComAtprotoServerCreateAccountRequest if err := e.Bind(&request); err != nil { @@ -68,11 +70,11 @@ func (s *Server) handleCreateAccount(e echo.Context) error { } } } - + var signupDid string if request.Did != nil { - signupDid = *request.Did; - + signupDid = *request.Did + token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) if token == "" { return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) @@ -90,7 +92,7 @@ func (s *Server) handleCreateAccount(e echo.Context) error { } // see if the handle is already taken - actor, err := s.getActorByHandle(request.Handle) + actor, err := s.getActorByHandle(ctx, request.Handle) if err != nil && err != gorm.ErrRecordNotFound { s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) return helpers.ServerError(e, nil) @@ -109,7 +111,7 @@ func (s *Server) handleCreateAccount(e echo.Context) error { return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) } - if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { if err == gorm.ErrRecordNotFound { return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) } @@ -123,7 +125,7 @@ func (s *Server) handleCreateAccount(e echo.Context) error { } // see if the email is already taken - existingRepo, err := s.getRepoByEmail(request.Email) + existingRepo, err := s.getRepoByEmail(ctx, request.Email) if err != nil && err != gorm.ErrRecordNotFound { s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) return helpers.ServerError(e, nil) @@ -137,7 +139,7 @@ func (s *Server) handleCreateAccount(e echo.Context) error { var k *atcrypto.PrivateKeyK256 if signupDid != "" { - reservedKey, err := s.getReservedKey(signupDid) + reservedKey, err := s.getReservedKey(ctx, signupDid) if err != nil { s.logger.Error("error looking up reserved key", "error", err) } @@ -148,7 +150,7 @@ func (s *Server) handleCreateAccount(e echo.Context) error { k = nil } else { defer func() { - if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil { + if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil { s.logger.Error("error deleting reserved key", "error", delErr) } }() @@ -199,17 +201,17 @@ func (s *Server) handleCreateAccount(e echo.Context) error { Handle: request.Handle, } - if err := s.db.Create(&urepo, nil).Error; err != nil { + if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { s.logger.Error("error inserting new repo", "error", err) return helpers.ServerError(e, nil) } - - if err := s.db.Create(&actor, nil).Error; err != nil { + + if err := s.db.Create(ctx, &actor, nil).Error; err != nil { s.logger.Error("error inserting new actor", "error", err) return helpers.ServerError(e, nil) } } else { - if err := s.db.Save(&actor, nil).Error; err != nil { + if err := s.db.Save(ctx, &actor, nil).Error; err != nil { s.logger.Error("error inserting new actor", "error", err) return helpers.ServerError(e, nil) } @@ -241,13 +243,13 @@ func (s *Server) handleCreateAccount(e echo.Context) error { } if s.config.RequireInvite { - if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { + if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { s.logger.Error("error decrementing use count", "error", err) return helpers.ServerError(e, nil) } } - sess, err := s.createSession(&urepo) + sess, err := s.createSession(ctx, &urepo) if err != nil { s.logger.Error("error creating new session", "error", err) return helpers.ServerError(e, nil) diff --git a/server/handle_server_create_invite_code.go b/server/handle_server_create_invite_code.go index 53c2570..5a77e83 100644 --- a/server/handle_server_create_invite_code.go +++ b/server/handle_server_create_invite_code.go @@ -17,6 +17,8 @@ type ComAtprotoServerCreateInviteCodeResponse struct { } func (s *Server) handleCreateInviteCode(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerCreateInviteCodeRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding", "error", err) @@ -37,7 +39,7 @@ func (s *Server) handleCreateInviteCode(e echo.Context) error { acc = *req.ForAccount } - if err := s.db.Create(&models.InviteCode{ + if err := s.db.Create(ctx, &models.InviteCode{ Code: ic, Did: acc, RemainingUseCount: req.UseCount, diff --git a/server/handle_server_create_invite_codes.go b/server/handle_server_create_invite_codes.go index 7d13fc5..e75e704 100644 --- a/server/handle_server_create_invite_codes.go +++ b/server/handle_server_create_invite_codes.go @@ -22,6 +22,8 @@ type ComAtprotoServerCreateInviteCodesItem struct { } func (s *Server) handleCreateInviteCodes(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerCreateInviteCodesRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding", "error", err) @@ -50,7 +52,7 @@ func (s *Server) handleCreateInviteCodes(e echo.Context) error { ic := uuid.NewString() ics = append(ics, ic) - if err := s.db.Create(&models.InviteCode{ + if err := s.db.Create(ctx, &models.InviteCode{ Code: ic, Did: did, RemainingUseCount: req.UseCount, diff --git a/server/handle_server_create_session.go b/server/handle_server_create_session.go index b09c705..645192d 100644 --- a/server/handle_server_create_session.go +++ b/server/handle_server_create_session.go @@ -32,6 +32,8 @@ type ComAtprotoServerCreateSessionResponse struct { } func (s *Server) handleCreateSession(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerCreateSessionRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) @@ -65,11 +67,11 @@ func (s *Server) handleCreateSession(e echo.Context) error { var err error switch idtype { case "did": - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error case "handle": - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error case "email": - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error } if err != nil { @@ -88,7 +90,7 @@ func (s *Server) handleCreateSession(e echo.Context) error { return helpers.InputError(e, to.StringPtr("InvalidRequest")) } - sess, err := s.createSession(&repo.Repo) + sess, err := s.createSession(ctx, &repo.Repo) if err != nil { s.logger.Error("error creating session", "error", err) return helpers.ServerError(e, nil) diff --git a/server/handle_server_deactivate_account.go b/server/handle_server_deactivate_account.go index 58972dd..02db068 100644 --- a/server/handle_server_deactivate_account.go +++ b/server/handle_server_deactivate_account.go @@ -19,6 +19,8 @@ type ComAtprotoServerDeactivateAccountRequest struct { } func (s *Server) handleServerDeactivateAccount(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerDeactivateAccountRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding", "error", err) @@ -27,7 +29,7 @@ func (s *Server) handleServerDeactivateAccount(e echo.Context) error { urepo := e.Get("repo").(*models.RepoActor) - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating account status to deactivated", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_delete_account.go b/server/handle_server_delete_account.go index 2cfb929..52c4e0a 100644 --- a/server/handle_server_delete_account.go +++ b/server/handle_server_delete_account.go @@ -20,6 +20,8 @@ type ComAtprotoServerDeleteAccountRequest struct { } func (s *Server) handleServerDeleteAccount(e echo.Context) error { + ctx := e.Request().Context() + var req ComAtprotoServerDeleteAccountRequest if err := e.Bind(&req); err != nil { s.logger.Error("error binding", "error", err) @@ -31,7 +33,7 @@ func (s *Server) handleServerDeleteAccount(e echo.Context) error { return helpers.ServerError(e, nil) } - urepo, err := s.getRepoActorByDid(req.Did) + urepo, err := s.getRepoActorByDid(ctx, req.Did) if err != nil { s.logger.Error("error getting repo", "error", err) return echo.NewHTTPError(400, "account not found") @@ -66,7 +68,7 @@ func (s *Server) handleServerDeleteAccount(e echo.Context) error { }) } - tx := s.db.BeginDangerously() + tx := s.db.BeginDangerously(ctx) if tx.Error != nil { s.logger.Error("error starting transaction", "error", tx.Error) return helpers.ServerError(e, nil) diff --git a/server/handle_server_delete_session.go b/server/handle_server_delete_session.go index bd054fb..7f2d5ba 100644 --- a/server/handle_server_delete_session.go +++ b/server/handle_server_delete_session.go @@ -7,15 +7,17 @@ import ( ) func (s *Server) handleDeleteSession(e echo.Context) error { + ctx := e.Request().Context() + token := e.Get("token").(string) var acctok models.Token - if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { + if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { s.logger.Error("error deleting access token from db", "error", err) return helpers.ServerError(e, nil) } - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { s.logger.Error("error deleting refresh token from db", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_refresh_session.go b/server/handle_server_refresh_session.go index f259cf7..34110b6 100644 --- a/server/handle_server_refresh_session.go +++ b/server/handle_server_refresh_session.go @@ -16,20 +16,22 @@ type ComAtprotoServerRefreshSessionResponse struct { } func (s *Server) handleRefreshSession(e echo.Context) error { + ctx := e.Request().Context() + token := e.Get("token").(string) repo := e.Get("repo").(*models.RepoActor) - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { s.logger.Error("error getting refresh token from db", "error", err) return helpers.ServerError(e, nil) } - if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { + if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { s.logger.Error("error deleting access token from db", "error", err) return helpers.ServerError(e, nil) } - sess, err := s.createSession(&repo.Repo) + sess, err := s.createSession(ctx, &repo.Repo) if err != nil { s.logger.Error("error creating new session for refresh", "error", err) return helpers.ServerError(e, nil) diff --git a/server/handle_server_request_account_delete.go b/server/handle_server_request_account_delete.go index b0a942e..ea8f8aa 100644 --- a/server/handle_server_request_account_delete.go +++ b/server/handle_server_request_account_delete.go @@ -10,12 +10,14 @@ import ( ) func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) expiresAt := time.Now().UTC().Add(15 * time.Minute) - if err := s.db.Exec("UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { s.logger.Error("error setting deletion token", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_request_email_confirmation.go b/server/handle_server_request_email_confirmation.go index d412f86..dce9981 100644 --- a/server/handle_server_request_email_confirmation.go +++ b/server/handle_server_request_email_confirmation.go @@ -11,6 +11,8 @@ import ( ) func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) if urepo.EmailConfirmedAt != nil { @@ -20,7 +22,7 @@ func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) eat := time.Now().Add(10 * time.Minute).UTC() - if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating user", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_request_email_update.go b/server/handle_server_request_email_update.go index 3af87e1..7ab4b48 100644 --- a/server/handle_server_request_email_update.go +++ b/server/handle_server_request_email_update.go @@ -14,13 +14,15 @@ type ComAtprotoRequestEmailUpdateResponse struct { } func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) if urepo.EmailConfirmedAt != nil { code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) eat := time.Now().Add(10 * time.Minute).UTC() - if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_request_password_reset.go b/server/handle_server_request_password_reset.go index 6339a18..2d90451 100644 --- a/server/handle_server_request_password_reset.go +++ b/server/handle_server_request_password_reset.go @@ -14,6 +14,8 @@ type ComAtprotoServerRequestPasswordResetRequest struct { } func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { + ctx := e.Request().Context() + urepo, ok := e.Get("repo").(*models.RepoActor) if !ok { var req ComAtprotoServerRequestPasswordResetRequest @@ -25,7 +27,7 @@ func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { return err } - murepo, err := s.getRepoActorByEmail(req.Email) + murepo, err := s.getRepoActorByEmail(ctx, req.Email) if err != nil { return err } @@ -36,7 +38,7 @@ func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) eat := time.Now().Add(10 * time.Minute).UTC() - if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_reserve_signing_key.go b/server/handle_server_reserve_signing_key.go index 1c9d55d..036be81 100644 --- a/server/handle_server_reserve_signing_key.go +++ b/server/handle_server_reserve_signing_key.go @@ -1,6 +1,7 @@ package server import ( + "context" "time" "github.com/bluesky-social/indigo/atproto/atcrypto" @@ -18,6 +19,8 @@ type ServerReserveSigningKeyResponse struct { } func (s *Server) handleServerReserveSigningKey(e echo.Context) error { + ctx := e.Request().Context() + var req ServerReserveSigningKeyRequest if err := e.Bind(&req); err != nil { s.logger.Error("could not bind reserve signing key request", "error", err) @@ -26,7 +29,7 @@ func (s *Server) handleServerReserveSigningKey(e echo.Context) error { if req.Did != nil && *req.Did != "" { var existing models.ReservedKey - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { return e.JSON(200, ServerReserveSigningKeyResponse{ SigningKey: existing.KeyDid, }) @@ -54,7 +57,7 @@ func (s *Server) handleServerReserveSigningKey(e echo.Context) error { CreatedAt: time.Now(), } - if err := s.db.Create(&reservedKey, nil).Error; err != nil { + if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil { s.logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) return helpers.ServerError(e, nil) } @@ -66,27 +69,27 @@ func (s *Server) handleServerReserveSigningKey(e echo.Context) error { }) } -func (s *Server) getReservedKey(keyDidOrDid string) (*models.ReservedKey, error) { +func (s *Server) getReservedKey(ctx context.Context, keyDidOrDid string) (*models.ReservedKey, error) { var reservedKey models.ReservedKey - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { return &reservedKey, nil } - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { return &reservedKey, nil } return nil, nil } -func (s *Server) deleteReservedKey(keyDid string, did *string) error { - if err := s.db.Exec("DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { +func (s *Server) deleteReservedKey(ctx context.Context, keyDid string, did *string) error { + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { return err } if did != nil && *did != "" { - if err := s.db.Exec("DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { return err } } diff --git a/server/handle_server_reset_password.go b/server/handle_server_reset_password.go index 941e455..473b5de 100644 --- a/server/handle_server_reset_password.go +++ b/server/handle_server_reset_password.go @@ -16,6 +16,8 @@ type ComAtprotoServerResetPasswordRequest struct { } func (s *Server) handleServerResetPassword(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) var req ComAtprotoServerResetPasswordRequest @@ -46,7 +48,7 @@ func (s *Server) handleServerResetPassword(e echo.Context) error { return helpers.ServerError(e, nil) } - if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_server_update_email.go b/server/handle_server_update_email.go index cdb7ba7..c94e86a 100644 --- a/server/handle_server_update_email.go +++ b/server/handle_server_update_email.go @@ -15,6 +15,8 @@ type ComAtprotoServerUpdateEmailRequest struct { } func (s *Server) handleServerUpdateEmail(e echo.Context) error { + ctx := e.Request().Context() + urepo := e.Get("repo").(*models.RepoActor) var req ComAtprotoServerUpdateEmailRequest @@ -39,7 +41,7 @@ func (s *Server) handleServerUpdateEmail(e echo.Context) error { return helpers.ExpiredTokenError(e) } - if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { s.logger.Error("error updating repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_sync_get_blob.go b/server/handle_sync_get_blob.go index 5297cbe..e136d5a 100644 --- a/server/handle_sync_get_blob.go +++ b/server/handle_sync_get_blob.go @@ -17,6 +17,8 @@ import ( ) func (s *Server) handleSyncGetBlob(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("did") if did == "" { return helpers.InputError(e, nil) @@ -32,7 +34,7 @@ func (s *Server) handleSyncGetBlob(e echo.Context) error { return helpers.InputError(e, nil) } - urepo, err := s.getRepoActorByDid(did) + urepo, err := s.getRepoActorByDid(ctx, did) if err != nil { s.logger.Error("could not find user for requested blob", "error", err) return helpers.InputError(e, nil) @@ -46,7 +48,7 @@ func (s *Server) handleSyncGetBlob(e echo.Context) error { } var blob models.Blob - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { s.logger.Error("error looking up blob", "error", err) return helpers.ServerError(e, nil) } @@ -55,7 +57,7 @@ func (s *Server) handleSyncGetBlob(e echo.Context) error { if blob.Storage == "sqlite" { var parts []models.BlobPart - if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { s.logger.Error("error getting blob parts", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_sync_get_blocks.go b/server/handle_sync_get_blocks.go index 7fe5d48..e9e0157 100644 --- a/server/handle_sync_get_blocks.go +++ b/server/handle_sync_get_blocks.go @@ -35,7 +35,7 @@ func (s *Server) handleGetBlocks(e echo.Context) error { cids = append(cids, c) } - urepo, err := s.getRepoActorByDid(req.Did) + urepo, err := s.getRepoActorByDid(ctx, req.Did) if err != nil { return helpers.ServerError(e, nil) } diff --git a/server/handle_sync_get_latest_commit.go b/server/handle_sync_get_latest_commit.go index d756c9a..176be68 100644 --- a/server/handle_sync_get_latest_commit.go +++ b/server/handle_sync_get_latest_commit.go @@ -12,12 +12,14 @@ type ComAtprotoSyncGetLatestCommitResponse struct { } func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("did") if did == "" { return helpers.InputError(e, nil) } - urepo, err := s.getRepoActorByDid(did) + urepo, err := s.getRepoActorByDid(ctx, did) if err != nil { return err } diff --git a/server/handle_sync_get_record.go b/server/handle_sync_get_record.go index c64e62b..cc97595 100644 --- a/server/handle_sync_get_record.go +++ b/server/handle_sync_get_record.go @@ -20,7 +20,7 @@ func (s *Server) handleSyncGetRecord(e echo.Context) error { rkey := e.QueryParam("rkey") var urepo models.Repo - if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { s.logger.Error("error getting repo", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_sync_get_repo.go b/server/handle_sync_get_repo.go index f862233..dacc02b 100644 --- a/server/handle_sync_get_repo.go +++ b/server/handle_sync_get_repo.go @@ -13,12 +13,14 @@ import ( ) func (s *Server) handleSyncGetRepo(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("did") if did == "" { return helpers.InputError(e, nil) } - urepo, err := s.getRepoActorByDid(did) + urepo, err := s.getRepoActorByDid(ctx, did) if err != nil { return err } @@ -41,7 +43,7 @@ func (s *Server) handleSyncGetRepo(e echo.Context) error { } var blocks []models.Block - if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { return err } diff --git a/server/handle_sync_get_repo_status.go b/server/handle_sync_get_repo_status.go index 9d701e3..442f83e 100644 --- a/server/handle_sync_get_repo_status.go +++ b/server/handle_sync_get_repo_status.go @@ -14,12 +14,14 @@ type ComAtprotoSyncGetRepoStatusResponse struct { // TODO: make this actually do the right thing func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("did") if did == "" { return helpers.InputError(e, nil) } - urepo, err := s.getRepoActorByDid(did) + urepo, err := s.getRepoActorByDid(ctx, did) if err != nil { return err } diff --git a/server/handle_sync_list_blobs.go b/server/handle_sync_list_blobs.go index 2261a6b..a5a0aca 100644 --- a/server/handle_sync_list_blobs.go +++ b/server/handle_sync_list_blobs.go @@ -14,6 +14,8 @@ type ComAtprotoSyncListBlobsResponse struct { } func (s *Server) handleSyncListBlobs(e echo.Context) error { + ctx := e.Request().Context() + did := e.QueryParam("did") if did == "" { return helpers.InputError(e, nil) @@ -35,7 +37,7 @@ func (s *Server) handleSyncListBlobs(e echo.Context) error { } params = append(params, limit) - urepo, err := s.getRepoActorByDid(did) + urepo, err := s.getRepoActorByDid(ctx, did) if err != nil { s.logger.Error("could not find user for requested blobs", "error", err) return helpers.InputError(e, nil) @@ -49,7 +51,7 @@ func (s *Server) handleSyncListBlobs(e echo.Context) error { } var blobs []models.Blob - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { s.logger.Error("error getting records", "error", err) return helpers.ServerError(e, nil) } diff --git a/server/handle_well_known.go b/server/handle_well_known.go index 565bbeb..cd4091a 100644 --- a/server/handle_well_known.go +++ b/server/handle_well_known.go @@ -67,6 +67,8 @@ func (s *Server) handleWellKnown(e echo.Context) error { } func (s *Server) handleAtprotoDid(e echo.Context) error { + ctx := e.Request().Context() + host := e.Request().Host if host == "" { return helpers.InputError(e, to.StringPtr("Invalid handle.")) @@ -84,7 +86,7 @@ func (s *Server) handleAtprotoDid(e echo.Context) error { return e.NoContent(404) } - actor, err := s.getActorByHandle(host) + actor, err := s.getActorByHandle(ctx, host) if err != nil { if err == gorm.ErrRecordNotFound { return e.NoContent(404) diff --git a/server/middleware.go b/server/middleware.go index 58b6039..f379738 100644 --- a/server/middleware.go +++ b/server/middleware.go @@ -37,6 +37,8 @@ func (s *Server) handleAdminMiddleware(next echo.HandlerFunc) echo.HandlerFunc { func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(e echo.Context) error { + ctx := e.Request().Context() + authheader := e.Request().Header.Get("authorization") if authheader == "" { return e.JSON(401, map[string]string{"error": "Unauthorized"}) @@ -78,7 +80,7 @@ func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.Handl } did = maybeDid - maybeRepo, err := s.getRepoActorByDid(did) + maybeRepo, err := s.getRepoActorByDid(ctx, did) if err != nil { s.logger.Error("error fetching repo", "error", err) return helpers.ServerError(e, nil) @@ -159,7 +161,7 @@ func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.Handl Found bool } var result Result - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { if err == gorm.ErrRecordNotFound { return helpers.InvalidTokenError(e) } @@ -184,7 +186,7 @@ func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.Handl } if repo == nil { - maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) + maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) if err != nil { s.logger.Error("error fetching repo", "error", err) return helpers.ServerError(e, nil) @@ -207,6 +209,8 @@ func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.Handl func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { return func(e echo.Context) error { + ctx := e.Request().Context() + authheader := e.Request().Header.Get("authorization") if authheader == "" { return e.JSON(401, map[string]string{"error": "Unauthorized"}) @@ -243,7 +247,7 @@ func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.Handle } var oauthToken provider.OauthToken - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { s.logger.Error("error finding access token in db", "error", err) return helpers.InputError(e, nil) } @@ -266,7 +270,7 @@ func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.Handle }) } - repo, err := s.getRepoActorByDid(oauthToken.Sub) + repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) if err != nil { s.logger.Error("could not find actor in db", "error", err) return helpers.ServerError(e, nil) diff --git a/server/repo.go b/server/repo.go index 5d8f7a4..6d32e48 100644 --- a/server/repo.go +++ b/server/repo.go @@ -181,7 +181,7 @@ func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes [] case OpTypeDelete: // try to find the old record in the database var old models.Record - if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { + if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { return nil, err } @@ -323,7 +323,7 @@ func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes [] var cids []cid.Cid // whenever there is cid present, we know it's a create (dumb) if entry.Cid != "" { - if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ + if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{ Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, UpdateAll: true, }}).Error; err != nil { @@ -331,7 +331,7 @@ func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes [] } // increment the given blob refs, yay - cids, err = rm.incrementBlobRefs(urepo, entry.Value) + cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value) if err != nil { return nil, err } @@ -339,12 +339,12 @@ func (rm *RepoMan) applyWrites(ctx context.Context, urepo models.Repo, writes [] // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey // is did + collection + rkey. i still really want to separate that out, or use a different type to make // this less confusing/easy to read. alas, its 2 am and yea no - if err := rm.s.db.Delete(&entry, nil).Error; err != nil { + if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil { return nil, err } // TODO: - cids, err = rm.decrementBlobRefs(urepo, entry.Value) + cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value) if err != nil { return nil, err } @@ -411,14 +411,14 @@ func (rm *RepoMan) getRecordProof(ctx context.Context, urepo models.Repo, collec return c, bs.GetReadLog(), nil } -func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { +func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { cids, err := getBlobCidsFromCbor(cbor) if err != nil { return nil, err } for _, c := range cids { - if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { + if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { return nil, err } } @@ -426,7 +426,7 @@ func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, return cids, nil } -func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { +func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { cids, err := getBlobCidsFromCbor(cbor) if err != nil { return nil, err @@ -437,17 +437,17 @@ func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, ID uint Count int } - if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { + if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { return nil, err } // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what // storage it is in, and clean up s3!!!! if res.Count == 0 { - if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { + if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { return nil, err } - if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { + if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { return nil, err } } diff --git a/server/server.go b/server/server.go index 49340f9..59ddc35 100644 --- a/server/server.go +++ b/server/server.go @@ -729,7 +729,7 @@ func (s *Server) backupRoutine() { } func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { - if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { + if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { return err } diff --git a/server/session.go b/server/session.go index 2304cf1..5db2677 100644 --- a/server/session.go +++ b/server/session.go @@ -1,6 +1,7 @@ package server import ( + "context" "time" "github.com/golang-jwt/jwt/v4" @@ -13,7 +14,7 @@ type Session struct { RefreshToken string } -func (s *Server) createSession(repo *models.Repo) (*Session, error) { +func (s *Server) createSession(ctx context.Context, repo *models.Repo) (*Session, error) { now := time.Now() accexp := now.Add(3 * time.Hour) refexp := now.Add(7 * 24 * time.Hour) @@ -49,7 +50,7 @@ func (s *Server) createSession(repo *models.Repo) (*Session, error) { return nil, err } - if err := s.db.Create(&models.Token{ + if err := s.db.Create(ctx, &models.Token{ Token: accessString, Did: repo.Did, RefreshToken: refreshString, @@ -59,7 +60,7 @@ func (s *Server) createSession(repo *models.Repo) (*Session, error) { return nil, err } - if err := s.db.Create(&models.RefreshToken{ + if err := s.db.Create(ctx, &models.RefreshToken{ Token: refreshString, Did: repo.Did, CreatedAt: now, diff --git a/sqlite_blockstore/sqlite_blockstore.go b/sqlite_blockstore/sqlite_blockstore.go index c670280..13a66cf 100644 --- a/sqlite_blockstore/sqlite_blockstore.go +++ b/sqlite_blockstore/sqlite_blockstore.go @@ -45,7 +45,7 @@ func (bs *SqliteBlockstore) Get(ctx context.Context, cid cid.Cid) (blocks.Block, return maybeBlock, nil } - if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { + if err := bs.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { return nil, err } @@ -71,7 +71,7 @@ func (bs *SqliteBlockstore) Put(ctx context.Context, block blocks.Block) error { Value: block.RawData(), } - if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ + if err := bs.db.Create(ctx, &b, []clause.Expression{clause.OnConflict{ Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, UpdateAll: true, }}).Error; err != nil { @@ -94,7 +94,7 @@ func (bs *SqliteBlockstore) GetSize(context.Context, cid.Cid) (int, error) { } func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { - tx := bs.db.BeginDangerously() + tx := bs.db.BeginDangerously(ctx) for _, block := range blocks { bs.inserts[block.Cid()] = block -- 2.43.0