Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'bpf-tcp-exactly-once-socket-iteration'

Jordan Rife says:

====================
bpf: tcp: Exactly-once socket iteration

TCP socket iterators use iter->offset to track progress through a
bucket, which is a measure of the number of matching sockets from the
current bucket that have been seen or processed by the iterator. On
subsequent iterations, if the current bucket has unprocessed items, we
skip at least iter->offset matching items in the bucket before adding
any remaining items to the next batch. However, iter->offset isn't
always an accurate measure of "things already seen" when the underlying
bucket changes between reads, which can lead to repeated or skipped
sockets. Instead, this series remembers the cookies of the sockets we
haven't seen yet in the current bucket and resumes from the first cookie
in that list that we can find on the next iteration.

This is a continuation of the work started in [1]. This series largely
replicates the patterns applied to UDP socket iterators, applying them
instead to TCP socket iterators.

CHANGES
=======
v5 -> v6:
* In patch ten ("selftests/bpf: Create established sockets in socket
iterator tests"), use poll() to choose a socket that has a connection
ready to be accept()ed. Before, connect_to_server would set the
O_NONBLOCK flag on all listening sockets so that accept_from_one could
loop through them all and find the one that connect_to_addr_str
connected to. However, this is subtly buggy and could potentially lead
to test flakes, since the 3 way handshake isn't necessarily done when
connect returns, so it's possible none of the accept() calls succeed.
Use poll() instead to guarantee that the socket we accept() from is
ready and eliminate the need for the O_NONBLOCK flag (Martin).

v4 -> v5:
* Move WARN_ON_ONCE before the `done` label in patch two ("bpf: tcp:
Make sure iter->batch always contains a full bucket snapshot"")
(Martin).
* Remove unnecessary kfunc declaration in patch eleven ("selftests/bpf:
Create iter_tcp_destroy test program") (Martin).
* Make sure to close the socket fd at the end of `destroy` in patch
twelve ("selftests/bpf: Add tests for bucket resume logic in
established sockets") (Martin).

v3 -> v4:
* Drop braces around sk_nulls_for_each_from in patch five ("bpf: tcp:
Avoid socket skips and repeats during iteration") (Stanislav).
* Add a break after the TCP_SEQ_STATE_ESTABLISHED case in patch five
(Stanislav).
* Add an `if (sock_type == SOCK_STREAM)` check before assigning
TCP_LISTEN to skel->rodata->ss in patch eight ("selftests/bpf: Allow
for iteration over multiple states") to more clearly express the
intent that the option is only consumed for SOCK_STREAM tests
(Stanislav).
* Move the `i = 0` assignment into the for loop in patch ten
("selftests/bpf: Create established sockets in socket iterator
tests") (Stanislav).

v2 -> v3:
* Unroll the loop inside bpf_iter_tcp_batch to make the logic easier to
follow in patch two ("bpf: tcp: Make sure iter->batch always contains
a full bucket snapshot"). This gets rid of the `resizes` variable from
v2 and eliminates the extra conditional that checks how many batch
resize attempts have occurred so far (Stanislav).
Note: This changes the behavior slightly. Before, in the case that
the second call to tcp_seek_last_pos (and later bpf_iter_tcp_resume)
advances to a new bucket, which may happen if the current bucket is
emptied after releasing its lock, the `resizes` "budget" would be
reset, the net effect being that we would try a batch resize with
GFP_USER at most once per bucket. Now, we try to resize the batch
with GFP_USER at most once per call, so it makes it slightly more
likely that we hit the GFP_NOWAIT scenario. However, this edge case
should be rare in practice anyway, and the new behavior is more or
less consistent with the original retry logic, so avoid the loop and
prefer code clarity.
* Move the call to bpf_iter_tcp_put_batch out of
bpf_iter_tcp_realloc_batch and call it directly before invoking
bpf_iter_tcp_realloc_batch with GFP_USER inside bpf_iter_tcp_batch.
/Don't/ call it before invoking bpf_iter_tcp_realloc_batch the second
time while we hold the lock with GFP_NOWAIT. This avoids a conditional
inside bpf_iter_tcp_realloc_batch from v2 that only calls
bpf_iter_tcp_put_batch if flags != GFP_NOWAIT and is a bit more
explicit (Stanislav).
* Adjust patch five ("bpf: tcp: Avoid socket skips and repeats during
iteration") to fit with the new logic in patch two.

v1 -> v2:
* In patch five ("bpf: tcp: Avoid socket skips and repeats during
iteration"), remove unnecessary bucket bounds checks in
bpf_iter_tcp_resume. In either case, if st->bucket is outside the
current table's range then bpf_iter_tcp_resume_* calls *_get_first
which immediately returns NULL anyway and the logic will fall through.
(Martin)
* Add a check at the top of bpf_iter_tcp_resume_listening and
bpf_iter_tcp_resume_established to see if we're done with the current
bucket and advance it immediately instead of wasting time finding the
first matching socket in that bucket with
(listening|established)_get_first. In v1, we originally discussed
adding logic to advance the bucket in bpf_iter_tcp_seq_next and
bpf_iter_tcp_seq_stop, but after trying this the logic seemed harder
to track. Overall, keeping everything inside bpf_iter_tcp_resume_*
seemed a bit clearer. (Martin)
* Instead of using a timeout in the last patch ("selftests/bpf: Add
tests for bucket resume logic in established sockets") to wait for
sockets to leave the ehash table after calling close(), use
bpf_sock_destroy to deterministically destroy and remove them. This
introduces one more patch ("selftests/bpf: Create iter_tcp_destroy
test program") to create the iterator program that destroys a selected
socket. Drive this through a destroy() function in the last patch
which, just like close(), accepts a socket file descriptor. (Martin)
* Introduce one more patch ("selftests/bpf: Allow for iteration over
multiple states") to fix a latent bug in iter_tcp_soreuse where the
sk->sk_state != TCP_LISTEN check was ignored. Add the "ss" variable to
allow test code to configure which socket states to allow.

[1]: https://lore.kernel.org/bpf/20250502161528.264630-1-jordan@jrife.io/
====================

Link: https://patch.msgid.link/20250714180919.127192-1-jordan@jrife.io
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>

+681 -86
+202 -71
net/ipv4/tcp_ipv4.c
··· 58 58 #include <linux/times.h> 59 59 #include <linux/slab.h> 60 60 #include <linux/sched.h> 61 + #include <linux/sock_diag.h> 61 62 62 63 #include <net/net_namespace.h> 63 64 #include <net/icmp.h> ··· 3015 3014 } 3016 3015 3017 3016 #ifdef CONFIG_BPF_SYSCALL 3017 + union bpf_tcp_iter_batch_item { 3018 + struct sock *sk; 3019 + __u64 cookie; 3020 + }; 3021 + 3018 3022 struct bpf_tcp_iter_state { 3019 3023 struct tcp_iter_state state; 3020 3024 unsigned int cur_sk; 3021 3025 unsigned int end_sk; 3022 3026 unsigned int max_sk; 3023 - struct sock **batch; 3024 - bool st_bucket_done; 3027 + union bpf_tcp_iter_batch_item *batch; 3025 3028 }; 3026 3029 3027 3030 struct bpf_iter__tcp { ··· 3048 3043 3049 3044 static void bpf_iter_tcp_put_batch(struct bpf_tcp_iter_state *iter) 3050 3045 { 3051 - while (iter->cur_sk < iter->end_sk) 3052 - sock_gen_put(iter->batch[iter->cur_sk++]); 3046 + union bpf_tcp_iter_batch_item *item; 3047 + unsigned int cur_sk = iter->cur_sk; 3048 + __u64 cookie; 3049 + 3050 + /* Remember the cookies of the sockets we haven't seen yet, so we can 3051 + * pick up where we left off next time around. 3052 + */ 3053 + while (cur_sk < iter->end_sk) { 3054 + item = &iter->batch[cur_sk++]; 3055 + cookie = sock_gen_cookie(item->sk); 3056 + sock_gen_put(item->sk); 3057 + item->cookie = cookie; 3058 + } 3053 3059 } 3054 3060 3055 3061 static int bpf_iter_tcp_realloc_batch(struct bpf_tcp_iter_state *iter, 3056 - unsigned int new_batch_sz) 3062 + unsigned int new_batch_sz, gfp_t flags) 3057 3063 { 3058 - struct sock **new_batch; 3064 + union bpf_tcp_iter_batch_item *new_batch; 3059 3065 3060 3066 new_batch = kvmalloc(sizeof(*new_batch) * new_batch_sz, 3061 - GFP_USER | __GFP_NOWARN); 3067 + flags | __GFP_NOWARN); 3062 3068 if (!new_batch) 3063 3069 return -ENOMEM; 3064 3070 3065 - bpf_iter_tcp_put_batch(iter); 3071 + memcpy(new_batch, iter->batch, sizeof(*iter->batch) * iter->end_sk); 3066 3072 kvfree(iter->batch); 3067 3073 iter->batch = new_batch; 3068 3074 iter->max_sk = new_batch_sz; ··· 3081 3065 return 0; 3082 3066 } 3083 3067 3084 - static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, 3085 - struct sock *start_sk) 3068 + static struct sock *bpf_iter_tcp_resume_bucket(struct sock *first_sk, 3069 + union bpf_tcp_iter_batch_item *cookies, 3070 + int n_cookies) 3071 + { 3072 + struct hlist_nulls_node *node; 3073 + struct sock *sk; 3074 + int i; 3075 + 3076 + for (i = 0; i < n_cookies; i++) { 3077 + sk = first_sk; 3078 + sk_nulls_for_each_from(sk, node) 3079 + if (cookies[i].cookie == atomic64_read(&sk->sk_cookie)) 3080 + return sk; 3081 + } 3082 + 3083 + return NULL; 3084 + } 3085 + 3086 + static struct sock *bpf_iter_tcp_resume_listening(struct seq_file *seq) 3086 3087 { 3087 3088 struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; 3088 3089 struct bpf_tcp_iter_state *iter = seq->private; 3089 3090 struct tcp_iter_state *st = &iter->state; 3091 + unsigned int find_cookie = iter->cur_sk; 3092 + unsigned int end_cookie = iter->end_sk; 3093 + int resume_bucket = st->bucket; 3094 + struct sock *sk; 3095 + 3096 + if (end_cookie && find_cookie == end_cookie) 3097 + ++st->bucket; 3098 + 3099 + sk = listening_get_first(seq); 3100 + iter->cur_sk = 0; 3101 + iter->end_sk = 0; 3102 + 3103 + if (sk && st->bucket == resume_bucket && end_cookie) { 3104 + sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie], 3105 + end_cookie - find_cookie); 3106 + if (!sk) { 3107 + spin_unlock(&hinfo->lhash2[st->bucket].lock); 3108 + ++st->bucket; 3109 + sk = listening_get_first(seq); 3110 + } 3111 + } 3112 + 3113 + return sk; 3114 + } 3115 + 3116 + static struct sock *bpf_iter_tcp_resume_established(struct seq_file *seq) 3117 + { 3118 + struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; 3119 + struct bpf_tcp_iter_state *iter = seq->private; 3120 + struct tcp_iter_state *st = &iter->state; 3121 + unsigned int find_cookie = iter->cur_sk; 3122 + unsigned int end_cookie = iter->end_sk; 3123 + int resume_bucket = st->bucket; 3124 + struct sock *sk; 3125 + 3126 + if (end_cookie && find_cookie == end_cookie) 3127 + ++st->bucket; 3128 + 3129 + sk = established_get_first(seq); 3130 + iter->cur_sk = 0; 3131 + iter->end_sk = 0; 3132 + 3133 + if (sk && st->bucket == resume_bucket && end_cookie) { 3134 + sk = bpf_iter_tcp_resume_bucket(sk, &iter->batch[find_cookie], 3135 + end_cookie - find_cookie); 3136 + if (!sk) { 3137 + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); 3138 + ++st->bucket; 3139 + sk = established_get_first(seq); 3140 + } 3141 + } 3142 + 3143 + return sk; 3144 + } 3145 + 3146 + static struct sock *bpf_iter_tcp_resume(struct seq_file *seq) 3147 + { 3148 + struct bpf_tcp_iter_state *iter = seq->private; 3149 + struct tcp_iter_state *st = &iter->state; 3150 + struct sock *sk = NULL; 3151 + 3152 + switch (st->state) { 3153 + case TCP_SEQ_STATE_LISTENING: 3154 + sk = bpf_iter_tcp_resume_listening(seq); 3155 + if (sk) 3156 + break; 3157 + st->bucket = 0; 3158 + st->state = TCP_SEQ_STATE_ESTABLISHED; 3159 + fallthrough; 3160 + case TCP_SEQ_STATE_ESTABLISHED: 3161 + sk = bpf_iter_tcp_resume_established(seq); 3162 + break; 3163 + } 3164 + 3165 + return sk; 3166 + } 3167 + 3168 + static unsigned int bpf_iter_tcp_listening_batch(struct seq_file *seq, 3169 + struct sock **start_sk) 3170 + { 3171 + struct bpf_tcp_iter_state *iter = seq->private; 3090 3172 struct hlist_nulls_node *node; 3091 3173 unsigned int expected = 1; 3092 3174 struct sock *sk; 3093 3175 3094 - sock_hold(start_sk); 3095 - iter->batch[iter->end_sk++] = start_sk; 3176 + sock_hold(*start_sk); 3177 + iter->batch[iter->end_sk++].sk = *start_sk; 3096 3178 3097 - sk = sk_nulls_next(start_sk); 3179 + sk = sk_nulls_next(*start_sk); 3180 + *start_sk = NULL; 3098 3181 sk_nulls_for_each_from(sk, node) { 3099 3182 if (seq_sk_match(seq, sk)) { 3100 3183 if (iter->end_sk < iter->max_sk) { 3101 3184 sock_hold(sk); 3102 - iter->batch[iter->end_sk++] = sk; 3185 + iter->batch[iter->end_sk++].sk = sk; 3186 + } else if (!*start_sk) { 3187 + /* Remember where we left off. */ 3188 + *start_sk = sk; 3103 3189 } 3104 3190 expected++; 3105 3191 } 3106 3192 } 3107 - spin_unlock(&hinfo->lhash2[st->bucket].lock); 3108 3193 3109 3194 return expected; 3110 3195 } 3111 3196 3112 3197 static unsigned int bpf_iter_tcp_established_batch(struct seq_file *seq, 3113 - struct sock *start_sk) 3198 + struct sock **start_sk) 3114 3199 { 3115 - struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; 3116 3200 struct bpf_tcp_iter_state *iter = seq->private; 3117 - struct tcp_iter_state *st = &iter->state; 3118 3201 struct hlist_nulls_node *node; 3119 3202 unsigned int expected = 1; 3120 3203 struct sock *sk; 3121 3204 3122 - sock_hold(start_sk); 3123 - iter->batch[iter->end_sk++] = start_sk; 3205 + sock_hold(*start_sk); 3206 + iter->batch[iter->end_sk++].sk = *start_sk; 3124 3207 3125 - sk = sk_nulls_next(start_sk); 3208 + sk = sk_nulls_next(*start_sk); 3209 + *start_sk = NULL; 3126 3210 sk_nulls_for_each_from(sk, node) { 3127 3211 if (seq_sk_match(seq, sk)) { 3128 3212 if (iter->end_sk < iter->max_sk) { 3129 3213 sock_hold(sk); 3130 - iter->batch[iter->end_sk++] = sk; 3214 + iter->batch[iter->end_sk++].sk = sk; 3215 + } else if (!*start_sk) { 3216 + /* Remember where we left off. */ 3217 + *start_sk = sk; 3131 3218 } 3132 3219 expected++; 3133 3220 } 3134 3221 } 3135 - spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); 3136 3222 3137 3223 return expected; 3138 3224 } 3139 3225 3140 - static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) 3226 + static unsigned int bpf_iter_fill_batch(struct seq_file *seq, 3227 + struct sock **start_sk) 3228 + { 3229 + struct bpf_tcp_iter_state *iter = seq->private; 3230 + struct tcp_iter_state *st = &iter->state; 3231 + 3232 + if (st->state == TCP_SEQ_STATE_LISTENING) 3233 + return bpf_iter_tcp_listening_batch(seq, start_sk); 3234 + else 3235 + return bpf_iter_tcp_established_batch(seq, start_sk); 3236 + } 3237 + 3238 + static void bpf_iter_tcp_unlock_bucket(struct seq_file *seq) 3141 3239 { 3142 3240 struct inet_hashinfo *hinfo = seq_file_net(seq)->ipv4.tcp_death_row.hashinfo; 3143 3241 struct bpf_tcp_iter_state *iter = seq->private; 3144 3242 struct tcp_iter_state *st = &iter->state; 3243 + 3244 + if (st->state == TCP_SEQ_STATE_LISTENING) 3245 + spin_unlock(&hinfo->lhash2[st->bucket].lock); 3246 + else 3247 + spin_unlock_bh(inet_ehash_lockp(hinfo, st->bucket)); 3248 + } 3249 + 3250 + static struct sock *bpf_iter_tcp_batch(struct seq_file *seq) 3251 + { 3252 + struct bpf_tcp_iter_state *iter = seq->private; 3145 3253 unsigned int expected; 3146 - bool resized = false; 3147 3254 struct sock *sk; 3255 + int err; 3148 3256 3149 - /* The st->bucket is done. Directly advance to the next 3150 - * bucket instead of having the tcp_seek_last_pos() to skip 3151 - * one by one in the current bucket and eventually find out 3152 - * it has to advance to the next bucket. 3153 - */ 3154 - if (iter->st_bucket_done) { 3155 - st->offset = 0; 3156 - st->bucket++; 3157 - if (st->state == TCP_SEQ_STATE_LISTENING && 3158 - st->bucket > hinfo->lhash2_mask) { 3159 - st->state = TCP_SEQ_STATE_ESTABLISHED; 3160 - st->bucket = 0; 3161 - } 3162 - } 3163 - 3164 - again: 3165 - /* Get a new batch */ 3166 - iter->cur_sk = 0; 3167 - iter->end_sk = 0; 3168 - iter->st_bucket_done = false; 3169 - 3170 - sk = tcp_seek_last_pos(seq); 3257 + sk = bpf_iter_tcp_resume(seq); 3171 3258 if (!sk) 3172 3259 return NULL; /* Done */ 3173 3260 3174 - if (st->state == TCP_SEQ_STATE_LISTENING) 3175 - expected = bpf_iter_tcp_listening_batch(seq, sk); 3176 - else 3177 - expected = bpf_iter_tcp_established_batch(seq, sk); 3261 + expected = bpf_iter_fill_batch(seq, &sk); 3262 + if (likely(iter->end_sk == expected)) 3263 + goto done; 3178 3264 3179 - if (iter->end_sk == expected) { 3180 - iter->st_bucket_done = true; 3181 - return sk; 3265 + /* Batch size was too small. */ 3266 + bpf_iter_tcp_unlock_bucket(seq); 3267 + bpf_iter_tcp_put_batch(iter); 3268 + err = bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2, 3269 + GFP_USER); 3270 + if (err) 3271 + return ERR_PTR(err); 3272 + 3273 + sk = bpf_iter_tcp_resume(seq); 3274 + if (!sk) 3275 + return NULL; /* Done */ 3276 + 3277 + expected = bpf_iter_fill_batch(seq, &sk); 3278 + if (likely(iter->end_sk == expected)) 3279 + goto done; 3280 + 3281 + /* Batch size was still too small. Hold onto the lock while we try 3282 + * again with a larger batch to make sure the current bucket's size 3283 + * does not change in the meantime. 3284 + */ 3285 + err = bpf_iter_tcp_realloc_batch(iter, expected, GFP_NOWAIT); 3286 + if (err) { 3287 + bpf_iter_tcp_unlock_bucket(seq); 3288 + return ERR_PTR(err); 3182 3289 } 3183 3290 3184 - if (!resized && !bpf_iter_tcp_realloc_batch(iter, expected * 3 / 2)) { 3185 - resized = true; 3186 - goto again; 3187 - } 3188 - 3189 - return sk; 3291 + expected = bpf_iter_fill_batch(seq, &sk); 3292 + WARN_ON_ONCE(iter->end_sk != expected); 3293 + done: 3294 + bpf_iter_tcp_unlock_bucket(seq); 3295 + return iter->batch[0].sk; 3190 3296 } 3191 3297 3192 3298 static void *bpf_iter_tcp_seq_start(struct seq_file *seq, loff_t *pos) ··· 3338 3200 * meta.seq_num is used instead. 3339 3201 */ 3340 3202 st->num++; 3341 - /* Move st->offset to the next sk in the bucket such that 3342 - * the future start() will resume at st->offset in 3343 - * st->bucket. See tcp_seek_last_pos(). 3344 - */ 3345 - st->offset++; 3346 - sock_gen_put(iter->batch[iter->cur_sk++]); 3203 + sock_gen_put(iter->batch[iter->cur_sk++].sk); 3347 3204 } 3348 3205 3349 3206 if (iter->cur_sk < iter->end_sk) 3350 - sk = iter->batch[iter->cur_sk]; 3207 + sk = iter->batch[iter->cur_sk].sk; 3351 3208 else 3352 3209 sk = bpf_iter_tcp_batch(seq); 3353 3210 ··· 3408 3275 (void)tcp_prog_seq_show(prog, &meta, v, 0); 3409 3276 } 3410 3277 3411 - if (iter->cur_sk < iter->end_sk) { 3278 + if (iter->cur_sk < iter->end_sk) 3412 3279 bpf_iter_tcp_put_batch(iter); 3413 - iter->st_bucket_done = false; 3414 - } 3415 3280 } 3416 3281 3417 3282 static const struct seq_operations bpf_iter_tcp_seq_ops = { ··· 3727 3596 if (err) 3728 3597 return err; 3729 3598 3730 - err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ); 3599 + err = bpf_iter_tcp_realloc_batch(iter, INIT_BATCH_SZ, GFP_USER); 3731 3600 if (err) { 3732 3601 bpf_iter_fini_seq_net(priv_data); 3733 3602 return err;
+448 -10
tools/testing/selftests/bpf/prog_tests/sock_iter_batch.c
··· 1 1 // SPDX-License-Identifier: GPL-2.0 2 2 // Copyright (c) 2024 Meta 3 3 4 + #include <poll.h> 4 5 #include <test_progs.h> 5 6 #include "network_helpers.h" 6 7 #include "sock_iter_batch.skel.h" 7 8 8 9 #define TEST_NS "sock_iter_batch_netns" 10 + #define TEST_CHILD_NS "sock_iter_batch_child_netns" 9 11 10 12 static const int init_batch_size = 16; 11 13 static const int nr_soreuse = 4; ··· 120 118 return nth_sock_idx; 121 119 } 122 120 121 + static void destroy(int fd) 122 + { 123 + struct sock_iter_batch *skel = NULL; 124 + __u64 cookie = socket_cookie(fd); 125 + struct bpf_link *link = NULL; 126 + int iter_fd = -1; 127 + int nread; 128 + __u64 out; 129 + 130 + skel = sock_iter_batch__open(); 131 + if (!ASSERT_OK_PTR(skel, "sock_iter_batch__open")) 132 + goto done; 133 + 134 + skel->rodata->destroy_cookie = cookie; 135 + 136 + if (!ASSERT_OK(sock_iter_batch__load(skel), "sock_iter_batch__load")) 137 + goto done; 138 + 139 + link = bpf_program__attach_iter(skel->progs.iter_tcp_destroy, NULL); 140 + if (!ASSERT_OK_PTR(link, "bpf_program__attach_iter")) 141 + goto done; 142 + 143 + iter_fd = bpf_iter_create(bpf_link__fd(link)); 144 + if (!ASSERT_OK_FD(iter_fd, "bpf_iter_create")) 145 + goto done; 146 + 147 + /* Delete matching socket. */ 148 + nread = read(iter_fd, &out, sizeof(out)); 149 + ASSERT_GE(nread, 0, "nread"); 150 + if (nread) 151 + ASSERT_EQ(out, cookie, "cookie matches"); 152 + done: 153 + if (iter_fd >= 0) 154 + close(iter_fd); 155 + bpf_link__destroy(link); 156 + sock_iter_batch__destroy(skel); 157 + close(fd); 158 + } 159 + 123 160 static int get_seen_count(int fd, struct sock_count counts[], int n) 124 161 { 125 162 __u64 cookie = socket_cookie(fd); ··· 193 152 ASSERT_EQ(seen_once, n, "seen_once"); 194 153 } 195 154 155 + static int accept_from_one(struct pollfd *server_poll_fds, 156 + int server_poll_fds_len) 157 + { 158 + static const int poll_timeout_ms = 5000; /* 5s */ 159 + int ret; 160 + int i; 161 + 162 + ret = poll(server_poll_fds, server_poll_fds_len, poll_timeout_ms); 163 + if (!ASSERT_EQ(ret, 1, "poll")) 164 + return -1; 165 + 166 + for (i = 0; i < server_poll_fds_len; i++) 167 + if (server_poll_fds[i].revents & POLLIN) 168 + return accept(server_poll_fds[i].fd, NULL, NULL); 169 + 170 + return -1; 171 + } 172 + 173 + static int *connect_to_server(int family, int sock_type, const char *addr, 174 + __u16 port, int nr_connects, int *server_fds, 175 + int server_fds_len) 176 + { 177 + struct pollfd *server_poll_fds = NULL; 178 + int *established_socks = NULL; 179 + int i; 180 + 181 + server_poll_fds = calloc(server_fds_len, sizeof(*server_poll_fds)); 182 + if (!ASSERT_OK_PTR(server_poll_fds, "server_poll_fds")) 183 + return NULL; 184 + 185 + for (i = 0; i < server_fds_len; i++) { 186 + server_poll_fds[i].fd = server_fds[i]; 187 + server_poll_fds[i].events = POLLIN; 188 + } 189 + 190 + i = 0; 191 + 192 + established_socks = malloc(sizeof(*established_socks) * nr_connects*2); 193 + if (!ASSERT_OK_PTR(established_socks, "established_socks")) 194 + goto error; 195 + 196 + while (nr_connects--) { 197 + established_socks[i] = connect_to_addr_str(family, sock_type, 198 + addr, port, NULL); 199 + if (!ASSERT_OK_FD(established_socks[i], "connect_to_addr_str")) 200 + goto error; 201 + i++; 202 + established_socks[i] = accept_from_one(server_poll_fds, 203 + server_fds_len); 204 + if (!ASSERT_OK_FD(established_socks[i], "accept_from_one")) 205 + goto error; 206 + i++; 207 + } 208 + 209 + free(server_poll_fds); 210 + return established_socks; 211 + error: 212 + free_fds(established_socks, i); 213 + free(server_poll_fds); 214 + return NULL; 215 + } 216 + 196 217 static void remove_seen(int family, int sock_type, const char *addr, __u16 port, 197 - int *socks, int socks_len, struct sock_count *counts, 218 + int *socks, int socks_len, int *established_socks, 219 + int established_socks_len, struct sock_count *counts, 198 220 int counts_len, struct bpf_link *link, int iter_fd) 199 221 { 200 222 int close_idx; ··· 286 182 counts_len); 287 183 } 288 184 185 + static void remove_seen_established(int family, int sock_type, const char *addr, 186 + __u16 port, int *listen_socks, 187 + int listen_socks_len, int *established_socks, 188 + int established_socks_len, 189 + struct sock_count *counts, int counts_len, 190 + struct bpf_link *link, int iter_fd) 191 + { 192 + int close_idx; 193 + 194 + /* Iterate through all listening sockets. */ 195 + read_n(iter_fd, listen_socks_len, counts, counts_len); 196 + 197 + /* Make sure we saw all listening sockets exactly once. */ 198 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 199 + counts, counts_len); 200 + 201 + /* Leave one established socket. */ 202 + read_n(iter_fd, established_socks_len - 1, counts, counts_len); 203 + 204 + /* Close a socket we've already seen to remove it from the bucket. */ 205 + close_idx = get_nth_socket(established_socks, established_socks_len, 206 + link, listen_socks_len + 1); 207 + if (!ASSERT_GE(close_idx, 0, "close_idx")) 208 + return; 209 + destroy(established_socks[close_idx]); 210 + established_socks[close_idx] = -1; 211 + 212 + /* Iterate through the rest of the sockets. */ 213 + read_n(iter_fd, -1, counts, counts_len); 214 + 215 + /* Make sure the last socket wasn't skipped and that there were no 216 + * repeats. 217 + */ 218 + check_n_were_seen_once(established_socks, established_socks_len, 219 + established_socks_len - 1, counts, counts_len); 220 + } 221 + 289 222 static void remove_unseen(int family, int sock_type, const char *addr, 290 223 __u16 port, int *socks, int socks_len, 224 + int *established_socks, int established_socks_len, 291 225 struct sock_count *counts, int counts_len, 292 226 struct bpf_link *link, int iter_fd) 293 227 { ··· 356 214 counts_len); 357 215 } 358 216 217 + static void remove_unseen_established(int family, int sock_type, 218 + const char *addr, __u16 port, 219 + int *listen_socks, int listen_socks_len, 220 + int *established_socks, 221 + int established_socks_len, 222 + struct sock_count *counts, int counts_len, 223 + struct bpf_link *link, int iter_fd) 224 + { 225 + int close_idx; 226 + 227 + /* Iterate through all listening sockets. */ 228 + read_n(iter_fd, listen_socks_len, counts, counts_len); 229 + 230 + /* Make sure we saw all listening sockets exactly once. */ 231 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 232 + counts, counts_len); 233 + 234 + /* Iterate through the first established socket. */ 235 + read_n(iter_fd, 1, counts, counts_len); 236 + 237 + /* Make sure we saw one established socks. */ 238 + check_n_were_seen_once(established_socks, established_socks_len, 1, 239 + counts, counts_len); 240 + 241 + /* Close what would be the next socket in the bucket to exercise the 242 + * condition where we need to skip past the first cookie we remembered. 243 + */ 244 + close_idx = get_nth_socket(established_socks, established_socks_len, 245 + link, listen_socks_len + 1); 246 + if (!ASSERT_GE(close_idx, 0, "close_idx")) 247 + return; 248 + 249 + destroy(established_socks[close_idx]); 250 + established_socks[close_idx] = -1; 251 + 252 + /* Iterate through the rest of the sockets. */ 253 + read_n(iter_fd, -1, counts, counts_len); 254 + 255 + /* Make sure the remaining sockets were seen exactly once and that we 256 + * didn't repeat the socket that was already seen. 257 + */ 258 + check_n_were_seen_once(established_socks, established_socks_len, 259 + established_socks_len - 1, counts, counts_len); 260 + } 261 + 359 262 static void remove_all(int family, int sock_type, const char *addr, 360 263 __u16 port, int *socks, int socks_len, 264 + int *established_socks, int established_socks_len, 361 265 struct sock_count *counts, int counts_len, 362 266 struct bpf_link *link, int iter_fd) 363 267 { ··· 430 242 ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n"); 431 243 } 432 244 245 + static void remove_all_established(int family, int sock_type, const char *addr, 246 + __u16 port, int *listen_socks, 247 + int listen_socks_len, int *established_socks, 248 + int established_socks_len, 249 + struct sock_count *counts, int counts_len, 250 + struct bpf_link *link, int iter_fd) 251 + { 252 + int *close_idx = NULL; 253 + int i; 254 + 255 + /* Iterate through all listening sockets. */ 256 + read_n(iter_fd, listen_socks_len, counts, counts_len); 257 + 258 + /* Make sure we saw all listening sockets exactly once. */ 259 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 260 + counts, counts_len); 261 + 262 + /* Iterate through the first established socket. */ 263 + read_n(iter_fd, 1, counts, counts_len); 264 + 265 + /* Make sure we saw one established socks. */ 266 + check_n_were_seen_once(established_socks, established_socks_len, 1, 267 + counts, counts_len); 268 + 269 + /* Close all remaining sockets to exhaust the list of saved cookies and 270 + * exit without putting any sockets into the batch on the next read. 271 + */ 272 + close_idx = malloc(sizeof(int) * (established_socks_len - 1)); 273 + if (!ASSERT_OK_PTR(close_idx, "close_idx malloc")) 274 + return; 275 + for (i = 0; i < established_socks_len - 1; i++) { 276 + close_idx[i] = get_nth_socket(established_socks, 277 + established_socks_len, link, 278 + listen_socks_len + i); 279 + if (!ASSERT_GE(close_idx[i], 0, "close_idx")) 280 + return; 281 + } 282 + 283 + for (i = 0; i < established_socks_len - 1; i++) { 284 + destroy(established_socks[close_idx[i]]); 285 + established_socks[close_idx[i]] = -1; 286 + } 287 + 288 + /* Make sure there are no more sockets returned */ 289 + ASSERT_EQ(read_n(iter_fd, -1, counts, counts_len), 0, "read_n"); 290 + free(close_idx); 291 + } 292 + 433 293 static void add_some(int family, int sock_type, const char *addr, __u16 port, 434 - int *socks, int socks_len, struct sock_count *counts, 294 + int *socks, int socks_len, int *established_socks, 295 + int established_socks_len, struct sock_count *counts, 435 296 int counts_len, struct bpf_link *link, int iter_fd) 436 297 { 437 298 int *new_socks = NULL; ··· 508 271 free_fds(new_socks, socks_len); 509 272 } 510 273 274 + static void add_some_established(int family, int sock_type, const char *addr, 275 + __u16 port, int *listen_socks, 276 + int listen_socks_len, int *established_socks, 277 + int established_socks_len, 278 + struct sock_count *counts, 279 + int counts_len, struct bpf_link *link, 280 + int iter_fd) 281 + { 282 + int *new_socks = NULL; 283 + 284 + /* Iterate through all listening sockets. */ 285 + read_n(iter_fd, listen_socks_len, counts, counts_len); 286 + 287 + /* Make sure we saw all listening sockets exactly once. */ 288 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 289 + counts, counts_len); 290 + 291 + /* Iterate through the first established_socks_len - 1 sockets. */ 292 + read_n(iter_fd, established_socks_len - 1, counts, counts_len); 293 + 294 + /* Make sure we saw established_socks_len - 1 sockets exactly once. */ 295 + check_n_were_seen_once(established_socks, established_socks_len, 296 + established_socks_len - 1, counts, counts_len); 297 + 298 + /* Double the number of established sockets in the bucket. */ 299 + new_socks = connect_to_server(family, sock_type, addr, port, 300 + established_socks_len / 2, listen_socks, 301 + listen_socks_len); 302 + if (!ASSERT_OK_PTR(new_socks, "connect_to_server")) 303 + goto done; 304 + 305 + /* Iterate through the rest of the sockets. */ 306 + read_n(iter_fd, -1, counts, counts_len); 307 + 308 + /* Make sure each of the original sockets was seen exactly once. */ 309 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 310 + counts, counts_len); 311 + check_n_were_seen_once(established_socks, established_socks_len, 312 + established_socks_len, counts, counts_len); 313 + done: 314 + free_fds(new_socks, established_socks_len); 315 + } 316 + 511 317 static void force_realloc(int family, int sock_type, const char *addr, 512 318 __u16 port, int *socks, int socks_len, 319 + int *established_socks, int established_socks_len, 513 320 struct sock_count *counts, int counts_len, 514 321 struct bpf_link *link, int iter_fd) 515 322 { ··· 580 299 free_fds(new_socks, socks_len); 581 300 } 582 301 302 + static void force_realloc_established(int family, int sock_type, 303 + const char *addr, __u16 port, 304 + int *listen_socks, int listen_socks_len, 305 + int *established_socks, 306 + int established_socks_len, 307 + struct sock_count *counts, int counts_len, 308 + struct bpf_link *link, int iter_fd) 309 + { 310 + /* Iterate through all sockets to trigger a realloc. */ 311 + read_n(iter_fd, -1, counts, counts_len); 312 + 313 + /* Make sure each socket was seen exactly once. */ 314 + check_n_were_seen_once(listen_socks, listen_socks_len, listen_socks_len, 315 + counts, counts_len); 316 + check_n_were_seen_once(established_socks, established_socks_len, 317 + established_socks_len, counts, counts_len); 318 + } 319 + 583 320 struct test_case { 584 321 void (*test)(int family, int sock_type, const char *addr, __u16 port, 585 - int *socks, int socks_len, struct sock_count *counts, 322 + int *socks, int socks_len, int *established_socks, 323 + int established_socks_len, struct sock_count *counts, 586 324 int counts_len, struct bpf_link *link, int iter_fd); 587 325 const char *description; 326 + int ehash_buckets; 327 + int connections; 588 328 int init_socks; 589 329 int max_socks; 590 330 int sock_type; ··· 660 358 .family = AF_INET6, 661 359 .test = force_realloc, 662 360 }, 361 + { 362 + .description = "tcp: resume after removing a seen socket (listening)", 363 + .init_socks = nr_soreuse, 364 + .max_socks = nr_soreuse, 365 + .sock_type = SOCK_STREAM, 366 + .family = AF_INET6, 367 + .test = remove_seen, 368 + }, 369 + { 370 + .description = "tcp: resume after removing one unseen socket (listening)", 371 + .init_socks = nr_soreuse, 372 + .max_socks = nr_soreuse, 373 + .sock_type = SOCK_STREAM, 374 + .family = AF_INET6, 375 + .test = remove_unseen, 376 + }, 377 + { 378 + .description = "tcp: resume after removing all unseen sockets (listening)", 379 + .init_socks = nr_soreuse, 380 + .max_socks = nr_soreuse, 381 + .sock_type = SOCK_STREAM, 382 + .family = AF_INET6, 383 + .test = remove_all, 384 + }, 385 + { 386 + .description = "tcp: resume after adding a few sockets (listening)", 387 + .init_socks = nr_soreuse, 388 + .max_socks = nr_soreuse, 389 + .sock_type = SOCK_STREAM, 390 + /* Use AF_INET so that new sockets are added to the head of the 391 + * bucket's list. 392 + */ 393 + .family = AF_INET, 394 + .test = add_some, 395 + }, 396 + { 397 + .description = "tcp: force a realloc to occur (listening)", 398 + .init_socks = init_batch_size, 399 + .max_socks = init_batch_size * 2, 400 + .sock_type = SOCK_STREAM, 401 + /* Use AF_INET6 so that new sockets are added to the tail of the 402 + * bucket's list, needing to be added to the next batch to force 403 + * a realloc. 404 + */ 405 + .family = AF_INET6, 406 + .test = force_realloc, 407 + }, 408 + { 409 + .description = "tcp: resume after removing a seen socket (established)", 410 + /* Force all established sockets into one bucket */ 411 + .ehash_buckets = 1, 412 + .connections = nr_soreuse, 413 + .init_socks = nr_soreuse, 414 + /* Room for connect()ed and accept()ed sockets */ 415 + .max_socks = nr_soreuse * 3, 416 + .sock_type = SOCK_STREAM, 417 + .family = AF_INET6, 418 + .test = remove_seen_established, 419 + }, 420 + { 421 + .description = "tcp: resume after removing one unseen socket (established)", 422 + /* Force all established sockets into one bucket */ 423 + .ehash_buckets = 1, 424 + .connections = nr_soreuse, 425 + .init_socks = nr_soreuse, 426 + /* Room for connect()ed and accept()ed sockets */ 427 + .max_socks = nr_soreuse * 3, 428 + .sock_type = SOCK_STREAM, 429 + .family = AF_INET6, 430 + .test = remove_unseen_established, 431 + }, 432 + { 433 + .description = "tcp: resume after removing all unseen sockets (established)", 434 + /* Force all established sockets into one bucket */ 435 + .ehash_buckets = 1, 436 + .connections = nr_soreuse, 437 + .init_socks = nr_soreuse, 438 + /* Room for connect()ed and accept()ed sockets */ 439 + .max_socks = nr_soreuse * 3, 440 + .sock_type = SOCK_STREAM, 441 + .family = AF_INET6, 442 + .test = remove_all_established, 443 + }, 444 + { 445 + .description = "tcp: resume after adding a few sockets (established)", 446 + /* Force all established sockets into one bucket */ 447 + .ehash_buckets = 1, 448 + .connections = nr_soreuse, 449 + .init_socks = nr_soreuse, 450 + /* Room for connect()ed and accept()ed sockets */ 451 + .max_socks = nr_soreuse * 3, 452 + .sock_type = SOCK_STREAM, 453 + .family = AF_INET6, 454 + .test = add_some_established, 455 + }, 456 + { 457 + .description = "tcp: force a realloc to occur (established)", 458 + /* Force all established sockets into one bucket */ 459 + .ehash_buckets = 1, 460 + /* Bucket size will need to double when going from listening to 461 + * established sockets. 462 + */ 463 + .connections = init_batch_size, 464 + .init_socks = nr_soreuse, 465 + /* Room for connect()ed and accept()ed sockets */ 466 + .max_socks = nr_soreuse + (init_batch_size * 2), 467 + .sock_type = SOCK_STREAM, 468 + .family = AF_INET6, 469 + .test = force_realloc_established, 470 + }, 663 471 }; 664 472 665 473 static void do_resume_test(struct test_case *tc) 666 474 { 667 475 struct sock_iter_batch *skel = NULL; 476 + struct sock_count *counts = NULL; 668 477 static const __u16 port = 10001; 478 + struct nstoken *nstoken = NULL; 669 479 struct bpf_link *link = NULL; 670 - struct sock_count *counts; 480 + int *established_fds = NULL; 671 481 int err, iter_fd = -1; 672 482 const char *addr; 673 483 int *fds = NULL; 674 - int local_port; 484 + 485 + if (tc->ehash_buckets) { 486 + SYS_NOFAIL("ip netns del " TEST_CHILD_NS); 487 + SYS(done, "sysctl -wq net.ipv4.tcp_child_ehash_entries=%d", 488 + tc->ehash_buckets); 489 + SYS(done, "ip netns add %s", TEST_CHILD_NS); 490 + SYS(done, "ip -net %s link set dev lo up", TEST_CHILD_NS); 491 + nstoken = open_netns(TEST_CHILD_NS); 492 + if (!ASSERT_OK_PTR(nstoken, "open_child_netns")) 493 + goto done; 494 + } 675 495 676 496 counts = calloc(tc->max_socks, sizeof(*counts)); 677 497 if (!ASSERT_OK_PTR(counts, "counts")) ··· 808 384 tc->init_socks); 809 385 if (!ASSERT_OK_PTR(fds, "start_reuseport_server")) 810 386 goto done; 811 - local_port = get_socket_local_port(*fds); 812 - if (!ASSERT_GE(local_port, 0, "get_socket_local_port")) 813 - goto done; 814 - skel->rodata->ports[0] = ntohs(local_port); 387 + if (tc->connections) { 388 + established_fds = connect_to_server(tc->family, tc->sock_type, 389 + addr, port, 390 + tc->connections, fds, 391 + tc->init_socks); 392 + if (!ASSERT_OK_PTR(established_fds, "connect_to_server")) 393 + goto done; 394 + } 395 + skel->rodata->ports[0] = 0; 396 + skel->rodata->ports[1] = 0; 815 397 skel->rodata->sf = tc->family; 398 + skel->rodata->ss = 0; 816 399 817 400 err = sock_iter_batch__load(skel); 818 401 if (!ASSERT_OK(err, "sock_iter_batch__load")) ··· 837 406 goto done; 838 407 839 408 tc->test(tc->family, tc->sock_type, addr, port, fds, tc->init_socks, 840 - counts, tc->max_socks, link, iter_fd); 409 + established_fds, tc->connections*2, counts, tc->max_socks, 410 + link, iter_fd); 841 411 done: 412 + close_netns(nstoken); 413 + SYS_NOFAIL("ip netns del " TEST_CHILD_NS); 414 + SYS_NOFAIL("sysctl -w net.ipv4.tcp_child_ehash_entries=0"); 842 415 free(counts); 843 416 free_fds(fds, tc->init_socks); 417 + free_fds(established_fds, tc->connections*2); 844 418 if (iter_fd >= 0) 845 419 close(iter_fd); 846 420 bpf_link__destroy(link); ··· 890 454 skel->rodata->ports[i] = ntohs(local_port); 891 455 } 892 456 skel->rodata->sf = AF_INET6; 457 + if (sock_type == SOCK_STREAM) 458 + skel->rodata->ss = TCP_LISTEN; 893 459 894 460 err = sock_iter_batch__load(skel); 895 461 if (!ASSERT_OK(err, "sock_iter_batch__load"))
+31 -5
tools/testing/selftests/bpf/progs/sock_iter_batch.c
··· 23 23 } 24 24 25 25 volatile const unsigned int sf; 26 + volatile const unsigned int ss; 26 27 volatile const __u16 ports[2]; 27 28 unsigned int bucket[2]; 28 29 ··· 43 42 sock_cookie = bpf_get_socket_cookie(sk); 44 43 sk = bpf_core_cast(sk, struct sock); 45 44 if (sk->sk_family != sf || 46 - sk->sk_state != TCP_LISTEN || 47 - sk->sk_family == AF_INET6 ? 45 + (ss && sk->sk_state != ss) || 46 + (sk->sk_family == AF_INET6 ? 48 47 !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : 49 - !ipv4_addr_loopback(sk->sk_rcv_saddr)) 48 + !ipv4_addr_loopback(sk->sk_rcv_saddr))) 50 49 return 0; 51 50 52 51 if (sk->sk_num == ports[0]) 53 52 idx = 0; 54 53 else if (sk->sk_num == ports[1]) 55 54 idx = 1; 55 + else if (!ports[0] && !ports[1]) 56 + idx = 0; 56 57 else 57 58 return 0; 58 59 ··· 65 62 hinfo = net->ipv4.tcp_death_row.hashinfo; 66 63 bucket[idx] = hash & hinfo->lhash2_mask; 67 64 bpf_seq_write(ctx->meta->seq, &idx, sizeof(idx)); 65 + bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); 66 + 67 + return 0; 68 + } 69 + 70 + volatile const __u64 destroy_cookie; 71 + 72 + SEC("iter/tcp") 73 + int iter_tcp_destroy(struct bpf_iter__tcp *ctx) 74 + { 75 + struct sock_common *sk_common = (struct sock_common *)ctx->sk_common; 76 + __u64 sock_cookie; 77 + 78 + if (!sk_common) 79 + return 0; 80 + 81 + sock_cookie = bpf_get_socket_cookie(sk_common); 82 + if (sock_cookie != destroy_cookie) 83 + return 0; 84 + 85 + bpf_sock_destroy(sk_common); 68 86 bpf_seq_write(ctx->meta->seq, &sock_cookie, sizeof(sock_cookie)); 69 87 70 88 return 0; ··· 107 83 sock_cookie = bpf_get_socket_cookie(sk); 108 84 sk = bpf_core_cast(sk, struct sock); 109 85 if (sk->sk_family != sf || 110 - sk->sk_family == AF_INET6 ? 86 + (sk->sk_family == AF_INET6 ? 111 87 !ipv6_addr_loopback(&sk->sk_v6_rcv_saddr) : 112 - !ipv4_addr_loopback(sk->sk_rcv_saddr)) 88 + !ipv4_addr_loopback(sk->sk_rcv_saddr))) 113 89 return 0; 114 90 115 91 if (sk->sk_num == ports[0]) 116 92 idx = 0; 117 93 else if (sk->sk_num == ports[1]) 118 94 idx = 1; 95 + else if (!ports[0] && !ports[1]) 96 + idx = 0; 119 97 else 120 98 return 0; 121 99