Live video on the AT Protocol
at next 182 lines 4.5 kB view raw
1package statedb 2 3import ( 4 "context" 5 "fmt" 6 "net/url" 7 "strings" 8 "sync" 9 10 "github.com/streamplace/oatproxy/pkg/oatproxy" 11 "gorm.io/driver/postgres" 12 "gorm.io/driver/sqlite" 13 "gorm.io/gorm" 14 "gorm.io/plugin/prometheus" 15 "stream.place/streamplace/pkg/config" 16 "stream.place/streamplace/pkg/log" 17 "stream.place/streamplace/pkg/model" 18 notificationpkg "stream.place/streamplace/pkg/notifications" 19) 20 21type DBType string 22 23const ( 24 DBTypeSQLite DBType = "sqlite" 25 DBTypePostgres DBType = "postgres" 26) 27 28type StatefulDB struct { 29 DB *gorm.DB 30 CLI *config.CLI 31 Type DBType 32 locks *NamedLocks 33 noter notificationpkg.FirebaseNotifier 34 model model.Model 35 // pokeQueue is used to wake up the queue processor when a new task is enqueued 36 pokeQueue chan struct{} 37 // pgLockConn is used to hold a connection to the database for locking 38 pgLockConn *gorm.DB 39 pgLockConnMu sync.Mutex 40 op *oatproxy.OATProxy 41} 42 43// list tables here so we can migrate them 44var StatefulDBModels = []any{ 45 oatproxy.OAuthSession{}, 46 Notification{}, 47 Config{}, 48 XrpcStreamEvent{}, 49 AppTask{}, 50 Repo{}, 51 Webhook{}, 52 MultistreamTarget{}, 53 MultistreamEvent{}, 54 BrandingBlob{}, 55 ModerationAuditLog{}, 56} 57 58var NoPostgresDatabaseCode = "3D000" 59 60// Stateful database for storing private streamplace state 61func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) { 62 dbURL := cli.DBURL 63 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL)) 64 var dial gorm.Dialector 65 var dbType DBType 66 if dbURL == ":memory:" { 67 dial = sqlite.Open(":memory:") 68 dbType = DBTypeSQLite 69 } else if strings.HasPrefix(dbURL, "sqlite://") { 70 dial = sqlite.Open(dbURL[len("sqlite://"):]) 71 dbType = DBTypeSQLite 72 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") { 73 dial = postgres.Open(dbURL) 74 dbType = DBTypePostgres 75 } else { 76 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL)) 77 } 78 79 db, err := openDB(dial) 80 81 if err != nil { 82 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) { 83 db, err = makePostgresDB(dbURL) 84 if err != nil { 85 return nil, fmt.Errorf("error creating streamplace database: %w", err) 86 } 87 } else { 88 return nil, fmt.Errorf("error starting database: %w", err) 89 } 90 } 91 if dbType == DBTypeSQLite { 92 err = db.Exec("PRAGMA journal_mode=WAL;").Error 93 if err != nil { 94 return nil, fmt.Errorf("error setting journal mode: %w", err) 95 } 96 sqlDB, err := db.DB() 97 if err != nil { 98 return nil, fmt.Errorf("error getting database: %w", err) 99 } 100 sqlDB.SetMaxOpenConns(1) 101 } 102 for _, model := range StatefulDBModels { 103 err = db.AutoMigrate(model) 104 if err != nil { 105 return nil, err 106 } 107 } 108 109 err = db.Use(prometheus.New(prometheus.Config{ 110 DBName: "state", 111 RefreshInterval: 10, 112 StartServer: false, 113 })) 114 if err != nil { 115 return nil, fmt.Errorf("error using prometheus plugin: %w", err) 116 } 117 118 state := &StatefulDB{ 119 DB: db, 120 CLI: cli, 121 Type: dbType, 122 locks: NewNamedLocks(), 123 model: model, 124 pokeQueue: make(chan struct{}, 1), 125 noter: noter, 126 } 127 if state.Type == DBTypePostgres { 128 err = state.startPostgresLockerConn(ctx) 129 if err != nil { 130 return nil, fmt.Errorf("error starting postgres locker connection: %w", err) 131 } 132 } 133 return state, nil 134} 135 136func openDB(dial gorm.Dialector) (*gorm.DB, error) { 137 return gorm.Open(dial, &gorm.Config{ 138 SkipDefaultTransaction: true, 139 TranslateError: true, 140 Logger: config.GormLogger, 141 }) 142} 143 144// helper function for creating the requested postgres database 145func makePostgresDB(dbURL string) (*gorm.DB, error) { 146 u, err := url.Parse(dbURL) 147 if err != nil { 148 return nil, err 149 } 150 dbName := strings.TrimPrefix(u.Path, "/") 151 u.Path = "/postgres" 152 153 rootDial := postgres.Open(u.String()) 154 155 db, err := openDB(rootDial) 156 if err != nil { 157 return nil, err 158 } 159 160 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself. 161 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error 162 if err != nil { 163 return nil, err 164 } 165 166 log.Warn(context.Background(), "created postgres database", "dbName", dbName) 167 168 realDial := postgres.Open(dbURL) 169 170 return openDB(realDial) 171} 172 173func redactDBURL(dbURL string) string { 174 u, err := url.Parse(dbURL) 175 if err != nil { 176 return "db url is malformed" 177 } 178 if u.User != nil { 179 u.User = url.UserPassword(u.User.Username(), "redacted") 180 } 181 return u.String() 182}