An atproto PDS written in Go
at push-based 810 lines 26 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "embed" 9 "errors" 10 "fmt" 11 "io" 12 "log/slog" 13 "net/http" 14 "net/smtp" 15 "os" 16 "path/filepath" 17 "sync" 18 "text/template" 19 "time" 20 21 "github.com/aws/aws-sdk-go/aws" 22 "github.com/aws/aws-sdk-go/aws/credentials" 23 "github.com/aws/aws-sdk-go/aws/session" 24 "github.com/aws/aws-sdk-go/service/s3" 25 "github.com/bluesky-social/indigo/api/atproto" 26 "github.com/bluesky-social/indigo/atproto/syntax" 27 "github.com/bluesky-social/indigo/events" 28 "github.com/bluesky-social/indigo/util" 29 "github.com/bluesky-social/indigo/xrpc" 30 "github.com/domodwyer/mailyak/v3" 31 "github.com/go-playground/validator" 32 "github.com/gorilla/sessions" 33 "github.com/haileyok/cocoon/identity" 34 "github.com/haileyok/cocoon/internal/db" 35 "github.com/haileyok/cocoon/internal/helpers" 36 "github.com/haileyok/cocoon/models" 37 "github.com/haileyok/cocoon/oauth/client" 38 "github.com/haileyok/cocoon/oauth/constants" 39 "github.com/haileyok/cocoon/oauth/dpop" 40 "github.com/haileyok/cocoon/oauth/provider" 41 "github.com/haileyok/cocoon/plc" 42 "github.com/ipfs/go-cid" 43 "github.com/labstack/echo-contrib/echoprometheus" 44 echo_session "github.com/labstack/echo-contrib/session" 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 _ "github.com/tursodatabase/libsql-client-go/libsql" 49 "gorm.io/driver/postgres" 50 "gorm.io/driver/sqlite" 51 "gorm.io/gorm" 52) 53 54const ( 55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 56) 57 58type S3Config struct { 59 BackupsEnabled bool 60 BlobstoreEnabled bool 61 Endpoint string 62 Region string 63 Bucket string 64 AccessKey string 65 SecretKey string 66 CDNUrl string 67} 68 69type Server struct { 70 http *http.Client 71 httpd *http.Server 72 mail *mailyak.MailYak 73 mailLk *sync.Mutex 74 echo *echo.Echo 75 db *db.DB 76 plcClient *plc.Client 77 logger *slog.Logger 78 config *config 79 privateKey *ecdsa.PrivateKey 80 repoman *RepoMan 81 oauthProvider *provider.Provider 82 evtman *events.EventManager 83 passport *identity.Passport 84 fallbackProxy string 85 86 lastRequestCrawl time.Time 87 requestCrawlMu sync.Mutex 88 89 dbName string 90 dbType string 91 s3Config *S3Config 92} 93 94type Args struct { 95 Logger *slog.Logger 96 97 LogLevel slog.Level 98 Addr string 99 DbName string 100 DbType string 101 DatabaseURL string 102 TursoToken string 103 Version string 104 Did string 105 Hostname string 106 RotationKeyPath string 107 JwkPath string 108 ContactEmail string 109 Relays []string 110 AdminPassword string 111 RequireInvite bool 112 113 SmtpUser string 114 SmtpPass string 115 SmtpHost string 116 SmtpPort string 117 SmtpEmail string 118 SmtpName string 119 120 S3Config *S3Config 121 122 SessionSecret string 123 SessionCookieKey string 124 125 BlockstoreVariant BlockstoreVariant 126 FallbackProxy string 127 128 PushBasedEvents bool 129 SubscribeReposServiceURL string 130} 131 132type config struct { 133 LogLevel slog.Level 134 Version string 135 Did string 136 Hostname string 137 ContactEmail string 138 EnforcePeering bool 139 Relays []string 140 AdminPassword string 141 RequireInvite bool 142 SmtpEmail string 143 SmtpName string 144 SessionCookieKey string 145 BlockstoreVariant BlockstoreVariant 146 FallbackProxy string 147 148 PushBasedEvents bool 149 SubscribeReposServiceURL string 150} 151 152type CustomValidator struct { 153 validator *validator.Validate 154} 155 156type ValidationError struct { 157 error 158 Field string 159 Tag string 160} 161 162func (cv *CustomValidator) Validate(i any) error { 163 if err := cv.validator.Struct(i); err != nil { 164 var validateErrors validator.ValidationErrors 165 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 166 first := validateErrors[0] 167 return ValidationError{ 168 error: err, 169 Field: first.Field(), 170 Tag: first.Tag(), 171 } 172 } 173 174 return err 175 } 176 177 return nil 178} 179 180//go:embed templates/* 181var templateFS embed.FS 182 183//go:embed static/* 184var staticFS embed.FS 185 186type TemplateRenderer struct { 187 templates *template.Template 188 isDev bool 189 templatePath string 190} 191 192func (s *Server) loadTemplates() { 193 absPath, _ := filepath.Abs("server/templates/*.html") 194 if s.config.Version == "dev" { 195 tmpl := template.Must(template.ParseGlob(absPath)) 196 s.echo.Renderer = &TemplateRenderer{ 197 templates: tmpl, 198 isDev: true, 199 templatePath: absPath, 200 } 201 } else { 202 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 203 s.echo.Renderer = &TemplateRenderer{ 204 templates: tmpl, 205 isDev: false, 206 } 207 } 208} 209 210func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 211 if t.isDev { 212 tmpl, err := template.ParseGlob(t.templatePath) 213 if err != nil { 214 return err 215 } 216 t.templates = tmpl 217 } 218 219 if viewContext, isMap := data.(map[string]any); isMap { 220 viewContext["reverse"] = c.Echo().Reverse 221 } 222 223 return t.templates.ExecuteTemplate(w, name, data) 224} 225 226type filteredHandler struct { 227 level slog.Level 228 handler slog.Handler 229} 230 231func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 232 return level >= h.level && h.handler.Enabled(ctx, level) 233} 234 235func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 236 return h.handler.Handle(ctx, r) 237} 238 239func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 240 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 241} 242 243func (h *filteredHandler) WithGroup(name string) slog.Handler { 244 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 245} 246 247func New(args *Args) (*Server, error) { 248 if args.Logger == nil { 249 args.Logger = slog.Default() 250 } 251 252 if args.LogLevel != 0 { 253 args.Logger = slog.New(&filteredHandler{ 254 level: args.LogLevel, 255 handler: args.Logger.Handler(), 256 }) 257 } 258 259 logger := args.Logger.With("name", "New") 260 261 if args.Addr == "" { 262 return nil, fmt.Errorf("addr must be set") 263 } 264 265 if args.DbName == "" { 266 return nil, fmt.Errorf("db name must be set") 267 } 268 269 if args.Did == "" { 270 return nil, fmt.Errorf("cocoon did must be set") 271 } 272 273 if args.ContactEmail == "" { 274 return nil, fmt.Errorf("cocoon contact email is required") 275 } 276 277 if _, err := syntax.ParseDID(args.Did); err != nil { 278 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 279 } 280 281 if args.Hostname == "" { 282 return nil, fmt.Errorf("cocoon hostname must be set") 283 } 284 285 if args.AdminPassword == "" { 286 return nil, fmt.Errorf("admin password must be set") 287 } 288 289 if args.SessionSecret == "" { 290 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 291 } 292 293 e := echo.New() 294 295 e.Pre(middleware.RemoveTrailingSlash()) 296 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 297 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 298 e.Use(echoprometheus.NewMiddleware("cocoon")) 299 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 300 AllowOrigins: []string{"*"}, 301 AllowHeaders: []string{"*"}, 302 AllowMethods: []string{"*"}, 303 AllowCredentials: true, 304 MaxAge: 100_000_000, 305 })) 306 307 vdtor := validator.New() 308 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 309 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 310 return false 311 } 312 return true 313 }) 314 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 315 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 316 return false 317 } 318 return true 319 }) 320 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 321 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 322 return false 323 } 324 return true 325 }) 326 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 327 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 328 return false 329 } 330 return true 331 }) 332 333 e.Validator = &CustomValidator{validator: vdtor} 334 335 httpd := &http.Server{ 336 Addr: args.Addr, 337 Handler: e, 338 // shitty defaults but okay for now, needed for import repo 339 ReadTimeout: 5 * time.Minute, 340 WriteTimeout: 5 * time.Minute, 341 IdleTimeout: 5 * time.Minute, 342 } 343 344 dbType := args.DbType 345 if dbType == "" { 346 dbType = "sqlite" 347 } 348 349 var gdb *gorm.DB 350 var err error 351 switch dbType { 352 case "postgres": 353 if args.DatabaseURL == "" { 354 return nil, fmt.Errorf("database-url must be set when using postgres") 355 } 356 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 357 if err != nil { 358 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 359 } 360 logger.Info("connected to PostgreSQL database") 361 case "turso": 362 primaryUrl := args.DatabaseURL 363 authToken := args.TursoToken 364 365 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken)) 366 gdb, err = gorm.Open(sqlite.New(sqlite.Config{ 367 Conn: db, 368 }), &gorm.Config{}) 369 if err != nil { 370 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 371 } 372 logger.Info("connected to PostgreSQL database") 373 374 default: 375 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 376 if err != nil { 377 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 378 } 379 gdb.Exec("PRAGMA journal_mode=WAL") 380 gdb.Exec("PRAGMA synchronous=NORMAL") 381 382 logger.Info("connected to SQLite database", "path", args.DbName) 383 } 384 dbw := db.NewDB(gdb) 385 386 rkbytes, err := os.ReadFile(args.RotationKeyPath) 387 if err != nil { 388 return nil, err 389 } 390 391 h := util.RobustHTTPClient() 392 393 plcClient, err := plc.NewClient(&plc.ClientArgs{ 394 H: h, 395 Service: "https://plc.directory", 396 PdsHostname: args.Hostname, 397 RotationKey: rkbytes, 398 }) 399 if err != nil { 400 return nil, err 401 } 402 403 jwkbytes, err := os.ReadFile(args.JwkPath) 404 if err != nil { 405 return nil, err 406 } 407 408 key, err := helpers.ParseJWKFromBytes(jwkbytes) 409 if err != nil { 410 return nil, err 411 } 412 413 var pkey ecdsa.PrivateKey 414 if err := key.Raw(&pkey); err != nil { 415 return nil, err 416 } 417 418 oauthCli := &http.Client{ 419 Timeout: 10 * time.Second, 420 } 421 422 var nonceSecret []byte 423 maybeSecret, err := os.ReadFile("nonce.secret") 424 if err != nil && !os.IsNotExist(err) { 425 logger.Error("error attempting to read nonce secret", "error", err) 426 } else { 427 nonceSecret = maybeSecret 428 } 429 430 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 431 if err != nil { 432 return nil, fmt.Errorf("failed to create event persister: %w", err) 433 } 434 435 s := &Server{ 436 http: h, 437 httpd: httpd, 438 echo: e, 439 logger: args.Logger, 440 db: dbw, 441 plcClient: plcClient, 442 privateKey: &pkey, 443 config: &config{ 444 LogLevel: args.LogLevel, 445 Version: args.Version, 446 Did: args.Did, 447 Hostname: args.Hostname, 448 ContactEmail: args.ContactEmail, 449 EnforcePeering: false, 450 Relays: args.Relays, 451 AdminPassword: args.AdminPassword, 452 RequireInvite: args.RequireInvite, 453 SmtpName: args.SmtpName, 454 SmtpEmail: args.SmtpEmail, 455 SessionCookieKey: args.SessionCookieKey, 456 BlockstoreVariant: args.BlockstoreVariant, 457 FallbackProxy: args.FallbackProxy, 458 PushBasedEvents: args.PushBasedEvents, 459 SubscribeReposServiceURL: args.SubscribeReposServiceURL, 460 }, 461 evtman: events.NewEventManager(evtPersister), 462 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 463 464 dbName: args.DbName, 465 dbType: dbType, 466 s3Config: args.S3Config, 467 468 oauthProvider: provider.NewProvider(provider.Args{ 469 Hostname: args.Hostname, 470 ClientManagerArgs: client.ManagerArgs{ 471 Cli: oauthCli, 472 Logger: args.Logger.With("component", "oauth-client-manager"), 473 }, 474 DpopManagerArgs: dpop.ManagerArgs{ 475 NonceSecret: nonceSecret, 476 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 477 OnNonceSecretCreated: func(newNonce []byte) { 478 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 479 logger.Error("error writing new nonce secret", "error", err) 480 } 481 }, 482 Logger: args.Logger.With("component", "dpop-manager"), 483 Hostname: args.Hostname, 484 }, 485 }), 486 } 487 488 s.loadTemplates() 489 490 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 491 492 // TODO: should validate these args 493 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 494 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 495 } else { 496 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 497 mail.From(s.config.SmtpEmail) 498 mail.FromName(s.config.SmtpName) 499 500 s.mail = mail 501 s.mailLk = &sync.Mutex{} 502 } 503 504 return s, nil 505} 506 507func (s *Server) addRoutes() { 508 // static 509 if s.config.Version == "dev" { 510 s.echo.Static("/static", "server/static") 511 } else { 512 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 513 } 514 515 // random stuff 516 s.echo.GET("/", s.handleRoot) 517 s.echo.GET("/xrpc/_health", s.handleHealth) 518 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 519 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 520 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 521 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 522 s.echo.GET("/robots.txt", s.handleRobots) 523 524 // public 525 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 526 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 527 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 528 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 529 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 530 531 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 532 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 533 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 534 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 535 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 536 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 537 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 538 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 539 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 540 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 541 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 542 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 543 544 // labels 545 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 546 547 // account 548 s.echo.GET("/account", s.handleAccount) 549 s.echo.POST("/account/revoke", s.handleAccountRevoke) 550 s.echo.GET("/account/signin", s.handleAccountSigninGet) 551 s.echo.POST("/account/signin", s.handleAccountSigninPost) 552 s.echo.GET("/account/signout", s.handleAccountSignout) 553 554 // oauth account 555 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 556 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 557 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 558 559 // oauth authorization 560 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 561 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 562 563 // authed 564 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 565 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 566 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 567 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 568 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 570 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 571 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 572 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 573 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 575 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 576 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 577 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 578 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 579 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 580 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 581 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 583 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 584 585 // repo 586 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 587 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 588 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 589 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 590 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 591 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 592 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 593 594 // stupid silly endpoints 595 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 596 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 597 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 598 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 599 // admin routes 600 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 601 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 602 603 // are there any routes that we should be allowing without auth? i dont think so but idk 604 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 605 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 606} 607 608func (s *Server) Serve(ctx context.Context) error { 609 logger := s.logger.With("name", "Serve") 610 611 s.addRoutes() 612 613 logger.Info("migrating...") 614 615 s.db.AutoMigrate( 616 &models.Actor{}, 617 &models.Repo{}, 618 &models.InviteCode{}, 619 &models.Token{}, 620 &models.RefreshToken{}, 621 &models.Block{}, 622 &models.Record{}, 623 &models.Blob{}, 624 &models.BlobPart{}, 625 &models.ReservedKey{}, 626 &provider.OauthToken{}, 627 &provider.OauthAuthorizationRequest{}, 628 ) 629 630 logger.Info("starting cocoon") 631 632 go func() { 633 if err := s.httpd.ListenAndServe(); err != nil { 634 panic(err) 635 } 636 }() 637 638 go s.backupRoutine() 639 640 go func() { 641 if err := s.requestCrawl(ctx); err != nil { 642 logger.Error("error requesting crawls", "err", err) 643 } 644 }() 645 646 if s.config.PushBasedEvents { 647 slog.Info("pushed based events enabled") 648 go func() { 649 if err := s.emmitEvents(ctx); err != nil { 650 logger.Error("error emitting events", "err", err) 651 } 652 }() 653 } 654 655 <-ctx.Done() 656 657 fmt.Println("shut down") 658 659 return nil 660} 661 662func (s *Server) requestCrawl(ctx context.Context) error { 663 logger := s.logger.With("component", "request-crawl") 664 s.requestCrawlMu.Lock() 665 defer s.requestCrawlMu.Unlock() 666 667 logger.Info("requesting crawl with configured relays") 668 669 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 670 return fmt.Errorf("a crawl request has already been made within the last minute") 671 } 672 673 for _, relay := range s.config.Relays { 674 logger := logger.With("relay", relay) 675 logger.Info("requesting crawl from relay") 676 cli := xrpc.Client{Host: relay} 677 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 678 Hostname: s.config.Hostname, 679 }); err != nil { 680 logger.Error("error requesting crawl", "err", err) 681 } else { 682 logger.Info("crawl requested successfully") 683 } 684 } 685 686 s.lastRequestCrawl = time.Now() 687 688 return nil 689} 690 691func (s *Server) doBackup() { 692 logger := s.logger.With("name", "doBackup") 693 694 if s.dbType == "postgres" || s.dbType == "turso" { 695 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)") 696 return 697 } 698 699 start := time.Now() 700 701 logger.Info("beginning backup to s3...") 702 703 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 704 defer os.Remove(tmpFile) 705 706 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 707 logger.Error("error creating tmp backup file", "err", err) 708 return 709 } 710 711 backupData, err := os.ReadFile(tmpFile) 712 if err != nil { 713 logger.Error("error reading tmp backup file", "err", err) 714 return 715 } 716 717 logger.Info("sending to s3...") 718 719 currTime := time.Now().Format("2006-01-02_15-04-05") 720 key := "cocoon-backup-" + currTime + ".db" 721 722 config := &aws.Config{ 723 Region: aws.String(s.s3Config.Region), 724 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 725 } 726 727 if s.s3Config.Endpoint != "" { 728 config.Endpoint = aws.String(s.s3Config.Endpoint) 729 config.S3ForcePathStyle = aws.Bool(true) 730 } 731 732 sess, err := session.NewSession(config) 733 if err != nil { 734 logger.Error("error creating s3 session", "err", err) 735 return 736 } 737 738 svc := s3.New(sess) 739 740 if _, err := svc.PutObject(&s3.PutObjectInput{ 741 Bucket: aws.String(s.s3Config.Bucket), 742 Key: aws.String(key), 743 Body: bytes.NewReader(backupData), 744 }); err != nil { 745 logger.Error("error uploading file to s3", "err", err) 746 return 747 } 748 749 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 750 751 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 752} 753 754func (s *Server) backupRoutine() { 755 logger := s.logger.With("name", "backupRoutine") 756 757 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 758 return 759 } 760 761 if s.s3Config.Region == "" { 762 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 763 return 764 } 765 766 if s.s3Config.Bucket == "" { 767 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 768 return 769 } 770 771 if s.s3Config.AccessKey == "" { 772 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 773 return 774 } 775 776 if s.s3Config.SecretKey == "" { 777 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 778 return 779 } 780 781 shouldBackupNow := false 782 lastBackupStr, err := os.ReadFile("last-backup.txt") 783 if err != nil { 784 shouldBackupNow = true 785 } else { 786 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 787 if err != nil { 788 shouldBackupNow = true 789 } else if time.Since(lastBackup).Seconds() > 3600 { 790 shouldBackupNow = true 791 } 792 } 793 794 if shouldBackupNow { 795 go s.doBackup() 796 } 797 798 ticker := time.NewTicker(time.Hour) 799 for range ticker.C { 800 go s.doBackup() 801 } 802} 803 804func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 805 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 806 return err 807 } 808 809 return nil 810}