Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

[NETLINK]: Fix module refcounting problems

Use-after-free: the struct proto_ops containing the module pointer
is freed when a socket with pid=0 is released, which besides for kernel
sockets is true for all unbound sockets.

Module refcount leak: when the kernel socket is closed before all user
sockets have been closed the proto_ops struct for this family is
replaced by the generic one and the module refcount can't be dropped.

The second problem can't be solved cleanly using module refcounting in the
generic socket code, so this patch adds explicit refcounting to
netlink_create/netlink_release.

Signed-off-by: Patrick McHardy <kaber@trash.net>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Patrick McHardy and committed by
David S. Miller
77247bbb db080529

+36 -66
+36 -66
net/netlink/af_netlink.c
··· 73 73 struct netlink_callback *cb; 74 74 spinlock_t cb_lock; 75 75 void (*data_ready)(struct sock *sk, int bytes); 76 + struct module *module; 77 + u32 flags; 76 78 }; 79 + 80 + #define NETLINK_KERNEL_SOCKET 0x1 77 81 78 82 static inline struct netlink_sock *nlk_sk(struct sock *sk) 79 83 { ··· 101 97 struct nl_pid_hash hash; 102 98 struct hlist_head mc_list; 103 99 unsigned int nl_nonroot; 104 - struct proto_ops *p_ops; 100 + struct module *module; 105 101 }; 106 102 107 103 static struct netlink_table *nl_table; ··· 342 338 { 343 339 struct sock *sk; 344 340 struct netlink_sock *nlk; 341 + struct module *module; 345 342 346 343 sock->state = SS_UNCONNECTED; 347 344 ··· 352 347 if (protocol<0 || protocol >= MAX_LINKS) 353 348 return -EPROTONOSUPPORT; 354 349 355 - netlink_table_grab(); 350 + netlink_lock_table(); 356 351 if (!nl_table[protocol].hash.entries) { 357 352 #ifdef CONFIG_KMOD 358 353 /* We do 'best effort'. If we find a matching module, 359 354 * it is loaded. If not, we don't return an error to 360 355 * allow pure userspace<->userspace communication. -HW 361 356 */ 362 - netlink_table_ungrab(); 357 + netlink_unlock_table(); 363 358 request_module("net-pf-%d-proto-%d", PF_NETLINK, protocol); 364 - netlink_table_grab(); 359 + netlink_lock_table(); 365 360 #endif 366 361 } 367 - netlink_table_ungrab(); 362 + module = nl_table[protocol].module; 363 + if (!try_module_get(module)) 364 + module = NULL; 365 + netlink_unlock_table(); 368 366 369 - sock->ops = nl_table[protocol].p_ops; 367 + sock->ops = &netlink_ops; 370 368 371 369 sk = sk_alloc(PF_NETLINK, GFP_KERNEL, &netlink_proto, 1); 372 - if (!sk) 370 + if (!sk) { 371 + module_put(module); 373 372 return -ENOMEM; 373 + } 374 374 375 375 sock_init_data(sock, sk); 376 376 377 377 nlk = nlk_sk(sk); 378 378 379 + nlk->module = module; 379 380 spin_lock_init(&nlk->cb_lock); 380 381 init_waitqueue_head(&nlk->wait); 381 382 sk->sk_destruct = netlink_sock_destruct; ··· 426 415 notifier_call_chain(&netlink_chain, NETLINK_URELEASE, &n); 427 416 } 428 417 429 - /* When this is a kernel socket, we need to remove the owner pointer, 430 - * since we don't know whether the module will be dying at any given 431 - * point - HW 432 - */ 433 - if (!nlk->pid) { 434 - struct proto_ops *p_tmp; 418 + if (nlk->module) 419 + module_put(nlk->module); 435 420 421 + if (nlk->flags & NETLINK_KERNEL_SOCKET) { 436 422 netlink_table_grab(); 437 - p_tmp = nl_table[sk->sk_protocol].p_ops; 438 - if (p_tmp != &netlink_ops) { 439 - nl_table[sk->sk_protocol].p_ops = &netlink_ops; 440 - kfree(p_tmp); 441 - } 423 + nl_table[sk->sk_protocol].module = NULL; 442 424 netlink_table_ungrab(); 443 425 } 444 - 426 + 445 427 sock_put(sk); 446 428 return 0; 447 429 } ··· 1064 1060 struct sock * 1065 1061 netlink_kernel_create(int unit, void (*input)(struct sock *sk, int len), struct module *module) 1066 1062 { 1067 - struct proto_ops *p_ops; 1068 1063 struct socket *sock; 1069 1064 struct sock *sk; 1065 + struct netlink_sock *nlk; 1070 1066 1071 1067 if (!nl_table) 1072 1068 return NULL; ··· 1074 1070 if (unit<0 || unit>=MAX_LINKS) 1075 1071 return NULL; 1076 1072 1077 - /* Do a quick check, to make us not go down to netlink_insert() 1078 - * if protocol already has kernel socket. 1079 - */ 1080 - sk = netlink_lookup(unit, 0); 1081 - if (unlikely(sk)) { 1082 - sock_put(sk); 1083 - return NULL; 1084 - } 1085 - 1086 1073 if (sock_create_lite(PF_NETLINK, SOCK_DGRAM, unit, &sock)) 1087 1074 return NULL; 1088 1075 1089 - sk = NULL; 1090 - if (module) { 1091 - /* Every registering protocol implemented in a module needs 1092 - * it's own p_ops, since the socket code cannot deal with 1093 - * module refcounting otherwise. -HW 1094 - */ 1095 - p_ops = kmalloc(sizeof(*p_ops), GFP_KERNEL); 1096 - if (!p_ops) 1097 - goto out_sock_release; 1098 - 1099 - memcpy(p_ops, &netlink_ops, sizeof(*p_ops)); 1100 - p_ops->owner = module; 1101 - } else 1102 - p_ops = &netlink_ops; 1103 - 1104 - netlink_table_grab(); 1105 - nl_table[unit].p_ops = p_ops; 1106 - netlink_table_ungrab(); 1107 - 1108 - if (netlink_create(sock, unit) < 0) { 1109 - sk = NULL; 1110 - goto out_kfree_p_ops; 1111 - } 1076 + if (netlink_create(sock, unit) < 0) 1077 + goto out_sock_release; 1112 1078 1113 1079 sk = sock->sk; 1114 1080 sk->sk_data_ready = netlink_data_ready; 1115 1081 if (input) 1116 1082 nlk_sk(sk)->data_ready = input; 1117 1083 1118 - if (netlink_insert(sk, 0)) { 1119 - sk = NULL; 1120 - goto out_kfree_p_ops; 1121 - } 1084 + if (netlink_insert(sk, 0)) 1085 + goto out_sock_release; 1086 + 1087 + nlk = nlk_sk(sk); 1088 + nlk->flags |= NETLINK_KERNEL_SOCKET; 1089 + 1090 + netlink_table_grab(); 1091 + nl_table[unit].module = module; 1092 + netlink_table_ungrab(); 1122 1093 1123 1094 return sk; 1124 1095 1125 - out_kfree_p_ops: 1126 - netlink_table_grab(); 1127 - if (nl_table[unit].p_ops != &netlink_ops) { 1128 - kfree(nl_table[unit].p_ops); 1129 - nl_table[unit].p_ops = &netlink_ops; 1130 - } 1131 - netlink_table_ungrab(); 1132 1096 out_sock_release: 1133 1097 sock_release(sock); 1134 - return sk; 1098 + return NULL; 1135 1099 } 1136 1100 1137 1101 void netlink_set_nonroot(int protocol, unsigned int flags) ··· 1461 1489 1462 1490 for (i = 0; i < MAX_LINKS; i++) { 1463 1491 struct nl_pid_hash *hash = &nl_table[i].hash; 1464 - 1465 - nl_table[i].p_ops = &netlink_ops; 1466 1492 1467 1493 hash->table = nl_pid_hash_alloc(1 * sizeof(*hash->table)); 1468 1494 if (!hash->table) {