Linux kernel mirror (for testing) git.kernel.org/pub/scm/linux/kernel/git/torvalds/linux.git
kernel os linux

Merge branch 'net-tls-fixes-for-nvme-over-tls'

Hannes Reinecke says:

====================
net/tls: fixes for NVMe-over-TLS

here are some small fixes to get NVMe-over-TLS up and running.
The first set are just minor modifications to have MSG_EOR handled
for TLS, but the second set implements the ->read_sock() callback
for tls_sw.
The ->read_sock() callbacks return -EIO when encountering any TLS
Alert message, but as that's the default behaviour anyway I guess
we can get away with it.
====================

Applied on top of the tag in case Sagi gets convinced to pull it.

Link: https://lore.kernel.org/r/20230726191556.41714-1-hare@suse.de
Signed-off-by: Jakub Kicinski <kuba@kernel.org>

+146 -20
+2
net/tls/tls.h
··· 110 110 ssize_t tls_sw_splice_read(struct socket *sock, loff_t *ppos, 111 111 struct pipe_inode_info *pipe, 112 112 size_t len, unsigned int flags); 113 + int tls_sw_read_sock(struct sock *sk, read_descriptor_t *desc, 114 + sk_read_actor_t read_actor); 113 115 114 116 int tls_device_sendmsg(struct sock *sk, struct msghdr *msg, size_t size); 115 117 void tls_device_splice_eof(struct socket *sock);
+5 -1
net/tls/tls_device.c
··· 441 441 long timeo; 442 442 443 443 if (flags & 444 - ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | MSG_SPLICE_PAGES)) 444 + ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 445 + MSG_SPLICE_PAGES | MSG_EOR)) 445 446 return -EOPNOTSUPP; 447 + 448 + if ((flags & (MSG_MORE | MSG_EOR)) == (MSG_MORE | MSG_EOR)) 449 + return -EINVAL; 446 450 447 451 if (unlikely(sk->sk_err)) 448 452 return -sk->sk_err;
+2
net/tls/tls_main.c
··· 962 962 ops[TLS_BASE][TLS_SW ] = ops[TLS_BASE][TLS_BASE]; 963 963 ops[TLS_BASE][TLS_SW ].splice_read = tls_sw_splice_read; 964 964 ops[TLS_BASE][TLS_SW ].poll = tls_sk_poll; 965 + ops[TLS_BASE][TLS_SW ].read_sock = tls_sw_read_sock; 965 966 966 967 ops[TLS_SW ][TLS_SW ] = ops[TLS_SW ][TLS_BASE]; 967 968 ops[TLS_SW ][TLS_SW ].splice_read = tls_sw_splice_read; 968 969 ops[TLS_SW ][TLS_SW ].poll = tls_sk_poll; 970 + ops[TLS_SW ][TLS_SW ].read_sock = tls_sw_read_sock; 969 971 970 972 #ifdef CONFIG_TLS_DEVICE 971 973 ops[TLS_HW ][TLS_BASE] = ops[TLS_BASE][TLS_BASE];
+1 -2
net/tls/tls_strp.c
··· 369 369 370 370 static int tls_strp_read_copyin(struct tls_strparser *strp) 371 371 { 372 - struct socket *sock = strp->sk->sk_socket; 373 372 read_descriptor_t desc; 374 373 375 374 desc.arg.data = strp; ··· 376 377 desc.count = 1; /* give more than one skb per call */ 377 378 378 379 /* sk should be locked here, so okay to do read_sock */ 379 - sock->ops->read_sock(strp->sk, &desc, tls_strp_copyin); 380 + tcp_read_sock(strp->sk, &desc, tls_strp_copyin); 380 381 381 382 return desc.error; 382 383 }
+125 -17
net/tls/tls_sw.c
··· 984 984 int ret = 0; 985 985 int pending; 986 986 987 + if (!eor && (msg->msg_flags & MSG_EOR)) 988 + return -EINVAL; 989 + 987 990 if (unlikely(msg->msg_controllen)) { 988 991 ret = tls_process_cmsg(sk, msg, &record_type); 989 992 if (ret) { ··· 1196 1193 int ret; 1197 1194 1198 1195 if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL | 1199 - MSG_CMSG_COMPAT | MSG_SPLICE_PAGES | 1196 + MSG_CMSG_COMPAT | MSG_SPLICE_PAGES | MSG_EOR | 1200 1197 MSG_SENDPAGE_NOPOLICY)) 1201 1198 return -EOPNOTSUPP; 1202 1199 ··· 1848 1845 return sk_flush_backlog(sk); 1849 1846 } 1850 1847 1851 - static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, 1852 - bool nonblock) 1848 + static int tls_rx_reader_acquire(struct sock *sk, struct tls_sw_context_rx *ctx, 1849 + bool nonblock) 1853 1850 { 1854 1851 long timeo; 1855 - int err; 1856 - 1857 - lock_sock(sk); 1858 1852 1859 1853 timeo = sock_rcvtimeo(sk, nonblock); 1860 1854 ··· 1865 1865 !READ_ONCE(ctx->reader_present), &wait); 1866 1866 remove_wait_queue(&ctx->wq, &wait); 1867 1867 1868 - if (timeo <= 0) { 1869 - err = -EAGAIN; 1870 - goto err_unlock; 1871 - } 1872 - if (signal_pending(current)) { 1873 - err = sock_intr_errno(timeo); 1874 - goto err_unlock; 1875 - } 1868 + if (timeo <= 0) 1869 + return -EAGAIN; 1870 + if (signal_pending(current)) 1871 + return sock_intr_errno(timeo); 1876 1872 } 1877 1873 1878 1874 WRITE_ONCE(ctx->reader_present, 1); 1879 1875 1880 1876 return 0; 1877 + } 1881 1878 1882 - err_unlock: 1883 - release_sock(sk); 1879 + static int tls_rx_reader_lock(struct sock *sk, struct tls_sw_context_rx *ctx, 1880 + bool nonblock) 1881 + { 1882 + int err; 1883 + 1884 + lock_sock(sk); 1885 + err = tls_rx_reader_acquire(sk, ctx, nonblock); 1886 + if (err) 1887 + release_sock(sk); 1884 1888 return err; 1885 1889 } 1886 1890 1887 - static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) 1891 + static void tls_rx_reader_release(struct sock *sk, struct tls_sw_context_rx *ctx) 1888 1892 { 1889 1893 if (unlikely(ctx->reader_contended)) { 1890 1894 if (wq_has_sleeper(&ctx->wq)) ··· 1900 1896 } 1901 1897 1902 1898 WRITE_ONCE(ctx->reader_present, 0); 1899 + } 1900 + 1901 + static void tls_rx_reader_unlock(struct sock *sk, struct tls_sw_context_rx *ctx) 1902 + { 1903 + tls_rx_reader_release(sk, ctx); 1903 1904 release_sock(sk); 1904 1905 } 1905 1906 ··· 2200 2191 splice_requeue: 2201 2192 __skb_queue_head(&ctx->rx_list, skb); 2202 2193 goto splice_read_end; 2194 + } 2195 + 2196 + int tls_sw_read_sock(struct sock *sk, read_descriptor_t *desc, 2197 + sk_read_actor_t read_actor) 2198 + { 2199 + struct tls_context *tls_ctx = tls_get_ctx(sk); 2200 + struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx); 2201 + struct tls_prot_info *prot = &tls_ctx->prot_info; 2202 + struct strp_msg *rxm = NULL; 2203 + struct sk_buff *skb = NULL; 2204 + struct sk_psock *psock; 2205 + size_t flushed_at = 0; 2206 + bool released = true; 2207 + struct tls_msg *tlm; 2208 + ssize_t copied = 0; 2209 + ssize_t decrypted; 2210 + int err, used; 2211 + 2212 + psock = sk_psock_get(sk); 2213 + if (psock) { 2214 + sk_psock_put(sk, psock); 2215 + return -EINVAL; 2216 + } 2217 + err = tls_rx_reader_acquire(sk, ctx, true); 2218 + if (err < 0) 2219 + return err; 2220 + 2221 + /* If crypto failed the connection is broken */ 2222 + err = ctx->async_wait.err; 2223 + if (err) 2224 + goto read_sock_end; 2225 + 2226 + decrypted = 0; 2227 + do { 2228 + if (!skb_queue_empty(&ctx->rx_list)) { 2229 + skb = __skb_dequeue(&ctx->rx_list); 2230 + rxm = strp_msg(skb); 2231 + tlm = tls_msg(skb); 2232 + } else { 2233 + struct tls_decrypt_arg darg; 2234 + int to_decrypt; 2235 + 2236 + err = tls_rx_rec_wait(sk, NULL, true, released); 2237 + if (err <= 0) 2238 + goto read_sock_end; 2239 + 2240 + memset(&darg.inargs, 0, sizeof(darg.inargs)); 2241 + 2242 + rxm = strp_msg(tls_strp_msg(ctx)); 2243 + tlm = tls_msg(tls_strp_msg(ctx)); 2244 + 2245 + to_decrypt = rxm->full_len - prot->overhead_size; 2246 + 2247 + err = tls_rx_one_record(sk, NULL, &darg); 2248 + if (err < 0) { 2249 + tls_err_abort(sk, -EBADMSG); 2250 + goto read_sock_end; 2251 + } 2252 + 2253 + released = tls_read_flush_backlog(sk, prot, rxm->full_len, to_decrypt, 2254 + decrypted, &flushed_at); 2255 + skb = darg.skb; 2256 + decrypted += rxm->full_len; 2257 + 2258 + tls_rx_rec_done(ctx); 2259 + } 2260 + 2261 + /* read_sock does not support reading control messages */ 2262 + if (tlm->control != TLS_RECORD_TYPE_DATA) { 2263 + err = -EINVAL; 2264 + goto read_sock_requeue; 2265 + } 2266 + 2267 + used = read_actor(desc, skb, rxm->offset, rxm->full_len); 2268 + if (used <= 0) { 2269 + if (!copied) 2270 + err = used; 2271 + goto read_sock_requeue; 2272 + } 2273 + copied += used; 2274 + if (used < rxm->full_len) { 2275 + rxm->offset += used; 2276 + rxm->full_len -= used; 2277 + if (!desc->count) 2278 + goto read_sock_requeue; 2279 + } else { 2280 + consume_skb(skb); 2281 + if (!desc->count) 2282 + skb = NULL; 2283 + } 2284 + } while (skb); 2285 + 2286 + read_sock_end: 2287 + tls_rx_reader_release(sk, ctx); 2288 + return copied ? : err; 2289 + 2290 + read_sock_requeue: 2291 + __skb_queue_head(&ctx->rx_list, skb); 2292 + goto read_sock_end; 2203 2293 } 2204 2294 2205 2295 bool tls_sw_sock_is_readable(struct sock *sk)
+11
tools/testing/selftests/net/tls.c
··· 486 486 EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1); 487 487 } 488 488 489 + TEST_F(tls, msg_eor) 490 + { 491 + char const *test_str = "test_read"; 492 + int send_len = 10; 493 + char buf[10]; 494 + 495 + EXPECT_EQ(send(self->fd, test_str, send_len, MSG_EOR), send_len); 496 + EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len); 497 + EXPECT_EQ(memcmp(buf, test_str, send_len), 0); 498 + } 499 + 489 500 TEST_F(tls, sendmsg_single) 490 501 { 491 502 struct msghdr msg;