Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'bpf-fix-backward-progress-bug-in-bpf_iter_udp'

Martin KaFai Lau says:

====================
bpf: Fix backward progress bug in bpf_iter_udp

From: Martin KaFai Lau <martin.lau@kernel.org>

This patch set fixes an issue in bpf_iter_udp that makes backward
progress and prevents the user space process from finishing. There is
a test at the end to reproduce the bug.

Please see individual patches for details.

v3:
- Fixed the iter_fd check and local_port check in the
patch 3 selftest. (Yonghong)
- Moved jhash2 to test_jhash.h in the patch 3. (Yonghong)
- Added explanation in the bucket selection in the patch 3. (Yonghong)

v2:
- Added patch 1 to fix another bug that goes back to
the previous bucket
- Simplify the fix in patch 2 to always reset iter->offset to 0
- Add a test case to close all udp_sk in a bucket while
in the middle of the iteration.
====================

Link: https://lore.kernel.org/r/20240112190530.3751661-1-martin.lau@linux.dev
Signed-off-by: Alexei Starovoitov <ast@kernel.org>

+270 -12
+10 -12
net/ipv4/udp.c
··· 3137 3137 struct bpf_udp_iter_state *iter = seq->private; 3138 3138 struct udp_iter_state *state = &iter->state; 3139 3139 struct net *net = seq_file_net(seq); 3140 + int resume_bucket, resume_offset; 3140 3141 struct udp_table *udptable; 3141 3142 unsigned int batch_sks = 0; 3142 3143 bool resized = false; 3143 3144 struct sock *sk; 3144 3145 3146 + resume_bucket = state->bucket; 3147 + resume_offset = iter->offset; 3148 + 3145 3149 /* The current batch is done, so advance the bucket. */ 3146 - if (iter->st_bucket_done) { 3150 + if (iter->st_bucket_done) 3147 3151 state->bucket++; 3148 - iter->offset = 0; 3149 - } 3150 3152 3151 3153 udptable = udp_get_table_seq(seq, net); 3152 3154 ··· 3168 3166 for (; state->bucket <= udptable->mask; state->bucket++) { 3169 3167 struct udp_hslot *hslot2 = &udptable->hash2[state->bucket]; 3170 3168 3171 - if (hlist_empty(&hslot2->head)) { 3172 - iter->offset = 0; 3169 + if (hlist_empty(&hslot2->head)) 3173 3170 continue; 3174 - } 3175 3171 3172 + iter->offset = 0; 3176 3173 spin_lock_bh(&hslot2->lock); 3177 3174 udp_portaddr_for_each_entry(sk, &hslot2->head) { 3178 3175 if (seq_sk_match(seq, sk)) { 3179 3176 /* Resume from the last iterated socket at the 3180 3177 * offset in the bucket before iterator was stopped. 3181 3178 */ 3182 - if (iter->offset) { 3183 - --iter->offset; 3179 + if (state->bucket == resume_bucket && 3180 + iter->offset < resume_offset) { 3181 + ++iter->offset; 3184 3182 continue; 3185 3183 } 3186 3184 if (iter->end_sk < iter->max_sk) { ··· 3194 3192 3195 3193 if (iter->end_sk) 3196 3194 break; 3197 - 3198 - /* Reset the current bucket's offset before moving to the next bucket. */ 3199 - iter->offset = 0; 3200 3195 } 3201 3196 3202 3197 /* All done: no batch made. */ ··· 3212 3213 /* After allocating a larger batch, retry one more time to grab 3213 3214 * the whole bucket. 3214 3215 */ 3215 - state->bucket--; 3216 3216 goto again; 3217 3217 } 3218 3218 done:
+135
tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + // Copyright (c) 2024 Meta 3 + 4 + #include <test_progs.h> 5 + #include "network_helpers.h" 6 + #include "sock_iter_batch.skel.h" 7 + 8 + #define TEST_NS "sock_iter_batch_netns" 9 + 10 + static const int nr_soreuse = 4; 11 + 12 + static void do_test(int sock_type, bool onebyone) 13 + { 14 + int err, i, nread, to_read, total_read, iter_fd = -1; 15 + int first_idx, second_idx, indices[nr_soreuse]; 16 + struct bpf_link *link = NULL; 17 + struct sock_iter_batch *skel; 18 + int *fds[2] = {}; 19 + 20 + skel = sock_iter_batch__open(); 21 + if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) 22 + return; 23 + 24 + /* Prepare 2 buckets of sockets in the kernel hashtable */ 25 + for (i = 0; i < ARRAY_SIZE(fds); i++) { 26 + int local_port; 27 + 28 + fds[i] = start_reuseport_server(AF_INET6, sock_type, "::1", 0, 0, 29 + nr_soreuse); 30 + if (!ASSERT_OK_PTR(fds[i], "start_reuseport_server")) 31 + goto done; 32 + local_port = get_socket_local_port(*fds[i]); 33 + if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) 34 + goto done; 35 + skel->rodata->ports[i] = ntohs(local_port); 36 + } 37 + 38 + err = sock_iter_batch__load(skel); 39 + if (!ASSERT_OK(err, "sock_iter_batch__load")) 40 + goto done; 41 + 42 + link = bpf_program__attach_iter(sock_type == SOCK_STREAM ? 43 + skel->progs.iter_tcp_soreuse : 44 + skel->progs.iter_udp_soreuse, 45 + NULL); 46 + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) 47 + goto done; 48 + 49 + iter_fd = bpf_iter_create(bpf_link__fd(link)); 50 + if (!ASSERT_GE(iter_fd, 0, "bpf_iter_create")) 51 + goto done; 52 + 53 + /* Test reading a bucket (either from fds[0] or fds[1]). 54 + * Only read "nr_soreuse - 1" number of sockets 55 + * from a bucket and leave one socket out from 56 + * that bucket on purpose. 57 + */ 58 + to_read = (nr_soreuse - 1) * sizeof(*indices); 59 + total_read = 0; 60 + first_idx = -1; 61 + do { 62 + nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); 63 + if (nread <= 0 || nread % sizeof(*indices)) 64 + break; 65 + total_read += nread; 66 + 67 + if (first_idx == -1) 68 + first_idx = indices[0]; 69 + for (i = 0; i < nread / sizeof(*indices); i++) 70 + ASSERT_EQ(indices[i], first_idx, "first_idx"); 71 + } while (total_read < to_read); 72 + ASSERT_EQ(nread, onebyone ? sizeof(*indices) : to_read, "nread"); 73 + ASSERT_EQ(total_read, to_read, "total_read"); 74 + 75 + free_fds(fds[first_idx], nr_soreuse); 76 + fds[first_idx] = NULL; 77 + 78 + /* Read the "whole" second bucket */ 79 + to_read = nr_soreuse * sizeof(*indices); 80 + total_read = 0; 81 + second_idx = !first_idx; 82 + do { 83 + nread = read(iter_fd, indices, onebyone ? sizeof(*indices) : to_read); 84 + if (nread <= 0 || nread % sizeof(*indices)) 85 + break; 86 + total_read += nread; 87 + 88 + for (i = 0; i < nread / sizeof(*indices); i++) 89 + ASSERT_EQ(indices[i], second_idx, "second_idx"); 90 + } while (total_read <= to_read); 91 + ASSERT_EQ(nread, 0, "nread"); 92 + /* Both so_reuseport ports should be in different buckets, so 93 + * total_read must equal to the expected to_read. 94 + * 95 + * For a very unlikely case, both ports collide at the same bucket, 96 + * the bucket offset (i.e. 3) will be skipped and it cannot 97 + * expect the to_read number of bytes. 98 + */ 99 + if (skel->bss->bucket[0] != skel->bss->bucket[1]) 100 + ASSERT_EQ(total_read, to_read, "total_read"); 101 + 102 + done: 103 + for (i = 0; i < ARRAY_SIZE(fds); i++) 104 + free_fds(fds[i], nr_soreuse); 105 + if (iter_fd < 0) 106 + close(iter_fd); 107 + bpf_link__destroy(link); 108 + sock_iter_batch__destroy(skel); 109 + } 110 + 111 + void test_sock_iter_batch(void) 112 + { 113 + struct nstoken *nstoken = NULL; 114 + 115 + SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null"); 116 + SYS(done, "ip netns add %s", TEST_NS); 117 + SYS(done, "ip -net %s link set dev lo up", TEST_NS); 118 + 119 + nstoken = open_netns(TEST_NS); 120 + if (!ASSERT_OK_PTR(nstoken, "open_netns")) 121 + goto done; 122 + 123 + if (test__start_subtest("tcp")) { 124 + do_test(SOCK_STREAM, true); 125 + do_test(SOCK_STREAM, false); 126 + } 127 + if (test__start_subtest("udp")) { 128 + do_test(SOCK_DGRAM, true); 129 + do_test(SOCK_DGRAM, false); 130 + } 131 + close_netns(nstoken); 132 + 133 + done: 134 + SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null"); 135 + }
+3
tools/testing/selftests/bpf/progs/bpf_tracing_net.h
··· 72 72 #define inet_rcv_saddr sk.__sk_common.skc_rcv_saddr 73 73 #define inet_dport sk.__sk_common.skc_dport 74 74 75 + #define udp_portaddr_hash inet.sk.__sk_common.skc_u16hashes[1] 76 + 75 77 #define ir_loc_addr req.__req_common.skc_rcv_saddr 76 78 #define ir_num req.__req_common.skc_num 77 79 #define ir_rmt_addr req.__req_common.skc_daddr ··· 87 85 #define sk_rmem_alloc sk_backlog.rmem_alloc 88 86 #define sk_refcnt __sk_common.skc_refcnt 89 87 #define sk_state __sk_common.skc_state 88 + #define sk_net __sk_common.skc_net 90 89 #define sk_v6_daddr __sk_common.skc_v6_daddr 91 90 #define sk_v6_rcv_saddr __sk_common.skc_v6_rcv_saddr 92 91 #define sk_flags __sk_common.skc_flags
+91
tools/testing/selftests/bpf/progs/sock_iter_batch.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + // Copyright (c) 2024 Meta 3 + 4 + #include "vmlinux.h" 5 + #include <bpf/bpf_helpers.h> 6 + #include <bpf/bpf_core_read.h> 7 + #include <bpf/bpf_endian.h> 8 + #include "bpf_tracing_net.h" 9 + #include "bpf_kfuncs.h" 10 + 11 + #define ATTR __always_inline 12 + #include "test_jhash.h" 13 + 14 + static bool ipv6_addr_loopback(const struct in6_addr *a) 15 + { 16 + return (a->s6_addr32[0] | a->s6_addr32[1] | 17 + a->s6_addr32[2] | (a->s6_addr32[3] ^ bpf_htonl(1))) == 0; 18 + } 19 + 20 + volatile const __u16 ports[2]; 21 + unsigned int bucket[2]; 22 + 23 + SEC("iter/tcp") 24 + int iter_tcp_soreuse(struct bpf_iter__tcp *ctx) 25 + { 26 + struct sock *sk = (struct sock *)ctx->sk_common; 27 + struct inet_hashinfo *hinfo; 28 + unsigned int hash; 29 + struct net *net; 30 + int idx; 31 + 32 + if (!sk) 33 + return 0; 34 + 35 + sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock)); 36 + if (sk->sk_family != AF_INET6 || 37 + sk->sk_state != TCP_LISTEN || 38 + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) 39 + return 0; 40 + 41 + if (sk->sk_num == ports[0]) 42 + idx = 0; 43 + else if (sk->sk_num == ports[1]) 44 + idx = 1; 45 + else 46 + return 0; 47 + 48 + /* bucket selection as in inet_lhash2_bucket_sk() */ 49 + net = sk->sk_net.net; 50 + hash = jhash2(sk->sk_v6_rcv_saddr.s6_addr32, 4, net->hash_mix); 51 + hash ^= sk->sk_num; 52 + hinfo = net->ipv4.tcp_death_row.hashinfo; 53 + bucket[idx] = hash & hinfo->lhash2_mask; 54 + bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); 55 + 56 + return 0; 57 + } 58 + 59 + #define udp_sk(ptr) container_of(ptr, struct udp_sock, inet.sk) 60 + 61 + SEC("iter/udp") 62 + int iter_udp_soreuse(struct bpf_iter__udp *ctx) 63 + { 64 + struct sock *sk = (struct sock *)ctx->udp_sk; 65 + struct udp_table *udptable; 66 + int idx; 67 + 68 + if (!sk) 69 + return 0; 70 + 71 + sk = bpf_rdonly_cast(sk, bpf_core_type_id_kernel(struct sock)); 72 + if (sk->sk_family != AF_INET6 || 73 + !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr)) 74 + return 0; 75 + 76 + if (sk->sk_num == ports[0]) 77 + idx = 0; 78 + else if (sk->sk_num == ports[1]) 79 + idx = 1; 80 + else 81 + return 0; 82 + 83 + /* bucket selection as in udp_hashslot2() */ 84 + udptable = sk->sk_net.net->ipv4.udp_table; 85 + bucket[idx] = udp_sk(sk)->udp_portaddr_hash & udptable->mask; 86 + bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); 87 + 88 + return 0; 89 + } 90 + 91 + char _license[] SEC("license") = "GPL";
+31
tools/testing/selftests/bpf/progs/test_jhash.h
··· 69 69 70 70 return c; 71 71 } 72 + 73 + static __always_inline u32 jhash2(const u32 *k, u32 length, u32 initval) 74 + { 75 + u32 a, b, c; 76 + 77 + /* Set up the internal state */ 78 + a = b = c = JHASH_INITVAL + (length<<2) + initval; 79 + 80 + /* Handle most of the key */ 81 + while (length > 3) { 82 + a += k[0]; 83 + b += k[1]; 84 + c += k[2]; 85 + __jhash_mix(a, b, c); 86 + length -= 3; 87 + k += 3; 88 + } 89 + 90 + /* Handle the last 3 u32's */ 91 + switch (length) { 92 + case 3: c += k[2]; 93 + case 2: b += k[1]; 94 + case 1: a += k[0]; 95 + __jhash_final(a, b, c); 96 + break; 97 + case 0: /* Nothing left to add */ 98 + break; 99 + } 100 + 101 + return c; 102 + }