Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

mm: convert mm's rss stats into percpu_counter

Currently mm_struct maintains rss_stats which are updated on page fault
and the unmapping codepaths. For page fault codepath the updates are
cached per thread with the batch of TASK_RSS_EVENTS_THRESH which is 64.
The reason for caching is performance for multithreaded applications
otherwise the rss_stats updates may become hotspot for such applications.

However this optimization comes with the cost of error margin in the rss
stats. The rss_stats for applications with large number of threads can be
very skewed. At worst the error margin is (nr_threads * 64) and we have a
lot of applications with 100s of threads, so the error margin can be very
high. Internally we had to reduce TASK_RSS_EVENTS_THRESH to 32.

Recently we started seeing the unbounded errors for rss_stats for specific
applications which use TCP rx0cp. It seems like vm_insert_pages()
codepath does not sync rss_stats at all.

This patch converts the rss_stats into percpu_counter to convert the error
margin from (nr_threads * 64) to approximately (nr_cpus ^ 2). However
this conversion enable us to get the accurate stats for situations where
accuracy is more important than the cpu cost.

This patch does not make such tradeoffs - we can just use
percpu_counter_add_local() for the updates and percpu_counter_sum() (or
percpu_counter_sync() + percpu_counter_read) for the readers. At the
moment the readers are either procfs interface, oom_killer and memory
reclaim which I think are not performance critical and should be ok with
slow read. However I think we can make that change in a separate patch.

Link: https://lkml.kernel.org/r/20221024052841.3291983-1-shakeelb@google.com
Signed-off-by: Shakeel Butt <shakeelb@google.com>
Cc: Marek Szyprowski <m.szyprowski@samsung.com>
Signed-off-by: Andrew Morton <akpm@linux-foundation.org>

authored by

Shakeel Butt and committed by
Andrew Morton
f1a79412 9cd6ffa6

