OSDN Git Service

swiotlb: do not panic on mapping failures
[uclinux-h8/linux.git] / net / tls / tls_sw.c
1 /*
2  * Copyright (c) 2016-2017, Mellanox Technologies. All rights reserved.
3  * Copyright (c) 2016-2017, Dave Watson <davejwatson@fb.com>. All rights reserved.
4  * Copyright (c) 2016-2017, Lance Chao <lancerchao@fb.com>. All rights reserved.
5  * Copyright (c) 2016, Fridolin Pokorny <fridolin.pokorny@gmail.com>. All rights reserved.
6  * Copyright (c) 2016, Nikos Mavrogiannopoulos <nmav@gnutls.org>. All rights reserved.
7  *
8  * This software is available to you under a choice of one of two
9  * licenses.  You may choose to be licensed under the terms of the GNU
10  * General Public License (GPL) Version 2, available from the file
11  * COPYING in the main directory of this source tree, or the
12  * OpenIB.org BSD license below:
13  *
14  *     Redistribution and use in source and binary forms, with or
15  *     without modification, are permitted provided that the following
16  *     conditions are met:
17  *
18  *      - Redistributions of source code must retain the above
19  *        copyright notice, this list of conditions and the following
20  *        disclaimer.
21  *
22  *      - Redistributions in binary form must reproduce the above
23  *        copyright notice, this list of conditions and the following
24  *        disclaimer in the documentation and/or other materials
25  *        provided with the distribution.
26  *
27  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
28  * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
29  * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
30  * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS
31  * BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN
32  * ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
33  * CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
34  * SOFTWARE.
35  */
36
37 #include <linux/sched/signal.h>
38 #include <linux/module.h>
39 #include <crypto/aead.h>
40
41 #include <net/strparser.h>
42 #include <net/tls.h>
43
44 #define MAX_IV_SIZE     TLS_CIPHER_AES_GCM_128_IV_SIZE
45
46 static int tls_do_decryption(struct sock *sk,
47                              struct scatterlist *sgin,
48                              struct scatterlist *sgout,
49                              char *iv_recv,
50                              size_t data_len,
51                              struct aead_request *aead_req)
52 {
53         struct tls_context *tls_ctx = tls_get_ctx(sk);
54         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
55         int ret;
56
57         aead_request_set_tfm(aead_req, ctx->aead_recv);
58         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
59         aead_request_set_crypt(aead_req, sgin, sgout,
60                                data_len + tls_ctx->rx.tag_size,
61                                (u8 *)iv_recv);
62         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
63                                   crypto_req_done, &ctx->async_wait);
64
65         ret = crypto_wait_req(crypto_aead_decrypt(aead_req), &ctx->async_wait);
66         return ret;
67 }
68
69 static void trim_sg(struct sock *sk, struct scatterlist *sg,
70                     int *sg_num_elem, unsigned int *sg_size, int target_size)
71 {
72         int i = *sg_num_elem - 1;
73         int trim = *sg_size - target_size;
74
75         if (trim <= 0) {
76                 WARN_ON(trim < 0);
77                 return;
78         }
79
80         *sg_size = target_size;
81         while (trim >= sg[i].length) {
82                 trim -= sg[i].length;
83                 sk_mem_uncharge(sk, sg[i].length);
84                 put_page(sg_page(&sg[i]));
85                 i--;
86
87                 if (i < 0)
88                         goto out;
89         }
90
91         sg[i].length -= trim;
92         sk_mem_uncharge(sk, trim);
93
94 out:
95         *sg_num_elem = i + 1;
96 }
97
98 static void trim_both_sgl(struct sock *sk, int target_size)
99 {
100         struct tls_context *tls_ctx = tls_get_ctx(sk);
101         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
102
103         trim_sg(sk, ctx->sg_plaintext_data,
104                 &ctx->sg_plaintext_num_elem,
105                 &ctx->sg_plaintext_size,
106                 target_size);
107
108         if (target_size > 0)
109                 target_size += tls_ctx->tx.overhead_size;
110
111         trim_sg(sk, ctx->sg_encrypted_data,
112                 &ctx->sg_encrypted_num_elem,
113                 &ctx->sg_encrypted_size,
114                 target_size);
115 }
116
117 static int alloc_encrypted_sg(struct sock *sk, int len)
118 {
119         struct tls_context *tls_ctx = tls_get_ctx(sk);
120         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
121         int rc = 0;
122
123         rc = sk_alloc_sg(sk, len,
124                          ctx->sg_encrypted_data, 0,
125                          &ctx->sg_encrypted_num_elem,
126                          &ctx->sg_encrypted_size, 0);
127
128         return rc;
129 }
130
131 static int alloc_plaintext_sg(struct sock *sk, int len)
132 {
133         struct tls_context *tls_ctx = tls_get_ctx(sk);
134         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
135         int rc = 0;
136
137         rc = sk_alloc_sg(sk, len, ctx->sg_plaintext_data, 0,
138                          &ctx->sg_plaintext_num_elem, &ctx->sg_plaintext_size,
139                          tls_ctx->pending_open_record_frags);
140
141         return rc;
142 }
143
144 static void free_sg(struct sock *sk, struct scatterlist *sg,
145                     int *sg_num_elem, unsigned int *sg_size)
146 {
147         int i, n = *sg_num_elem;
148
149         for (i = 0; i < n; ++i) {
150                 sk_mem_uncharge(sk, sg[i].length);
151                 put_page(sg_page(&sg[i]));
152         }
153         *sg_num_elem = 0;
154         *sg_size = 0;
155 }
156
157 static void tls_free_both_sg(struct sock *sk)
158 {
159         struct tls_context *tls_ctx = tls_get_ctx(sk);
160         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
161
162         free_sg(sk, ctx->sg_encrypted_data, &ctx->sg_encrypted_num_elem,
163                 &ctx->sg_encrypted_size);
164
165         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
166                 &ctx->sg_plaintext_size);
167 }
168
169 static int tls_do_encryption(struct tls_context *tls_ctx,
170                              struct tls_sw_context_tx *ctx,
171                              struct aead_request *aead_req,
172                              size_t data_len)
173 {
174         int rc;
175
176         ctx->sg_encrypted_data[0].offset += tls_ctx->tx.prepend_size;
177         ctx->sg_encrypted_data[0].length -= tls_ctx->tx.prepend_size;
178
179         aead_request_set_tfm(aead_req, ctx->aead_send);
180         aead_request_set_ad(aead_req, TLS_AAD_SPACE_SIZE);
181         aead_request_set_crypt(aead_req, ctx->sg_aead_in, ctx->sg_aead_out,
182                                data_len, tls_ctx->tx.iv);
183
184         aead_request_set_callback(aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
185                                   crypto_req_done, &ctx->async_wait);
186
187         rc = crypto_wait_req(crypto_aead_encrypt(aead_req), &ctx->async_wait);
188
189         ctx->sg_encrypted_data[0].offset -= tls_ctx->tx.prepend_size;
190         ctx->sg_encrypted_data[0].length += tls_ctx->tx.prepend_size;
191
192         return rc;
193 }
194
195 static int tls_push_record(struct sock *sk, int flags,
196                            unsigned char record_type)
197 {
198         struct tls_context *tls_ctx = tls_get_ctx(sk);
199         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
200         struct aead_request *req;
201         int rc;
202
203         req = aead_request_alloc(ctx->aead_send, sk->sk_allocation);
204         if (!req)
205                 return -ENOMEM;
206
207         sg_mark_end(ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem - 1);
208         sg_mark_end(ctx->sg_encrypted_data + ctx->sg_encrypted_num_elem - 1);
209
210         tls_make_aad(ctx->aad_space, ctx->sg_plaintext_size,
211                      tls_ctx->tx.rec_seq, tls_ctx->tx.rec_seq_size,
212                      record_type);
213
214         tls_fill_prepend(tls_ctx,
215                          page_address(sg_page(&ctx->sg_encrypted_data[0])) +
216                          ctx->sg_encrypted_data[0].offset,
217                          ctx->sg_plaintext_size, record_type);
218
219         tls_ctx->pending_open_record_frags = 0;
220         set_bit(TLS_PENDING_CLOSED_RECORD, &tls_ctx->flags);
221
222         rc = tls_do_encryption(tls_ctx, ctx, req, ctx->sg_plaintext_size);
223         if (rc < 0) {
224                 /* If we are called from write_space and
225                  * we fail, we need to set this SOCK_NOSPACE
226                  * to trigger another write_space in the future.
227                  */
228                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
229                 goto out_req;
230         }
231
232         free_sg(sk, ctx->sg_plaintext_data, &ctx->sg_plaintext_num_elem,
233                 &ctx->sg_plaintext_size);
234
235         ctx->sg_encrypted_num_elem = 0;
236         ctx->sg_encrypted_size = 0;
237
238         /* Only pass through MSG_DONTWAIT and MSG_NOSIGNAL flags */
239         rc = tls_push_sg(sk, tls_ctx, ctx->sg_encrypted_data, 0, flags);
240         if (rc < 0 && rc != -EAGAIN)
241                 tls_err_abort(sk, EBADMSG);
242
243         tls_advance_record_sn(sk, &tls_ctx->tx);
244 out_req:
245         aead_request_free(req);
246         return rc;
247 }
248
249 static int tls_sw_push_pending_record(struct sock *sk, int flags)
250 {
251         return tls_push_record(sk, flags, TLS_RECORD_TYPE_DATA);
252 }
253
254 static int zerocopy_from_iter(struct sock *sk, struct iov_iter *from,
255                               int length, int *pages_used,
256                               unsigned int *size_used,
257                               struct scatterlist *to, int to_max_pages,
258                               bool charge)
259 {
260         struct page *pages[MAX_SKB_FRAGS];
261
262         size_t offset;
263         ssize_t copied, use;
264         int i = 0;
265         unsigned int size = *size_used;
266         int num_elem = *pages_used;
267         int rc = 0;
268         int maxpages;
269
270         while (length > 0) {
271                 i = 0;
272                 maxpages = to_max_pages - num_elem;
273                 if (maxpages == 0) {
274                         rc = -EFAULT;
275                         goto out;
276                 }
277                 copied = iov_iter_get_pages(from, pages,
278                                             length,
279                                             maxpages, &offset);
280                 if (copied <= 0) {
281                         rc = -EFAULT;
282                         goto out;
283                 }
284
285                 iov_iter_advance(from, copied);
286
287                 length -= copied;
288                 size += copied;
289                 while (copied) {
290                         use = min_t(int, copied, PAGE_SIZE - offset);
291
292                         sg_set_page(&to[num_elem],
293                                     pages[i], use, offset);
294                         sg_unmark_end(&to[num_elem]);
295                         if (charge)
296                                 sk_mem_charge(sk, use);
297
298                         offset = 0;
299                         copied -= use;
300
301                         ++i;
302                         ++num_elem;
303                 }
304         }
305
306         /* Mark the end in the last sg entry if newly added */
307         if (num_elem > *pages_used)
308                 sg_mark_end(&to[num_elem - 1]);
309 out:
310         if (rc)
311                 iov_iter_revert(from, size - *size_used);
312         *size_used = size;
313         *pages_used = num_elem;
314
315         return rc;
316 }
317
318 static int memcopy_from_iter(struct sock *sk, struct iov_iter *from,
319                              int bytes)
320 {
321         struct tls_context *tls_ctx = tls_get_ctx(sk);
322         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
323         struct scatterlist *sg = ctx->sg_plaintext_data;
324         int copy, i, rc = 0;
325
326         for (i = tls_ctx->pending_open_record_frags;
327              i < ctx->sg_plaintext_num_elem; ++i) {
328                 copy = sg[i].length;
329                 if (copy_from_iter(
330                                 page_address(sg_page(&sg[i])) + sg[i].offset,
331                                 copy, from) != copy) {
332                         rc = -EFAULT;
333                         goto out;
334                 }
335                 bytes -= copy;
336
337                 ++tls_ctx->pending_open_record_frags;
338
339                 if (!bytes)
340                         break;
341         }
342
343 out:
344         return rc;
345 }
346
347 int tls_sw_sendmsg(struct sock *sk, struct msghdr *msg, size_t size)
348 {
349         struct tls_context *tls_ctx = tls_get_ctx(sk);
350         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
351         int ret = 0;
352         int required_size;
353         long timeo = sock_sndtimeo(sk, msg->msg_flags & MSG_DONTWAIT);
354         bool eor = !(msg->msg_flags & MSG_MORE);
355         size_t try_to_copy, copied = 0;
356         unsigned char record_type = TLS_RECORD_TYPE_DATA;
357         int record_room;
358         bool full_record;
359         int orig_size;
360         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
361
362         if (msg->msg_flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL))
363                 return -ENOTSUPP;
364
365         lock_sock(sk);
366
367         if (tls_complete_pending_work(sk, tls_ctx, msg->msg_flags, &timeo))
368                 goto send_end;
369
370         if (unlikely(msg->msg_controllen)) {
371                 ret = tls_proccess_cmsg(sk, msg, &record_type);
372                 if (ret)
373                         goto send_end;
374         }
375
376         while (msg_data_left(msg)) {
377                 if (sk->sk_err) {
378                         ret = -sk->sk_err;
379                         goto send_end;
380                 }
381
382                 orig_size = ctx->sg_plaintext_size;
383                 full_record = false;
384                 try_to_copy = msg_data_left(msg);
385                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
386                 if (try_to_copy >= record_room) {
387                         try_to_copy = record_room;
388                         full_record = true;
389                 }
390
391                 required_size = ctx->sg_plaintext_size + try_to_copy +
392                                 tls_ctx->tx.overhead_size;
393
394                 if (!sk_stream_memory_free(sk))
395                         goto wait_for_sndbuf;
396 alloc_encrypted:
397                 ret = alloc_encrypted_sg(sk, required_size);
398                 if (ret) {
399                         if (ret != -ENOSPC)
400                                 goto wait_for_memory;
401
402                         /* Adjust try_to_copy according to the amount that was
403                          * actually allocated. The difference is due
404                          * to max sg elements limit
405                          */
406                         try_to_copy -= required_size - ctx->sg_encrypted_size;
407                         full_record = true;
408                 }
409                 if (!is_kvec && (full_record || eor)) {
410                         ret = zerocopy_from_iter(sk, &msg->msg_iter,
411                                 try_to_copy, &ctx->sg_plaintext_num_elem,
412                                 &ctx->sg_plaintext_size,
413                                 ctx->sg_plaintext_data,
414                                 ARRAY_SIZE(ctx->sg_plaintext_data),
415                                 true);
416                         if (ret)
417                                 goto fallback_to_reg_send;
418
419                         copied += try_to_copy;
420                         ret = tls_push_record(sk, msg->msg_flags, record_type);
421                         if (ret)
422                                 goto send_end;
423                         continue;
424
425 fallback_to_reg_send:
426                         trim_sg(sk, ctx->sg_plaintext_data,
427                                 &ctx->sg_plaintext_num_elem,
428                                 &ctx->sg_plaintext_size,
429                                 orig_size);
430                 }
431
432                 required_size = ctx->sg_plaintext_size + try_to_copy;
433 alloc_plaintext:
434                 ret = alloc_plaintext_sg(sk, required_size);
435                 if (ret) {
436                         if (ret != -ENOSPC)
437                                 goto wait_for_memory;
438
439                         /* Adjust try_to_copy according to the amount that was
440                          * actually allocated. The difference is due
441                          * to max sg elements limit
442                          */
443                         try_to_copy -= required_size - ctx->sg_plaintext_size;
444                         full_record = true;
445
446                         trim_sg(sk, ctx->sg_encrypted_data,
447                                 &ctx->sg_encrypted_num_elem,
448                                 &ctx->sg_encrypted_size,
449                                 ctx->sg_plaintext_size +
450                                 tls_ctx->tx.overhead_size);
451                 }
452
453                 ret = memcopy_from_iter(sk, &msg->msg_iter, try_to_copy);
454                 if (ret)
455                         goto trim_sgl;
456
457                 copied += try_to_copy;
458                 if (full_record || eor) {
459 push_record:
460                         ret = tls_push_record(sk, msg->msg_flags, record_type);
461                         if (ret) {
462                                 if (ret == -ENOMEM)
463                                         goto wait_for_memory;
464
465                                 goto send_end;
466                         }
467                 }
468
469                 continue;
470
471 wait_for_sndbuf:
472                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
473 wait_for_memory:
474                 ret = sk_stream_wait_memory(sk, &timeo);
475                 if (ret) {
476 trim_sgl:
477                         trim_both_sgl(sk, orig_size);
478                         goto send_end;
479                 }
480
481                 if (tls_is_pending_closed_record(tls_ctx))
482                         goto push_record;
483
484                 if (ctx->sg_encrypted_size < required_size)
485                         goto alloc_encrypted;
486
487                 goto alloc_plaintext;
488         }
489
490 send_end:
491         ret = sk_stream_error(sk, msg->msg_flags, ret);
492
493         release_sock(sk);
494         return copied ? copied : ret;
495 }
496
497 int tls_sw_sendpage(struct sock *sk, struct page *page,
498                     int offset, size_t size, int flags)
499 {
500         struct tls_context *tls_ctx = tls_get_ctx(sk);
501         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
502         int ret = 0;
503         long timeo = sock_sndtimeo(sk, flags & MSG_DONTWAIT);
504         bool eor;
505         size_t orig_size = size;
506         unsigned char record_type = TLS_RECORD_TYPE_DATA;
507         struct scatterlist *sg;
508         bool full_record;
509         int record_room;
510
511         if (flags & ~(MSG_MORE | MSG_DONTWAIT | MSG_NOSIGNAL |
512                       MSG_SENDPAGE_NOTLAST))
513                 return -ENOTSUPP;
514
515         /* No MSG_EOR from splice, only look at MSG_MORE */
516         eor = !(flags & (MSG_MORE | MSG_SENDPAGE_NOTLAST));
517
518         lock_sock(sk);
519
520         sk_clear_bit(SOCKWQ_ASYNC_NOSPACE, sk);
521
522         if (tls_complete_pending_work(sk, tls_ctx, flags, &timeo))
523                 goto sendpage_end;
524
525         /* Call the sk_stream functions to manage the sndbuf mem. */
526         while (size > 0) {
527                 size_t copy, required_size;
528
529                 if (sk->sk_err) {
530                         ret = -sk->sk_err;
531                         goto sendpage_end;
532                 }
533
534                 full_record = false;
535                 record_room = TLS_MAX_PAYLOAD_SIZE - ctx->sg_plaintext_size;
536                 copy = size;
537                 if (copy >= record_room) {
538                         copy = record_room;
539                         full_record = true;
540                 }
541                 required_size = ctx->sg_plaintext_size + copy +
542                               tls_ctx->tx.overhead_size;
543
544                 if (!sk_stream_memory_free(sk))
545                         goto wait_for_sndbuf;
546 alloc_payload:
547                 ret = alloc_encrypted_sg(sk, required_size);
548                 if (ret) {
549                         if (ret != -ENOSPC)
550                                 goto wait_for_memory;
551
552                         /* Adjust copy according to the amount that was
553                          * actually allocated. The difference is due
554                          * to max sg elements limit
555                          */
556                         copy -= required_size - ctx->sg_plaintext_size;
557                         full_record = true;
558                 }
559
560                 get_page(page);
561                 sg = ctx->sg_plaintext_data + ctx->sg_plaintext_num_elem;
562                 sg_set_page(sg, page, copy, offset);
563                 sg_unmark_end(sg);
564
565                 ctx->sg_plaintext_num_elem++;
566
567                 sk_mem_charge(sk, copy);
568                 offset += copy;
569                 size -= copy;
570                 ctx->sg_plaintext_size += copy;
571                 tls_ctx->pending_open_record_frags = ctx->sg_plaintext_num_elem;
572
573                 if (full_record || eor ||
574                     ctx->sg_plaintext_num_elem ==
575                     ARRAY_SIZE(ctx->sg_plaintext_data)) {
576 push_record:
577                         ret = tls_push_record(sk, flags, record_type);
578                         if (ret) {
579                                 if (ret == -ENOMEM)
580                                         goto wait_for_memory;
581
582                                 goto sendpage_end;
583                         }
584                 }
585                 continue;
586 wait_for_sndbuf:
587                 set_bit(SOCK_NOSPACE, &sk->sk_socket->flags);
588 wait_for_memory:
589                 ret = sk_stream_wait_memory(sk, &timeo);
590                 if (ret) {
591                         trim_both_sgl(sk, ctx->sg_plaintext_size);
592                         goto sendpage_end;
593                 }
594
595                 if (tls_is_pending_closed_record(tls_ctx))
596                         goto push_record;
597
598                 goto alloc_payload;
599         }
600
601 sendpage_end:
602         if (orig_size > size)
603                 ret = orig_size - size;
604         else
605                 ret = sk_stream_error(sk, flags, ret);
606
607         release_sock(sk);
608         return ret;
609 }
610
611 static struct sk_buff *tls_wait_data(struct sock *sk, int flags,
612                                      long timeo, int *err)
613 {
614         struct tls_context *tls_ctx = tls_get_ctx(sk);
615         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
616         struct sk_buff *skb;
617         DEFINE_WAIT_FUNC(wait, woken_wake_function);
618
619         while (!(skb = ctx->recv_pkt)) {
620                 if (sk->sk_err) {
621                         *err = sock_error(sk);
622                         return NULL;
623                 }
624
625                 if (sk->sk_shutdown & RCV_SHUTDOWN)
626                         return NULL;
627
628                 if (sock_flag(sk, SOCK_DONE))
629                         return NULL;
630
631                 if ((flags & MSG_DONTWAIT) || !timeo) {
632                         *err = -EAGAIN;
633                         return NULL;
634                 }
635
636                 add_wait_queue(sk_sleep(sk), &wait);
637                 sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
638                 sk_wait_event(sk, &timeo, ctx->recv_pkt != skb, &wait);
639                 sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
640                 remove_wait_queue(sk_sleep(sk), &wait);
641
642                 /* Handle signals */
643                 if (signal_pending(current)) {
644                         *err = sock_intr_errno(timeo);
645                         return NULL;
646                 }
647         }
648
649         return skb;
650 }
651
652 /* This function decrypts the input skb into either out_iov or in out_sg
653  * or in skb buffers itself. The input parameter 'zc' indicates if
654  * zero-copy mode needs to be tried or not. With zero-copy mode, either
655  * out_iov or out_sg must be non-NULL. In case both out_iov and out_sg are
656  * NULL, then the decryption happens inside skb buffers itself, i.e.
657  * zero-copy gets disabled and 'zc' is updated.
658  */
659
660 static int decrypt_internal(struct sock *sk, struct sk_buff *skb,
661                             struct iov_iter *out_iov,
662                             struct scatterlist *out_sg,
663                             int *chunk, bool *zc)
664 {
665         struct tls_context *tls_ctx = tls_get_ctx(sk);
666         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
667         struct strp_msg *rxm = strp_msg(skb);
668         int n_sgin, n_sgout, nsg, mem_size, aead_size, err, pages = 0;
669         struct aead_request *aead_req;
670         struct sk_buff *unused;
671         u8 *aad, *iv, *mem = NULL;
672         struct scatterlist *sgin = NULL;
673         struct scatterlist *sgout = NULL;
674         const int data_len = rxm->full_len - tls_ctx->rx.overhead_size;
675
676         if (*zc && (out_iov || out_sg)) {
677                 if (out_iov)
678                         n_sgout = iov_iter_npages(out_iov, INT_MAX) + 1;
679                 else
680                         n_sgout = sg_nents(out_sg);
681         } else {
682                 n_sgout = 0;
683                 *zc = false;
684         }
685
686         n_sgin = skb_cow_data(skb, 0, &unused);
687         if (n_sgin < 1)
688                 return -EBADMSG;
689
690         /* Increment to accommodate AAD */
691         n_sgin = n_sgin + 1;
692
693         nsg = n_sgin + n_sgout;
694
695         aead_size = sizeof(*aead_req) + crypto_aead_reqsize(ctx->aead_recv);
696         mem_size = aead_size + (nsg * sizeof(struct scatterlist));
697         mem_size = mem_size + TLS_AAD_SPACE_SIZE;
698         mem_size = mem_size + crypto_aead_ivsize(ctx->aead_recv);
699
700         /* Allocate a single block of memory which contains
701          * aead_req || sgin[] || sgout[] || aad || iv.
702          * This order achieves correct alignment for aead_req, sgin, sgout.
703          */
704         mem = kmalloc(mem_size, sk->sk_allocation);
705         if (!mem)
706                 return -ENOMEM;
707
708         /* Segment the allocated memory */
709         aead_req = (struct aead_request *)mem;
710         sgin = (struct scatterlist *)(mem + aead_size);
711         sgout = sgin + n_sgin;
712         aad = (u8 *)(sgout + n_sgout);
713         iv = aad + TLS_AAD_SPACE_SIZE;
714
715         /* Prepare IV */
716         err = skb_copy_bits(skb, rxm->offset + TLS_HEADER_SIZE,
717                             iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
718                             tls_ctx->rx.iv_size);
719         if (err < 0) {
720                 kfree(mem);
721                 return err;
722         }
723         memcpy(iv, tls_ctx->rx.iv, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
724
725         /* Prepare AAD */
726         tls_make_aad(aad, rxm->full_len - tls_ctx->rx.overhead_size,
727                      tls_ctx->rx.rec_seq, tls_ctx->rx.rec_seq_size,
728                      ctx->control);
729
730         /* Prepare sgin */
731         sg_init_table(sgin, n_sgin);
732         sg_set_buf(&sgin[0], aad, TLS_AAD_SPACE_SIZE);
733         err = skb_to_sgvec(skb, &sgin[1],
734                            rxm->offset + tls_ctx->rx.prepend_size,
735                            rxm->full_len - tls_ctx->rx.prepend_size);
736         if (err < 0) {
737                 kfree(mem);
738                 return err;
739         }
740
741         if (n_sgout) {
742                 if (out_iov) {
743                         sg_init_table(sgout, n_sgout);
744                         sg_set_buf(&sgout[0], aad, TLS_AAD_SPACE_SIZE);
745
746                         *chunk = 0;
747                         err = zerocopy_from_iter(sk, out_iov, data_len, &pages,
748                                                  chunk, &sgout[1],
749                                                  (n_sgout - 1), false);
750                         if (err < 0)
751                                 goto fallback_to_reg_recv;
752                 } else if (out_sg) {
753                         memcpy(sgout, out_sg, n_sgout * sizeof(*sgout));
754                 } else {
755                         goto fallback_to_reg_recv;
756                 }
757         } else {
758 fallback_to_reg_recv:
759                 sgout = sgin;
760                 pages = 0;
761                 *chunk = 0;
762                 *zc = false;
763         }
764
765         /* Prepare and submit AEAD request */
766         err = tls_do_decryption(sk, sgin, sgout, iv, data_len, aead_req);
767
768         /* Release the pages in case iov was mapped to pages */
769         for (; pages > 0; pages--)
770                 put_page(sg_page(&sgout[pages]));
771
772         kfree(mem);
773         return err;
774 }
775
776 static int decrypt_skb_update(struct sock *sk, struct sk_buff *skb,
777                               struct iov_iter *dest, int *chunk, bool *zc)
778 {
779         struct tls_context *tls_ctx = tls_get_ctx(sk);
780         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
781         struct strp_msg *rxm = strp_msg(skb);
782         int err = 0;
783
784 #ifdef CONFIG_TLS_DEVICE
785         err = tls_device_decrypted(sk, skb);
786         if (err < 0)
787                 return err;
788 #endif
789         if (!ctx->decrypted) {
790                 err = decrypt_internal(sk, skb, dest, NULL, chunk, zc);
791                 if (err < 0)
792                         return err;
793         } else {
794                 *zc = false;
795         }
796
797         rxm->offset += tls_ctx->rx.prepend_size;
798         rxm->full_len -= tls_ctx->rx.overhead_size;
799         tls_advance_record_sn(sk, &tls_ctx->rx);
800         ctx->decrypted = true;
801         ctx->saved_data_ready(sk);
802
803         return err;
804 }
805
806 int decrypt_skb(struct sock *sk, struct sk_buff *skb,
807                 struct scatterlist *sgout)
808 {
809         bool zc = true;
810         int chunk;
811
812         return decrypt_internal(sk, skb, NULL, sgout, &chunk, &zc);
813 }
814
815 static bool tls_sw_advance_skb(struct sock *sk, struct sk_buff *skb,
816                                unsigned int len)
817 {
818         struct tls_context *tls_ctx = tls_get_ctx(sk);
819         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
820         struct strp_msg *rxm = strp_msg(skb);
821
822         if (len < rxm->full_len) {
823                 rxm->offset += len;
824                 rxm->full_len -= len;
825
826                 return false;
827         }
828
829         /* Finished with message */
830         ctx->recv_pkt = NULL;
831         kfree_skb(skb);
832         __strp_unpause(&ctx->strp);
833
834         return true;
835 }
836
837 int tls_sw_recvmsg(struct sock *sk,
838                    struct msghdr *msg,
839                    size_t len,
840                    int nonblock,
841                    int flags,
842                    int *addr_len)
843 {
844         struct tls_context *tls_ctx = tls_get_ctx(sk);
845         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
846         unsigned char control;
847         struct strp_msg *rxm;
848         struct sk_buff *skb;
849         ssize_t copied = 0;
850         bool cmsg = false;
851         int target, err = 0;
852         long timeo;
853         bool is_kvec = msg->msg_iter.type & ITER_KVEC;
854
855         flags |= nonblock;
856
857         if (unlikely(flags & MSG_ERRQUEUE))
858                 return sock_recv_errqueue(sk, msg, len, SOL_IP, IP_RECVERR);
859
860         lock_sock(sk);
861
862         target = sock_rcvlowat(sk, flags & MSG_WAITALL, len);
863         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
864         do {
865                 bool zc = false;
866                 int chunk = 0;
867
868                 skb = tls_wait_data(sk, flags, timeo, &err);
869                 if (!skb)
870                         goto recv_end;
871
872                 rxm = strp_msg(skb);
873                 if (!cmsg) {
874                         int cerr;
875
876                         cerr = put_cmsg(msg, SOL_TLS, TLS_GET_RECORD_TYPE,
877                                         sizeof(ctx->control), &ctx->control);
878                         cmsg = true;
879                         control = ctx->control;
880                         if (ctx->control != TLS_RECORD_TYPE_DATA) {
881                                 if (cerr || msg->msg_flags & MSG_CTRUNC) {
882                                         err = -EIO;
883                                         goto recv_end;
884                                 }
885                         }
886                 } else if (control != ctx->control) {
887                         goto recv_end;
888                 }
889
890                 if (!ctx->decrypted) {
891                         int to_copy = rxm->full_len - tls_ctx->rx.overhead_size;
892
893                         if (!is_kvec && to_copy <= len &&
894                             likely(!(flags & MSG_PEEK)))
895                                 zc = true;
896
897                         err = decrypt_skb_update(sk, skb, &msg->msg_iter,
898                                                  &chunk, &zc);
899                         if (err < 0) {
900                                 tls_err_abort(sk, EBADMSG);
901                                 goto recv_end;
902                         }
903                         ctx->decrypted = true;
904                 }
905
906                 if (!zc) {
907                         chunk = min_t(unsigned int, rxm->full_len, len);
908                         err = skb_copy_datagram_msg(skb, rxm->offset, msg,
909                                                     chunk);
910                         if (err < 0)
911                                 goto recv_end;
912                 }
913
914                 copied += chunk;
915                 len -= chunk;
916                 if (likely(!(flags & MSG_PEEK))) {
917                         u8 control = ctx->control;
918
919                         if (tls_sw_advance_skb(sk, skb, chunk)) {
920                                 /* Return full control message to
921                                  * userspace before trying to parse
922                                  * another message type
923                                  */
924                                 msg->msg_flags |= MSG_EOR;
925                                 if (control != TLS_RECORD_TYPE_DATA)
926                                         goto recv_end;
927                         }
928                 }
929                 /* If we have a new message from strparser, continue now. */
930                 if (copied >= target && !ctx->recv_pkt)
931                         break;
932         } while (len);
933
934 recv_end:
935         release_sock(sk);
936         return copied ? : err;
937 }
938
939 ssize_t tls_sw_splice_read(struct socket *sock,  loff_t *ppos,
940                            struct pipe_inode_info *pipe,
941                            size_t len, unsigned int flags)
942 {
943         struct tls_context *tls_ctx = tls_get_ctx(sock->sk);
944         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
945         struct strp_msg *rxm = NULL;
946         struct sock *sk = sock->sk;
947         struct sk_buff *skb;
948         ssize_t copied = 0;
949         int err = 0;
950         long timeo;
951         int chunk;
952         bool zc = false;
953
954         lock_sock(sk);
955
956         timeo = sock_rcvtimeo(sk, flags & MSG_DONTWAIT);
957
958         skb = tls_wait_data(sk, flags, timeo, &err);
959         if (!skb)
960                 goto splice_read_end;
961
962         /* splice does not support reading control messages */
963         if (ctx->control != TLS_RECORD_TYPE_DATA) {
964                 err = -ENOTSUPP;
965                 goto splice_read_end;
966         }
967
968         if (!ctx->decrypted) {
969                 err = decrypt_skb_update(sk, skb, NULL, &chunk, &zc);
970
971                 if (err < 0) {
972                         tls_err_abort(sk, EBADMSG);
973                         goto splice_read_end;
974                 }
975                 ctx->decrypted = true;
976         }
977         rxm = strp_msg(skb);
978
979         chunk = min_t(unsigned int, rxm->full_len, len);
980         copied = skb_splice_bits(skb, sk, rxm->offset, pipe, chunk, flags);
981         if (copied < 0)
982                 goto splice_read_end;
983
984         if (likely(!(flags & MSG_PEEK)))
985                 tls_sw_advance_skb(sk, skb, copied);
986
987 splice_read_end:
988         release_sock(sk);
989         return copied ? : err;
990 }
991
992 unsigned int tls_sw_poll(struct file *file, struct socket *sock,
993                          struct poll_table_struct *wait)
994 {
995         unsigned int ret;
996         struct sock *sk = sock->sk;
997         struct tls_context *tls_ctx = tls_get_ctx(sk);
998         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
999
1000         /* Grab POLLOUT and POLLHUP from the underlying socket */
1001         ret = ctx->sk_poll(file, sock, wait);
1002
1003         /* Clear POLLIN bits, and set based on recv_pkt */
1004         ret &= ~(POLLIN | POLLRDNORM);
1005         if (ctx->recv_pkt)
1006                 ret |= POLLIN | POLLRDNORM;
1007
1008         return ret;
1009 }
1010
1011 static int tls_read_size(struct strparser *strp, struct sk_buff *skb)
1012 {
1013         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1014         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1015         char header[TLS_HEADER_SIZE + MAX_IV_SIZE];
1016         struct strp_msg *rxm = strp_msg(skb);
1017         size_t cipher_overhead;
1018         size_t data_len = 0;
1019         int ret;
1020
1021         /* Verify that we have a full TLS header, or wait for more data */
1022         if (rxm->offset + tls_ctx->rx.prepend_size > skb->len)
1023                 return 0;
1024
1025         /* Sanity-check size of on-stack buffer. */
1026         if (WARN_ON(tls_ctx->rx.prepend_size > sizeof(header))) {
1027                 ret = -EINVAL;
1028                 goto read_failure;
1029         }
1030
1031         /* Linearize header to local buffer */
1032         ret = skb_copy_bits(skb, rxm->offset, header, tls_ctx->rx.prepend_size);
1033
1034         if (ret < 0)
1035                 goto read_failure;
1036
1037         ctx->control = header[0];
1038
1039         data_len = ((header[4] & 0xFF) | (header[3] << 8));
1040
1041         cipher_overhead = tls_ctx->rx.tag_size + tls_ctx->rx.iv_size;
1042
1043         if (data_len > TLS_MAX_PAYLOAD_SIZE + cipher_overhead) {
1044                 ret = -EMSGSIZE;
1045                 goto read_failure;
1046         }
1047         if (data_len < cipher_overhead) {
1048                 ret = -EBADMSG;
1049                 goto read_failure;
1050         }
1051
1052         if (header[1] != TLS_VERSION_MINOR(tls_ctx->crypto_recv.version) ||
1053             header[2] != TLS_VERSION_MAJOR(tls_ctx->crypto_recv.version)) {
1054                 ret = -EINVAL;
1055                 goto read_failure;
1056         }
1057
1058 #ifdef CONFIG_TLS_DEVICE
1059         handle_device_resync(strp->sk, TCP_SKB_CB(skb)->seq + rxm->offset,
1060                              *(u64*)tls_ctx->rx.rec_seq);
1061 #endif
1062         return data_len + TLS_HEADER_SIZE;
1063
1064 read_failure:
1065         tls_err_abort(strp->sk, ret);
1066
1067         return ret;
1068 }
1069
1070 static void tls_queue(struct strparser *strp, struct sk_buff *skb)
1071 {
1072         struct tls_context *tls_ctx = tls_get_ctx(strp->sk);
1073         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1074
1075         ctx->decrypted = false;
1076
1077         ctx->recv_pkt = skb;
1078         strp_pause(strp);
1079
1080         ctx->saved_data_ready(strp->sk);
1081 }
1082
1083 static void tls_data_ready(struct sock *sk)
1084 {
1085         struct tls_context *tls_ctx = tls_get_ctx(sk);
1086         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1087
1088         strp_data_ready(&ctx->strp);
1089 }
1090
1091 void tls_sw_free_resources_tx(struct sock *sk)
1092 {
1093         struct tls_context *tls_ctx = tls_get_ctx(sk);
1094         struct tls_sw_context_tx *ctx = tls_sw_ctx_tx(tls_ctx);
1095
1096         crypto_free_aead(ctx->aead_send);
1097         tls_free_both_sg(sk);
1098
1099         kfree(ctx);
1100 }
1101
1102 void tls_sw_release_resources_rx(struct sock *sk)
1103 {
1104         struct tls_context *tls_ctx = tls_get_ctx(sk);
1105         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1106
1107         if (ctx->aead_recv) {
1108                 kfree_skb(ctx->recv_pkt);
1109                 ctx->recv_pkt = NULL;
1110                 crypto_free_aead(ctx->aead_recv);
1111                 strp_stop(&ctx->strp);
1112                 write_lock_bh(&sk->sk_callback_lock);
1113                 sk->sk_data_ready = ctx->saved_data_ready;
1114                 write_unlock_bh(&sk->sk_callback_lock);
1115                 release_sock(sk);
1116                 strp_done(&ctx->strp);
1117                 lock_sock(sk);
1118         }
1119 }
1120
1121 void tls_sw_free_resources_rx(struct sock *sk)
1122 {
1123         struct tls_context *tls_ctx = tls_get_ctx(sk);
1124         struct tls_sw_context_rx *ctx = tls_sw_ctx_rx(tls_ctx);
1125
1126         tls_sw_release_resources_rx(sk);
1127
1128         kfree(ctx);
1129 }
1130
1131 int tls_set_sw_offload(struct sock *sk, struct tls_context *ctx, int tx)
1132 {
1133         char keyval[TLS_CIPHER_AES_GCM_128_KEY_SIZE];
1134         struct tls_crypto_info *crypto_info;
1135         struct tls12_crypto_info_aes_gcm_128 *gcm_128_info;
1136         struct tls_sw_context_tx *sw_ctx_tx = NULL;
1137         struct tls_sw_context_rx *sw_ctx_rx = NULL;
1138         struct cipher_context *cctx;
1139         struct crypto_aead **aead;
1140         struct strp_callbacks cb;
1141         u16 nonce_size, tag_size, iv_size, rec_seq_size;
1142         char *iv, *rec_seq;
1143         int rc = 0;
1144
1145         if (!ctx) {
1146                 rc = -EINVAL;
1147                 goto out;
1148         }
1149
1150         if (tx) {
1151                 if (!ctx->priv_ctx_tx) {
1152                         sw_ctx_tx = kzalloc(sizeof(*sw_ctx_tx), GFP_KERNEL);
1153                         if (!sw_ctx_tx) {
1154                                 rc = -ENOMEM;
1155                                 goto out;
1156                         }
1157                         ctx->priv_ctx_tx = sw_ctx_tx;
1158                 } else {
1159                         sw_ctx_tx =
1160                                 (struct tls_sw_context_tx *)ctx->priv_ctx_tx;
1161                 }
1162         } else {
1163                 if (!ctx->priv_ctx_rx) {
1164                         sw_ctx_rx = kzalloc(sizeof(*sw_ctx_rx), GFP_KERNEL);
1165                         if (!sw_ctx_rx) {
1166                                 rc = -ENOMEM;
1167                                 goto out;
1168                         }
1169                         ctx->priv_ctx_rx = sw_ctx_rx;
1170                 } else {
1171                         sw_ctx_rx =
1172                                 (struct tls_sw_context_rx *)ctx->priv_ctx_rx;
1173                 }
1174         }
1175
1176         if (tx) {
1177                 crypto_init_wait(&sw_ctx_tx->async_wait);
1178                 crypto_info = &ctx->crypto_send;
1179                 cctx = &ctx->tx;
1180                 aead = &sw_ctx_tx->aead_send;
1181         } else {
1182                 crypto_init_wait(&sw_ctx_rx->async_wait);
1183                 crypto_info = &ctx->crypto_recv;
1184                 cctx = &ctx->rx;
1185                 aead = &sw_ctx_rx->aead_recv;
1186         }
1187
1188         switch (crypto_info->cipher_type) {
1189         case TLS_CIPHER_AES_GCM_128: {
1190                 nonce_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1191                 tag_size = TLS_CIPHER_AES_GCM_128_TAG_SIZE;
1192                 iv_size = TLS_CIPHER_AES_GCM_128_IV_SIZE;
1193                 iv = ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->iv;
1194                 rec_seq_size = TLS_CIPHER_AES_GCM_128_REC_SEQ_SIZE;
1195                 rec_seq =
1196                  ((struct tls12_crypto_info_aes_gcm_128 *)crypto_info)->rec_seq;
1197                 gcm_128_info =
1198                         (struct tls12_crypto_info_aes_gcm_128 *)crypto_info;
1199                 break;
1200         }
1201         default:
1202                 rc = -EINVAL;
1203                 goto free_priv;
1204         }
1205
1206         /* Sanity-check the IV size for stack allocations. */
1207         if (iv_size > MAX_IV_SIZE || nonce_size > MAX_IV_SIZE) {
1208                 rc = -EINVAL;
1209                 goto free_priv;
1210         }
1211
1212         cctx->prepend_size = TLS_HEADER_SIZE + nonce_size;
1213         cctx->tag_size = tag_size;
1214         cctx->overhead_size = cctx->prepend_size + cctx->tag_size;
1215         cctx->iv_size = iv_size;
1216         cctx->iv = kmalloc(iv_size + TLS_CIPHER_AES_GCM_128_SALT_SIZE,
1217                            GFP_KERNEL);
1218         if (!cctx->iv) {
1219                 rc = -ENOMEM;
1220                 goto free_priv;
1221         }
1222         memcpy(cctx->iv, gcm_128_info->salt, TLS_CIPHER_AES_GCM_128_SALT_SIZE);
1223         memcpy(cctx->iv + TLS_CIPHER_AES_GCM_128_SALT_SIZE, iv, iv_size);
1224         cctx->rec_seq_size = rec_seq_size;
1225         cctx->rec_seq = kmemdup(rec_seq, rec_seq_size, GFP_KERNEL);
1226         if (!cctx->rec_seq) {
1227                 rc = -ENOMEM;
1228                 goto free_iv;
1229         }
1230
1231         if (sw_ctx_tx) {
1232                 sg_init_table(sw_ctx_tx->sg_encrypted_data,
1233                               ARRAY_SIZE(sw_ctx_tx->sg_encrypted_data));
1234                 sg_init_table(sw_ctx_tx->sg_plaintext_data,
1235                               ARRAY_SIZE(sw_ctx_tx->sg_plaintext_data));
1236
1237                 sg_init_table(sw_ctx_tx->sg_aead_in, 2);
1238                 sg_set_buf(&sw_ctx_tx->sg_aead_in[0], sw_ctx_tx->aad_space,
1239                            sizeof(sw_ctx_tx->aad_space));
1240                 sg_unmark_end(&sw_ctx_tx->sg_aead_in[1]);
1241                 sg_chain(sw_ctx_tx->sg_aead_in, 2,
1242                          sw_ctx_tx->sg_plaintext_data);
1243                 sg_init_table(sw_ctx_tx->sg_aead_out, 2);
1244                 sg_set_buf(&sw_ctx_tx->sg_aead_out[0], sw_ctx_tx->aad_space,
1245                            sizeof(sw_ctx_tx->aad_space));
1246                 sg_unmark_end(&sw_ctx_tx->sg_aead_out[1]);
1247                 sg_chain(sw_ctx_tx->sg_aead_out, 2,
1248                          sw_ctx_tx->sg_encrypted_data);
1249         }
1250
1251         if (!*aead) {
1252                 *aead = crypto_alloc_aead("gcm(aes)", 0, 0);
1253                 if (IS_ERR(*aead)) {
1254                         rc = PTR_ERR(*aead);
1255                         *aead = NULL;
1256                         goto free_rec_seq;
1257                 }
1258         }
1259
1260         ctx->push_pending_record = tls_sw_push_pending_record;
1261
1262         memcpy(keyval, gcm_128_info->key, TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1263
1264         rc = crypto_aead_setkey(*aead, keyval,
1265                                 TLS_CIPHER_AES_GCM_128_KEY_SIZE);
1266         if (rc)
1267                 goto free_aead;
1268
1269         rc = crypto_aead_setauthsize(*aead, cctx->tag_size);
1270         if (rc)
1271                 goto free_aead;
1272
1273         if (sw_ctx_rx) {
1274                 /* Set up strparser */
1275                 memset(&cb, 0, sizeof(cb));
1276                 cb.rcv_msg = tls_queue;
1277                 cb.parse_msg = tls_read_size;
1278
1279                 strp_init(&sw_ctx_rx->strp, sk, &cb);
1280
1281                 write_lock_bh(&sk->sk_callback_lock);
1282                 sw_ctx_rx->saved_data_ready = sk->sk_data_ready;
1283                 sk->sk_data_ready = tls_data_ready;
1284                 write_unlock_bh(&sk->sk_callback_lock);
1285
1286                 sw_ctx_rx->sk_poll = sk->sk_socket->ops->poll;
1287
1288                 strp_check_rcv(&sw_ctx_rx->strp);
1289         }
1290
1291         goto out;
1292
1293 free_aead:
1294         crypto_free_aead(*aead);
1295         *aead = NULL;
1296 free_rec_seq:
1297         kfree(cctx->rec_seq);
1298         cctx->rec_seq = NULL;
1299 free_iv:
1300         kfree(cctx->iv);
1301         cctx->iv = NULL;
1302 free_priv:
1303         if (tx) {
1304                 kfree(ctx->priv_ctx_tx);
1305                 ctx->priv_ctx_tx = NULL;
1306         } else {
1307                 kfree(ctx->priv_ctx_rx);
1308                 ctx->priv_ctx_rx = NULL;
1309         }
1310 out:
1311         return rc;
1312 }