Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

mpls: Protect net->mpls.platform_label with a per-netns mutex.

MPLS (re)uses RTNL to protect net->mpls.platform_label,
but the lock does not need to be RTNL at all.

Let's protect net->mpls.platform_label with a dedicated
per-netns mutex.

Signed-off-by: Kuniyuki Iwashima <kuniyu@google.com>
Reviewed-by: Guillaume Nault <gnault@redhat.com>
Link: https://patch.msgid.link/20251029173344.2934622-13-kuniyu@google.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

authored by

Kuniyuki Iwashima and committed by
Jakub Kicinski
e833eb25 fb2b77b9

+43 -20
+1
include/net/netns/mpls.h
··· 16 16 int default_ttl; 17 17 size_t platform_labels; 18 18 struct mpls_route __rcu * __rcu *platform_label; 19 + struct mutex platform_mutex; 19 20 20 21 struct ctl_table_header *ctl; 21 22 };
+36 -19
net/mpls/af_mpls.c
··· 79 79 { 80 80 struct mpls_route __rcu **platform_label; 81 81 82 - platform_label = rtnl_dereference(net->mpls.platform_label); 83 - return rtnl_dereference(platform_label[index]); 82 + platform_label = mpls_dereference(net, net->mpls.platform_label); 83 + return mpls_dereference(net, platform_label[index]); 84 84 } 85 85 86 86 static struct mpls_route *mpls_route_input_rcu(struct net *net, unsigned int index) ··· 578 578 struct mpls_route __rcu **platform_label; 579 579 struct mpls_route *rt; 580 580 581 - ASSERT_RTNL(); 582 - 583 - platform_label = rtnl_dereference(net->mpls.platform_label); 584 - rt = rtnl_dereference(platform_label[index]); 581 + platform_label = mpls_dereference(net, net->mpls.platform_label); 582 + rt = mpls_dereference(net, platform_label[index]); 585 583 rcu_assign_pointer(platform_label[index], new); 586 584 587 585 mpls_notify_route(net, index, rt, new, info); ··· 1470 1472 int err = -ENOMEM; 1471 1473 int i; 1472 1474 1473 - ASSERT_RTNL(); 1474 - 1475 1475 mdev = kzalloc(sizeof(*mdev), GFP_KERNEL); 1476 1476 if (!mdev) 1477 1477 return ERR_PTR(err); ··· 1629 1633 unsigned int flags; 1630 1634 int err; 1631 1635 1636 + mutex_lock(&net->mpls.platform_mutex); 1637 + 1632 1638 if (event == NETDEV_REGISTER) { 1633 1639 mdev = mpls_add_dev(dev); 1634 1640 if (IS_ERR(mdev)) { ··· 1693 1695 } 1694 1696 1695 1697 out: 1698 + mutex_unlock(&net->mpls.platform_mutex); 1696 1699 return NOTIFY_OK; 1697 1700 1698 1701 err: 1702 + mutex_unlock(&net->mpls.platform_mutex); 1699 1703 return notifier_from_errno(err); 1700 1704 } 1701 1705 ··· 1973 1973 static int mpls_rtm_delroute(struct sk_buff *skb, struct nlmsghdr *nlh, 1974 1974 struct netlink_ext_ack *extack) 1975 1975 { 1976 + struct net *net = sock_net(skb->sk); 1976 1977 struct mpls_route_config *cfg; 1977 1978 int err; 1978 1979 ··· 1985 1984 if (err < 0) 1986 1985 goto out; 1987 1986 1987 + mutex_lock(&net->mpls.platform_mutex); 1988 1988 err = mpls_route_del(cfg, extack); 1989 + mutex_unlock(&net->mpls.platform_mutex); 1989 1990 out: 1990 1991 kfree(cfg); 1991 1992 ··· 1998 1995 static int mpls_rtm_newroute(struct sk_buff *skb, struct nlmsghdr *nlh, 1999 1996 struct netlink_ext_ack *extack) 2000 1997 { 1998 + struct net *net = sock_net(skb->sk); 2001 1999 struct mpls_route_config *cfg; 2002 2000 int err; 2003 2001 ··· 2010 2006 if (err < 0) 2011 2007 goto out; 2012 2008 2009 + mutex_lock(&net->mpls.platform_mutex); 2013 2010 err = mpls_route_add(cfg, extack); 2011 + mutex_unlock(&net->mpls.platform_mutex); 2014 2012 out: 2015 2013 kfree(cfg); 2016 2014 ··· 2413 2407 u8 n_labels; 2414 2408 int err; 2415 2409 2410 + mutex_lock(&net->mpls.platform_mutex); 2411 + 2416 2412 err = mpls_valid_getroute_req(in_skb, in_nlh, tb, extack); 2417 2413 if (err < 0) 2418 2414 goto errout; ··· 2458 2450 goto errout_free; 2459 2451 } 2460 2452 2461 - return rtnl_unicast(skb, net, portid); 2453 + err = rtnl_unicast(skb, net, portid); 2454 + goto errout; 2462 2455 } 2463 2456 2464 2457 if (tb[RTA_NEWDST]) { ··· 2551 2542 2552 2543 err = rtnl_unicast(skb, net, portid); 2553 2544 errout: 2545 + mutex_unlock(&net->mpls.platform_mutex); 2554 2546 return err; 2555 2547 2556 2548 nla_put_failure: 2557 2549 nlmsg_cancel(skb, nlh); 2558 2550 err = -EMSGSIZE; 2559 2551 errout_free: 2552 + mutex_unlock(&net->mpls.platform_mutex); 2560 2553 kfree_skb(skb); 2561 2554 return err; 2562 2555 } ··· 2614 2603 lo->addr_len); 2615 2604 } 2616 2605 2617 - rtnl_lock(); 2606 + mutex_lock(&net->mpls.platform_mutex); 2607 + 2618 2608 /* Remember the original table */ 2619 - old = rtnl_dereference(net->mpls.platform_label); 2609 + old = mpls_dereference(net, net->mpls.platform_label); 2620 2610 old_limit = net->mpls.platform_labels; 2621 2611 2622 2612 /* Free any labels beyond the new table */ ··· 2648 2636 net->mpls.platform_labels = limit; 2649 2637 rcu_assign_pointer(net->mpls.platform_label, labels); 2650 2638 2651 - rtnl_unlock(); 2639 + mutex_unlock(&net->mpls.platform_mutex); 2652 2640 2653 2641 mpls_rt_free(rt2); 2654 2642 mpls_rt_free(rt0); ··· 2721 2709 }, 2722 2710 }; 2723 2711 2724 - static int mpls_net_init(struct net *net) 2712 + static __net_init int mpls_net_init(struct net *net) 2725 2713 { 2726 2714 size_t table_size = ARRAY_SIZE(mpls_table); 2727 2715 struct ctl_table *table; 2728 2716 int i; 2729 2717 2718 + mutex_init(&net->mpls.platform_mutex); 2730 2719 net->mpls.platform_labels = 0; 2731 2720 net->mpls.platform_label = NULL; 2732 2721 net->mpls.ip_ttl_propagate = 1; ··· 2753 2740 return 0; 2754 2741 } 2755 2742 2756 - static void mpls_net_exit(struct net *net) 2743 + static __net_exit void mpls_net_exit(struct net *net) 2757 2744 { 2758 2745 struct mpls_route __rcu **platform_label; 2759 2746 size_t platform_labels; ··· 2773 2760 * As such no additional rcu synchronization is necessary when 2774 2761 * freeing the platform_label table. 2775 2762 */ 2776 - rtnl_lock(); 2777 - platform_label = rtnl_dereference(net->mpls.platform_label); 2763 + mutex_lock(&net->mpls.platform_mutex); 2764 + 2765 + platform_label = mpls_dereference(net, net->mpls.platform_label); 2778 2766 platform_labels = net->mpls.platform_labels; 2767 + 2779 2768 for (index = 0; index < platform_labels; index++) { 2780 - struct mpls_route *rt = rtnl_dereference(platform_label[index]); 2781 - RCU_INIT_POINTER(platform_label[index], NULL); 2769 + struct mpls_route *rt; 2770 + 2771 + rt = mpls_dereference(net, platform_label[index]); 2782 2772 mpls_notify_route(net, index, rt, NULL, NULL); 2783 2773 mpls_rt_free(rt); 2784 2774 } 2785 - rtnl_unlock(); 2775 + 2776 + mutex_unlock(&net->mpls.platform_mutex); 2786 2777 2787 2778 kvfree(platform_label); 2788 2779 }
+6 -1
net/mpls/internal.h
··· 185 185 return result; 186 186 } 187 187 188 + #define mpls_dereference(net, p) \ 189 + rcu_dereference_protected( \ 190 + (p), \ 191 + lockdep_is_held(&(net)->mpls.platform_mutex)) 192 + 188 193 static inline struct mpls_dev *mpls_dev_rcu(const struct net_device *dev) 189 194 { 190 195 return rcu_dereference(dev->mpls_ptr); ··· 198 193 static inline struct mpls_dev *mpls_dev_get(const struct net *net, 199 194 const struct net_device *dev) 200 195 { 201 - return rcu_dereference_rtnl(dev->mpls_ptr); 196 + return mpls_dereference(net, dev->mpls_ptr); 202 197 } 203 198 204 199 int nla_put_labels(struct sk_buff *skb, int attrtype, u8 labels,