1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "embed"
8 "errors"
9 "fmt"
10 "io"
11 "log/slog"
12 "net/http"
13 "net/smtp"
14 "os"
15 "path/filepath"
16 "sync"
17 "text/template"
18 "time"
19
20 "github.com/aws/aws-sdk-go/aws"
21 "github.com/aws/aws-sdk-go/aws/credentials"
22 "github.com/aws/aws-sdk-go/aws/session"
23 "github.com/aws/aws-sdk-go/service/s3"
24 "github.com/bluesky-social/indigo/api/atproto"
25 "github.com/bluesky-social/indigo/atproto/syntax"
26 "github.com/bluesky-social/indigo/events"
27 "github.com/bluesky-social/indigo/util"
28 "github.com/bluesky-social/indigo/xrpc"
29 "github.com/domodwyer/mailyak/v3"
30 "github.com/go-playground/validator"
31 "github.com/gorilla/sessions"
32 "github.com/haileyok/cocoon/identity"
33 "github.com/haileyok/cocoon/internal/db"
34 "github.com/haileyok/cocoon/internal/helpers"
35 "github.com/haileyok/cocoon/models"
36 "github.com/haileyok/cocoon/oauth/client"
37 "github.com/haileyok/cocoon/oauth/constants"
38 "github.com/haileyok/cocoon/oauth/dpop"
39 "github.com/haileyok/cocoon/oauth/provider"
40 "github.com/haileyok/cocoon/plc"
41 "github.com/ipfs/go-cid"
42 "github.com/labstack/echo-contrib/echoprometheus"
43 echo_session "github.com/labstack/echo-contrib/session"
44 "github.com/labstack/echo/v4"
45 "github.com/labstack/echo/v4/middleware"
46 slogecho "github.com/samber/slog-echo"
47 "gorm.io/driver/postgres"
48 "gorm.io/driver/sqlite"
49 "gorm.io/gorm"
50)
51
52const (
53 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
54)
55
56type S3Config struct {
57 BackupsEnabled bool
58 BlobstoreEnabled bool
59 Endpoint string
60 Region string
61 Bucket string
62 AccessKey string
63 SecretKey string
64 CDNUrl string
65}
66
67type Server struct {
68 http *http.Client
69 httpd *http.Server
70 mail *mailyak.MailYak
71 mailLk *sync.Mutex
72 echo *echo.Echo
73 db *db.DB
74 plcClient *plc.Client
75 logger *slog.Logger
76 config *config
77 privateKey *ecdsa.PrivateKey
78 repoman *RepoMan
79 oauthProvider *provider.Provider
80 evtman *events.EventManager
81 passport *identity.Passport
82 fallbackProxy string
83
84 lastRequestCrawl time.Time
85 requestCrawlMu sync.Mutex
86
87 dbName string
88 dbType string
89 s3Config *S3Config
90}
91
92type Args struct {
93 Logger *slog.Logger
94
95 Addr string
96 DbName string
97 DbType string
98 DatabaseURL string
99 Version string
100 Did string
101 Hostname string
102 RotationKeyPath string
103 JwkPath string
104 ContactEmail string
105 Relays []string
106 AdminPassword string
107 RequireInvite bool
108
109 SmtpUser string
110 SmtpPass string
111 SmtpHost string
112 SmtpPort string
113 SmtpEmail string
114 SmtpName string
115
116 S3Config *S3Config
117
118 SessionSecret string
119
120 BlockstoreVariant BlockstoreVariant
121 FallbackProxy string
122}
123
124type config struct {
125 Version string
126 Did string
127 Hostname string
128 ContactEmail string
129 EnforcePeering bool
130 Relays []string
131 AdminPassword string
132 RequireInvite bool
133 SmtpEmail string
134 SmtpName string
135 BlockstoreVariant BlockstoreVariant
136 FallbackProxy string
137}
138
139type CustomValidator struct {
140 validator *validator.Validate
141}
142
143type ValidationError struct {
144 error
145 Field string
146 Tag string
147}
148
149func (cv *CustomValidator) Validate(i any) error {
150 if err := cv.validator.Struct(i); err != nil {
151 var validateErrors validator.ValidationErrors
152 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
153 first := validateErrors[0]
154 return ValidationError{
155 error: err,
156 Field: first.Field(),
157 Tag: first.Tag(),
158 }
159 }
160
161 return err
162 }
163
164 return nil
165}
166
167//go:embed templates/*
168var templateFS embed.FS
169
170//go:embed static/*
171var staticFS embed.FS
172
173type TemplateRenderer struct {
174 templates *template.Template
175 isDev bool
176 templatePath string
177}
178
179func (s *Server) loadTemplates() {
180 absPath, _ := filepath.Abs("server/templates/*.html")
181 if s.config.Version == "dev" {
182 tmpl := template.Must(template.ParseGlob(absPath))
183 s.echo.Renderer = &TemplateRenderer{
184 templates: tmpl,
185 isDev: true,
186 templatePath: absPath,
187 }
188 } else {
189 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
190 s.echo.Renderer = &TemplateRenderer{
191 templates: tmpl,
192 isDev: false,
193 }
194 }
195}
196
197func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
198 if t.isDev {
199 tmpl, err := template.ParseGlob(t.templatePath)
200 if err != nil {
201 return err
202 }
203 t.templates = tmpl
204 }
205
206 if viewContext, isMap := data.(map[string]any); isMap {
207 viewContext["reverse"] = c.Echo().Reverse
208 }
209
210 return t.templates.ExecuteTemplate(w, name, data)
211}
212
213func New(args *Args) (*Server, error) {
214 if args.Logger == nil {
215 args.Logger = slog.Default()
216 }
217
218 logger := args.Logger.With("name", "New")
219
220 if args.Addr == "" {
221 return nil, fmt.Errorf("addr must be set")
222 }
223
224 if args.DbName == "" {
225 return nil, fmt.Errorf("db name must be set")
226 }
227
228 if args.Did == "" {
229 return nil, fmt.Errorf("cocoon did must be set")
230 }
231
232 if args.ContactEmail == "" {
233 return nil, fmt.Errorf("cocoon contact email is required")
234 }
235
236 if _, err := syntax.ParseDID(args.Did); err != nil {
237 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
238 }
239
240 if args.Hostname == "" {
241 return nil, fmt.Errorf("cocoon hostname must be set")
242 }
243
244 if args.AdminPassword == "" {
245 return nil, fmt.Errorf("admin password must be set")
246 }
247
248 if args.SessionSecret == "" {
249 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
250 }
251
252 e := echo.New()
253
254 e.Pre(middleware.RemoveTrailingSlash())
255 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
256 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
257 e.Use(echoprometheus.NewMiddleware("cocoon"))
258 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
259 AllowOrigins: []string{"*"},
260 AllowHeaders: []string{"*"},
261 AllowMethods: []string{"*"},
262 AllowCredentials: true,
263 MaxAge: 100_000_000,
264 }))
265
266 vdtor := validator.New()
267 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
268 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
269 return false
270 }
271 return true
272 })
273 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
274 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
275 return false
276 }
277 return true
278 })
279 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
280 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
281 return false
282 }
283 return true
284 })
285 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
286 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
287 return false
288 }
289 return true
290 })
291
292 e.Validator = &CustomValidator{validator: vdtor}
293
294 httpd := &http.Server{
295 Addr: args.Addr,
296 Handler: e,
297 // shitty defaults but okay for now, needed for import repo
298 ReadTimeout: 5 * time.Minute,
299 WriteTimeout: 5 * time.Minute,
300 IdleTimeout: 5 * time.Minute,
301 }
302
303 dbType := args.DbType
304 if dbType == "" {
305 dbType = "sqlite"
306 }
307
308 var gdb *gorm.DB
309 var err error
310 switch dbType {
311 case "postgres":
312 if args.DatabaseURL == "" {
313 return nil, fmt.Errorf("database-url must be set when using postgres")
314 }
315 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
316 if err != nil {
317 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
318 }
319 logger.Info("connected to PostgreSQL database")
320 default:
321 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
322 if err != nil {
323 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
324 }
325 logger.Info("connected to SQLite database", "path", args.DbName)
326 }
327 dbw := db.NewDB(gdb)
328
329 rkbytes, err := os.ReadFile(args.RotationKeyPath)
330 if err != nil {
331 return nil, err
332 }
333
334 h := util.RobustHTTPClient()
335
336 plcClient, err := plc.NewClient(&plc.ClientArgs{
337 H: h,
338 Service: "https://plc.directory",
339 PdsHostname: args.Hostname,
340 RotationKey: rkbytes,
341 })
342 if err != nil {
343 return nil, err
344 }
345
346 jwkbytes, err := os.ReadFile(args.JwkPath)
347 if err != nil {
348 return nil, err
349 }
350
351 key, err := helpers.ParseJWKFromBytes(jwkbytes)
352 if err != nil {
353 return nil, err
354 }
355
356 var pkey ecdsa.PrivateKey
357 if err := key.Raw(&pkey); err != nil {
358 return nil, err
359 }
360
361 oauthCli := &http.Client{
362 Timeout: 10 * time.Second,
363 }
364
365 var nonceSecret []byte
366 maybeSecret, err := os.ReadFile("nonce.secret")
367 if err != nil && !os.IsNotExist(err) {
368 logger.Error("error attempting to read nonce secret", "error", err)
369 } else {
370 nonceSecret = maybeSecret
371 }
372
373 s := &Server{
374 http: h,
375 httpd: httpd,
376 echo: e,
377 logger: args.Logger,
378 db: dbw,
379 plcClient: plcClient,
380 privateKey: &pkey,
381 config: &config{
382 Version: args.Version,
383 Did: args.Did,
384 Hostname: args.Hostname,
385 ContactEmail: args.ContactEmail,
386 EnforcePeering: false,
387 Relays: args.Relays,
388 AdminPassword: args.AdminPassword,
389 RequireInvite: args.RequireInvite,
390 SmtpName: args.SmtpName,
391 SmtpEmail: args.SmtpEmail,
392 BlockstoreVariant: args.BlockstoreVariant,
393 FallbackProxy: args.FallbackProxy,
394 },
395 evtman: events.NewEventManager(events.NewMemPersister()),
396 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
397
398 dbName: args.DbName,
399 dbType: dbType,
400 s3Config: args.S3Config,
401
402 oauthProvider: provider.NewProvider(provider.Args{
403 Hostname: args.Hostname,
404 ClientManagerArgs: client.ManagerArgs{
405 Cli: oauthCli,
406 Logger: args.Logger.With("component", "oauth-client-manager"),
407 },
408 DpopManagerArgs: dpop.ManagerArgs{
409 NonceSecret: nonceSecret,
410 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
411 OnNonceSecretCreated: func(newNonce []byte) {
412 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
413 logger.Error("error writing new nonce secret", "error", err)
414 }
415 },
416 Logger: args.Logger.With("component", "dpop-manager"),
417 Hostname: args.Hostname,
418 },
419 }),
420 }
421
422 s.loadTemplates()
423
424 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
425
426 // TODO: should validate these args
427 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
428 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
429 } else {
430 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
431 mail.From(s.config.SmtpEmail)
432 mail.FromName(s.config.SmtpName)
433
434 s.mail = mail
435 s.mailLk = &sync.Mutex{}
436 }
437
438 return s, nil
439}
440
441func (s *Server) addRoutes() {
442 // static
443 if s.config.Version == "dev" {
444 s.echo.Static("/static", "server/static")
445 } else {
446 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
447 }
448
449 // random stuff
450 s.echo.GET("/", s.handleRoot)
451 s.echo.GET("/xrpc/_health", s.handleHealth)
452 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
453 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
454 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
455 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
456 s.echo.GET("/robots.txt", s.handleRobots)
457
458 // public
459 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
460 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
461 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
462 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
463 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
464
465 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
466 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
467 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
468 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
469 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
470 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
471 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
472 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
473 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
474 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
475 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
476 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
477
478 // labels
479 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
480
481 // account
482 s.echo.GET("/account", s.handleAccount)
483 s.echo.POST("/account/revoke", s.handleAccountRevoke)
484 s.echo.GET("/account/signin", s.handleAccountSigninGet)
485 s.echo.POST("/account/signin", s.handleAccountSigninPost)
486 s.echo.GET("/account/signout", s.handleAccountSignout)
487
488 // oauth account
489 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
490 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
491 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
492
493 // oauth authorization
494 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
495 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
496
497 // authed
498 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
499 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
500 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
501 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
502 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
503 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
504 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
505 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
506 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
507 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
508 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
509 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
510 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
511 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
512 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
513 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
514 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
515 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
516 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
517 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
518
519 // repo
520 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
521 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
522 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
523 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
524 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
525 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
526 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
527
528 // stupid silly endpoints
529 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
530 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
531 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
532
533 // admin routes
534 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
535 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
536
537 // are there any routes that we should be allowing without auth? i dont think so but idk
538 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
539 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
540}
541
542func (s *Server) Serve(ctx context.Context) error {
543 logger := s.logger.With("name", "Serve")
544
545 s.addRoutes()
546
547 logger.Info("migrating...")
548
549 s.db.AutoMigrate(
550 &models.Actor{},
551 &models.Repo{},
552 &models.InviteCode{},
553 &models.Token{},
554 &models.RefreshToken{},
555 &models.Block{},
556 &models.Record{},
557 &models.Blob{},
558 &models.BlobPart{},
559 &models.ReservedKey{},
560 &provider.OauthToken{},
561 &provider.OauthAuthorizationRequest{},
562 )
563
564 logger.Info("starting cocoon")
565
566 go func() {
567 if err := s.httpd.ListenAndServe(); err != nil {
568 panic(err)
569 }
570 }()
571
572 go s.backupRoutine()
573
574 go func() {
575 if err := s.requestCrawl(ctx); err != nil {
576 logger.Error("error requesting crawls", "err", err)
577 }
578 }()
579
580 <-ctx.Done()
581
582 fmt.Println("shut down")
583
584 return nil
585}
586
587func (s *Server) requestCrawl(ctx context.Context) error {
588 logger := s.logger.With("component", "request-crawl")
589 s.requestCrawlMu.Lock()
590 defer s.requestCrawlMu.Unlock()
591
592 logger.Info("requesting crawl with configured relays")
593
594 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
595 return fmt.Errorf("a crawl request has already been made within the last minute")
596 }
597
598 for _, relay := range s.config.Relays {
599 logger := logger.With("relay", relay)
600 logger.Info("requesting crawl from relay")
601 cli := xrpc.Client{Host: relay}
602 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
603 Hostname: s.config.Hostname,
604 }); err != nil {
605 logger.Error("error requesting crawl", "err", err)
606 } else {
607 logger.Info("crawl requested successfully")
608 }
609 }
610
611 s.lastRequestCrawl = time.Now()
612
613 return nil
614}
615
616func (s *Server) doBackup() {
617 logger := s.logger.With("name", "doBackup")
618
619 if s.dbType == "postgres" {
620 logger.Info("skipping S3 backup - PostgreSQL backups should be handled externally (pg_dump, managed database backups, etc.)")
621 return
622 }
623
624 start := time.Now()
625
626 logger.Info("beginning backup to s3...")
627
628 var buf bytes.Buffer
629 if err := func() error {
630 logger.Info("reading database bytes...")
631 s.db.Lock()
632 defer s.db.Unlock()
633
634 sf, err := os.Open(s.dbName)
635 if err != nil {
636 return fmt.Errorf("error opening database for backup: %w", err)
637 }
638 defer sf.Close()
639
640 if _, err := io.Copy(&buf, sf); err != nil {
641 return fmt.Errorf("error reading bytes of backup db: %w", err)
642 }
643
644 return nil
645 }(); err != nil {
646 logger.Error("error backing up database", "error", err)
647 return
648 }
649
650 if err := func() error {
651 logger.Info("sending to s3...")
652
653 currTime := time.Now().Format("2006-01-02_15-04-05")
654 key := "cocoon-backup-" + currTime + ".db"
655
656 config := &aws.Config{
657 Region: aws.String(s.s3Config.Region),
658 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
659 }
660
661 if s.s3Config.Endpoint != "" {
662 config.Endpoint = aws.String(s.s3Config.Endpoint)
663 config.S3ForcePathStyle = aws.Bool(true)
664 }
665
666 sess, err := session.NewSession(config)
667 if err != nil {
668 return err
669 }
670
671 svc := s3.New(sess)
672
673 if _, err := svc.PutObject(&s3.PutObjectInput{
674 Bucket: aws.String(s.s3Config.Bucket),
675 Key: aws.String(key),
676 Body: bytes.NewReader(buf.Bytes()),
677 }); err != nil {
678 return fmt.Errorf("error uploading file to s3: %w", err)
679 }
680
681 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Now().Sub(start).Seconds())
682
683 return nil
684 }(); err != nil {
685 logger.Error("error uploading database backup", "error", err)
686 return
687 }
688
689 os.WriteFile("last-backup.txt", []byte(time.Now().String()), 0644)
690}
691
692func (s *Server) backupRoutine() {
693 logger := s.logger.With("name", "backupRoutine")
694
695 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
696 return
697 }
698
699 if s.s3Config.Region == "" {
700 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
701 return
702 }
703
704 if s.s3Config.Bucket == "" {
705 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
706 return
707 }
708
709 if s.s3Config.AccessKey == "" {
710 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
711 return
712 }
713
714 if s.s3Config.SecretKey == "" {
715 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
716 return
717 }
718
719 shouldBackupNow := false
720 lastBackupStr, err := os.ReadFile("last-backup.txt")
721 if err != nil {
722 shouldBackupNow = true
723 } else {
724 lastBackup, err := time.Parse("2006-01-02 15:04:05.999999999 -0700 MST", string(lastBackupStr))
725 if err != nil {
726 shouldBackupNow = true
727 } else if time.Now().Sub(lastBackup).Seconds() > 3600 {
728 shouldBackupNow = true
729 }
730 }
731
732 if shouldBackupNow {
733 go s.doBackup()
734 }
735
736 ticker := time.NewTicker(time.Hour)
737 for range ticker.C {
738 go s.doBackup()
739 }
740}
741
742func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
743 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
744 return err
745 }
746
747 return nil
748}