[NET]: Fix module reference counts for loadable protocol modules

I have been experimenting with loadable protocol modules, and ran into
several issues with module reference counting.

The first issue was that __module_get failed at the BUG_ON check at
the top of the routine (checking that my module reference count was
not zero) when I created the first socket. When sk_alloc() is called,
my module reference count was still 0. When I looked at why sctp
didn't have this problem, I discovered that sctp creates a control
socket during module init (when the module ref count is not 0), which
keeps the reference count non-zero. This section has been updated to
address the point Stephen raised about checking the return value of
try_module_get().

The next problem arose when my socket init routine returned an error.
This resulted in my module reference count being decremented below 0.
My socket ops->release routine was also being called. The issue here
is that sock_release() calls the ops->release routine and decrements
the ref count if sock->ops is not NULL. Since the socket probably
didn't get correctly initialized, this should not be done, so we will
set sock->ops to NULL because we will not call try_module_get().

While searching for another bug, I also noticed that sys_accept() has
a possibility of doing a module_put() when it did not do an
__module_get so I re-ordered the call to security_socket_accept().

Signed-off-by: Frank Filz <ffilzlnx@us.ibm.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by Frank Filz and committed by David S. Miller a79af59e 9356b8fc

+20 -13
+12 -8
net/core/sock.c
··· 660 660 sock_lock_init(sk); 661 661 } 662 662 663 - if (security_sk_alloc(sk, family, priority)) { 664 - if (slab != NULL) 665 - kmem_cache_free(slab, sk); 666 - else 667 - kfree(sk); 668 - sk = NULL; 669 - } else 670 - __module_get(prot->owner); 663 + if (security_sk_alloc(sk, family, priority)) 664 + goto out_free; 665 + 666 + if (!try_module_get(prot->owner)) 667 + goto out_free; 671 668 } 672 669 return sk; 670 + 671 + out_free: 672 + if (slab != NULL) 673 + kmem_cache_free(slab, sk); 674 + else 675 + kfree(sk); 676 + return NULL; 673 677 } 674 678 675 679 void sk_free(struct sock *sk)
+8 -5
net/socket.c
··· 1145 1145 if (!try_module_get(net_families[family]->owner)) 1146 1146 goto out_release; 1147 1147 1148 - if ((err = net_families[family]->create(sock, protocol)) < 0) 1148 + if ((err = net_families[family]->create(sock, protocol)) < 0) { 1149 + sock->ops = NULL; 1149 1150 goto out_module_put; 1151 + } 1152 + 1150 1153 /* 1151 1154 * Now to bump the refcnt of the [loadable] module that owns this 1152 1155 * socket at sock_release time we decrement its refcnt. ··· 1363 1360 newsock->type = sock->type; 1364 1361 newsock->ops = sock->ops; 1365 1362 1366 - err = security_socket_accept(sock, newsock); 1367 - if (err) 1368 - goto out_release; 1369 - 1370 1363 /* 1371 1364 * We don't need try_module_get here, as the listening socket (sock) 1372 1365 * has the protocol module (sock->ops->owner) held. 1373 1366 */ 1374 1367 __module_get(newsock->ops->owner); 1368 + 1369 + err = security_socket_accept(sock, newsock); 1370 + if (err) 1371 + goto out_release; 1375 1372 1376 1373 err = sock->ops->accept(sock, newsock, sock->file->f_flags); 1377 1374 if (err < 0)