Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

wifi: mt76: add a wrapper for wcid access with validation

Several places use rcu_dereference to get a wcid entry without validating
if the index exceeds the array boundary. Fix this by using a helper function,
which handles validation.

Link: https://patch.msgid.link/20250707154702.1726-1-nbd@nbd.name
Signed-off-by: Felix Fietkau <nbd@nbd.name>

+41 -68
+10
drivers/net/wireless/mediatek/mt76/mt76.h
··· 1224 1224 #define mt76_dereference(p, dev) \ 1225 1225 rcu_dereference_protected(p, lockdep_is_held(&(dev)->mutex)) 1226 1226 1227 + static inline struct mt76_wcid * 1228 + __mt76_wcid_ptr(struct mt76_dev *dev, u16 idx) 1229 + { 1230 + if (idx >= ARRAY_SIZE(dev->wcid)) 1231 + return NULL; 1232 + return rcu_dereference(dev->wcid[idx]); 1233 + } 1234 + 1235 + #define mt76_wcid_ptr(dev, idx) __mt76_wcid_ptr(&(dev)->mt76, idx) 1236 + 1227 1237 struct mt76_dev *mt76_alloc_device(struct device *pdev, unsigned int size, 1228 1238 const struct ieee80211_ops *ops, 1229 1239 const struct mt76_driver_ops *drv_ops);
+1 -1
drivers/net/wireless/mediatek/mt76/mt7603/dma.c
··· 44 44 if (idx >= MT7603_WTBL_STA - 1) 45 45 goto free; 46 46 47 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 47 + wcid = mt76_wcid_ptr(dev, idx); 48 48 if (!wcid) 49 49 goto free; 50 50
+2 -8
drivers/net/wireless/mediatek/mt76/mt7603/mac.c
··· 487 487 struct mt7603_sta *sta; 488 488 struct mt76_wcid *wcid; 489 489 490 - if (idx >= MT7603_WTBL_SIZE) 491 - return NULL; 492 - 493 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 490 + wcid = mt76_wcid_ptr(dev, idx); 494 491 if (unicast || !wcid) 495 492 return wcid; 496 493 ··· 1263 1266 if (pid == MT_PACKET_ID_NO_ACK) 1264 1267 return; 1265 1268 1266 - if (wcidx >= MT7603_WTBL_SIZE) 1267 - return; 1268 - 1269 1269 rcu_read_lock(); 1270 1270 1271 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 1271 + wcid = mt76_wcid_ptr(dev, wcidx); 1272 1272 if (!wcid) 1273 1273 goto out; 1274 1274
+2 -5
drivers/net/wireless/mediatek/mt76/mt7615/mac.c
··· 90 90 struct mt7615_sta *sta; 91 91 struct mt76_wcid *wcid; 92 92 93 - if (idx >= MT7615_WTBL_SIZE) 94 - return NULL; 95 - 96 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 93 + wcid = mt76_wcid_ptr(dev, idx); 97 94 if (unicast || !wcid) 98 95 return wcid; 99 96 ··· 1501 1504 1502 1505 rcu_read_lock(); 1503 1506 1504 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 1507 + wcid = mt76_wcid_ptr(dev, wcidx); 1505 1508 if (!wcid) 1506 1509 goto out; 1507 1510
+1 -1
drivers/net/wireless/mediatek/mt76/mt76_connac_mac.c
··· 1172 1172 wcid_idx = wcid->idx; 1173 1173 } else { 1174 1174 wcid_idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); 1175 - wcid = rcu_dereference(dev->wcid[wcid_idx]); 1175 + wcid = __mt76_wcid_ptr(dev, wcid_idx); 1176 1176 1177 1177 if (wcid && wcid->sta) { 1178 1178 sta = container_of((void *)wcid, struct ieee80211_sta,
+1 -4
drivers/net/wireless/mediatek/mt76/mt76x02.h
··· 262 262 { 263 263 struct mt76_wcid *wcid; 264 264 265 - if (idx >= MT76x02_N_WCIDS) 266 - return NULL; 267 - 268 - wcid = rcu_dereference(dev->wcid[idx]); 265 + wcid = __mt76_wcid_ptr(dev, idx); 269 266 if (!wcid) 270 267 return NULL; 271 268
+1 -3
drivers/net/wireless/mediatek/mt76/mt76x02_mac.c
··· 564 564 565 565 rcu_read_lock(); 566 566 567 - if (stat->wcid < MT76x02_N_WCIDS) 568 - wcid = rcu_dereference(dev->mt76.wcid[stat->wcid]); 569 - 567 + wcid = mt76_wcid_ptr(dev, stat->wcid); 570 568 if (wcid && wcid->sta) { 571 569 void *priv; 572 570
+3 -9
drivers/net/wireless/mediatek/mt76/mt7915/mac.c
··· 56 56 struct mt7915_sta *sta; 57 57 struct mt76_wcid *wcid; 58 58 59 - if (idx >= ARRAY_SIZE(dev->mt76.wcid)) 60 - return NULL; 61 - 62 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 59 + wcid = mt76_wcid_ptr(dev, idx); 63 60 if (unicast || !wcid) 64 61 return wcid; 65 62 ··· 914 917 u16 idx; 915 918 916 919 idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info); 917 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 920 + wcid = mt76_wcid_ptr(dev, idx); 918 921 sta = wcid_to_sta(wcid); 919 922 if (!sta) 920 923 continue; ··· 1010 1013 if (pid < MT_PACKET_ID_WED) 1011 1014 return; 1012 1015 1013 - if (wcidx >= mt7915_wtbl_size(dev)) 1014 - return; 1015 - 1016 1016 rcu_read_lock(); 1017 1017 1018 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 1018 + wcid = mt76_wcid_ptr(dev, wcidx); 1019 1019 if (!wcid) 1020 1020 goto out; 1021 1021
+1 -1
drivers/net/wireless/mediatek/mt76/mt7915/mcu.c
··· 3986 3986 3987 3987 rcu_read_lock(); 3988 3988 3989 - wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); 3989 + wcid = mt76_wcid_ptr(dev, wlan_idx); 3990 3990 if (wcid) 3991 3991 wcid->stats.tx_packets += le32_to_cpu(res->tx_packets); 3992 3992 else
+1 -4
drivers/net/wireless/mediatek/mt76/mt7915/mmio.c
··· 587 587 588 588 dev = container_of(wed, struct mt7915_dev, mt76.mmio.wed); 589 589 590 - if (idx >= mt7915_wtbl_size(dev)) 591 - return; 592 - 593 590 rcu_read_lock(); 594 591 595 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 592 + wcid = mt76_wcid_ptr(dev, idx); 596 593 if (wcid) { 597 594 wcid->stats.rx_bytes += le32_to_cpu(stats->rx_byte_cnt); 598 595 wcid->stats.rx_packets += le32_to_cpu(stats->rx_pkt_cnt);
+3 -3
drivers/net/wireless/mediatek/mt76/mt7921/mac.c
··· 465 465 466 466 rcu_read_lock(); 467 467 468 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 468 + wcid = mt76_wcid_ptr(dev, wcidx); 469 469 if (!wcid) 470 470 goto out; 471 471 ··· 516 516 517 517 count++; 518 518 idx = FIELD_GET(MT_TX_FREE_WLAN_ID, info); 519 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 519 + wcid = mt76_wcid_ptr(dev, idx); 520 520 sta = wcid_to_sta(wcid); 521 521 if (!sta) 522 522 continue; ··· 816 816 u16 idx; 817 817 818 818 idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); 819 - wcid = rcu_dereference(mdev->wcid[idx]); 819 + wcid = __mt76_wcid_ptr(mdev, idx); 820 820 sta = wcid_to_sta(wcid); 821 821 822 822 if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE)))
+3 -3
drivers/net/wireless/mediatek/mt76/mt7925/mac.c
··· 1040 1040 1041 1041 rcu_read_lock(); 1042 1042 1043 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 1043 + wcid = mt76_wcid_ptr(dev, wcidx); 1044 1044 if (!wcid) 1045 1045 goto out; 1046 1046 ··· 1122 1122 u16 idx; 1123 1123 1124 1124 idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info); 1125 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 1125 + wcid = mt76_wcid_ptr(dev, idx); 1126 1126 sta = wcid_to_sta(wcid); 1127 1127 if (!sta) 1128 1128 continue; ··· 1445 1445 u16 idx; 1446 1446 1447 1447 idx = le32_get_bits(txwi[1], MT_TXD1_WLAN_IDX); 1448 - wcid = rcu_dereference(mdev->wcid[idx]); 1448 + wcid = __mt76_wcid_ptr(mdev, idx); 1449 1449 sta = wcid_to_sta(wcid); 1450 1450 1451 1451 if (sta && likely(e->skb->protocol != cpu_to_be16(ETH_P_PAE)))
+1 -4
drivers/net/wireless/mediatek/mt76/mt792x_mac.c
··· 142 142 struct mt792x_sta *sta; 143 143 struct mt76_wcid *wcid; 144 144 145 - if (idx >= ARRAY_SIZE(dev->mt76.wcid)) 146 - return NULL; 147 - 148 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 145 + wcid = mt76_wcid_ptr(dev, idx); 149 146 if (unicast || !wcid) 150 147 return wcid; 151 148
+3 -9
drivers/net/wireless/mediatek/mt76/mt7996/mac.c
··· 61 61 struct mt76_wcid *wcid; 62 62 int i; 63 63 64 - if (idx >= ARRAY_SIZE(dev->mt76.wcid)) 65 - return NULL; 66 - 67 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 64 + wcid = mt76_wcid_ptr(dev, idx); 68 65 if (!wcid) 69 66 return NULL; 70 67 ··· 1246 1249 u16 idx; 1247 1250 1248 1251 idx = FIELD_GET(MT_TXFREE_INFO_WLAN_ID, info); 1249 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 1252 + wcid = mt76_wcid_ptr(dev, idx); 1250 1253 sta = wcid_to_sta(wcid); 1251 1254 if (!sta) 1252 1255 goto next; ··· 1468 1471 if (pid < MT_PACKET_ID_NO_SKB) 1469 1472 return; 1470 1473 1471 - if (wcidx >= mt7996_wtbl_size(dev)) 1472 - return; 1473 - 1474 1474 rcu_read_lock(); 1475 1475 1476 - wcid = rcu_dereference(dev->mt76.wcid[wcidx]); 1476 + wcid = mt76_wcid_ptr(dev, wcidx); 1477 1477 if (!wcid) 1478 1478 goto out; 1479 1479
+4 -7
drivers/net/wireless/mediatek/mt76/mt7996/mcu.c
··· 555 555 switch (le16_to_cpu(res->tag)) { 556 556 case UNI_ALL_STA_TXRX_RATE: 557 557 wlan_idx = le16_to_cpu(res->rate[i].wlan_idx); 558 - wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); 558 + wcid = mt76_wcid_ptr(dev, wlan_idx); 559 559 560 560 if (!wcid) 561 561 break; ··· 565 565 break; 566 566 case UNI_ALL_STA_TXRX_ADM_STAT: 567 567 wlan_idx = le16_to_cpu(res->adm_stat[i].wlan_idx); 568 - wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); 568 + wcid = mt76_wcid_ptr(dev, wlan_idx); 569 569 570 570 if (!wcid) 571 571 break; ··· 579 579 break; 580 580 case UNI_ALL_STA_TXRX_MSDU_COUNT: 581 581 wlan_idx = le16_to_cpu(res->msdu_cnt[i].wlan_idx); 582 - wcid = rcu_dereference(dev->mt76.wcid[wlan_idx]); 582 + wcid = mt76_wcid_ptr(dev, wlan_idx); 583 583 584 584 if (!wcid) 585 585 break; ··· 676 676 677 677 e = (void *)skb->data; 678 678 idx = le16_to_cpu(e->wlan_id); 679 - if (idx >= ARRAY_SIZE(dev->mt76.wcid)) 680 - break; 681 - 682 - wcid = rcu_dereference(dev->mt76.wcid[idx]); 679 + wcid = mt76_wcid_ptr(dev, idx); 683 680 if (!wcid || !wcid->sta) 684 681 break; 685 682
+3 -5
drivers/net/wireless/mediatek/mt76/tx.c
··· 64 64 struct mt76_tx_cb *cb = mt76_tx_skb_cb(skb); 65 65 struct mt76_wcid *wcid; 66 66 67 - wcid = rcu_dereference(dev->wcid[cb->wcid]); 67 + wcid = __mt76_wcid_ptr(dev, cb->wcid); 68 68 if (wcid) { 69 69 status.sta = wcid_to_sta(wcid); 70 70 if (status.sta && (wcid->rate.flags || wcid->rate.legacy)) { ··· 251 251 252 252 rcu_read_lock(); 253 253 254 - if (wcid_idx < ARRAY_SIZE(dev->wcid)) 255 - wcid = rcu_dereference(dev->wcid[wcid_idx]); 256 - 254 + wcid = __mt76_wcid_ptr(dev, wcid_idx); 257 255 mt76_tx_check_non_aql(dev, wcid, skb); 258 256 259 257 #ifdef CONFIG_NL80211_TESTMODE ··· 536 538 break; 537 539 538 540 mtxq = (struct mt76_txq *)txq->drv_priv; 539 - wcid = rcu_dereference(dev->wcid[mtxq->wcid]); 541 + wcid = __mt76_wcid_ptr(dev, mtxq->wcid); 540 542 if (!wcid || test_bit(MT_WCID_FLAG_PS, &wcid->flags)) 541 543 continue; 542 544
+1 -1
drivers/net/wireless/mediatek/mt76/util.c
··· 83 83 if (!(mask & 1)) 84 84 continue; 85 85 86 - wcid = rcu_dereference(dev->wcid[j]); 86 + wcid = __mt76_wcid_ptr(dev, j); 87 87 if (!wcid || wcid->phy_idx != phy_idx) 88 88 continue; 89 89