Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

net/tls: Split conf to rx + tx

In TLS inline crypto, we can have one direction in software
and another in hardware. Thus, we split the TLS configuration to separate
structures for receive and transmit.

Signed-off-by: Boris Pismenny <borisp@mellanox.com>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Boris Pismenny and committed by
David S. Miller
f66de3ee 2342a851

+158 -124
+33 -18
include/net/tls.h
··· 83 83 void (*unhash)(struct tls_device *device, struct sock *sk); 84 84 }; 85 85 86 - struct tls_sw_context { 86 + struct tls_sw_context_tx { 87 87 struct crypto_aead *aead_send; 88 - struct crypto_aead *aead_recv; 89 88 struct crypto_wait async_wait; 90 89 91 - /* Receive context */ 92 - struct strparser strp; 93 - void (*saved_data_ready)(struct sock *sk); 94 - unsigned int (*sk_poll)(struct file *file, struct socket *sock, 95 - struct poll_table_struct *wait); 96 - struct sk_buff *recv_pkt; 97 - u8 control; 98 - bool decrypted; 99 - 100 - /* Sending context */ 101 90 char aad_space[TLS_AAD_SPACE_SIZE]; 102 91 103 92 unsigned int sg_plaintext_size; ··· 101 112 struct scatterlist sg_aead_in[2]; 102 113 /* AAD | sg_encrypted_data (data contain overhead for hdr&iv&tag) */ 103 114 struct scatterlist sg_aead_out[2]; 115 + }; 116 + 117 + struct tls_sw_context_rx { 118 + struct crypto_aead *aead_recv; 119 + struct crypto_wait async_wait; 120 + 121 + struct strparser strp; 122 + void (*saved_data_ready)(struct sock *sk); 123 + unsigned int (*sk_poll)(struct file *file, struct socket *sock, 124 + struct poll_table_struct *wait); 125 + struct sk_buff *recv_pkt; 126 + u8 control; 127 + bool decrypted; 104 128 }; 105 129 106 130 enum { ··· 140 138 struct tls12_crypto_info_aes_gcm_128 crypto_recv_aes_gcm_128; 141 139 }; 142 140 143 - void *priv_ctx; 141 + struct list_head list; 142 + struct net_device *netdev; 143 + refcount_t refcount; 144 144 145 - u8 conf:3; 145 + void *priv_ctx_tx; 146 + void *priv_ctx_rx; 147 + 148 + u8 tx_conf:3; 149 + u8 rx_conf:3; 146 150 147 151 struct cipher_context tx; 148 152 struct cipher_context rx; ··· 185 177 int tls_sw_sendpage(struct sock *sk, struct page *page, 186 178 int offset, size_t size, int flags); 187 179 void tls_sw_close(struct sock *sk, long timeout); 188 - void tls_sw_free_resources(struct sock *sk); 180 + void tls_sw_free_resources_tx(struct sock *sk); 181 + void tls_sw_free_resources_rx(struct sock *sk); 189 182 int tls_sw_recvmsg(struct sock *sk, struct msghdr *msg, size_t len, 190 183 int nonblock, int flags, int *addr_len); 191 184 unsigned int tls_sw_poll(struct file *file, struct socket *sock, ··· 306 297 return icsk->icsk_ulp_data; 307 298 } 308 299 309 - static inline struct tls_sw_context *tls_sw_ctx( 300 + static inline struct tls_sw_context_rx *tls_sw_ctx_rx( 310 301 const struct tls_context *tls_ctx) 311 302 { 312 - return (struct tls_sw_context *)tls_ctx->priv_ctx; 303 + return (struct tls_sw_context_rx *)tls_ctx->priv_ctx_rx; 304 + } 305 + 306 + static inline struct tls_sw_context_tx *tls_sw_ctx_tx( 307 + const struct tls_context *tls_ctx) 308 + { 309 + return (struct tls_sw_context_tx *)tls_ctx->priv_ctx_tx; 313 310 } 314 311 315 312 static inline struct tls_offload_context *tls_offload_ctx( 316 313 const struct tls_context *tls_ctx) 317 314 { 318 - return (struct tls_offload_context *)tls_ctx->priv_ctx; 315 + return (struct tls_offload_context *)tls_ctx->priv_ctx_tx; 319 316 } 320 317 321 318 int tls_proccess_cmsg(struct sock *sk, struct msghdr *msg,
+47 -48
net/tls/tls_main.c
··· 51 51 TLSV6, 52 52 TLS_NUM_PROTS, 53 53 }; 54 - 55 54 enum { 56 55 TLS_BASE, 57 - TLS_SW_TX, 58 - TLS_SW_RX, 59 - TLS_SW_RXTX, 56 + TLS_SW, 60 57 TLS_HW_RECORD, 61 58 TLS_NUM_CONFIG, 62 59 }; ··· 62 65 static DEFINE_MUTEX(tcpv6_prot_mutex); 63 66 static LIST_HEAD(device_list); 64 67 static DEFINE_MUTEX(device_mutex); 65 - static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG]; 68 + static struct proto tls_prots[TLS_NUM_PROTS][TLS_NUM_CONFIG][TLS_NUM_CONFIG]; 66 69 static struct proto_ops tls_sw_proto_ops; 67 70 68 - static inline void update_sk_prot(struct sock *sk, struct tls_context *ctx) 71 + static void update_sk_prot(struct sock *sk, struct tls_context *ctx) 69 72 { 70 73 int ip_ver = sk->sk_family == AF_INET6 ? TLSV6 : TLSV4; 71 74 72 - sk->sk_prot = &tls_prots[ip_ver][ctx->conf]; 75 + sk->sk_prot = &tls_prots[ip_ver][ctx->tx_conf][ctx->rx_conf]; 73 76 } 74 77 75 78 int wait_on_pending_writer(struct sock *sk, long *timeo) ··· 242 245 lock_sock(sk); 243 246 sk_proto_close = ctx->sk_proto_close; 244 247 245 - if (ctx->conf == TLS_HW_RECORD) 248 + if (ctx->tx_conf == TLS_HW_RECORD && ctx->rx_conf == TLS_HW_RECORD) 246 249 goto skip_tx_cleanup; 247 250 248 - if (ctx->conf == TLS_BASE) { 251 + if (ctx->tx_conf == TLS_BASE && ctx->rx_conf == TLS_BASE) { 249 252 kfree(ctx); 250 253 ctx = NULL; 251 254 goto skip_tx_cleanup; ··· 267 270 } 268 271 } 269 272 270 - kfree(ctx->tx.rec_seq); 271 - kfree(ctx->tx.iv); 272 - kfree(ctx->rx.rec_seq); 273 - kfree(ctx->rx.iv); 273 + /* We need these for tls_sw_fallback handling of other packets */ 274 + if (ctx->tx_conf == TLS_SW) { 275 + kfree(ctx->tx.rec_seq); 276 + kfree(ctx->tx.iv); 277 + tls_sw_free_resources_tx(sk); 278 + } 274 279 275 - if (ctx->conf == TLS_SW_TX || 276 - ctx->conf == TLS_SW_RX || 277 - ctx->conf == TLS_SW_RXTX) { 278 - tls_sw_free_resources(sk); 280 + if (ctx->rx_conf == TLS_SW) { 281 + kfree(ctx->rx.rec_seq); 282 + kfree(ctx->rx.iv); 283 + tls_sw_free_resources_rx(sk); 279 284 } 280 285 281 286 skip_tx_cleanup: ··· 286 287 /* free ctx for TLS_HW_RECORD, used by tcp_set_state 287 288 * for sk->sk_prot->unhash [tls_hw_unhash] 288 289 */ 289 - if (ctx && ctx->conf == TLS_HW_RECORD) 290 + if (ctx && ctx->tx_conf == TLS_HW_RECORD && 291 + ctx->rx_conf == TLS_HW_RECORD) 290 292 kfree(ctx); 291 293 } 292 294 ··· 441 441 goto err_crypto_info; 442 442 } 443 443 444 - /* currently SW is default, we will have ethtool in future */ 445 444 if (tx) { 446 445 rc = tls_set_sw_offload(sk, ctx, 1); 447 - if (ctx->conf == TLS_SW_RX) 448 - conf = TLS_SW_RXTX; 449 - else 450 - conf = TLS_SW_TX; 446 + conf = TLS_SW; 451 447 } else { 452 448 rc = tls_set_sw_offload(sk, ctx, 0); 453 - if (ctx->conf == TLS_SW_TX) 454 - conf = TLS_SW_RXTX; 455 - else 456 - conf = TLS_SW_RX; 449 + conf = TLS_SW; 457 450 } 458 451 459 452 if (rc) 460 453 goto err_crypto_info; 461 454 462 - ctx->conf = conf; 455 + if (tx) 456 + ctx->tx_conf = conf; 457 + else 458 + ctx->rx_conf = conf; 463 459 update_sk_prot(sk, ctx); 464 460 if (tx) { 465 461 ctx->sk_write_space = sk->sk_write_space; ··· 531 535 ctx->hash = sk->sk_prot->hash; 532 536 ctx->unhash = sk->sk_prot->unhash; 533 537 ctx->sk_proto_close = sk->sk_prot->close; 534 - ctx->conf = TLS_HW_RECORD; 538 + ctx->rx_conf = TLS_HW_RECORD; 539 + ctx->tx_conf = TLS_HW_RECORD; 535 540 update_sk_prot(sk, ctx); 536 541 rc = 1; 537 542 break; ··· 576 579 return err; 577 580 } 578 581 579 - static void build_protos(struct proto *prot, struct proto *base) 582 + static void build_protos(struct proto prot[TLS_NUM_CONFIG][TLS_NUM_CONFIG], 583 + struct proto *base) 580 584 { 581 - prot[TLS_BASE] = *base; 582 - prot[TLS_BASE].setsockopt = tls_setsockopt; 583 - prot[TLS_BASE].getsockopt = tls_getsockopt; 584 - prot[TLS_BASE].close = tls_sk_proto_close; 585 + prot[TLS_BASE][TLS_BASE] = *base; 586 + prot[TLS_BASE][TLS_BASE].setsockopt = tls_setsockopt; 587 + prot[TLS_BASE][TLS_BASE].getsockopt = tls_getsockopt; 588 + prot[TLS_BASE][TLS_BASE].close = tls_sk_proto_close; 585 589 586 - prot[TLS_SW_TX] = prot[TLS_BASE]; 587 - prot[TLS_SW_TX].sendmsg = tls_sw_sendmsg; 588 - prot[TLS_SW_TX].sendpage = tls_sw_sendpage; 590 + prot[TLS_SW][TLS_BASE] = prot[TLS_BASE][TLS_BASE]; 591 + prot[TLS_SW][TLS_BASE].sendmsg = tls_sw_sendmsg; 592 + prot[TLS_SW][TLS_BASE].sendpage = tls_sw_sendpage; 589 593 590 - prot[TLS_SW_RX] = prot[TLS_BASE]; 591 - prot[TLS_SW_RX].recvmsg = tls_sw_recvmsg; 592 - prot[TLS_SW_RX].close = tls_sk_proto_close; 594 + prot[TLS_BASE][TLS_SW] = prot[TLS_BASE][TLS_BASE]; 595 + prot[TLS_BASE][TLS_SW].recvmsg = tls_sw_recvmsg; 596 + prot[TLS_BASE][TLS_SW].close = tls_sk_proto_close; 593 597 594 - prot[TLS_SW_RXTX] = prot[TLS_SW_TX]; 595 - prot[TLS_SW_RXTX].recvmsg = tls_sw_recvmsg; 596 - prot[TLS_SW_RXTX].close = tls_sk_proto_close; 598 + prot[TLS_SW][TLS_SW] = prot[TLS_SW][TLS_BASE]; 599 + prot[TLS_SW][TLS_SW].recvmsg = tls_sw_recvmsg; 600 + prot[TLS_SW][TLS_SW].close = tls_sk_proto_close; 597 601 598 - prot[TLS_HW_RECORD] = *base; 599 - prot[TLS_HW_RECORD].hash = tls_hw_hash; 600 - prot[TLS_HW_RECORD].unhash = tls_hw_unhash; 601 - prot[TLS_HW_RECORD].close = tls_sk_proto_close; 602 + prot[TLS_HW_RECORD][TLS_HW_RECORD] = *base; 603 + prot[TLS_HW_RECORD][TLS_HW_RECORD].hash = tls_hw_hash; 604 + prot[TLS_HW_RECORD][TLS_HW_RECORD].unhash = tls_hw_unhash; 605 + prot[TLS_HW_RECORD][TLS_HW_RECORD].close = tls_sk_proto_close; 602 606 } 603 607 604 608 static int tls_init(struct sock *sk) ··· 641 643 mutex_unlock(&tcpv6_prot_mutex); 642 644 } 643 645 644 - ctx->conf = TLS_BASE; 646 + ctx->tx_conf = TLS_BASE; 647 + ctx->rx_conf = TLS_BASE; 645 648 update_sk_prot(sk, ctx); 646 649 out: 647 650 return rc;
+78 -58
net/tls/tls_sw.c
··· 52 52 gfp_t flags) 53 53 { 54 54 struct tls_context *tls_ctx = tls_get_ctx(sk); 55 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 55 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 56 56 struct strp_msg *rxm = strp_msg(skb); 57 57 struct aead_request *aead_req; 58 58 ··· 122 122 static void trim_both_sgl(struct sock *sk, int target_size) 123 123 { 124 124 struct tls_context *tls_ctx = tls_get_ctx(sk); 125 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 125 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 126 126 127 127 trim_sg(sk, ctx->sg_plaintext_data, 128 128 &ctx->sg_plaintext_num_elem, ··· 141 141 static int alloc_encrypted_sg(struct sock *sk, int len) 142 142 { 143 143 struct tls_context *tls_ctx = tls_get_ctx(sk); 144 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 144 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 145 145 int rc = 0; 146 146 147 147 rc = sk_alloc_sg(sk, len, ··· 155 155 static int alloc_plaintext_sg(struct sock *sk, int len) 156 156 { 157 157 struct tls_context *tls_ctx = tls_get_ctx(sk); 158 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 158 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 159 159 int rc = 0; 160 160 161 161 rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0, ··· 181 181 static void tls_free_both_sg(struct sock *sk) 182 182 { 183 183 struct tls_context *tls_ctx = tls_get_ctx(sk); 184 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 184 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 185 185 186 186 free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem, 187 187 &ctx->sg_encrypted_size); ··· 191 191 } 192 192 193 193 static int tls_do_encryption(struct tls_context *tls_ctx, 194 - struct tls_sw_context *ctx, size_t data_len, 194 + struct tls_sw_context_tx *ctx, size_t data_len, 195 195 gfp_t flags) 196 196 { 197 197 unsigned int req_size = sizeof(struct aead_request) + ··· 227 227 unsigned char record_type) 228 228 { 229 229 struct tls_context *tls_ctx = tls_get_ctx(sk); 230 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 230 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 231 231 int rc; 232 232 233 233 sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1); ··· 339 339 int bytes) 340 340 { 341 341 struct tls_context *tls_ctx = tls_get_ctx(sk); 342 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 342 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 343 343 struct scatterlist *sg = ctx->sg_plaintext_data; 344 344 int copy, i, rc = 0; 345 345 ··· 367 367 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size) 368 368 { 369 369 struct tls_context *tls_ctx = tls_get_ctx(sk); 370 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 370 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 371 371 int ret = 0; 372 372 int required_size; 373 373 long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT); ··· 522 522 int offset, size_t size, int flags) 523 523 { 524 524 struct tls_context *tls_ctx = tls_get_ctx(sk); 525 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 525 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 526 526 int ret = 0; 527 527 long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT); 528 528 bool eor; ··· 636 636 long timeo, int *err) 637 637 { 638 638 struct tls_context *tls_ctx = tls_get_ctx(sk); 639 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 639 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 640 640 struct sk_buff *skb; 641 641 DEFINE_WAIT_FUNC(wait, woken_wake_function); 642 642 ··· 674 674 struct scatterlist *sgout) 675 675 { 676 676 struct tls_context *tls_ctx = tls_get_ctx(sk); 677 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 677 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 678 678 char iv[TLS_CIPHER_AES_GCM_128_SALT_SIZE + MAX_IV_SIZE]; 679 679 struct scatterlist sgin_arr[MAX_SKB_FRAGS + 2]; 680 680 struct scatterlist *sgin = &sgin_arr[0]; ··· 723 723 unsigned int len) 724 724 { 725 725 struct tls_context *tls_ctx = tls_get_ctx(sk); 726 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 726 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 727 727 struct strp_msg *rxm = strp_msg(skb); 728 728 729 729 if (len < rxm->full_len) { ··· 749 749 int *addr_len) 750 750 { 751 751 struct tls_context *tls_ctx = tls_get_ctx(sk); 752 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 752 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 753 753 unsigned char control; 754 754 struct strp_msg *rxm; 755 755 struct sk_buff *skb; ··· 869 869 size_t len, unsigned int flags) 870 870 { 871 871 struct tls_context *tls_ctx = tls_get_ctx(sock->sk); 872 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 872 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 873 873 struct strp_msg *rxm = NULL; 874 874 struct sock *sk = sock->sk; 875 875 struct sk_buff *skb; ··· 922 922 unsigned int ret; 923 923 struct sock *sk = sock->sk; 924 924 struct tls_context *tls_ctx = tls_get_ctx(sk); 925 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 925 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 926 926 927 927 /* Grab POLLOUT and POLLHUP from the underlying socket */ 928 928 ret = ctx->sk_poll(file, sock, wait); ··· 938 938 static int tls_read_size(struct strparser *strp, struct sk_buff *skb) 939 939 { 940 940 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 941 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 941 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 942 942 char header[tls_ctx->rx.prepend_size]; 943 943 struct strp_msg *rxm = strp_msg(skb); 944 944 size_t cipher_overhead; ··· 987 987 static void tls_queue(struct strparser *strp, struct sk_buff *skb) 988 988 { 989 989 struct tls_context *tls_ctx = tls_get_ctx(strp->sk); 990 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 990 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 991 991 struct strp_msg *rxm; 992 992 993 993 rxm = strp_msg(skb); ··· 1003 1003 static void tls_data_ready(struct sock *sk) 1004 1004 { 1005 1005 struct tls_context *tls_ctx = tls_get_ctx(sk); 1006 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 1006 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1007 1007 1008 1008 strp_data_ready(&ctx->strp); 1009 1009 } 1010 1010 1011 - void tls_sw_free_resources(struct sock *sk) 1011 + void tls_sw_free_resources_tx(struct sock *sk) 1012 1012 { 1013 1013 struct tls_context *tls_ctx = tls_get_ctx(sk); 1014 - struct tls_sw_context *ctx = tls_sw_ctx(tls_ctx); 1014 + struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx); 1015 1015 1016 1016 if (ctx->aead_send) 1017 1017 crypto_free_aead(ctx->aead_send); 1018 + tls_free_both_sg(sk); 1019 + 1020 + kfree(ctx); 1021 + } 1022 + 1023 + void tls_sw_free_resources_rx(struct sock *sk) 1024 + { 1025 + struct tls_context *tls_ctx = tls_get_ctx(sk); 1026 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1027 + 1018 1028 if (ctx->aead_recv) { 1019 1029 if (ctx->recv_pkt) { 1020 1030 kfree_skb(ctx->recv_pkt); ··· 1040 1030 lock_sock(sk); 1041 1031 } 1042 1032 1043 - tls_free_both_sg(sk); 1044 - 1045 1033 kfree(ctx); 1046 - kfree(tls_ctx); 1047 1034 } 1048 1035 1049 1036 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx) ··· 1048 1041 char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE]; 1049 1042 struct tls_crypto_info *crypto_info; 1050 1043 struct tls12_crypto_info_aes_gcm_128 *gcm_128_info; 1051 - struct tls_sw_context *sw_ctx; 1044 + struct tls_sw_context_tx *sw_ctx_tx = NULL; 1045 + struct tls_sw_context_rx *sw_ctx_rx = NULL; 1052 1046 struct cipher_context *cctx; 1053 1047 struct crypto_aead **aead; 1054 1048 struct strp_callbacks cb; ··· 1062 1054 goto out; 1063 1055 } 1064 1056 1065 - if (!ctx->priv_ctx) { 1066 - sw_ctx = kzalloc(sizeof(*sw_ctx), GFP_KERNEL); 1067 - if (!sw_ctx) { 1057 + if (tx) { 1058 + sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL); 1059 + if (!sw_ctx_tx) { 1068 1060 rc = -ENOMEM; 1069 1061 goto out; 1070 1062 } 1071 - crypto_init_wait(&sw_ctx->async_wait); 1063 + crypto_init_wait(&sw_ctx_tx->async_wait); 1064 + ctx->priv_ctx_tx = sw_ctx_tx; 1072 1065 } else { 1073 - sw_ctx = ctx->priv_ctx; 1066 + sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL); 1067 + if (!sw_ctx_rx) { 1068 + rc = -ENOMEM; 1069 + goto out; 1070 + } 1071 + crypto_init_wait(&sw_ctx_rx->async_wait); 1072 + ctx->priv_ctx_rx = sw_ctx_rx; 1074 1073 } 1075 - 1076 - ctx->priv_ctx = (struct tls_offload_context *)sw_ctx; 1077 1074 1078 1075 if (tx) { 1079 1076 crypto_info = &ctx->crypto_send; 1080 1077 cctx = &ctx->tx; 1081 - aead = &sw_ctx->aead_send; 1078 + aead = &sw_ctx_tx->aead_send; 1082 1079 } else { 1083 1080 crypto_info = &ctx->crypto_recv; 1084 1081 cctx = &ctx->rx; 1085 - aead = &sw_ctx->aead_recv; 1082 + aead = &sw_ctx_rx->aead_recv; 1086 1083 } 1087 1084 1088 1085 switch (crypto_info->cipher_type) { ··· 1134 1121 } 1135 1122 memcpy(cctx->rec_seq, rec_seq, rec_seq_size); 1136 1123 1137 - if (tx) { 1138 - sg_init_table(sw_ctx->sg_encrypted_data, 1139 - ARRAY_SIZE(sw_ctx->sg_encrypted_data)); 1140 - sg_init_table(sw_ctx->sg_plaintext_data, 1141 - ARRAY_SIZE(sw_ctx->sg_plaintext_data)); 1124 + if (sw_ctx_tx) { 1125 + sg_init_table(sw_ctx_tx->sg_encrypted_data, 1126 + ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data)); 1127 + sg_init_table(sw_ctx_tx->sg_plaintext_data, 1128 + ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data)); 1142 1129 1143 - sg_init_table(sw_ctx->sg_aead_in, 2); 1144 - sg_set_buf(&sw_ctx->sg_aead_in[0], sw_ctx->aad_space, 1145 - sizeof(sw_ctx->aad_space)); 1146 - sg_unmark_end(&sw_ctx->sg_aead_in[1]); 1147 - sg_chain(sw_ctx->sg_aead_in, 2, sw_ctx->sg_plaintext_data); 1148 - sg_init_table(sw_ctx->sg_aead_out, 2); 1149 - sg_set_buf(&sw_ctx->sg_aead_out[0], sw_ctx->aad_space, 1150 - sizeof(sw_ctx->aad_space)); 1151 - sg_unmark_end(&sw_ctx->sg_aead_out[1]); 1152 - sg_chain(sw_ctx->sg_aead_out, 2, sw_ctx->sg_encrypted_data); 1130 + sg_init_table(sw_ctx_tx->sg_aead_in, 2); 1131 + sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space, 1132 + sizeof(sw_ctx_tx->aad_space)); 1133 + sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]); 1134 + sg_chain(sw_ctx_tx->sg_aead_in, 2, 1135 + sw_ctx_tx->sg_plaintext_data); 1136 + sg_init_table(sw_ctx_tx->sg_aead_out, 2); 1137 + sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space, 1138 + sizeof(sw_ctx_tx->aad_space)); 1139 + sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]); 1140 + sg_chain(sw_ctx_tx->sg_aead_out, 2, 1141 + sw_ctx_tx->sg_encrypted_data); 1153 1142 } 1154 1143 1155 1144 if (!*aead) { ··· 1176 1161 if (rc) 1177 1162 goto free_aead; 1178 1163 1179 - if (!tx) { 1164 + if (sw_ctx_rx) { 1180 1165 /* Set up strparser */ 1181 1166 memset(&cb, 0, sizeof(cb)); 1182 1167 cb.rcv_msg = tls_queue; 1183 1168 cb.parse_msg = tls_read_size; 1184 1169 1185 - strp_init(&sw_ctx->strp, sk, &cb); 1170 + strp_init(&sw_ctx_rx->strp, sk, &cb); 1186 1171 1187 1172 write_lock_bh(&sk->sk_callback_lock); 1188 - sw_ctx->saved_data_ready = sk->sk_data_ready; 1173 + sw_ctx_rx->saved_data_ready = sk->sk_data_ready; 1189 1174 sk->sk_data_ready = tls_data_ready; 1190 1175 write_unlock_bh(&sk->sk_callback_lock); 1191 1176 1192 - sw_ctx->sk_poll = sk->sk_socket->ops->poll; 1177 + sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll; 1193 1178 1194 - strp_check_rcv(&sw_ctx->strp); 1179 + strp_check_rcv(&sw_ctx_rx->strp); 1195 1180 } 1196 1181 1197 1182 goto out; ··· 1203 1188 kfree(cctx->rec_seq); 1204 1189 cctx->rec_seq = NULL; 1205 1190 free_iv: 1206 - kfree(ctx->tx.iv); 1207 - ctx->tx.iv = NULL; 1191 + kfree(cctx->iv); 1192 + cctx->iv = NULL; 1208 1193 free_priv: 1209 - kfree(ctx->priv_ctx); 1210 - ctx->priv_ctx = NULL; 1194 + if (tx) { 1195 + kfree(ctx->priv_ctx_tx); 1196 + ctx->priv_ctx_tx = NULL; 1197 + } else { 1198 + kfree(ctx->priv_ctx_rx); 1199 + ctx->priv_ctx_rx = NULL; 1200 + } 1211 1201 out: 1212 1202 return rc; 1213 1203 }