OSDN Git Service

SUNRPC: Add function rpc_sleep_on_timeout()
[uclinux-h8/linux.git] / net / sunrpc / auth_gss / auth_gss.c
index 1531b02..c055edf 100644 (file)
@@ -1,3 +1,4 @@
+// SPDX-License-Identifier: BSD-3-Clause
 /*
  * linux/net/sunrpc/auth_gss/auth_gss.c
  *
@@ -8,34 +9,8 @@
  *
  *  Dug Song       <dugsong@monkey.org>
  *  Andy Adamson   <andros@umich.edu>
- *
- *  Redistribution and use in source and binary forms, with or without
- *  modification, are permitted provided that the following conditions
- *  are met:
- *
- *  1. Redistributions of source code must retain the above copyright
- *     notice, this list of conditions and the following disclaimer.
- *  2. Redistributions in binary form must reproduce the above copyright
- *     notice, this list of conditions and the following disclaimer in the
- *     documentation and/or other materials provided with the distribution.
- *  3. Neither the name of the University nor the names of its
- *     contributors may be used to endorse or promote products derived
- *     from this software without specific prior written permission.
- *
- *  THIS SOFTWARE IS PROVIDED ``AS IS'' AND ANY EXPRESS OR IMPLIED
- *  WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF
- *  MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
- *  DISCLAIMED. IN NO EVENT SHALL THE REGENTS OR CONTRIBUTORS BE LIABLE
- *  FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
- *  CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
- *  SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR
- *  BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF
- *  LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
- *  NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
- *  SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
  */
 
-
 #include <linux/module.h>
 #include <linux/init.h>
 #include <linux/types.h>
@@ -55,6 +30,8 @@
 
 #include "../netns.h"
 
+#include <trace/events/rpcgss.h>
+
 static const struct rpc_authops authgss_ops;
 
 static const struct rpc_credops gss_credops;
@@ -260,6 +237,7 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct
        }
        ret = gss_import_sec_context(p, seclen, gm, &ctx->gc_gss_ctx, NULL, GFP_NOFS);
        if (ret < 0) {
+               trace_rpcgss_import_ctx(ret);
                p = ERR_PTR(ret);
                goto err;
        }
@@ -275,12 +253,9 @@ gss_fill_context(const void *p, const void *end, struct gss_cl_ctx *ctx, struct
        if (IS_ERR(p))
                goto err;
 done:
-       dprintk("RPC:       %s Success. gc_expiry %lu now %lu timeout %u acceptor %.*s\n",
-               __func__, ctx->gc_expiry, now, timeout, ctx->gc_acceptor.len,
-               ctx->gc_acceptor.data);
-       return p;
+       trace_rpcgss_context(ctx->gc_expiry, now, timeout,
+                            ctx->gc_acceptor.len, ctx->gc_acceptor.data);
 err:
-       dprintk("RPC:       %s returns error %ld\n", __func__, -PTR_ERR(p));
        return p;
 }
 
@@ -354,10 +329,8 @@ __gss_find_upcall(struct rpc_pipe *pipe, kuid_t uid, const struct gss_auth *auth
                if (auth && pos->auth->service != auth->service)
                        continue;
                refcount_inc(&pos->count);
-               dprintk("RPC:       %s found msg %p\n", __func__, pos);
                return pos;
        }
-       dprintk("RPC:       %s found nothing\n", __func__);
        return NULL;
 }
 
@@ -456,7 +429,7 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
        size_t buflen = sizeof(gss_msg->databuf);
        int len;
 
