Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "os"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/lmittmann/tint"
13 slogGorm "github.com/orandin/slog-gorm"
14 "github.com/streamplace/oatproxy/pkg/oatproxy"
15 "gorm.io/driver/postgres"
16 "gorm.io/driver/sqlite"
17 "gorm.io/gorm"
18 "stream.place/streamplace/pkg/config"
19 "stream.place/streamplace/pkg/log"
20 "stream.place/streamplace/pkg/model"
21 notificationpkg "stream.place/streamplace/pkg/notifications"
22)
23
24type DBType string
25
26const (
27 DBTypeSQLite DBType = "sqlite"
28 DBTypePostgres DBType = "postgres"
29)
30
31type StatefulDB struct {
32 DB *gorm.DB
33 CLI *config.CLI
34 Type DBType
35 locks *NamedLocks
36 noter notificationpkg.FirebaseNotifier
37 model model.Model
38 // pokeQueue is used to wake up the queue processor when a new task is enqueued
39 pokeQueue chan struct{}
40 // pgLockConn is used to hold a connection to the database for locking
41 pgLockConn *gorm.DB
42 pgLockConnMu sync.Mutex
43}
44
45// list tables here so we can migrate them
46var StatefulDBModels = []any{
47 oatproxy.OAuthSession{},
48 Notification{},
49 Config{},
50 XrpcStreamEvent{},
51 AppTask{},
52 Repo{},
53}
54
55var NoPostgresDatabaseCode = "3D000"
56
57// Stateful database for storing private streamplace state
58func MakeDB(ctx context.Context, cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
59 dbURL := cli.DBURL
60 log.Log(ctx, "starting stateful database", "dbURL", redactDBURL(dbURL))
61 var dial gorm.Dialector
62 var dbType DBType
63 if dbURL == ":memory:" {
64 dial = sqlite.Open(":memory:")
65 dbType = DBTypeSQLite
66 } else if strings.HasPrefix(dbURL, "sqlite://") {
67 dial = sqlite.Open(dbURL[len("sqlite://"):])
68 dbType = DBTypeSQLite
69 } else if strings.HasPrefix(dbURL, "postgres://") || strings.HasPrefix(dbURL, "postgresql://") {
70 dial = postgres.Open(dbURL)
71 dbType = DBTypePostgres
72 } else {
73 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgresql://): %s", redactDBURL(dbURL))
74 }
75
76 db, err := openDB(dial)
77
78 if err != nil {
79 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
80 db, err = makePostgresDB(dbURL)
81 if err != nil {
82 return nil, fmt.Errorf("error creating streamplace database: %w", err)
83 }
84 } else {
85 return nil, fmt.Errorf("error starting database: %w", err)
86 }
87 }
88 if dbType == DBTypeSQLite {
89 err = db.Exec("PRAGMA journal_mode=WAL;").Error
90 if err != nil {
91 return nil, fmt.Errorf("error setting journal mode: %w", err)
92 }
93 sqlDB, err := db.DB()
94 if err != nil {
95 return nil, fmt.Errorf("error getting database: %w", err)
96 }
97 sqlDB.SetMaxOpenConns(1)
98 }
99 for _, model := range StatefulDBModels {
100 err = db.AutoMigrate(model)
101 if err != nil {
102 return nil, err
103 }
104 }
105 state := &StatefulDB{
106 DB: db,
107 CLI: cli,
108 Type: dbType,
109 locks: NewNamedLocks(),
110 model: model,
111 pokeQueue: make(chan struct{}, 1),
112 }
113 if state.Type == DBTypePostgres {
114 err = state.startPostgresLockerConn(ctx)
115 if err != nil {
116 return nil, fmt.Errorf("error starting postgres locker connection: %w", err)
117 }
118 }
119 return state, nil
120}
121
122func openDB(dial gorm.Dialector) (*gorm.DB, error) {
123 gormLogger := slogGorm.New(
124 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{
125 TimeFormat: time.RFC3339,
126 })),
127 slogGorm.WithTraceAll(),
128 )
129
130 return gorm.Open(dial, &gorm.Config{
131 SkipDefaultTransaction: true,
132 TranslateError: true,
133 Logger: gormLogger,
134 })
135}
136
137// helper function for creating the requested postgres database
138func makePostgresDB(dbURL string) (*gorm.DB, error) {
139 u, err := url.Parse(dbURL)
140 if err != nil {
141 return nil, err
142 }
143 dbName := strings.TrimPrefix(u.Path, "/")
144 u.Path = "/postgres"
145
146 rootDial := postgres.Open(u.String())
147
148 db, err := openDB(rootDial)
149 if err != nil {
150 return nil, err
151 }
152
153 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
154 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
155 if err != nil {
156 return nil, err
157 }
158
159 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
160
161 realDial := postgres.Open(dbURL)
162
163 return openDB(realDial)
164}
165
166func redactDBURL(dbURL string) string {
167 u, err := url.Parse(dbURL)
168 if err != nil {
169 return "db url is malformed"
170 }
171 if u.User != nil {
172 u.User = url.UserPassword(u.User.Username(), "redacted")
173 }
174 return u.String()
175}