An atproto PDS written in Go

Compare changes

Choose any two refs to compare.

+3 -21
internal/db/db.go
··· 2 2 3 3 import ( 4 4 "context" 5 - "sync" 6 5 7 6 "gorm.io/gorm" 8 7 "gorm.io/gorm/clause" ··· 10 9 11 10 type DB struct { 12 11 cli *gorm.DB 13 - mu sync.Mutex 14 12 } 15 13 16 14 func NewDB(cli *gorm.DB) *DB { 17 15 return &DB{ 18 16 cli: cli, 19 - mu: sync.Mutex{}, 20 17 } 21 18 } 22 19 23 20 func (db *DB) Create(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 24 - db.mu.Lock() 25 - defer db.mu.Unlock() 26 21 return db.cli.WithContext(ctx).Clauses(clauses...).Create(value) 27 22 } 28 23 29 24 func (db *DB) Save(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 30 - db.mu.Lock() 31 - defer db.mu.Unlock() 32 25 return db.cli.WithContext(ctx).Clauses(clauses...).Save(value) 33 26 } 34 27 35 28 func (db *DB) Exec(ctx context.Context, sql string, clauses []clause.Expression, values ...any) *gorm.DB { 36 - db.mu.Lock() 37 - defer db.mu.Unlock() 38 29 return db.cli.WithContext(ctx).Clauses(clauses...).Exec(sql, values...) 39 30 } 40 31 ··· 47 38 } 48 39 49 40 func (db *DB) Delete(ctx context.Context, value any, clauses []clause.Expression) *gorm.DB { 50 - db.mu.Lock() 51 - defer db.mu.Unlock() 52 41 return db.cli.WithContext(ctx).Clauses(clauses...).Delete(value) 53 42 } 54 43 ··· 56 45 return db.cli.WithContext(ctx).First(dest, conds...) 57 46 } 58 47 59 - // TODO: this isn't actually good. we can commit even if the db is locked here. this is probably okay for the time being, but need to figure 60 - // out a better solution. right now we only do this whenever we're importing a repo though so i'm mostly not worried, but it's still bad. 61 - // e.g. when we do apply writes we should also be using a transcation but we don't right now 62 - func (db *DB) BeginDangerously(ctx context.Context) *gorm.DB { 48 + func (db *DB) Begin(ctx context.Context) *gorm.DB { 63 49 return db.cli.WithContext(ctx).Begin() 64 50 } 65 51 66 - func (db *DB) Lock() { 67 - db.mu.Lock() 68 - } 69 - 70 - func (db *DB) Unlock() { 71 - db.mu.Unlock() 52 + func (db *DB) Client() *gorm.DB { 53 + return db.cli 72 54 }
+4 -4
oauth/dpop/manager.go
··· 75 75 } 76 76 77 77 proof := extractProof(headers) 78 - 79 78 if proof == "" { 80 79 return nil, nil 81 80 } ··· 197 196 198 197 nonce, _ := claims["nonce"].(string) 199 198 if nonce == "" { 200 - // WARN: this _must_ be `use_dpop_nonce` for clients know they should make another request 199 + // reference impl checks if self.nonce is not null before returning an error, but we always have a 200 + // nonce so we do not bother checking 201 201 return nil, ErrUseDpopNonce 202 202 } 203 203 204 204 if nonce != "" && !dm.nonce.Check(nonce) { 205 - // WARN: this _must_ be `use_dpop_nonce` so that clients will fetch a new nonce 205 + // dpop nonce mismatch 206 206 return nil, ErrUseDpopNonce 207 207 } 208 208 ··· 237 237 } 238 238 239 239 func extractProof(headers http.Header) string { 240 - dpopHeaders := headers["Dpop"] 240 + dpopHeaders := headers.Values("dpop") 241 241 switch len(dpopHeaders) { 242 242 case 0: 243 243 return ""
+3 -3
oauth/provider/client_auth.go
··· 19 19 } 20 20 21 21 type AuthenticateClientRequestBase struct { 22 - ClientID string `form:"client_id" json:"client_id" validate:"required"` 23 - ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty"` 24 - ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty"` 22 + ClientID string `form:"client_id" json:"client_id" query:"client_id" validate:"required"` 23 + ClientAssertionType *string `form:"client_assertion_type" json:"client_assertion_type,omitempty" query:"client_assertion_type"` 24 + ClientAssertion *string `form:"client_assertion" json:"client_assertion,omitempty" query:"client_assertion"` 25 25 } 26 26 27 27 func (p *Provider) AuthenticateClient(ctx context.Context, req AuthenticateClientRequestBase, proof *dpop.Proof, opts *AuthenticateClientOptions) (*client.Client, *ClientAuth, error) {
+9 -8
oauth/provider/models.go
··· 32 32 33 33 type ParRequest struct { 34 34 AuthenticateClientRequestBase 35 - ResponseType string `form:"response_type" json:"response_type" validate:"required"` 36 - CodeChallenge *string `form:"code_challenge" json:"code_challenge" validate:"required"` 37 - CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" validate:"required"` 38 - State string `form:"state" json:"state" validate:"required"` 39 - RedirectURI string `form:"redirect_uri" json:"redirect_uri" validate:"required"` 40 - Scope string `form:"scope" json:"scope" validate:"required"` 41 - LoginHint *string `form:"login_hint" json:"login_hint,omitempty"` 42 - DpopJkt *string `form:"dpop_jkt" json:"dpop_jkt,omitempty"` 35 + ResponseType string `form:"response_type" json:"response_type" query:"response_type" validate:"required"` 36 + CodeChallenge *string `form:"code_challenge" json:"code_challenge" query:"code_challenge" validate:"required"` 37 + CodeChallengeMethod string `form:"code_challenge_method" json:"code_challenge_method" query:"code_challenge_method" validate:"required"` 38 + State string `form:"state" json:"state" query:"state" validate:"required"` 39 + RedirectURI string `form:"redirect_uri" json:"redirect_uri" query:"redirect_uri" validate:"required"` 40 + Scope string `form:"scope" json:"scope" query:"scope" validate:"required"` 41 + LoginHint *string `form:"login_hint" query:"login_hint" json:"login_hint,omitempty"` 42 + DpopJkt *string `form:"dpop_jkt" query:"dpop_jkt" json:"dpop_jkt,omitempty"` 43 + ResponseMode *string `form:"response_mode" json:"response_mode,omitempty" query:"response_mode"` 43 44 } 44 45 45 46 func (opr *ParRequest) Scan(value any) error {
+1 -1
server/handle_import_repo.go
··· 66 66 return helpers.ServerError(e, nil) 67 67 } 68 68 69 - tx := s.db.BeginDangerously(ctx) 69 + tx := s.db.Begin(ctx) 70 70 71 71 clock := syntax.NewTIDClock(0) 72 72
+95 -19
server/handle_oauth_authorize.go
··· 1 1 package server 2 2 3 3 import ( 4 + "fmt" 4 5 "net/url" 5 6 "strings" 6 7 "time" ··· 8 9 "github.com/Azure/go-autorest/autorest/to" 9 10 "github.com/haileyok/cocoon/internal/helpers" 10 11 "github.com/haileyok/cocoon/oauth" 12 + "github.com/haileyok/cocoon/oauth/constants" 11 13 "github.com/haileyok/cocoon/oauth/provider" 12 14 "github.com/labstack/echo/v4" 13 15 ) 14 16 17 + type HandleOauthAuthorizeGetInput struct { 18 + RequestUri string `query:"request_uri"` 19 + } 20 + 15 21 func (s *Server) handleOauthAuthorizeGet(e echo.Context) error { 16 22 ctx := e.Request().Context() 17 23 18 - reqUri := e.QueryParam("request_uri") 19 - if reqUri == "" { 20 - // render page for logged out dev 21 - if s.config.Version == "dev" { 22 - return e.Render(200, "authorize.html", map[string]any{ 23 - "Scopes": []string{"atproto", "transition:generic"}, 24 - "AppName": "DEV MODE AUTHORIZATION PAGE", 25 - "Handle": "paula.cocoon.social", 26 - "RequestUri": "", 27 - }) 24 + logger := s.logger.With("name", "handleOauthAuthorizeGet") 25 + 26 + var input HandleOauthAuthorizeGetInput 27 + if err := e.Bind(&input); err != nil { 28 + logger.Error("error binding request", "err", err) 29 + return fmt.Errorf("error binding request") 30 + } 31 + 32 + var reqId string 33 + if input.RequestUri != "" { 34 + id, err := oauth.DecodeRequestUri(input.RequestUri) 35 + if err != nil { 36 + logger.Error("no request uri found in input", "url", e.Request().URL.String()) 37 + return helpers.InputError(e, to.StringPtr("no request uri")) 38 + } 39 + reqId = id 40 + } else { 41 + var parRequest provider.ParRequest 42 + if err := e.Bind(&parRequest); err != nil { 43 + s.logger.Error("error binding for standard auth request", "error", err) 44 + return helpers.InputError(e, to.StringPtr("InvalidRequest")) 45 + } 46 + 47 + if err := e.Validate(parRequest); err != nil { 48 + // render page for logged out dev 49 + if s.config.Version == "dev" && parRequest.ClientID == "" { 50 + return e.Render(200, "authorize.html", map[string]any{ 51 + "Scopes": []string{"atproto", "transition:generic"}, 52 + "AppName": "DEV MODE AUTHORIZATION PAGE", 53 + "Handle": "paula.cocoon.social", 54 + "RequestUri": "", 55 + }) 56 + } 57 + return helpers.InputError(e, to.StringPtr("no request uri and invalid parameters")) 28 58 } 29 - return helpers.InputError(e, to.StringPtr("no request uri")) 59 + 60 + client, clientAuth, err := s.oauthProvider.AuthenticateClient(ctx, parRequest.AuthenticateClientRequestBase, nil, &provider.AuthenticateClientOptions{ 61 + AllowMissingDpopProof: true, 62 + }) 63 + if err != nil { 64 + s.logger.Error("error authenticating client in standard request", "client_id", parRequest.ClientID, "error", err) 65 + return helpers.ServerError(e, to.StringPtr(err.Error())) 66 + } 67 + 68 + if parRequest.DpopJkt == nil { 69 + if client.Metadata.DpopBoundAccessTokens { 70 + } 71 + } else { 72 + if !client.Metadata.DpopBoundAccessTokens { 73 + msg := "dpop bound access tokens are not enabled for this client" 74 + return helpers.InputError(e, &msg) 75 + } 76 + } 77 + 78 + eat := time.Now().Add(constants.ParExpiresIn) 79 + id := oauth.GenerateRequestId() 80 + 81 + authRequest := &provider.OauthAuthorizationRequest{ 82 + RequestId: id, 83 + ClientId: client.Metadata.ClientID, 84 + ClientAuth: *clientAuth, 85 + Parameters: parRequest, 86 + ExpiresAt: eat, 87 + } 88 + 89 + if err := s.db.Create(ctx, authRequest, nil).Error; err != nil { 90 + s.logger.Error("error creating auth request in db", "error", err) 91 + return helpers.ServerError(e, nil) 92 + } 93 + 94 + input.RequestUri = oauth.EncodeRequestUri(id) 95 + reqId = id 96 + 30 97 } 31 98 32 99 repo, _, err := s.getSessionRepoOrErr(e) 33 100 if err != nil { 34 101 return e.Redirect(303, "/account/signin?"+e.QueryParams().Encode()) 35 - } 36 - 37 - reqId, err := oauth.DecodeRequestUri(reqUri) 38 - if err != nil { 39 - return helpers.InputError(e, to.StringPtr(err.Error())) 40 102 } 41 103 42 104 var req provider.OauthAuthorizationRequest ··· 60 122 data := map[string]any{ 61 123 "Scopes": scopes, 62 124 "AppName": appName, 63 - "RequestUri": reqUri, 125 + "RequestUri": input.RequestUri, 64 126 "QueryParams": e.QueryParams().Encode(), 65 127 "Handle": repo.Actor.Handle, 66 128 } ··· 129 191 q.Set("code", code) 130 192 131 193 hashOrQuestion := "?" 132 - if authReq.ClientAuth.Method != "private_key_jwt" { 133 - hashOrQuestion = "#" 194 + if authReq.Parameters.ResponseMode != nil { 195 + switch *authReq.Parameters.ResponseMode { 196 + case "fragment": 197 + hashOrQuestion = "#" 198 + case "query": 199 + // do nothing 200 + break 201 + default: 202 + if authReq.Parameters.ResponseType != "code" { 203 + hashOrQuestion = "#" 204 + } 205 + } 206 + } else { 207 + if authReq.Parameters.ResponseType != "code" { 208 + hashOrQuestion = "#" 209 + } 134 210 } 135 211 136 212 return e.Redirect(303, authReq.Parameters.RedirectURI+hashOrQuestion+q.Encode())
+1
server/handle_oauth_par.go
··· 42 42 e.Response().Header().Set("DPoP-Nonce", nonce) 43 43 e.Response().Header().Add("access-control-expose-headers", "DPoP-Nonce") 44 44 } 45 + logger.Error("nonce error: use_dpop_nonce", "headers", e.Request().Header) 45 46 return e.JSON(400, map[string]string{ 46 47 "error": "use_dpop_nonce", 47 48 })
+1 -1
server/handle_server_delete_account.go
··· 69 69 }) 70 70 } 71 71 72 - tx := s.db.BeginDangerously(ctx) 72 + tx := s.db.Begin(ctx) 73 73 if tx.Error != nil { 74 74 logger.Error("error starting transaction", "error", tx.Error) 75 75 return helpers.ServerError(e, nil)
+40 -50
server/server.go
··· 322 322 if err != nil { 323 323 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 324 324 } 325 + gdb.Exec("PRAGMA journal_mode=WAL") 326 + gdb.Exec("PRAGMA synchronous=NORMAL") 327 + 325 328 logger.Info("connected to SQLite database", "path", args.DbName) 326 329 } 327 330 dbw := db.NewDB(gdb) ··· 625 628 626 629 logger.Info("beginning backup to s3...") 627 630 628 - var buf bytes.Buffer 629 - if err := func() error { 630 - logger.Info("reading database bytes...") 631 - s.db.Lock() 632 - defer s.db.Unlock() 633 - 634 - sf, err := os.Open(s.dbName) 635 - if err != nil { 636 - return fmt.Errorf("error opening database for backup: %w", err) 637 - } 638 - defer sf.Close() 631 + tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 632 + defer os.Remove(tmpFile) 639 633 640 - if _, err := io.Copy(&buf, sf); err != nil { 641 - return fmt.Errorf("error reading bytes of backup db: %w", err) 642 - } 634 + if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 635 + logger.Error("error creating tmp backup file", "err", err) 636 + return 637 + } 643 638 644 - return nil 645 - }(); err != nil { 646 - logger.Error("error backing up database", "error", err) 639 + backupData, err := os.ReadFile(tmpFile) 640 + if err != nil { 641 + logger.Error("error reading tmp backup file", "err", err) 647 642 return 648 643 } 649 644 650 - if err := func() error { 651 - logger.Info("sending to s3...") 645 + logger.Info("sending to s3...") 652 646 653 - currTime := time.Now().Format("2006-01-02_15-04-05") 654 - key := "cocoon-backup-" + currTime + ".db" 647 + currTime := time.Now().Format("2006-01-02_15-04-05") 648 + key := "cocoon-backup-" + currTime + ".db" 655 649 656 - config := &aws.Config{ 657 - Region: aws.String(s.s3Config.Region), 658 - Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 659 - } 650 + config := &aws.Config{ 651 + Region: aws.String(s.s3Config.Region), 652 + Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 653 + } 660 654 661 - if s.s3Config.Endpoint != "" { 662 - config.Endpoint = aws.String(s.s3Config.Endpoint) 663 - config.S3ForcePathStyle = aws.Bool(true) 664 - } 655 + if s.s3Config.Endpoint != "" { 656 + config.Endpoint = aws.String(s.s3Config.Endpoint) 657 + config.S3ForcePathStyle = aws.Bool(true) 658 + } 665 659 666 - sess, err := session.NewSession(config) 667 - if err != nil { 668 - return err 669 - } 660 + sess, err := session.NewSession(config) 661 + if err != nil { 662 + logger.Error("error creating s3 session", "err", err) 663 + return 664 + } 670 665 671 - svc := s3.New(sess) 666 + svc := s3.New(sess) 672 667 673 - if _, err := svc.PutObject(&s3.PutObjectInput{ 674 - Bucket: aws.String(s.s3Config.Bucket), 675 - Key: aws.String(key), 676 - Body: bytes.NewReader(buf.Bytes()), 677 - }); err != nil { 678 - return fmt.Errorf("error uploading file to s3: %w", err) 679 - } 680 - 681 - logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 682 - 683 - return nil 684 - }(); err != nil { 685 - logger.Error("error uploading database backup", "error", err) 668 + if _, err := svc.PutObject(&s3.PutObjectInput{ 669 + Bucket: aws.String(s.s3Config.Bucket), 670 + Key: aws.String(key), 671 + Body: bytes.NewReader(backupData), 672 + }); err != nil { 673 + logger.Error("error uploading file to s3", "err", err) 686 674 return 687 675 } 688 676 689 - os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 677 + logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 678 + 679 + os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 690 680 } 691 681 692 682 func (s *Server) backupRoutine() { ··· 721 711 if err != nil { 722 712 shouldBackupNow = true 723 713 } else { 724 - lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 714 + lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 725 715 if err != nil { 726 716 shouldBackupNow = true 727 - } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 717 + } else if time.Since(lastBackup).Seconds() > 3600 { 728 718 shouldBackupNow = true 729 719 } 730 720 }
+1 -1
sqlite_blockstore/sqlite_blockstore.go
··· 94 94 } 95 95 96 96 func (bs *SqliteBlockstore) PutMany(ctx context.Context, blocks []blocks.Block) error { 97 - tx := bs.db.BeginDangerously(ctx) 97 + tx := bs.db.Begin(ctx) 98 98 99 99 for _, block := range blocks { 100 100 bs.inserts[block.Cid()] = block