Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

tcp: allow again tcp_disconnect() when threads are waiting

As reported by Tom, .NET and applications build on top of it rely
on connect(AF_UNSPEC) to async cancel pending I/O operations on TCP
socket.

The blamed commit below caused a regression, as such cancellation
can now fail.

As suggested by Eric, this change addresses the problem explicitly
causing blocking I/O operation to terminate immediately (with an error)
when a concurrent disconnect() is executed.

Instead of tracking the number of threads blocked on a given socket,
track the number of disconnect() issued on such socket. If such counter
changes after a blocking operation releasing and re-acquiring the socket
lock, error out the current operation.

Fixes: 4faeee0cf8a5 ("tcp: deny tcp_disconnect() when threads are waiting")
Reported-by: Tom Deseyn <tdeseyn@redhat.com>
Closes: https://bugzilla.redhat.com/show_bug.cgi?id=1886305
Suggested-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Paolo Abeni <pabeni@redhat.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://lore.kernel.org/r/f3b95e47e3dbed840960548aebaa8d954372db41.1697008693.git.pabeni@redhat.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

authored by

Paolo Abeni and committed by
Jakub Kicinski
419ce133 242e3450

+80 -45
+29 -7
drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls_io.c
··· 911 911 struct sock *sk, long *timeo_p) 912 912 { 913 913 DEFINE_WAIT_FUNC(wait, woken_wake_function); 914 - int err = 0; 914 + int ret, err = 0; 915 915 long current_timeo; 916 916 long vm_wait = 0; 917 917 bool noblock; ··· 942 942 943 943 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 944 944 sk->sk_write_pending++; 945 - sk_wait_event(sk, &current_timeo, sk->sk_err || 946 - (sk->sk_shutdown & SEND_SHUTDOWN) || 947 - (csk_mem_free(cdev, sk) && !vm_wait), &wait); 945 + ret = sk_wait_event(sk, &current_timeo, sk->sk_err || 946 + (sk->sk_shutdown & SEND_SHUTDOWN) || 947 + (csk_mem_free(cdev, sk) && !vm_wait), 948 + &wait); 948 949 sk->sk_write_pending--; 950 + if (ret < 0) 951 + goto do_error; 949 952 950 953 if (vm_wait) { 951 954 vm_wait -= current_timeo; ··· 1351 1348 int copied = 0; 1352 1349 int target; 1353 1350 long timeo; 1351 + int ret; 1354 1352 1355 1353 buffers_freed = 0; 1356 1354 ··· 1427 1423 if (copied >= target) 1428 1424 break; 1429 1425 chtls_cleanup_rbuf(sk, copied); 1430 - sk_wait_data(sk, &timeo, NULL); 1426 + ret = sk_wait_data(sk, &timeo, NULL); 1427 + if (ret < 0) { 1428 + copied = copied ? : ret; 1429 + goto unlock; 1430 + } 1431 1431 continue; 1432 1432 found_ok_skb: 1433 1433 if (!skb->len) { ··· 1526 1518 1527 1519 if (buffers_freed) 1528 1520 chtls_cleanup_rbuf(sk, copied); 1521 + 1522 + unlock: 1529 1523 release_sock(sk); 1530 1524 return copied; 1531 1525 } ··· 1544 1534 int copied = 0; 1545 1535 size_t avail; /* amount of available data in current skb */ 1546 1536 long timeo; 1537 + int ret; 1547 1538 1548 1539 lock_sock(sk); 1549 1540 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); ··· 1596 1585 release_sock(sk); 1597 1586 lock_sock(sk); 1598 1587 } else { 1599 - sk_wait_data(sk, &timeo, NULL); 1588 + ret = sk_wait_data(sk, &timeo, NULL); 1589 + if (ret < 0) { 1590 + /* here 'copied' is 0 due to previous checks */ 1591 + copied = ret; 1592 + break; 1593 + } 1600 1594 } 1601 1595 1602 1596 if (unlikely(peek_seq != tp->copied_seq)) { ··· 1672 1656 int copied = 0; 1673 1657 long timeo; 1674 1658 int target; /* Read at least this many bytes */ 1659 + int ret; 1675 1660 1676 1661 buffers_freed = 0; 1677 1662 ··· 1764 1747 if (copied >= target) 1765 1748 break; 1766 1749 chtls_cleanup_rbuf(sk, copied); 1767 - sk_wait_data(sk, &timeo, NULL); 1750 + ret = sk_wait_data(sk, &timeo, NULL); 1751 + if (ret < 0) { 1752 + copied = copied ? : ret; 1753 + goto unlock; 1754 + } 1768 1755 continue; 1769 1756 1770 1757 found_ok_skb: ··· 1837 1816 if (buffers_freed) 1838 1817 chtls_cleanup_rbuf(sk, copied); 1839 1818 1819 + unlock: 1840 1820 release_sock(sk); 1841 1821 return copied; 1842 1822 }
+4 -6
include/net/sock.h
··· 336 336 * @sk_cgrp_data: cgroup data for this cgroup 337 337 * @sk_memcg: this socket's memory cgroup association 338 338 * @sk_write_pending: a write to stream socket waits to start 339 - * @sk_wait_pending: number of threads blocked on this socket 339 + * @sk_disconnects: number of disconnect operations performed on this sock 340 340 * @sk_state_change: callback to indicate change in the state of the sock 341 341 * @sk_data_ready: callback to indicate there is data to be processed 342 342 * @sk_write_space: callback to indicate there is bf sending space available ··· 429 429 unsigned int sk_napi_id; 430 430 #endif 431 431 int sk_rcvbuf; 432 - int sk_wait_pending; 432 + int sk_disconnects; 433 433 434 434 struct sk_filter __rcu *sk_filter; 435 435 union { ··· 1189 1189 } 1190 1190 1191 1191 #define sk_wait_event(__sk, __timeo, __condition, __wait) \ 1192 - ({ int __rc; \ 1193 - __sk->sk_wait_pending++; \ 1192 + ({ int __rc, __dis = __sk->sk_disconnects; \ 1194 1193 release_sock(__sk); \ 1195 1194 __rc = __condition; \ 1196 1195 if (!__rc) { \ ··· 1199 1200 } \ 1200 1201 sched_annotate_sleep(); \ 1201 1202 lock_sock(__sk); \ 1202 - __sk->sk_wait_pending--; \ 1203 - __rc = __condition; \ 1203 + __rc = __dis == __sk->sk_disconnects ? __condition : -EPIPE; \ 1204 1204 __rc; \ 1205 1205 }) 1206 1206
+7 -5
net/core/stream.c
··· 117 117 */ 118 118 int sk_stream_wait_memory(struct sock *sk, long *timeo_p) 119 119 { 120 - int err = 0; 120 + int ret, err = 0; 121 121 long vm_wait = 0; 122 122 long current_timeo = *timeo_p; 123 123 DEFINE_WAIT_FUNC(wait, woken_wake_function); ··· 142 142 143 143 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags); 144 144 sk->sk_write_pending++; 145 - sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) || 146 - (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || 147 - (sk_stream_memory_free(sk) && 148 - !vm_wait), &wait); 145 + ret = sk_wait_event(sk, &current_timeo, READ_ONCE(sk->sk_err) || 146 + (READ_ONCE(sk->sk_shutdown) & SEND_SHUTDOWN) || 147 + (sk_stream_memory_free(sk) && !vm_wait), 148 + &wait); 149 149 sk->sk_write_pending--; 150 + if (ret < 0) 151 + goto do_error; 150 152 151 153 if (vm_wait) { 152 154 vm_wait -= current_timeo;
+8 -2
net/ipv4/af_inet.c
··· 597 597 598 598 add_wait_queue(sk_sleep(sk), &wait); 599 599 sk->sk_write_pending += writebias; 600 - sk->sk_wait_pending++; 601 600 602 601 /* Basic assumption: if someone sets sk->sk_err, he _must_ 603 602 * change state of the socket from TCP_SYN_*. ··· 612 613 } 613 614 remove_wait_queue(sk_sleep(sk), &wait); 614 615 sk->sk_write_pending -= writebias; 615 - sk->sk_wait_pending--; 616 616 return timeo; 617 617 } 618 618 ··· 640 642 return -EINVAL; 641 643 642 644 if (uaddr->sa_family == AF_UNSPEC) { 645 + sk->sk_disconnects++; 643 646 err = sk->sk_prot->disconnect(sk, flags); 644 647 sock->state = err ? SS_DISCONNECTING : SS_UNCONNECTED; 645 648 goto out; ··· 695 696 int writebias = (sk->sk_protocol == IPPROTO_TCP) && 696 697 tcp_sk(sk)->fastopen_req && 697 698 tcp_sk(sk)->fastopen_req->data ? 1 : 0; 699 + int dis = sk->sk_disconnects; 698 700 699 701 /* Error code is set above */ 700 702 if (!timeo || !inet_wait_for_connect(sk, timeo, writebias)) ··· 704 704 err = sock_intr_errno(timeo); 705 705 if (signal_pending(current)) 706 706 goto out; 707 + 708 + if (dis != sk->sk_disconnects) { 709 + err = -EPIPE; 710 + goto out; 711 + } 707 712 } 708 713 709 714 /* Connection was closed by RST, timeout, ICMP error ··· 730 725 sock_error: 731 726 err = sock_error(sk) ? : -ECONNABORTED; 732 727 sock->state = SS_UNCONNECTED; 728 + sk->sk_disconnects++; 733 729 if (sk->sk_prot->disconnect(sk, flags)) 734 730 sock->state = SS_DISCONNECTING; 735 731 goto out;
-1
net/ipv4/inet_connection_sock.c
··· 1145 1145 if (newsk) { 1146 1146 struct inet_connection_sock *newicsk = inet_csk(newsk); 1147 1147 1148 - newsk->sk_wait_pending = 0; 1149 1148 inet_sk_set_state(newsk, TCP_SYN_RECV); 1150 1149 newicsk->icsk_bind_hash = NULL; 1151 1150 newicsk->icsk_bind2_hash = NULL;
+8 -8
net/ipv4/tcp.c
··· 831 831 */ 832 832 if (!skb_queue_empty(&sk->sk_receive_queue)) 833 833 break; 834 - sk_wait_data(sk, &timeo, NULL); 834 + ret = sk_wait_data(sk, &timeo, NULL); 835 + if (ret < 0) 836 + break; 835 837 if (signal_pending(current)) { 836 838 ret = sock_intr_errno(timeo); 837 839 break; ··· 2444 2442 __sk_flush_backlog(sk); 2445 2443 } else { 2446 2444 tcp_cleanup_rbuf(sk, copied); 2447 - sk_wait_data(sk, &timeo, last); 2445 + err = sk_wait_data(sk, &timeo, last); 2446 + if (err < 0) { 2447 + err = copied ? : err; 2448 + goto out; 2449 + } 2448 2450 } 2449 2451 2450 2452 if ((flags & MSG_PEEK) && ··· 2971 2965 struct tcp_sock *tp = tcp_sk(sk); 2972 2966 int old_state = sk->sk_state; 2973 2967 u32 seq; 2974 - 2975 - /* Deny disconnect if other threads are blocked in sk_wait_event() 2976 - * or inet_wait_for_connect(). 2977 - */ 2978 - if (sk->sk_wait_pending) 2979 - return -EBUSY; 2980 2968 2981 2969 if (old_state != TCP_CLOSE) 2982 2970 tcp_set_state(sk, TCP_CLOSE);
+4
net/ipv4/tcp_bpf.c
··· 307 307 } 308 308 309 309 data = tcp_msg_wait_data(sk, psock, timeo); 310 + if (data < 0) 311 + return data; 310 312 if (data && !sk_psock_queue_empty(psock)) 311 313 goto msg_bytes_ready; 312 314 copied = -EAGAIN; ··· 353 351 354 352 timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 355 353 data = tcp_msg_wait_data(sk, psock, timeo); 354 + if (data < 0) 355 + return data; 356 356 if (data) { 357 357 if (!sk_psock_queue_empty(psock)) 358 358 goto msg_bytes_ready;
-7
net/mptcp/protocol.c
··· 3098 3098 { 3099 3099 struct mptcp_sock *msk = mptcp_sk(sk); 3100 3100 3101 - /* Deny disconnect if other threads are blocked in sk_wait_event() 3102 - * or inet_wait_for_connect(). 3103 - */ 3104 - if (sk->sk_wait_pending) 3105 - return -EBUSY; 3106 - 3107 3101 /* We are on the fastopen error path. We can't call straight into the 3108 3102 * subflows cleanup code due to lock nesting (we are already under 3109 3103 * msk->firstsocket lock). ··· 3167 3173 inet_sk(nsk)->pinet6 = mptcp_inet6_sk(nsk); 3168 3174 #endif 3169 3175 3170 - nsk->sk_wait_pending = 0; 3171 3176 __mptcp_init_sock(nsk); 3172 3177 3173 3178 msk = mptcp_sk(nsk);
+7 -3
net/tls/tls_main.c
··· 139 139 140 140 int wait_on_pending_writer(struct sock *sk, long *timeo) 141 141 { 142 - int rc = 0; 143 142 DEFINE_WAIT_FUNC(wait, woken_wake_function); 143 + int ret, rc = 0; 144 144 145 145 add_wait_queue(sk_sleep(sk), &wait); 146 146 while (1) { ··· 154 154 break; 155 155 } 156 156 157 - if (sk_wait_event(sk, timeo, 158 - !READ_ONCE(sk->sk_write_pending), &wait)) 157 + ret = sk_wait_event(sk, timeo, 158 + !READ_ONCE(sk->sk_write_pending), &wait); 159 + if (ret) { 160 + if (ret < 0) 161 + rc = ret; 159 162 break; 163 + } 160 164 } 161 165 remove_wait_queue(sk_sleep(sk), &wait); 162 166 return rc;
+13 -6
net/tls/tls_sw.c
··· 1291 1291 struct tls_context *tls_ctx = tls_get_ctx(sk); 1292 1292 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1293 1293 DEFINE_WAIT_FUNC(wait, woken_wake_function); 1294 + int ret = 0; 1294 1295 long timeo; 1295 1296 1296 1297 timeo = sock_rcvtimeo(sk, nonblock); ··· 1302 1301 1303 1302 if (sk->sk_err) 1304 1303 return sock_error(sk); 1304 + 1305 + if (ret < 0) 1306 + return ret; 1305 1307 1306 1308 if (!skb_queue_empty(&sk->sk_receive_queue)) { 1307 1309 tls_strp_check_rcv(&ctx->strp); ··· 1324 1320 released = true; 1325 1321 add_wait_queue(sk_sleep(sk), &wait); 1326 1322 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1327 - sk_wait_event(sk, &timeo, 1328 - tls_strp_msg_ready(ctx) || 1329 - !sk_psock_queue_empty(psock), 1330 - &wait); 1323 + ret = sk_wait_event(sk, &timeo, 1324 + tls_strp_msg_ready(ctx) || 1325 + !sk_psock_queue_empty(psock), 1326 + &wait); 1331 1327 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 1332 1328 remove_wait_queue(sk_sleep(sk), &wait); 1333 1329 ··· 1856 1852 bool nonblock) 1857 1853 { 1858 1854 long timeo; 1855 + int ret; 1859 1856 1860 1857 timeo = sock_rcvtimeo(sk, nonblock); 1861 1858 ··· 1866 1861 ctx->reader_contended = 1; 1867 1862 1868 1863 add_wait_queue(&ctx->wq, &wait); 1869 - sk_wait_event(sk, &timeo, 1870 - !READ_ONCE(ctx->reader_present), &wait); 1864 + ret = sk_wait_event(sk, &timeo, 1865 + !READ_ONCE(ctx->reader_present), &wait); 1871 1866 remove_wait_queue(&ctx->wq, &wait); 1872 1867 1873 1868 if (timeo <= 0) 1874 1869 return -EAGAIN; 1875 1870 if (signal_pending(current)) 1876 1871 return sock_intr_errno(timeo); 1872 + if (ret < 0) 1873 + return ret; 1877 1874 } 1878 1875 1879 1876 WRITE_ONCE(ctx->reader_present, 1);