forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbName string
90 dbType string
91 s3Config *S3Config
92}
93
94type Args struct {
95 Logger *slog.Logger
96
97 LogLevel slog.Level
98 Addr string
99 DbName string
100 DbType string
101 DatabaseURL string
102 TursoToken string
103 Version string
104 Did string
105 Hostname string
106 RotationKeyPath string
107 JwkPath string
108 ContactEmail string
109 Relays []string
110 AdminPassword string
111 RequireInvite bool
112
113 SmtpUser string
114 SmtpPass string
115 SmtpHost string
116 SmtpPort string
117 SmtpEmail string
118 SmtpName string
119
120 S3Config *S3Config
121
122 SessionSecret string
123 SessionCookieKey string
124
125 BlockstoreVariant BlockstoreVariant
126 FallbackProxy string
127}
128
129type config struct {
130 LogLevel slog.Level
131 Version string
132 Did string
133 Hostname string
134 ContactEmail string
135 EnforcePeering bool
136 Relays []string
137 AdminPassword string
138 RequireInvite bool
139 SmtpEmail string
140 SmtpName string
141 SessionCookieKey string
142 BlockstoreVariant BlockstoreVariant
143 FallbackProxy string
144}
145
146type CustomValidator struct {
147 validator *validator.Validate
148}
149
150type ValidationError struct {
151 error
152 Field string
153 Tag string
154}
155
156func (cv *CustomValidator) Validate(i any) error {
157 if err := cv.validator.Struct(i); err != nil {
158 var validateErrors validator.ValidationErrors
159 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
160 first := validateErrors[0]
161 return ValidationError{
162 error: err,
163 Field: first.Field(),
164 Tag: first.Tag(),
165 }
166 }
167
168 return err
169 }
170
171 return nil
172}
173
174//go:embed templates/*
175var templateFS embed.FS
176
177//go:embed static/*
178var staticFS embed.FS
179
180type TemplateRenderer struct {
181 templates *template.Template
182 isDev bool
183 templatePath string
184}
185
186func (s *Server) loadTemplates() {
187 absPath, _ := filepath.Abs("server/templates/*.html")
188 if s.config.Version == "dev" {
189 tmpl := template.Must(template.ParseGlob(absPath))
190 s.echo.Renderer = &TemplateRenderer{
191 templates: tmpl,
192 isDev: true,
193 templatePath: absPath,
194 }
195 } else {
196 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
197 s.echo.Renderer = &TemplateRenderer{
198 templates: tmpl,
199 isDev: false,
200 }
201 }
202}
203
204func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
205 if t.isDev {
206 tmpl, err := template.ParseGlob(t.templatePath)
207 if err != nil {
208 return err
209 }
210 t.templates = tmpl
211 }
212
213 if viewContext, isMap := data.(map[string]any); isMap {
214 viewContext["reverse"] = c.Echo().Reverse
215 }
216
217 return t.templates.ExecuteTemplate(w, name, data)
218}
219
220type filteredHandler struct {
221 level slog.Level
222 handler slog.Handler
223}
224
225func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
226 return level >= h.level && h.handler.Enabled(ctx, level)
227}
228
229func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
230 return h.handler.Handle(ctx, r)
231}
232
233func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
234 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
235}
236
237func (h *filteredHandler) WithGroup(name string) slog.Handler {
238 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
239}
240
241func New(args *Args) (*Server, error) {
242 if args.Logger == nil {
243 args.Logger = slog.Default()
244 }
245
246 if args.LogLevel != 0 {
247 args.Logger = slog.New(&filteredHandler{
248 level: args.LogLevel,
249 handler: args.Logger.Handler(),
250 })
251 }
252
253 logger := args.Logger.With("name", "New")
254
255 if args.Addr == "" {
256 return nil, fmt.Errorf("addr must be set")
257 }
258
259 if args.DbName == "" {
260 return nil, fmt.Errorf("db name must be set")
261 }
262
263 if args.Did == "" {
264 return nil, fmt.Errorf("cocoon did must be set")
265 }
266
267 if args.ContactEmail == "" {
268 return nil, fmt.Errorf("cocoon contact email is required")
269 }
270
271 if _, err := syntax.ParseDID(args.Did); err != nil {
272 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
273 }
274
275 if args.Hostname == "" {
276 return nil, fmt.Errorf("cocoon hostname must be set")
277 }
278
279 if args.AdminPassword == "" {
280 return nil, fmt.Errorf("admin password must be set")
281 }
282
283 if args.SessionSecret == "" {
284 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
285 }
286
287 e := echo.New()
288
289 e.Pre(middleware.RemoveTrailingSlash())
290 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
291 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
292 e.Use(echoprometheus.NewMiddleware("cocoon"))
293 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
294 AllowOrigins: []string{"*"},
295 AllowHeaders: []string{"*"},
296 AllowMethods: []string{"*"},
297 AllowCredentials: true,
298 MaxAge: 100_000_000,
299 }))
300
301 vdtor := validator.New()
302 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
303 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
304 return false
305 }
306 return true
307 })
308 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
309 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
310 return false
311 }
312 return true
313 })
314 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
315 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
316 return false
317 }
318 return true
319 })
320 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
321 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
322 return false
323 }
324 return true
325 })
326
327 e.Validator = &CustomValidator{validator: vdtor}
328
329 httpd := &http.Server{
330 Addr: args.Addr,
331 Handler: e,
332 // shitty defaults but okay for now, needed for import repo
333 ReadTimeout: 5 * time.Minute,
334 WriteTimeout: 5 * time.Minute,
335 IdleTimeout: 5 * time.Minute,
336 }
337
338 dbType := args.DbType
339 if dbType == "" {
340 dbType = "sqlite"
341 }
342
343 var gdb *gorm.DB
344 var err error
345 switch dbType {
346 case "postgres":
347 if args.DatabaseURL == "" {
348 return nil, fmt.Errorf("database-url must be set when using postgres")
349 }
350 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
351 if err != nil {
352 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
353 }
354 logger.Info("connected to PostgreSQL database")
355 case "turso":
356 primaryUrl := args.DatabaseURL
357 authToken := args.TursoToken
358
359 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
360 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
361 Conn: db,
362 }), &gorm.Config{})
363 if err != nil {
364 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
365 }
366 logger.Info("connected to PostgreSQL database")
367
368 default:
369 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
370 if err != nil {
371 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
372 }
373 gdb.Exec("PRAGMA journal_mode=WAL")
374 gdb.Exec("PRAGMA synchronous=NORMAL")
375
376 logger.Info("connected to SQLite database", "path", args.DbName)
377 }
378 dbw := db.NewDB(gdb)
379
380 rkbytes, err := os.ReadFile(args.RotationKeyPath)
381 if err != nil {
382 return nil, err
383 }
384
385 h := util.RobustHTTPClient()
386
387 plcClient, err := plc.NewClient(&plc.ClientArgs{
388 H: h,
389 Service: "https://plc.directory",
390 PdsHostname: args.Hostname,
391 RotationKey: rkbytes,
392 })
393 if err != nil {
394 return nil, err
395 }
396
397 jwkbytes, err := os.ReadFile(args.JwkPath)
398 if err != nil {
399 return nil, err
400 }
401
402 key, err := helpers.ParseJWKFromBytes(jwkbytes)
403 if err != nil {
404 return nil, err
405 }
406
407 var pkey ecdsa.PrivateKey
408 if err := key.Raw(&pkey); err != nil {
409 return nil, err
410 }
411
412 oauthCli := &http.Client{
413 Timeout: 10 * time.Second,
414 }
415
416 var nonceSecret []byte
417 maybeSecret, err := os.ReadFile("nonce.secret")
418 if err != nil && !os.IsNotExist(err) {
419 logger.Error("error attempting to read nonce secret", "error", err)
420 } else {
421 nonceSecret = maybeSecret
422 }
423
424 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
425 if err != nil {
426 return nil, fmt.Errorf("failed to create event persister: %w", err)
427 }
428
429 s := &Server{
430 http: h,
431 httpd: httpd,
432 echo: e,
433 logger: args.Logger,
434 db: dbw,
435 plcClient: plcClient,
436 privateKey: &pkey,
437 config: &config{
438 LogLevel: args.LogLevel,
439 Version: args.Version,
440 Did: args.Did,
441 Hostname: args.Hostname,
442 ContactEmail: args.ContactEmail,
443 EnforcePeering: false,
444 Relays: args.Relays,
445 AdminPassword: args.AdminPassword,
446 RequireInvite: args.RequireInvite,
447 SmtpName: args.SmtpName,
448 SmtpEmail: args.SmtpEmail,
449 SessionCookieKey: args.SessionCookieKey,
450 BlockstoreVariant: args.BlockstoreVariant,
451 FallbackProxy: args.FallbackProxy,
452 },
453 evtman: events.NewEventManager(evtPersister),
454 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
455
456 dbName: args.DbName,
457 dbType: dbType,
458 s3Config: args.S3Config,
459
460 oauthProvider: provider.NewProvider(provider.Args{
461 Hostname: args.Hostname,
462 ClientManagerArgs: client.ManagerArgs{
463 Cli: oauthCli,
464 Logger: args.Logger.With("component", "oauth-client-manager"),
465 },
466 DpopManagerArgs: dpop.ManagerArgs{
467 NonceSecret: nonceSecret,
468 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
469 OnNonceSecretCreated: func(newNonce []byte) {
470 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
471 logger.Error("error writing new nonce secret", "error", err)
472 }
473 },
474 Logger: args.Logger.With("component", "dpop-manager"),
475 Hostname: args.Hostname,
476 },
477 }),
478 }
479
480 s.loadTemplates()
481
482 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
483
484 // TODO: should validate these args
485 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
486 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
487 } else {
488 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
489 mail.From(s.config.SmtpEmail)
490 mail.FromName(s.config.SmtpName)
491
492 s.mail = mail
493 s.mailLk = &sync.Mutex{}
494 }
495
496 return s, nil
497}
498
499func (s *Server) addRoutes() {
500 // static
501 if s.config.Version == "dev" {
502 s.echo.Static("/static", "server/static")
503 } else {
504 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
505 }
506
507 // random stuff
508 s.echo.GET("/", s.handleRoot)
509 s.echo.GET("/xrpc/_health", s.handleHealth)
510 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
511 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
512 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
513 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
514 s.echo.GET("/robots.txt", s.handleRobots)
515
516 // public
517 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
518 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
519 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
520 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
521 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
522
523 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
524 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
525 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
526 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
527 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
528 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
529 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
530 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
531 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
532 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
533 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
534 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
535
536 // labels
537 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
538
539 // account
540 s.echo.GET("/account", s.handleAccount)
541 s.echo.POST("/account/revoke", s.handleAccountRevoke)
542 s.echo.GET("/account/signin", s.handleAccountSigninGet)
543 s.echo.POST("/account/signin", s.handleAccountSigninPost)
544 s.echo.GET("/account/signout", s.handleAccountSignout)
545
546 // oauth account
547 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
548 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
549 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
550
551 // oauth authorization
552 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
553 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
554
555 // authed
556 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
557 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
558 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
559 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
560 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
561 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
562 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
563 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
564 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
565 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
566 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
567 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
568 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
570 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
571 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
572 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
575 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
576
577 // repo
578 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
580 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
584 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
585
586 // stupid silly endpoints
587 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
589 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
590 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
591 // admin routes
592 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
593 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
594
595 // are there any routes that we should be allowing without auth? i dont think so but idk
596 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598}
599
600func (s *Server) Serve(ctx context.Context) error {
601 logger := s.logger.With("name", "Serve")
602
603 s.addRoutes()
604
605 logger.Info("migrating...")
606
607 s.db.AutoMigrate(
608 &models.Actor{},
609 &models.Repo{},
610 &models.InviteCode{},
611 &models.Token{},
612 &models.RefreshToken{},
613 &models.Block{},
614 &models.Record{},
615 &models.Blob{},
616 &models.BlobPart{},
617 &models.ReservedKey{},
618 &provider.OauthToken{},
619 &provider.OauthAuthorizationRequest{},
620 )
621
622 logger.Info("starting cocoon")
623
624 go func() {
625 if err := s.httpd.ListenAndServe(); err != nil {
626 panic(err)
627 }
628 }()
629
630 go s.backupRoutine()
631
632 go func() {
633 if err := s.requestCrawl(ctx); err != nil {
634 logger.Error("error requesting crawls", "err", err)
635 }
636 }()
637
638 <-ctx.Done()
639
640 fmt.Println("shut down")
641
642 return nil
643}
644
645func (s *Server) requestCrawl(ctx context.Context) error {
646 logger := s.logger.With("component", "request-crawl")
647 s.requestCrawlMu.Lock()
648 defer s.requestCrawlMu.Unlock()
649
650 logger.Info("requesting crawl with configured relays")
651
652 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
653 return fmt.Errorf("a crawl request has already been made within the last minute")
654 }
655
656 for _, relay := range s.config.Relays {
657 logger := logger.With("relay", relay)
658 logger.Info("requesting crawl from relay")
659 cli := xrpc.Client{Host: relay}
660 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
661 Hostname: s.config.Hostname,
662 }); err != nil {
663 logger.Error("error requesting crawl", "err", err)
664 } else {
665 logger.Info("crawl requested successfully")
666 }
667 }
668
669 s.lastRequestCrawl = time.Now()
670
671 return nil
672}
673
674func (s *Server) doBackup() {
675 logger := s.logger.With("name", "doBackup")
676
677 if s.dbType == "postgres" || s.dbType == "turso" {
678 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
679 return
680 }
681
682 start := time.Now()
683
684 logger.Info("beginning backup to s3...")
685
686 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
687 defer os.Remove(tmpFile)
688
689 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
690 logger.Error("error creating tmp backup file", "err", err)
691 return
692 }
693
694 backupData, err := os.ReadFile(tmpFile)
695 if err != nil {
696 logger.Error("error reading tmp backup file", "err", err)
697 return
698 }
699
700 logger.Info("sending to s3...")
701
702 currTime := time.Now().Format("2006-01-02_15-04-05")
703 key := "cocoon-backup-" + currTime + ".db"
704
705 config := &aws.Config{
706 Region: aws.String(s.s3Config.Region),
707 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
708 }
709
710 if s.s3Config.Endpoint != "" {
711 config.Endpoint = aws.String(s.s3Config.Endpoint)
712 config.S3ForcePathStyle = aws.Bool(true)
713 }
714
715 sess, err := session.NewSession(config)
716 if err != nil {
717 logger.Error("error creating s3 session", "err", err)
718 return
719 }
720
721 svc := s3.New(sess)
722
723 if _, err := svc.PutObject(&s3.PutObjectInput{
724 Bucket: aws.String(s.s3Config.Bucket),
725 Key: aws.String(key),
726 Body: bytes.NewReader(backupData),
727 }); err != nil {
728 logger.Error("error uploading file to s3", "err", err)
729 return
730 }
731
732 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
733
734 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
735}
736
737func (s *Server) backupRoutine() {
738 logger := s.logger.With("name", "backupRoutine")
739
740 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
741 return
742 }
743
744 if s.s3Config.Region == "" {
745 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
746 return
747 }
748
749 if s.s3Config.Bucket == "" {
750 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
751 return
752 }
753
754 if s.s3Config.AccessKey == "" {
755 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
756 return
757 }
758
759 if s.s3Config.SecretKey == "" {
760 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
761 return
762 }
763
764 shouldBackupNow := false
765 lastBackupStr, err := os.ReadFile("last-backup.txt")
766 if err != nil {
767 shouldBackupNow = true
768 } else {
769 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
770 if err != nil {
771 shouldBackupNow = true
772 } else if time.Since(lastBackup).Seconds() > 3600 {
773 shouldBackupNow = true
774 }
775 }
776
777 if shouldBackupNow {
778 go s.doBackup()
779 }
780
781 ticker := time.NewTicker(time.Hour)
782 for range ticker.C {
783 go s.doBackup()
784 }
785}
786
787func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
788 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
789 return err
790 }
791
792 return nil
793}