Monorepo for Tangled tangled.org
at sl/shared-stacks 676 lines 16 kB view raw
1package db 2 3import ( 4 "cmp" 5 "database/sql" 6 "errors" 7 "fmt" 8 "maps" 9 "slices" 10 "sort" 11 "strings" 12 "time" 13 14 "github.com/bluesky-social/indigo/atproto/syntax" 15 "tangled.org/core/appview/models" 16 "tangled.org/core/appview/pagination" 17 "tangled.org/core/orm" 18) 19 20func NewPull(tx *sql.Tx, pull *models.Pull) error { 21 _, err := tx.Exec(` 22 insert or ignore into repo_pull_seqs (repo_at, next_pull_id) 23 values (?, 1) 24 `, pull.RepoAt) 25 if err != nil { 26 return err 27 } 28 29 var nextId int 30 err = tx.QueryRow(` 31 update repo_pull_seqs 32 set next_pull_id = next_pull_id + 1 33 where repo_at = ? 34 returning next_pull_id - 1 35 `, pull.RepoAt).Scan(&nextId) 36 if err != nil { 37 return err 38 } 39 40 pull.PullId = nextId 41 pull.State = models.PullOpen 42 43 var sourceBranch, sourceRepoAt *string 44 if pull.PullSource != nil { 45 sourceBranch = &pull.PullSource.Branch 46 if pull.PullSource.RepoAt != nil { 47 x := pull.PullSource.RepoAt.String() 48 sourceRepoAt = &x 49 } 50 } 51 52 var stackId, changeId, parentChangeId *string 53 if pull.StackId != "" { 54 stackId = &pull.StackId 55 } 56 if pull.ChangeId != "" { 57 changeId = &pull.ChangeId 58 } 59 if pull.ParentChangeId != "" { 60 parentChangeId = &pull.ParentChangeId 61 } 62 63 result, err := tx.Exec( 64 ` 65 insert into pulls ( 66 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id 67 ) 68 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`, 69 pull.RepoAt, 70 pull.OwnerDid, 71 pull.PullId, 72 pull.Title, 73 pull.TargetBranch, 74 pull.Body, 75 pull.Rkey, 76 pull.State, 77 sourceBranch, 78 sourceRepoAt, 79 stackId, 80 changeId, 81 parentChangeId, 82 ) 83 if err != nil { 84 return err 85 } 86 87 // Set the database primary key ID 88 id, err := result.LastInsertId() 89 if err != nil { 90 return err 91 } 92 pull.ID = int(id) 93 94 _, err = tx.Exec(` 95 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 96 values (?, ?, ?, ?, ?) 97 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev) 98 if err != nil { 99 return err 100 } 101 102 if err := putReferences(tx, pull.AtUri(), nil, pull.References); err != nil { 103 return fmt.Errorf("put reference_links: %w", err) 104 } 105 106 return nil 107} 108 109func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) { 110 pull, err := GetPull(e, repoAt, pullId) 111 if err != nil { 112 return "", err 113 } 114 return pull.AtUri(), err 115} 116 117func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) { 118 var pullId int 119 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId) 120 return pullId - 1, err 121} 122 123func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) { 124 pulls := make(map[syntax.ATURI]*models.Pull) 125 126 var conditions []string 127 var args []any 128 for _, filter := range filters { 129 conditions = append(conditions, filter.Condition()) 130 args = append(args, filter.Arg()...) 131 } 132 133 whereClause := "" 134 if conditions != nil { 135 whereClause = " where " + strings.Join(conditions, " and ") 136 } 137 pageClause := "" 138 if page.Limit != 0 { 139 pageClause = fmt.Sprintf( 140 " limit %d offset %d ", 141 page.Limit, 142 page.Offset, 143 ) 144 } 145 146 query := fmt.Sprintf(` 147 select 148 id, 149 owner_did, 150 repo_at, 151 pull_id, 152 created, 153 title, 154 state, 155 target_branch, 156 body, 157 rkey, 158 source_branch, 159 source_repo_at, 160 stack_id, 161 change_id, 162 parent_change_id 163 from 164 pulls 165 %s 166 order by 167 created desc 168 %s 169 `, whereClause, pageClause) 170 171 rows, err := e.Query(query, args...) 172 if err != nil { 173 return nil, err 174 } 175 defer rows.Close() 176 177 for rows.Next() { 178 var pull models.Pull 179 var createdAt string 180 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString 181 err := rows.Scan( 182 &pull.ID, 183 &pull.OwnerDid, 184 &pull.RepoAt, 185 &pull.PullId, 186 &createdAt, 187 &pull.Title, 188 &pull.State, 189 &pull.TargetBranch, 190 &pull.Body, 191 &pull.Rkey, 192 &sourceBranch, 193 &sourceRepoAt, 194 &stackId, 195 &changeId, 196 &parentChangeId, 197 ) 198 if err != nil { 199 return nil, err 200 } 201 202 createdTime, err := time.Parse(time.RFC3339, createdAt) 203 if err != nil { 204 return nil, err 205 } 206 pull.Created = createdTime 207 208 if sourceBranch.Valid { 209 pull.PullSource = &models.PullSource{ 210 Branch: sourceBranch.String, 211 } 212 if sourceRepoAt.Valid { 213 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String) 214 if err != nil { 215 return nil, err 216 } 217 pull.PullSource.RepoAt = &sourceRepoAtParsed 218 } 219 } 220 221 if stackId.Valid { 222 pull.StackId = stackId.String 223 } 224 if changeId.Valid { 225 pull.ChangeId = changeId.String 226 } 227 if parentChangeId.Valid { 228 pull.ParentChangeId = parentChangeId.String 229 } 230 231 pulls[pull.AtUri()] = &pull 232 } 233 234 var pullAts []syntax.ATURI 235 for _, p := range pulls { 236 pullAts = append(pullAts, p.AtUri()) 237 } 238 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts)) 239 if err != nil { 240 return nil, fmt.Errorf("failed to get submissions: %w", err) 241 } 242 243 for pullAt, submissions := range submissionsMap { 244 if p, ok := pulls[pullAt]; ok { 245 p.Submissions = submissions 246 } 247 } 248 249 // collect allLabels for each issue 250 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts)) 251 if err != nil { 252 return nil, fmt.Errorf("failed to query labels: %w", err) 253 } 254 for pullAt, labels := range allLabels { 255 if p, ok := pulls[pullAt]; ok { 256 p.Labels = labels 257 } 258 } 259 260 // collect pull source for all pulls that need it 261 var sourceAts []syntax.ATURI 262 for _, p := range pulls { 263 if p.PullSource != nil && p.PullSource.RepoAt != nil { 264 sourceAts = append(sourceAts, *p.PullSource.RepoAt) 265 } 266 } 267 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts)) 268 if err != nil && !errors.Is(err, sql.ErrNoRows) { 269 return nil, fmt.Errorf("failed to get source repos: %w", err) 270 } 271 sourceRepoMap := make(map[syntax.ATURI]*models.Repo) 272 for _, r := range sourceRepos { 273 sourceRepoMap[r.RepoAt()] = &r 274 } 275 for _, p := range pulls { 276 if p.PullSource != nil && p.PullSource.RepoAt != nil { 277 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok { 278 p.PullSource.Repo = sourceRepo 279 } 280 } 281 } 282 283 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts)) 284 if err != nil { 285 return nil, fmt.Errorf("failed to query reference_links: %w", err) 286 } 287 for pullAt, references := range allReferences { 288 if pull, ok := pulls[pullAt]; ok { 289 pull.References = references 290 } 291 } 292 293 orderedByPullId := []*models.Pull{} 294 for _, p := range pulls { 295 orderedByPullId = append(orderedByPullId, p) 296 } 297 sort.Slice(orderedByPullId, func(i, j int) bool { 298 return orderedByPullId[i].PullId > orderedByPullId[j].PullId 299 }) 300 301 return orderedByPullId, nil 302} 303 304func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) { 305 return GetPullsPaginated(e, pagination.Page{}, filters...) 306} 307 308func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) { 309 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId)) 310 if err != nil { 311 return nil, err 312 } 313 if len(pulls) == 0 { 314 return nil, sql.ErrNoRows 315 } 316 317 return pulls[0], nil 318} 319 320// mapping from pull -> pull submissions 321func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) { 322 var conditions []string 323 var args []any 324 for _, filter := range filters { 325 conditions = append(conditions, filter.Condition()) 326 args = append(args, filter.Arg()...) 327 } 328 329 whereClause := "" 330 if conditions != nil { 331 whereClause = " where " + strings.Join(conditions, " and ") 332 } 333 334 query := fmt.Sprintf(` 335 select 336 id, 337 pull_at, 338 round_number, 339 patch, 340 combined, 341 created, 342 source_rev 343 from 344 pull_submissions 345 %s 346 order by 347 round_number asc 348 `, whereClause) 349 350 rows, err := e.Query(query, args...) 351 if err != nil { 352 return nil, err 353 } 354 defer rows.Close() 355 356 submissionMap := make(map[int]*models.PullSubmission) 357 358 for rows.Next() { 359 var submission models.PullSubmission 360 var submissionCreatedStr string 361 var submissionSourceRev, submissionCombined sql.NullString 362 err := rows.Scan( 363 &submission.ID, 364 &submission.PullAt, 365 &submission.RoundNumber, 366 &submission.Patch, 367 &submissionCombined, 368 &submissionCreatedStr, 369 &submissionSourceRev, 370 ) 371 if err != nil { 372 return nil, err 373 } 374 375 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil { 376 submission.Created = t 377 } 378 379 if submissionSourceRev.Valid { 380 submission.SourceRev = submissionSourceRev.String 381 } 382 383 if submissionCombined.Valid { 384 submission.Combined = submissionCombined.String 385 } 386 387 submissionMap[submission.ID] = &submission 388 } 389 390 if err := rows.Err(); err != nil { 391 return nil, err 392 } 393 394 // Get comments for all submissions using GetComments 395 submissionIds := slices.Collect(maps.Keys(submissionMap)) 396 comments, err := GetComments(e, orm.FilterIn("pull_submission_id", submissionIds)) 397 if err != nil { 398 return nil, fmt.Errorf("failed to get pull comments: %w", err) 399 } 400 for _, comment := range comments { 401 if comment.PullSubmissionId != nil { 402 if submission, ok := submissionMap[*comment.PullSubmissionId]; ok { 403 submission.Comments = append(submission.Comments, comment) 404 } 405 } 406 } 407 408 // group the submissions by pull_at 409 m := make(map[syntax.ATURI][]*models.PullSubmission) 410 for _, s := range submissionMap { 411 m[s.PullAt] = append(m[s.PullAt], s) 412 } 413 414 // sort each one by round number 415 for _, s := range m { 416 slices.SortFunc(s, func(a, b *models.PullSubmission) int { 417 return cmp.Compare(a.RoundNumber, b.RoundNumber) 418 }) 419 } 420 421 return m, nil 422} 423 424// timeframe here is directly passed into the sql query filter, and any 425// timeframe in the past should be negative; e.g.: "-3 months" 426func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) { 427 var pulls []models.Pull 428 429 rows, err := e.Query(` 430 select 431 p.owner_did, 432 p.repo_at, 433 p.pull_id, 434 p.created, 435 p.title, 436 p.state, 437 r.did, 438 r.name, 439 r.knot, 440 r.rkey, 441 r.created 442 from 443 pulls p 444 join 445 repos r on p.repo_at = r.at_uri 446 where 447 p.owner_did = ? and p.created >= date ('now', ?) 448 order by 449 p.created desc`, did, timeframe) 450 if err != nil { 451 return nil, err 452 } 453 defer rows.Close() 454 455 for rows.Next() { 456 var pull models.Pull 457 var repo models.Repo 458 var pullCreatedAt, repoCreatedAt string 459 err := rows.Scan( 460 &pull.OwnerDid, 461 &pull.RepoAt, 462 &pull.PullId, 463 &pullCreatedAt, 464 &pull.Title, 465 &pull.State, 466 &repo.Did, 467 &repo.Name, 468 &repo.Knot, 469 &repo.Rkey, 470 &repoCreatedAt, 471 ) 472 if err != nil { 473 return nil, err 474 } 475 476 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt) 477 if err != nil { 478 return nil, err 479 } 480 pull.Created = pullCreatedTime 481 482 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt) 483 if err != nil { 484 return nil, err 485 } 486 repo.Created = repoCreatedTime 487 488 pull.Repo = &repo 489 490 pulls = append(pulls, pull) 491 } 492 493 if err := rows.Err(); err != nil { 494 return nil, err 495 } 496 497 return pulls, nil 498} 499 500func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error { 501 _, err := e.Exec( 502 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`, 503 pullState, 504 repoAt, 505 pullId, 506 models.PullDeleted, // only update state of non-deleted pulls 507 models.PullMerged, // only update state of non-merged pulls 508 ) 509 return err 510} 511 512func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error { 513 err := SetPullState(e, repoAt, pullId, models.PullClosed) 514 return err 515} 516 517func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error { 518 err := SetPullState(e, repoAt, pullId, models.PullOpen) 519 return err 520} 521 522func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error { 523 err := SetPullState(e, repoAt, pullId, models.PullMerged) 524 return err 525} 526 527func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error { 528 err := SetPullState(e, repoAt, pullId, models.PullDeleted) 529 return err 530} 531 532func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error { 533 _, err := e.Exec(` 534 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev) 535 values (?, ?, ?, ?, ?) 536 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev) 537 538 return err 539} 540 541func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error { 542 var conditions []string 543 var args []any 544 545 args = append(args, parentChangeId) 546 547 for _, filter := range filters { 548 conditions = append(conditions, filter.Condition()) 549 args = append(args, filter.Arg()...) 550 } 551 552 whereClause := "" 553 if conditions != nil { 554 whereClause = " where " + strings.Join(conditions, " and ") 555 } 556 557 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause) 558 _, err := e.Exec(query, args...) 559 560 return err 561} 562 563// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty). 564// otherwise submissions are immutable 565func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error { 566 var conditions []string 567 var args []any 568 569 args = append(args, sourceRev) 570 args = append(args, newPatch) 571 572 for _, filter := range filters { 573 conditions = append(conditions, filter.Condition()) 574 args = append(args, filter.Arg()...) 575 } 576 577 whereClause := "" 578 if conditions != nil { 579 whereClause = " where " + strings.Join(conditions, " and ") 580 } 581 582 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause) 583 _, err := e.Exec(query, args...) 584 585 return err 586} 587 588func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) { 589 row := e.QueryRow(` 590 select 591 count(case when state = ? then 1 end) as open_count, 592 count(case when state = ? then 1 end) as merged_count, 593 count(case when state = ? then 1 end) as closed_count, 594 count(case when state = ? then 1 end) as deleted_count 595 from pulls 596 where repo_at = ?`, 597 models.PullOpen, 598 models.PullMerged, 599 models.PullClosed, 600 models.PullDeleted, 601 repoAt, 602 ) 603 604 var count models.PullCount 605 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil { 606 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err 607 } 608 609 return count, nil 610} 611 612// change-id parent-change-id 613// 614// 4 w ,-------- z (TOP) 615// 3 z <----',------- y 616// 2 y <-----',------ x 617// 1 x <------' nil (BOT) 618// 619// `w` is parent of none, so it is the top of the stack 620func GetStack(e Execer, stackId string) (models.Stack, error) { 621 unorderedPulls, err := GetPulls( 622 e, 623 orm.FilterEq("stack_id", stackId), 624 orm.FilterNotEq("state", models.PullDeleted), 625 ) 626 if err != nil { 627 return nil, err 628 } 629 // map of parent-change-id to pull 630 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls)) 631 parentMap := make(map[string]*models.Pull, len(unorderedPulls)) 632 for _, p := range unorderedPulls { 633 changeIdMap[p.ChangeId] = p 634 if p.ParentChangeId != "" { 635 parentMap[p.ParentChangeId] = p 636 } 637 } 638 639 // the top of the stack is the pull that is not a parent of any pull 640 var topPull *models.Pull 641 for _, maybeTop := range unorderedPulls { 642 if _, ok := parentMap[maybeTop.ChangeId]; !ok { 643 topPull = maybeTop 644 break 645 } 646 } 647 648 pulls := []*models.Pull{} 649 for { 650 pulls = append(pulls, topPull) 651 if topPull.ParentChangeId != "" { 652 if next, ok := changeIdMap[topPull.ParentChangeId]; ok { 653 topPull = next 654 } else { 655 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed") 656 } 657 } else { 658 break 659 } 660 } 661 662 return pulls, nil 663} 664 665func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) { 666 pulls, err := GetPulls( 667 e, 668 orm.FilterEq("stack_id", stackId), 669 orm.FilterEq("state", models.PullDeleted), 670 ) 671 if err != nil { 672 return nil, err 673 } 674 675 return pulls, nil 676}