forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
Monorepo for Tangled
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 }
91
92 record := tangled.RepoPull{
93 Title: p.Title,
94 Body: &p.Body,
95 CreatedAt: p.Created.Format(time.RFC3339),
96 PullId: int64(p.PullId),
97 TargetRepo: p.RepoAt.String(),
98 TargetBranch: p.TargetBranch,
99 Patch: p.LatestPatch(),
100 Source: source,
101 }
102 return record
103}
104
105type PullSource struct {
106 Branch string
107 RepoAt *syntax.ATURI
108
109 // optionally populate this for reverse mappings
110 Repo *Repo
111}
112
113func (p PullSource) AsRecord() tangled.RepoPull_Source {
114 var repoAt *string
115 if p.RepoAt != nil {
116 s := p.RepoAt.String()
117 repoAt = &s
118 }
119 record := tangled.RepoPull_Source{
120 Branch: p.Branch,
121 Repo: repoAt,
122 }
123 return record
124}
125
126type PullSubmission struct {
127 // ids
128 ID int
129 PullId int
130
131 // at ids
132 RepoAt syntax.ATURI
133
134 // content
135 RoundNumber int
136 Patch string
137 Comments []PullComment
138 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
139
140 // meta
141 Created time.Time
142}
143
144type PullComment struct {
145 // ids
146 ID int
147 PullId int
148 SubmissionId int
149
150 // at ids
151 RepoAt string
152 OwnerDid string
153 CommentAt string
154
155 // content
156 Body string
157
158 // meta
159 Created time.Time
160}
161
162func (p *Pull) LatestPatch() string {
163 latestSubmission := p.Submissions[p.LastRoundNumber()]
164 return latestSubmission.Patch
165}
166
167func (p *Pull) PullAt() syntax.ATURI {
168 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
169}
170
171func (p *Pull) LastRoundNumber() int {
172 return len(p.Submissions) - 1
173}
174
175func (p *Pull) IsPatchBased() bool {
176 return p.PullSource == nil
177}
178
179func (p *Pull) IsBranchBased() bool {
180 if p.PullSource != nil {
181 if p.PullSource.RepoAt != nil {
182 return p.PullSource.RepoAt == &p.RepoAt
183 } else {
184 // no repo specified
185 return true
186 }
187 }
188 return false
189}
190
191func (p *Pull) IsForkBased() bool {
192 if p.PullSource != nil {
193 if p.PullSource.RepoAt != nil {
194 // make sure repos are different
195 return p.PullSource.RepoAt != &p.RepoAt
196 }
197 }
198 return false
199}
200
201func (p *Pull) IsStacked() bool {
202 return p.StackId != ""
203}
204
205func (s PullSubmission) IsFormatPatch() bool {
206 return patchutil.IsFormatPatch(s.Patch)
207}
208
209func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
210 patches, err := patchutil.ExtractPatches(s.Patch)
211 if err != nil {
212 log.Println("error extracting patches from submission:", err)
213 return []types.FormatPatch{}
214 }
215
216 return patches
217}
218
219func NewPull(tx *sql.Tx, pull *Pull) error {
220 _, err := tx.Exec(`
221 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
222 values (?, 1)
223 `, pull.RepoAt)
224 if err != nil {
225 return err
226 }
227
228 var nextId int
229 err = tx.QueryRow(`
230 update repo_pull_seqs
231 set next_pull_id = next_pull_id + 1
232 where repo_at = ?
233 returning next_pull_id - 1
234 `, pull.RepoAt).Scan(&nextId)
235 if err != nil {
236 return err
237 }
238
239 pull.PullId = nextId
240 pull.State = PullOpen
241
242 var sourceBranch, sourceRepoAt *string
243 if pull.PullSource != nil {
244 sourceBranch = &pull.PullSource.Branch
245 if pull.PullSource.RepoAt != nil {
246 x := pull.PullSource.RepoAt.String()
247 sourceRepoAt = &x
248 }
249 }
250
251 var stackId, changeId, parentChangeId *string
252 if pull.StackId != "" {
253 stackId = &pull.StackId
254 }
255 if pull.ChangeId != "" {
256 changeId = &pull.ChangeId
257 }
258 if pull.ParentChangeId != "" {
259 parentChangeId = &pull.ParentChangeId
260 }
261
262 _, err = tx.Exec(
263 `
264 insert into pulls (
265 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
266 )
267 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
268 pull.RepoAt,
269 pull.OwnerDid,
270 pull.PullId,
271 pull.Title,
272 pull.TargetBranch,
273 pull.Body,
274 pull.Rkey,
275 pull.State,
276 sourceBranch,
277 sourceRepoAt,
278 stackId,
279 changeId,
280 parentChangeId,
281 )
282 if err != nil {
283 return err
284 }
285
286 _, err = tx.Exec(`
287 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
288 values (?, ?, ?, ?, ?)
289 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
290 return err
291}
292
293func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
294 pull, err := GetPull(e, repoAt, pullId)
295 if err != nil {
296 return "", err
297 }
298 return pull.PullAt(), err
299}
300
301func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
302 var pullId int
303 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
304 return pullId - 1, err
305}
306
307func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
308 pulls := make(map[int]*Pull)
309
310 var conditions []string
311 var args []any
312 for _, filter := range filters {
313 conditions = append(conditions, filter.Condition())
314 args = append(args, filter.arg)
315 }
316
317 whereClause := ""
318 if conditions != nil {
319 whereClause = " where " + strings.Join(conditions, " and ")
320 }
321
322 query := fmt.Sprintf(`
323 select
324 owner_did,
325 repo_at,
326 pull_id,
327 created,
328 title,
329 state,
330 target_branch,
331 body,
332 rkey,
333 source_branch,
334 source_repo_at,
335 stack_id,
336 change_id,
337 parent_change_id
338 from
339 pulls
340 %s
341 `, whereClause)
342
343 rows, err := e.Query(query, args...)
344 if err != nil {
345 return nil, err
346 }
347 defer rows.Close()
348
349 for rows.Next() {
350 var pull Pull
351 var createdAt string
352 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
353 err := rows.Scan(
354 &pull.OwnerDid,
355 &pull.RepoAt,
356 &pull.PullId,
357 &createdAt,
358 &pull.Title,
359 &pull.State,
360 &pull.TargetBranch,
361 &pull.Body,
362 &pull.Rkey,
363 &sourceBranch,
364 &sourceRepoAt,
365 &stackId,
366 &changeId,
367 &parentChangeId,
368 )
369 if err != nil {
370 return nil, err
371 }
372
373 createdTime, err := time.Parse(time.RFC3339, createdAt)
374 if err != nil {
375 return nil, err
376 }
377 pull.Created = createdTime
378
379 if sourceBranch.Valid {
380 pull.PullSource = &PullSource{
381 Branch: sourceBranch.String,
382 }
383 if sourceRepoAt.Valid {
384 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
385 if err != nil {
386 return nil, err
387 }
388 pull.PullSource.RepoAt = &sourceRepoAtParsed
389 }
390 }
391
392 if stackId.Valid {
393 pull.StackId = stackId.String
394 }
395 if changeId.Valid {
396 pull.ChangeId = changeId.String
397 }
398 if parentChangeId.Valid {
399 pull.ParentChangeId = parentChangeId.String
400 }
401
402 pulls[pull.PullId] = &pull
403 }
404
405 // get latest round no. for each pull
406 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
407 submissionsQuery := fmt.Sprintf(`
408 select
409 id, pull_id, round_number, patch, source_rev
410 from
411 pull_submissions
412 where
413 repo_at in (%s) and pull_id in (%s)
414 `, inClause, inClause)
415
416 args = make([]any, len(pulls)*2)
417 idx := 0
418 for _, p := range pulls {
419 args[idx] = p.RepoAt
420 idx += 1
421 }
422 for _, p := range pulls {
423 args[idx] = p.PullId
424 idx += 1
425 }
426 submissionsRows, err := e.Query(submissionsQuery, args...)
427 if err != nil {
428 return nil, err
429 }
430 defer submissionsRows.Close()
431
432 for submissionsRows.Next() {
433 var s PullSubmission
434 var sourceRev sql.NullString
435 err := submissionsRows.Scan(
436 &s.ID,
437 &s.PullId,
438 &s.RoundNumber,
439 &s.Patch,
440 &sourceRev,
441 )
442 if err != nil {
443 return nil, err
444 }
445
446 if sourceRev.Valid {
447 s.SourceRev = sourceRev.String
448 }
449
450 if p, ok := pulls[s.PullId]; ok {
451 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
452 p.Submissions[s.RoundNumber] = &s
453 }
454 }
455 if err := rows.Err(); err != nil {
456 return nil, err
457 }
458
459 // get comment count on latest submission on each pull
460 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
461 commentsQuery := fmt.Sprintf(`
462 select
463 count(id), pull_id
464 from
465 pull_comments
466 where
467 submission_id in (%s)
468 group by
469 submission_id
470 `, inClause)
471
472 args = []any{}
473 for _, p := range pulls {
474 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
475 }
476 commentsRows, err := e.Query(commentsQuery, args...)
477 if err != nil {
478 return nil, err
479 }
480 defer commentsRows.Close()
481
482 for commentsRows.Next() {
483 var commentCount, pullId int
484 err := commentsRows.Scan(
485 &commentCount,
486 &pullId,
487 )
488 if err != nil {
489 return nil, err
490 }
491 if p, ok := pulls[pullId]; ok {
492 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
493 }
494 }
495 if err := rows.Err(); err != nil {
496 return nil, err
497 }
498
499 orderedByPullId := []*Pull{}
500 for _, p := range pulls {
501 orderedByPullId = append(orderedByPullId, p)
502 }
503 sort.Slice(orderedByPullId, func(i, j int) bool {
504 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
505 })
506
507 return orderedByPullId, nil
508}
509
510func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
511 query := `
512 select
513 owner_did,
514 pull_id,
515 created,
516 title,
517 state,
518 target_branch,
519 repo_at,
520 body,
521 rkey,
522 source_branch,
523 source_repo_at,
524 stack_id,
525 change_id,
526 parent_change_id
527 from
528 pulls
529 where
530 repo_at = ? and pull_id = ?
531 `
532 row := e.QueryRow(query, repoAt, pullId)
533
534 var pull Pull
535 var createdAt string
536 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
537 err := row.Scan(
538 &pull.OwnerDid,
539 &pull.PullId,
540 &createdAt,
541 &pull.Title,
542 &pull.State,
543 &pull.TargetBranch,
544 &pull.RepoAt,
545 &pull.Body,
546 &pull.Rkey,
547 &sourceBranch,
548 &sourceRepoAt,
549 &stackId,
550 &changeId,
551 &parentChangeId,
552 )
553 if err != nil {
554 return nil, err
555 }
556
557 createdTime, err := time.Parse(time.RFC3339, createdAt)
558 if err != nil {
559 return nil, err
560 }
561 pull.Created = createdTime
562
563 // populate source
564 if sourceBranch.Valid {
565 pull.PullSource = &PullSource{
566 Branch: sourceBranch.String,
567 }
568 if sourceRepoAt.Valid {
569 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
570 if err != nil {
571 return nil, err
572 }
573 pull.PullSource.RepoAt = &sourceRepoAtParsed
574 }
575 }
576
577 if stackId.Valid {
578 pull.StackId = stackId.String
579 }
580 if changeId.Valid {
581 pull.ChangeId = changeId.String
582 }
583 if parentChangeId.Valid {
584 pull.ParentChangeId = parentChangeId.String
585 }
586
587 submissionsQuery := `
588 select
589 id, pull_id, repo_at, round_number, patch, created, source_rev
590 from
591 pull_submissions
592 where
593 repo_at = ? and pull_id = ?
594 `
595 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
596 if err != nil {
597 return nil, err
598 }
599 defer submissionsRows.Close()
600
601 submissionsMap := make(map[int]*PullSubmission)
602
603 for submissionsRows.Next() {
604 var submission PullSubmission
605 var submissionCreatedStr string
606 var submissionSourceRev sql.NullString
607 err := submissionsRows.Scan(
608 &submission.ID,
609 &submission.PullId,
610 &submission.RepoAt,
611 &submission.RoundNumber,
612 &submission.Patch,
613 &submissionCreatedStr,
614 &submissionSourceRev,
615 )
616 if err != nil {
617 return nil, err
618 }
619
620 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
621 if err != nil {
622 return nil, err
623 }
624 submission.Created = submissionCreatedTime
625
626 if submissionSourceRev.Valid {
627 submission.SourceRev = submissionSourceRev.String
628 }
629
630 submissionsMap[submission.ID] = &submission
631 }
632 if err = submissionsRows.Close(); err != nil {
633 return nil, err
634 }
635 if len(submissionsMap) == 0 {
636 return &pull, nil
637 }
638
639 var args []any
640 for k := range submissionsMap {
641 args = append(args, k)
642 }
643 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
644 commentsQuery := fmt.Sprintf(`
645 select
646 id,
647 pull_id,
648 submission_id,
649 repo_at,
650 owner_did,
651 comment_at,
652 body,
653 created
654 from
655 pull_comments
656 where
657 submission_id IN (%s)
658 order by
659 created asc
660 `, inClause)
661 commentsRows, err := e.Query(commentsQuery, args...)
662 if err != nil {
663 return nil, err
664 }
665 defer commentsRows.Close()
666
667 for commentsRows.Next() {
668 var comment PullComment
669 var commentCreatedStr string
670 err := commentsRows.Scan(
671 &comment.ID,
672 &comment.PullId,
673 &comment.SubmissionId,
674 &comment.RepoAt,
675 &comment.OwnerDid,
676 &comment.CommentAt,
677 &comment.Body,
678 &commentCreatedStr,
679 )
680 if err != nil {
681 return nil, err
682 }
683
684 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
685 if err != nil {
686 return nil, err
687 }
688 comment.Created = commentCreatedTime
689
690 // Add the comment to its submission
691 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
692 submission.Comments = append(submission.Comments, comment)
693 }
694
695 }
696 if err = commentsRows.Err(); err != nil {
697 return nil, err
698 }
699
700 var pullSourceRepo *Repo
701 if pull.PullSource != nil {
702 if pull.PullSource.RepoAt != nil {
703 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
704 if err != nil {
705 log.Printf("failed to get repo by at uri: %v", err)
706 } else {
707 pull.PullSource.Repo = pullSourceRepo
708 }
709 }
710 }
711
712 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
713 for _, submission := range submissionsMap {
714 pull.Submissions[submission.RoundNumber] = submission
715 }
716
717 return &pull, nil
718}
719
720// timeframe here is directly passed into the sql query filter, and any
721// timeframe in the past should be negative; e.g.: "-3 months"
722func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
723 var pulls []Pull
724
725 rows, err := e.Query(`
726 select
727 p.owner_did,
728 p.repo_at,
729 p.pull_id,
730 p.created,
731 p.title,
732 p.state,
733 r.did,
734 r.name,
735 r.knot,
736 r.rkey,
737 r.created
738 from
739 pulls p
740 join
741 repos r on p.repo_at = r.at_uri
742 where
743 p.owner_did = ? and p.created >= date ('now', ?)
744 order by
745 p.created desc`, did, timeframe)
746 if err != nil {
747 return nil, err
748 }
749 defer rows.Close()
750
751 for rows.Next() {
752 var pull Pull
753 var repo Repo
754 var pullCreatedAt, repoCreatedAt string
755 err := rows.Scan(
756 &pull.OwnerDid,
757 &pull.RepoAt,
758 &pull.PullId,
759 &pullCreatedAt,
760 &pull.Title,
761 &pull.State,
762 &repo.Did,
763 &repo.Name,
764 &repo.Knot,
765 &repo.Rkey,
766 &repoCreatedAt,
767 )
768 if err != nil {
769 return nil, err
770 }
771
772 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
773 if err != nil {
774 return nil, err
775 }
776 pull.Created = pullCreatedTime
777
778 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
779 if err != nil {
780 return nil, err
781 }
782 repo.Created = repoCreatedTime
783
784 pull.Repo = &repo
785
786 pulls = append(pulls, pull)
787 }
788
789 if err := rows.Err(); err != nil {
790 return nil, err
791 }
792
793 return pulls, nil
794}
795
796func NewPullComment(e Execer, comment *PullComment) (int64, error) {
797 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
798 res, err := e.Exec(
799 query,
800 comment.OwnerDid,
801 comment.RepoAt,
802 comment.SubmissionId,
803 comment.CommentAt,
804 comment.PullId,
805 comment.Body,
806 )
807 if err != nil {
808 return 0, err
809 }
810
811 i, err := res.LastInsertId()
812 if err != nil {
813 return 0, err
814 }
815
816 return i, nil
817}
818
819func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
820 _, err := e.Exec(
821 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
822 pullState,
823 repoAt,
824 pullId,
825 PullDeleted, // only update state of non-deleted pulls
826 PullMerged, // only update state of non-merged pulls
827 )
828 return err
829}
830
831func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
832 err := SetPullState(e, repoAt, pullId, PullClosed)
833 return err
834}
835
836func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
837 err := SetPullState(e, repoAt, pullId, PullOpen)
838 return err
839}
840
841func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
842 err := SetPullState(e, repoAt, pullId, PullMerged)
843 return err
844}
845
846func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
847 err := SetPullState(e, repoAt, pullId, PullDeleted)
848 return err
849}
850
851func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
852 newRoundNumber := len(pull.Submissions)
853 _, err := e.Exec(`
854 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
855 values (?, ?, ?, ?, ?)
856 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
857
858 return err
859}
860
861func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
862 var conditions []string
863 var args []any
864
865 args = append(args, parentChangeId)
866
867 for _, filter := range filters {
868 conditions = append(conditions, filter.Condition())
869 args = append(args, filter.arg)
870 }
871
872 whereClause := ""
873 if conditions != nil {
874 whereClause = " where " + strings.Join(conditions, " and ")
875 }
876
877 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
878 _, err := e.Exec(query, args...)
879
880 return err
881}
882
883// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
884// otherwise submissions are immutable
885func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
886 var conditions []string
887 var args []any
888
889 args = append(args, sourceRev)
890 args = append(args, newPatch)
891
892 for _, filter := range filters {
893 conditions = append(conditions, filter.Condition())
894 args = append(args, filter.arg)
895 }
896
897 whereClause := ""
898 if conditions != nil {
899 whereClause = " where " + strings.Join(conditions, " and ")
900 }
901
902 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
903 _, err := e.Exec(query, args...)
904
905 return err
906}
907
908type PullCount struct {
909 Open int
910 Merged int
911 Closed int
912 Deleted int
913}
914
915func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
916 row := e.QueryRow(`
917 select
918 count(case when state = ? then 1 end) as open_count,
919 count(case when state = ? then 1 end) as merged_count,
920 count(case when state = ? then 1 end) as closed_count,
921 count(case when state = ? then 1 end) as deleted_count
922 from pulls
923 where repo_at = ?`,
924 PullOpen,
925 PullMerged,
926 PullClosed,
927 PullDeleted,
928 repoAt,
929 )
930
931 var count PullCount
932 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
933 return PullCount{0, 0, 0, 0}, err
934 }
935
936 return count, nil
937}
938
939type Stack []*Pull
940
941// change-id parent-change-id
942//
943// 4 w ,-------- z (TOP)
944// 3 z <----',------- y
945// 2 y <-----',------ x
946// 1 x <------' nil (BOT)
947//
948// `w` is parent of none, so it is the top of the stack
949func GetStack(e Execer, stackId string) (Stack, error) {
950 unorderedPulls, err := GetPulls(
951 e,
952 FilterEq("stack_id", stackId),
953 FilterNotEq("state", PullDeleted),
954 )
955 if err != nil {
956 return nil, err
957 }
958 // map of parent-change-id to pull
959 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
960 parentMap := make(map[string]*Pull, len(unorderedPulls))
961 for _, p := range unorderedPulls {
962 changeIdMap[p.ChangeId] = p
963 if p.ParentChangeId != "" {
964 parentMap[p.ParentChangeId] = p
965 }
966 }
967
968 // the top of the stack is the pull that is not a parent of any pull
969 var topPull *Pull
970 for _, maybeTop := range unorderedPulls {
971 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
972 topPull = maybeTop
973 break
974 }
975 }
976
977 pulls := []*Pull{}
978 for {
979 pulls = append(pulls, topPull)
980 if topPull.ParentChangeId != "" {
981 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
982 topPull = next
983 } else {
984 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
985 }
986 } else {
987 break
988 }
989 }
990
991 return pulls, nil
992}
993
994func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
995 pulls, err := GetPulls(
996 e,
997 FilterEq("stack_id", stackId),
998 FilterEq("state", PullDeleted),
999 )
1000 if err != nil {
1001 return nil, err
1002 }
1003
1004 return pulls, nil
1005}
1006
1007// position of this pull in the stack
1008func (stack Stack) Position(pull *Pull) int {
1009 return slices.IndexFunc(stack, func(p *Pull) bool {
1010 return p.ChangeId == pull.ChangeId
1011 })
1012}
1013
1014// all pulls below this pull (including self) in this stack
1015//
1016// nil if this pull does not belong to this stack
1017func (stack Stack) Below(pull *Pull) Stack {
1018 position := stack.Position(pull)
1019
1020 if position < 0 {
1021 return nil
1022 }
1023
1024 return stack[position:]
1025}
1026
1027// all pulls below this pull (excluding self) in this stack
1028func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1029 below := stack.Below(pull)
1030
1031 if len(below) > 0 {
1032 return below[1:]
1033 }
1034
1035 return nil
1036}
1037
1038// all pulls above this pull (including self) in this stack
1039func (stack Stack) Above(pull *Pull) Stack {
1040 position := stack.Position(pull)
1041
1042 if position < 0 {
1043 return nil
1044 }
1045
1046 return stack[:position+1]
1047}
1048
1049// all pulls below this pull (excluding self) in this stack
1050func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1051 above := stack.Above(pull)
1052
1053 if len(above) > 0 {
1054 return above[:len(above)-1]
1055 }
1056
1057 return nil
1058}
1059
1060// the combined format-patches of all the newest submissions in this stack
1061func (stack Stack) CombinedPatch() string {
1062 // go in reverse order because the bottom of the stack is the last element in the slice
1063 var combined strings.Builder
1064 for idx := range stack {
1065 pull := stack[len(stack)-1-idx]
1066 combined.WriteString(pull.LatestPatch())
1067 combined.WriteString("\n")
1068 }
1069 return combined.String()
1070}
1071
1072// filter out PRs that are "active"
1073//
1074// PRs that are still open are active
1075func (stack Stack) Mergeable() Stack {
1076 var mergeable Stack
1077
1078 for _, p := range stack {
1079 // stop at the first merged PR
1080 if p.State == PullMerged || p.State == PullClosed {
1081 break
1082 }
1083
1084 // skip over deleted PRs
1085 if p.State != PullDeleted {
1086 mergeable = append(mergeable, p)
1087 }
1088 }
1089
1090 return mergeable
1091}