Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "strings"
8 "sync"
9
10 "github.com/streamplace/oatproxy/pkg/oatproxy"
11 "gorm.io/driver/postgres"
12 "gorm.io/driver/sqlite"
13 "gorm.io/gorm"
14 "gorm.io/plugin/prometheus"
15 "stream.place/streamplace/pkg/config"
16 "stream.place/streamplace/pkg/log"
17 "stream.place/streamplace/pkg/model"
18 notificationpkg "stream.place/streamplace/pkg/notifications"
19)
20
21type DBType string
22
23const (
24 DBTypeSQLite DBType = "sqlite"
25 DBTypePostgres DBType = "postgres"
26)
27
28type StatefulDB struct {
29 DB *gorm.DB
30 CLI *config.CLI
31 Type DBType
32 locks *NamedLocks
33 noter notificationpkg.FirebaseNotifier
34 model model.Model
35 // pokeQueue is used to wake up the queue processor when a new task is enqueued
36 pokeQueue chan struct{}
37 // pgLockConn is used to hold a connection to the database for locking
38 pgLockConn *gorm.DB
39 pgLockConnMu sync.Mutex
40 op *oatproxy.OATProxy
41}
42
43// list tables here so we can migrate them
44var StatefulDBModels = []any{
45 oatproxy.OAuthSession{},
46 Notification{},
47 Config{},
48 XrpcStreamEvent{},
49 AppTask{},
50 Repo{},
51 Webhook{},
52 MultistreamTarget{},
53 MultistreamEvent{},
54 BrandingBlob{},
55 ModerationAuditLog{},
56}
57
58var NoPostgresDatabaseCode = "3D000"
59
60// Stateful database for storing private streamplace state
61func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
62 dbURL := cli.DBURL
63 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
64 var dial gorm.Dialector
65 var dbType DBType
66 if dbURL == ":memory:" {
67 dial = sqlite.Open(":memory:")
68 dbType = DBTypeSQLite
69 } else if strings.HasPrefix(dbURL, "sqlite://") {
70 dial = sqlite.Open(dbURL[len("sqlite://"):])
71 dbType = DBTypeSQLite
72 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
73 dial = postgres.Open(dbURL)
74 dbType = DBTypePostgres
75 } else {
76 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
77 }
78
79 db, err := openDB(dial)
80
81 if err != nil {
82 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
83 db, err = makePostgresDB(dbURL)
84 if err != nil {
85 return nil, fmt.Errorf("error creating streamplace database: %w", err)
86 }
87 } else {
88 return nil, fmt.Errorf("error starting database: %w", err)
89 }
90 }
91 if dbType == DBTypeSQLite {
92 err = db.Exec("PRAGMA journal_mode=WAL;").Error
93 if err != nil {
94 return nil, fmt.Errorf("error setting journal mode: %w", err)
95 }
96 sqlDB, err := db.DB()
97 if err != nil {
98 return nil, fmt.Errorf("error getting database: %w", err)
99 }
100 sqlDB.SetMaxOpenConns(1)
101 }
102 for _, model := range StatefulDBModels {
103 err = db.AutoMigrate(model)
104 if err != nil {
105 return nil, err
106 }
107 }
108
109 err = db.Use(prometheus.New(prometheus.Config{
110 DBName: "state",
111 RefreshInterval: 10,
112 StartServer: false,
113 }))
114 if err != nil {
115 return nil, fmt.Errorf("error using prometheus plugin: %w", err)
116 }
117
118 state := &StatefulDB{
119 DB: db,
120 CLI: cli,
121 Type: dbType,
122 locks: NewNamedLocks(),
123 model: model,
124 pokeQueue: make(chan struct{}, 1),
125 noter: noter,
126 }
127 if state.Type == DBTypePostgres {
128 err = state.startPostgresLockerConn(ctx)
129 if err != nil {
130 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
131 }
132 }
133 return state, nil
134}
135
136func openDB(dial gorm.Dialector) (*gorm.DB, error) {
137 return gorm.Open(dial, &gorm.Config{
138 SkipDefaultTransaction: true,
139 TranslateError: true,
140 Logger: config.GormLogger,
141 })
142}
143
144// helper function for creating the requested postgres database
145func makePostgresDB(dbURL string) (*gorm.DB, error) {
146 u, err := url.Parse(dbURL)
147 if err != nil {
148 return nil, err
149 }
150 dbName := strings.TrimPrefix(u.Path, "/")
151 u.Path = "/postgres"
152
153 rootDial := postgres.Open(u.String())
154
155 db, err := openDB(rootDial)
156 if err != nil {
157 return nil, err
158 }
159
160 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
161 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
162 if err != nil {
163 return nil, err
164 }
165
166 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
167
168 realDial := postgres.Open(dbURL)
169
170 return openDB(realDial)
171}
172
173func redactDBURL(dbURL string) string {
174 u, err := url.Parse(dbURL)
175 if err != nil {
176 return "db url is malformed"
177 }
178 if u.User != nil {
179 u.User = url.UserPassword(u.User.Username(), "redacted")
180 }
181 return u.String()
182}