An atproto PDS written in Go
at turso-db 793 lines 25 kB view raw
1package server 2 3import ( 4 "bytes" 5 "context" 6 "crypto/ecdsa" 7 "database/sql" 8 "embed" 9 "errors" 10 "fmt" 11 "io" 12 "log/slog" 13 "net/http" 14 "net/smtp" 15 "os" 16 "path/filepath" 17 "sync" 18 "text/template" 19 "time" 20 21 "github.com/aws/aws-sdk-go/aws" 22 "github.com/aws/aws-sdk-go/aws/credentials" 23 "github.com/aws/aws-sdk-go/aws/session" 24 "github.com/aws/aws-sdk-go/service/s3" 25 "github.com/bluesky-social/indigo/api/atproto" 26 "github.com/bluesky-social/indigo/atproto/syntax" 27 "github.com/bluesky-social/indigo/events" 28 "github.com/bluesky-social/indigo/util" 29 "github.com/bluesky-social/indigo/xrpc" 30 "github.com/domodwyer/mailyak/v3" 31 "github.com/go-playground/validator" 32 "github.com/gorilla/sessions" 33 "github.com/haileyok/cocoon/identity" 34 "github.com/haileyok/cocoon/internal/db" 35 "github.com/haileyok/cocoon/internal/helpers" 36 "github.com/haileyok/cocoon/models" 37 "github.com/haileyok/cocoon/oauth/client" 38 "github.com/haileyok/cocoon/oauth/constants" 39 "github.com/haileyok/cocoon/oauth/dpop" 40 "github.com/haileyok/cocoon/oauth/provider" 41 "github.com/haileyok/cocoon/plc" 42 "github.com/ipfs/go-cid" 43 "github.com/labstack/echo-contrib/echoprometheus" 44 echo_session "github.com/labstack/echo-contrib/session" 45 "github.com/labstack/echo/v4" 46 "github.com/labstack/echo/v4/middleware" 47 slogecho "github.com/samber/slog-echo" 48 _ "github.com/tursodatabase/libsql-client-go/libsql" 49 "gorm.io/driver/postgres" 50 "gorm.io/driver/sqlite" 51 "gorm.io/gorm" 52) 53 54const ( 55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week 56) 57 58type S3Config struct { 59 BackupsEnabled bool 60 BlobstoreEnabled bool 61 Endpoint string 62 Region string 63 Bucket string 64 AccessKey string 65 SecretKey string 66 CDNUrl string 67} 68 69type Server struct { 70 http *http.Client 71 httpd *http.Server 72 mail *mailyak.MailYak 73 mailLk *sync.Mutex 74 echo *echo.Echo 75 db *db.DB 76 plcClient *plc.Client 77 logger *slog.Logger 78 config *config 79 privateKey *ecdsa.PrivateKey 80 repoman *RepoMan 81 oauthProvider *provider.Provider 82 evtman *events.EventManager 83 passport *identity.Passport 84 fallbackProxy string 85 86 lastRequestCrawl time.Time 87 requestCrawlMu sync.Mutex 88 89 dbName string 90 dbType string 91 s3Config *S3Config 92} 93 94type Args struct { 95 Logger *slog.Logger 96 97 LogLevel slog.Level 98 Addr string 99 DbName string 100 DbType string 101 DatabaseURL string 102 TursoToken string 103 Version string 104 Did string 105 Hostname string 106 RotationKeyPath string 107 JwkPath string 108 ContactEmail string 109 Relays []string 110 AdminPassword string 111 RequireInvite bool 112 113 SmtpUser string 114 SmtpPass string 115 SmtpHost string 116 SmtpPort string 117 SmtpEmail string 118 SmtpName string 119 120 S3Config *S3Config 121 122 SessionSecret string 123 SessionCookieKey string 124 125 BlockstoreVariant BlockstoreVariant 126 FallbackProxy string 127} 128 129type config struct { 130 LogLevel slog.Level 131 Version string 132 Did string 133 Hostname string 134 ContactEmail string 135 EnforcePeering bool 136 Relays []string 137 AdminPassword string 138 RequireInvite bool 139 SmtpEmail string 140 SmtpName string 141 SessionCookieKey string 142 BlockstoreVariant BlockstoreVariant 143 FallbackProxy string 144} 145 146type CustomValidator struct { 147 validator *validator.Validate 148} 149 150type ValidationError struct { 151 error 152 Field string 153 Tag string 154} 155 156func (cv *CustomValidator) Validate(i any) error { 157 if err := cv.validator.Struct(i); err != nil { 158 var validateErrors validator.ValidationErrors 159 if errors.As(err, &validateErrors) && len(validateErrors) > 0 { 160 first := validateErrors[0] 161 return ValidationError{ 162 error: err, 163 Field: first.Field(), 164 Tag: first.Tag(), 165 } 166 } 167 168 return err 169 } 170 171 return nil 172} 173 174//go:embed templates/* 175var templateFS embed.FS 176 177//go:embed static/* 178var staticFS embed.FS 179 180type TemplateRenderer struct { 181 templates *template.Template 182 isDev bool 183 templatePath string 184} 185 186func (s *Server) loadTemplates() { 187 absPath, _ := filepath.Abs("server/templates/*.html") 188 if s.config.Version == "dev" { 189 tmpl := template.Must(template.ParseGlob(absPath)) 190 s.echo.Renderer = &TemplateRenderer{ 191 templates: tmpl, 192 isDev: true, 193 templatePath: absPath, 194 } 195 } else { 196 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html")) 197 s.echo.Renderer = &TemplateRenderer{ 198 templates: tmpl, 199 isDev: false, 200 } 201 } 202} 203 204func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error { 205 if t.isDev { 206 tmpl, err := template.ParseGlob(t.templatePath) 207 if err != nil { 208 return err 209 } 210 t.templates = tmpl 211 } 212 213 if viewContext, isMap := data.(map[string]any); isMap { 214 viewContext["reverse"] = c.Echo().Reverse 215 } 216 217 return t.templates.ExecuteTemplate(w, name, data) 218} 219 220type filteredHandler struct { 221 level slog.Level 222 handler slog.Handler 223} 224 225func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool { 226 return level >= h.level && h.handler.Enabled(ctx, level) 227} 228 229func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error { 230 return h.handler.Handle(ctx, r) 231} 232 233func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler { 234 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)} 235} 236 237func (h *filteredHandler) WithGroup(name string) slog.Handler { 238 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)} 239} 240 241func New(args *Args) (*Server, error) { 242 if args.Logger == nil { 243 args.Logger = slog.Default() 244 } 245 246 if args.LogLevel != 0 { 247 args.Logger = slog.New(&filteredHandler{ 248 level: args.LogLevel, 249 handler: args.Logger.Handler(), 250 }) 251 } 252 253 logger := args.Logger.With("name", "New") 254 255 if args.Addr == "" { 256 return nil, fmt.Errorf("addr must be set") 257 } 258 259 if args.DbName == "" { 260 return nil, fmt.Errorf("db name must be set") 261 } 262 263 if args.Did == "" { 264 return nil, fmt.Errorf("cocoon did must be set") 265 } 266 267 if args.ContactEmail == "" { 268 return nil, fmt.Errorf("cocoon contact email is required") 269 } 270 271 if _, err := syntax.ParseDID(args.Did); err != nil { 272 return nil, fmt.Errorf("error parsing cocoon did: %w", err) 273 } 274 275 if args.Hostname == "" { 276 return nil, fmt.Errorf("cocoon hostname must be set") 277 } 278 279 if args.AdminPassword == "" { 280 return nil, fmt.Errorf("admin password must be set") 281 } 282 283 if args.SessionSecret == "" { 284 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ") 285 } 286 287 e := echo.New() 288 289 e.Pre(middleware.RemoveTrailingSlash()) 290 e.Pre(slogecho.New(args.Logger.With("component", "slogecho"))) 291 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret)))) 292 e.Use(echoprometheus.NewMiddleware("cocoon")) 293 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{ 294 AllowOrigins: []string{"*"}, 295 AllowHeaders: []string{"*"}, 296 AllowMethods: []string{"*"}, 297 AllowCredentials: true, 298 MaxAge: 100_000_000, 299 })) 300 301 vdtor := validator.New() 302 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool { 303 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil { 304 return false 305 } 306 return true 307 }) 308 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool { 309 if _, err := syntax.ParseDID(fl.Field().String()); err != nil { 310 return false 311 } 312 return true 313 }) 314 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool { 315 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil { 316 return false 317 } 318 return true 319 }) 320 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool { 321 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil { 322 return false 323 } 324 return true 325 }) 326 327 e.Validator = &CustomValidator{validator: vdtor} 328 329 httpd := &http.Server{ 330 Addr: args.Addr, 331 Handler: e, 332 // shitty defaults but okay for now, needed for import repo 333 ReadTimeout: 5 * time.Minute, 334 WriteTimeout: 5 * time.Minute, 335 IdleTimeout: 5 * time.Minute, 336 } 337 338 dbType := args.DbType 339 if dbType == "" { 340 dbType = "sqlite" 341 } 342 343 var gdb *gorm.DB 344 var err error 345 switch dbType { 346 case "postgres": 347 if args.DatabaseURL == "" { 348 return nil, fmt.Errorf("database-url must be set when using postgres") 349 } 350 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{}) 351 if err != nil { 352 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 353 } 354 logger.Info("connected to PostgreSQL database") 355 case "turso": 356 primaryUrl := args.DatabaseURL 357 authToken := args.TursoToken 358 359 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken)) 360 gdb, err = gorm.Open(sqlite.New(sqlite.Config{ 361 Conn: db, 362 }), &gorm.Config{}) 363 if err != nil { 364 return nil, fmt.Errorf("failed to connect to postgres: %w", err) 365 } 366 logger.Info("connected to PostgreSQL database") 367 368 default: 369 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{}) 370 if err != nil { 371 return nil, fmt.Errorf("failed to open sqlite database: %w", err) 372 } 373 gdb.Exec("PRAGMA journal_mode=WAL") 374 gdb.Exec("PRAGMA synchronous=NORMAL") 375 376 logger.Info("connected to SQLite database", "path", args.DbName) 377 } 378 dbw := db.NewDB(gdb) 379 380 rkbytes, err := os.ReadFile(args.RotationKeyPath) 381 if err != nil { 382 return nil, err 383 } 384 385 h := util.RobustHTTPClient() 386 387 plcClient, err := plc.NewClient(&plc.ClientArgs{ 388 H: h, 389 Service: "https://plc.directory", 390 PdsHostname: args.Hostname, 391 RotationKey: rkbytes, 392 }) 393 if err != nil { 394 return nil, err 395 } 396 397 jwkbytes, err := os.ReadFile(args.JwkPath) 398 if err != nil { 399 return nil, err 400 } 401 402 key, err := helpers.ParseJWKFromBytes(jwkbytes) 403 if err != nil { 404 return nil, err 405 } 406 407 var pkey ecdsa.PrivateKey 408 if err := key.Raw(&pkey); err != nil { 409 return nil, err 410 } 411 412 oauthCli := &http.Client{ 413 Timeout: 10 * time.Second, 414 } 415 416 var nonceSecret []byte 417 maybeSecret, err := os.ReadFile("nonce.secret") 418 if err != nil && !os.IsNotExist(err) { 419 logger.Error("error attempting to read nonce secret", "error", err) 420 } else { 421 nonceSecret = maybeSecret 422 } 423 424 evtPersister, err := NewDbPersister(gdb, 72*time.Hour) 425 if err != nil { 426 return nil, fmt.Errorf("failed to create event persister: %w", err) 427 } 428 429 s := &Server{ 430 http: h, 431 httpd: httpd, 432 echo: e, 433 logger: args.Logger, 434 db: dbw, 435 plcClient: plcClient, 436 privateKey: &pkey, 437 config: &config{ 438 LogLevel: args.LogLevel, 439 Version: args.Version, 440 Did: args.Did, 441 Hostname: args.Hostname, 442 ContactEmail: args.ContactEmail, 443 EnforcePeering: false, 444 Relays: args.Relays, 445 AdminPassword: args.AdminPassword, 446 RequireInvite: args.RequireInvite, 447 SmtpName: args.SmtpName, 448 SmtpEmail: args.SmtpEmail, 449 SessionCookieKey: args.SessionCookieKey, 450 BlockstoreVariant: args.BlockstoreVariant, 451 FallbackProxy: args.FallbackProxy, 452 }, 453 evtman: events.NewEventManager(evtPersister), 454 passport: identity.NewPassport(h, identity.NewMemCache(10_000)), 455 456 dbName: args.DbName, 457 dbType: dbType, 458 s3Config: args.S3Config, 459 460 oauthProvider: provider.NewProvider(provider.Args{ 461 Hostname: args.Hostname, 462 ClientManagerArgs: client.ManagerArgs{ 463 Cli: oauthCli, 464 Logger: args.Logger.With("component", "oauth-client-manager"), 465 }, 466 DpopManagerArgs: dpop.ManagerArgs{ 467 NonceSecret: nonceSecret, 468 NonceRotationInterval: constants.NonceMaxRotationInterval / 3, 469 OnNonceSecretCreated: func(newNonce []byte) { 470 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil { 471 logger.Error("error writing new nonce secret", "error", err) 472 } 473 }, 474 Logger: args.Logger.With("component", "dpop-manager"), 475 Hostname: args.Hostname, 476 }, 477 }), 478 } 479 480 s.loadTemplates() 481 482 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it 483 484 // TODO: should validate these args 485 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" { 486 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.") 487 } else { 488 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost)) 489 mail.From(s.config.SmtpEmail) 490 mail.FromName(s.config.SmtpName) 491 492 s.mail = mail 493 s.mailLk = &sync.Mutex{} 494 } 495 496 return s, nil 497} 498 499func (s *Server) addRoutes() { 500 // static 501 if s.config.Version == "dev" { 502 s.echo.Static("/static", "server/static") 503 } else { 504 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS)))) 505 } 506 507 // random stuff 508 s.echo.GET("/", s.handleRoot) 509 s.echo.GET("/xrpc/_health", s.handleHealth) 510 s.echo.GET("/.well-known/did.json", s.handleWellKnown) 511 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid) 512 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource) 513 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer) 514 s.echo.GET("/robots.txt", s.handleRobots) 515 516 // public 517 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle) 518 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount) 519 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession) 520 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer) 521 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey) 522 523 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo) 524 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos) 525 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords) 526 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord) 527 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord) 528 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks) 529 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit) 530 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus) 531 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo) 532 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos) 533 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs) 534 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob) 535 536 // labels 537 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels) 538 539 // account 540 s.echo.GET("/account", s.handleAccount) 541 s.echo.POST("/account/revoke", s.handleAccountRevoke) 542 s.echo.GET("/account/signin", s.handleAccountSigninGet) 543 s.echo.POST("/account/signin", s.handleAccountSigninPost) 544 s.echo.GET("/account/signout", s.handleAccountSignout) 545 546 // oauth account 547 s.echo.GET("/oauth/jwks", s.handleOauthJwks) 548 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet) 549 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost) 550 551 // oauth authorization 552 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware) 553 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware) 554 555 // authed 556 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 557 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 558 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 559 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 560 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 561 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 562 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 563 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 564 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 565 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 566 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE 567 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 568 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 569 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 570 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 571 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 572 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 573 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 574 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 575 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount) 576 577 // repo 578 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 579 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 580 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 581 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 582 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 583 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 584 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 585 586 // stupid silly endpoints 587 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 588 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 589 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 590 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 591 // admin routes 592 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware) 593 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware) 594 595 // are there any routes that we should be allowing without auth? i dont think so but idk 596 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 597 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware) 598} 599 600func (s *Server) Serve(ctx context.Context) error { 601 logger := s.logger.With("name", "Serve") 602 603 s.addRoutes() 604 605 logger.Info("migrating...") 606 607 s.db.AutoMigrate( 608 &models.Actor{}, 609 &models.Repo{}, 610 &models.InviteCode{}, 611 &models.Token{}, 612 &models.RefreshToken{}, 613 &models.Block{}, 614 &models.Record{}, 615 &models.Blob{}, 616 &models.BlobPart{}, 617 &models.ReservedKey{}, 618 &provider.OauthToken{}, 619 &provider.OauthAuthorizationRequest{}, 620 ) 621 622 logger.Info("starting cocoon") 623 624 go func() { 625 if err := s.httpd.ListenAndServe(); err != nil { 626 panic(err) 627 } 628 }() 629 630 go s.backupRoutine() 631 632 go func() { 633 if err := s.requestCrawl(ctx); err != nil { 634 logger.Error("error requesting crawls", "err", err) 635 } 636 }() 637 638 <-ctx.Done() 639 640 fmt.Println("shut down") 641 642 return nil 643} 644 645func (s *Server) requestCrawl(ctx context.Context) error { 646 logger := s.logger.With("component", "request-crawl") 647 s.requestCrawlMu.Lock() 648 defer s.requestCrawlMu.Unlock() 649 650 logger.Info("requesting crawl with configured relays") 651 652 if time.Since(s.lastRequestCrawl) <= 1*time.Minute { 653 return fmt.Errorf("a crawl request has already been made within the last minute") 654 } 655 656 for _, relay := range s.config.Relays { 657 logger := logger.With("relay", relay) 658 logger.Info("requesting crawl from relay") 659 cli := xrpc.Client{Host: relay} 660 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{ 661 Hostname: s.config.Hostname, 662 }); err != nil { 663 logger.Error("error requesting crawl", "err", err) 664 } else { 665 logger.Info("crawl requested successfully") 666 } 667 } 668 669 s.lastRequestCrawl = time.Now() 670 671 return nil 672} 673 674func (s *Server) doBackup() { 675 logger := s.logger.With("name", "doBackup") 676 677 if s.dbType == "postgres" || s.dbType == "turso" { 678 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)") 679 return 680 } 681 682 start := time.Now() 683 684 logger.Info("beginning backup to s3...") 685 686 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano)) 687 defer os.Remove(tmpFile) 688 689 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil { 690 logger.Error("error creating tmp backup file", "err", err) 691 return 692 } 693 694 backupData, err := os.ReadFile(tmpFile) 695 if err != nil { 696 logger.Error("error reading tmp backup file", "err", err) 697 return 698 } 699 700 logger.Info("sending to s3...") 701 702 currTime := time.Now().Format("2006-01-02_15-04-05") 703 key := "cocoon-backup-" + currTime + ".db" 704 705 config := &aws.Config{ 706 Region: aws.String(s.s3Config.Region), 707 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""), 708 } 709 710 if s.s3Config.Endpoint != "" { 711 config.Endpoint = aws.String(s.s3Config.Endpoint) 712 config.S3ForcePathStyle = aws.Bool(true) 713 } 714 715 sess, err := session.NewSession(config) 716 if err != nil { 717 logger.Error("error creating s3 session", "err", err) 718 return 719 } 720 721 svc := s3.New(sess) 722 723 if _, err := svc.PutObject(&s3.PutObjectInput{ 724 Bucket: aws.String(s.s3Config.Bucket), 725 Key: aws.String(key), 726 Body: bytes.NewReader(backupData), 727 }); err != nil { 728 logger.Error("error uploading file to s3", "err", err) 729 return 730 } 731 732 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds()) 733 734 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644) 735} 736 737func (s *Server) backupRoutine() { 738 logger := s.logger.With("name", "backupRoutine") 739 740 if s.s3Config == nil || !s.s3Config.BackupsEnabled { 741 return 742 } 743 744 if s.s3Config.Region == "" { 745 logger.Warn("no s3 region configured but backups are enabled. backups will not run.") 746 return 747 } 748 749 if s.s3Config.Bucket == "" { 750 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.") 751 return 752 } 753 754 if s.s3Config.AccessKey == "" { 755 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.") 756 return 757 } 758 759 if s.s3Config.SecretKey == "" { 760 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.") 761 return 762 } 763 764 shouldBackupNow := false 765 lastBackupStr, err := os.ReadFile("last-backup.txt") 766 if err != nil { 767 shouldBackupNow = true 768 } else { 769 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr)) 770 if err != nil { 771 shouldBackupNow = true 772 } else if time.Since(lastBackup).Seconds() > 3600 { 773 shouldBackupNow = true 774 } 775 } 776 777 if shouldBackupNow { 778 go s.doBackup() 779 } 780 781 ticker := time.NewTicker(time.Hour) 782 for range ticker.C { 783 go s.doBackup() 784 } 785} 786 787func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error { 788 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil { 789 return err 790 } 791 792 return nil 793}