Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'net-fix-uaf-of-sk_dst_get-sk-dev'

Kuniyuki Iwashima says:

====================
net: Fix UAF of sk_dst_get(sk)->dev.

syzbot caught use-after-free of sk_dst_get(sk)->dev,
which was not fetched under RCU nor RTNL. [0]

Patch 1 ~ 5, 7 fix UAF in smc, tcp, ktls, mptcp
Patch 6 fixes dst ref leak in mptcp

[0]: https://lore.kernel.org/68c237c7.050a0220.3c6139.0036.GAE@google.com

v1: https://lore.kernel.org/20250911030620.1284754-1-kuniyu@google.com
====================

Link: https://patch.msgid.link/20250916214758.650211-1-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

+86 -78
+7 -2
net/mptcp/ctrl.c
··· 501 501 struct mptcp_pernet *pernet = mptcp_get_pernet(sock_net(sk)); 502 502 503 503 if (atomic_read(&pernet->active_disable_times)) { 504 - struct dst_entry *dst = sk_dst_get(sk); 504 + struct net_device *dev; 505 + struct dst_entry *dst; 505 506 506 - if (dst && dst->dev && (dst->dev->flags & IFF_LOOPBACK)) 507 + rcu_read_lock(); 508 + dst = __sk_dst_get(sk); 509 + dev = dst ? dst_dev_rcu(dst) : NULL; 510 + if (dev && (dev->flags & IFF_LOOPBACK)) 507 511 atomic_set(&pernet->active_disable_times, 0); 512 + rcu_read_unlock(); 508 513 } 509 514 } 510 515
+35 -32
net/smc/smc_clc.c
··· 509 509 } 510 510 511 511 /* find ipv4 addr on device and get the prefix len, fill CLC proposal msg */ 512 - static int smc_clc_prfx_set4_rcu(struct dst_entry *dst, __be32 ipv4, 512 + static int smc_clc_prfx_set4_rcu(struct net_device *dev, __be32 ipv4, 513 513 struct smc_clc_msg_proposal_prefix *prop) 514 514 { 515 - struct in_device *in_dev = __in_dev_get_rcu(dst->dev); 515 + struct in_device *in_dev = __in_dev_get_rcu(dev); 516 516 const struct in_ifaddr *ifa; 517 517 518 518 if (!in_dev) ··· 530 530 } 531 531 532 532 /* fill CLC proposal msg with ipv6 prefixes from device */ 533 - static int smc_clc_prfx_set6_rcu(struct dst_entry *dst, 533 + static int smc_clc_prfx_set6_rcu(struct net_device *dev, 534 534 struct smc_clc_msg_proposal_prefix *prop, 535 535 struct smc_clc_ipv6_prefix *ipv6_prfx) 536 536 { 537 537 #if IS_ENABLED(CONFIG_IPV6) 538 - struct inet6_dev *in6_dev = __in6_dev_get(dst->dev); 538 + struct inet6_dev *in6_dev = __in6_dev_get(dev); 539 539 struct inet6_ifaddr *ifa; 540 540 int cnt = 0; 541 541 ··· 564 564 struct smc_clc_msg_proposal_prefix *prop, 565 565 struct smc_clc_ipv6_prefix *ipv6_prfx) 566 566 { 567 - struct dst_entry *dst = sk_dst_get(clcsock->sk); 568 567 struct sockaddr_storage addrs; 569 568 struct sockaddr_in6 *addr6; 570 569 struct sockaddr_in *addr; 570 + struct net_device *dev; 571 + struct dst_entry *dst; 571 572 int rc = -ENOENT; 572 573 573 - if (!dst) { 574 - rc = -ENOTCONN; 575 - goto out; 576 - } 577 - if (!dst->dev) { 578 - rc = -ENODEV; 579 - goto out_rel; 580 - } 581 574 /* get address to which the internal TCP socket is bound */ 582 575 if (kernel_getsockname(clcsock, (struct sockaddr *)&addrs) < 0) 583 - goto out_rel; 576 + goto out; 577 + 584 578 /* analyze IP specific data of net_device belonging to TCP socket */ 585 579 addr6 = (struct sockaddr_in6 *)&addrs; 580 + 586 581 rcu_read_lock(); 582 + 583 + dst = __sk_dst_get(clcsock->sk); 584 + dev = dst ? dst_dev_rcu(dst) : NULL; 585 + if (!dev) { 586 + rc = -ENODEV; 587 + goto out_unlock; 588 + } 589 + 587 590 if (addrs.ss_family == PF_INET) { 588 591 /* IPv4 */ 589 592 addr = (struct sockaddr_in *)&addrs; 590 - rc = smc_clc_prfx_set4_rcu(dst, addr->sin_addr.s_addr, prop); 593 + rc = smc_clc_prfx_set4_rcu(dev, addr->sin_addr.s_addr, prop); 591 594 } else if (ipv6_addr_v4mapped(&addr6->sin6_addr)) { 592 595 /* mapped IPv4 address - peer is IPv4 only */ 593 - rc = smc_clc_prfx_set4_rcu(dst, addr6->sin6_addr.s6_addr32[3], 596 + rc = smc_clc_prfx_set4_rcu(dev, addr6->sin6_addr.s6_addr32[3], 594 597 prop); 595 598 } else { 596 599 /* IPv6 */ 597 - rc = smc_clc_prfx_set6_rcu(dst, prop, ipv6_prfx); 600 + rc = smc_clc_prfx_set6_rcu(dev, prop, ipv6_prfx); 598 601 } 602 + 603 + out_unlock: 599 604 rcu_read_unlock(); 600 - out_rel: 601 - dst_release(dst); 602 605 out: 603 606 return rc; 604 607 } ··· 657 654 int smc_clc_prfx_match(struct socket *clcsock, 658 655 struct smc_clc_msg_proposal_prefix *prop) 659 656 { 660 - struct dst_entry *dst = sk_dst_get(clcsock->sk); 657 + struct net_device *dev; 658 + struct dst_entry *dst; 661 659 int rc; 662 660 663 - if (!dst) { 664 - rc = -ENOTCONN; 661 + rcu_read_lock(); 662 + 663 + dst = __sk_dst_get(clcsock->sk); 664 + dev = dst ? dst_dev_rcu(dst) : NULL; 665 + if (!dev) { 666 + rc = -ENODEV; 665 667 goto out; 666 668 } 667 - if (!dst->dev) { 668 - rc = -ENODEV; 669 - goto out_rel; 670 - } 671 - rcu_read_lock(); 669 + 672 670 if (!prop->ipv6_prefixes_cnt) 673 - rc = smc_clc_prfx_match4_rcu(dst->dev, prop); 671 + rc = smc_clc_prfx_match4_rcu(dev, prop); 674 672 else 675 - rc = smc_clc_prfx_match6_rcu(dst->dev, prop); 676 - rcu_read_unlock(); 677 - out_rel: 678 - dst_release(dst); 673 + rc = smc_clc_prfx_match6_rcu(dev, prop); 679 674 out: 675 + rcu_read_unlock(); 676 + 680 677 return rc; 681 678 } 682 679
+12 -15
net/smc/smc_core.c
··· 1883 1883 /* Determine vlan of internal TCP socket. */ 1884 1884 int smc_vlan_by_tcpsk(struct socket *clcsock, struct smc_init_info *ini) 1885 1885 { 1886 - struct dst_entry *dst = sk_dst_get(clcsock->sk); 1887 1886 struct netdev_nested_priv priv; 1888 1887 struct net_device *ndev; 1888 + struct dst_entry *dst; 1889 1889 int rc = 0; 1890 1890 1891 1891 ini->vlan_id = 0; 1892 - if (!dst) { 1893 - rc = -ENOTCONN; 1892 + 1893 + rcu_read_lock(); 1894 + 1895 + dst = __sk_dst_get(clcsock->sk); 1896 + ndev = dst ? dst_dev_rcu(dst) : NULL; 1897 + if (!ndev) { 1898 + rc = -ENODEV; 1894 1899 goto out; 1895 1900 } 1896 - if (!dst->dev) { 1897 - rc = -ENODEV; 1898 - goto out_rel; 1899 - } 1900 1901 1901 - ndev = dst->dev; 1902 1902 if (is_vlan_dev(ndev)) { 1903 1903 ini->vlan_id = vlan_dev_vlan_id(ndev); 1904 - goto out_rel; 1904 + goto out; 1905 1905 } 1906 1906 1907 1907 priv.data = (void *)&ini->vlan_id; 1908 - rtnl_lock(); 1909 - netdev_walk_all_lower_dev(ndev, smc_vlan_by_tcpsk_walk, &priv); 1910 - rtnl_unlock(); 1911 - 1912 - out_rel: 1913 - dst_release(dst); 1908 + netdev_walk_all_lower_dev_rcu(ndev, smc_vlan_by_tcpsk_walk, &priv); 1914 1909 out: 1910 + rcu_read_unlock(); 1911 + 1915 1912 return rc; 1916 1913 } 1917 1914
+22 -21
net/smc/smc_pnet.c
··· 1126 1126 */ 1127 1127 void smc_pnet_find_roce_resource(struct sock *sk, struct smc_init_info *ini) 1128 1128 { 1129 - struct dst_entry *dst = sk_dst_get(sk); 1129 + struct net_device *dev; 1130 + struct dst_entry *dst; 1130 1131 1131 - if (!dst) 1132 - goto out; 1133 - if (!dst->dev) 1134 - goto out_rel; 1132 + rcu_read_lock(); 1133 + dst = __sk_dst_get(sk); 1134 + dev = dst ? dst_dev_rcu(dst) : NULL; 1135 + dev_hold(dev); 1136 + rcu_read_unlock(); 1135 1137 1136 - smc_pnet_find_roce_by_pnetid(dst->dev, ini); 1137 - 1138 - out_rel: 1139 - dst_release(dst); 1140 - out: 1141 - return; 1138 + if (dev) { 1139 + smc_pnet_find_roce_by_pnetid(dev, ini); 1140 + dev_put(dev); 1141 + } 1142 1142 } 1143 1143 1144 1144 void smc_pnet_find_ism_resource(struct sock *sk, struct smc_init_info *ini) 1145 1145 { 1146 - struct dst_entry *dst = sk_dst_get(sk); 1146 + struct net_device *dev; 1147 + struct dst_entry *dst; 1147 1148 1148 1149 ini->ism_dev[0] = NULL; 1149 - if (!dst) 1150 - goto out; 1151 - if (!dst->dev) 1152 - goto out_rel; 1153 1150 1154 - smc_pnet_find_ism_by_pnetid(dst->dev, ini); 1151 + rcu_read_lock(); 1152 + dst = __sk_dst_get(sk); 1153 + dev = dst ? dst_dev_rcu(dst) : NULL; 1154 + dev_hold(dev); 1155 + rcu_read_unlock(); 1155 1156 1156 - out_rel: 1157 - dst_release(dst); 1158 - out: 1159 - return; 1157 + if (dev) { 1158 + smc_pnet_find_ism_by_pnetid(dev, ini); 1159 + dev_put(dev); 1160 + } 1160 1161 } 1161 1162 1162 1163 /* Lookup and apply a pnet table entry to the given ib device.
+10 -8
net/tls/tls_device.c
··· 123 123 /* We assume that the socket is already connected */ 124 124 static struct net_device *get_netdev_for_sock(struct sock *sk) 125 125 { 126 - struct dst_entry *dst = sk_dst_get(sk); 127 - struct net_device *netdev = NULL; 126 + struct net_device *dev, *lowest_dev = NULL; 127 + struct dst_entry *dst; 128 128 129 - if (likely(dst)) { 130 - netdev = netdev_sk_get_lowest_dev(dst->dev, sk); 131 - dev_hold(netdev); 129 + rcu_read_lock(); 130 + dst = __sk_dst_get(sk); 131 + dev = dst ? dst_dev_rcu(dst) : NULL; 132 + if (likely(dev)) { 133 + lowest_dev = netdev_sk_get_lowest_dev(dev, sk); 134 + dev_hold(lowest_dev); 132 135 } 136 + rcu_read_unlock(); 133 137 134 - dst_release(dst); 135 - 136 - return netdev; 138 + return lowest_dev; 137 139 } 138 140 139 141 static void destroy_record(struct tls_record_info *record)