-       len = scnprintf(p, buflen, "mech=%s uid=%d ", mech->gm_name,
+       len = scnprintf(p, buflen, "mech=%s uid=%d", mech->gm_name,
                        from_kuid(&init_user_ns, gss_msg->uid));
        buflen -= len;
        p += len;
@@ -467,7 +440,7 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
         * identity that we are authenticating to.
         */
        if (target_name) {
-               len = scnprintf(p, buflen, "target=%s ", target_name);
+               len = scnprintf(p, buflen, " target=%s", target_name);
                buflen -= len;
                p += len;
                gss_msg->msg.len += len;
@@ -487,11 +460,11 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
                char *c = strchr(service_name, '@');
 
                if (!c)
-                       len = scnprintf(p, buflen, "service=%s ",
+                       len = scnprintf(p, buflen, " service=%s",
                                        service_name);
                else
                        len = scnprintf(p, buflen,
-                                       "service=%.*s srchost=%s ",
+                                       " service=%.*s srchost=%s",
                                        (int)(c - service_name),
                                        service_name, c + 1);
                buflen -= len;
@@ -500,17 +473,17 @@ static int gss_encode_v1_msg(struct gss_upcall_msg *gss_msg,
        }
 
        if (mech->gm_upcall_enctypes) {
-               len = scnprintf(p, buflen, "enctypes=%s ",
+               len = scnprintf(p, buflen, " enctypes=%s",
                                mech->gm_upcall_enctypes);
                buflen -= len;
                p += len;
                gss_msg->msg.len += len;
        }
+       trace_rpcgss_upcall_msg(gss_msg->databuf);
        len = scnprintf(p, buflen, "\n");
        if (len == 0)
                goto out_overflow;
        gss_msg->msg.len += len;
-
        gss_msg->msg.data = gss_msg->databuf;
        return 0;
 out_overflow:
@@ -603,16 +576,15 @@ gss_refresh_upcall(struct rpc_task *task)
        struct rpc_pipe *pipe;
        int err = 0;
 
-       dprintk("RPC: %5u %s for uid %u\n",
-               task->tk_pid, __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid));
        gss_msg = gss_setup_upcall(gss_auth, cred);
        if (PTR_ERR(gss_msg) == -EAGAIN) {
                /* XXX: warning on the first, under the assumption we
                 * shouldn't normally hit this case on a refresh. */
                warn_gssd();
-               task->tk_timeout = 15*HZ;
-               rpc_sleep_on(&pipe_version_rpc_waitqueue, task, NULL);
-               return -EAGAIN;
+               rpc_sleep_on_timeout(&pipe_version_rpc_waitqueue,
+                               task, NULL, jiffies + (15 * HZ));
+               err = -EAGAIN;
+               goto out;
        }
        if (IS_ERR(gss_msg)) {
                err = PTR_ERR(gss_msg);
@@ -623,7 +595,6 @@ gss_refresh_upcall(struct rpc_task *task)
        if (gss_cred->gc_upcall != NULL)
                rpc_sleep_on(&gss_cred->gc_upcall->rpc_waitqueue, task, NULL);
        else if (gss_msg->ctx == NULL && gss_msg->msg.errno >= 0) {
-               task->tk_timeout = 0;
                gss_cred->gc_upcall = gss_msg;
                /* gss_upcall_callback will release the reference to gss_upcall_msg */
                refcount_inc(&gss_msg->count);
@@ -635,9 +606,8 @@ gss_refresh_upcall(struct rpc_task *task)
        spin_unlock(&pipe->lock);
        gss_release_msg(gss_msg);
 out:
-       dprintk("RPC: %5u %s for uid %u result %d\n",
-               task->tk_pid, __func__,
-               from_kuid(&init_user_ns, cred->cr_cred->fsuid), err);
+       trace_rpcgss_upcall_result(from_kuid(&init_user_ns,
+                                            cred->cr_cred->fsuid), err);
        return err;
 }
 
@@ -652,14 +622,13 @@ gss_create_upcall(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
        DEFINE_WAIT(wait);
        int err;
 
-       dprintk("RPC:       %s for uid %u\n",
-               __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid));
 retry:
        err = 0;
        /* if gssd is down, just skip upcalling altogether */
        if (!gssd_running(net)) {
                warn_gssd();
-               return -EACCES;
+               err = -EACCES;
+               goto out;
        }
        gss_msg = gss_setup_upcall(gss_auth, cred);
        if (PTR_ERR(gss_msg) == -EAGAIN) {
@@ -700,8 +669,8 @@ out_intr:
        finish_wait(&gss_msg->waitqueue, &wait);
        gss_release_msg(gss_msg);
 out:
-       dprintk("RPC:       %s for uid %u result %d\n",
-               __func__, from_kuid(&init_user_ns, cred->cr_cred->fsuid), err);
+       trace_rpcgss_upcall_result(from_kuid(&init_user_ns,
+                                            cred->cr_cred->fsuid), err);
        return err;
 }
 
@@ -794,7 +763,6 @@ err_put_ctx:
 err:
        kfree(buf);
 out:
-       dprintk("RPC:       %s returning %zd\n", __func__, err);
        return err;
 }
 
@@ -863,8 +831,6 @@ gss_pipe_destroy_msg(struct rpc_pipe_msg *msg)
        struct gss_upcall_msg *gss_msg = container_of(msg, struct gss_upcall_msg, msg);
 
        if (msg->errno < 0) {
-               dprintk("RPC:       %s releasing msg %p\n",
-                       __func__, gss_msg);
                refcount_inc(&gss_msg->count);
                gss_unhash_msg(gss_msg);
                if (msg->errno == -ETIMEDOUT)
@@ -1024,8 +990,6 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
        struct rpc_auth * auth;
        int err = -ENOMEM; /* XXX? */
 
-       dprintk("RPC:       creating GSS authenticator for client %p\n", clnt);
-
        if (!try_module_get(THIS_MODULE))
                return ERR_PTR(err);
        if (!(gss_auth = kmalloc(sizeof(*gss_auth), GFP_KERNEL)))
@@ -1041,10 +1005,8 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
        gss_auth->net = get_net(rpc_net_ns(clnt));
        err = -EINVAL;
        gss_auth->mech = gss_mech_get_by_pseudoflavor(flavor);
-       if (!gss_auth->mech) {
-               dprintk("RPC:       Pseudoflavor %d not found!\n", flavor);
+       if (!gss_auth->mech)
                goto err_put_net;
-       }
        gss_auth->service = gss_pseudoflavor_to_service(gss_auth->mech, flavor);
        if (gss_auth->service == 0)
                goto err_put_mech;
@@ -1053,6 +1015,8 @@ gss_create_new(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
        auth = &gss_auth->rpc_auth;
        auth->au_cslack = GSS_CRED_SLACK >> 2;
        auth->au_rslack = GSS_VERF_SLACK >> 2;
+       auth->au_verfsize = GSS_VERF_SLACK >> 2;
+       auth->au_ralign = GSS_VERF_SLACK >> 2;
        auth->au_flags = 0;
        auth->au_ops = &authgss_ops;
        auth->au_flavor = flavor;
@@ -1099,6 +1063,7 @@ err_free:
        kfree(gss_auth);
 out_dec:
        module_put(THIS_MODULE);
+       trace_rpcgss_createauth(flavor, err);
        return ERR_PTR(err);
 }
 
@@ -1135,9 +1100,6 @@ gss_destroy(struct rpc_auth *auth)
        struct gss_auth *gss_auth = container_of(auth,
                        struct gss_auth, rpc_auth);
 
-       dprintk("RPC:       destroying GSS authenticator %p flavor %d\n",
-                       auth, auth->au_flavor);
-
        if (hash_hashed(&gss_auth->hash)) {
                spin_lock(&gss_auth_hash_lock);
                hash_del(&gss_auth->hash);
@@ -1245,7 +1207,7 @@ gss_dup_cred(struct gss_auth *gss_auth, struct gss_cred *gss_cred)
        struct gss_cred *new;
 
        /* Make a copy of the cred so that we can reference count it */
-       new = kzalloc(sizeof(*gss_cred), GFP_NOIO);
+       new = kzalloc(sizeof(*gss_cred), GFP_NOFS);
        if (new) {
                struct auth_cred acred = {
                        .cred = gss_cred->gc_base.cr_cred,
@@ -1300,8 +1262,6 @@ gss_send_destroy_context(struct rpc_cred *cred)
 static void
 gss_do_free_ctx(struct gss_cl_ctx *ctx)
 {
-       dprintk("RPC:       %s\n", __func__);
-
        gss_delete_sec_context(&ctx->gc_gss_ctx);
        kfree(ctx->gc_wire_ctx.data);
        kfree(ctx->gc_acceptor.data);
@@ -1324,7 +1284,6 @@ gss_free_ctx(struct gss_cl_ctx *ctx)
 static void
 gss_free_cred(struct gss_cred *gss_cred)
 {
-       dprintk("RPC:       %s cred=%p\n", __func__, gss_cred);
        kfree(gss_cred);
 }
 
@@ -1381,10 +1340,6 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t
        struct gss_cred *cred = NULL;
        int err = -ENOMEM;
 
-       dprintk("RPC:       %s for uid %d, flavor %d\n",
-               __func__, from_kuid(&init_user_ns, acred->cred->fsuid),
-               auth->au_flavor);
-
        if (!(cred = kzalloc(sizeof(*cred), gfp)))
                goto out_err;
 
@@ -1400,7 +1355,6 @@ gss_create_cred(struct rpc_auth *auth, struct auth_cred *acred, int flags, gfp_t
        return &cred->gc_base;
 
 out_err:
-       dprintk("RPC:       %s failed with error %d\n", __func__, err);
        return ERR_PTR(err);
 }
 
@@ -1526,69 +1480,84 @@ out:
 }
 
 /*
-* Marshal credentials.
-* Maybe we should keep a cached credential for performance reasons.
-*/
-static __be32 *
-gss_marshal(struct rpc_task *task, __be32 *p)
+ * Marshal credentials.
+ *
+ * The expensive part is computing the verifier. We can't cache a
+ * pre-computed version of the verifier because the seqno, which
+ * is different every time, is included in the MIC.
+ */
+static int gss_marshal(struct rpc_task *task, struct xdr_stream *xdr)
 {
        struct rpc_rqst *req = task->tk_rqstp;
        struct rpc_cred *cred = req->rq_cred;
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
                                                 gc_base);
        struct gss_cl_ctx       *ctx = gss_cred_get_ctx(cred);
-       __be32          *cred_len;
+       __be32          *p, *cred_len;
        u32             maj_stat = 0;
        struct xdr_netobj mic;
        struct kvec     iov;
        struct xdr_buf  verf_buf;
+       int status;
 
-       dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
+       /* Credential */
 
-       *p++ = htonl(RPC_AUTH_GSS);
+       p = xdr_reserve_space(xdr, 7 * sizeof(*p) +
+                             ctx->gc_wire_ctx.len);
+       if (!p)
+               goto marshal_failed;
+       *p++ = rpc_auth_gss;
        cred_len = p++;
 
        spin_lock(&ctx->gc_seq_lock);
        req->rq_seqno = (ctx->gc_seq < MAXSEQ) ? ctx->gc_seq++ : MAXSEQ;
        spin_unlock(&ctx->gc_seq_lock);
        if (req->rq_seqno == MAXSEQ)
-               goto out_expired;
+               goto expired;
+       trace_rpcgss_seqno(task);
 
-       *p++ = htonl((u32) RPC_GSS_VERSION);
-       *p++ = htonl((u32) ctx->gc_proc);
-       *p++ = htonl((u32) req->rq_seqno);
-       *p++ = htonl((u32) gss_cred->gc_service);
+       *p++ = cpu_to_be32(RPC_GSS_VERSION);
+       *p++ = cpu_to_be32(ctx->gc_proc);
+       *p++ = cpu_to_be32(req->rq_seqno);
+       *p++ = cpu_to_be32(gss_cred->gc_service);
        p = xdr_encode_netobj(p, &ctx->gc_wire_ctx);
-       *cred_len = htonl((p - (cred_len + 1)) << 2);
+       *cred_len = cpu_to_be32((p - (cred_len + 1)) << 2);
+
+       /* Verifier */
 
        /* We compute the checksum for the verifier over the xdr-encoded bytes
         * starting with the xid and ending at the end of the credential: */
-       iov.iov_base = xprt_skip_transport_header(req->rq_xprt,
-                                       req->rq_snd_buf.head[0].iov_base);
+       iov.iov_base = req->rq_snd_buf.head[0].iov_base;
        iov.iov_len = (u8 *)p - (u8 *)iov.iov_base;
        xdr_buf_from_iov(&iov, &verf_buf);
 
-       /* set verifier flavor*/
-       *p++ = htonl(RPC_AUTH_GSS);
-
+       p = xdr_reserve_space(xdr, sizeof(*p));
+       if (!p)
+               goto marshal_failed;
+       *p++ = rpc_auth_gss;
        mic.data = (u8 *)(p + 1);
        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
-       if (maj_stat == GSS_S_CONTEXT_EXPIRED) {
-               goto out_expired;
-       } else if (maj_stat != 0) {
-               pr_warn("gss_marshal: gss_get_mic FAILED (%d)\n", maj_stat);
-               task->tk_status = -EIO;
-               goto out_put_ctx;
-       }
-       p = xdr_encode_opaque(p, NULL, mic.len);
+       if (maj_stat == GSS_S_CONTEXT_EXPIRED)
+               goto expired;
+       else if (maj_stat != 0)
+               goto bad_mic;
+       if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0)
+               goto marshal_failed;
+       status = 0;
+out:
        gss_put_ctx(ctx);
-       return p;
-out_expired:
+       return status;
+expired:
        clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
-       task->tk_status = -EKEYEXPIRED;
-out_put_ctx:
-       gss_put_ctx(ctx);
-       return NULL;
+       status = -EKEYEXPIRED;
+       goto out;
+marshal_failed:
+       status = -EMSGSIZE;
+       goto out;
+bad_mic:
+       trace_rpcgss_get_mic(task, maj_stat);
+       status = -EIO;
+       goto out;
 }
 
 static int gss_renew_cred(struct rpc_task *task)
@@ -1662,116 +1631,105 @@ gss_refresh_null(struct rpc_task *task)
        return 0;
 }
 
-static __be32 *
-gss_validate(struct rpc_task *task, __be32 *p)
+static int
+gss_validate(struct rpc_task *task, struct xdr_stream *xdr)
 {
        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       __be32          *seq = NULL;
+       __be32          *p, *seq = NULL;
        struct kvec     iov;
        struct xdr_buf  verf_buf;
        struct xdr_netobj mic;
-       u32             flav,len;
-       u32             maj_stat;
-       __be32          *ret = ERR_PTR(-EIO);
+       u32             len, maj_stat;
+       int             status;
 
-       dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
+       p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+       if (!p)
+               goto validate_failed;
+       if (*p++ != rpc_auth_gss)
+               goto validate_failed;
+       len = be32_to_cpup(p);
+       if (len > RPC_MAX_AUTH_SIZE)
+               goto validate_failed;
+       p = xdr_inline_decode(xdr, len);
+       if (!p)
+               goto validate_failed;
 
-       flav = ntohl(*p++);
-       if ((len = ntohl(*p++)) > RPC_MAX_AUTH_SIZE)
-               goto out_bad;
-       if (flav != RPC_AUTH_GSS)
-               goto out_bad;
        seq = kmalloc(4, GFP_NOFS);
        if (!seq)
-               goto out_bad;
-       *seq = htonl(task->tk_rqstp->rq_seqno);
+               goto validate_failed;
+       *seq = cpu_to_be32(task->tk_rqstp->rq_seqno);
        iov.iov_base = seq;
        iov.iov_len = 4;
        xdr_buf_from_iov(&iov, &verf_buf);
        mic.data = (u8 *)p;
        mic.len = len;
-
-       ret = ERR_PTR(-EACCES);
        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &verf_buf, &mic);
        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
-       if (maj_stat) {
-               dprintk("RPC: %5u %s: gss_verify_mic returned error 0x%08x\n",
-                       task->tk_pid, __func__, maj_stat);
-               goto out_bad;
-       }
+       if (maj_stat)
+               goto bad_mic;
+
        /* We leave it to unwrap to calculate au_rslack. For now we just
         * calculate the length of the verifier: */
        cred->cr_auth->au_verfsize = XDR_QUADLEN(len) + 2;
+       status = 0;
+out:
        gss_put_ctx(ctx);
-       dprintk("RPC: %5u %s: gss_verify_mic succeeded.\n",
-                       task->tk_pid, __func__);
-       kfree(seq);
-       return p + XDR_QUADLEN(len);
-out_bad:
-       gss_put_ctx(ctx);
-       dprintk("RPC: %5u %s failed ret %ld.\n", task->tk_pid, __func__,
-               PTR_ERR(ret));
        kfree(seq);
-       return ret;
-}
-
-static void gss_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
-                               __be32 *p, void *obj)
-{
-       struct xdr_stream xdr;
+       return status;
 
-       xdr_init_encode(&xdr, &rqstp->rq_snd_buf, p);
-       encode(rqstp, &xdr, obj);
+validate_failed:
+       status = -EIO;
+       goto out;
+bad_mic:
+       trace_rpcgss_verify_mic(task, maj_stat);
+       status = -EACCES;
+       goto out;
 }
 
-static inline int
-gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
-                  kxdreproc_t encode, struct rpc_rqst *rqstp,
-                  __be32 *p, void *obj)
+static int gss_wrap_req_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
+                             struct rpc_task *task, struct xdr_stream *xdr)
 {
-       struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
-       struct xdr_buf  integ_buf;
-       __be32          *integ_len = NULL;
+       struct rpc_rqst *rqstp = task->tk_rqstp;
+       struct xdr_buf integ_buf, *snd_buf = &rqstp->rq_snd_buf;
        struct xdr_netobj mic;
-       u32             offset;
-       __be32          *q;
-       struct kvec     *iov;
-       u32             maj_stat = 0;
-       int             status = -EIO;
+       __be32 *p, *integ_len;
+       u32 offset, maj_stat;
 
+       p = xdr_reserve_space(xdr, 2 * sizeof(*p));
+       if (!p)
+               goto wrap_failed;
        integ_len = p++;
-       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
-       *p++ = htonl(rqstp->rq_seqno);
+       *p = cpu_to_be32(rqstp->rq_seqno);
 
-       gss_wrap_req_encode(encode, rqstp, p, obj);
+       if (rpcauth_wrap_req_encode(task, xdr))
+               goto wrap_failed;
 
+       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        if (xdr_buf_subsegment(snd_buf, &integ_buf,
                                offset, snd_buf->len - offset))
-               return status;
-       *integ_len = htonl(integ_buf.len);
+               goto wrap_failed;
+       *integ_len = cpu_to_be32(integ_buf.len);
 
-       /* guess whether we're in the head or the tail: */
-       if (snd_buf->page_len || snd_buf->tail[0].iov_len)
-               iov = snd_buf->tail;
-       else
-               iov = snd_buf->head;
-       p = iov->iov_base + iov->iov_len;
+       p = xdr_reserve_space(xdr, 0);
+       if (!p)
+               goto wrap_failed;
        mic.data = (u8 *)(p + 1);
-
        maj_stat = gss_get_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
-       status = -EIO; /* XXX? */
        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
        else if (maj_stat)
-               return status;
-       q = xdr_encode_opaque(p, NULL, mic.len);
-
-       offset = (u8 *)q - (u8 *)p;
-       iov->iov_len += offset;
-       snd_buf->len += offset;
+               goto bad_mic;
+       /* Check that the trailing MIC fit in the buffer, after the fact */
+       if (xdr_stream_encode_opaque_inline(xdr, (void **)&p, mic.len) < 0)
+               goto wrap_failed;
        return 0;
+wrap_failed:
+       return -EMSGSIZE;
+bad_mic:
+       trace_rpcgss_get_mic(task, maj_stat);
+       return -EIO;
 }
 
 static void
