forked from tangled.org/core
Monorepo for Tangled
at master 791 lines 18 kB view raw
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16 "tangled.org/core/appview/pagination" 17 "tangled.org/core/orm" 18) 19 20func NewPull(tx *sql.Tx, pull *models.Pull) error { 21 _, err := tx.Exec(` 22 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 23 values (?, 1) 24 `, pull.RepoAt) 25 if err != nil { 26 return err 27 } 28 29 var nextId int 30 err = tx.QueryRow(` 31 update repo_pull_seqs 32 set next_pull_id = next_pull_id + 1 33 where repo_at = ? 34 returning next_pull_id - 1 35 `, pull.RepoAt).Scan(&nextId) 36 if err != nil { 37 return err 38 } 39 40 pull.PullId = nextId 41 pull.State = models.PullOpen 42 43 var sourceBranch, sourceRepoAt *string 44 if pull.PullSource != nil { 45 sourceBranch = &pull.PullSource.Branch 46 if pull.PullSource.RepoAt != nil { 47 x := pull.PullSource.RepoAt.String() 48 sourceRepoAt = &x 49 } 50 } 51 52 var stackId, changeId, parentChangeId *string 53 if pull.StackId != "" { 54 stackId = &pull.StackId 55 } 56 if pull.ChangeId != "" { 57 changeId = &pull.ChangeId 58 } 59 if pull.ParentChangeId != "" { 60 parentChangeId = &pull.ParentChangeId 61 } 62 63 result, err := tx.Exec( 64 ` 65 insert into pulls ( 66 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 67 ) 68 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 69 pull.RepoAt, 70 pull.OwnerDid, 71 pull.PullId, 72 pull.Title, 73 pull.TargetBranch, 74 pull.Body, 75 pull.Rkey, 76 pull.State, 77 sourceBranch, 78 sourceRepoAt, 79 stackId, 80 changeId, 81 parentChangeId, 82 ) 83 if err != nil { 84 return err 85 } 86 87 // Set the database primary key ID 88 id, err := result.LastInsertId() 89 if err != nil { 90 return err 91 } 92 pull.ID = int(id) 93 94 _, err = tx.Exec(` 95 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 96 values (?, ?, ?, ?, ?) 97 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 98 if err != nil { 99 return err 100 } 101 102 if err := putReferences(tx, pull.AtUri(), pull.References); err != nil { 103 return fmt.Errorf("put reference_links: %w", err) 104 } 105 106 return nil 107} 108 109func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 110 pull, err := GetPull(e, repoAt, pullId) 111 if err != nil { 112 return "", err 113 } 114 return pull.AtUri(), err 115} 116 117func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 118 var pullId int 119 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 120 return pullId - 1, err 121} 122 123func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) { 124 pulls := make(map[syntax.ATURI]*models.Pull) 125 126 var conditions []string 127 var args []any 128 for _, filter := range filters { 129 conditions = append(conditions, filter.Condition()) 130 args = append(args, filter.Arg()...) 131 } 132 133 whereClause := "" 134 if conditions != nil { 135 whereClause = " where " + strings.Join(conditions, " and ") 136 } 137 pageClause := "" 138 if page.Limit != 0 { 139 pageClause = fmt.Sprintf( 140 " limit %d offset %d ", 141 page.Limit, 142 page.Offset, 143 ) 144 } 145 146 query := fmt.Sprintf(` 147 select 148 id, 149 owner_did, 150 repo_at, 151 pull_id, 152 created, 153 title, 154 state, 155 target_branch, 156 body, 157 rkey, 158 source_branch, 159 source_repo_at, 160 stack_id, 161 change_id, 162 parent_change_id 163 from 164 pulls 165 %s 166 order by 167 created desc 168 %s 169 `, whereClause, pageClause) 170 171 rows, err := e.Query(query, args...) 172 if err != nil { 173 return nil, err 174 } 175 defer rows.Close() 176 177 for rows.Next() { 178 var pull models.Pull 179 var createdAt string 180 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 181 err := rows.Scan( 182 &pull.ID, 183 &pull.OwnerDid, 184 &pull.RepoAt, 185 &pull.PullId, 186 &createdAt, 187 &pull.Title, 188 &pull.State, 189 &pull.TargetBranch, 190 &pull.Body, 191 &pull.Rkey, 192 &sourceBranch, 193 &sourceRepoAt, 194 &stackId, 195 &changeId, 196 &parentChangeId, 197 ) 198 if err != nil { 199 return nil, err 200 } 201 202 createdTime, err := time.Parse(time.RFC3339, createdAt) 203 if err != nil { 204 return nil, err 205 } 206 pull.Created = createdTime 207 208 if sourceBranch.Valid { 209 pull.PullSource = &models.PullSource{ 210 Branch: sourceBranch.String, 211 } 212 if sourceRepoAt.Valid { 213 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 214 if err != nil { 215 return nil, err 216 } 217 pull.PullSource.RepoAt = &sourceRepoAtParsed 218 } 219 } 220 221 if stackId.Valid { 222 pull.StackId = stackId.String 223 } 224 if changeId.Valid { 225 pull.ChangeId = changeId.String 226 } 227 if parentChangeId.Valid { 228 pull.ParentChangeId = parentChangeId.String 229 } 230 231 pulls[pull.AtUri()] = &pull 232 } 233 234 var pullAts []syntax.ATURI 235 for _, p := range pulls { 236 pullAts = append(pullAts, p.AtUri()) 237 } 238 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts)) 239 if err != nil { 240 return nil, fmt.Errorf("failed to get submissions: %w", err) 241 } 242 243 for pullAt, submissions := range submissionsMap { 244 if p, ok := pulls[pullAt]; ok { 245 p.Submissions = submissions 246 } 247 } 248 249 // collect allLabels for each issue 250 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts)) 251 if err != nil { 252 return nil, fmt.Errorf("failed to query labels: %w", err) 253 } 254 for pullAt, labels := range allLabels { 255 if p, ok := pulls[pullAt]; ok { 256 p.Labels = labels 257 } 258 } 259 260 // collect pull source for all pulls that need it 261 var sourceAts []syntax.ATURI 262 for _, p := range pulls { 263 if p.PullSource != nil && p.PullSource.RepoAt != nil { 264 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 265 } 266 } 267 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts)) 268 if err != nil && !errors.Is(err, sql.ErrNoRows) { 269 return nil, fmt.Errorf("failed to get source repos: %w", err) 270 } 271 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 272 for _, r := range sourceRepos { 273 sourceRepoMap[r.RepoAt()] = &r 274 } 275 for _, p := range pulls { 276 if p.PullSource != nil && p.PullSource.RepoAt != nil { 277 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 278 p.PullSource.Repo = sourceRepo 279 } 280 } 281 } 282 283 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts)) 284 if err != nil { 285 return nil, fmt.Errorf("failed to query reference_links: %w", err) 286 } 287 for pullAt, references := range allReferences { 288 if pull, ok := pulls[pullAt]; ok { 289 pull.References = references 290 } 291 } 292 293 orderedByPullId := []*models.Pull{} 294 for _, p := range pulls { 295 orderedByPullId = append(orderedByPullId, p) 296 } 297 sort.Slice(orderedByPullId, func(i, j int) bool { 298 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 299 }) 300 301 return orderedByPullId, nil 302} 303 304func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) { 305 return GetPullsPaginated(e, pagination.Page{}, filters...) 306} 307 308func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 309 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId)) 310 if err != nil { 311 return nil, err 312 } 313 if len(pulls) == 0 { 314 return nil, sql.ErrNoRows 315 } 316 317 return pulls[0], nil 318} 319 320// mapping from pull -> pull submissions 321func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 322 var conditions []string 323 var args []any 324 for _, filter := range filters { 325 conditions = append(conditions, filter.Condition()) 326 args = append(args, filter.Arg()...) 327 } 328 329 whereClause := "" 330 if conditions != nil { 331 whereClause = " where " + strings.Join(conditions, " and ") 332 } 333 334 query := fmt.Sprintf(` 335 select 336 id, 337 pull_at, 338 round_number, 339 patch, 340 combined, 341 created, 342 source_rev 343 from 344 pull_submissions 345 %s 346 order by 347 round_number asc 348 `, whereClause) 349 350 rows, err := e.Query(query, args...) 351 if err != nil { 352 return nil, err 353 } 354 defer rows.Close() 355 356 submissionMap := make(map[int]*models.PullSubmission) 357 358 for rows.Next() { 359 var submission models.PullSubmission 360 var submissionCreatedStr string 361 var submissionSourceRev, submissionCombined sql.NullString 362 err := rows.Scan( 363 &submission.ID, 364 &submission.PullAt, 365 &submission.RoundNumber, 366 &submission.Patch, 367 &submissionCombined, 368 &submissionCreatedStr, 369 &submissionSourceRev, 370 ) 371 if err != nil { 372 return nil, err 373 } 374 375 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 376 submission.Created = t 377 } 378 379 if submissionSourceRev.Valid { 380 submission.SourceRev = submissionSourceRev.String 381 } 382 383 if submissionCombined.Valid { 384 submission.Combined = submissionCombined.String 385 } 386 387 submissionMap[submission.ID] = &submission 388 } 389 390 if err := rows.Err(); err != nil { 391 return nil, err 392 } 393 394 // Get comments for all submissions using GetPullComments 395 submissionIds := slices.Collect(maps.Keys(submissionMap)) 396 comments, err := GetPullComments(e, orm.FilterIn("submission_id", submissionIds)) 397 if err != nil { 398 return nil, fmt.Errorf("failed to get pull comments: %w", err) 399 } 400 for _, comment := range comments { 401 if submission, ok := submissionMap[comment.SubmissionId]; ok { 402 submission.Comments = append(submission.Comments, comment) 403 } 404 } 405 406 // group the submissions by pull_at 407 m := make(map[syntax.ATURI][]*models.PullSubmission) 408 for _, s := range submissionMap { 409 m[s.PullAt] = append(m[s.PullAt], s) 410 } 411 412 // sort each one by round number 413 for _, s := range m { 414 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 415 return cmp.Compare(a.RoundNumber, b.RoundNumber) 416 }) 417 } 418 419 return m, nil 420} 421 422func GetPullComments(e Execer, filters ...orm.Filter) ([]models.PullComment, error) { 423 var conditions []string 424 var args []any 425 for _, filter := range filters { 426 conditions = append(conditions, filter.Condition()) 427 args = append(args, filter.Arg()...) 428 } 429 430 whereClause := "" 431 if conditions != nil { 432 whereClause = " where " + strings.Join(conditions, " and ") 433 } 434 435 query := fmt.Sprintf(` 436 select 437 id, 438 pull_id, 439 submission_id, 440 repo_at, 441 owner_did, 442 comment_at, 443 body, 444 created 445 from 446 pull_comments 447 %s 448 order by 449 created asc 450 `, whereClause) 451 452 rows, err := e.Query(query, args...) 453 if err != nil { 454 return nil, err 455 } 456 defer rows.Close() 457 458 commentMap := make(map[string]*models.PullComment) 459 for rows.Next() { 460 var comment models.PullComment 461 var createdAt string 462 err := rows.Scan( 463 &comment.ID, 464 &comment.PullId, 465 &comment.SubmissionId, 466 &comment.RepoAt, 467 &comment.OwnerDid, 468 &comment.CommentAt, 469 &comment.Body, 470 &createdAt, 471 ) 472 if err != nil { 473 return nil, err 474 } 475 476 if t, err := time.Parse(time.RFC3339, createdAt); err == nil { 477 comment.Created = t 478 } 479 480 atUri := comment.AtUri().String() 481 commentMap[atUri] = &comment 482 } 483 484 if err := rows.Err(); err != nil { 485 return nil, err 486 } 487 488 // collect references for each comments 489 commentAts := slices.Collect(maps.Keys(commentMap)) 490 allReferencs, err := GetReferencesAll(e, orm.FilterIn("from_at", commentAts)) 491 if err != nil { 492 return nil, fmt.Errorf("failed to query reference_links: %w", err) 493 } 494 for commentAt, references := range allReferencs { 495 if comment, ok := commentMap[commentAt.String()]; ok { 496 comment.References = references 497 } 498 } 499 500 var comments []models.PullComment 501 for _, c := range commentMap { 502 comments = append(comments, *c) 503 } 504 505 sort.Slice(comments, func(i, j int) bool { 506 return comments[i].Created.Before(comments[j].Created) 507 }) 508 509 return comments, nil 510} 511 512// timeframe here is directly passed into the sql query filter, and any 513// timeframe in the past should be negative; e.g.: "-3 months" 514func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 515 var pulls []models.Pull 516 517 rows, err := e.Query(` 518 select 519 p.owner_did, 520 p.repo_at, 521 p.pull_id, 522 p.created, 523 p.title, 524 p.state, 525 r.did, 526 r.name, 527 r.knot, 528 r.rkey, 529 r.created 530 from 531 pulls p 532 join 533 repos r on p.repo_at = r.at_uri 534 where 535 p.owner_did = ? and p.created >= date ('now', ?) 536 order by 537 p.created desc`, did, timeframe) 538 if err != nil { 539 return nil, err 540 } 541 defer rows.Close() 542 543 for rows.Next() { 544 var pull models.Pull 545 var repo models.Repo 546 var pullCreatedAt, repoCreatedAt string 547 err := rows.Scan( 548 &pull.OwnerDid, 549 &pull.RepoAt, 550 &pull.PullId, 551 &pullCreatedAt, 552 &pull.Title, 553 &pull.State, 554 &repo.Did, 555 &repo.Name, 556 &repo.Knot, 557 &repo.Rkey, 558 &repoCreatedAt, 559 ) 560 if err != nil { 561 return nil, err 562 } 563 564 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 565 if err != nil { 566 return nil, err 567 } 568 pull.Created = pullCreatedTime 569 570 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 571 if err != nil { 572 return nil, err 573 } 574 repo.Created = repoCreatedTime 575 576 pull.Repo = &repo 577 578 pulls = append(pulls, pull) 579 } 580 581 if err := rows.Err(); err != nil { 582 return nil, err 583 } 584 585 return pulls, nil 586} 587 588func NewPullComment(tx *sql.Tx, comment *models.PullComment) (int64, error) { 589 query := `insert into pull_comments (owner_did, repo_at, submission_id, comment_at, pull_id, body) values (?, ?, ?, ?, ?, ?)` 590 res, err := tx.Exec( 591 query, 592 comment.OwnerDid, 593 comment.RepoAt, 594 comment.SubmissionId, 595 comment.CommentAt, 596 comment.PullId, 597 comment.Body, 598 ) 599 if err != nil { 600 return 0, err 601 } 602 603 i, err := res.LastInsertId() 604 if err != nil { 605 return 0, err 606 } 607 608 if err := putReferences(tx, comment.AtUri(), comment.References); err != nil { 609 return 0, fmt.Errorf("put reference_links: %w", err) 610 } 611 612 return i, nil 613} 614 615func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 616 _, err := e.Exec( 617 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 618 pullState, 619 repoAt, 620 pullId, 621 models.PullDeleted, // only update state of non-deleted pulls 622 models.PullMerged, // only update state of non-merged pulls 623 ) 624 return err 625} 626 627func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 628 err := SetPullState(e, repoAt, pullId, models.PullClosed) 629 return err 630} 631 632func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 633 err := SetPullState(e, repoAt, pullId, models.PullOpen) 634 return err 635} 636 637func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 638 err := SetPullState(e, repoAt, pullId, models.PullMerged) 639 return err 640} 641 642func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 643 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 644 return err 645} 646 647func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 648 _, err := e.Exec(` 649 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 650 values (?, ?, ?, ?, ?) 651 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 652 653 return err 654} 655 656func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error { 657 var conditions []string 658 var args []any 659 660 args = append(args, parentChangeId) 661 662 for _, filter := range filters { 663 conditions = append(conditions, filter.Condition()) 664 args = append(args, filter.Arg()...) 665 } 666 667 whereClause := "" 668 if conditions != nil { 669 whereClause = " where " + strings.Join(conditions, " and ") 670 } 671 672 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 673 _, err := e.Exec(query, args...) 674 675 return err 676} 677 678// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 679// otherwise submissions are immutable 680func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error { 681 var conditions []string 682 var args []any 683 684 args = append(args, sourceRev) 685 args = append(args, newPatch) 686 687 for _, filter := range filters { 688 conditions = append(conditions, filter.Condition()) 689 args = append(args, filter.Arg()...) 690 } 691 692 whereClause := "" 693 if conditions != nil { 694 whereClause = " where " + strings.Join(conditions, " and ") 695 } 696 697 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 698 _, err := e.Exec(query, args...) 699 700 return err 701} 702 703func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 704 row := e.QueryRow(` 705 select 706 count(case when state = ? then 1 end) as open_count, 707 count(case when state = ? then 1 end) as merged_count, 708 count(case when state = ? then 1 end) as closed_count, 709 count(case when state = ? then 1 end) as deleted_count 710 from pulls 711 where repo_at = ?`, 712 models.PullOpen, 713 models.PullMerged, 714 models.PullClosed, 715 models.PullDeleted, 716 repoAt, 717 ) 718 719 var count models.PullCount 720 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 721 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 722 } 723 724 return count, nil 725} 726 727// change-id parent-change-id 728// 729// 4 w ,-------- z (TOP) 730// 3 z <----',------- y 731// 2 y <-----',------ x 732// 1 x <------' nil (BOT) 733// 734// `w` is parent of none, so it is the top of the stack 735func GetStack(e Execer, stackId string) (models.Stack, error) { 736 unorderedPulls, err := GetPulls( 737 e, 738 orm.FilterEq("stack_id", stackId), 739 orm.FilterNotEq("state", models.PullDeleted), 740 ) 741 if err != nil { 742 return nil, err 743 } 744 // map of parent-change-id to pull 745 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 746 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 747 for _, p := range unorderedPulls { 748 changeIdMap[p.ChangeId] = p 749 if p.ParentChangeId != "" { 750 parentMap[p.ParentChangeId] = p 751 } 752 } 753 754 // the top of the stack is the pull that is not a parent of any pull 755 var topPull *models.Pull 756 for _, maybeTop := range unorderedPulls { 757 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 758 topPull = maybeTop 759 break 760 } 761 } 762 763 pulls := []*models.Pull{} 764 for { 765 pulls = append(pulls, topPull) 766 if topPull.ParentChangeId != "" { 767 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 768 topPull = next 769 } else { 770 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 771 } 772 } else { 773 break 774 } 775 } 776 777 return pulls, nil 778} 779 780func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 781 pulls, err := GetPulls( 782 e, 783 orm.FilterEq("stack_id", stackId), 784 orm.FilterEq("state", models.PullDeleted), 785 ) 786 if err != nil { 787 return nil, err 788 } 789 790 return pulls, nil 791}