+40 -107
+8 -18
include/linux/mm.h
··· 2052 2052 */ 2053 2053 static inline unsigned long get_mm_counter(struct mm_struct *mm, int member) 2054 2054 { 2055 - long val = atomic_long_read(&mm->rss_stat.count[member]); 2056 - 2057 - #ifdef SPLIT_RSS_COUNTING 2058 - /* 2059 - * counter is updated in asynchronous manner and may go to minus. 2060 - * But it's never be expected number for users. 2061 - */ 2062 - if (val < 0) 2063 - val = 0; 2064 - #endif 2065 - return (unsigned long)val; 2055 + return percpu_counter_read_positive(&mm->rss_stat[member]); 2066 2056 } 2067 2057 2068 - void mm_trace_rss_stat(struct mm_struct *mm, int member, long count); 2058 + void mm_trace_rss_stat(struct mm_struct *mm, int member); 2069 2059 2070 2060 static inline void add_mm_counter(struct mm_struct *mm, int member, long value) 2071 2061 { 2072 - long count = atomic_long_add_return(value, &mm->rss_stat.count[member]); 2062 + percpu_counter_add(&mm->rss_stat[member], value); 2073 2063 2074 - mm_trace_rss_stat(mm, member, count); 2064 + mm_trace_rss_stat(mm, member); 2075 2065 } 2076 2066 2077 2067 static inline void inc_mm_counter(struct mm_struct *mm, int member) 2078 2068 { 2079 - long count = atomic_long_inc_return(&mm->rss_stat.count[member]); 2069 + percpu_counter_inc(&mm->rss_stat[member]); 2080 2070 2081 - mm_trace_rss_stat(mm, member, count); 2071 + mm_trace_rss_stat(mm, member); 2082 2072 } 2083 2073 2084 2074 static inline void dec_mm_counter(struct mm_struct *mm, int member) 2085 2075 { 2086 - long count = atomic_long_dec_return(&mm->rss_stat.count[member]); 2076 + percpu_counter_dec(&mm->rss_stat[member]); 2087 2077 2088 - mm_trace_rss_stat(mm, member, count); 2078 + mm_trace_rss_stat(mm, member); 2089 2079 } 2090 2080 2091 2081 /* Optimized variant when page is already known not to be PageAnon */
+2 -5
include/linux/mm_types.h
··· 18 18 #include <linux/page-flags-layout.h> 19 19 #include <linux/workqueue.h> 20 20 #include <linux/seqlock.h> 21 + #include <linux/percpu_counter.h> 21 22 22 23 #include <asm/mmu.h> 23 24 ··· 627 626 628 627 unsigned long saved_auxv[AT_VECTOR_SIZE]; /* for /proc/PID/auxv */ 629 628 630 - /* 631 - * Special counters, in some configurations protected by the 632 - * page_table_lock, in other configurations by being atomic. 633 - */ 634 - struct mm_rss_stat rss_stat; 629 + struct percpu_counter rss_stat[NR_MM_COUNTERS]; 635 630 636 631 struct linux_binfmt *binfmt; 637 632
-13
include/linux/mm_types_task.h
··· 36 36 NR_MM_COUNTERS 37 37 }; 38 38 39 - #if USE_SPLIT_PTE_PTLOCKS && defined(CONFIG_MMU) 40 - #define SPLIT_RSS_COUNTING 41 - /* per-thread cached information, */ 42 - struct task_rss_stat { 43 - int events; /* for synchronization threshold */ 44 - int count[NR_MM_COUNTERS]; 45 - }; 46 - #endif /* USE_SPLIT_PTE_PTLOCKS */ 47 - 48 - struct mm_rss_stat { 49 - atomic_long_t count[NR_MM_COUNTERS]; 50 - }; 51 - 52 39 struct page_frag { 53 40 struct page *page; 54 41 #if (BITS_PER_LONG > 32) || (PAGE_SIZE >= 65536)
-1
include/linux/percpu_counter.h
··· 13 13 #include <linux/threads.h> 14 14 #include <linux/percpu.h> 15 15 #include <linux/types.h> 16 - #include <linux/gfp.h> 17 16 18 17 /* percpu_counter batch for local add or sub */ 19 18 #define PERCPU_COUNTER_LOCAL_BATCH INT_MAX
-3
include/linux/sched.h
··· 870 870 struct mm_struct *mm; 871 871 struct mm_struct *active_mm; 872 872 873 - #ifdef SPLIT_RSS_COUNTING 874 - struct task_rss_stat rss_stat; 875 - #endif 876 873 int exit_state; 877 874 int exit_code; 878 875 int exit_signal;
+4 -4
include/trace/events/kmem.h
··· 346 346 TRACE_EVENT(rss_stat, 347 347 348 348 TP_PROTO(struct mm_struct *mm, 349 - int member, 350 - long count), 349 + int member), 351 350 352 - TP_ARGS(mm, member, count), 351 + TP_ARGS(mm, member), 353 352 354 353 TP_STRUCT__entry( 355 354 __field(unsigned int, mm_id) ··· 361 362 __entry->mm_id = mm_ptr_to_hash(mm); 362 363 __entry->curr = !!(current->mm == mm); 363 364 __entry->member = member; 364 - __entry->size = (count << PAGE_SHIFT); 365 + __entry->size = (percpu_counter_sum_positive(&mm->rss_stat[member]) 366 + << PAGE_SHIFT); 365 367 ), 366 368 367 369 TP_printk("mm_id=%u curr=%d type=%s size=%ldB",
+15 -1
kernel/fork.c
··· 753 753 "Please make sure 'struct resident_page_types[]' is updated as well"); 754 754 755 755 for (i = 0; i < NR_MM_COUNTERS; i++) { 756 - long x = atomic_long_read(&mm->rss_stat.count[i]); 756 + long x = percpu_counter_sum(&mm->rss_stat[i]); 757 757 758 758 if (unlikely(x)) 759 759 pr_alert("BUG: Bad rss-counter state mm:%p type:%s val:%ld\n", ··· 779 779 */ 780 780 void __mmdrop(struct mm_struct *mm) 781 781 { 782 + int i; 783 + 782 784 BUG_ON(mm == &init_mm); 783 785 WARN_ON_ONCE(mm == current->mm); 784 786 WARN_ON_ONCE(mm == current->active_mm); ··· 790 788 check_mm(mm); 791 789 put_user_ns(mm->user_ns); 792 790 mm_pasid_drop(mm); 791 + 792 + for (i = 0; i < NR_MM_COUNTERS; i++) 793 + percpu_counter_destroy(&mm->rss_stat[i]); 793 794 free_mm(mm); 794 795 } 795 796 EXPORT_SYMBOL_GPL(__mmdrop); ··· 1112 1107 static struct mm_struct *mm_init(struct mm_struct *mm, struct task_struct *p, 1113 1108 struct user_namespace *user_ns) 1114 1109 { 1110 + int i; 1111 + 1115 1112 mt_init_flags(&mm->mm_mt, MM_MT_FLAGS); 1116 1113 mt_set_external_lock(&mm->mm_mt, &mm->mmap_lock); 1117 1114 atomic_set(&mm->mm_users, 1); ··· 1155 1148 if (init_new_context(p, mm)) 1156 1149 goto fail_nocontext; 1157 1150 1151 + for (i = 0; i < NR_MM_COUNTERS; i++) 1152 + if (percpu_counter_init(&mm->rss_stat[i], 0, GFP_KERNEL_ACCOUNT)) 1153 + goto fail_pcpu; 1154 + 1158 1155 mm->user_ns = get_user_ns(user_ns); 1159 1156 lru_gen_init_mm(mm); 1160 1157 return mm; 1161 1158 1159 + fail_pcpu: 1160 + while (i > 0) 1161 + percpu_counter_destroy(&mm->rss_stat[--i]); 1162 1162 fail_nocontext: 1163 1163 mm_free_pgd(mm); 1164 1164 fail_nopgd:
+11 -62
mm/memory.c
··· 162 162 } 163 163 early_initcall(init_zero_pfn); 164 164 165 - void mm_trace_rss_stat(struct mm_struct *mm, int member, long count) 165 + void mm_trace_rss_stat(struct mm_struct *mm, int member) 166 166 { 167 - trace_rss_stat(mm, member, count); 167 + trace_rss_stat(mm, member); 168 168 } 169 - 170 - #if defined(SPLIT_RSS_COUNTING) 171 - 172 - void sync_mm_rss(struct mm_struct *mm) 173 - { 174 - int i; 175 - 176 - for (i = 0; i < NR_MM_COUNTERS; i++) { 177 - if (current->rss_stat.count[i]) { 178 - add_mm_counter(mm, i, current->rss_stat.count[i]); 179 - current->rss_stat.count[i] = 0; 180 - } 181 - } 182 - current->rss_stat.events = 0; 183 - } 184 - 185 - static void add_mm_counter_fast(struct mm_struct *mm, int member, int val) 186 - { 187 - struct task_struct *task = current; 188 - 189 - if (likely(task->mm == mm)) 190 - task->rss_stat.count[member] += val; 191 - else 192 - add_mm_counter(mm, member, val); 193 - } 194 - #define inc_mm_counter_fast(mm, member) add_mm_counter_fast(mm, member, 1) 195 - #define dec_mm_counter_fast(mm, member) add_mm_counter_fast(mm, member, -1) 196 - 197 - /* sync counter once per 64 page faults */ 198 - #define TASK_RSS_EVENTS_THRESH (64) 199 - static void check_sync_rss_stat(struct task_struct *task) 200 - { 201 - if (unlikely(task != current)) 202 - return; 203 - if (unlikely(task->rss_stat.events++ > TASK_RSS_EVENTS_THRESH)) 204 - sync_mm_rss(task->mm); 205 - } 206 - #else /* SPLIT_RSS_COUNTING */ 207 - 208 - #define inc_mm_counter_fast(mm, member) inc_mm_counter(mm, member) 209 - #define dec_mm_counter_fast(mm, member) dec_mm_counter(mm, member) 210 - 211 - static void check_sync_rss_stat(struct task_struct *task) 212 - { 213 - } 214 - 215 - #endif /* SPLIT_RSS_COUNTING */ 216 169 217 170 /* 218 171 * Note: this doesn't free the actual pages themselves. That ··· 1810 1857 return -EBUSY; 1811 1858 /* Ok, finally just insert the thing.. */ 1812 1859 get_page(page); 1813 - inc_mm_counter_fast(vma->vm_mm, mm_counter_file(page)); 1860 + inc_mm_counter(vma->vm_mm, mm_counter_file(page)); 1814 1861 page_add_file_rmap(page, vma, false); 1815 1862 set_pte_at(vma->vm_mm, addr, pte, mk_pte(page, prot)); 1816 1863 return 0; ··· 3106 3153 if (likely(pte_same(*vmf->pte, vmf->orig_pte))) { 3107 3154 if (old_page) { 3108 3155 if (!PageAnon(old_page)) { 3109 - dec_mm_counter_fast(mm, 3110 - mm_counter_file(old_page)); 3111 - inc_mm_counter_fast(mm, MM_ANONPAGES); 3156 + dec_mm_counter(mm, mm_counter_file(old_page)); 3157 + inc_mm_counter(mm, MM_ANONPAGES); 3112 3158 } 3113 3159 } else { 3114 - inc_mm_counter_fast(mm, MM_ANONPAGES); 3160 + inc_mm_counter(mm, MM_ANONPAGES); 3115 3161 } 3116 3162 flush_cache_page(vma, vmf->address, pte_pfn(vmf->orig_pte)); 3117 3163 entry = mk_pte(new_page, vma->vm_page_prot); ··· 3917 3965 if (should_try_to_free_swap(folio, vma, vmf->flags)) 3918 3966 folio_free_swap(folio); 3919 3967 3920 - inc_mm_counter_fast(vma->vm_mm, MM_ANONPAGES); 3921 - dec_mm_counter_fast(vma->vm_mm, MM_SWAPENTS); 3968 + inc_mm_counter(vma->vm_mm, MM_ANONPAGES); 3969 + dec_mm_counter(vma->vm_mm, MM_SWAPENTS); 3922 3970 pte = mk_pte(page, vma->vm_page_prot); 3923 3971 3924 3972 /* ··· 4098 4146 return handle_userfault(vmf, VM_UFFD_MISSING); 4099 4147 } 4100 4148 4101 - inc_mm_counter_fast(vma->vm_mm, MM_ANONPAGES); 4149 + inc_mm_counter(vma->vm_mm, MM_ANONPAGES); 4102 4150 page_add_new_anon_rmap(page, vma, vmf->address); 4103 4151 lru_cache_add_inactive_or_unevictable(page, vma); 4104 4152 setpte: ··· 4288 4336 entry = pte_mkuffd_wp(pte_wrprotect(entry)); 4289 4337 /* copy-on-write page */ 4290 4338 if (write && !(vma->vm_flags & VM_SHARED)) { 4291 - inc_mm_counter_fast(vma->vm_mm, MM_ANONPAGES); 4339 + inc_mm_counter(vma->vm_mm, MM_ANONPAGES); 4292 4340 page_add_new_anon_rmap(page, vma, addr); 4293 4341 lru_cache_add_inactive_or_unevictable(page, vma); 4294 4342 } else { 4295 - inc_mm_counter_fast(vma->vm_mm, mm_counter_file(page)); 4343 + inc_mm_counter(vma->vm_mm, mm_counter_file(page)); 4296 4344 page_add_file_rmap(page, vma, false); 4297 4345 } 4298 4346 set_pte_at(vma->vm_mm, addr, vmf->pte, entry); ··· 5143 5191 5144 5192 count_vm_event(PGFAULT); 5145 5193 count_memcg_event_mm(vma->vm_mm, PGFAULT); 5146 - 5147 - /* do counter updates before entering really critical section. */ 5148 - check_sync_rss_stat(current); 5149 5194 5150 5195 if (!arch_vma_access_permitted(vma, flags & FAULT_FLAG_WRITE, 5151 5196 flags & FAULT_FLAG_INSTRUCTION,