OSDN Git Service

crypto: hash - prevent using keyed hashes without setting key
[android-x86/kernel.git] / crypto / algif_aead.c
1 /*
2  * algif_aead: User-space interface for AEAD algorithms
3  *
4  * Copyright (C) 2014, Stephan Mueller <smueller@chronox.de>
5  *
6  * This file provides the user-space API for AEAD ciphers.
7  *
8  * This file is derived from algif_skcipher.c.
9  *
10  * This program is free software; you can redistribute it and/or modify it
11  * under the terms of the GNU General Public License as published by the Free
12  * Software Foundation; either version 2 of the License, or (at your option)
13  * any later version.
14  */
15
16 #include <crypto/internal/aead.h>
17 #include <crypto/scatterwalk.h>
18 #include <crypto/if_alg.h>
19 #include <linux/init.h>
20 #include <linux/list.h>
21 #include <linux/kernel.h>
22 #include <linux/mm.h>
23 #include <linux/module.h>
24 #include <linux/net.h>
25 #include <net/sock.h>
26
27 struct aead_sg_list {
28         unsigned int cur;
29         struct scatterlist sg[ALG_MAX_PAGES];
30 };
31
32 struct aead_async_rsgl {
33         struct af_alg_sgl sgl;
34         struct list_head list;
35 };
36
37 struct aead_async_req {
38         struct scatterlist *tsgl;
39         struct aead_async_rsgl first_rsgl;
40         struct list_head list;
41         struct kiocb *iocb;
42         struct sock *sk;
43         unsigned int tsgls;
44         char iv[];
45 };
46
47 struct aead_tfm {
48         struct crypto_aead *aead;
49         bool has_key;
50 };
51
52 struct aead_ctx {
53         struct aead_sg_list tsgl;
54         struct aead_async_rsgl first_rsgl;
55         struct list_head list;
56
57         void *iv;
58
59         struct af_alg_completion completion;
60
61         unsigned long used;
62
63         unsigned int len;
64         bool more;
65         bool merge;
66         bool enc;
67
68         size_t aead_assoclen;
69         struct aead_request aead_req;
70 };
71
72 static inline int aead_sndbuf(struct sock *sk)
73 {
74         struct alg_sock *ask = alg_sk(sk);
75         struct aead_ctx *ctx = ask->private;
76
77         return max_t(int, max_t(int, sk->sk_sndbuf & PAGE_MASK, PAGE_SIZE) -
78                           ctx->used, 0);
79 }
80
81 static inline bool aead_writable(struct sock *sk)
82 {
83         return PAGE_SIZE <= aead_sndbuf(sk);
84 }
85
86 static inline bool aead_sufficient_data(struct aead_ctx *ctx)
87 {
88         unsigned as = crypto_aead_authsize(crypto_aead_reqtfm(&ctx->aead_req));
89
90         /*
91          * The minimum amount of memory needed for an AEAD cipher is
92          * the AAD and in case of decryption the tag.
93          */
94         return ctx->used >= ctx->aead_assoclen + (ctx->enc ? 0 : as);
95 }
96
97 static void aead_reset_ctx(struct aead_ctx *ctx)
98 {
99         struct aead_sg_list *sgl = &ctx->tsgl;
100
101         sg_init_table(sgl->sg, ALG_MAX_PAGES);
102         sgl->cur = 0;
103         ctx->used = 0;
104         ctx->more = 0;
105         ctx->merge = 0;
106 }
107
108 static void aead_put_sgl(struct sock *sk)
109 {
110         struct alg_sock *ask = alg_sk(sk);
111         struct aead_ctx *ctx = ask->private;
112         struct aead_sg_list *sgl = &ctx->tsgl;
113         struct scatterlist *sg = sgl->sg;
114         unsigned int i;
115
116         for (i = 0; i < sgl->cur; i++) {
117                 if (!sg_page(sg + i))
118                         continue;
119
120                 put_page(sg_page(sg + i));
121                 sg_assign_page(sg + i, NULL);
122         }
123         aead_reset_ctx(ctx);
124 }
125
126 static void aead_wmem_wakeup(struct sock *sk)
127 {
128         struct socket_wq *wq;
129
130         if (!aead_writable(sk))
131                 return;
132
133         rcu_read_lock();
134         wq = rcu_dereference(sk->sk_wq);
135         if (skwq_has_sleeper(wq))
136                 wake_up_interruptible_sync_poll(&wq->wait, POLLIN |
137                                                            POLLRDNORM |
138                                                            POLLRDBAND);
139         sk_wake_async(sk, SOCK_WAKE_WAITD, POLL_IN);
140         rcu_read_unlock();
141 }
142
143 static int aead_wait_for_data(struct sock *sk, unsigned flags)
144 {
145         struct alg_sock *ask = alg_sk(sk);
146         struct aead_ctx *ctx = ask->private;
147         long timeout;
148         DEFINE_WAIT(wait);
149         int err = -ERESTARTSYS;
150
151         if (flags & MSG_DONTWAIT)
152                 return -EAGAIN;
153
154         sk_set_bit(SOCKWQ_ASYNC_WAITDATA, sk);
155
156         for (;;) {
157                 if (signal_pending(current))
158                         break;
159                 prepare_to_wait(sk_sleep(sk), &wait, TASK_INTERRUPTIBLE);
160                 timeout = MAX_SCHEDULE_TIMEOUT;
161                 if (sk_wait_event(sk, &timeout, !ctx->more)) {
162                         err = 0;
163                         break;
164                 }
165         }
166         finish_wait(sk_sleep(sk), &wait);
167
168         sk_clear_bit(SOCKWQ_ASYNC_WAITDATA, sk);
169
170         return err;
171 }
172
173 static void aead_data_wakeup(struct sock *sk)
174 {
175         struct alg_sock *ask = alg_sk(sk);
176         struct aead_ctx *ctx = ask->private;
177         struct socket_wq *wq;
178
179         if (ctx->more)
180                 return;
181         if (!ctx->used)
182                 return;
183
184         rcu_read_lock();
185         wq = rcu_dereference(sk->sk_wq);
186         if (skwq_has_sleeper(wq))
187                 wake_up_interruptible_sync_poll(&wq->wait, POLLOUT |
188                                                            POLLRDNORM |
189                                                            POLLRDBAND);
190         sk_wake_async(sk, SOCK_WAKE_SPACE, POLL_OUT);
191         rcu_read_unlock();
192 }
193
194 static int aead_sendmsg(struct socket *sock, struct msghdr *msg, size_t size)
195 {
196         struct sock *sk = sock->sk;
197         struct alg_sock *ask = alg_sk(sk);
198         struct aead_ctx *ctx = ask->private;
199         unsigned ivsize =
200                 crypto_aead_ivsize(crypto_aead_reqtfm(&ctx->aead_req));
201         struct aead_sg_list *sgl = &ctx->tsgl;
202         struct af_alg_control con = {};
203         long copied = 0;
204         bool enc = 0;
205         bool init = 0;
206         int err = -EINVAL;
207
208         if (msg->msg_controllen) {
209                 err = af_alg_cmsg_send(msg, &con);
210                 if (err)
211                         return err;
212
213                 init = 1;
214                 switch (con.op) {
215                 case ALG_OP_ENCRYPT:
216                         enc = 1;
217                         break;
218                 case ALG_OP_DECRYPT:
219                         enc = 0;
220                         break;
221                 default:
222                         return -EINVAL;
223                 }
224
225                 if (con.iv && con.iv->ivlen != ivsize)
226                         return -EINVAL;
227         }
228
229         lock_sock(sk);
230         if (!ctx->more && ctx->used)
231                 goto unlock;
232
233         if (init) {
234                 ctx->enc = enc;
235                 if (con.iv)
236                         memcpy(ctx->iv, con.iv->iv, ivsize);
237
238                 ctx->aead_assoclen = con.aead_assoclen;
239         }
240
241         while (size) {
242                 size_t len = size;
243                 struct scatterlist *sg = NULL;
244
245                 /* use the existing memory in an allocated page */
246                 if (ctx->merge) {
247                         sg = sgl->sg + sgl->cur - 1;
248                         len = min_t(unsigned long, len,
249                                     PAGE_SIZE - sg->offset - sg->length);
250                         err = memcpy_from_msg(page_address(sg_page(sg)) +
251                                               sg->offset + sg->length,
252                                               msg, len);
253                         if (err)
254                                 goto unlock;
255
256                         sg->length += len;
257                         ctx->merge = (sg->offset + sg->length) &
258                                      (PAGE_SIZE - 1);
259
260                         ctx->used += len;
261                         copied += len;
262                         size -= len;
263                         continue;
264                 }
265
266                 if (!aead_writable(sk)) {
267                         /* user space sent too much data */
268                         aead_put_sgl(sk);
269                         err = -EMSGSIZE;
270                         goto unlock;
271                 }
272
273                 /* allocate a new page */
274                 len = min_t(unsigned long, size, aead_sndbuf(sk));
275                 while (len) {
276                         size_t plen = 0;
277
278                         if (sgl->cur >= ALG_MAX_PAGES) {
279                                 aead_put_sgl(sk);
280                                 err = -E2BIG;
281                                 goto unlock;
282                         }
283
284                         sg = sgl->sg + sgl->cur;
285                         plen = min_t(size_t, len, PAGE_SIZE);
286
287                         sg_assign_page(sg, alloc_page(GFP_KERNEL));
288                         err = -ENOMEM;
289                         if (!sg_page(sg))
290                                 goto unlock;
291
292                         err = memcpy_from_msg(page_address(sg_page(sg)),
293                                               msg, plen);
294                         if (err) {
295                                 __free_page(sg_page(sg));
296                                 sg_assign_page(sg, NULL);
297                                 goto unlock;
298                         }
299
300                         sg->offset = 0;
301                         sg->length = plen;
302                         len -= plen;
303                         ctx->used += plen;
304                         copied += plen;
305                         sgl->cur++;
306                         size -= plen;
307                         ctx->merge = plen & (PAGE_SIZE - 1);
308                 }
309         }
310
311         err = 0;
312
313         ctx->more = msg->msg_flags & MSG_MORE;
314         if (!ctx->more && !aead_sufficient_data(ctx)) {
315                 aead_put_sgl(sk);
316                 err = -EMSGSIZE;
317         }
318
319 unlock:
320         aead_data_wakeup(sk);
321         release_sock(sk);
322
323         return err ?: copied;
324 }
325
326 static ssize_t aead_sendpage(struct socket *sock, struct page *page,
327                              int offset, size_t size, int flags)
328 {
329         struct sock *sk = sock->sk;
330         struct alg_sock *ask = alg_sk(sk);
331         struct aead_ctx *ctx = ask->private;
332         struct aead_sg_list *sgl = &ctx->tsgl;
333         int err = -EINVAL;
334
335         if (flags & MSG_SENDPAGE_NOTLAST)
336                 flags |= MSG_MORE;
337
338         if (sgl->cur >= ALG_MAX_PAGES)
339                 return -E2BIG;
340
341         lock_sock(sk);
342         if (!ctx->more && ctx->used)
343                 goto unlock;
344
345         if (!size)
346                 goto done;
347
348         if (!aead_writable(sk)) {
349                 /* user space sent too much data */
350                 aead_put_sgl(sk);
351                 err = -EMSGSIZE;
352                 goto unlock;
353         }
354
355         ctx->merge = 0;
356
357         get_page(page);
358         sg_set_page(sgl->sg + sgl->cur, page, size, offset);
359         sgl->cur++;
360         ctx->used += size;
361
362         err = 0;
363
364 done:
365         ctx->more = flags & MSG_MORE;
366         if (!ctx->more && !aead_sufficient_data(ctx)) {
367                 aead_put_sgl(sk);
368                 err = -EMSGSIZE;
369         }
370
371 unlock:
372         aead_data_wakeup(sk);
373         release_sock(sk);
374
375         return err ?: size;
376 }
377
378 #define GET_ASYM_REQ(req, tfm) (struct aead_async_req *) \
379                 ((char *)req + sizeof(struct aead_request) + \
380                  crypto_aead_reqsize(tfm))
381
382  #define GET_REQ_SIZE(tfm) sizeof(struct aead_async_req) + \
383         crypto_aead_reqsize(tfm) + crypto_aead_ivsize(tfm) + \
384         sizeof(struct aead_request)
385
386 static void aead_async_cb(struct crypto_async_request *_req, int err)
387 {
388         struct aead_request *req = _req->data;
389         struct crypto_aead *tfm = crypto_aead_reqtfm(req);
390         struct aead_async_req *areq = GET_ASYM_REQ(req, tfm);
391         struct sock *sk = areq->sk;
392         struct scatterlist *sg = areq->tsgl;
393         struct aead_async_rsgl *rsgl;
394         struct kiocb *iocb = areq->iocb;
395         unsigned int i, reqlen = GET_REQ_SIZE(tfm);
396
397         list_for_each_entry(rsgl, &areq->list, list) {
398                 af_alg_free_sg(&rsgl->sgl);
399                 if (rsgl != &areq->first_rsgl)
400                         sock_kfree_s(sk, rsgl, sizeof(*rsgl));
401         }
402
403         for (i = 0; i < areq->tsgls; i++)
404                 put_page(sg_page(sg + i));
405
406         sock_kfree_s(sk, areq->tsgl, sizeof(*areq->tsgl) * areq->tsgls);
407         sock_kfree_s(sk, req, reqlen);
408         __sock_put(sk);
409         iocb->ki_complete(iocb, err, err);
410 }
411
412 static int aead_recvmsg_async(struct socket *sock, struct msghdr *msg,
413                               int flags)
414 {
415         struct sock *sk = sock->sk;
416         struct alg_sock *ask = alg_sk(sk);
417         struct aead_ctx *ctx = ask->private;
418         struct crypto_aead *tfm = crypto_aead_reqtfm(&ctx->aead_req);
419         struct aead_async_req *areq;
420         struct aead_request *req = NULL;
421         struct aead_sg_list *sgl = &ctx->tsgl;
422         struct aead_async_rsgl *last_rsgl = NULL, *rsgl;
423         unsigned int as = crypto_aead_authsize(tfm);
424         unsigned int i, reqlen = GET_REQ_SIZE(tfm);
425         int err = -ENOMEM;
426         unsigned long used;
427         size_t outlen = 0;
428         size_t usedpages = 0;
429
430         lock_sock(sk);
431         if (ctx->more) {
432                 err = aead_wait_for_data(sk, flags);
433                 if (err)
434                         goto unlock;
435         }
436
437         if (!aead_sufficient_data(ctx))
438                 goto unlock;
439
440         used = ctx->used;
441         if (ctx->enc)
442                 outlen = used + as;
443         else
444                 outlen = used - as;
445
446         req = sock_kmalloc(sk, reqlen, GFP_KERNEL);
447         if (unlikely(!req))
448                 goto unlock;
449
450         areq = GET_ASYM_REQ(req, tfm);
451         memset(&areq->first_rsgl, '\0', sizeof(areq->first_rsgl));
452         INIT_LIST_HEAD(&areq->list);
453         areq->iocb = msg->msg_iocb;
454         areq->sk = sk;
455         memcpy(areq->iv, ctx->iv, crypto_aead_ivsize(tfm));
456         aead_request_set_tfm(req, tfm);
457         aead_request_set_ad(req, ctx->aead_assoclen);
458         aead_request_set_callback(req, CRYPTO_TFM_REQ_MAY_BACKLOG,
459                                   aead_async_cb, req);
460         used -= ctx->aead_assoclen;
461
462         /* take over all tx sgls from ctx */
463         areq->tsgl = sock_kmalloc(sk, sizeof(*areq->tsgl) * sgl->cur,
464                                   GFP_KERNEL);
465         if (unlikely(!areq->tsgl))
466                 goto free;
467
468         sg_init_table(areq->tsgl, sgl->cur);
469         for (i = 0; i < sgl->cur; i++)
470                 sg_set_page(&areq->tsgl[i], sg_page(&sgl->sg[i]),
471                             sgl->sg[i].length, sgl->sg[i].offset);
472
473         areq->tsgls = sgl->cur;
474
475         /* create rx sgls */
476         while (outlen > usedpages && iov_iter_count(&msg->msg_iter)) {
477                 size_t seglen = min_t(size_t, iov_iter_count(&msg->msg_iter),
478                                       (outlen - usedpages));
479
480                 if (list_empty(&areq->list)) {
481                         rsgl = &areq->first_rsgl;
482
483                 } else {
484                         rsgl = sock_kmalloc(sk, sizeof(*rsgl), GFP_KERNEL);
485                         if (unlikely(!rsgl)) {
486                                 err = -ENOMEM;
487                                 goto free;
488                         }
489                 }
490                 rsgl->sgl.npages = 0;
491                 list_add_tail(&rsgl->list, &areq->list);
492
493                 /* make one iovec available as scatterlist */
494                 err = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, seglen);
495                 if (err < 0)
496                         goto free;
497
498                 usedpages += err;
499
500                 /* chain the new scatterlist with previous one */
501                 if (last_rsgl)
502                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
503
504                 last_rsgl = rsgl;
505
506                 iov_iter_advance(&msg->msg_iter, err);
507         }
508
509         /* ensure output buffer is sufficiently large */
510         if (usedpages < outlen) {
511                 err = -EINVAL;
512                 goto unlock;
513         }
514
515         aead_request_set_crypt(req, areq->tsgl, areq->first_rsgl.sgl.sg, used,
516                                areq->iv);
517         err = ctx->enc ? crypto_aead_encrypt(req) : crypto_aead_decrypt(req);
518         if (err) {
519                 if (err == -EINPROGRESS) {
520                         sock_hold(sk);
521                         err = -EIOCBQUEUED;
522                         aead_reset_ctx(ctx);
523                         goto unlock;
524                 } else if (err == -EBADMSG) {
525                         aead_put_sgl(sk);
526                 }
527                 goto free;
528         }
529         aead_put_sgl(sk);
530
531 free:
532         list_for_each_entry(rsgl, &areq->list, list) {
533                 af_alg_free_sg(&rsgl->sgl);
534                 if (rsgl != &areq->first_rsgl)
535                         sock_kfree_s(sk, rsgl, sizeof(*rsgl));
536         }
537         if (areq->tsgl)
538                 sock_kfree_s(sk, areq->tsgl, sizeof(*areq->tsgl) * areq->tsgls);
539         if (req)
540                 sock_kfree_s(sk, req, reqlen);
541 unlock:
542         aead_wmem_wakeup(sk);
543         release_sock(sk);
544         return err ? err : outlen;
545 }
546
547 static int aead_recvmsg_sync(struct socket *sock, struct msghdr *msg, int flags)
548 {
549         struct sock *sk = sock->sk;
550         struct alg_sock *ask = alg_sk(sk);
551         struct aead_ctx *ctx = ask->private;
552         unsigned as = crypto_aead_authsize(crypto_aead_reqtfm(&ctx->aead_req));
553         struct aead_sg_list *sgl = &ctx->tsgl;
554         struct aead_async_rsgl *last_rsgl = NULL;
555         struct aead_async_rsgl *rsgl, *tmp;
556         int err = -EINVAL;
557         unsigned long used = 0;
558         size_t outlen = 0;
559         size_t usedpages = 0;
560
561         lock_sock(sk);
562
563         /*
564          * AEAD memory structure: For encryption, the tag is appended to the
565          * ciphertext which implies that the memory allocated for the ciphertext
566          * must be increased by the tag length. For decryption, the tag
567          * is expected to be concatenated to the ciphertext. The plaintext
568          * therefore has a memory size of the ciphertext minus the tag length.
569          *
570          * The memory structure for cipher operation has the following
571          * structure:
572          *      AEAD encryption input:  assoc data || plaintext
573          *      AEAD encryption output: cipherntext || auth tag
574          *      AEAD decryption input:  assoc data || ciphertext || auth tag
575          *      AEAD decryption output: plaintext
576          */
577
578         if (ctx->more) {
579                 err = aead_wait_for_data(sk, flags);
580                 if (err)
581                         goto unlock;
582         }
583
584         /* data length provided by caller via sendmsg/sendpage */
585         used = ctx->used;
586
587         /*
588          * Make sure sufficient data is present -- note, the same check is
589          * is also present in sendmsg/sendpage. The checks in sendpage/sendmsg
590          * shall provide an information to the data sender that something is
591          * wrong, but they are irrelevant to maintain the kernel integrity.
592          * We need this check here too in case user space decides to not honor
593          * the error message in sendmsg/sendpage and still call recvmsg. This
594          * check here protects the kernel integrity.
595          */
596         if (!aead_sufficient_data(ctx))
597                 goto unlock;
598
599         /*
600          * Calculate the minimum output buffer size holding the result of the
601          * cipher operation. When encrypting data, the receiving buffer is
602          * larger by the tag length compared to the input buffer as the
603          * encryption operation generates the tag. For decryption, the input
604          * buffer provides the tag which is consumed resulting in only the
605          * plaintext without a buffer for the tag returned to the caller.
606          */
607         if (ctx->enc)
608                 outlen = used + as;
609         else
610                 outlen = used - as;
611
612         /*
613          * The cipher operation input data is reduced by the associated data
614          * length as this data is processed separately later on.
615          */
616         used -= ctx->aead_assoclen;
617
618         /* convert iovecs of output buffers into scatterlists */
619         while (outlen > usedpages && iov_iter_count(&msg->msg_iter)) {
620                 size_t seglen = min_t(size_t, iov_iter_count(&msg->msg_iter),
621                                       (outlen - usedpages));
622
623                 if (list_empty(&ctx->list)) {
624                         rsgl = &ctx->first_rsgl;
625                 } else {
626                         rsgl = sock_kmalloc(sk, sizeof(*rsgl), GFP_KERNEL);
627                         if (unlikely(!rsgl)) {
628                                 err = -ENOMEM;
629                                 goto unlock;
630                         }
631                 }
632                 rsgl->sgl.npages = 0;
633                 list_add_tail(&rsgl->list, &ctx->list);
634
635                 /* make one iovec available as scatterlist */
636                 err = af_alg_make_sg(&rsgl->sgl, &msg->msg_iter, seglen);
637                 if (err < 0)
638                         goto unlock;
639                 usedpages += err;
640                 /* chain the new scatterlist with previous one */
641                 if (last_rsgl)
642                         af_alg_link_sg(&last_rsgl->sgl, &rsgl->sgl);
643
644                 last_rsgl = rsgl;
645
646                 iov_iter_advance(&msg->msg_iter, err);
647         }
648
649         /* ensure output buffer is sufficiently large */
650         if (usedpages < outlen) {
651                 err = -EINVAL;
652                 goto unlock;
653         }
654
655         sg_mark_end(sgl->sg + sgl->cur - 1);
656         aead_request_set_crypt(&ctx->aead_req, sgl->sg, ctx->first_rsgl.sgl.sg,
657                                used, ctx->iv);
658         aead_request_set_ad(&ctx->aead_req, ctx->aead_assoclen);
659
660         err = af_alg_wait_for_completion(ctx->enc ?
661                                          crypto_aead_encrypt(&ctx->aead_req) :
662                                          crypto_aead_decrypt(&ctx->aead_req),
663                                          &ctx->completion);
664
665         if (err) {
666                 /* EBADMSG implies a valid cipher operation took place */
667                 if (err == -EBADMSG)
668                         aead_put_sgl(sk);
669
670                 goto unlock;
671         }
672
673         aead_put_sgl(sk);
674         err = 0;
675
676 unlock:
677         list_for_each_entry_safe(rsgl, tmp, &ctx->list, list) {
678                 af_alg_free_sg(&rsgl->sgl);
679                 list_del(&rsgl->list);
680                 if (rsgl != &ctx->first_rsgl)
681                         sock_kfree_s(sk, rsgl, sizeof(*rsgl));
682         }
683         INIT_LIST_HEAD(&ctx->list);
684         aead_wmem_wakeup(sk);
685         release_sock(sk);
686
687         return err ? err : outlen;
688 }
689
690 static int aead_recvmsg(struct socket *sock, struct msghdr *msg, size_t ignored,
691                         int flags)
692 {
693         return (msg->msg_iocb && !is_sync_kiocb(msg->msg_iocb)) ?
694                 aead_recvmsg_async(sock, msg, flags) :
695                 aead_recvmsg_sync(sock, msg, flags);
696 }
697
698 static unsigned int aead_poll(struct file *file, struct socket *sock,
699                               poll_table *wait)
700 {
701         struct sock *sk = sock->sk;
702         struct alg_sock *ask = alg_sk(sk);
703         struct aead_ctx *ctx = ask->private;
704         unsigned int mask;
705
706         sock_poll_wait(file, sk_sleep(sk), wait);
707         mask = 0;
708
709         if (!ctx->more)
710                 mask |= POLLIN | POLLRDNORM;
711
712         if (aead_writable(sk))
713                 mask |= POLLOUT | POLLWRNORM | POLLWRBAND;
714
715         return mask;
716 }
717
718 static struct proto_ops algif_aead_ops = {
719         .family         =       PF_ALG,
720
721         .connect        =       sock_no_connect,
722         .socketpair     =       sock_no_socketpair,
723         .getname        =       sock_no_getname,
724         .ioctl          =       sock_no_ioctl,
725         .listen         =       sock_no_listen,
726         .shutdown       =       sock_no_shutdown,
727         .getsockopt     =       sock_no_getsockopt,
728         .mmap           =       sock_no_mmap,
729         .bind           =       sock_no_bind,
730         .accept         =       sock_no_accept,
731         .setsockopt     =       sock_no_setsockopt,
732
733         .release        =       af_alg_release,
734         .sendmsg        =       aead_sendmsg,
735         .sendpage       =       aead_sendpage,
736         .recvmsg        =       aead_recvmsg,
737         .poll           =       aead_poll,
738 };
739
740 static int aead_check_key(struct socket *sock)
741 {
742         int err = 0;
743         struct sock *psk;
744         struct alg_sock *pask;
745         struct aead_tfm *tfm;
746         struct sock *sk = sock->sk;
747         struct alg_sock *ask = alg_sk(sk);
748
749         lock_sock(sk);
750         if (ask->refcnt)
751                 goto unlock_child;
752
753         psk = ask->parent;
754         pask = alg_sk(ask->parent);
755         tfm = pask->private;
756
757         err = -ENOKEY;
758         lock_sock_nested(psk, SINGLE_DEPTH_NESTING);
759         if (!tfm->has_key)
760                 goto unlock;
761
762         if (!pask->refcnt++)
763                 sock_hold(psk);
764
765         ask->refcnt = 1;
766         sock_put(psk);
767
768         err = 0;
769
770 unlock:
771         release_sock(psk);
772 unlock_child:
773         release_sock(sk);
774
775         return err;
776 }
777
778 static int aead_sendmsg_nokey(struct socket *sock, struct msghdr *msg,
779                                   size_t size)
780 {
781         int err;
782
783         err = aead_check_key(sock);
784         if (err)
785                 return err;
786
787         return aead_sendmsg(sock, msg, size);
788 }
789
790 static ssize_t aead_sendpage_nokey(struct socket *sock, struct page *page,
791                                        int offset, size_t size, int flags)
792 {
793         int err;
794
795         err = aead_check_key(sock);
796         if (err)
797                 return err;
798
799         return aead_sendpage(sock, page, offset, size, flags);
800 }
801
802 static int aead_recvmsg_nokey(struct socket *sock, struct msghdr *msg,
803                                   size_t ignored, int flags)
804 {
805         int err;
806
807         err = aead_check_key(sock);
808         if (err)
809                 return err;
810
811         return aead_recvmsg(sock, msg, ignored, flags);
812 }
813
814 static struct proto_ops algif_aead_ops_nokey = {
815         .family         =       PF_ALG,
816
817         .connect        =       sock_no_connect,
818         .socketpair     =       sock_no_socketpair,
819         .getname        =       sock_no_getname,
820         .ioctl          =       sock_no_ioctl,
821         .listen         =       sock_no_listen,
822         .shutdown       =       sock_no_shutdown,
823         .getsockopt     =       sock_no_getsockopt,
824         .mmap           =       sock_no_mmap,
825         .bind           =       sock_no_bind,
826         .accept         =       sock_no_accept,
827         .setsockopt     =       sock_no_setsockopt,
828
829         .release        =       af_alg_release,
830         .sendmsg        =       aead_sendmsg_nokey,
831         .sendpage       =       aead_sendpage_nokey,
832         .recvmsg        =       aead_recvmsg_nokey,
833         .poll           =       aead_poll,
834 };
835
836 static void *aead_bind(const char *name, u32 type, u32 mask)
837 {
838         struct aead_tfm *tfm;
839         struct crypto_aead *aead;
840
841         tfm = kzalloc(sizeof(*tfm), GFP_KERNEL);
842         if (!tfm)
843                 return ERR_PTR(-ENOMEM);
844
845         aead = crypto_alloc_aead(name, type, mask);
846         if (IS_ERR(aead)) {
847                 kfree(tfm);
848                 return ERR_CAST(aead);
849         }
850
851         tfm->aead = aead;
852
853         return tfm;
854 }
855
856 static void aead_release(void *private)
857 {
858         struct aead_tfm *tfm = private;
859
860         crypto_free_aead(tfm->aead);
861         kfree(tfm);
862 }
863
864 static int aead_setauthsize(void *private, unsigned int authsize)
865 {
866         struct aead_tfm *tfm = private;
867
868         return crypto_aead_setauthsize(tfm->aead, authsize);
869 }
870
871 static int aead_setkey(void *private, const u8 *key, unsigned int keylen)
872 {
873         struct aead_tfm *tfm = private;
874         int err;
875
876         err = crypto_aead_setkey(tfm->aead, key, keylen);
877         tfm->has_key = !err;
878
879         return err;
880 }
881
882 static void aead_sock_destruct(struct sock *sk)
883 {
884         struct alg_sock *ask = alg_sk(sk);
885         struct aead_ctx *ctx = ask->private;
886         unsigned int ivlen = crypto_aead_ivsize(
887                                 crypto_aead_reqtfm(&ctx->aead_req));
888
889         WARN_ON(atomic_read(&sk->sk_refcnt) != 0);
890         aead_put_sgl(sk);
891         sock_kzfree_s(sk, ctx->iv, ivlen);
892         sock_kfree_s(sk, ctx, ctx->len);
893         af_alg_release_parent(sk);
894 }
895
896 static int aead_accept_parent_nokey(void *private, struct sock *sk)
897 {
898         struct aead_ctx *ctx;
899         struct alg_sock *ask = alg_sk(sk);
900         struct aead_tfm *tfm = private;
901         struct crypto_aead *aead = tfm->aead;
902         unsigned int len = sizeof(*ctx) + crypto_aead_reqsize(aead);
903         unsigned int ivlen = crypto_aead_ivsize(aead);
904
905         ctx = sock_kmalloc(sk, len, GFP_KERNEL);
906         if (!ctx)
907                 return -ENOMEM;
908         memset(ctx, 0, len);
909
910         ctx->iv = sock_kmalloc(sk, ivlen, GFP_KERNEL);
911         if (!ctx->iv) {
912                 sock_kfree_s(sk, ctx, len);
913                 return -ENOMEM;
914         }
915         memset(ctx->iv, 0, ivlen);
916
917         ctx->len = len;
918         ctx->used = 0;
919         ctx->more = 0;
920         ctx->merge = 0;
921         ctx->enc = 0;
922         ctx->tsgl.cur = 0;
923         ctx->aead_assoclen = 0;
924         af_alg_init_completion(&ctx->completion);
925         sg_init_table(ctx->tsgl.sg, ALG_MAX_PAGES);
926         INIT_LIST_HEAD(&ctx->list);
927
928         ask->private = ctx;
929
930         aead_request_set_tfm(&ctx->aead_req, aead);
931         aead_request_set_callback(&ctx->aead_req, CRYPTO_TFM_REQ_MAY_BACKLOG,
932                                   af_alg_complete, &ctx->completion);
933
934         sk->sk_destruct = aead_sock_destruct;
935
936         return 0;
937 }
938
939 static int aead_accept_parent(void *private, struct sock *sk)
940 {
941         struct aead_tfm *tfm = private;
942
943         if (!tfm->has_key)
944                 return -ENOKEY;
945
946         return aead_accept_parent_nokey(private, sk);
947 }
948
949 static const struct af_alg_type algif_type_aead = {
950         .bind           =       aead_bind,
951         .release        =       aead_release,
952         .setkey         =       aead_setkey,
953         .setauthsize    =       aead_setauthsize,
954         .accept         =       aead_accept_parent,
955         .accept_nokey   =       aead_accept_parent_nokey,
956         .ops            =       &algif_aead_ops,
957         .ops_nokey      =       &algif_aead_ops_nokey,
958         .name           =       "aead",
959         .owner          =       THIS_MODULE
960 };
961
962 static int __init algif_aead_init(void)
963 {
964         return af_alg_register_type(&algif_type_aead);
965 }
966
967 static void __exit algif_aead_exit(void)
968 {
969         int err = af_alg_unregister_type(&algif_type_aead);
970         BUG_ON(err);
971 }
972
973 module_init(algif_aead_init);
974 module_exit(algif_aead_exit);
975 MODULE_LICENSE("GPL");
976 MODULE_AUTHOR("Stephan Mueller <smueller@chronox.de>");
977 MODULE_DESCRIPTION("AEAD kernel crypto API user space interface");