forked from
tangled.org/core
fork
Configure Feed
Select the types of activity you want to include in your feed.
this repo has no description
fork
Configure Feed
Select the types of activity you want to include in your feed.
1package db
2
3import (
4 "database/sql"
5 "fmt"
6 "log"
7 "slices"
8 "sort"
9 "strings"
10 "time"
11
12 "github.com/bluesky-social/indigo/atproto/syntax"
13 "tangled.sh/tangled.sh/core/api/tangled"
14 "tangled.sh/tangled.sh/core/patchutil"
15 "tangled.sh/tangled.sh/core/types"
16)
17
18type PullState int
19
20const (
21 PullClosed PullState = iota
22 PullOpen
23 PullMerged
24 PullDeleted
25)
26
27func (p PullState) String() string {
28 switch p {
29 case PullOpen:
30 return "open"
31 case PullMerged:
32 return "merged"
33 case PullClosed:
34 return "closed"
35 case PullDeleted:
36 return "deleted"
37 default:
38 return "closed"
39 }
40}
41
42func (p PullState) IsOpen() bool {
43 return p == PullOpen
44}
45func (p PullState) IsMerged() bool {
46 return p == PullMerged
47}
48func (p PullState) IsClosed() bool {
49 return p == PullClosed
50}
51func (p PullState) IsDeleted() bool {
52 return p == PullDeleted
53}
54
55type Pull struct {
56 // ids
57 ID int
58 PullId int
59
60 // at ids
61 RepoAt syntax.ATURI
62 OwnerDid string
63 Rkey string
64
65 // content
66 Title string
67 Body string
68 TargetBranch string
69 State PullState
70 Submissions []*PullSubmission
71
72 // stacking
73 StackId string // nullable string
74 ChangeId string // nullable string
75 ParentChangeId string // nullable string
76
77 // meta
78 Created time.Time
79 PullSource *PullSource
80
81 // optionally, populate this when querying for reverse mappings
82 Repo *Repo
83}
84
85func (p Pull) AsRecord() tangled.RepoPull {
86 var source *tangled.RepoPull_Source
87 if p.PullSource != nil {
88 s := p.PullSource.AsRecord()
89 source = &s
90 source.Sha = p.LatestSha()
91 }
92
93 record := tangled.RepoPull{
94 Title: p.Title,
95 Body: &p.Body,
96 CreatedAt: p.Created.Format(time.RFC3339),
97 PullId: int64(p.PullId),
98 TargetRepo: p.RepoAt.String(),
99 TargetBranch: p.TargetBranch,
100 Patch: p.LatestPatch(),
101 Source: source,
102 }
103 return record
104}
105
106type PullSource struct {
107 Branch string
108 RepoAt *syntax.ATURI
109
110 // optionally populate this for reverse mappings
111 Repo *Repo
112}
113
114func (p PullSource) AsRecord() tangled.RepoPull_Source {
115 var repoAt *string
116 if p.RepoAt != nil {
117 s := p.RepoAt.String()
118 repoAt = &s
119 }
120 record := tangled.RepoPull_Source{
121 Branch: p.Branch,
122 Repo: repoAt,
123 }
124 return record
125}
126
127type PullSubmission struct {
128 // ids
129 ID int
130 PullId int
131
132 // at ids
133 RepoAt syntax.ATURI
134
135 // content
136 RoundNumber int
137 Patch string
138 Comments []PullComment
139 SourceRev string // include the rev that was used to create this submission: only for branch/fork PRs
140
141 // meta
142 Created time.Time
143}
144
145type PullComment struct {
146 // ids
147 ID int
148 PullId int
149 SubmissionId int
150
151 // at ids
152 RepoAt string
153 OwnerDid string
154 CommentAt string
155
156 // content
157 Body string
158
159 // meta
160 Created time.Time
161}
162
163func (p *Pull) LatestPatch() string {
164 latestSubmission := p.Submissions[p.LastRoundNumber()]
165 return latestSubmission.Patch
166}
167
168func (p *Pull) LatestSha() string {
169 latestSubmission := p.Submissions[p.LastRoundNumber()]
170 return latestSubmission.SourceRev
171}
172
173func (p *Pull) PullAt() syntax.ATURI {
174 return syntax.ATURI(fmt.Sprintf("at://%s/%s/%s", p.OwnerDid, tangled.RepoPullNSID, p.Rkey))
175}
176
177func (p *Pull) LastRoundNumber() int {
178 return len(p.Submissions) - 1
179}
180
181func (p *Pull) IsPatchBased() bool {
182 return p.PullSource == nil
183}
184
185func (p *Pull) IsBranchBased() bool {
186 if p.PullSource != nil {
187 if p.PullSource.RepoAt != nil {
188 return p.PullSource.RepoAt == &p.RepoAt
189 } else {
190 // no repo specified
191 return true
192 }
193 }
194 return false
195}
196
197func (p *Pull) IsForkBased() bool {
198 if p.PullSource != nil {
199 if p.PullSource.RepoAt != nil {
200 // make sure repos are different
201 return p.PullSource.RepoAt != &p.RepoAt
202 }
203 }
204 return false
205}
206
207func (p *Pull) IsStacked() bool {
208 return p.StackId != ""
209}
210
211func (s PullSubmission) IsFormatPatch() bool {
212 return patchutil.IsFormatPatch(s.Patch)
213}
214
215func (s PullSubmission) AsFormatPatch() []types.FormatPatch {
216 patches, err := patchutil.ExtractPatches(s.Patch)
217 if err != nil {
218 log.Println("error extracting patches from submission:", err)
219 return []types.FormatPatch{}
220 }
221
222 return patches
223}
224
225func NewPull(tx *sql.Tx, pull *Pull) error {
226 _, err := tx.Exec(`
227 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
228 values (?, 1)
229 `, pull.RepoAt)
230 if err != nil {
231 return err
232 }
233
234 var nextId int
235 err = tx.QueryRow(`
236 update repo_pull_seqs
237 set next_pull_id = next_pull_id + 1
238 where repo_at = ?
239 returning next_pull_id - 1
240 `, pull.RepoAt).Scan(&nextId)
241 if err != nil {
242 return err
243 }
244
245 pull.PullId = nextId
246 pull.State = PullOpen
247
248 var sourceBranch, sourceRepoAt *string
249 if pull.PullSource != nil {
250 sourceBranch = &pull.PullSource.Branch
251 if pull.PullSource.RepoAt != nil {
252 x := pull.PullSource.RepoAt.String()
253 sourceRepoAt = &x
254 }
255 }
256
257 var stackId, changeId, parentChangeId *string
258 if pull.StackId != "" {
259 stackId = &pull.StackId
260 }
261 if pull.ChangeId != "" {
262 changeId = &pull.ChangeId
263 }
264 if pull.ParentChangeId != "" {
265 parentChangeId = &pull.ParentChangeId
266 }
267
268 _, err = tx.Exec(
269 `
270 insert into pulls (
271 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
272 )
273 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
274 pull.RepoAt,
275 pull.OwnerDid,
276 pull.PullId,
277 pull.Title,
278 pull.TargetBranch,
279 pull.Body,
280 pull.Rkey,
281 pull.State,
282 sourceBranch,
283 sourceRepoAt,
284 stackId,
285 changeId,
286 parentChangeId,
287 )
288 if err != nil {
289 return err
290 }
291
292 _, err = tx.Exec(`
293 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
294 values (?, ?, ?, ?, ?)
295 `, pull.PullId, pull.RepoAt, 0, pull.Submissions[0].Patch, pull.Submissions[0].SourceRev)
296 return err
297}
298
299func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
300 pull, err := GetPull(e, repoAt, pullId)
301 if err != nil {
302 return "", err
303 }
304 return pull.PullAt(), err
305}
306
307func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
308 var pullId int
309 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
310 return pullId - 1, err
311}
312
313func GetPulls(e Execer, filters ...filter) ([]*Pull, error) {
314 pulls := make(map[int]*Pull)
315
316 var conditions []string
317 var args []any
318 for _, filter := range filters {
319 conditions = append(conditions, filter.Condition())
320 args = append(args, filter.Arg()...)
321 }
322
323 whereClause := ""
324 if conditions != nil {
325 whereClause = " where " + strings.Join(conditions, " and ")
326 }
327
328 query := fmt.Sprintf(`
329 select
330 owner_did,
331 repo_at,
332 pull_id,
333 created,
334 title,
335 state,
336 target_branch,
337 body,
338 rkey,
339 source_branch,
340 source_repo_at,
341 stack_id,
342 change_id,
343 parent_change_id
344 from
345 pulls
346 %s
347 `, whereClause)
348
349 rows, err := e.Query(query, args...)
350 if err != nil {
351 return nil, err
352 }
353 defer rows.Close()
354
355 for rows.Next() {
356 var pull Pull
357 var createdAt string
358 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
359 err := rows.Scan(
360 &pull.OwnerDid,
361 &pull.RepoAt,
362 &pull.PullId,
363 &createdAt,
364 &pull.Title,
365 &pull.State,
366 &pull.TargetBranch,
367 &pull.Body,
368 &pull.Rkey,
369 &sourceBranch,
370 &sourceRepoAt,
371 &stackId,
372 &changeId,
373 &parentChangeId,
374 )
375 if err != nil {
376 return nil, err
377 }
378
379 createdTime, err := time.Parse(time.RFC3339, createdAt)
380 if err != nil {
381 return nil, err
382 }
383 pull.Created = createdTime
384
385 if sourceBranch.Valid {
386 pull.PullSource = &PullSource{
387 Branch: sourceBranch.String,
388 }
389 if sourceRepoAt.Valid {
390 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
391 if err != nil {
392 return nil, err
393 }
394 pull.PullSource.RepoAt = &sourceRepoAtParsed
395 }
396 }
397
398 if stackId.Valid {
399 pull.StackId = stackId.String
400 }
401 if changeId.Valid {
402 pull.ChangeId = changeId.String
403 }
404 if parentChangeId.Valid {
405 pull.ParentChangeId = parentChangeId.String
406 }
407
408 pulls[pull.PullId] = &pull
409 }
410
411 // get latest round no. for each pull
412 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
413 submissionsQuery := fmt.Sprintf(`
414 select
415 id, pull_id, round_number, patch, source_rev
416 from
417 pull_submissions
418 where
419 repo_at in (%s) and pull_id in (%s)
420 `, inClause, inClause)
421
422 args = make([]any, len(pulls)*2)
423 idx := 0
424 for _, p := range pulls {
425 args[idx] = p.RepoAt
426 idx += 1
427 }
428 for _, p := range pulls {
429 args[idx] = p.PullId
430 idx += 1
431 }
432 submissionsRows, err := e.Query(submissionsQuery, args...)
433 if err != nil {
434 return nil, err
435 }
436 defer submissionsRows.Close()
437
438 for submissionsRows.Next() {
439 var s PullSubmission
440 var sourceRev sql.NullString
441 err := submissionsRows.Scan(
442 &s.ID,
443 &s.PullId,
444 &s.RoundNumber,
445 &s.Patch,
446 &sourceRev,
447 )
448 if err != nil {
449 return nil, err
450 }
451
452 if sourceRev.Valid {
453 s.SourceRev = sourceRev.String
454 }
455
456 if p, ok := pulls[s.PullId]; ok {
457 p.Submissions = make([]*PullSubmission, s.RoundNumber+1)
458 p.Submissions[s.RoundNumber] = &s
459 }
460 }
461 if err := rows.Err(); err != nil {
462 return nil, err
463 }
464
465 // get comment count on latest submission on each pull
466 inClause = strings.TrimSuffix(strings.Repeat("?, ", len(pulls)), ", ")
467 commentsQuery := fmt.Sprintf(`
468 select
469 count(id), pull_id
470 from
471 pull_comments
472 where
473 submission_id in (%s)
474 group by
475 submission_id
476 `, inClause)
477
478 args = []any{}
479 for _, p := range pulls {
480 args = append(args, p.Submissions[p.LastRoundNumber()].ID)
481 }
482 commentsRows, err := e.Query(commentsQuery, args...)
483 if err != nil {
484 return nil, err
485 }
486 defer commentsRows.Close()
487
488 for commentsRows.Next() {
489 var commentCount, pullId int
490 err := commentsRows.Scan(
491 &commentCount,
492 &pullId,
493 )
494 if err != nil {
495 return nil, err
496 }
497 if p, ok := pulls[pullId]; ok {
498 p.Submissions[p.LastRoundNumber()].Comments = make([]PullComment, commentCount)
499 }
500 }
501 if err := rows.Err(); err != nil {
502 return nil, err
503 }
504
505 orderedByPullId := []*Pull{}
506 for _, p := range pulls {
507 orderedByPullId = append(orderedByPullId, p)
508 }
509 sort.Slice(orderedByPullId, func(i, j int) bool {
510 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
511 })
512
513 return orderedByPullId, nil
514}
515
516func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*Pull, error) {
517 query := `
518 select
519 owner_did,
520 pull_id,
521 created,
522 title,
523 state,
524 target_branch,
525 repo_at,
526 body,
527 rkey,
528 source_branch,
529 source_repo_at,
530 stack_id,
531 change_id,
532 parent_change_id
533 from
534 pulls
535 where
536 repo_at = ? and pull_id = ?
537 `
538 row := e.QueryRow(query, repoAt, pullId)
539
540 var pull Pull
541 var createdAt string
542 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
543 err := row.Scan(
544 &pull.OwnerDid,
545 &pull.PullId,
546 &createdAt,
547 &pull.Title,
548 &pull.State,
549 &pull.TargetBranch,
550 &pull.RepoAt,
551 &pull.Body,
552 &pull.Rkey,
553 &sourceBranch,
554 &sourceRepoAt,
555 &stackId,
556 &changeId,
557 &parentChangeId,
558 )
559 if err != nil {
560 return nil, err
561 }
562
563 createdTime, err := time.Parse(time.RFC3339, createdAt)
564 if err != nil {
565 return nil, err
566 }
567 pull.Created = createdTime
568
569 // populate source
570 if sourceBranch.Valid {
571 pull.PullSource = &PullSource{
572 Branch: sourceBranch.String,
573 }
574 if sourceRepoAt.Valid {
575 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
576 if err != nil {
577 return nil, err
578 }
579 pull.PullSource.RepoAt = &sourceRepoAtParsed
580 }
581 }
582
583 if stackId.Valid {
584 pull.StackId = stackId.String
585 }
586 if changeId.Valid {
587 pull.ChangeId = changeId.String
588 }
589 if parentChangeId.Valid {
590 pull.ParentChangeId = parentChangeId.String
591 }
592
593 submissionsQuery := `
594 select
595 id, pull_id, repo_at, round_number, patch, created, source_rev
596 from
597 pull_submissions
598 where
599 repo_at = ? and pull_id = ?
600 `
601 submissionsRows, err := e.Query(submissionsQuery, repoAt, pullId)
602 if err != nil {
603 return nil, err
604 }
605 defer submissionsRows.Close()
606
607 submissionsMap := make(map[int]*PullSubmission)
608
609 for submissionsRows.Next() {
610 var submission PullSubmission
611 var submissionCreatedStr string
612 var submissionSourceRev sql.NullString
613 err := submissionsRows.Scan(
614 &submission.ID,
615 &submission.PullId,
616 &submission.RepoAt,
617 &submission.RoundNumber,
618 &submission.Patch,
619 &submissionCreatedStr,
620 &submissionSourceRev,
621 )
622 if err != nil {
623 return nil, err
624 }
625
626 submissionCreatedTime, err := time.Parse(time.RFC3339, submissionCreatedStr)
627 if err != nil {
628 return nil, err
629 }
630 submission.Created = submissionCreatedTime
631
632 if submissionSourceRev.Valid {
633 submission.SourceRev = submissionSourceRev.String
634 }
635
636 submissionsMap[submission.ID] = &submission
637 }
638 if err = submissionsRows.Close(); err != nil {
639 return nil, err
640 }
641 if len(submissionsMap) == 0 {
642 return &pull, nil
643 }
644
645 var args []any
646 for k := range submissionsMap {
647 args = append(args, k)
648 }
649 inClause := strings.TrimSuffix(strings.Repeat("?, ", len(submissionsMap)), ", ")
650 commentsQuery := fmt.Sprintf(`
651 select
652 id,
653 pull_id,
654 submission_id,
655 repo_at,
656 owner_did,
657 comment_at,
658 body,
659 created
660 from
661 pull_comments
662 where
663 submission_id IN (%s)
664 order by
665 created asc
666 `, inClause)
667 commentsRows, err := e.Query(commentsQuery, args...)
668 if err != nil {
669 return nil, err
670 }
671 defer commentsRows.Close()
672
673 for commentsRows.Next() {
674 var comment PullComment
675 var commentCreatedStr string
676 err := commentsRows.Scan(
677 &comment.ID,
678 &comment.PullId,
679 &comment.SubmissionId,
680 &comment.RepoAt,
681 &comment.OwnerDid,
682 &comment.CommentAt,
683 &comment.Body,
684 &commentCreatedStr,
685 )
686 if err != nil {
687 return nil, err
688 }
689
690 commentCreatedTime, err := time.Parse(time.RFC3339, commentCreatedStr)
691 if err != nil {
692 return nil, err
693 }
694 comment.Created = commentCreatedTime
695
696 // Add the comment to its submission
697 if submission, ok := submissionsMap[comment.SubmissionId]; ok {
698 submission.Comments = append(submission.Comments, comment)
699 }
700
701 }
702 if err = commentsRows.Err(); err != nil {
703 return nil, err
704 }
705
706 var pullSourceRepo *Repo
707 if pull.PullSource != nil {
708 if pull.PullSource.RepoAt != nil {
709 pullSourceRepo, err = GetRepoByAtUri(e, pull.PullSource.RepoAt.String())
710 if err != nil {
711 log.Printf("failed to get repo by at uri: %v", err)
712 } else {
713 pull.PullSource.Repo = pullSourceRepo
714 }
715 }
716 }
717
718 pull.Submissions = make([]*PullSubmission, len(submissionsMap))
719 for _, submission := range submissionsMap {
720 pull.Submissions[submission.RoundNumber] = submission
721 }
722
723 return &pull, nil
724}
725
726// timeframe here is directly passed into the sql query filter, and any
727// timeframe in the past should be negative; e.g.: "-3 months"
728func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]Pull, error) {
729 var pulls []Pull
730
731 rows, err := e.Query(`
732 select
733 p.owner_did,
734 p.repo_at,
735 p.pull_id,
736 p.created,
737 p.title,
738 p.state,
739 r.did,
740 r.name,
741 r.knot,
742 r.rkey,
743 r.created
744 from
745 pulls p
746 join
747 repos r on p.repo_at = r.at_uri
748 where
749 p.owner_did = ? and p.created >= date ('now', ?)
750 order by
751 p.created desc`, did, timeframe)
752 if err != nil {
753 return nil, err
754 }
755 defer rows.Close()
756
757 for rows.Next() {
758 var pull Pull
759 var repo Repo
760 var pullCreatedAt, repoCreatedAt string
761 err := rows.Scan(
762 &pull.OwnerDid,
763 &pull.RepoAt,
764 &pull.PullId,
765 &pullCreatedAt,
766 &pull.Title,
767 &pull.State,
768 &repo.Did,
769 &repo.Name,
770 &repo.Knot,
771 &repo.Rkey,
772 &repoCreatedAt,
773 )
774 if err != nil {
775 return nil, err
776 }
777
778 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
779 if err != nil {
780 return nil, err
781 }
782 pull.Created = pullCreatedTime
783
784 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
785 if err != nil {
786 return nil, err
787 }
788 repo.Created = repoCreatedTime
789
790 pull.Repo = &repo
791
792 pulls = append(pulls, pull)
793 }
794
795 if err := rows.Err(); err != nil {
796 return nil, err
797 }
798
799 return pulls, nil
800}
801
802func NewPullComment(e Execer, comment *PullComment) (int64, error) {
803 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
804 res, err := e.Exec(
805 query,
806 comment.OwnerDid,
807 comment.RepoAt,
808 comment.SubmissionId,
809 comment.CommentAt,
810 comment.PullId,
811 comment.Body,
812 )
813 if err != nil {
814 return 0, err
815 }
816
817 i, err := res.LastInsertId()
818 if err != nil {
819 return 0, err
820 }
821
822 return i, nil
823}
824
825func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState PullState) error {
826 _, err := e.Exec(
827 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
828 pullState,
829 repoAt,
830 pullId,
831 PullDeleted, // only update state of non-deleted pulls
832 PullMerged, // only update state of non-merged pulls
833 )
834 return err
835}
836
837func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
838 err := SetPullState(e, repoAt, pullId, PullClosed)
839 return err
840}
841
842func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
843 err := SetPullState(e, repoAt, pullId, PullOpen)
844 return err
845}
846
847func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
848 err := SetPullState(e, repoAt, pullId, PullMerged)
849 return err
850}
851
852func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
853 err := SetPullState(e, repoAt, pullId, PullDeleted)
854 return err
855}
856
857func ResubmitPull(e Execer, pull *Pull, newPatch, sourceRev string) error {
858 newRoundNumber := len(pull.Submissions)
859 _, err := e.Exec(`
860 insert into pull_submissions (pull_id, repo_at, round_number, patch, source_rev)
861 values (?, ?, ?, ?, ?)
862 `, pull.PullId, pull.RepoAt, newRoundNumber, newPatch, sourceRev)
863
864 return err
865}
866
867func SetPullParentChangeId(e Execer, parentChangeId string, filters ...filter) error {
868 var conditions []string
869 var args []any
870
871 args = append(args, parentChangeId)
872
873 for _, filter := range filters {
874 conditions = append(conditions, filter.Condition())
875 args = append(args, filter.Arg()...)
876 }
877
878 whereClause := ""
879 if conditions != nil {
880 whereClause = " where " + strings.Join(conditions, " and ")
881 }
882
883 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
884 _, err := e.Exec(query, args...)
885
886 return err
887}
888
889// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
890// otherwise submissions are immutable
891func UpdatePull(e Execer, newPatch, sourceRev string, filters ...filter) error {
892 var conditions []string
893 var args []any
894
895 args = append(args, sourceRev)
896 args = append(args, newPatch)
897
898 for _, filter := range filters {
899 conditions = append(conditions, filter.Condition())
900 args = append(args, filter.Arg()...)
901 }
902
903 whereClause := ""
904 if conditions != nil {
905 whereClause = " where " + strings.Join(conditions, " and ")
906 }
907
908 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
909 _, err := e.Exec(query, args...)
910
911 return err
912}
913
914type PullCount struct {
915 Open int
916 Merged int
917 Closed int
918 Deleted int
919}
920
921func GetPullCount(e Execer, repoAt syntax.ATURI) (PullCount, error) {
922 row := e.QueryRow(`
923 select
924 count(case when state = ? then 1 end) as open_count,
925 count(case when state = ? then 1 end) as merged_count,
926 count(case when state = ? then 1 end) as closed_count,
927 count(case when state = ? then 1 end) as deleted_count
928 from pulls
929 where repo_at = ?`,
930 PullOpen,
931 PullMerged,
932 PullClosed,
933 PullDeleted,
934 repoAt,
935 )
936
937 var count PullCount
938 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
939 return PullCount{0, 0, 0, 0}, err
940 }
941
942 return count, nil
943}
944
945type Stack []*Pull
946
947// change-id parent-change-id
948//
949// 4 w ,-------- z (TOP)
950// 3 z <----',------- y
951// 2 y <-----',------ x
952// 1 x <------' nil (BOT)
953//
954// `w` is parent of none, so it is the top of the stack
955func GetStack(e Execer, stackId string) (Stack, error) {
956 unorderedPulls, err := GetPulls(
957 e,
958 FilterEq("stack_id", stackId),
959 FilterNotEq("state", PullDeleted),
960 )
961 if err != nil {
962 return nil, err
963 }
964 // map of parent-change-id to pull
965 changeIdMap := make(map[string]*Pull, len(unorderedPulls))
966 parentMap := make(map[string]*Pull, len(unorderedPulls))
967 for _, p := range unorderedPulls {
968 changeIdMap[p.ChangeId] = p
969 if p.ParentChangeId != "" {
970 parentMap[p.ParentChangeId] = p
971 }
972 }
973
974 // the top of the stack is the pull that is not a parent of any pull
975 var topPull *Pull
976 for _, maybeTop := range unorderedPulls {
977 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
978 topPull = maybeTop
979 break
980 }
981 }
982
983 pulls := []*Pull{}
984 for {
985 pulls = append(pulls, topPull)
986 if topPull.ParentChangeId != "" {
987 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
988 topPull = next
989 } else {
990 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
991 }
992 } else {
993 break
994 }
995 }
996
997 return pulls, nil
998}
999
1000func GetAbandonedPulls(e Execer, stackId string) ([]*Pull, error) {
1001 pulls, err := GetPulls(
1002 e,
1003 FilterEq("stack_id", stackId),
1004 FilterEq("state", PullDeleted),
1005 )
1006 if err != nil {
1007 return nil, err
1008 }
1009
1010 return pulls, nil
1011}
1012
1013// position of this pull in the stack
1014func (stack Stack) Position(pull *Pull) int {
1015 return slices.IndexFunc(stack, func(p *Pull) bool {
1016 return p.ChangeId == pull.ChangeId
1017 })
1018}
1019
1020// all pulls below this pull (including self) in this stack
1021//
1022// nil if this pull does not belong to this stack
1023func (stack Stack) Below(pull *Pull) Stack {
1024 position := stack.Position(pull)
1025
1026 if position < 0 {
1027 return nil
1028 }
1029
1030 return stack[position:]
1031}
1032
1033// all pulls below this pull (excluding self) in this stack
1034func (stack Stack) StrictlyBelow(pull *Pull) Stack {
1035 below := stack.Below(pull)
1036
1037 if len(below) > 0 {
1038 return below[1:]
1039 }
1040
1041 return nil
1042}
1043
1044// all pulls above this pull (including self) in this stack
1045func (stack Stack) Above(pull *Pull) Stack {
1046 position := stack.Position(pull)
1047
1048 if position < 0 {
1049 return nil
1050 }
1051
1052 return stack[:position+1]
1053}
1054
1055// all pulls below this pull (excluding self) in this stack
1056func (stack Stack) StrictlyAbove(pull *Pull) Stack {
1057 above := stack.Above(pull)
1058
1059 if len(above) > 0 {
1060 return above[:len(above)-1]
1061 }
1062
1063 return nil
1064}
1065
1066// the combined format-patches of all the newest submissions in this stack
1067func (stack Stack) CombinedPatch() string {
1068 // go in reverse order because the bottom of the stack is the last element in the slice
1069 var combined strings.Builder
1070 for idx := range stack {
1071 pull := stack[len(stack)-1-idx]
1072 combined.WriteString(pull.LatestPatch())
1073 combined.WriteString("\n")
1074 }
1075 return combined.String()
1076}
1077
1078// filter out PRs that are "active"
1079//
1080// PRs that are still open are active
1081func (stack Stack) Mergeable() Stack {
1082 var mergeable Stack
1083
1084 for _, p := range stack {
1085 // stop at the first merged PR
1086 if p.State == PullMerged || p.State == PullClosed {
1087 break
1088 }
1089
1090 // skip over deleted PRs
1091 if p.State != PullDeleted {
1092 mergeable = append(mergeable, p)
1093 }
1094 }
1095
1096 return mergeable
1097}