Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

tls: rx: read the input skb from ctx->recv_pkt

Callers always pass ctx->recv_pkt into decrypt_skb_update(),
and it propagates it to its callees. This may give someone
the false impression that those functions can accept any valid
skb containing a TLS record. That's not the case, the record
sequence number is read from the context, and they can only
take the next record coming out of the strp.

Let the functions get the skb from the context instead of
passing it in. This will also make it cleaner to return
a different skb than ctx->recv_pkt as the decrypted one
later on.

Since we're touching the definition of decrypt_skb_update()
use this as an opportunity to rename it.

Signed-off-by: Jakub Kicinski <kuba@kernel.org>
Signed-off-by: David S. Miller <davem@davemloft.net>

authored by

Jakub Kicinski and committed by
David S. Miller
541cc48b 8a958732

+42 -34
+8 -6
net/tls/tls.h
··· 118 118 119 119 int tls_process_cmsg(struct sock *sk, struct msghdr *msg, 120 120 unsigned char *record_type); 121 - int decrypt_skb(struct sock *sk, struct sk_buff *skb, 122 - struct scatterlist *sgout); 121 + int decrypt_skb(struct sock *sk, struct scatterlist *sgout); 123 122 124 123 int tls_sw_fallback_init(struct sock *sk, 125 124 struct tls_offload_context_tx *offload_ctx, ··· 131 132 return &scb->tls; 132 133 } 133 134 135 + static inline struct sk_buff *tls_strp_msg(struct tls_sw_context_rx *ctx) 136 + { 137 + return ctx->recv_pkt; 138 + } 139 + 134 140 #ifdef CONFIG_TLS_DEVICE 135 141 int tls_device_init(void); 136 142 void tls_device_cleanup(void); ··· 144 140 int tls_set_device_offload_rx(struct sock *sk, struct tls_context *ctx); 145 141 void tls_device_offload_cleanup_rx(struct sock *sk); 146 142 void tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq); 147 - int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, 148 - struct sk_buff *skb, struct strp_msg *rxm); 143 + int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx); 149 144 #else 150 145 static inline int tls_device_init(void) { return 0; } 151 146 static inline void tls_device_cleanup(void) {} ··· 168 165 tls_device_rx_resync_new_rec(struct sock *sk, u32 rcd_len, u32 seq) {} 169 166 170 167 static inline int 171 - tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, 172 - struct sk_buff *skb, struct strp_msg *rxm) 168 + tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) 173 169 { 174 170 return 0; 175 171 }
+16 -9
net/tls/tls_device.c
··· 889 889 } 890 890 } 891 891 892 - static int tls_device_reencrypt(struct sock *sk, struct sk_buff *skb) 892 + static int 893 + tls_device_reencrypt(struct sock *sk, struct tls_sw_context_rx *sw_ctx) 893 894 { 894 - struct strp_msg *rxm = strp_msg(skb); 895 - int err = 0, offset = rxm->offset, copy, nsg, data_len, pos; 896 - struct sk_buff *skb_iter, *unused; 895 + int err = 0, offset, copy, nsg, data_len, pos; 896 + struct sk_buff *skb, *skb_iter, *unused; 897 897 struct scatterlist sg[1]; 898 + struct strp_msg *rxm; 898 899 char *orig_buf, *buf; 900 + 901 + skb = tls_strp_msg(sw_ctx); 902 + rxm = strp_msg(skb); 903 + offset = rxm->offset; 899 904 900 905 orig_buf = kmalloc(rxm->full_len + TLS_HEADER_SIZE + 901 906 TLS_CIPHER_AES_GCM_128_IV_SIZE, sk->sk_allocation); ··· 924 919 goto free_buf; 925 920 926 921 /* We are interested only in the decrypted data not the auth */ 927 - err = decrypt_skb(sk, skb, sg); 922 + err = decrypt_skb(sk, sg); 928 923 if (err != -EBADMSG) 929 924 goto free_buf; 930 925 else ··· 979 974 return err; 980 975 } 981 976 982 - int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx, 983 - struct sk_buff *skb, struct strp_msg *rxm) 977 + int tls_device_decrypted(struct sock *sk, struct tls_context *tls_ctx) 984 978 { 985 979 struct tls_offload_context_rx *ctx = tls_offload_ctx_rx(tls_ctx); 980 + struct tls_sw_context_rx *sw_ctx = tls_sw_ctx_rx(tls_ctx); 981 + struct sk_buff *skb = tls_strp_msg(sw_ctx); 982 + struct strp_msg *rxm = strp_msg(skb); 986 983 int is_decrypted = skb->decrypted; 987 984 int is_encrypted = !is_decrypted; 988 985 struct sk_buff *skb_iter; ··· 1007 1000 * likely have initial fragments decrypted, and final ones not 1008 1001 * decrypted. We need to reencrypt that single SKB. 1009 1002 */ 1010 - return tls_device_reencrypt(sk, skb); 1003 + return tls_device_reencrypt(sk, sw_ctx); 1011 1004 } 1012 1005 1013 1006 /* Return immediately if the record is either entirely plaintext or ··· 1024 1017 } 1025 1018 1026 1019 ctx->resync_nh_reset = 1; 1027 - return tls_device_reencrypt(sk, skb); 1020 + return tls_device_reencrypt(sk, sw_ctx); 1028 1021 } 1029 1022 1030 1023 static void tls_device_attach(struct tls_context *ctx, struct sock *sk,
+18 -19
net/tls/tls_sw.c
··· 1421 1421 * NULL, then the decryption happens inside skb buffers itself, i.e. 1422 1422 * zero-copy gets disabled and 'darg->zc' is updated. 1423 1423 */ 1424 - static int tls_decrypt_sg(struct sock *sk, struct sk_buff *skb, 1425 - struct iov_iter *out_iov, 1424 + static int tls_decrypt_sg(struct sock *sk, struct iov_iter *out_iov, 1426 1425 struct scatterlist *out_sg, 1427 1426 struct tls_decrypt_arg *darg) 1428 1427 { ··· 1429 1430 struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1430 1431 struct tls_prot_info *prot = &tls_ctx->prot_info; 1431 1432 int n_sgin, n_sgout, aead_size, err, pages = 0; 1433 + struct sk_buff *skb = tls_strp_msg(ctx); 1432 1434 struct strp_msg *rxm = strp_msg(skb); 1433 1435 struct tls_msg *tlm = tls_msg(skb); 1434 1436 struct aead_request *aead_req; ··· 1567 1567 1568 1568 static int 1569 1569 tls_decrypt_device(struct sock *sk, struct tls_context *tls_ctx, 1570 - struct sk_buff *skb, struct tls_decrypt_arg *darg) 1570 + struct tls_decrypt_arg *darg) 1571 1571 { 1572 1572 int err; 1573 1573 1574 1574 if (tls_ctx->rx_conf != TLS_HW) 1575 1575 return 0; 1576 1576 1577 - err = tls_device_decrypted(sk, tls_ctx, skb, strp_msg(skb)); 1577 + err = tls_device_decrypted(sk, tls_ctx); 1578 1578 if (err <= 0) 1579 1579 return err; 1580 1580 ··· 1583 1583 return 1; 1584 1584 } 1585 1585 1586 - static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb, 1587 - struct iov_iter *dest, 1588 - struct tls_decrypt_arg *darg) 1586 + static int tls_rx_one_record(struct sock *sk, struct iov_iter *dest, 1587 + struct tls_decrypt_arg *darg) 1589 1588 { 1590 1589 struct tls_context *tls_ctx = tls_get_ctx(sk); 1590 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 1591 1591 struct tls_prot_info *prot = &tls_ctx->prot_info; 1592 - struct strp_msg *rxm = strp_msg(skb); 1592 + struct strp_msg *rxm; 1593 1593 int pad, err; 1594 1594 1595 - err = tls_decrypt_device(sk, tls_ctx, skb, darg); 1595 + err = tls_decrypt_device(sk, tls_ctx, darg); 1596 1596 if (err < 0) 1597 1597 return err; 1598 1598 if (err) 1599 1599 goto decrypt_done; 1600 1600 1601 - err = tls_decrypt_sg(sk, skb, dest, NULL, darg); 1601 + err = tls_decrypt_sg(sk, dest, NULL, darg); 1602 1602 if (err < 0) { 1603 1603 if (err == -EBADMSG) 1604 1604 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTERROR); ··· 1613 1613 if (!darg->tail) 1614 1614 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSRXNOPADVIOL); 1615 1615 TLS_INC_STATS(sock_net(sk), LINUX_MIB_TLSDECRYPTRETRY); 1616 - return decrypt_skb_update(sk, skb, dest, darg); 1616 + return tls_rx_one_record(sk, dest, darg); 1617 1617 } 1618 1618 1619 1619 decrypt_done: 1620 - pad = tls_padding_length(prot, skb, darg); 1620 + pad = tls_padding_length(prot, ctx->recv_pkt, darg); 1621 1621 if (pad < 0) 1622 1622 return pad; 1623 1623 1624 + rxm = strp_msg(ctx->recv_pkt); 1624 1625 rxm->full_len -= pad; 1625 1626 rxm->offset += prot->prepend_size; 1626 1627 rxm->full_len -= prot->overhead_size; ··· 1631 1630 return 0; 1632 1631 } 1633 1632 1634 - int decrypt_skb(struct sock *sk, struct sk_buff *skb, 1635 - struct scatterlist *sgout) 1633 + int decrypt_skb(struct sock *sk, struct scatterlist *sgout) 1636 1634 { 1637 1635 struct tls_decrypt_arg darg = { .zc = true, }; 1638 1636 1639 - return tls_decrypt_sg(sk, skb, NULL, sgout, &darg); 1637 + return tls_decrypt_sg(sk, NULL, sgout, &darg); 1640 1638 } 1641 1639 1642 1640 static int tls_record_content_type(struct msghdr *msg, struct tls_msg *tlm, ··· 1905 1905 else 1906 1906 darg.async = false; 1907 1907 1908 - err = decrypt_skb_update(sk, skb, &msg->msg_iter, &darg); 1908 + err = tls_rx_one_record(sk, &msg->msg_iter, &darg); 1909 1909 if (err < 0) { 1910 1910 tls_err_abort(sk, -EBADMSG); 1911 1911 goto recv_end; ··· 2058 2058 if (err <= 0) 2059 2059 goto splice_read_end; 2060 2060 2061 - skb = ctx->recv_pkt; 2062 - 2063 - err = decrypt_skb_update(sk, skb, NULL, &darg); 2061 + err = tls_rx_one_record(sk, NULL, &darg); 2064 2062 if (err < 0) { 2065 2063 tls_err_abort(sk, -EBADMSG); 2066 2064 goto splice_read_end; 2067 2065 } 2068 2066 2067 + skb = ctx->recv_pkt; 2069 2068 tls_rx_rec_done(ctx); 2070 2069 } 2071 2070