dev vouch dev on at. thats about it
atvouch.dev
1package main
2
3import (
4 "context"
5 "database/sql"
6 "encoding/json"
7 "fmt"
8 "os"
9 "path/filepath"
10
11 "github.com/bluesky-social/indigo/atproto/auth/oauth"
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 _ "github.com/mattn/go-sqlite3"
14)
15
16type Store struct {
17 db *sql.DB
18}
19
20var _ oauth.ClientAuthStore = &Store{}
21
22func NewStore() (*Store, error) {
23 configDir, err := os.UserConfigDir()
24 if err != nil {
25 return nil, fmt.Errorf("getting config dir: %w", err)
26 }
27 dir := filepath.Join(configDir, "atvouch")
28 if err := os.MkdirAll(dir, 0700); err != nil {
29 return nil, err
30 }
31
32 db, err := sql.Open("sqlite3", filepath.Join(dir, "state.db"))
33 if err != nil {
34 return nil, err
35 }
36
37 for _, pragma := range []string{
38 "PRAGMA journal_mode = WAL",
39 "PRAGMA busy_timeout = 5000",
40 "PRAGMA synchronous = NORMAL",
41 "PRAGMA cache_size = -6000",
42 "PRAGMA foreign_keys = true",
43 "PRAGMA temp_store = memory",
44 } {
45 if _, err := db.Exec(pragma); err != nil {
46 db.Close()
47 return nil, fmt.Errorf("setting pragma: %w", err)
48 }
49 }
50
51 if err := migrate(db); err != nil {
52 db.Close()
53 return nil, err
54 }
55
56 return &Store{db: db}, nil
57}
58
59var migrations = []struct {
60 version int
61 sql string
62}{
63 {1, `
64 CREATE TABLE sessions (
65 did TEXT NOT NULL,
66 session_id TEXT NOT NULL,
67 data TEXT NOT NULL,
68 PRIMARY KEY (did, session_id)
69 ) STRICT;
70
71 CREATE TABLE auth_requests (
72 state TEXT NOT NULL PRIMARY KEY,
73 data TEXT NOT NULL
74 ) STRICT;
75
76 CREATE TABLE active_session (
77 id INTEGER NOT NULL PRIMARY KEY CHECK (id = 1),
78 did TEXT NOT NULL,
79 session_id TEXT NOT NULL
80 ) STRICT;
81 `},
82}
83
84func migrate(db *sql.DB) error {
85 _, err := db.Exec(`
86 CREATE TABLE IF NOT EXISTS migration_log (
87 version INTEGER NOT NULL PRIMARY KEY,
88 applied_at TEXT NOT NULL DEFAULT (datetime('now'))
89 ) STRICT;
90 `)
91 if err != nil {
92 return fmt.Errorf("creating migration_log: %w", err)
93 }
94
95 for _, m := range migrations {
96 var exists int
97 err := db.QueryRow("SELECT 1 FROM migration_log WHERE version = ?", m.version).Scan(&exists)
98 if err == nil {
99 continue
100 }
101 if err != sql.ErrNoRows {
102 return fmt.Errorf("checking migration %d: %w", m.version, err)
103 }
104
105 tx, err := db.Begin()
106 if err != nil {
107 return err
108 }
109 if _, err := tx.Exec(m.sql); err != nil {
110 tx.Rollback()
111 return fmt.Errorf("migration %d: %w", m.version, err)
112 }
113 if _, err := tx.Exec("INSERT INTO migration_log (version) VALUES (?)", m.version); err != nil {
114 tx.Rollback()
115 return fmt.Errorf("recording migration %d: %w", m.version, err)
116 }
117 if err := tx.Commit(); err != nil {
118 return fmt.Errorf("committing migration %d: %w", m.version, err)
119 }
120 }
121
122 return nil
123}
124
125func (s *Store) Close() error {
126 return s.db.Close()
127}
128
129func (s *Store) SetActive(did syntax.DID, sessionID string) error {
130 _, err := s.db.Exec(
131 `INSERT INTO active_session (id, did, session_id) VALUES (1, ?, ?)
132 ON CONFLICT (id) DO UPDATE SET did = excluded.did, session_id = excluded.session_id`,
133 did.String(), sessionID,
134 )
135 return err
136}
137
138type activeSession struct {
139 DID syntax.DID
140 SessionID string
141}
142
143func (s *Store) GetActive() (*activeSession, error) {
144 var didStr, sessionID string
145 err := s.db.QueryRow("SELECT did, session_id FROM active_session WHERE id = 1").Scan(&didStr, &sessionID)
146 if err == sql.ErrNoRows {
147 return nil, fmt.Errorf("no active session (run 'atvouch login' first)")
148 }
149 if err != nil {
150 return nil, err
151 }
152 did, err := syntax.ParseDID(didStr)
153 if err != nil {
154 return nil, err
155 }
156 return &activeSession{DID: did, SessionID: sessionID}, nil
157}
158
159func (s *Store) GetSession(ctx context.Context, did syntax.DID, sessionID string) (*oauth.ClientSessionData, error) {
160 var data string
161 err := s.db.QueryRowContext(ctx,
162 "SELECT data FROM sessions WHERE did = ? AND session_id = ?",
163 did.String(), sessionID,
164 ).Scan(&data)
165 if err == sql.ErrNoRows {
166 return nil, fmt.Errorf("session not found for %s", did)
167 }
168 if err != nil {
169 return nil, err
170 }
171 var sess oauth.ClientSessionData
172 if err := json.Unmarshal([]byte(data), &sess); err != nil {
173 return nil, err
174 }
175 return &sess, nil
176}
177
178func (s *Store) SaveSession(ctx context.Context, sess oauth.ClientSessionData) error {
179 data, err := json.Marshal(sess)
180 if err != nil {
181 return err
182 }
183 _, err = s.db.ExecContext(ctx,
184 `INSERT INTO sessions (did, session_id, data) VALUES (?, ?, ?)
185 ON CONFLICT (did, session_id) DO UPDATE SET data = excluded.data`,
186 sess.AccountDID.String(), sess.SessionID, string(data),
187 )
188 return err
189}
190
191func (s *Store) DeleteSession(ctx context.Context, did syntax.DID, sessionID string) error {
192 _, err := s.db.ExecContext(ctx,
193 "DELETE FROM sessions WHERE did = ? AND session_id = ?",
194 did.String(), sessionID,
195 )
196 return err
197}
198
199func (s *Store) GetAuthRequestInfo(ctx context.Context, state string) (*oauth.AuthRequestData, error) {
200 var data string
201 err := s.db.QueryRowContext(ctx,
202 "SELECT data FROM auth_requests WHERE state = ?",
203 state,
204 ).Scan(&data)
205 if err == sql.ErrNoRows {
206 return nil, fmt.Errorf("request info not found: %s", state)
207 }
208 if err != nil {
209 return nil, err
210 }
211 var req oauth.AuthRequestData
212 if err := json.Unmarshal([]byte(data), &req); err != nil {
213 return nil, err
214 }
215 return &req, nil
216}
217
218func (s *Store) SaveAuthRequestInfo(ctx context.Context, info oauth.AuthRequestData) error {
219 data, err := json.Marshal(info)
220 if err != nil {
221 return err
222 }
223 _, err = s.db.ExecContext(ctx,
224 "INSERT INTO auth_requests (state, data) VALUES (?, ?)",
225 info.State, string(data),
226 )
227 return err
228}
229
230func (s *Store) DeleteAuthRequestInfo(ctx context.Context, state string) error {
231 _, err := s.db.ExecContext(ctx,
232 "DELETE FROM auth_requests WHERE state = ?",
233 state,
234 )
235 return err
236}