Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

bridge: simplify ip_mc_check_igmp() and ipv6_mc_check_mld() internals

With this patch the internal use of the skb_trimmed is reduced to
the ICMPv6/IGMP checksum verification. And for the length checks
the newly introduced helper functions are used instead of calculating
and checking with skb->len directly.

These changes should hopefully make it easier to verify that length
checks are performed properly.

Signed-off-by: Linus Lüssing <linus.luessing@c0d3.blue>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Linus Lüssing and committed by
David S. Miller
a2e2ca3b ba5ea614

+52 -61
+22 -29
net/ipv4/igmp.c
··· 1493 1493 1494 1494 len += sizeof(struct igmpv3_report); 1495 1495 1496 - return pskb_may_pull(skb, len) ? 0 : -EINVAL; 1496 + return ip_mc_may_pull(skb, len) ? 0 : -EINVAL; 1497 1497 } 1498 1498 1499 1499 static int ip_mc_check_igmp_query(struct sk_buff *skb) 1500 1500 { 1501 - unsigned int len = skb_transport_offset(skb); 1502 - 1503 - len += sizeof(struct igmphdr); 1504 - if (skb->len < len) 1505 - return -EINVAL; 1501 + unsigned int transport_len = ip_transport_len(skb); 1502 + unsigned int len; 1506 1503 1507 1504 /* IGMPv{1,2}? */ 1508 - if (skb->len != len) { 1505 + if (transport_len != sizeof(struct igmphdr)) { 1509 1506 /* or IGMPv3? */ 1510 - len += sizeof(struct igmpv3_query) - sizeof(struct igmphdr); 1511 - if (skb->len < len || !pskb_may_pull(skb, len)) 1507 + if (transport_len < sizeof(struct igmpv3_query)) 1508 + return -EINVAL; 1509 + 1510 + len = skb_transport_offset(skb) + sizeof(struct igmpv3_query); 1511 + if (!ip_mc_may_pull(skb, len)) 1512 1512 return -EINVAL; 1513 1513 } 1514 1514 ··· 1544 1544 return skb_checksum_simple_validate(skb); 1545 1545 } 1546 1546 1547 - static int __ip_mc_check_igmp(struct sk_buff *skb) 1548 - 1547 + static int ip_mc_check_igmp_csum(struct sk_buff *skb) 1549 1548 { 1550 - struct sk_buff *skb_chk; 1551 - unsigned int transport_len; 1552 1549 unsigned int len = skb_transport_offset(skb) + sizeof(struct igmphdr); 1553 - int ret = -EINVAL; 1550 + unsigned int transport_len = ip_transport_len(skb); 1551 + struct sk_buff *skb_chk; 1554 1552 1555 - transport_len = ntohs(ip_hdr(skb)->tot_len) - ip_hdrlen(skb); 1553 + if (!ip_mc_may_pull(skb, len)) 1554 + return -EINVAL; 1556 1555 1557 1556 skb_chk = skb_checksum_trimmed(skb, transport_len, 1558 1557 ip_mc_validate_checksum); 1559 1558 if (!skb_chk) 1560 - goto err; 1559 + return -EINVAL; 1561 1560 1562 - if (!pskb_may_pull(skb_chk, len)) 1563 - goto err; 1564 - 1565 - ret = ip_mc_check_igmp_msg(skb_chk); 1566 - if (ret) 1567 - goto err; 1568 - 1569 - ret = 0; 1570 - 1571 - err: 1572 - if (skb_chk && skb_chk != skb) 1561 + if (skb_chk != skb) 1573 1562 kfree_skb(skb_chk); 1574 1563 1575 - return ret; 1564 + return 0; 1576 1565 } 1577 1566 1578 1567 /** ··· 1589 1600 if (ip_hdr(skb)->protocol != IPPROTO_IGMP) 1590 1601 return -ENOMSG; 1591 1602 1592 - return __ip_mc_check_igmp(skb); 1603 + ret = ip_mc_check_igmp_csum(skb); 1604 + if (ret < 0) 1605 + return ret; 1606 + 1607 + return ip_mc_check_igmp_msg(skb); 1593 1608 } 1594 1609 EXPORT_SYMBOL(ip_mc_check_igmp); 1595 1610
+30 -32
net/ipv6/mcast_snoop.c
··· 77 77 78 78 len += sizeof(struct mld2_report); 79 79 80 - return pskb_may_pull(skb, len) ? 0 : -EINVAL; 80 + return ipv6_mc_may_pull(skb, len) ? 0 : -EINVAL; 81 81 } 82 82 83 83 static int ipv6_mc_check_mld_query(struct sk_buff *skb) 84 84 { 85 + unsigned int transport_len = ipv6_transport_len(skb); 85 86 struct mld_msg *mld; 86 - unsigned int len = skb_transport_offset(skb); 87 + unsigned int len; 87 88 88 89 /* RFC2710+RFC3810 (MLDv1+MLDv2) require link-local source addresses */ 89 90 if (!(ipv6_addr_type(&ipv6_hdr(skb)->saddr) & IPV6_ADDR_LINKLOCAL)) 90 91 return -EINVAL; 91 92 92 - len += sizeof(struct mld_msg); 93 - if (skb->len < len) 94 - return -EINVAL; 95 - 96 93 /* MLDv1? */ 97 - if (skb->len != len) { 94 + if (transport_len != sizeof(struct mld_msg)) { 98 95 /* or MLDv2? */ 99 - len += sizeof(struct mld2_query) - sizeof(struct mld_msg); 100 - if (skb->len < len || !pskb_may_pull(skb, len)) 96 + if (transport_len < sizeof(struct mld2_query)) 97 + return -EINVAL; 98 + 99 + len = skb_transport_offset(skb) + sizeof(struct mld2_query); 100 + if (!ipv6_mc_may_pull(skb, len)) 101 101 return -EINVAL; 102 102 } 103 103 ··· 115 115 116 116 static int ipv6_mc_check_mld_msg(struct sk_buff *skb) 117 117 { 118 - struct mld_msg *mld = (struct mld_msg *)skb_transport_header(skb); 118 + unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); 119 + struct mld_msg *mld; 120 + 121 + if (!ipv6_mc_may_pull(skb, len)) 122 + return -EINVAL; 123 + 124 + mld = (struct mld_msg *)skb_transport_header(skb); 119 125 120 126 switch (mld->mld_type) { 121 127 case ICMPV6_MGM_REDUCTION: ··· 142 136 return skb_checksum_validate(skb, IPPROTO_ICMPV6, ip6_compute_pseudo); 143 137 } 144 138 145 - static int __ipv6_mc_check_mld(struct sk_buff *skb) 146 - 139 + static int ipv6_mc_check_icmpv6(struct sk_buff *skb) 147 140 { 148 - struct sk_buff *skb_chk = NULL; 149 - unsigned int transport_len; 150 - unsigned int len = skb_transport_offset(skb) + sizeof(struct mld_msg); 151 - int ret = -EINVAL; 141 + unsigned int len = skb_transport_offset(skb) + sizeof(struct icmp6hdr); 142 + unsigned int transport_len = ipv6_transport_len(skb); 143 + struct sk_buff *skb_chk; 152 144 153 - transport_len = ntohs(ipv6_hdr(skb)->payload_len); 154 - transport_len -= skb_transport_offset(skb) - sizeof(struct ipv6hdr); 145 + if (!ipv6_mc_may_pull(skb, len)) 146 + return -EINVAL; 155 147 156 148 skb_chk = skb_checksum_trimmed(skb, transport_len, 157 149 ipv6_mc_validate_checksum); 158 150 if (!skb_chk) 159 - goto err; 151 + return -EINVAL; 160 152 161 - if (!pskb_may_pull(skb_chk, len)) 162 - goto err; 163 - 164 - ret = ipv6_mc_check_mld_msg(skb_chk); 165 - if (ret) 166 - goto err; 167 - 168 - ret = 0; 169 - 170 - err: 171 - if (skb_chk && skb_chk != skb) 153 + if (skb_chk != skb) 172 154 kfree_skb(skb_chk); 173 155 174 - return ret; 156 + return 0; 175 157 } 176 158 177 159 /** ··· 189 195 if (ret < 0) 190 196 return ret; 191 197 192 - return __ipv6_mc_check_mld(skb); 198 + ret = ipv6_mc_check_icmpv6(skb); 199 + if (ret < 0) 200 + return ret; 201 + 202 + return ipv6_mc_check_mld_msg(skb); 193 203 } 194 204 EXPORT_SYMBOL(ipv6_mc_check_mld);