Diffdown is a real-time collaborative Markdown editor/previewer built on the AT Protocol
diffdown.com
1package db
2
3import (
4 "crypto/rand"
5 "database/sql"
6 "fmt"
7 "log"
8 "os"
9 "path/filepath"
10 "strings"
11 "time"
12
13 _ "github.com/mattn/go-sqlite3"
14 "github.com/oklog/ulid/v2"
15
16 "github.com/limeleaf/diffdown/internal/model"
17)
18
19var migrationsDir = "migrations"
20
21func SetMigrationsDir(dir string) {
22 migrationsDir = dir
23}
24
25type DB struct {
26 *sql.DB
27}
28
29func NewID() string {
30 return ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String()
31}
32
33func Open(path string) (*DB, error) {
34 sqlDB, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on&_busy_timeout=5000")
35 if err != nil {
36 return nil, fmt.Errorf("open db: %w", err)
37 }
38 sqlDB.SetMaxOpenConns(1)
39 return &DB{sqlDB}, nil
40}
41
42func (db *DB) Migrate() error {
43 if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY)`); err != nil {
44 return fmt.Errorf("create schema_migrations: %w", err)
45 }
46
47 entries, err := os.ReadDir(migrationsDir)
48 if err != nil {
49 return fmt.Errorf("read migrations dir %s: %w", migrationsDir, err)
50 }
51 for _, e := range entries {
52 if !strings.HasSuffix(e.Name(), ".sql") {
53 continue
54 }
55 var applied string
56 db.QueryRow(`SELECT name FROM schema_migrations WHERE name = ?`, e.Name()).Scan(&applied)
57 if applied != "" {
58 continue
59 }
60 data, err := os.ReadFile(filepath.Join(migrationsDir, e.Name()))
61 if err != nil {
62 return fmt.Errorf("read migration %s: %w", e.Name(), err)
63 }
64 if _, err := db.Exec(string(data)); err != nil {
65 return fmt.Errorf("exec migration %s: %w", e.Name(), err)
66 }
67 db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, e.Name())
68 log.Printf("Applied migration: %s", e.Name())
69 }
70 return nil
71}
72
73// --- Users ---
74
75func (db *DB) CreateUser(u *model.User) error {
76 if u.ID == "" {
77 u.ID = NewID()
78 }
79 _, err := db.Exec(
80 `INSERT INTO users (id, did) VALUES (?, ?)`,
81 u.ID, u.DID,
82 )
83 return err
84}
85
86func (db *DB) scanUser(row interface{ Scan(...interface{}) error }) (*model.User, error) {
87 u := &model.User{}
88 err := row.Scan(&u.ID, &u.DID)
89 if err != nil {
90 return nil, err
91 }
92 return u, nil
93}
94
95const userColumns = `id, did`
96
97func (db *DB) GetUserByID(id string) (*model.User, error) {
98 return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE id = ?`, id))
99}
100
101func (db *DB) GetUserByDID(did string) (*model.User, error) {
102 return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE did = ?`, did))
103}
104
105// --- ATProto Sessions ---
106
107func (db *DB) UpsertATProtoSession(s *model.ATProtoSession) error {
108 _, err := db.Exec(
109 `INSERT INTO atproto_sessions (user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at)
110 VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)
111 ON CONFLICT(user_id) DO UPDATE SET
112 did = excluded.did,
113 pds_url = excluded.pds_url,
114 access_token = excluded.access_token,
115 refresh_token = excluded.refresh_token,
116 dpop_key_jwk = excluded.dpop_key_jwk,
117 dpop_nonce = excluded.dpop_nonce,
118 token_endpoint = excluded.token_endpoint,
119 expires_at = excluded.expires_at,
120 updated_at = excluded.updated_at`,
121 s.UserID, s.DID, s.PDSURL, s.AccessToken, s.RefreshToken, s.DPoPKeyJWK, s.DPoPNonce, s.TokenEndpoint, s.ExpiresAt, time.Now(),
122 )
123 return err
124}
125
126func (db *DB) GetATProtoSession(userID string) (*model.ATProtoSession, error) {
127 s := &model.ATProtoSession{}
128 err := db.QueryRow(
129 `SELECT user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at
130 FROM atproto_sessions WHERE user_id = ?`, userID,
131 ).Scan(&s.UserID, &s.DID, &s.PDSURL, &s.AccessToken, &s.RefreshToken, &s.DPoPKeyJWK, &s.DPoPNonce, &s.TokenEndpoint, &s.ExpiresAt, &s.UpdatedAt)
132 if err != nil {
133 return nil, err
134 }
135 return s, nil
136}
137
138func (db *DB) UpdateATProtoTokens(userID, accessToken, refreshToken, nonce string, expiresAt time.Time) error {
139 _, err := db.Exec(
140 `UPDATE atproto_sessions SET access_token = ?, refresh_token = ?, dpop_nonce = ?, expires_at = ?, updated_at = ? WHERE user_id = ?`,
141 accessToken, refreshToken, nonce, expiresAt, time.Now(), userID,
142 )
143 return err
144}
145
146// --- Invites ---
147
148func (db *DB) CreateInvite(invite *model.Invite) error {
149 _, err := db.Exec(`
150 INSERT INTO invites (id, document_rkey, token, created_by_did, created_at, expires_at)
151 VALUES (?, ?, ?, ?, ?, ?)`,
152 invite.ID, invite.DocumentRKey, invite.Token, invite.CreatedBy, invite.CreatedAt, invite.ExpiresAt)
153 return err
154}
155
156func (db *DB) GetInviteByToken(token string) (*model.Invite, error) {
157 row := db.QueryRow(`SELECT id, document_rkey, token, created_by_did, created_at, expires_at FROM invites WHERE token = ?`, token)
158 var invite model.Invite
159 err := row.Scan(&invite.ID, &invite.DocumentRKey, &invite.Token, &invite.CreatedBy, &invite.CreatedAt, &invite.ExpiresAt)
160 if err != nil {
161 return nil, err
162 }
163 return &invite, nil
164}
165
166func (db *DB) DeleteInvite(token string) error {
167 _, err := db.Exec(`DELETE FROM invites WHERE token = ?`, token)
168 return err
169}
170
171// --- Document Steps (prosemirror-collab) ---
172
173type StepRow struct {
174 Version int
175 JSON string
176}
177
178func (db *DB) GetDocVersion(docRKey string) (int, error) {
179 var v int
180 err := db.QueryRow(
181 `SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey,
182 ).Scan(&v)
183 if err != nil {
184 return 0, fmt.Errorf("GetDocVersion: %w", err)
185 }
186 return v, nil
187}
188
189func (db *DB) AppendSteps(docRKey string, clientVersion int, stepsJSON []string, clientID string) (int, error) {
190 tx, err := db.Begin()
191 if err != nil {
192 return 0, fmt.Errorf("AppendSteps begin: %w", err)
193 }
194 defer tx.Rollback()
195
196 var current int
197 tx.QueryRow(`SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey).Scan(¤t)
198 if current != clientVersion {
199 return 0, fmt.Errorf("version conflict: server=%d client=%d", current, clientVersion)
200 }
201
202 for i, stepJSON := range stepsJSON {
203 version := clientVersion + i + 1
204 _, err := tx.Exec(
205 `INSERT INTO doc_steps (doc_rkey, version, step_json, client_id) VALUES (?, ?, ?, ?)`,
206 docRKey, version, stepJSON, clientID,
207 )
208 if err != nil {
209 return 0, fmt.Errorf("AppendSteps insert v%d: %w", version, err)
210 }
211 }
212
213 if err := tx.Commit(); err != nil {
214 return 0, fmt.Errorf("AppendSteps commit: %w", err)
215 }
216 return clientVersion + len(stepsJSON), nil
217}
218
219func (db *DB) GetStepsSince(docRKey string, sinceVersion int) ([]StepRow, error) {
220 rows, err := db.Query(
221 `SELECT version, step_json FROM doc_steps WHERE doc_rkey = ? AND version > ? ORDER BY version ASC`,
222 docRKey, sinceVersion,
223 )
224 if err != nil {
225 return nil, fmt.Errorf("GetStepsSince: %w", err)
226 }
227 defer rows.Close()
228 var result []StepRow
229 for rows.Next() {
230 var r StepRow
231 if err := rows.Scan(&r.Version, &r.JSON); err != nil {
232 return nil, err
233 }
234 result = append(result, r)
235 }
236 return result, rows.Err()
237}
238
239// --- Collaborations ---
240
241type CollaborationRow struct {
242 CollaboratorDID string
243 OwnerDID string
244 DocumentRKey string
245 AddedAt time.Time
246}
247
248func (db *DB) AddCollaboration(collabDID, ownerDID, rkey string) error {
249 _, err := db.Exec(
250 `INSERT INTO collaborations (collaborator_did, owner_did, document_rkey, added_at)
251 VALUES (?, ?, ?, ?)
252 ON CONFLICT DO NOTHING`,
253 collabDID, ownerDID, rkey, time.Now(),
254 )
255 return err
256}
257
258func (db *DB) GetCollaborations(collabDID string) ([]CollaborationRow, error) {
259 rows, err := db.Query(
260 `SELECT collaborator_did, owner_did, document_rkey, added_at
261 FROM collaborations WHERE collaborator_did = ? ORDER BY added_at DESC`,
262 collabDID,
263 )
264 if err != nil {
265 return nil, err
266 }
267 defer rows.Close()
268 var result []CollaborationRow
269 for rows.Next() {
270 var r CollaborationRow
271 if err := rows.Scan(&r.CollaboratorDID, &r.OwnerDID, &r.DocumentRKey, &r.AddedAt); err != nil {
272 return nil, err
273 }
274 result = append(result, r)
275 }
276 return result, rows.Err()
277}