at main 8.7 kB view raw
1package sqlite 2 3import ( 4 "database/sql" 5 "embed" 6 "fmt" 7 "io/fs" 8 "log" 9 "strings" 10 "time" 11 12 _ "github.com/glebarez/go-sqlite" 13) 14 15//go:embed migrations/*.sql 16var migrationFiles embed.FS 17 18type DB struct { 19 sql *sql.DB 20} 21 22type SavedItem struct { 23 ArchiveURL string 24 CreatedAt time.Time 25 ItemTitle string 26 ItemURL string 27} 28 29// New opens a sqlite database, populates it with tables, and 30// returns a ready-to-use *sqlite.DB object which is used for 31// abstracting database queries. 32func New(path string) *DB { 33 db, err := sql.Open("sqlite", path) 34 if err != nil { 35 log.Fatal(err) 36 } 37 38 _, err = db.Exec("CREATE TABLE IF NOT EXISTS schema_migrations (version INTEGER PRIMARY KEY)") 39 if err != nil { 40 log.Fatal(err) 41 } 42 43 var latestVersion int 44 row := db.QueryRow("SELECT MAX(version) FROM schema_migrations") 45 err = row.Scan(&latestVersion) 46 if err != nil { 47 if strings.Contains(err.Error(), "converting NULL to int is unsupported") { 48 // assume that we're starting from ground zero 49 latestVersion = 0 50 } else { 51 log.Fatal(err) 52 } 53 } 54 55 files, err := fs.ReadDir(migrationFiles, "migrations") 56 if err != nil { 57 log.Fatal(err) 58 } 59 for _, f := range files { 60 var version int 61 _, err = fmt.Sscanf(f.Name(), "%d_", &version) 62 if err != nil { 63 log.Fatal(err) 64 } 65 66 // Apply migration if not already applied 67 if version > latestVersion { 68 fileData, _ := fs.ReadFile(migrationFiles, "migrations/"+f.Name()) 69 _, err := db.Exec(string(fileData)) 70 if err != nil { 71 log.Fatalf("Failed to apply migration %s: %v", f.Name(), err) 72 } 73 _, err = db.Exec(`INSERT INTO schema_migrations (version) VALUES (?)`, version) 74 if err != nil { 75 log.Fatalf("Failed to record migration version %d: %v", version, err) 76 } 77 fmt.Printf("Applied migration %s\n", f.Name()) 78 } 79 } 80 81 return &DB{sql: db} 82} 83 84func (db *DB) GetUsernameBySessionToken(token string) string { 85 var username string 86 err := db.sql.QueryRow("SELECT username FROM user WHERE session_token=?", token).Scan(&username) 87 if err == sql.ErrNoRows { 88 return "" 89 } 90 if err != nil { 91 log.Fatal(err) 92 } 93 return username 94} 95 96func (db *DB) GetPassword(username string) string { 97 var password string 98 err := db.sql.QueryRow("SELECT password FROM user WHERE username=?", username).Scan(&password) 99 if err == sql.ErrNoRows { 100 return "" 101 } 102 if err != nil { 103 log.Fatal(err) 104 } 105 return password 106} 107 108func (db *DB) GetSessionToken(username string) (string, error) { 109 var result sql.NullString 110 err := db.sql.QueryRow("SELECT session_token FROM user WHERE username=?", username).Scan(&result) 111 if err == sql.ErrNoRows { 112 return "", nil 113 } 114 return result.String, err 115} 116 117func (db *DB) SetSessionToken(username string, token string) error { 118 _, err := db.sql.Exec("UPDATE user SET session_token=? WHERE username=?", token, username) 119 return err 120} 121 122func (db *DB) AddUser(username string, passwordHash string) error { 123 _, err := db.sql.Exec("INSERT INTO user (username, password) VALUES (?, ?)", username, passwordHash) 124 return err 125} 126 127func (db *DB) subscribe(uid int, fid int) { 128 var id int 129 err := db.sql.QueryRow("SELECT id FROM subscribe WHERE user_id=? AND feed_id=?", uid, fid).Scan(&id) 130 if err == sql.ErrNoRows { 131 _, err := db.sql.Exec("INSERT INTO subscribe (user_id, feed_id) VALUES (?, ?)", uid, fid) 132 if err != nil { 133 log.Fatal(err) 134 } 135 return 136 } 137 if err != nil { 138 log.Fatal(err) 139 } 140} 141 142func (db *DB) unsubscribeAll(uid int) { 143 _, err := db.sql.Exec("DELETE FROM subscribe WHERE user_id=?", uid) 144 if err != nil { 145 log.Fatal(err) 146 } 147} 148 149func (db *DB) UserExists(username string) bool { 150 var result string 151 err := db.sql.QueryRow("SELECT username FROM user WHERE username=?", username).Scan(&result) 152 if err == sql.ErrNoRows { 153 return false 154 } 155 if err != nil { 156 log.Fatal(err) 157 } 158 return true 159} 160 161func (db *DB) GetAllFeedURLs() []string { 162 // TODO: BAD SELECT STATEMENT!! SORRY :( --wesley 163 rows, err := db.sql.Query("SELECT url FROM feed") 164 if err != nil { 165 log.Fatal(err) 166 } 167 defer rows.Close() 168 169 var urls []string 170 for rows.Next() { 171 var url string 172 err = rows.Scan(&url) 173 if err != nil { 174 log.Fatal(err) 175 } 176 urls = append(urls, url) 177 } 178 return urls 179} 180 181func (db *DB) GetUserFeedURLs(username string) []string { 182 uid := db.GetUserID(username) 183 184 // this query returns sql rows representing the list of 185 // rss feed urls the user is subscribed to 186 rows, err := db.sql.Query(` 187 SELECT f.url 188 FROM feed f 189 JOIN subscribe s ON f.id = s.feed_id 190 JOIN user u ON s.user_id = u.id 191 WHERE u.id = ?`, uid) 192 if err == sql.ErrNoRows { 193 return []string{} 194 } 195 if err != nil { 196 log.Fatal(err) 197 } 198 defer rows.Close() 199 200 var urls []string 201 for rows.Next() { 202 var url string 203 err = rows.Scan(&url) 204 if err != nil { 205 log.Fatal(err) 206 } 207 urls = append(urls, url) 208 } 209 return urls 210} 211 212func (db *DB) GetUserSavedItems(username string) []SavedItem { 213 uid := db.GetUserID(username) 214 215 rows, err := db.sql.Query(`SELECT item_url, item_title, archive_url, created_at 216 FROM saved_item WHERE user_id = ? 217 ORDER BY created_at DESC`, uid) 218 if err == sql.ErrNoRows { 219 return []SavedItem{} 220 } 221 if err != nil { 222 log.Fatal(err) 223 } 224 defer rows.Close() 225 226 var savedItems []SavedItem 227 for rows.Next() { 228 var si SavedItem 229 err = rows.Scan(&si.ItemURL, &si.ItemTitle, &si.ArchiveURL, &si.CreatedAt) 230 if err != nil { 231 log.Fatal(err) 232 } 233 savedItems = append(savedItems, si) 234 } 235 return savedItems 236} 237 238func (db *DB) GetUserID(username string) int { 239 var uid int 240 err := db.sql.QueryRow("SELECT id FROM user WHERE username=?", username).Scan(&uid) 241 if err != nil { 242 log.Fatal(err) 243 } 244 return uid 245} 246 247func (db *DB) GetFeedID(feedURL string) int { 248 var fid int 249 err := db.sql.QueryRow("SELECT id FROM feed WHERE url=?", feedURL).Scan(&fid) 250 if err != nil { 251 log.Fatal(err) 252 } 253 return fid 254} 255 256// WriteFeed writes an rss feed to the database for permanent storage 257// if the given feed already exists, WriteFeed does nothing. 258func (db *DB) WriteFeed(url string) { 259 _, err := db.sql.Exec(`INSERT INTO feed(url) VALUES(?) 260 ON CONFLICT(url) DO NOTHING`, url) 261 if err != nil { 262 log.Fatal(err) 263 } 264} 265 266func (db *DB) WriteSavedItem(username string, item SavedItem) error { 267 uid := db.GetUserID(username) 268 269 _, err := db.sql.Exec(` 270 INSERT INTO saved_item(user_id, item_url, item_title, archive_url) 271 VALUES(?, ?, ?, ?)`, uid, item.ItemURL, item.ItemTitle, item.ArchiveURL) 272 273 return err 274} 275 276// WriteFeed writes an rss feed to the database for permanent storage 277// if the given feed already exists, WriteFeed does nothing. 278func (db *DB) SetFeedFetchError(url string, fetchErr string) error { 279 _, err := db.sql.Exec("UPDATE feed SET fetch_error=? WHERE url=?", fetchErr, url) 280 if err != nil { 281 return err 282 } 283 return nil 284} 285 286// WriteFeed writes an rss feed to the database for permanent storage 287// if the given feed already exists, WriteFeed does nothing. 288func (db *DB) GetFeedFetchError(url string) (string, error) { 289 var result sql.NullString 290 err := db.sql.QueryRow("SELECT fetch_error FROM feed WHERE url=?", url).Scan(&result) 291 if err != nil { 292 return "", err 293 } 294 if result.Valid { 295 return result.String, nil 296 } 297 return "", nil 298} 299 300func (db *DB) GetSubscriberCount(feedURL string) int { 301 var count int 302 err := db.sql.QueryRow(` 303 SELECT COUNT(s.user_id) 304 FROM subscribe s 305 JOIN feed f ON s.feed_id = f.id 306 WHERE f.url = ? 307 `, feedURL).Scan(&count) 308 if err != nil { 309 log.Fatal(err) 310 } 311 return count 312} 313 314func (db *DB) GetFeedIDAndExists(feedURL string) (int, bool) { 315 var fid int 316 err := db.sql.QueryRow("SELECT id FROM feed WHERE url=?", feedURL).Scan(&fid) 317 if err == sql.ErrNoRows { 318 return 0, false 319 } 320 if err != nil { 321 log.Fatal(err) 322 } 323 return fid, true 324} 325 326func (db *DB) BatchSubscribe(username string, feedURLs []string) error { 327 tx, err := db.sql.Begin() 328 if err != nil { 329 return err 330 } 331 332 defer func() { 333 if err != nil { 334 tx.Rollback() 335 } 336 }() 337 338 // first, unsub from everything 339 uid := db.GetUserID(username) 340 db.unsubscribeAll(uid) 341 342 // Add new subscriptions 343 for _, url := range feedURLs { 344 db.subscribe(uid, db.GetFeedID(url)) 345 } 346 347 return tx.Commit() 348} 349 350func (db *DB) MarkItemRead(username string, itemURL string) error { 351 uid := db.GetUserID(username) 352 _, err := db.sql.Exec(` 353 INSERT INTO read_item(user_id, item_url) 354 VALUES(?, ?) 355 ON CONFLICT(user_id, item_url) DO NOTHING`, uid, itemURL) 356 return err 357} 358 359func (db *DB) GetUserReadItems(username string) map[string]bool { 360 uid := db.GetUserID(username) 361 rows, err := db.sql.Query("SELECT item_url FROM read_item WHERE user_id = ?", uid) 362 if err != nil { 363 log.Fatal(err) 364 } 365 defer rows.Close() 366 367 readItems := make(map[string]bool) 368 for rows.Next() { 369 var itemURL string 370 err = rows.Scan(&itemURL) 371 if err != nil { 372 log.Fatal(err) 373 } 374 readItems[itemURL] = true 375 } 376 return readItems 377}