Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'in-kernel-support-for-the-tls-alert-protocol'

Chuck Lever says:

====================
In-kernel support for the TLS Alert protocol

IMO the kernel doesn't need user space (ie, tlshd) to handle the TLS
Alert protocol. Instead, a set of small helper functions can be used
to handle sending and receiving TLS Alerts for in-kernel TLS
consumers.
====================

Merged on top of a tag in case it's needed in the NFS tree.

Link: https://lore.kernel.org/r/169047923706.5241.1181144206068116926.stgit@oracle-102.nfsv4bat.org
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

+429 -44
+1
drivers/net/ethernet/chelsio/inline_crypto/chtls/chtls.h
··· 22 22 #include <crypto/internal/hash.h> 23 23 #include <linux/tls.h> 24 24 #include <net/tls.h> 25 + #include <net/tls_prot.h> 25 26 #include <net/tls_toe.h> 26 27 27 28 #include "t4fw_api.h"
+5
include/net/handshake.h
··· 40 40 int tls_server_hello_psk(const struct tls_handshake_args *args, gfp_t flags); 41 41 42 42 bool tls_handshake_cancel(struct sock *sk); 43 + void tls_handshake_close(struct socket *sock); 44 + 45 + u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *msg); 46 + void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, 47 + u8 *level, u8 *description); 43 48 44 49 #endif /* _NET_HANDSHAKE_H */
-4
include/net/tls.h
··· 69 69 70 70 #define TLS_CRYPTO_INFO_READY(info) ((info)->cipher_type) 71 71 72 - #define TLS_RECORD_TYPE_ALERT 0x15 73 - #define TLS_RECORD_TYPE_HANDSHAKE 0x16 74 - #define TLS_RECORD_TYPE_DATA 0x17 75 - 76 72 #define TLS_AAD_SPACE_SIZE 13 77 73 78 74 #define MAX_IV_SIZE 16
+68
include/net/tls_prot.h
··· 1 + /* SPDX-License-Identifier: GPL-2.0 OR BSD-3-Clause */ 2 + /* 3 + * Copyright (c) 2023, Oracle and/or its affiliates. 4 + * 5 + * TLS Protocol definitions 6 + * 7 + * From https://www.iana.org/assignments/tls-parameters/tls-parameters.xhtml 8 + */ 9 + 10 + #ifndef _TLS_PROT_H 11 + #define _TLS_PROT_H 12 + 13 + /* 14 + * TLS Record protocol: ContentType 15 + */ 16 + enum { 17 + TLS_RECORD_TYPE_CHANGE_CIPHER_SPEC = 20, 18 + TLS_RECORD_TYPE_ALERT = 21, 19 + TLS_RECORD_TYPE_HANDSHAKE = 22, 20 + TLS_RECORD_TYPE_DATA = 23, 21 + TLS_RECORD_TYPE_HEARTBEAT = 24, 22 + TLS_RECORD_TYPE_TLS12_CID = 25, 23 + TLS_RECORD_TYPE_ACK = 26, 24 + }; 25 + 26 + /* 27 + * TLS Alert protocol: AlertLevel 28 + */ 29 + enum { 30 + TLS_ALERT_LEVEL_WARNING = 1, 31 + TLS_ALERT_LEVEL_FATAL = 2, 32 + }; 33 + 34 + /* 35 + * TLS Alert protocol: AlertDescription 36 + */ 37 + enum { 38 + TLS_ALERT_DESC_CLOSE_NOTIFY = 0, 39 + TLS_ALERT_DESC_UNEXPECTED_MESSAGE = 10, 40 + TLS_ALERT_DESC_BAD_RECORD_MAC = 20, 41 + TLS_ALERT_DESC_RECORD_OVERFLOW = 22, 42 + TLS_ALERT_DESC_HANDSHAKE_FAILURE = 40, 43 + TLS_ALERT_DESC_BAD_CERTIFICATE = 42, 44 + TLS_ALERT_DESC_UNSUPPORTED_CERTIFICATE = 43, 45 + TLS_ALERT_DESC_CERTIFICATE_REVOKED = 44, 46 + TLS_ALERT_DESC_CERTIFICATE_EXPIRED = 45, 47 + TLS_ALERT_DESC_CERTIFICATE_UNKNOWN = 46, 48 + TLS_ALERT_DESC_ILLEGAL_PARAMETER = 47, 49 + TLS_ALERT_DESC_UNKNOWN_CA = 48, 50 + TLS_ALERT_DESC_ACCESS_DENIED = 49, 51 + TLS_ALERT_DESC_DECODE_ERROR = 50, 52 + TLS_ALERT_DESC_DECRYPT_ERROR = 51, 53 + TLS_ALERT_DESC_TOO_MANY_CIDS_REQUESTED = 52, 54 + TLS_ALERT_DESC_PROTOCOL_VERSION = 70, 55 + TLS_ALERT_DESC_INSUFFICIENT_SECURITY = 71, 56 + TLS_ALERT_DESC_INTERNAL_ERROR = 80, 57 + TLS_ALERT_DESC_INAPPROPRIATE_FALLBACK = 86, 58 + TLS_ALERT_DESC_USER_CANCELED = 90, 59 + TLS_ALERT_DESC_MISSING_EXTENSION = 109, 60 + TLS_ALERT_DESC_UNSUPPORTED_EXTENSION = 110, 61 + TLS_ALERT_DESC_UNRECOGNIZED_NAME = 112, 62 + TLS_ALERT_DESC_BAD_CERTIFICATE_STATUS_RESPONSE = 113, 63 + TLS_ALERT_DESC_UNKNOWN_PSK_IDENTITY = 115, 64 + TLS_ALERT_DESC_CERTIFICATE_REQUIRED = 116, 65 + TLS_ALERT_DESC_NO_APPLICATION_PROTOCOL = 120, 66 + }; 67 + 68 + #endif /* _TLS_PROT_H */
+160
include/trace/events/handshake.h
··· 6 6 #define _TRACE_HANDSHAKE_H 7 7 8 8 #include <linux/net.h> 9 + #include <net/tls_prot.h> 9 10 #include <linux/tracepoint.h> 11 + #include <trace/events/net_probe_common.h> 12 + 13 + #define TLS_RECORD_TYPE_LIST \ 14 + record_type(CHANGE_CIPHER_SPEC) \ 15 + record_type(ALERT) \ 16 + record_type(HANDSHAKE) \ 17 + record_type(DATA) \ 18 + record_type(HEARTBEAT) \ 19 + record_type(TLS12_CID) \ 20 + record_type_end(ACK) 21 + 22 + #undef record_type 23 + #undef record_type_end 24 + #define record_type(x) TRACE_DEFINE_ENUM(TLS_RECORD_TYPE_##x); 25 + #define record_type_end(x) TRACE_DEFINE_ENUM(TLS_RECORD_TYPE_##x); 26 + 27 + TLS_RECORD_TYPE_LIST 28 + 29 + #undef record_type 30 + #undef record_type_end 31 + #define record_type(x) { TLS_RECORD_TYPE_##x, #x }, 32 + #define record_type_end(x) { TLS_RECORD_TYPE_##x, #x } 33 + 34 + #define show_tls_content_type(type) \ 35 + __print_symbolic(type, TLS_RECORD_TYPE_LIST) 36 + 37 + TRACE_DEFINE_ENUM(TLS_ALERT_LEVEL_WARNING); 38 + TRACE_DEFINE_ENUM(TLS_ALERT_LEVEL_FATAL); 39 + 40 + #define show_tls_alert_level(level) \ 41 + __print_symbolic(level, \ 42 + { TLS_ALERT_LEVEL_WARNING, "Warning" }, \ 43 + { TLS_ALERT_LEVEL_FATAL, "Fatal" }) 44 + 45 + #define TLS_ALERT_DESCRIPTION_LIST \ 46 + alert_description(CLOSE_NOTIFY) \ 47 + alert_description(UNEXPECTED_MESSAGE) \ 48 + alert_description(BAD_RECORD_MAC) \ 49 + alert_description(RECORD_OVERFLOW) \ 50 + alert_description(HANDSHAKE_FAILURE) \ 51 + alert_description(BAD_CERTIFICATE) \ 52 + alert_description(UNSUPPORTED_CERTIFICATE) \ 53 + alert_description(CERTIFICATE_REVOKED) \ 54 + alert_description(CERTIFICATE_EXPIRED) \ 55 + alert_description(CERTIFICATE_UNKNOWN) \ 56 + alert_description(ILLEGAL_PARAMETER) \ 57 + alert_description(UNKNOWN_CA) \ 58 + alert_description(ACCESS_DENIED) \ 59 + alert_description(DECODE_ERROR) \ 60 + alert_description(DECRYPT_ERROR) \ 61 + alert_description(TOO_MANY_CIDS_REQUESTED) \ 62 + alert_description(PROTOCOL_VERSION) \ 63 + alert_description(INSUFFICIENT_SECURITY) \ 64 + alert_description(INTERNAL_ERROR) \ 65 + alert_description(INAPPROPRIATE_FALLBACK) \ 66 + alert_description(USER_CANCELED) \ 67 + alert_description(MISSING_EXTENSION) \ 68 + alert_description(UNSUPPORTED_EXTENSION) \ 69 + alert_description(UNRECOGNIZED_NAME) \ 70 + alert_description(BAD_CERTIFICATE_STATUS_RESPONSE) \ 71 + alert_description(UNKNOWN_PSK_IDENTITY) \ 72 + alert_description(CERTIFICATE_REQUIRED) \ 73 + alert_description_end(NO_APPLICATION_PROTOCOL) 74 + 75 + #undef alert_description 76 + #undef alert_description_end 77 + #define alert_description(x) TRACE_DEFINE_ENUM(TLS_ALERT_DESC_##x); 78 + #define alert_description_end(x) TRACE_DEFINE_ENUM(TLS_ALERT_DESC_##x); 79 + 80 + TLS_ALERT_DESCRIPTION_LIST 81 + 82 + #undef alert_description 83 + #undef alert_description_end 84 + #define alert_description(x) { TLS_ALERT_DESC_##x, #x }, 85 + #define alert_description_end(x) { TLS_ALERT_DESC_##x, #x } 86 + 87 + #define show_tls_alert_description(desc) \ 88 + __print_symbolic(desc, TLS_ALERT_DESCRIPTION_LIST) 10 89 11 90 DECLARE_EVENT_CLASS(handshake_event_class, 12 91 TP_PROTO( ··· 185 106 ), \ 186 107 TP_ARGS(net, req, sk, err)) 187 108 109 + DECLARE_EVENT_CLASS(handshake_alert_class, 110 + TP_PROTO( 111 + const struct sock *sk, 112 + unsigned char level, 113 + unsigned char description 114 + ), 115 + TP_ARGS(sk, level, description), 116 + TP_STRUCT__entry( 117 + /* sockaddr_in6 is always bigger than sockaddr_in */ 118 + __array(__u8, saddr, sizeof(struct sockaddr_in6)) 119 + __array(__u8, daddr, sizeof(struct sockaddr_in6)) 120 + __field(unsigned int, netns_ino) 121 + __field(unsigned long, level) 122 + __field(unsigned long, description) 123 + ), 124 + TP_fast_assign( 125 + const struct inet_sock *inet = inet_sk(sk); 126 + 127 + memset(__entry->saddr, 0, sizeof(struct sockaddr_in6)); 128 + memset(__entry->daddr, 0, sizeof(struct sockaddr_in6)); 129 + TP_STORE_ADDR_PORTS(__entry, inet, sk); 130 + 131 + __entry->netns_ino = sock_net(sk)->ns.inum; 132 + __entry->level = level; 133 + __entry->description = description; 134 + ), 135 + TP_printk("src=%pISpc dest=%pISpc %s: %s", 136 + __entry->saddr, __entry->daddr, 137 + show_tls_alert_level(__entry->level), 138 + show_tls_alert_description(__entry->description) 139 + ) 140 + ); 141 + #define DEFINE_HANDSHAKE_ALERT(name) \ 142 + DEFINE_EVENT(handshake_alert_class, name, \ 143 + TP_PROTO( \ 144 + const struct sock *sk, \ 145 + unsigned char level, \ 146 + unsigned char description \ 147 + ), \ 148 + TP_ARGS(sk, level, description)) 149 + 188 150 189 151 /* 190 152 * Request lifetime events ··· 273 153 DEFINE_HANDSHAKE_ERROR(handshake_cmd_accept_err); 274 154 DEFINE_HANDSHAKE_FD_EVENT(handshake_cmd_done); 275 155 DEFINE_HANDSHAKE_ERROR(handshake_cmd_done_err); 156 + 157 + /* 158 + * TLS Record events 159 + */ 160 + 161 + TRACE_EVENT(tls_contenttype, 162 + TP_PROTO( 163 + const struct sock *sk, 164 + unsigned char type 165 + ), 166 + TP_ARGS(sk, type), 167 + TP_STRUCT__entry( 168 + /* sockaddr_in6 is always bigger than sockaddr_in */ 169 + __array(__u8, saddr, sizeof(struct sockaddr_in6)) 170 + __array(__u8, daddr, sizeof(struct sockaddr_in6)) 171 + __field(unsigned int, netns_ino) 172 + __field(unsigned long, type) 173 + ), 174 + TP_fast_assign( 175 + const struct inet_sock *inet = inet_sk(sk); 176 + 177 + memset(__entry->saddr, 0, sizeof(struct sockaddr_in6)); 178 + memset(__entry->daddr, 0, sizeof(struct sockaddr_in6)); 179 + TP_STORE_ADDR_PORTS(__entry, inet, sk); 180 + 181 + __entry->netns_ino = sock_net(sk)->ns.inum; 182 + __entry->type = type; 183 + ), 184 + TP_printk("src=%pISpc dest=%pISpc %s", 185 + __entry->saddr, __entry->daddr, 186 + show_tls_content_type(__entry->type) 187 + ) 188 + ); 189 + 190 + /* 191 + * TLS Alert events 192 + */ 193 + 194 + DEFINE_HANDSHAKE_ALERT(tls_alert_send); 195 + DEFINE_HANDSHAKE_ALERT(tls_alert_recv); 276 196 277 197 #endif /* _TRACE_HANDSHAKE_H */ 278 198
+1 -1
net/handshake/Makefile
··· 8 8 # 9 9 10 10 obj-y += handshake.o 11 - handshake-y := genl.o netlink.o request.o tlshd.o trace.o 11 + handshake-y := alert.o genl.o netlink.o request.o tlshd.o trace.o 12 12 13 13 obj-$(CONFIG_NET_HANDSHAKE_KUNIT_TEST) += handshake-test.o
+110
net/handshake/alert.c
··· 1 + // SPDX-License-Identifier: GPL-2.0-only 2 + /* 3 + * Handle the TLS Alert protocol 4 + * 5 + * Author: Chuck Lever <chuck.lever@oracle.com> 6 + * 7 + * Copyright (c) 2023, Oracle and/or its affiliates. 8 + */ 9 + 10 + #include <linux/types.h> 11 + #include <linux/socket.h> 12 + #include <linux/kernel.h> 13 + #include <linux/module.h> 14 + #include <linux/skbuff.h> 15 + #include <linux/inet.h> 16 + 17 + #include <net/sock.h> 18 + #include <net/handshake.h> 19 + #include <net/tls.h> 20 + #include <net/tls_prot.h> 21 + 22 + #include "handshake.h" 23 + 24 + #include <trace/events/handshake.h> 25 + 26 + /** 27 + * tls_alert_send - send a TLS Alert on a kTLS socket 28 + * @sock: open kTLS socket to send on 29 + * @level: TLS Alert level 30 + * @description: TLS Alert description 31 + * 32 + * Returns zero on success or a negative errno. 33 + */ 34 + int tls_alert_send(struct socket *sock, u8 level, u8 description) 35 + { 36 + u8 record_type = TLS_RECORD_TYPE_ALERT; 37 + u8 buf[CMSG_SPACE(sizeof(record_type))]; 38 + struct msghdr msg = { 0 }; 39 + struct cmsghdr *cmsg; 40 + struct kvec iov; 41 + u8 alert[2]; 42 + int ret; 43 + 44 + trace_tls_alert_send(sock->sk, level, description); 45 + 46 + alert[0] = level; 47 + alert[1] = description; 48 + iov.iov_base = alert; 49 + iov.iov_len = sizeof(alert); 50 + 51 + memset(buf, 0, sizeof(buf)); 52 + msg.msg_control = buf; 53 + msg.msg_controllen = sizeof(buf); 54 + msg.msg_flags = MSG_DONTWAIT; 55 + 56 + cmsg = CMSG_FIRSTHDR(&msg); 57 + cmsg->cmsg_level = SOL_TLS; 58 + cmsg->cmsg_type = TLS_SET_RECORD_TYPE; 59 + cmsg->cmsg_len = CMSG_LEN(sizeof(record_type)); 60 + memcpy(CMSG_DATA(cmsg), &record_type, sizeof(record_type)); 61 + 62 + iov_iter_kvec(&msg.msg_iter, ITER_SOURCE, &iov, 1, iov.iov_len); 63 + ret = sock_sendmsg(sock, &msg); 64 + return ret < 0 ? ret : 0; 65 + } 66 + 67 + /** 68 + * tls_get_record_type - Look for TLS RECORD_TYPE information 69 + * @sk: socket (for IP address information) 70 + * @cmsg: incoming message to be parsed 71 + * 72 + * Returns zero or a TLS_RECORD_TYPE value. 73 + */ 74 + u8 tls_get_record_type(const struct sock *sk, const struct cmsghdr *cmsg) 75 + { 76 + u8 record_type; 77 + 78 + if (cmsg->cmsg_level != SOL_TLS) 79 + return 0; 80 + if (cmsg->cmsg_type != TLS_GET_RECORD_TYPE) 81 + return 0; 82 + 83 + record_type = *((u8 *)CMSG_DATA(cmsg)); 84 + trace_tls_contenttype(sk, record_type); 85 + return record_type; 86 + } 87 + EXPORT_SYMBOL(tls_get_record_type); 88 + 89 + /** 90 + * tls_alert_recv - Parse TLS Alert messages 91 + * @sk: socket (for IP address information) 92 + * @msg: incoming message to be parsed 93 + * @level: OUT - TLS AlertLevel value 94 + * @description: OUT - TLS AlertDescription value 95 + * 96 + */ 97 + void tls_alert_recv(const struct sock *sk, const struct msghdr *msg, 98 + u8 *level, u8 *description) 99 + { 100 + const struct kvec *iov; 101 + u8 *data; 102 + 103 + iov = msg->msg_iter.kvec; 104 + data = iov->iov_base; 105 + *level = data[0]; 106 + *description = data[1]; 107 + 108 + trace_tls_alert_recv(sk, *level, *description); 109 + } 110 + EXPORT_SYMBOL(tls_alert_recv);
+6
net/handshake/handshake.h
··· 41 41 42 42 enum hr_flags_bits { 43 43 HANDSHAKE_F_REQ_COMPLETED, 44 + HANDSHAKE_F_REQ_SESSION, 44 45 }; 46 + 47 + struct genl_info; 45 48 46 49 /* Invariants for all handshake requests for one transport layer 47 50 * security protocol ··· 65 62 enum hp_flags_bits { 66 63 HANDSHAKE_F_PROTO_NOTIFY, 67 64 }; 65 + 66 + /* alert.c */ 67 + int tls_alert_send(struct socket *sock, u8 level, u8 description); 68 68 69 69 /* netlink.c */ 70 70 int handshake_genl_notify(struct net *net, const struct handshake_proto *proto,
+23
net/handshake/tlshd.c
··· 18 18 #include <net/sock.h> 19 19 #include <net/handshake.h> 20 20 #include <net/genetlink.h> 21 + #include <net/tls_prot.h> 21 22 22 23 #include <uapi/linux/keyctl.h> 23 24 #include <uapi/linux/handshake.h> ··· 100 99 treq->th_peerid[0] = TLS_NO_PEERID; 101 100 if (info) 102 101 tls_handshake_remote_peerids(treq, info); 102 + 103 + if (!status) 104 + set_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags); 103 105 104 106 treq->th_consumer_done(treq->th_consumer_data, -status, 105 107 treq->th_peerid[0]); ··· 428 424 return handshake_req_cancel(sk); 429 425 } 430 426 EXPORT_SYMBOL(tls_handshake_cancel); 427 + 428 + /** 429 + * tls_handshake_close - send a Closure alert 430 + * @sock: an open socket 431 + * 432 + */ 433 + void tls_handshake_close(struct socket *sock) 434 + { 435 + struct handshake_req *req; 436 + 437 + req = handshake_req_hash_lookup(sock->sk); 438 + if (!req) 439 + return; 440 + if (!test_and_clear_bit(HANDSHAKE_F_REQ_SESSION, &req->hr_flags)) 441 + return; 442 + tls_alert_send(sock, TLS_ALERT_LEVEL_WARNING, 443 + TLS_ALERT_DESC_CLOSE_NOTIFY); 444 + } 445 + EXPORT_SYMBOL(tls_handshake_close);
+2
net/handshake/trace.c
··· 8 8 */ 9 9 10 10 #include <linux/types.h> 11 + #include <linux/ipv6.h> 11 12 12 13 #include <net/sock.h> 14 + #include <net/inet_sock.h> 13 15 #include <net/netlink.h> 14 16 #include <net/genetlink.h> 15 17
+27 -21
net/sunrpc/svcsock.c
··· 43 43 #include <net/udp.h> 44 44 #include <net/tcp.h> 45 45 #include <net/tcp_states.h> 46 - #include <net/tls.h> 46 + #include <net/tls_prot.h> 47 47 #include <net/handshake.h> 48 48 #include <linux/uaccess.h> 49 49 #include <linux/highmem.h> ··· 226 226 } 227 227 228 228 static int 229 - svc_tcp_sock_process_cmsg(struct svc_sock *svsk, struct msghdr *msg, 229 + svc_tcp_sock_process_cmsg(struct socket *sock, struct msghdr *msg, 230 230 struct cmsghdr *cmsg, int ret) 231 231 { 232 - if (cmsg->cmsg_level == SOL_TLS && 233 - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { 234 - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); 232 + u8 content_type = tls_get_record_type(sock->sk, cmsg); 233 + u8 level, description; 235 234 236 - switch (content_type) { 237 - case TLS_RECORD_TYPE_DATA: 238 - /* TLS sets EOR at the end of each application data 239 - * record, even though there might be more frames 240 - * waiting to be decrypted. 241 - */ 242 - msg->msg_flags &= ~MSG_EOR; 243 - break; 244 - case TLS_RECORD_TYPE_ALERT: 245 - ret = -ENOTCONN; 246 - break; 247 - default: 248 - ret = -EAGAIN; 249 - } 235 + switch (content_type) { 236 + case 0: 237 + break; 238 + case TLS_RECORD_TYPE_DATA: 239 + /* TLS sets EOR at the end of each application data 240 + * record, even though there might be more frames 241 + * waiting to be decrypted. 242 + */ 243 + msg->msg_flags &= ~MSG_EOR; 244 + break; 245 + case TLS_RECORD_TYPE_ALERT: 246 + tls_alert_recv(sock->sk, msg, &level, &description); 247 + ret = (level == TLS_ALERT_LEVEL_FATAL) ? 248 + -ENOTCONN : -EAGAIN; 249 + break; 250 + default: 251 + /* discard this record type */ 252 + ret = -EAGAIN; 250 253 } 251 254 return ret; 252 255 } ··· 261 258 struct cmsghdr cmsg; 262 259 u8 buf[CMSG_SPACE(sizeof(u8))]; 263 260 } u; 261 + struct socket *sock = svsk->sk_sock; 264 262 int ret; 265 263 266 264 msg->msg_control = &u; 267 265 msg->msg_controllen = sizeof(u); 268 - ret = sock_recvmsg(svsk->sk_sock, msg, MSG_DONTWAIT); 266 + ret = sock_recvmsg(sock, msg, MSG_DONTWAIT); 269 267 if (unlikely(msg->msg_controllen != sizeof(u))) 270 - ret = svc_tcp_sock_process_cmsg(svsk, msg, &u.cmsg, ret); 268 + ret = svc_tcp_sock_process_cmsg(sock, msg, &u.cmsg, ret); 271 269 return ret; 272 270 } 273 271 ··· 1624 1620 static void svc_tcp_sock_detach(struct svc_xprt *xprt) 1625 1621 { 1626 1622 struct svc_sock *svsk = container_of(xprt, struct svc_sock, sk_xprt); 1623 + 1624 + tls_handshake_close(svsk->sk_sock); 1627 1625 1628 1626 svc_sock_detach(xprt); 1629 1627
+25 -18
net/sunrpc/xprtsock.c
··· 47 47 #include <net/checksum.h> 48 48 #include <net/udp.h> 49 49 #include <net/tcp.h> 50 - #include <net/tls.h> 50 + #include <net/tls_prot.h> 51 51 #include <net/handshake.h> 52 52 53 53 #include <linux/bvec.h> ··· 360 360 xs_sock_process_cmsg(struct socket *sock, struct msghdr *msg, 361 361 struct cmsghdr *cmsg, int ret) 362 362 { 363 - if (cmsg->cmsg_level == SOL_TLS && 364 - cmsg->cmsg_type == TLS_GET_RECORD_TYPE) { 365 - u8 content_type = *((u8 *)CMSG_DATA(cmsg)); 363 + u8 content_type = tls_get_record_type(sock->sk, cmsg); 364 + u8 level, description; 366 365 367 - switch (content_type) { 368 - case TLS_RECORD_TYPE_DATA: 369 - /* TLS sets EOR at the end of each application data 370 - * record, even though there might be more frames 371 - * waiting to be decrypted. 372 - */ 373 - msg->msg_flags &= ~MSG_EOR; 374 - break; 375 - case TLS_RECORD_TYPE_ALERT: 376 - ret = -ENOTCONN; 377 - break; 378 - default: 379 - ret = -EAGAIN; 380 - } 366 + switch (content_type) { 367 + case 0: 368 + break; 369 + case TLS_RECORD_TYPE_DATA: 370 + /* TLS sets EOR at the end of each application data 371 + * record, even though there might be more frames 372 + * waiting to be decrypted. 373 + */ 374 + msg->msg_flags &= ~MSG_EOR; 375 + break; 376 + case TLS_RECORD_TYPE_ALERT: 377 + tls_alert_recv(sock->sk, msg, &level, &description); 378 + ret = (level == TLS_ALERT_LEVEL_FATAL) ? 379 + -EACCES : -EAGAIN; 380 + break; 381 + default: 382 + /* discard this record type */ 383 + ret = -EAGAIN; 381 384 } 382 385 return ret; 383 386 } ··· 780 777 } 781 778 if (ret == -ESHUTDOWN) 782 779 kernel_sock_shutdown(transport->sock, SHUT_RDWR); 780 + else if (ret == -EACCES) 781 + xprt_wake_pending_tasks(&transport->xprt, -EACCES); 783 782 else 784 783 xs_poll_check_readable(transport); 785 784 out: ··· 1297 1292 1298 1293 dprintk("RPC: xs_close xprt %p\n", xprt); 1299 1294 1295 + if (transport->sock) 1296 + tls_handshake_close(transport->sock); 1300 1297 xs_reset_transport(transport); 1301 1298 xprt->reestablish_timeout = 0; 1302 1299 }
+1
net/tls/tls.h
··· 39 39 #include <linux/types.h> 40 40 #include <linux/skmsg.h> 41 41 #include <net/tls.h> 42 + #include <net/tls_prot.h> 42 43 43 44 #define TLS_PAGE_ORDER (min_t(unsigned int, PAGE_ALLOC_COSTLY_ORDER, \ 44 45 TLS_MAX_PAYLOAD_SIZE >> PAGE_SHIFT))