Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

selftests/bpf: Test bpf_sock_destroy

The test cases for destroying sockets mirror the intended usages of the
bpf_sock_destroy kfunc using iterators.

The destroy helpers set `ECONNABORTED` error code that we can validate
in the test code with client sockets. But UDP sockets have an overriding
error code from `disconnect()` called during abort, so the error code
validation is only done for TCP sockets.

The failure test cases validate that the `bpf_sock_destroy` kfunc is not
allowed from program attach types other than BPF trace iterator, and
such programs fail to load.

Signed-off-by: Aditi Ghag <aditi.ghag@isovalent.com>
Link: https://lore.kernel.org/r/20230519225157.760788-10-aditi.ghag@isovalent.com
Signed-off-by: Martin KaFai Lau <martin.lau@kernel.org>

authored by

Aditi Ghag and committed by
Martin KaFai Lau
1a8bc229 176ba657

+388
+221
tools/testing/selftests/bpf/prog_tests/sock_destroy.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + #include <test_progs.h> 3 + #include <bpf/bpf_endian.h> 4 + 5 + #include "sock_destroy_prog.skel.h" 6 + #include "sock_destroy_prog_fail.skel.h" 7 + #include "network_helpers.h" 8 + 9 + #define TEST_NS "sock_destroy_netns" 10 + 11 + static void start_iter_sockets(struct bpf_program *prog) 12 + { 13 + struct bpf_link *link; 14 + char buf[50] = {}; 15 + int iter_fd, len; 16 + 17 + link = bpf_program__attach_iter(prog, NULL); 18 + if (!ASSERT_OK_PTR(link, "attach_iter")) 19 + return; 20 + 21 + iter_fd = bpf_iter_create(bpf_link__fd(link)); 22 + if (!ASSERT_GE(iter_fd, 0, "create_iter")) 23 + goto free_link; 24 + 25 + while ((len = read(iter_fd, buf, sizeof(buf))) > 0) 26 + ; 27 + ASSERT_GE(len, 0, "read"); 28 + 29 + close(iter_fd); 30 + 31 + free_link: 32 + bpf_link__destroy(link); 33 + } 34 + 35 + static void test_tcp_client(struct sock_destroy_prog *skel) 36 + { 37 + int serv = -1, clien = -1, accept_serv = -1, n; 38 + 39 + serv = start_server(AF_INET6, SOCK_STREAM, NULL, 0, 0); 40 + if (!ASSERT_GE(serv, 0, "start_server")) 41 + goto cleanup; 42 + 43 + clien = connect_to_fd(serv, 0); 44 + if (!ASSERT_GE(clien, 0, "connect_to_fd")) 45 + goto cleanup; 46 + 47 + accept_serv = accept(serv, NULL, NULL); 48 + if (!ASSERT_GE(accept_serv, 0, "serv accept")) 49 + goto cleanup; 50 + 51 + n = send(clien, "t", 1, 0); 52 + if (!ASSERT_EQ(n, 1, "client send")) 53 + goto cleanup; 54 + 55 + /* Run iterator program that destroys connected client sockets. */ 56 + start_iter_sockets(skel->progs.iter_tcp6_client); 57 + 58 + n = send(clien, "t", 1, 0); 59 + if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) 60 + goto cleanup; 61 + ASSERT_EQ(errno, ECONNABORTED, "error code on destroyed socket"); 62 + 63 + cleanup: 64 + if (clien != -1) 65 + close(clien); 66 + if (accept_serv != -1) 67 + close(accept_serv); 68 + if (serv != -1) 69 + close(serv); 70 + } 71 + 72 + static void test_tcp_server(struct sock_destroy_prog *skel) 73 + { 74 + int serv = -1, clien = -1, accept_serv = -1, n, serv_port; 75 + 76 + serv = start_server(AF_INET6, SOCK_STREAM, NULL, 0, 0); 77 + if (!ASSERT_GE(serv, 0, "start_server")) 78 + goto cleanup; 79 + serv_port = get_socket_local_port(serv); 80 + if (!ASSERT_GE(serv_port, 0, "get_sock_local_port")) 81 + goto cleanup; 82 + skel->bss->serv_port = (__be16) serv_port; 83 + 84 + clien = connect_to_fd(serv, 0); 85 + if (!ASSERT_GE(clien, 0, "connect_to_fd")) 86 + goto cleanup; 87 + 88 + accept_serv = accept(serv, NULL, NULL); 89 + if (!ASSERT_GE(accept_serv, 0, "serv accept")) 90 + goto cleanup; 91 + 92 + n = send(clien, "t", 1, 0); 93 + if (!ASSERT_EQ(n, 1, "client send")) 94 + goto cleanup; 95 + 96 + /* Run iterator program that destroys server sockets. */ 97 + start_iter_sockets(skel->progs.iter_tcp6_server); 98 + 99 + n = send(clien, "t", 1, 0); 100 + if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) 101 + goto cleanup; 102 + ASSERT_EQ(errno, ECONNRESET, "error code on destroyed socket"); 103 + 104 + cleanup: 105 + if (clien != -1) 106 + close(clien); 107 + if (accept_serv != -1) 108 + close(accept_serv); 109 + if (serv != -1) 110 + close(serv); 111 + } 112 + 113 + static void test_udp_client(struct sock_destroy_prog *skel) 114 + { 115 + int serv = -1, clien = -1, n = 0; 116 + 117 + serv = start_server(AF_INET6, SOCK_DGRAM, NULL, 0, 0); 118 + if (!ASSERT_GE(serv, 0, "start_server")) 119 + goto cleanup; 120 + 121 + clien = connect_to_fd(serv, 0); 122 + if (!ASSERT_GE(clien, 0, "connect_to_fd")) 123 + goto cleanup; 124 + 125 + n = send(clien, "t", 1, 0); 126 + if (!ASSERT_EQ(n, 1, "client send")) 127 + goto cleanup; 128 + 129 + /* Run iterator program that destroys sockets. */ 130 + start_iter_sockets(skel->progs.iter_udp6_client); 131 + 132 + n = send(clien, "t", 1, 0); 133 + if (!ASSERT_LT(n, 0, "client_send on destroyed socket")) 134 + goto cleanup; 135 + /* UDP sockets have an overriding error code after they are disconnected, 136 + * so we don't check for ECONNABORTED error code. 137 + */ 138 + 139 + cleanup: 140 + if (clien != -1) 141 + close(clien); 142 + if (serv != -1) 143 + close(serv); 144 + } 145 + 146 + static void test_udp_server(struct sock_destroy_prog *skel) 147 + { 148 + int *listen_fds = NULL, n, i, serv_port; 149 + unsigned int num_listens = 5; 150 + char buf[1]; 151 + 152 + /* Start reuseport servers. */ 153 + listen_fds = start_reuseport_server(AF_INET6, SOCK_DGRAM, 154 + "::1", 0, 0, num_listens); 155 + if (!ASSERT_OK_PTR(listen_fds, "start_reuseport_server")) 156 + goto cleanup; 157 + serv_port = get_socket_local_port(listen_fds[0]); 158 + if (!ASSERT_GE(serv_port, 0, "get_sock_local_port")) 159 + goto cleanup; 160 + skel->bss->serv_port = (__be16) serv_port; 161 + 162 + /* Run iterator program that destroys server sockets. */ 163 + start_iter_sockets(skel->progs.iter_udp6_server); 164 + 165 + for (i = 0; i < num_listens; ++i) { 166 + n = read(listen_fds[i], buf, sizeof(buf)); 167 + if (!ASSERT_EQ(n, -1, "read") || 168 + !ASSERT_EQ(errno, ECONNABORTED, "error code on destroyed socket")) 169 + break; 170 + } 171 + ASSERT_EQ(i, num_listens, "server socket"); 172 + 173 + cleanup: 174 + free_fds(listen_fds, num_listens); 175 + } 176 + 177 + void test_sock_destroy(void) 178 + { 179 + struct sock_destroy_prog *skel; 180 + struct nstoken *nstoken = NULL; 181 + int cgroup_fd; 182 + 183 + skel = sock_destroy_prog__open_and_load(); 184 + if (!ASSERT_OK_PTR(skel, "skel_open")) 185 + return; 186 + 187 + cgroup_fd = test__join_cgroup("/sock_destroy"); 188 + if (!ASSERT_GE(cgroup_fd, 0, "join_cgroup")) 189 + goto cleanup; 190 + 191 + skel->links.sock_connect = bpf_program__attach_cgroup( 192 + skel->progs.sock_connect, cgroup_fd); 193 + if (!ASSERT_OK_PTR(skel->links.sock_connect, "prog_attach")) 194 + goto cleanup; 195 + 196 + SYS(cleanup, "ip netns add %s", TEST_NS); 197 + SYS(cleanup, "ip -net %s link set dev lo up", TEST_NS); 198 + 199 + nstoken = open_netns(TEST_NS); 200 + if (!ASSERT_OK_PTR(nstoken, "open_netns")) 201 + goto cleanup; 202 + 203 + if (test__start_subtest("tcp_client")) 204 + test_tcp_client(skel); 205 + if (test__start_subtest("tcp_server")) 206 + test_tcp_server(skel); 207 + if (test__start_subtest("udp_client")) 208 + test_udp_client(skel); 209 + if (test__start_subtest("udp_server")) 210 + test_udp_server(skel); 211 + 212 + RUN_TESTS(sock_destroy_prog_fail); 213 + 214 + cleanup: 215 + if (nstoken) 216 + close_netns(nstoken); 217 + SYS_NOFAIL("ip netns del " TEST_NS " &> /dev/null"); 218 + if (cgroup_fd >= 0) 219 + close(cgroup_fd); 220 + sock_destroy_prog__destroy(skel); 221 + }
+145
tools/testing/selftests/bpf/progs/sock_destroy_prog.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + 3 + #include "vmlinux.h" 4 + #include <bpf/bpf_helpers.h> 5 + #include <bpf/bpf_endian.h> 6 + 7 + #include "bpf_tracing_net.h" 8 + 9 + __be16 serv_port = 0; 10 + 11 + int bpf_sock_destroy(struct sock_common *sk) __ksym; 12 + 13 + struct { 14 + __uint(type, BPF_MAP_TYPE_ARRAY); 15 + __uint(max_entries, 1); 16 + __type(key, __u32); 17 + __type(value, __u64); 18 + } tcp_conn_sockets SEC(".maps"); 19 + 20 + struct { 21 + __uint(type, BPF_MAP_TYPE_ARRAY); 22 + __uint(max_entries, 1); 23 + __type(key, __u32); 24 + __type(value, __u64); 25 + } udp_conn_sockets SEC(".maps"); 26 + 27 + SEC("cgroup/connect6") 28 + int sock_connect(struct bpf_sock_addr *ctx) 29 + { 30 + __u64 sock_cookie = 0; 31 + int key = 0; 32 + __u32 keyc = 0; 33 + 34 + if (ctx->family != AF_INET6 || ctx->user_family != AF_INET6) 35 + return 1; 36 + 37 + sock_cookie = bpf_get_socket_cookie(ctx); 38 + if (ctx->protocol == IPPROTO_TCP) 39 + bpf_map_update_elem(&tcp_conn_sockets, &key, &sock_cookie, 0); 40 + else if (ctx->protocol == IPPROTO_UDP) 41 + bpf_map_update_elem(&udp_conn_sockets, &keyc, &sock_cookie, 0); 42 + else 43 + return 1; 44 + 45 + return 1; 46 + } 47 + 48 + SEC("iter/tcp") 49 + int iter_tcp6_client(struct bpf_iter__tcp *ctx) 50 + { 51 + struct sock_common *sk_common = ctx->sk_common; 52 + __u64 sock_cookie = 0; 53 + __u64 *val; 54 + int key = 0; 55 + 56 + if (!sk_common) 57 + return 0; 58 + 59 + if (sk_common->skc_family != AF_INET6) 60 + return 0; 61 + 62 + sock_cookie = bpf_get_socket_cookie(sk_common); 63 + val = bpf_map_lookup_elem(&tcp_conn_sockets, &key); 64 + if (!val) 65 + return 0; 66 + /* Destroy connected client sockets. */ 67 + if (sock_cookie == *val) 68 + bpf_sock_destroy(sk_common); 69 + 70 + return 0; 71 + } 72 + 73 + SEC("iter/tcp") 74 + int iter_tcp6_server(struct bpf_iter__tcp *ctx) 75 + { 76 + struct sock_common *sk_common = ctx->sk_common; 77 + const struct inet_connection_sock *icsk; 78 + const struct inet_sock *inet; 79 + struct tcp6_sock *tcp_sk; 80 + __be16 srcp; 81 + 82 + if (!sk_common) 83 + return 0; 84 + 85 + if (sk_common->skc_family != AF_INET6) 86 + return 0; 87 + 88 + tcp_sk = bpf_skc_to_tcp6_sock(sk_common); 89 + if (!tcp_sk) 90 + return 0; 91 + 92 + icsk = &tcp_sk->tcp.inet_conn; 93 + inet = &icsk->icsk_inet; 94 + srcp = inet->inet_sport; 95 + 96 + /* Destroy server sockets. */ 97 + if (srcp == serv_port) 98 + bpf_sock_destroy(sk_common); 99 + 100 + return 0; 101 + } 102 + 103 + 104 + SEC("iter/udp") 105 + int iter_udp6_client(struct bpf_iter__udp *ctx) 106 + { 107 + struct udp_sock *udp_sk = ctx->udp_sk; 108 + struct sock *sk = (struct sock *) udp_sk; 109 + __u64 sock_cookie = 0, *val; 110 + int key = 0; 111 + 112 + if (!sk) 113 + return 0; 114 + 115 + sock_cookie = bpf_get_socket_cookie(sk); 116 + val = bpf_map_lookup_elem(&udp_conn_sockets, &key); 117 + if (!val) 118 + return 0; 119 + /* Destroy connected client sockets. */ 120 + if (sock_cookie == *val) 121 + bpf_sock_destroy((struct sock_common *)sk); 122 + 123 + return 0; 124 + } 125 + 126 + SEC("iter/udp") 127 + int iter_udp6_server(struct bpf_iter__udp *ctx) 128 + { 129 + struct udp_sock *udp_sk = ctx->udp_sk; 130 + struct sock *sk = (struct sock *) udp_sk; 131 + struct inet_sock *inet; 132 + __be16 srcp; 133 + 134 + if (!sk) 135 + return 0; 136 + 137 + inet = &udp_sk->inet; 138 + srcp = inet->inet_sport; 139 + if (srcp == serv_port) 140 + bpf_sock_destroy((struct sock_common *)sk); 141 + 142 + return 0; 143 + } 144 + 145 + char _license[] SEC("license") = "GPL";
+22
tools/testing/selftests/bpf/progs/sock_destroy_prog_fail.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + 3 + #include "vmlinux.h" 4 + #include <bpf/bpf_tracing.h> 5 + #include <bpf/bpf_helpers.h> 6 + 7 + #include "bpf_misc.h" 8 + 9 + char _license[] SEC("license") = "GPL"; 10 + 11 + int bpf_sock_destroy(struct sock_common *sk) __ksym; 12 + 13 + SEC("tp_btf/tcp_destroy_sock") 14 + __failure __msg("calling kernel function bpf_sock_destroy is not allowed") 15 + int BPF_PROG(trace_tcp_destroy_sock, struct sock *sk) 16 + { 17 + /* should not load */ 18 + bpf_sock_destroy((struct sock_common *)sk); 19 + 20 + return 0; 21 + } 22 +