@@ -1822,61 +1780,62 @@ out:
        return -EAGAIN;
 }
 
-static inline int
-gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
-                 kxdreproc_t encode, struct rpc_rqst *rqstp,
-                 __be32 *p, void *obj)
+static int gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
+                            struct rpc_task *task, struct xdr_stream *xdr)
 {
+       struct rpc_rqst *rqstp = task->tk_rqstp;
        struct xdr_buf  *snd_buf = &rqstp->rq_snd_buf;
-       u32             offset;
-       u32             maj_stat;
+       u32             pad, offset, maj_stat;
        int             status;
-       __be32          *opaque_len;
+       __be32          *p, *opaque_len;
        struct page     **inpages;
        int             first;
-       int             pad;
        struct kvec     *iov;
-       char            *tmp;
 
+       status = -EIO;
+       p = xdr_reserve_space(xdr, 2 * sizeof(*p));
+       if (!p)
+               goto wrap_failed;
        opaque_len = p++;
-       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
-       *p++ = htonl(rqstp->rq_seqno);
+       *p = cpu_to_be32(rqstp->rq_seqno);
 
-       gss_wrap_req_encode(encode, rqstp, p, obj);
+       if (rpcauth_wrap_req_encode(task, xdr))
+               goto wrap_failed;
 
        status = alloc_enc_pages(rqstp);
-       if (status)
-               return status;
+       if (unlikely(status))
+               goto wrap_failed;
        first = snd_buf->page_base >> PAGE_SHIFT;
        inpages = snd_buf->pages + first;
        snd_buf->pages = rqstp->rq_enc_pages;
        snd_buf->page_base -= first << PAGE_SHIFT;
        /*
-        * Give the tail its own page, in case we need extra space in the
-        * head when wrapping:
+        * Move the tail into its own page, in case gss_wrap needs
+        * more space in the head when wrapping.
         *
-        * call_allocate() allocates twice the slack space required
-        * by the authentication flavor to rq_callsize.
-        * For GSS, slack is GSS_CRED_SLACK.
+        * Still... Why can't gss_wrap just slide the tail down?
         */
        if (snd_buf->page_len || snd_buf->tail[0].iov_len) {
+               char *tmp;
+
                tmp = page_address(rqstp->rq_enc_pages[rqstp->rq_enc_pages_num - 1]);
                memcpy(tmp, snd_buf->tail[0].iov_base, snd_buf->tail[0].iov_len);
                snd_buf->tail[0].iov_base = tmp;
        }
+       offset = (u8 *)p - (u8 *)snd_buf->head[0].iov_base;
        maj_stat = gss_wrap(ctx->gc_gss_ctx, offset, snd_buf, inpages);
        /* slack space should prevent this ever happening: */
-       BUG_ON(snd_buf->len > snd_buf->buflen);
-       status = -EIO;
+       if (unlikely(snd_buf->len > snd_buf->buflen))
+               goto wrap_failed;
        /* We're assuming that when GSS_S_CONTEXT_EXPIRED, the encryption was
         * done anyway, so it's safe to put the request on the wire: */
        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
        else if (maj_stat)
-               return status;
+               goto bad_wrap;
 
-       *opaque_len = htonl(snd_buf->len - offset);
-       /* guess whether we're in the head or the tail: */
+       *opaque_len = cpu_to_be32(snd_buf->len - offset);
+       /* guess whether the pad goes into the head or the tail: */
        if (snd_buf->page_len || snd_buf->tail[0].iov_len)
                iov = snd_buf->tail;
        else
@@ -1888,118 +1847,154 @@ gss_wrap_req_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
        snd_buf->len += pad;
 
        return 0;
+wrap_failed:
+       return status;
+bad_wrap:
+       trace_rpcgss_wrap(task, maj_stat);
+       return -EIO;
 }
 
