package db import ( "crypto/rand" "database/sql" "fmt" "log" "os" "path/filepath" "strings" "time" _ "github.com/mattn/go-sqlite3" "github.com/oklog/ulid/v2" "github.com/limeleaf/diffdown/internal/model" ) var migrationsDir = "migrations" func SetMigrationsDir(dir string) { migrationsDir = dir } type DB struct { *sql.DB } func NewID() string { return ulid.MustNew(ulid.Timestamp(time.Now()), rand.Reader).String() } func Open(path string) (*DB, error) { sqlDB, err := sql.Open("sqlite3", path+"?_journal_mode=WAL&_foreign_keys=on&_busy_timeout=5000") if err != nil { return nil, fmt.Errorf("open db: %w", err) } sqlDB.SetMaxOpenConns(1) return &DB{sqlDB}, nil } func (db *DB) Migrate() error { if _, err := db.Exec(`CREATE TABLE IF NOT EXISTS schema_migrations (name TEXT PRIMARY KEY)`); err != nil { return fmt.Errorf("create schema_migrations: %w", err) } entries, err := os.ReadDir(migrationsDir) if err != nil { return fmt.Errorf("read migrations dir %s: %w", migrationsDir, err) } for _, e := range entries { if !strings.HasSuffix(e.Name(), ".sql") { continue } var applied string db.QueryRow(`SELECT name FROM schema_migrations WHERE name = ?`, e.Name()).Scan(&applied) if applied != "" { continue } data, err := os.ReadFile(filepath.Join(migrationsDir, e.Name())) if err != nil { return fmt.Errorf("read migration %s: %w", e.Name(), err) } if _, err := db.Exec(string(data)); err != nil { return fmt.Errorf("exec migration %s: %w", e.Name(), err) } db.Exec(`INSERT INTO schema_migrations (name) VALUES (?)`, e.Name()) log.Printf("Applied migration: %s", e.Name()) } return nil } // --- Users --- func (db *DB) CreateUser(u *model.User) error { if u.ID == "" { u.ID = NewID() } _, err := db.Exec( `INSERT INTO users (id, did) VALUES (?, ?)`, u.ID, u.DID, ) return err } func (db *DB) scanUser(row interface{ Scan(...interface{}) error }) (*model.User, error) { u := &model.User{} err := row.Scan(&u.ID, &u.DID) if err != nil { return nil, err } return u, nil } const userColumns = `id, did` func (db *DB) GetUserByID(id string) (*model.User, error) { return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE id = ?`, id)) } func (db *DB) GetUserByDID(did string) (*model.User, error) { return db.scanUser(db.QueryRow(`SELECT `+userColumns+` FROM users WHERE did = ?`, did)) } // --- ATProto Sessions --- func (db *DB) UpsertATProtoSession(s *model.ATProtoSession) error { _, err := db.Exec( `INSERT INTO atproto_sessions (user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?) ON CONFLICT(user_id) DO UPDATE SET did = excluded.did, pds_url = excluded.pds_url, access_token = excluded.access_token, refresh_token = excluded.refresh_token, dpop_key_jwk = excluded.dpop_key_jwk, dpop_nonce = excluded.dpop_nonce, token_endpoint = excluded.token_endpoint, expires_at = excluded.expires_at, updated_at = excluded.updated_at`, s.UserID, s.DID, s.PDSURL, s.AccessToken, s.RefreshToken, s.DPoPKeyJWK, s.DPoPNonce, s.TokenEndpoint, s.ExpiresAt, time.Now(), ) return err } func (db *DB) GetATProtoSession(userID string) (*model.ATProtoSession, error) { s := &model.ATProtoSession{} err := db.QueryRow( `SELECT user_id, did, pds_url, access_token, refresh_token, dpop_key_jwk, dpop_nonce, token_endpoint, expires_at, updated_at FROM atproto_sessions WHERE user_id = ?`, userID, ).Scan(&s.UserID, &s.DID, &s.PDSURL, &s.AccessToken, &s.RefreshToken, &s.DPoPKeyJWK, &s.DPoPNonce, &s.TokenEndpoint, &s.ExpiresAt, &s.UpdatedAt) if err != nil { return nil, err } return s, nil } func (db *DB) UpdateATProtoTokens(userID, accessToken, refreshToken, nonce string, expiresAt time.Time) error { _, err := db.Exec( `UPDATE atproto_sessions SET access_token = ?, refresh_token = ?, dpop_nonce = ?, expires_at = ?, updated_at = ? WHERE user_id = ?`, accessToken, refreshToken, nonce, expiresAt, time.Now(), userID, ) return err } // --- Invites --- func (db *DB) CreateInvite(invite *model.Invite) error { _, err := db.Exec(` INSERT INTO invites (id, document_rkey, token, created_by_did, created_at, expires_at) VALUES (?, ?, ?, ?, ?, ?)`, invite.ID, invite.DocumentRKey, invite.Token, invite.CreatedBy, invite.CreatedAt, invite.ExpiresAt) return err } func (db *DB) GetInviteByToken(token string) (*model.Invite, error) { row := db.QueryRow(`SELECT id, document_rkey, token, created_by_did, created_at, expires_at FROM invites WHERE token = ?`, token) var invite model.Invite err := row.Scan(&invite.ID, &invite.DocumentRKey, &invite.Token, &invite.CreatedBy, &invite.CreatedAt, &invite.ExpiresAt) if err != nil { return nil, err } return &invite, nil } func (db *DB) DeleteInvite(token string) error { _, err := db.Exec(`DELETE FROM invites WHERE token = ?`, token) return err } // --- Document Steps (prosemirror-collab) --- type StepRow struct { Version int JSON string } func (db *DB) GetDocVersion(docRKey string) (int, error) { var v int err := db.QueryRow( `SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey, ).Scan(&v) if err != nil { return 0, fmt.Errorf("GetDocVersion: %w", err) } return v, nil } func (db *DB) AppendSteps(docRKey string, clientVersion int, stepsJSON []string, clientID string) (int, error) { tx, err := db.Begin() if err != nil { return 0, fmt.Errorf("AppendSteps begin: %w", err) } defer tx.Rollback() var current int tx.QueryRow(`SELECT COALESCE(MAX(version), 0) FROM doc_steps WHERE doc_rkey = ?`, docRKey).Scan(¤t) if current != clientVersion { return 0, fmt.Errorf("version conflict: server=%d client=%d", current, clientVersion) } for i, stepJSON := range stepsJSON { version := clientVersion + i + 1 _, err := tx.Exec( `INSERT INTO doc_steps (doc_rkey, version, step_json, client_id) VALUES (?, ?, ?, ?)`, docRKey, version, stepJSON, clientID, ) if err != nil { return 0, fmt.Errorf("AppendSteps insert v%d: %w", version, err) } } if err := tx.Commit(); err != nil { return 0, fmt.Errorf("AppendSteps commit: %w", err) } return clientVersion + len(stepsJSON), nil } func (db *DB) GetStepsSince(docRKey string, sinceVersion int) ([]StepRow, error) { rows, err := db.Query( `SELECT version, step_json FROM doc_steps WHERE doc_rkey = ? AND version > ? ORDER BY version ASC`, docRKey, sinceVersion, ) if err != nil { return nil, fmt.Errorf("GetStepsSince: %w", err) } defer rows.Close() var result []StepRow for rows.Next() { var r StepRow if err := rows.Scan(&r.Version, &r.JSON); err != nil { return nil, err } result = append(result, r) } return result, rows.Err() } // --- Collaborations --- type CollaborationRow struct { CollaboratorDID string OwnerDID string DocumentRKey string AddedAt time.Time } func (db *DB) AddCollaboration(collabDID, ownerDID, rkey string) error { _, err := db.Exec( `INSERT INTO collaborations (collaborator_did, owner_did, document_rkey, added_at) VALUES (?, ?, ?, ?) ON CONFLICT DO NOTHING`, collabDID, ownerDID, rkey, time.Now(), ) return err } func (db *DB) GetCollaborations(collabDID string) ([]CollaborationRow, error) { rows, err := db.Query( `SELECT collaborator_did, owner_did, document_rkey, added_at FROM collaborations WHERE collaborator_did = ? ORDER BY added_at DESC`, collabDID, ) if err != nil { return nil, err } defer rows.Close() var result []CollaborationRow for rows.Next() { var r CollaborationRow if err := rows.Scan(&r.CollaboratorDID, &r.OwnerDID, &r.DocumentRKey, &r.AddedAt); err != nil { return nil, err } result = append(result, r) } return result, rows.Err() }