Live video on the AT Protocol
at eli/buffer-thumbnails 166 lines 4.2 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "strings" 8 "sync" 9 10 "github.com/streamplace/oatproxy/pkg/oatproxy" 11 "gorm.io/driver/postgres" 12 "gorm.io/driver/sqlite" 13 "gorm.io/gorm" 14 "stream.place/streamplace/pkg/config" 15 "stream.place/streamplace/pkg/log" 16 "stream.place/streamplace/pkg/model" 17 notificationpkg "stream.place/streamplace/pkg/notifications" 18) 19 20type DBType string 21 22const ( 23 DBTypeSQLite DBType = "sqlite" 24 DBTypePostgres DBType = "postgres" 25) 26 27type StatefulDB struct { 28 DB *gorm.DB 29 CLI *config.CLI 30 Type DBType 31 locks *NamedLocks 32 noter notificationpkg.FirebaseNotifier 33 model model.Model 34 // pokeQueue is used to wake up the queue processor when a new task is enqueued 35 pokeQueue chan struct{} 36 // pgLockConn is used to hold a connection to the database for locking 37 pgLockConn *gorm.DB 38 pgLockConnMu sync.Mutex 39} 40 41// list tables here so we can migrate them 42var StatefulDBModels = []any{ 43 oatproxy.OAuthSession{}, 44 Notification{}, 45 Config{}, 46 XrpcStreamEvent{}, 47 AppTask{}, 48 Repo{}, 49 Webhook{}, 50} 51 52var NoPostgresDatabaseCode = "3D000" 53 54// Stateful database for storing private streamplace state 55func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 56 dbURL := cli.DBURL 57 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL)) 58 var dial gorm.Dialector 59 var dbType DBType 60 if dbURL == ":memory:" { 61 dial = sqlite.Open(":memory:") 62 dbType = DBTypeSQLite 63 } else if strings.HasPrefix(dbURL, "sqlite://") { 64 dial = sqlite.Open(dbURL[len("sqlite://"):]) 65 dbType = DBTypeSQLite 66 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") { 67 dial = postgres.Open(dbURL) 68 dbType = DBTypePostgres 69 } else { 70 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL)) 71 } 72 73 db, err := openDB(dial) 74 75 if err != nil { 76 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 77 db, err = makePostgresDB(dbURL) 78 if err != nil { 79 return nil, fmt.Errorf("error creating streamplace database: %w", err) 80 } 81 } else { 82 return nil, fmt.Errorf("error starting database: %w", err) 83 } 84 } 85 if dbType == DBTypeSQLite { 86 err = db.Exec("PRAGMA journal_mode=WAL;").Error 87 if err != nil { 88 return nil, fmt.Errorf("error setting journal mode: %w", err) 89 } 90 sqlDB, err := db.DB() 91 if err != nil { 92 return nil, fmt.Errorf("error getting database: %w", err) 93 } 94 sqlDB.SetMaxOpenConns(1) 95 } 96 for _, model := range StatefulDBModels { 97 err = db.AutoMigrate(model) 98 if err != nil { 99 return nil, err 100 } 101 } 102 state := &StatefulDB{ 103 DB: db, 104 CLI: cli, 105 Type: dbType, 106 locks: NewNamedLocks(), 107 model: model, 108 pokeQueue: make(chan struct{}, 1), 109 noter: noter, 110 } 111 if state.Type == DBTypePostgres { 112 err = state.startPostgresLockerConn(ctx) 113 if err != nil { 114 return nil, fmt.Errorf("error starting postgres locker connection: %w", err) 115 } 116 } 117 return state, nil 118} 119 120func openDB(dial gorm.Dialector) (*gorm.DB, error) { 121 return gorm.Open(dial, &gorm.Config{ 122 SkipDefaultTransaction: true, 123 TranslateError: true, 124 Logger: config.GormLogger, 125 }) 126} 127 128// helper function for creating the requested postgres database 129func makePostgresDB(dbURL string) (*gorm.DB, error) { 130 u, err := url.Parse(dbURL) 131 if err != nil { 132 return nil, err 133 } 134 dbName := strings.TrimPrefix(u.Path, "/") 135 u.Path = "/postgres" 136 137 rootDial := postgres.Open(u.String()) 138 139 db, err := openDB(rootDial) 140 if err != nil { 141 return nil, err 142 } 143 144 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 145 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 146 if err != nil { 147 return nil, err 148 } 149 150 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 151 152 realDial := postgres.Open(dbURL) 153 154 return openDB(realDial) 155} 156 157func redactDBURL(dbURL string) string { 158 u, err := url.Parse(dbURL) 159 if err != nil { 160 return "db url is malformed" 161 } 162 if u.User != nil { 163 u.User = url.UserPassword(u.User.Username(), "redacted") 164 } 165 return u.String() 166}