OSDN Git Service

Merge tag 'devicetree-fixes-for-5.17-1' of git://git.kernel.org/pub/scm/linux/kernel...
[uclinux-h8/linux.git] / tools / testing / selftests / net / tls.c
1 // SPDX-License-Identifier: GPL-2.0
2
3 #define _GNU_SOURCE
4
5 #include <arpa/inet.h>
6 #include <errno.h>
7 #include <error.h>
8 #include <fcntl.h>
9 #include <poll.h>
10 #include <stdio.h>
11 #include <stdlib.h>
12 #include <unistd.h>
13
14 #include <linux/tls.h>
15 #include <linux/tcp.h>
16 #include <linux/socket.h>
17
18 #include <sys/types.h>
19 #include <sys/sendfile.h>
20 #include <sys/socket.h>
21 #include <sys/stat.h>
22
23 #include "../kselftest_harness.h"
24
25 #define TLS_PAYLOAD_MAX_LEN 16384
26 #define SOL_TLS 282
27
28 struct tls_crypto_info_keys {
29         union {
30                 struct tls12_crypto_info_aes_gcm_128 aes128;
31                 struct tls12_crypto_info_chacha20_poly1305 chacha20;
32                 struct tls12_crypto_info_sm4_gcm sm4gcm;
33                 struct tls12_crypto_info_sm4_ccm sm4ccm;
34                 struct tls12_crypto_info_aes_ccm_128 aesccm128;
35                 struct tls12_crypto_info_aes_gcm_256 aesgcm256;
36         };
37         size_t len;
38 };
39
40 static void tls_crypto_info_init(uint16_t tls_version, uint16_t cipher_type,
41                                  struct tls_crypto_info_keys *tls12)
42 {
43         memset(tls12, 0, sizeof(*tls12));
44
45         switch (cipher_type) {
46         case TLS_CIPHER_CHACHA20_POLY1305:
47                 tls12->len = sizeof(struct tls12_crypto_info_chacha20_poly1305);
48                 tls12->chacha20.info.version = tls_version;
49                 tls12->chacha20.info.cipher_type = cipher_type;
50                 break;
51         case TLS_CIPHER_AES_GCM_128:
52                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_128);
53                 tls12->aes128.info.version = tls_version;
54                 tls12->aes128.info.cipher_type = cipher_type;
55                 break;
56         case TLS_CIPHER_SM4_GCM:
57                 tls12->len = sizeof(struct tls12_crypto_info_sm4_gcm);
58                 tls12->sm4gcm.info.version = tls_version;
59                 tls12->sm4gcm.info.cipher_type = cipher_type;
60                 break;
61         case TLS_CIPHER_SM4_CCM:
62                 tls12->len = sizeof(struct tls12_crypto_info_sm4_ccm);
63                 tls12->sm4ccm.info.version = tls_version;
64                 tls12->sm4ccm.info.cipher_type = cipher_type;
65                 break;
66         case TLS_CIPHER_AES_CCM_128:
67                 tls12->len = sizeof(struct tls12_crypto_info_aes_ccm_128);
68                 tls12->aesccm128.info.version = tls_version;
69                 tls12->aesccm128.info.cipher_type = cipher_type;
70                 break;
71         case TLS_CIPHER_AES_GCM_256:
72                 tls12->len = sizeof(struct tls12_crypto_info_aes_gcm_256);
73                 tls12->aesgcm256.info.version = tls_version;
74                 tls12->aesgcm256.info.cipher_type = cipher_type;
75                 break;
76         default:
77                 break;
78         }
79 }
80
81 static void memrnd(void *s, size_t n)
82 {
83         int *dword = s;
84         char *byte;
85
86         for (; n >= 4; n -= 4)
87                 *dword++ = rand();
88         byte = (void *)dword;
89         while (n--)
90                 *byte++ = rand();
91 }
92
93 static void ulp_sock_pair(struct __test_metadata *_metadata,
94                           int *fd, int *cfd, bool *notls)
95 {
96         struct sockaddr_in addr;
97         socklen_t len;
98         int sfd, ret;
99
100         *notls = false;
101         len = sizeof(addr);
102
103         addr.sin_family = AF_INET;
104         addr.sin_addr.s_addr = htonl(INADDR_ANY);
105         addr.sin_port = 0;
106
107         *fd = socket(AF_INET, SOCK_STREAM, 0);
108         sfd = socket(AF_INET, SOCK_STREAM, 0);
109
110         ret = bind(sfd, &addr, sizeof(addr));
111         ASSERT_EQ(ret, 0);
112         ret = listen(sfd, 10);
113         ASSERT_EQ(ret, 0);
114
115         ret = getsockname(sfd, &addr, &len);
116         ASSERT_EQ(ret, 0);
117
118         ret = connect(*fd, &addr, sizeof(addr));
119         ASSERT_EQ(ret, 0);
120
121         *cfd = accept(sfd, &addr, &len);
122         ASSERT_GE(*cfd, 0);
123
124         close(sfd);
125
126         ret = setsockopt(*fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
127         if (ret != 0) {
128                 ASSERT_EQ(errno, ENOENT);
129                 *notls = true;
130                 printf("Failure setting TCP_ULP, testing without tls\n");
131                 return;
132         }
133
134         ret = setsockopt(*cfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
135         ASSERT_EQ(ret, 0);
136 }
137
138 /* Produce a basic cmsg */
139 static int tls_send_cmsg(int fd, unsigned char record_type,
140                          void *data, size_t len, int flags)
141 {
142         char cbuf[CMSG_SPACE(sizeof(char))];
143         int cmsg_len = sizeof(char);
144         struct cmsghdr *cmsg;
145         struct msghdr msg;
146         struct iovec vec;
147
148         vec.iov_base = data;
149         vec.iov_len = len;
150         memset(&msg, 0, sizeof(struct msghdr));
151         msg.msg_iov = &vec;
152         msg.msg_iovlen = 1;
153         msg.msg_control = cbuf;
154         msg.msg_controllen = sizeof(cbuf);
155         cmsg = CMSG_FIRSTHDR(&msg);
156         cmsg->cmsg_level = SOL_TLS;
157         /* test sending non-record types. */
158         cmsg->cmsg_type = TLS_SET_RECORD_TYPE;
159         cmsg->cmsg_len = CMSG_LEN(cmsg_len);
160         *CMSG_DATA(cmsg) = record_type;
161         msg.msg_controllen = cmsg->cmsg_len;
162
163         return sendmsg(fd, &msg, flags);
164 }
165
166 static int tls_recv_cmsg(struct __test_metadata *_metadata,
167                          int fd, unsigned char record_type,
168                          void *data, size_t len, int flags)
169 {
170         char cbuf[CMSG_SPACE(sizeof(char))];
171         struct cmsghdr *cmsg;
172         unsigned char ctype;
173         struct msghdr msg;
174         struct iovec vec;
175         int n;
176
177         vec.iov_base = data;
178         vec.iov_len = len;
179         memset(&msg, 0, sizeof(struct msghdr));
180         msg.msg_iov = &vec;
181         msg.msg_iovlen = 1;
182         msg.msg_control = cbuf;
183         msg.msg_controllen = sizeof(cbuf);
184
185         n = recvmsg(fd, &msg, flags);
186
187         cmsg = CMSG_FIRSTHDR(&msg);
188         EXPECT_NE(cmsg, NULL);
189         EXPECT_EQ(cmsg->cmsg_level, SOL_TLS);
190         EXPECT_EQ(cmsg->cmsg_type, TLS_GET_RECORD_TYPE);
191         ctype = *((unsigned char *)CMSG_DATA(cmsg));
192         EXPECT_EQ(ctype, record_type);
193
194         return n;
195 }
196
197 FIXTURE(tls_basic)
198 {
199         int fd, cfd;
200         bool notls;
201 };
202
203 FIXTURE_SETUP(tls_basic)
204 {
205         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
206 }
207
208 FIXTURE_TEARDOWN(tls_basic)
209 {
210         close(self->fd);
211         close(self->cfd);
212 }
213
214 /* Send some data through with ULP but no keys */
215 TEST_F(tls_basic, base_base)
216 {
217         char const *test_str = "test_read";
218         int send_len = 10;
219         char buf[10];
220
221         ASSERT_EQ(strlen(test_str) + 1, send_len);
222
223         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
224         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
225         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
226 };
227
228 FIXTURE(tls)
229 {
230         int fd, cfd;
231         bool notls;
232 };
233
234 FIXTURE_VARIANT(tls)
235 {
236         uint16_t tls_version;
237         uint16_t cipher_type;
238 };
239
240 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm)
241 {
242         .tls_version = TLS_1_2_VERSION,
243         .cipher_type = TLS_CIPHER_AES_GCM_128,
244 };
245
246 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm)
247 {
248         .tls_version = TLS_1_3_VERSION,
249         .cipher_type = TLS_CIPHER_AES_GCM_128,
250 };
251
252 FIXTURE_VARIANT_ADD(tls, 12_chacha)
253 {
254         .tls_version = TLS_1_2_VERSION,
255         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
256 };
257
258 FIXTURE_VARIANT_ADD(tls, 13_chacha)
259 {
260         .tls_version = TLS_1_3_VERSION,
261         .cipher_type = TLS_CIPHER_CHACHA20_POLY1305,
262 };
263
264 FIXTURE_VARIANT_ADD(tls, 13_sm4_gcm)
265 {
266         .tls_version = TLS_1_3_VERSION,
267         .cipher_type = TLS_CIPHER_SM4_GCM,
268 };
269
270 FIXTURE_VARIANT_ADD(tls, 13_sm4_ccm)
271 {
272         .tls_version = TLS_1_3_VERSION,
273         .cipher_type = TLS_CIPHER_SM4_CCM,
274 };
275
276 FIXTURE_VARIANT_ADD(tls, 12_aes_ccm)
277 {
278         .tls_version = TLS_1_2_VERSION,
279         .cipher_type = TLS_CIPHER_AES_CCM_128,
280 };
281
282 FIXTURE_VARIANT_ADD(tls, 13_aes_ccm)
283 {
284         .tls_version = TLS_1_3_VERSION,
285         .cipher_type = TLS_CIPHER_AES_CCM_128,
286 };
287
288 FIXTURE_VARIANT_ADD(tls, 12_aes_gcm_256)
289 {
290         .tls_version = TLS_1_2_VERSION,
291         .cipher_type = TLS_CIPHER_AES_GCM_256,
292 };
293
294 FIXTURE_VARIANT_ADD(tls, 13_aes_gcm_256)
295 {
296         .tls_version = TLS_1_3_VERSION,
297         .cipher_type = TLS_CIPHER_AES_GCM_256,
298 };
299
300 FIXTURE_SETUP(tls)
301 {
302         struct tls_crypto_info_keys tls12;
303         int ret;
304
305         tls_crypto_info_init(variant->tls_version, variant->cipher_type,
306                              &tls12);
307
308         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
309
310         if (self->notls)
311                 return;
312
313         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
314         ASSERT_EQ(ret, 0);
315
316         ret = setsockopt(self->cfd, SOL_TLS, TLS_RX, &tls12, tls12.len);
317         ASSERT_EQ(ret, 0);
318 }
319
320 FIXTURE_TEARDOWN(tls)
321 {
322         close(self->fd);
323         close(self->cfd);
324 }
325
326 TEST_F(tls, sendfile)
327 {
328         int filefd = open("/proc/self/exe", O_RDONLY);
329         struct stat st;
330
331         EXPECT_GE(filefd, 0);
332         fstat(filefd, &st);
333         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
334 }
335
336 TEST_F(tls, send_then_sendfile)
337 {
338         int filefd = open("/proc/self/exe", O_RDONLY);
339         char const *test_str = "test_send";
340         int to_send = strlen(test_str) + 1;
341         char recv_buf[10];
342         struct stat st;
343         char *buf;
344
345         EXPECT_GE(filefd, 0);
346         fstat(filefd, &st);
347         buf = (char *)malloc(st.st_size);
348
349         EXPECT_EQ(send(self->fd, test_str, to_send, 0), to_send);
350         EXPECT_EQ(recv(self->cfd, recv_buf, to_send, MSG_WAITALL), to_send);
351         EXPECT_EQ(memcmp(test_str, recv_buf, to_send), 0);
352
353         EXPECT_GE(sendfile(self->fd, filefd, 0, st.st_size), 0);
354         EXPECT_EQ(recv(self->cfd, buf, st.st_size, MSG_WAITALL), st.st_size);
355 }
356
357 static void chunked_sendfile(struct __test_metadata *_metadata,
358                              struct _test_data_tls *self,
359                              uint16_t chunk_size,
360                              uint16_t extra_payload_size)
361 {
362         char buf[TLS_PAYLOAD_MAX_LEN];
363         uint16_t test_payload_size;
364         int size = 0;
365         int ret;
366         char filename[] = "/tmp/mytemp.XXXXXX";
367         int fd = mkstemp(filename);
368         off_t offset = 0;
369
370         unlink(filename);
371         ASSERT_GE(fd, 0);
372         EXPECT_GE(chunk_size, 1);
373         test_payload_size = chunk_size + extra_payload_size;
374         ASSERT_GE(TLS_PAYLOAD_MAX_LEN, test_payload_size);
375         memset(buf, 1, test_payload_size);
376         size = write(fd, buf, test_payload_size);
377         EXPECT_EQ(size, test_payload_size);
378         fsync(fd);
379
380         while (size > 0) {
381                 ret = sendfile(self->fd, fd, &offset, chunk_size);
382                 EXPECT_GE(ret, 0);
383                 size -= ret;
384         }
385
386         EXPECT_EQ(recv(self->cfd, buf, test_payload_size, MSG_WAITALL),
387                   test_payload_size);
388
389         close(fd);
390 }
391
392 TEST_F(tls, multi_chunk_sendfile)
393 {
394         chunked_sendfile(_metadata, self, 4096, 4096);
395         chunked_sendfile(_metadata, self, 4096, 0);
396         chunked_sendfile(_metadata, self, 4096, 1);
397         chunked_sendfile(_metadata, self, 4096, 2048);
398         chunked_sendfile(_metadata, self, 8192, 2048);
399         chunked_sendfile(_metadata, self, 4096, 8192);
400         chunked_sendfile(_metadata, self, 8192, 4096);
401         chunked_sendfile(_metadata, self, 12288, 1024);
402         chunked_sendfile(_metadata, self, 12288, 2000);
403         chunked_sendfile(_metadata, self, 15360, 100);
404         chunked_sendfile(_metadata, self, 15360, 300);
405         chunked_sendfile(_metadata, self, 1, 4096);
406         chunked_sendfile(_metadata, self, 2048, 4096);
407         chunked_sendfile(_metadata, self, 2048, 8192);
408         chunked_sendfile(_metadata, self, 4096, 8192);
409         chunked_sendfile(_metadata, self, 1024, 12288);
410         chunked_sendfile(_metadata, self, 2000, 12288);
411         chunked_sendfile(_metadata, self, 100, 15360);
412         chunked_sendfile(_metadata, self, 300, 15360);
413 }
414
415 TEST_F(tls, recv_max)
416 {
417         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
418         char recv_mem[TLS_PAYLOAD_MAX_LEN];
419         char buf[TLS_PAYLOAD_MAX_LEN];
420
421         memrnd(buf, sizeof(buf));
422
423         EXPECT_GE(send(self->fd, buf, send_len, 0), 0);
424         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
425         EXPECT_EQ(memcmp(buf, recv_mem, send_len), 0);
426 }
427
428 TEST_F(tls, recv_small)
429 {
430         char const *test_str = "test_read";
431         int send_len = 10;
432         char buf[10];
433
434         send_len = strlen(test_str) + 1;
435         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
436         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
437         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
438 }
439
440 TEST_F(tls, msg_more)
441 {
442         char const *test_str = "test_read";
443         int send_len = 10;
444         char buf[10 * 2];
445
446         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
447         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
448         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
449         EXPECT_EQ(recv(self->cfd, buf, send_len * 2, MSG_WAITALL),
450                   send_len * 2);
451         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
452 }
453
454 TEST_F(tls, msg_more_unsent)
455 {
456         char const *test_str = "test_read";
457         int send_len = 10;
458         char buf[10];
459
460         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
461         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_DONTWAIT), -1);
462 }
463
464 TEST_F(tls, sendmsg_single)
465 {
466         struct msghdr msg;
467
468         char const *test_str = "test_sendmsg";
469         size_t send_len = 13;
470         struct iovec vec;
471         char buf[13];
472
473         vec.iov_base = (char *)test_str;
474         vec.iov_len = send_len;
475         memset(&msg, 0, sizeof(struct msghdr));
476         msg.msg_iov = &vec;
477         msg.msg_iovlen = 1;
478         EXPECT_EQ(sendmsg(self->fd, &msg, 0), send_len);
479         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
480         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
481 }
482
483 #define MAX_FRAGS       64
484 #define SEND_LEN        13
485 TEST_F(tls, sendmsg_fragmented)
486 {
487         char const *test_str = "test_sendmsg";
488         char buf[SEND_LEN * MAX_FRAGS];
489         struct iovec vec[MAX_FRAGS];
490         struct msghdr msg;
491         int i, frags;
492
493         for (frags = 1; frags <= MAX_FRAGS; frags++) {
494                 for (i = 0; i < frags; i++) {
495                         vec[i].iov_base = (char *)test_str;
496                         vec[i].iov_len = SEND_LEN;
497                 }
498
499                 memset(&msg, 0, sizeof(struct msghdr));
500                 msg.msg_iov = vec;
501                 msg.msg_iovlen = frags;
502
503                 EXPECT_EQ(sendmsg(self->fd, &msg, 0), SEND_LEN * frags);
504                 EXPECT_EQ(recv(self->cfd, buf, SEND_LEN * frags, MSG_WAITALL),
505                           SEND_LEN * frags);
506
507                 for (i = 0; i < frags; i++)
508                         EXPECT_EQ(memcmp(buf + SEND_LEN * i,
509                                          test_str, SEND_LEN), 0);
510         }
511 }
512 #undef MAX_FRAGS
513 #undef SEND_LEN
514
515 TEST_F(tls, sendmsg_large)
516 {
517         void *mem = malloc(16384);
518         size_t send_len = 16384;
519         size_t sends = 128;
520         struct msghdr msg;
521         size_t recvs = 0;
522         size_t sent = 0;
523
524         memset(&msg, 0, sizeof(struct msghdr));
525         while (sent++ < sends) {
526                 struct iovec vec = { (void *)mem, send_len };
527
528                 msg.msg_iov = &vec;
529                 msg.msg_iovlen = 1;
530                 EXPECT_EQ(sendmsg(self->cfd, &msg, 0), send_len);
531         }
532
533         while (recvs++ < sends) {
534                 EXPECT_NE(recv(self->fd, mem, send_len, 0), -1);
535         }
536
537         free(mem);
538 }
539
540 TEST_F(tls, sendmsg_multiple)
541 {
542         char const *test_str = "test_sendmsg_multiple";
543         struct iovec vec[5];
544         char *test_strs[5];
545         struct msghdr msg;
546         int total_len = 0;
547         int len_cmp = 0;
548         int iov_len = 5;
549         char *buf;
550         int i;
551
552         memset(&msg, 0, sizeof(struct msghdr));
553         for (i = 0; i < iov_len; i++) {
554                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
555                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
556                 vec[i].iov_base = (void *)test_strs[i];
557                 vec[i].iov_len = strlen(test_strs[i]) + 1;
558                 total_len += vec[i].iov_len;
559         }
560         msg.msg_iov = vec;
561         msg.msg_iovlen = iov_len;
562
563         EXPECT_EQ(sendmsg(self->cfd, &msg, 0), total_len);
564         buf = malloc(total_len);
565         EXPECT_NE(recv(self->fd, buf, total_len, 0), -1);
566         for (i = 0; i < iov_len; i++) {
567                 EXPECT_EQ(memcmp(test_strs[i], buf + len_cmp,
568                                  strlen(test_strs[i])),
569                           0);
570                 len_cmp += strlen(buf + len_cmp) + 1;
571         }
572         for (i = 0; i < iov_len; i++)
573                 free(test_strs[i]);
574         free(buf);
575 }
576
577 TEST_F(tls, sendmsg_multiple_stress)
578 {
579         char const *test_str = "abcdefghijklmno";
580         struct iovec vec[1024];
581         char *test_strs[1024];
582         int iov_len = 1024;
583         int total_len = 0;
584         char buf[1 << 14];
585         struct msghdr msg;
586         int len_cmp = 0;
587         int i;
588
589         memset(&msg, 0, sizeof(struct msghdr));
590         for (i = 0; i < iov_len; i++) {
591                 test_strs[i] = (char *)malloc(strlen(test_str) + 1);
592                 snprintf(test_strs[i], strlen(test_str) + 1, "%s", test_str);
593                 vec[i].iov_base = (void *)test_strs[i];
594                 vec[i].iov_len = strlen(test_strs[i]) + 1;
595                 total_len += vec[i].iov_len;
596         }
597         msg.msg_iov = vec;
598         msg.msg_iovlen = iov_len;
599
600         EXPECT_EQ(sendmsg(self->fd, &msg, 0), total_len);
601         EXPECT_NE(recv(self->cfd, buf, total_len, 0), -1);
602
603         for (i = 0; i < iov_len; i++)
604                 len_cmp += strlen(buf + len_cmp) + 1;
605
606         for (i = 0; i < iov_len; i++)
607                 free(test_strs[i]);
608 }
609
610 TEST_F(tls, splice_from_pipe)
611 {
612         int send_len = TLS_PAYLOAD_MAX_LEN;
613         char mem_send[TLS_PAYLOAD_MAX_LEN];
614         char mem_recv[TLS_PAYLOAD_MAX_LEN];
615         int p[2];
616
617         ASSERT_GE(pipe(p), 0);
618         EXPECT_GE(write(p[1], mem_send, send_len), 0);
619         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), 0);
620         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
621         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
622 }
623
624 TEST_F(tls, splice_from_pipe2)
625 {
626         int send_len = 16000;
627         char mem_send[16000];
628         char mem_recv[16000];
629         int p2[2];
630         int p[2];
631
632         ASSERT_GE(pipe(p), 0);
633         ASSERT_GE(pipe(p2), 0);
634         EXPECT_GE(write(p[1], mem_send, 8000), 0);
635         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, 8000, 0), 0);
636         EXPECT_GE(write(p2[1], mem_send + 8000, 8000), 0);
637         EXPECT_GE(splice(p2[0], NULL, self->fd, NULL, 8000, 0), 0);
638         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
639         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
640 }
641
642 TEST_F(tls, send_and_splice)
643 {
644         int send_len = TLS_PAYLOAD_MAX_LEN;
645         char mem_send[TLS_PAYLOAD_MAX_LEN];
646         char mem_recv[TLS_PAYLOAD_MAX_LEN];
647         char const *test_str = "test_read";
648         int send_len2 = 10;
649         char buf[10];
650         int p[2];
651
652         ASSERT_GE(pipe(p), 0);
653         EXPECT_EQ(send(self->fd, test_str, send_len2, 0), send_len2);
654         EXPECT_EQ(recv(self->cfd, buf, send_len2, MSG_WAITALL), send_len2);
655         EXPECT_EQ(memcmp(test_str, buf, send_len2), 0);
656
657         EXPECT_GE(write(p[1], mem_send, send_len), send_len);
658         EXPECT_GE(splice(p[0], NULL, self->fd, NULL, send_len, 0), send_len);
659
660         EXPECT_EQ(recv(self->cfd, mem_recv, send_len, MSG_WAITALL), send_len);
661         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
662 }
663
664 TEST_F(tls, splice_to_pipe)
665 {
666         int send_len = TLS_PAYLOAD_MAX_LEN;
667         char mem_send[TLS_PAYLOAD_MAX_LEN];
668         char mem_recv[TLS_PAYLOAD_MAX_LEN];
669         int p[2];
670
671         ASSERT_GE(pipe(p), 0);
672         EXPECT_GE(send(self->fd, mem_send, send_len, 0), 0);
673         EXPECT_GE(splice(self->cfd, NULL, p[1], NULL, send_len, 0), 0);
674         EXPECT_GE(read(p[0], mem_recv, send_len), 0);
675         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
676 }
677
678 TEST_F(tls, splice_cmsg_to_pipe)
679 {
680         char *test_str = "test_read";
681         char record_type = 100;
682         int send_len = 10;
683         char buf[10];
684         int p[2];
685
686         ASSERT_GE(pipe(p), 0);
687         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
688         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
689         EXPECT_EQ(errno, EINVAL);
690         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
691         EXPECT_EQ(errno, EIO);
692         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
693                                 buf, sizeof(buf), MSG_WAITALL),
694                   send_len);
695         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
696 }
697
698 TEST_F(tls, splice_dec_cmsg_to_pipe)
699 {
700         char *test_str = "test_read";
701         char record_type = 100;
702         int send_len = 10;
703         char buf[10];
704         int p[2];
705
706         ASSERT_GE(pipe(p), 0);
707         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
708         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
709         EXPECT_EQ(errno, EIO);
710         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, send_len, 0), -1);
711         EXPECT_EQ(errno, EINVAL);
712         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
713                                 buf, sizeof(buf), MSG_WAITALL),
714                   send_len);
715         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
716 }
717
718 TEST_F(tls, recv_and_splice)
719 {
720         int send_len = TLS_PAYLOAD_MAX_LEN;
721         char mem_send[TLS_PAYLOAD_MAX_LEN];
722         char mem_recv[TLS_PAYLOAD_MAX_LEN];
723         int half = send_len / 2;
724         int p[2];
725
726         ASSERT_GE(pipe(p), 0);
727         EXPECT_EQ(send(self->fd, mem_send, send_len, 0), send_len);
728         /* Recv hald of the record, splice the other half */
729         EXPECT_EQ(recv(self->cfd, mem_recv, half, MSG_WAITALL), half);
730         EXPECT_EQ(splice(self->cfd, NULL, p[1], NULL, half, SPLICE_F_NONBLOCK),
731                   half);
732         EXPECT_EQ(read(p[0], &mem_recv[half], half), half);
733         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
734 }
735
736 TEST_F(tls, peek_and_splice)
737 {
738         int send_len = TLS_PAYLOAD_MAX_LEN;
739         char mem_send[TLS_PAYLOAD_MAX_LEN];
740         char mem_recv[TLS_PAYLOAD_MAX_LEN];
741         int chunk = TLS_PAYLOAD_MAX_LEN / 4;
742         int n, i, p[2];
743
744         memrnd(mem_send, sizeof(mem_send));
745
746         ASSERT_GE(pipe(p), 0);
747         for (i = 0; i < 4; i++)
748                 EXPECT_EQ(send(self->fd, &mem_send[chunk * i], chunk, 0),
749                           chunk);
750
751         EXPECT_EQ(recv(self->cfd, mem_recv, chunk * 5 / 2,
752                        MSG_WAITALL | MSG_PEEK),
753                   chunk * 5 / 2);
754         EXPECT_EQ(memcmp(mem_send, mem_recv, chunk * 5 / 2), 0);
755
756         n = 0;
757         while (n < send_len) {
758                 i = splice(self->cfd, NULL, p[1], NULL, send_len - n, 0);
759                 EXPECT_GT(i, 0);
760                 n += i;
761         }
762         EXPECT_EQ(n, send_len);
763         EXPECT_EQ(read(p[0], mem_recv, send_len), send_len);
764         EXPECT_EQ(memcmp(mem_send, mem_recv, send_len), 0);
765 }
766
767 TEST_F(tls, recvmsg_single)
768 {
769         char const *test_str = "test_recvmsg_single";
770         int send_len = strlen(test_str) + 1;
771         char buf[20];
772         struct msghdr hdr;
773         struct iovec vec;
774
775         memset(&hdr, 0, sizeof(hdr));
776         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
777         vec.iov_base = (char *)buf;
778         vec.iov_len = send_len;
779         hdr.msg_iovlen = 1;
780         hdr.msg_iov = &vec;
781         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
782         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
783 }
784
785 TEST_F(tls, recvmsg_single_max)
786 {
787         int send_len = TLS_PAYLOAD_MAX_LEN;
788         char send_mem[TLS_PAYLOAD_MAX_LEN];
789         char recv_mem[TLS_PAYLOAD_MAX_LEN];
790         struct iovec vec;
791         struct msghdr hdr;
792
793         memrnd(send_mem, sizeof(send_mem));
794
795         EXPECT_EQ(send(self->fd, send_mem, send_len, 0), send_len);
796         vec.iov_base = (char *)recv_mem;
797         vec.iov_len = TLS_PAYLOAD_MAX_LEN;
798
799         hdr.msg_iovlen = 1;
800         hdr.msg_iov = &vec;
801         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
802         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
803 }
804
805 TEST_F(tls, recvmsg_multiple)
806 {
807         unsigned int msg_iovlen = 1024;
808         struct iovec vec[1024];
809         char *iov_base[1024];
810         unsigned int iov_len = 16;
811         int send_len = 1 << 14;
812         char buf[1 << 14];
813         struct msghdr hdr;
814         int i;
815
816         memrnd(buf, sizeof(buf));
817
818         EXPECT_EQ(send(self->fd, buf, send_len, 0), send_len);
819         for (i = 0; i < msg_iovlen; i++) {
820                 iov_base[i] = (char *)malloc(iov_len);
821                 vec[i].iov_base = iov_base[i];
822                 vec[i].iov_len = iov_len;
823         }
824
825         hdr.msg_iovlen = msg_iovlen;
826         hdr.msg_iov = vec;
827         EXPECT_NE(recvmsg(self->cfd, &hdr, 0), -1);
828
829         for (i = 0; i < msg_iovlen; i++)
830                 free(iov_base[i]);
831 }
832
833 TEST_F(tls, single_send_multiple_recv)
834 {
835         unsigned int total_len = TLS_PAYLOAD_MAX_LEN * 2;
836         unsigned int send_len = TLS_PAYLOAD_MAX_LEN;
837         char send_mem[TLS_PAYLOAD_MAX_LEN * 2];
838         char recv_mem[TLS_PAYLOAD_MAX_LEN * 2];
839
840         memrnd(send_mem, sizeof(send_mem));
841
842         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
843         memset(recv_mem, 0, total_len);
844
845         EXPECT_NE(recv(self->cfd, recv_mem, send_len, 0), -1);
846         EXPECT_NE(recv(self->cfd, recv_mem + send_len, send_len, 0), -1);
847         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
848 }
849
850 TEST_F(tls, multiple_send_single_recv)
851 {
852         unsigned int total_len = 2 * 10;
853         unsigned int send_len = 10;
854         char recv_mem[2 * 10];
855         char send_mem[10];
856
857         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
858         EXPECT_GE(send(self->fd, send_mem, send_len, 0), 0);
859         memset(recv_mem, 0, total_len);
860         EXPECT_EQ(recv(self->cfd, recv_mem, total_len, MSG_WAITALL), total_len);
861
862         EXPECT_EQ(memcmp(send_mem, recv_mem, send_len), 0);
863         EXPECT_EQ(memcmp(send_mem, recv_mem + send_len, send_len), 0);
864 }
865
866 TEST_F(tls, single_send_multiple_recv_non_align)
867 {
868         const unsigned int total_len = 15;
869         const unsigned int recv_len = 10;
870         char recv_mem[recv_len * 2];
871         char send_mem[total_len];
872
873         EXPECT_GE(send(self->fd, send_mem, total_len, 0), 0);
874         memset(recv_mem, 0, total_len);
875
876         EXPECT_EQ(recv(self->cfd, recv_mem, recv_len, 0), recv_len);
877         EXPECT_EQ(recv(self->cfd, recv_mem + recv_len, recv_len, 0), 5);
878         EXPECT_EQ(memcmp(send_mem, recv_mem, total_len), 0);
879 }
880
881 TEST_F(tls, recv_partial)
882 {
883         char const *test_str = "test_read_partial";
884         char const *test_str_first = "test_read";
885         char const *test_str_second = "_partial";
886         int send_len = strlen(test_str) + 1;
887         char recv_mem[18];
888
889         memset(recv_mem, 0, sizeof(recv_mem));
890         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
891         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_first),
892                        MSG_WAITALL), -1);
893         EXPECT_EQ(memcmp(test_str_first, recv_mem, strlen(test_str_first)), 0);
894         memset(recv_mem, 0, sizeof(recv_mem));
895         EXPECT_NE(recv(self->cfd, recv_mem, strlen(test_str_second),
896                        MSG_WAITALL), -1);
897         EXPECT_EQ(memcmp(test_str_second, recv_mem, strlen(test_str_second)),
898                   0);
899 }
900
901 TEST_F(tls, recv_nonblock)
902 {
903         char buf[4096];
904         bool err;
905
906         EXPECT_EQ(recv(self->cfd, buf, sizeof(buf), MSG_DONTWAIT), -1);
907         err = (errno == EAGAIN || errno == EWOULDBLOCK);
908         EXPECT_EQ(err, true);
909 }
910
911 TEST_F(tls, recv_peek)
912 {
913         char const *test_str = "test_read_peek";
914         int send_len = strlen(test_str) + 1;
915         char buf[15];
916
917         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
918         EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
919         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
920         memset(buf, 0, sizeof(buf));
921         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
922         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
923 }
924
925 TEST_F(tls, recv_peek_multiple)
926 {
927         char const *test_str = "test_read_peek";
928         int send_len = strlen(test_str) + 1;
929         unsigned int num_peeks = 100;
930         char buf[15];
931         int i;
932
933         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
934         for (i = 0; i < num_peeks; i++) {
935                 EXPECT_NE(recv(self->cfd, buf, send_len, MSG_PEEK), -1);
936                 EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
937                 memset(buf, 0, sizeof(buf));
938         }
939         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
940         EXPECT_EQ(memcmp(test_str, buf, send_len), 0);
941 }
942
943 TEST_F(tls, recv_peek_multiple_records)
944 {
945         char const *test_str = "test_read_peek_mult_recs";
946         char const *test_str_first = "test_read_peek";
947         char const *test_str_second = "_mult_recs";
948         int len;
949         char buf[64];
950
951         len = strlen(test_str_first);
952         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
953
954         len = strlen(test_str_second) + 1;
955         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
956
957         len = strlen(test_str_first);
958         memset(buf, 0, len);
959         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
960
961         /* MSG_PEEK can only peek into the current record. */
962         len = strlen(test_str_first);
963         EXPECT_EQ(memcmp(test_str_first, buf, len), 0);
964
965         len = strlen(test_str) + 1;
966         memset(buf, 0, len);
967         EXPECT_EQ(recv(self->cfd, buf, len, MSG_WAITALL), len);
968
969         /* Non-MSG_PEEK will advance strparser (and therefore record)
970          * however.
971          */
972         len = strlen(test_str) + 1;
973         EXPECT_EQ(memcmp(test_str, buf, len), 0);
974
975         /* MSG_MORE will hold current record open, so later MSG_PEEK
976          * will see everything.
977          */
978         len = strlen(test_str_first);
979         EXPECT_EQ(send(self->fd, test_str_first, len, MSG_MORE), len);
980
981         len = strlen(test_str_second) + 1;
982         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
983
984         len = strlen(test_str) + 1;
985         memset(buf, 0, len);
986         EXPECT_EQ(recv(self->cfd, buf, len, MSG_PEEK | MSG_WAITALL), len);
987
988         len = strlen(test_str) + 1;
989         EXPECT_EQ(memcmp(test_str, buf, len), 0);
990 }
991
992 TEST_F(tls, recv_peek_large_buf_mult_recs)
993 {
994         char const *test_str = "test_read_peek_mult_recs";
995         char const *test_str_first = "test_read_peek";
996         char const *test_str_second = "_mult_recs";
997         int len;
998         char buf[64];
999
1000         len = strlen(test_str_first);
1001         EXPECT_EQ(send(self->fd, test_str_first, len, 0), len);
1002
1003         len = strlen(test_str_second) + 1;
1004         EXPECT_EQ(send(self->fd, test_str_second, len, 0), len);
1005
1006         len = strlen(test_str) + 1;
1007         memset(buf, 0, len);
1008         EXPECT_NE((len = recv(self->cfd, buf, len,
1009                               MSG_PEEK | MSG_WAITALL)), -1);
1010         len = strlen(test_str) + 1;
1011         EXPECT_EQ(memcmp(test_str, buf, len), 0);
1012 }
1013
1014 TEST_F(tls, recv_lowat)
1015 {
1016         char send_mem[10] = { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9 };
1017         char recv_mem[20];
1018         int lowat = 8;
1019
1020         EXPECT_EQ(send(self->fd, send_mem, 10, 0), 10);
1021         EXPECT_EQ(send(self->fd, send_mem, 5, 0), 5);
1022
1023         memset(recv_mem, 0, 20);
1024         EXPECT_EQ(setsockopt(self->cfd, SOL_SOCKET, SO_RCVLOWAT,
1025                              &lowat, sizeof(lowat)), 0);
1026         EXPECT_EQ(recv(self->cfd, recv_mem, 1, MSG_WAITALL), 1);
1027         EXPECT_EQ(recv(self->cfd, recv_mem + 1, 6, MSG_WAITALL), 6);
1028         EXPECT_EQ(recv(self->cfd, recv_mem + 7, 10, 0), 8);
1029
1030         EXPECT_EQ(memcmp(send_mem, recv_mem, 10), 0);
1031         EXPECT_EQ(memcmp(send_mem, recv_mem + 10, 5), 0);
1032 }
1033
1034 TEST_F(tls, bidir)
1035 {
1036         char const *test_str = "test_read";
1037         int send_len = 10;
1038         char buf[10];
1039         int ret;
1040
1041         if (!self->notls) {
1042                 struct tls_crypto_info_keys tls12;
1043
1044                 tls_crypto_info_init(variant->tls_version, variant->cipher_type,
1045                                      &tls12);
1046
1047                 ret = setsockopt(self->fd, SOL_TLS, TLS_RX, &tls12,
1048                                  tls12.len);
1049                 ASSERT_EQ(ret, 0);
1050
1051                 ret = setsockopt(self->cfd, SOL_TLS, TLS_TX, &tls12,
1052                                  tls12.len);
1053                 ASSERT_EQ(ret, 0);
1054         }
1055
1056         ASSERT_EQ(strlen(test_str) + 1, send_len);
1057
1058         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1059         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1060         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1061
1062         memset(buf, 0, sizeof(buf));
1063
1064         EXPECT_EQ(send(self->cfd, test_str, send_len, 0), send_len);
1065         EXPECT_NE(recv(self->fd, buf, send_len, 0), -1);
1066         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1067 };
1068
1069 TEST_F(tls, pollin)
1070 {
1071         char const *test_str = "test_poll";
1072         struct pollfd fd = { 0, 0, 0 };
1073         char buf[10];
1074         int send_len = 10;
1075
1076         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1077         fd.fd = self->cfd;
1078         fd.events = POLLIN;
1079
1080         EXPECT_EQ(poll(&fd, 1, 20), 1);
1081         EXPECT_EQ(fd.revents & POLLIN, 1);
1082         EXPECT_EQ(recv(self->cfd, buf, send_len, MSG_WAITALL), send_len);
1083         /* Test timing out */
1084         EXPECT_EQ(poll(&fd, 1, 20), 0);
1085 }
1086
1087 TEST_F(tls, poll_wait)
1088 {
1089         char const *test_str = "test_poll_wait";
1090         int send_len = strlen(test_str) + 1;
1091         struct pollfd fd = { 0, 0, 0 };
1092         char recv_mem[15];
1093
1094         fd.fd = self->cfd;
1095         fd.events = POLLIN;
1096         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1097         /* Set timeout to inf. secs */
1098         EXPECT_EQ(poll(&fd, 1, -1), 1);
1099         EXPECT_EQ(fd.revents & POLLIN, 1);
1100         EXPECT_EQ(recv(self->cfd, recv_mem, send_len, MSG_WAITALL), send_len);
1101 }
1102
1103 TEST_F(tls, poll_wait_split)
1104 {
1105         struct pollfd fd = { 0, 0, 0 };
1106         char send_mem[20] = {};
1107         char recv_mem[15];
1108
1109         fd.fd = self->cfd;
1110         fd.events = POLLIN;
1111         /* Send 20 bytes */
1112         EXPECT_EQ(send(self->fd, send_mem, sizeof(send_mem), 0),
1113                   sizeof(send_mem));
1114         /* Poll with inf. timeout */
1115         EXPECT_EQ(poll(&fd, 1, -1), 1);
1116         EXPECT_EQ(fd.revents & POLLIN, 1);
1117         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), MSG_WAITALL),
1118                   sizeof(recv_mem));
1119
1120         /* Now the remaining 5 bytes of record data are in TLS ULP */
1121         fd.fd = self->cfd;
1122         fd.events = POLLIN;
1123         EXPECT_EQ(poll(&fd, 1, -1), 1);
1124         EXPECT_EQ(fd.revents & POLLIN, 1);
1125         EXPECT_EQ(recv(self->cfd, recv_mem, sizeof(recv_mem), 0),
1126                   sizeof(send_mem) - sizeof(recv_mem));
1127 }
1128
1129 TEST_F(tls, blocking)
1130 {
1131         size_t data = 100000;
1132         int res = fork();
1133
1134         EXPECT_NE(res, -1);
1135
1136         if (res) {
1137                 /* parent */
1138                 size_t left = data;
1139                 char buf[16384];
1140                 int status;
1141                 int pid2;
1142
1143                 while (left) {
1144                         int res = send(self->fd, buf,
1145                                        left > 16384 ? 16384 : left, 0);
1146
1147                         EXPECT_GE(res, 0);
1148                         left -= res;
1149                 }
1150
1151                 pid2 = wait(&status);
1152                 EXPECT_EQ(status, 0);
1153                 EXPECT_EQ(res, pid2);
1154         } else {
1155                 /* child */
1156                 size_t left = data;
1157                 char buf[16384];
1158
1159                 while (left) {
1160                         int res = recv(self->cfd, buf,
1161                                        left > 16384 ? 16384 : left, 0);
1162
1163                         EXPECT_GE(res, 0);
1164                         left -= res;
1165                 }
1166         }
1167 }
1168
1169 TEST_F(tls, nonblocking)
1170 {
1171         size_t data = 100000;
1172         int sendbuf = 100;
1173         int flags;
1174         int res;
1175
1176         flags = fcntl(self->fd, F_GETFL, 0);
1177         fcntl(self->fd, F_SETFL, flags | O_NONBLOCK);
1178         fcntl(self->cfd, F_SETFL, flags | O_NONBLOCK);
1179
1180         /* Ensure nonblocking behavior by imposing a small send
1181          * buffer.
1182          */
1183         EXPECT_EQ(setsockopt(self->fd, SOL_SOCKET, SO_SNDBUF,
1184                              &sendbuf, sizeof(sendbuf)), 0);
1185
1186         res = fork();
1187         EXPECT_NE(res, -1);
1188
1189         if (res) {
1190                 /* parent */
1191                 bool eagain = false;
1192                 size_t left = data;
1193                 char buf[16384];
1194                 int status;
1195                 int pid2;
1196
1197                 while (left) {
1198                         int res = send(self->fd, buf,
1199                                        left > 16384 ? 16384 : left, 0);
1200
1201                         if (res == -1 && errno == EAGAIN) {
1202                                 eagain = true;
1203                                 usleep(10000);
1204                                 continue;
1205                         }
1206                         EXPECT_GE(res, 0);
1207                         left -= res;
1208                 }
1209
1210                 EXPECT_TRUE(eagain);
1211                 pid2 = wait(&status);
1212
1213                 EXPECT_EQ(status, 0);
1214                 EXPECT_EQ(res, pid2);
1215         } else {
1216                 /* child */
1217                 bool eagain = false;
1218                 size_t left = data;
1219                 char buf[16384];
1220
1221                 while (left) {
1222                         int res = recv(self->cfd, buf,
1223                                        left > 16384 ? 16384 : left, 0);
1224
1225                         if (res == -1 && errno == EAGAIN) {
1226                                 eagain = true;
1227                                 usleep(10000);
1228                                 continue;
1229                         }
1230                         EXPECT_GE(res, 0);
1231                         left -= res;
1232                 }
1233                 EXPECT_TRUE(eagain);
1234         }
1235 }
1236
1237 static void
1238 test_mutliproc(struct __test_metadata *_metadata, struct _test_data_tls *self,
1239                bool sendpg, unsigned int n_readers, unsigned int n_writers)
1240 {
1241         const unsigned int n_children = n_readers + n_writers;
1242         const size_t data = 6 * 1000 * 1000;
1243         const size_t file_sz = data / 100;
1244         size_t read_bias, write_bias;
1245         int i, fd, child_id;
1246         char buf[file_sz];
1247         pid_t pid;
1248
1249         /* Only allow multiples for simplicity */
1250         ASSERT_EQ(!(n_readers % n_writers) || !(n_writers % n_readers), true);
1251         read_bias = n_writers / n_readers ?: 1;
1252         write_bias = n_readers / n_writers ?: 1;
1253
1254         /* prep a file to send */
1255         fd = open("/tmp/", O_TMPFILE | O_RDWR, 0600);
1256         ASSERT_GE(fd, 0);
1257
1258         memset(buf, 0xac, file_sz);
1259         ASSERT_EQ(write(fd, buf, file_sz), file_sz);
1260
1261         /* spawn children */
1262         for (child_id = 0; child_id < n_children; child_id++) {
1263                 pid = fork();
1264                 ASSERT_NE(pid, -1);
1265                 if (!pid)
1266                         break;
1267         }
1268
1269         /* parent waits for all children */
1270         if (pid) {
1271                 for (i = 0; i < n_children; i++) {
1272                         int status;
1273
1274                         wait(&status);
1275                         EXPECT_EQ(status, 0);
1276                 }
1277
1278                 return;
1279         }
1280
1281         /* Split threads for reading and writing */
1282         if (child_id < n_readers) {
1283                 size_t left = data * read_bias;
1284                 char rb[8001];
1285
1286                 while (left) {
1287                         int res;
1288
1289                         res = recv(self->cfd, rb,
1290                                    left > sizeof(rb) ? sizeof(rb) : left, 0);
1291
1292                         EXPECT_GE(res, 0);
1293                         left -= res;
1294                 }
1295         } else {
1296                 size_t left = data * write_bias;
1297
1298                 while (left) {
1299                         int res;
1300
1301                         ASSERT_EQ(lseek(fd, 0, SEEK_SET), 0);
1302                         if (sendpg)
1303                                 res = sendfile(self->fd, fd, NULL,
1304                                                left > file_sz ? file_sz : left);
1305                         else
1306                                 res = send(self->fd, buf,
1307                                            left > file_sz ? file_sz : left, 0);
1308
1309                         EXPECT_GE(res, 0);
1310                         left -= res;
1311                 }
1312         }
1313 }
1314
1315 TEST_F(tls, mutliproc_even)
1316 {
1317         test_mutliproc(_metadata, self, false, 6, 6);
1318 }
1319
1320 TEST_F(tls, mutliproc_readers)
1321 {
1322         test_mutliproc(_metadata, self, false, 4, 12);
1323 }
1324
1325 TEST_F(tls, mutliproc_writers)
1326 {
1327         test_mutliproc(_metadata, self, false, 10, 2);
1328 }
1329
1330 TEST_F(tls, mutliproc_sendpage_even)
1331 {
1332         test_mutliproc(_metadata, self, true, 6, 6);
1333 }
1334
1335 TEST_F(tls, mutliproc_sendpage_readers)
1336 {
1337         test_mutliproc(_metadata, self, true, 4, 12);
1338 }
1339
1340 TEST_F(tls, mutliproc_sendpage_writers)
1341 {
1342         test_mutliproc(_metadata, self, true, 10, 2);
1343 }
1344
1345 TEST_F(tls, control_msg)
1346 {
1347         char *test_str = "test_read";
1348         char record_type = 100;
1349         int send_len = 10;
1350         char buf[10];
1351
1352         if (self->notls)
1353                 SKIP(return, "no TLS support");
1354
1355         EXPECT_EQ(tls_send_cmsg(self->fd, record_type, test_str, send_len, 0),
1356                   send_len);
1357         /* Should fail because we didn't provide a control message */
1358         EXPECT_EQ(recv(self->cfd, buf, send_len, 0), -1);
1359
1360         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1361                                 buf, sizeof(buf), MSG_WAITALL | MSG_PEEK),
1362                   send_len);
1363         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1364
1365         /* Recv the message again without MSG_PEEK */
1366         memset(buf, 0, sizeof(buf));
1367
1368         EXPECT_EQ(tls_recv_cmsg(_metadata, self->cfd, record_type,
1369                                 buf, sizeof(buf), MSG_WAITALL),
1370                   send_len);
1371         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1372 }
1373
1374 TEST_F(tls, shutdown)
1375 {
1376         char const *test_str = "test_read";
1377         int send_len = 10;
1378         char buf[10];
1379
1380         ASSERT_EQ(strlen(test_str) + 1, send_len);
1381
1382         EXPECT_EQ(send(self->fd, test_str, send_len, 0), send_len);
1383         EXPECT_NE(recv(self->cfd, buf, send_len, 0), -1);
1384         EXPECT_EQ(memcmp(buf, test_str, send_len), 0);
1385
1386         shutdown(self->fd, SHUT_RDWR);
1387         shutdown(self->cfd, SHUT_RDWR);
1388 }
1389
1390 TEST_F(tls, shutdown_unsent)
1391 {
1392         char const *test_str = "test_read";
1393         int send_len = 10;
1394
1395         EXPECT_EQ(send(self->fd, test_str, send_len, MSG_MORE), send_len);
1396
1397         shutdown(self->fd, SHUT_RDWR);
1398         shutdown(self->cfd, SHUT_RDWR);
1399 }
1400
1401 TEST_F(tls, shutdown_reuse)
1402 {
1403         struct sockaddr_in addr;
1404         int ret;
1405
1406         shutdown(self->fd, SHUT_RDWR);
1407         shutdown(self->cfd, SHUT_RDWR);
1408         close(self->cfd);
1409
1410         addr.sin_family = AF_INET;
1411         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1412         addr.sin_port = 0;
1413
1414         ret = bind(self->fd, &addr, sizeof(addr));
1415         EXPECT_EQ(ret, 0);
1416         ret = listen(self->fd, 10);
1417         EXPECT_EQ(ret, -1);
1418         EXPECT_EQ(errno, EINVAL);
1419
1420         ret = connect(self->fd, &addr, sizeof(addr));
1421         EXPECT_EQ(ret, -1);
1422         EXPECT_EQ(errno, EISCONN);
1423 }
1424
1425 FIXTURE(tls_err)
1426 {
1427         int fd, cfd;
1428         int fd2, cfd2;
1429         bool notls;
1430 };
1431
1432 FIXTURE_VARIANT(tls_err)
1433 {
1434         uint16_t tls_version;
1435 };
1436
1437 FIXTURE_VARIANT_ADD(tls_err, 12_aes_gcm)
1438 {
1439         .tls_version = TLS_1_2_VERSION,
1440 };
1441
1442 FIXTURE_VARIANT_ADD(tls_err, 13_aes_gcm)
1443 {
1444         .tls_version = TLS_1_3_VERSION,
1445 };
1446
1447 FIXTURE_SETUP(tls_err)
1448 {
1449         struct tls_crypto_info_keys tls12;
1450         int ret;
1451
1452         tls_crypto_info_init(variant->tls_version, TLS_CIPHER_AES_GCM_128,
1453                              &tls12);
1454
1455         ulp_sock_pair(_metadata, &self->fd, &self->cfd, &self->notls);
1456         ulp_sock_pair(_metadata, &self->fd2, &self->cfd2, &self->notls);
1457         if (self->notls)
1458                 return;
1459
1460         ret = setsockopt(self->fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1461         ASSERT_EQ(ret, 0);
1462
1463         ret = setsockopt(self->cfd2, SOL_TLS, TLS_RX, &tls12, tls12.len);
1464         ASSERT_EQ(ret, 0);
1465 }
1466
1467 FIXTURE_TEARDOWN(tls_err)
1468 {
1469         close(self->fd);
1470         close(self->cfd);
1471         close(self->fd2);
1472         close(self->cfd2);
1473 }
1474
1475 TEST_F(tls_err, bad_rec)
1476 {
1477         char buf[64];
1478
1479         if (self->notls)
1480                 SKIP(return, "no TLS support");
1481
1482         memset(buf, 0x55, sizeof(buf));
1483         EXPECT_EQ(send(self->fd2, buf, sizeof(buf), 0), sizeof(buf));
1484         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1485         EXPECT_EQ(errno, EMSGSIZE);
1486         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), MSG_DONTWAIT), -1);
1487         EXPECT_EQ(errno, EAGAIN);
1488 }
1489
1490 TEST_F(tls_err, bad_auth)
1491 {
1492         char buf[128];
1493         int n;
1494
1495         if (self->notls)
1496                 SKIP(return, "no TLS support");
1497
1498         memrnd(buf, sizeof(buf) / 2);
1499         EXPECT_EQ(send(self->fd, buf, sizeof(buf) / 2, 0), sizeof(buf) / 2);
1500         n = recv(self->cfd, buf, sizeof(buf), 0);
1501         EXPECT_GT(n, sizeof(buf) / 2);
1502
1503         buf[n - 1]++;
1504
1505         EXPECT_EQ(send(self->fd2, buf, n, 0), n);
1506         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1507         EXPECT_EQ(errno, EBADMSG);
1508         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1509         EXPECT_EQ(errno, EBADMSG);
1510 }
1511
1512 TEST_F(tls_err, bad_in_large_read)
1513 {
1514         char txt[3][64];
1515         char cip[3][128];
1516         char buf[3 * 128];
1517         int i, n;
1518
1519         if (self->notls)
1520                 SKIP(return, "no TLS support");
1521
1522         /* Put 3 records in the sockets */
1523         for (i = 0; i < 3; i++) {
1524                 memrnd(txt[i], sizeof(txt[i]));
1525                 EXPECT_EQ(send(self->fd, txt[i], sizeof(txt[i]), 0),
1526                           sizeof(txt[i]));
1527                 n = recv(self->cfd, cip[i], sizeof(cip[i]), 0);
1528                 EXPECT_GT(n, sizeof(txt[i]));
1529                 /* Break the third message */
1530                 if (i == 2)
1531                         cip[2][n - 1]++;
1532                 EXPECT_EQ(send(self->fd2, cip[i], n, 0), n);
1533         }
1534
1535         /* We should be able to receive the first two messages */
1536         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt[0]) * 2);
1537         EXPECT_EQ(memcmp(buf, txt[0], sizeof(txt[0])), 0);
1538         EXPECT_EQ(memcmp(buf + sizeof(txt[0]), txt[1], sizeof(txt[1])), 0);
1539         /* Third mesasge is bad */
1540         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1541         EXPECT_EQ(errno, EBADMSG);
1542         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1543         EXPECT_EQ(errno, EBADMSG);
1544 }
1545
1546 TEST_F(tls_err, bad_cmsg)
1547 {
1548         char *test_str = "test_read";
1549         int send_len = 10;
1550         char cip[128];
1551         char buf[128];
1552         char txt[64];
1553         int n;
1554
1555         if (self->notls)
1556                 SKIP(return, "no TLS support");
1557
1558         /* Queue up one data record */
1559         memrnd(txt, sizeof(txt));
1560         EXPECT_EQ(send(self->fd, txt, sizeof(txt), 0), sizeof(txt));
1561         n = recv(self->cfd, cip, sizeof(cip), 0);
1562         EXPECT_GT(n, sizeof(txt));
1563         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1564
1565         EXPECT_EQ(tls_send_cmsg(self->fd, 100, test_str, send_len, 0), 10);
1566         n = recv(self->cfd, cip, sizeof(cip), 0);
1567         cip[n - 1]++; /* Break it */
1568         EXPECT_GT(n, send_len);
1569         EXPECT_EQ(send(self->fd2, cip, n, 0), n);
1570
1571         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), sizeof(txt));
1572         EXPECT_EQ(memcmp(buf, txt, sizeof(txt)), 0);
1573         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1574         EXPECT_EQ(errno, EBADMSG);
1575         EXPECT_EQ(recv(self->cfd2, buf, sizeof(buf), 0), -1);
1576         EXPECT_EQ(errno, EBADMSG);
1577 }
1578
1579 TEST(non_established) {
1580         struct tls12_crypto_info_aes_gcm_256 tls12;
1581         struct sockaddr_in addr;
1582         int sfd, ret, fd;
1583         socklen_t len;
1584
1585         len = sizeof(addr);
1586
1587         memset(&tls12, 0, sizeof(tls12));
1588         tls12.info.version = TLS_1_2_VERSION;
1589         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1590
1591         addr.sin_family = AF_INET;
1592         addr.sin_addr.s_addr = htonl(INADDR_ANY);
1593         addr.sin_port = 0;
1594
1595         fd = socket(AF_INET, SOCK_STREAM, 0);
1596         sfd = socket(AF_INET, SOCK_STREAM, 0);
1597
1598         ret = bind(sfd, &addr, sizeof(addr));
1599         ASSERT_EQ(ret, 0);
1600         ret = listen(sfd, 10);
1601         ASSERT_EQ(ret, 0);
1602
1603         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1604         EXPECT_EQ(ret, -1);
1605         /* TLS ULP not supported */
1606         if (errno == ENOENT)
1607                 return;
1608         EXPECT_EQ(errno, ENOTCONN);
1609
1610         ret = setsockopt(sfd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1611         EXPECT_EQ(ret, -1);
1612         EXPECT_EQ(errno, ENOTCONN);
1613
1614         ret = getsockname(sfd, &addr, &len);
1615         ASSERT_EQ(ret, 0);
1616
1617         ret = connect(fd, &addr, sizeof(addr));
1618         ASSERT_EQ(ret, 0);
1619
1620         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1621         ASSERT_EQ(ret, 0);
1622
1623         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1624         EXPECT_EQ(ret, -1);
1625         EXPECT_EQ(errno, EEXIST);
1626
1627         close(fd);
1628         close(sfd);
1629 }
1630
1631 TEST(keysizes) {
1632         struct tls12_crypto_info_aes_gcm_256 tls12;
1633         int ret, fd, cfd;
1634         bool notls;
1635
1636         memset(&tls12, 0, sizeof(tls12));
1637         tls12.info.version = TLS_1_2_VERSION;
1638         tls12.info.cipher_type = TLS_CIPHER_AES_GCM_256;
1639
1640         ulp_sock_pair(_metadata, &fd, &cfd, &notls);
1641
1642         if (!notls) {
1643                 ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12,
1644                                  sizeof(tls12));
1645                 EXPECT_EQ(ret, 0);
1646
1647                 ret = setsockopt(cfd, SOL_TLS, TLS_RX, &tls12,
1648                                  sizeof(tls12));
1649                 EXPECT_EQ(ret, 0);
1650         }
1651
1652         close(fd);
1653         close(cfd);
1654 }
1655
1656 TEST(tls_v6ops) {
1657         struct tls_crypto_info_keys tls12;
1658         struct sockaddr_in6 addr, addr2;
1659         int sfd, ret, fd;
1660         socklen_t len, len2;
1661
1662         tls_crypto_info_init(TLS_1_2_VERSION, TLS_CIPHER_AES_GCM_128, &tls12);
1663
1664         addr.sin6_family = AF_INET6;
1665         addr.sin6_addr = in6addr_any;
1666         addr.sin6_port = 0;
1667
1668         fd = socket(AF_INET6, SOCK_STREAM, 0);
1669         sfd = socket(AF_INET6, SOCK_STREAM, 0);
1670
1671         ret = bind(sfd, &addr, sizeof(addr));
1672         ASSERT_EQ(ret, 0);
1673         ret = listen(sfd, 10);
1674         ASSERT_EQ(ret, 0);
1675
1676         len = sizeof(addr);
1677         ret = getsockname(sfd, &addr, &len);
1678         ASSERT_EQ(ret, 0);
1679
1680         ret = connect(fd, &addr, sizeof(addr));
1681         ASSERT_EQ(ret, 0);
1682
1683         len = sizeof(addr);
1684         ret = getsockname(fd, &addr, &len);
1685         ASSERT_EQ(ret, 0);
1686
1687         ret = setsockopt(fd, IPPROTO_TCP, TCP_ULP, "tls", sizeof("tls"));
1688         if (ret) {
1689                 ASSERT_EQ(errno, ENOENT);
1690                 SKIP(return, "no TLS support");
1691         }
1692         ASSERT_EQ(ret, 0);
1693
1694         ret = setsockopt(fd, SOL_TLS, TLS_TX, &tls12, tls12.len);
1695         ASSERT_EQ(ret, 0);
1696
1697         ret = setsockopt(fd, SOL_TLS, TLS_RX, &tls12, tls12.len);
1698         ASSERT_EQ(ret, 0);
1699
1700         len2 = sizeof(addr2);
1701         ret = getsockname(fd, &addr2, &len2);
1702         ASSERT_EQ(ret, 0);
1703
1704         EXPECT_EQ(len2, len);
1705         EXPECT_EQ(memcmp(&addr, &addr2, len), 0);
1706
1707         close(fd);
1708         close(sfd);
1709 }
1710
1711 TEST_HARNESS_MAIN