Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

net/tcp: Disable TCP-MD5 static key on tcp_md5sig_info destruction

To do that, separate two scenarios:
- where it's the first MD5 key on the system, which means that enabling
of the static key may need to sleep;
- copying of an existing key from a listening socket to the request
socket upon receiving a signed TCP segment, where static key was
already enabled (when the key was added to the listening socket).

Now the life-time of the static branch for TCP-MD5 is until:
- last tcp_md5sig_info is destroyed
- last socket in time-wait state with MD5 key is closed.

Which means that after all sockets with TCP-MD5 keys are gone, the
system gets back the performance of disabled md5-key static branch.

While at here, provide static_key_fast_inc() helper that does ref
counter increment in atomic fashion (without grabbing cpus_read_lock()
on CONFIG_JUMP_LABEL=y). This is needed to add a new user for
a static_key when the caller controls the lifetime of another user.

Signed-off-by: Dmitry Safonov <dima@arista.com>
Acked-by: Jakub Kicinski <kuba@kernel.org>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

authored by

Dmitry Safonov and committed by
Jakub Kicinski
459837b5 f62c7517

+84 -32
+7 -3
include/net/tcp.h
··· 1675 1675 const struct sock *sk, const struct sk_buff *skb); 1676 1676 int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, 1677 1677 int family, u8 prefixlen, int l3index, u8 flags, 1678 - const u8 *newkey, u8 newkeylen, gfp_t gfp); 1678 + const u8 *newkey, u8 newkeylen); 1679 + int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, 1680 + int family, u8 prefixlen, int l3index, 1681 + struct tcp_md5sig_key *key); 1682 + 1679 1683 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, 1680 1684 int family, u8 prefixlen, int l3index, u8 flags); 1681 1685 struct tcp_md5sig_key *tcp_v4_md5_lookup(const struct sock *sk, ··· 1687 1683 1688 1684 #ifdef CONFIG_TCP_MD5SIG 1689 1685 #include <linux/jump_label.h> 1690 - extern struct static_key_false tcp_md5_needed; 1686 + extern struct static_key_false_deferred tcp_md5_needed; 1691 1687 struct tcp_md5sig_key *__tcp_md5_do_lookup(const struct sock *sk, int l3index, 1692 1688 const union tcp_md5_addr *addr, 1693 1689 int family); ··· 1695 1691 tcp_md5_do_lookup(const struct sock *sk, int l3index, 1696 1692 const union tcp_md5_addr *addr, int family) 1697 1693 { 1698 - if (!static_branch_unlikely(&tcp_md5_needed)) 1694 + if (!static_branch_unlikely(&tcp_md5_needed.key)) 1699 1695 return NULL; 1700 1696 return __tcp_md5_do_lookup(sk, l3index, addr, family); 1701 1697 }
+1 -4
net/ipv4/tcp.c
··· 4464 4464 if (unlikely(!READ_ONCE(tcp_md5sig_pool_populated))) { 4465 4465 mutex_lock(&tcp_md5sig_mutex); 4466 4466 4467 - if (!tcp_md5sig_pool_populated) { 4467 + if (!tcp_md5sig_pool_populated) 4468 4468 __tcp_alloc_md5sig_pool(); 4469 - if (tcp_md5sig_pool_populated) 4470 - static_branch_inc(&tcp_md5_needed); 4471 - } 4472 4469 4473 4470 mutex_unlock(&tcp_md5sig_mutex); 4474 4471 }
+58 -13
net/ipv4/tcp_ipv4.c
··· 1053 1053 * We need to maintain these in the sk structure. 1054 1054 */ 1055 1055 1056 - DEFINE_STATIC_KEY_FALSE(tcp_md5_needed); 1056 + DEFINE_STATIC_KEY_DEFERRED_FALSE(tcp_md5_needed, HZ); 1057 1057 EXPORT_SYMBOL(tcp_md5_needed); 1058 1058 1059 1059 static bool better_md5_match(struct tcp_md5sig_key *old, struct tcp_md5sig_key *new) ··· 1166 1166 struct tcp_sock *tp = tcp_sk(sk); 1167 1167 struct tcp_md5sig_info *md5sig; 1168 1168 1169 - if (rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) 1170 - return 0; 1171 - 1172 1169 md5sig = kmalloc(sizeof(*md5sig), gfp); 1173 1170 if (!md5sig) 1174 1171 return -ENOMEM; ··· 1177 1180 } 1178 1181 1179 1182 /* This can be called on a newly created socket, from other files */ 1180 - int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, 1181 - int family, u8 prefixlen, int l3index, u8 flags, 1182 - const u8 *newkey, u8 newkeylen, gfp_t gfp) 1183 + static int __tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, 1184 + int family, u8 prefixlen, int l3index, u8 flags, 1185 + const u8 *newkey, u8 newkeylen, gfp_t gfp) 1183 1186 { 1184 1187 /* Add Key to the list */ 1185 1188 struct tcp_md5sig_key *key; ··· 1206 1209 return 0; 1207 1210 } 1208 1211 1209 - if (tcp_md5sig_info_add(sk, gfp)) 1210 - return -ENOMEM; 1211 - 1212 1212 md5sig = rcu_dereference_protected(tp->md5sig_info, 1213 1213 lockdep_sock_is_held(sk)); 1214 1214 ··· 1229 1235 hlist_add_head_rcu(&key->node, &md5sig->head); 1230 1236 return 0; 1231 1237 } 1238 + 1239 + int tcp_md5_do_add(struct sock *sk, const union tcp_md5_addr *addr, 1240 + int family, u8 prefixlen, int l3index, u8 flags, 1241 + const u8 *newkey, u8 newkeylen) 1242 + { 1243 + struct tcp_sock *tp = tcp_sk(sk); 1244 + 1245 + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { 1246 + if (tcp_md5sig_info_add(sk, GFP_KERNEL)) 1247 + return -ENOMEM; 1248 + 1249 + if (!static_branch_inc(&tcp_md5_needed.key)) { 1250 + struct tcp_md5sig_info *md5sig; 1251 + 1252 + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); 1253 + rcu_assign_pointer(tp->md5sig_info, NULL); 1254 + kfree_rcu(md5sig); 1255 + return -EUSERS; 1256 + } 1257 + } 1258 + 1259 + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, flags, 1260 + newkey, newkeylen, GFP_KERNEL); 1261 + } 1232 1262 EXPORT_SYMBOL(tcp_md5_do_add); 1263 + 1264 + int tcp_md5_key_copy(struct sock *sk, const union tcp_md5_addr *addr, 1265 + int family, u8 prefixlen, int l3index, 1266 + struct tcp_md5sig_key *key) 1267 + { 1268 + struct tcp_sock *tp = tcp_sk(sk); 1269 + 1270 + if (!rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk))) { 1271 + if (tcp_md5sig_info_add(sk, sk_gfp_mask(sk, GFP_ATOMIC))) 1272 + return -ENOMEM; 1273 + 1274 + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) { 1275 + struct tcp_md5sig_info *md5sig; 1276 + 1277 + md5sig = rcu_dereference_protected(tp->md5sig_info, lockdep_sock_is_held(sk)); 1278 + net_warn_ratelimited("Too many TCP-MD5 keys in the system\n"); 1279 + rcu_assign_pointer(tp->md5sig_info, NULL); 1280 + kfree_rcu(md5sig); 1281 + return -EUSERS; 1282 + } 1283 + } 1284 + 1285 + return __tcp_md5_do_add(sk, addr, family, prefixlen, l3index, 1286 + key->flags, key->key, key->keylen, 1287 + sk_gfp_mask(sk, GFP_ATOMIC)); 1288 + } 1289 + EXPORT_SYMBOL(tcp_md5_key_copy); 1233 1290 1234 1291 int tcp_md5_do_del(struct sock *sk, const union tcp_md5_addr *addr, int family, 1235 1292 u8 prefixlen, int l3index, u8 flags) ··· 1368 1323 return -EINVAL; 1369 1324 1370 1325 return tcp_md5_do_add(sk, addr, AF_INET, prefixlen, l3index, flags, 1371 - cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); 1326 + cmd.tcpm_key, cmd.tcpm_keylen); 1372 1327 } 1373 1328 1374 1329 static int tcp_v4_md5_hash_headers(struct tcp_md5sig_pool *hp, ··· 1625 1580 * memory, then we end up not copying the key 1626 1581 * across. Shucks. 1627 1582 */ 1628 - tcp_md5_do_add(newsk, addr, AF_INET, 32, l3index, key->flags, 1629 - key->key, key->keylen, GFP_ATOMIC); 1583 + tcp_md5_key_copy(newsk, addr, AF_INET, 32, l3index, key); 1630 1584 sk_gso_disable(newsk); 1631 1585 } 1632 1586 #endif ··· 2317 2273 tcp_clear_md5_list(sk); 2318 2274 kfree_rcu(rcu_dereference_protected(tp->md5sig_info, 1), rcu); 2319 2275 tp->md5sig_info = NULL; 2276 + static_branch_slow_dec_deferred(&tcp_md5_needed); 2320 2277 } 2321 2278 #endif 2322 2279
+12 -4
net/ipv4/tcp_minisocks.c
··· 291 291 */ 292 292 do { 293 293 tcptw->tw_md5_key = NULL; 294 - if (static_branch_unlikely(&tcp_md5_needed)) { 294 + if (static_branch_unlikely(&tcp_md5_needed.key)) { 295 295 struct tcp_md5sig_key *key; 296 296 297 297 key = tp->af_specific->md5_lookup(sk, sk); 298 298 if (key) { 299 299 tcptw->tw_md5_key = kmemdup(key, sizeof(*key), GFP_ATOMIC); 300 - BUG_ON(tcptw->tw_md5_key && !tcp_alloc_md5sig_pool()); 300 + if (!tcptw->tw_md5_key) 301 + break; 302 + BUG_ON(!tcp_alloc_md5sig_pool()); 303 + if (!static_key_fast_inc_not_disabled(&tcp_md5_needed.key.key)) { 304 + kfree(tcptw->tw_md5_key); 305 + tcptw->tw_md5_key = NULL; 306 + } 301 307 } 302 308 } 303 309 } while (0); ··· 343 337 void tcp_twsk_destructor(struct sock *sk) 344 338 { 345 339 #ifdef CONFIG_TCP_MD5SIG 346 - if (static_branch_unlikely(&tcp_md5_needed)) { 340 + if (static_branch_unlikely(&tcp_md5_needed.key)) { 347 341 struct tcp_timewait_sock *twsk = tcp_twsk(sk); 348 342 349 - if (twsk->tw_md5_key) 343 + if (twsk->tw_md5_key) { 350 344 kfree_rcu(twsk->tw_md5_key, rcu); 345 + static_branch_slow_dec_deferred(&tcp_md5_needed); 346 + } 351 347 } 352 348 #endif 353 349 }
+2 -2
net/ipv4/tcp_output.c
··· 766 766 767 767 *md5 = NULL; 768 768 #ifdef CONFIG_TCP_MD5SIG 769 - if (static_branch_unlikely(&tcp_md5_needed) && 769 + if (static_branch_unlikely(&tcp_md5_needed.key) && 770 770 rcu_access_pointer(tp->md5sig_info)) { 771 771 *md5 = tp->af_specific->md5_lookup(sk, sk); 772 772 if (*md5) { ··· 922 922 923 923 *md5 = NULL; 924 924 #ifdef CONFIG_TCP_MD5SIG 925 - if (static_branch_unlikely(&tcp_md5_needed) && 925 + if (static_branch_unlikely(&tcp_md5_needed.key) && 926 926 rcu_access_pointer(tp->md5sig_info)) { 927 927 *md5 = tp->af_specific->md5_lookup(sk, sk); 928 928 if (*md5) {
+4 -6
net/ipv6/tcp_ipv6.c
··· 665 665 if (ipv6_addr_v4mapped(&sin6->sin6_addr)) 666 666 return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr.s6_addr32[3], 667 667 AF_INET, prefixlen, l3index, flags, 668 - cmd.tcpm_key, cmd.tcpm_keylen, 669 - GFP_KERNEL); 668 + cmd.tcpm_key, cmd.tcpm_keylen); 670 669 671 670 return tcp_md5_do_add(sk, (union tcp_md5_addr *)&sin6->sin6_addr, 672 671 AF_INET6, prefixlen, l3index, flags, 673 - cmd.tcpm_key, cmd.tcpm_keylen, GFP_KERNEL); 672 + cmd.tcpm_key, cmd.tcpm_keylen); 674 673 } 675 674 676 675 static int tcp_v6_md5_hash_headers(struct tcp_md5sig_pool *hp, ··· 1369 1370 * memory, then we end up not copying the key 1370 1371 * across. Shucks. 1371 1372 */ 1372 - tcp_md5_do_add(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr, 1373 - AF_INET6, 128, l3index, key->flags, key->key, key->keylen, 1374 - sk_gfp_mask(sk, GFP_ATOMIC)); 1373 + tcp_md5_key_copy(newsk, (union tcp_md5_addr *)&newsk->sk_v6_daddr, 1374 + AF_INET6, 128, l3index, key); 1375 1375 } 1376 1376 #endif 1377 1377