Live video on the AT Protocol
1package statedb
2
3import (
4 "context"
5 "fmt"
6 "net/url"
7 "os"
8 "strings"
9 "time"
10
11 "github.com/lmittmann/tint"
12 slogGorm "github.com/orandin/slog-gorm"
13 "github.com/streamplace/oatproxy/pkg/oatproxy"
14 "gorm.io/driver/postgres"
15 "gorm.io/driver/sqlite"
16 "gorm.io/gorm"
17 "stream.place/streamplace/pkg/config"
18 "stream.place/streamplace/pkg/log"
19 "stream.place/streamplace/pkg/model"
20 notificationpkg "stream.place/streamplace/pkg/notifications"
21)
22
23type DBType string
24
25const (
26 DBTypeSQLite DBType = "sqlite"
27 DBTypePostgres DBType = "postgres"
28)
29
30type StatefulDB struct {
31 DB *gorm.DB
32 CLI *config.CLI
33 Type DBType
34 locks *NamedLocks
35 noter notificationpkg.FirebaseNotifier
36 model model.Model
37 // pokeQueue is used to wake up the queue processor when a new task is enqueued
38 pokeQueue chan struct{}
39}
40
41// list tables here so we can migrate them
42var StatefulDBModels = []any{
43 oatproxy.OAuthSession{},
44 Notification{},
45 Config{},
46 XrpcStreamEvent{},
47 AppTask{},
48}
49
50var NoPostgresDatabaseCode = "3D000"
51
52// Stateful database for storing private streamplace state
53func MakeDB(cli *config.CLI, noter notificationpkg.FirebaseNotifier, model model.Model) (*StatefulDB, error) {
54 dbURL := cli.DBURL
55 log.Log(context.Background(), "starting stateful database", "dbURL", redactDBURL(dbURL))
56 var dial gorm.Dialector
57 var dbType DBType
58 if dbURL == ":memory:" {
59 dial = sqlite.Open(":memory:")
60 dbType = DBTypeSQLite
61 } else if strings.HasPrefix(dbURL, "sqlite://") {
62 dial = sqlite.Open(dbURL[len("sqlite://"):])
63 dbType = DBTypeSQLite
64 } else if strings.HasPrefix(dbURL, "postgres://") {
65 dial = postgres.Open(dbURL)
66 dbType = DBTypePostgres
67 } else {
68 return nil, fmt.Errorf("unsupported database URL (most start with sqlite:// or postgres://): %s", redactDBURL(dbURL))
69 }
70
71 db, err := openDB(dial)
72
73 if err != nil {
74 if dbType == DBTypePostgres && strings.Contains(err.Error(), NoPostgresDatabaseCode) {
75 db, err = makePostgresDB(dbURL)
76 if err != nil {
77 return nil, fmt.Errorf("error creating streamplace database: %w", err)
78 }
79 } else {
80 return nil, fmt.Errorf("error starting database: %w", err)
81 }
82 }
83 if dbType == DBTypeSQLite {
84 err = db.Exec("PRAGMA journal_mode=WAL;").Error
85 if err != nil {
86 return nil, fmt.Errorf("error setting journal mode: %w", err)
87 }
88 sqlDB, err := db.DB()
89 if err != nil {
90 return nil, fmt.Errorf("error getting database: %w", err)
91 }
92 sqlDB.SetMaxOpenConns(1)
93 }
94 for _, model := range StatefulDBModels {
95 err = db.AutoMigrate(model)
96 if err != nil {
97 return nil, err
98 }
99 }
100 return &StatefulDB{
101 DB: db,
102 CLI: cli,
103 Type: dbType,
104 locks: NewNamedLocks(),
105 model: model,
106 pokeQueue: make(chan struct{}, 1),
107 }, nil
108}
109
110func openDB(dial gorm.Dialector) (*gorm.DB, error) {
111 gormLogger := slogGorm.New(
112 slogGorm.WithHandler(tint.NewHandler(os.Stderr, &tint.Options{
113 TimeFormat: time.RFC3339,
114 })),
115 // slogGorm.WithTraceAll(),
116 )
117
118 return gorm.Open(dial, &gorm.Config{
119 SkipDefaultTransaction: true,
120 TranslateError: true,
121 Logger: gormLogger,
122 })
123}
124
125// helper function for creating the requested postgres database
126func makePostgresDB(dbURL string) (*gorm.DB, error) {
127 u, err := url.Parse(dbURL)
128 if err != nil {
129 return nil, err
130 }
131 dbName := strings.TrimPrefix(u.Path, "/")
132 u.Path = "/postgres"
133
134 rootDial := postgres.Open(u.String())
135
136 db, err := openDB(rootDial)
137 if err != nil {
138 return nil, err
139 }
140
141 // postgres doesn't support prepared statements for CREATE DATABASE. don't SQL inject yourself.
142 err = db.Exec(fmt.Sprintf("CREATE DATABASE %s;", dbName)).Error
143 if err != nil {
144 return nil, err
145 }
146
147 log.Warn(context.Background(), "created postgres database", "dbName", dbName)
148
149 realDial := postgres.Open(dbURL)
150
151 return openDB(realDial)
152}
153
154func redactDBURL(dbURL string) string {
155 u, err := url.Parse(dbURL)
156 if err != nil {
157 return "db url is malformed"
158 }
159 if u.User != nil {
160 u.User = url.UserPassword(u.User.Username(), "redacted")
161 }
162 return u.String()
163}