Live video on the AT Protocol
at eli/handle-changes 164 lines 4.2 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "strings" 8 "sync" 9 10 "github.com/streamplace/oatproxy/pkg/oatproxy" 11 "gorm.io/driver/postgres" 12 "gorm.io/driver/sqlite" 13 "gorm.io/gorm" 14 "stream.place/streamplace/pkg/config" 15 "stream.place/streamplace/pkg/log" 16 "stream.place/streamplace/pkg/model" 17 notificationpkg "stream.place/streamplace/pkg/notifications" 18) 19 20type DBType string 21 22const ( 23 DBTypeSQLite DBType = "sqlite" 24 DBTypePostgres DBType = "postgres" 25) 26 27type StatefulDB struct { 28 DB *gorm.DB 29 CLI *config.CLI 30 Type DBType 31 locks *NamedLocks 32 noter notificationpkg.FirebaseNotifier 33 model model.Model 34 // pokeQueue is used to wake up the queue processor when a new task is enqueued 35 pokeQueue chan struct{} 36 // pgLockConn is used to hold a connection to the database for locking 37 pgLockConn *gorm.DB 38 pgLockConnMu sync.Mutex 39} 40 41// list tables here so we can migrate them 42var StatefulDBModels = []any{ 43 oatproxy.OAuthSession{}, 44 Notification{}, 45 Config{}, 46 XrpcStreamEvent{}, 47 AppTask{}, 48 Repo{}, 49} 50 51var NoPostgresDatabaseCode = "3D000" 52 53// Stateful database for storing private streamplace state 54func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 55 dbURL := cli.DBURL 56 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL)) 57 var dial gorm.Dialector 58 var dbType DBType 59 if dbURL == ":memory:" { 60 dial = sqlite.Open(":memory:") 61 dbType = DBTypeSQLite 62 } else if strings.HasPrefix(dbURL, "sqlite://") { 63 dial = sqlite.Open(dbURL[len("sqlite://"):]) 64 dbType = DBTypeSQLite 65 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") { 66 dial = postgres.Open(dbURL) 67 dbType = DBTypePostgres 68 } else { 69 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL)) 70 } 71 72 db, err := openDB(dial) 73 74 if err != nil { 75 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 76 db, err = makePostgresDB(dbURL) 77 if err != nil { 78 return nil, fmt.Errorf("error creating streamplace database: %w", err) 79 } 80 } else { 81 return nil, fmt.Errorf("error starting database: %w", err) 82 } 83 } 84 if dbType == DBTypeSQLite { 85 err = db.Exec("PRAGMA journal_mode=WAL;").Error 86 if err != nil { 87 return nil, fmt.Errorf("error setting journal mode: %w", err) 88 } 89 sqlDB, err := db.DB() 90 if err != nil { 91 return nil, fmt.Errorf("error getting database: %w", err) 92 } 93 sqlDB.SetMaxOpenConns(1) 94 } 95 for _, model := range StatefulDBModels { 96 err = db.AutoMigrate(model) 97 if err != nil { 98 return nil, err 99 } 100 } 101 state := &StatefulDB{ 102 DB: db, 103 CLI: cli, 104 Type: dbType, 105 locks: NewNamedLocks(), 106 model: model, 107 pokeQueue: make(chan struct{}, 1), 108 } 109 if state.Type == DBTypePostgres { 110 err = state.startPostgresLockerConn(ctx) 111 if err != nil { 112 return nil, fmt.Errorf("error starting postgres locker connection: %w", err) 113 } 114 } 115 return state, nil 116} 117 118func openDB(dial gorm.Dialector) (*gorm.DB, error) { 119 return gorm.Open(dial, &gorm.Config{ 120 SkipDefaultTransaction: true, 121 TranslateError: true, 122 Logger: config.GormLogger, 123 }) 124} 125 126// helper function for creating the requested postgres database 127func makePostgresDB(dbURL string) (*gorm.DB, error) { 128 u, err := url.Parse(dbURL) 129 if err != nil { 130 return nil, err 131 } 132 dbName := strings.TrimPrefix(u.Path, "/") 133 u.Path = "/postgres" 134 135 rootDial := postgres.Open(u.String()) 136 137 db, err := openDB(rootDial) 138 if err != nil { 139 return nil, err 140 } 141 142 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 143 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 144 if err != nil { 145 return nil, err 146 } 147 148 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 149 150 realDial := postgres.Open(dbURL) 151 152 return openDB(realDial) 153} 154 155func redactDBURL(dbURL string) string { 156 u, err := url.Parse(dbURL) 157 if err != nil { 158 return "db url is malformed" 159 } 160 if u.User != nil { 161 u.User = url.UserPassword(u.User.Username(), "redacted") 162 } 163 return u.String() 164}