Linux kernel mirror (for testing)
git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel
os
linux
1// SPDX-License-Identifier: GPL-2.0
2
3#include <linux/ip.h>
4#include <linux/skbuff.h>
5#include <net/ip6_checksum.h>
6#include <net/psp.h>
7#include <net/sock.h>
8
9#include "netdevsim.h"
10
11void nsim_psp_handle_ext(struct sk_buff *skb, struct skb_ext *psp_ext)
12{
13 if (psp_ext)
14 __skb_ext_set(skb, SKB_EXT_PSP, psp_ext);
15}
16
17enum skb_drop_reason
18nsim_do_psp(struct sk_buff *skb, struct netdevsim *ns,
19 struct netdevsim *peer_ns, struct skb_ext **psp_ext)
20{
21 enum skb_drop_reason rc = 0;
22 struct psp_assoc *pas;
23 struct net *net;
24 void **ptr;
25
26 rcu_read_lock();
27 pas = psp_skb_get_assoc_rcu(skb);
28 if (!pas) {
29 rc = SKB_NOT_DROPPED_YET;
30 goto out_unlock;
31 }
32
33 if (!skb_transport_header_was_set(skb)) {
34 rc = SKB_DROP_REASON_PSP_OUTPUT;
35 goto out_unlock;
36 }
37
38 ptr = psp_assoc_drv_data(pas);
39 if (*ptr != ns) {
40 rc = SKB_DROP_REASON_PSP_OUTPUT;
41 goto out_unlock;
42 }
43
44 net = sock_net(skb->sk);
45 if (!psp_dev_encapsulate(net, skb, pas->tx.spi, pas->version, 0)) {
46 rc = SKB_DROP_REASON_PSP_OUTPUT;
47 goto out_unlock;
48 }
49
50 /* Now pretend we just received this frame */
51 if (peer_ns->psp.dev->config.versions & (1 << pas->version)) {
52 bool strip_icv = false;
53 u8 generation;
54
55 /* We cheat a bit and put the generation in the key.
56 * In real life if generation was too old, then decryption would
57 * fail. Here, we just make it so a bad key causes a bad
58 * generation too, and psp_sk_rx_policy_check() will fail.
59 */
60 generation = pas->tx.key[0];
61
62 skb_ext_reset(skb);
63 skb->mac_len = ETH_HLEN;
64 if (psp_dev_rcv(skb, peer_ns->psp.dev->id, generation,
65 strip_icv)) {
66 rc = SKB_DROP_REASON_PSP_OUTPUT;
67 goto out_unlock;
68 }
69
70 *psp_ext = skb->extensions;
71 refcount_inc(&(*psp_ext)->refcnt);
72 skb->decrypted = 1;
73
74 u64_stats_update_begin(&ns->psp.syncp);
75 ns->psp.tx_packets++;
76 ns->psp.rx_packets++;
77 ns->psp.tx_bytes += skb->len - skb_inner_transport_offset(skb);
78 ns->psp.rx_bytes += skb->len - skb_inner_transport_offset(skb);
79 u64_stats_update_end(&ns->psp.syncp);
80 } else {
81 struct ipv6hdr *ip6h __maybe_unused;
82 struct iphdr *iph;
83 struct udphdr *uh;
84 __wsum csum;
85
86 /* Do not decapsulate. Receive the skb with the udp and psp
87 * headers still there as if this is a normal udp packet.
88 * psp_dev_encapsulate() sets udp checksum to 0, so we need to
89 * provide a valid checksum here, so the skb isn't dropped.
90 */
91 uh = udp_hdr(skb);
92 csum = skb_checksum(skb, skb_transport_offset(skb),
93 ntohs(uh->len), 0);
94
95 switch (skb->protocol) {
96 case htons(ETH_P_IP):
97 iph = ip_hdr(skb);
98 uh->check = udp_v4_check(ntohs(uh->len), iph->saddr,
99 iph->daddr, csum);
100 break;
101#if IS_ENABLED(CONFIG_IPV6)
102 case htons(ETH_P_IPV6):
103 ip6h = ipv6_hdr(skb);
104 uh->check = udp_v6_check(ntohs(uh->len), &ip6h->saddr,
105 &ip6h->daddr, csum);
106 break;
107#endif
108 }
109
110 uh->check = uh->check ?: CSUM_MANGLED_0;
111 skb->ip_summed = CHECKSUM_NONE;
112 }
113
114out_unlock:
115 rcu_read_unlock();
116 return rc;
117}
118
119static int
120nsim_psp_set_config(struct psp_dev *psd, struct psp_dev_config *conf,
121 struct netlink_ext_ack *extack)
122{
123 return 0;
124}
125
126static int
127nsim_rx_spi_alloc(struct psp_dev *psd, u32 version,
128 struct psp_key_parsed *assoc,
129 struct netlink_ext_ack *extack)
130{
131 struct netdevsim *ns = psd->drv_priv;
132 unsigned int new;
133 int i;
134
135 new = ++ns->psp.spi & PSP_SPI_KEY_ID;
136 if (psd->generation & 1)
137 new |= PSP_SPI_KEY_PHASE;
138
139 assoc->spi = cpu_to_be32(new);
140 assoc->key[0] = psd->generation;
141 for (i = 1; i < PSP_MAX_KEY; i++)
142 assoc->key[i] = ns->psp.spi + i;
143
144 return 0;
145}
146
147static int nsim_assoc_add(struct psp_dev *psd, struct psp_assoc *pas,
148 struct netlink_ext_ack *extack)
149{
150 struct netdevsim *ns = psd->drv_priv;
151 void **ptr = psp_assoc_drv_data(pas);
152
153 /* Copy drv_priv from psd to assoc */
154 *ptr = psd->drv_priv;
155 ns->psp.assoc_cnt++;
156
157 return 0;
158}
159
160static int nsim_key_rotate(struct psp_dev *psd, struct netlink_ext_ack *extack)
161{
162 return 0;
163}
164
165static void nsim_assoc_del(struct psp_dev *psd, struct psp_assoc *pas)
166{
167 struct netdevsim *ns = psd->drv_priv;
168 void **ptr = psp_assoc_drv_data(pas);
169
170 *ptr = NULL;
171 ns->psp.assoc_cnt--;
172}
173
174static void nsim_get_stats(struct psp_dev *psd, struct psp_dev_stats *stats)
175{
176 struct netdevsim *ns = psd->drv_priv;
177 unsigned int start;
178
179 /* WARNING: do *not* blindly zero stats in real drivers!
180 * All required stats must be reported by the device!
181 */
182 memset(stats, 0, sizeof(struct psp_dev_stats));
183
184 do {
185 start = u64_stats_fetch_begin(&ns->psp.syncp);
186 stats->rx_bytes = ns->psp.rx_bytes;
187 stats->rx_packets = ns->psp.rx_packets;
188 stats->tx_bytes = ns->psp.tx_bytes;
189 stats->tx_packets = ns->psp.tx_packets;
190 } while (u64_stats_fetch_retry(&ns->psp.syncp, start));
191}
192
193static struct psp_dev_ops nsim_psp_ops = {
194 .set_config = nsim_psp_set_config,
195 .rx_spi_alloc = nsim_rx_spi_alloc,
196 .tx_key_add = nsim_assoc_add,
197 .tx_key_del = nsim_assoc_del,
198 .key_rotate = nsim_key_rotate,
199 .get_stats = nsim_get_stats,
200};
201
202static struct psp_dev_caps nsim_psp_caps = {
203 .versions = 1 << PSP_VERSION_HDR0_AES_GCM_128 |
204 1 << PSP_VERSION_HDR0_AES_GMAC_128 |
205 1 << PSP_VERSION_HDR0_AES_GCM_256 |
206 1 << PSP_VERSION_HDR0_AES_GMAC_256,
207 .assoc_drv_spc = sizeof(void *),
208};
209
210void nsim_psp_uninit(struct netdevsim *ns)
211{
212 if (!IS_ERR(ns->psp.dev))
213 psp_dev_unregister(ns->psp.dev);
214 WARN_ON(ns->psp.assoc_cnt);
215}
216
217static ssize_t
218nsim_psp_rereg_write(struct file *file, const char __user *data, size_t count,
219 loff_t *ppos)
220{
221 struct netdevsim *ns = file->private_data;
222 int err;
223
224 nsim_psp_uninit(ns);
225
226 ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
227 &nsim_psp_caps, ns);
228 err = PTR_ERR_OR_ZERO(ns->psp.dev);
229 return err ?: count;
230}
231
232static const struct file_operations nsim_psp_rereg_fops = {
233 .open = simple_open,
234 .write = nsim_psp_rereg_write,
235 .llseek = generic_file_llseek,
236 .owner = THIS_MODULE,
237};
238
239int nsim_psp_init(struct netdevsim *ns)
240{
241 struct dentry *ddir = ns->nsim_dev_port->ddir;
242 int err;
243
244 ns->psp.dev = psp_dev_create(ns->netdev, &nsim_psp_ops,
245 &nsim_psp_caps, ns);
246 err = PTR_ERR_OR_ZERO(ns->psp.dev);
247 if (err)
248 return err;
249
250 debugfs_create_file("psp_rereg", 0200, ddir, ns, &nsim_psp_rereg_fops);
251 return 0;
252}