-static int
-gss_wrap_req(struct rpc_task *task,
-            kxdreproc_t encode, void *rqstp, __be32 *p, void *obj)
+static int gss_wrap_req(struct rpc_task *task, struct xdr_stream *xdr)
 {
        struct rpc_cred *cred = task->tk_rqstp->rq_cred;
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
                        gc_base);
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       int             status = -EIO;
+       int status;
 
-       dprintk("RPC: %5u %s\n", task->tk_pid, __func__);
+       status = -EIO;
        if (ctx->gc_proc != RPC_GSS_PROC_DATA) {
                /* The spec seems a little ambiguous here, but I think that not
                 * wrapping context destruction requests makes the most sense.
                 */
-               gss_wrap_req_encode(encode, rqstp, p, obj);
-               status = 0;
+               status = rpcauth_wrap_req_encode(task, xdr);
                goto out;
        }
        switch (gss_cred->gc_service) {
        case RPC_GSS_SVC_NONE:
-               gss_wrap_req_encode(encode, rqstp, p, obj);
-               status = 0;
+               status = rpcauth_wrap_req_encode(task, xdr);
                break;
        case RPC_GSS_SVC_INTEGRITY:
-               status = gss_wrap_req_integ(cred, ctx, encode, rqstp, p, obj);
+               status = gss_wrap_req_integ(cred, ctx, task, xdr);
                break;
        case RPC_GSS_SVC_PRIVACY:
-               status = gss_wrap_req_priv(cred, ctx, encode, rqstp, p, obj);
+               status = gss_wrap_req_priv(cred, ctx, task, xdr);
                break;
+       default:
+               status = -EIO;
        }
 out:
        gss_put_ctx(ctx);
