beebo
1package sqlite
2
3import (
4 "database/sql"
5 "embed"
6 "fmt"
7 "io/fs"
8 "log"
9 "strings"
10 "time"
11
12 _ "github.com/glebarez/go-sqlite"
13)
14
15//go:embed migrations/*.sql
16var migrationFiles embed.FS
17
18type DB struct {
19 sql *sql.DB
20}
21
22type SavedItem struct {
23 ArchiveURL string
24 CreatedAt time.Time
25 ItemTitle string
26 ItemURL string
27}
28
29// New opens a sqlite database, populates it with tables, and
30// returns a ready-to-use *sqlite.DB object which is used for
31// abstracting database queries.
32func New(path string) *DB {
33 db, err := sql.Open("sqlite", path)
34 if err != nil {
35 log.Fatal(err)
36 }
37
38 _, err = db.Exec("CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY)")
39 if err != nil {
40 log.Fatal(err)
41 }
42
43 var latestVersion int
44 row := db.QueryRow("SELECT MAX(version) FROM schema_migrations")
45 err = row.Scan(&latestVersion)
46 if err != nil {
47 if strings.Contains(err.Error(), "converting NULL to int is unsupported") {
48 // assume that we're starting from ground zero
49 latestVersion = 0
50 } else {
51 log.Fatal(err)
52 }
53 }
54
55 files, err := fs.ReadDir(migrationFiles, "migrations")
56 if err != nil {
57 log.Fatal(err)
58 }
59 for _, f := range files {
60 var version int
61 _, err = fmt.Sscanf(f.Name(), "%d_", &version)
62 if err != nil {
63 log.Fatal(err)
64 }
65
66 // Apply migration if not already applied
67 if version > latestVersion {
68 fileData, _ := fs.ReadFile(migrationFiles, "migrations/"+f.Name())
69 _, err := db.Exec(string(fileData))
70 if err != nil {
71 log.Fatalf("Failed to apply migration %s: %v", f.Name(), err)
72 }
73 _, err = db.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version)
74 if err != nil {
75 log.Fatalf("Failed to record migration version %d: %v", version, err)
76 }
77 fmt.Printf("Applied migration %s\n", f.Name())
78 }
79 }
80
81 return &DB{sql: db}
82}
83
84func (db *DB) GetUsernameBySessionToken(token string) string {
85 var username string
86 err := db.sql.QueryRow("SELECT username FROM user WHERE session_token=?", token).Scan(&username)
87 if err == sql.ErrNoRows {
88 return ""
89 }
90 if err != nil {
91 log.Fatal(err)
92 }
93 return username
94}
95
96func (db *DB) GetPassword(username string) string {
97 var password string
98 err := db.sql.QueryRow("SELECT password FROM user WHERE username=?", username).Scan(&password)
99 if err == sql.ErrNoRows {
100 return ""
101 }
102 if err != nil {
103 log.Fatal(err)
104 }
105 return password
106}
107
108func (db *DB) GetSessionToken(username string) (string, error) {
109 var result sql.NullString
110 err := db.sql.QueryRow("SELECT session_token FROM user WHERE username=?", username).Scan(&result)
111 if err == sql.ErrNoRows {
112 return "", nil
113 }
114 return result.String, err
115}
116
117func (db *DB) SetSessionToken(username string, token string) error {
118 _, err := db.sql.Exec("UPDATE user SET session_token=? WHERE username=?", token, username)
119 return err
120}
121
122func (db *DB) AddUser(username string, passwordHash string) error {
123 _, err := db.sql.Exec("INSERT INTO user (username, password) VALUES (?, ?)", username, passwordHash)
124 return err
125}
126
127func (db *DB) subscribe(uid int, fid int) {
128 var id int
129 err := db.sql.QueryRow("SELECT id FROM subscribe WHERE user_id=? AND feed_id=?", uid, fid).Scan(&id)
130 if err == sql.ErrNoRows {
131 _, err := db.sql.Exec("INSERT INTO subscribe (user_id, feed_id) VALUES (?, ?)", uid, fid)
132 if err != nil {
133 log.Fatal(err)
134 }
135 return
136 }
137 if err != nil {
138 log.Fatal(err)
139 }
140}
141
142func (db *DB) unsubscribeAll(uid int) {
143 _, err := db.sql.Exec("DELETE FROM subscribe WHERE user_id=?", uid)
144 if err != nil {
145 log.Fatal(err)
146 }
147}
148
149func (db *DB) UserExists(username string) bool {
150 var result string
151 err := db.sql.QueryRow("SELECT username FROM user WHERE username=?", username).Scan(&result)
152 if err == sql.ErrNoRows {
153 return false
154 }
155 if err != nil {
156 log.Fatal(err)
157 }
158 return true
159}
160
161func (db *DB) GetAllFeedURLs() []string {
162 // TODO: BAD SELECT STATEMENT!! SORRY :( --wesley
163 rows, err := db.sql.Query("SELECT url FROM feed")
164 if err != nil {
165 log.Fatal(err)
166 }
167 defer rows.Close()
168
169 var urls []string
170 for rows.Next() {
171 var url string
172 err = rows.Scan(&url)
173 if err != nil {
174 log.Fatal(err)
175 }
176 urls = append(urls, url)
177 }
178 return urls
179}
180
181func (db *DB) GetUserFeedURLs(username string) []string {
182 uid := db.GetUserID(username)
183
184 // this query returns sql rows representing the list of
185 // rss feed urls the user is subscribed to
186 rows, err := db.sql.Query(`
187 SELECT f.url
188 FROM feed f
189 JOIN subscribe s ON f.id = s.feed_id
190 JOIN user u ON s.user_id = u.id
191 WHERE u.id = ?`, uid)
192 if err == sql.ErrNoRows {
193 return []string{}
194 }
195 if err != nil {
196 log.Fatal(err)
197 }
198 defer rows.Close()
199
200 var urls []string
201 for rows.Next() {
202 var url string
203 err = rows.Scan(&url)
204 if err != nil {
205 log.Fatal(err)
206 }
207 urls = append(urls, url)
208 }
209 return urls
210}
211
212func (db *DB) GetUserSavedItems(username string) []SavedItem {
213 uid := db.GetUserID(username)
214
215 rows, err := db.sql.Query(`SELECT item_url, item_title, archive_url, created_at
216 FROM saved_item WHERE user_id = ?
217 ORDER BY created_at DESC`, uid)
218 if err == sql.ErrNoRows {
219 return []SavedItem{}
220 }
221 if err != nil {
222 log.Fatal(err)
223 }
224 defer rows.Close()
225
226 var savedItems []SavedItem
227 for rows.Next() {
228 var si SavedItem
229 err = rows.Scan(&si.ItemURL, &si.ItemTitle, &si.ArchiveURL, &si.CreatedAt)
230 if err != nil {
231 log.Fatal(err)
232 }
233 savedItems = append(savedItems, si)
234 }
235 return savedItems
236}
237
238func (db *DB) GetUserID(username string) int {
239 var uid int
240 err := db.sql.QueryRow("SELECT id FROM user WHERE username=?", username).Scan(&uid)
241 if err != nil {
242 log.Fatal(err)
243 }
244 return uid
245}
246
247func (db *DB) GetFeedID(feedURL string) int {
248 var fid int
249 err := db.sql.QueryRow("SELECT id FROM feed WHERE url=?", feedURL).Scan(&fid)
250 if err != nil {
251 log.Fatal(err)
252 }
253 return fid
254}
255
256// WriteFeed writes an rss feed to the database for permanent storage
257// if the given feed already exists, WriteFeed does nothing.
258func (db *DB) WriteFeed(url string) {
259 _, err := db.sql.Exec(`INSERT INTO feed(url) VALUES(?)
260 ON CONFLICT(url) DO NOTHING`, url)
261 if err != nil {
262 log.Fatal(err)
263 }
264}
265
266func (db *DB) WriteSavedItem(username string, item SavedItem) error {
267 uid := db.GetUserID(username)
268
269 _, err := db.sql.Exec(`
270 INSERT INTO saved_item(user_id, item_url, item_title, archive_url)
271 VALUES(?, ?, ?, ?)`, uid, item.ItemURL, item.ItemTitle, item.ArchiveURL)
272
273 return err
274}
275
276// WriteFeed writes an rss feed to the database for permanent storage
277// if the given feed already exists, WriteFeed does nothing.
278func (db *DB) SetFeedFetchError(url string, fetchErr string) error {
279 _, err := db.sql.Exec("UPDATE feed SET fetch_error=? WHERE url=?", fetchErr, url)
280 if err != nil {
281 return err
282 }
283 return nil
284}
285
286// WriteFeed writes an rss feed to the database for permanent storage
287// if the given feed already exists, WriteFeed does nothing.
288func (db *DB) GetFeedFetchError(url string) (string, error) {
289 var result sql.NullString
290 err := db.sql.QueryRow("SELECT fetch_error FROM feed WHERE url=?", url).Scan(&result)
291 if err != nil {
292 return "", err
293 }
294 if result.Valid {
295 return result.String, nil
296 }
297 return "", nil
298}
299
300func (db *DB) GetSubscriberCount(feedURL string) int {
301 var count int
302 err := db.sql.QueryRow(`
303 SELECT COUNT(s.user_id)
304 FROM subscribe s
305 JOIN feed f ON s.feed_id = f.id
306 WHERE f.url = ?
307 `, feedURL).Scan(&count)
308 if err != nil {
309 log.Fatal(err)
310 }
311 return count
312}
313
314func (db *DB) GetFeedIDAndExists(feedURL string) (int, bool) {
315 var fid int
316 err := db.sql.QueryRow("SELECT id FROM feed WHERE url=?", feedURL).Scan(&fid)
317 if err == sql.ErrNoRows {
318 return 0, false
319 }
320 if err != nil {
321 log.Fatal(err)
322 }
323 return fid, true
324}
325
326func (db *DB) BatchSubscribe(username string, feedURLs []string) error {
327 tx, err := db.sql.Begin()
328 if err != nil {
329 return err
330 }
331
332 defer func() {
333 if err != nil {
334 tx.Rollback()
335 }
336 }()
337
338 // first, unsub from everything
339 uid := db.GetUserID(username)
340 db.unsubscribeAll(uid)
341
342 // Add new subscriptions
343 for _, url := range feedURLs {
344 db.subscribe(uid, db.GetFeedID(url))
345 }
346
347 return tx.Commit()
348}
349
350func (db *DB) MarkItemRead(username string, itemURL string) error {
351 uid := db.GetUserID(username)
352 _, err := db.sql.Exec(`
353 INSERT INTO read_item(user_id, item_url)
354 VALUES(?, ?)
355 ON CONFLICT(user_id, item_url) DO NOTHING`, uid, itemURL)
356 return err
357}
358
359func (db *DB) GetUserReadItems(username string) map[string]bool {
360 uid := db.GetUserID(username)
361 rows, err := db.sql.Query("SELECT item_url FROM read_item WHERE user_id = ?", uid)
362 if err != nil {
363 log.Fatal(err)
364 }
365 defer rows.Close()
366
367 readItems := make(map[string]bool)
368 for rows.Next() {
369 var itemURL string
370 err = rows.Scan(&itemURL)
371 if err != nil {
372 log.Fatal(err)
373 }
374 readItems[itemURL] = true
375 }
376 return readItems
377}