Diffdown is a real-time collaborative Markdown editor/previewer built on the AT Protocol diffdown.com
at main 277 lines 7.9 kB view raw
1package db 2 3import ( 4 "crypto/rand" 5 "database/sql" 6 "fmt" 7 "log" 8 "os" 9 "path/filepath" 10 "strings" 11 "time" 12 13 _ "github.com/mattn/go-sqlite3" 14 "github.com/oklog/ulid/v2" 15 16 "github.com/limeleaf/diffdown/internal/model" 17) 18 19var migrationsDir = "migrations" 20 21func SetMigrationsDir(dir string) { 22 migrationsDir = dir 23} 24 25type DB struct { 26 *sql.DB 27} 28 29func NewID() string { 30 return ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String() 31} 32 33func Open(path string) (*DB, error) { 34 sqlDB, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on&_busy_timeout=5000") 35 if err != nil { 36 return nil, fmt.Errorf("open db: %w", err) 37 } 38 sqlDB.SetMaxOpenConns(1) 39 return &DB{sqlDB}, nil 40} 41 42func (db *DB) Migrate() error { 43 if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY)`); err != nil { 44 return fmt.Errorf("create schema_migrations: %w", err) 45 } 46 47 entries, err := os.ReadDir(migrationsDir) 48 if err != nil { 49 return fmt.Errorf("read migrations dir %s: %w", migrationsDir, err) 50 } 51 for _, e := range entries { 52 if !strings.HasSuffix(e.Name(), ".sql") { 53 continue 54 } 55 var applied string 56 db.QueryRow(`SELECT name FROM schema_migrations WHERE name = ?`, e.Name()).Scan(&applied) 57 if applied != "" { 58 continue 59 } 60 data, err := os.ReadFile(filepath.Join(migrationsDir, e.Name())) 61 if err != nil { 62 return fmt.Errorf("read migration %s: %w", e.Name(), err) 63 } 64 if _, err := db.Exec(string(data)); err != nil { 65 return fmt.Errorf("exec migration %s: %w", e.Name(), err) 66 } 67 db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, e.Name()) 68 log.Printf("Applied migration: %s", e.Name()) 69 } 70 return nil 71} 72 73// --- Users --- 74 75func (db *DB) CreateUser(u *model.User) error { 76 if u.ID == "" { 77 u.ID = NewID() 78 } 79 _, err := db.Exec( 80 `INSERT INTO users (id, did) VALUES (?, ?)`, 81 u.ID, u.DID, 82 ) 83 return err 84} 85 86func (db *DB) scanUser(row interface{ Scan(...interface{}) error }) (*model.User, error) { 87 u := &model.User{} 88 err := row.Scan(&u.ID, &u.DID) 89 if err != nil { 90 return nil, err 91 } 92 return u, nil 93} 94 95const userColumns = `id, did` 96 97func (db *DB) GetUserByID(id string) (*model.User, error) { 98 return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE id = ?`, id)) 99} 100 101func (db *DB) GetUserByDID(did string) (*model.User, error) { 102 return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE did = ?`, did)) 103} 104 105// --- ATProto Sessions --- 106 107func (db *DB) UpsertATProtoSession(s *model.ATProtoSession) error { 108 _, err := db.Exec( 109 `INSERT INTO atproto_sessions (user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at) 110 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) 111 ON CONFLICT(user_id) DO UPDATE SET 112 did = excluded.did, 113 pds_url = excluded.pds_url, 114 access_token = excluded.access_token, 115 refresh_token = excluded.refresh_token, 116 dpop_key_jwk = excluded.dpop_key_jwk, 117 dpop_nonce = excluded.dpop_nonce, 118 token_endpoint = excluded.token_endpoint, 119 expires_at = excluded.expires_at, 120 updated_at = excluded.updated_at`, 121 s.UserID, s.DID, s.PDSURL, s.AccessToken, s.RefreshToken, s.DPoPKeyJWK, s.DPoPNonce, s.TokenEndpoint, s.ExpiresAt, time.Now(), 122 ) 123 return err 124} 125 126func (db *DB) GetATProtoSession(userID string) (*model.ATProtoSession, error) { 127 s := &model.ATProtoSession{} 128 err := db.QueryRow( 129 `SELECT user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at 130 FROM atproto_sessions WHERE user_id = ?`, userID, 131 ).Scan(&s.UserID, &s.DID, &s.PDSURL, &s.AccessToken, &s.RefreshToken, &s.DPoPKeyJWK, &s.DPoPNonce, &s.TokenEndpoint, &s.ExpiresAt, &s.UpdatedAt) 132 if err != nil { 133 return nil, err 134 } 135 return s, nil 136} 137 138func (db *DB) UpdateATProtoTokens(userID, accessToken, refreshToken, nonce string, expiresAt time.Time) error { 139 _, err := db.Exec( 140 `UPDATE atproto_sessions SET access_token = ?, refresh_token = ?, dpop_nonce = ?, expires_at = ?, updated_at = ? WHERE user_id = ?`, 141 accessToken, refreshToken, nonce, expiresAt, time.Now(), userID, 142 ) 143 return err 144} 145 146// --- Invites --- 147 148func (db *DB) CreateInvite(invite *model.Invite) error { 149 _, err := db.Exec(` 150 INSERT INTO invites (id, document_rkey, token, created_by_did, created_at, expires_at) 151 VALUES (?, ?, ?, ?, ?, ?)`, 152 invite.ID, invite.DocumentRKey, invite.Token, invite.CreatedBy, invite.CreatedAt, invite.ExpiresAt) 153 return err 154} 155 156func (db *DB) GetInviteByToken(token string) (*model.Invite, error) { 157 row := db.QueryRow(`SELECT id, document_rkey, token, created_by_did, created_at, expires_at FROM invites WHERE token = ?`, token) 158 var invite model.Invite 159 err := row.Scan(&invite.ID, &invite.DocumentRKey, &invite.Token, &invite.CreatedBy, &invite.CreatedAt, &invite.ExpiresAt) 160 if err != nil { 161 return nil, err 162 } 163 return &invite, nil 164} 165 166func (db *DB) DeleteInvite(token string) error { 167 _, err := db.Exec(`DELETE FROM invites WHERE token = ?`, token) 168 return err 169} 170 171// --- Document Steps (prosemirror-collab) --- 172 173type StepRow struct { 174 Version int 175 JSON string 176} 177 178func (db *DB) GetDocVersion(docRKey string) (int, error) { 179 var v int 180 err := db.QueryRow( 181 `SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey, 182 ).Scan(&v) 183 if err != nil { 184 return 0, fmt.Errorf("GetDocVersion: %w", err) 185 } 186 return v, nil 187} 188 189func (db *DB) AppendSteps(docRKey string, clientVersion int, stepsJSON []string, clientID string) (int, error) { 190 tx, err := db.Begin() 191 if err != nil { 192 return 0, fmt.Errorf("AppendSteps begin: %w", err) 193 } 194 defer tx.Rollback() 195 196 var current int 197 tx.QueryRow(`SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey).Scan(&current) 198 if current != clientVersion { 199 return 0, fmt.Errorf("version conflict: server=%d client=%d", current, clientVersion) 200 } 201 202 for i, stepJSON := range stepsJSON { 203 version := clientVersion + i + 1 204 _, err := tx.Exec( 205 `INSERT INTO doc_steps (doc_rkey, version, step_json, client_id) VALUES (?, ?, ?, ?)`, 206 docRKey, version, stepJSON, clientID, 207 ) 208 if err != nil { 209 return 0, fmt.Errorf("AppendSteps insert v%d: %w", version, err) 210 } 211 } 212 213 if err := tx.Commit(); err != nil { 214 return 0, fmt.Errorf("AppendSteps commit: %w", err) 215 } 216 return clientVersion + len(stepsJSON), nil 217} 218 219func (db *DB) GetStepsSince(docRKey string, sinceVersion int) ([]StepRow, error) { 220 rows, err := db.Query( 221 `SELECT version, step_json FROM doc_steps WHERE doc_rkey = ? AND version > ? ORDER BY version ASC`, 222 docRKey, sinceVersion, 223 ) 224 if err != nil { 225 return nil, fmt.Errorf("GetStepsSince: %w", err) 226 } 227 defer rows.Close() 228 var result []StepRow 229 for rows.Next() { 230 var r StepRow 231 if err := rows.Scan(&r.Version, &r.JSON); err != nil { 232 return nil, err 233 } 234 result = append(result, r) 235 } 236 return result, rows.Err() 237} 238 239// --- Collaborations --- 240 241type CollaborationRow struct { 242 CollaboratorDID string 243 OwnerDID string 244 DocumentRKey string 245 AddedAt time.Time 246} 247 248func (db *DB) AddCollaboration(collabDID, ownerDID, rkey string) error { 249 _, err := db.Exec( 250 `INSERT INTO collaborations (collaborator_did, owner_did, document_rkey, added_at) 251 VALUES (?, ?, ?, ?) 252 ON CONFLICT DO NOTHING`, 253 collabDID, ownerDID, rkey, time.Now(), 254 ) 255 return err 256} 257 258func (db *DB) GetCollaborations(collabDID string) ([]CollaborationRow, error) { 259 rows, err := db.Query( 260 `SELECT collaborator_did, owner_did, document_rkey, added_at 261 FROM collaborations WHERE collaborator_did = ? ORDER BY added_at DESC`, 262 collabDID, 263 ) 264 if err != nil { 265 return nil, err 266 } 267 defer rows.Close() 268 var result []CollaborationRow 269 for rows.Next() { 270 var r CollaborationRow 271 if err := rows.Scan(&r.CollaboratorDID, &r.OwnerDID, &r.DocumentRKey, &r.AddedAt); err != nil { 272 return nil, err 273 } 274 result = append(result, r) 275 } 276 return result, rows.Err() 277}