sctp: fix an use-after-free issue in sctp_sock_dump

Commit 86fdb3448cc1 ("sctp: ensure ep is not destroyed before doing the
dump") tried to fix an use-after-free issue by checking !sctp_sk(sk)->ep
with holding sock and sock lock.

But Paolo noticed that endpoint could be destroyed in sctp_rcv without
sock lock protection. It means the use-after-free issue still could be
triggered when sctp_rcv put and destroy ep after sctp_sock_dump checks
!ep, although it's pretty hard to reproduce.

I could reproduce it by mdelay in sctp_rcv while msleep in sctp_close
and sctp_sock_dump long time.

This patch is to add another param cb_done to sctp_for_each_transport
and dump ep->assocs with holding tsp after jumping out of transport's
traversal in it to avoid this issue.

It can also improve sctp diag dump to make it run faster, as no need
to save sk into cb->args[5] and keep calling sctp_for_each_transport
any more.

This patch is also to use int * instead of int for the pos argument
in sctp_for_each_transport, which could make postion increment only
in sctp_for_each_transport and no need to keep changing cb->args[2]
in sctp_sock_filter and sctp_sock_dump any more.

Fixes: 86fdb3448cc1 ("sctp: ensure ep is not destroyed before doing the dump")
Reported-by: Paolo Abeni <pabeni@redhat.com>
Signed-off-by: Xin Long <lucien.xin@gmail.com>
Acked-by: Marcelo Ricardo Leitner <marcelo.leitner@gmail.com>
Acked-by: Neil Horman <nhorman@tuxdriver.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by Xin Long and committed by David S. Miller d25adbeb 5023a6db

+36 -39
+2 -1
include/net/sctp/sctp.h
··· 127 const union sctp_addr *laddr, 128 const union sctp_addr *paddr, void *p); 129 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), 130 - struct net *net, int pos, void *p); 131 int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p); 132 int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc, 133 struct sctp_info *info);
··· 127 const union sctp_addr *laddr, 128 const union sctp_addr *paddr, void *p); 129 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), 130 + int (*cb_done)(struct sctp_transport *, void *), 131 + struct net *net, int *pos, void *p); 132 int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p); 133 int sctp_get_sctp_info(struct sock *sk, struct sctp_association *asoc, 134 struct sctp_info *info);
+9 -23
net/sctp/sctp_diag.c
··· 279 return err; 280 } 281 282 - static int sctp_sock_dump(struct sock *sk, void *p) 283 { 284 struct sctp_comm_param *commp = p; 285 struct sk_buff *skb = commp->skb; 286 struct netlink_callback *cb = commp->cb; 287 const struct inet_diag_req_v2 *r = commp->r; ··· 291 int err = 0; 292 293 lock_sock(sk); 294 - if (!sctp_sk(sk)->ep) 295 - goto release; 296 - list_for_each_entry(assoc, &sctp_sk(sk)->ep->asocs, asocs) { 297 if (cb->args[4] < cb->args[1]) 298 goto next; 299 ··· 327 cb->args[4]++; 328 } 329 cb->args[1] = 0; 330 - cb->args[2]++; 331 cb->args[3] = 0; 332 cb->args[4] = 0; 333 release: 334 release_sock(sk); 335 - sock_put(sk); 336 return err; 337 } 338 339 - static int sctp_get_sock(struct sctp_transport *tsp, void *p) 340 { 341 struct sctp_endpoint *ep = tsp->asoc->ep; 342 struct sctp_comm_param *commp = p; 343 struct sock *sk = ep->base.sk; 344 - struct netlink_callback *cb = commp->cb; 345 const struct inet_diag_req_v2 *r = commp->r; 346 struct sctp_association *assoc = 347 list_entry(ep->asocs.next, struct sctp_association, asocs); 348 349 /* find the ep only once through the transports by this condition */ 350 if (tsp->asoc != assoc) 351 - goto out; 352 353 if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family) 354 - goto out; 355 - 356 - sock_hold(sk); 357 - cb->args[5] = (long)sk; 358 359 return 1; 360 - 361 - out: 362 - cb->args[2]++; 363 - return 0; 364 } 365 366 static int sctp_ep_dump(struct sctp_endpoint *ep, void *p) ··· 493 if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) 494 goto done; 495 496 - next: 497 - cb->args[5] = 0; 498 - sctp_for_each_transport(sctp_get_sock, net, cb->args[2], &commp); 499 - 500 - if (cb->args[5] && !sctp_sock_dump((struct sock *)cb->args[5], &commp)) 501 - goto next; 502 503 done: 504 cb->args[1] = cb->args[4];
··· 279 return err; 280 } 281 282 + static int sctp_sock_dump(struct sctp_transport *tsp, void *p) 283 { 284 + struct sctp_endpoint *ep = tsp->asoc->ep; 285 struct sctp_comm_param *commp = p; 286 + struct sock *sk = ep->base.sk; 287 struct sk_buff *skb = commp->skb; 288 struct netlink_callback *cb = commp->cb; 289 const struct inet_diag_req_v2 *r = commp->r; ··· 289 int err = 0; 290 291 lock_sock(sk); 292 + list_for_each_entry(assoc, &ep->asocs, asocs) { 293 if (cb->args[4] < cb->args[1]) 294 goto next; 295 ··· 327 cb->args[4]++; 328 } 329 cb->args[1] = 0; 330 cb->args[3] = 0; 331 cb->args[4] = 0; 332 release: 333 release_sock(sk); 334 return err; 335 } 336 337 + static int sctp_sock_filter(struct sctp_transport *tsp, void *p) 338 { 339 struct sctp_endpoint *ep = tsp->asoc->ep; 340 struct sctp_comm_param *commp = p; 341 struct sock *sk = ep->base.sk; 342 const struct inet_diag_req_v2 *r = commp->r; 343 struct sctp_association *assoc = 344 list_entry(ep->asocs.next, struct sctp_association, asocs); 345 346 /* find the ep only once through the transports by this condition */ 347 if (tsp->asoc != assoc) 348 + return 0; 349 350 if (r->sdiag_family != AF_UNSPEC && sk->sk_family != r->sdiag_family) 351 + return 0; 352 353 return 1; 354 } 355 356 static int sctp_ep_dump(struct sctp_endpoint *ep, void *p) ··· 503 if (!(idiag_states & ~(TCPF_LISTEN | TCPF_CLOSE))) 504 goto done; 505 506 + sctp_for_each_transport(sctp_sock_filter, sctp_sock_dump, 507 + net, (int *)&cb->args[2], &commp); 508 509 done: 510 cb->args[1] = cb->args[4];
+25 -15
net/sctp/socket.c
··· 4658 EXPORT_SYMBOL_GPL(sctp_transport_lookup_process); 4659 4660 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), 4661 - struct net *net, int pos, void *p) { 4662 struct rhashtable_iter hti; 4663 - void *obj; 4664 - int err; 4665 4666 - err = sctp_transport_walk_start(&hti); 4667 - if (err) 4668 - return err; 4669 4670 - obj = sctp_transport_get_idx(net, &hti, pos + 1); 4671 - for (; !IS_ERR_OR_NULL(obj); obj = sctp_transport_get_next(net, &hti)) { 4672 - struct sctp_transport *transport = obj; 4673 - 4674 - if (!sctp_transport_hold(transport)) 4675 continue; 4676 - err = cb(transport, p); 4677 - sctp_transport_put(transport); 4678 - if (err) 4679 break; 4680 } 4681 sctp_transport_walk_stop(&hti); 4682 4683 - return err; 4684 } 4685 EXPORT_SYMBOL_GPL(sctp_for_each_transport); 4686
··· 4658 EXPORT_SYMBOL_GPL(sctp_transport_lookup_process); 4659 4660 int sctp_for_each_transport(int (*cb)(struct sctp_transport *, void *), 4661 + int (*cb_done)(struct sctp_transport *, void *), 4662 + struct net *net, int *pos, void *p) { 4663 struct rhashtable_iter hti; 4664 + struct sctp_transport *tsp; 4665 + int ret; 4666 4667 + again: 4668 + ret = sctp_transport_walk_start(&hti); 4669 + if (ret) 4670 + return ret; 4671 4672 + tsp = sctp_transport_get_idx(net, &hti, *pos + 1); 4673 + for (; !IS_ERR_OR_NULL(tsp); tsp = sctp_transport_get_next(net, &hti)) { 4674 + if (!sctp_transport_hold(tsp)) 4675 continue; 4676 + ret = cb(tsp, p); 4677 + if (ret) 4678 break; 4679 + (*pos)++; 4680 + sctp_transport_put(tsp); 4681 } 4682 sctp_transport_walk_stop(&hti); 4683 4684 + if (ret) { 4685 + if (cb_done && !cb_done(tsp, p)) { 4686 + (*pos)++; 4687 + sctp_transport_put(tsp); 4688 + goto again; 4689 + } 4690 + sctp_transport_put(tsp); 4691 + } 4692 + 4693 + return ret; 4694 } 4695 EXPORT_SYMBOL_GPL(sctp_for_each_transport); 4696