at master 5.3 kB view raw
1/* SPDX-License-Identifier: GPL-2.0-only */ 2 3#ifndef __NET_PSP_HELPERS_H 4#define __NET_PSP_HELPERS_H 5 6#include <linux/skbuff.h> 7#include <linux/rcupdate.h> 8#include <linux/udp.h> 9#include <net/sock.h> 10#include <net/tcp.h> 11#include <net/psp/types.h> 12 13struct inet_timewait_sock; 14 15/* Driver-facing API */ 16struct psp_dev * 17psp_dev_create(struct net_device *netdev, struct psp_dev_ops *psd_ops, 18 struct psp_dev_caps *psd_caps, void *priv_ptr); 19void psp_dev_unregister(struct psp_dev *psd); 20bool psp_dev_encapsulate(struct net *net, struct sk_buff *skb, __be32 spi, 21 u8 ver, __be16 sport); 22int psp_dev_rcv(struct sk_buff *skb, u16 dev_id, u8 generation, bool strip_icv); 23 24/* Kernel-facing API */ 25void psp_assoc_put(struct psp_assoc *pas); 26 27static inline void *psp_assoc_drv_data(struct psp_assoc *pas) 28{ 29 return pas->drv_data; 30} 31 32#if IS_ENABLED(CONFIG_INET_PSP) 33unsigned int psp_key_size(u32 version); 34void psp_sk_assoc_free(struct sock *sk); 35void psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk); 36void psp_twsk_assoc_free(struct inet_timewait_sock *tw); 37void psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb); 38 39static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 40{ 41 return rcu_dereference_check(sk->psp_assoc, lockdep_sock_is_held(sk)); 42} 43 44static inline void 45psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) 46{ 47 struct psp_assoc *pas; 48 49 pas = psp_sk_assoc(sk); 50 if (pas && pas->tx.spi) 51 skb->decrypted = 1; 52} 53 54static inline unsigned long 55__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 56 unsigned long diffs) 57{ 58 struct psp_skb_ext *a, *b; 59 60 a = skb_ext_find(one, SKB_EXT_PSP); 61 b = skb_ext_find(two, SKB_EXT_PSP); 62 63 diffs |= (!!a) ^ (!!b); 64 if (!diffs && unlikely(a)) 65 diffs |= memcmp(a, b, sizeof(*a)); 66 return diffs; 67} 68 69static inline bool 70psp_is_allowed_nondata(struct sk_buff *skb, struct psp_assoc *pas) 71{ 72 bool fin = !!(TCP_SKB_CB(skb)->tcp_flags & TCPHDR_FIN); 73 u32 end_seq = TCP_SKB_CB(skb)->end_seq; 74 u32 seq = TCP_SKB_CB(skb)->seq; 75 bool pure_fin; 76 77 pure_fin = fin && end_seq - seq == 1; 78 79 return seq == end_seq || (pure_fin && seq == pas->upgrade_seq); 80} 81 82static inline bool 83psp_pse_matches_pas(struct psp_skb_ext *pse, struct psp_assoc *pas) 84{ 85 return pse && pas->rx.spi == pse->spi && 86 pas->generation == pse->generation && 87 pas->version == pse->version && 88 pas->dev_id == pse->dev_id; 89} 90 91static inline enum skb_drop_reason 92__psp_sk_rx_policy_check(struct sk_buff *skb, struct psp_assoc *pas) 93{ 94 struct psp_skb_ext *pse = skb_ext_find(skb, SKB_EXT_PSP); 95 96 if (!pas) 97 return pse ? SKB_DROP_REASON_PSP_INPUT : 0; 98 99 if (likely(psp_pse_matches_pas(pse, pas))) { 100 if (unlikely(!pas->peer_tx)) 101 pas->peer_tx = 1; 102 103 return 0; 104 } 105 106 if (!pse) { 107 if (!pas->tx.spi || 108 (!pas->peer_tx && psp_is_allowed_nondata(skb, pas))) 109 return 0; 110 } 111 112 return SKB_DROP_REASON_PSP_INPUT; 113} 114 115static inline enum skb_drop_reason 116psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 117{ 118 return __psp_sk_rx_policy_check(skb, psp_sk_assoc(sk)); 119} 120 121static inline enum skb_drop_reason 122psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 123{ 124 return __psp_sk_rx_policy_check(skb, rcu_dereference(tw->psp_assoc)); 125} 126 127static inline struct psp_assoc *psp_sk_get_assoc_rcu(const struct sock *sk) 128{ 129 struct psp_assoc *pas; 130 int state; 131 132 state = READ_ONCE(sk->sk_state); 133 if (!sk_is_inet(sk) || state == TCP_NEW_SYN_RECV) 134 return NULL; 135 136 pas = state == TCP_TIME_WAIT ? 137 rcu_dereference(inet_twsk(sk)->psp_assoc) : 138 rcu_dereference(sk->psp_assoc); 139 return pas; 140} 141 142static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 143{ 144 if (!skb->decrypted || !skb->sk) 145 return NULL; 146 147 return psp_sk_get_assoc_rcu(skb->sk); 148} 149 150static inline unsigned int psp_sk_overhead(const struct sock *sk) 151{ 152 int psp_encap = sizeof(struct udphdr) + PSP_HDR_SIZE + PSP_TRL_SIZE; 153 bool has_psp = rcu_access_pointer(sk->psp_assoc); 154 155 return has_psp ? psp_encap : 0; 156} 157#else 158static inline void psp_sk_assoc_free(struct sock *sk) { } 159static inline void 160psp_twsk_init(struct inet_timewait_sock *tw, const struct sock *sk) { } 161static inline void psp_twsk_assoc_free(struct inet_timewait_sock *tw) { } 162static inline void 163psp_reply_set_decrypted(const struct sock *sk, struct sk_buff *skb) { } 164 165static inline struct psp_assoc *psp_sk_assoc(const struct sock *sk) 166{ 167 return NULL; 168} 169 170static inline void 171psp_enqueue_set_decrypted(struct sock *sk, struct sk_buff *skb) { } 172 173static inline unsigned long 174__psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two, 175 unsigned long diffs) 176{ 177 return diffs; 178} 179 180static inline enum skb_drop_reason 181psp_sk_rx_policy_check(struct sock *sk, struct sk_buff *skb) 182{ 183 return 0; 184} 185 186static inline enum skb_drop_reason 187psp_twsk_rx_policy_check(struct inet_timewait_sock *tw, struct sk_buff *skb) 188{ 189 return 0; 190} 191 192static inline struct psp_assoc *psp_skb_get_assoc_rcu(struct sk_buff *skb) 193{ 194 return NULL; 195} 196 197static inline unsigned int psp_sk_overhead(const struct sock *sk) 198{ 199 return 0; 200} 201#endif 202 203static inline unsigned long 204psp_skb_coalesce_diff(const struct sk_buff *one, const struct sk_buff *two) 205{ 206 return __psp_skb_coalesce_diff(one, two, 0); 207} 208 209#endif /* __NET_PSP_HELPERS_H */