An atproto PDS written in Go

add withcontxt to all db calls

authored by hailey.at and committed by Tangled a3703416 d91516f9

+15 -14
internal/db/db.go
··· 1 1 package db 2 2 3 3 import ( 4 + "context" 4 5 "sync" 5 6 6 7 "gorm.io/gorm" ··· 19 20 } 20 21 } 21 22 22 - func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB { 23 + func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 23 24 db.mu.Lock() 24 25 defer db.mu.Unlock() 25 - return db.cli.Clauses(clauses...).Create(value) 26 + return db.cli.WithContext(ctx).Clauses(clauses...).Create(value) 26 27 } 27 28 28 - func (db *DB) Save(value any, clauses []clause.Expression) *gorm.DB { 29 + func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 29 30 db.mu.Lock() 30 31 defer db.mu.Unlock() 31 - return db.cli.Clauses(clauses...).Save(value) 32 + return db.cli.WithContext(ctx).Clauses(clauses...).Save(value) 32 33 } 33 34 34 - func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 35 + func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 35 36 db.mu.Lock() 36 37 defer db.mu.Unlock() 37 - return db.cli.Clauses(clauses...).Exec(sql, values...) 38 + return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...) 38 39 } 39 40 40 - func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 41 - return db.cli.Clauses(clauses...).Raw(sql, values...) 41 + func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 42 + return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...) 42 43 } 43 44 44 45 func (db *DB) AutoMigrate(models ...any) error { 45 46 return db.cli.AutoMigrate(models...) 46 47 } 47 48 48 - func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB { 49 + func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 49 50 db.mu.Lock() 50 51 defer db.mu.Unlock() 51 - return db.cli.Clauses(clauses...).Delete(value) 52 + return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value) 52 53 } 53 54 54 - func (db *DB) First(dest any, conds ...any) *gorm.DB { 55 - return db.cli.First(dest, conds...) 55 + func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB { 56 + return db.cli.WithContext(ctx).First(dest, conds...) 56 57 } 57 58 58 59 // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure 59 60 // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. 60 61 // e.g. when we do apply writes we should also be using a transcation but we don't right now 61 - func (db *DB) BeginDangerously() *gorm.DB { 62 - return db.cli.Begin() 62 + func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB { 63 + return db.cli.WithContext(ctx).Begin() 63 64 } 64 65 65 66 func (db *DB) Lock() {
+10 -8
server/common.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 5 + 4 6 "github.com/haileyok/cocoon/models" 5 7 ) 6 8 7 - func (s *Server) getActorByHandle(handle string) (*models.Actor, error) { 9 + func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) { 8 10 var actor models.Actor 9 - if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil { 11 + if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil { 10 12 return nil, err 11 13 } 12 14 return &actor, nil 13 15 } 14 16 15 - func (s *Server) getRepoByEmail(email string) (*models.Repo, error) { 17 + func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) { 16 18 var repo models.Repo 17 - if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil { 19 + if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil { 18 20 return nil, err 19 21 } 20 22 return &repo, nil 21 23 } 22 24 23 - func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) { 25 + func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) { 24 26 var repo models.RepoActor 25 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 27 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 26 28 return nil, err 27 29 } 28 30 return &repo, nil 29 31 } 30 32 31 - func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) { 33 + func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) { 32 34 var repo models.RepoActor 33 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 35 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 34 36 return nil, err 35 37 } 36 38 return &repo, nil
+2 -1
server/handle_account.go
··· 12 12 13 13 func (s *Server) handleAccount(e echo.Context) error { 14 14 ctx := e.Request().Context() 15 + 15 16 repo, sess, err := s.getSessionRepoOrErr(e) 16 17 if err != nil { 17 18 return e.Redirect(303, "/account/signin") ··· 20 21 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) 21 22 22 23 var tokens []provider.OauthToken 23 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 24 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 24 25 s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 25 26 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 26 27 sess.Save(e.Request(), e.Response())
+5 -3
server/handle_account_revoke.go
··· 5 5 "github.com/labstack/echo/v4" 6 6 ) 7 7 8 - type AccountRevokeRequest struct { 8 + type AccountRevokeInput struct { 9 9 Token string `form:"token"` 10 10 } 11 11 12 12 func (s *Server) handleAccountRevoke(e echo.Context) error { 13 - var req AccountRevokeRequest 13 + ctx := e.Request().Context() 14 + 15 + var req AccountRevokeInput 14 16 if err := e.Bind(&req); err != nil { 15 17 s.logger.Error("could not bind account revoke request", "error", err) 16 18 return helpers.ServerError(e, nil) ··· 21 23 return e.Redirect(303, "/account/signin") 22 24 } 23 25 24 - if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 26 + if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 25 27 s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 26 28 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") 27 29 sess.Save(e.Request(), e.Response())
+10 -6
server/handle_account_signin.go
··· 14 14 "gorm.io/gorm" 15 15 ) 16 16 17 - type OauthSigninRequest struct { 17 + type OauthSigninInput struct { 18 18 Username string `form:"username"` 19 19 Password string `form:"password"` 20 20 QueryParams string `form:"query_params"` 21 21 } 22 22 23 23 func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 24 + ctx := e.Request().Context() 25 + 24 26 sess, err := session.Get("session", e) 25 27 if err != nil { 26 28 return nil, nil, err ··· 31 33 return nil, sess, errors.New("did was not set in session") 32 34 } 33 35 34 - repo, err := s.getRepoActorByDid(did) 36 + repo, err := s.getRepoActorByDid(ctx, did) 35 37 if err != nil { 36 38 return nil, sess, err 37 39 } ··· 60 62 } 61 63 62 64 func (s *Server) handleAccountSigninPost(e echo.Context) error { 63 - var req OauthSigninRequest 65 + ctx := e.Request().Context() 66 + 67 + var req OauthSigninInput 64 68 if err := e.Bind(&req); err != nil { 65 69 s.logger.Error("error binding sign in req", "error", err) 66 70 return helpers.ServerError(e, nil) ··· 83 87 var err error 84 88 switch idtype { 85 89 case "did": 86 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 90 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 87 91 case "handle": 88 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 92 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 89 93 case "email": 90 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 94 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 91 95 } 92 96 if err != nil { 93 97 if err == gorm.ErrRecordNotFound {
+3 -1
server/handle_actor_put_preferences.go
··· 10 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 11 12 12 func (s *Server) handleActorPutPreferences(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 13 15 repo := e.Get("repo").(*models.RepoActor) 14 16 15 17 var prefs map[string]any ··· 22 24 return err 23 25 } 24 26 25 - if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 27 + if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 26 28 return err 27 29 } 28 30
+3 -1
server/handle_identity_request_plc_operation.go
··· 10 10 ) 11 11 12 12 func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 13 15 urepo := e.Get("repo").(*models.RepoActor) 14 16 15 17 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 16 18 eat := time.Now().Add(10 * time.Minute).UTC() 17 19 18 - if err := s.db.Exec("UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 20 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 19 21 s.logger.Error("error updating user", "error", err) 20 22 return helpers.ServerError(e, nil) 21 23 }
+1 -1
server/handle_identity_sign_plc_operation.go
··· 92 92 return helpers.ServerError(e, nil) 93 93 } 94 94 95 - if err := s.db.Exec("UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 95 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 96 96 s.logger.Error("error updating repo", "error", err) 97 97 return helpers.ServerError(e, nil) 98 98 }
+1 -1
server/handle_identity_update_handle.go
··· 94 94 }, 95 95 }) 96 96 97 - if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 97 + if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 98 98 s.logger.Error("error updating handle in db", "error", err) 99 99 return helpers.ServerError(e, nil) 100 100 }
+3 -1
server/handle_import_repo.go
··· 18 18 ) 19 19 20 20 func (s *Server) handleRepoImportRepo(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + 21 23 urepo := e.Get("repo").(*models.RepoActor) 22 24 23 25 b, err := io.ReadAll(e.Request().Body) ··· 63 65 return helpers.ServerError(e, nil) 64 66 } 65 67 66 - tx := s.db.BeginDangerously() 68 + tx := s.db.BeginDangerously(ctx) 67 69 68 70 clock := syntax.NewTIDClock(0) 69 71
+7 -3
server/handle_oauth_authorize.go
··· 13 13 ) 14 14 15 15 func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + 16 18 reqUri := e.QueryParam("request_uri") 17 19 if reqUri == "" { 18 20 // render page for logged out dev ··· 38 40 } 39 41 40 42 var req provider.OauthAuthorizationRequest 41 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 43 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 42 44 return helpers.ServerError(e, to.StringPtr(err.Error())) 43 45 } 44 46 ··· 72 74 } 73 75 74 76 func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 77 + ctx := e.Request().Context() 78 + 75 79 repo, _, err := s.getSessionRepoOrErr(e) 76 80 if err != nil { 77 81 return e.Redirect(303, "/account/signin") ··· 89 93 } 90 94 91 95 var authReq provider.OauthAuthorizationRequest 92 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 96 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 93 97 return helpers.ServerError(e, to.StringPtr(err.Error())) 94 98 } 95 99 ··· 113 117 114 118 code := oauth.GenerateCode() 115 119 116 - if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 120 + if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 117 121 s.logger.Error("error updating authorization request", "error", err) 118 122 return helpers.ServerError(e, nil) 119 123 }
+3 -1
server/handle_oauth_par.go
··· 19 19 } 20 20 21 21 func (s *Server) handleOauthPar(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 22 24 var parRequest provider.ParRequest 23 25 if err := e.Bind(&parRequest); err != nil { 24 26 s.logger.Error("error binding for par request", "error", err) ··· 86 88 ExpiresAt: eat, 87 89 } 88 90 89 - if err := s.db.Create(authRequest, nil).Error; err != nil { 91 + if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 90 92 s.logger.Error("error creating auth request in db", "error", err) 91 93 return helpers.ServerError(e, nil) 92 94 }
+7 -5
server/handle_oauth_token.go
··· 38 38 } 39 39 40 40 func (s *Server) handleOauthToken(e echo.Context) error { 41 + ctx := e.Request().Context() 42 + 41 43 var req OauthTokenRequest 42 44 if err := e.Bind(&req); err != nil { 43 45 s.logger.Error("error binding token request", "error", err) ··· 84 86 85 87 var authReq provider.OauthAuthorizationRequest 86 88 // get the lil guy and delete him 87 - if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 89 + if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 88 90 s.logger.Error("error finding authorization request", "error", err) 89 91 return helpers.ServerError(e, nil) 90 92 } ··· 128 130 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 129 131 } 130 132 131 - repo, err := s.getRepoActorByDid(*authReq.Sub) 133 + repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) 132 134 if err != nil { 133 135 helpers.InputError(e, to.StringPtr("unable to find actor")) 134 136 } ··· 159 161 return err 160 162 } 161 163 162 - if err := s.db.Create(&provider.OauthToken{ 164 + if err := s.db.Create(ctx, &provider.OauthToken{ 163 165 ClientId: authReq.ClientId, 164 166 ClientAuth: *clientAuth, 165 167 Parameters: authReq.Parameters, ··· 199 201 } 200 202 201 203 var oauthToken provider.OauthToken 202 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 204 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 203 205 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 204 206 return helpers.ServerError(e, nil) 205 207 } ··· 257 259 return err 258 260 } 259 261 260 - if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 262 + if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 261 263 s.logger.Error("error updating token", "error", err) 262 264 return helpers.ServerError(e, nil) 263 265 }
+4 -2
server/handle_repo_describe_repo.go
··· 20 20 } 21 21 22 22 func (s *Server) handleDescribeRepo(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 23 25 did := e.QueryParam("repo") 24 - repo, err := s.getRepoActorByDid(did) 26 + repo, err := s.getRepoActorByDid(ctx, did) 25 27 if err != nil { 26 28 if err == gorm.ErrRecordNotFound { 27 29 return helpers.InputError(e, to.StringPtr("RepoNotFound")) ··· 64 66 } 65 67 66 68 var records []models.Record 67 - if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 69 + if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 68 70 s.logger.Error("error getting collections", "error", err) 69 71 return helpers.ServerError(e, nil) 70 72 }
+3 -1
server/handle_repo_get_record.go
··· 14 14 } 15 15 16 16 func (s *Server) handleRepoGetRecord(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 repo := e.QueryParam("repo") 18 20 collection := e.QueryParam("collection") 19 21 rkey := e.QueryParam("rkey") ··· 32 34 } 33 35 34 36 var record models.Record 35 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 37 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 36 38 // TODO: handle error nicely 37 39 return err 38 40 }
+4 -2
server/handle_repo_list_missing_blobs.go
··· 22 22 } 23 23 24 24 func (s *Server) handleListMissingBlobs(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + 25 27 urepo := e.Get("repo").(*models.RepoActor) 26 28 27 29 limitStr := e.QueryParam("limit") ··· 35 37 } 36 38 37 39 var records []models.Record 38 - if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 40 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 39 41 s.logger.Error("failed to get records for listMissingBlobs", "error", err) 40 42 return helpers.ServerError(e, nil) 41 43 } ··· 69 71 } 70 72 71 73 var count int64 72 - if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { 74 + if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { 73 75 continue 74 76 } 75 77
+4 -2
server/handle_repo_list_records.go
··· 46 46 } 47 47 48 48 func (s *Server) handleListRecords(e echo.Context) error { 49 + ctx := e.Request().Context() 50 + 49 51 var req ComAtprotoRepoListRecordsRequest 50 52 if err := e.Bind(&req); err != nil { 51 53 s.logger.Error("could not bind list records request", "error", err) ··· 78 80 79 81 did := req.Repo 80 82 if _, err := syntax.ParseDID(did); err != nil { 81 - actor, err := s.getActorByHandle(req.Repo) 83 + actor, err := s.getActorByHandle(ctx, req.Repo) 82 84 if err != nil { 83 85 return helpers.InputError(e, to.StringPtr("RepoNotFound")) 84 86 } ··· 93 95 params = append(params, limit) 94 96 95 97 var records []models.Record 96 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 98 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 97 99 s.logger.Error("error getting records", "error", err) 98 100 return helpers.ServerError(e, nil) 99 101 }
+3 -1
server/handle_repo_list_repos.go
··· 21 21 22 22 // TODO: paginate this bitch 23 23 func (s *Server) handleListRepos(e echo.Context) error { 24 + ctx := e.Request().Context() 25 + 24 26 var repos []models.Repo 25 - if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 27 + if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 26 28 return err 27 29 } 28 30
+5 -3
server/handle_repo_upload_blob.go
··· 32 32 } 33 33 34 34 func (s *Server) handleRepoUploadBlob(e echo.Context) error { 35 + ctx := e.Request().Context() 36 + 35 37 urepo := e.Get("repo").(*models.RepoActor) 36 38 37 39 mime := e.Request().Header.Get("content-type") ··· 51 53 Storage: storage, 52 54 } 53 55 54 - if err := s.db.Create(&blob, nil).Error; err != nil { 56 + if err := s.db.Create(ctx, &blob, nil).Error; err != nil { 55 57 s.logger.Error("error creating new blob in db", "error", err) 56 58 return helpers.ServerError(e, nil) 57 59 } ··· 84 86 Data: data, 85 87 } 86 88 87 - if err := s.db.Create(&blobPart, nil).Error; err != nil { 89 + if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil { 88 90 s.logger.Error("error adding blob part to db", "error", err) 89 91 return helpers.ServerError(e, nil) 90 92 } ··· 131 133 } 132 134 } 133 135 134 - if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 136 + if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 135 137 // there should probably be somme handling here if this fails... 136 138 s.logger.Error("error updating blob", "error", err) 137 139 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_activate_account.go
··· 18 18 } 19 19 20 20 func (s *Server) handleServerActivateAccount(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + 21 23 var req ComAtprotoServerDeactivateAccountRequest 22 24 if err := e.Bind(&req); err != nil { 23 25 s.logger.Error("error binding", "error", err) ··· 26 28 27 29 urepo := e.Get("repo").(*models.RepoActor) 28 30 29 - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 31 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 30 32 s.logger.Error("error updating account status to deactivated", "error", err) 31 33 return helpers.ServerError(e, nil) 32 34 }
+5 -3
server/handle_server_check_account_status.go
··· 20 20 } 21 21 22 22 func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 23 25 urepo := e.Get("repo").(*models.RepoActor) 24 26 25 27 resp := ComAtprotoServerCheckAccountStatusResponse{ ··· 41 43 } 42 44 43 45 var blockCtResp CountResp 44 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 46 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 45 47 s.logger.Error("error getting block count", "error", err) 46 48 return helpers.ServerError(e, nil) 47 49 } 48 50 resp.RepoBlocks = blockCtResp.Ct 49 51 50 52 var recCtResp CountResp 51 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 53 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 52 54 s.logger.Error("error getting record count", "error", err) 53 55 return helpers.ServerError(e, nil) 54 56 } 55 57 resp.IndexedRecords = recCtResp.Ct 56 58 57 59 var blobCtResp CountResp 58 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 60 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 59 61 s.logger.Error("error getting record count", "error", err) 60 62 return helpers.ServerError(e, nil) 61 63 }
+3 -1
server/handle_server_confirm_email.go
··· 15 15 } 16 16 17 17 func (s *Server) handleServerConfirmEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + 18 20 urepo := e.Get("repo").(*models.RepoActor) 19 21 20 22 var req ComAtprotoServerConfirmEmailRequest ··· 41 43 42 44 now := time.Now().UTC() 43 45 44 - if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 46 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 45 47 s.logger.Error("error updating user", "error", err) 46 48 return helpers.ServerError(e, nil) 47 49 }
+16 -14
server/handle_server_create_account.go
··· 36 36 } 37 37 38 38 func (s *Server) handleCreateAccount(e echo.Context) error { 39 + ctx := e.Request().Context() 40 + 39 41 var request ComAtprotoServerCreateAccountRequest 40 42 41 43 if err := e.Bind(&request); err != nil { ··· 68 70 } 69 71 } 70 72 } 71 - 73 + 72 74 var signupDid string 73 75 if request.Did != nil { 74 - signupDid = *request.Did; 75 - 76 + signupDid = *request.Did 77 + 76 78 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 77 79 if token == "" { 78 80 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) ··· 90 92 } 91 93 92 94 // see if the handle is already taken 93 - actor, err := s.getActorByHandle(request.Handle) 95 + actor, err := s.getActorByHandle(ctx, request.Handle) 94 96 if err != nil && err != gorm.ErrRecordNotFound { 95 97 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 96 98 return helpers.ServerError(e, nil) ··· 109 111 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 110 112 } 111 113 112 - if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 114 + if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 113 115 if err == gorm.ErrRecordNotFound { 114 116 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 115 117 } ··· 123 125 } 124 126 125 127 // see if the email is already taken 126 - existingRepo, err := s.getRepoByEmail(request.Email) 128 + existingRepo, err := s.getRepoByEmail(ctx, request.Email) 127 129 if err != nil && err != gorm.ErrRecordNotFound { 128 130 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 129 131 return helpers.ServerError(e, nil) ··· 137 139 var k *atcrypto.PrivateKeyK256 138 140 139 141 if signupDid != "" { 140 - reservedKey, err := s.getReservedKey(signupDid) 142 + reservedKey, err := s.getReservedKey(ctx, signupDid) 141 143 if err != nil { 142 144 s.logger.Error("error looking up reserved key", "error", err) 143 145 } ··· 148 150 k = nil 149 151 } else { 150 152 defer func() { 151 - if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil { 153 + if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil { 152 154 s.logger.Error("error deleting reserved key", "error", delErr) 153 155 } 154 156 }() ··· 199 201 Handle: request.Handle, 200 202 } 201 203 202 - if err := s.db.Create(&urepo, nil).Error; err != nil { 204 + if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { 203 205 s.logger.Error("error inserting new repo", "error", err) 204 206 return helpers.ServerError(e, nil) 205 207 } 206 - 207 - if err := s.db.Create(&actor, nil).Error; err != nil { 208 + 209 + if err := s.db.Create(ctx, &actor, nil).Error; err != nil { 208 210 s.logger.Error("error inserting new actor", "error", err) 209 211 return helpers.ServerError(e, nil) 210 212 } 211 213 } else { 212 - if err := s.db.Save(&actor, nil).Error; err != nil { 214 + if err := s.db.Save(ctx, &actor, nil).Error; err != nil { 213 215 s.logger.Error("error inserting new actor", "error", err) 214 216 return helpers.ServerError(e, nil) 215 217 } ··· 241 243 } 242 244 243 245 if s.config.RequireInvite { 244 - if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 246 + if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 245 247 s.logger.Error("error decrementing use count", "error", err) 246 248 return helpers.ServerError(e, nil) 247 249 } 248 250 } 249 251 250 - sess, err := s.createSession(&urepo) 252 + sess, err := s.createSession(ctx, &urepo) 251 253 if err != nil { 252 254 s.logger.Error("error creating new session", "error", err) 253 255 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_create_invite_code.go
··· 17 17 } 18 18 19 19 func (s *Server) handleCreateInviteCode(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + 20 22 var req ComAtprotoServerCreateInviteCodeRequest 21 23 if err := e.Bind(&req); err != nil { 22 24 s.logger.Error("error binding", "error", err) ··· 37 39 acc = *req.ForAccount 38 40 } 39 41 40 - if err := s.db.Create(&models.InviteCode{ 42 + if err := s.db.Create(ctx, &models.InviteCode{ 41 43 Code: ic, 42 44 Did: acc, 43 45 RemainingUseCount: req.UseCount,
+3 -1
server/handle_server_create_invite_codes.go
··· 22 22 } 23 23 24 24 func (s *Server) handleCreateInviteCodes(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + 25 27 var req ComAtprotoServerCreateInviteCodesRequest 26 28 if err := e.Bind(&req); err != nil { 27 29 s.logger.Error("error binding", "error", err) ··· 50 52 ic := uuid.NewString() 51 53 ics = append(ics, ic) 52 54 53 - if err := s.db.Create(&models.InviteCode{ 55 + if err := s.db.Create(ctx, &models.InviteCode{ 54 56 Code: ic, 55 57 Did: did, 56 58 RemainingUseCount: req.UseCount,
+6 -4
server/handle_server_create_session.go
··· 32 32 } 33 33 34 34 func (s *Server) handleCreateSession(e echo.Context) error { 35 + ctx := e.Request().Context() 36 + 35 37 var req ComAtprotoServerCreateSessionRequest 36 38 if err := e.Bind(&req); err != nil { 37 39 s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) ··· 65 67 var err error 66 68 switch idtype { 67 69 case "did": 68 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 70 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 69 71 case "handle": 70 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 72 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 71 73 case "email": 72 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 74 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 73 75 } 74 76 75 77 if err != nil { ··· 88 90 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 89 91 } 90 92 91 - sess, err := s.createSession(&repo.Repo) 93 + sess, err := s.createSession(ctx, &repo.Repo) 92 94 if err != nil { 93 95 s.logger.Error("error creating session", "error", err) 94 96 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_deactivate_account.go
··· 19 19 } 20 20 21 21 func (s *Server) handleServerDeactivateAccount(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 22 24 var req ComAtprotoServerDeactivateAccountRequest 23 25 if err := e.Bind(&req); err != nil { 24 26 s.logger.Error("error binding", "error", err) ··· 27 29 28 30 urepo := e.Get("repo").(*models.RepoActor) 29 31 30 - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 32 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 31 33 s.logger.Error("error updating account status to deactivated", "error", err) 32 34 return helpers.ServerError(e, nil) 33 35 }
+4 -2
server/handle_server_delete_account.go
··· 20 20 } 21 21 22 22 func (s *Server) handleServerDeleteAccount(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 23 25 var req ComAtprotoServerDeleteAccountRequest 24 26 if err := e.Bind(&req); err != nil { 25 27 s.logger.Error("error binding", "error", err) ··· 31 33 return helpers.ServerError(e, nil) 32 34 } 33 35 34 - urepo, err := s.getRepoActorByDid(req.Did) 36 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 35 37 if err != nil { 36 38 s.logger.Error("error getting repo", "error", err) 37 39 return echo.NewHTTPError(400, "account not found") ··· 66 68 }) 67 69 } 68 70 69 - tx := s.db.BeginDangerously() 71 + tx := s.db.BeginDangerously(ctx) 70 72 if tx.Error != nil { 71 73 s.logger.Error("error starting transaction", "error", tx.Error) 72 74 return helpers.ServerError(e, nil)
+4 -2
server/handle_server_delete_session.go
··· 7 7 ) 8 8 9 9 func (s *Server) handleDeleteSession(e echo.Context) error { 10 + ctx := e.Request().Context() 11 + 10 12 token := e.Get("token").(string) 11 13 12 14 var acctok models.Token 13 - if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 15 + if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 14 16 s.logger.Error("error deleting access token from db", "error", err) 15 17 return helpers.ServerError(e, nil) 16 18 } 17 19 18 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 20 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 19 21 s.logger.Error("error deleting refresh token from db", "error", err) 20 22 return helpers.ServerError(e, nil) 21 23 }
+5 -3
server/handle_server_refresh_session.go
··· 16 16 } 17 17 18 18 func (s *Server) handleRefreshSession(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + 19 21 token := e.Get("token").(string) 20 22 repo := e.Get("repo").(*models.RepoActor) 21 23 22 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 24 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 23 25 s.logger.Error("error getting refresh token from db", "error", err) 24 26 return helpers.ServerError(e, nil) 25 27 } 26 28 27 - if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 29 + if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 28 30 s.logger.Error("error deleting access token from db", "error", err) 29 31 return helpers.ServerError(e, nil) 30 32 } 31 33 32 - sess, err := s.createSession(&repo.Repo) 34 + sess, err := s.createSession(ctx, &repo.Repo) 33 35 if err != nil { 34 36 s.logger.Error("error creating new session for refresh", "error", err) 35 37 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_request_account_delete.go
··· 10 10 ) 11 11 12 12 func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 13 15 urepo := e.Get("repo").(*models.RepoActor) 14 16 15 17 token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 16 18 expiresAt := time.Now().UTC().Add(15 * time.Minute) 17 19 18 - if err := s.db.Exec("UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 20 + if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 19 21 s.logger.Error("error setting deletion token", "error", err) 20 22 return helpers.ServerError(e, nil) 21 23 }
+3 -1
server/handle_server_request_email_confirmation.go
··· 11 11 ) 12 12 13 13 func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { 14 + ctx := e.Request().Context() 15 + 14 16 urepo := e.Get("repo").(*models.RepoActor) 15 17 16 18 if urepo.EmailConfirmedAt != nil { ··· 20 22 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 23 eat := time.Now().Add(10 * time.Minute).UTC() 22 24 23 - if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 25 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 26 s.logger.Error("error updating user", "error", err) 25 27 return helpers.ServerError(e, nil) 26 28 }
+3 -1
server/handle_server_request_email_update.go
··· 14 14 } 15 15 16 16 func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 urepo := e.Get("repo").(*models.RepoActor) 18 20 19 21 if urepo.EmailConfirmedAt != nil { 20 22 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 23 eat := time.Now().Add(10 * time.Minute).UTC() 22 24 23 - if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 25 + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 26 s.logger.Error("error updating repo", "error", err) 25 27 return helpers.ServerError(e, nil) 26 28 }
+4 -2
server/handle_server_request_password_reset.go
··· 14 14 } 15 15 16 16 func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 urepo, ok := e.Get("repo").(*models.RepoActor) 18 20 if !ok { 19 21 var req ComAtprotoServerRequestPasswordResetRequest ··· 25 27 return err 26 28 } 27 29 28 - murepo, err := s.getRepoActorByEmail(req.Email) 30 + murepo, err := s.getRepoActorByEmail(ctx, req.Email) 29 31 if err != nil { 30 32 return err 31 33 } ··· 36 38 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 37 39 eat := time.Now().Add(10 * time.Minute).UTC() 38 40 39 - if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 41 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 40 42 s.logger.Error("error updating repo", "error", err) 41 43 return helpers.ServerError(e, nil) 42 44 }
+11 -8
server/handle_server_reserve_signing_key.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "time" 5 6 6 7 "github.com/bluesky-social/indigo/atproto/atcrypto" ··· 18 19 } 19 20 20 21 func (s *Server) handleServerReserveSigningKey(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 21 24 var req ServerReserveSigningKeyRequest 22 25 if err := e.Bind(&req); err != nil { 23 26 s.logger.Error("could not bind reserve signing key request", "error", err) ··· 26 29 27 30 if req.Did != nil && *req.Did != "" { 28 31 var existing models.ReservedKey 29 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 32 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 30 33 return e.JSON(200, ServerReserveSigningKeyResponse{ 31 34 SigningKey: existing.KeyDid, 32 35 }) ··· 54 57 CreatedAt: time.Now(), 55 58 } 56 59 57 - if err := s.db.Create(&reservedKey, nil).Error; err != nil { 60 + if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil { 58 61 s.logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 59 62 return helpers.ServerError(e, nil) 60 63 } ··· 66 69 }) 67 70 } 68 71 69 - func (s *Server) getReservedKey(keyDidOrDid string) (*models.ReservedKey, error) { 72 + func (s *Server) getReservedKey(ctx context.Context, keyDidOrDid string) (*models.ReservedKey, error) { 70 73 var reservedKey models.ReservedKey 71 74 72 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 75 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 73 76 return &reservedKey, nil 74 77 } 75 78 76 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 79 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 77 80 return &reservedKey, nil 78 81 } 79 82 80 83 return nil, nil 81 84 } 82 85 83 - func (s *Server) deleteReservedKey(keyDid string, did *string) error { 84 - if err := s.db.Exec("DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { 86 + func (s *Server) deleteReservedKey(ctx context.Context, keyDid string, did *string) error { 87 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { 85 88 return err 86 89 } 87 90 88 91 if did != nil && *did != "" { 89 - if err := s.db.Exec("DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { 92 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { 90 93 return err 91 94 } 92 95 }
+3 -1
server/handle_server_reset_password.go
··· 16 16 } 17 17 18 18 func (s *Server) handleServerResetPassword(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + 19 21 urepo := e.Get("repo").(*models.RepoActor) 20 22 21 23 var req ComAtprotoServerResetPasswordRequest ··· 46 48 return helpers.ServerError(e, nil) 47 49 } 48 50 49 - if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 51 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 50 52 s.logger.Error("error updating repo", "error", err) 51 53 return helpers.ServerError(e, nil) 52 54 }
+3 -1
server/handle_server_update_email.go
··· 15 15 } 16 16 17 17 func (s *Server) handleServerUpdateEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + 18 20 urepo := e.Get("repo").(*models.RepoActor) 19 21 20 22 var req ComAtprotoServerUpdateEmailRequest ··· 39 41 return helpers.ExpiredTokenError(e) 40 42 } 41 43 42 - if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { 44 + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { 43 45 s.logger.Error("error updating repo", "error", err) 44 46 return helpers.ServerError(e, nil) 45 47 }
+5 -3
server/handle_sync_get_blob.go
··· 17 17 ) 18 18 19 19 func (s *Server) handleSyncGetBlob(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + 20 22 did := e.QueryParam("did") 21 23 if did == "" { 22 24 return helpers.InputError(e, nil) ··· 32 34 return helpers.InputError(e, nil) 33 35 } 34 36 35 - urepo, err := s.getRepoActorByDid(did) 37 + urepo, err := s.getRepoActorByDid(ctx, did) 36 38 if err != nil { 37 39 s.logger.Error("could not find user for requested blob", "error", err) 38 40 return helpers.InputError(e, nil) ··· 46 48 } 47 49 48 50 var blob models.Blob 49 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 51 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 50 52 s.logger.Error("error looking up blob", "error", err) 51 53 return helpers.ServerError(e, nil) 52 54 } ··· 55 57 56 58 if blob.Storage == "sqlite" { 57 59 var parts []models.BlobPart 58 - if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 60 + if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 59 61 s.logger.Error("error getting blob parts", "error", err) 60 62 return helpers.ServerError(e, nil) 61 63 }
+1 -1
server/handle_sync_get_blocks.go
··· 35 35 cids = append(cids, c) 36 36 } 37 37 38 - urepo, err := s.getRepoActorByDid(req.Did) 38 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 39 39 if err != nil { 40 40 return helpers.ServerError(e, nil) 41 41 }
+3 -1
server/handle_sync_get_latest_commit.go
··· 12 12 } 13 13 14 14 func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { 15 + ctx := e.Request().Context() 16 + 15 17 did := e.QueryParam("did") 16 18 if did == "" { 17 19 return helpers.InputError(e, nil) 18 20 } 19 21 20 - urepo, err := s.getRepoActorByDid(did) 22 + urepo, err := s.getRepoActorByDid(ctx, did) 21 23 if err != nil { 22 24 return err 23 25 }
+1 -1
server/handle_sync_get_record.go
··· 20 20 rkey := e.QueryParam("rkey") 21 21 22 22 var urepo models.Repo 23 - if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 23 + if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 24 24 s.logger.Error("error getting repo", "error", err) 25 25 return helpers.ServerError(e, nil) 26 26 }
+4 -2
server/handle_sync_get_repo.go
··· 13 13 ) 14 14 15 15 func (s *Server) handleSyncGetRepo(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + 16 18 did := e.QueryParam("did") 17 19 if did == "" { 18 20 return helpers.InputError(e, nil) 19 21 } 20 22 21 - urepo, err := s.getRepoActorByDid(did) 23 + urepo, err := s.getRepoActorByDid(ctx, did) 22 24 if err != nil { 23 25 return err 24 26 } ··· 41 43 } 42 44 43 45 var blocks []models.Block 44 - if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 46 + if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 45 47 return err 46 48 } 47 49
+3 -1
server/handle_sync_get_repo_status.go
··· 14 14 15 15 // TODO: make this actually do the right thing 16 16 func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 did := e.QueryParam("did") 18 20 if did == "" { 19 21 return helpers.InputError(e, nil) 20 22 } 21 23 22 - urepo, err := s.getRepoActorByDid(did) 24 + urepo, err := s.getRepoActorByDid(ctx, did) 23 25 if err != nil { 24 26 return err 25 27 }
+4 -2
server/handle_sync_list_blobs.go
··· 14 14 } 15 15 16 16 func (s *Server) handleSyncListBlobs(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 17 19 did := e.QueryParam("did") 18 20 if did == "" { 19 21 return helpers.InputError(e, nil) ··· 35 37 } 36 38 params = append(params, limit) 37 39 38 - urepo, err := s.getRepoActorByDid(did) 40 + urepo, err := s.getRepoActorByDid(ctx, did) 39 41 if err != nil { 40 42 s.logger.Error("could not find user for requested blobs", "error", err) 41 43 return helpers.InputError(e, nil) ··· 49 51 } 50 52 51 53 var blobs []models.Blob 52 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 54 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 53 55 s.logger.Error("error getting records", "error", err) 54 56 return helpers.ServerError(e, nil) 55 57 }
+3 -1
server/handle_well_known.go
··· 67 67 } 68 68 69 69 func (s *Server) handleAtprotoDid(e echo.Context) error { 70 + ctx := e.Request().Context() 71 + 70 72 host := e.Request().Host 71 73 if host == "" { 72 74 return helpers.InputError(e, to.StringPtr("Invalid handle.")) ··· 84 86 return e.NoContent(404) 85 87 } 86 88 87 - actor, err := s.getActorByHandle(host) 89 + actor, err := s.getActorByHandle(ctx, host) 88 90 if err != nil { 89 91 if err == gorm.ErrRecordNotFound { 90 92 return e.NoContent(404)
+9 -5
server/middleware.go
··· 37 37 38 38 func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 39 return func(e echo.Context) error { 40 + ctx := e.Request().Context() 41 + 40 42 authheader := e.Request().Header.Get("authorization") 41 43 if authheader == "" { 42 44 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 78 80 } 79 81 did = maybeDid 80 82 81 - maybeRepo, err := s.getRepoActorByDid(did) 83 + maybeRepo, err := s.getRepoActorByDid(ctx, did) 82 84 if err != nil { 83 85 s.logger.Error("error fetching repo", "error", err) 84 86 return helpers.ServerError(e, nil) ··· 159 161 Found bool 160 162 } 161 163 var result Result 162 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 164 + if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 163 165 if err == gorm.ErrRecordNotFound { 164 166 return helpers.InvalidTokenError(e) 165 167 } ··· 184 186 } 185 187 186 188 if repo == nil { 187 - maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 189 + maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 188 190 if err != nil { 189 191 s.logger.Error("error fetching repo", "error", err) 190 192 return helpers.ServerError(e, nil) ··· 207 209 208 210 func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 209 211 return func(e echo.Context) error { 212 + ctx := e.Request().Context() 213 + 210 214 authheader := e.Request().Header.Get("authorization") 211 215 if authheader == "" { 212 216 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 243 247 } 244 248 245 249 var oauthToken provider.OauthToken 246 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 250 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 247 251 s.logger.Error("error finding access token in db", "error", err) 248 252 return helpers.InputError(e, nil) 249 253 } ··· 266 270 }) 267 271 } 268 272 269 - repo, err := s.getRepoActorByDid(oauthToken.Sub) 273 + repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 270 274 if err != nil { 271 275 s.logger.Error("could not find actor in db", "error", err) 272 276 return helpers.ServerError(e, nil)
+11 -11
server/repo.go
··· 181 181 case OpTypeDelete: 182 182 // try to find the old record in the database 183 183 var old models.Record 184 - if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 184 + if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 185 185 return nil, err 186 186 } 187 187 ··· 323 323 var cids []cid.Cid 324 324 // whenever there is cid present, we know it's a create (dumb) 325 325 if entry.Cid != "" { 326 - if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ 326 + if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{ 327 327 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 328 328 UpdateAll: true, 329 329 }}).Error; err != nil { ··· 331 331 } 332 332 333 333 // increment the given blob refs, yay 334 - cids, err = rm.incrementBlobRefs(urepo, entry.Value) 334 + cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value) 335 335 if err != nil { 336 336 return nil, err 337 337 } ··· 339 339 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey 340 340 // is did + collection + rkey. i still really want to separate that out, or use a different type to make 341 341 // this less confusing/easy to read. alas, its 2 am and yea no 342 - if err := rm.s.db.Delete(&entry, nil).Error; err != nil { 342 + if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil { 343 343 return nil, err 344 344 } 345 345 346 346 // TODO: 347 - cids, err = rm.decrementBlobRefs(urepo, entry.Value) 347 + cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value) 348 348 if err != nil { 349 349 return nil, err 350 350 } ··· 411 411 return c, bs.GetReadLog(), nil 412 412 } 413 413 414 - func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 414 + func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 415 415 cids, err := getBlobCidsFromCbor(cbor) 416 416 if err != nil { 417 417 return nil, err 418 418 } 419 419 420 420 for _, c := range cids { 421 - if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 421 + if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 422 422 return nil, err 423 423 } 424 424 } ··· 426 426 return cids, nil 427 427 } 428 428 429 - func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 429 + func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 430 430 cids, err := getBlobCidsFromCbor(cbor) 431 431 if err != nil { 432 432 return nil, err ··· 437 437 ID uint 438 438 Count int 439 439 } 440 - if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 440 + if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 441 441 return nil, err 442 442 } 443 443 444 444 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what 445 445 // storage it is in, and clean up s3!!!! 446 446 if res.Count == 0 { 447 - if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 447 + if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 448 448 return nil, err 449 449 } 450 - if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 450 + if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 451 451 return nil, err 452 452 } 453 453 }
+1 -1
server/server.go
··· 729 729 } 730 730 731 731 func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 732 - if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 732 + if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 733 733 return err 734 734 } 735 735
+4 -3
server/session.go
··· 1 1 package server 2 2 3 3 import ( 4 + "context" 4 5 "time" 5 6 6 7 "github.com/golang-jwt/jwt/v4" ··· 13 14 RefreshToken string 14 15 } 15 16 16 - func (s *Server) createSession(repo *models.Repo) (*Session, error) { 17 + func (s *Server) createSession(ctx context.Context, repo *models.Repo) (*Session, error) { 17 18 now := time.Now() 18 19 accexp := now.Add(3 * time.Hour) 19 20 refexp := now.Add(7 * 24 * time.Hour) ··· 49 50 return nil, err 50 51 } 51 52 52 - if err := s.db.Create(&models.Token{ 53 + if err := s.db.Create(ctx, &models.Token{ 53 54 Token: accessString, 54 55 Did: repo.Did, 55 56 RefreshToken: refreshString, ··· 59 60 return nil, err 60 61 } 61 62 62 - if err := s.db.Create(&models.RefreshToken{ 63 + if err := s.db.Create(ctx, &models.RefreshToken{ 63 64 Token: refreshString, 64 65 Did: repo.Did, 65 66 CreatedAt: now,
+3 -3
sqlite_blockstore/sqlite_blockstore.go
··· 45 45 return maybeBlock, nil 46 46 } 47 47 48 - if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 48 + if err := bs.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 49 return nil, err 50 50 } 51 51 ··· 71 71 Value: block.RawData(), 72 72 } 73 73 74 - if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ 74 + if err := bs.db.Create(ctx, &b, []clause.Expression{clause.OnConflict{ 75 75 Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 76 UpdateAll: true, 77 77 }}).Error; err != nil { ··· 94 94 } 95 95 96 96 func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 - tx := bs.db.BeginDangerously() 97 + tx := bs.db.BeginDangerously(ctx) 98 98 99 99 for _, block := range blocks { 100 100 bs.inserts[block.Cid()] = block