Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

RDMA/core: Use refcount_t instead of atomic_t on refcount of mcast_member

The refcount_t API will WARN on underflow and overflow of a reference
counter, and avoid use-after-free risks.

Link: https://lore.kernel.org/r/1622194663-2383-5-git-send-email-liweihang@huawei.com
Signed-off-by: Weihang Li <liweihang@huawei.com>
Signed-off-by: Jason Gunthorpe <jgg@nvidia.com>

authored by

Weihang Li and committed by
Jason Gunthorpe
cd74db6c 64485080

+6 -6
+6 -6
drivers/infiniband/core/multicast.c
··· 117 117 struct mcast_group *group; 118 118 struct list_head list; 119 119 enum mcast_state state; 120 - atomic_t refcount; 120 + refcount_t refcount; 121 121 struct completion comp; 122 122 }; 123 123 ··· 199 199 200 200 static void deref_member(struct mcast_member *member) 201 201 { 202 - if (atomic_dec_and_test(&member->refcount)) 202 + if (refcount_dec_and_test(&member->refcount)) 203 203 complete(&member->comp); 204 204 } 205 205 ··· 401 401 while (!list_empty(&group->active_list)) { 402 402 member = list_entry(group->active_list.next, 403 403 struct mcast_member, list); 404 - atomic_inc(&member->refcount); 404 + refcount_inc(&member->refcount); 405 405 list_del_init(&member->list); 406 406 adjust_membership(group, member->multicast.rec.join_state, -1); 407 407 member->state = MCAST_ERROR; ··· 445 445 struct mcast_member, list); 446 446 multicast = &member->multicast; 447 447 join_state = multicast->rec.join_state; 448 - atomic_inc(&member->refcount); 448 + refcount_inc(&member->refcount); 449 449 450 450 if (join_state == (group->rec.join_state & join_state)) { 451 451 status = cmp_rec(&group->rec, &multicast->rec, ··· 497 497 member = list_entry(group->pending_list.next, 498 498 struct mcast_member, list); 499 499 if (group->last_join == member) { 500 - atomic_inc(&member->refcount); 500 + refcount_inc(&member->refcount); 501 501 list_del_init(&member->list); 502 502 spin_unlock_irq(&group->lock); 503 503 ret = member->multicast.callback(status, &member->multicast); ··· 632 632 member->multicast.callback = callback; 633 633 member->multicast.context = context; 634 634 init_completion(&member->comp); 635 - atomic_set(&member->refcount, 1); 635 + refcount_set(&member->refcount, 1); 636 636 member->state = MCAST_JOINING; 637 637 638 638 member->group = acquire_group(&dev->port[port_num - dev->start_port],