at v5.0 221 lines 4.1 kB view raw
1// SPDX-License-Identifier: GPL-2.0 2// Copyright (c) 2018 Facebook 3 4#include <string.h> 5#include <unistd.h> 6 7#include <arpa/inet.h> 8#include <netinet/in.h> 9#include <sys/types.h> 10#include <sys/socket.h> 11 12#include <bpf/bpf.h> 13#include <bpf/libbpf.h> 14 15#include "bpf_rlimit.h" 16#include "cgroup_helpers.h" 17 18#define CG_PATH "/foo" 19#define SOCKET_COOKIE_PROG "./socket_cookie_prog.o" 20 21static int start_server(void) 22{ 23 struct sockaddr_in6 addr; 24 int fd; 25 26 fd = socket(AF_INET6, SOCK_STREAM, 0); 27 if (fd == -1) { 28 log_err("Failed to create server socket"); 29 goto out; 30 } 31 32 memset(&addr, 0, sizeof(addr)); 33 addr.sin6_family = AF_INET6; 34 addr.sin6_addr = in6addr_loopback; 35 addr.sin6_port = 0; 36 37 if (bind(fd, (const struct sockaddr *)&addr, sizeof(addr)) == -1) { 38 log_err("Failed to bind server socket"); 39 goto close_out; 40 } 41 42 if (listen(fd, 128) == -1) { 43 log_err("Failed to listen on server socket"); 44 goto close_out; 45 } 46 47 goto out; 48 49close_out: 50 close(fd); 51 fd = -1; 52out: 53 return fd; 54} 55 56static int connect_to_server(int server_fd) 57{ 58 struct sockaddr_storage addr; 59 socklen_t len = sizeof(addr); 60 int fd; 61 62 fd = socket(AF_INET6, SOCK_STREAM, 0); 63 if (fd == -1) { 64 log_err("Failed to create client socket"); 65 goto out; 66 } 67 68 if (getsockname(server_fd, (struct sockaddr *)&addr, &len)) { 69 log_err("Failed to get server addr"); 70 goto close_out; 71 } 72 73 if (connect(fd, (const struct sockaddr *)&addr, len) == -1) { 74 log_err("Fail to connect to server"); 75 goto close_out; 76 } 77 78 goto out; 79 80close_out: 81 close(fd); 82 fd = -1; 83out: 84 return fd; 85} 86 87static int validate_map(struct bpf_map *map, int client_fd) 88{ 89 __u32 cookie_expected_value; 90 struct sockaddr_in6 addr; 91 socklen_t len = sizeof(addr); 92 __u32 cookie_value; 93 __u64 cookie_key; 94 int err = 0; 95 int map_fd; 96 97 if (!map) { 98 log_err("Map not found in BPF object"); 99 goto err; 100 } 101 102 map_fd = bpf_map__fd(map); 103 104 err = bpf_map_get_next_key(map_fd, NULL, &cookie_key); 105 if (err) { 106 log_err("Can't get cookie key from map"); 107 goto out; 108 } 109 110 err = bpf_map_lookup_elem(map_fd, &cookie_key, &cookie_value); 111 if (err) { 112 log_err("Can't get cookie value from map"); 113 goto out; 114 } 115 116 err = getsockname(client_fd, (struct sockaddr *)&addr, &len); 117 if (err) { 118 log_err("Can't get client local addr"); 119 goto out; 120 } 121 122 cookie_expected_value = (ntohs(addr.sin6_port) << 8) | 0xFF; 123 if (cookie_value != cookie_expected_value) { 124 log_err("Unexpected value in map: %x != %x", cookie_value, 125 cookie_expected_value); 126 goto err; 127 } 128 129 goto out; 130err: 131 err = -1; 132out: 133 return err; 134} 135 136static int run_test(int cgfd) 137{ 138 enum bpf_attach_type attach_type; 139 struct bpf_prog_load_attr attr; 140 struct bpf_program *prog; 141 struct bpf_object *pobj; 142 const char *prog_name; 143 int server_fd = -1; 144 int client_fd = -1; 145 int prog_fd = -1; 146 int err = 0; 147 148 memset(&attr, 0, sizeof(attr)); 149 attr.file = SOCKET_COOKIE_PROG; 150 attr.prog_type = BPF_PROG_TYPE_UNSPEC; 151 152 err = bpf_prog_load_xattr(&attr, &pobj, &prog_fd); 153 if (err) { 154 log_err("Failed to load %s", attr.file); 155 goto out; 156 } 157 158 bpf_object__for_each_program(prog, pobj) { 159 prog_name = bpf_program__title(prog, /*needs_copy*/ false); 160 161 if (libbpf_attach_type_by_name(prog_name, &attach_type)) { 162 log_err("Unexpected prog: %s", prog_name); 163 goto err; 164 } 165 166 err = bpf_prog_attach(bpf_program__fd(prog), cgfd, attach_type, 167 BPF_F_ALLOW_OVERRIDE); 168 if (err) { 169 log_err("Failed to attach prog %s", prog_name); 170 goto out; 171 } 172 } 173 174 server_fd = start_server(); 175 if (server_fd == -1) 176 goto err; 177 178 client_fd = connect_to_server(server_fd); 179 if (client_fd == -1) 180 goto err; 181 182 if (validate_map(bpf_map__next(NULL, pobj), client_fd)) 183 goto err; 184 185 goto out; 186err: 187 err = -1; 188out: 189 close(client_fd); 190 close(server_fd); 191 bpf_object__close(pobj); 192 printf("%s\n", err ? "FAILED" : "PASSED"); 193 return err; 194} 195 196int main(int argc, char **argv) 197{ 198 int cgfd = -1; 199 int err = 0; 200 201 if (setup_cgroup_environment()) 202 goto err; 203 204 cgfd = create_and_get_cgroup(CG_PATH); 205 if (cgfd < 0) 206 goto err; 207 208 if (join_cgroup(CG_PATH)) 209 goto err; 210 211 if (run_test(cgfd)) 212 goto err; 213 214 goto out; 215err: 216 err = -1; 217out: 218 close(cgfd); 219 cleanup_cgroup_environment(); 220 return err; 221}