Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

net: psp: add socket security association code

Add the ability to install PSP Rx and Tx crypto keys on TCP
connections. Netlink ops are provided for both operations.
Rx side combines allocating a new Rx key and installing it
on the socket. Theoretically these are separate actions,
but in practice they will always be used one after the
other. We can add distinct "alloc" and "install" ops later.

Reviewed-by: Willem de Bruijn <willemb@google.com>
Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Co-developed-by: Daniel Zahka <daniel.zahka@gmail.com>
Signed-off-by: Daniel Zahka <daniel.zahka@gmail.com>
Reviewed-by: Eric Dumazet <edumazet@google.com>
Link: https://patch.msgid.link/20250917000954.859376-9-daniel.zahka@gmail.com
Signed-off-by: Paolo Abeni <pabeni@redhat.com>

authored by

Jakub Kicinski and committed by
Paolo Abeni
6b46ca26 0917bb13

+854 -11
+70
Documentation/netlink/specs/psp.yaml
··· 38 38 type: u32 39 39 enum: version 40 40 enum-as-flags: true 41 + - 42 + name: assoc 43 + attributes: 44 + - 45 + name: dev-id 46 + doc: PSP device ID. 47 + type: u32 48 + checks: 49 + min: 1 50 + - 51 + name: version 52 + doc: | 53 + PSP versions (AEAD and protocol version) used by this association, 54 + dictates the size of the key. 55 + type: u32 56 + enum: version 57 + - 58 + name: rx-key 59 + type: nest 60 + nested-attributes: keys 61 + - 62 + name: tx-key 63 + type: nest 64 + nested-attributes: keys 65 + - 66 + name: sock-fd 67 + doc: Sockets which should be bound to the association immediately. 68 + type: u32 69 + - 70 + name: keys 71 + attributes: 72 + - 73 + name: key 74 + type: binary 75 + - 76 + name: spi 77 + doc: Security Parameters Index (SPI) of the association. 78 + type: u32 41 79 42 80 operations: 43 81 list: ··· 144 106 doc: Notification about device key getting rotated. 145 107 notify: key-rotate 146 108 mcgrp: use 109 + 110 + - 111 + name: rx-assoc 112 + doc: Allocate a new Rx key + SPI pair, associate it with a socket. 113 + attribute-set: assoc 114 + do: 115 + request: 116 + attributes: 117 + - dev-id 118 + - version 119 + - sock-fd 120 + reply: 121 + attributes: 122 + - dev-id 123 + - rx-key 124 + pre: psp-assoc-device-get-locked 125 + post: psp-device-unlock 126 + - 127 + name: tx-assoc 128 + doc: Add a PSP Tx association. 129 + attribute-set: assoc 130 + do: 131 + request: 132 + attributes: 133 + - dev-id 134 + - version 135 + - tx-key 136 + - sock-fd 137 + reply: 138 + attributes: [] 139 + pre: psp-assoc-device-get-locked 140 + post: psp-device-unlock 147 141 148 142 mcast-groups: 149 143 list:
+105 -9
include/net/psp/functions.h
··· 4 4 #define __NET_PSP_HELPERS_H 5 5 6 6 #include <linux/skbuff.h> 7 + #include <linux/rcupdate.h> 7 8 #include <net/sock.h> 9 + #include <net/tcp.h> 8 10 #include <net/psp/types.h> 9 11 10 12 struct inet_timewait_sock; ··· 18 16 void psp_dev_unregister(struct psp_dev *psd); 19 17 20 18 /* Kernel-facing API */ 19 + void psp_assoc_put(struct psp_assoc *pas); 20 + 21 + static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 22 + { 23 + return pas->drv_data; 24 + } 25 + 21 26 #if IS_ENABLED(CONFIG_INET_PSP) 22 - static inline void psp_sk_assoc_free(struct sock *sk) { } 23 - static inline void 24 - psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 25 - static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 26 - static inline void 27 - psp_reply_set_decrypted(struct sk_buff *skb) { } 27 + unsigned int psp_key_size(u32 version); 28 + void psp_sk_assoc_free(struct sock *sk); 29 + void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 30 + void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 31 + void psp_reply_set_decrypted(struct sk_buff *skb); 32 + 33 + static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 34 + { 35 + return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 36 + } 28 37 29 38 static inline void 30 39 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 31 40 { 41 + struct psp_assoc *pas; 42 + 43 + pas = psp_sk_assoc(sk); 44 + if (pas && pas->tx.spi) 45 + skb->decrypted = 1; 32 46 } 33 47 34 48 static inline unsigned long 35 49 __psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 36 50 unsigned long diffs) 37 51 { 52 + struct psp_skb_ext *a, *b; 53 + 54 + a = skb_ext_find(one, SKB_EXT_PSP); 55 + b = skb_ext_find(two, SKB_EXT_PSP); 56 + 57 + diffs |= (!!a) ^ (!!b); 58 + if (!diffs && unlikely(a)) 59 + diffs |= memcmp(a, b, sizeof(*a)); 38 60 return diffs; 61 + } 62 + 63 + static inline bool 64 + psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 65 + { 66 + bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 67 + u32 end_seq = TCP_SKB_CB(skb)->end_seq; 68 + u32 seq = TCP_SKB_CB(skb)->seq; 69 + bool pure_fin; 70 + 71 + pure_fin = fin && end_seq - seq == 1; 72 + 73 + return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 74 + } 75 + 76 + static inline bool 77 + psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 78 + { 79 + return pse && pas->rx.spi == pse->spi && 80 + pas->generation == pse->generation && 81 + pas->version == pse->version && 82 + pas->dev_id == pse->dev_id; 83 + } 84 + 85 + static inline enum skb_drop_reason 86 + __psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 87 + { 88 + struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 89 + 90 + if (!pas) 91 + return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 92 + 93 + if (likely(psp_pse_matches_pas(pse, pas))) { 94 + if (unlikely(!pas->peer_tx)) 95 + pas->peer_tx = 1; 96 + 97 + return 0; 98 + } 99 + 100 + if (!pse) { 101 + if (!pas->tx.spi || 102 + (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 103 + return 0; 104 + } 105 + 106 + return SKB_DROP_REASON_PSP_INPUT; 39 107 } 40 108 41 109 static inline enum skb_drop_reason 42 110 psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 43 111 { 44 - return 0; 112 + return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 45 113 } 46 114 47 115 static inline enum skb_drop_reason 48 116 psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 49 117 { 50 - return 0; 118 + return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 119 + } 120 + 121 + static inline struct psp_assoc *psp_sk_get_assoc_rcu(struct sock *sk) 122 + { 123 + struct inet_timewait_sock *tw; 124 + struct psp_assoc *pas; 125 + int state; 126 + 127 + state = 1 << READ_ONCE(sk->sk_state); 128 + if (!sk_is_inet(sk) || state & TCPF_NEW_SYN_RECV) 129 + return NULL; 130 + 131 + tw = inet_twsk(sk); 132 + pas = state & TCPF_TIME_WAIT ? rcu_dereference(tw->psp_assoc) : 133 + rcu_dereference(sk->psp_assoc); 134 + return pas; 51 135 } 52 136 53 137 static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 54 138 { 55 - return NULL; 139 + if (!skb->decrypted || !skb->sk) 140 + return NULL; 141 + 142 + return psp_sk_get_assoc_rcu(skb->sk); 56 143 } 57 144 #else 58 145 static inline void psp_sk_assoc_free(struct sock *sk) { } ··· 150 59 static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 151 60 static inline void 152 61 psp_reply_set_decrypted(struct sk_buff *skb) { } 62 + 63 + static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 64 + { 65 + return NULL; 66 + } 153 67 154 68 static inline void 155 69 psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { }
+57
include/net/psp/types.h
··· 51 51 * @refcnt: reference count for the instance 52 52 * @id: instance id 53 53 * @config: current device configuration 54 + * @active_assocs: list of registered associations 54 55 * 55 56 * @rcu: RCU head for freeing the structure 56 57 */ ··· 69 68 70 69 struct psp_dev_config config; 71 70 71 + struct list_head active_assocs; 72 + 72 73 struct rcu_head rcu; 73 74 }; 74 75 ··· 83 80 * Set this field to 0 to indicate PSP is not supported at all. 84 81 */ 85 82 u32 versions; 83 + 84 + /** 85 + * @assoc_drv_spc: size of driver-specific state in Tx assoc 86 + * Determines the size of struct psp_assoc::drv_spc 87 + */ 88 + u32 assoc_drv_spc; 86 89 }; 87 90 88 91 #define PSP_MAX_KEY 32 ··· 98 89 u16 dev_id; 99 90 u8 generation; 100 91 u8 version; 92 + }; 93 + 94 + struct psp_key_parsed { 95 + __be32 spi; 96 + u8 key[PSP_MAX_KEY]; 97 + }; 98 + 99 + struct psp_assoc { 100 + struct psp_dev *psd; 101 + 102 + u16 dev_id; 103 + u8 generation; 104 + u8 version; 105 + u8 peer_tx; 106 + 107 + u32 upgrade_seq; 108 + 109 + struct psp_key_parsed tx; 110 + struct psp_key_parsed rx; 111 + 112 + refcount_t refcnt; 113 + struct rcu_head rcu; 114 + struct work_struct work; 115 + struct list_head assocs_list; 116 + 117 + u8 drv_data[] __aligned(8); 101 118 }; 102 119 103 120 /** ··· 142 107 * @key_rotate: rotate the device key 143 108 */ 144 109 int (*key_rotate)(struct psp_dev *psd, struct netlink_ext_ack *extack); 110 + 111 + /** 112 + * @rx_spi_alloc: allocate an Rx SPI+key pair 113 + * Allocate an Rx SPI and resulting derived key. 114 + * This key should remain valid until key rotation. 115 + */ 116 + int (*rx_spi_alloc)(struct psp_dev *psd, u32 version, 117 + struct psp_key_parsed *assoc, 118 + struct netlink_ext_ack *extack); 119 + 120 + /** 121 + * @tx_key_add: add a Tx key to the device 122 + * Install an association in the device. Core will allocate space 123 + * for the driver to use at drv_data. 124 + */ 125 + int (*tx_key_add)(struct psp_dev *psd, struct psp_assoc *pas, 126 + struct netlink_ext_ack *extack); 127 + /** 128 + * @tx_key_del: remove a Tx key from the device 129 + * Remove an association from the device. 130 + */ 131 + void (*tx_key_del)(struct psp_dev *psd, struct psp_assoc *pas); 145 132 }; 146 133 147 134 #endif /* __NET_PSP_H */
+21
include/uapi/linux/psp.h
··· 27 27 }; 28 28 29 29 enum { 30 + PSP_A_ASSOC_DEV_ID = 1, 31 + PSP_A_ASSOC_VERSION, 32 + PSP_A_ASSOC_RX_KEY, 33 + PSP_A_ASSOC_TX_KEY, 34 + PSP_A_ASSOC_SOCK_FD, 35 + 36 + __PSP_A_ASSOC_MAX, 37 + PSP_A_ASSOC_MAX = (__PSP_A_ASSOC_MAX - 1) 38 + }; 39 + 40 + enum { 41 + PSP_A_KEYS_KEY = 1, 42 + PSP_A_KEYS_SPI, 43 + 44 + __PSP_A_KEYS_MAX, 45 + PSP_A_KEYS_MAX = (__PSP_A_KEYS_MAX - 1) 46 + }; 47 + 48 + enum { 30 49 PSP_CMD_DEV_GET = 1, 31 50 PSP_CMD_DEV_ADD_NTF, 32 51 PSP_CMD_DEV_DEL_NTF, ··· 53 34 PSP_CMD_DEV_CHANGE_NTF, 54 35 PSP_CMD_KEY_ROTATE, 55 36 PSP_CMD_KEY_ROTATE_NTF, 37 + PSP_CMD_RX_ASSOC, 38 + PSP_CMD_TX_ASSOC, 56 39 57 40 __PSP_CMD_MAX, 58 41 PSP_CMD_MAX = (__PSP_CMD_MAX - 1)
+1
net/psp/Kconfig
··· 6 6 bool "PSP Security Protocol support" 7 7 depends on INET 8 8 select SKB_DECRYPTED 9 + select SOCK_VALIDATE_XMIT 9 10 help 10 11 Enable kernel support for the PSP protocol. 11 12 For more information see:
+1 -1
net/psp/Makefile
··· 2 2 3 3 obj-$(CONFIG_INET_PSP) += psp.o 4 4 5 - psp-y := psp_main.o psp_nl.o psp-nl-gen.o 5 + psp-y := psp_main.o psp_nl.o psp_sock.o psp-nl-gen.o
+39
net/psp/psp-nl-gen.c
··· 10 10 11 11 #include <uapi/linux/psp.h> 12 12 13 + /* Common nested types */ 14 + const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1] = { 15 + [PSP_A_KEYS_KEY] = { .type = NLA_BINARY, }, 16 + [PSP_A_KEYS_SPI] = { .type = NLA_U32, }, 17 + }; 18 + 13 19 /* PSP_CMD_DEV_GET - do */ 14 20 static const struct nla_policy psp_dev_get_nl_policy[PSP_A_DEV_ID + 1] = { 15 21 [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), ··· 30 24 /* PSP_CMD_KEY_ROTATE - do */ 31 25 static const struct nla_policy psp_key_rotate_nl_policy[PSP_A_DEV_ID + 1] = { 32 26 [PSP_A_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), 27 + }; 28 + 29 + /* PSP_CMD_RX_ASSOC - do */ 30 + static const struct nla_policy psp_rx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { 31 + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), 32 + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), 33 + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, 34 + }; 35 + 36 + /* PSP_CMD_TX_ASSOC - do */ 37 + static const struct nla_policy psp_tx_assoc_nl_policy[PSP_A_ASSOC_SOCK_FD + 1] = { 38 + [PSP_A_ASSOC_DEV_ID] = NLA_POLICY_MIN(NLA_U32, 1), 39 + [PSP_A_ASSOC_VERSION] = NLA_POLICY_MAX(NLA_U32, 3), 40 + [PSP_A_ASSOC_TX_KEY] = NLA_POLICY_NESTED(psp_keys_nl_policy), 41 + [PSP_A_ASSOC_SOCK_FD] = { .type = NLA_U32, }, 33 42 }; 34 43 35 44 /* Ops table for psp */ ··· 79 58 .post_doit = psp_device_unlock, 80 59 .policy = psp_key_rotate_nl_policy, 81 60 .maxattr = PSP_A_DEV_ID, 61 + .flags = GENL_CMD_CAP_DO, 62 + }, 63 + { 64 + .cmd = PSP_CMD_RX_ASSOC, 65 + .pre_doit = psp_assoc_device_get_locked, 66 + .doit = psp_nl_rx_assoc_doit, 67 + .post_doit = psp_device_unlock, 68 + .policy = psp_rx_assoc_nl_policy, 69 + .maxattr = PSP_A_ASSOC_SOCK_FD, 70 + .flags = GENL_CMD_CAP_DO, 71 + }, 72 + { 73 + .cmd = PSP_CMD_TX_ASSOC, 74 + .pre_doit = psp_assoc_device_get_locked, 75 + .doit = psp_nl_tx_assoc_doit, 76 + .post_doit = psp_device_unlock, 77 + .policy = psp_tx_assoc_nl_policy, 78 + .maxattr = PSP_A_ASSOC_SOCK_FD, 82 79 .flags = GENL_CMD_CAP_DO, 83 80 }, 84 81 };
+7
net/psp/psp-nl-gen.h
··· 11 11 12 12 #include <uapi/linux/psp.h> 13 13 14 + /* Common nested types */ 15 + extern const struct nla_policy psp_keys_nl_policy[PSP_A_KEYS_SPI + 1]; 16 + 14 17 int psp_device_get_locked(const struct genl_split_ops *ops, 15 18 struct sk_buff *skb, struct genl_info *info); 19 + int psp_assoc_device_get_locked(const struct genl_split_ops *ops, 20 + struct sk_buff *skb, struct genl_info *info); 16 21 void 17 22 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, 18 23 struct genl_info *info); ··· 26 21 int psp_nl_dev_get_dumpit(struct sk_buff *skb, struct netlink_callback *cb); 27 22 int psp_nl_dev_set_doit(struct sk_buff *skb, struct genl_info *info); 28 23 int psp_nl_key_rotate_doit(struct sk_buff *skb, struct genl_info *info); 24 + int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info); 25 + int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info); 29 26 30 27 enum { 31 28 PSP_NLGRP_MGMT,
+22
net/psp/psp.h
··· 4 4 #define __PSP_PSP_H 5 5 6 6 #include <linux/list.h> 7 + #include <linux/lockdep.h> 7 8 #include <linux/mutex.h> 8 9 #include <net/netns/generic.h> 9 10 #include <net/psp.h> ··· 18 17 19 18 void psp_nl_notify_dev(struct psp_dev *psd, u32 cmd); 20 19 20 + struct psp_assoc *psp_assoc_create(struct psp_dev *psd); 21 + struct psp_dev *psp_dev_get_for_sock(struct sock *sk); 22 + void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas); 23 + int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, 24 + struct psp_key_parsed *key, 25 + struct netlink_ext_ack *extack); 26 + int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, 27 + u32 version, struct psp_key_parsed *key, 28 + struct netlink_ext_ack *extack); 29 + 21 30 static inline void psp_dev_get(struct psp_dev *psd) 22 31 { 23 32 refcount_inc(&psd->refcnt); 33 + } 34 + 35 + static inline bool psp_dev_tryget(struct psp_dev *psd) 36 + { 37 + return refcount_inc_not_zero(&psd->refcnt); 24 38 } 25 39 26 40 static inline void psp_dev_put(struct psp_dev *psd) 27 41 { 28 42 if (refcount_dec_and_test(&psd->refcnt)) 29 43 psp_dev_destroy(psd); 44 + } 45 + 46 + static inline bool psp_dev_is_registered(struct psp_dev *psd) 47 + { 48 + lockdep_assert_held(&psd->lock); 49 + return !!psd->ops; 30 50 } 31 51 32 52 #endif /* __PSP_PSP_H */
+25 -1
net/psp/psp_main.c
··· 55 55 56 56 if (WARN_ON(!psd_caps->versions || 57 57 !psd_ops->set_config || 58 - !psd_ops->key_rotate)) 58 + !psd_ops->key_rotate || 59 + !psd_ops->rx_spi_alloc || 60 + !psd_ops->tx_key_add || 61 + !psd_ops->tx_key_del)) 59 62 return ERR_PTR(-EINVAL); 60 63 61 64 psd = kzalloc(sizeof(*psd), GFP_KERNEL); ··· 71 68 psd->drv_priv = priv_ptr; 72 69 73 70 mutex_init(&psd->lock); 71 + INIT_LIST_HEAD(&psd->active_assocs); 74 72 refcount_set(&psd->refcnt, 1); 75 73 76 74 mutex_lock(&psp_devs_lock); ··· 111 107 */ 112 108 void psp_dev_unregister(struct psp_dev *psd) 113 109 { 110 + struct psp_assoc *pas, *next; 111 + 114 112 mutex_lock(&psp_devs_lock); 115 113 mutex_lock(&psd->lock); 116 114 ··· 125 119 xa_store(&psp_devs, psd->id, NULL, GFP_KERNEL); 126 120 mutex_unlock(&psp_devs_lock); 127 121 122 + list_for_each_entry_safe(pas, next, &psd->active_assocs, assocs_list) 123 + psp_dev_tx_key_del(psd, pas); 124 + 128 125 rcu_assign_pointer(psd->main_netdev->psp_dev, NULL); 129 126 130 127 psd->ops = NULL; ··· 138 129 psp_dev_put(psd); 139 130 } 140 131 EXPORT_SYMBOL(psp_dev_unregister); 132 + 133 + unsigned int psp_key_size(u32 version) 134 + { 135 + switch (version) { 136 + case PSP_VERSION_HDR0_AES_GCM_128: 137 + case PSP_VERSION_HDR0_AES_GMAC_128: 138 + return 16; 139 + case PSP_VERSION_HDR0_AES_GCM_256: 140 + case PSP_VERSION_HDR0_AES_GMAC_256: 141 + return 32; 142 + default: 143 + return 0; 144 + } 145 + } 146 + EXPORT_SYMBOL(psp_key_size); 141 147 142 148 static int __init psp_init(void) 143 149 {
+232
net/psp/psp_nl.c
··· 79 79 psp_device_unlock(const struct genl_split_ops *ops, struct sk_buff *skb, 80 80 struct genl_info *info) 81 81 { 82 + struct socket *socket = info->user_ptr[1]; 82 83 struct psp_dev *psd = info->user_ptr[0]; 83 84 84 85 mutex_unlock(&psd->lock); 86 + if (socket) 87 + sockfd_put(socket); 85 88 } 86 89 87 90 static int ··· 261 258 err_free_ntf: 262 259 nlmsg_free(ntf); 263 260 err_free_rsp: 261 + nlmsg_free(rsp); 262 + return err; 263 + } 264 + 265 + /* Key etc. */ 266 + 267 + int psp_assoc_device_get_locked(const struct genl_split_ops *ops, 268 + struct sk_buff *skb, struct genl_info *info) 269 + { 270 + struct socket *socket; 271 + struct psp_dev *psd; 272 + struct nlattr *id; 273 + int fd, err; 274 + 275 + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_SOCK_FD)) 276 + return -EINVAL; 277 + 278 + fd = nla_get_u32(info->attrs[PSP_A_ASSOC_SOCK_FD]); 279 + socket = sockfd_lookup(fd, &err); 280 + if (!socket) 281 + return err; 282 + 283 + if (!sk_is_tcp(socket->sk)) { 284 + NL_SET_ERR_MSG_ATTR(info->extack, 285 + info->attrs[PSP_A_ASSOC_SOCK_FD], 286 + "Unsupported socket family and type"); 287 + err = -EOPNOTSUPP; 288 + goto err_sock_put; 289 + } 290 + 291 + psd = psp_dev_get_for_sock(socket->sk); 292 + if (psd) { 293 + err = psp_dev_check_access(psd, genl_info_net(info)); 294 + if (err) { 295 + psp_dev_put(psd); 296 + psd = NULL; 297 + } 298 + } 299 + 300 + if (!psd && GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_DEV_ID)) { 301 + err = -EINVAL; 302 + goto err_sock_put; 303 + } 304 + 305 + id = info->attrs[PSP_A_ASSOC_DEV_ID]; 306 + if (psd) { 307 + mutex_lock(&psd->lock); 308 + if (id && psd->id != nla_get_u32(id)) { 309 + mutex_unlock(&psd->lock); 310 + NL_SET_ERR_MSG_ATTR(info->extack, id, 311 + "Device id vs socket mismatch"); 312 + err = -EINVAL; 313 + goto err_psd_put; 314 + } 315 + 316 + psp_dev_put(psd); 317 + } else { 318 + psd = psp_device_get_and_lock(genl_info_net(info), id); 319 + if (IS_ERR(psd)) { 320 + err = PTR_ERR(psd); 321 + goto err_sock_put; 322 + } 323 + } 324 + 325 + info->user_ptr[0] = psd; 326 + info->user_ptr[1] = socket; 327 + 328 + return 0; 329 + 330 + err_psd_put: 331 + psp_dev_put(psd); 332 + err_sock_put: 333 + sockfd_put(socket); 334 + return err; 335 + } 336 + 337 + static int 338 + psp_nl_parse_key(struct genl_info *info, u32 attr, struct psp_key_parsed *key, 339 + unsigned int key_sz) 340 + { 341 + struct nlattr *nest = info->attrs[attr]; 342 + struct nlattr *tb[PSP_A_KEYS_SPI + 1]; 343 + u32 spi; 344 + int err; 345 + 346 + err = nla_parse_nested(tb, ARRAY_SIZE(tb) - 1, nest, 347 + psp_keys_nl_policy, info->extack); 348 + if (err) 349 + return err; 350 + 351 + if (NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_KEY) || 352 + NL_REQ_ATTR_CHECK(info->extack, nest, tb, PSP_A_KEYS_SPI)) 353 + return -EINVAL; 354 + 355 + if (nla_len(tb[PSP_A_KEYS_KEY]) != key_sz) { 356 + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 357 + "incorrect key length"); 358 + return -EINVAL; 359 + } 360 + 361 + spi = nla_get_u32(tb[PSP_A_KEYS_SPI]); 362 + if (!(spi & PSP_SPI_KEY_ID)) { 363 + NL_SET_ERR_MSG_ATTR(info->extack, tb[PSP_A_KEYS_KEY], 364 + "invalid SPI: lower 31b must be non-zero"); 365 + return -EINVAL; 366 + } 367 + 368 + key->spi = cpu_to_be32(spi); 369 + memcpy(key->key, nla_data(tb[PSP_A_KEYS_KEY]), key_sz); 370 + 371 + return 0; 372 + } 373 + 374 + static int 375 + psp_nl_put_key(struct sk_buff *skb, u32 attr, u32 version, 376 + struct psp_key_parsed *key) 377 + { 378 + int key_sz = psp_key_size(version); 379 + void *nest; 380 + 381 + nest = nla_nest_start(skb, attr); 382 + 383 + if (nla_put_u32(skb, PSP_A_KEYS_SPI, be32_to_cpu(key->spi)) || 384 + nla_put(skb, PSP_A_KEYS_KEY, key_sz, key->key)) { 385 + nla_nest_cancel(skb, nest); 386 + return -EMSGSIZE; 387 + } 388 + 389 + nla_nest_end(skb, nest); 390 + 391 + return 0; 392 + } 393 + 394 + int psp_nl_rx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 395 + { 396 + struct socket *socket = info->user_ptr[1]; 397 + struct psp_dev *psd = info->user_ptr[0]; 398 + struct psp_key_parsed key; 399 + struct psp_assoc *pas; 400 + struct sk_buff *rsp; 401 + u32 version; 402 + int err; 403 + 404 + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION)) 405 + return -EINVAL; 406 + 407 + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 408 + if (!(psd->caps->versions & (1 << version))) { 409 + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 410 + return -EOPNOTSUPP; 411 + } 412 + 413 + rsp = psp_nl_reply_new(info); 414 + if (!rsp) 415 + return -ENOMEM; 416 + 417 + pas = psp_assoc_create(psd); 418 + if (!pas) { 419 + err = -ENOMEM; 420 + goto err_free_rsp; 421 + } 422 + pas->version = version; 423 + 424 + err = psd->ops->rx_spi_alloc(psd, version, &key, info->extack); 425 + if (err) 426 + goto err_free_pas; 427 + 428 + if (nla_put_u32(rsp, PSP_A_ASSOC_DEV_ID, psd->id) || 429 + psp_nl_put_key(rsp, PSP_A_ASSOC_RX_KEY, version, &key)) { 430 + err = -EMSGSIZE; 431 + goto err_free_pas; 432 + } 433 + 434 + err = psp_sock_assoc_set_rx(socket->sk, pas, &key, info->extack); 435 + if (err) { 436 + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_SOCK_FD]); 437 + goto err_free_pas; 438 + } 439 + psp_assoc_put(pas); 440 + 441 + return psp_nl_reply_send(rsp, info); 442 + 443 + err_free_pas: 444 + psp_assoc_put(pas); 445 + err_free_rsp: 446 + nlmsg_free(rsp); 447 + return err; 448 + } 449 + 450 + int psp_nl_tx_assoc_doit(struct sk_buff *skb, struct genl_info *info) 451 + { 452 + struct socket *socket = info->user_ptr[1]; 453 + struct psp_dev *psd = info->user_ptr[0]; 454 + struct psp_key_parsed key; 455 + struct sk_buff *rsp; 456 + unsigned int key_sz; 457 + u32 version; 458 + int err; 459 + 460 + if (GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_VERSION) || 461 + GENL_REQ_ATTR_CHECK(info, PSP_A_ASSOC_TX_KEY)) 462 + return -EINVAL; 463 + 464 + version = nla_get_u32(info->attrs[PSP_A_ASSOC_VERSION]); 465 + if (!(psd->caps->versions & (1 << version))) { 466 + NL_SET_BAD_ATTR(info->extack, info->attrs[PSP_A_ASSOC_VERSION]); 467 + return -EOPNOTSUPP; 468 + } 469 + 470 + key_sz = psp_key_size(version); 471 + if (!key_sz) 472 + return -EINVAL; 473 + 474 + err = psp_nl_parse_key(info, PSP_A_ASSOC_TX_KEY, &key, key_sz); 475 + if (err < 0) 476 + return err; 477 + 478 + rsp = psp_nl_reply_new(info); 479 + if (!rsp) 480 + return -ENOMEM; 481 + 482 + err = psp_sock_assoc_set_tx(socket->sk, psd, version, &key, 483 + info->extack); 484 + if (err) 485 + goto err_free_msg; 486 + 487 + return psp_nl_reply_send(rsp, info); 488 + 489 + err_free_msg: 264 490 nlmsg_free(rsp); 265 491 return err; 266 492 }
+274
net/psp/psp_sock.c
··· 1 + // SPDX-License-Identifier: GPL-2.0-only 2 + 3 + #include <linux/file.h> 4 + #include <linux/net.h> 5 + #include <linux/rcupdate.h> 6 + #include <linux/tcp.h> 7 + 8 + #include <net/ip.h> 9 + #include <net/psp.h> 10 + #include "psp.h" 11 + 12 + struct psp_dev *psp_dev_get_for_sock(struct sock *sk) 13 + { 14 + struct dst_entry *dst; 15 + struct psp_dev *psd; 16 + 17 + dst = sk_dst_get(sk); 18 + if (!dst) 19 + return NULL; 20 + 21 + rcu_read_lock(); 22 + psd = rcu_dereference(dst->dev->psp_dev); 23 + if (psd && !psp_dev_tryget(psd)) 24 + psd = NULL; 25 + rcu_read_unlock(); 26 + 27 + dst_release(dst); 28 + 29 + return psd; 30 + } 31 + 32 + static struct sk_buff * 33 + psp_validate_xmit(struct sock *sk, struct net_device *dev, struct sk_buff *skb) 34 + { 35 + struct psp_assoc *pas; 36 + bool good; 37 + 38 + rcu_read_lock(); 39 + pas = psp_skb_get_assoc_rcu(skb); 40 + good = !pas || rcu_access_pointer(dev->psp_dev) == pas->psd; 41 + rcu_read_unlock(); 42 + if (!good) { 43 + kfree_skb_reason(skb, SKB_DROP_REASON_PSP_OUTPUT); 44 + return NULL; 45 + } 46 + 47 + return skb; 48 + } 49 + 50 + struct psp_assoc *psp_assoc_create(struct psp_dev *psd) 51 + { 52 + struct psp_assoc *pas; 53 + 54 + lockdep_assert_held(&psd->lock); 55 + 56 + pas = kzalloc(struct_size(pas, drv_data, psd->caps->assoc_drv_spc), 57 + GFP_KERNEL_ACCOUNT); 58 + if (!pas) 59 + return NULL; 60 + 61 + pas->psd = psd; 62 + pas->dev_id = psd->id; 63 + psp_dev_get(psd); 64 + refcount_set(&pas->refcnt, 1); 65 + 66 + list_add_tail(&pas->assocs_list, &psd->active_assocs); 67 + 68 + return pas; 69 + } 70 + 71 + static struct psp_assoc *psp_assoc_dummy(struct psp_assoc *pas) 72 + { 73 + struct psp_dev *psd = pas->psd; 74 + size_t sz; 75 + 76 + lockdep_assert_held(&psd->lock); 77 + 78 + sz = struct_size(pas, drv_data, psd->caps->assoc_drv_spc); 79 + return kmemdup(pas, sz, GFP_KERNEL); 80 + } 81 + 82 + static int psp_dev_tx_key_add(struct psp_dev *psd, struct psp_assoc *pas, 83 + struct netlink_ext_ack *extack) 84 + { 85 + return psd->ops->tx_key_add(psd, pas, extack); 86 + } 87 + 88 + void psp_dev_tx_key_del(struct psp_dev *psd, struct psp_assoc *pas) 89 + { 90 + if (pas->tx.spi) 91 + psd->ops->tx_key_del(psd, pas); 92 + list_del(&pas->assocs_list); 93 + } 94 + 95 + static void psp_assoc_free(struct work_struct *work) 96 + { 97 + struct psp_assoc *pas = container_of(work, struct psp_assoc, work); 98 + struct psp_dev *psd = pas->psd; 99 + 100 + mutex_lock(&psd->lock); 101 + if (psd->ops) 102 + psp_dev_tx_key_del(psd, pas); 103 + mutex_unlock(&psd->lock); 104 + psp_dev_put(psd); 105 + kfree(pas); 106 + } 107 + 108 + static void psp_assoc_free_queue(struct rcu_head *head) 109 + { 110 + struct psp_assoc *pas = container_of(head, struct psp_assoc, rcu); 111 + 112 + INIT_WORK(&pas->work, psp_assoc_free); 113 + schedule_work(&pas->work); 114 + } 115 + 116 + /** 117 + * psp_assoc_put() - release a reference on a PSP association 118 + * @pas: association to release 119 + */ 120 + void psp_assoc_put(struct psp_assoc *pas) 121 + { 122 + if (pas && refcount_dec_and_test(&pas->refcnt)) 123 + call_rcu(&pas->rcu, psp_assoc_free_queue); 124 + } 125 + 126 + void psp_sk_assoc_free(struct sock *sk) 127 + { 128 + struct psp_assoc *pas = rcu_dereference_protected(sk->psp_assoc, 1); 129 + 130 + rcu_assign_pointer(sk->psp_assoc, NULL); 131 + psp_assoc_put(pas); 132 + } 133 + 134 + int psp_sock_assoc_set_rx(struct sock *sk, struct psp_assoc *pas, 135 + struct psp_key_parsed *key, 136 + struct netlink_ext_ack *extack) 137 + { 138 + int err; 139 + 140 + memcpy(&pas->rx, key, sizeof(*key)); 141 + 142 + lock_sock(sk); 143 + 144 + if (psp_sk_assoc(sk)) { 145 + NL_SET_ERR_MSG(extack, "Socket already has PSP state"); 146 + err = -EBUSY; 147 + goto exit_unlock; 148 + } 149 + 150 + refcount_inc(&pas->refcnt); 151 + rcu_assign_pointer(sk->psp_assoc, pas); 152 + err = 0; 153 + 154 + exit_unlock: 155 + release_sock(sk); 156 + 157 + return err; 158 + } 159 + 160 + static int psp_sock_recv_queue_check(struct sock *sk, struct psp_assoc *pas) 161 + { 162 + struct psp_skb_ext *pse; 163 + struct sk_buff *skb; 164 + 165 + skb_rbtree_walk(skb, &tcp_sk(sk)->out_of_order_queue) { 166 + pse = skb_ext_find(skb, SKB_EXT_PSP); 167 + if (!psp_pse_matches_pas(pse, pas)) 168 + return -EBUSY; 169 + } 170 + 171 + skb_queue_walk(&sk->sk_receive_queue, skb) { 172 + pse = skb_ext_find(skb, SKB_EXT_PSP); 173 + if (!psp_pse_matches_pas(pse, pas)) 174 + return -EBUSY; 175 + } 176 + return 0; 177 + } 178 + 179 + int psp_sock_assoc_set_tx(struct sock *sk, struct psp_dev *psd, 180 + u32 version, struct psp_key_parsed *key, 181 + struct netlink_ext_ack *extack) 182 + { 183 + struct psp_assoc *pas, *dummy; 184 + int err; 185 + 186 + lock_sock(sk); 187 + 188 + pas = psp_sk_assoc(sk); 189 + if (!pas) { 190 + NL_SET_ERR_MSG(extack, "Socket has no Rx key"); 191 + err = -EINVAL; 192 + goto exit_unlock; 193 + } 194 + if (pas->psd != psd) { 195 + NL_SET_ERR_MSG(extack, "Rx key from different device"); 196 + err = -EINVAL; 197 + goto exit_unlock; 198 + } 199 + if (pas->version != version) { 200 + NL_SET_ERR_MSG(extack, 201 + "PSP version mismatch with existing state"); 202 + err = -EINVAL; 203 + goto exit_unlock; 204 + } 205 + if (pas->tx.spi) { 206 + NL_SET_ERR_MSG(extack, "Tx key already set"); 207 + err = -EBUSY; 208 + goto exit_unlock; 209 + } 210 + 211 + err = psp_sock_recv_queue_check(sk, pas); 212 + if (err) { 213 + NL_SET_ERR_MSG(extack, "Socket has incompatible segments already in the recv queue"); 214 + goto exit_unlock; 215 + } 216 + 217 + /* Pass a fake association to drivers to make sure they don't 218 + * try to store pointers to it. For re-keying we'll need to 219 + * re-allocate the assoc structures. 220 + */ 221 + dummy = psp_assoc_dummy(pas); 222 + if (!dummy) { 223 + err = -ENOMEM; 224 + goto exit_unlock; 225 + } 226 + 227 + memcpy(&dummy->tx, key, sizeof(*key)); 228 + err = psp_dev_tx_key_add(psd, dummy, extack); 229 + if (err) 230 + goto exit_free_dummy; 231 + 232 + memcpy(pas->drv_data, dummy->drv_data, psd->caps->assoc_drv_spc); 233 + memcpy(&pas->tx, key, sizeof(*key)); 234 + 235 + WRITE_ONCE(sk->sk_validate_xmit_skb, psp_validate_xmit); 236 + tcp_write_collapse_fence(sk); 237 + pas->upgrade_seq = tcp_sk(sk)->rcv_nxt; 238 + 239 + exit_free_dummy: 240 + kfree(dummy); 241 + exit_unlock: 242 + release_sock(sk); 243 + return err; 244 + } 245 + 246 + void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) 247 + { 248 + struct psp_assoc *pas = psp_sk_assoc(sk); 249 + 250 + if (pas) 251 + refcount_inc(&pas->refcnt); 252 + rcu_assign_pointer(tw->psp_assoc, pas); 253 + tw->tw_validate_xmit_skb = psp_validate_xmit; 254 + } 255 + 256 + void psp_twsk_assoc_free(struct inet_timewait_sock *tw) 257 + { 258 + struct psp_assoc *pas = rcu_dereference_protected(tw->psp_assoc, 1); 259 + 260 + rcu_assign_pointer(tw->psp_assoc, NULL); 261 + psp_assoc_put(pas); 262 + } 263 + 264 + void psp_reply_set_decrypted(struct sk_buff *skb) 265 + { 266 + struct psp_assoc *pas; 267 + 268 + rcu_read_lock(); 269 + pas = psp_sk_get_assoc_rcu(skb->sk); 270 + if (pas && pas->tx.spi) 271 + skb->decrypted = 1; 272 + rcu_read_unlock(); 273 + } 274 + EXPORT_IPV6_MOD_GPL(psp_reply_set_decrypted);