Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

hv_netvsc: Copy packets sent by Hyper-V out of the receive buffer

Pointers to receive-buffer packets sent by Hyper-V are used within the
guest VM. Hyper-V can send packets with erroneous values or modify
packet fields after they are processed by the guest. To defend against
these scenarios, copy (sections of) the incoming packet after validating
their length and offset fields in netvsc_filter_receive(). In this way,
the packet can no longer be modified by the host.

Reported-by: Juan Vazquez <juvazq@microsoft.com>
Signed-off-by: Andrea Parri (Microsoft) <parri.andrea@gmail.com>
Link: https://lore.kernel.org/r/20210126162907.21056-1-parri.andrea@gmail.com
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

authored by

Andrea Parri (Microsoft) and committed by
Jakub Kicinski
0ba35fe9 46eb3c10

+150 -86
+48 -45
drivers/net/hyperv/hyperv_net.h
··· 105 105 u32 processor_masks_entry_size; 106 106 }; 107 107 108 - /* Fwd declaration */ 109 - struct ndis_tcp_ip_checksum_info; 110 - struct ndis_pkt_8021q_info; 108 + struct ndis_tcp_ip_checksum_info { 109 + union { 110 + struct { 111 + u32 is_ipv4:1; 112 + u32 is_ipv6:1; 113 + u32 tcp_checksum:1; 114 + u32 udp_checksum:1; 115 + u32 ip_header_checksum:1; 116 + u32 reserved:11; 117 + u32 tcp_header_offset:10; 118 + } transmit; 119 + struct { 120 + u32 tcp_checksum_failed:1; 121 + u32 udp_checksum_failed:1; 122 + u32 ip_checksum_failed:1; 123 + u32 tcp_checksum_succeeded:1; 124 + u32 udp_checksum_succeeded:1; 125 + u32 ip_checksum_succeeded:1; 126 + u32 loopback:1; 127 + u32 tcp_checksum_value_invalid:1; 128 + u32 ip_checksum_value_invalid:1; 129 + } receive; 130 + u32 value; 131 + }; 132 + }; 133 + 134 + struct ndis_pkt_8021q_info { 135 + union { 136 + struct { 137 + u32 pri:3; /* User Priority */ 138 + u32 cfi:1; /* Canonical Format ID */ 139 + u32 vlanid:12; /* VLAN ID */ 140 + u32 reserved:16; 141 + }; 142 + u32 value; 143 + }; 144 + }; 111 145 112 146 /* 113 147 * Represent netvsc packet which contains 1 RNDIS and 1 ethernet frame ··· 228 194 struct sk_buff *skb, 229 195 bool xdp_tx); 230 196 void netvsc_linkstatus_callback(struct net_device *net, 231 - struct rndis_message *resp); 197 + struct rndis_message *resp, 198 + void *data); 232 199 int netvsc_recv_callback(struct net_device *net, 233 200 struct netvsc_device *nvdev, 234 201 struct netvsc_channel *nvchan); ··· 919 884 #define NVSP_RSC_MAX 562 /* Max #RSC frags in a vmbus xfer page pkt */ 920 885 921 886 struct nvsc_rsc { 922 - const struct ndis_pkt_8021q_info *vlan; 923 - const struct ndis_tcp_ip_checksum_info *csum_info; 924 - const u32 *hash_info; 887 + struct ndis_pkt_8021q_info vlan; 888 + struct ndis_tcp_ip_checksum_info csum_info; 889 + u32 hash_info; 890 + u8 ppi_flags; /* valid/present bits for the above PPIs */ 925 891 u8 is_last; /* last RNDIS msg in a vmtransfer_page */ 926 892 u32 cnt; /* #fragments in an RSC packet */ 927 893 u32 pktlen; /* Full packet length */ 928 894 void *data[NVSP_RSC_MAX]; 929 895 u32 len[NVSP_RSC_MAX]; 930 896 }; 897 + 898 + #define NVSC_RSC_VLAN BIT(0) /* valid/present bit for 'vlan' */ 899 + #define NVSC_RSC_CSUM_INFO BIT(1) /* valid/present bit for 'csum_info' */ 900 + #define NVSC_RSC_HASH_INFO BIT(2) /* valid/present bit for 'hash_info' */ 931 901 932 902 struct netvsc_stats { 933 903 u64 packets; ··· 1042 1002 struct netvsc_channel { 1043 1003 struct vmbus_channel *channel; 1044 1004 struct netvsc_device *net_device; 1005 + void *recv_buf; /* buffer to copy packets out from the receive buffer */ 1045 1006 const struct vmpacket_descriptor *desc; 1046 1007 struct napi_struct napi; 1047 1008 struct multi_send_data msd; ··· 1275 1234 u16 pkt_id; 1276 1235 }; 1277 1236 1278 - struct ndis_pkt_8021q_info { 1279 - union { 1280 - struct { 1281 - u32 pri:3; /* User Priority */ 1282 - u32 cfi:1; /* Canonical Format ID */ 1283 - u32 vlanid:12; /* VLAN ID */ 1284 - u32 reserved:16; 1285 - }; 1286 - u32 value; 1287 - }; 1288 - }; 1289 - 1290 1237 struct ndis_object_header { 1291 1238 u8 type; 1292 1239 u8 revision; ··· 1462 1433 struct { 1463 1434 u8 encapsulated_packet_task_offload; 1464 1435 u8 encapsulation_types; 1465 - }; 1466 - }; 1467 - 1468 - struct ndis_tcp_ip_checksum_info { 1469 - union { 1470 - struct { 1471 - u32 is_ipv4:1; 1472 - u32 is_ipv6:1; 1473 - u32 tcp_checksum:1; 1474 - u32 udp_checksum:1; 1475 - u32 ip_header_checksum:1; 1476 - u32 reserved:11; 1477 - u32 tcp_header_offset:10; 1478 - } transmit; 1479 - struct { 1480 - u32 tcp_checksum_failed:1; 1481 - u32 udp_checksum_failed:1; 1482 - u32 ip_checksum_failed:1; 1483 - u32 tcp_checksum_succeeded:1; 1484 - u32 udp_checksum_succeeded:1; 1485 - u32 ip_checksum_succeeded:1; 1486 - u32 loopback:1; 1487 - u32 tcp_checksum_value_invalid:1; 1488 - u32 ip_checksum_value_invalid:1; 1489 - } receive; 1490 - u32 value; 1491 1436 }; 1492 1437 }; 1493 1438
+20
drivers/net/hyperv/netvsc.c
··· 131 131 132 132 for (i = 0; i < VRSS_CHANNEL_MAX; i++) { 133 133 xdp_rxq_info_unreg(&nvdev->chan_table[i].xdp_rxq); 134 + kfree(nvdev->chan_table[i].recv_buf); 134 135 vfree(nvdev->chan_table[i].mrc.slots); 135 136 } 136 137 ··· 1285 1284 continue; 1286 1285 } 1287 1286 1287 + /* We're going to copy (sections of) the packet into nvchan->recv_buf; 1288 + * make sure that nvchan->recv_buf is large enough to hold the packet. 1289 + */ 1290 + if (unlikely(buflen > net_device->recv_section_size)) { 1291 + nvchan->rsc.cnt = 0; 1292 + status = NVSP_STAT_FAIL; 1293 + netif_err(net_device_ctx, rx_err, ndev, 1294 + "Packet too big: buflen=%u recv_section_size=%u\n", 1295 + buflen, net_device->recv_section_size); 1296 + 1297 + continue; 1298 + } 1299 + 1288 1300 data = recv_buf + offset; 1289 1301 1290 1302 nvchan->rsc.is_last = (i == count - 1); ··· 1548 1534 1549 1535 for (i = 0; i < VRSS_CHANNEL_MAX; i++) { 1550 1536 struct netvsc_channel *nvchan = &net_device->chan_table[i]; 1537 + 1538 + nvchan->recv_buf = kzalloc(device_info->recv_section_size, GFP_KERNEL); 1539 + if (nvchan->recv_buf == NULL) { 1540 + ret = -ENOMEM; 1541 + goto cleanup2; 1542 + } 1551 1543 1552 1544 nvchan->channel = device->channel; 1553 1545 nvchan->net_device = net_device;
+14 -10
drivers/net/hyperv/netvsc_drv.c
··· 743 743 * netvsc_linkstatus_callback - Link up/down notification 744 744 */ 745 745 void netvsc_linkstatus_callback(struct net_device *net, 746 - struct rndis_message *resp) 746 + struct rndis_message *resp, 747 + void *data) 747 748 { 748 749 struct rndis_indicate_status *indicate = &resp->msg.indicate_status; 749 750 struct net_device_context *ndev_ctx = netdev_priv(net); ··· 757 756 resp->msg_len); 758 757 return; 759 758 } 759 + 760 + /* Copy the RNDIS indicate status into nvchan->recv_buf */ 761 + memcpy(indicate, data + RNDIS_HEADER_SIZE, sizeof(*indicate)); 760 762 761 763 /* Update the physical link speed when changing to another vSwitch */ 762 764 if (indicate->status == RNDIS_STATUS_LINK_SPEED_CHANGE) { ··· 775 771 return; 776 772 } 777 773 778 - speed = *(u32 *)((void *)indicate 779 - + indicate->status_buf_offset) / 10000; 774 + speed = *(u32 *)(data + RNDIS_HEADER_SIZE + indicate->status_buf_offset) / 10000; 780 775 ndev_ctx->speed = speed; 781 776 return; 782 777 } ··· 830 827 struct xdp_buff *xdp) 831 828 { 832 829 struct napi_struct *napi = &nvchan->napi; 833 - const struct ndis_pkt_8021q_info *vlan = nvchan->rsc.vlan; 830 + const struct ndis_pkt_8021q_info *vlan = &nvchan->rsc.vlan; 834 831 const struct ndis_tcp_ip_checksum_info *csum_info = 835 - nvchan->rsc.csum_info; 836 - const u32 *hash_info = nvchan->rsc.hash_info; 832 + &nvchan->rsc.csum_info; 833 + const u32 *hash_info = &nvchan->rsc.hash_info; 834 + u8 ppi_flags = nvchan->rsc.ppi_flags; 837 835 struct sk_buff *skb; 838 836 void *xbuf = xdp->data_hard_start; 839 837 int i; ··· 878 874 * We compute it here if the flags are set, because on Linux, the IP 879 875 * checksum is always checked. 880 876 */ 881 - if (csum_info && csum_info->receive.ip_checksum_value_invalid && 877 + if ((ppi_flags & NVSC_RSC_CSUM_INFO) && csum_info->receive.ip_checksum_value_invalid && 882 878 csum_info->receive.ip_checksum_succeeded && 883 879 skb->protocol == htons(ETH_P_IP)) { 884 880 /* Check that there is enough space to hold the IP header. */ ··· 890 886 } 891 887 892 888 /* Do L4 checksum offload if enabled and present. */ 893 - if (csum_info && (net->features & NETIF_F_RXCSUM)) { 889 + if ((ppi_flags & NVSC_RSC_CSUM_INFO) && (net->features & NETIF_F_RXCSUM)) { 894 890 if (csum_info->receive.tcp_checksum_succeeded || 895 891 csum_info->receive.udp_checksum_succeeded) 896 892 skb->ip_summed = CHECKSUM_UNNECESSARY; 897 893 } 898 894 899 - if (hash_info && (net->features & NETIF_F_RXHASH)) 895 + if ((ppi_flags & NVSC_RSC_HASH_INFO) && (net->features & NETIF_F_RXHASH)) 900 896 skb_set_hash(skb, *hash_info, PKT_HASH_TYPE_L4); 901 897 902 - if (vlan) { 898 + if (ppi_flags & NVSC_RSC_VLAN) { 903 899 u16 vlan_tci = vlan->vlanid | (vlan->pri << VLAN_PRIO_SHIFT) | 904 900 (vlan->cfi ? VLAN_CFI_MASK : 0); 905 901
+68 -31
drivers/net/hyperv/rndis_filter.c
··· 127 127 } 128 128 129 129 static void dump_rndis_message(struct net_device *netdev, 130 - const struct rndis_message *rndis_msg) 130 + const struct rndis_message *rndis_msg, 131 + const void *data) 131 132 { 132 133 switch (rndis_msg->ndis_msg_type) { 133 134 case RNDIS_MSG_PACKET: 134 135 if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >= sizeof(struct rndis_packet)) { 135 - const struct rndis_packet *pkt = &rndis_msg->msg.pkt; 136 + const struct rndis_packet *pkt = data + RNDIS_HEADER_SIZE; 136 137 netdev_dbg(netdev, "RNDIS_MSG_PACKET (len %u, " 137 138 "data offset %u data len %u, # oob %u, " 138 139 "oob offset %u, oob len %u, pkt offset %u, " ··· 153 152 if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >= 154 153 sizeof(struct rndis_initialize_complete)) { 155 154 const struct rndis_initialize_complete *init_complete = 156 - &rndis_msg->msg.init_complete; 155 + data + RNDIS_HEADER_SIZE; 157 156 netdev_dbg(netdev, "RNDIS_MSG_INIT_C " 158 157 "(len %u, id 0x%x, status 0x%x, major %d, minor %d, " 159 158 "device flags %d, max xfer size 0x%x, max pkts %u, " ··· 174 173 if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >= 175 174 sizeof(struct rndis_query_complete)) { 176 175 const struct rndis_query_complete *query_complete = 177 - &rndis_msg->msg.query_complete; 176 + data + RNDIS_HEADER_SIZE; 178 177 netdev_dbg(netdev, "RNDIS_MSG_QUERY_C " 179 178 "(len %u, id 0x%x, status 0x%x, buf len %u, " 180 179 "buf offset %u)\n", ··· 189 188 case RNDIS_MSG_SET_C: 190 189 if (rndis_msg->msg_len - RNDIS_HEADER_SIZE + sizeof(struct rndis_set_complete)) { 191 190 const struct rndis_set_complete *set_complete = 192 - &rndis_msg->msg.set_complete; 191 + data + RNDIS_HEADER_SIZE; 193 192 netdev_dbg(netdev, 194 193 "RNDIS_MSG_SET_C (len %u, id 0x%x, status 0x%x)\n", 195 194 rndis_msg->msg_len, ··· 202 201 if (rndis_msg->msg_len - RNDIS_HEADER_SIZE >= 203 202 sizeof(struct rndis_indicate_status)) { 204 203 const struct rndis_indicate_status *indicate_status = 205 - &rndis_msg->msg.indicate_status; 204 + data + RNDIS_HEADER_SIZE; 206 205 netdev_dbg(netdev, "RNDIS_MSG_INDICATE " 207 206 "(len %u, status 0x%x, buf len %u, buf offset %u)\n", 208 207 rndis_msg->msg_len, ··· 287 286 288 287 static void rndis_filter_receive_response(struct net_device *ndev, 289 288 struct netvsc_device *nvdev, 290 - const struct rndis_message *resp) 289 + struct rndis_message *resp, 290 + void *data) 291 291 { 292 + u32 *req_id = &resp->msg.init_complete.req_id; 292 293 struct rndis_device *dev = nvdev->extension; 293 294 struct rndis_request *request = NULL; 294 295 bool found = false; ··· 315 312 return; 316 313 } 317 314 315 + /* Copy the request ID into nvchan->recv_buf */ 316 + *req_id = *(u32 *)(data + RNDIS_HEADER_SIZE); 317 + 318 318 spin_lock_irqsave(&dev->request_lock, flags); 319 319 list_for_each_entry(request, &dev->req_list, list_ent) { 320 320 /* 321 321 * All request/response message contains RequestId as the 1st 322 322 * field 323 323 */ 324 - if (request->request_msg.msg.init_req.req_id 325 - == resp->msg.init_complete.req_id) { 324 + if (request->request_msg.msg.init_req.req_id == *req_id) { 326 325 found = true; 327 326 break; 328 327 } ··· 334 329 if (found) { 335 330 if (resp->msg_len <= 336 331 sizeof(struct rndis_message) + RNDIS_EXT_LEN) { 337 - memcpy(&request->response_msg, resp, 338 - resp->msg_len); 332 + memcpy(&request->response_msg, resp, RNDIS_HEADER_SIZE + sizeof(*req_id)); 333 + memcpy((void *)&request->response_msg + RNDIS_HEADER_SIZE + sizeof(*req_id), 334 + data + RNDIS_HEADER_SIZE + sizeof(*req_id), 335 + resp->msg_len - RNDIS_HEADER_SIZE - sizeof(*req_id)); 339 336 if (request->request_msg.ndis_msg_type == 340 337 RNDIS_MSG_QUERY && request->request_msg.msg. 341 338 query_req.oid == RNDIS_OID_GEN_MEDIA_CONNECT_STATUS) ··· 366 359 netdev_err(ndev, 367 360 "no rndis request found for this response " 368 361 "(id 0x%x res type 0x%x)\n", 369 - resp->msg.init_complete.req_id, 362 + *req_id, 370 363 resp->ndis_msg_type); 371 364 } 372 365 } ··· 378 371 static inline void *rndis_get_ppi(struct net_device *ndev, 379 372 struct rndis_packet *rpkt, 380 373 u32 rpkt_len, u32 type, u8 internal, 381 - u32 ppi_size) 374 + u32 ppi_size, void *data) 382 375 { 383 376 struct rndis_per_packet_info *ppi; 384 377 int len; ··· 403 396 404 397 ppi = (struct rndis_per_packet_info *)((ulong)rpkt + 405 398 rpkt->per_pkt_info_offset); 399 + /* Copy the PPIs into nvchan->recv_buf */ 400 + memcpy(ppi, data + RNDIS_HEADER_SIZE + rpkt->per_pkt_info_offset, rpkt->per_pkt_info_len); 406 401 len = rpkt->per_pkt_info_len; 407 402 408 403 while (len > 0) { ··· 447 438 if (cnt) { 448 439 nvchan->rsc.pktlen += len; 449 440 } else { 450 - nvchan->rsc.vlan = vlan; 451 - nvchan->rsc.csum_info = csum_info; 441 + /* The data/values pointed by vlan, csum_info and hash_info are shared 442 + * across the different 'fragments' of the RSC packet; store them into 443 + * the packet itself. 444 + */ 445 + if (vlan != NULL) { 446 + memcpy(&nvchan->rsc.vlan, vlan, sizeof(*vlan)); 447 + nvchan->rsc.ppi_flags |= NVSC_RSC_VLAN; 448 + } else { 449 + nvchan->rsc.ppi_flags &= ~NVSC_RSC_VLAN; 450 + } 451 + if (csum_info != NULL) { 452 + memcpy(&nvchan->rsc.csum_info, csum_info, sizeof(*csum_info)); 453 + nvchan->rsc.ppi_flags |= NVSC_RSC_CSUM_INFO; 454 + } else { 455 + nvchan->rsc.ppi_flags &= ~NVSC_RSC_CSUM_INFO; 456 + } 452 457 nvchan->rsc.pktlen = len; 453 - nvchan->rsc.hash_info = hash_info; 458 + if (hash_info != NULL) { 459 + nvchan->rsc.csum_info = *csum_info; 460 + nvchan->rsc.ppi_flags |= NVSC_RSC_HASH_INFO; 461 + } else { 462 + nvchan->rsc.ppi_flags &= ~NVSC_RSC_HASH_INFO; 463 + } 454 464 } 455 465 456 466 nvchan->rsc.data[cnt] = data; ··· 481 453 struct netvsc_device *nvdev, 482 454 struct netvsc_channel *nvchan, 483 455 struct rndis_message *msg, 484 - u32 data_buflen) 456 + void *data, u32 data_buflen) 485 457 { 486 458 struct rndis_packet *rndis_pkt = &msg->msg.pkt; 487 459 const struct ndis_tcp_ip_checksum_info *csum_info; ··· 489 461 const struct rndis_pktinfo_id *pktinfo_id; 490 462 const u32 *hash_info; 491 463 u32 data_offset, rpkt_len; 492 - void *data; 493 464 bool rsc_more = false; 494 465 int ret; 495 466 ··· 498 471 data_buflen); 499 472 return NVSP_STAT_FAIL; 500 473 } 474 + 475 + /* Copy the RNDIS packet into nvchan->recv_buf */ 476 + memcpy(rndis_pkt, data + RNDIS_HEADER_SIZE, sizeof(*rndis_pkt)); 501 477 502 478 /* Validate rndis_pkt offset */ 503 479 if (rndis_pkt->data_offset >= data_buflen - RNDIS_HEADER_SIZE) { ··· 527 497 return NVSP_STAT_FAIL; 528 498 } 529 499 530 - vlan = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, IEEE_8021Q_INFO, 0, sizeof(*vlan)); 500 + vlan = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, IEEE_8021Q_INFO, 0, sizeof(*vlan), 501 + data); 531 502 532 503 csum_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, TCPIP_CHKSUM_PKTINFO, 0, 533 - sizeof(*csum_info)); 504 + sizeof(*csum_info), data); 534 505 535 506 hash_info = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, NBL_HASH_VALUE, 0, 536 - sizeof(*hash_info)); 507 + sizeof(*hash_info), data); 537 508 538 509 pktinfo_id = rndis_get_ppi(ndev, rndis_pkt, rpkt_len, RNDIS_PKTINFO_ID, 1, 539 - sizeof(*pktinfo_id)); 540 - 541 - data = (void *)msg + data_offset; 510 + sizeof(*pktinfo_id), data); 542 511 543 512 /* Identify RSC frags, drop erroneous packets */ 544 513 if (pktinfo_id && (pktinfo_id->flag & RNDIS_PKTINFO_SUBALLOC)) { ··· 566 537 * the data packet to the stack, without the rndis trailer padding 567 538 */ 568 539 rsc_add_data(nvchan, vlan, csum_info, hash_info, 569 - data, rndis_pkt->data_len); 540 + data + data_offset, rndis_pkt->data_len); 570 541 571 542 if (rsc_more) 572 543 return NVSP_STAT_SUCCESS; ··· 588 559 void *data, u32 buflen) 589 560 { 590 561 struct net_device_context *net_device_ctx = netdev_priv(ndev); 591 - struct rndis_message *rndis_msg = data; 562 + struct rndis_message *rndis_msg = nvchan->recv_buf; 563 + 564 + if (buflen < RNDIS_HEADER_SIZE) { 565 + netdev_err(ndev, "Invalid rndis_msg (buflen: %u)\n", buflen); 566 + return NVSP_STAT_FAIL; 567 + } 568 + 569 + /* Copy the RNDIS msg header into nvchan->recv_buf */ 570 + memcpy(rndis_msg, data, RNDIS_HEADER_SIZE); 592 571 593 572 /* Validate incoming rndis_message packet */ 594 - if (buflen < RNDIS_HEADER_SIZE || rndis_msg->msg_len < RNDIS_HEADER_SIZE || 573 + if (rndis_msg->msg_len < RNDIS_HEADER_SIZE || 595 574 buflen < rndis_msg->msg_len) { 596 575 netdev_err(ndev, "Invalid rndis_msg (buflen: %u, msg_len: %u)\n", 597 576 buflen, rndis_msg->msg_len); ··· 607 570 } 608 571 609 572 if (netif_msg_rx_status(net_device_ctx)) 610 - dump_rndis_message(ndev, rndis_msg); 573 + dump_rndis_message(ndev, rndis_msg, data); 611 574 612 575 switch (rndis_msg->ndis_msg_type) { 613 576 case RNDIS_MSG_PACKET: 614 577 return rndis_filter_receive_data(ndev, net_dev, nvchan, 615 - rndis_msg, buflen); 578 + rndis_msg, data, buflen); 616 579 case RNDIS_MSG_INIT_C: 617 580 case RNDIS_MSG_QUERY_C: 618 581 case RNDIS_MSG_SET_C: 619 582 /* completion msgs */ 620 - rndis_filter_receive_response(ndev, net_dev, rndis_msg); 583 + rndis_filter_receive_response(ndev, net_dev, rndis_msg, data); 621 584 break; 622 585 623 586 case RNDIS_MSG_INDICATE: 624 587 /* notification msgs */ 625 - netvsc_linkstatus_callback(ndev, rndis_msg); 588 + netvsc_linkstatus_callback(ndev, rndis_msg, data); 626 589 break; 627 590 default: 628 591 netdev_err(ndev,