Live video on the AT Protocol
at eli/fix-context-recursion 175 lines 4.4 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "os" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/lmittmann/tint" 13 slogGorm "github.com/orandin/slog-gorm" 14 "github.com/streamplace/oatproxy/pkg/oatproxy" 15 "gorm.io/driver/postgres" 16 "gorm.io/driver/sqlite" 17 "gorm.io/gorm" 18 "stream.place/streamplace/pkg/config" 19 "stream.place/streamplace/pkg/log" 20 "stream.place/streamplace/pkg/model" 21 notificationpkg "stream.place/streamplace/pkg/notifications" 22) 23 24type DBType string 25 26const ( 27 DBTypeSQLite DBType = "sqlite" 28 DBTypePostgres DBType = "postgres" 29) 30 31type StatefulDB struct { 32 DB *gorm.DB 33 CLI *config.CLI 34 Type DBType 35 locks *NamedLocks 36 noter notificationpkg.FirebaseNotifier 37 model model.Model 38 // pokeQueue is used to wake up the queue processor when a new task is enqueued 39 pokeQueue chan struct{} 40 // pgLockConn is used to hold a connection to the database for locking 41 pgLockConn *gorm.DB 42 pgLockConnMu sync.Mutex 43} 44 45// list tables here so we can migrate them 46var StatefulDBModels = []any{ 47 oatproxy.OAuthSession{}, 48 Notification{}, 49 Config{}, 50 XrpcStreamEvent{}, 51 AppTask{}, 52 Repo{}, 53} 54 55var NoPostgresDatabaseCode = "3D000" 56 57// Stateful database for storing private streamplace state 58func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 59 dbURL := cli.DBURL 60 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL)) 61 var dial gorm.Dialector 62 var dbType DBType 63 if dbURL == ":memory:" { 64 dial = sqlite.Open(":memory:") 65 dbType = DBTypeSQLite 66 } else if strings.HasPrefix(dbURL, "sqlite://") { 67 dial = sqlite.Open(dbURL[len("sqlite://"):]) 68 dbType = DBTypeSQLite 69 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") { 70 dial = postgres.Open(dbURL) 71 dbType = DBTypePostgres 72 } else { 73 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL)) 74 } 75 76 db, err := openDB(dial) 77 78 if err != nil { 79 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 80 db, err = makePostgresDB(dbURL) 81 if err != nil { 82 return nil, fmt.Errorf("error creating streamplace database: %w", err) 83 } 84 } else { 85 return nil, fmt.Errorf("error starting database: %w", err) 86 } 87 } 88 if dbType == DBTypeSQLite { 89 err = db.Exec("PRAGMA journal_mode=WAL;").Error 90 if err != nil { 91 return nil, fmt.Errorf("error setting journal mode: %w", err) 92 } 93 sqlDB, err := db.DB() 94 if err != nil { 95 return nil, fmt.Errorf("error getting database: %w", err) 96 } 97 sqlDB.SetMaxOpenConns(1) 98 } 99 for _, model := range StatefulDBModels { 100 err = db.AutoMigrate(model) 101 if err != nil { 102 return nil, err 103 } 104 } 105 state := &StatefulDB{ 106 DB: db, 107 CLI: cli, 108 Type: dbType, 109 locks: NewNamedLocks(), 110 model: model, 111 pokeQueue: make(chan struct{}, 1), 112 } 113 if state.Type == DBTypePostgres { 114 err = state.startPostgresLockerConn(ctx) 115 if err != nil { 116 return nil, fmt.Errorf("error starting postgres locker connection: %w", err) 117 } 118 } 119 return state, nil 120} 121 122func openDB(dial gorm.Dialector) (*gorm.DB, error) { 123 gormLogger := slogGorm.New( 124 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{ 125 TimeFormat: time.RFC3339, 126 })), 127 slogGorm.WithTraceAll(), 128 ) 129 130 return gorm.Open(dial, &gorm.Config{ 131 SkipDefaultTransaction: true, 132 TranslateError: true, 133 Logger: gormLogger, 134 }) 135} 136 137// helper function for creating the requested postgres database 138func makePostgresDB(dbURL string) (*gorm.DB, error) { 139 u, err := url.Parse(dbURL) 140 if err != nil { 141 return nil, err 142 } 143 dbName := strings.TrimPrefix(u.Path, "/") 144 u.Path = "/postgres" 145 146 rootDial := postgres.Open(u.String()) 147 148 db, err := openDB(rootDial) 149 if err != nil { 150 return nil, err 151 } 152 153 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 154 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 155 if err != nil { 156 return nil, err 157 } 158 159 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 160 161 realDial := postgres.Open(dbURL) 162 163 return openDB(realDial) 164} 165 166func redactDBURL(dbURL string) string { 167 u, err := url.Parse(dbURL) 168 if err != nil { 169 return "db url is malformed" 170 } 171 if u.User != nil { 172 u.User = url.UserPassword(u.User.Username(), "redacted") 173 } 174 return u.String() 175}