Live video on the AT Protocol
at eli/postgres 163 lines 4.0 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "os" 8 "strings" 9 "time" 10 11 "github.com/lmittmann/tint" 12 slogGorm "github.com/orandin/slog-gorm" 13 "github.com/streamplace/oatproxy/pkg/oatproxy" 14 "gorm.io/driver/postgres" 15 "gorm.io/driver/sqlite" 16 "gorm.io/gorm" 17 "stream.place/streamplace/pkg/config" 18 "stream.place/streamplace/pkg/log" 19 "stream.place/streamplace/pkg/model" 20 notificationpkg "stream.place/streamplace/pkg/notifications" 21) 22 23type DBType string 24 25const ( 26 DBTypeSQLite DBType = "sqlite" 27 DBTypePostgres DBType = "postgres" 28) 29 30type StatefulDB struct { 31 DB *gorm.DB 32 CLI *config.CLI 33 Type DBType 34 locks *NamedLocks 35 noter notificationpkg.FirebaseNotifier 36 model model.Model 37 // pokeQueue is used to wake up the queue processor when a new task is enqueued 38 pokeQueue chan struct{} 39} 40 41// list tables here so we can migrate them 42var StatefulDBModels = []any{ 43 oatproxy.OAuthSession{}, 44 Notification{}, 45 Config{}, 46 XrpcStreamEvent{}, 47 AppTask{}, 48} 49 50var NoPostgresDatabaseCode = "3D000" 51 52// Stateful database for storing private streamplace state 53func MakeDB(cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 54 dbURL := cli.DBURL 55 log.Log(context.Background(), "starting stateful database", "dbURL", redactDBURL(dbURL)) 56 var dial gorm.Dialector 57 var dbType DBType 58 if dbURL == ":memory:" { 59 dial = sqlite.Open(":memory:") 60 dbType = DBTypeSQLite 61 } else if strings.HasPrefix(dbURL, "sqlite://") { 62 dial = sqlite.Open(dbURL[len("sqlite://"):]) 63 dbType = DBTypeSQLite 64 } else if strings.HasPrefix(dbURL, "postgres://") { 65 dial = postgres.Open(dbURL) 66 dbType = DBTypePostgres 67 } else { 68 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgres://): %s", redactDBURL(dbURL)) 69 } 70 71 db, err := openDB(dial) 72 73 if err != nil { 74 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 75 db, err = makePostgresDB(dbURL) 76 if err != nil { 77 return nil, fmt.Errorf("error creating streamplace database: %w", err) 78 } 79 } else { 80 return nil, fmt.Errorf("error starting database: %w", err) 81 } 82 } 83 if dbType == DBTypeSQLite { 84 err = db.Exec("PRAGMA journal_mode=WAL;").Error 85 if err != nil { 86 return nil, fmt.Errorf("error setting journal mode: %w", err) 87 } 88 sqlDB, err := db.DB() 89 if err != nil { 90 return nil, fmt.Errorf("error getting database: %w", err) 91 } 92 sqlDB.SetMaxOpenConns(1) 93 } 94 for _, model := range StatefulDBModels { 95 err = db.AutoMigrate(model) 96 if err != nil { 97 return nil, err 98 } 99 } 100 return &StatefulDB{ 101 DB: db, 102 CLI: cli, 103 Type: dbType, 104 locks: NewNamedLocks(), 105 model: model, 106 pokeQueue: make(chan struct{}, 1), 107 }, nil 108} 109 110func openDB(dial gorm.Dialector) (*gorm.DB, error) { 111 gormLogger := slogGorm.New( 112 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{ 113 TimeFormat: time.RFC3339, 114 })), 115 // slogGorm.WithTraceAll(), 116 ) 117 118 return gorm.Open(dial, &gorm.Config{ 119 SkipDefaultTransaction: true, 120 TranslateError: true, 121 Logger: gormLogger, 122 }) 123} 124 125// helper function for creating the requested postgres database 126func makePostgresDB(dbURL string) (*gorm.DB, error) { 127 u, err := url.Parse(dbURL) 128 if err != nil { 129 return nil, err 130 } 131 dbName := strings.TrimPrefix(u.Path, "/") 132 u.Path = "/postgres" 133 134 rootDial := postgres.Open(u.String()) 135 136 db, err := openDB(rootDial) 137 if err != nil { 138 return nil, err 139 } 140 141 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 142 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 143 if err != nil { 144 return nil, err 145 } 146 147 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 148 149 realDial := postgres.Open(dbURL) 150 151 return openDB(realDial) 152} 153 154func redactDBURL(dbURL string) string { 155 u, err := url.Parse(dbURL) 156 if err != nil { 157 return "db url is malformed" 158 } 159 if u.User != nil { 160 u.User = url.UserPassword(u.User.Username(), "redacted") 161 } 162 return u.String() 163}