forked from
hailey.at/cocoon
An atproto PDS written in Go
1package server
2
3import (
4 "bytes"
5 "context"
6 "crypto/ecdsa"
7 "database/sql"
8 "embed"
9 "errors"
10 "fmt"
11 "io"
12 "log/slog"
13 "net/http"
14 "net/smtp"
15 "os"
16 "path/filepath"
17 "sync"
18 "text/template"
19 "time"
20
21 "github.com/aws/aws-sdk-go/aws"
22 "github.com/aws/aws-sdk-go/aws/credentials"
23 "github.com/aws/aws-sdk-go/aws/session"
24 "github.com/aws/aws-sdk-go/service/s3"
25 "github.com/bluesky-social/indigo/api/atproto"
26 "github.com/bluesky-social/indigo/atproto/syntax"
27 "github.com/bluesky-social/indigo/events"
28 "github.com/bluesky-social/indigo/util"
29 "github.com/bluesky-social/indigo/xrpc"
30 "github.com/domodwyer/mailyak/v3"
31 "github.com/go-playground/validator"
32 "github.com/gorilla/sessions"
33 "github.com/haileyok/cocoon/identity"
34 "github.com/haileyok/cocoon/internal/db"
35 "github.com/haileyok/cocoon/internal/helpers"
36 "github.com/haileyok/cocoon/models"
37 "github.com/haileyok/cocoon/oauth/client"
38 "github.com/haileyok/cocoon/oauth/constants"
39 "github.com/haileyok/cocoon/oauth/dpop"
40 "github.com/haileyok/cocoon/oauth/provider"
41 "github.com/haileyok/cocoon/plc"
42 "github.com/ipfs/go-cid"
43 "github.com/labstack/echo-contrib/echoprometheus"
44 echo_session "github.com/labstack/echo-contrib/session"
45 "github.com/labstack/echo/v4"
46 "github.com/labstack/echo/v4/middleware"
47 slogecho "github.com/samber/slog-echo"
48 _ "github.com/tursodatabase/libsql-client-go/libsql"
49 "gorm.io/driver/postgres"
50 "gorm.io/driver/sqlite"
51 "gorm.io/gorm"
52)
53
54const (
55 AccountSessionMaxAge = 30 * 24 * time.Hour // one week
56)
57
58type S3Config struct {
59 BackupsEnabled bool
60 BlobstoreEnabled bool
61 Endpoint string
62 Region string
63 Bucket string
64 AccessKey string
65 SecretKey string
66 CDNUrl string
67}
68
69type Server struct {
70 http *http.Client
71 httpd *http.Server
72 mail *mailyak.MailYak
73 mailLk *sync.Mutex
74 echo *echo.Echo
75 db *db.DB
76 plcClient *plc.Client
77 logger *slog.Logger
78 config *config
79 privateKey *ecdsa.PrivateKey
80 repoman *RepoMan
81 oauthProvider *provider.Provider
82 evtman *events.EventManager
83 passport *identity.Passport
84 fallbackProxy string
85
86 lastRequestCrawl time.Time
87 requestCrawlMu sync.Mutex
88
89 dbName string
90 dbType string
91 s3Config *S3Config
92}
93
94type Args struct {
95 Logger *slog.Logger
96
97 LogLevel slog.Level
98 Addr string
99 DbName string
100 DbType string
101 DatabaseURL string
102 TursoToken string
103 Version string
104 Did string
105 Hostname string
106 RotationKeyPath string
107 JwkPath string
108 ContactEmail string
109 Relays []string
110 AdminPassword string
111 RequireInvite bool
112
113 SmtpUser string
114 SmtpPass string
115 SmtpHost string
116 SmtpPort string
117 SmtpEmail string
118 SmtpName string
119
120 S3Config *S3Config
121
122 SessionSecret string
123 SessionCookieKey string
124
125 BlockstoreVariant BlockstoreVariant
126 FallbackProxy string
127
128 PushBasedEvents bool
129 SubscribeReposServiceURL string
130}
131
132type config struct {
133 LogLevel slog.Level
134 Version string
135 Did string
136 Hostname string
137 ContactEmail string
138 EnforcePeering bool
139 Relays []string
140 AdminPassword string
141 RequireInvite bool
142 SmtpEmail string
143 SmtpName string
144 SessionCookieKey string
145 BlockstoreVariant BlockstoreVariant
146 FallbackProxy string
147
148 PushBasedEvents bool
149 SubscribeReposServiceURL string
150}
151
152type CustomValidator struct {
153 validator *validator.Validate
154}
155
156type ValidationError struct {
157 error
158 Field string
159 Tag string
160}
161
162func (cv *CustomValidator) Validate(i any) error {
163 if err := cv.validator.Struct(i); err != nil {
164 var validateErrors validator.ValidationErrors
165 if errors.As(err, &validateErrors) && len(validateErrors) > 0 {
166 first := validateErrors[0]
167 return ValidationError{
168 error: err,
169 Field: first.Field(),
170 Tag: first.Tag(),
171 }
172 }
173
174 return err
175 }
176
177 return nil
178}
179
180//go:embed templates/*
181var templateFS embed.FS
182
183//go:embed static/*
184var staticFS embed.FS
185
186type TemplateRenderer struct {
187 templates *template.Template
188 isDev bool
189 templatePath string
190}
191
192func (s *Server) loadTemplates() {
193 absPath, _ := filepath.Abs("server/templates/*.html")
194 if s.config.Version == "dev" {
195 tmpl := template.Must(template.ParseGlob(absPath))
196 s.echo.Renderer = &TemplateRenderer{
197 templates: tmpl,
198 isDev: true,
199 templatePath: absPath,
200 }
201 } else {
202 tmpl := template.Must(template.ParseFS(templateFS, "templates/*.html"))
203 s.echo.Renderer = &TemplateRenderer{
204 templates: tmpl,
205 isDev: false,
206 }
207 }
208}
209
210func (t *TemplateRenderer) Render(w io.Writer, name string, data any, c echo.Context) error {
211 if t.isDev {
212 tmpl, err := template.ParseGlob(t.templatePath)
213 if err != nil {
214 return err
215 }
216 t.templates = tmpl
217 }
218
219 if viewContext, isMap := data.(map[string]any); isMap {
220 viewContext["reverse"] = c.Echo().Reverse
221 }
222
223 return t.templates.ExecuteTemplate(w, name, data)
224}
225
226type filteredHandler struct {
227 level slog.Level
228 handler slog.Handler
229}
230
231func (h *filteredHandler) Enabled(ctx context.Context, level slog.Level) bool {
232 return level >= h.level && h.handler.Enabled(ctx, level)
233}
234
235func (h *filteredHandler) Handle(ctx context.Context, r slog.Record) error {
236 return h.handler.Handle(ctx, r)
237}
238
239func (h *filteredHandler) WithAttrs(attrs []slog.Attr) slog.Handler {
240 return &filteredHandler{level: h.level, handler: h.handler.WithAttrs(attrs)}
241}
242
243func (h *filteredHandler) WithGroup(name string) slog.Handler {
244 return &filteredHandler{level: h.level, handler: h.handler.WithGroup(name)}
245}
246
247func New(args *Args) (*Server, error) {
248 if args.Logger == nil {
249 args.Logger = slog.Default()
250 }
251
252 if args.LogLevel != 0 {
253 args.Logger = slog.New(&filteredHandler{
254 level: args.LogLevel,
255 handler: args.Logger.Handler(),
256 })
257 }
258
259 logger := args.Logger.With("name", "New")
260
261 if args.Addr == "" {
262 return nil, fmt.Errorf("addr must be set")
263 }
264
265 if args.DbName == "" {
266 return nil, fmt.Errorf("db name must be set")
267 }
268
269 if args.Did == "" {
270 return nil, fmt.Errorf("cocoon did must be set")
271 }
272
273 if args.ContactEmail == "" {
274 return nil, fmt.Errorf("cocoon contact email is required")
275 }
276
277 if _, err := syntax.ParseDID(args.Did); err != nil {
278 return nil, fmt.Errorf("error parsing cocoon did: %w", err)
279 }
280
281 if args.Hostname == "" {
282 return nil, fmt.Errorf("cocoon hostname must be set")
283 }
284
285 if args.AdminPassword == "" {
286 return nil, fmt.Errorf("admin password must be set")
287 }
288
289 if args.SessionSecret == "" {
290 panic("SESSION SECRET WAS NOT SET. THIS IS REQUIRED. ")
291 }
292
293 e := echo.New()
294
295 e.Pre(middleware.RemoveTrailingSlash())
296 e.Pre(slogecho.New(args.Logger.With("component", "slogecho")))
297 e.Use(echo_session.Middleware(sessions.NewCookieStore([]byte(args.SessionSecret))))
298 e.Use(echoprometheus.NewMiddleware("cocoon"))
299 e.Use(middleware.CORSWithConfig(middleware.CORSConfig{
300 AllowOrigins: []string{"*"},
301 AllowHeaders: []string{"*"},
302 AllowMethods: []string{"*"},
303 AllowCredentials: true,
304 MaxAge: 100_000_000,
305 }))
306
307 vdtor := validator.New()
308 vdtor.RegisterValidation("atproto-handle", func(fl validator.FieldLevel) bool {
309 if _, err := syntax.ParseHandle(fl.Field().String()); err != nil {
310 return false
311 }
312 return true
313 })
314 vdtor.RegisterValidation("atproto-did", func(fl validator.FieldLevel) bool {
315 if _, err := syntax.ParseDID(fl.Field().String()); err != nil {
316 return false
317 }
318 return true
319 })
320 vdtor.RegisterValidation("atproto-rkey", func(fl validator.FieldLevel) bool {
321 if _, err := syntax.ParseRecordKey(fl.Field().String()); err != nil {
322 return false
323 }
324 return true
325 })
326 vdtor.RegisterValidation("atproto-nsid", func(fl validator.FieldLevel) bool {
327 if _, err := syntax.ParseNSID(fl.Field().String()); err != nil {
328 return false
329 }
330 return true
331 })
332
333 e.Validator = &CustomValidator{validator: vdtor}
334
335 httpd := &http.Server{
336 Addr: args.Addr,
337 Handler: e,
338 // shitty defaults but okay for now, needed for import repo
339 ReadTimeout: 5 * time.Minute,
340 WriteTimeout: 5 * time.Minute,
341 IdleTimeout: 5 * time.Minute,
342 }
343
344 dbType := args.DbType
345 if dbType == "" {
346 dbType = "sqlite"
347 }
348
349 var gdb *gorm.DB
350 var err error
351 switch dbType {
352 case "postgres":
353 if args.DatabaseURL == "" {
354 return nil, fmt.Errorf("database-url must be set when using postgres")
355 }
356 gdb, err = gorm.Open(postgres.Open(args.DatabaseURL), &gorm.Config{})
357 if err != nil {
358 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
359 }
360 logger.Info("connected to PostgreSQL database")
361 case "turso":
362 primaryUrl := args.DatabaseURL
363 authToken := args.TursoToken
364
365 db, err := sql.Open("libsql", fmt.Sprintf("%s?authToken=%s", primaryUrl, authToken))
366 gdb, err = gorm.Open(sqlite.New(sqlite.Config{
367 Conn: db,
368 }), &gorm.Config{})
369 if err != nil {
370 return nil, fmt.Errorf("failed to connect to postgres: %w", err)
371 }
372 logger.Info("connected to PostgreSQL database")
373
374 default:
375 gdb, err = gorm.Open(sqlite.Open(args.DbName), &gorm.Config{})
376 if err != nil {
377 return nil, fmt.Errorf("failed to open sqlite database: %w", err)
378 }
379 gdb.Exec("PRAGMA journal_mode=WAL")
380 gdb.Exec("PRAGMA synchronous=NORMAL")
381
382 logger.Info("connected to SQLite database", "path", args.DbName)
383 }
384 dbw := db.NewDB(gdb)
385
386 rkbytes, err := os.ReadFile(args.RotationKeyPath)
387 if err != nil {
388 return nil, err
389 }
390
391 h := util.RobustHTTPClient()
392
393 plcClient, err := plc.NewClient(&plc.ClientArgs{
394 H: h,
395 Service: "https://plc.directory",
396 PdsHostname: args.Hostname,
397 RotationKey: rkbytes,
398 })
399 if err != nil {
400 return nil, err
401 }
402
403 jwkbytes, err := os.ReadFile(args.JwkPath)
404 if err != nil {
405 return nil, err
406 }
407
408 key, err := helpers.ParseJWKFromBytes(jwkbytes)
409 if err != nil {
410 return nil, err
411 }
412
413 var pkey ecdsa.PrivateKey
414 if err := key.Raw(&pkey); err != nil {
415 return nil, err
416 }
417
418 oauthCli := &http.Client{
419 Timeout: 10 * time.Second,
420 }
421
422 var nonceSecret []byte
423 maybeSecret, err := os.ReadFile("nonce.secret")
424 if err != nil && !os.IsNotExist(err) {
425 logger.Error("error attempting to read nonce secret", "error", err)
426 } else {
427 nonceSecret = maybeSecret
428 }
429
430 evtPersister, err := NewDbPersister(gdb, 72*time.Hour)
431 if err != nil {
432 return nil, fmt.Errorf("failed to create event persister: %w", err)
433 }
434
435 s := &Server{
436 http: h,
437 httpd: httpd,
438 echo: e,
439 logger: args.Logger,
440 db: dbw,
441 plcClient: plcClient,
442 privateKey: &pkey,
443 config: &config{
444 LogLevel: args.LogLevel,
445 Version: args.Version,
446 Did: args.Did,
447 Hostname: args.Hostname,
448 ContactEmail: args.ContactEmail,
449 EnforcePeering: false,
450 Relays: args.Relays,
451 AdminPassword: args.AdminPassword,
452 RequireInvite: args.RequireInvite,
453 SmtpName: args.SmtpName,
454 SmtpEmail: args.SmtpEmail,
455 SessionCookieKey: args.SessionCookieKey,
456 BlockstoreVariant: args.BlockstoreVariant,
457 FallbackProxy: args.FallbackProxy,
458 PushBasedEvents: args.PushBasedEvents,
459 SubscribeReposServiceURL: args.SubscribeReposServiceURL,
460 },
461 evtman: events.NewEventManager(evtPersister),
462 passport: identity.NewPassport(h, identity.NewMemCache(10_000)),
463
464 dbName: args.DbName,
465 dbType: dbType,
466 s3Config: args.S3Config,
467
468 oauthProvider: provider.NewProvider(provider.Args{
469 Hostname: args.Hostname,
470 ClientManagerArgs: client.ManagerArgs{
471 Cli: oauthCli,
472 Logger: args.Logger.With("component", "oauth-client-manager"),
473 },
474 DpopManagerArgs: dpop.ManagerArgs{
475 NonceSecret: nonceSecret,
476 NonceRotationInterval: constants.NonceMaxRotationInterval / 3,
477 OnNonceSecretCreated: func(newNonce []byte) {
478 if err := os.WriteFile("nonce.secret", newNonce, 0644); err != nil {
479 logger.Error("error writing new nonce secret", "error", err)
480 }
481 },
482 Logger: args.Logger.With("component", "dpop-manager"),
483 Hostname: args.Hostname,
484 },
485 }),
486 }
487
488 s.loadTemplates()
489
490 s.repoman = NewRepoMan(s) // TODO: this is way too lazy, stop it
491
492 // TODO: should validate these args
493 if args.SmtpUser == "" || args.SmtpPass == "" || args.SmtpHost == "" || args.SmtpPort == "" || args.SmtpEmail == "" || args.SmtpName == "" {
494 args.Logger.Warn("not enough smtp args were provided. mailing will not work for your server.")
495 } else {
496 mail := mailyak.New(args.SmtpHost+":"+args.SmtpPort, smtp.PlainAuth("", args.SmtpUser, args.SmtpPass, args.SmtpHost))
497 mail.From(s.config.SmtpEmail)
498 mail.FromName(s.config.SmtpName)
499
500 s.mail = mail
501 s.mailLk = &sync.Mutex{}
502 }
503
504 return s, nil
505}
506
507func (s *Server) addRoutes() {
508 // static
509 if s.config.Version == "dev" {
510 s.echo.Static("/static", "server/static")
511 } else {
512 s.echo.GET("/static/*", echo.WrapHandler(http.FileServer(http.FS(staticFS))))
513 }
514
515 // random stuff
516 s.echo.GET("/", s.handleRoot)
517 s.echo.GET("/xrpc/_health", s.handleHealth)
518 s.echo.GET("/.well-known/did.json", s.handleWellKnown)
519 s.echo.GET("/.well-known/atproto-did", s.handleAtprotoDid)
520 s.echo.GET("/.well-known/oauth-protected-resource", s.handleOauthProtectedResource)
521 s.echo.GET("/.well-known/oauth-authorization-server", s.handleOauthAuthorizationServer)
522 s.echo.GET("/robots.txt", s.handleRobots)
523
524 // public
525 s.echo.GET("/xrpc/com.atproto.identity.resolveHandle", s.handleResolveHandle)
526 s.echo.POST("/xrpc/com.atproto.server.createAccount", s.handleCreateAccount)
527 s.echo.POST("/xrpc/com.atproto.server.createSession", s.handleCreateSession)
528 s.echo.GET("/xrpc/com.atproto.server.describeServer", s.handleDescribeServer)
529 s.echo.POST("/xrpc/com.atproto.server.reserveSigningKey", s.handleServerReserveSigningKey)
530
531 s.echo.GET("/xrpc/com.atproto.repo.describeRepo", s.handleDescribeRepo)
532 s.echo.GET("/xrpc/com.atproto.sync.listRepos", s.handleListRepos)
533 s.echo.GET("/xrpc/com.atproto.repo.listRecords", s.handleListRecords)
534 s.echo.GET("/xrpc/com.atproto.repo.getRecord", s.handleRepoGetRecord)
535 s.echo.GET("/xrpc/com.atproto.sync.getRecord", s.handleSyncGetRecord)
536 s.echo.GET("/xrpc/com.atproto.sync.getBlocks", s.handleGetBlocks)
537 s.echo.GET("/xrpc/com.atproto.sync.getLatestCommit", s.handleSyncGetLatestCommit)
538 s.echo.GET("/xrpc/com.atproto.sync.getRepoStatus", s.handleSyncGetRepoStatus)
539 s.echo.GET("/xrpc/com.atproto.sync.getRepo", s.handleSyncGetRepo)
540 s.echo.GET("/xrpc/com.atproto.sync.subscribeRepos", s.handleSyncSubscribeRepos)
541 s.echo.GET("/xrpc/com.atproto.sync.listBlobs", s.handleSyncListBlobs)
542 s.echo.GET("/xrpc/com.atproto.sync.getBlob", s.handleSyncGetBlob)
543
544 // labels
545 s.echo.GET("/xrpc/com.atproto.label.queryLabels", s.handleLabelQueryLabels)
546
547 // account
548 s.echo.GET("/account", s.handleAccount)
549 s.echo.POST("/account/revoke", s.handleAccountRevoke)
550 s.echo.GET("/account/signin", s.handleAccountSigninGet)
551 s.echo.POST("/account/signin", s.handleAccountSigninPost)
552 s.echo.GET("/account/signout", s.handleAccountSignout)
553
554 // oauth account
555 s.echo.GET("/oauth/jwks", s.handleOauthJwks)
556 s.echo.GET("/oauth/authorize", s.handleOauthAuthorizeGet)
557 s.echo.POST("/oauth/authorize", s.handleOauthAuthorizePost)
558
559 // oauth authorization
560 s.echo.POST("/oauth/par", s.handleOauthPar, s.oauthProvider.BaseMiddleware)
561 s.echo.POST("/oauth/token", s.handleOauthToken, s.oauthProvider.BaseMiddleware)
562
563 // authed
564 s.echo.GET("/xrpc/com.atproto.server.getSession", s.handleGetSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
565 s.echo.POST("/xrpc/com.atproto.server.refreshSession", s.handleRefreshSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
566 s.echo.POST("/xrpc/com.atproto.server.deleteSession", s.handleDeleteSession, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
567 s.echo.GET("/xrpc/com.atproto.identity.getRecommendedDidCredentials", s.handleGetRecommendedDidCredentials, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
568 s.echo.POST("/xrpc/com.atproto.identity.updateHandle", s.handleIdentityUpdateHandle, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
569 s.echo.POST("/xrpc/com.atproto.identity.requestPlcOperationSignature", s.handleIdentityRequestPlcOperationSignature, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
570 s.echo.POST("/xrpc/com.atproto.identity.signPlcOperation", s.handleSignPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
571 s.echo.POST("/xrpc/com.atproto.identity.submitPlcOperation", s.handleSubmitPlcOperation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
572 s.echo.POST("/xrpc/com.atproto.server.confirmEmail", s.handleServerConfirmEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
573 s.echo.POST("/xrpc/com.atproto.server.requestEmailConfirmation", s.handleServerRequestEmailConfirmation, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
574 s.echo.POST("/xrpc/com.atproto.server.requestPasswordReset", s.handleServerRequestPasswordReset) // AUTH NOT REQUIRED FOR THIS ONE
575 s.echo.POST("/xrpc/com.atproto.server.requestEmailUpdate", s.handleServerRequestEmailUpdate, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
576 s.echo.POST("/xrpc/com.atproto.server.resetPassword", s.handleServerResetPassword, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
577 s.echo.POST("/xrpc/com.atproto.server.updateEmail", s.handleServerUpdateEmail, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
578 s.echo.GET("/xrpc/com.atproto.server.getServiceAuth", s.handleServerGetServiceAuth, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
579 s.echo.GET("/xrpc/com.atproto.server.checkAccountStatus", s.handleServerCheckAccountStatus, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
580 s.echo.POST("/xrpc/com.atproto.server.deactivateAccount", s.handleServerDeactivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
581 s.echo.POST("/xrpc/com.atproto.server.activateAccount", s.handleServerActivateAccount, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
582 s.echo.POST("/xrpc/com.atproto.server.requestAccountDelete", s.handleServerRequestAccountDelete, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
583 s.echo.POST("/xrpc/com.atproto.server.deleteAccount", s.handleServerDeleteAccount)
584
585 // repo
586 s.echo.GET("/xrpc/com.atproto.repo.listMissingBlobs", s.handleListMissingBlobs, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
587 s.echo.POST("/xrpc/com.atproto.repo.createRecord", s.handleCreateRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
588 s.echo.POST("/xrpc/com.atproto.repo.putRecord", s.handlePutRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
589 s.echo.POST("/xrpc/com.atproto.repo.deleteRecord", s.handleDeleteRecord, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
590 s.echo.POST("/xrpc/com.atproto.repo.applyWrites", s.handleApplyWrites, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
591 s.echo.POST("/xrpc/com.atproto.repo.uploadBlob", s.handleRepoUploadBlob, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
592 s.echo.POST("/xrpc/com.atproto.repo.importRepo", s.handleRepoImportRepo, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
593
594 // stupid silly endpoints
595 s.echo.GET("/xrpc/app.bsky.actor.getPreferences", s.handleActorGetPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
596 s.echo.POST("/xrpc/app.bsky.actor.putPreferences", s.handleActorPutPreferences, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
597 s.echo.GET("/xrpc/app.bsky.feed.getFeed", s.handleProxyBskyFeedGetFeed, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
598 s.echo.GET("/xrpc/app.bsky.ageassurance.getState", s.handleAgeAssurance, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
599 // admin routes
600 s.echo.POST("/xrpc/com.atproto.server.createInviteCode", s.handleCreateInviteCode, s.handleAdminMiddleware)
601 s.echo.POST("/xrpc/com.atproto.server.createInviteCodes", s.handleCreateInviteCodes, s.handleAdminMiddleware)
602
603 // are there any routes that we should be allowing without auth? i dont think so but idk
604 s.echo.GET("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
605 s.echo.POST("/xrpc/*", s.handleProxy, s.handleLegacySessionMiddleware, s.handleOauthSessionMiddleware)
606}
607
608func (s *Server) Serve(ctx context.Context) error {
609 logger := s.logger.With("name", "Serve")
610
611 s.addRoutes()
612
613 logger.Info("migrating...")
614
615 s.db.AutoMigrate(
616 &models.Actor{},
617 &models.Repo{},
618 &models.InviteCode{},
619 &models.Token{},
620 &models.RefreshToken{},
621 &models.Block{},
622 &models.Record{},
623 &models.Blob{},
624 &models.BlobPart{},
625 &models.ReservedKey{},
626 &provider.OauthToken{},
627 &provider.OauthAuthorizationRequest{},
628 )
629
630 logger.Info("starting cocoon")
631
632 go func() {
633 if err := s.httpd.ListenAndServe(); err != nil {
634 panic(err)
635 }
636 }()
637
638 go s.backupRoutine()
639
640 go func() {
641 if err := s.requestCrawl(ctx); err != nil {
642 logger.Error("error requesting crawls", "err", err)
643 }
644 }()
645
646 if s.config.PushBasedEvents {
647 slog.Info("pushed based events enabled")
648 go func() {
649 if err := s.emmitEvents(ctx); err != nil {
650 logger.Error("error emitting events", "err", err)
651 }
652 }()
653 }
654
655 <-ctx.Done()
656
657 fmt.Println("shut down")
658
659 return nil
660}
661
662func (s *Server) requestCrawl(ctx context.Context) error {
663 logger := s.logger.With("component", "request-crawl")
664 s.requestCrawlMu.Lock()
665 defer s.requestCrawlMu.Unlock()
666
667 logger.Info("requesting crawl with configured relays")
668
669 if time.Since(s.lastRequestCrawl) <= 1*time.Minute {
670 return fmt.Errorf("a crawl request has already been made within the last minute")
671 }
672
673 for _, relay := range s.config.Relays {
674 logger := logger.With("relay", relay)
675 logger.Info("requesting crawl from relay")
676 cli := xrpc.Client{Host: relay}
677 if err := atproto.SyncRequestCrawl(ctx, &cli, &atproto.SyncRequestCrawl_Input{
678 Hostname: s.config.Hostname,
679 }); err != nil {
680 logger.Error("error requesting crawl", "err", err)
681 } else {
682 logger.Info("crawl requested successfully")
683 }
684 }
685
686 s.lastRequestCrawl = time.Now()
687
688 return nil
689}
690
691func (s *Server) doBackup() {
692 logger := s.logger.With("name", "doBackup")
693
694 if s.dbType == "postgres" || s.dbType == "turso" {
695 logger.Info("skipping S3 backup - PostgreSQL or Turso backups should be handled externally (pg_dump, managed database backups, etc.)")
696 return
697 }
698
699 start := time.Now()
700
701 logger.Info("beginning backup to s3...")
702
703 tmpFile := fmt.Sprintf("/tmp/cocoon-backup-%s.db", time.Now().Format(time.RFC3339Nano))
704 defer os.Remove(tmpFile)
705
706 if err := s.db.Client().Exec(fmt.Sprintf("VACUUM INTO '%s'", tmpFile)).Error; err != nil {
707 logger.Error("error creating tmp backup file", "err", err)
708 return
709 }
710
711 backupData, err := os.ReadFile(tmpFile)
712 if err != nil {
713 logger.Error("error reading tmp backup file", "err", err)
714 return
715 }
716
717 logger.Info("sending to s3...")
718
719 currTime := time.Now().Format("2006-01-02_15-04-05")
720 key := "cocoon-backup-" + currTime + ".db"
721
722 config := &aws.Config{
723 Region: aws.String(s.s3Config.Region),
724 Credentials: credentials.NewStaticCredentials(s.s3Config.AccessKey, s.s3Config.SecretKey, ""),
725 }
726
727 if s.s3Config.Endpoint != "" {
728 config.Endpoint = aws.String(s.s3Config.Endpoint)
729 config.S3ForcePathStyle = aws.Bool(true)
730 }
731
732 sess, err := session.NewSession(config)
733 if err != nil {
734 logger.Error("error creating s3 session", "err", err)
735 return
736 }
737
738 svc := s3.New(sess)
739
740 if _, err := svc.PutObject(&s3.PutObjectInput{
741 Bucket: aws.String(s.s3Config.Bucket),
742 Key: aws.String(key),
743 Body: bytes.NewReader(backupData),
744 }); err != nil {
745 logger.Error("error uploading file to s3", "err", err)
746 return
747 }
748
749 logger.Info("finished uploading backup to s3", "key", key, "duration", time.Since(start).Seconds())
750
751 os.WriteFile("last-backup.txt", []byte(time.Now().Format(time.RFC3339Nano)), 0644)
752}
753
754func (s *Server) backupRoutine() {
755 logger := s.logger.With("name", "backupRoutine")
756
757 if s.s3Config == nil || !s.s3Config.BackupsEnabled {
758 return
759 }
760
761 if s.s3Config.Region == "" {
762 logger.Warn("no s3 region configured but backups are enabled. backups will not run.")
763 return
764 }
765
766 if s.s3Config.Bucket == "" {
767 logger.Warn("no s3 bucket configured but backups are enabled. backups will not run.")
768 return
769 }
770
771 if s.s3Config.AccessKey == "" {
772 logger.Warn("no s3 access key configured but backups are enabled. backups will not run.")
773 return
774 }
775
776 if s.s3Config.SecretKey == "" {
777 logger.Warn("no s3 secret key configured but backups are enabled. backups will not run.")
778 return
779 }
780
781 shouldBackupNow := false
782 lastBackupStr, err := os.ReadFile("last-backup.txt")
783 if err != nil {
784 shouldBackupNow = true
785 } else {
786 lastBackup, err := time.Parse(time.RFC3339Nano, string(lastBackupStr))
787 if err != nil {
788 shouldBackupNow = true
789 } else if time.Since(lastBackup).Seconds() > 3600 {
790 shouldBackupNow = true
791 }
792 }
793
794 if shouldBackupNow {
795 go s.doBackup()
796 }
797
798 ticker := time.NewTicker(time.Hour)
799 for range ticker.C {
800 go s.doBackup()
801 }
802}
803
804func (s *Server) UpdateRepo(ctx context.Context, did string, root cid.Cid, rev string) error {
805 if err := s.db.Exec(ctx, "UPDATE repos SET root = ?, rev = ? WHERE did = ?", nil, root.Bytes(), rev, did).Error; err != nil {
806 return err
807 }
808
809 return nil
810}