-       dprintk("RPC: %5u %s returning %d\n", task->tk_pid, __func__, status);
        return status;
 }
 
-static inline int
-gss_unwrap_resp_integ(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
-               struct rpc_rqst *rqstp, __be32 **p)
+static int
+gss_unwrap_resp_auth(struct rpc_cred *cred)
 {
-       struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
-       struct xdr_buf integ_buf;
+       struct rpc_auth *auth = cred->cr_auth;
+
+       auth->au_rslack = auth->au_verfsize;
+       auth->au_ralign = auth->au_verfsize;
+       return 0;
+}
+
+static int
+gss_unwrap_resp_integ(struct rpc_task *task, struct rpc_cred *cred,
+                     struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp,
+                     struct xdr_stream *xdr)
+{
+       struct xdr_buf integ_buf, *rcv_buf = &rqstp->rq_rcv_buf;
+       u32 data_offset, mic_offset, integ_len, maj_stat;
+       struct rpc_auth *auth = cred->cr_auth;
        struct xdr_netobj mic;
-       u32 data_offset, mic_offset;
-       u32 integ_len;
-       u32 maj_stat;
-       int status = -EIO;
+       __be32 *p;
 
-       integ_len = ntohl(*(*p)++);
+       p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+       if (unlikely(!p))
+               goto unwrap_failed;
+       integ_len = be32_to_cpup(p++);
        if (integ_len & 3)
-               return status;
-       data_offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
+               goto unwrap_failed;
+       data_offset = (u8 *)(p) - (u8 *)rcv_buf->head[0].iov_base;
        mic_offset = integ_len + data_offset;
        if (mic_offset > rcv_buf->len)
-               return status;
-       if (ntohl(*(*p)++) != rqstp->rq_seqno)
-               return status;
-
-       if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset,
-                               mic_offset - data_offset))
-               return status;
+               goto unwrap_failed;
+       if (be32_to_cpup(p) != rqstp->rq_seqno)
+               goto bad_seqno;
 
