Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

af_unix: Add unix_stream_proto for sockmap

Previously, sockmap for AF_UNIX protocol only supports
dgram type. This patch add unix stream type support, which
is similar to unix_dgram_proto. To support sockmap, dgram
and stream cannot share the same unix_proto anymore, because
they have different implementations, such as unhash for stream
type (which will remove closed or disconnected sockets from the map),
so rename unix_proto to unix_dgram_proto and add a new
unix_stream_proto.

Also implement stream related sockmap functions.
And add dgram key words to those dgram specific functions.

Signed-off-by: Jiang Wang <jiang.wang@bytedance.com>
Signed-off-by: Andrii Nakryiko <andrii@kernel.org>
Reviewed-by: Cong Wang <cong.wang@bytedance.com>
Acked-by: Jakub Sitnicki <jakub@cloudflare.com>
Acked-by: John Fastabend <john.fastabend@gmail.com>
Link: https://lore.kernel.org/bpf/20210816190327.2739291-3-jiang.wang@bytedance.com

authored by

Jiang Wang and committed by
Andrii Nakryiko
94531cfc 77462de1

+148 -37
+6 -2
include/net/af_unix.h
··· 87 87 88 88 int __unix_dgram_recvmsg(struct sock *sk, struct msghdr *msg, size_t size, 89 89 int flags); 90 + int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, size_t size, 91 + int flags); 90 92 #ifdef CONFIG_SYSCTL 91 93 int unix_sysctl_register(struct net *net); 92 94 void unix_sysctl_unregister(struct net *net); ··· 98 96 #endif 99 97 100 98 #ifdef CONFIG_BPF_SYSCALL 101 - extern struct proto unix_proto; 99 + extern struct proto unix_dgram_proto; 100 + extern struct proto unix_stream_proto; 102 101 103 - int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore); 102 + int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore); 103 + int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore); 104 104 void __init unix_bpf_build_proto(void); 105 105 #else 106 106 static inline void __init unix_bpf_build_proto(void)
+1
net/core/sock_map.c
··· 1494 1494 rcu_read_unlock(); 1495 1495 saved_unhash(sk); 1496 1496 } 1497 + EXPORT_SYMBOL_GPL(sock_map_unhash); 1497 1498 1498 1499 void sock_map_close(struct sock *sk, long timeout) 1499 1500 {
+70 -13
net/unix/af_unix.c
··· 798 798 */ 799 799 } 800 800 801 - struct proto unix_proto = { 802 - .name = "UNIX", 801 + static void unix_unhash(struct sock *sk) 802 + { 803 + /* Nothing to do here, unix socket does not need a ->unhash(). 804 + * This is merely for sockmap. 805 + */ 806 + } 807 + 808 + struct proto unix_dgram_proto = { 809 + .name = "UNIX-DGRAM", 803 810 .owner = THIS_MODULE, 804 811 .obj_size = sizeof(struct unix_sock), 805 812 .close = unix_close, 806 813 #ifdef CONFIG_BPF_SYSCALL 807 - .psock_update_sk_prot = unix_bpf_update_proto, 814 + .psock_update_sk_prot = unix_dgram_bpf_update_proto, 808 815 #endif 809 816 }; 810 817 811 - static struct sock *unix_create1(struct net *net, struct socket *sock, int kern) 818 + struct proto unix_stream_proto = { 819 + .name = "UNIX-STREAM", 820 + .owner = THIS_MODULE, 821 + .obj_size = sizeof(struct unix_sock), 822 + .close = unix_close, 823 + .unhash = unix_unhash, 824 + #ifdef CONFIG_BPF_SYSCALL 825 + .psock_update_sk_prot = unix_stream_bpf_update_proto, 826 + #endif 827 + }; 828 + 829 + static struct sock *unix_create1(struct net *net, struct socket *sock, int kern, int type) 812 830 { 813 831 struct sock *sk = NULL; 814 832 struct unix_sock *u; ··· 835 817 if (atomic_long_read(&unix_nr_socks) > 2 * get_max_files()) 836 818 goto out; 837 819 838 - sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_proto, kern); 820 + if (type == SOCK_STREAM) 821 + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_stream_proto, kern); 822 + else /*dgram and seqpacket */ 823 + sk = sk_alloc(net, PF_UNIX, GFP_KERNEL, &unix_dgram_proto, kern); 824 + 839 825 if (!sk) 840 826 goto out; 841 827 ··· 901 879 return -ESOCKTNOSUPPORT; 902 880 } 903 881 904 - return unix_create1(net, sock, kern) ? 0 : -ENOMEM; 882 + return unix_create1(net, sock, kern, sock->type) ? 0 : -ENOMEM; 905 883 } 906 884 907 885 static int unix_release(struct socket *sock) ··· 1315 1293 err = -ENOMEM; 1316 1294 1317 1295 /* create new sock for complete connection */ 1318 - newsk = unix_create1(sock_net(sk), NULL, 0); 1296 + newsk = unix_create1(sock_net(sk), NULL, 0, sock->type); 1319 1297 if (newsk == NULL) 1320 1298 goto out; 1321 1299 ··· 2345 2323 struct sock *sk = sock->sk; 2346 2324 2347 2325 #ifdef CONFIG_BPF_SYSCALL 2348 - if (sk->sk_prot != &unix_proto) 2349 - return sk->sk_prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT, 2326 + const struct proto *prot = READ_ONCE(sk->sk_prot); 2327 + 2328 + if (prot != &unix_dgram_proto) 2329 + return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT, 2350 2330 flags & ~MSG_DONTWAIT, NULL); 2351 2331 #endif 2352 2332 return __unix_dgram_recvmsg(sk, msg, size, flags); ··· 2752 2728 return ret ?: chunk; 2753 2729 } 2754 2730 2731 + int __unix_stream_recvmsg(struct sock *sk, struct msghdr *msg, 2732 + size_t size, int flags) 2733 + { 2734 + struct unix_stream_read_state state = { 2735 + .recv_actor = unix_stream_read_actor, 2736 + .socket = sk->sk_socket, 2737 + .msg = msg, 2738 + .size = size, 2739 + .flags = flags 2740 + }; 2741 + 2742 + return unix_stream_read_generic(&state, true); 2743 + } 2744 + 2755 2745 static int unix_stream_recvmsg(struct socket *sock, struct msghdr *msg, 2756 2746 size_t size, int flags) 2757 2747 { ··· 2777 2739 .flags = flags 2778 2740 }; 2779 2741 2742 + #ifdef CONFIG_BPF_SYSCALL 2743 + struct sock *sk = sock->sk; 2744 + const struct proto *prot = READ_ONCE(sk->sk_prot); 2745 + 2746 + if (prot != &unix_stream_proto) 2747 + return prot->recvmsg(sk, msg, size, flags & MSG_DONTWAIT, 2748 + flags & ~MSG_DONTWAIT, NULL); 2749 + #endif 2780 2750 return unix_stream_read_generic(&state, true); 2781 2751 } 2782 2752 ··· 2845 2799 (sk->sk_type == SOCK_STREAM || sk->sk_type == SOCK_SEQPACKET)) { 2846 2800 2847 2801 int peer_mode = 0; 2802 + const struct proto *prot = READ_ONCE(other->sk_prot); 2848 2803 2804 + prot->unhash(other); 2849 2805 if (mode&RCV_SHUTDOWN) 2850 2806 peer_mode |= SEND_SHUTDOWN; 2851 2807 if (mode&SEND_SHUTDOWN) ··· 2856 2808 other->sk_shutdown |= peer_mode; 2857 2809 unix_state_unlock(other); 2858 2810 other->sk_state_change(other); 2859 - if (peer_mode == SHUTDOWN_MASK) 2811 + if (peer_mode == SHUTDOWN_MASK) { 2860 2812 sk_wake_async(other, SOCK_WAKE_WAITD, POLL_HUP); 2861 - else if (peer_mode & RCV_SHUTDOWN) 2813 + other->sk_state = TCP_CLOSE; 2814 + } else if (peer_mode & RCV_SHUTDOWN) { 2862 2815 sk_wake_async(other, SOCK_WAKE_WAITD, POLL_IN); 2816 + } 2863 2817 } 2864 2818 if (other) 2865 2819 sock_put(other); ··· 3339 3289 3340 3290 BUILD_BUG_ON(sizeof(struct unix_skb_parms) > sizeof_field(struct sk_buff, cb)); 3341 3291 3342 - rc = proto_register(&unix_proto, 1); 3292 + rc = proto_register(&unix_dgram_proto, 1); 3293 + if (rc != 0) { 3294 + pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); 3295 + goto out; 3296 + } 3297 + 3298 + rc = proto_register(&unix_stream_proto, 1); 3343 3299 if (rc != 0) { 3344 3300 pr_crit("%s: Cannot create unix_sock SLAB cache!\n", __func__); 3345 3301 goto out; ··· 3366 3310 static void __exit af_unix_exit(void) 3367 3311 { 3368 3312 sock_unregister(PF_UNIX); 3369 - proto_unregister(&unix_proto); 3313 + proto_unregister(&unix_dgram_proto); 3314 + proto_unregister(&unix_stream_proto); 3370 3315 unregister_pernet_subsys(&unix_net_ops); 3371 3316 } 3372 3317
+71 -22
net/unix/unix_bpf.c
··· 38 38 return ret; 39 39 } 40 40 41 - static int unix_dgram_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 42 - size_t len, int nonblock, int flags, 43 - int *addr_len) 41 + static int __unix_recvmsg(struct sock *sk, struct msghdr *msg, 42 + size_t len, int flags) 43 + { 44 + if (sk->sk_type == SOCK_DGRAM) 45 + return __unix_dgram_recvmsg(sk, msg, len, flags); 46 + else 47 + return __unix_stream_recvmsg(sk, msg, len, flags); 48 + } 49 + 50 + static int unix_bpf_recvmsg(struct sock *sk, struct msghdr *msg, 51 + size_t len, int nonblock, int flags, 52 + int *addr_len) 44 53 { 45 54 struct unix_sock *u = unix_sk(sk); 46 55 struct sk_psock *psock; ··· 57 48 58 49 psock = sk_psock_get(sk); 59 50 if (unlikely(!psock)) 60 - return __unix_dgram_recvmsg(sk, msg, len, flags); 51 + return __unix_recvmsg(sk, msg, len, flags); 61 52 62 53 mutex_lock(&u->iolock); 63 54 if (!skb_queue_empty(&sk->sk_receive_queue) && 64 55 sk_psock_queue_empty(psock)) { 65 56 mutex_unlock(&u->iolock); 66 57 sk_psock_put(sk, psock); 67 - return __unix_dgram_recvmsg(sk, msg, len, flags); 58 + return __unix_recvmsg(sk, msg, len, flags); 68 59 } 69 60 70 61 msg_bytes_ready: ··· 80 71 goto msg_bytes_ready; 81 72 mutex_unlock(&u->iolock); 82 73 sk_psock_put(sk, psock); 83 - return __unix_dgram_recvmsg(sk, msg, len, flags); 74 + return __unix_recvmsg(sk, msg, len, flags); 84 75 } 85 76 copied = -EAGAIN; 86 77 } ··· 89 80 return copied; 90 81 } 91 82 92 - static struct proto *unix_prot_saved __read_mostly; 93 - static DEFINE_SPINLOCK(unix_prot_lock); 94 - static struct proto unix_bpf_prot; 83 + static struct proto *unix_dgram_prot_saved __read_mostly; 84 + static DEFINE_SPINLOCK(unix_dgram_prot_lock); 85 + static struct proto unix_dgram_bpf_prot; 95 86 96 - static void unix_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 87 + static struct proto *unix_stream_prot_saved __read_mostly; 88 + static DEFINE_SPINLOCK(unix_stream_prot_lock); 89 + static struct proto unix_stream_bpf_prot; 90 + 91 + static void unix_dgram_bpf_rebuild_protos(struct proto *prot, const struct proto *base) 97 92 { 98 93 *prot = *base; 99 94 prot->close = sock_map_close; 100 - prot->recvmsg = unix_dgram_bpf_recvmsg; 95 + prot->recvmsg = unix_bpf_recvmsg; 101 96 } 102 97 103 - static void unix_bpf_check_needs_rebuild(struct proto *ops) 98 + static void unix_stream_bpf_rebuild_protos(struct proto *prot, 99 + const struct proto *base) 104 100 { 105 - if (unlikely(ops != smp_load_acquire(&unix_prot_saved))) { 106 - spin_lock_bh(&unix_prot_lock); 107 - if (likely(ops != unix_prot_saved)) { 108 - unix_bpf_rebuild_protos(&unix_bpf_prot, ops); 109 - smp_store_release(&unix_prot_saved, ops); 101 + *prot = *base; 102 + prot->close = sock_map_close; 103 + prot->recvmsg = unix_bpf_recvmsg; 104 + prot->unhash = sock_map_unhash; 105 + } 106 + 107 + static void unix_dgram_bpf_check_needs_rebuild(struct proto *ops) 108 + { 109 + if (unlikely(ops != smp_load_acquire(&unix_dgram_prot_saved))) { 110 + spin_lock_bh(&unix_dgram_prot_lock); 111 + if (likely(ops != unix_dgram_prot_saved)) { 112 + unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, ops); 113 + smp_store_release(&unix_dgram_prot_saved, ops); 110 114 } 111 - spin_unlock_bh(&unix_prot_lock); 115 + spin_unlock_bh(&unix_dgram_prot_lock); 112 116 } 113 117 } 114 118 115 - int unix_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 119 + static void unix_stream_bpf_check_needs_rebuild(struct proto *ops) 120 + { 121 + if (unlikely(ops != smp_load_acquire(&unix_stream_prot_saved))) { 122 + spin_lock_bh(&unix_stream_prot_lock); 123 + if (likely(ops != unix_stream_prot_saved)) { 124 + unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, ops); 125 + smp_store_release(&unix_stream_prot_saved, ops); 126 + } 127 + spin_unlock_bh(&unix_stream_prot_lock); 128 + } 129 + } 130 + 131 + int unix_dgram_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 116 132 { 117 133 if (sk->sk_type != SOCK_DGRAM) 118 134 return -EOPNOTSUPP; ··· 148 114 return 0; 149 115 } 150 116 151 - unix_bpf_check_needs_rebuild(psock->sk_proto); 152 - WRITE_ONCE(sk->sk_prot, &unix_bpf_prot); 117 + unix_dgram_bpf_check_needs_rebuild(psock->sk_proto); 118 + WRITE_ONCE(sk->sk_prot, &unix_dgram_bpf_prot); 119 + return 0; 120 + } 121 + 122 + int unix_stream_bpf_update_proto(struct sock *sk, struct sk_psock *psock, bool restore) 123 + { 124 + if (restore) { 125 + sk->sk_write_space = psock->saved_write_space; 126 + WRITE_ONCE(sk->sk_prot, psock->sk_proto); 127 + return 0; 128 + } 129 + 130 + unix_stream_bpf_check_needs_rebuild(psock->sk_proto); 131 + WRITE_ONCE(sk->sk_prot, &unix_stream_bpf_prot); 153 132 return 0; 154 133 } 155 134 156 135 void __init unix_bpf_build_proto(void) 157 136 { 158 - unix_bpf_rebuild_protos(&unix_bpf_prot, &unix_proto); 137 + unix_dgram_bpf_rebuild_protos(&unix_dgram_bpf_prot, &unix_dgram_proto); 138 + unix_stream_bpf_rebuild_protos(&unix_stream_bpf_prot, &unix_stream_proto); 139 + 159 140 }