forked from hailey.at/cocoon
An atproto PDS written in Go
at main 23 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "embed" 8 "errors" 9 "fmt" 10 "io" 11 "log/slog" 12 "net/http" 13 "net/smtp" 14 "os" 15 "path/filepath" 16 "sync" 17 "text/template" 18 "time" 19 20 "github.com/aws/aws-sdk-go/aws" 21 "github.com/aws/aws-sdk-go/aws/credentials" 22 "github.com/aws/aws-sdk-go/aws/session" 23 "github.com/aws/aws-sdk-go/service/s3" 24 "github.com/bluesky-social/indigo/api/atproto" 25 "github.com/bluesky-social/indigo/atproto/syntax" 26 "github.com/bluesky-social/indigo/events" 27 "github.com/bluesky-social/indigo/util" 28 "github.com/bluesky-social/indigo/xrpc" 29 "github.com/domodwyer/mailyak/v3" 30 "github.com/go-playground/validator" 31 "github.com/gorilla/sessions" 32 "github.com/haileyok/cocoon/identity" 33 "github.com/haileyok/cocoon/internal/db" 34 "github.com/haileyok/cocoon/internal/helpers" 35 "github.com/haileyok/cocoon/models" 36 "github.com/haileyok/cocoon/oauth/client" 37 "github.com/haileyok/cocoon/oauth/constants" 38 "github.com/haileyok/cocoon/oauth/dpop" 39 "github.com/haileyok/cocoon/oauth/provider" 40 "github.com/haileyok/cocoon/plc" 41 "github.com/ipfs/go-cid" 42 "github.com/labstack/echo-contrib/echoprometheus" 43 echo_session "github.com/labstack/echo-contrib/session" 44 "github.com/labstack/echo/v4" 45 "github.com/labstack/echo/v4/middleware" 46 slogecho "github.com/samber/slog-echo" 47 "gorm.io/driver/postgres" 48 "gorm.io/driver/sqlite" 49 "gorm.io/gorm" 50) 51 52const ( 53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 54) 55 56type S3Config struct { 57 BackupsEnabled bool 58 BlobstoreEnabled bool 59 Endpoint string 60 Region string 61 Bucket string 62 AccessKey string 63 SecretKey string 64 CDNUrl string 65} 66 67type Server struct { 68 http *http.Client 69 httpd *http.Server 70 mail *mailyak.MailYak 71 mailLk *sync.Mutex 72 echo *echo.Echo 73 db *db.DB 74 plcClient *plc.Client 75 logger *slog.Logger 76 config *config 77 privateKey *ecdsa.PrivateKey 78 repoman *RepoMan 79 oauthProvider *provider.Provider 80 evtman *events.EventManager 81 passport *identity.Passport 82 fallbackProxy string 83 84 lastRequestCrawl time.Time 85 requestCrawlMu sync.Mutex 86 87 dbName string 88 dbType string 89 s3Config *S3Config 90} 91 92type Args struct { 93 Logger *slog.Logger 94 95 Addr string 96 DbName string 97 DbType string 98 DatabaseURL string 99 Version string 100 Did string 101 Hostname string 102 RotationKeyPath string 103 JwkPath string 104 ContactEmail string 105 Relays []string 106 AdminPassword string 107 RequireInvite bool 108 109 SmtpUser string 110 SmtpPass string 111 SmtpHost string 112 SmtpPort string 113 SmtpEmail string 114 SmtpName string 115 116 S3Config *S3Config 117 118 SessionSecret string 119 120 BlockstoreVariant BlockstoreVariant 121 FallbackProxy string 122} 123 124type config struct { 125 Version string 126 Did string 127 Hostname string 128 ContactEmail string 129 EnforcePeering bool 130 Relays []string 131 AdminPassword string 132 RequireInvite bool 133 SmtpEmail string 134 SmtpName string 135 BlockstoreVariant BlockstoreVariant 136 FallbackProxy string 137} 138 139type CustomValidator struct { 140 validator *validator.Validate 141} 142 143type ValidationError struct { 144 error 145 Field string 146 Tag string 147} 148 149func (cv *CustomValidator) Validate(i any) error { 150 if err := cv.validator.Struct(i); err != nil { 151 var validateErrors validator.ValidationErrors 152 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 153 first := validateErrors[0] 154 return ValidationError{ 155 error: err, 156 Field: first.Field(), 157 Tag: first.Tag(), 158 } 159 } 160 161 return err 162 } 163 164 return nil 165} 166 167//go:embed templates/* 168var templateFS embed.FS 169 170//go:embed static/* 171var staticFS embed.FS 172 173type TemplateRenderer struct { 174 templates *template.Template 175 isDev bool 176 templatePath string 177} 178 179func (s *Server) loadTemplates() { 180 absPath, _ := filepath.Abs("server/templates/*.html") 181 if s.config.Version == "dev" { 182 tmpl := template.Must(template.ParseGlob(absPath)) 183 s.echo.Renderer = &TemplateRenderer{ 184 templates: tmpl, 185 isDev: true, 186 templatePath: absPath, 187 } 188 } else { 189 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 190 s.echo.Renderer = &TemplateRenderer{ 191 templates: tmpl, 192 isDev: false, 193 } 194 } 195} 196 197func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 198 if t.isDev { 199 tmpl, err := template.ParseGlob(t.templatePath) 200 if err != nil { 201 return err 202 } 203 t.templates = tmpl 204 } 205 206 if viewContext, isMap := data.(map[string]any); isMap { 207 viewContext["reverse"] = c.Echo().Reverse 208 } 209 210 return t.templates.ExecuteTemplate(w, name, data) 211} 212 213func New(args *Args) (*Server, error) { 214 if args.Logger == nil { 215 args.Logger = slog.Default() 216 } 217 218 logger := args.Logger.With("name", "New") 219 220 if args.Addr == "" { 221 return nil, fmt.Errorf("addr must be set") 222 } 223 224 if args.DbName == "" { 225 return nil, fmt.Errorf("db name must be set") 226 } 227 228 if args.Did == "" { 229 return nil, fmt.Errorf("cocoon did must be set") 230 } 231 232 if args.ContactEmail == "" { 233 return nil, fmt.Errorf("cocoon contact email is required") 234 } 235 236 if _, err := syntax.ParseDID(args.Did); err != nil { 237 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 238 } 239 240 if args.Hostname == "" { 241 return nil, fmt.Errorf("cocoon hostname must be set") 242 } 243 244 if args.AdminPassword == "" { 245 return nil, fmt.Errorf("admin password must be set") 246 } 247 248 if args.SessionSecret == "" { 249 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 250 } 251 252 e := echo.New() 253 254 e.Pre(middleware.RemoveTrailingSlash()) 255 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 256 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 257 e.Use(echoprometheus.NewMiddleware("cocoon")) 258 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 259 AllowOrigins: []string{"*"}, 260 AllowHeaders: []string{"*"}, 261 AllowMethods: []string{"*"}, 262 AllowCredentials: true, 263 MaxAge: 100_000_000, 264 })) 265 266 vdtor := validator.New() 267 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 268 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 269 return false 270 } 271 return true 272 }) 273 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 274 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 275 return false 276 } 277 return true 278 }) 279 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 280 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 281 return false 282 } 283 return true 284 }) 285 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 286 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 287 return false 288 } 289 return true 290 }) 291 292 e.Validator = &CustomValidator{validator: vdtor} 293 294 httpd := &http.Server{ 295 Addr: args.Addr, 296 Handler: e, 297 // shitty defaults but okay for now, needed for import repo 298 ReadTimeout: 5 * time.Minute, 299 WriteTimeout: 5 * time.Minute, 300 IdleTimeout: 5 * time.Minute, 301 } 302 303 dbType := args.DbType 304 if dbType == "" { 305 dbType = "sqlite" 306 } 307 308 var gdb *gorm.DB 309 var err error 310 switch dbType { 311 case "postgres": 312 if args.DatabaseURL == "" { 313 return nil, fmt.Errorf("database-url must be set when using postgres") 314 } 315 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 316 if err != nil { 317 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 318 } 319 logger.Info("connected to PostgreSQL database") 320 default: 321 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 322 if err != nil { 323 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 324 } 325 logger.Info("connected to SQLite database", "path", args.DbName) 326 } 327 dbw := db.NewDB(gdb) 328 329 rkbytes, err := os.ReadFile(args.RotationKeyPath) 330 if err != nil { 331 return nil, err 332 } 333 334 h := util.RobustHTTPClient() 335 336 plcClient, err := plc.NewClient(&plc.ClientArgs{ 337 H: h, 338 Service: "https://plc.directory", 339 PdsHostname: args.Hostname, 340 RotationKey: rkbytes, 341 }) 342 if err != nil { 343 return nil, err 344 } 345 346 jwkbytes, err := os.ReadFile(args.JwkPath) 347 if err != nil { 348 return nil, err 349 } 350 351 key, err := helpers.ParseJWKFromBytes(jwkbytes) 352 if err != nil { 353 return nil, err 354 } 355 356 var pkey ecdsa.PrivateKey 357 if err := key.Raw(&pkey); err != nil { 358 return nil, err 359 } 360 361 oauthCli := &http.Client{ 362 Timeout: 10 * time.Second, 363 } 364 365 var nonceSecret []byte 366 maybeSecret, err := os.ReadFile("nonce.secret") 367 if err != nil && !os.IsNotExist(err) { 368 logger.Error("error attempting to read nonce secret", "error", err) 369 } else { 370 nonceSecret = maybeSecret 371 } 372 373 s := &Server{ 374 http: h, 375 httpd: httpd, 376 echo: e, 377 logger: args.Logger, 378 db: dbw, 379 plcClient: plcClient, 380 privateKey: &pkey, 381 config: &config{ 382 Version: args.Version, 383 Did: args.Did, 384 Hostname: args.Hostname, 385 ContactEmail: args.ContactEmail, 386 EnforcePeering: false, 387 Relays: args.Relays, 388 AdminPassword: args.AdminPassword, 389 RequireInvite: args.RequireInvite, 390 SmtpName: args.SmtpName, 391 SmtpEmail: args.SmtpEmail, 392 BlockstoreVariant: args.BlockstoreVariant, 393 FallbackProxy: args.FallbackProxy, 394 }, 395 evtman: events.NewEventManager(events.NewMemPersister()), 396 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 397 398 dbName: args.DbName, 399 dbType: dbType, 400 s3Config: args.S3Config, 401 402 oauthProvider: provider.NewProvider(provider.Args{ 403 Hostname: args.Hostname, 404 ClientManagerArgs: client.ManagerArgs{ 405 Cli: oauthCli, 406 Logger: args.Logger.With("component", "oauth-client-manager"), 407 }, 408 DpopManagerArgs: dpop.ManagerArgs{ 409 NonceSecret: nonceSecret, 410 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 411 OnNonceSecretCreated: func(newNonce []byte) { 412 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 413 logger.Error("error writing new nonce secret", "error", err) 414 } 415 }, 416 Logger: args.Logger.With("component", "dpop-manager"), 417 Hostname: args.Hostname, 418 }, 419 }), 420 } 421 422 s.loadTemplates() 423 424 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 425 426 // TODO: should validate these args 427 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 428 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 429 } else { 430 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 431 mail.From(s.config.SmtpEmail) 432 mail.FromName(s.config.SmtpName) 433 434 s.mail = mail 435 s.mailLk = &sync.Mutex{} 436 } 437 438 return s, nil 439} 440 441func (s *Server) addRoutes() { 442 // static 443 if s.config.Version == "dev" { 444 s.echo.Static("/static", "server/static") 445 } else { 446 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 447 } 448 449 // random stuff 450 s.echo.GET("/", s.handleRoot) 451 s.echo.GET("/xrpc/_health", s.handleHealth) 452 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 453 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 454 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 455 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 456 s.echo.GET("/robots.txt", s.handleRobots) 457 458 // public 459 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 460 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 461 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 462 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 463 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 464 465 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 466 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 467 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 468 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 469 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 470 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 471 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 472 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 473 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 474 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 475 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 476 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 477 478 // labels 479 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 480 481 // account 482 s.echo.GET("/account", s.handleAccount) 483 s.echo.POST("/account/revoke", s.handleAccountRevoke) 484 s.echo.GET("/account/signin", s.handleAccountSigninGet) 485 s.echo.POST("/account/signin", s.handleAccountSigninPost) 486 s.echo.GET("/account/signout", s.handleAccountSignout) 487 488 // oauth account 489 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 490 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 491 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 492 493 // oauth authorization 494 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 495 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 496 497 // authed 498 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 499 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 500 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 501 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 502 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 503 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 504 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 505 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 506 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 507 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 508 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 509 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 510 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 511 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 512 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 513 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 514 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 515 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 516 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 517 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 518 519 // repo 520 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 521 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 522 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 523 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 524 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 525 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 526 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 527 528 // stupid silly endpoints 529 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 530 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 531 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 532 533 // admin routes 534 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 535 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 536 537 // are there any routes that we should be allowing without auth? i dont think so but idk 538 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 539 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 540} 541 542func (s *Server) Serve(ctx context.Context) error { 543 logger := s.logger.With("name", "Serve") 544 545 s.addRoutes() 546 547 logger.Info("migrating...") 548 549 s.db.AutoMigrate( 550 &models.Actor{}, 551 &models.Repo{}, 552 &models.InviteCode{}, 553 &models.Token{}, 554 &models.RefreshToken{}, 555 &models.Block{}, 556 &models.Record{}, 557 &models.Blob{}, 558 &models.BlobPart{}, 559 &models.ReservedKey{}, 560 &provider.OauthToken{}, 561 &provider.OauthAuthorizationRequest{}, 562 ) 563 564 logger.Info("starting cocoon") 565 566 go func() { 567 if err := s.httpd.ListenAndServe(); err != nil { 568 panic(err) 569 } 570 }() 571 572 go s.backupRoutine() 573 574 go func() { 575 if err := s.requestCrawl(ctx); err != nil { 576 logger.Error("error requesting crawls", "err", err) 577 } 578 }() 579 580 <-ctx.Done() 581 582 fmt.Println("shut down") 583 584 return nil 585} 586 587func (s *Server) requestCrawl(ctx context.Context) error { 588 logger := s.logger.With("component", "request-crawl") 589 s.requestCrawlMu.Lock() 590 defer s.requestCrawlMu.Unlock() 591 592 logger.Info("requesting crawl with configured relays") 593 594 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 595 return fmt.Errorf("a crawl request has already been made within the last minute") 596 } 597 598 for _, relay := range s.config.Relays { 599 logger := logger.With("relay", relay) 600 logger.Info("requesting crawl from relay") 601 cli := xrpc.Client{Host: relay} 602 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 603 Hostname: s.config.Hostname, 604 }); err != nil { 605 logger.Error("error requesting crawl", "err", err) 606 } else { 607 logger.Info("crawl requested successfully") 608 } 609 } 610 611 s.lastRequestCrawl = time.Now() 612 613 return nil 614} 615 616func (s *Server) doBackup() { 617 logger := s.logger.With("name", "doBackup") 618 619 if s.dbType == "postgres" { 620 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)") 621 return 622 } 623 624 start := time.Now() 625 626 logger.Info("beginning backup to s3...") 627 628 var buf bytes.Buffer 629 if err := func() error { 630 logger.Info("reading database bytes...") 631 s.db.Lock() 632 defer s.db.Unlock() 633 634 sf, err := os.Open(s.dbName) 635 if err != nil { 636 return fmt.Errorf("error opening database for backup: %w", err) 637 } 638 defer sf.Close() 639 640 if _, err := io.Copy(&buf, sf); err != nil { 641 return fmt.Errorf("error reading bytes of backup db: %w", err) 642 } 643 644 return nil 645 }(); err != nil { 646 logger.Error("error backing up database", "error", err) 647 return 648 } 649 650 if err := func() error { 651 logger.Info("sending to s3...") 652 653 currTime := time.Now().Format("2006-01-02_15-04-05") 654 key := "cocoon-backup-" + currTime + ".db" 655 656 config := &aws.Config{ 657 Region: aws.String(s.s3Config.Region), 658 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 659 } 660 661 if s.s3Config.Endpoint != "" { 662 config.Endpoint = aws.String(s.s3Config.Endpoint) 663 config.S3ForcePathStyle = aws.Bool(true) 664 } 665 666 sess, err := session.NewSession(config) 667 if err != nil { 668 return err 669 } 670 671 svc := s3.New(sess) 672 673 if _, err := svc.PutObject(&s3.PutObjectInput{ 674 Bucket: aws.String(s.s3Config.Bucket), 675 Key: aws.String(key), 676 Body: bytes.NewReader(buf.Bytes()), 677 }); err != nil { 678 return fmt.Errorf("error uploading file to s3: %w", err) 679 } 680 681 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds()) 682 683 return nil 684 }(); err != nil { 685 logger.Error("error uploading database backup", "error", err) 686 return 687 } 688 689 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644) 690} 691 692func (s *Server) backupRoutine() { 693 logger := s.logger.With("name", "backupRoutine") 694 695 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 696 return 697 } 698 699 if s.s3Config.Region == "" { 700 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 701 return 702 } 703 704 if s.s3Config.Bucket == "" { 705 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 706 return 707 } 708 709 if s.s3Config.AccessKey == "" { 710 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 711 return 712 } 713 714 if s.s3Config.SecretKey == "" { 715 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 716 return 717 } 718 719 shouldBackupNow := false 720 lastBackupStr, err := os.ReadFile("last-backup.txt") 721 if err != nil { 722 shouldBackupNow = true 723 } else { 724 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr)) 725 if err != nil { 726 shouldBackupNow = true 727 } else if time.Now().Sub(lastBackup).Seconds() > 3600 { 728 shouldBackupNow = true 729 } 730 } 731 732 if shouldBackupNow { 733 go s.doBackup() 734 } 735 736 ticker := time.NewTicker(time.Hour) 737 for range ticker.C { 738 go s.doBackup() 739 } 740} 741 742func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 743 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 744 return err 745 } 746 747 return nil 748}