Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

net: Track socket refcounts in skb_steal_sock()

Refactor the UDP/TCP handlers slightly to allow skb_steal_sock() to make
the determination of whether the socket is reference counted in the case
where it is prefetched by earlier logic such as early_demux.

Signed-off-by: Joe Stringer <joe@wand.net.nz>
Signed-off-by: Alexei Starovoitov <ast@kernel.org>
Acked-by: Martin KaFai Lau <kafai@fb.com>
Link: https://lore.kernel.org/bpf/20200329225342.16317-3-joe@wand.net.nz

authored by

Joe Stringer and committed by
Alexei Starovoitov
71489e21 cf7fbe66

+21 -10
+1 -2
include/net/inet6_hashtables.h
··· 85 85 int iif, int sdif, 86 86 bool *refcounted) 87 87 { 88 - struct sock *sk = skb_steal_sock(skb); 88 + struct sock *sk = skb_steal_sock(skb, refcounted); 89 89 90 - *refcounted = true; 91 90 if (sk) 92 91 return sk; 93 92
+1 -2
include/net/inet_hashtables.h
··· 379 379 const int sdif, 380 380 bool *refcounted) 381 381 { 382 - struct sock *sk = skb_steal_sock(skb); 382 + struct sock *sk = skb_steal_sock(skb, refcounted); 383 383 const struct iphdr *iph = ip_hdr(skb); 384 384 385 - *refcounted = true; 386 385 if (sk) 387 386 return sk; 388 387
+9 -1
include/net/sock.h
··· 2537 2537 #endif /* CONFIG_INET */ 2538 2538 } 2539 2539 2540 - static inline struct sock *skb_steal_sock(struct sk_buff *skb) 2540 + /** 2541 + * skb_steal_sock 2542 + * @skb to steal the socket from 2543 + * @refcounted is set to true if the socket is reference-counted 2544 + */ 2545 + static inline struct sock * 2546 + skb_steal_sock(struct sk_buff *skb, bool *refcounted) 2541 2547 { 2542 2548 if (skb->sk) { 2543 2549 struct sock *sk = skb->sk; 2544 2550 2551 + *refcounted = true; 2545 2552 skb->destructor = NULL; 2546 2553 skb->sk = NULL; 2547 2554 return sk; 2548 2555 } 2556 + *refcounted = false; 2549 2557 return NULL; 2550 2558 } 2551 2559
+4 -2
net/ipv4/udp.c
··· 2288 2288 struct rtable *rt = skb_rtable(skb); 2289 2289 __be32 saddr, daddr; 2290 2290 struct net *net = dev_net(skb->dev); 2291 + bool refcounted; 2291 2292 2292 2293 /* 2293 2294 * Validate the packet. ··· 2314 2313 if (udp4_csum_init(skb, uh, proto)) 2315 2314 goto csum_error; 2316 2315 2317 - sk = skb_steal_sock(skb); 2316 + sk = skb_steal_sock(skb, &refcounted); 2318 2317 if (sk) { 2319 2318 struct dst_entry *dst = skb_dst(skb); 2320 2319 int ret; ··· 2323 2322 udp_sk_rx_dst_set(sk, dst); 2324 2323 2325 2324 ret = udp_unicast_rcv_skb(sk, skb, uh); 2326 - sock_put(sk); 2325 + if (refcounted) 2326 + sock_put(sk); 2327 2327 return ret; 2328 2328 } 2329 2329
+6 -3
net/ipv6/udp.c
··· 843 843 struct net *net = dev_net(skb->dev); 844 844 struct udphdr *uh; 845 845 struct sock *sk; 846 + bool refcounted; 846 847 u32 ulen = 0; 847 848 848 849 if (!pskb_may_pull(skb, sizeof(struct udphdr))) ··· 880 879 goto csum_error; 881 880 882 881 /* Check if the socket is already available, e.g. due to early demux */ 883 - sk = skb_steal_sock(skb); 882 + sk = skb_steal_sock(skb, &refcounted); 884 883 if (sk) { 885 884 struct dst_entry *dst = skb_dst(skb); 886 885 int ret; ··· 889 888 udp6_sk_rx_dst_set(sk, dst); 890 889 891 890 if (!uh->check && !udp_sk(sk)->no_check6_rx) { 892 - sock_put(sk); 891 + if (refcounted) 892 + sock_put(sk); 893 893 goto report_csum_error; 894 894 } 895 895 896 896 ret = udp6_unicast_rcv_skb(sk, skb, uh); 897 - sock_put(sk); 897 + if (refcounted) 898 + sock_put(sk); 898 899 return ret; 899 900 } 900 901