Live video on the AT Protocol
79
fork

Configure Feed

Select the types of activity you want to include in your feed.

at v0.7.20 164 lines 4.0 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/lmittmann/tint" 12 slogGorm "github.com/orandin/slog-gorm" 13 "github.com/streamplace/oatproxy/pkg/oatproxy" 14 "gorm.io/driver/postgres" 15 "gorm.io/driver/sqlite" 16 "gorm.io/gorm" 17 "stream.place/streamplace/pkg/config" 18 "stream.place/streamplace/pkg/log" 19 "stream.place/streamplace/pkg/model" 20 notificationpkg "stream.place/streamplace/pkg/notifications" 21) 22 23type DBType string 24 25const ( 26 DBTypeSQLite DBType = "sqlite" 27 DBTypePostgres DBType = "postgres" 28) 29 30type StatefulDB struct { 31 DB *gorm.DB 32 CLI *config.CLI 33 Type DBType 34 locks *NamedLocks 35 noter notificationpkg.FirebaseNotifier 36 model model.Model 37 // pokeQueue is used to wake up the queue processor when a new task is enqueued 38 pokeQueue chan struct{} 39} 40 41// list tables here so we can migrate them 42var StatefulDBModels = []any{ 43 oatproxy.OAuthSession{}, 44 Notification{}, 45 Config{}, 46 XrpcStreamEvent{}, 47 AppTask{}, 48 Repo{}, 49} 50 51var NoPostgresDatabaseCode = "3D000" 52 53// Stateful database for storing private streamplace state 54func MakeDB(cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 55 dbURL := cli.DBURL 56 log.Log(context.Background(), "starting stateful database", "dbURL", redactDBURL(dbURL)) 57 var dial gorm.Dialector 58 var dbType DBType 59 if dbURL == ":memory:" { 60 dial = sqlite.Open(":memory:") 61 dbType = DBTypeSQLite 62 } else if strings.HasPrefix(dbURL, "sqlite://") { 63 dial = sqlite.Open(dbURL[len("sqlite://"):]) 64 dbType = DBTypeSQLite 65 } else if strings.HasPrefix(dbURL, "postgres://") { 66 dial = postgres.Open(dbURL) 67 dbType = DBTypePostgres 68 } else { 69 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgres://): %s", redactDBURL(dbURL)) 70 } 71 72 db, err := openDB(dial) 73 74 if err != nil { 75 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 76 db, err = makePostgresDB(dbURL) 77 if err != nil { 78 return nil, fmt.Errorf("error creating streamplace database: %w", err) 79 } 80 } else { 81 return nil, fmt.Errorf("error starting database: %w", err) 82 } 83 } 84 if dbType == DBTypeSQLite { 85 err = db.Exec("PRAGMA journal_mode=WAL;").Error 86 if err != nil { 87 return nil, fmt.Errorf("error setting journal mode: %w", err) 88 } 89 sqlDB, err := db.DB() 90 if err != nil { 91 return nil, fmt.Errorf("error getting database: %w", err) 92 } 93 sqlDB.SetMaxOpenConns(1) 94 } 95 for _, model := range StatefulDBModels { 96 err = db.AutoMigrate(model) 97 if err != nil { 98 return nil, err 99 } 100 } 101 return &StatefulDB{ 102 DB: db, 103 CLI: cli, 104 Type: dbType, 105 locks: NewNamedLocks(), 106 model: model, 107 pokeQueue: make(chan struct{}, 1), 108 }, nil 109} 110 111func openDB(dial gorm.Dialector) (*gorm.DB, error) { 112 gormLogger := slogGorm.New( 113 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{ 114 TimeFormat: time.RFC3339, 115 })), 116 // slogGorm.WithTraceAll(), 117 ) 118 119 return gorm.Open(dial, &gorm.Config{ 120 SkipDefaultTransaction: true, 121 TranslateError: true, 122 Logger: gormLogger, 123 }) 124} 125 126// helper function for creating the requested postgres database 127func makePostgresDB(dbURL string) (*gorm.DB, error) { 128 u, err := url.Parse(dbURL) 129 if err != nil { 130 return nil, err 131 } 132 dbName := strings.TrimPrefix(u.Path, "/") 133 u.Path = "/postgres" 134 135 rootDial := postgres.Open(u.String()) 136 137 db, err := openDB(rootDial) 138 if err != nil { 139 return nil, err 140 } 141 142 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 143 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 144 if err != nil { 145 return nil, err 146 } 147 148 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 149 150 realDial := postgres.Open(dbURL) 151 152 return openDB(realDial) 153} 154 155func redactDBURL(dbURL string) string { 156 u, err := url.Parse(dbURL) 157 if err != nil { 158 return "db url is malformed" 159 } 160 if u.User != nil { 161 u.User = url.UserPassword(u.User.Username(), "redacted") 162 } 163 return u.String() 164}