Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

inet: Sanitize inet{,6} protocol demux.

Don't pretend that inet_protos[] and inet6_protos[] are hashes, thay
are just a straight arrays. Remove all unnecessary hash masking.

Document MAX_INET_PROTOS.

Use RAW_HTABLE_SIZE when appropriate.

Reported-by: Ben Hutchings <bhutchings@solarflare.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

+37 -48
+5 -2
include/net/protocol.h
··· 29 29 #include <linux/ipv6.h> 30 30 #endif 31 31 32 - #define MAX_INET_PROTOS 256 /* Must be a power of 2 */ 33 - 32 + /* This is one larger than the largest protocol value that can be 33 + * found in an ipv4 or ipv6 header. Since in both cases the protocol 34 + * value is presented in a __u8, this is defined to be 256. 35 + */ 36 + #define MAX_INET_PROTOS 256 34 37 35 38 /* This is used to register protocols. */ 36 39 struct net_protocol {
+13 -15
net/ipv4/af_inet.c
··· 242 242 } 243 243 EXPORT_SYMBOL(build_ehash_secret); 244 244 245 - static inline int inet_netns_ok(struct net *net, int protocol) 245 + static inline int inet_netns_ok(struct net *net, __u8 protocol) 246 246 { 247 - int hash; 248 247 const struct net_protocol *ipprot; 249 248 250 249 if (net_eq(net, &init_net)) 251 250 return 1; 252 251 253 - hash = protocol & (MAX_INET_PROTOS - 1); 254 - ipprot = rcu_dereference(inet_protos[hash]); 255 - 256 - if (ipprot == NULL) 252 + ipprot = rcu_dereference(inet_protos[protocol]); 253 + if (ipprot == NULL) { 257 254 /* raw IP is OK */ 258 255 return 1; 256 + } 259 257 return ipprot->netns_ok; 260 258 } 261 259 ··· 1214 1216 1215 1217 static int inet_gso_send_check(struct sk_buff *skb) 1216 1218 { 1217 - const struct iphdr *iph; 1218 1219 const struct net_protocol *ops; 1220 + const struct iphdr *iph; 1219 1221 int proto; 1220 1222 int ihl; 1221 1223 int err = -EINVAL; ··· 1234 1236 __skb_pull(skb, ihl); 1235 1237 skb_reset_transport_header(skb); 1236 1238 iph = ip_hdr(skb); 1237 - proto = iph->protocol & (MAX_INET_PROTOS - 1); 1239 + proto = iph->protocol; 1238 1240 err = -EPROTONOSUPPORT; 1239 1241 1240 1242 rcu_read_lock(); ··· 1251 1253 netdev_features_t features) 1252 1254 { 1253 1255 struct sk_buff *segs = ERR_PTR(-EINVAL); 1254 - struct iphdr *iph; 1255 1256 const struct net_protocol *ops; 1257 + struct iphdr *iph; 1256 1258 int proto; 1257 1259 int ihl; 1258 1260 int id; ··· 1284 1286 skb_reset_transport_header(skb); 1285 1287 iph = ip_hdr(skb); 1286 1288 id = ntohs(iph->id); 1287 - proto = iph->protocol & (MAX_INET_PROTOS - 1); 1289 + proto = iph->protocol; 1288 1290 segs = ERR_PTR(-EPROTONOSUPPORT); 1289 1291 1290 1292 rcu_read_lock(); ··· 1338 1340 goto out; 1339 1341 } 1340 1342 1341 - proto = iph->protocol & (MAX_INET_PROTOS - 1); 1343 + proto = iph->protocol; 1342 1344 1343 1345 rcu_read_lock(); 1344 1346 ops = rcu_dereference(inet_protos[proto]); ··· 1396 1398 1397 1399 static int inet_gro_complete(struct sk_buff *skb) 1398 1400 { 1399 - const struct net_protocol *ops; 1400 - struct iphdr *iph = ip_hdr(skb); 1401 - int proto = iph->protocol & (MAX_INET_PROTOS - 1); 1402 - int err = -ENOSYS; 1403 1401 __be16 newlen = htons(skb->len - skb_network_offset(skb)); 1402 + struct iphdr *iph = ip_hdr(skb); 1403 + const struct net_protocol *ops; 1404 + int proto = iph->protocol; 1405 + int err = -ENOSYS; 1404 1406 1405 1407 csum_replace2(&iph->check, iph->tot_len, newlen); 1406 1408 iph->tot_len = newlen;
+4 -5
net/ipv4/icmp.c
··· 637 637 638 638 static void icmp_unreach(struct sk_buff *skb) 639 639 { 640 + const struct net_protocol *ipprot; 640 641 const struct iphdr *iph; 641 642 struct icmphdr *icmph; 642 - int hash, protocol; 643 - const struct net_protocol *ipprot; 644 - u32 info = 0; 645 643 struct net *net; 644 + u32 info = 0; 645 + int protocol; 646 646 647 647 net = dev_net(skb_dst(skb)->dev); 648 648 ··· 731 731 */ 732 732 raw_icmp_error(skb, protocol, info); 733 733 734 - hash = protocol & (MAX_INET_PROTOS - 1); 735 734 rcu_read_lock(); 736 - ipprot = rcu_dereference(inet_protos[hash]); 735 + ipprot = rcu_dereference(inet_protos[protocol]); 737 736 if (ipprot && ipprot->err_handler) 738 737 ipprot->err_handler(skb, info); 739 738 rcu_read_unlock();
+2 -3
net/ipv4/ip_input.c
··· 198 198 rcu_read_lock(); 199 199 { 200 200 int protocol = ip_hdr(skb)->protocol; 201 - int hash, raw; 202 201 const struct net_protocol *ipprot; 202 + int raw; 203 203 204 204 resubmit: 205 205 raw = raw_local_deliver(skb, protocol); 206 206 207 - hash = protocol & (MAX_INET_PROTOS - 1); 208 - ipprot = rcu_dereference(inet_protos[hash]); 207 + ipprot = rcu_dereference(inet_protos[protocol]); 209 208 if (ipprot != NULL) { 210 209 int ret; 211 210
+3 -5
net/ipv4/protocol.c
··· 36 36 37 37 int inet_add_protocol(const struct net_protocol *prot, unsigned char protocol) 38 38 { 39 - int hash = protocol & (MAX_INET_PROTOS - 1); 40 - 41 - return !cmpxchg((const struct net_protocol **)&inet_protos[hash], 39 + return !cmpxchg((const struct net_protocol **)&inet_protos[protocol], 42 40 NULL, prot) ? 0 : -1; 43 41 } 44 42 EXPORT_SYMBOL(inet_add_protocol); ··· 47 49 48 50 int inet_del_protocol(const struct net_protocol *prot, unsigned char protocol) 49 51 { 50 - int ret, hash = protocol & (MAX_INET_PROTOS - 1); 52 + int ret; 51 53 52 - ret = (cmpxchg((const struct net_protocol **)&inet_protos[hash], 54 + ret = (cmpxchg((const struct net_protocol **)&inet_protos[protocol], 53 55 prot, NULL) == prot) ? 0 : -1; 54 56 55 57 synchronize_net();
+2 -5
net/ipv6/icmp.c
··· 600 600 { 601 601 const struct inet6_protocol *ipprot; 602 602 int inner_offset; 603 - int hash; 604 - u8 nexthdr; 605 603 __be16 frag_off; 604 + u8 nexthdr; 606 605 607 606 if (!pskb_may_pull(skb, sizeof(struct ipv6hdr))) 608 607 return; ··· 628 629 --ANK (980726) 629 630 */ 630 631 631 - hash = nexthdr & (MAX_INET_PROTOS - 1); 632 - 633 632 rcu_read_lock(); 634 - ipprot = rcu_dereference(inet6_protos[hash]); 633 + ipprot = rcu_dereference(inet6_protos[nexthdr]); 635 634 if (ipprot && ipprot->err_handler) 636 635 ipprot->err_handler(skb, NULL, type, code, inner_offset, info); 637 636 rcu_read_unlock();
+3 -6
net/ipv6/ip6_input.c
··· 168 168 169 169 static int ip6_input_finish(struct sk_buff *skb) 170 170 { 171 + struct net *net = dev_net(skb_dst(skb)->dev); 171 172 const struct inet6_protocol *ipprot; 173 + struct inet6_dev *idev; 172 174 unsigned int nhoff; 173 175 int nexthdr; 174 176 bool raw; 175 - u8 hash; 176 - struct inet6_dev *idev; 177 - struct net *net = dev_net(skb_dst(skb)->dev); 178 177 179 178 /* 180 179 * Parse extension headers ··· 188 189 nexthdr = skb_network_header(skb)[nhoff]; 189 190 190 191 raw = raw6_local_deliver(skb, nexthdr); 191 - 192 - hash = nexthdr & (MAX_INET_PROTOS - 1); 193 - if ((ipprot = rcu_dereference(inet6_protos[hash])) != NULL) { 192 + if ((ipprot = rcu_dereference(inet6_protos[nexthdr])) != NULL) { 194 193 int ret; 195 194 196 195 if (ipprot->flags & INET6_PROTO_FINAL) {
+3 -5
net/ipv6/protocol.c
··· 29 29 30 30 int inet6_add_protocol(const struct inet6_protocol *prot, unsigned char protocol) 31 31 { 32 - int hash = protocol & (MAX_INET_PROTOS - 1); 33 - 34 - return !cmpxchg((const struct inet6_protocol **)&inet6_protos[hash], 32 + return !cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol], 35 33 NULL, prot) ? 0 : -1; 36 34 } 37 35 EXPORT_SYMBOL(inet6_add_protocol); ··· 40 42 41 43 int inet6_del_protocol(const struct inet6_protocol *prot, unsigned char protocol) 42 44 { 43 - int ret, hash = protocol & (MAX_INET_PROTOS - 1); 45 + int ret; 44 46 45 - ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[hash], 47 + ret = (cmpxchg((const struct inet6_protocol **)&inet6_protos[protocol], 46 48 prot, NULL) == prot) ? 0 : -1; 47 49 48 50 synchronize_net();
+2 -2
net/ipv6/raw.c
··· 165 165 saddr = &ipv6_hdr(skb)->saddr; 166 166 daddr = saddr + 1; 167 167 168 - hash = nexthdr & (MAX_INET_PROTOS - 1); 168 + hash = nexthdr & (RAW_HTABLE_SIZE - 1); 169 169 170 170 read_lock(&raw_v6_hashinfo.lock); 171 171 sk = sk_head(&raw_v6_hashinfo.ht[hash]); ··· 229 229 { 230 230 struct sock *raw_sk; 231 231 232 - raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (MAX_INET_PROTOS - 1)]); 232 + raw_sk = sk_head(&raw_v6_hashinfo.ht[nexthdr & (RAW_HTABLE_SIZE - 1)]); 233 233 if (raw_sk && !ipv6_raw_deliver(skb, nexthdr)) 234 234 raw_sk = NULL; 235 235