Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

vsock: support sockmap

This patch adds sockmap support for vsock sockets. It is intended to be
usable by all transports, but only the virtio and loopback transports
are implemented.

SOCK_STREAM, SOCK_DGRAM, and SOCK_SEQPACKET are all supported.

Signed-off-by: Bobby Eshleman <bobby.eshleman@bytedance.com>
Acked-by: Michael S. Tsirkin <mst@redhat.com>
Reviewed-by: Stefano Garzarella <sgarzare@redhat.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Bobby Eshleman and committed by
David S. Miller
634f1a71 24265c2c

+281 -6
+1
drivers/vhost/vsock.c
··· 439 439 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 440 440 .notify_buffer_size = virtio_transport_notify_buffer_size, 441 441 442 + .read_skb = virtio_transport_read_skb, 442 443 }, 443 444 444 445 .send_pkt = vhost_transport_send_pkt,
+1
include/linux/virtio_vsock.h
··· 245 245 void virtio_transport_put_credit(struct virtio_vsock_sock *vvs, u32 credit); 246 246 void virtio_transport_deliver_tap_pkt(struct sk_buff *skb); 247 247 int virtio_transport_purge_skbs(void *vsk, struct sk_buff_head *list); 248 + int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t read_actor); 248 249 #endif /* _LINUX_VIRTIO_VSOCK_H */
+17
include/net/af_vsock.h
··· 75 75 void *trans; 76 76 }; 77 77 78 + s64 vsock_connectible_has_data(struct vsock_sock *vsk); 78 79 s64 vsock_stream_has_data(struct vsock_sock *vsk); 79 80 s64 vsock_stream_has_space(struct vsock_sock *vsk); 80 81 struct sock *vsock_create_connected(struct sock *parent); ··· 174 173 175 174 /* Addressing. */ 176 175 u32 (*get_local_cid)(void); 176 + 177 + /* Read a single skb */ 178 + int (*read_skb)(struct vsock_sock *, skb_read_actor_t); 177 179 }; 178 180 179 181 /**** CORE ****/ ··· 229 225 int vsock_add_tap(struct vsock_tap *vt); 230 226 int vsock_remove_tap(struct vsock_tap *vt); 231 227 void vsock_deliver_tap(struct sk_buff *build_skb(void *opaque), void *opaque); 228 + int vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 229 + int flags); 230 + int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, 231 + size_t len, int flags); 232 + 233 + #ifdef CONFIG_BPF_SYSCALL 234 + extern struct proto vsock_proto; 235 + int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore); 236 + void __init vsock_bpf_build_proto(void); 237 + #else 238 + static inline void __init vsock_bpf_build_proto(void) 239 + {} 240 + #endif 232 241 233 242 #endif /* __AF_VSOCK_H__ */
+1
net/vmw_vsock/Makefile
··· 8 8 obj-$(CONFIG_VSOCKETS_LOOPBACK) += vsock_loopback.o 9 9 10 10 vsock-y += af_vsock.o af_vsock_tap.o vsock_addr.o 11 + vsock-$(CONFIG_BPF_SYSCALL) += vsock_bpf.o 11 12 12 13 vsock_diag-y += diag.o 13 14
+58 -6
net/vmw_vsock/af_vsock.c
··· 116 116 static int vsock_queue_rcv_skb(struct sock *sk, struct sk_buff *skb); 117 117 118 118 /* Protocol family. */ 119 - static struct proto vsock_proto = { 119 + struct proto vsock_proto = { 120 120 .name = "AF_VSOCK", 121 121 .owner = THIS_MODULE, 122 122 .obj_size = sizeof(struct vsock_sock), 123 + #ifdef CONFIG_BPF_SYSCALL 124 + .psock_update_sk_prot = vsock_bpf_update_proto, 125 + #endif 123 126 }; 124 127 125 128 /* The default peer timeout indicates how long we will wait for a peer response ··· 868 865 } 869 866 EXPORT_SYMBOL_GPL(vsock_stream_has_data); 870 867 871 - static s64 vsock_connectible_has_data(struct vsock_sock *vsk) 868 + s64 vsock_connectible_has_data(struct vsock_sock *vsk) 872 869 { 873 870 struct sock *sk = sk_vsock(vsk); 874 871 ··· 877 874 else 878 875 return vsock_stream_has_data(vsk); 879 876 } 877 + EXPORT_SYMBOL_GPL(vsock_connectible_has_data); 880 878 881 879 s64 vsock_stream_has_space(struct vsock_sock *vsk) 882 880 { ··· 1135 1131 return mask; 1136 1132 } 1137 1133 1134 + static int vsock_read_skb(struct sock *sk, skb_read_actor_t read_actor) 1135 + { 1136 + struct vsock_sock *vsk = vsock_sk(sk); 1137 + 1138 + return vsk->transport->read_skb(vsk, read_actor); 1139 + } 1140 + 1138 1141 static int vsock_dgram_sendmsg(struct socket *sock, struct msghdr *msg, 1139 1142 size_t len) 1140 1143 { ··· 1253 1242 memcpy(&vsk->remote_addr, remote_addr, sizeof(vsk->remote_addr)); 1254 1243 sock->state = SS_CONNECTED; 1255 1244 1245 + /* sock map disallows redirection of non-TCP sockets with sk_state != 1246 + * TCP_ESTABLISHED (see sock_map_redirect_allowed()), so we set 1247 + * TCP_ESTABLISHED here to allow redirection of connected vsock dgrams. 1248 + * 1249 + * This doesn't seem to be abnormal state for datagram sockets, as the 1250 + * same approach can be see in other datagram socket types as well 1251 + * (such as unix sockets). 1252 + */ 1253 + sk->sk_state = TCP_ESTABLISHED; 1254 + 1256 1255 out: 1257 1256 release_sock(sk); 1258 1257 return err; 1259 1258 } 1260 1259 1261 - static int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, 1262 - size_t len, int flags) 1260 + int vsock_dgram_recvmsg(struct socket *sock, struct msghdr *msg, 1261 + size_t len, int flags) 1263 1262 { 1264 - struct vsock_sock *vsk = vsock_sk(sock->sk); 1263 + #ifdef CONFIG_BPF_SYSCALL 1264 + const struct proto *prot; 1265 + #endif 1266 + struct vsock_sock *vsk; 1267 + struct sock *sk; 1268 + 1269 + sk = sock->sk; 1270 + vsk = vsock_sk(sk); 1271 + 1272 + #ifdef CONFIG_BPF_SYSCALL 1273 + prot = READ_ONCE(sk->sk_prot); 1274 + if (prot != &vsock_proto) 1275 + return prot->recvmsg(sk, msg, len, flags, NULL); 1276 + #endif 1265 1277 1266 1278 return vsk->transport->dgram_dequeue(vsk, msg, len, flags); 1267 1279 } 1280 + EXPORT_SYMBOL_GPL(vsock_dgram_recvmsg); 1268 1281 1269 1282 static const struct proto_ops vsock_dgram_ops = { 1270 1283 .family = PF_VSOCK, ··· 1307 1272 .recvmsg = vsock_dgram_recvmsg, 1308 1273 .mmap = sock_no_mmap, 1309 1274 .sendpage = sock_no_sendpage, 1275 + .read_skb = vsock_read_skb, 1310 1276 }; 1311 1277 1312 1278 static int vsock_transport_cancel_pkt(struct vsock_sock *vsk) ··· 2122 2086 return err; 2123 2087 } 2124 2088 2125 - static int 2089 + int 2126 2090 vsock_connectible_recvmsg(struct socket *sock, struct msghdr *msg, size_t len, 2127 2091 int flags) 2128 2092 { 2129 2093 struct sock *sk; 2130 2094 struct vsock_sock *vsk; 2131 2095 const struct vsock_transport *transport; 2096 + #ifdef CONFIG_BPF_SYSCALL 2097 + const struct proto *prot; 2098 + #endif 2132 2099 int err; 2133 2100 2134 2101 sk = sock->sk; ··· 2178 2139 goto out; 2179 2140 } 2180 2141 2142 + #ifdef CONFIG_BPF_SYSCALL 2143 + prot = READ_ONCE(sk->sk_prot); 2144 + if (prot != &vsock_proto) { 2145 + release_sock(sk); 2146 + return prot->recvmsg(sk, msg, len, flags, NULL); 2147 + } 2148 + #endif 2149 + 2181 2150 if (sk->sk_type == SOCK_STREAM) 2182 2151 err = __vsock_stream_recvmsg(sk, msg, len, flags); 2183 2152 else ··· 2195 2148 release_sock(sk); 2196 2149 return err; 2197 2150 } 2151 + EXPORT_SYMBOL_GPL(vsock_connectible_recvmsg); 2198 2152 2199 2153 static int vsock_set_rcvlowat(struct sock *sk, int val) 2200 2154 { ··· 2236 2188 .mmap = sock_no_mmap, 2237 2189 .sendpage = sock_no_sendpage, 2238 2190 .set_rcvlowat = vsock_set_rcvlowat, 2191 + .read_skb = vsock_read_skb, 2239 2192 }; 2240 2193 2241 2194 static const struct proto_ops vsock_seqpacket_ops = { ··· 2258 2209 .recvmsg = vsock_connectible_recvmsg, 2259 2210 .mmap = sock_no_mmap, 2260 2211 .sendpage = sock_no_sendpage, 2212 + .read_skb = vsock_read_skb, 2261 2213 }; 2262 2214 2263 2215 static int vsock_create(struct net *net, struct socket *sock, ··· 2397 2347 AF_VSOCK, err); 2398 2348 goto err_unregister_proto; 2399 2349 } 2350 + 2351 + vsock_bpf_build_proto(); 2400 2352 2401 2353 return 0; 2402 2354
+2
net/vmw_vsock/virtio_transport.c
··· 457 457 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, 458 458 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 459 459 .notify_buffer_size = virtio_transport_notify_buffer_size, 460 + 461 + .read_skb = virtio_transport_read_skb, 460 462 }, 461 463 462 464 .send_pkt = virtio_transport_send_pkt,
+25
net/vmw_vsock/virtio_transport_common.c
··· 1418 1418 } 1419 1419 EXPORT_SYMBOL_GPL(virtio_transport_purge_skbs); 1420 1420 1421 + int virtio_transport_read_skb(struct vsock_sock *vsk, skb_read_actor_t recv_actor) 1422 + { 1423 + struct virtio_vsock_sock *vvs = vsk->trans; 1424 + struct sock *sk = sk_vsock(vsk); 1425 + struct sk_buff *skb; 1426 + int off = 0; 1427 + int copied; 1428 + int err; 1429 + 1430 + spin_lock_bh(&vvs->rx_lock); 1431 + /* Use __skb_recv_datagram() for race-free handling of the receive. It 1432 + * works for types other than dgrams. 1433 + */ 1434 + skb = __skb_recv_datagram(sk, &vvs->rx_queue, MSG_DONTWAIT, &off, &err); 1435 + spin_unlock_bh(&vvs->rx_lock); 1436 + 1437 + if (!skb) 1438 + return err; 1439 + 1440 + copied = recv_actor(sk, skb); 1441 + kfree_skb(skb); 1442 + return copied; 1443 + } 1444 + EXPORT_SYMBOL_GPL(virtio_transport_read_skb); 1445 + 1421 1446 MODULE_LICENSE("GPL v2"); 1422 1447 MODULE_AUTHOR("Asias He"); 1423 1448 MODULE_DESCRIPTION("common code for virtio vsock");
+174
net/vmw_vsock/vsock_bpf.c
··· 1 + // SPDX-License-Identifier: GPL-2.0 2 + /* Copyright (c) 2022 Bobby Eshleman <bobby.eshleman@bytedance.com> 3 + * 4 + * Based off of net/unix/unix_bpf.c 5 + */ 6 + 7 + #include <linux/bpf.h> 8 + #include <linux/module.h> 9 + #include <linux/skmsg.h> 10 + #include <linux/socket.h> 11 + #include <linux/wait.h> 12 + #include <net/af_vsock.h> 13 + #include <net/sock.h> 14 + 15 + #define vsock_sk_has_data(__sk, __psock) \ 16 + ({ !skb_queue_empty(&(__sk)->sk_receive_queue) || \ 17 + !skb_queue_empty(&(__psock)->ingress_skb) || \ 18 + !list_empty(&(__psock)->ingress_msg); \ 19 + }) 20 + 21 + static struct proto *vsock_prot_saved __read_mostly; 22 + static DEFINE_SPINLOCK(vsock_prot_lock); 23 + static struct proto vsock_bpf_prot; 24 + 25 + static bool vsock_has_data(struct sock *sk, struct sk_psock *psock) 26 + { 27 + struct vsock_sock *vsk = vsock_sk(sk); 28 + s64 ret; 29 + 30 + ret = vsock_connectible_has_data(vsk); 31 + if (ret > 0) 32 + return true; 33 + 34 + return vsock_sk_has_data(sk, psock); 35 + } 36 + 37 + static bool vsock_msg_wait_data(struct sock *sk, struct sk_psock *psock, long timeo) 38 + { 39 + bool ret; 40 + 41 + DEFINE_WAIT_FUNC(wait, woken_wake_function); 42 + 43 + if (sk->sk_shutdown & RCV_SHUTDOWN) 44 + return true; 45 + 46 + if (!timeo) 47 + return false; 48 + 49 + add_wait_queue(sk_sleep(sk), &wait); 50 + sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk); 51 + ret = vsock_has_data(sk, psock); 52 + if (!ret) { 53 + wait_woken(&wait, TASK_INTERRUPTIBLE, timeo); 54 + ret = vsock_has_data(sk, psock); 55 + } 56 + sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk); 57 + remove_wait_queue(sk_sleep(sk), &wait); 58 + return ret; 59 + } 60 + 61 + static int __vsock_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, int flags) 62 + { 63 + struct socket *sock = sk->sk_socket; 64 + int err; 65 + 66 + if (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET) 67 + err = vsock_connectible_recvmsg(sock, msg, len, flags); 68 + else if (sk->sk_type == SOCK_DGRAM) 69 + err = vsock_dgram_recvmsg(sock, msg, len, flags); 70 + else 71 + err = -EPROTOTYPE; 72 + 73 + return err; 74 + } 75 + 76 + static int vsock_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 77 + size_t len, int flags, int *addr_len) 78 + { 79 + struct sk_psock *psock; 80 + int copied; 81 + 82 + psock = sk_psock_get(sk); 83 + if (unlikely(!psock)) 84 + return __vsock_recvmsg(sk, msg, len, flags); 85 + 86 + lock_sock(sk); 87 + if (vsock_has_data(sk, psock) && sk_psock_queue_empty(psock)) { 88 + release_sock(sk); 89 + sk_psock_put(sk, psock); 90 + return __vsock_recvmsg(sk, msg, len, flags); 91 + } 92 + 93 + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 94 + while (copied == 0) { 95 + long timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT); 96 + 97 + if (!vsock_msg_wait_data(sk, psock, timeo)) { 98 + copied = -EAGAIN; 99 + break; 100 + } 101 + 102 + if (sk_psock_queue_empty(psock)) { 103 + release_sock(sk); 104 + sk_psock_put(sk, psock); 105 + return __vsock_recvmsg(sk, msg, len, flags); 106 + } 107 + 108 + copied = sk_msg_recvmsg(sk, psock, msg, len, flags); 109 + } 110 + 111 + release_sock(sk); 112 + sk_psock_put(sk, psock); 113 + 114 + return copied; 115 + } 116 + 117 + /* Copy of original proto with updated sock_map methods */ 118 + static struct proto vsock_bpf_prot = { 119 + .close = sock_map_close, 120 + .recvmsg = vsock_bpf_recvmsg, 121 + .sock_is_readable = sk_msg_is_readable, 122 + .unhash = sock_map_unhash, 123 + }; 124 + 125 + static void vsock_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 126 + { 127 + *prot = *base; 128 + prot->close = sock_map_close; 129 + prot->recvmsg = vsock_bpf_recvmsg; 130 + prot->sock_is_readable = sk_msg_is_readable; 131 + } 132 + 133 + static void vsock_bpf_check_needs_rebuild(struct proto *ops) 134 + { 135 + /* Paired with the smp_store_release() below. */ 136 + if (unlikely(ops != smp_load_acquire(&vsock_prot_saved))) { 137 + spin_lock_bh(&vsock_prot_lock); 138 + if (likely(ops != vsock_prot_saved)) { 139 + vsock_bpf_rebuild_protos(&vsock_bpf_prot, ops); 140 + /* Make sure proto function pointers are updated before publishing the 141 + * pointer to the struct. 142 + */ 143 + smp_store_release(&vsock_prot_saved, ops); 144 + } 145 + spin_unlock_bh(&vsock_prot_lock); 146 + } 147 + } 148 + 149 + int vsock_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 150 + { 151 + struct vsock_sock *vsk; 152 + 153 + if (restore) { 154 + sk->sk_write_space = psock->saved_write_space; 155 + sock_replace_proto(sk, psock->sk_proto); 156 + return 0; 157 + } 158 + 159 + vsk = vsock_sk(sk); 160 + if (!vsk->transport) 161 + return -ENODEV; 162 + 163 + if (!vsk->transport->read_skb) 164 + return -EOPNOTSUPP; 165 + 166 + vsock_bpf_check_needs_rebuild(psock->sk_proto); 167 + sock_replace_proto(sk, &vsock_bpf_prot); 168 + return 0; 169 + } 170 + 171 + void __init vsock_bpf_build_proto(void) 172 + { 173 + vsock_bpf_rebuild_protos(&vsock_bpf_prot, &vsock_proto); 174 + }
+2
net/vmw_vsock/vsock_loopback.c
··· 94 94 .notify_send_pre_enqueue = virtio_transport_notify_send_pre_enqueue, 95 95 .notify_send_post_enqueue = virtio_transport_notify_send_post_enqueue, 96 96 .notify_buffer_size = virtio_transport_notify_buffer_size, 97 + 98 + .read_skb = virtio_transport_read_skb, 97 99 }, 98 100 99 101 .send_pkt = vsock_loopback_send_pkt,