at v5.7 4.0 kB view raw
1// SPDX-License-Identifier: GPL-2.0 2// Copyright (c) 2018 Facebook 3 4#include <string.h> 5#include <unistd.h> 6 7#include <arpa/inet.h> 8#include <netinet/in.h> 9#include <sys/types.h> 10#include <sys/socket.h> 11 12#include <bpf/bpf.h> 13#include <bpf/libbpf.h> 14 15#include "bpf_rlimit.h" 16#include "cgroup_helpers.h" 17 18#define CG_PATH "/foo" 19#define SOCKET_COOKIE_PROG "./socket_cookie_prog.o" 20 21struct socket_cookie { 22 __u64 cookie_key; 23 __u32 cookie_value; 24}; 25 26static int start_server(void) 27{ 28 struct sockaddr_in6 addr; 29 int fd; 30 31 fd = socket(AF_INET6, SOCK_STREAM, 0); 32 if (fd == -1) { 33 log_err("Failed to create server socket"); 34 goto out; 35 } 36 37 memset(&addr, 0, sizeof(addr)); 38 addr.sin6_family = AF_INET6; 39 addr.sin6_addr = in6addr_loopback; 40 addr.sin6_port = 0; 41 42 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) == -1) { 43 log_err("Failed to bind server socket"); 44 goto close_out; 45 } 46 47 if (listen(fd, 128) == -1) { 48 log_err("Failed to listen on server socket"); 49 goto close_out; 50 } 51 52 goto out; 53 54close_out: 55 close(fd); 56 fd = -1; 57out: 58 return fd; 59} 60 61static int connect_to_server(int server_fd) 62{ 63 struct sockaddr_storage addr; 64 socklen_t len = sizeof(addr); 65 int fd; 66 67 fd = socket(AF_INET6, SOCK_STREAM, 0); 68 if (fd == -1) { 69 log_err("Failed to create client socket"); 70 goto out; 71 } 72 73 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 74 log_err("Failed to get server addr"); 75 goto close_out; 76 } 77 78 if (connect(fd, (const struct sockaddr *)&addr, len) == -1) { 79 log_err("Fail to connect to server"); 80 goto close_out; 81 } 82 83 goto out; 84 85close_out: 86 close(fd); 87 fd = -1; 88out: 89 return fd; 90} 91 92static int validate_map(struct bpf_map *map, int client_fd) 93{ 94 __u32 cookie_expected_value; 95 struct sockaddr_in6 addr; 96 socklen_t len = sizeof(addr); 97 struct socket_cookie val; 98 int err = 0; 99 int map_fd; 100 101 if (!map) { 102 log_err("Map not found in BPF object"); 103 goto err; 104 } 105 106 map_fd = bpf_map__fd(map); 107 108 err = bpf_map_lookup_elem(map_fd, &client_fd, &val); 109 110 err = getsockname(client_fd, (struct sockaddr *)&addr, &len); 111 if (err) { 112 log_err("Can't get client local addr"); 113 goto out; 114 } 115 116 cookie_expected_value = (ntohs(addr.sin6_port) << 8) | 0xFF; 117 if (val.cookie_value != cookie_expected_value) { 118 log_err("Unexpected value in map: %x != %x", val.cookie_value, 119 cookie_expected_value); 120 goto err; 121 } 122 123 goto out; 124err: 125 err = -1; 126out: 127 return err; 128} 129 130static int run_test(int cgfd) 131{ 132 enum bpf_attach_type attach_type; 133 struct bpf_prog_load_attr attr; 134 struct bpf_program *prog; 135 struct bpf_object *pobj; 136 const char *prog_name; 137 int server_fd = -1; 138 int client_fd = -1; 139 int prog_fd = -1; 140 int err = 0; 141 142 memset(&attr, 0, sizeof(attr)); 143 attr.file = SOCKET_COOKIE_PROG; 144 attr.prog_type = BPF_PROG_TYPE_UNSPEC; 145 attr.prog_flags = BPF_F_TEST_RND_HI32; 146 147 err = bpf_prog_load_xattr(&attr, &pobj, &prog_fd); 148 if (err) { 149 log_err("Failed to load %s", attr.file); 150 goto out; 151 } 152 153 bpf_object__for_each_program(prog, pobj) { 154 prog_name = bpf_program__title(prog, /*needs_copy*/ false); 155 156 if (libbpf_attach_type_by_name(prog_name, &attach_type)) 157 goto err; 158 159 err = bpf_prog_attach(bpf_program__fd(prog), cgfd, attach_type, 160 BPF_F_ALLOW_OVERRIDE); 161 if (err) { 162 log_err("Failed to attach prog %s", prog_name); 163 goto out; 164 } 165 } 166 167 server_fd = start_server(); 168 if (server_fd == -1) 169 goto err; 170 171 client_fd = connect_to_server(server_fd); 172 if (client_fd == -1) 173 goto err; 174 175 if (validate_map(bpf_map__next(NULL, pobj), client_fd)) 176 goto err; 177 178 goto out; 179err: 180 err = -1; 181out: 182 close(client_fd); 183 close(server_fd); 184 bpf_object__close(pobj); 185 printf("%s\n", err ? "FAILED" : "PASSED"); 186 return err; 187} 188 189int main(int argc, char **argv) 190{ 191 int cgfd = -1; 192 int err = 0; 193 194 if (setup_cgroup_environment()) 195 goto err; 196 197 cgfd = create_and_get_cgroup(CG_PATH); 198 if (cgfd < 0) 199 goto err; 200 201 if (join_cgroup(CG_PATH)) 202 goto err; 203 204 if (run_test(cgfd)) 205 goto err; 206 207 goto out; 208err: 209 err = -1; 210out: 211 close(cgfd); 212 cleanup_cgroup_environment(); 213 return err; 214}