OSDN Git Service

MAINTAINERS: add entry for redpine wireless driver
[uclinux-h8/linux.git] / net / sunrpc / auth.c
1 /*
2  * linux/net/sunrpc/auth.c
3  *
4  * Generic RPC client authentication API.
5  *
6  * Copyright (C) 1996, Olaf Kirch <okir@monad.swb.de>
7  */
8
9 #include <linux/types.h>
10 #include <linux/sched.h>
11 #include <linux/cred.h>
12 #include <linux/module.h>
13 #include <linux/slab.h>
14 #include <linux/errno.h>
15 #include <linux/hash.h>
16 #include <linux/sunrpc/clnt.h>
17 #include <linux/sunrpc/gss_api.h>
18 #include <linux/spinlock.h>
19
20 #if IS_ENABLED(CONFIG_SUNRPC_DEBUG)
21 # define RPCDBG_FACILITY        RPCDBG_AUTH
22 #endif
23
24 #define RPC_CREDCACHE_DEFAULT_HASHBITS  (4)
25 struct rpc_cred_cache {
26         struct hlist_head       *hashtable;
27         unsigned int            hashbits;
28         spinlock_t              lock;
29 };
30
31 static unsigned int auth_hashbits = RPC_CREDCACHE_DEFAULT_HASHBITS;
32
33 static const struct rpc_authops __rcu *auth_flavors[RPC_AUTH_MAXFLAVOR] = {
34         [RPC_AUTH_NULL] = (const struct rpc_authops __force __rcu *)&authnull_ops,
35         [RPC_AUTH_UNIX] = (const struct rpc_authops __force __rcu *)&authunix_ops,
36         NULL,                   /* others can be loadable modules */
37 };
38
39 static LIST_HEAD(cred_unused);
40 static unsigned long number_cred_unused;
41
42 static struct cred machine_cred = {
43         .usage = ATOMIC_INIT(1),
44 };
45
46 /*
47  * Return the machine_cred pointer to be used whenever
48  * the a generic machine credential is needed.
49  */
50 const struct cred *rpc_machine_cred(void)
51 {
52         return &machine_cred;
53 }
54 EXPORT_SYMBOL_GPL(rpc_machine_cred);
55
56 #define MAX_HASHTABLE_BITS (14)
57 static int param_set_hashtbl_sz(const char *val, const struct kernel_param *kp)
58 {
59         unsigned long num;
60         unsigned int nbits;
61         int ret;
62
63         if (!val)
64                 goto out_inval;
65         ret = kstrtoul(val, 0, &num);
66         if (ret)
67                 goto out_inval;
68         nbits = fls(num - 1);
69         if (nbits > MAX_HASHTABLE_BITS || nbits < 2)
70                 goto out_inval;
71         *(unsigned int *)kp->arg = nbits;
72         return 0;
73 out_inval:
74         return -EINVAL;
75 }
76
77 static int param_get_hashtbl_sz(char *buffer, const struct kernel_param *kp)
78 {
79         unsigned int nbits;
80
81         nbits = *(unsigned int *)kp->arg;
82         return sprintf(buffer, "%u", 1U << nbits);
83 }
84
85 #define param_check_hashtbl_sz(name, p) __param_check(name, p, unsigned int);
86
87 static const struct kernel_param_ops param_ops_hashtbl_sz = {
88         .set = param_set_hashtbl_sz,
89         .get = param_get_hashtbl_sz,
90 };
91
92 module_param_named(auth_hashtable_size, auth_hashbits, hashtbl_sz, 0644);
93 MODULE_PARM_DESC(auth_hashtable_size, "RPC credential cache hashtable size");
94
95 static unsigned long auth_max_cred_cachesize = ULONG_MAX;
96 module_param(auth_max_cred_cachesize, ulong, 0644);
97 MODULE_PARM_DESC(auth_max_cred_cachesize, "RPC credential maximum total cache size");
98
99 static u32
100 pseudoflavor_to_flavor(u32 flavor) {
101         if (flavor > RPC_AUTH_MAXFLAVOR)
102                 return RPC_AUTH_GSS;
103         return flavor;
104 }
105
106 int
107 rpcauth_register(const struct rpc_authops *ops)
108 {
109         const struct rpc_authops *old;
110         rpc_authflavor_t flavor;
111
112         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
113                 return -EINVAL;
114         old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], NULL, ops);
115         if (old == NULL || old == ops)
116                 return 0;
117         return -EPERM;
118 }
119 EXPORT_SYMBOL_GPL(rpcauth_register);
120
121 int
122 rpcauth_unregister(const struct rpc_authops *ops)
123 {
124         const struct rpc_authops *old;
125         rpc_authflavor_t flavor;
126
127         if ((flavor = ops->au_flavor) >= RPC_AUTH_MAXFLAVOR)
128                 return -EINVAL;
129
130         old = cmpxchg((const struct rpc_authops ** __force)&auth_flavors[flavor], ops, NULL);
131         if (old == ops || old == NULL)
132                 return 0;
133         return -EPERM;
134 }
135 EXPORT_SYMBOL_GPL(rpcauth_unregister);
136
137 static const struct rpc_authops *
138 rpcauth_get_authops(rpc_authflavor_t flavor)
139 {
140         const struct rpc_authops *ops;
141
142         if (flavor >= RPC_AUTH_MAXFLAVOR)
143                 return NULL;
144
145         rcu_read_lock();
146         ops = rcu_dereference(auth_flavors[flavor]);
147         if (ops == NULL) {
148                 rcu_read_unlock();
149                 request_module("rpc-auth-%u", flavor);
150                 rcu_read_lock();
151                 ops = rcu_dereference(auth_flavors[flavor]);
152                 if (ops == NULL)
153                         goto out;
154         }
155         if (!try_module_get(ops->owner))
156                 ops = NULL;
157 out:
158         rcu_read_unlock();
159         return ops;
160 }
161
162 static void
163 rpcauth_put_authops(const struct rpc_authops *ops)
164 {
165         module_put(ops->owner);
166 }
167
168 /**
169  * rpcauth_get_pseudoflavor - check if security flavor is supported
170  * @flavor: a security flavor
171  * @info: a GSS mech OID, quality of protection, and service value
172  *
173  * Verifies that an appropriate kernel module is available or already loaded.
174  * Returns an equivalent pseudoflavor, or RPC_AUTH_MAXFLAVOR if "flavor" is
175  * not supported locally.
176  */
177 rpc_authflavor_t
178 rpcauth_get_pseudoflavor(rpc_authflavor_t flavor, struct rpcsec_gss_info *info)
179 {
180         const struct rpc_authops *ops = rpcauth_get_authops(flavor);
181         rpc_authflavor_t pseudoflavor;
182
183         if (!ops)
184                 return RPC_AUTH_MAXFLAVOR;
185         pseudoflavor = flavor;
186         if (ops->info2flavor != NULL)
187                 pseudoflavor = ops->info2flavor(info);
188
189         rpcauth_put_authops(ops);
190         return pseudoflavor;
191 }
192 EXPORT_SYMBOL_GPL(rpcauth_get_pseudoflavor);
193
194 /**
195  * rpcauth_get_gssinfo - find GSS tuple matching a GSS pseudoflavor
196  * @pseudoflavor: GSS pseudoflavor to match
197  * @info: rpcsec_gss_info structure to fill in
198  *
199  * Returns zero and fills in "info" if pseudoflavor matches a
200  * supported mechanism.
201  */
202 int
203 rpcauth_get_gssinfo(rpc_authflavor_t pseudoflavor, struct rpcsec_gss_info *info)
204 {
205         rpc_authflavor_t flavor = pseudoflavor_to_flavor(pseudoflavor);
206         const struct rpc_authops *ops;
207         int result;
208
209         ops = rpcauth_get_authops(flavor);
210         if (ops == NULL)
211                 return -ENOENT;
212
213         result = -ENOENT;
214         if (ops->flavor2info != NULL)
215                 result = ops->flavor2info(pseudoflavor, info);
216
217         rpcauth_put_authops(ops);
218         return result;
219 }
220 EXPORT_SYMBOL_GPL(rpcauth_get_gssinfo);
221
222 /**
223  * rpcauth_list_flavors - discover registered flavors and pseudoflavors
224  * @array: array to fill in
225  * @size: size of "array"
226  *
227  * Returns the number of array items filled in, or a negative errno.
228  *
229  * The returned array is not sorted by any policy.  Callers should not
230  * rely on the order of the items in the returned array.
231  */
232 int
233 rpcauth_list_flavors(rpc_authflavor_t *array, int size)
234 {
235         const struct rpc_authops *ops;
236         rpc_authflavor_t flavor, pseudos[4];
237         int i, len, result = 0;
238
239         rcu_read_lock();
240         for (flavor = 0; flavor < RPC_AUTH_MAXFLAVOR; flavor++) {
241                 ops = rcu_dereference(auth_flavors[flavor]);
242                 if (result >= size) {
243                         result = -ENOMEM;
244                         break;
245                 }
246
247                 if (ops == NULL)
248                         continue;
249                 if (ops->list_pseudoflavors == NULL) {
250                         array[result++] = ops->au_flavor;
251                         continue;
252                 }
253                 len = ops->list_pseudoflavors(pseudos, ARRAY_SIZE(pseudos));
254                 if (len < 0) {
255                         result = len;
256                         break;
257                 }
258                 for (i = 0; i < len; i++) {
259                         if (result >= size) {
260                                 result = -ENOMEM;
261                                 break;
262                         }
263                         array[result++] = pseudos[i];
264                 }
265         }
266         rcu_read_unlock();
267
268         dprintk("RPC:       %s returns %d\n", __func__, result);
269         return result;
270 }
271 EXPORT_SYMBOL_GPL(rpcauth_list_flavors);
272
273 struct rpc_auth *
274 rpcauth_create(const struct rpc_auth_create_args *args, struct rpc_clnt *clnt)
275 {
276         struct rpc_auth *auth = ERR_PTR(-EINVAL);
277         const struct rpc_authops *ops;
278         u32 flavor = pseudoflavor_to_flavor(args->pseudoflavor);
279
280         ops = rpcauth_get_authops(flavor);
281         if (ops == NULL)
282                 goto out;
283
284         auth = ops->create(args, clnt);
285
286         rpcauth_put_authops(ops);
287         if (IS_ERR(auth))
288                 return auth;
289         if (clnt->cl_auth)
290                 rpcauth_release(clnt->cl_auth);
291         clnt->cl_auth = auth;
292
293 out:
294         return auth;
295 }
296 EXPORT_SYMBOL_GPL(rpcauth_create);
297
298 void
299 rpcauth_release(struct rpc_auth *auth)
300 {
301         if (!refcount_dec_and_test(&auth->au_count))
302                 return;
303         auth->au_ops->destroy(auth);
304 }
305
306 static DEFINE_SPINLOCK(rpc_credcache_lock);
307
308 /*
309  * On success, the caller is responsible for freeing the reference
310  * held by the hashtable
311  */
312 static bool
313 rpcauth_unhash_cred_locked(struct rpc_cred *cred)
314 {
315         if (!test_and_clear_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
316                 return false;
317         hlist_del_rcu(&cred->cr_hash);
318         return true;
319 }
320
321 static bool
322 rpcauth_unhash_cred(struct rpc_cred *cred)
323 {
324         spinlock_t *cache_lock;
325         bool ret;
326
327         if (!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
328                 return false;
329         cache_lock = &cred->cr_auth->au_credcache->lock;
330         spin_lock(cache_lock);
331         ret = rpcauth_unhash_cred_locked(cred);
332         spin_unlock(cache_lock);
333         return ret;
334 }
335
336 /*
337  * Initialize RPC credential cache
338  */
339 int
340 rpcauth_init_credcache(struct rpc_auth *auth)
341 {
342         struct rpc_cred_cache *new;
343         unsigned int hashsize;
344
345         new = kmalloc(sizeof(*new), GFP_KERNEL);
346         if (!new)
347                 goto out_nocache;
348         new->hashbits = auth_hashbits;
349         hashsize = 1U << new->hashbits;
350         new->hashtable = kcalloc(hashsize, sizeof(new->hashtable[0]), GFP_KERNEL);
351         if (!new->hashtable)
352                 goto out_nohashtbl;
353         spin_lock_init(&new->lock);
354         auth->au_credcache = new;
355         return 0;
356 out_nohashtbl:
357         kfree(new);
358 out_nocache:
359         return -ENOMEM;
360 }
361 EXPORT_SYMBOL_GPL(rpcauth_init_credcache);
362
363 char *
364 rpcauth_stringify_acceptor(struct rpc_cred *cred)
365 {
366         if (!cred->cr_ops->crstringify_acceptor)
367                 return NULL;
368         return cred->cr_ops->crstringify_acceptor(cred);
369 }
370 EXPORT_SYMBOL_GPL(rpcauth_stringify_acceptor);
371
372 /*
373  * Destroy a list of credentials
374  */
375 static inline
376 void rpcauth_destroy_credlist(struct list_head *head)
377 {
378         struct rpc_cred *cred;
379
380         while (!list_empty(head)) {
381                 cred = list_entry(head->next, struct rpc_cred, cr_lru);
382                 list_del_init(&cred->cr_lru);
383                 put_rpccred(cred);
384         }
385 }
386
387 static void
388 rpcauth_lru_add_locked(struct rpc_cred *cred)
389 {
390         if (!list_empty(&cred->cr_lru))
391                 return;
392         number_cred_unused++;
393         list_add_tail(&cred->cr_lru, &cred_unused);
394 }
395
396 static void
397 rpcauth_lru_add(struct rpc_cred *cred)
398 {
399         if (!list_empty(&cred->cr_lru))
400                 return;
401         spin_lock(&rpc_credcache_lock);
402         rpcauth_lru_add_locked(cred);
403         spin_unlock(&rpc_credcache_lock);
404 }
405
406 static void
407 rpcauth_lru_remove_locked(struct rpc_cred *cred)
408 {
409         if (list_empty(&cred->cr_lru))
410                 return;
411         number_cred_unused--;
412         list_del_init(&cred->cr_lru);
413 }
414
415 static void
416 rpcauth_lru_remove(struct rpc_cred *cred)
417 {
418         if (list_empty(&cred->cr_lru))
419                 return;
420         spin_lock(&rpc_credcache_lock);
421         rpcauth_lru_remove_locked(cred);
422         spin_unlock(&rpc_credcache_lock);
423 }
424
425 /*
426  * Clear the RPC credential cache, and delete those credentials
427  * that are not referenced.
428  */
429 void
430 rpcauth_clear_credcache(struct rpc_cred_cache *cache)
431 {
432         LIST_HEAD(free);
433         struct hlist_head *head;
434         struct rpc_cred *cred;
435         unsigned int hashsize = 1U << cache->hashbits;
436         int             i;
437
438         spin_lock(&rpc_credcache_lock);
439         spin_lock(&cache->lock);
440         for (i = 0; i < hashsize; i++) {
441                 head = &cache->hashtable[i];
442                 while (!hlist_empty(head)) {
443                         cred = hlist_entry(head->first, struct rpc_cred, cr_hash);
444                         rpcauth_unhash_cred_locked(cred);
445                         /* Note: We now hold a reference to cred */
446                         rpcauth_lru_remove_locked(cred);
447                         list_add_tail(&cred->cr_lru, &free);
448                 }
449         }
450         spin_unlock(&cache->lock);
451         spin_unlock(&rpc_credcache_lock);
452         rpcauth_destroy_credlist(&free);
453 }
454
455 /*
456  * Destroy the RPC credential cache
457  */
458 void
459 rpcauth_destroy_credcache(struct rpc_auth *auth)
460 {
461         struct rpc_cred_cache *cache = auth->au_credcache;
462
463         if (cache) {
464                 auth->au_credcache = NULL;
465                 rpcauth_clear_credcache(cache);
466                 kfree(cache->hashtable);
467                 kfree(cache);
468         }
469 }
470 EXPORT_SYMBOL_GPL(rpcauth_destroy_credcache);
471
472
473 #define RPC_AUTH_EXPIRY_MORATORIUM (60 * HZ)
474
475 /*
476  * Remove stale credentials. Avoid sleeping inside the loop.
477  */
478 static long
479 rpcauth_prune_expired(struct list_head *free, int nr_to_scan)
480 {
481         struct rpc_cred *cred, *next;
482         unsigned long expired = jiffies - RPC_AUTH_EXPIRY_MORATORIUM;
483         long freed = 0;
484
485         list_for_each_entry_safe(cred, next, &cred_unused, cr_lru) {
486
487                 if (nr_to_scan-- == 0)
488                         break;
489                 if (refcount_read(&cred->cr_count) > 1) {
490                         rpcauth_lru_remove_locked(cred);
491                         continue;
492                 }
493                 /*
494                  * Enforce a 60 second garbage collection moratorium
495                  * Note that the cred_unused list must be time-ordered.
496                  */
497                 if (!time_in_range(cred->cr_expire, expired, jiffies))
498                         continue;
499                 if (!rpcauth_unhash_cred(cred))
500                         continue;
501
502                 rpcauth_lru_remove_locked(cred);
503                 freed++;
504                 list_add_tail(&cred->cr_lru, free);
505         }
506         return freed ? freed : SHRINK_STOP;
507 }
508
509 static unsigned long
510 rpcauth_cache_do_shrink(int nr_to_scan)
511 {
512         LIST_HEAD(free);
513         unsigned long freed;
514
515         spin_lock(&rpc_credcache_lock);
516         freed = rpcauth_prune_expired(&free, nr_to_scan);
517         spin_unlock(&rpc_credcache_lock);
518         rpcauth_destroy_credlist(&free);
519
520         return freed;
521 }
522
523 /*
524  * Run memory cache shrinker.
525  */
526 static unsigned long
527 rpcauth_cache_shrink_scan(struct shrinker *shrink, struct shrink_control *sc)
528
529 {
530         if ((sc->gfp_mask & GFP_KERNEL) != GFP_KERNEL)
531                 return SHRINK_STOP;
532
533         /* nothing left, don't come back */
534         if (list_empty(&cred_unused))
535                 return SHRINK_STOP;
536
537         return rpcauth_cache_do_shrink(sc->nr_to_scan);
538 }
539
540 static unsigned long
541 rpcauth_cache_shrink_count(struct shrinker *shrink, struct shrink_control *sc)
542
543 {
544         return number_cred_unused * sysctl_vfs_cache_pressure / 100;
545 }
546
547 static void
548 rpcauth_cache_enforce_limit(void)
549 {
550         unsigned long diff;
551         unsigned int nr_to_scan;
552
553         if (number_cred_unused <= auth_max_cred_cachesize)
554                 return;
555         diff = number_cred_unused - auth_max_cred_cachesize;
556         nr_to_scan = 100;
557         if (diff < nr_to_scan)
558                 nr_to_scan = diff;
559         rpcauth_cache_do_shrink(nr_to_scan);
560 }
561
562 /*
563  * Look up a process' credentials in the authentication cache
564  */
565 struct rpc_cred *
566 rpcauth_lookup_credcache(struct rpc_auth *auth, struct auth_cred * acred,
567                 int flags, gfp_t gfp)
568 {
569         LIST_HEAD(free);
570         struct rpc_cred_cache *cache = auth->au_credcache;
571         struct rpc_cred *cred = NULL,
572                         *entry, *new;
573         unsigned int nr;
574
575         nr = auth->au_ops->hash_cred(acred, cache->hashbits);
576
577         rcu_read_lock();
578         hlist_for_each_entry_rcu(entry, &cache->hashtable[nr], cr_hash) {
579                 if (!entry->cr_ops->crmatch(acred, entry, flags))
580                         continue;
581                 cred = get_rpccred(entry);
582                 if (cred)
583                         break;
584         }
585         rcu_read_unlock();
586
587         if (cred != NULL)
588                 goto found;
589
590         new = auth->au_ops->crcreate(auth, acred, flags, gfp);
591         if (IS_ERR(new)) {
592                 cred = new;
593                 goto out;
594         }
595
596         spin_lock(&cache->lock);
597         hlist_for_each_entry(entry, &cache->hashtable[nr], cr_hash) {
598                 if (!entry->cr_ops->crmatch(acred, entry, flags))
599                         continue;
600                 cred = get_rpccred(entry);
601                 if (cred)
602                         break;
603         }
604         if (cred == NULL) {
605                 cred = new;
606                 set_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags);
607                 refcount_inc(&cred->cr_count);
608                 hlist_add_head_rcu(&cred->cr_hash, &cache->hashtable[nr]);
609         } else
610                 list_add_tail(&new->cr_lru, &free);
611         spin_unlock(&cache->lock);
612         rpcauth_cache_enforce_limit();
613 found:
614         if (test_bit(RPCAUTH_CRED_NEW, &cred->cr_flags) &&
615             cred->cr_ops->cr_init != NULL &&
616             !(flags & RPCAUTH_LOOKUP_NEW)) {
617                 int res = cred->cr_ops->cr_init(auth, cred);
618                 if (res < 0) {
619                         put_rpccred(cred);
620                         cred = ERR_PTR(res);
621                 }
622         }
623         rpcauth_destroy_credlist(&free);
624 out:
625         return cred;
626 }
627 EXPORT_SYMBOL_GPL(rpcauth_lookup_credcache);
628
629 struct rpc_cred *
630 rpcauth_lookupcred(struct rpc_auth *auth, int flags)
631 {
632         struct auth_cred acred;
633         struct rpc_cred *ret;
634         const struct cred *cred = current_cred();
635
636         dprintk("RPC:       looking up %s cred\n",
637                 auth->au_ops->au_name);
638
639         memset(&acred, 0, sizeof(acred));
640         acred.cred = cred;
641         ret = auth->au_ops->lookup_cred(auth, &acred, flags);
642         return ret;
643 }
644 EXPORT_SYMBOL_GPL(rpcauth_lookupcred);
645
646 void
647 rpcauth_init_cred(struct rpc_cred *cred, const struct auth_cred *acred,
648                   struct rpc_auth *auth, const struct rpc_credops *ops)
649 {
650         INIT_HLIST_NODE(&cred->cr_hash);
651         INIT_LIST_HEAD(&cred->cr_lru);
652         refcount_set(&cred->cr_count, 1);
653         cred->cr_auth = auth;
654         cred->cr_flags = 0;
655         cred->cr_ops = ops;
656         cred->cr_expire = jiffies;
657         cred->cr_cred = get_cred(acred->cred);
658 }
659 EXPORT_SYMBOL_GPL(rpcauth_init_cred);
660
661 static struct rpc_cred *
662 rpcauth_bind_root_cred(struct rpc_task *task, int lookupflags)
663 {
664         struct rpc_auth *auth = task->tk_client->cl_auth;
665         struct auth_cred acred = {
666                 .cred = get_task_cred(&init_task),
667         };
668         struct rpc_cred *ret;
669
670         dprintk("RPC: %5u looking up %s cred\n",
671                 task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
672         ret = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
673         put_cred(acred.cred);
674         return ret;
675 }
676
677 static struct rpc_cred *
678 rpcauth_bind_machine_cred(struct rpc_task *task, int lookupflags)
679 {
680         struct rpc_auth *auth = task->tk_client->cl_auth;
681         struct auth_cred acred = {
682                 .principal = task->tk_client->cl_principal,
683                 .cred = init_task.cred,
684         };
685
686         if (!acred.principal)
687                 return NULL;
688         dprintk("RPC: %5u looking up %s machine cred\n",
689                 task->tk_pid, task->tk_client->cl_auth->au_ops->au_name);
690         return auth->au_ops->lookup_cred(auth, &acred, lookupflags);
691 }
692
693 static struct rpc_cred *
694 rpcauth_bind_new_cred(struct rpc_task *task, int lookupflags)
695 {
696         struct rpc_auth *auth = task->tk_client->cl_auth;
697
698         dprintk("RPC: %5u looking up %s cred\n",
699                 task->tk_pid, auth->au_ops->au_name);
700         return rpcauth_lookupcred(auth, lookupflags);
701 }
702
703 static int
704 rpcauth_bindcred(struct rpc_task *task, const struct cred *cred, int flags)
705 {
706         struct rpc_rqst *req = task->tk_rqstp;
707         struct rpc_cred *new = NULL;
708         int lookupflags = 0;
709         struct rpc_auth *auth = task->tk_client->cl_auth;
710         struct auth_cred acred = {
711                 .cred = cred,
712         };
713
714         if (flags & RPC_TASK_ASYNC)
715                 lookupflags |= RPCAUTH_LOOKUP_NEW;
716         if (task->tk_op_cred)
717                 /* Task must use exactly this rpc_cred */
718                 new = get_rpccred(task->tk_op_cred);
719         else if (cred != NULL && cred != &machine_cred)
720                 new = auth->au_ops->lookup_cred(auth, &acred, lookupflags);
721         else if (cred == &machine_cred)
722                 new = rpcauth_bind_machine_cred(task, lookupflags);
723
724         /* If machine cred couldn't be bound, try a root cred */
725         if (new)
726                 ;
727         else if (cred == &machine_cred || (flags & RPC_TASK_ROOTCREDS))
728                 new = rpcauth_bind_root_cred(task, lookupflags);
729         else if (flags & RPC_TASK_NULLCREDS)
730                 new = authnull_ops.lookup_cred(NULL, NULL, 0);
731         else
732                 new = rpcauth_bind_new_cred(task, lookupflags);
733         if (IS_ERR(new))
734                 return PTR_ERR(new);
735         put_rpccred(req->rq_cred);
736         req->rq_cred = new;
737         return 0;
738 }
739
740 void
741 put_rpccred(struct rpc_cred *cred)
742 {
743         if (cred == NULL)
744                 return;
745         rcu_read_lock();
746         if (refcount_dec_and_test(&cred->cr_count))
747                 goto destroy;
748         if (refcount_read(&cred->cr_count) != 1 ||
749             !test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags))
750                 goto out;
751         if (test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0) {
752                 cred->cr_expire = jiffies;
753                 rpcauth_lru_add(cred);
754                 /* Race breaker */
755                 if (unlikely(!test_bit(RPCAUTH_CRED_HASHED, &cred->cr_flags)))
756                         rpcauth_lru_remove(cred);
757         } else if (rpcauth_unhash_cred(cred)) {
758                 rpcauth_lru_remove(cred);
759                 if (refcount_dec_and_test(&cred->cr_count))
760                         goto destroy;
761         }
762 out:
763         rcu_read_unlock();
764         return;
765 destroy:
766         rcu_read_unlock();
767         cred->cr_ops->crdestroy(cred);
768 }
769 EXPORT_SYMBOL_GPL(put_rpccred);
770
771 __be32 *
772 rpcauth_marshcred(struct rpc_task *task, __be32 *p)
773 {
774         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
775
776         dprintk("RPC: %5u marshaling %s cred %p\n",
777                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
778
779         return cred->cr_ops->crmarshal(task, p);
780 }
781
782 __be32 *
783 rpcauth_checkverf(struct rpc_task *task, __be32 *p)
784 {
785         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
786
787         dprintk("RPC: %5u validating %s cred %p\n",
788                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
789
790         return cred->cr_ops->crvalidate(task, p);
791 }
792
793 static void rpcauth_wrap_req_encode(kxdreproc_t encode, struct rpc_rqst *rqstp,
794                                    __be32 *data, void *obj)
795 {
796         struct xdr_stream xdr;
797
798         xdr_init_encode(&xdr, &rqstp->rq_snd_buf, data);
799         encode(rqstp, &xdr, obj);
800 }
801
802 int
803 rpcauth_wrap_req(struct rpc_task *task, kxdreproc_t encode, void *rqstp,
804                 __be32 *data, void *obj)
805 {
806         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
807
808         dprintk("RPC: %5u using %s cred %p to wrap rpc data\n",
809                         task->tk_pid, cred->cr_ops->cr_name, cred);
810         if (cred->cr_ops->crwrap_req)
811                 return cred->cr_ops->crwrap_req(task, encode, rqstp, data, obj);
812         /* By default, we encode the arguments normally. */
813         rpcauth_wrap_req_encode(encode, rqstp, data, obj);
814         return 0;
815 }
816
817 static int
818 rpcauth_unwrap_req_decode(kxdrdproc_t decode, struct rpc_rqst *rqstp,
819                           __be32 *data, void *obj)
820 {
821         struct xdr_stream xdr;
822
823         xdr_init_decode(&xdr, &rqstp->rq_rcv_buf, data);
824         return decode(rqstp, &xdr, obj);
825 }
826
827 int
828 rpcauth_unwrap_resp(struct rpc_task *task, kxdrdproc_t decode, void *rqstp,
829                 __be32 *data, void *obj)
830 {
831         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
832
833         dprintk("RPC: %5u using %s cred %p to unwrap rpc data\n",
834                         task->tk_pid, cred->cr_ops->cr_name, cred);
835         if (cred->cr_ops->crunwrap_resp)
836                 return cred->cr_ops->crunwrap_resp(task, decode, rqstp,
837                                                    data, obj);
838         /* By default, we decode the arguments normally. */
839         return rpcauth_unwrap_req_decode(decode, rqstp, data, obj);
840 }
841
842 bool
843 rpcauth_xmit_need_reencode(struct rpc_task *task)
844 {
845         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
846
847         if (!cred || !cred->cr_ops->crneed_reencode)
848                 return false;
849         return cred->cr_ops->crneed_reencode(task);
850 }
851
852 int
853 rpcauth_refreshcred(struct rpc_task *task)
854 {
855         struct rpc_cred *cred;
856         int err;
857
858         cred = task->tk_rqstp->rq_cred;
859         if (cred == NULL) {
860                 err = rpcauth_bindcred(task, task->tk_msg.rpc_cred, task->tk_flags);
861                 if (err < 0)
862                         goto out;
863                 cred = task->tk_rqstp->rq_cred;
864         }
865         dprintk("RPC: %5u refreshing %s cred %p\n",
866                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
867
868         err = cred->cr_ops->crrefresh(task);
869 out:
870         if (err < 0)
871                 task->tk_status = err;
872         return err;
873 }
874
875 void
876 rpcauth_invalcred(struct rpc_task *task)
877 {
878         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
879
880         dprintk("RPC: %5u invalidating %s cred %p\n",
881                 task->tk_pid, cred->cr_auth->au_ops->au_name, cred);
882         if (cred)
883                 clear_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags);
884 }
885
886 int
887 rpcauth_uptodatecred(struct rpc_task *task)
888 {
889         struct rpc_cred *cred = task->tk_rqstp->rq_cred;
890
891         return cred == NULL ||
892                 test_bit(RPCAUTH_CRED_UPTODATE, &cred->cr_flags) != 0;
893 }
894
895 static struct shrinker rpc_cred_shrinker = {
896         .count_objects = rpcauth_cache_shrink_count,
897         .scan_objects = rpcauth_cache_shrink_scan,
898         .seeks = DEFAULT_SEEKS,
899 };
900
901 int __init rpcauth_init_module(void)
902 {
903         int err;
904
905         err = rpc_init_authunix();
906         if (err < 0)
907                 goto out1;
908         err = register_shrinker(&rpc_cred_shrinker);
909         if (err < 0)
910                 goto out2;
911         return 0;
912 out2:
913         rpc_destroy_authunix();
914 out1:
915         return err;
916 }
917
918 void rpcauth_remove_module(void)
919 {
920         rpc_destroy_authunix();
921         unregister_shrinker(&rpc_cred_shrinker);
922 }