Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "strings"
8 "sync"
9
10 "github.com/streamplace/oatproxy/pkg/oatproxy"
11 "gorm.io/driver/postgres"
12 "gorm.io/driver/sqlite"
13 "gorm.io/gorm"
14 "gorm.io/plugin/prometheus"
15 "stream.place/streamplace/pkg/config"
16 "stream.place/streamplace/pkg/log"
17 "stream.place/streamplace/pkg/model"
18 notificationpkg "stream.place/streamplace/pkg/notifications"
19)
20
21type DBType string
22
23const (
24 DBTypeSQLite DBType = "sqlite"
25 DBTypePostgres DBType = "postgres"
26)
27
28type StatefulDB struct {
29 DB *gorm.DB
30 CLI *config.CLI
31 Type DBType
32 locks *NamedLocks
33 noter notificationpkg.FirebaseNotifier
34 model model.Model
35 // pokeQueue is used to wake up the queue processor when a new task is enqueued
36 pokeQueue chan struct{}
37 // pgLockConn is used to hold a connection to the database for locking
38 pgLockConn *gorm.DB
39 pgLockConnMu sync.Mutex
40 op *oatproxy.OATProxy
41}
42
43// list tables here so we can migrate them
44var StatefulDBModels = []any{
45 oatproxy.OAuthSession{},
46 Notification{},
47 Config{},
48 XrpcStreamEvent{},
49 AppTask{},
50 Repo{},
51 Webhook{},
52}
53
54var NoPostgresDatabaseCode = "3D000"
55
56// Stateful database for storing private streamplace state
57func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
58 dbURL := cli.DBURL
59 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
60 var dial gorm.Dialector
61 var dbType DBType
62 if dbURL == ":memory:" {
63 dial = sqlite.Open(":memory:")
64 dbType = DBTypeSQLite
65 } else if strings.HasPrefix(dbURL, "sqlite://") {
66 dial = sqlite.Open(dbURL[len("sqlite://"):])
67 dbType = DBTypeSQLite
68 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
69 dial = postgres.Open(dbURL)
70 dbType = DBTypePostgres
71 } else {
72 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
73 }
74
75 db, err := openDB(dial)
76
77 if err != nil {
78 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
79 db, err = makePostgresDB(dbURL)
80 if err != nil {
81 return nil, fmt.Errorf("error creating streamplace database: %w", err)
82 }
83 } else {
84 return nil, fmt.Errorf("error starting database: %w", err)
85 }
86 }
87 if dbType == DBTypeSQLite {
88 err = db.Exec("PRAGMA journal_mode=WAL;").Error
89 if err != nil {
90 return nil, fmt.Errorf("error setting journal mode: %w", err)
91 }
92 sqlDB, err := db.DB()
93 if err != nil {
94 return nil, fmt.Errorf("error getting database: %w", err)
95 }
96 sqlDB.SetMaxOpenConns(1)
97 }
98 for _, model := range StatefulDBModels {
99 err = db.AutoMigrate(model)
100 if err != nil {
101 return nil, err
102 }
103 }
104
105 err = db.Use(prometheus.New(prometheus.Config{
106 DBName: "state",
107 RefreshInterval: 10,
108 StartServer: false,
109 }))
110 if err != nil {
111 return nil, fmt.Errorf("error using prometheus plugin: %w", err)
112 }
113
114 state := &StatefulDB{
115 DB: db,
116 CLI: cli,
117 Type: dbType,
118 locks: NewNamedLocks(),
119 model: model,
120 pokeQueue: make(chan struct{}, 1),
121 noter: noter,
122 }
123 if state.Type == DBTypePostgres {
124 err = state.startPostgresLockerConn(ctx)
125 if err != nil {
126 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
127 }
128 }
129 return state, nil
130}
131
132func openDB(dial gorm.Dialector) (*gorm.DB, error) {
133 return gorm.Open(dial, &gorm.Config{
134 SkipDefaultTransaction: true,
135 TranslateError: true,
136 Logger: config.GormLogger,
137 })
138}
139
140// helper function for creating the requested postgres database
141func makePostgresDB(dbURL string) (*gorm.DB, error) {
142 u, err := url.Parse(dbURL)
143 if err != nil {
144 return nil, err
145 }
146 dbName := strings.TrimPrefix(u.Path, "/")
147 u.Path = "/postgres"
148
149 rootDial := postgres.Open(u.String())
150
151 db, err := openDB(rootDial)
152 if err != nil {
153 return nil, err
154 }
155
156 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
157 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
158 if err != nil {
159 return nil, err
160 }
161
162 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
163
164 realDial := postgres.Open(dbURL)
165
166 return openDB(realDial)
167}
168
169func redactDBURL(dbURL string) string {
170 u, err := url.Parse(dbURL)
171 if err != nil {
172 return "db url is malformed"
173 }
174 if u.User != nil {
175 u.User = url.UserPassword(u.User.Username(), "redacted")
176 }
177 return u.String()
178}