Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "strings"
8 "sync"
9
10 "github.com/streamplace/oatproxy/pkg/oatproxy"
11 "gorm.io/driver/postgres"
12 "gorm.io/driver/sqlite"
13 "gorm.io/gorm"
14 "gorm.io/plugin/prometheus"
15 "stream.place/streamplace/pkg/config"
16 "stream.place/streamplace/pkg/log"
17 "stream.place/streamplace/pkg/model"
18 notificationpkg "stream.place/streamplace/pkg/notifications"
19)
20
21type DBType string
22
23const (
24 DBTypeSQLite DBType = "sqlite"
25 DBTypePostgres DBType = "postgres"
26)
27
28type StatefulDB struct {
29 DB *gorm.DB
30 CLI *config.CLI
31 Type DBType
32 locks *NamedLocks
33 noter notificationpkg.FirebaseNotifier
34 model model.Model
35 // pokeQueue is used to wake up the queue processor when a new task is enqueued
36 pokeQueue chan struct{}
37 // pgLockConn is used to hold a connection to the database for locking
38 pgLockConn *gorm.DB
39 pgLockConnMu sync.Mutex
40}
41
42// list tables here so we can migrate them
43var StatefulDBModels = []any{
44 oatproxy.OAuthSession{},
45 Notification{},
46 Config{},
47 XrpcStreamEvent{},
48 AppTask{},
49 Repo{},
50 Webhook{},
51}
52
53var NoPostgresDatabaseCode = "3D000"
54
55// Stateful database for storing private streamplace state
56func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
57 dbURL := cli.DBURL
58 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
59 var dial gorm.Dialector
60 var dbType DBType
61 if dbURL == ":memory:" {
62 dial = sqlite.Open(":memory:")
63 dbType = DBTypeSQLite
64 } else if strings.HasPrefix(dbURL, "sqlite://") {
65 dial = sqlite.Open(dbURL[len("sqlite://"):])
66 dbType = DBTypeSQLite
67 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
68 dial = postgres.Open(dbURL)
69 dbType = DBTypePostgres
70 } else {
71 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
72 }
73
74 db, err := openDB(dial)
75
76 if err != nil {
77 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
78 db, err = makePostgresDB(dbURL)
79 if err != nil {
80 return nil, fmt.Errorf("error creating streamplace database: %w", err)
81 }
82 } else {
83 return nil, fmt.Errorf("error starting database: %w", err)
84 }
85 }
86 if dbType == DBTypeSQLite {
87 err = db.Exec("PRAGMA journal_mode=WAL;").Error
88 if err != nil {
89 return nil, fmt.Errorf("error setting journal mode: %w", err)
90 }
91 sqlDB, err := db.DB()
92 if err != nil {
93 return nil, fmt.Errorf("error getting database: %w", err)
94 }
95 sqlDB.SetMaxOpenConns(1)
96 }
97 for _, model := range StatefulDBModels {
98 err = db.AutoMigrate(model)
99 if err != nil {
100 return nil, err
101 }
102 }
103
104 err = db.Use(prometheus.New(prometheus.Config{
105 DBName: "state",
106 RefreshInterval: 10,
107 StartServer: false,
108 }))
109 if err != nil {
110 return nil, fmt.Errorf("error using prometheus plugin: %w", err)
111 }
112
113 state := &StatefulDB{
114 DB: db,
115 CLI: cli,
116 Type: dbType,
117 locks: NewNamedLocks(),
118 model: model,
119 pokeQueue: make(chan struct{}, 1),
120 noter: noter,
121 }
122 if state.Type == DBTypePostgres {
123 err = state.startPostgresLockerConn(ctx)
124 if err != nil {
125 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
126 }
127 }
128 return state, nil
129}
130
131func openDB(dial gorm.Dialector) (*gorm.DB, error) {
132 return gorm.Open(dial, &gorm.Config{
133 SkipDefaultTransaction: true,
134 TranslateError: true,
135 Logger: config.GormLogger,
136 })
137}
138
139// helper function for creating the requested postgres database
140func makePostgresDB(dbURL string) (*gorm.DB, error) {
141 u, err := url.Parse(dbURL)
142 if err != nil {
143 return nil, err
144 }
145 dbName := strings.TrimPrefix(u.Path, "/")
146 u.Path = "/postgres"
147
148 rootDial := postgres.Open(u.String())
149
150 db, err := openDB(rootDial)
151 if err != nil {
152 return nil, err
153 }
154
155 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
156 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
157 if err != nil {
158 return nil, err
159 }
160
161 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
162
163 realDial := postgres.Open(dbURL)
164
165 return openDB(realDial)
166}
167
168func redactDBURL(dbURL string) string {
169 u, err := url.Parse(dbURL)
170 if err != nil {
171 return "db url is malformed"
172 }
173 if u.User != nil {
174 u.User = url.UserPassword(u.User.Username(), "redacted")
175 }
176 return u.String()
177}