+       if (xdr_buf_subsegment(rcv_buf, &integ_buf, data_offset, integ_len))
+               goto unwrap_failed;
        if (xdr_buf_read_netobj(rcv_buf, &mic, mic_offset))
-               return status;
-
+               goto unwrap_failed;
        maj_stat = gss_verify_mic(ctx->gc_gss_ctx, &integ_buf, &mic);
        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
        if (maj_stat != GSS_S_COMPLETE)
-               return status;
+               goto bad_mic;
+
+       auth->au_rslack = auth->au_verfsize + 2 + 1 + XDR_QUADLEN(mic.len);
+       auth->au_ralign = auth->au_verfsize + 2;
        return 0;
+unwrap_failed:
+       trace_rpcgss_unwrap_failed(task);
+       return -EIO;
+bad_seqno:
+       trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(p));
+       return -EIO;
+bad_mic:
+       trace_rpcgss_verify_mic(task, maj_stat);
+       return -EIO;
 }
 
-static inline int
-gss_unwrap_resp_priv(struct rpc_cred *cred, struct gss_cl_ctx *ctx,
-               struct rpc_rqst *rqstp, __be32 **p)
-{
-       struct xdr_buf  *rcv_buf = &rqstp->rq_rcv_buf;
-       u32 offset;
-       u32 opaque_len;
-       u32 maj_stat;
-       int status = -EIO;
-
-       opaque_len = ntohl(*(*p)++);
-       offset = (u8 *)(*p) - (u8 *)rcv_buf->head[0].iov_base;
+static int
+gss_unwrap_resp_priv(struct rpc_task *task, struct rpc_cred *cred,
+                    struct gss_cl_ctx *ctx, struct rpc_rqst *rqstp,
+                    struct xdr_stream *xdr)
+{
+       struct xdr_buf *rcv_buf = &rqstp->rq_rcv_buf;
+       struct kvec *head = rqstp->rq_rcv_buf.head;
+       struct rpc_auth *auth = cred->cr_auth;
+       unsigned int savedlen = rcv_buf->len;
+       u32 offset, opaque_len, maj_stat;
+       __be32 *p;
+
+       p = xdr_inline_decode(xdr, 2 * sizeof(*p));
+       if (unlikely(!p))
+               goto unwrap_failed;
+       opaque_len = be32_to_cpup(p++);
+       offset = (u8 *)(p) - (u8 *)head->iov_base;
        if (offset + opaque_len > rcv_buf->len)
-               return status;
-       /* remove padding: */
+               goto unwrap_failed;
        rcv_buf->len = offset + opaque_len;
 
        maj_stat = gss_unwrap(ctx->gc_gss_ctx, offset, rcv_buf);
        if (maj_stat == GSS_S_CONTEXT_EXPIRED)
                clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
        if (maj_stat != GSS_S_COMPLETE)
-               return status;
-       if (ntohl(*(*p)++) != rqstp->rq_seqno)
-               return status;
+               goto bad_unwrap;
+       /* gss_unwrap decrypted the sequence number */
+       if (be32_to_cpup(p++) != rqstp->rq_seqno)
+               goto bad_seqno;
 
-       return 0;
-}
-
-static int
-gss_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
-                     __be32 *p, void *obj)
-{
-       struct xdr_stream xdr;
+       /* gss_unwrap redacts the opaque blob from the head iovec.
+        * rcv_buf has changed, thus the stream needs to be reset.
+        */
+       xdr_init_decode(xdr, rcv_buf, p, rqstp);
 
-       xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, p);
-       return decode(rqstp, &xdr, obj);
+       auth->au_rslack = auth->au_verfsize + 2 +
+                         XDR_QUADLEN(savedlen - rcv_buf->len);
+       auth->au_ralign = auth->au_verfsize + 2 +
+                         XDR_QUADLEN(savedlen - rcv_buf->len);
+       return 0;
+unwrap_failed:
+       trace_rpcgss_unwrap_failed(task);
+       return -EIO;
+bad_seqno:
+       trace_rpcgss_bad_seqno(task, rqstp->rq_seqno, be32_to_cpup(--p));
+       return -EIO;
+bad_unwrap:
+       trace_rpcgss_unwrap(task, maj_stat);
+       return -EIO;
 }
 
 static bool
