Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

bpf: keep a reference to the mm, in case the task is dead.

Fix the system crash that happens when a task iterator travel through
vma of tasks.

In task iterators, we used to access mm by following the pointer on
the task_struct; however, the death of a task will clear the pointer,
even though we still hold the task_struct. That can cause an
unexpected crash for a null pointer when an iterator is visiting a
task that dies during the visit. Keeping a reference of mm on the
iterator ensures we always have a valid pointer to mm.

Co-developed-by: Song Liu <song@kernel.org>
Signed-off-by: Song Liu <song@kernel.org>
Signed-off-by: Kui-Feng Lee <kuifeng@meta.com>
Reported-by: Nathan Slingerland <slinger@meta.com>
Acked-by: Yonghong Song <yhs@fb.com>
Link: https://lore.kernel.org/r/20221216221855.4122288-2-kuifeng@meta.com
Signed-off-by: Alexei Starovoitov <ast@kernel.org>

authored by

Kui-Feng Lee and committed by
Alexei Starovoitov
7ff94f27 8f161ca1

+27 -12
+27 -12
kernel/bpf/task_iter.c
··· 438 438 */ 439 439 struct bpf_iter_seq_task_common common; 440 440 struct task_struct *task; 441 + struct mm_struct *mm; 441 442 struct vm_area_struct *vma; 442 443 u32 tid; 443 444 unsigned long prev_vm_start; ··· 457 456 enum bpf_task_vma_iter_find_op op; 458 457 struct vm_area_struct *curr_vma; 459 458 struct task_struct *curr_task; 459 + struct mm_struct *curr_mm; 460 460 u32 saved_tid = info->tid; 461 461 462 462 /* If this function returns a non-NULL vma, it holds a reference to 463 - * the task_struct, and holds read lock on vma->mm->mmap_lock. 463 + * the task_struct, holds a refcount on mm->mm_users, and holds 464 + * read lock on vma->mm->mmap_lock. 464 465 * If this function returns NULL, it does not hold any reference or 465 466 * lock. 466 467 */ 467 468 if (info->task) { 468 469 curr_task = info->task; 469 470 curr_vma = info->vma; 471 + curr_mm = info->mm; 470 472 /* In case of lock contention, drop mmap_lock to unblock 471 473 * the writer. 472 474 * ··· 508 504 * 4.2) VMA2 and VMA2' covers different ranges, process 509 505 * VMA2'. 510 506 */ 511 - if (mmap_lock_is_contended(curr_task->mm)) { 507 + if (mmap_lock_is_contended(curr_mm)) { 512 508 info->prev_vm_start = curr_vma->vm_start; 513 509 info->prev_vm_end = curr_vma->vm_end; 514 510 op = task_vma_iter_find_vma; 515 - mmap_read_unlock(curr_task->mm); 516 - if (mmap_read_lock_killable(curr_task->mm)) 511 + mmap_read_unlock(curr_mm); 512 + if (mmap_read_lock_killable(curr_mm)) { 513 + mmput(curr_mm); 517 514 goto finish; 515 + } 518 516 } else { 519 517 op = task_vma_iter_next_vma; 520 518 } ··· 541 535 op = task_vma_iter_find_vma; 542 536 } 543 537 544 - if (!curr_task->mm) 538 + curr_mm = get_task_mm(curr_task); 539 + if (!curr_mm) 545 540 goto next_task; 546 541 547 - if (mmap_read_lock_killable(curr_task->mm)) 542 + if (mmap_read_lock_killable(curr_mm)) { 543 + mmput(curr_mm); 548 544 goto finish; 545 + } 549 546 } 550 547 551 548 switch (op) { 552 549 case task_vma_iter_first_vma: 553 - curr_vma = find_vma(curr_task->mm, 0); 550 + curr_vma = find_vma(curr_mm, 0); 554 551 break; 555 552 case task_vma_iter_next_vma: 556 - curr_vma = find_vma(curr_task->mm, curr_vma->vm_end); 553 + curr_vma = find_vma(curr_mm, curr_vma->vm_end); 557 554 break; 558 555 case task_vma_iter_find_vma: 559 556 /* We dropped mmap_lock so it is necessary to use find_vma 560 557 * to find the next vma. This is similar to the mechanism 561 558 * in show_smaps_rollup(). 562 559 */ 563 - curr_vma = find_vma(curr_task->mm, info->prev_vm_end - 1); 560 + curr_vma = find_vma(curr_mm, info->prev_vm_end - 1); 564 561 /* case 1) and 4.2) above just use curr_vma */ 565 562 566 563 /* check for case 2) or case 4.1) above */ 567 564 if (curr_vma && 568 565 curr_vma->vm_start == info->prev_vm_start && 569 566 curr_vma->vm_end == info->prev_vm_end) 570 - curr_vma = find_vma(curr_task->mm, curr_vma->vm_end); 567 + curr_vma = find_vma(curr_mm, curr_vma->vm_end); 571 568 break; 572 569 } 573 570 if (!curr_vma) { 574 571 /* case 3) above, or case 2) 4.1) with vma->next == NULL */ 575 - mmap_read_unlock(curr_task->mm); 572 + mmap_read_unlock(curr_mm); 573 + mmput(curr_mm); 576 574 goto next_task; 577 575 } 578 576 info->task = curr_task; 579 577 info->vma = curr_vma; 578 + info->mm = curr_mm; 580 579 return curr_vma; 581 580 582 581 next_task: ··· 590 579 591 580 put_task_struct(curr_task); 592 581 info->task = NULL; 582 + info->mm = NULL; 593 583 info->tid++; 594 584 goto again; 595 585 ··· 599 587 put_task_struct(curr_task); 600 588 info->task = NULL; 601 589 info->vma = NULL; 590 + info->mm = NULL; 602 591 return NULL; 603 592 } 604 593 ··· 671 658 */ 672 659 info->prev_vm_start = ~0UL; 673 660 info->prev_vm_end = info->vma->vm_end; 674 - mmap_read_unlock(info->task->mm); 661 + mmap_read_unlock(info->mm); 662 + mmput(info->mm); 663 + info->mm = NULL; 675 664 put_task_struct(info->task); 676 665 info->task = NULL; 677 666 }