Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

mctp: make __mctp_dev_get() take a refcount hold

Previously there was a race that could allow the mctp_dev refcount
to hit zero:

rcu_read_lock();
mdev = __mctp_dev_get(dev);
// mctp_unregister() happens here, mdev->refs hits zero
mctp_dev_hold(dev);
rcu_read_unlock();

Now we make __mctp_dev_get() take the hold itself. It is safe to test
against the zero refcount because __mctp_dev_get() is called holding
rcu_read_lock and mctp_dev uses kfree_rcu().

Reported-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: Matt Johnston <matt@codeconstruct.com.au>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Matt Johnston and committed by
David S. Miller
dc121c00 4767b7e2

+22 -5
+18 -3
net/mctp/device.c
··· 25 25 size_t a_idx; 26 26 }; 27 27 28 - /* unlocked: caller must hold rcu_read_lock */ 28 + /* unlocked: caller must hold rcu_read_lock. 29 + * Returned mctp_dev has its refcount incremented, or NULL if unset. 30 + */ 29 31 struct mctp_dev *__mctp_dev_get(const struct net_device *dev) 30 32 { 31 - return rcu_dereference(dev->mctp_ptr); 33 + struct mctp_dev *mdev = rcu_dereference(dev->mctp_ptr); 34 + 35 + /* RCU guarantees that any mdev is still live. 36 + * Zero refcount implies a pending free, return NULL. 37 + */ 38 + if (mdev) 39 + if (!refcount_inc_not_zero(&mdev->refs)) 40 + return NULL; 41 + return mdev; 32 42 } 33 43 44 + /* Returned mctp_dev does not have refcount incremented. The returned pointer 45 + * remains live while rtnl_lock is held, as that prevents mctp_unregister() 46 + */ 34 47 struct mctp_dev *mctp_dev_get_rtnl(const struct net_device *dev) 35 48 { 36 49 return rtnl_dereference(dev->mctp_ptr); ··· 137 124 if (mdev) { 138 125 rc = mctp_dump_dev_addrinfo(mdev, 139 126 skb, cb); 127 + mctp_dev_put(mdev); 140 128 // Error indicates full buffer, this 141 129 // callback will get retried. 142 130 if (rc < 0) ··· 312 298 313 299 void mctp_dev_put(struct mctp_dev *mdev) 314 300 { 315 - if (refcount_dec_and_test(&mdev->refs)) { 301 + if (mdev && refcount_dec_and_test(&mdev->refs)) { 316 302 dev_put(mdev->dev); 317 303 kfree_rcu(mdev, rcu); 318 304 } ··· 384 370 if (!mdev) 385 371 return 0; 386 372 ret = nla_total_size(4); /* IFLA_MCTP_NET */ 373 + mctp_dev_put(mdev); 387 374 return ret; 388 375 } 389 376
+4 -1
net/mctp/route.c
··· 836 836 { 837 837 struct mctp_sock *msk = container_of(sk, struct mctp_sock, sk); 838 838 struct mctp_skb_cb *cb = mctp_cb(skb); 839 - struct mctp_route tmp_rt; 839 + struct mctp_route tmp_rt = {0}; 840 840 struct mctp_sk_key *key; 841 841 struct net_device *dev; 842 842 struct mctp_hdr *hdr; ··· 948 948 mctp_route_release(rt); 949 949 950 950 dev_put(dev); 951 + mctp_dev_put(tmp_rt.dev); 951 952 952 953 return rc; 953 954 ··· 1125 1124 1126 1125 rt->output(rt, skb); 1127 1126 mctp_route_release(rt); 1127 + mctp_dev_put(mdev); 1128 1128 1129 1129 return NET_RX_SUCCESS; 1130 1130 1131 1131 err_drop: 1132 1132 kfree_skb(skb); 1133 + mctp_dev_put(mdev); 1133 1134 return NET_RX_DROP; 1134 1135 } 1135 1136
-1
net/mctp/test/utils.c
··· 54 54 55 55 rcu_read_lock(); 56 56 dev->mdev = __mctp_dev_get(ndev); 57 - mctp_dev_hold(dev->mdev); 58 57 rcu_read_unlock(); 59 58 60 59 return dev;