forked from tangled.org/core
Monorepo for Tangled

implement transactions

+14 -2
appview/db/db.go
··· 1 1 package db 2 2 3 3 import ( 4 + "context" 4 5 "database/sql" 5 6 6 7 _ "github.com/mattn/go-sqlite3" 7 8 ) 8 9 9 10 type DB struct { 10 - db *sql.DB 11 + *sql.DB 12 + } 13 + 14 + type Execer interface { 15 + Query(query string, args ...any) (*sql.Rows, error) 16 + QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) 17 + QueryRow(query string, args ...any) *sql.Row 18 + QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row 19 + Exec(query string, args ...any) (sql.Result, error) 20 + ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) 21 + Prepare(query string) (*sql.Stmt, error) 22 + PrepareContext(ctx context.Context, query string) (*sql.Stmt, error) 11 23 } 12 24 13 25 func Make(dbPath string) (*DB, error) { ··· 104 116 if err != nil { 105 117 return nil, err 106 118 } 107 - return &DB{db: db}, nil 119 + return &DB{db}, nil 108 120 }
+12 -12
appview/db/follow.go
··· 12 12 RKey string 13 13 } 14 14 15 - func (d *DB) AddFollow(userDid, subjectDid, rkey string) error { 15 + func AddFollow(e Execer, userDid, subjectDid, rkey string) error { 16 16 query := `insert or ignore into follows (user_did, subject_did, rkey) values (?, ?, ?)` 17 - _, err := d.db.Exec(query, userDid, subjectDid, rkey) 17 + _, err := e.Exec(query, userDid, subjectDid, rkey) 18 18 return err 19 19 } 20 20 21 21 // Get a follow record 22 - func (d *DB) GetFollow(userDid, subjectDid string) (*Follow, error) { 22 + func GetFollow(e Execer, userDid, subjectDid string) (*Follow, error) { 23 23 query := `select user_did, subject_did, followed_at, rkey from follows where user_did = ? and subject_did = ?` 24 - row := d.db.QueryRow(query, userDid, subjectDid) 24 + row := e.QueryRow(query, userDid, subjectDid) 25 25 26 26 var follow Follow 27 27 var followedAt string ··· 42 42 } 43 43 44 44 // Get a follow record 45 - func (d *DB) DeleteFollow(userDid, subjectDid string) error { 46 - _, err := d.db.Exec(`delete from follows where user_did = ? and subject_did = ?`, userDid, subjectDid) 45 + func DeleteFollow(e Execer, userDid, subjectDid string) error { 46 + _, err := e.Exec(`delete from follows where user_did = ? and subject_did = ?`, userDid, subjectDid) 47 47 return err 48 48 } 49 49 50 - func (d *DB) GetFollowerFollowing(did string) (int, int, error) { 50 + func GetFollowerFollowing(e Execer, did string) (int, int, error) { 51 51 followers, following := 0, 0 52 - err := d.db.QueryRow( 52 + err := e.QueryRow( 53 53 `SELECT 54 54 COUNT(CASE WHEN subject_did = ? THEN 1 END) AS followers, 55 55 COUNT(CASE WHEN user_did = ? THEN 1 END) AS following ··· 81 81 } 82 82 } 83 83 84 - func (d *DB) GetFollowStatus(userDid, subjectDid string) FollowStatus { 84 + func GetFollowStatus(e Execer, userDid, subjectDid string) FollowStatus { 85 85 if userDid == subjectDid { 86 86 return IsSelf 87 - } else if _, err := d.GetFollow(userDid, subjectDid); err != nil { 87 + } else if _, err := GetFollow(e, userDid, subjectDid); err != nil { 88 88 return IsNotFollowing 89 89 } else { 90 90 return IsFollowing 91 91 } 92 92 } 93 93 94 - func (d *DB) GetAllFollows() ([]Follow, error) { 94 + func GetAllFollows(e Execer) ([]Follow, error) { 95 95 var follows []Follow 96 96 97 - rows, err := d.db.Query(`select user_did, subject_did, followed_at, rkey from follows`) 97 + rows, err := e.Query(`select user_did, subject_did, followed_at, rkey from follows`) 98 98 if err != nil { 99 99 return nil, err 100 100 }
+25 -29
appview/db/issues.go
··· 26 26 Created *time.Time 27 27 } 28 28 29 - func (d *DB) NewIssue(issue *Issue) error { 30 - tx, err := d.db.Begin() 31 - if err != nil { 32 - return err 33 - } 29 + func NewIssue(tx *sql.Tx, issue *Issue) error { 34 30 defer tx.Rollback() 35 31 36 - _, err = tx.Exec(` 32 + _, err := tx.Exec(` 37 33 insert or ignore into repo_issue_seqs (repo_at, next_issue_id) 38 34 values (?, 1) 39 35 `, issue.RepoAt) ··· 69 65 return nil 70 66 } 71 67 72 - func (d *DB) SetIssueAt(repoAt string, issueId int, issueAt string) error { 73 - _, err := d.db.Exec(`update issues set issue_at = ? where repo_at = ? and issue_id = ?`, issueAt, repoAt, issueId) 68 + func SetIssueAt(e Execer, repoAt string, issueId int, issueAt string) error { 69 + _, err := e.Exec(`update issues set issue_at = ? where repo_at = ? and issue_id = ?`, issueAt, repoAt, issueId) 74 70 return err 75 71 } 76 72 77 - func (d *DB) GetIssueAt(repoAt string, issueId int) (string, error) { 73 + func GetIssueAt(e Execer, repoAt string, issueId int) (string, error) { 78 74 var issueAt string 79 - err := d.db.QueryRow(`select issue_at from issues where repo_at = ? and issue_id = ?`, repoAt, issueId).Scan(&issueAt) 75 + err := e.QueryRow(`select issue_at from issues where repo_at = ? and issue_id = ?`, repoAt, issueId).Scan(&issueAt) 80 76 return issueAt, err 81 77 } 82 78 83 - func (d *DB) GetIssueId(repoAt string) (int, error) { 79 + func GetIssueId(e Execer, repoAt string) (int, error) { 84 80 var issueId int 85 - err := d.db.QueryRow(`select next_issue_id from repo_issue_seqs where repo_at = ?`, repoAt).Scan(&issueId) 81 + err := e.QueryRow(`select next_issue_id from repo_issue_seqs where repo_at = ?`, repoAt).Scan(&issueId) 86 82 return issueId - 1, err 87 83 } 88 84 89 - func (d *DB) GetIssueOwnerDid(repoAt string, issueId int) (string, error) { 85 + func GetIssueOwnerDid(e Execer, repoAt string, issueId int) (string, error) { 90 86 var ownerDid string 91 - err := d.db.QueryRow(`select owner_did from issues where repo_at = ? and issue_id = ?`, repoAt, issueId).Scan(&ownerDid) 87 + err := e.QueryRow(`select owner_did from issues where repo_at = ? and issue_id = ?`, repoAt, issueId).Scan(&ownerDid) 92 88 return ownerDid, err 93 89 } 94 90 95 - func (d *DB) GetIssues(repoAt string) ([]Issue, error) { 91 + func GetIssues(e Execer, repoAt string) ([]Issue, error) { 96 92 var issues []Issue 97 93 98 - rows, err := d.db.Query(`select owner_did, issue_id, created, title, body, open from issues where repo_at = ? order by created desc`, repoAt) 94 + rows, err := e.Query(`select owner_did, issue_id, created, title, body, open from issues where repo_at = ? order by created desc`, repoAt) 99 95 if err != nil { 100 96 return nil, err 101 97 } ··· 125 121 return issues, nil 126 122 } 127 123 128 - func (d *DB) GetIssue(repoAt string, issueId int) (*Issue, error) { 124 + func GetIssue(e Execer, repoAt string, issueId int) (*Issue, error) { 129 125 query := `select owner_did, created, title, body, open from issues where repo_at = ? and issue_id = ?` 130 - row := d.db.QueryRow(query, repoAt, issueId) 126 + row := e.QueryRow(query, repoAt, issueId) 131 127 132 128 var issue Issue 133 129 var createdAt string ··· 145 141 return &issue, nil 146 142 } 147 143 148 - func (d *DB) GetIssueWithComments(repoAt string, issueId int) (*Issue, []Comment, error) { 144 + func GetIssueWithComments(e Execer, repoAt string, issueId int) (*Issue, []Comment, error) { 149 145 query := `select owner_did, issue_id, created, title, body, open from issues where repo_at = ? and issue_id = ?` 150 - row := d.db.QueryRow(query, repoAt, issueId) 146 + row := e.QueryRow(query, repoAt, issueId) 151 147 152 148 var issue Issue 153 149 var createdAt string ··· 162 158 } 163 159 issue.Created = &createdTime 164 160 165 - comments, err := d.GetComments(repoAt, issueId) 161 + comments, err := GetComments(e, repoAt, issueId) 166 162 if err != nil { 167 163 return nil, nil, err 168 164 } ··· 170 166 return &issue, comments, nil 171 167 } 172 168 173 - func (d *DB) NewComment(comment *Comment) error { 169 + func NewComment(e Execer, comment *Comment) error { 174 170 query := `insert into comments (owner_did, repo_at, comment_at, issue_id, comment_id, body) values (?, ?, ?, ?, ?, ?)` 175 - _, err := d.db.Exec( 171 + _, err := e.Exec( 176 172 query, 177 173 comment.OwnerDid, 178 174 comment.RepoAt, ··· 184 180 return err 185 181 } 186 182 187 - func (d *DB) GetComments(repoAt string, issueId int) ([]Comment, error) { 183 + func GetComments(e Execer, repoAt string, issueId int) ([]Comment, error) { 188 184 var comments []Comment 189 185 190 - rows, err := d.db.Query(`select owner_did, issue_id, comment_id, comment_at, body, created from comments where repo_at = ? and issue_id = ? order by created asc`, repoAt, issueId) 186 + rows, err := e.Query(`select owner_did, issue_id, comment_id, comment_at, body, created from comments where repo_at = ? and issue_id = ? order by created asc`, repoAt, issueId) 191 187 if err == sql.ErrNoRows { 192 188 return []Comment{}, nil 193 189 } ··· 220 216 return comments, nil 221 217 } 222 218 223 - func (d *DB) CloseIssue(repoAt string, issueId int) error { 224 - _, err := d.db.Exec(`update issues set open = 0 where repo_at = ? and issue_id = ?`, repoAt, issueId) 219 + func CloseIssue(e Execer, repoAt string, issueId int) error { 220 + _, err := e.Exec(`update issues set open = 0 where repo_at = ? and issue_id = ?`, repoAt, issueId) 225 221 return err 226 222 } 227 223 228 - func (d *DB) ReopenIssue(repoAt string, issueId int) error { 229 - _, err := d.db.Exec(`update issues set open = 1 where repo_at = ? and issue_id = ?`, repoAt, issueId) 224 + func ReopenIssue(e Execer, repoAt string, issueId int) error { 225 + _, err := e.Exec(`update issues set open = 1 where repo_at = ? and issue_id = ?`, repoAt, issueId) 230 226 return err 231 227 }
+10 -6
appview/db/jetstream.go
··· 1 1 package db 2 2 3 - func (d *DB) SaveLastTimeUs(lastTimeUs int64) error { 4 - _, err := d.db.Exec(`insert into _jetstream (last_time_us) values (?)`, lastTimeUs) 3 + type DbWrapper struct { 4 + Execer 5 + } 6 + 7 + func (db DbWrapper) SaveLastTimeUs(lastTimeUs int64) error { 8 + _, err := db.Exec(`insert into _jetstream (last_time_us) values (?)`, lastTimeUs) 5 9 return err 6 10 } 7 11 8 - func (d *DB) UpdateLastTimeUs(lastTimeUs int64) error { 9 - _, err := d.db.Exec(`update _jetstream set last_time_us = ? where rowid = 1`, lastTimeUs) 12 + func (db DbWrapper) UpdateLastTimeUs(lastTimeUs int64) error { 13 + _, err := db.Exec(`update _jetstream set last_time_us = ? where rowid = 1`, lastTimeUs) 10 14 if err != nil { 11 15 return err 12 16 } 13 17 return nil 14 18 } 15 19 16 - func (d *DB) GetLastTimeUs() (int64, error) { 20 + func (db DbWrapper) GetLastTimeUs() (int64, error) { 17 21 var lastTimeUs int64 18 - row := d.db.QueryRow(`select last_time_us from _jetstream`) 22 + row := db.QueryRow(`select last_time_us from _jetstream`) 19 23 err := row.Scan(&lastTimeUs) 20 24 return lastTimeUs, err 21 25 }
+8 -8
appview/db/pubkeys.go
··· 5 5 "time" 6 6 ) 7 7 8 - func (d *DB) AddPublicKey(did, name, key string) error { 8 + func AddPublicKey(e Execer, did, name, key string) error { 9 9 query := `insert or ignore into public_keys (did, name, key) values (?, ?, ?)` 10 - _, err := d.db.Exec(query, did, name, key) 10 + _, err := e.Exec(query, did, name, key) 11 11 return err 12 12 } 13 13 14 - func (d *DB) RemovePublicKey(did string) error { 14 + func RemovePublicKey(e Execer, did string) error { 15 15 query := `delete from public_keys where did = ?` 16 - _, err := d.db.Exec(query, did) 16 + _, err := e.Exec(query, did) 17 17 return err 18 18 } 19 19 ··· 35 35 }) 36 36 } 37 37 38 - func (d *DB) GetAllPublicKeys() ([]PublicKey, error) { 38 + func GetAllPublicKeys(e Execer) ([]PublicKey, error) { 39 39 var keys []PublicKey 40 40 41 - rows, err := d.db.Query(`select key, name, did, created from public_keys`) 41 + rows, err := e.Query(`select key, name, did, created from public_keys`) 42 42 if err != nil { 43 43 return nil, err 44 44 } ··· 62 62 return keys, nil 63 63 } 64 64 65 - func (d *DB) GetPublicKeys(did string) ([]PublicKey, error) { 65 + func GetPublicKeys(e Execer, did string) ([]PublicKey, error) { 66 66 var keys []PublicKey 67 67 68 - rows, err := d.db.Query(`select did, key, name, created from public_keys where did = ?`, did) 68 + rows, err := e.Query(`select did, key, name, created from public_keys where did = ?`, did) 69 69 if err != nil { 70 70 return nil, err 71 71 }
+11 -11
appview/db/registration.go
··· 32 32 ) 33 33 34 34 // returns registered status, did of owner, error 35 - func (d *DB) RegistrationsByDid(did string) ([]Registration, error) { 35 + func RegistrationsByDid(e Execer, did string) ([]Registration, error) { 36 36 var registrations []Registration 37 37 38 - rows, err := d.db.Query(` 38 + rows, err := e.Query(` 39 39 select domain, did, created, registered from registrations 40 40 where did = ? 41 41 `, did) ··· 69 69 } 70 70 71 71 // returns registered status, did of owner, error 72 - func (d *DB) RegistrationByDomain(domain string) (*Registration, error) { 72 + func RegistrationByDomain(e Execer, domain string) (*Registration, error) { 73 73 var createdAt *string 74 74 var registeredAt *string 75 75 var registration Registration 76 76 77 - err := d.db.QueryRow(` 77 + err := e.QueryRow(` 78 78 select domain, did, created, registered from registrations 79 79 where domain = ? 80 80 `, domain).Scan(&registration.Domain, &registration.ByDid, &createdAt, &registeredAt) ··· 106 106 return hex.EncodeToString(key) 107 107 } 108 108 109 - func (d *DB) GenerateRegistrationKey(domain, did string) (string, error) { 109 + func GenerateRegistrationKey(e Execer, domain, did string) (string, error) { 110 110 // sanity check: does this domain already have a registration? 111 - reg, err := d.RegistrationByDomain(domain) 111 + reg, err := RegistrationByDomain(e, domain) 112 112 if err != nil { 113 113 return "", err 114 114 } ··· 127 127 128 128 secret := genSecret() 129 129 130 - _, err = d.db.Exec(` 130 + _, err = e.Exec(` 131 131 insert into registrations (domain, did, secret) 132 132 values (?, ?, ?) 133 133 on conflict(domain) do update set did = excluded.did, secret = excluded.secret ··· 140 140 return secret, nil 141 141 } 142 142 143 - func (d *DB) GetRegistrationKey(domain string) (string, error) { 144 - res := d.db.QueryRow(`select secret from registrations where domain = ?`, domain) 143 + func GetRegistrationKey(e Execer, domain string) (string, error) { 144 + res := e.QueryRow(`select secret from registrations where domain = ?`, domain) 145 145 146 146 var secret string 147 147 err := res.Scan(&secret) ··· 152 152 return secret, nil 153 153 } 154 154 155 - func (d *DB) Register(domain string) error { 156 - _, err := d.db.Exec(` 155 + func Register(e Execer, domain string) error { 156 + _, err := e.Exec(` 157 157 update registrations 158 158 set registered = strftime('%Y-%m-%dT%H:%M:%SZ', 'now') 159 159 where domain = ?;
+14 -14
appview/db/repos.go
··· 14 14 AtUri string 15 15 } 16 16 17 - func (d *DB) GetAllRepos() ([]Repo, error) { 17 + func GetAllRepos(e Execer) ([]Repo, error) { 18 18 var repos []Repo 19 19 20 - rows, err := d.db.Query(`select did, name, knot, rkey, created from repos`) 20 + rows, err := e.Query(`select did, name, knot, rkey, created from repos`) 21 21 if err != nil { 22 22 return nil, err 23 23 } ··· 39 39 return repos, nil 40 40 } 41 41 42 - func (d *DB) GetAllReposByDid(did string) ([]Repo, error) { 42 + func GetAllReposByDid(e Execer, did string) ([]Repo, error) { 43 43 var repos []Repo 44 44 45 - rows, err := d.db.Query(`select did, name, knot, rkey, created from repos where did = ?`, did) 45 + rows, err := e.Query(`select did, name, knot, rkey, created from repos where did = ?`, did) 46 46 if err != nil { 47 47 return nil, err 48 48 } ··· 64 64 return repos, nil 65 65 } 66 66 67 - func (d *DB) GetRepo(did, name string) (*Repo, error) { 67 + func GetRepo(e Execer, did, name string) (*Repo, error) { 68 68 var repo Repo 69 69 70 - row := d.db.QueryRow(`select did, name, knot, created, at_uri from repos where did = ? and name = ?`, did, name) 70 + row := e.QueryRow(`select did, name, knot, created, at_uri from repos where did = ? and name = ?`, did, name) 71 71 72 72 var createdAt string 73 73 if err := row.Scan(&repo.Did, &repo.Name, &repo.Knot, &createdAt, &repo.AtUri); err != nil { ··· 79 79 return &repo, nil 80 80 } 81 81 82 - func (d *DB) AddRepo(repo *Repo) error { 83 - _, err := d.db.Exec(`insert into repos (did, name, knot, rkey, at_uri) values (?, ?, ?, ?, ?)`, repo.Did, repo.Name, repo.Knot, repo.Rkey, repo.AtUri) 82 + func AddRepo(e Execer, repo *Repo) error { 83 + _, err := e.Exec(`insert into repos (did, name, knot, rkey, at_uri) values (?, ?, ?, ?, ?)`, repo.Did, repo.Name, repo.Knot, repo.Rkey, repo.AtUri) 84 84 return err 85 85 } 86 86 87 - func (d *DB) RemoveRepo(did, name, rkey string) error { 88 - _, err := d.db.Exec(`delete from repos where did = ? and name = ? and rkey = ?`, did, name, rkey) 87 + func RemoveRepo(e Execer, did, name, rkey string) error { 88 + _, err := e.Exec(`delete from repos where did = ? and name = ? and rkey = ?`, did, name, rkey) 89 89 return err 90 90 } 91 91 92 - func (d *DB) AddCollaborator(collaborator, repoOwnerDid, repoName, repoKnot string) error { 93 - _, err := d.db.Exec( 92 + func AddCollaborator(e Execer, collaborator, repoOwnerDid, repoName, repoKnot string) error { 93 + _, err := e.Exec( 94 94 `insert into collaborators (did, repo) 95 95 values (?, (select id from repos where did = ? and name = ? and knot = ?));`, 96 96 collaborator, repoOwnerDid, repoName, repoKnot) 97 97 return err 98 98 } 99 99 100 - func (d *DB) CollaboratingIn(collaborator string) ([]Repo, error) { 100 + func CollaboratingIn(e Execer, collaborator string) ([]Repo, error) { 101 101 var repos []Repo 102 102 103 - rows, err := d.db.Query(`select r.did, r.name, r.knot, r.rkey, r.created from repos r join collaborators c on r.id = c.repo where c.did = ?;`, collaborator) 103 + rows, err := e.Query(`select r.did, r.name, r.knot, r.rkey, r.created from repos r join collaborators c on r.id = c.repo where c.did = ?;`, collaborator) 104 104 if err != nil { 105 105 return nil, err 106 106 }
+3 -3
appview/db/timeline.go
··· 11 11 EventAt time.Time 12 12 } 13 13 14 - func (d *DB) MakeTimeline() ([]TimelineEvent, error) { 14 + func MakeTimeline(e Execer) ([]TimelineEvent, error) { 15 15 var events []TimelineEvent 16 16 17 - repos, err := d.GetAllRepos() 17 + repos, err := GetAllRepos(e) 18 18 if err != nil { 19 19 return nil, err 20 20 } 21 21 22 - follows, err := d.GetAllFollows() 22 + follows, err := GetAllFollows(e) 23 23 if err != nil { 24 24 return nil, err 25 25 }
+1 -1
appview/pages/templates/repo/new.html
··· 22 22 type="text" 23 23 id="branch" 24 24 name="branch" 25 + value="main" 25 26 required 26 27 class="w-full max-w-md" 27 28 /> 28 - <p class="text-sm text-gray-500">The default branch is <span class="font-bold">main</span></p> 29 29 </div> 30 30 31 31 <fieldset class="space-y-3">
+4 -3
appview/state/follow.go
··· 9 9 comatproto "github.com/bluesky-social/indigo/api/atproto" 10 10 lexutil "github.com/bluesky-social/indigo/lex/util" 11 11 tangled "github.com/sotangled/tangled/api/tangled" 12 + "github.com/sotangled/tangled/appview/db" 12 13 ) 13 14 14 15 func (s *State) Follow(w http.ResponseWriter, r *http.Request) { ··· 51 52 return 52 53 } 53 54 54 - err = s.db.AddFollow(currentUser.Did, subjectIdent.DID.String(), rkey) 55 + err = db.AddFollow(s.db, currentUser.Did, subjectIdent.DID.String(), rkey) 55 56 if err != nil { 56 57 log.Println("failed to follow", err) 57 58 return ··· 73 74 return 74 75 case http.MethodDelete: 75 76 // find the record in the db 76 - follow, err := s.db.GetFollow(currentUser.Did, subjectIdent.DID.String()) 77 + follow, err := db.GetFollow(s.db, currentUser.Did, subjectIdent.DID.String()) 77 78 if err != nil { 78 79 log.Println("failed to get follow relationship") 79 80 return ··· 90 91 return 91 92 } 92 93 93 - err = s.db.DeleteFollow(currentUser.Did, subjectIdent.DID.String()) 94 + err = db.DeleteFollow(s.db, currentUser.Did, subjectIdent.DID.String()) 94 95 if err != nil { 95 96 log.Println("failed to delete follow from DB") 96 97 // this is not an issue, the firehose event might have already done this
+3 -3
appview/state/jetstream.go
··· 13 13 14 14 type Ingester func(ctx context.Context, e *models.Event) error 15 15 16 - func jetstreamIngester(db *db.DB) Ingester { 16 + func jetstreamIngester(d db.DbWrapper) Ingester { 17 17 return func(ctx context.Context, e *models.Event) error { 18 18 var err error 19 19 defer func() { 20 20 eventTime := e.TimeUS 21 21 lastTimeUs := eventTime + 1 22 - if err := db.UpdateLastTimeUs(lastTimeUs); err != nil { 22 + if err := d.UpdateLastTimeUs(lastTimeUs); err != nil { 23 23 err = fmt.Errorf("(deferred) failed to save last time us: %w", err) 24 24 } 25 25 }() ··· 39 39 log.Println("invalid record") 40 40 return err 41 41 } 42 - err = db.AddFollow(did, record.Subject, e.Commit.RKey) 42 + err = db.AddFollow(d, did, record.Subject, e.Commit.RKey) 43 43 if err != nil { 44 44 return fmt.Errorf("failed to add follow to db: %w", err) 45 45 }
+2 -1
appview/state/middleware.go
··· 13 13 "github.com/go-chi/chi/v5" 14 14 "github.com/sotangled/tangled/appview" 15 15 "github.com/sotangled/tangled/appview/auth" 16 + "github.com/sotangled/tangled/appview/db" 16 17 ) 17 18 18 19 type Middleware func(http.Handler) http.Handler ··· 176 177 return 177 178 } 178 179 179 - repo, err := s.db.GetRepo(id.DID.String(), repoName) 180 + repo, err := db.GetRepo(s.db, id.DID.String(), repoName) 180 181 if err != nil { 181 182 // invalid did or handle 182 183 log.Println("failed to resolve repo")
+46 -12
appview/state/repo.go
··· 389 389 390 390 // TODO: create an atproto record for this 391 391 392 - secret, err := s.db.GetRegistrationKey(f.Knot) 392 + secret, err := db.GetRegistrationKey(s.db, f.Knot) 393 393 if err != nil { 394 394 log.Printf("no key found for domain %s: %s\n", f.Knot, err) 395 395 return ··· 412 412 return 413 413 } 414 414 415 + tx, err := s.db.BeginTx(r.Context(), nil) 416 + if err != nil { 417 + log.Println("failed to start tx") 418 + w.Write([]byte(fmt.Sprint("failed to add collaborator: ", err))) 419 + return 420 + } 421 + defer func() { 422 + tx.Rollback() 423 + err = s.enforcer.E.LoadPolicy() 424 + if err != nil { 425 + log.Println("failed to rollback policies") 426 + } 427 + }() 428 + 415 429 err = s.enforcer.AddCollaborator(collaboratorIdent.DID.String(), f.Knot, f.OwnerSlashRepo()) 416 430 if err != nil { 417 431 w.Write([]byte(fmt.Sprint("failed to add collaborator: ", err))) 418 432 return 419 433 } 420 434 421 - err = s.db.AddCollaborator(collaboratorIdent.DID.String(), f.OwnerDid(), f.RepoName, f.Knot) 435 + err = db.AddCollaborator(s.db, collaboratorIdent.DID.String(), f.OwnerDid(), f.RepoName, f.Knot) 422 436 if err != nil { 423 437 w.Write([]byte(fmt.Sprint("failed to add collaborator: ", err))) 424 438 return 425 439 } 426 440 441 + err = tx.Commit() 442 + if err != nil { 443 + log.Println("failed to commit changes", err) 444 + http.Error(w, err.Error(), http.StatusInternalServerError) 445 + return 446 + } 447 + 448 + err = s.enforcer.E.SavePolicy() 449 + if err != nil { 450 + log.Println("failed to update ACLs", err) 451 + http.Error(w, err.Error(), http.StatusInternalServerError) 452 + return 453 + } 454 + 427 455 w.Write([]byte(fmt.Sprint("added collaborator: ", collaboratorIdent.Handle.String()))) 428 456 429 457 } ··· 546 574 return 547 575 } 548 576 549 - issue, comments, err := s.db.GetIssueWithComments(f.RepoAt, issueIdInt) 577 + issue, comments, err := db.GetIssueWithComments(s.db, f.RepoAt, issueIdInt) 550 578 if err != nil { 551 579 log.Println("failed to get issue and comments", err) 552 580 s.pages.Notice(w, "issues", "Failed to load issue. Try again later.") ··· 605 633 return 606 634 } 607 635 608 - issue, err := s.db.GetIssue(f.RepoAt, issueIdInt) 636 + issue, err := db.GetIssue(s.db, f.RepoAt, issueIdInt) 609 637 if err != nil { 610 638 log.Println("failed to get issue", err) 611 639 s.pages.Notice(w, "issue-action", "Failed to close issue. Try again later.") ··· 645 673 return 646 674 } 647 675 648 - err := s.db.CloseIssue(f.RepoAt, issueIdInt) 676 + err := db.CloseIssue(s.db, f.RepoAt, issueIdInt) 649 677 if err != nil { 650 678 log.Println("failed to close issue", err) 651 679 s.pages.Notice(w, "issue-action", "Failed to close issue. Try again later.") ··· 678 706 } 679 707 680 708 if user.Did == f.OwnerDid() { 681 - err := s.db.ReopenIssue(f.RepoAt, issueIdInt) 709 + err := db.ReopenIssue(s.db, f.RepoAt, issueIdInt) 682 710 if err != nil { 683 711 log.Println("failed to reopen issue", err) 684 712 s.pages.Notice(w, "issue-action", "Failed to reopen issue. Try again later.") ··· 719 747 720 748 commentId := rand.IntN(1000000) 721 749 722 - err := s.db.NewComment(&db.Comment{ 750 + err := db.NewComment(s.db, &db.Comment{ 723 751 OwnerDid: user.Did, 724 752 RepoAt: f.RepoAt, 725 753 Issue: issueIdInt, ··· 735 763 createdAt := time.Now().Format(time.RFC3339) 736 764 commentIdInt64 := int64(commentId) 737 765 ownerDid := user.Did 738 - issueAt, err := s.db.GetIssueAt(f.RepoAt, issueIdInt) 766 + issueAt, err := db.GetIssueAt(s.db, f.RepoAt, issueIdInt) 739 767 if err != nil { 740 768 log.Println("failed to get issue at", err) 741 769 s.pages.Notice(w, "issue-comment", "Failed to create comment.") ··· 777 805 return 778 806 } 779 807 780 - issues, err := s.db.GetIssues(f.RepoAt) 808 + issues, err := db.GetIssues(s.db, f.RepoAt) 781 809 if err != nil { 782 810 log.Println("failed to get issues", err) 783 811 s.pages.Notice(w, "issues", "Failed to load issues. Try again later.") ··· 841 869 return 842 870 } 843 871 844 - err = s.db.NewIssue(&db.Issue{ 872 + tx, err := s.db.BeginTx(r.Context(), nil) 873 + if err != nil { 874 + s.pages.Notice(w, "issues", "Failed to create issue, try again later") 875 + return 876 + } 877 + 878 + err = db.NewIssue(tx, &db.Issue{ 845 879 RepoAt: f.RepoAt, 846 880 Title: title, 847 881 Body: body, ··· 853 887 return 854 888 } 855 889 856 - issueId, err := s.db.GetIssueId(f.RepoAt) 890 + issueId, err := db.GetIssueId(s.db, f.RepoAt) 857 891 if err != nil { 858 892 log.Println("failed to get issue id", err) 859 893 s.pages.Notice(w, "issues", "Failed to create issue.") ··· 881 915 return 882 916 } 883 917 884 - err = s.db.SetIssueAt(f.RepoAt, issueId, resp.Uri) 918 + err = db.SetIssueAt(s.db, f.RepoAt, issueId, resp.Uri) 885 919 if err != nil { 886 920 log.Println("failed to set issue at", err) 887 921 s.pages.Notice(w, "issues", "Failed to create issue.")
+3 -2
appview/state/settings.go
··· 10 10 lexutil "github.com/bluesky-social/indigo/lex/util" 11 11 "github.com/gliderlabs/ssh" 12 12 "github.com/sotangled/tangled/api/tangled" 13 + "github.com/sotangled/tangled/appview/db" 13 14 "github.com/sotangled/tangled/appview/pages" 14 15 ) 15 16 16 17 func (s *State) Settings(w http.ResponseWriter, r *http.Request) { 17 18 // for now, this is just pubkeys 18 19 user := s.auth.GetUser(r) 19 - pubKeys, err := s.db.GetPublicKeys(user.Did) 20 + pubKeys, err := db.GetPublicKeys(s.db, user.Did) 20 21 if err != nil { 21 22 log.Println(err) 22 23 } ··· 47 48 return 48 49 } 49 50 50 - if err := s.db.AddPublicKey(did, name, key); err != nil { 51 + if err := db.AddPublicKey(s.db, did, name, key); err != nil { 51 52 log.Printf("adding public key: %s", err) 52 53 s.pages.Notice(w, "settings-keys", "Failed to add public key.") 53 54 return
+99 -36
appview/state/state.go
··· 39 39 } 40 40 41 41 func Make(config *appview.Config) (*State, error) { 42 - db, err := db.Make(config.DbPath) 42 + d, err := db.Make(config.DbPath) 43 43 if err != nil { 44 44 return nil, err 45 45 } ··· 60 60 61 61 resolver := appview.NewResolver() 62 62 63 - jc, err := jetstream.NewJetstreamClient("appview", []string{tangled.GraphFollowNSID}, nil, slog.Default(), db, false) 63 + wrapper := db.DbWrapper{d} 64 + jc, err := jetstream.NewJetstreamClient("appview", []string{tangled.GraphFollowNSID}, nil, slog.Default(), wrapper, false) 64 65 if err != nil { 65 66 return nil, fmt.Errorf("failed to create jetstream client: %w", err) 66 67 } 67 - err = jc.StartJetstream(context.Background(), jetstreamIngester(db)) 68 + err = jc.StartJetstream(context.Background(), jetstreamIngester(wrapper)) 68 69 if err != nil { 69 70 return nil, fmt.Errorf("failed to start jetstream watcher: %w", err) 70 71 } 71 72 72 73 state := &State{ 73 - db, 74 + d, 74 75 auth, 75 76 enforcer, 76 77 clock, ··· 135 136 func (s *State) Timeline(w http.ResponseWriter, r *http.Request) { 136 137 user := s.auth.GetUser(r) 137 138 138 - timeline, err := s.db.MakeTimeline() 139 + timeline, err := db.MakeTimeline(s.db) 139 140 if err != nil { 140 141 log.Println(err) 141 142 s.pages.Notice(w, "timeline", "Uh oh! Failed to load timeline.") ··· 195 196 return 196 197 } 197 198 198 - key, err := s.db.GenerateRegistrationKey(domain, did) 199 + key, err := db.GenerateRegistrationKey(s.db, domain, did) 199 200 200 201 if err != nil { 201 202 log.Println(err) ··· 222 223 return 223 224 } 224 225 225 - pubKeys, err := s.db.GetPublicKeys(id.DID.String()) 226 + pubKeys, err := db.GetPublicKeys(s.db, id.DID.String()) 226 227 if err != nil { 227 228 w.WriteHeader(http.StatusNotFound) 228 229 return ··· 250 251 } 251 252 log.Println("checking ", domain) 252 253 253 - secret, err := s.db.GetRegistrationKey(domain) 254 + secret, err := db.GetRegistrationKey(s.db, domain) 254 255 if err != nil { 255 256 log.Printf("no key found for domain %s: %s\n", domain, err) 256 257 return ··· 295 296 return 296 297 } 297 298 299 + tx, err := s.db.BeginTx(r.Context(), nil) 300 + if err != nil { 301 + log.Println("failed to start tx", err) 302 + http.Error(w, err.Error(), http.StatusInternalServerError) 303 + return 304 + } 305 + defer func() { 306 + tx.Rollback() 307 + err = s.enforcer.E.LoadPolicy() 308 + if err != nil { 309 + log.Println("failed to rollback policies") 310 + } 311 + }() 312 + 298 313 // mark as registered 299 - err = s.db.Register(domain) 314 + err = db.Register(tx, domain) 300 315 if err != nil { 301 316 log.Println("failed to register domain", err) 302 317 http.Error(w, err.Error(), http.StatusInternalServerError) ··· 304 319 } 305 320 306 321 // set permissions for this did as owner 307 - reg, err := s.db.RegistrationByDomain(domain) 322 + reg, err := db.RegistrationByDomain(tx, domain) 308 323 if err != nil { 309 324 log.Println("failed to register domain", err) 310 325 http.Error(w, err.Error(), http.StatusInternalServerError) ··· 327 342 return 328 343 } 329 344 345 + err = tx.Commit() 346 + if err != nil { 347 + log.Println("failed to commit changes", err) 348 + http.Error(w, err.Error(), http.StatusInternalServerError) 349 + return 350 + } 351 + 352 + err = s.enforcer.E.SavePolicy() 353 + if err != nil { 354 + log.Println("failed to update ACLs", err) 355 + http.Error(w, err.Error(), http.StatusInternalServerError) 356 + return 357 + } 358 + 330 359 w.Write([]byte("check success")) 331 360 } 332 361 ··· 338 367 } 339 368 340 369 user := s.auth.GetUser(r) 341 - reg, err := s.db.RegistrationByDomain(domain) 370 + reg, err := db.RegistrationByDomain(s.db, domain) 342 371 if err != nil { 343 372 w.Write([]byte("failed to pull up registration info")) 344 373 return ··· 370 399 func (s *State) Knots(w http.ResponseWriter, r *http.Request) { 371 400 // for now, this is just pubkeys 372 401 user := s.auth.GetUser(r) 373 - registrations, err := s.db.RegistrationsByDid(user.Did) 402 + registrations, err := db.RegistrationsByDid(s.db, user.Did) 374 403 if err != nil { 375 404 log.Println(err) 376 405 } ··· 444 473 } 445 474 log.Println("created atproto record: ", resp.Uri) 446 475 447 - secret, err := s.db.GetRegistrationKey(domain) 476 + secret, err := db.GetRegistrationKey(s.db, domain) 448 477 if err != nil { 449 478 log.Printf("no key found for domain %s: %s\n", domain, err) 450 479 return ··· 520 549 return 521 550 } 522 551 523 - secret, err := s.db.GetRegistrationKey(domain) 552 + existingRepo, err := db.GetRepo(s.db, user.Did, repoName) 553 + if err == nil && existingRepo != nil { 554 + s.pages.Notice(w, "repo", fmt.Sprintf("A repo by this name already exists on %s", existingRepo.Knot)) 555 + return 556 + } 557 + 558 + secret, err := db.GetRegistrationKey(s.db, domain) 524 559 if err != nil { 525 560 s.pages.Notice(w, "repo", fmt.Sprintf("No registration key found for knot %s.", domain)) 526 561 return ··· 532 567 return 533 568 } 534 569 535 - resp, err := client.NewRepo(user.Did, repoName, defaultBranch) 536 - if err != nil { 537 - s.pages.Notice(w, "repo", "Failed to create repository on knot server.") 538 - return 539 - } 540 - 541 - switch resp.StatusCode { 542 - case http.StatusConflict: 543 - s.pages.Notice(w, "repo", "A repository with that name already exists.") 544 - return 545 - case http.StatusInternalServerError: 546 - s.pages.Notice(w, "repo", "Failed to create repository on knot. Try again later.") 547 - case http.StatusNoContent: 548 - // continue 549 - } 550 - 551 570 rkey := s.TID() 552 571 repo := &db.Repo{ 553 572 Did: user.Did, ··· 578 597 } 579 598 log.Println("created repo record: ", atresp.Uri) 580 599 581 - repo.AtUri = atresp.Uri 600 + tx, err := s.db.BeginTx(r.Context(), nil) 601 + if err != nil { 602 + log.Println(err) 603 + s.pages.Notice(w, "repo", "Failed to save repository information.") 604 + return 605 + } 606 + defer func() { 607 + tx.Rollback() 608 + err = s.enforcer.E.LoadPolicy() 609 + if err != nil { 610 + log.Println("failed to rollback policies") 611 + } 612 + }() 582 613 583 - err = s.db.AddRepo(repo) 614 + resp, err := client.NewRepo(user.Did, repoName, defaultBranch) 615 + if err != nil { 616 + s.pages.Notice(w, "repo", "Failed to create repository on knot server.") 617 + return 618 + } 619 + 620 + switch resp.StatusCode { 621 + case http.StatusConflict: 622 + s.pages.Notice(w, "repo", "A repository with that name already exists.") 623 + return 624 + case http.StatusInternalServerError: 625 + s.pages.Notice(w, "repo", "Failed to create repository on knot. Try again later.") 626 + case http.StatusNoContent: 627 + // continue 628 + } 629 + 630 + repo.AtUri = atresp.Uri 631 + err = db.AddRepo(tx, repo) 584 632 if err != nil { 585 633 log.Println(err) 586 634 s.pages.Notice(w, "repo", "Failed to save repository information.") ··· 596 644 return 597 645 } 598 646 647 + err = tx.Commit() 648 + if err != nil { 649 + log.Println("failed to commit changes", err) 650 + http.Error(w, err.Error(), http.StatusInternalServerError) 651 + return 652 + } 653 + 654 + err = s.enforcer.E.SavePolicy() 655 + if err != nil { 656 + log.Println("failed to update ACLs", err) 657 + http.Error(w, err.Error(), http.StatusInternalServerError) 658 + return 659 + } 660 + 599 661 s.pages.HxLocation(w, fmt.Sprintf("/@%s/%s", user.Handle, repoName)) 600 662 return 601 663 } ··· 615 677 return 616 678 } 617 679 618 - repos, err := s.db.GetAllReposByDid(ident.DID.String()) 680 + repos, err := db.GetAllReposByDid(s.db, ident.DID.String()) 619 681 if err != nil { 620 682 log.Printf("getting repos for %s: %s", ident.DID.String(), err) 621 683 } 622 684 623 - collaboratingRepos, err := s.db.CollaboratingIn(ident.DID.String()) 685 + collaboratingRepos, err := db.CollaboratingIn(s.db, ident.DID.String()) 624 686 if err != nil { 625 687 log.Printf("getting collaborating repos for %s: %s", ident.DID.String(), err) 626 688 } ··· 638 700 } 639 701 } 640 702 641 - followers, following, err := s.db.GetFollowerFollowing(ident.DID.String()) 703 + followers, following, err := db.GetFollowerFollowing(s.db, ident.DID.String()) 642 704 if err != nil { 643 705 log.Printf("getting follow stats repos for %s: %s", ident.DID.String(), err) 644 706 } ··· 646 708 loggedInUser := s.auth.GetUser(r) 647 709 followStatus := db.IsNotFollowing 648 710 if loggedInUser != nil { 649 - followStatus = s.db.GetFollowStatus(loggedInUser.Did, ident.DID.String()) 711 + followStatus = db.GetFollowStatus(s.db, loggedInUser.Did, ident.DID.String()) 650 712 } 651 713 652 714 profileAvatarUri, err := GetAvatarUri(ident.DID.String()) ··· 818 880 819 881 r.Route("/repo", func(r chi.Router) { 820 882 r.Route("/new", func(r chi.Router) { 883 + r.Use(AuthMiddleware(s)) 821 884 r.Get("/", s.AddRepo) 822 885 r.Post("/", s.AddRepo) 823 886 })
+8 -7
rbac/rbac.go
··· 6 6 "path" 7 7 "strings" 8 8 9 - sqladapter "github.com/Blank-Xu/sql-adapter" 9 + adapter "github.com/Blank-Xu/sql-adapter" 10 10 "github.com/casbin/casbin/v2" 11 11 "github.com/casbin/casbin/v2/model" 12 12 ) ··· 31 31 ) 32 32 33 33 type Enforcer struct { 34 - E *casbin.SyncedEnforcer 34 + E *casbin.Enforcer 35 35 } 36 36 37 37 func keyMatch2(key1 string, key2 string) bool { ··· 50 50 return nil, err 51 51 } 52 52 53 - a, err := sqladapter.NewAdapter(db, "sqlite3", "acl") 53 + a, err := adapter.NewAdapter(db, "sqlite3", "acl") 54 54 if err != nil { 55 55 return nil, err 56 56 } 57 57 58 - e, err := casbin.NewSyncedEnforcer(m, a) 58 + e, err := casbin.NewEnforcer(m, a) 59 59 if err != nil { 60 60 return nil, err 61 61 } 62 62 63 - e.EnableAutoSave(true) 63 + e.EnableAutoSave(false) 64 + 64 65 e.AddFunction("keyMatch2", keyMatch2Func) 65 66 66 67 return &Enforcer{e}, nil ··· 82 83 } 83 84 84 85 func (e *Enforcer) GetDomainsForUser(did string) ([]string, error) { 85 - return e.E.Enforcer.GetDomainsForUser(did) 86 + return e.E.GetDomainsForUser(did) 86 87 } 87 88 88 89 func (e *Enforcer) AddOwner(domain, owner string) error { ··· 131 132 132 133 // this includes roles too, casbin does not differentiate. 133 134 // the filtering criteria is to remove strings not starting with `did:` 134 - members, err := e.E.Enforcer.GetImplicitUsersForRole(role, domain) 135 + members, err := e.E.GetImplicitUsersForRole(role, domain) 135 136 for _, m := range members { 136 137 if strings.HasPrefix(m, "did:") { 137 138 membersWithoutRoles = append(membersWithoutRoles, m)