Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: GPL-2.0
2
3#define _GNU_SOURCE
4
5#include <assert.h>
6#include <errno.h>
7#include <fcntl.h>
8#include <limits.h>
9#include <string.h>
10#include <stdarg.h>
11#include <stdbool.h>
12#include <stdint.h>
13#include <inttypes.h>
14#include <stdio.h>
15#include <stdlib.h>
16#include <strings.h>
17#include <unistd.h>
18#include <time.h>
19
20#include <sys/ioctl.h>
21#include <sys/random.h>
22#include <sys/socket.h>
23#include <sys/types.h>
24#include <sys/wait.h>
25
26#include <netdb.h>
27#include <netinet/in.h>
28
29#include <linux/tcp.h>
30#include <linux/sockios.h>
31
32#ifndef IPPROTO_MPTCP
33#define IPPROTO_MPTCP 262
34#endif
35#ifndef SOL_MPTCP
36#define SOL_MPTCP 284
37#endif
38
39static int pf = AF_INET;
40static int proto_tx = IPPROTO_MPTCP;
41static int proto_rx = IPPROTO_MPTCP;
42
43static void die_perror(const char *msg)
44{
45 perror(msg);
46 exit(1);
47}
48
49static void die_usage(int r)
50{
51 fprintf(stderr, "Usage: mptcp_inq [-6] [ -t tcp|mptcp ] [ -r tcp|mptcp]\n");
52 exit(r);
53}
54
55static void xerror(const char *fmt, ...)
56{
57 va_list ap;
58
59 va_start(ap, fmt);
60 vfprintf(stderr, fmt, ap);
61 va_end(ap);
62 fputc('\n', stderr);
63 exit(1);
64}
65
66static const char *getxinfo_strerr(int err)
67{
68 if (err == EAI_SYSTEM)
69 return strerror(errno);
70
71 return gai_strerror(err);
72}
73
74static void xgetaddrinfo(const char *node, const char *service,
75 struct addrinfo *hints,
76 struct addrinfo **res)
77{
78again:
79 int err = getaddrinfo(node, service, hints, res);
80
81 if (err) {
82 const char *errstr;
83
84 if (err == EAI_SOCKTYPE) {
85 hints->ai_protocol = IPPROTO_TCP;
86 goto again;
87 }
88
89 errstr = getxinfo_strerr(err);
90
91 fprintf(stderr, "Fatal: getaddrinfo(%s:%s): %s\n",
92 node ? node : "", service ? service : "", errstr);
93 exit(1);
94 }
95}
96
97static int sock_listen_mptcp(const char * const listenaddr,
98 const char * const port)
99{
100 int sock = -1;
101 struct addrinfo hints = {
102 .ai_protocol = IPPROTO_MPTCP,
103 .ai_socktype = SOCK_STREAM,
104 .ai_flags = AI_PASSIVE | AI_NUMERICHOST
105 };
106
107 hints.ai_family = pf;
108
109 struct addrinfo *a, *addr;
110 int one = 1;
111
112 xgetaddrinfo(listenaddr, port, &hints, &addr);
113 hints.ai_family = pf;
114
115 for (a = addr; a; a = a->ai_next) {
116 sock = socket(a->ai_family, a->ai_socktype, proto_rx);
117 if (sock < 0)
118 continue;
119
120 if (-1 == setsockopt(sock, SOL_SOCKET, SO_REUSEADDR, &one,
121 sizeof(one)))
122 perror("setsockopt");
123
124 if (bind(sock, a->ai_addr, a->ai_addrlen) == 0)
125 break; /* success */
126
127 perror("bind");
128 close(sock);
129 sock = -1;
130 }
131
132 freeaddrinfo(addr);
133
134 if (sock < 0)
135 xerror("could not create listen socket");
136
137 if (listen(sock, 20))
138 die_perror("listen");
139
140 return sock;
141}
142
143static int sock_connect_mptcp(const char * const remoteaddr,
144 const char * const port, int proto)
145{
146 struct addrinfo hints = {
147 .ai_protocol = IPPROTO_MPTCP,
148 .ai_socktype = SOCK_STREAM,
149 };
150 struct addrinfo *a, *addr;
151 int sock = -1;
152
153 hints.ai_family = pf;
154
155 xgetaddrinfo(remoteaddr, port, &hints, &addr);
156 for (a = addr; a; a = a->ai_next) {
157 sock = socket(a->ai_family, a->ai_socktype, proto);
158 if (sock < 0)
159 continue;
160
161 if (connect(sock, a->ai_addr, a->ai_addrlen) == 0)
162 break; /* success */
163
164 die_perror("connect");
165 }
166
167 if (sock < 0)
168 xerror("could not create connect socket");
169
170 freeaddrinfo(addr);
171 return sock;
172}
173
174static int protostr_to_num(const char *s)
175{
176 if (strcasecmp(s, "tcp") == 0)
177 return IPPROTO_TCP;
178 if (strcasecmp(s, "mptcp") == 0)
179 return IPPROTO_MPTCP;
180
181 die_usage(1);
182 return 0;
183}
184
185static void parse_opts(int argc, char **argv)
186{
187 int c;
188
189 while ((c = getopt(argc, argv, "h6t:r:")) != -1) {
190 switch (c) {
191 case 'h':
192 die_usage(0);
193 break;
194 case '6':
195 pf = AF_INET6;
196 break;
197 case 't':
198 proto_tx = protostr_to_num(optarg);
199 break;
200 case 'r':
201 proto_rx = protostr_to_num(optarg);
202 break;
203 default:
204 die_usage(1);
205 break;
206 }
207 }
208}
209
210/* wait up to timeout milliseconds */
211static void wait_for_ack(int fd, int timeout, size_t total)
212{
213 int i;
214
215 for (i = 0; i < timeout; i++) {
216 int nsd, ret, queued = -1;
217 struct timespec req;
218
219 ret = ioctl(fd, TIOCOUTQ, &queued);
220 if (ret < 0)
221 die_perror("TIOCOUTQ");
222
223 ret = ioctl(fd, SIOCOUTQNSD, &nsd);
224 if (ret < 0)
225 die_perror("SIOCOUTQNSD");
226
227 if ((size_t)queued > total)
228 xerror("TIOCOUTQ %u, but only %zu expected\n", queued, total);
229 assert(nsd <= queued);
230
231 if (queued == 0)
232 return;
233
234 /* wait for peer to ack rx of all data */
235 req.tv_sec = 0;
236 req.tv_nsec = 1 * 1000 * 1000ul; /* 1ms */
237 nanosleep(&req, NULL);
238 }
239
240 xerror("still tx data queued after %u ms\n", timeout);
241}
242
243static void connect_one_server(int fd, int unixfd)
244{
245 size_t len, i, total, sent;
246 char buf[4096], buf2[4096];
247 ssize_t ret;
248
249 len = rand() % (sizeof(buf) - 1);
250
251 if (len < 128)
252 len = 128;
253
254 for (i = 0; i < len ; i++) {
255 buf[i] = rand() % 26;
256 buf[i] += 'A';
257 }
258
259 buf[i] = '\n';
260
261 /* un-block server */
262 ret = read(unixfd, buf2, 4);
263 assert(ret == 4);
264
265 assert(strncmp(buf2, "xmit", 4) == 0);
266
267 ret = write(unixfd, &len, sizeof(len));
268 assert(ret == (ssize_t)sizeof(len));
269
270 ret = write(fd, buf, len);
271 if (ret < 0)
272 die_perror("write");
273
274 if (ret != (ssize_t)len)
275 xerror("short write");
276
277 ret = read(unixfd, buf2, 4);
278 assert(strncmp(buf2, "huge", 4) == 0);
279
280 total = rand() % (16 * 1024 * 1024);
281 total += (1 * 1024 * 1024);
282 sent = total;
283
284 ret = write(unixfd, &total, sizeof(total));
285 assert(ret == (ssize_t)sizeof(total));
286
287 wait_for_ack(fd, 5000, len);
288
289 while (total > 0) {
290 if (total > sizeof(buf))
291 len = sizeof(buf);
292 else
293 len = total;
294
295 ret = write(fd, buf, len);
296 if (ret < 0)
297 die_perror("write");
298 total -= ret;
299
300 /* we don't have to care about buf content, only
301 * number of total bytes sent
302 */
303 }
304
305 ret = read(unixfd, buf2, 4);
306 assert(ret == 4);
307 assert(strncmp(buf2, "shut", 4) == 0);
308
309 wait_for_ack(fd, 5000, sent);
310
311 ret = write(fd, buf, 1);
312 assert(ret == 1);
313 close(fd);
314 ret = write(unixfd, "closed", 6);
315 assert(ret == 6);
316
317 close(unixfd);
318}
319
320static void get_tcp_inq(struct msghdr *msgh, unsigned int *inqv)
321{
322 struct cmsghdr *cmsg;
323
324 for (cmsg = CMSG_FIRSTHDR(msgh); cmsg ; cmsg = CMSG_NXTHDR(msgh, cmsg)) {
325 if (cmsg->cmsg_level == IPPROTO_TCP && cmsg->cmsg_type == TCP_CM_INQ) {
326 memcpy(inqv, CMSG_DATA(cmsg), sizeof(*inqv));
327 return;
328 }
329 }
330
331 xerror("could not find TCP_CM_INQ cmsg type");
332}
333
334static void process_one_client(int fd, int unixfd)
335{
336 unsigned int tcp_inq;
337 size_t expect_len;
338 char msg_buf[4096];
339 char buf[4096];
340 char tmp[16];
341 struct iovec iov = {
342 .iov_base = buf,
343 .iov_len = 1,
344 };
345 struct msghdr msg = {
346 .msg_iov = &iov,
347 .msg_iovlen = 1,
348 .msg_control = msg_buf,
349 .msg_controllen = sizeof(msg_buf),
350 };
351 ssize_t ret, tot;
352
353 ret = write(unixfd, "xmit", 4);
354 assert(ret == 4);
355
356 ret = read(unixfd, &expect_len, sizeof(expect_len));
357 assert(ret == (ssize_t)sizeof(expect_len));
358
359 if (expect_len > sizeof(buf))
360 xerror("expect len %zu exceeds buffer size", expect_len);
361
362 for (;;) {
363 struct timespec req;
364 unsigned int queued;
365
366 ret = ioctl(fd, FIONREAD, &queued);
367 if (ret < 0)
368 die_perror("FIONREAD");
369 if (queued > expect_len)
370 xerror("FIONREAD returned %u, but only %zu expected\n",
371 queued, expect_len);
372 if (queued == expect_len)
373 break;
374
375 req.tv_sec = 0;
376 req.tv_nsec = 1000 * 1000ul;
377 nanosleep(&req, NULL);
378 }
379
380 /* read one byte, expect cmsg to return expected - 1 */
381 ret = recvmsg(fd, &msg, 0);
382 if (ret < 0)
383 die_perror("recvmsg");
384
385 if (msg.msg_controllen == 0)
386 xerror("msg_controllen is 0");
387
388 get_tcp_inq(&msg, &tcp_inq);
389
390 assert((size_t)tcp_inq == (expect_len - 1));
391
392 iov.iov_len = sizeof(buf);
393 ret = recvmsg(fd, &msg, 0);
394 if (ret < 0)
395 die_perror("recvmsg");
396
397 /* should have gotten exact remainder of all pending data */
398 assert(ret == (ssize_t)tcp_inq);
399
400 /* should be 0, all drained */
401 get_tcp_inq(&msg, &tcp_inq);
402 assert(tcp_inq == 0);
403
404 /* request a large swath of data. */
405 ret = write(unixfd, "huge", 4);
406 assert(ret == 4);
407
408 ret = read(unixfd, &expect_len, sizeof(expect_len));
409 assert(ret == (ssize_t)sizeof(expect_len));
410
411 /* peer should send us a few mb of data */
412 if (expect_len <= sizeof(buf))
413 xerror("expect len %zu too small\n", expect_len);
414
415 tot = 0;
416 do {
417 iov.iov_len = sizeof(buf);
418 ret = recvmsg(fd, &msg, 0);
419 if (ret < 0)
420 die_perror("recvmsg");
421
422 tot += ret;
423
424 get_tcp_inq(&msg, &tcp_inq);
425
426 if (tcp_inq > expect_len - tot)
427 xerror("inq %d, remaining %d total_len %d\n",
428 tcp_inq, expect_len - tot, (int)expect_len);
429
430 assert(tcp_inq <= expect_len - tot);
431 } while ((size_t)tot < expect_len);
432
433 ret = write(unixfd, "shut", 4);
434 assert(ret == 4);
435
436 /* wait for hangup. Should have received one more byte of data. */
437 ret = read(unixfd, tmp, sizeof(tmp));
438 assert(ret == 6);
439 assert(strncmp(tmp, "closed", 6) == 0);
440
441 sleep(1);
442
443 iov.iov_len = 1;
444 ret = recvmsg(fd, &msg, 0);
445 if (ret < 0)
446 die_perror("recvmsg");
447 assert(ret == 1);
448
449 get_tcp_inq(&msg, &tcp_inq);
450
451 /* tcp_inq should be 1 due to received fin. */
452 assert(tcp_inq == 1);
453
454 iov.iov_len = 1;
455 ret = recvmsg(fd, &msg, 0);
456 if (ret < 0)
457 die_perror("recvmsg");
458
459 /* expect EOF */
460 assert(ret == 0);
461 get_tcp_inq(&msg, &tcp_inq);
462 assert(tcp_inq == 1);
463
464 close(fd);
465}
466
467static int xaccept(int s)
468{
469 int fd = accept(s, NULL, 0);
470
471 if (fd < 0)
472 die_perror("accept");
473
474 return fd;
475}
476
477static int server(int unixfd)
478{
479 int fd = -1, r, on = 1;
480
481 switch (pf) {
482 case AF_INET:
483 fd = sock_listen_mptcp("127.0.0.1", "15432");
484 break;
485 case AF_INET6:
486 fd = sock_listen_mptcp("::1", "15432");
487 break;
488 default:
489 xerror("Unknown pf %d\n", pf);
490 break;
491 }
492
493 r = write(unixfd, "conn", 4);
494 assert(r == 4);
495
496 alarm(15);
497 r = xaccept(fd);
498
499 if (-1 == setsockopt(r, IPPROTO_TCP, TCP_INQ, &on, sizeof(on)))
500 die_perror("setsockopt");
501
502 process_one_client(r, unixfd);
503
504 return 0;
505}
506
507static int client(int unixfd)
508{
509 int fd = -1;
510
511 alarm(15);
512
513 switch (pf) {
514 case AF_INET:
515 fd = sock_connect_mptcp("127.0.0.1", "15432", proto_tx);
516 break;
517 case AF_INET6:
518 fd = sock_connect_mptcp("::1", "15432", proto_tx);
519 break;
520 default:
521 xerror("Unknown pf %d\n", pf);
522 }
523
524 connect_one_server(fd, unixfd);
525
526 return 0;
527}
528
529static void init_rng(void)
530{
531 unsigned int foo;
532
533 if (getrandom(&foo, sizeof(foo), 0) == -1) {
534 perror("getrandom");
535 exit(1);
536 }
537
538 srand(foo);
539}
540
541static pid_t xfork(void)
542{
543 pid_t p = fork();
544
545 if (p < 0)
546 die_perror("fork");
547 else if (p == 0)
548 init_rng();
549
550 return p;
551}
552
553static int rcheck(int wstatus, const char *what)
554{
555 if (WIFEXITED(wstatus)) {
556 if (WEXITSTATUS(wstatus) == 0)
557 return 0;
558 fprintf(stderr, "%s exited, status=%d\n", what, WEXITSTATUS(wstatus));
559 return WEXITSTATUS(wstatus);
560 } else if (WIFSIGNALED(wstatus)) {
561 xerror("%s killed by signal %d\n", what, WTERMSIG(wstatus));
562 } else if (WIFSTOPPED(wstatus)) {
563 xerror("%s stopped by signal %d\n", what, WSTOPSIG(wstatus));
564 }
565
566 return 111;
567}
568
569int main(int argc, char *argv[])
570{
571 int e1, e2, wstatus;
572 pid_t s, c, ret;
573 int unixfds[2];
574
575 parse_opts(argc, argv);
576
577 e1 = socketpair(AF_UNIX, SOCK_DGRAM, 0, unixfds);
578 if (e1 < 0)
579 die_perror("pipe");
580
581 s = xfork();
582 if (s == 0)
583 return server(unixfds[1]);
584
585 close(unixfds[1]);
586
587 /* wait until server bound a socket */
588 e1 = read(unixfds[0], &e1, 4);
589 assert(e1 == 4);
590
591 c = xfork();
592 if (c == 0)
593 return client(unixfds[0]);
594
595 close(unixfds[0]);
596
597 ret = waitpid(s, &wstatus, 0);
598 if (ret == -1)
599 die_perror("waitpid");
600 e1 = rcheck(wstatus, "server");
601 ret = waitpid(c, &wstatus, 0);
602 if (ret == -1)
603 die_perror("waitpid");
604 e2 = rcheck(wstatus, "client");
605
606 return e1 ? e1 : e2;
607}