+15
-14
internal/db/db.go
+15
-14
internal/db/db.go
···
1
1
package db
2
2
3
3
import (
4
+
"context"
4
5
"sync"
5
6
6
7
"gorm.io/gorm"
···
19
20
}
20
21
}
21
22
22
-
func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB {
23
+
func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
23
24
db.mu.Lock()
24
25
defer db.mu.Unlock()
25
-
return db.cli.Clauses(clauses...).Create(value)
26
+
return db.cli.WithContext(ctx).Clauses(clauses...).Create(value)
26
27
}
27
28
28
-
func (db *DB) Save(value any, clauses []clause.Expression) *gorm.DB {
29
+
func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
29
30
db.mu.Lock()
30
31
defer db.mu.Unlock()
31
-
return db.cli.Clauses(clauses...).Save(value)
32
+
return db.cli.WithContext(ctx).Clauses(clauses...).Save(value)
32
33
}
33
34
34
-
func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
35
+
func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB {
35
36
db.mu.Lock()
36
37
defer db.mu.Unlock()
37
-
return db.cli.Clauses(clauses...).Exec(sql, values...)
38
+
return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...)
38
39
}
39
40
40
-
func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB {
41
-
return db.cli.Clauses(clauses...).Raw(sql, values...)
41
+
func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB {
42
+
return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...)
42
43
}
43
44
44
45
func (db *DB) AutoMigrate(models ...any) error {
45
46
return db.cli.AutoMigrate(models...)
46
47
}
47
48
48
-
func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB {
49
+
func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB {
49
50
db.mu.Lock()
50
51
defer db.mu.Unlock()
51
-
return db.cli.Clauses(clauses...).Delete(value)
52
+
return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value)
52
53
}
53
54
54
-
func (db *DB) First(dest any, conds ...any) *gorm.DB {
55
-
return db.cli.First(dest, conds...)
55
+
func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB {
56
+
return db.cli.WithContext(ctx).First(dest, conds...)
56
57
}
57
58
58
59
// TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure
59
60
// out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad.
60
61
// e.g. when we do apply writes we should also be using a transcation but we don't right now
61
-
func (db *DB) BeginDangerously() *gorm.DB {
62
-
return db.cli.Begin()
62
+
func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB {
63
+
return db.cli.WithContext(ctx).Begin()
63
64
}
64
65
65
66
func (db *DB) Lock() {
+10
-8
server/common.go
+10
-8
server/common.go
···
1
1
package server
2
2
3
3
import (
4
+
"context"
5
+
4
6
"github.com/haileyok/cocoon/models"
5
7
)
6
8
7
-
func (s *Server) getActorByHandle(handle string) (*models.Actor, error) {
9
+
func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) {
8
10
var actor models.Actor
9
-
if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil {
11
+
if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil {
10
12
return nil, err
11
13
}
12
14
return &actor, nil
13
15
}
14
16
15
-
func (s *Server) getRepoByEmail(email string) (*models.Repo, error) {
17
+
func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) {
16
18
var repo models.Repo
17
-
if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil {
19
+
if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil {
18
20
return nil, err
19
21
}
20
22
return &repo, nil
21
23
}
22
24
23
-
func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) {
25
+
func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) {
24
26
var repo models.RepoActor
25
-
if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
27
+
if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil {
26
28
return nil, err
27
29
}
28
30
return &repo, nil
29
31
}
30
32
31
-
func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) {
33
+
func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) {
32
34
var repo models.RepoActor
33
-
if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
35
+
if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil {
34
36
return nil, err
35
37
}
36
38
return &repo, nil
+2
-1
server/handle_account.go
+2
-1
server/handle_account.go
···
12
12
13
13
func (s *Server) handleAccount(e echo.Context) error {
14
14
ctx := e.Request().Context()
15
+
15
16
repo, sess, err := s.getSessionRepoOrErr(e)
16
17
if err != nil {
17
18
return e.Redirect(303, "/account/signin")
···
20
21
oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime)
21
22
22
23
var tokens []provider.OauthToken
23
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
24
+
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil {
24
25
s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err)
25
26
sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error")
26
27
sess.Save(e.Request(), e.Response())
+5
-3
server/handle_account_revoke.go
+5
-3
server/handle_account_revoke.go
···
5
5
"github.com/labstack/echo/v4"
6
6
)
7
7
8
-
type AccountRevokeRequest struct {
8
+
type AccountRevokeInput struct {
9
9
Token string `form:"token"`
10
10
}
11
11
12
12
func (s *Server) handleAccountRevoke(e echo.Context) error {
13
-
var req AccountRevokeRequest
13
+
ctx := e.Request().Context()
14
+
15
+
var req AccountRevokeInput
14
16
if err := e.Bind(&req); err != nil {
15
17
s.logger.Error("could not bind account revoke request", "error", err)
16
18
return helpers.ServerError(e, nil)
···
21
23
return e.Redirect(303, "/account/signin")
22
24
}
23
25
24
-
if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
26
+
if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil {
25
27
s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err)
26
28
sess.AddFlash("Unable to revoke session. See server logs for more details.", "error")
27
29
sess.Save(e.Request(), e.Response())
+10
-6
server/handle_account_signin.go
+10
-6
server/handle_account_signin.go
···
14
14
"gorm.io/gorm"
15
15
)
16
16
17
-
type OauthSigninRequest struct {
17
+
type OauthSigninInput struct {
18
18
Username string `form:"username"`
19
19
Password string `form:"password"`
20
20
QueryParams string `form:"query_params"`
21
21
}
22
22
23
23
func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) {
24
+
ctx := e.Request().Context()
25
+
24
26
sess, err := session.Get("session", e)
25
27
if err != nil {
26
28
return nil, nil, err
···
31
33
return nil, sess, errors.New("did was not set in session")
32
34
}
33
35
34
-
repo, err := s.getRepoActorByDid(did)
36
+
repo, err := s.getRepoActorByDid(ctx, did)
35
37
if err != nil {
36
38
return nil, sess, err
37
39
}
···
60
62
}
61
63
62
64
func (s *Server) handleAccountSigninPost(e echo.Context) error {
63
-
var req OauthSigninRequest
65
+
ctx := e.Request().Context()
66
+
67
+
var req OauthSigninInput
64
68
if err := e.Bind(&req); err != nil {
65
69
s.logger.Error("error binding sign in req", "error", err)
66
70
return helpers.ServerError(e, nil)
···
83
87
var err error
84
88
switch idtype {
85
89
case "did":
86
-
err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
90
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error
87
91
case "handle":
88
-
err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
92
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error
89
93
case "email":
90
-
err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
94
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error
91
95
}
92
96
if err != nil {
93
97
if err == gorm.ErrRecordNotFound {
+3
-1
server/handle_actor_put_preferences.go
+3
-1
server/handle_actor_put_preferences.go
···
10
10
// This is kinda lame. Not great to implement app.bsky in the pds, but alas
11
11
12
12
func (s *Server) handleActorPutPreferences(e echo.Context) error {
13
+
ctx := e.Request().Context()
14
+
13
15
repo := e.Get("repo").(*models.RepoActor)
14
16
15
17
var prefs map[string]any
···
22
24
return err
23
25
}
24
26
25
-
if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
27
+
if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil {
26
28
return err
27
29
}
28
30
+3
-1
server/handle_identity_request_plc_operation.go
+3
-1
server/handle_identity_request_plc_operation.go
···
10
10
)
11
11
12
12
func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error {
13
+
ctx := e.Request().Context()
14
+
13
15
urepo := e.Get("repo").(*models.RepoActor)
14
16
15
17
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
16
18
eat := time.Now().Add(10 * time.Minute).UTC()
17
19
18
-
if err := s.db.Exec("UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
20
+
if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
19
21
s.logger.Error("error updating user", "error", err)
20
22
return helpers.ServerError(e, nil)
21
23
}
+1
-1
server/handle_identity_sign_plc_operation.go
+1
-1
server/handle_identity_sign_plc_operation.go
···
92
92
return helpers.ServerError(e, nil)
93
93
}
94
94
95
-
if err := s.db.Exec("UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil {
95
+
if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil {
96
96
s.logger.Error("error updating repo", "error", err)
97
97
return helpers.ServerError(e, nil)
98
98
}
+1
-1
server/handle_identity_update_handle.go
+1
-1
server/handle_identity_update_handle.go
···
94
94
},
95
95
})
96
96
97
-
if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil {
97
+
if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil {
98
98
s.logger.Error("error updating handle in db", "error", err)
99
99
return helpers.ServerError(e, nil)
100
100
}
+3
-1
server/handle_import_repo.go
+3
-1
server/handle_import_repo.go
···
18
18
)
19
19
20
20
func (s *Server) handleRepoImportRepo(e echo.Context) error {
21
+
ctx := e.Request().Context()
22
+
21
23
urepo := e.Get("repo").(*models.RepoActor)
22
24
23
25
b, err := io.ReadAll(e.Request().Body)
···
63
65
return helpers.ServerError(e, nil)
64
66
}
65
67
66
-
tx := s.db.BeginDangerously()
68
+
tx := s.db.BeginDangerously(ctx)
67
69
68
70
clock := syntax.NewTIDClock(0)
69
71
+3
-1
server/handle_oauth_par.go
+3
-1
server/handle_oauth_par.go
···
19
19
}
20
20
21
21
func (s *Server) handleOauthPar(e echo.Context) error {
22
+
ctx := e.Request().Context()
23
+
22
24
var parRequest provider.ParRequest
23
25
if err := e.Bind(&parRequest); err != nil {
24
26
s.logger.Error("error binding for par request", "error", err)
···
86
88
ExpiresAt: eat,
87
89
}
88
90
89
-
if err := s.db.Create(authRequest, nil).Error; err != nil {
91
+
if err := s.db.Create(ctx, authRequest, nil).Error; err != nil {
90
92
s.logger.Error("error creating auth request in db", "error", err)
91
93
return helpers.ServerError(e, nil)
92
94
}
+7
-5
server/handle_oauth_token.go
+7
-5
server/handle_oauth_token.go
···
38
38
}
39
39
40
40
func (s *Server) handleOauthToken(e echo.Context) error {
41
+
ctx := e.Request().Context()
42
+
41
43
var req OauthTokenRequest
42
44
if err := e.Bind(&req); err != nil {
43
45
s.logger.Error("error binding token request", "error", err)
···
84
86
85
87
var authReq provider.OauthAuthorizationRequest
86
88
// get the lil guy and delete him
87
-
if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
89
+
if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil {
88
90
s.logger.Error("error finding authorization request", "error", err)
89
91
return helpers.ServerError(e, nil)
90
92
}
···
128
130
return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided"))
129
131
}
130
132
131
-
repo, err := s.getRepoActorByDid(*authReq.Sub)
133
+
repo, err := s.getRepoActorByDid(ctx, *authReq.Sub)
132
134
if err != nil {
133
135
helpers.InputError(e, to.StringPtr("unable to find actor"))
134
136
}
···
159
161
return err
160
162
}
161
163
162
-
if err := s.db.Create(&provider.OauthToken{
164
+
if err := s.db.Create(ctx, &provider.OauthToken{
163
165
ClientId: authReq.ClientId,
164
166
ClientAuth: *clientAuth,
165
167
Parameters: authReq.Parameters,
···
199
201
}
200
202
201
203
var oauthToken provider.OauthToken
202
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
204
+
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil {
203
205
s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken)
204
206
return helpers.ServerError(e, nil)
205
207
}
···
257
259
return err
258
260
}
259
261
260
-
if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
262
+
if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil {
261
263
s.logger.Error("error updating token", "error", err)
262
264
return helpers.ServerError(e, nil)
263
265
}
+4
-2
server/handle_repo_describe_repo.go
+4
-2
server/handle_repo_describe_repo.go
···
20
20
}
21
21
22
22
func (s *Server) handleDescribeRepo(e echo.Context) error {
23
+
ctx := e.Request().Context()
24
+
23
25
did := e.QueryParam("repo")
24
-
repo, err := s.getRepoActorByDid(did)
26
+
repo, err := s.getRepoActorByDid(ctx, did)
25
27
if err != nil {
26
28
if err == gorm.ErrRecordNotFound {
27
29
return helpers.InputError(e, to.StringPtr("RepoNotFound"))
···
64
66
}
65
67
66
68
var records []models.Record
67
-
if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
69
+
if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil {
68
70
s.logger.Error("error getting collections", "error", err)
69
71
return helpers.ServerError(e, nil)
70
72
}
+3
-1
server/handle_repo_get_record.go
+3
-1
server/handle_repo_get_record.go
···
14
14
}
15
15
16
16
func (s *Server) handleRepoGetRecord(e echo.Context) error {
17
+
ctx := e.Request().Context()
18
+
17
19
repo := e.QueryParam("repo")
18
20
collection := e.QueryParam("collection")
19
21
rkey := e.QueryParam("rkey")
···
32
34
}
33
35
34
36
var record models.Record
35
-
if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
37
+
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil {
36
38
// TODO: handle error nicely
37
39
return err
38
40
}
+4
-2
server/handle_repo_list_missing_blobs.go
+4
-2
server/handle_repo_list_missing_blobs.go
···
22
22
}
23
23
24
24
func (s *Server) handleListMissingBlobs(e echo.Context) error {
25
+
ctx := e.Request().Context()
26
+
25
27
urepo := e.Get("repo").(*models.RepoActor)
26
28
27
29
limitStr := e.QueryParam("limit")
···
35
37
}
36
38
37
39
var records []models.Record
38
-
if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
40
+
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil {
39
41
s.logger.Error("failed to get records for listMissingBlobs", "error", err)
40
42
return helpers.ServerError(e, nil)
41
43
}
···
69
71
}
70
72
71
73
var count int64
72
-
if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
74
+
if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil {
73
75
continue
74
76
}
75
77
+4
-2
server/handle_repo_list_records.go
+4
-2
server/handle_repo_list_records.go
···
46
46
}
47
47
48
48
func (s *Server) handleListRecords(e echo.Context) error {
49
+
ctx := e.Request().Context()
50
+
49
51
var req ComAtprotoRepoListRecordsRequest
50
52
if err := e.Bind(&req); err != nil {
51
53
s.logger.Error("could not bind list records request", "error", err)
···
78
80
79
81
did := req.Repo
80
82
if _, err := syntax.ParseDID(did); err != nil {
81
-
actor, err := s.getActorByHandle(req.Repo)
83
+
actor, err := s.getActorByHandle(ctx, req.Repo)
82
84
if err != nil {
83
85
return helpers.InputError(e, to.StringPtr("RepoNotFound"))
84
86
}
···
93
95
params = append(params, limit)
94
96
95
97
var records []models.Record
96
-
if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
98
+
if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil {
97
99
s.logger.Error("error getting records", "error", err)
98
100
return helpers.ServerError(e, nil)
99
101
}
+3
-1
server/handle_repo_list_repos.go
+3
-1
server/handle_repo_list_repos.go
···
21
21
22
22
// TODO: paginate this bitch
23
23
func (s *Server) handleListRepos(e echo.Context) error {
24
+
ctx := e.Request().Context()
25
+
24
26
var repos []models.Repo
25
-
if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
27
+
if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil {
26
28
return err
27
29
}
28
30
+5
-3
server/handle_repo_upload_blob.go
+5
-3
server/handle_repo_upload_blob.go
···
32
32
}
33
33
34
34
func (s *Server) handleRepoUploadBlob(e echo.Context) error {
35
+
ctx := e.Request().Context()
36
+
35
37
urepo := e.Get("repo").(*models.RepoActor)
36
38
37
39
mime := e.Request().Header.Get("content-type")
···
51
53
Storage: storage,
52
54
}
53
55
54
-
if err := s.db.Create(&blob, nil).Error; err != nil {
56
+
if err := s.db.Create(ctx, &blob, nil).Error; err != nil {
55
57
s.logger.Error("error creating new blob in db", "error", err)
56
58
return helpers.ServerError(e, nil)
57
59
}
···
84
86
Data: data,
85
87
}
86
88
87
-
if err := s.db.Create(&blobPart, nil).Error; err != nil {
89
+
if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil {
88
90
s.logger.Error("error adding blob part to db", "error", err)
89
91
return helpers.ServerError(e, nil)
90
92
}
···
131
133
}
132
134
}
133
135
134
-
if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
136
+
if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil {
135
137
// there should probably be somme handling here if this fails...
136
138
s.logger.Error("error updating blob", "error", err)
137
139
return helpers.ServerError(e, nil)
+3
-1
server/handle_server_activate_account.go
+3
-1
server/handle_server_activate_account.go
···
18
18
}
19
19
20
20
func (s *Server) handleServerActivateAccount(e echo.Context) error {
21
+
ctx := e.Request().Context()
22
+
21
23
var req ComAtprotoServerDeactivateAccountRequest
22
24
if err := e.Bind(&req); err != nil {
23
25
s.logger.Error("error binding", "error", err)
···
26
28
27
29
urepo := e.Get("repo").(*models.RepoActor)
28
30
29
-
if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
31
+
if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil {
30
32
s.logger.Error("error updating account status to deactivated", "error", err)
31
33
return helpers.ServerError(e, nil)
32
34
}
+5
-3
server/handle_server_check_account_status.go
+5
-3
server/handle_server_check_account_status.go
···
20
20
}
21
21
22
22
func (s *Server) handleServerCheckAccountStatus(e echo.Context) error {
23
+
ctx := e.Request().Context()
24
+
23
25
urepo := e.Get("repo").(*models.RepoActor)
24
26
25
27
resp := ComAtprotoServerCheckAccountStatusResponse{
···
41
43
}
42
44
43
45
var blockCtResp CountResp
44
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
46
+
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil {
45
47
s.logger.Error("error getting block count", "error", err)
46
48
return helpers.ServerError(e, nil)
47
49
}
48
50
resp.RepoBlocks = blockCtResp.Ct
49
51
50
52
var recCtResp CountResp
51
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
53
+
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil {
52
54
s.logger.Error("error getting record count", "error", err)
53
55
return helpers.ServerError(e, nil)
54
56
}
55
57
resp.IndexedRecords = recCtResp.Ct
56
58
57
59
var blobCtResp CountResp
58
-
if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
60
+
if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil {
59
61
s.logger.Error("error getting record count", "error", err)
60
62
return helpers.ServerError(e, nil)
61
63
}
+3
-1
server/handle_server_confirm_email.go
+3
-1
server/handle_server_confirm_email.go
···
15
15
}
16
16
17
17
func (s *Server) handleServerConfirmEmail(e echo.Context) error {
18
+
ctx := e.Request().Context()
19
+
18
20
urepo := e.Get("repo").(*models.RepoActor)
19
21
20
22
var req ComAtprotoServerConfirmEmailRequest
···
41
43
42
44
now := time.Now().UTC()
43
45
44
-
if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
46
+
if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil {
45
47
s.logger.Error("error updating user", "error", err)
46
48
return helpers.ServerError(e, nil)
47
49
}
+16
-14
server/handle_server_create_account.go
+16
-14
server/handle_server_create_account.go
···
36
36
}
37
37
38
38
func (s *Server) handleCreateAccount(e echo.Context) error {
39
+
ctx := e.Request().Context()
40
+
39
41
var request ComAtprotoServerCreateAccountRequest
40
42
41
43
if err := e.Bind(&request); err != nil {
···
68
70
}
69
71
}
70
72
}
71
-
73
+
72
74
var signupDid string
73
75
if request.Did != nil {
74
-
signupDid = *request.Did;
75
-
76
+
signupDid = *request.Did
77
+
76
78
token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1))
77
79
if token == "" {
78
80
return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did"))
···
90
92
}
91
93
92
94
// see if the handle is already taken
93
-
actor, err := s.getActorByHandle(request.Handle)
95
+
actor, err := s.getActorByHandle(ctx, request.Handle)
94
96
if err != nil && err != gorm.ErrRecordNotFound {
95
97
s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err)
96
98
return helpers.ServerError(e, nil)
···
109
111
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
110
112
}
111
113
112
-
if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
114
+
if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
113
115
if err == gorm.ErrRecordNotFound {
114
116
return helpers.InputError(e, to.StringPtr("InvalidInviteCode"))
115
117
}
···
123
125
}
124
126
125
127
// see if the email is already taken
126
-
existingRepo, err := s.getRepoByEmail(request.Email)
128
+
existingRepo, err := s.getRepoByEmail(ctx, request.Email)
127
129
if err != nil && err != gorm.ErrRecordNotFound {
128
130
s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err)
129
131
return helpers.ServerError(e, nil)
···
137
139
var k *atcrypto.PrivateKeyK256
138
140
139
141
if signupDid != "" {
140
-
reservedKey, err := s.getReservedKey(signupDid)
142
+
reservedKey, err := s.getReservedKey(ctx, signupDid)
141
143
if err != nil {
142
144
s.logger.Error("error looking up reserved key", "error", err)
143
145
}
···
148
150
k = nil
149
151
} else {
150
152
defer func() {
151
-
if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil {
153
+
if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil {
152
154
s.logger.Error("error deleting reserved key", "error", delErr)
153
155
}
154
156
}()
···
199
201
Handle: request.Handle,
200
202
}
201
203
202
-
if err := s.db.Create(&urepo, nil).Error; err != nil {
204
+
if err := s.db.Create(ctx, &urepo, nil).Error; err != nil {
203
205
s.logger.Error("error inserting new repo", "error", err)
204
206
return helpers.ServerError(e, nil)
205
207
}
206
-
207
-
if err := s.db.Create(&actor, nil).Error; err != nil {
208
+
209
+
if err := s.db.Create(ctx, &actor, nil).Error; err != nil {
208
210
s.logger.Error("error inserting new actor", "error", err)
209
211
return helpers.ServerError(e, nil)
210
212
}
211
213
} else {
212
-
if err := s.db.Save(&actor, nil).Error; err != nil {
214
+
if err := s.db.Save(ctx, &actor, nil).Error; err != nil {
213
215
s.logger.Error("error inserting new actor", "error", err)
214
216
return helpers.ServerError(e, nil)
215
217
}
···
241
243
}
242
244
243
245
if s.config.RequireInvite {
244
-
if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
246
+
if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil {
245
247
s.logger.Error("error decrementing use count", "error", err)
246
248
return helpers.ServerError(e, nil)
247
249
}
248
250
}
249
251
250
-
sess, err := s.createSession(&urepo)
252
+
sess, err := s.createSession(ctx, &urepo)
251
253
if err != nil {
252
254
s.logger.Error("error creating new session", "error", err)
253
255
return helpers.ServerError(e, nil)
+3
-1
server/handle_server_create_invite_code.go
+3
-1
server/handle_server_create_invite_code.go
···
17
17
}
18
18
19
19
func (s *Server) handleCreateInviteCode(e echo.Context) error {
20
+
ctx := e.Request().Context()
21
+
20
22
var req ComAtprotoServerCreateInviteCodeRequest
21
23
if err := e.Bind(&req); err != nil {
22
24
s.logger.Error("error binding", "error", err)
···
37
39
acc = *req.ForAccount
38
40
}
39
41
40
-
if err := s.db.Create(&models.InviteCode{
42
+
if err := s.db.Create(ctx, &models.InviteCode{
41
43
Code: ic,
42
44
Did: acc,
43
45
RemainingUseCount: req.UseCount,
+3
-1
server/handle_server_create_invite_codes.go
+3
-1
server/handle_server_create_invite_codes.go
···
22
22
}
23
23
24
24
func (s *Server) handleCreateInviteCodes(e echo.Context) error {
25
+
ctx := e.Request().Context()
26
+
25
27
var req ComAtprotoServerCreateInviteCodesRequest
26
28
if err := e.Bind(&req); err != nil {
27
29
s.logger.Error("error binding", "error", err)
···
50
52
ic := uuid.NewString()
51
53
ics = append(ics, ic)
52
54
53
-
if err := s.db.Create(&models.InviteCode{
55
+
if err := s.db.Create(ctx, &models.InviteCode{
54
56
Code: ic,
55
57
Did: did,
56
58
RemainingUseCount: req.UseCount,
+6
-4
server/handle_server_create_session.go
+6
-4
server/handle_server_create_session.go
···
32
32
}
33
33
34
34
func (s *Server) handleCreateSession(e echo.Context) error {
35
+
ctx := e.Request().Context()
36
+
35
37
var req ComAtprotoServerCreateSessionRequest
36
38
if err := e.Bind(&req); err != nil {
37
39
s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err)
···
65
67
var err error
66
68
switch idtype {
67
69
case "did":
68
-
err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
70
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error
69
71
case "handle":
70
-
err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
72
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error
71
73
case "email":
72
-
err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
74
+
err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error
73
75
}
74
76
75
77
if err != nil {
···
88
90
return helpers.InputError(e, to.StringPtr("InvalidRequest"))
89
91
}
90
92
91
-
sess, err := s.createSession(&repo.Repo)
93
+
sess, err := s.createSession(ctx, &repo.Repo)
92
94
if err != nil {
93
95
s.logger.Error("error creating session", "error", err)
94
96
return helpers.ServerError(e, nil)
+3
-1
server/handle_server_deactivate_account.go
+3
-1
server/handle_server_deactivate_account.go
···
19
19
}
20
20
21
21
func (s *Server) handleServerDeactivateAccount(e echo.Context) error {
22
+
ctx := e.Request().Context()
23
+
22
24
var req ComAtprotoServerDeactivateAccountRequest
23
25
if err := e.Bind(&req); err != nil {
24
26
s.logger.Error("error binding", "error", err)
···
27
29
28
30
urepo := e.Get("repo").(*models.RepoActor)
29
31
30
-
if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil {
32
+
if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil {
31
33
s.logger.Error("error updating account status to deactivated", "error", err)
32
34
return helpers.ServerError(e, nil)
33
35
}
+4
-2
server/handle_server_delete_account.go
+4
-2
server/handle_server_delete_account.go
···
20
20
}
21
21
22
22
func (s *Server) handleServerDeleteAccount(e echo.Context) error {
23
+
ctx := e.Request().Context()
24
+
23
25
var req ComAtprotoServerDeleteAccountRequest
24
26
if err := e.Bind(&req); err != nil {
25
27
s.logger.Error("error binding", "error", err)
···
31
33
return helpers.ServerError(e, nil)
32
34
}
33
35
34
-
urepo, err := s.getRepoActorByDid(req.Did)
36
+
urepo, err := s.getRepoActorByDid(ctx, req.Did)
35
37
if err != nil {
36
38
s.logger.Error("error getting repo", "error", err)
37
39
return echo.NewHTTPError(400, "account not found")
···
66
68
})
67
69
}
68
70
69
-
tx := s.db.BeginDangerously()
71
+
tx := s.db.BeginDangerously(ctx)
70
72
if tx.Error != nil {
71
73
s.logger.Error("error starting transaction", "error", tx.Error)
72
74
return helpers.ServerError(e, nil)
+4
-2
server/handle_server_delete_session.go
+4
-2
server/handle_server_delete_session.go
···
7
7
)
8
8
9
9
func (s *Server) handleDeleteSession(e echo.Context) error {
10
+
ctx := e.Request().Context()
11
+
10
12
token := e.Get("token").(string)
11
13
12
14
var acctok models.Token
13
-
if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil {
15
+
if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil {
14
16
s.logger.Error("error deleting access token from db", "error", err)
15
17
return helpers.ServerError(e, nil)
16
18
}
17
19
18
-
if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil {
20
+
if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil {
19
21
s.logger.Error("error deleting refresh token from db", "error", err)
20
22
return helpers.ServerError(e, nil)
21
23
}
+5
-3
server/handle_server_refresh_session.go
+5
-3
server/handle_server_refresh_session.go
···
16
16
}
17
17
18
18
func (s *Server) handleRefreshSession(e echo.Context) error {
19
+
ctx := e.Request().Context()
20
+
19
21
token := e.Get("token").(string)
20
22
repo := e.Get("repo").(*models.RepoActor)
21
23
22
-
if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil {
24
+
if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil {
23
25
s.logger.Error("error getting refresh token from db", "error", err)
24
26
return helpers.ServerError(e, nil)
25
27
}
26
28
27
-
if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil {
29
+
if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil {
28
30
s.logger.Error("error deleting access token from db", "error", err)
29
31
return helpers.ServerError(e, nil)
30
32
}
31
33
32
-
sess, err := s.createSession(&repo.Repo)
34
+
sess, err := s.createSession(ctx, &repo.Repo)
33
35
if err != nil {
34
36
s.logger.Error("error creating new session for refresh", "error", err)
35
37
return helpers.ServerError(e, nil)
+3
-1
server/handle_server_request_account_delete.go
+3
-1
server/handle_server_request_account_delete.go
···
10
10
)
11
11
12
12
func (s *Server) handleServerRequestAccountDelete(e echo.Context) error {
13
+
ctx := e.Request().Context()
14
+
13
15
urepo := e.Get("repo").(*models.RepoActor)
14
16
15
17
token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
16
18
expiresAt := time.Now().UTC().Add(15 * time.Minute)
17
19
18
-
if err := s.db.Exec("UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil {
20
+
if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil {
19
21
s.logger.Error("error setting deletion token", "error", err)
20
22
return helpers.ServerError(e, nil)
21
23
}
+3
-1
server/handle_server_request_email_confirmation.go
+3
-1
server/handle_server_request_email_confirmation.go
···
11
11
)
12
12
13
13
func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error {
14
+
ctx := e.Request().Context()
15
+
14
16
urepo := e.Get("repo").(*models.RepoActor)
15
17
16
18
if urepo.EmailConfirmedAt != nil {
···
20
22
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
21
23
eat := time.Now().Add(10 * time.Minute).UTC()
22
24
23
-
if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
25
+
if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
24
26
s.logger.Error("error updating user", "error", err)
25
27
return helpers.ServerError(e, nil)
26
28
}
+3
-1
server/handle_server_request_email_update.go
+3
-1
server/handle_server_request_email_update.go
···
14
14
}
15
15
16
16
func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error {
17
+
ctx := e.Request().Context()
18
+
17
19
urepo := e.Get("repo").(*models.RepoActor)
18
20
19
21
if urepo.EmailConfirmedAt != nil {
20
22
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
21
23
eat := time.Now().Add(10 * time.Minute).UTC()
22
24
23
-
if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
25
+
if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
24
26
s.logger.Error("error updating repo", "error", err)
25
27
return helpers.ServerError(e, nil)
26
28
}
+4
-2
server/handle_server_request_password_reset.go
+4
-2
server/handle_server_request_password_reset.go
···
14
14
}
15
15
16
16
func (s *Server) handleServerRequestPasswordReset(e echo.Context) error {
17
+
ctx := e.Request().Context()
18
+
17
19
urepo, ok := e.Get("repo").(*models.RepoActor)
18
20
if !ok {
19
21
var req ComAtprotoServerRequestPasswordResetRequest
···
25
27
return err
26
28
}
27
29
28
-
murepo, err := s.getRepoActorByEmail(req.Email)
30
+
murepo, err := s.getRepoActorByEmail(ctx, req.Email)
29
31
if err != nil {
30
32
return err
31
33
}
···
36
38
code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5))
37
39
eat := time.Now().Add(10 * time.Minute).UTC()
38
40
39
-
if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
41
+
if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil {
40
42
s.logger.Error("error updating repo", "error", err)
41
43
return helpers.ServerError(e, nil)
42
44
}
+11
-8
server/handle_server_reserve_signing_key.go
+11
-8
server/handle_server_reserve_signing_key.go
···
1
1
package server
2
2
3
3
import (
4
+
"context"
4
5
"time"
5
6
6
7
"github.com/bluesky-social/indigo/atproto/atcrypto"
···
18
19
}
19
20
20
21
func (s *Server) handleServerReserveSigningKey(e echo.Context) error {
22
+
ctx := e.Request().Context()
23
+
21
24
var req ServerReserveSigningKeyRequest
22
25
if err := e.Bind(&req); err != nil {
23
26
s.logger.Error("could not bind reserve signing key request", "error", err)
···
26
29
27
30
if req.Did != nil && *req.Did != "" {
28
31
var existing models.ReservedKey
29
-
if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" {
32
+
if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" {
30
33
return e.JSON(200, ServerReserveSigningKeyResponse{
31
34
SigningKey: existing.KeyDid,
32
35
})
···
54
57
CreatedAt: time.Now(),
55
58
}
56
59
57
-
if err := s.db.Create(&reservedKey, nil).Error; err != nil {
60
+
if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil {
58
61
s.logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err)
59
62
return helpers.ServerError(e, nil)
60
63
}
···
66
69
})
67
70
}
68
71
69
-
func (s *Server) getReservedKey(keyDidOrDid string) (*models.ReservedKey, error) {
72
+
func (s *Server) getReservedKey(ctx context.Context, keyDidOrDid string) (*models.ReservedKey, error) {
70
73
var reservedKey models.ReservedKey
71
74
72
-
if err := s.db.Raw("SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" {
75
+
if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" {
73
76
return &reservedKey, nil
74
77
}
75
78
76
-
if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" {
79
+
if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" {
77
80
return &reservedKey, nil
78
81
}
79
82
80
83
return nil, nil
81
84
}
82
85
83
-
func (s *Server) deleteReservedKey(keyDid string, did *string) error {
84
-
if err := s.db.Exec("DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil {
86
+
func (s *Server) deleteReservedKey(ctx context.Context, keyDid string, did *string) error {
87
+
if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil {
85
88
return err
86
89
}
87
90
88
91
if did != nil && *did != "" {
89
-
if err := s.db.Exec("DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil {
92
+
if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil {
90
93
return err
91
94
}
92
95
}
+3
-1
server/handle_server_reset_password.go
+3
-1
server/handle_server_reset_password.go
···
16
16
}
17
17
18
18
func (s *Server) handleServerResetPassword(e echo.Context) error {
19
+
ctx := e.Request().Context()
20
+
19
21
urepo := e.Get("repo").(*models.RepoActor)
20
22
21
23
var req ComAtprotoServerResetPasswordRequest
···
46
48
return helpers.ServerError(e, nil)
47
49
}
48
50
49
-
if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil {
51
+
if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil {
50
52
s.logger.Error("error updating repo", "error", err)
51
53
return helpers.ServerError(e, nil)
52
54
}
+3
-1
server/handle_server_update_email.go
+3
-1
server/handle_server_update_email.go
···
15
15
}
16
16
17
17
func (s *Server) handleServerUpdateEmail(e echo.Context) error {
18
+
ctx := e.Request().Context()
19
+
18
20
urepo := e.Get("repo").(*models.RepoActor)
19
21
20
22
var req ComAtprotoServerUpdateEmailRequest
···
39
41
return helpers.ExpiredTokenError(e)
40
42
}
41
43
42
-
if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
44
+
if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil {
43
45
s.logger.Error("error updating repo", "error", err)
44
46
return helpers.ServerError(e, nil)
45
47
}
+5
-3
server/handle_sync_get_blob.go
+5
-3
server/handle_sync_get_blob.go
···
17
17
)
18
18
19
19
func (s *Server) handleSyncGetBlob(e echo.Context) error {
20
+
ctx := e.Request().Context()
21
+
20
22
did := e.QueryParam("did")
21
23
if did == "" {
22
24
return helpers.InputError(e, nil)
···
32
34
return helpers.InputError(e, nil)
33
35
}
34
36
35
-
urepo, err := s.getRepoActorByDid(did)
37
+
urepo, err := s.getRepoActorByDid(ctx, did)
36
38
if err != nil {
37
39
s.logger.Error("could not find user for requested blob", "error", err)
38
40
return helpers.InputError(e, nil)
···
46
48
}
47
49
48
50
var blob models.Blob
49
-
if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil {
51
+
if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil {
50
52
s.logger.Error("error looking up blob", "error", err)
51
53
return helpers.ServerError(e, nil)
52
54
}
···
55
57
56
58
if blob.Storage == "sqlite" {
57
59
var parts []models.BlobPart
58
-
if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
60
+
if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil {
59
61
s.logger.Error("error getting blob parts", "error", err)
60
62
return helpers.ServerError(e, nil)
61
63
}
+1
-1
server/handle_sync_get_blocks.go
+1
-1
server/handle_sync_get_blocks.go
+3
-1
server/handle_sync_get_latest_commit.go
+3
-1
server/handle_sync_get_latest_commit.go
···
12
12
}
13
13
14
14
func (s *Server) handleSyncGetLatestCommit(e echo.Context) error {
15
+
ctx := e.Request().Context()
16
+
15
17
did := e.QueryParam("did")
16
18
if did == "" {
17
19
return helpers.InputError(e, nil)
18
20
}
19
21
20
-
urepo, err := s.getRepoActorByDid(did)
22
+
urepo, err := s.getRepoActorByDid(ctx, did)
21
23
if err != nil {
22
24
return err
23
25
}
+1
-1
server/handle_sync_get_record.go
+1
-1
server/handle_sync_get_record.go
···
20
20
rkey := e.QueryParam("rkey")
21
21
22
22
var urepo models.Repo
23
-
if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil {
23
+
if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil {
24
24
s.logger.Error("error getting repo", "error", err)
25
25
return helpers.ServerError(e, nil)
26
26
}
+4
-2
server/handle_sync_get_repo.go
+4
-2
server/handle_sync_get_repo.go
···
13
13
)
14
14
15
15
func (s *Server) handleSyncGetRepo(e echo.Context) error {
16
+
ctx := e.Request().Context()
17
+
16
18
did := e.QueryParam("did")
17
19
if did == "" {
18
20
return helpers.InputError(e, nil)
19
21
}
20
22
21
-
urepo, err := s.getRepoActorByDid(did)
23
+
urepo, err := s.getRepoActorByDid(ctx, did)
22
24
if err != nil {
23
25
return err
24
26
}
···
41
43
}
42
44
43
45
var blocks []models.Block
44
-
if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil {
46
+
if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil {
45
47
return err
46
48
}
47
49
+3
-1
server/handle_sync_get_repo_status.go
+3
-1
server/handle_sync_get_repo_status.go
···
14
14
15
15
// TODO: make this actually do the right thing
16
16
func (s *Server) handleSyncGetRepoStatus(e echo.Context) error {
17
+
ctx := e.Request().Context()
18
+
17
19
did := e.QueryParam("did")
18
20
if did == "" {
19
21
return helpers.InputError(e, nil)
20
22
}
21
23
22
-
urepo, err := s.getRepoActorByDid(did)
24
+
urepo, err := s.getRepoActorByDid(ctx, did)
23
25
if err != nil {
24
26
return err
25
27
}
+4
-2
server/handle_sync_list_blobs.go
+4
-2
server/handle_sync_list_blobs.go
···
14
14
}
15
15
16
16
func (s *Server) handleSyncListBlobs(e echo.Context) error {
17
+
ctx := e.Request().Context()
18
+
17
19
did := e.QueryParam("did")
18
20
if did == "" {
19
21
return helpers.InputError(e, nil)
···
35
37
}
36
38
params = append(params, limit)
37
39
38
-
urepo, err := s.getRepoActorByDid(did)
40
+
urepo, err := s.getRepoActorByDid(ctx, did)
39
41
if err != nil {
40
42
s.logger.Error("could not find user for requested blobs", "error", err)
41
43
return helpers.InputError(e, nil)
···
49
51
}
50
52
51
53
var blobs []models.Blob
52
-
if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil {
54
+
if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil {
53
55
s.logger.Error("error getting records", "error", err)
54
56
return helpers.ServerError(e, nil)
55
57
}
+3
-1
server/handle_well_known.go
+3
-1
server/handle_well_known.go
···
67
67
}
68
68
69
69
func (s *Server) handleAtprotoDid(e echo.Context) error {
70
+
ctx := e.Request().Context()
71
+
70
72
host := e.Request().Host
71
73
if host == "" {
72
74
return helpers.InputError(e, to.StringPtr("Invalid handle."))
···
84
86
return e.NoContent(404)
85
87
}
86
88
87
-
actor, err := s.getActorByHandle(host)
89
+
actor, err := s.getActorByHandle(ctx, host)
88
90
if err != nil {
89
91
if err == gorm.ErrRecordNotFound {
90
92
return e.NoContent(404)
+9
-5
server/middleware.go
+9
-5
server/middleware.go
···
37
37
38
38
func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
39
39
return func(e echo.Context) error {
40
+
ctx := e.Request().Context()
41
+
40
42
authheader := e.Request().Header.Get("authorization")
41
43
if authheader == "" {
42
44
return e.JSON(401, map[string]string{"error": "Unauthorized"})
···
78
80
}
79
81
did = maybeDid
80
82
81
-
maybeRepo, err := s.getRepoActorByDid(did)
83
+
maybeRepo, err := s.getRepoActorByDid(ctx, did)
82
84
if err != nil {
83
85
s.logger.Error("error fetching repo", "error", err)
84
86
return helpers.ServerError(e, nil)
···
159
161
Found bool
160
162
}
161
163
var result Result
162
-
if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
164
+
if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil {
163
165
if err == gorm.ErrRecordNotFound {
164
166
return helpers.InvalidTokenError(e)
165
167
}
···
184
186
}
185
187
186
188
if repo == nil {
187
-
maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string))
189
+
maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string))
188
190
if err != nil {
189
191
s.logger.Error("error fetching repo", "error", err)
190
192
return helpers.ServerError(e, nil)
···
207
209
208
210
func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc {
209
211
return func(e echo.Context) error {
212
+
ctx := e.Request().Context()
213
+
210
214
authheader := e.Request().Header.Get("authorization")
211
215
if authheader == "" {
212
216
return e.JSON(401, map[string]string{"error": "Unauthorized"})
···
243
247
}
244
248
245
249
var oauthToken provider.OauthToken
246
-
if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
250
+
if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil {
247
251
s.logger.Error("error finding access token in db", "error", err)
248
252
return helpers.InputError(e, nil)
249
253
}
···
266
270
})
267
271
}
268
272
269
-
repo, err := s.getRepoActorByDid(oauthToken.Sub)
273
+
repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub)
270
274
if err != nil {
271
275
s.logger.Error("could not find actor in db", "error", err)
272
276
return helpers.ServerError(e, nil)
+11
-11
server/repo.go
+11
-11
server/repo.go
···
181
181
case OpTypeDelete:
182
182
// try to find the old record in the database
183
183
var old models.Record
184
-
if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
184
+
if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil {
185
185
return nil, err
186
186
}
187
187
···
323
323
var cids []cid.Cid
324
324
// whenever there is cid present, we know it's a create (dumb)
325
325
if entry.Cid != "" {
326
-
if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{
326
+
if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{
327
327
Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}},
328
328
UpdateAll: true,
329
329
}}).Error; err != nil {
···
331
331
}
332
332
333
333
// increment the given blob refs, yay
334
-
cids, err = rm.incrementBlobRefs(urepo, entry.Value)
334
+
cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value)
335
335
if err != nil {
336
336
return nil, err
337
337
}
···
339
339
// as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey
340
340
// is did + collection + rkey. i still really want to separate that out, or use a different type to make
341
341
// this less confusing/easy to read. alas, its 2 am and yea no
342
-
if err := rm.s.db.Delete(&entry, nil).Error; err != nil {
342
+
if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil {
343
343
return nil, err
344
344
}
345
345
346
346
// TODO:
347
-
cids, err = rm.decrementBlobRefs(urepo, entry.Value)
347
+
cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value)
348
348
if err != nil {
349
349
return nil, err
350
350
}
···
411
411
return c, bs.GetReadLog(), nil
412
412
}
413
413
414
-
func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
414
+
func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
415
415
cids, err := getBlobCidsFromCbor(cbor)
416
416
if err != nil {
417
417
return nil, err
418
418
}
419
419
420
420
for _, c := range cids {
421
-
if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
421
+
if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil {
422
422
return nil, err
423
423
}
424
424
}
···
426
426
return cids, nil
427
427
}
428
428
429
-
func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
429
+
func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) {
430
430
cids, err := getBlobCidsFromCbor(cbor)
431
431
if err != nil {
432
432
return nil, err
···
437
437
ID uint
438
438
Count int
439
439
}
440
-
if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
440
+
if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil {
441
441
return nil, err
442
442
}
443
443
444
444
// TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what
445
445
// storage it is in, and clean up s3!!!!
446
446
if res.Count == 0 {
447
-
if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
447
+
if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil {
448
448
return nil, err
449
449
}
450
-
if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
450
+
if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil {
451
451
return nil, err
452
452
}
453
453
}
+1
-1
server/server.go
+1
-1
server/server.go
···
729
729
}
730
730
731
731
func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
732
-
if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
732
+
if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
733
733
return err
734
734
}
735
735
+4
-3
server/session.go
+4
-3
server/session.go
···
1
1
package server
2
2
3
3
import (
4
+
"context"
4
5
"time"
5
6
6
7
"github.com/golang-jwt/jwt/v4"
···
13
14
RefreshToken string
14
15
}
15
16
16
-
func (s *Server) createSession(repo *models.Repo) (*Session, error) {
17
+
func (s *Server) createSession(ctx context.Context, repo *models.Repo) (*Session, error) {
17
18
now := time.Now()
18
19
accexp := now.Add(3 * time.Hour)
19
20
refexp := now.Add(7 * 24 * time.Hour)
···
49
50
return nil, err
50
51
}
51
52
52
-
if err := s.db.Create(&models.Token{
53
+
if err := s.db.Create(ctx, &models.Token{
53
54
Token: accessString,
54
55
Did: repo.Did,
55
56
RefreshToken: refreshString,
···
59
60
return nil, err
60
61
}
61
62
62
-
if err := s.db.Create(&models.RefreshToken{
63
+
if err := s.db.Create(ctx, &models.RefreshToken{
63
64
Token: refreshString,
64
65
Did: repo.Did,
65
66
CreatedAt: now,
+3
-3
sqlite_blockstore/sqlite_blockstore.go
+3
-3
sqlite_blockstore/sqlite_blockstore.go
···
45
45
return maybeBlock, nil
46
46
}
47
47
48
-
if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
48
+
if err := bs.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil {
49
49
return nil, err
50
50
}
51
51
···
71
71
Value: block.RawData(),
72
72
}
73
73
74
-
if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{
74
+
if err := bs.db.Create(ctx, &b, []clause.Expression{clause.OnConflict{
75
75
Columns: []clause.Column{{Name: "did"}, {Name: "cid"}},
76
76
UpdateAll: true,
77
77
}}).Error; err != nil {
···
94
94
}
95
95
96
96
func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error {
97
-
tx := bs.db.BeginDangerously()
97
+
tx := bs.db.BeginDangerously(ctx)
98
98
99
99
for _, block := range blocks {
100
100
bs.inserts[block.Cid()] = block