@@ -2014,14 +2009,14 @@ gss_xmit_need_reencode(struct rpc_task *task)
        struct rpc_rqst *req = task->tk_rqstp;
        struct rpc_cred *cred = req->rq_cred;
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       u32 win, seq_xmit;
+       u32 win, seq_xmit = 0;
        bool ret = true;
 
        if (!ctx)
-               return true;
+               goto out;
 
        if (gss_seq_is_newer(req->rq_seqno, READ_ONCE(ctx->gc_seq)))
-               goto out;
+               goto out_ctx;
 
        seq_xmit = READ_ONCE(ctx->gc_seq_xmit);
        while (gss_seq_is_newer(req->rq_seqno, seq_xmit)) {
@@ -2030,56 +2025,51 @@ gss_xmit_need_reencode(struct rpc_task *task)
                seq_xmit = cmpxchg(&ctx->gc_seq_xmit, tmp, req->rq_seqno);
                if (seq_xmit == tmp) {
                        ret = false;
-                       goto out;
+                       goto out_ctx;
                }
        }
 
        win = ctx->gc_win;
        if (win > 0)
                ret = !gss_seq_is_newer(req->rq_seqno, seq_xmit - win);
-out:
+
+out_ctx:
        gss_put_ctx(ctx);
+out:
+       trace_rpcgss_need_reencode(task, seq_xmit, ret);
        return ret;
 }
 
 static int
