Monorepo for Tangled
tangled.org
1package db
2
3import (
4 "cmp"
5 "database/sql"
6 "errors"
7 "fmt"
8 "maps"
9 "slices"
10 "sort"
11 "strings"
12 "time"
13
14 "github.com/bluesky-social/indigo/atproto/syntax"
15 "tangled.org/core/appview/models"
16 "tangled.org/core/appview/pagination"
17 "tangled.org/core/orm"
18)
19
20func NewPull(tx *sql.Tx, pull *models.Pull) error {
21 _, err := tx.Exec(`
22 insert or ignore into repo_pull_seqs (repo_at, next_pull_id)
23 values (?, 1)
24 `, pull.RepoAt)
25 if err != nil {
26 return err
27 }
28
29 var nextId int
30 err = tx.QueryRow(`
31 update repo_pull_seqs
32 set next_pull_id = next_pull_id + 1
33 where repo_at = ?
34 returning next_pull_id - 1
35 `, pull.RepoAt).Scan(&nextId)
36 if err != nil {
37 return err
38 }
39
40 pull.PullId = nextId
41 pull.State = models.PullOpen
42
43 var sourceBranch, sourceRepoAt *string
44 if pull.PullSource != nil {
45 sourceBranch = &pull.PullSource.Branch
46 if pull.PullSource.RepoAt != nil {
47 x := pull.PullSource.RepoAt.String()
48 sourceRepoAt = &x
49 }
50 }
51
52 var stackId, changeId, parentChangeId *string
53 if pull.StackId != "" {
54 stackId = &pull.StackId
55 }
56 if pull.ChangeId != "" {
57 changeId = &pull.ChangeId
58 }
59 if pull.ParentChangeId != "" {
60 parentChangeId = &pull.ParentChangeId
61 }
62
63 result, err := tx.Exec(
64 `
65 insert into pulls (
66 repo_at, owner_did, pull_id, title, target_branch, body, rkey, state, source_branch, source_repo_at, stack_id, change_id, parent_change_id
67 )
68 values (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)`,
69 pull.RepoAt,
70 pull.OwnerDid,
71 pull.PullId,
72 pull.Title,
73 pull.TargetBranch,
74 pull.Body,
75 pull.Rkey,
76 pull.State,
77 sourceBranch,
78 sourceRepoAt,
79 stackId,
80 changeId,
81 parentChangeId,
82 )
83 if err != nil {
84 return err
85 }
86
87 // Set the database primary key ID
88 id, err := result.LastInsertId()
89 if err != nil {
90 return err
91 }
92 pull.ID = int(id)
93
94 _, err = tx.Exec(`
95 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
96 values (?, ?, ?, ?, ?)
97 `, pull.AtUri(), 0, pull.Submissions[0].Patch, pull.Submissions[0].Combined, pull.Submissions[0].SourceRev)
98 if err != nil {
99 return err
100 }
101
102 if err := putReferences(tx, pull.AtUri(), nil, pull.References); err != nil {
103 return fmt.Errorf("put reference_links: %w", err)
104 }
105
106 return nil
107}
108
109func GetPullAt(e Execer, repoAt syntax.ATURI, pullId int) (syntax.ATURI, error) {
110 pull, err := GetPull(e, repoAt, pullId)
111 if err != nil {
112 return "", err
113 }
114 return pull.AtUri(), err
115}
116
117func NextPullId(e Execer, repoAt syntax.ATURI) (int, error) {
118 var pullId int
119 err := e.QueryRow(`select next_pull_id from repo_pull_seqs where repo_at = ?`, repoAt).Scan(&pullId)
120 return pullId - 1, err
121}
122
123func GetPullsPaginated(e Execer, page pagination.Page, filters ...orm.Filter) ([]*models.Pull, error) {
124 pulls := make(map[syntax.ATURI]*models.Pull)
125
126 var conditions []string
127 var args []any
128 for _, filter := range filters {
129 conditions = append(conditions, filter.Condition())
130 args = append(args, filter.Arg()...)
131 }
132
133 whereClause := ""
134 if conditions != nil {
135 whereClause = " where " + strings.Join(conditions, " and ")
136 }
137 pageClause := ""
138 if page.Limit != 0 {
139 pageClause = fmt.Sprintf(
140 " limit %d offset %d ",
141 page.Limit,
142 page.Offset,
143 )
144 }
145
146 query := fmt.Sprintf(`
147 select
148 id,
149 owner_did,
150 repo_at,
151 pull_id,
152 created,
153 title,
154 state,
155 target_branch,
156 body,
157 rkey,
158 source_branch,
159 source_repo_at,
160 stack_id,
161 change_id,
162 parent_change_id
163 from
164 pulls
165 %s
166 order by
167 created desc
168 %s
169 `, whereClause, pageClause)
170
171 rows, err := e.Query(query, args...)
172 if err != nil {
173 return nil, err
174 }
175 defer rows.Close()
176
177 for rows.Next() {
178 var pull models.Pull
179 var createdAt string
180 var sourceBranch, sourceRepoAt, stackId, changeId, parentChangeId sql.NullString
181 err := rows.Scan(
182 &pull.ID,
183 &pull.OwnerDid,
184 &pull.RepoAt,
185 &pull.PullId,
186 &createdAt,
187 &pull.Title,
188 &pull.State,
189 &pull.TargetBranch,
190 &pull.Body,
191 &pull.Rkey,
192 &sourceBranch,
193 &sourceRepoAt,
194 &stackId,
195 &changeId,
196 &parentChangeId,
197 )
198 if err != nil {
199 return nil, err
200 }
201
202 createdTime, err := time.Parse(time.RFC3339, createdAt)
203 if err != nil {
204 return nil, err
205 }
206 pull.Created = createdTime
207
208 if sourceBranch.Valid {
209 pull.PullSource = &models.PullSource{
210 Branch: sourceBranch.String,
211 }
212 if sourceRepoAt.Valid {
213 sourceRepoAtParsed, err := syntax.ParseATURI(sourceRepoAt.String)
214 if err != nil {
215 return nil, err
216 }
217 pull.PullSource.RepoAt = &sourceRepoAtParsed
218 }
219 }
220
221 if stackId.Valid {
222 pull.StackId = stackId.String
223 }
224 if changeId.Valid {
225 pull.ChangeId = changeId.String
226 }
227 if parentChangeId.Valid {
228 pull.ParentChangeId = parentChangeId.String
229 }
230
231 pulls[pull.AtUri()] = &pull
232 }
233
234 var pullAts []syntax.ATURI
235 for _, p := range pulls {
236 pullAts = append(pullAts, p.AtUri())
237 }
238 submissionsMap, err := GetPullSubmissions(e, orm.FilterIn("pull_at", pullAts))
239 if err != nil {
240 return nil, fmt.Errorf("failed to get submissions: %w", err)
241 }
242
243 for pullAt, submissions := range submissionsMap {
244 if p, ok := pulls[pullAt]; ok {
245 p.Submissions = submissions
246 }
247 }
248
249 // collect allLabels for each issue
250 allLabels, err := GetLabels(e, orm.FilterIn("subject", pullAts))
251 if err != nil {
252 return nil, fmt.Errorf("failed to query labels: %w", err)
253 }
254 for pullAt, labels := range allLabels {
255 if p, ok := pulls[pullAt]; ok {
256 p.Labels = labels
257 }
258 }
259
260 // collect pull source for all pulls that need it
261 var sourceAts []syntax.ATURI
262 for _, p := range pulls {
263 if p.PullSource != nil && p.PullSource.RepoAt != nil {
264 sourceAts = append(sourceAts, *p.PullSource.RepoAt)
265 }
266 }
267 sourceRepos, err := GetRepos(e, 0, orm.FilterIn("at_uri", sourceAts))
268 if err != nil && !errors.Is(err, sql.ErrNoRows) {
269 return nil, fmt.Errorf("failed to get source repos: %w", err)
270 }
271 sourceRepoMap := make(map[syntax.ATURI]*models.Repo)
272 for _, r := range sourceRepos {
273 sourceRepoMap[r.RepoAt()] = &r
274 }
275 for _, p := range pulls {
276 if p.PullSource != nil && p.PullSource.RepoAt != nil {
277 if sourceRepo, ok := sourceRepoMap[*p.PullSource.RepoAt]; ok {
278 p.PullSource.Repo = sourceRepo
279 }
280 }
281 }
282
283 allReferences, err := GetReferencesAll(e, orm.FilterIn("from_at", pullAts))
284 if err != nil {
285 return nil, fmt.Errorf("failed to query reference_links: %w", err)
286 }
287 for pullAt, references := range allReferences {
288 if pull, ok := pulls[pullAt]; ok {
289 pull.References = references
290 }
291 }
292
293 orderedByPullId := []*models.Pull{}
294 for _, p := range pulls {
295 orderedByPullId = append(orderedByPullId, p)
296 }
297 sort.Slice(orderedByPullId, func(i, j int) bool {
298 return orderedByPullId[i].PullId > orderedByPullId[j].PullId
299 })
300
301 return orderedByPullId, nil
302}
303
304func GetPulls(e Execer, filters ...orm.Filter) ([]*models.Pull, error) {
305 return GetPullsPaginated(e, pagination.Page{}, filters...)
306}
307
308func GetPull(e Execer, repoAt syntax.ATURI, pullId int) (*models.Pull, error) {
309 pulls, err := GetPullsPaginated(e, pagination.Page{Limit: 1}, orm.FilterEq("repo_at", repoAt), orm.FilterEq("pull_id", pullId))
310 if err != nil {
311 return nil, err
312 }
313 if len(pulls) == 0 {
314 return nil, sql.ErrNoRows
315 }
316
317 return pulls[0], nil
318}
319
320// mapping from pull -> pull submissions
321func GetPullSubmissions(e Execer, filters ...orm.Filter) (map[syntax.ATURI][]*models.PullSubmission, error) {
322 var conditions []string
323 var args []any
324 for _, filter := range filters {
325 conditions = append(conditions, filter.Condition())
326 args = append(args, filter.Arg()...)
327 }
328
329 whereClause := ""
330 if conditions != nil {
331 whereClause = " where " + strings.Join(conditions, " and ")
332 }
333
334 query := fmt.Sprintf(`
335 select
336 id,
337 pull_at,
338 round_number,
339 patch,
340 combined,
341 created,
342 source_rev
343 from
344 pull_submissions
345 %s
346 order by
347 round_number asc
348 `, whereClause)
349
350 rows, err := e.Query(query, args...)
351 if err != nil {
352 return nil, err
353 }
354 defer rows.Close()
355
356 submissionMap := make(map[int]*models.PullSubmission)
357
358 for rows.Next() {
359 var submission models.PullSubmission
360 var submissionCreatedStr string
361 var submissionSourceRev, submissionCombined sql.NullString
362 err := rows.Scan(
363 &submission.ID,
364 &submission.PullAt,
365 &submission.RoundNumber,
366 &submission.Patch,
367 &submissionCombined,
368 &submissionCreatedStr,
369 &submissionSourceRev,
370 )
371 if err != nil {
372 return nil, err
373 }
374
375 if t, err := time.Parse(time.RFC3339, submissionCreatedStr); err == nil {
376 submission.Created = t
377 }
378
379 if submissionSourceRev.Valid {
380 submission.SourceRev = submissionSourceRev.String
381 }
382
383 if submissionCombined.Valid {
384 submission.Combined = submissionCombined.String
385 }
386
387 submissionMap[submission.ID] = &submission
388 }
389
390 if err := rows.Err(); err != nil {
391 return nil, err
392 }
393
394 // Get comments for all submissions using GetComments
395 submissionIds := slices.Collect(maps.Keys(submissionMap))
396 comments, err := GetComments(e, orm.FilterIn("pull_submission_id", submissionIds))
397 if err != nil {
398 return nil, fmt.Errorf("failed to get pull comments: %w", err)
399 }
400 for _, comment := range comments {
401 if comment.PullSubmissionId != nil {
402 if submission, ok := submissionMap[*comment.PullSubmissionId]; ok {
403 submission.Comments = append(submission.Comments, comment)
404 }
405 }
406 }
407
408 // group the submissions by pull_at
409 m := make(map[syntax.ATURI][]*models.PullSubmission)
410 for _, s := range submissionMap {
411 m[s.PullAt] = append(m[s.PullAt], s)
412 }
413
414 // sort each one by round number
415 for _, s := range m {
416 slices.SortFunc(s, func(a, b *models.PullSubmission) int {
417 return cmp.Compare(a.RoundNumber, b.RoundNumber)
418 })
419 }
420
421 return m, nil
422}
423
424// timeframe here is directly passed into the sql query filter, and any
425// timeframe in the past should be negative; e.g.: "-3 months"
426func GetPullsByOwnerDid(e Execer, did, timeframe string) ([]models.Pull, error) {
427 var pulls []models.Pull
428
429 rows, err := e.Query(`
430 select
431 p.owner_did,
432 p.repo_at,
433 p.pull_id,
434 p.created,
435 p.title,
436 p.state,
437 r.did,
438 r.name,
439 r.knot,
440 r.rkey,
441 r.created
442 from
443 pulls p
444 join
445 repos r on p.repo_at = r.at_uri
446 where
447 p.owner_did = ? and p.created >= date ('now', ?)
448 order by
449 p.created desc`, did, timeframe)
450 if err != nil {
451 return nil, err
452 }
453 defer rows.Close()
454
455 for rows.Next() {
456 var pull models.Pull
457 var repo models.Repo
458 var pullCreatedAt, repoCreatedAt string
459 err := rows.Scan(
460 &pull.OwnerDid,
461 &pull.RepoAt,
462 &pull.PullId,
463 &pullCreatedAt,
464 &pull.Title,
465 &pull.State,
466 &repo.Did,
467 &repo.Name,
468 &repo.Knot,
469 &repo.Rkey,
470 &repoCreatedAt,
471 )
472 if err != nil {
473 return nil, err
474 }
475
476 pullCreatedTime, err := time.Parse(time.RFC3339, pullCreatedAt)
477 if err != nil {
478 return nil, err
479 }
480 pull.Created = pullCreatedTime
481
482 repoCreatedTime, err := time.Parse(time.RFC3339, repoCreatedAt)
483 if err != nil {
484 return nil, err
485 }
486 repo.Created = repoCreatedTime
487
488 pull.Repo = &repo
489
490 pulls = append(pulls, pull)
491 }
492
493 if err := rows.Err(); err != nil {
494 return nil, err
495 }
496
497 return pulls, nil
498}
499
500func SetPullState(e Execer, repoAt syntax.ATURI, pullId int, pullState models.PullState) error {
501 _, err := e.Exec(
502 `update pulls set state = ? where repo_at = ? and pull_id = ? and (state <> ? or state <> ?)`,
503 pullState,
504 repoAt,
505 pullId,
506 models.PullDeleted, // only update state of non-deleted pulls
507 models.PullMerged, // only update state of non-merged pulls
508 )
509 return err
510}
511
512func ClosePull(e Execer, repoAt syntax.ATURI, pullId int) error {
513 err := SetPullState(e, repoAt, pullId, models.PullClosed)
514 return err
515}
516
517func ReopenPull(e Execer, repoAt syntax.ATURI, pullId int) error {
518 err := SetPullState(e, repoAt, pullId, models.PullOpen)
519 return err
520}
521
522func MergePull(e Execer, repoAt syntax.ATURI, pullId int) error {
523 err := SetPullState(e, repoAt, pullId, models.PullMerged)
524 return err
525}
526
527func DeletePull(e Execer, repoAt syntax.ATURI, pullId int) error {
528 err := SetPullState(e, repoAt, pullId, models.PullDeleted)
529 return err
530}
531
532func ResubmitPull(e Execer, pullAt syntax.ATURI, newRoundNumber int, newPatch string, combinedPatch string, newSourceRev string) error {
533 _, err := e.Exec(`
534 insert into pull_submissions (pull_at, round_number, patch, combined, source_rev)
535 values (?, ?, ?, ?, ?)
536 `, pullAt, newRoundNumber, newPatch, combinedPatch, newSourceRev)
537
538 return err
539}
540
541func SetPullParentChangeId(e Execer, parentChangeId string, filters ...orm.Filter) error {
542 var conditions []string
543 var args []any
544
545 args = append(args, parentChangeId)
546
547 for _, filter := range filters {
548 conditions = append(conditions, filter.Condition())
549 args = append(args, filter.Arg()...)
550 }
551
552 whereClause := ""
553 if conditions != nil {
554 whereClause = " where " + strings.Join(conditions, " and ")
555 }
556
557 query := fmt.Sprintf("update pulls set parent_change_id = ? %s", whereClause)
558 _, err := e.Exec(query, args...)
559
560 return err
561}
562
563// Only used when stacking to update contents in the event of a rebase (the interdiff should be empty).
564// otherwise submissions are immutable
565func UpdatePull(e Execer, newPatch, sourceRev string, filters ...orm.Filter) error {
566 var conditions []string
567 var args []any
568
569 args = append(args, sourceRev)
570 args = append(args, newPatch)
571
572 for _, filter := range filters {
573 conditions = append(conditions, filter.Condition())
574 args = append(args, filter.Arg()...)
575 }
576
577 whereClause := ""
578 if conditions != nil {
579 whereClause = " where " + strings.Join(conditions, " and ")
580 }
581
582 query := fmt.Sprintf("update pull_submissions set source_rev = ?, patch = ? %s", whereClause)
583 _, err := e.Exec(query, args...)
584
585 return err
586}
587
588func GetPullCount(e Execer, repoAt syntax.ATURI) (models.PullCount, error) {
589 row := e.QueryRow(`
590 select
591 count(case when state = ? then 1 end) as open_count,
592 count(case when state = ? then 1 end) as merged_count,
593 count(case when state = ? then 1 end) as closed_count,
594 count(case when state = ? then 1 end) as deleted_count
595 from pulls
596 where repo_at = ?`,
597 models.PullOpen,
598 models.PullMerged,
599 models.PullClosed,
600 models.PullDeleted,
601 repoAt,
602 )
603
604 var count models.PullCount
605 if err := row.Scan(&count.Open, &count.Merged, &count.Closed, &count.Deleted); err != nil {
606 return models.PullCount{Open: 0, Merged: 0, Closed: 0, Deleted: 0}, err
607 }
608
609 return count, nil
610}
611
612// change-id parent-change-id
613//
614// 4 w ,-------- z (TOP)
615// 3 z <----',------- y
616// 2 y <-----',------ x
617// 1 x <------' nil (BOT)
618//
619// `w` is parent of none, so it is the top of the stack
620func GetStack(e Execer, stackId string) (models.Stack, error) {
621 unorderedPulls, err := GetPulls(
622 e,
623 orm.FilterEq("stack_id", stackId),
624 orm.FilterNotEq("state", models.PullDeleted),
625 )
626 if err != nil {
627 return nil, err
628 }
629 // map of parent-change-id to pull
630 changeIdMap := make(map[string]*models.Pull, len(unorderedPulls))
631 parentMap := make(map[string]*models.Pull, len(unorderedPulls))
632 for _, p := range unorderedPulls {
633 changeIdMap[p.ChangeId] = p
634 if p.ParentChangeId != "" {
635 parentMap[p.ParentChangeId] = p
636 }
637 }
638
639 // the top of the stack is the pull that is not a parent of any pull
640 var topPull *models.Pull
641 for _, maybeTop := range unorderedPulls {
642 if _, ok := parentMap[maybeTop.ChangeId]; !ok {
643 topPull = maybeTop
644 break
645 }
646 }
647
648 pulls := []*models.Pull{}
649 for {
650 pulls = append(pulls, topPull)
651 if topPull.ParentChangeId != "" {
652 if next, ok := changeIdMap[topPull.ParentChangeId]; ok {
653 topPull = next
654 } else {
655 return nil, fmt.Errorf("failed to find parent pull request, stack is malformed")
656 }
657 } else {
658 break
659 }
660 }
661
662 return pulls, nil
663}
664
665func GetAbandonedPulls(e Execer, stackId string) ([]*models.Pull, error) {
666 pulls, err := GetPulls(
667 e,
668 orm.FilterEq("stack_id", stackId),
669 orm.FilterEq("state", models.PullDeleted),
670 )
671 if err != nil {
672 return nil, err
673 }
674
675 return pulls, nil
676}