forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
Monorepo for Tangled
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16)
17
18func NewPull(tx *sql.Tx, pull *models.Pull) error {
19 _, err := tx.Exec(`
20 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
21 values (?, 1)
22 `, pull.RepoAt)
23 if err != nil {
24 return err
25 }
26
27 var nextId int
28 err = tx.QueryRow(`
29 update repo_pull_seqs
30 set next_pull_id = next_pull_id + 1
31 where repo_at = ?
32 returning next_pull_id - 1
33 `, pull.RepoAt).Scan(&nextId)
34 if err != nil {
35 return err
36 }
37
38 pull.PullId = nextId
39 pull.State = models.PullOpen
40
41 var sourceBranch, sourceRepoAt *string
42 if pull.PullSource != nil {
43 sourceBranch = &pull.PullSource.Branch
44 if pull.PullSource.RepoAt != nil {
45 x := pull.PullSource.RepoAt.String()
46 sourceRepoAt = &x
47 }
48 }
49
50 var stackId, changeId, parentChangeId *string
51 if pull.StackId != "" {
52 stackId = &pull.StackId
53 }
54 if pull.ChangeId != "" {
55 changeId = &pull.ChangeId
56 }
57 if pull.ParentChangeId != "" {
58 parentChangeId = &pull.ParentChangeId
59 }
60
61 result, err := tx.Exec(
62 `
63 insert into pulls (
64 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
65 )
66 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
67 pull.RepoAt,
68 pull.OwnerDid,
69 pull.PullId,
70 pull.Title,
71 pull.TargetBranch,
72 pull.Body,
73 pull.Rkey,
74 pull.State,
75 sourceBranch,
76 sourceRepoAt,
77 stackId,
78 changeId,
79 parentChangeId,
80 )
81 if err != nil {
82 return err
83 }
84
85 // Set the database primary key ID
86 id, err := result.LastInsertId()
87 if err != nil {
88 return err
89 }
90 pull.ID = int(id)
91
92 _, err = tx.Exec(`
93 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
94 values (?, ?, ?, ?, ?)
95 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
96 return err
97}
98
99func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
100 pull, err := GetPull(e, repoAt, pullId)
101 if err != nil {
102 return "", err
103 }
104 return pull.AtUri(), err
105}
106
107func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
108 var pullId int
109 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
110 return pullId - 1, err
111}
112
113func GetPullsWithLimit(e Execer, limit int, filters ...filter) ([]*models.Pull, error) {
114 pulls := make(map[syntax.ATURI]*models.Pull)
115
116 var conditions []string
117 var args []any
118 for _, filter := range filters {
119 conditions = append(conditions, filter.Condition())
120 args = append(args, filter.Arg()...)
121 }
122
123 whereClause := ""
124 if conditions != nil {
125 whereClause = " where " + strings.Join(conditions, " and ")
126 }
127 limitClause := ""
128 if limit != 0 {
129 limitClause = fmt.Sprintf(" limit %d ", limit)
130 }
131
132 query := fmt.Sprintf(`
133 select
134 id,
135 owner_did,
136 repo_at,
137 pull_id,
138 created,
139 title,
140 state,
141 target_branch,
142 body,
143 rkey,
144 source_branch,
145 source_repo_at,
146 stack_id,
147 change_id,
148 parent_change_id
149 from
150 pulls
151 %s
152 order by
153 created desc
154 %s
155 `, whereClause, limitClause)
156
157 rows, err := e.Query(query, args...)
158 if err != nil {
159 return nil, err
160 }
161 defer rows.Close()
162
163 for rows.Next() {
164 var pull models.Pull
165 var createdAt string
166 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
167 err := rows.Scan(
168 &pull.ID,
169 &pull.OwnerDid,
170 &pull.RepoAt,
171 &pull.PullId,
172 &createdAt,
173 &pull.Title,
174 &pull.State,
175 &pull.TargetBranch,
176 &pull.Body,
177 &pull.Rkey,
178 &sourceBranch,
179 &sourceRepoAt,
180 &stackId,
181 &changeId,
182 &parentChangeId,
183 )
184 if err != nil {
185 return nil, err
186 }
187
188 createdTime, err := time.Parse(time.RFC3339, createdAt)
189 if err != nil {
190 return nil, err
191 }
192 pull.Created = createdTime
193
194 if sourceBranch.Valid {
195 pull.PullSource = &models.PullSource{
196 Branch: sourceBranch.String,
197 }
198 if sourceRepoAt.Valid {
199 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
200 if err != nil {
201 return nil, err
202 }
203 pull.PullSource.RepoAt = &sourceRepoAtParsed
204 }
205 }
206
207 if stackId.Valid {
208 pull.StackId = stackId.String
209 }
210 if changeId.Valid {
211 pull.ChangeId = changeId.String
212 }
213 if parentChangeId.Valid {
214 pull.ParentChangeId = parentChangeId.String
215 }
216
217 pulls[pull.AtUri()] = &pull
218 }
219
220 var pullAts []syntax.ATURI
221 for _, p := range pulls {
222 pullAts = append(pullAts, p.AtUri())
223 }
224 submissionsMap, err := GetPullSubmissions(e, FilterIn("pull_at", pullAts))
225 if err != nil {
226 return nil, fmt.Errorf("failed to get submissions: %w", err)
227 }
228
229 for pullAt, submissions := range submissionsMap {
230 if p, ok := pulls[pullAt]; ok {
231 p.Submissions = submissions
232 }
233 }
234
235 // collect allLabels for each issue
236 allLabels, err := GetLabels(e, FilterIn("subject", pullAts))
237 if err != nil {
238 return nil, fmt.Errorf("failed to query labels: %w", err)
239 }
240 for pullAt, labels := range allLabels {
241 if p, ok := pulls[pullAt]; ok {
242 p.Labels = labels
243 }
244 }
245
246 // collect pull source for all pulls that need it
247 var sourceAts []syntax.ATURI
248 for _, p := range pulls {
249 if p.PullSource != nil && p.PullSource.RepoAt != nil {
250 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
251 }
252 }
253 sourceRepos, err := GetRepos(e, 0, FilterIn("at_uri", sourceAts))
254 if err != nil && !errors.Is(err, sql.ErrNoRows) {
255 return nil, fmt.Errorf("failed to get source repos: %w", err)
256 }
257 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
258 for _, r := range sourceRepos {
259 sourceRepoMap[r.RepoAt()] = &r
260 }
261 for _, p := range pulls {
262 if p.PullSource != nil && p.PullSource.RepoAt != nil {
263 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
264 p.PullSource.Repo = sourceRepo
265 }
266 }
267 }
268
269 orderedByPullId := []*models.Pull{}
270 for _, p := range pulls {
271 orderedByPullId = append(orderedByPullId, p)
272 }
273 sort.Slice(orderedByPullId, func(i, j int) bool {
274 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
275 })
276
277 return orderedByPullId, nil
278}
279
280func GetPulls(e Execer, filters ...filter) ([]*models.Pull, error) {
281 return GetPullsWithLimit(e, 0, filters...)
282}
283
284func GetPullIDs(e Execer, opts models.PullSearchOptions) ([]int64, error) {
285 var ids []int64
286
287 var filters []filter
288 filters = append(filters, FilterEq("state", opts.State))
289 if opts.RepoAt != "" {
290 filters = append(filters, FilterEq("repo_at", opts.RepoAt))
291 }
292
293 var conditions []string
294 var args []any
295
296 for _, filter := range filters {
297 conditions = append(conditions, filter.Condition())
298 args = append(args, filter.Arg()...)
299 }
300
301 whereClause := ""
302 if conditions != nil {
303 whereClause = " where " + strings.Join(conditions, " and ")
304 }
305 pageClause := ""
306 if opts.Page.Limit != 0 {
307 pageClause = fmt.Sprintf(
308 " limit %d offset %d ",
309 opts.Page.Limit,
310 opts.Page.Offset,
311 )
312 }
313
314 query := fmt.Sprintf(
315 `
316 select
317 id
318 from
319 pulls
320 %s
321 %s`,
322 whereClause,
323 pageClause,
324 )
325 args = append(args, opts.Page.Limit, opts.Page.Offset)
326 rows, err := e.Query(query, args...)
327 if err != nil {
328 return nil, err
329 }
330 defer rows.Close()
331
332 for rows.Next() {
333 var id int64
334 err := rows.Scan(&id)
335 if err != nil {
336 return nil, err
337 }
338
339 ids = append(ids, id)
340 }
341
342 return ids, nil
343}
344
345func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
346 pulls, err := GetPullsWithLimit(e, 1, FilterEq("repo_at", repoAt), FilterEq("pull_id", pullId))
347 if err != nil {
348 return nil, err
349 }
350 if len(pulls) == 0 {
351 return nil, sql.ErrNoRows
352 }
353
354 return pulls[0], nil
355}
356
357// mapping from pull -> pull submissions
358func GetPullSubmissions(e Execer, filters ...filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
359 var conditions []string
360 var args []any
361 for _, filter := range filters {
362 conditions = append(conditions, filter.Condition())
363 args = append(args, filter.Arg()...)
364 }
365
366 whereClause := ""
367 if conditions != nil {
368 whereClause = " where " + strings.Join(conditions, " and ")
369 }
370
371 query := fmt.Sprintf(`
372 select
373 id,
374 pull_at,
375 round_number,
376 patch,
377 combined,
378 created,
379 source_rev
380 from
381 pull_submissions
382 %s
383 order by
384 round_number asc
385 `, whereClause)
386
387 rows, err := e.Query(query, args...)
388 if err != nil {
389 return nil, err
390 }
391 defer rows.Close()
392
393 submissionMap := make(map[int]*models.PullSubmission)
394
395 for rows.Next() {
396 var submission models.PullSubmission
397 var submissionCreatedStr string
398 var submissionSourceRev, submissionCombined sql.NullString
399 err := rows.Scan(
400 &submission.ID,
401 &submission.PullAt,
402 &submission.RoundNumber,
403 &submission.Patch,
404 &submissionCombined,
405 &submissionCreatedStr,
406 &submissionSourceRev,
407 )
408 if err != nil {
409 return nil, err
410 }
411
412 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
413 submission.Created = t
414 }
415
416 if submissionSourceRev.Valid {
417 submission.SourceRev = submissionSourceRev.String
418 }
419
420 if submissionCombined.Valid {
421 submission.Combined = submissionCombined.String
422 }
423
424 submissionMap[submission.ID] = &submission
425 }
426
427 if err := rows.Err(); err != nil {
428 return nil, err
429 }
430
431 // Get comments for all submissions using GetPullComments
432 submissionIds := slices.Collect(maps.Keys(submissionMap))
433 comments, err := GetPullComments(e, FilterIn("submission_id", submissionIds))
434 if err != nil {
435 return nil, err
436 }
437 for _, comment := range comments {
438 if submission, ok := submissionMap[comment.SubmissionId]; ok {
439 submission.Comments = append(submission.Comments, comment)
440 }
441 }
442
443 // group the submissions by pull_at
444 m := make(map[syntax.ATURI][]*models.PullSubmission)
445 for _, s := range submissionMap {
446 m[s.PullAt] = append(m[s.PullAt], s)
447 }
448
449 // sort each one by round number
450 for _, s := range m {
451 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
452 return cmp.Compare(a.RoundNumber, b.RoundNumber)
453 })
454 }
455
456 return m, nil
457}
458
459func GetPullComments(e Execer, filters ...filter) ([]models.PullComment, error) {
460 var conditions []string
461 var args []any
462 for _, filter := range filters {
463 conditions = append(conditions, filter.Condition())
464 args = append(args, filter.Arg()...)
465 }
466
467 whereClause := ""
468 if conditions != nil {
469 whereClause = " where " + strings.Join(conditions, " and ")
470 }
471
472 query := fmt.Sprintf(`
473 select
474 id,
475 pull_id,
476 submission_id,
477 repo_at,
478 owner_did,
479 comment_at,
480 body,
481 created
482 from
483 pull_comments
484 %s
485 order by
486 created asc
487 `, whereClause)
488
489 rows, err := e.Query(query, args...)
490 if err != nil {
491 return nil, err
492 }
493 defer rows.Close()
494
495 var comments []models.PullComment
496 for rows.Next() {
497 var comment models.PullComment
498 var createdAt string
499 err := rows.Scan(
500 &comment.ID,
501 &comment.PullId,
502 &comment.SubmissionId,
503 &comment.RepoAt,
504 &comment.OwnerDid,
505 &comment.CommentAt,
506 &comment.Body,
507 &createdAt,
508 )
509 if err != nil {
510 return nil, err
511 }
512
513 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
514 comment.Created = t
515 }
516
517 comments = append(comments, comment)
518 }
519
520 if err := rows.Err(); err != nil {
521 return nil, err
522 }
523
524 return comments, nil
525}
526
527// timeframe here is directly passed into the sql query filter, and any
528// timeframe in the past should be negative; e.g.: "-3 months"
529func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
530 var pulls []models.Pull
531
532 rows, err := e.Query(`
533 select
534 p.owner_did,
535 p.repo_at,
536 p.pull_id,
537 p.created,
538 p.title,
539 p.state,
540 r.did,
541 r.name,
542 r.knot,
543 r.rkey,
544 r.created
545 from
546 pulls p
547 join
548 repos r on p.repo_at = r.at_uri
549 where
550 p.owner_did = ? and p.created >= date ('now', ?)
551 order by
552 p.created desc`, did, timeframe)
553 if err != nil {
554 return nil, err
555 }
556 defer rows.Close()
557
558 for rows.Next() {
559 var pull models.Pull
560 var repo models.Repo
561 var pullCreatedAt, repoCreatedAt string
562 err := rows.Scan(
563 &pull.OwnerDid,
564 &pull.RepoAt,
565 &pull.PullId,
566 &pullCreatedAt,
567 &pull.Title,
568 &pull.State,
569 &repo.Did,
570 &repo.Name,
571 &repo.Knot,
572 &repo.Rkey,
573 &repoCreatedAt,
574 )
575 if err != nil {
576 return nil, err
577 }
578
579 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
580 if err != nil {
581 return nil, err
582 }
583 pull.Created = pullCreatedTime
584
585 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
586 if err != nil {
587 return nil, err
588 }
589 repo.Created = repoCreatedTime
590
591 pull.Repo = &repo
592
593 pulls = append(pulls, pull)
594 }
595
596 if err := rows.Err(); err != nil {
597 return nil, err
598 }
599
600 return pulls, nil
601}
602
603func NewPullComment(e Execer, comment *models.PullComment) (int64, error) {
604 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
605 res, err := e.Exec(
606 query,
607 comment.OwnerDid,
608 comment.RepoAt,
609 comment.SubmissionId,
610 comment.CommentAt,
611 comment.PullId,
612 comment.Body,
613 )
614 if err != nil {
615 return 0, err
616 }
617
618 i, err := res.LastInsertId()
619 if err != nil {
620 return 0, err
621 }
622
623 return i, nil
624}
625
626func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
627 _, err := e.Exec(
628 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
629 pullState,
630 repoAt,
631 pullId,
632 models.PullDeleted, // only update state of non-deleted pulls
633 models.PullMerged, // only update state of non-merged pulls
634 )
635 return err
636}
637
638func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
639 err := SetPullState(e, repoAt, pullId, models.PullClosed)
640 return err
641}
642
643func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
644 err := SetPullState(e, repoAt, pullId, models.PullOpen)
645 return err
646}
647
648func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
649 err := SetPullState(e, repoAt, pullId, models.PullMerged)
650 return err
651}
652
653func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
654 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
655 return err
656}
657
658func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
659 _, err := e.Exec(`
660 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
661 values (?, ?, ?, ?, ?)
662 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
663
664 return err
665}
666
667func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
668 var conditions []string
669 var args []any
670
671 args = append(args, parentChangeId)
672
673 for _, filter := range filters {
674 conditions = append(conditions, filter.Condition())
675 args = append(args, filter.Arg()...)
676 }
677
678 whereClause := ""
679 if conditions != nil {
680 whereClause = " where " + strings.Join(conditions, " and ")
681 }
682
683 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
684 _, err := e.Exec(query, args...)
685
686 return err
687}
688
689// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
690// otherwise submissions are immutable
691func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
692 var conditions []string
693 var args []any
694
695 args = append(args, sourceRev)
696 args = append(args, newPatch)
697
698 for _, filter := range filters {
699 conditions = append(conditions, filter.Condition())
700 args = append(args, filter.Arg()...)
701 }
702
703 whereClause := ""
704 if conditions != nil {
705 whereClause = " where " + strings.Join(conditions, " and ")
706 }
707
708 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
709 _, err := e.Exec(query, args...)
710
711 return err
712}
713
714func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
715 row := e.QueryRow(`
716 select
717 count(case when state = ? then 1 end) as open_count,
718 count(case when state = ? then 1 end) as merged_count,
719 count(case when state = ? then 1 end) as closed_count,
720 count(case when state = ? then 1 end) as deleted_count
721 from pulls
722 where repo_at = ?`,
723 models.PullOpen,
724 models.PullMerged,
725 models.PullClosed,
726 models.PullDeleted,
727 repoAt,
728 )
729
730 var count models.PullCount
731 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
732 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
733 }
734
735 return count, nil
736}
737
738// change-id parent-change-id
739//
740// 4 w ,-------- z (TOP)
741// 3 z <----',------- y
742// 2 y <-----',------ x
743// 1 x <------' nil (BOT)
744//
745// `w` is parent of none, so it is the top of the stack
746func GetStack(e Execer, stackId string) (models.Stack, error) {
747 unorderedPulls, err := GetPulls(
748 e,
749 FilterEq("stack_id", stackId),
750 FilterNotEq("state", models.PullDeleted),
751 )
752 if err != nil {
753 return nil, err
754 }
755 // map of parent-change-id to pull
756 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
757 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
758 for _, p := range unorderedPulls {
759 changeIdMap[p.ChangeId] = p
760 if p.ParentChangeId != "" {
761 parentMap[p.ParentChangeId] = p
762 }
763 }
764
765 // the top of the stack is the pull that is not a parent of any pull
766 var topPull *models.Pull
767 for _, maybeTop := range unorderedPulls {
768 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
769 topPull = maybeTop
770 break
771 }
772 }
773
774 pulls := []*models.Pull{}
775 for {
776 pulls = append(pulls, topPull)
777 if topPull.ParentChangeId != "" {
778 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
779 topPull = next
780 } else {
781 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
782 }
783 } else {
784 break
785 }
786 }
787
788 return pulls, nil
789}
790
791func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
792 pulls, err := GetPulls(
793 e,
794 FilterEq("stack_id", stackId),
795 FilterEq("state", models.PullDeleted),
796 )
797 if err != nil {
798 return nil, err
799 }
800
801 return pulls, nil
802}