1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/appview/pagination"
17 "tangled.org/core/orm"
18)
19
20func NewPull(tx *sql.Tx, pull *models.Pull) error {
21 _, err := tx.Exec(`
22 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
23 values (?, 1)
24 `, pull.RepoAt)
25 if err != nil {
26 return err
27 }
28
29 var nextId int
30 err = tx.QueryRow(`
31 update repo_pull_seqs
32 set next_pull_id = next_pull_id + 1
33 where repo_at = ?
34 returning next_pull_id - 1
35 `, pull.RepoAt).Scan(&nextId)
36 if err != nil {
37 return err
38 }
39
40 pull.PullId = nextId
41 pull.State = models.PullOpen
42
43 var sourceBranch, sourceRepoAt *string
44 if pull.PullSource != nil {
45 sourceBranch = &pull.PullSource.Branch
46 if pull.PullSource.RepoAt != nil {
47 x := pull.PullSource.RepoAt.String()
48 sourceRepoAt = &x
49 }
50 }
51
52 var stackId, changeId, parentChangeId *string
53 if pull.StackId != "" {
54 stackId = &pull.StackId
55 }
56 if pull.ChangeId != "" {
57 changeId = &pull.ChangeId
58 }
59 if pull.ParentChangeId != "" {
60 parentChangeId = &pull.ParentChangeId
61 }
62
63 result, err := tx.Exec(
64 `
65 insert into pulls (
66 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
67 )
68 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
69 pull.RepoAt,
70 pull.OwnerDid,
71 pull.PullId,
72 pull.Title,
73 pull.TargetBranch,
74 pull.Body,
75 pull.Rkey,
76 pull.State,
77 sourceBranch,
78 sourceRepoAt,
79 stackId,
80 changeId,
81 parentChangeId,
82 )
83 if err != nil {
84 return err
85 }
86
87 // Set the database primary key ID
88 id, err := result.LastInsertId()
89 if err != nil {
90 return err
91 }
92 pull.ID = int(id)
93
94 _, err = tx.Exec(`
95 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
96 values (?, ?, ?, ?, ?)
97 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
98 if err != nil {
99 return err
100 }
101
102 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil {
103 return fmt.Errorf("put reference_links: %w", err)
104 }
105
106 return nil
107}
108
109func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
110 pull, err := GetPull(e, repoAt, pullId)
111 if err != nil {
112 return "", err
113 }
114 return pull.AtUri(), err
115}
116
117func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
118 var pullId int
119 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
120 return pullId - 1, err
121}
122
123func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
124 pulls := make(map[syntax.ATURI]*models.Pull)
125
126 var conditions []string
127 var args []any
128 for _, filter := range filters {
129 conditions = append(conditions, filter.Condition())
130 args = append(args, filter.Arg()...)
131 }
132
133 whereClause := ""
134 if conditions != nil {
135 whereClause = " where " + strings.Join(conditions, " and ")
136 }
137 pageClause := ""
138 if page.Limit != 0 {
139 pageClause = fmt.Sprintf(
140 " limit %d offset %d ",
141 page.Limit,
142 page.Offset,
143 )
144 }
145
146 query := fmt.Sprintf(`
147 select
148 id,
149 owner_did,
150 repo_at,
151 pull_id,
152 created,
153 title,
154 state,
155 target_branch,
156 body,
157 rkey,
158 source_branch,
159 source_repo_at,
160 stack_id,
161 change_id,
162 parent_change_id
163 from
164 pulls
165 %s
166 order by
167 created desc
168 %s
169 `, whereClause, pageClause)
170
171 rows, err := e.Query(query, args...)
172 if err != nil {
173 return nil, err
174 }
175 defer rows.Close()
176
177 for rows.Next() {
178 var pull models.Pull
179 var createdAt string
180 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
181 err := rows.Scan(
182 &pull.ID,
183 &pull.OwnerDid,
184 &pull.RepoAt,
185 &pull.PullId,
186 &createdAt,
187 &pull.Title,
188 &pull.State,
189 &pull.TargetBranch,
190 &pull.Body,
191 &pull.Rkey,
192 &sourceBranch,
193 &sourceRepoAt,
194 &stackId,
195 &changeId,
196 &parentChangeId,
197 )
198 if err != nil {
199 return nil, err
200 }
201
202 createdTime, err := time.Parse(time.RFC3339, createdAt)
203 if err != nil {
204 return nil, err
205 }
206 pull.Created = createdTime
207
208 if sourceBranch.Valid {
209 pull.PullSource = &models.PullSource{
210 Branch: sourceBranch.String,
211 }
212 if sourceRepoAt.Valid {
213 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
214 if err != nil {
215 return nil, err
216 }
217 pull.PullSource.RepoAt = &sourceRepoAtParsed
218 }
219 }
220
221 if stackId.Valid {
222 pull.StackId = stackId.String
223 }
224 if changeId.Valid {
225 pull.ChangeId = changeId.String
226 }
227 if parentChangeId.Valid {
228 pull.ParentChangeId = parentChangeId.String
229 }
230
231 pulls[pull.AtUri()] = &pull
232 }
233
234 var pullAts []syntax.ATURI
235 for _, p := range pulls {
236 pullAts = append(pullAts, p.AtUri())
237 }
238 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
239 if err != nil {
240 return nil, fmt.Errorf("failed to get submissions: %w", err)
241 }
242
243 for pullAt, submissions := range submissionsMap {
244 if p, ok := pulls[pullAt]; ok {
245 p.Submissions = submissions
246 }
247 }
248
249 // collect allLabels for each issue
250 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
251 if err != nil {
252 return nil, fmt.Errorf("failed to query labels: %w", err)
253 }
254 for pullAt, labels := range allLabels {
255 if p, ok := pulls[pullAt]; ok {
256 p.Labels = labels
257 }
258 }
259
260 // collect pull source for all pulls that need it
261 var sourceAts []syntax.ATURI
262 for _, p := range pulls {
263 if p.PullSource != nil && p.PullSource.RepoAt != nil {
264 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
265 }
266 }
267 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
268 if err != nil && !errors.Is(err, sql.ErrNoRows) {
269 return nil, fmt.Errorf("failed to get source repos: %w", err)
270 }
271 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
272 for _, r := range sourceRepos {
273 sourceRepoMap[r.RepoAt()] = &r
274 }
275 for _, p := range pulls {
276 if p.PullSource != nil && p.PullSource.RepoAt != nil {
277 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
278 p.PullSource.Repo = sourceRepo
279 }
280 }
281 }
282
283 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
284 if err != nil {
285 return nil, fmt.Errorf("failed to query reference_links: %w", err)
286 }
287 for pullAt, references := range allReferences {
288 if pull, ok := pulls[pullAt]; ok {
289 pull.References = references
290 }
291 }
292
293 orderedByPullId := []*models.Pull{}
294 for _, p := range pulls {
295 orderedByPullId = append(orderedByPullId, p)
296 }
297 sort.Slice(orderedByPullId, func(i, j int) bool {
298 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
299 })
300
301 return orderedByPullId, nil
302}
303
304func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
305 return GetPullsPaginated(e, pagination.Page{}, filters...)
306}
307
308func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
309 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
310 if err != nil {
311 return nil, err
312 }
313 if len(pulls) == 0 {
314 return nil, sql.ErrNoRows
315 }
316
317 return pulls[0], nil
318}
319
320// mapping from pull -> pull submissions
321func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
322 var conditions []string
323 var args []any
324 for _, filter := range filters {
325 conditions = append(conditions, filter.Condition())
326 args = append(args, filter.Arg()...)
327 }
328
329 whereClause := ""
330 if conditions != nil {
331 whereClause = " where " + strings.Join(conditions, " and ")
332 }
333
334 query := fmt.Sprintf(`
335 select
336 id,
337 pull_at,
338 round_number,
339 patch,
340 combined,
341 created,
342 source_rev
343 from
344 pull_submissions
345 %s
346 order by
347 round_number asc
348 `, whereClause)
349
350 rows, err := e.Query(query, args...)
351 if err != nil {
352 return nil, err
353 }
354 defer rows.Close()
355
356 submissionMap := make(map[int]*models.PullSubmission)
357
358 for rows.Next() {
359 var submission models.PullSubmission
360 var submissionCreatedStr string
361 var submissionSourceRev, submissionCombined sql.NullString
362 err := rows.Scan(
363 &submission.ID,
364 &submission.PullAt,
365 &submission.RoundNumber,
366 &submission.Patch,
367 &submissionCombined,
368 &submissionCreatedStr,
369 &submissionSourceRev,
370 )
371 if err != nil {
372 return nil, err
373 }
374
375 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
376 submission.Created = t
377 }
378
379 if submissionSourceRev.Valid {
380 submission.SourceRev = submissionSourceRev.String
381 }
382
383 if submissionCombined.Valid {
384 submission.Combined = submissionCombined.String
385 }
386
387 submissionMap[submission.ID] = &submission
388 }
389
390 if err := rows.Err(); err != nil {
391 return nil, err
392 }
393
394 // Get comments for all submissions using GetPullComments
395 submissionIds := slices.Collect(maps.Keys(submissionMap))
396 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds))
397 if err != nil {
398 return nil, fmt.Errorf("failed to get pull comments: %w", err)
399 }
400 for _, comment := range comments {
401 if submission, ok := submissionMap[comment.SubmissionId]; ok {
402 submission.Comments = append(submission.Comments, comment)
403 }
404 }
405
406 // group the submissions by pull_at
407 m := make(map[syntax.ATURI][]*models.PullSubmission)
408 for _, s := range submissionMap {
409 m[s.PullAt] = append(m[s.PullAt], s)
410 }
411
412 // sort each one by round number
413 for _, s := range m {
414 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
415 return cmp.Compare(a.RoundNumber, b.RoundNumber)
416 })
417 }
418
419 return m, nil
420}
421
422func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) {
423 var conditions []string
424 var args []any
425 for _, filter := range filters {
426 conditions = append(conditions, filter.Condition())
427 args = append(args, filter.Arg()...)
428 }
429
430 whereClause := ""
431 if conditions != nil {
432 whereClause = " where " + strings.Join(conditions, " and ")
433 }
434
435 query := fmt.Sprintf(`
436 select
437 id,
438 pull_id,
439 submission_id,
440 repo_at,
441 owner_did,
442 comment_at,
443 body,
444 created
445 from
446 pull_comments
447 %s
448 order by
449 created asc
450 `, whereClause)
451
452 rows, err := e.Query(query, args...)
453 if err != nil {
454 return nil, err
455 }
456 defer rows.Close()
457
458 commentMap := make(map[string]*models.PullComment)
459 for rows.Next() {
460 var comment models.PullComment
461 var createdAt string
462 err := rows.Scan(
463 &comment.ID,
464 &comment.PullId,
465 &comment.SubmissionId,
466 &comment.RepoAt,
467 &comment.OwnerDid,
468 &comment.CommentAt,
469 &comment.Body,
470 &createdAt,
471 )
472 if err != nil {
473 return nil, err
474 }
475
476 if t, err := time.Parse(time.RFC3339, createdAt); err == nil {
477 comment.Created = t
478 }
479
480 atUri := comment.AtUri().String()
481 commentMap[atUri] = &comment
482 }
483
484 if err := rows.Err(); err != nil {
485 return nil, err
486 }
487
488 // collect references for each comments
489 commentAts := slices.Collect(maps.Keys(commentMap))
490 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts))
491 if err != nil {
492 return nil, fmt.Errorf("failed to query reference_links: %w", err)
493 }
494 for commentAt, references := range allReferencs {
495 if comment, ok := commentMap[commentAt.String()]; ok {
496 comment.References = references
497 }
498 }
499
500 var comments []models.PullComment
501 for _, c := range commentMap {
502 comments = append(comments, *c)
503 }
504
505 sort.Slice(comments, func(i, j int) bool {
506 return comments[i].Created.Before(comments[j].Created)
507 })
508
509 return comments, nil
510}
511
512// timeframe here is directly passed into the sql query filter, and any
513// timeframe in the past should be negative; e.g.: "-3 months"
514func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
515 var pulls []models.Pull
516
517 rows, err := e.Query(`
518 select
519 p.owner_did,
520 p.repo_at,
521 p.pull_id,
522 p.created,
523 p.title,
524 p.state,
525 r.did,
526 r.name,
527 r.knot,
528 r.rkey,
529 r.created
530 from
531 pulls p
532 join
533 repos r on p.repo_at = r.at_uri
534 where
535 p.owner_did = ? and p.created >= date ('now', ?)
536 order by
537 p.created desc`, did, timeframe)
538 if err != nil {
539 return nil, err
540 }
541 defer rows.Close()
542
543 for rows.Next() {
544 var pull models.Pull
545 var repo models.Repo
546 var pullCreatedAt, repoCreatedAt string
547 err := rows.Scan(
548 &pull.OwnerDid,
549 &pull.RepoAt,
550 &pull.PullId,
551 &pullCreatedAt,
552 &pull.Title,
553 &pull.State,
554 &repo.Did,
555 &repo.Name,
556 &repo.Knot,
557 &repo.Rkey,
558 &repoCreatedAt,
559 )
560 if err != nil {
561 return nil, err
562 }
563
564 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
565 if err != nil {
566 return nil, err
567 }
568 pull.Created = pullCreatedTime
569
570 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
571 if err != nil {
572 return nil, err
573 }
574 repo.Created = repoCreatedTime
575
576 pull.Repo = &repo
577
578 pulls = append(pulls, pull)
579 }
580
581 if err := rows.Err(); err != nil {
582 return nil, err
583 }
584
585 return pulls, nil
586}
587
588func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) {
589 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)`
590 res, err := tx.Exec(
591 query,
592 comment.OwnerDid,
593 comment.RepoAt,
594 comment.SubmissionId,
595 comment.CommentAt,
596 comment.PullId,
597 comment.Body,
598 )
599 if err != nil {
600 return 0, err
601 }
602
603 i, err := res.LastInsertId()
604 if err != nil {
605 return 0, err
606 }
607
608 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil {
609 return 0, fmt.Errorf("put reference_links: %w", err)
610 }
611
612 return i, nil
613}
614
615func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
616 _, err := e.Exec(
617 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
618 pullState,
619 repoAt,
620 pullId,
621 models.PullDeleted, // only update state of non-deleted pulls
622 models.PullMerged, // only update state of non-merged pulls
623 )
624 return err
625}
626
627func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
628 err := SetPullState(e, repoAt, pullId, models.PullClosed)
629 return err
630}
631
632func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
633 err := SetPullState(e, repoAt, pullId, models.PullOpen)
634 return err
635}
636
637func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
638 err := SetPullState(e, repoAt, pullId, models.PullMerged)
639 return err
640}
641
642func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
643 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
644 return err
645}
646
647func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
648 _, err := e.Exec(`
649 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
650 values (?, ?, ?, ?, ?)
651 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
652
653 return err
654}
655
656func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
657 var conditions []string
658 var args []any
659
660 args = append(args, parentChangeId)
661
662 for _, filter := range filters {
663 conditions = append(conditions, filter.Condition())
664 args = append(args, filter.Arg()...)
665 }
666
667 whereClause := ""
668 if conditions != nil {
669 whereClause = " where " + strings.Join(conditions, " and ")
670 }
671
672 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
673 _, err := e.Exec(query, args...)
674
675 return err
676}
677
678// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
679// otherwise submissions are immutable
680func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
681 var conditions []string
682 var args []any
683
684 args = append(args, sourceRev)
685 args = append(args, newPatch)
686
687 for _, filter := range filters {
688 conditions = append(conditions, filter.Condition())
689 args = append(args, filter.Arg()...)
690 }
691
692 whereClause := ""
693 if conditions != nil {
694 whereClause = " where " + strings.Join(conditions, " and ")
695 }
696
697 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
698 _, err := e.Exec(query, args...)
699
700 return err
701}
702
703func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
704 row := e.QueryRow(`
705 select
706 count(case when state = ? then 1 end) as open_count,
707 count(case when state = ? then 1 end) as merged_count,
708 count(case when state = ? then 1 end) as closed_count,
709 count(case when state = ? then 1 end) as deleted_count
710 from pulls
711 where repo_at = ?`,
712 models.PullOpen,
713 models.PullMerged,
714 models.PullClosed,
715 models.PullDeleted,
716 repoAt,
717 )
718
719 var count models.PullCount
720 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
721 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
722 }
723
724 return count, nil
725}
726
727// change-id parent-change-id
728//
729// 4 w ,-------- z (TOP)
730// 3 z <----',------- y
731// 2 y <-----',------ x
732// 1 x <------' nil (BOT)
733//
734// `w` is parent of none, so it is the top of the stack
735func GetStack(e Execer, stackId string) (models.Stack, error) {
736 unorderedPulls, err := GetPulls(
737 e,
738 orm.FilterEq("stack_id", stackId),
739 orm.FilterNotEq("state", models.PullDeleted),
740 )
741 if err != nil {
742 return nil, err
743 }
744 // map of parent-change-id to pull
745 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
746 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
747 for _, p := range unorderedPulls {
748 changeIdMap[p.ChangeId] = p
749 if p.ParentChangeId != "" {
750 parentMap[p.ParentChangeId] = p
751 }
752 }
753
754 // the top of the stack is the pull that is not a parent of any pull
755 var topPull *models.Pull
756 for _, maybeTop := range unorderedPulls {
757 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
758 topPull = maybeTop
759 break
760 }
761 }
762
763 pulls := []*models.Pull{}
764 for {
765 pulls = append(pulls, topPull)
766 if topPull.ParentChangeId != "" {
767 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
768 topPull = next
769 } else {
770 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
771 }
772 } else {
773 break
774 }
775 }
776
777 return pulls, nil
778}
779
780func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
781 pulls, err := GetPulls(
782 e,
783 orm.FilterEq("stack_id", stackId),
784 orm.FilterEq("state", models.PullDeleted),
785 )
786 if err != nil {
787 return nil, err
788 }
789
790 return pulls, nil
791}