Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "strings"
8 "sync"
9
10 "github.com/streamplace/oatproxy/pkg/oatproxy"
11 "gorm.io/driver/postgres"
12 "gorm.io/driver/sqlite"
13 "gorm.io/gorm"
14 "stream.place/streamplace/pkg/config"
15 "stream.place/streamplace/pkg/log"
16 "stream.place/streamplace/pkg/model"
17 notificationpkg "stream.place/streamplace/pkg/notifications"
18)
19
20type DBType string
21
22const (
23 DBTypeSQLite DBType = "sqlite"
24 DBTypePostgres DBType = "postgres"
25)
26
27type StatefulDB struct {
28 DB *gorm.DB
29 CLI *config.CLI
30 Type DBType
31 locks *NamedLocks
32 noter notificationpkg.FirebaseNotifier
33 model model.Model
34 // pokeQueue is used to wake up the queue processor when a new task is enqueued
35 pokeQueue chan struct{}
36 // pgLockConn is used to hold a connection to the database for locking
37 pgLockConn *gorm.DB
38 pgLockConnMu sync.Mutex
39}
40
41// list tables here so we can migrate them
42var StatefulDBModels = []any{
43 oatproxy.OAuthSession{},
44 Notification{},
45 Config{},
46 XrpcStreamEvent{},
47 AppTask{},
48 Repo{},
49 Webhook{},
50}
51
52var NoPostgresDatabaseCode = "3D000"
53
54// Stateful database for storing private streamplace state
55func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
56 dbURL := cli.DBURL
57 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
58 var dial gorm.Dialector
59 var dbType DBType
60 if dbURL == ":memory:" {
61 dial = sqlite.Open(":memory:")
62 dbType = DBTypeSQLite
63 } else if strings.HasPrefix(dbURL, "sqlite://") {
64 dial = sqlite.Open(dbURL[len("sqlite://"):])
65 dbType = DBTypeSQLite
66 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
67 dial = postgres.Open(dbURL)
68 dbType = DBTypePostgres
69 } else {
70 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
71 }
72
73 db, err := openDB(dial)
74
75 if err != nil {
76 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
77 db, err = makePostgresDB(dbURL)
78 if err != nil {
79 return nil, fmt.Errorf("error creating streamplace database: %w", err)
80 }
81 } else {
82 return nil, fmt.Errorf("error starting database: %w", err)
83 }
84 }
85 if dbType == DBTypeSQLite {
86 err = db.Exec("PRAGMA journal_mode=WAL;").Error
87 if err != nil {
88 return nil, fmt.Errorf("error setting journal mode: %w", err)
89 }
90 sqlDB, err := db.DB()
91 if err != nil {
92 return nil, fmt.Errorf("error getting database: %w", err)
93 }
94 sqlDB.SetMaxOpenConns(1)
95 }
96 for _, model := range StatefulDBModels {
97 err = db.AutoMigrate(model)
98 if err != nil {
99 return nil, err
100 }
101 }
102 state := &StatefulDB{
103 DB: db,
104 CLI: cli,
105 Type: dbType,
106 locks: NewNamedLocks(),
107 model: model,
108 pokeQueue: make(chan struct{}, 1),
109 noter: noter,
110 }
111 if state.Type == DBTypePostgres {
112 err = state.startPostgresLockerConn(ctx)
113 if err != nil {
114 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
115 }
116 }
117 return state, nil
118}
119
120func openDB(dial gorm.Dialector) (*gorm.DB, error) {
121 return gorm.Open(dial, &gorm.Config{
122 SkipDefaultTransaction: true,
123 TranslateError: true,
124 Logger: config.GormLogger,
125 })
126}
127
128// helper function for creating the requested postgres database
129func makePostgresDB(dbURL string) (*gorm.DB, error) {
130 u, err := url.Parse(dbURL)
131 if err != nil {
132 return nil, err
133 }
134 dbName := strings.TrimPrefix(u.Path, "/")
135 u.Path = "/postgres"
136
137 rootDial := postgres.Open(u.String())
138
139 db, err := openDB(rootDial)
140 if err != nil {
141 return nil, err
142 }
143
144 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
145 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
146 if err != nil {
147 return nil, err
148 }
149
150 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
151
152 realDial := postgres.Open(dbURL)
153
154 return openDB(realDial)
155}
156
157func redactDBURL(dbURL string) string {
158 u, err := url.Parse(dbURL)
159 if err != nil {
160 return "db url is malformed"
161 }
162 if u.User != nil {
163 u.User = url.UserPassword(u.User.Username(), "redacted")
164 }
165 return u.String()
166}