Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "strings"
8 "sync"
9
10 "github.com/streamplace/oatproxy/pkg/oatproxy"
11 "gorm.io/driver/postgres"
12 "gorm.io/driver/sqlite"
13 "gorm.io/gorm"
14 "stream.place/streamplace/pkg/config"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/model"
17 notificationpkg "stream.place/streamplace/pkg/notifications"
18)
19
20type DBType string
21
22const (
23 DBTypeSQLite DBType = "sqlite"
24 DBTypePostgres DBType = "postgres"
25)
26
27type StatefulDB struct {
28 DB *gorm.DB
29 CLI *config.CLI
30 Type DBType
31 locks *NamedLocks
32 noter notificationpkg.FirebaseNotifier
33 model model.Model
34 // pokeQueue is used to wake up the queue processor when a new task is enqueued
35 pokeQueue chan struct{}
36 // pgLockConn is used to hold a connection to the database for locking
37 pgLockConn *gorm.DB
38 pgLockConnMu sync.Mutex
39}
40
41// list tables here so we can migrate them
42var StatefulDBModels = []any{
43 oatproxy.OAuthSession{},
44 Notification{},
45 Config{},
46 XrpcStreamEvent{},
47 AppTask{},
48 Repo{},
49}
50
51var NoPostgresDatabaseCode = "3D000"
52
53// Stateful database for storing private streamplace state
54func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
55 dbURL := cli.DBURL
56 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
57 var dial gorm.Dialector
58 var dbType DBType
59 if dbURL == ":memory:" {
60 dial = sqlite.Open(":memory:")
61 dbType = DBTypeSQLite
62 } else if strings.HasPrefix(dbURL, "sqlite://") {
63 dial = sqlite.Open(dbURL[len("sqlite://"):])
64 dbType = DBTypeSQLite
65 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
66 dial = postgres.Open(dbURL)
67 dbType = DBTypePostgres
68 } else {
69 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
70 }
71
72 db, err := openDB(dial)
73
74 if err != nil {
75 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
76 db, err = makePostgresDB(dbURL)
77 if err != nil {
78 return nil, fmt.Errorf("error creating streamplace database: %w", err)
79 }
80 } else {
81 return nil, fmt.Errorf("error starting database: %w", err)
82 }
83 }
84 if dbType == DBTypeSQLite {
85 err = db.Exec("PRAGMA journal_mode=WAL;").Error
86 if err != nil {
87 return nil, fmt.Errorf("error setting journal mode: %w", err)
88 }
89 sqlDB, err := db.DB()
90 if err != nil {
91 return nil, fmt.Errorf("error getting database: %w", err)
92 }
93 sqlDB.SetMaxOpenConns(1)
94 }
95 for _, model := range StatefulDBModels {
96 err = db.AutoMigrate(model)
97 if err != nil {
98 return nil, err
99 }
100 }
101 state := &StatefulDB{
102 DB: db,
103 CLI: cli,
104 Type: dbType,
105 locks: NewNamedLocks(),
106 model: model,
107 pokeQueue: make(chan struct{}, 1),
108 }
109 if state.Type == DBTypePostgres {
110 err = state.startPostgresLockerConn(ctx)
111 if err != nil {
112 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
113 }
114 }
115 return state, nil
116}
117
118func openDB(dial gorm.Dialector) (*gorm.DB, error) {
119 return gorm.Open(dial, &gorm.Config{
120 SkipDefaultTransaction: true,
121 TranslateError: true,
122 Logger: config.GormLogger,
123 })
124}
125
126// helper function for creating the requested postgres database
127func makePostgresDB(dbURL string) (*gorm.DB, error) {
128 u, err := url.Parse(dbURL)
129 if err != nil {
130 return nil, err
131 }
132 dbName := strings.TrimPrefix(u.Path, "/")
133 u.Path = "/postgres"
134
135 rootDial := postgres.Open(u.String())
136
137 db, err := openDB(rootDial)
138 if err != nil {
139 return nil, err
140 }
141
142 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
143 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
144 if err != nil {
145 return nil, err
146 }
147
148 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
149
150 realDial := postgres.Open(dbURL)
151
152 return openDB(realDial)
153}
154
155func redactDBURL(dbURL string) string {
156 u, err := url.Parse(dbURL)
157 if err != nil {
158 return "db url is malformed"
159 }
160 if u.User != nil {
161 u.User = url.UserPassword(u.User.Username(), "redacted")
162 }
163 return u.String()
164}