Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'mpls-remove-rtnl-dependency'

Kuniyuki Iwashima says:

====================
mpls: Remove RTNL dependency.

MPLS uses RTNL

1) to guarantee the lifetime of struct mpls_nh.nh_dev
2) to protect net->mpls.platform_label

, but neither actually requires RTNL.

If struct mpls_nh holds a refcnt for nh_dev, we do not need RTNL,
and it can be replaced with a dedicated mutex.

The series removes RTNL from net/mpls/.

Overview:

Patch 1 is misc cleanup.

Patch 2 - 9 are prep to drop RTNL for RTM_{NEW,DEL,GET}ROUTE
handlers.

Patch 10 & 11 converts mpls_dump_routes() and RTM_GETNETCONF to RCU.

Patch 12 replaces RTNL with a new per-netns mutex.

Patch 13 drops RTNL from RTM_{NEW,DEL,GET}ROUTE.
====================

Link: https://patch.msgid.link/20251029173344.2934622-1-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

+224 -128
+5
include/net/addrconf.h
··· 347 347 return rcu_dereference_rtnl(dev->ip6_ptr); 348 348 } 349 349 350 + static inline struct inet6_dev *in6_dev_rcu(const struct net_device *dev) 351 + { 352 + return rcu_dereference(dev->ip6_ptr); 353 + } 354 + 350 355 static inline struct inet6_dev *__in6_dev_get_rtnl_net(const struct net_device *dev) 351 356 { 352 357 return rtnl_net_dereference(dev_net(dev), dev->ip6_ptr);
+1
include/net/netns/mpls.h
··· 16 16 int default_ttl; 17 17 size_t platform_labels; 18 18 struct mpls_route __rcu * __rcu *platform_label; 19 + struct mutex platform_mutex; 19 20 20 21 struct ctl_table_header *ctl; 21 22 };
+199 -122
net/mpls/af_mpls.c
··· 75 75 struct nlmsghdr *nlh, struct net *net, u32 portid, 76 76 unsigned int nlm_flags); 77 77 78 - static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned index) 78 + static struct mpls_route *mpls_route_input(struct net *net, unsigned int index) 79 79 { 80 - struct mpls_route *rt = NULL; 80 + struct mpls_route __rcu **platform_label; 81 81 82 - if (index < net->mpls.platform_labels) { 83 - struct mpls_route __rcu **platform_label = 84 - rcu_dereference_rtnl(net->mpls.platform_label); 85 - rt = rcu_dereference_rtnl(platform_label[index]); 86 - } 87 - return rt; 82 + platform_label = mpls_dereference(net, net->mpls.platform_label); 83 + return mpls_dereference(net, platform_label[index]); 84 + } 85 + 86 + static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index) 87 + { 88 + struct mpls_route __rcu **platform_label; 89 + 90 + if (index >= net->mpls.platform_labels) 91 + return NULL; 92 + 93 + platform_label = rcu_dereference(net->mpls.platform_label); 94 + return rcu_dereference(platform_label[index]); 88 95 } 89 96 90 97 bool mpls_output_possible(const struct net_device *dev) ··· 136 129 } 137 130 EXPORT_SYMBOL_GPL(mpls_pkt_too_big); 138 131 139 - void mpls_stats_inc_outucastpkts(struct net_device *dev, 132 + void mpls_stats_inc_outucastpkts(struct net *net, 133 + struct net_device *dev, 140 134 const struct sk_buff *skb) 141 135 { 142 136 struct mpls_dev *mdev; 143 137 144 138 if (skb->protocol == htons(ETH_P_MPLS_UC)) { 145 - mdev = mpls_dev_get(dev); 139 + mdev = mpls_dev_rcu(dev); 146 140 if (mdev) 147 141 MPLS_INC_STATS_LEN(mdev, skb->len, 148 142 tx_packets, 149 143 tx_bytes); 150 144 } else if (skb->protocol == htons(ETH_P_IP)) { 151 - IP_UPD_PO_STATS(dev_net(dev), IPSTATS_MIB_OUT, skb->len); 145 + IP_UPD_PO_STATS(net, IPSTATS_MIB_OUT, skb->len); 152 146 #if IS_ENABLED(CONFIG_IPV6) 153 147 } else if (skb->protocol == htons(ETH_P_IPV6)) { 154 - struct inet6_dev *in6dev = __in6_dev_get(dev); 148 + struct inet6_dev *in6dev = in6_dev_rcu(dev); 155 149 156 150 if (in6dev) 157 - IP6_UPD_PO_STATS(dev_net(dev), in6dev, 151 + IP6_UPD_PO_STATS(net, in6dev, 158 152 IPSTATS_MIB_OUT, skb->len); 159 153 #endif 160 154 } ··· 350 342 static int mpls_forward(struct sk_buff *skb, struct net_device *dev, 351 343 struct packet_type *pt, struct net_device *orig_dev) 352 344 { 353 - struct net *net = dev_net(dev); 345 + struct net *net = dev_net_rcu(dev); 354 346 struct mpls_shim_hdr *hdr; 355 347 const struct mpls_nh *nh; 356 348 struct mpls_route *rt; ··· 365 357 366 358 /* Careful this entire function runs inside of an rcu critical section */ 367 359 368 - mdev = mpls_dev_get(dev); 360 + mdev = mpls_dev_rcu(dev); 369 361 if (!mdev) 370 362 goto drop; 371 363 ··· 442 434 dec.ttl -= 1; 443 435 if (unlikely(!new_header_size && dec.bos)) { 444 436 /* Penultimate hop popping */ 445 - if (!mpls_egress(dev_net(out_dev), rt, skb, dec)) 437 + if (!mpls_egress(net, rt, skb, dec)) 446 438 goto err; 447 439 } else { 448 440 bool bos; ··· 459 451 } 460 452 } 461 453 462 - mpls_stats_inc_outucastpkts(out_dev, skb); 454 + mpls_stats_inc_outucastpkts(net, out_dev, skb); 463 455 464 456 /* If via wasn't specified then send out using device address */ 465 457 if (nh->nh_via_table == MPLS_NEIGH_TABLE_UNSPEC) ··· 474 466 return 0; 475 467 476 468 tx_err: 477 - out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; 469 + out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL; 478 470 if (out_mdev) 479 471 MPLS_INC_STATS(out_mdev, tx_errors); 480 472 goto drop; ··· 538 530 return rt; 539 531 } 540 532 533 + static void mpls_rt_free_rcu(struct rcu_head *head) 534 + { 535 + struct mpls_route *rt; 536 + 537 + rt = container_of(head, struct mpls_route, rt_rcu); 538 + 539 + change_nexthops(rt) { 540 + netdev_put(nh->nh_dev, &nh->nh_dev_tracker); 541 + } endfor_nexthops(rt); 542 + 543 + kfree(rt); 544 + } 545 + 541 546 static void mpls_rt_free(struct mpls_route *rt) 542 547 { 543 548 if (rt) 544 - kfree_rcu(rt, rt_rcu); 549 + call_rcu(&rt->rt_rcu, mpls_rt_free_rcu); 545 550 } 546 551 547 552 static void mpls_notify_route(struct net *net, unsigned index, ··· 578 557 struct mpls_route __rcu **platform_label; 579 558 struct mpls_route *rt; 580 559 581 - ASSERT_RTNL(); 582 - 583 - platform_label = rtnl_dereference(net->mpls.platform_label); 584 - rt = rtnl_dereference(platform_label[index]); 560 + platform_label = mpls_dereference(net, net->mpls.platform_label); 561 + rt = mpls_dereference(net, platform_label[index]); 585 562 rcu_assign_pointer(platform_label[index], new); 586 563 587 564 mpls_notify_route(net, index, rt, new, info); ··· 588 569 mpls_rt_free(rt); 589 570 } 590 571 591 - static unsigned find_free_label(struct net *net) 572 + static unsigned int find_free_label(struct net *net) 592 573 { 593 - struct mpls_route __rcu **platform_label; 594 - size_t platform_labels; 595 - unsigned index; 574 + unsigned int index; 596 575 597 - platform_label = rtnl_dereference(net->mpls.platform_label); 598 - platform_labels = net->mpls.platform_labels; 599 - for (index = MPLS_LABEL_FIRST_UNRESERVED; index < platform_labels; 576 + for (index = MPLS_LABEL_FIRST_UNRESERVED; 577 + index < net->mpls.platform_labels; 600 578 index++) { 601 - if (!rtnl_dereference(platform_label[index])) 579 + if (!mpls_route_input(net, index)) 602 580 return index; 603 581 } 582 + 604 583 return LABEL_NOT_SPECIFIED; 605 584 } 606 585 607 586 #if IS_ENABLED(CONFIG_INET) 608 587 static struct net_device *inet_fib_lookup_dev(struct net *net, 588 + struct mpls_nh *nh, 609 589 const void *addr) 610 590 { 611 591 struct net_device *dev; ··· 617 599 return ERR_CAST(rt); 618 600 619 601 dev = rt->dst.dev; 620 - dev_hold(dev); 621 - 602 + netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL); 622 603 ip_rt_put(rt); 623 604 624 605 return dev; 625 606 } 626 607 #else 627 608 static struct net_device *inet_fib_lookup_dev(struct net *net, 609 + struct mpls_nh *nh, 628 610 const void *addr) 629 611 { 630 612 return ERR_PTR(-EAFNOSUPPORT); ··· 633 615 634 616 #if IS_ENABLED(CONFIG_IPV6) 635 617 static struct net_device *inet6_fib_lookup_dev(struct net *net, 618 + struct mpls_nh *nh, 636 619 const void *addr) 637 620 { 638 621 struct net_device *dev; ··· 650 631 return ERR_CAST(dst); 651 632 652 633 dev = dst->dev; 653 - dev_hold(dev); 634 + netdev_hold(dev, &nh->nh_dev_tracker, GFP_KERNEL); 654 635 dst_release(dst); 655 636 656 637 return dev; 657 638 } 658 639 #else 659 640 static struct net_device *inet6_fib_lookup_dev(struct net *net, 641 + struct mpls_nh *nh, 660 642 const void *addr) 661 643 { 662 644 return ERR_PTR(-EAFNOSUPPORT); ··· 673 653 if (!oif) { 674 654 switch (nh->nh_via_table) { 675 655 case NEIGH_ARP_TABLE: 676 - dev = inet_fib_lookup_dev(net, mpls_nh_via(rt, nh)); 656 + dev = inet_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh)); 677 657 break; 678 658 case NEIGH_ND_TABLE: 679 - dev = inet6_fib_lookup_dev(net, mpls_nh_via(rt, nh)); 659 + dev = inet6_fib_lookup_dev(net, nh, mpls_nh_via(rt, nh)); 680 660 break; 681 661 case NEIGH_LINK_TABLE: 682 662 break; 683 663 } 684 664 } else { 685 - dev = dev_get_by_index(net, oif); 665 + dev = netdev_get_by_index(net, oif, 666 + &nh->nh_dev_tracker, GFP_KERNEL); 686 667 } 687 668 688 669 if (!dev) ··· 692 671 if (IS_ERR(dev)) 693 672 return dev; 694 673 695 - /* The caller is holding rtnl anyways, so release the dev reference */ 696 - dev_put(dev); 674 + nh->nh_dev = dev; 697 675 698 676 return dev; 699 677 } ··· 706 686 dev = find_outdev(net, rt, nh, oif); 707 687 if (IS_ERR(dev)) { 708 688 err = PTR_ERR(dev); 709 - dev = NULL; 710 689 goto errout; 711 690 } 712 691 713 692 /* Ensure this is a supported device */ 714 693 err = -EINVAL; 715 - if (!mpls_dev_get(dev)) 716 - goto errout; 694 + if (!mpls_dev_get(net, dev)) 695 + goto errout_put; 717 696 718 697 if ((nh->nh_via_table == NEIGH_LINK_TABLE) && 719 698 (dev->addr_len != nh->nh_via_alen)) 720 - goto errout; 721 - 722 - nh->nh_dev = dev; 699 + goto errout_put; 723 700 724 701 if (!(dev->flags & IFF_UP)) { 725 702 nh->nh_flags |= RTNH_F_DEAD; ··· 730 713 731 714 return 0; 732 715 716 + errout_put: 717 + netdev_put(nh->nh_dev, &nh->nh_dev_tracker); 718 + nh->nh_dev = NULL; 733 719 errout: 734 720 return err; 735 721 } ··· 910 890 struct nlattr *nla_via, *nla_newdst; 911 891 int remaining = cfg->rc_mp_len; 912 892 int err = 0; 913 - u8 nhs = 0; 893 + 894 + rt->rt_nhn = 0; 914 895 915 896 change_nexthops(rt) { 916 897 int attrlen; ··· 947 926 rt->rt_nhn_alive--; 948 927 949 928 rtnh = rtnh_next(rtnh, &remaining); 950 - nhs++; 929 + rt->rt_nhn++; 951 930 } endfor_nexthops(rt); 952 - 953 - rt->rt_nhn = nhs; 954 931 955 932 return 0; 956 933 ··· 959 940 static bool mpls_label_ok(struct net *net, unsigned int *index, 960 941 struct netlink_ext_ack *extack) 961 942 { 962 - bool is_ok = true; 963 - 964 943 /* Reserved labels may not be set */ 965 944 if (*index < MPLS_LABEL_FIRST_UNRESERVED) { 966 945 NL_SET_ERR_MSG(extack, 967 946 "Invalid label - must be MPLS_LABEL_FIRST_UNRESERVED or higher"); 968 - is_ok = false; 947 + return false; 969 948 } 970 949 971 950 /* The full 20 bit range may not be supported. */ 972 - if (is_ok && *index >= net->mpls.platform_labels) { 951 + if (*index >= net->mpls.platform_labels) { 973 952 NL_SET_ERR_MSG(extack, 974 953 "Label >= configured maximum in platform_labels"); 975 - is_ok = false; 954 + return false; 976 955 } 977 956 978 957 *index = array_index_nospec(*index, net->mpls.platform_labels); 979 - return is_ok; 958 + 959 + return true; 980 960 } 981 961 982 962 static int mpls_route_add(struct mpls_route_config *cfg, 983 963 struct netlink_ext_ack *extack) 984 964 { 985 - struct mpls_route __rcu **platform_label; 986 965 struct net *net = cfg->rc_nlinfo.nl_net; 987 966 struct mpls_route *rt, *old; 988 967 int err = -EINVAL; ··· 1008 991 } 1009 992 1010 993 err = -EEXIST; 1011 - platform_label = rtnl_dereference(net->mpls.platform_label); 1012 - old = rtnl_dereference(platform_label[index]); 994 + old = mpls_route_input(net, index); 1013 995 if ((cfg->rc_nlflags & NLM_F_EXCL) && old) 1014 996 goto errout; 1015 997 ··· 1119 1103 struct mpls_dev *mdev; 1120 1104 struct nlattr *nla; 1121 1105 1122 - mdev = mpls_dev_get(dev); 1106 + mdev = mpls_dev_rcu(dev); 1123 1107 if (!mdev) 1124 1108 return -ENODATA; 1125 1109 ··· 1139 1123 { 1140 1124 struct mpls_dev *mdev; 1141 1125 1142 - mdev = mpls_dev_get(dev); 1126 + mdev = mpls_dev_rcu(dev); 1143 1127 if (!mdev) 1144 1128 return 0; 1145 1129 ··· 1280 1264 if (err < 0) 1281 1265 goto errout; 1282 1266 1283 - err = -EINVAL; 1284 - if (!tb[NETCONFA_IFINDEX]) 1267 + if (!tb[NETCONFA_IFINDEX]) { 1268 + err = -EINVAL; 1285 1269 goto errout; 1270 + } 1286 1271 1287 1272 ifindex = nla_get_s32(tb[NETCONFA_IFINDEX]); 1288 - dev = __dev_get_by_index(net, ifindex); 1289 - if (!dev) 1290 - goto errout; 1291 1273 1292 - mdev = mpls_dev_get(dev); 1293 - if (!mdev) 1294 - goto errout; 1295 - 1296 - err = -ENOBUFS; 1297 1274 skb = nlmsg_new(mpls_netconf_msgsize_devconf(NETCONFA_ALL), GFP_KERNEL); 1298 - if (!skb) 1275 + if (!skb) { 1276 + err = -ENOBUFS; 1299 1277 goto errout; 1278 + } 1279 + 1280 + rcu_read_lock(); 1281 + 1282 + dev = dev_get_by_index_rcu(net, ifindex); 1283 + if (!dev) { 1284 + err = -EINVAL; 1285 + goto errout_unlock; 1286 + } 1287 + 1288 + mdev = mpls_dev_rcu(dev); 1289 + if (!mdev) { 1290 + err = -EINVAL; 1291 + goto errout_unlock; 1292 + } 1300 1293 1301 1294 err = mpls_netconf_fill_devconf(skb, mdev, 1302 1295 NETLINK_CB(in_skb).portid, ··· 1314 1289 if (err < 0) { 1315 1290 /* -EMSGSIZE implies BUG in mpls_netconf_msgsize_devconf() */ 1316 1291 WARN_ON(err == -EMSGSIZE); 1317 - kfree_skb(skb); 1318 - goto errout; 1292 + goto errout_unlock; 1319 1293 } 1294 + 1320 1295 err = rtnl_unicast(skb, net, NETLINK_CB(in_skb).portid); 1296 + 1297 + rcu_read_unlock(); 1321 1298 errout: 1322 1299 return err; 1300 + 1301 + errout_unlock: 1302 + rcu_read_unlock(); 1303 + kfree_skb(skb); 1304 + goto errout; 1323 1305 } 1324 1306 1325 1307 static int mpls_netconf_dump_devconf(struct sk_buff *skb, ··· 1358 1326 1359 1327 rcu_read_lock(); 1360 1328 for_each_netdev_dump(net, dev, ctx->ifindex) { 1361 - mdev = mpls_dev_get(dev); 1329 + mdev = mpls_dev_rcu(dev); 1362 1330 if (!mdev) 1363 1331 continue; 1364 1332 err = mpls_netconf_fill_devconf(skb, mdev, ··· 1470 1438 int err = -ENOMEM; 1471 1439 int i; 1472 1440 1473 - ASSERT_RTNL(); 1474 - 1475 1441 mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); 1476 1442 if (!mdev) 1477 1443 return ERR_PTR(err); ··· 1511 1481 1512 1482 static int mpls_ifdown(struct net_device *dev, int event) 1513 1483 { 1514 - struct mpls_route __rcu **platform_label; 1515 1484 struct net *net = dev_net(dev); 1516 - unsigned index; 1485 + unsigned int index; 1517 1486 1518 - platform_label = rtnl_dereference(net->mpls.platform_label); 1519 1487 for (index = 0; index < net->mpls.platform_labels; index++) { 1520 - struct mpls_route *rt = rtnl_dereference(platform_label[index]); 1488 + struct mpls_route *rt; 1521 1489 bool nh_del = false; 1522 1490 u8 alive = 0; 1523 1491 1492 + rt = mpls_route_input(net, index); 1524 1493 if (!rt) 1525 1494 continue; 1526 1495 ··· 1553 1524 change_nexthops(rt) { 1554 1525 unsigned int nh_flags = nh->nh_flags; 1555 1526 1556 - if (nh->nh_dev != dev) 1527 + if (nh->nh_dev != dev) { 1528 + if (nh_del) 1529 + netdev_hold(nh->nh_dev, &nh->nh_dev_tracker, 1530 + GFP_KERNEL); 1557 1531 goto next; 1532 + } 1558 1533 1559 1534 switch (event) { 1560 1535 case NETDEV_DOWN: ··· 1590 1557 1591 1558 static void mpls_ifup(struct net_device *dev, unsigned int flags) 1592 1559 { 1593 - struct mpls_route __rcu **platform_label; 1594 1560 struct net *net = dev_net(dev); 1595 - unsigned index; 1561 + unsigned int index; 1596 1562 u8 alive; 1597 1563 1598 - platform_label = rtnl_dereference(net->mpls.platform_label); 1599 1564 for (index = 0; index < net->mpls.platform_labels; index++) { 1600 - struct mpls_route *rt = rtnl_dereference(platform_label[index]); 1565 + struct mpls_route *rt; 1601 1566 1567 + rt = mpls_route_input(net, index); 1602 1568 if (!rt) 1603 1569 continue; 1604 1570 ··· 1624 1592 void *ptr) 1625 1593 { 1626 1594 struct net_device *dev = netdev_notifier_info_to_dev(ptr); 1595 + struct net *net = dev_net(dev); 1627 1596 struct mpls_dev *mdev; 1628 1597 unsigned int flags; 1629 1598 int err; 1630 1599 1600 + mutex_lock(&net->mpls.platform_mutex); 1601 + 1631 1602 if (event == NETDEV_REGISTER) { 1632 1603 mdev = mpls_add_dev(dev); 1633 - if (IS_ERR(mdev)) 1634 - return notifier_from_errno(PTR_ERR(mdev)); 1604 + if (IS_ERR(mdev)) { 1605 + err = PTR_ERR(mdev); 1606 + goto err; 1607 + } 1635 1608 1636 - return NOTIFY_OK; 1609 + goto out; 1637 1610 } 1638 1611 1639 - mdev = mpls_dev_get(dev); 1612 + mdev = mpls_dev_get(net, dev); 1640 1613 if (!mdev) 1641 - return NOTIFY_OK; 1614 + goto out; 1642 1615 1643 1616 switch (event) { 1644 1617 1645 1618 case NETDEV_DOWN: 1646 1619 err = mpls_ifdown(dev, event); 1647 1620 if (err) 1648 - return notifier_from_errno(err); 1621 + goto err; 1649 1622 break; 1650 1623 case NETDEV_UP: 1651 1624 flags = netif_get_flags(dev); ··· 1666 1629 } else { 1667 1630 err = mpls_ifdown(dev, event); 1668 1631 if (err) 1669 - return notifier_from_errno(err); 1632 + goto err; 1670 1633 } 1671 1634 break; 1672 1635 case NETDEV_UNREGISTER: 1673 1636 err = mpls_ifdown(dev, event); 1674 1637 if (err) 1675 - return notifier_from_errno(err); 1676 - mdev = mpls_dev_get(dev); 1638 + goto err; 1639 + 1640 + mdev = mpls_dev_get(net, dev); 1677 1641 if (mdev) { 1678 1642 mpls_dev_sysctl_unregister(dev, mdev); 1679 1643 RCU_INIT_POINTER(dev->mpls_ptr, NULL); ··· 1682 1644 } 1683 1645 break; 1684 1646 case NETDEV_CHANGENAME: 1685 - mdev = mpls_dev_get(dev); 1647 + mdev = mpls_dev_get(net, dev); 1686 1648 if (mdev) { 1687 1649 mpls_dev_sysctl_unregister(dev, mdev); 1688 1650 err = mpls_dev_sysctl_register(dev, mdev); 1689 1651 if (err) 1690 - return notifier_from_errno(err); 1652 + goto err; 1691 1653 } 1692 1654 break; 1693 1655 } 1656 + 1657 + out: 1658 + mutex_unlock(&net->mpls.platform_mutex); 1694 1659 return NOTIFY_OK; 1660 + 1661 + err: 1662 + mutex_unlock(&net->mpls.platform_mutex); 1663 + return notifier_from_errno(err); 1695 1664 } 1696 1665 1697 1666 static struct notifier_block mpls_dev_notifier = { ··· 1973 1928 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, 1974 1929 struct netlink_ext_ack *extack) 1975 1930 { 1931 + struct net *net = sock_net(skb->sk); 1976 1932 struct mpls_route_config *cfg; 1977 1933 int err; 1978 1934 ··· 1985 1939 if (err < 0) 1986 1940 goto out; 1987 1941 1942 + mutex_lock(&net->mpls.platform_mutex); 1988 1943 err = mpls_route_del(cfg, extack); 1944 + mutex_unlock(&net->mpls.platform_mutex); 1989 1945 out: 1990 1946 kfree(cfg); 1991 1947 ··· 1998 1950 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, 1999 1951 struct netlink_ext_ack *extack) 2000 1952 { 1953 + struct net *net = sock_net(skb->sk); 2001 1954 struct mpls_route_config *cfg; 2002 1955 int err; 2003 1956 ··· 2010 1961 if (err < 0) 2011 1962 goto out; 2012 1963 1964 + mutex_lock(&net->mpls.platform_mutex); 2013 1965 err = mpls_route_add(cfg, extack); 1966 + mutex_unlock(&net->mpls.platform_mutex); 2014 1967 out: 2015 1968 kfree(cfg); 2016 1969 ··· 2175 2124 2176 2125 if (i == RTA_OIF) { 2177 2126 ifindex = nla_get_u32(tb[i]); 2178 - filter->dev = __dev_get_by_index(net, ifindex); 2127 + filter->dev = dev_get_by_index_rcu(net, ifindex); 2179 2128 if (!filter->dev) 2180 2129 return -ENODEV; 2181 2130 filter->filter_set = 1; ··· 2213 2162 struct net *net = sock_net(skb->sk); 2214 2163 struct mpls_route __rcu **platform_label; 2215 2164 struct fib_dump_filter filter = { 2216 - .rtnl_held = true, 2165 + .rtnl_held = false, 2217 2166 }; 2218 2167 unsigned int flags = NLM_F_MULTI; 2219 2168 size_t platform_labels; 2220 2169 unsigned int index; 2170 + int err; 2221 2171 2222 - ASSERT_RTNL(); 2172 + rcu_read_lock(); 2223 2173 2224 2174 if (cb->strict_check) { 2225 - int err; 2226 - 2227 2175 err = mpls_valid_fib_dump_req(net, nlh, &filter, cb); 2228 2176 if (err < 0) 2229 - return err; 2177 + goto err; 2230 2178 2231 2179 /* for MPLS, there is only 1 table with fixed type and flags. 2232 2180 * If either are set in the filter then return nothing. ··· 2233 2183 if ((filter.table_id && filter.table_id != RT_TABLE_MAIN) || 2234 2184 (filter.rt_type && filter.rt_type != RTN_UNICAST) || 2235 2185 filter.flags) 2236 - return skb->len; 2186 + goto unlock; 2237 2187 } 2238 2188 2239 2189 index = cb->args[0]; 2240 2190 if (index < MPLS_LABEL_FIRST_UNRESERVED) 2241 2191 index = MPLS_LABEL_FIRST_UNRESERVED; 2242 2192 2243 - platform_label = rtnl_dereference(net->mpls.platform_label); 2193 + platform_label = rcu_dereference(net->mpls.platform_label); 2244 2194 platform_labels = net->mpls.platform_labels; 2245 2195 2246 2196 if (filter.filter_set) ··· 2249 2199 for (; index < platform_labels; index++) { 2250 2200 struct mpls_route *rt; 2251 2201 2252 - rt = rtnl_dereference(platform_label[index]); 2202 + rt = rcu_dereference(platform_label[index]); 2253 2203 if (!rt) 2254 2204 continue; 2255 2205 ··· 2264 2214 } 2265 2215 cb->args[0] = index; 2266 2216 2217 + unlock: 2218 + rcu_read_unlock(); 2267 2219 return skb->len; 2220 + 2221 + err: 2222 + rcu_read_unlock(); 2223 + return err; 2268 2224 } 2269 2225 2270 2226 static inline size_t lfib_nlmsg_size(struct mpls_route *rt) ··· 2401 2345 u32 portid = NETLINK_CB(in_skb).portid; 2402 2346 u32 in_label = LABEL_NOT_SPECIFIED; 2403 2347 struct nlattr *tb[RTA_MAX + 1]; 2348 + struct mpls_route *rt = NULL; 2404 2349 u32 labels[MAX_NEW_LABELS]; 2405 2350 struct mpls_shim_hdr *hdr; 2406 2351 unsigned int hdr_size = 0; 2407 2352 const struct mpls_nh *nh; 2408 2353 struct net_device *dev; 2409 - struct mpls_route *rt; 2410 2354 struct rtmsg *rtm, *r; 2411 2355 struct nlmsghdr *nlh; 2412 2356 struct sk_buff *skb; 2413 2357 u8 n_labels; 2414 2358 int err; 2359 + 2360 + mutex_lock(&net->mpls.platform_mutex); 2415 2361 2416 2362 err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack); 2417 2363 if (err < 0) ··· 2436 2378 } 2437 2379 } 2438 2380 2439 - rt = mpls_route_input_rcu(net, in_label); 2381 + if (in_label < net->mpls.platform_labels) 2382 + rt = mpls_route_input(net, in_label); 2440 2383 if (!rt) { 2441 2384 err = -ENETUNREACH; 2442 2385 goto errout; ··· 2458 2399 goto errout_free; 2459 2400 } 2460 2401 2461 - return rtnl_unicast(skb, net, portid); 2402 + err = rtnl_unicast(skb, net, portid); 2403 + goto errout; 2462 2404 } 2463 2405 2464 2406 if (tb[RTA_NEWDST]) { ··· 2551 2491 2552 2492 err = rtnl_unicast(skb, net, portid); 2553 2493 errout: 2494 + mutex_unlock(&net->mpls.platform_mutex); 2554 2495 return err; 2555 2496 2556 2497 nla_put_failure: 2557 2498 nlmsg_cancel(skb, nlh); 2558 2499 err = -EMSGSIZE; 2559 2500 errout_free: 2501 + mutex_unlock(&net->mpls.platform_mutex); 2560 2502 kfree_skb(skb); 2561 2503 return err; 2562 2504 } ··· 2581 2519 /* In case the predefined labels need to be populated */ 2582 2520 if (limit > MPLS_LABEL_IPV4NULL) { 2583 2521 struct net_device *lo = net->loopback_dev; 2522 + 2584 2523 rt0 = mpls_rt_alloc(1, lo->addr_len, 0); 2585 2524 if (IS_ERR(rt0)) 2586 2525 goto nort0; 2526 + 2587 2527 rt0->rt_nh->nh_dev = lo; 2528 + netdev_hold(lo, &rt0->rt_nh->nh_dev_tracker, GFP_KERNEL); 2588 2529 rt0->rt_protocol = RTPROT_KERNEL; 2589 2530 rt0->rt_payload_type = MPT_IPV4; 2590 2531 rt0->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; ··· 2598 2533 } 2599 2534 if (limit > MPLS_LABEL_IPV6NULL) { 2600 2535 struct net_device *lo = net->loopback_dev; 2536 + 2601 2537 rt2 = mpls_rt_alloc(1, lo->addr_len, 0); 2602 2538 if (IS_ERR(rt2)) 2603 2539 goto nort2; 2540 + 2604 2541 rt2->rt_nh->nh_dev = lo; 2542 + netdev_hold(lo, &rt2->rt_nh->nh_dev_tracker, GFP_KERNEL); 2605 2543 rt2->rt_protocol = RTPROT_KERNEL; 2606 2544 rt2->rt_payload_type = MPT_IPV6; 2607 2545 rt2->rt_ttl_propagate = MPLS_TTL_PROP_DEFAULT; ··· 2614 2546 lo->addr_len); 2615 2547 } 2616 2548 2617 - rtnl_lock(); 2549 + mutex_lock(&net->mpls.platform_mutex); 2550 + 2618 2551 /* Remember the original table */ 2619 - old = rtnl_dereference(net->mpls.platform_label); 2552 + old = mpls_dereference(net, net->mpls.platform_label); 2620 2553 old_limit = net->mpls.platform_labels; 2621 2554 2622 2555 /* Free any labels beyond the new table */ ··· 2648 2579 net->mpls.platform_labels = limit; 2649 2580 rcu_assign_pointer(net->mpls.platform_label, labels); 2650 2581 2651 - rtnl_unlock(); 2582 + mutex_unlock(&net->mpls.platform_mutex); 2652 2583 2653 2584 mpls_rt_free(rt2); 2654 2585 mpls_rt_free(rt0); ··· 2721 2652 }, 2722 2653 }; 2723 2654 2724 - static int mpls_net_init(struct net *net) 2655 + static __net_init int mpls_net_init(struct net *net) 2725 2656 { 2726 2657 size_t table_size = ARRAY_SIZE(mpls_table); 2727 2658 struct ctl_table *table; 2728 2659 int i; 2729 2660 2661 + mutex_init(&net->mpls.platform_mutex); 2730 2662 net->mpls.platform_labels = 0; 2731 2663 net->mpls.platform_label = NULL; 2732 2664 net->mpls.ip_ttl_propagate = 1; ··· 2753 2683 return 0; 2754 2684 } 2755 2685 2756 - static void mpls_net_exit(struct net *net) 2686 + static __net_exit void mpls_net_exit(struct net *net) 2757 2687 { 2758 2688 struct mpls_route __rcu **platform_label; 2759 2689 size_t platform_labels; ··· 2773 2703 * As such no additional rcu synchronization is necessary when 2774 2704 * freeing the platform_label table. 2775 2705 */ 2776 - rtnl_lock(); 2777 - platform_label = rtnl_dereference(net->mpls.platform_label); 2706 + mutex_lock(&net->mpls.platform_mutex); 2707 + 2708 + platform_label = mpls_dereference(net, net->mpls.platform_label); 2778 2709 platform_labels = net->mpls.platform_labels; 2710 + 2779 2711 for (index = 0; index < platform_labels; index++) { 2780 - struct mpls_route *rt = rtnl_dereference(platform_label[index]); 2781 - RCU_INIT_POINTER(platform_label[index], NULL); 2712 + struct mpls_route *rt; 2713 + 2714 + rt = mpls_dereference(net, platform_label[index]); 2782 2715 mpls_notify_route(net, index, rt, NULL, NULL); 2783 2716 mpls_rt_free(rt); 2784 2717 } 2785 - rtnl_unlock(); 2718 + 2719 + mutex_unlock(&net->mpls.platform_mutex); 2786 2720 2787 2721 kvfree(platform_label); 2788 2722 } ··· 2803 2729 }; 2804 2730 2805 2731 static const struct rtnl_msg_handler mpls_rtnl_msg_handlers[] __initdata_or_module = { 2806 - {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, 0}, 2807 - {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, 0}, 2808 - {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, 0}, 2732 + {THIS_MODULE, PF_MPLS, RTM_NEWROUTE, mpls_rtm_newroute, NULL, 2733 + RTNL_FLAG_DOIT_UNLOCKED}, 2734 + {THIS_MODULE, PF_MPLS, RTM_DELROUTE, mpls_rtm_delroute, NULL, 2735 + RTNL_FLAG_DOIT_UNLOCKED}, 2736 + {THIS_MODULE, PF_MPLS, RTM_GETROUTE, mpls_getroute, mpls_dump_routes, 2737 + RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED}, 2809 2738 {THIS_MODULE, PF_MPLS, RTM_GETNETCONF, 2810 2739 mpls_netconf_get_devconf, mpls_netconf_dump_devconf, 2811 - RTNL_FLAG_DUMP_UNLOCKED}, 2740 + RTNL_FLAG_DOIT_UNLOCKED | RTNL_FLAG_DUMP_UNLOCKED}, 2812 2741 }; 2813 2742 2814 2743 static int __init mpls_init(void)
+16 -3
net/mpls/internal.h
··· 88 88 89 89 struct mpls_nh { /* next hop label forwarding entry */ 90 90 struct net_device *nh_dev; 91 + netdevice_tracker nh_dev_tracker; 91 92 92 93 /* nh_flags is accessed under RCU in the packet path; it is 93 94 * modified handling netdev events with rtnl lock held ··· 185 184 return result; 186 185 } 187 186 188 - static inline struct mpls_dev *mpls_dev_get(const struct net_device *dev) 187 + #define mpls_dereference(net, p) \ 188 + rcu_dereference_protected( \ 189 + (p), \ 190 + lockdep_is_held(&(net)->mpls.platform_mutex)) 191 + 192 + static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev) 189 193 { 190 - return rcu_dereference_rtnl(dev->mpls_ptr); 194 + return rcu_dereference(dev->mpls_ptr); 195 + } 196 + 197 + static inline struct mpls_dev *mpls_dev_get(const struct net *net, 198 + const struct net_device *dev) 199 + { 200 + return mpls_dereference(net, dev->mpls_ptr); 191 201 } 192 202 193 203 int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels, ··· 208 196 bool mpls_output_possible(const struct net_device *dev); 209 197 unsigned int mpls_dev_mtu(const struct net_device *dev); 210 198 bool mpls_pkt_too_big(const struct sk_buff *skb, unsigned int mtu); 211 - void mpls_stats_inc_outucastpkts(struct net_device *dev, 199 + void mpls_stats_inc_outucastpkts(struct net *net, 200 + struct net_device *dev, 212 201 const struct sk_buff *skb); 213 202 214 203 #endif /* MPLS_INTERNAL_H */
+3 -3
net/mpls/mpls_iptunnel.c
··· 53 53 54 54 /* Find the output device */ 55 55 out_dev = dst->dev; 56 - net = dev_net(out_dev); 56 + net = dev_net_rcu(out_dev); 57 57 58 58 if (!mpls_output_possible(out_dev) || 59 59 !dst->lwtstate || skb_warn_if_lro(skb)) ··· 128 128 bos = false; 129 129 } 130 130 131 - mpls_stats_inc_outucastpkts(out_dev, skb); 131 + mpls_stats_inc_outucastpkts(net, out_dev, skb); 132 132 133 133 if (rt) { 134 134 if (rt->rt_gw_family == AF_INET6) ··· 153 153 return LWTUNNEL_XMIT_DONE; 154 154 155 155 drop: 156 - out_mdev = out_dev ? mpls_dev_get(out_dev) : NULL; 156 + out_mdev = out_dev ? mpls_dev_rcu(out_dev) : NULL; 157 157 if (out_mdev) 158 158 MPLS_INC_STATS(out_mdev, tx_errors); 159 159 kfree_skb(skb);