Live video on the AT Protocol
at eli/revert-dev-env 164 lines 4.0 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/lmittmann/tint" 12 slogGorm "github.com/orandin/slog-gorm" 13 "github.com/streamplace/oatproxy/pkg/oatproxy" 14 "gorm.io/driver/postgres" 15 "gorm.io/driver/sqlite" 16 "gorm.io/gorm" 17 "stream.place/streamplace/pkg/config" 18 "stream.place/streamplace/pkg/log" 19 "stream.place/streamplace/pkg/model" 20 notificationpkg "stream.place/streamplace/pkg/notifications" 21) 22 23type DBType string 24 25const ( 26 DBTypeSQLite DBType = "sqlite" 27 DBTypePostgres DBType = "postgres" 28) 29 30type StatefulDB struct { 31 DB *gorm.DB 32 CLI *config.CLI 33 Type DBType 34 locks *NamedLocks 35 noter notificationpkg.FirebaseNotifier 36 model model.Model 37 // pokeQueue is used to wake up the queue processor when a new task is enqueued 38 pokeQueue chan struct{} 39} 40 41// list tables here so we can migrate them 42var StatefulDBModels = []any{ 43 oatproxy.OAuthSession{}, 44 Notification{}, 45 Config{}, 46 XrpcStreamEvent{}, 47 AppTask{}, 48 Repo{}, 49} 50 51var NoPostgresDatabaseCode = "3D000" 52 53// Stateful database for storing private streamplace state 54func MakeDB(cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 55 dbURL := cli.DBURL 56 log.Log(context.Background(), "starting stateful database", "dbURL", redactDBURL(dbURL)) 57 var dial gorm.Dialector 58 var dbType DBType 59 if dbURL == ":memory:" { 60 dial = sqlite.Open(":memory:") 61 dbType = DBTypeSQLite 62 } else if strings.HasPrefix(dbURL, "sqlite://") { 63 dial = sqlite.Open(dbURL[len("sqlite://"):]) 64 dbType = DBTypeSQLite 65 } else if strings.HasPrefix(dbURL, "postgres://") { 66 dial = postgres.Open(dbURL) 67 dbType = DBTypePostgres 68 } else { 69 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgres://): %s", redactDBURL(dbURL)) 70 } 71 72 db, err := openDB(dial) 73 74 if err != nil { 75 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 76 db, err = makePostgresDB(dbURL) 77 if err != nil { 78 return nil, fmt.Errorf("error creating streamplace database: %w", err) 79 } 80 } else { 81 return nil, fmt.Errorf("error starting database: %w", err) 82 } 83 } 84 if dbType == DBTypeSQLite { 85 err = db.Exec("PRAGMA journal_mode=WAL;").Error 86 if err != nil { 87 return nil, fmt.Errorf("error setting journal mode: %w", err) 88 } 89 sqlDB, err := db.DB() 90 if err != nil { 91 return nil, fmt.Errorf("error getting database: %w", err) 92 } 93 sqlDB.SetMaxOpenConns(1) 94 } 95 for _, model := range StatefulDBModels { 96 err = db.AutoMigrate(model) 97 if err != nil { 98 return nil, err 99 } 100 } 101 return &StatefulDB{ 102 DB: db, 103 CLI: cli, 104 Type: dbType, 105 locks: NewNamedLocks(), 106 model: model, 107 pokeQueue: make(chan struct{}, 1), 108 }, nil 109} 110 111func openDB(dial gorm.Dialector) (*gorm.DB, error) { 112 gormLogger := slogGorm.New( 113 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{ 114 TimeFormat: time.RFC3339, 115 })), 116 // slogGorm.WithTraceAll(), 117 ) 118 119 return gorm.Open(dial, &gorm.Config{ 120 SkipDefaultTransaction: true, 121 TranslateError: true, 122 Logger: gormLogger, 123 }) 124} 125 126// helper function for creating the requested postgres database 127func makePostgresDB(dbURL string) (*gorm.DB, error) { 128 u, err := url.Parse(dbURL) 129 if err != nil { 130 return nil, err 131 } 132 dbName := strings.TrimPrefix(u.Path, "/") 133 u.Path = "/postgres" 134 135 rootDial := postgres.Open(u.String()) 136 137 db, err := openDB(rootDial) 138 if err != nil { 139 return nil, err 140 } 141 142 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 143 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 144 if err != nil { 145 return nil, err 146 } 147 148 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 149 150 realDial := postgres.Open(dbURL) 151 152 return openDB(realDial) 153} 154 155func redactDBURL(dbURL string) string { 156 u, err := url.Parse(dbURL) 157 if err != nil { 158 return "db url is malformed" 159 } 160 if u.User != nil { 161 u.User = url.UserPassword(u.User.Username(), "redacted") 162 } 163 return u.String() 164}