An atproto PDS written in Go

add withcontxt to all db calls

+15 -14
internal/db/db.go
··· 1 package db 2 3 import ( 4 "sync" 5 6 "gorm.io/gorm" ··· 19 } 20 } 21 22 - func (db *DB) Create(value any, clauses []clause.Expression) *gorm.DB { 23 db.mu.Lock() 24 defer db.mu.Unlock() 25 - return db.cli.Clauses(clauses...).Create(value) 26 } 27 28 - func (db *DB) Save(value any, clauses []clause.Expression) *gorm.DB { 29 db.mu.Lock() 30 defer db.mu.Unlock() 31 - return db.cli.Clauses(clauses...).Save(value) 32 } 33 34 - func (db *DB) Exec(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 35 db.mu.Lock() 36 defer db.mu.Unlock() 37 - return db.cli.Clauses(clauses...).Exec(sql, values...) 38 } 39 40 - func (db *DB) Raw(sql string, clauses []clause.Expression, values ...any) *gorm.DB { 41 - return db.cli.Clauses(clauses...).Raw(sql, values...) 42 } 43 44 func (db *DB) AutoMigrate(models ...any) error { 45 return db.cli.AutoMigrate(models...) 46 } 47 48 - func (db *DB) Delete(value any, clauses []clause.Expression) *gorm.DB { 49 db.mu.Lock() 50 defer db.mu.Unlock() 51 - return db.cli.Clauses(clauses...).Delete(value) 52 } 53 54 - func (db *DB) First(dest any, conds ...any) *gorm.DB { 55 - return db.cli.First(dest, conds...) 56 } 57 58 // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure 59 // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. 60 // e.g. when we do apply writes we should also be using a transcation but we don't right now 61 - func (db *DB) BeginDangerously() *gorm.DB { 62 - return db.cli.Begin() 63 } 64 65 func (db *DB) Lock() {
··· 1 package db 2 3 import ( 4 + "context" 5 "sync" 6 7 "gorm.io/gorm" ··· 20 } 21 } 22 23 + func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 24 db.mu.Lock() 25 defer db.mu.Unlock() 26 + return db.cli.WithContext(ctx).Clauses(clauses...).Create(value) 27 } 28 29 + func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 30 db.mu.Lock() 31 defer db.mu.Unlock() 32 + return db.cli.WithContext(ctx).Clauses(clauses...).Save(value) 33 } 34 35 + func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 36 db.mu.Lock() 37 defer db.mu.Unlock() 38 + return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...) 39 } 40 41 + func (db *DB) Raw(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 42 + return db.cli.WithContext(ctx).Clauses(clauses...).Raw(sql, values...) 43 } 44 45 func (db *DB) AutoMigrate(models ...any) error { 46 return db.cli.AutoMigrate(models...) 47 } 48 49 + func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 50 db.mu.Lock() 51 defer db.mu.Unlock() 52 + return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value) 53 } 54 55 + func (db *DB) First(ctx context.Context, dest any, conds ...any) *gorm.DB { 56 + return db.cli.WithContext(ctx).First(dest, conds...) 57 } 58 59 // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure 60 // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. 61 // e.g. when we do apply writes we should also be using a transcation but we don't right now 62 + func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB { 63 + return db.cli.WithContext(ctx).Begin() 64 } 65 66 func (db *DB) Lock() {
+10 -8
server/common.go
··· 1 package server 2 3 import ( 4 "github.com/haileyok/cocoon/models" 5 ) 6 7 - func (s *Server) getActorByHandle(handle string) (*models.Actor, error) { 8 var actor models.Actor 9 - if err := s.db.First(&actor, models.Actor{Handle: handle}).Error; err != nil { 10 return nil, err 11 } 12 return &actor, nil 13 } 14 15 - func (s *Server) getRepoByEmail(email string) (*models.Repo, error) { 16 var repo models.Repo 17 - if err := s.db.First(&repo, models.Repo{Email: email}).Error; err != nil { 18 return nil, err 19 } 20 return &repo, nil 21 } 22 23 - func (s *Server) getRepoActorByEmail(email string) (*models.RepoActor, error) { 24 var repo models.RepoActor 25 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 26 return nil, err 27 } 28 return &repo, nil 29 } 30 31 - func (s *Server) getRepoActorByDid(did string) (*models.RepoActor, error) { 32 var repo models.RepoActor 33 - if err := s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 34 return nil, err 35 } 36 return &repo, nil
··· 1 package server 2 3 import ( 4 + "context" 5 + 6 "github.com/haileyok/cocoon/models" 7 ) 8 9 + func (s *Server) getActorByHandle(ctx context.Context, handle string) (*models.Actor, error) { 10 var actor models.Actor 11 + if err := s.db.First(ctx, &actor, models.Actor{Handle: handle}).Error; err != nil { 12 return nil, err 13 } 14 return &actor, nil 15 } 16 17 + func (s *Server) getRepoByEmail(ctx context.Context, email string) (*models.Repo, error) { 18 var repo models.Repo 19 + if err := s.db.First(ctx, &repo, models.Repo{Email: email}).Error; err != nil { 20 return nil, err 21 } 22 return &repo, nil 23 } 24 25 + func (s *Server) getRepoActorByEmail(ctx context.Context, email string) (*models.RepoActor, error) { 26 var repo models.RepoActor 27 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email= ?", nil, email).Scan(&repo).Error; err != nil { 28 return nil, err 29 } 30 return &repo, nil 31 } 32 33 + func (s *Server) getRepoActorByDid(ctx context.Context, did string) (*models.RepoActor, error) { 34 var repo models.RepoActor 35 + if err := s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, did).Scan(&repo).Error; err != nil { 36 return nil, err 37 } 38 return &repo, nil
+2 -1
server/handle_account.go
··· 12 13 func (s *Server) handleAccount(e echo.Context) error { 14 ctx := e.Request().Context() 15 repo, sess, err := s.getSessionRepoOrErr(e) 16 if err != nil { 17 return e.Redirect(303, "/account/signin") ··· 20 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) 21 22 var tokens []provider.OauthToken 23 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 24 s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 25 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 26 sess.Save(e.Request(), e.Response())
··· 12 13 func (s *Server) handleAccount(e echo.Context) error { 14 ctx := e.Request().Context() 15 + 16 repo, sess, err := s.getSessionRepoOrErr(e) 17 if err != nil { 18 return e.Redirect(303, "/account/signin") ··· 21 oldestPossibleSession := time.Now().Add(constants.ConfidentialClientSessionLifetime) 22 23 var tokens []provider.OauthToken 24 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE sub = ? AND created_at < ? ORDER BY created_at ASC", nil, repo.Repo.Did, oldestPossibleSession).Scan(&tokens).Error; err != nil { 25 s.logger.Error("couldnt fetch oauth sessions for account", "did", repo.Repo.Did, "error", err) 26 sess.AddFlash("Unable to fetch sessions. See server logs for more details.", "error") 27 sess.Save(e.Request(), e.Response())
+5 -3
server/handle_account_revoke.go
··· 5 "github.com/labstack/echo/v4" 6 ) 7 8 - type AccountRevokeRequest struct { 9 Token string `form:"token"` 10 } 11 12 func (s *Server) handleAccountRevoke(e echo.Context) error { 13 - var req AccountRevokeRequest 14 if err := e.Bind(&req); err != nil { 15 s.logger.Error("could not bind account revoke request", "error", err) 16 return helpers.ServerError(e, nil) ··· 21 return e.Redirect(303, "/account/signin") 22 } 23 24 - if err := s.db.Exec("DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 25 s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 26 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") 27 sess.Save(e.Request(), e.Response())
··· 5 "github.com/labstack/echo/v4" 6 ) 7 8 + type AccountRevokeInput struct { 9 Token string `form:"token"` 10 } 11 12 func (s *Server) handleAccountRevoke(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 15 + var req AccountRevokeInput 16 if err := e.Bind(&req); err != nil { 17 s.logger.Error("could not bind account revoke request", "error", err) 18 return helpers.ServerError(e, nil) ··· 23 return e.Redirect(303, "/account/signin") 24 } 25 26 + if err := s.db.Exec(ctx, "DELETE FROM oauth_tokens WHERE sub = ? AND token = ?", nil, repo.Repo.Did, req.Token).Error; err != nil { 27 s.logger.Error("couldnt delete oauth session for account", "did", repo.Repo.Did, "token", req.Token, "error", err) 28 sess.AddFlash("Unable to revoke session. See server logs for more details.", "error") 29 sess.Save(e.Request(), e.Response())
+10 -6
server/handle_account_signin.go
··· 14 "gorm.io/gorm" 15 ) 16 17 - type OauthSigninRequest struct { 18 Username string `form:"username"` 19 Password string `form:"password"` 20 QueryParams string `form:"query_params"` 21 } 22 23 func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 24 sess, err := session.Get("session", e) 25 if err != nil { 26 return nil, nil, err ··· 31 return nil, sess, errors.New("did was not set in session") 32 } 33 34 - repo, err := s.getRepoActorByDid(did) 35 if err != nil { 36 return nil, sess, err 37 } ··· 60 } 61 62 func (s *Server) handleAccountSigninPost(e echo.Context) error { 63 - var req OauthSigninRequest 64 if err := e.Bind(&req); err != nil { 65 s.logger.Error("error binding sign in req", "error", err) 66 return helpers.ServerError(e, nil) ··· 83 var err error 84 switch idtype { 85 case "did": 86 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 87 case "handle": 88 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 89 case "email": 90 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 91 } 92 if err != nil { 93 if err == gorm.ErrRecordNotFound {
··· 14 "gorm.io/gorm" 15 ) 16 17 + type OauthSigninInput struct { 18 Username string `form:"username"` 19 Password string `form:"password"` 20 QueryParams string `form:"query_params"` 21 } 22 23 func (s *Server) getSessionRepoOrErr(e echo.Context) (*models.RepoActor, *sessions.Session, error) { 24 + ctx := e.Request().Context() 25 + 26 sess, err := session.Get("session", e) 27 if err != nil { 28 return nil, nil, err ··· 33 return nil, sess, errors.New("did was not set in session") 34 } 35 36 + repo, err := s.getRepoActorByDid(ctx, did) 37 if err != nil { 38 return nil, sess, err 39 } ··· 62 } 63 64 func (s *Server) handleAccountSigninPost(e echo.Context) error { 65 + ctx := e.Request().Context() 66 + 67 + var req OauthSigninInput 68 if err := e.Bind(&req); err != nil { 69 s.logger.Error("error binding sign in req", "error", err) 70 return helpers.ServerError(e, nil) ··· 87 var err error 88 switch idtype { 89 case "did": 90 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Username).Scan(&repo).Error 91 case "handle": 92 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Username).Scan(&repo).Error 93 case "email": 94 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Username).Scan(&repo).Error 95 } 96 if err != nil { 97 if err == gorm.ErrRecordNotFound {
+3 -1
server/handle_actor_put_preferences.go
··· 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 12 func (s *Server) handleActorPutPreferences(e echo.Context) error { 13 repo := e.Get("repo").(*models.RepoActor) 14 15 var prefs map[string]any ··· 22 return err 23 } 24 25 - if err := s.db.Exec("UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 26 return err 27 } 28
··· 10 // This is kinda lame. Not great to implement app.bsky in the pds, but alas 11 12 func (s *Server) handleActorPutPreferences(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 15 repo := e.Get("repo").(*models.RepoActor) 16 17 var prefs map[string]any ··· 24 return err 25 } 26 27 + if err := s.db.Exec(ctx, "UPDATE repos SET preferences = ? WHERE did = ?", nil, b, repo.Repo.Did).Error; err != nil { 28 return err 29 } 30
+3 -1
server/handle_identity_request_plc_operation.go
··· 10 ) 11 12 func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { 13 urepo := e.Get("repo").(*models.RepoActor) 14 15 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 16 eat := time.Now().Add(10 * time.Minute).UTC() 17 18 - if err := s.db.Exec("UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 19 s.logger.Error("error updating user", "error", err) 20 return helpers.ServerError(e, nil) 21 }
··· 10 ) 11 12 func (s *Server) handleIdentityRequestPlcOperationSignature(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 15 urepo := e.Get("repo").(*models.RepoActor) 16 17 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 18 eat := time.Now().Add(10 * time.Minute).UTC() 19 20 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = ?, plc_operation_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 21 s.logger.Error("error updating user", "error", err) 22 return helpers.ServerError(e, nil) 23 }
+1 -1
server/handle_identity_sign_plc_operation.go
··· 92 return helpers.ServerError(e, nil) 93 } 94 95 - if err := s.db.Exec("UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 96 s.logger.Error("error updating repo", "error", err) 97 return helpers.ServerError(e, nil) 98 }
··· 92 return helpers.ServerError(e, nil) 93 } 94 95 + if err := s.db.Exec(ctx, "UPDATE repos SET plc_operation_code = NULL, plc_operation_code_expires_at = NULL WHERE did = ?", nil, repo.Repo.Did).Error; err != nil { 96 s.logger.Error("error updating repo", "error", err) 97 return helpers.ServerError(e, nil) 98 }
+1 -1
server/handle_identity_update_handle.go
··· 94 }, 95 }) 96 97 - if err := s.db.Exec("UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 98 s.logger.Error("error updating handle in db", "error", err) 99 return helpers.ServerError(e, nil) 100 }
··· 94 }, 95 }) 96 97 + if err := s.db.Exec(ctx, "UPDATE actors SET handle = ? WHERE did = ?", nil, req.Handle, repo.Repo.Did).Error; err != nil { 98 s.logger.Error("error updating handle in db", "error", err) 99 return helpers.ServerError(e, nil) 100 }
+3 -1
server/handle_import_repo.go
··· 18 ) 19 20 func (s *Server) handleRepoImportRepo(e echo.Context) error { 21 urepo := e.Get("repo").(*models.RepoActor) 22 23 b, err := io.ReadAll(e.Request().Body) ··· 63 return helpers.ServerError(e, nil) 64 } 65 66 - tx := s.db.BeginDangerously() 67 68 clock := syntax.NewTIDClock(0) 69
··· 18 ) 19 20 func (s *Server) handleRepoImportRepo(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + 23 urepo := e.Get("repo").(*models.RepoActor) 24 25 b, err := io.ReadAll(e.Request().Body) ··· 65 return helpers.ServerError(e, nil) 66 } 67 68 + tx := s.db.BeginDangerously(ctx) 69 70 clock := syntax.NewTIDClock(0) 71
+7 -3
server/handle_oauth_authorize.go
··· 13 ) 14 15 func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 reqUri := e.QueryParam("request_uri") 17 if reqUri == "" { 18 // render page for logged out dev ··· 38 } 39 40 var req provider.OauthAuthorizationRequest 41 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 42 return helpers.ServerError(e, to.StringPtr(err.Error())) 43 } 44 ··· 72 } 73 74 func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 75 repo, _, err := s.getSessionRepoOrErr(e) 76 if err != nil { 77 return e.Redirect(303, "/account/signin") ··· 89 } 90 91 var authReq provider.OauthAuthorizationRequest 92 - if err := s.db.Raw("SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 93 return helpers.ServerError(e, to.StringPtr(err.Error())) 94 } 95 ··· 113 114 code := oauth.GenerateCode() 115 116 - if err := s.db.Exec("UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 117 s.logger.Error("error updating authorization request", "error", err) 118 return helpers.ServerError(e, nil) 119 }
··· 13 ) 14 15 func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + 18 reqUri := e.QueryParam("request_uri") 19 if reqUri == "" { 20 // render page for logged out dev ··· 40 } 41 42 var req provider.OauthAuthorizationRequest 43 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&req).Error; err != nil { 44 return helpers.ServerError(e, to.StringPtr(err.Error())) 45 } 46 ··· 74 } 75 76 func (s *Server) handleOauthAuthorizePost(e echo.Context) error { 77 + ctx := e.Request().Context() 78 + 79 repo, _, err := s.getSessionRepoOrErr(e) 80 if err != nil { 81 return e.Redirect(303, "/account/signin") ··· 93 } 94 95 var authReq provider.OauthAuthorizationRequest 96 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_authorization_requests WHERE request_id = ?", nil, reqId).Scan(&authReq).Error; err != nil { 97 return helpers.ServerError(e, to.StringPtr(err.Error())) 98 } 99 ··· 117 118 code := oauth.GenerateCode() 119 120 + if err := s.db.Exec(ctx, "UPDATE oauth_authorization_requests SET sub = ?, code = ?, accepted = ?, ip = ? WHERE request_id = ?", nil, repo.Repo.Did, code, true, e.RealIP(), reqId).Error; err != nil { 121 s.logger.Error("error updating authorization request", "error", err) 122 return helpers.ServerError(e, nil) 123 }
+3 -1
server/handle_oauth_par.go
··· 19 } 20 21 func (s *Server) handleOauthPar(e echo.Context) error { 22 var parRequest provider.ParRequest 23 if err := e.Bind(&parRequest); err != nil { 24 s.logger.Error("error binding for par request", "error", err) ··· 86 ExpiresAt: eat, 87 } 88 89 - if err := s.db.Create(authRequest, nil).Error; err != nil { 90 s.logger.Error("error creating auth request in db", "error", err) 91 return helpers.ServerError(e, nil) 92 }
··· 19 } 20 21 func (s *Server) handleOauthPar(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 24 var parRequest provider.ParRequest 25 if err := e.Bind(&parRequest); err != nil { 26 s.logger.Error("error binding for par request", "error", err) ··· 88 ExpiresAt: eat, 89 } 90 91 + if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 92 s.logger.Error("error creating auth request in db", "error", err) 93 return helpers.ServerError(e, nil) 94 }
+7 -5
server/handle_oauth_token.go
··· 38 } 39 40 func (s *Server) handleOauthToken(e echo.Context) error { 41 var req OauthTokenRequest 42 if err := e.Bind(&req); err != nil { 43 s.logger.Error("error binding token request", "error", err) ··· 84 85 var authReq provider.OauthAuthorizationRequest 86 // get the lil guy and delete him 87 - if err := s.db.Raw("DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 88 s.logger.Error("error finding authorization request", "error", err) 89 return helpers.ServerError(e, nil) 90 } ··· 128 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 129 } 130 131 - repo, err := s.getRepoActorByDid(*authReq.Sub) 132 if err != nil { 133 helpers.InputError(e, to.StringPtr("unable to find actor")) 134 } ··· 159 return err 160 } 161 162 - if err := s.db.Create(&provider.OauthToken{ 163 ClientId: authReq.ClientId, 164 ClientAuth: *clientAuth, 165 Parameters: authReq.Parameters, ··· 199 } 200 201 var oauthToken provider.OauthToken 202 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 203 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 204 return helpers.ServerError(e, nil) 205 } ··· 257 return err 258 } 259 260 - if err := s.db.Exec("UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 261 s.logger.Error("error updating token", "error", err) 262 return helpers.ServerError(e, nil) 263 }
··· 38 } 39 40 func (s *Server) handleOauthToken(e echo.Context) error { 41 + ctx := e.Request().Context() 42 + 43 var req OauthTokenRequest 44 if err := e.Bind(&req); err != nil { 45 s.logger.Error("error binding token request", "error", err) ··· 86 87 var authReq provider.OauthAuthorizationRequest 88 // get the lil guy and delete him 89 + if err := s.db.Raw(ctx, "DELETE FROM oauth_authorization_requests WHERE code = ? RETURNING *", nil, *req.Code).Scan(&authReq).Error; err != nil { 90 s.logger.Error("error finding authorization request", "error", err) 91 return helpers.ServerError(e, nil) 92 } ··· 130 return helpers.InputError(e, to.StringPtr("code_challenge parameter wasn't provided")) 131 } 132 133 + repo, err := s.getRepoActorByDid(ctx, *authReq.Sub) 134 if err != nil { 135 helpers.InputError(e, to.StringPtr("unable to find actor")) 136 } ··· 161 return err 162 } 163 164 + if err := s.db.Create(ctx, &provider.OauthToken{ 165 ClientId: authReq.ClientId, 166 ClientAuth: *clientAuth, 167 Parameters: authReq.Parameters, ··· 201 } 202 203 var oauthToken provider.OauthToken 204 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE refresh_token = ?", nil, req.RefreshToken).Scan(&oauthToken).Error; err != nil { 205 s.logger.Error("error finding oauth token by refresh token", "error", err, "refresh_token", req.RefreshToken) 206 return helpers.ServerError(e, nil) 207 } ··· 259 return err 260 } 261 262 + if err := s.db.Exec(ctx, "UPDATE oauth_tokens SET token = ?, refresh_token = ?, expires_at = ?, updated_at = ? WHERE refresh_token = ?", nil, accessString, nextRefreshToken, eat, now, *req.RefreshToken).Error; err != nil { 263 s.logger.Error("error updating token", "error", err) 264 return helpers.ServerError(e, nil) 265 }
+4 -2
server/handle_repo_describe_repo.go
··· 20 } 21 22 func (s *Server) handleDescribeRepo(e echo.Context) error { 23 did := e.QueryParam("repo") 24 - repo, err := s.getRepoActorByDid(did) 25 if err != nil { 26 if err == gorm.ErrRecordNotFound { 27 return helpers.InputError(e, to.StringPtr("RepoNotFound")) ··· 64 } 65 66 var records []models.Record 67 - if err := s.db.Raw("SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 68 s.logger.Error("error getting collections", "error", err) 69 return helpers.ServerError(e, nil) 70 }
··· 20 } 21 22 func (s *Server) handleDescribeRepo(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 25 did := e.QueryParam("repo") 26 + repo, err := s.getRepoActorByDid(ctx, did) 27 if err != nil { 28 if err == gorm.ErrRecordNotFound { 29 return helpers.InputError(e, to.StringPtr("RepoNotFound")) ··· 66 } 67 68 var records []models.Record 69 + if err := s.db.Raw(ctx, "SELECT DISTINCT(nsid) FROM records WHERE did = ?", nil, repo.Repo.Did).Scan(&records).Error; err != nil { 70 s.logger.Error("error getting collections", "error", err) 71 return helpers.ServerError(e, nil) 72 }
+3 -1
server/handle_repo_get_record.go
··· 14 } 15 16 func (s *Server) handleRepoGetRecord(e echo.Context) error { 17 repo := e.QueryParam("repo") 18 collection := e.QueryParam("collection") 19 rkey := e.QueryParam("rkey") ··· 32 } 33 34 var record models.Record 35 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 36 // TODO: handle error nicely 37 return err 38 }
··· 14 } 15 16 func (s *Server) handleRepoGetRecord(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 19 repo := e.QueryParam("repo") 20 collection := e.QueryParam("collection") 21 rkey := e.QueryParam("rkey") ··· 34 } 35 36 var record models.Record 37 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? AND rkey = ?"+cidquery, nil, params...).Scan(&record).Error; err != nil { 38 // TODO: handle error nicely 39 return err 40 }
+4 -2
server/handle_repo_list_missing_blobs.go
··· 22 } 23 24 func (s *Server) handleListMissingBlobs(e echo.Context) error { 25 urepo := e.Get("repo").(*models.RepoActor) 26 27 limitStr := e.QueryParam("limit") ··· 35 } 36 37 var records []models.Record 38 - if err := s.db.Raw("SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 39 s.logger.Error("failed to get records for listMissingBlobs", "error", err) 40 return helpers.ServerError(e, nil) 41 } ··· 69 } 70 71 var count int64 72 - if err := s.db.Raw("SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { 73 continue 74 } 75
··· 22 } 23 24 func (s *Server) handleListMissingBlobs(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + 27 urepo := e.Get("repo").(*models.RepoActor) 28 29 limitStr := e.QueryParam("limit") ··· 37 } 38 39 var records []models.Record 40 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&records).Error; err != nil { 41 s.logger.Error("failed to get records for listMissingBlobs", "error", err) 42 return helpers.ServerError(e, nil) 43 } ··· 71 } 72 73 var count int64 74 + if err := s.db.Raw(ctx, "SELECT COUNT(*) FROM blobs WHERE did = ? AND cid = ?", nil, urepo.Repo.Did, ref.cid.Bytes()).Scan(&count).Error; err != nil { 75 continue 76 } 77
+4 -2
server/handle_repo_list_records.go
··· 46 } 47 48 func (s *Server) handleListRecords(e echo.Context) error { 49 var req ComAtprotoRepoListRecordsRequest 50 if err := e.Bind(&req); err != nil { 51 s.logger.Error("could not bind list records request", "error", err) ··· 78 79 did := req.Repo 80 if _, err := syntax.ParseDID(did); err != nil { 81 - actor, err := s.getActorByHandle(req.Repo) 82 if err != nil { 83 return helpers.InputError(e, to.StringPtr("RepoNotFound")) 84 } ··· 93 params = append(params, limit) 94 95 var records []models.Record 96 - if err := s.db.Raw("SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 97 s.logger.Error("error getting records", "error", err) 98 return helpers.ServerError(e, nil) 99 }
··· 46 } 47 48 func (s *Server) handleListRecords(e echo.Context) error { 49 + ctx := e.Request().Context() 50 + 51 var req ComAtprotoRepoListRecordsRequest 52 if err := e.Bind(&req); err != nil { 53 s.logger.Error("could not bind list records request", "error", err) ··· 80 81 did := req.Repo 82 if _, err := syntax.ParseDID(did); err != nil { 83 + actor, err := s.getActorByHandle(ctx, req.Repo) 84 if err != nil { 85 return helpers.InputError(e, to.StringPtr("RepoNotFound")) 86 } ··· 95 params = append(params, limit) 96 97 var records []models.Record 98 + if err := s.db.Raw(ctx, "SELECT * FROM records WHERE did = ? AND nsid = ? "+cursorquery+" ORDER BY created_at "+sort+" limit ?", nil, params...).Scan(&records).Error; err != nil { 99 s.logger.Error("error getting records", "error", err) 100 return helpers.ServerError(e, nil) 101 }
+3 -1
server/handle_repo_list_repos.go
··· 21 22 // TODO: paginate this bitch 23 func (s *Server) handleListRepos(e echo.Context) error { 24 var repos []models.Repo 25 - if err := s.db.Raw("SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 26 return err 27 } 28
··· 21 22 // TODO: paginate this bitch 23 func (s *Server) handleListRepos(e echo.Context) error { 24 + ctx := e.Request().Context() 25 + 26 var repos []models.Repo 27 + if err := s.db.Raw(ctx, "SELECT * FROM repos ORDER BY created_at DESC LIMIT 500", nil).Scan(&repos).Error; err != nil { 28 return err 29 } 30
+5 -3
server/handle_repo_upload_blob.go
··· 32 } 33 34 func (s *Server) handleRepoUploadBlob(e echo.Context) error { 35 urepo := e.Get("repo").(*models.RepoActor) 36 37 mime := e.Request().Header.Get("content-type") ··· 51 Storage: storage, 52 } 53 54 - if err := s.db.Create(&blob, nil).Error; err != nil { 55 s.logger.Error("error creating new blob in db", "error", err) 56 return helpers.ServerError(e, nil) 57 } ··· 84 Data: data, 85 } 86 87 - if err := s.db.Create(&blobPart, nil).Error; err != nil { 88 s.logger.Error("error adding blob part to db", "error", err) 89 return helpers.ServerError(e, nil) 90 } ··· 131 } 132 } 133 134 - if err := s.db.Exec("UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 135 // there should probably be somme handling here if this fails... 136 s.logger.Error("error updating blob", "error", err) 137 return helpers.ServerError(e, nil)
··· 32 } 33 34 func (s *Server) handleRepoUploadBlob(e echo.Context) error { 35 + ctx := e.Request().Context() 36 + 37 urepo := e.Get("repo").(*models.RepoActor) 38 39 mime := e.Request().Header.Get("content-type") ··· 53 Storage: storage, 54 } 55 56 + if err := s.db.Create(ctx, &blob, nil).Error; err != nil { 57 s.logger.Error("error creating new blob in db", "error", err) 58 return helpers.ServerError(e, nil) 59 } ··· 86 Data: data, 87 } 88 89 + if err := s.db.Create(ctx, &blobPart, nil).Error; err != nil { 90 s.logger.Error("error adding blob part to db", "error", err) 91 return helpers.ServerError(e, nil) 92 } ··· 133 } 134 } 135 136 + if err := s.db.Exec(ctx, "UPDATE blobs SET cid = ? WHERE id = ?", nil, c.Bytes(), blob.ID).Error; err != nil { 137 // there should probably be somme handling here if this fails... 138 s.logger.Error("error updating blob", "error", err) 139 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_activate_account.go
··· 18 } 19 20 func (s *Server) handleServerActivateAccount(e echo.Context) error { 21 var req ComAtprotoServerDeactivateAccountRequest 22 if err := e.Bind(&req); err != nil { 23 s.logger.Error("error binding", "error", err) ··· 26 27 urepo := e.Get("repo").(*models.RepoActor) 28 29 - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 30 s.logger.Error("error updating account status to deactivated", "error", err) 31 return helpers.ServerError(e, nil) 32 }
··· 18 } 19 20 func (s *Server) handleServerActivateAccount(e echo.Context) error { 21 + ctx := e.Request().Context() 22 + 23 var req ComAtprotoServerDeactivateAccountRequest 24 if err := e.Bind(&req); err != nil { 25 s.logger.Error("error binding", "error", err) ··· 28 29 urepo := e.Get("repo").(*models.RepoActor) 30 31 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, false, urepo.Repo.Did).Error; err != nil { 32 s.logger.Error("error updating account status to deactivated", "error", err) 33 return helpers.ServerError(e, nil) 34 }
+5 -3
server/handle_server_check_account_status.go
··· 20 } 21 22 func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 urepo := e.Get("repo").(*models.RepoActor) 24 25 resp := ComAtprotoServerCheckAccountStatusResponse{ ··· 41 } 42 43 var blockCtResp CountResp 44 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 45 s.logger.Error("error getting block count", "error", err) 46 return helpers.ServerError(e, nil) 47 } 48 resp.RepoBlocks = blockCtResp.Ct 49 50 var recCtResp CountResp 51 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 52 s.logger.Error("error getting record count", "error", err) 53 return helpers.ServerError(e, nil) 54 } 55 resp.IndexedRecords = recCtResp.Ct 56 57 var blobCtResp CountResp 58 - if err := s.db.Raw("SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 59 s.logger.Error("error getting record count", "error", err) 60 return helpers.ServerError(e, nil) 61 }
··· 20 } 21 22 func (s *Server) handleServerCheckAccountStatus(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 25 urepo := e.Get("repo").(*models.RepoActor) 26 27 resp := ComAtprotoServerCheckAccountStatusResponse{ ··· 43 } 44 45 var blockCtResp CountResp 46 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blocks WHERE did = ?", nil, urepo.Repo.Did).Scan(&blockCtResp).Error; err != nil { 47 s.logger.Error("error getting block count", "error", err) 48 return helpers.ServerError(e, nil) 49 } 50 resp.RepoBlocks = blockCtResp.Ct 51 52 var recCtResp CountResp 53 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM records WHERE did = ?", nil, urepo.Repo.Did).Scan(&recCtResp).Error; err != nil { 54 s.logger.Error("error getting record count", "error", err) 55 return helpers.ServerError(e, nil) 56 } 57 resp.IndexedRecords = recCtResp.Ct 58 59 var blobCtResp CountResp 60 + if err := s.db.Raw(ctx, "SELECT COUNT(*) AS ct FROM blobs WHERE did = ?", nil, urepo.Repo.Did).Scan(&blobCtResp).Error; err != nil { 61 s.logger.Error("error getting record count", "error", err) 62 return helpers.ServerError(e, nil) 63 }
+3 -1
server/handle_server_confirm_email.go
··· 15 } 16 17 func (s *Server) handleServerConfirmEmail(e echo.Context) error { 18 urepo := e.Get("repo").(*models.RepoActor) 19 20 var req ComAtprotoServerConfirmEmailRequest ··· 41 42 now := time.Now().UTC() 43 44 - if err := s.db.Exec("UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 45 s.logger.Error("error updating user", "error", err) 46 return helpers.ServerError(e, nil) 47 }
··· 15 } 16 17 func (s *Server) handleServerConfirmEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + 20 urepo := e.Get("repo").(*models.RepoActor) 21 22 var req ComAtprotoServerConfirmEmailRequest ··· 43 44 now := time.Now().UTC() 45 46 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = NULL, email_verification_code_expires_at = NULL, email_confirmed_at = ? WHERE did = ?", nil, now, urepo.Repo.Did).Error; err != nil { 47 s.logger.Error("error updating user", "error", err) 48 return helpers.ServerError(e, nil) 49 }
+16 -14
server/handle_server_create_account.go
··· 36 } 37 38 func (s *Server) handleCreateAccount(e echo.Context) error { 39 var request ComAtprotoServerCreateAccountRequest 40 41 if err := e.Bind(&request); err != nil { ··· 68 } 69 } 70 } 71 - 72 var signupDid string 73 if request.Did != nil { 74 - signupDid = *request.Did; 75 - 76 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 77 if token == "" { 78 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) ··· 90 } 91 92 // see if the handle is already taken 93 - actor, err := s.getActorByHandle(request.Handle) 94 if err != nil && err != gorm.ErrRecordNotFound { 95 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 96 return helpers.ServerError(e, nil) ··· 109 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 110 } 111 112 - if err := s.db.Raw("SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 113 if err == gorm.ErrRecordNotFound { 114 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 115 } ··· 123 } 124 125 // see if the email is already taken 126 - existingRepo, err := s.getRepoByEmail(request.Email) 127 if err != nil && err != gorm.ErrRecordNotFound { 128 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 129 return helpers.ServerError(e, nil) ··· 137 var k *atcrypto.PrivateKeyK256 138 139 if signupDid != "" { 140 - reservedKey, err := s.getReservedKey(signupDid) 141 if err != nil { 142 s.logger.Error("error looking up reserved key", "error", err) 143 } ··· 148 k = nil 149 } else { 150 defer func() { 151 - if delErr := s.deleteReservedKey(reservedKey.KeyDid, reservedKey.Did); delErr != nil { 152 s.logger.Error("error deleting reserved key", "error", delErr) 153 } 154 }() ··· 199 Handle: request.Handle, 200 } 201 202 - if err := s.db.Create(&urepo, nil).Error; err != nil { 203 s.logger.Error("error inserting new repo", "error", err) 204 return helpers.ServerError(e, nil) 205 } 206 - 207 - if err := s.db.Create(&actor, nil).Error; err != nil { 208 s.logger.Error("error inserting new actor", "error", err) 209 return helpers.ServerError(e, nil) 210 } 211 } else { 212 - if err := s.db.Save(&actor, nil).Error; err != nil { 213 s.logger.Error("error inserting new actor", "error", err) 214 return helpers.ServerError(e, nil) 215 } ··· 241 } 242 243 if s.config.RequireInvite { 244 - if err := s.db.Raw("UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 245 s.logger.Error("error decrementing use count", "error", err) 246 return helpers.ServerError(e, nil) 247 } 248 } 249 250 - sess, err := s.createSession(&urepo) 251 if err != nil { 252 s.logger.Error("error creating new session", "error", err) 253 return helpers.ServerError(e, nil)
··· 36 } 37 38 func (s *Server) handleCreateAccount(e echo.Context) error { 39 + ctx := e.Request().Context() 40 + 41 var request ComAtprotoServerCreateAccountRequest 42 43 if err := e.Bind(&request); err != nil { ··· 70 } 71 } 72 } 73 + 74 var signupDid string 75 if request.Did != nil { 76 + signupDid = *request.Did 77 + 78 token := strings.TrimSpace(strings.Replace(e.Request().Header.Get("authorization"), "Bearer ", "", 1)) 79 if token == "" { 80 return helpers.UnauthorizedError(e, to.StringPtr("must authenticate to use an existing did")) ··· 92 } 93 94 // see if the handle is already taken 95 + actor, err := s.getActorByHandle(ctx, request.Handle) 96 if err != nil && err != gorm.ErrRecordNotFound { 97 s.logger.Error("error looking up handle in db", "endpoint", "com.atproto.server.createAccount", "error", err) 98 return helpers.ServerError(e, nil) ··· 111 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 112 } 113 114 + if err := s.db.Raw(ctx, "SELECT * FROM invite_codes WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 115 if err == gorm.ErrRecordNotFound { 116 return helpers.InputError(e, to.StringPtr("InvalidInviteCode")) 117 } ··· 125 } 126 127 // see if the email is already taken 128 + existingRepo, err := s.getRepoByEmail(ctx, request.Email) 129 if err != nil && err != gorm.ErrRecordNotFound { 130 s.logger.Error("error looking up email in db", "endpoint", "com.atproto.server.createAccount", "error", err) 131 return helpers.ServerError(e, nil) ··· 139 var k *atcrypto.PrivateKeyK256 140 141 if signupDid != "" { 142 + reservedKey, err := s.getReservedKey(ctx, signupDid) 143 if err != nil { 144 s.logger.Error("error looking up reserved key", "error", err) 145 } ··· 150 k = nil 151 } else { 152 defer func() { 153 + if delErr := s.deleteReservedKey(ctx, reservedKey.KeyDid, reservedKey.Did); delErr != nil { 154 s.logger.Error("error deleting reserved key", "error", delErr) 155 } 156 }() ··· 201 Handle: request.Handle, 202 } 203 204 + if err := s.db.Create(ctx, &urepo, nil).Error; err != nil { 205 s.logger.Error("error inserting new repo", "error", err) 206 return helpers.ServerError(e, nil) 207 } 208 + 209 + if err := s.db.Create(ctx, &actor, nil).Error; err != nil { 210 s.logger.Error("error inserting new actor", "error", err) 211 return helpers.ServerError(e, nil) 212 } 213 } else { 214 + if err := s.db.Save(ctx, &actor, nil).Error; err != nil { 215 s.logger.Error("error inserting new actor", "error", err) 216 return helpers.ServerError(e, nil) 217 } ··· 243 } 244 245 if s.config.RequireInvite { 246 + if err := s.db.Raw(ctx, "UPDATE invite_codes SET remaining_use_count = remaining_use_count - 1 WHERE code = ?", nil, request.InviteCode).Scan(&ic).Error; err != nil { 247 s.logger.Error("error decrementing use count", "error", err) 248 return helpers.ServerError(e, nil) 249 } 250 } 251 252 + sess, err := s.createSession(ctx, &urepo) 253 if err != nil { 254 s.logger.Error("error creating new session", "error", err) 255 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_create_invite_code.go
··· 17 } 18 19 func (s *Server) handleCreateInviteCode(e echo.Context) error { 20 var req ComAtprotoServerCreateInviteCodeRequest 21 if err := e.Bind(&req); err != nil { 22 s.logger.Error("error binding", "error", err) ··· 37 acc = *req.ForAccount 38 } 39 40 - if err := s.db.Create(&models.InviteCode{ 41 Code: ic, 42 Did: acc, 43 RemainingUseCount: req.UseCount,
··· 17 } 18 19 func (s *Server) handleCreateInviteCode(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + 22 var req ComAtprotoServerCreateInviteCodeRequest 23 if err := e.Bind(&req); err != nil { 24 s.logger.Error("error binding", "error", err) ··· 39 acc = *req.ForAccount 40 } 41 42 + if err := s.db.Create(ctx, &models.InviteCode{ 43 Code: ic, 44 Did: acc, 45 RemainingUseCount: req.UseCount,
+3 -1
server/handle_server_create_invite_codes.go
··· 22 } 23 24 func (s *Server) handleCreateInviteCodes(e echo.Context) error { 25 var req ComAtprotoServerCreateInviteCodesRequest 26 if err := e.Bind(&req); err != nil { 27 s.logger.Error("error binding", "error", err) ··· 50 ic := uuid.NewString() 51 ics = append(ics, ic) 52 53 - if err := s.db.Create(&models.InviteCode{ 54 Code: ic, 55 Did: did, 56 RemainingUseCount: req.UseCount,
··· 22 } 23 24 func (s *Server) handleCreateInviteCodes(e echo.Context) error { 25 + ctx := e.Request().Context() 26 + 27 var req ComAtprotoServerCreateInviteCodesRequest 28 if err := e.Bind(&req); err != nil { 29 s.logger.Error("error binding", "error", err) ··· 52 ic := uuid.NewString() 53 ics = append(ics, ic) 54 55 + if err := s.db.Create(ctx, &models.InviteCode{ 56 Code: ic, 57 Did: did, 58 RemainingUseCount: req.UseCount,
+6 -4
server/handle_server_create_session.go
··· 32 } 33 34 func (s *Server) handleCreateSession(e echo.Context) error { 35 var req ComAtprotoServerCreateSessionRequest 36 if err := e.Bind(&req); err != nil { 37 s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) ··· 65 var err error 66 switch idtype { 67 case "did": 68 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 69 case "handle": 70 - err = s.db.Raw("SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 71 case "email": 72 - err = s.db.Raw("SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 73 } 74 75 if err != nil { ··· 88 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 89 } 90 91 - sess, err := s.createSession(&repo.Repo) 92 if err != nil { 93 s.logger.Error("error creating session", "error", err) 94 return helpers.ServerError(e, nil)
··· 32 } 33 34 func (s *Server) handleCreateSession(e echo.Context) error { 35 + ctx := e.Request().Context() 36 + 37 var req ComAtprotoServerCreateSessionRequest 38 if err := e.Bind(&req); err != nil { 39 s.logger.Error("error binding request", "endpoint", "com.atproto.server.serverCreateSession", "error", err) ··· 67 var err error 68 switch idtype { 69 case "did": 70 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.did = ?", nil, req.Identifier).Scan(&repo).Error 71 case "handle": 72 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM actors a LEFT JOIN repos r ON a.did = r.did WHERE a.handle = ?", nil, req.Identifier).Scan(&repo).Error 73 case "email": 74 + err = s.db.Raw(ctx, "SELECT r.*, a.* FROM repos r LEFT JOIN actors a ON r.did = a.did WHERE r.email = ?", nil, req.Identifier).Scan(&repo).Error 75 } 76 77 if err != nil { ··· 90 return helpers.InputError(e, to.StringPtr("InvalidRequest")) 91 } 92 93 + sess, err := s.createSession(ctx, &repo.Repo) 94 if err != nil { 95 s.logger.Error("error creating session", "error", err) 96 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_deactivate_account.go
··· 19 } 20 21 func (s *Server) handleServerDeactivateAccount(e echo.Context) error { 22 var req ComAtprotoServerDeactivateAccountRequest 23 if err := e.Bind(&req); err != nil { 24 s.logger.Error("error binding", "error", err) ··· 27 28 urepo := e.Get("repo").(*models.RepoActor) 29 30 - if err := s.db.Exec("UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 31 s.logger.Error("error updating account status to deactivated", "error", err) 32 return helpers.ServerError(e, nil) 33 }
··· 19 } 20 21 func (s *Server) handleServerDeactivateAccount(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 24 var req ComAtprotoServerDeactivateAccountRequest 25 if err := e.Bind(&req); err != nil { 26 s.logger.Error("error binding", "error", err) ··· 29 30 urepo := e.Get("repo").(*models.RepoActor) 31 32 + if err := s.db.Exec(ctx, "UPDATE repos SET deactivated = ? WHERE did = ?", nil, true, urepo.Repo.Did).Error; err != nil { 33 s.logger.Error("error updating account status to deactivated", "error", err) 34 return helpers.ServerError(e, nil) 35 }
+4 -2
server/handle_server_delete_account.go
··· 20 } 21 22 func (s *Server) handleServerDeleteAccount(e echo.Context) error { 23 var req ComAtprotoServerDeleteAccountRequest 24 if err := e.Bind(&req); err != nil { 25 s.logger.Error("error binding", "error", err) ··· 31 return helpers.ServerError(e, nil) 32 } 33 34 - urepo, err := s.getRepoActorByDid(req.Did) 35 if err != nil { 36 s.logger.Error("error getting repo", "error", err) 37 return echo.NewHTTPError(400, "account not found") ··· 66 }) 67 } 68 69 - tx := s.db.BeginDangerously() 70 if tx.Error != nil { 71 s.logger.Error("error starting transaction", "error", tx.Error) 72 return helpers.ServerError(e, nil)
··· 20 } 21 22 func (s *Server) handleServerDeleteAccount(e echo.Context) error { 23 + ctx := e.Request().Context() 24 + 25 var req ComAtprotoServerDeleteAccountRequest 26 if err := e.Bind(&req); err != nil { 27 s.logger.Error("error binding", "error", err) ··· 33 return helpers.ServerError(e, nil) 34 } 35 36 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 37 if err != nil { 38 s.logger.Error("error getting repo", "error", err) 39 return echo.NewHTTPError(400, "account not found") ··· 68 }) 69 } 70 71 + tx := s.db.BeginDangerously(ctx) 72 if tx.Error != nil { 73 s.logger.Error("error starting transaction", "error", tx.Error) 74 return helpers.ServerError(e, nil)
+4 -2
server/handle_server_delete_session.go
··· 7 ) 8 9 func (s *Server) handleDeleteSession(e echo.Context) error { 10 token := e.Get("token").(string) 11 12 var acctok models.Token 13 - if err := s.db.Raw("DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 14 s.logger.Error("error deleting access token from db", "error", err) 15 return helpers.ServerError(e, nil) 16 } 17 18 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 19 s.logger.Error("error deleting refresh token from db", "error", err) 20 return helpers.ServerError(e, nil) 21 }
··· 7 ) 8 9 func (s *Server) handleDeleteSession(e echo.Context) error { 10 + ctx := e.Request().Context() 11 + 12 token := e.Get("token").(string) 13 14 var acctok models.Token 15 + if err := s.db.Raw(ctx, "DELETE FROM tokens WHERE token = ? RETURNING *", nil, token).Scan(&acctok).Error; err != nil { 16 s.logger.Error("error deleting access token from db", "error", err) 17 return helpers.ServerError(e, nil) 18 } 19 20 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, acctok.RefreshToken).Error; err != nil { 21 s.logger.Error("error deleting refresh token from db", "error", err) 22 return helpers.ServerError(e, nil) 23 }
+5 -3
server/handle_server_refresh_session.go
··· 16 } 17 18 func (s *Server) handleRefreshSession(e echo.Context) error { 19 token := e.Get("token").(string) 20 repo := e.Get("repo").(*models.RepoActor) 21 22 - if err := s.db.Exec("DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 23 s.logger.Error("error getting refresh token from db", "error", err) 24 return helpers.ServerError(e, nil) 25 } 26 27 - if err := s.db.Exec("DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 28 s.logger.Error("error deleting access token from db", "error", err) 29 return helpers.ServerError(e, nil) 30 } 31 32 - sess, err := s.createSession(&repo.Repo) 33 if err != nil { 34 s.logger.Error("error creating new session for refresh", "error", err) 35 return helpers.ServerError(e, nil)
··· 16 } 17 18 func (s *Server) handleRefreshSession(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + 21 token := e.Get("token").(string) 22 repo := e.Get("repo").(*models.RepoActor) 23 24 + if err := s.db.Exec(ctx, "DELETE FROM refresh_tokens WHERE token = ?", nil, token).Error; err != nil { 25 s.logger.Error("error getting refresh token from db", "error", err) 26 return helpers.ServerError(e, nil) 27 } 28 29 + if err := s.db.Exec(ctx, "DELETE FROM tokens WHERE refresh_token = ?", nil, token).Error; err != nil { 30 s.logger.Error("error deleting access token from db", "error", err) 31 return helpers.ServerError(e, nil) 32 } 33 34 + sess, err := s.createSession(ctx, &repo.Repo) 35 if err != nil { 36 s.logger.Error("error creating new session for refresh", "error", err) 37 return helpers.ServerError(e, nil)
+3 -1
server/handle_server_request_account_delete.go
··· 10 ) 11 12 func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { 13 urepo := e.Get("repo").(*models.RepoActor) 14 15 token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 16 expiresAt := time.Now().UTC().Add(15 * time.Minute) 17 18 - if err := s.db.Exec("UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 19 s.logger.Error("error setting deletion token", "error", err) 20 return helpers.ServerError(e, nil) 21 }
··· 10 ) 11 12 func (s *Server) handleServerRequestAccountDelete(e echo.Context) error { 13 + ctx := e.Request().Context() 14 + 15 urepo := e.Get("repo").(*models.RepoActor) 16 17 token := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 18 expiresAt := time.Now().UTC().Add(15 * time.Minute) 19 20 + if err := s.db.Exec(ctx, "UPDATE repos SET account_delete_code = ?, account_delete_code_expires_at = ? WHERE did = ?", nil, token, expiresAt, urepo.Repo.Did).Error; err != nil { 21 s.logger.Error("error setting deletion token", "error", err) 22 return helpers.ServerError(e, nil) 23 }
+3 -1
server/handle_server_request_email_confirmation.go
··· 11 ) 12 13 func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { 14 urepo := e.Get("repo").(*models.RepoActor) 15 16 if urepo.EmailConfirmedAt != nil { ··· 20 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 eat := time.Now().Add(10 * time.Minute).UTC() 22 23 - if err := s.db.Exec("UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 s.logger.Error("error updating user", "error", err) 25 return helpers.ServerError(e, nil) 26 }
··· 11 ) 12 13 func (s *Server) handleServerRequestEmailConfirmation(e echo.Context) error { 14 + ctx := e.Request().Context() 15 + 16 urepo := e.Get("repo").(*models.RepoActor) 17 18 if urepo.EmailConfirmedAt != nil { ··· 22 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 23 eat := time.Now().Add(10 * time.Minute).UTC() 24 25 + if err := s.db.Exec(ctx, "UPDATE repos SET email_verification_code = ?, email_verification_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 26 s.logger.Error("error updating user", "error", err) 27 return helpers.ServerError(e, nil) 28 }
+3 -1
server/handle_server_request_email_update.go
··· 14 } 15 16 func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { 17 urepo := e.Get("repo").(*models.RepoActor) 18 19 if urepo.EmailConfirmedAt != nil { 20 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 21 eat := time.Now().Add(10 * time.Minute).UTC() 22 23 - if err := s.db.Exec("UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 24 s.logger.Error("error updating repo", "error", err) 25 return helpers.ServerError(e, nil) 26 }
··· 14 } 15 16 func (s *Server) handleServerRequestEmailUpdate(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 19 urepo := e.Get("repo").(*models.RepoActor) 20 21 if urepo.EmailConfirmedAt != nil { 22 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 23 eat := time.Now().Add(10 * time.Minute).UTC() 24 25 + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = ?, email_update_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 26 s.logger.Error("error updating repo", "error", err) 27 return helpers.ServerError(e, nil) 28 }
+4 -2
server/handle_server_request_password_reset.go
··· 14 } 15 16 func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { 17 urepo, ok := e.Get("repo").(*models.RepoActor) 18 if !ok { 19 var req ComAtprotoServerRequestPasswordResetRequest ··· 25 return err 26 } 27 28 - murepo, err := s.getRepoActorByEmail(req.Email) 29 if err != nil { 30 return err 31 } ··· 36 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 37 eat := time.Now().Add(10 * time.Minute).UTC() 38 39 - if err := s.db.Exec("UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 40 s.logger.Error("error updating repo", "error", err) 41 return helpers.ServerError(e, nil) 42 }
··· 14 } 15 16 func (s *Server) handleServerRequestPasswordReset(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 19 urepo, ok := e.Get("repo").(*models.RepoActor) 20 if !ok { 21 var req ComAtprotoServerRequestPasswordResetRequest ··· 27 return err 28 } 29 30 + murepo, err := s.getRepoActorByEmail(ctx, req.Email) 31 if err != nil { 32 return err 33 } ··· 38 code := fmt.Sprintf("%s-%s", helpers.RandomVarchar(5), helpers.RandomVarchar(5)) 39 eat := time.Now().Add(10 * time.Minute).UTC() 40 41 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = ?, password_reset_code_expires_at = ? WHERE did = ?", nil, code, eat, urepo.Repo.Did).Error; err != nil { 42 s.logger.Error("error updating repo", "error", err) 43 return helpers.ServerError(e, nil) 44 }
+11 -8
server/handle_server_reserve_signing_key.go
··· 1 package server 2 3 import ( 4 "time" 5 6 "github.com/bluesky-social/indigo/atproto/atcrypto" ··· 18 } 19 20 func (s *Server) handleServerReserveSigningKey(e echo.Context) error { 21 var req ServerReserveSigningKeyRequest 22 if err := e.Bind(&req); err != nil { 23 s.logger.Error("could not bind reserve signing key request", "error", err) ··· 26 27 if req.Did != nil && *req.Did != "" { 28 var existing models.ReservedKey 29 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 30 return e.JSON(200, ServerReserveSigningKeyResponse{ 31 SigningKey: existing.KeyDid, 32 }) ··· 54 CreatedAt: time.Now(), 55 } 56 57 - if err := s.db.Create(&reservedKey, nil).Error; err != nil { 58 s.logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 59 return helpers.ServerError(e, nil) 60 } ··· 66 }) 67 } 68 69 - func (s *Server) getReservedKey(keyDidOrDid string) (*models.ReservedKey, error) { 70 var reservedKey models.ReservedKey 71 72 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 73 return &reservedKey, nil 74 } 75 76 - if err := s.db.Raw("SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 77 return &reservedKey, nil 78 } 79 80 return nil, nil 81 } 82 83 - func (s *Server) deleteReservedKey(keyDid string, did *string) error { 84 - if err := s.db.Exec("DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { 85 return err 86 } 87 88 if did != nil && *did != "" { 89 - if err := s.db.Exec("DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { 90 return err 91 } 92 }
··· 1 package server 2 3 import ( 4 + "context" 5 "time" 6 7 "github.com/bluesky-social/indigo/atproto/atcrypto" ··· 19 } 20 21 func (s *Server) handleServerReserveSigningKey(e echo.Context) error { 22 + ctx := e.Request().Context() 23 + 24 var req ServerReserveSigningKeyRequest 25 if err := e.Bind(&req); err != nil { 26 s.logger.Error("could not bind reserve signing key request", "error", err) ··· 29 30 if req.Did != nil && *req.Did != "" { 31 var existing models.ReservedKey 32 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, *req.Did).Scan(&existing).Error; err == nil && existing.KeyDid != "" { 33 return e.JSON(200, ServerReserveSigningKeyResponse{ 34 SigningKey: existing.KeyDid, 35 }) ··· 57 CreatedAt: time.Now(), 58 } 59 60 + if err := s.db.Create(ctx, &reservedKey, nil).Error; err != nil { 61 s.logger.Error("error storing reserved key", "endpoint", "com.atproto.server.reserveSigningKey", "error", err) 62 return helpers.ServerError(e, nil) 63 } ··· 69 }) 70 } 71 72 + func (s *Server) getReservedKey(ctx context.Context, keyDidOrDid string) (*models.ReservedKey, error) { 73 var reservedKey models.ReservedKey 74 75 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE key_did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 76 return &reservedKey, nil 77 } 78 79 + if err := s.db.Raw(ctx, "SELECT * FROM reserved_keys WHERE did = ?", nil, keyDidOrDid).Scan(&reservedKey).Error; err == nil && reservedKey.KeyDid != "" { 80 return &reservedKey, nil 81 } 82 83 return nil, nil 84 } 85 86 + func (s *Server) deleteReservedKey(ctx context.Context, keyDid string, did *string) error { 87 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE key_did = ?", nil, keyDid).Error; err != nil { 88 return err 89 } 90 91 if did != nil && *did != "" { 92 + if err := s.db.Exec(ctx, "DELETE FROM reserved_keys WHERE did = ?", nil, *did).Error; err != nil { 93 return err 94 } 95 }
+3 -1
server/handle_server_reset_password.go
··· 16 } 17 18 func (s *Server) handleServerResetPassword(e echo.Context) error { 19 urepo := e.Get("repo").(*models.RepoActor) 20 21 var req ComAtprotoServerResetPasswordRequest ··· 46 return helpers.ServerError(e, nil) 47 } 48 49 - if err := s.db.Exec("UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 50 s.logger.Error("error updating repo", "error", err) 51 return helpers.ServerError(e, nil) 52 }
··· 16 } 17 18 func (s *Server) handleServerResetPassword(e echo.Context) error { 19 + ctx := e.Request().Context() 20 + 21 urepo := e.Get("repo").(*models.RepoActor) 22 23 var req ComAtprotoServerResetPasswordRequest ··· 48 return helpers.ServerError(e, nil) 49 } 50 51 + if err := s.db.Exec(ctx, "UPDATE repos SET password_reset_code = NULL, password_reset_code_expires_at = NULL, password = ? WHERE did = ?", nil, hash, urepo.Repo.Did).Error; err != nil { 52 s.logger.Error("error updating repo", "error", err) 53 return helpers.ServerError(e, nil) 54 }
+3 -1
server/handle_server_update_email.go
··· 15 } 16 17 func (s *Server) handleServerUpdateEmail(e echo.Context) error { 18 urepo := e.Get("repo").(*models.RepoActor) 19 20 var req ComAtprotoServerUpdateEmailRequest ··· 39 return helpers.ExpiredTokenError(e) 40 } 41 42 - if err := s.db.Exec("UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { 43 s.logger.Error("error updating repo", "error", err) 44 return helpers.ServerError(e, nil) 45 }
··· 15 } 16 17 func (s *Server) handleServerUpdateEmail(e echo.Context) error { 18 + ctx := e.Request().Context() 19 + 20 urepo := e.Get("repo").(*models.RepoActor) 21 22 var req ComAtprotoServerUpdateEmailRequest ··· 41 return helpers.ExpiredTokenError(e) 42 } 43 44 + if err := s.db.Exec(ctx, "UPDATE repos SET email_update_code = NULL, email_update_code_expires_at = NULL, email_confirmed_at = NULL, email = ? WHERE did = ?", nil, req.Email, urepo.Repo.Did).Error; err != nil { 45 s.logger.Error("error updating repo", "error", err) 46 return helpers.ServerError(e, nil) 47 }
+5 -3
server/handle_sync_get_blob.go
··· 17 ) 18 19 func (s *Server) handleSyncGetBlob(e echo.Context) error { 20 did := e.QueryParam("did") 21 if did == "" { 22 return helpers.InputError(e, nil) ··· 32 return helpers.InputError(e, nil) 33 } 34 35 - urepo, err := s.getRepoActorByDid(did) 36 if err != nil { 37 s.logger.Error("could not find user for requested blob", "error", err) 38 return helpers.InputError(e, nil) ··· 46 } 47 48 var blob models.Blob 49 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 50 s.logger.Error("error looking up blob", "error", err) 51 return helpers.ServerError(e, nil) 52 } ··· 55 56 if blob.Storage == "sqlite" { 57 var parts []models.BlobPart 58 - if err := s.db.Raw("SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 59 s.logger.Error("error getting blob parts", "error", err) 60 return helpers.ServerError(e, nil) 61 }
··· 17 ) 18 19 func (s *Server) handleSyncGetBlob(e echo.Context) error { 20 + ctx := e.Request().Context() 21 + 22 did := e.QueryParam("did") 23 if did == "" { 24 return helpers.InputError(e, nil) ··· 34 return helpers.InputError(e, nil) 35 } 36 37 + urepo, err := s.getRepoActorByDid(ctx, did) 38 if err != nil { 39 s.logger.Error("could not find user for requested blob", "error", err) 40 return helpers.InputError(e, nil) ··· 48 } 49 50 var blob models.Blob 51 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? AND cid = ?", nil, did, c.Bytes()).Scan(&blob).Error; err != nil { 52 s.logger.Error("error looking up blob", "error", err) 53 return helpers.ServerError(e, nil) 54 } ··· 57 58 if blob.Storage == "sqlite" { 59 var parts []models.BlobPart 60 + if err := s.db.Raw(ctx, "SELECT * FROM blob_parts WHERE blob_id = ? ORDER BY idx", nil, blob.ID).Scan(&parts).Error; err != nil { 61 s.logger.Error("error getting blob parts", "error", err) 62 return helpers.ServerError(e, nil) 63 }
+1 -1
server/handle_sync_get_blocks.go
··· 35 cids = append(cids, c) 36 } 37 38 - urepo, err := s.getRepoActorByDid(req.Did) 39 if err != nil { 40 return helpers.ServerError(e, nil) 41 }
··· 35 cids = append(cids, c) 36 } 37 38 + urepo, err := s.getRepoActorByDid(ctx, req.Did) 39 if err != nil { 40 return helpers.ServerError(e, nil) 41 }
+3 -1
server/handle_sync_get_latest_commit.go
··· 12 } 13 14 func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { 15 did := e.QueryParam("did") 16 if did == "" { 17 return helpers.InputError(e, nil) 18 } 19 20 - urepo, err := s.getRepoActorByDid(did) 21 if err != nil { 22 return err 23 }
··· 12 } 13 14 func (s *Server) handleSyncGetLatestCommit(e echo.Context) error { 15 + ctx := e.Request().Context() 16 + 17 did := e.QueryParam("did") 18 if did == "" { 19 return helpers.InputError(e, nil) 20 } 21 22 + urepo, err := s.getRepoActorByDid(ctx, did) 23 if err != nil { 24 return err 25 }
+1 -1
server/handle_sync_get_record.go
··· 20 rkey := e.QueryParam("rkey") 21 22 var urepo models.Repo 23 - if err := s.db.Raw("SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 24 s.logger.Error("error getting repo", "error", err) 25 return helpers.ServerError(e, nil) 26 }
··· 20 rkey := e.QueryParam("rkey") 21 22 var urepo models.Repo 23 + if err := s.db.Raw(ctx, "SELECT * FROM repos WHERE did = ?", nil, did).Scan(&urepo).Error; err != nil { 24 s.logger.Error("error getting repo", "error", err) 25 return helpers.ServerError(e, nil) 26 }
+4 -2
server/handle_sync_get_repo.go
··· 13 ) 14 15 func (s *Server) handleSyncGetRepo(e echo.Context) error { 16 did := e.QueryParam("did") 17 if did == "" { 18 return helpers.InputError(e, nil) 19 } 20 21 - urepo, err := s.getRepoActorByDid(did) 22 if err != nil { 23 return err 24 } ··· 41 } 42 43 var blocks []models.Block 44 - if err := s.db.Raw("SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 45 return err 46 } 47
··· 13 ) 14 15 func (s *Server) handleSyncGetRepo(e echo.Context) error { 16 + ctx := e.Request().Context() 17 + 18 did := e.QueryParam("did") 19 if did == "" { 20 return helpers.InputError(e, nil) 21 } 22 23 + urepo, err := s.getRepoActorByDid(ctx, did) 24 if err != nil { 25 return err 26 } ··· 43 } 44 45 var blocks []models.Block 46 + if err := s.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? ORDER BY rev ASC", nil, urepo.Repo.Did).Scan(&blocks).Error; err != nil { 47 return err 48 } 49
+3 -1
server/handle_sync_get_repo_status.go
··· 14 15 // TODO: make this actually do the right thing 16 func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { 17 did := e.QueryParam("did") 18 if did == "" { 19 return helpers.InputError(e, nil) 20 } 21 22 - urepo, err := s.getRepoActorByDid(did) 23 if err != nil { 24 return err 25 }
··· 14 15 // TODO: make this actually do the right thing 16 func (s *Server) handleSyncGetRepoStatus(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 19 did := e.QueryParam("did") 20 if did == "" { 21 return helpers.InputError(e, nil) 22 } 23 24 + urepo, err := s.getRepoActorByDid(ctx, did) 25 if err != nil { 26 return err 27 }
+4 -2
server/handle_sync_list_blobs.go
··· 14 } 15 16 func (s *Server) handleSyncListBlobs(e echo.Context) error { 17 did := e.QueryParam("did") 18 if did == "" { 19 return helpers.InputError(e, nil) ··· 35 } 36 params = append(params, limit) 37 38 - urepo, err := s.getRepoActorByDid(did) 39 if err != nil { 40 s.logger.Error("could not find user for requested blobs", "error", err) 41 return helpers.InputError(e, nil) ··· 49 } 50 51 var blobs []models.Blob 52 - if err := s.db.Raw("SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 53 s.logger.Error("error getting records", "error", err) 54 return helpers.ServerError(e, nil) 55 }
··· 14 } 15 16 func (s *Server) handleSyncListBlobs(e echo.Context) error { 17 + ctx := e.Request().Context() 18 + 19 did := e.QueryParam("did") 20 if did == "" { 21 return helpers.InputError(e, nil) ··· 37 } 38 params = append(params, limit) 39 40 + urepo, err := s.getRepoActorByDid(ctx, did) 41 if err != nil { 42 s.logger.Error("could not find user for requested blobs", "error", err) 43 return helpers.InputError(e, nil) ··· 51 } 52 53 var blobs []models.Blob 54 + if err := s.db.Raw(ctx, "SELECT * FROM blobs WHERE did = ? "+cursorquery+" ORDER BY created_at DESC LIMIT ?", nil, params...).Scan(&blobs).Error; err != nil { 55 s.logger.Error("error getting records", "error", err) 56 return helpers.ServerError(e, nil) 57 }
+3 -1
server/handle_well_known.go
··· 67 } 68 69 func (s *Server) handleAtprotoDid(e echo.Context) error { 70 host := e.Request().Host 71 if host == "" { 72 return helpers.InputError(e, to.StringPtr("Invalid handle.")) ··· 84 return e.NoContent(404) 85 } 86 87 - actor, err := s.getActorByHandle(host) 88 if err != nil { 89 if err == gorm.ErrRecordNotFound { 90 return e.NoContent(404)
··· 67 } 68 69 func (s *Server) handleAtprotoDid(e echo.Context) error { 70 + ctx := e.Request().Context() 71 + 72 host := e.Request().Host 73 if host == "" { 74 return helpers.InputError(e, to.StringPtr("Invalid handle.")) ··· 86 return e.NoContent(404) 87 } 88 89 + actor, err := s.getActorByHandle(ctx, host) 90 if err != nil { 91 if err == gorm.ErrRecordNotFound { 92 return e.NoContent(404)
+9 -5
server/middleware.go
··· 37 38 func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 authheader := e.Request().Header.Get("authorization") 41 if authheader == "" { 42 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 78 } 79 did = maybeDid 80 81 - maybeRepo, err := s.getRepoActorByDid(did) 82 if err != nil { 83 s.logger.Error("error fetching repo", "error", err) 84 return helpers.ServerError(e, nil) ··· 159 Found bool 160 } 161 var result Result 162 - if err := s.db.Raw("SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 163 if err == gorm.ErrRecordNotFound { 164 return helpers.InvalidTokenError(e) 165 } ··· 184 } 185 186 if repo == nil { 187 - maybeRepo, err := s.getRepoActorByDid(claims["sub"].(string)) 188 if err != nil { 189 s.logger.Error("error fetching repo", "error", err) 190 return helpers.ServerError(e, nil) ··· 207 208 func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 209 return func(e echo.Context) error { 210 authheader := e.Request().Header.Get("authorization") 211 if authheader == "" { 212 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 243 } 244 245 var oauthToken provider.OauthToken 246 - if err := s.db.Raw("SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 247 s.logger.Error("error finding access token in db", "error", err) 248 return helpers.InputError(e, nil) 249 } ··· 266 }) 267 } 268 269 - repo, err := s.getRepoActorByDid(oauthToken.Sub) 270 if err != nil { 271 s.logger.Error("could not find actor in db", "error", err) 272 return helpers.ServerError(e, nil)
··· 37 38 func (s *Server) handleLegacySessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 39 return func(e echo.Context) error { 40 + ctx := e.Request().Context() 41 + 42 authheader := e.Request().Header.Get("authorization") 43 if authheader == "" { 44 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 80 } 81 did = maybeDid 82 83 + maybeRepo, err := s.getRepoActorByDid(ctx, did) 84 if err != nil { 85 s.logger.Error("error fetching repo", "error", err) 86 return helpers.ServerError(e, nil) ··· 161 Found bool 162 } 163 var result Result 164 + if err := s.db.Raw(ctx, "SELECT EXISTS(SELECT 1 FROM "+table+" WHERE token = ?) AS found", nil, tokenstr).Scan(&result).Error; err != nil { 165 if err == gorm.ErrRecordNotFound { 166 return helpers.InvalidTokenError(e) 167 } ··· 186 } 187 188 if repo == nil { 189 + maybeRepo, err := s.getRepoActorByDid(ctx, claims["sub"].(string)) 190 if err != nil { 191 s.logger.Error("error fetching repo", "error", err) 192 return helpers.ServerError(e, nil) ··· 209 210 func (s *Server) handleOauthSessionMiddleware(next echo.HandlerFunc) echo.HandlerFunc { 211 return func(e echo.Context) error { 212 + ctx := e.Request().Context() 213 + 214 authheader := e.Request().Header.Get("authorization") 215 if authheader == "" { 216 return e.JSON(401, map[string]string{"error": "Unauthorized"}) ··· 247 } 248 249 var oauthToken provider.OauthToken 250 + if err := s.db.Raw(ctx, "SELECT * FROM oauth_tokens WHERE token = ?", nil, accessToken).Scan(&oauthToken).Error; err != nil { 251 s.logger.Error("error finding access token in db", "error", err) 252 return helpers.InputError(e, nil) 253 } ··· 270 }) 271 } 272 273 + repo, err := s.getRepoActorByDid(ctx, oauthToken.Sub) 274 if err != nil { 275 s.logger.Error("could not find actor in db", "error", err) 276 return helpers.ServerError(e, nil)
+11 -11
server/repo.go
··· 181 case OpTypeDelete: 182 // try to find the old record in the database 183 var old models.Record 184 - if err := rm.db.Raw("SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 185 return nil, err 186 } 187 ··· 323 var cids []cid.Cid 324 // whenever there is cid present, we know it's a create (dumb) 325 if entry.Cid != "" { 326 - if err := rm.s.db.Create(&entry, []clause.Expression{clause.OnConflict{ 327 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 328 UpdateAll: true, 329 }}).Error; err != nil { ··· 331 } 332 333 // increment the given blob refs, yay 334 - cids, err = rm.incrementBlobRefs(urepo, entry.Value) 335 if err != nil { 336 return nil, err 337 } ··· 339 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey 340 // is did + collection + rkey. i still really want to separate that out, or use a different type to make 341 // this less confusing/easy to read. alas, its 2 am and yea no 342 - if err := rm.s.db.Delete(&entry, nil).Error; err != nil { 343 return nil, err 344 } 345 346 // TODO: 347 - cids, err = rm.decrementBlobRefs(urepo, entry.Value) 348 if err != nil { 349 return nil, err 350 } ··· 411 return c, bs.GetReadLog(), nil 412 } 413 414 - func (rm *RepoMan) incrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 415 cids, err := getBlobCidsFromCbor(cbor) 416 if err != nil { 417 return nil, err 418 } 419 420 for _, c := range cids { 421 - if err := rm.db.Exec("UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 422 return nil, err 423 } 424 } ··· 426 return cids, nil 427 } 428 429 - func (rm *RepoMan) decrementBlobRefs(urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 430 cids, err := getBlobCidsFromCbor(cbor) 431 if err != nil { 432 return nil, err ··· 437 ID uint 438 Count int 439 } 440 - if err := rm.db.Raw("UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 441 return nil, err 442 } 443 444 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what 445 // storage it is in, and clean up s3!!!! 446 if res.Count == 0 { 447 - if err := rm.db.Exec("DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 448 return nil, err 449 } 450 - if err := rm.db.Exec("DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 451 return nil, err 452 } 453 }
··· 181 case OpTypeDelete: 182 // try to find the old record in the database 183 var old models.Record 184 + if err := rm.db.Raw(ctx, "SELECT value FROM records WHERE did = ? AND nsid = ? AND rkey = ?", nil, urepo.Did, op.Collection, op.Rkey).Scan(&old).Error; err != nil { 185 return nil, err 186 } 187 ··· 323 var cids []cid.Cid 324 // whenever there is cid present, we know it's a create (dumb) 325 if entry.Cid != "" { 326 + if err := rm.s.db.Create(ctx, &entry, []clause.Expression{clause.OnConflict{ 327 Columns: []clause.Column{{Name: "did"}, {Name: "nsid"}, {Name: "rkey"}}, 328 UpdateAll: true, 329 }}).Error; err != nil { ··· 331 } 332 333 // increment the given blob refs, yay 334 + cids, err = rm.incrementBlobRefs(ctx, urepo, entry.Value) 335 if err != nil { 336 return nil, err 337 } ··· 339 // as i noted above this is dumb. but we delete whenever the cid is nil. it works solely becaue the pkey 340 // is did + collection + rkey. i still really want to separate that out, or use a different type to make 341 // this less confusing/easy to read. alas, its 2 am and yea no 342 + if err := rm.s.db.Delete(ctx, &entry, nil).Error; err != nil { 343 return nil, err 344 } 345 346 // TODO: 347 + cids, err = rm.decrementBlobRefs(ctx, urepo, entry.Value) 348 if err != nil { 349 return nil, err 350 } ··· 411 return c, bs.GetReadLog(), nil 412 } 413 414 + func (rm *RepoMan) incrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 415 cids, err := getBlobCidsFromCbor(cbor) 416 if err != nil { 417 return nil, err 418 } 419 420 for _, c := range cids { 421 + if err := rm.db.Exec(ctx, "UPDATE blobs SET ref_count = ref_count + 1 WHERE did = ? AND cid = ?", nil, urepo.Did, c.Bytes()).Error; err != nil { 422 return nil, err 423 } 424 } ··· 426 return cids, nil 427 } 428 429 + func (rm *RepoMan) decrementBlobRefs(ctx context.Context, urepo models.Repo, cbor []byte) ([]cid.Cid, error) { 430 cids, err := getBlobCidsFromCbor(cbor) 431 if err != nil { 432 return nil, err ··· 437 ID uint 438 Count int 439 } 440 + if err := rm.db.Raw(ctx, "UPDATE blobs SET ref_count = ref_count - 1 WHERE did = ? AND cid = ? RETURNING id, ref_count", nil, urepo.Did, c.Bytes()).Scan(&res).Error; err != nil { 441 return nil, err 442 } 443 444 // TODO: this does _not_ handle deletions of blobs that are on s3 storage!!!! we need to get the blob, see what 445 // storage it is in, and clean up s3!!!! 446 if res.Count == 0 { 447 + if err := rm.db.Exec(ctx, "DELETE FROM blobs WHERE id = ?", nil, res.ID).Error; err != nil { 448 return nil, err 449 } 450 + if err := rm.db.Exec(ctx, "DELETE FROM blob_parts WHERE blob_id = ?", nil, res.ID).Error; err != nil { 451 return nil, err 452 } 453 }
+1 -1
server/server.go
··· 729 } 730 731 func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 732 - if err := s.db.Exec("UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 733 return err 734 } 735
··· 729 } 730 731 func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 732 + if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 733 return err 734 } 735
+4 -3
server/session.go
··· 1 package server 2 3 import ( 4 "time" 5 6 "github.com/golang-jwt/jwt/v4" ··· 13 RefreshToken string 14 } 15 16 - func (s *Server) createSession(repo *models.Repo) (*Session, error) { 17 now := time.Now() 18 accexp := now.Add(3 * time.Hour) 19 refexp := now.Add(7 * 24 * time.Hour) ··· 49 return nil, err 50 } 51 52 - if err := s.db.Create(&models.Token{ 53 Token: accessString, 54 Did: repo.Did, 55 RefreshToken: refreshString, ··· 59 return nil, err 60 } 61 62 - if err := s.db.Create(&models.RefreshToken{ 63 Token: refreshString, 64 Did: repo.Did, 65 CreatedAt: now,
··· 1 package server 2 3 import ( 4 + "context" 5 "time" 6 7 "github.com/golang-jwt/jwt/v4" ··· 14 RefreshToken string 15 } 16 17 + func (s *Server) createSession(ctx context.Context, repo *models.Repo) (*Session, error) { 18 now := time.Now() 19 accexp := now.Add(3 * time.Hour) 20 refexp := now.Add(7 * 24 * time.Hour) ··· 50 return nil, err 51 } 52 53 + if err := s.db.Create(ctx, &models.Token{ 54 Token: accessString, 55 Did: repo.Did, 56 RefreshToken: refreshString, ··· 60 return nil, err 61 } 62 63 + if err := s.db.Create(ctx, &models.RefreshToken{ 64 Token: refreshString, 65 Did: repo.Did, 66 CreatedAt: now,
+3 -3
sqlite_blockstore/sqlite_blockstore.go
··· 45 return maybeBlock, nil 46 } 47 48 - if err := bs.db.Raw("SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 return nil, err 50 } 51 ··· 71 Value: block.RawData(), 72 } 73 74 - if err := bs.db.Create(&b, []clause.Expression{clause.OnConflict{ 75 Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 UpdateAll: true, 77 }}).Error; err != nil { ··· 94 } 95 96 func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 - tx := bs.db.BeginDangerously() 98 99 for _, block := range blocks { 100 bs.inserts[block.Cid()] = block
··· 45 return maybeBlock, nil 46 } 47 48 + if err := bs.db.Raw(ctx, "SELECT * FROM blocks WHERE did = ? AND cid = ?", nil, bs.did, cid.Bytes()).Scan(&block).Error; err != nil { 49 return nil, err 50 } 51 ··· 71 Value: block.RawData(), 72 } 73 74 + if err := bs.db.Create(ctx, &b, []clause.Expression{clause.OnConflict{ 75 Columns: []clause.Column{{Name: "did"}, {Name: "cid"}}, 76 UpdateAll: true, 77 }}).Error; err != nil { ··· 94 } 95 96 func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 + tx := bs.db.BeginDangerously(ctx) 98 99 for _, block := range blocks { 100 bs.inserts[block.Cid()] = block