-gss_unwrap_resp(struct rpc_task *task,
-               kxdrdproc_t decode, void *rqstp, __be32 *p, void *obj)
+gss_unwrap_resp(struct rpc_task *task, struct xdr_stream *xdr)
 {
-       struct rpc_cred *cred = task->tk_rqstp->rq_cred;
+       struct rpc_rqst *rqstp = task->tk_rqstp;
+       struct rpc_cred *cred = rqstp->rq_cred;
        struct gss_cred *gss_cred = container_of(cred, struct gss_cred,
                        gc_base);
        struct gss_cl_ctx *ctx = gss_cred_get_ctx(cred);
-       __be32          *savedp = p;
-       struct kvec     *head = ((struct rpc_rqst *)rqstp)->rq_rcv_buf.head;
-       int             savedlen = head->iov_len;
-       int             status = -EIO;
+       int status = -EIO;
 
        if (ctx->gc_proc != RPC_GSS_PROC_DATA)
                goto out_decode;
        switch (gss_cred->gc_service) {
        case RPC_GSS_SVC_NONE:
+               status = gss_unwrap_resp_auth(cred);
                break;
        case RPC_GSS_SVC_INTEGRITY:
-               status = gss_unwrap_resp_integ(cred, ctx, rqstp, &p);
-               if (status)
-                       goto out;
+               status = gss_unwrap_resp_integ(task, cred, ctx, rqstp, xdr);
                break;
        case RPC_GSS_SVC_PRIVACY:
-               status = gss_unwrap_resp_priv(cred, ctx, rqstp, &p);
-               if (status)
-                       goto out;
+               status = gss_unwrap_resp_priv(task, cred, ctx, rqstp, xdr);
                break;
        }
-       /* take into account extra slack for integrity and privacy cases: */
-       cred->cr_auth->au_rslack = cred->cr_auth->au_verfsize + (p - savedp)
-                                               + (savedlen - head->iov_len);
+       if (status)
+               goto out;
+
 out_decode:
-       status = gss_unwrap_req_decode(decode, rqstp, p, obj);
+       status = rpcauth_unwrap_resp_decode(task, xdr);
 out:
        gss_put_ctx(ctx);
-       dprintk("RPC: %5u %s returning %d\n",
-               task->tk_pid, __func__, status);
        return status;
 }