1 // SPDX-License-Identifier: GPL-2.0
2 /* Copyright (c) 2017 - 2018 Covalent IO, Inc. http://covalent.io */
5 #include <linux/filter.h>
6 #include <linux/errno.h>
7 #include <linux/file.h>
9 #include <linux/workqueue.h>
10 #include <linux/skmsg.h>
11 #include <linux/list.h>
12 #include <linux/jhash.h>
13 #include <linux/sock_diag.h>
19 struct sk_psock_progs progs;
23 #define SOCK_CREATE_FLAG_MASK \
24 (BPF_F_NUMA_NODE | BPF_F_RDONLY | BPF_F_WRONLY)
26 static struct bpf_map *sock_map_alloc(union bpf_attr *attr)
28 struct bpf_stab *stab;
32 if (!capable(CAP_NET_ADMIN))
33 return ERR_PTR(-EPERM);
34 if (attr->max_entries == 0 ||
35 attr->key_size != 4 ||
36 (attr->value_size != sizeof(u32) &&
37 attr->value_size != sizeof(u64)) ||
38 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
39 return ERR_PTR(-EINVAL);
41 stab = kzalloc(sizeof(*stab), GFP_USER);
43 return ERR_PTR(-ENOMEM);
45 bpf_map_init_from_attr(&stab->map, attr);
46 raw_spin_lock_init(&stab->lock);
48 /* Make sure page count doesn't overflow. */
49 cost = (u64) stab->map.max_entries * sizeof(struct sock *);
50 err = bpf_map_charge_init(&stab->map.memory, cost);
54 stab->sks = bpf_map_area_alloc(stab->map.max_entries *
55 sizeof(struct sock *),
60 bpf_map_charge_finish(&stab->map.memory);
66 int sock_map_get_from_fd(const union bpf_attr *attr, struct bpf_prog *prog)
68 u32 ufd = attr->target_fd;
73 if (attr->attach_flags || attr->replace_bpf_fd)
77 map = __bpf_map_get(f);
80 ret = sock_map_prog_update(map, prog, NULL, attr->attach_type);
85 int sock_map_prog_detach(const union bpf_attr *attr, enum bpf_prog_type ptype)
87 u32 ufd = attr->target_fd;
88 struct bpf_prog *prog;
93 if (attr->attach_flags || attr->replace_bpf_fd)
97 map = __bpf_map_get(f);
101 prog = bpf_prog_get(attr->attach_bpf_fd);
107 if (prog->type != ptype) {
112 ret = sock_map_prog_update(map, NULL, prog, attr->attach_type);
120 static void sock_map_sk_acquire(struct sock *sk)
121 __acquires(&sk->sk_lock.slock)
128 static void sock_map_sk_release(struct sock *sk)
129 __releases(&sk->sk_lock.slock)
136 static void sock_map_add_link(struct sk_psock *psock,
137 struct sk_psock_link *link,
138 struct bpf_map *map, void *link_raw)
140 link->link_raw = link_raw;
142 spin_lock_bh(&psock->link_lock);
143 list_add_tail(&link->list, &psock->link);
144 spin_unlock_bh(&psock->link_lock);
147 static void sock_map_del_link(struct sock *sk,
148 struct sk_psock *psock, void *link_raw)
150 struct sk_psock_link *link, *tmp;
151 bool strp_stop = false;
153 spin_lock_bh(&psock->link_lock);
154 list_for_each_entry_safe(link, tmp, &psock->link, list) {
155 if (link->link_raw == link_raw) {
156 struct bpf_map *map = link->map;
157 struct bpf_stab *stab = container_of(map, struct bpf_stab,
159 if (psock->parser.enabled && stab->progs.skb_parser)
161 list_del(&link->list);
162 sk_psock_free_link(link);
165 spin_unlock_bh(&psock->link_lock);
167 write_lock_bh(&sk->sk_callback_lock);
168 sk_psock_stop_strp(sk, psock);
169 write_unlock_bh(&sk->sk_callback_lock);
173 static void sock_map_unref(struct sock *sk, void *link_raw)
175 struct sk_psock *psock = sk_psock(sk);
178 sock_map_del_link(sk, psock, link_raw);
179 sk_psock_put(sk, psock);
183 static int sock_map_init_proto(struct sock *sk, struct sk_psock *psock)
187 switch (sk->sk_type) {
189 prot = tcp_bpf_get_proto(sk, psock);
193 prot = udp_bpf_get_proto(sk, psock);
201 return PTR_ERR(prot);
203 sk_psock_update_proto(sk, psock, prot);
207 static struct sk_psock *sock_map_psock_get_checked(struct sock *sk)
209 struct sk_psock *psock;
212 psock = sk_psock(sk);
214 if (sk->sk_prot->close != sock_map_close) {
215 psock = ERR_PTR(-EBUSY);
219 if (!refcount_inc_not_zero(&psock->refcnt))
220 psock = ERR_PTR(-EBUSY);
227 static int sock_map_link(struct bpf_map *map, struct sk_psock_progs *progs,
230 struct bpf_prog *msg_parser, *skb_parser, *skb_verdict;
231 struct sk_psock *psock;
235 skb_verdict = READ_ONCE(progs->skb_verdict);
236 skb_parser = READ_ONCE(progs->skb_parser);
237 skb_progs = skb_parser && skb_verdict;
239 skb_verdict = bpf_prog_inc_not_zero(skb_verdict);
240 if (IS_ERR(skb_verdict))
241 return PTR_ERR(skb_verdict);
242 skb_parser = bpf_prog_inc_not_zero(skb_parser);
243 if (IS_ERR(skb_parser)) {
244 bpf_prog_put(skb_verdict);
245 return PTR_ERR(skb_parser);
249 msg_parser = READ_ONCE(progs->msg_parser);
251 msg_parser = bpf_prog_inc_not_zero(msg_parser);
252 if (IS_ERR(msg_parser)) {
253 ret = PTR_ERR(msg_parser);
258 psock = sock_map_psock_get_checked(sk);
260 ret = PTR_ERR(psock);
265 if ((msg_parser && READ_ONCE(psock->progs.msg_parser)) ||
266 (skb_progs && READ_ONCE(psock->progs.skb_parser))) {
267 sk_psock_put(sk, psock);
272 psock = sk_psock_init(sk, map->numa_node);
274 ret = PTR_ERR(psock);
280 psock_set_prog(&psock->progs.msg_parser, msg_parser);
282 ret = sock_map_init_proto(sk, psock);
286 write_lock_bh(&sk->sk_callback_lock);
287 if (skb_progs && !psock->parser.enabled) {
288 ret = sk_psock_init_strp(sk, psock);
290 write_unlock_bh(&sk->sk_callback_lock);
293 psock_set_prog(&psock->progs.skb_verdict, skb_verdict);
294 psock_set_prog(&psock->progs.skb_parser, skb_parser);
295 sk_psock_start_strp(sk, psock);
297 write_unlock_bh(&sk->sk_callback_lock);
300 sk_psock_put(sk, psock);
303 bpf_prog_put(msg_parser);
306 bpf_prog_put(skb_verdict);
307 bpf_prog_put(skb_parser);
312 static int sock_map_link_no_progs(struct bpf_map *map, struct sock *sk)
314 struct sk_psock *psock;
317 psock = sock_map_psock_get_checked(sk);
319 return PTR_ERR(psock);
322 psock = sk_psock_init(sk, map->numa_node);
324 return PTR_ERR(psock);
327 ret = sock_map_init_proto(sk, psock);
329 sk_psock_put(sk, psock);
333 static void sock_map_free(struct bpf_map *map)
335 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
338 /* After the sync no updates or deletes will be in-flight so it
339 * is safe to walk map and remove entries without risking a race
340 * in EEXIST update case.
343 for (i = 0; i < stab->map.max_entries; i++) {
344 struct sock **psk = &stab->sks[i];
347 sk = xchg(psk, NULL);
351 sock_map_unref(sk, psk);
357 /* wait for psock readers accessing its map link */
360 bpf_map_area_free(stab->sks);
364 static void sock_map_release_progs(struct bpf_map *map)
366 psock_progs_drop(&container_of(map, struct bpf_stab, map)->progs);
369 static struct sock *__sock_map_lookup_elem(struct bpf_map *map, u32 key)
371 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
373 WARN_ON_ONCE(!rcu_read_lock_held());
375 if (unlikely(key >= map->max_entries))
377 return READ_ONCE(stab->sks[key]);
380 static void *sock_map_lookup(struct bpf_map *map, void *key)
384 sk = __sock_map_lookup_elem(map, *(u32 *)key);
385 if (!sk || !sk_fullsock(sk))
387 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
392 static void *sock_map_lookup_sys(struct bpf_map *map, void *key)
396 if (map->value_size != sizeof(u64))
397 return ERR_PTR(-ENOSPC);
399 sk = __sock_map_lookup_elem(map, *(u32 *)key);
401 return ERR_PTR(-ENOENT);
404 return &sk->sk_cookie;
407 static int __sock_map_delete(struct bpf_stab *stab, struct sock *sk_test,
413 raw_spin_lock_bh(&stab->lock);
415 if (!sk_test || sk_test == sk)
416 sk = xchg(psk, NULL);
419 sock_map_unref(sk, psk);
423 raw_spin_unlock_bh(&stab->lock);
427 static void sock_map_delete_from_link(struct bpf_map *map, struct sock *sk,
430 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
432 __sock_map_delete(stab, sk, link_raw);
435 static int sock_map_delete_elem(struct bpf_map *map, void *key)
437 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
441 if (unlikely(i >= map->max_entries))
445 return __sock_map_delete(stab, NULL, psk);
448 static int sock_map_get_next_key(struct bpf_map *map, void *key, void *next)
450 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
451 u32 i = key ? *(u32 *)key : U32_MAX;
452 u32 *key_next = next;
454 if (i == stab->map.max_entries - 1)
456 if (i >= stab->map.max_entries)
463 static bool sock_map_redirect_allowed(const struct sock *sk);
465 static int sock_map_update_common(struct bpf_map *map, u32 idx,
466 struct sock *sk, u64 flags)
468 struct bpf_stab *stab = container_of(map, struct bpf_stab, map);
469 struct sk_psock_link *link;
470 struct sk_psock *psock;
474 WARN_ON_ONCE(!rcu_read_lock_held());
475 if (unlikely(flags > BPF_EXIST))
477 if (unlikely(idx >= map->max_entries))
480 link = sk_psock_init_link();
484 /* Only sockets we can redirect into/from in BPF need to hold
485 * refs to parser/verdict progs and have their sk_data_ready
486 * and sk_write_space callbacks overridden.
488 if (sock_map_redirect_allowed(sk))
489 ret = sock_map_link(map, &stab->progs, sk);
491 ret = sock_map_link_no_progs(map, sk);
495 psock = sk_psock(sk);
496 WARN_ON_ONCE(!psock);
498 raw_spin_lock_bh(&stab->lock);
499 osk = stab->sks[idx];
500 if (osk && flags == BPF_NOEXIST) {
503 } else if (!osk && flags == BPF_EXIST) {
508 sock_map_add_link(psock, link, map, &stab->sks[idx]);
511 sock_map_unref(osk, &stab->sks[idx]);
512 raw_spin_unlock_bh(&stab->lock);
515 raw_spin_unlock_bh(&stab->lock);
517 sk_psock_put(sk, psock);
519 sk_psock_free_link(link);
523 static bool sock_map_op_okay(const struct bpf_sock_ops_kern *ops)
525 return ops->op == BPF_SOCK_OPS_PASSIVE_ESTABLISHED_CB ||
526 ops->op == BPF_SOCK_OPS_ACTIVE_ESTABLISHED_CB ||
527 ops->op == BPF_SOCK_OPS_TCP_LISTEN_CB;
530 static bool sk_is_tcp(const struct sock *sk)
532 return sk->sk_type == SOCK_STREAM &&
533 sk->sk_protocol == IPPROTO_TCP;
536 static bool sk_is_udp(const struct sock *sk)
538 return sk->sk_type == SOCK_DGRAM &&
539 sk->sk_protocol == IPPROTO_UDP;
542 static bool sock_map_redirect_allowed(const struct sock *sk)
544 return sk_is_tcp(sk) && sk->sk_state != TCP_LISTEN;
547 static bool sock_map_sk_is_suitable(const struct sock *sk)
549 return sk_is_tcp(sk) || sk_is_udp(sk);
552 static bool sock_map_sk_state_allowed(const struct sock *sk)
555 return (1 << sk->sk_state) & (TCPF_ESTABLISHED | TCPF_LISTEN);
556 else if (sk_is_udp(sk))
557 return sk_hashed(sk);
562 static int sock_hash_update_common(struct bpf_map *map, void *key,
563 struct sock *sk, u64 flags);
565 static int sock_map_update_elem(struct bpf_map *map, void *key,
566 void *value, u64 flags)
573 if (map->value_size == sizeof(u64))
580 sock = sockfd_lookup(ufd, &ret);
588 if (!sock_map_sk_is_suitable(sk)) {
593 sock_map_sk_acquire(sk);
594 if (!sock_map_sk_state_allowed(sk))
596 else if (map->map_type == BPF_MAP_TYPE_SOCKMAP)
597 ret = sock_map_update_common(map, *(u32 *)key, sk, flags);
599 ret = sock_hash_update_common(map, key, sk, flags);
600 sock_map_sk_release(sk);
606 BPF_CALL_4(bpf_sock_map_update, struct bpf_sock_ops_kern *, sops,
607 struct bpf_map *, map, void *, key, u64, flags)
609 WARN_ON_ONCE(!rcu_read_lock_held());
611 if (likely(sock_map_sk_is_suitable(sops->sk) &&
612 sock_map_op_okay(sops)))
613 return sock_map_update_common(map, *(u32 *)key, sops->sk,
618 const struct bpf_func_proto bpf_sock_map_update_proto = {
619 .func = bpf_sock_map_update,
622 .ret_type = RET_INTEGER,
623 .arg1_type = ARG_PTR_TO_CTX,
624 .arg2_type = ARG_CONST_MAP_PTR,
625 .arg3_type = ARG_PTR_TO_MAP_KEY,
626 .arg4_type = ARG_ANYTHING,
629 BPF_CALL_4(bpf_sk_redirect_map, struct sk_buff *, skb,
630 struct bpf_map *, map, u32, key, u64, flags)
632 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
635 if (unlikely(flags & ~(BPF_F_INGRESS)))
638 sk = __sock_map_lookup_elem(map, key);
639 if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
642 tcb->bpf.flags = flags;
643 tcb->bpf.sk_redir = sk;
647 const struct bpf_func_proto bpf_sk_redirect_map_proto = {
648 .func = bpf_sk_redirect_map,
650 .ret_type = RET_INTEGER,
651 .arg1_type = ARG_PTR_TO_CTX,
652 .arg2_type = ARG_CONST_MAP_PTR,
653 .arg3_type = ARG_ANYTHING,
654 .arg4_type = ARG_ANYTHING,
657 BPF_CALL_4(bpf_msg_redirect_map, struct sk_msg *, msg,
658 struct bpf_map *, map, u32, key, u64, flags)
662 if (unlikely(flags & ~(BPF_F_INGRESS)))
665 sk = __sock_map_lookup_elem(map, key);
666 if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
674 const struct bpf_func_proto bpf_msg_redirect_map_proto = {
675 .func = bpf_msg_redirect_map,
677 .ret_type = RET_INTEGER,
678 .arg1_type = ARG_PTR_TO_CTX,
679 .arg2_type = ARG_CONST_MAP_PTR,
680 .arg3_type = ARG_ANYTHING,
681 .arg4_type = ARG_ANYTHING,
684 static int sock_map_btf_id;
685 const struct bpf_map_ops sock_map_ops = {
686 .map_alloc = sock_map_alloc,
687 .map_free = sock_map_free,
688 .map_get_next_key = sock_map_get_next_key,
689 .map_lookup_elem_sys_only = sock_map_lookup_sys,
690 .map_update_elem = sock_map_update_elem,
691 .map_delete_elem = sock_map_delete_elem,
692 .map_lookup_elem = sock_map_lookup,
693 .map_release_uref = sock_map_release_progs,
694 .map_check_btf = map_check_no_btf,
695 .map_btf_name = "bpf_stab",
696 .map_btf_id = &sock_map_btf_id,
699 struct bpf_shtab_elem {
703 struct hlist_node node;
707 struct bpf_shtab_bucket {
708 struct hlist_head head;
714 struct bpf_shtab_bucket *buckets;
717 struct sk_psock_progs progs;
721 static inline u32 sock_hash_bucket_hash(const void *key, u32 len)
723 return jhash(key, len, 0);
726 static struct bpf_shtab_bucket *sock_hash_select_bucket(struct bpf_shtab *htab,
729 return &htab->buckets[hash & (htab->buckets_num - 1)];
732 static struct bpf_shtab_elem *
733 sock_hash_lookup_elem_raw(struct hlist_head *head, u32 hash, void *key,
736 struct bpf_shtab_elem *elem;
738 hlist_for_each_entry_rcu(elem, head, node) {
739 if (elem->hash == hash &&
740 !memcmp(&elem->key, key, key_size))
747 static struct sock *__sock_hash_lookup_elem(struct bpf_map *map, void *key)
749 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
750 u32 key_size = map->key_size, hash;
751 struct bpf_shtab_bucket *bucket;
752 struct bpf_shtab_elem *elem;
754 WARN_ON_ONCE(!rcu_read_lock_held());
756 hash = sock_hash_bucket_hash(key, key_size);
757 bucket = sock_hash_select_bucket(htab, hash);
758 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
760 return elem ? elem->sk : NULL;
763 static void sock_hash_free_elem(struct bpf_shtab *htab,
764 struct bpf_shtab_elem *elem)
766 atomic_dec(&htab->count);
767 kfree_rcu(elem, rcu);
770 static void sock_hash_delete_from_link(struct bpf_map *map, struct sock *sk,
773 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
774 struct bpf_shtab_elem *elem_probe, *elem = link_raw;
775 struct bpf_shtab_bucket *bucket;
777 WARN_ON_ONCE(!rcu_read_lock_held());
778 bucket = sock_hash_select_bucket(htab, elem->hash);
780 /* elem may be deleted in parallel from the map, but access here
781 * is okay since it's going away only after RCU grace period.
782 * However, we need to check whether it's still present.
784 raw_spin_lock_bh(&bucket->lock);
785 elem_probe = sock_hash_lookup_elem_raw(&bucket->head, elem->hash,
786 elem->key, map->key_size);
787 if (elem_probe && elem_probe == elem) {
788 hlist_del_rcu(&elem->node);
789 sock_map_unref(elem->sk, elem);
790 sock_hash_free_elem(htab, elem);
792 raw_spin_unlock_bh(&bucket->lock);
795 static int sock_hash_delete_elem(struct bpf_map *map, void *key)
797 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
798 u32 hash, key_size = map->key_size;
799 struct bpf_shtab_bucket *bucket;
800 struct bpf_shtab_elem *elem;
803 hash = sock_hash_bucket_hash(key, key_size);
804 bucket = sock_hash_select_bucket(htab, hash);
806 raw_spin_lock_bh(&bucket->lock);
807 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
809 hlist_del_rcu(&elem->node);
810 sock_map_unref(elem->sk, elem);
811 sock_hash_free_elem(htab, elem);
814 raw_spin_unlock_bh(&bucket->lock);
818 static struct bpf_shtab_elem *sock_hash_alloc_elem(struct bpf_shtab *htab,
819 void *key, u32 key_size,
820 u32 hash, struct sock *sk,
821 struct bpf_shtab_elem *old)
823 struct bpf_shtab_elem *new;
825 if (atomic_inc_return(&htab->count) > htab->map.max_entries) {
827 atomic_dec(&htab->count);
828 return ERR_PTR(-E2BIG);
832 new = kmalloc_node(htab->elem_size, GFP_ATOMIC | __GFP_NOWARN,
833 htab->map.numa_node);
835 atomic_dec(&htab->count);
836 return ERR_PTR(-ENOMEM);
838 memcpy(new->key, key, key_size);
844 static int sock_hash_update_common(struct bpf_map *map, void *key,
845 struct sock *sk, u64 flags)
847 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
848 u32 key_size = map->key_size, hash;
849 struct bpf_shtab_elem *elem, *elem_new;
850 struct bpf_shtab_bucket *bucket;
851 struct sk_psock_link *link;
852 struct sk_psock *psock;
855 WARN_ON_ONCE(!rcu_read_lock_held());
856 if (unlikely(flags > BPF_EXIST))
859 link = sk_psock_init_link();
863 /* Only sockets we can redirect into/from in BPF need to hold
864 * refs to parser/verdict progs and have their sk_data_ready
865 * and sk_write_space callbacks overridden.
867 if (sock_map_redirect_allowed(sk))
868 ret = sock_map_link(map, &htab->progs, sk);
870 ret = sock_map_link_no_progs(map, sk);
874 psock = sk_psock(sk);
875 WARN_ON_ONCE(!psock);
877 hash = sock_hash_bucket_hash(key, key_size);
878 bucket = sock_hash_select_bucket(htab, hash);
880 raw_spin_lock_bh(&bucket->lock);
881 elem = sock_hash_lookup_elem_raw(&bucket->head, hash, key, key_size);
882 if (elem && flags == BPF_NOEXIST) {
885 } else if (!elem && flags == BPF_EXIST) {
890 elem_new = sock_hash_alloc_elem(htab, key, key_size, hash, sk, elem);
891 if (IS_ERR(elem_new)) {
892 ret = PTR_ERR(elem_new);
896 sock_map_add_link(psock, link, map, elem_new);
897 /* Add new element to the head of the list, so that
898 * concurrent search will find it before old elem.
900 hlist_add_head_rcu(&elem_new->node, &bucket->head);
902 hlist_del_rcu(&elem->node);
903 sock_map_unref(elem->sk, elem);
904 sock_hash_free_elem(htab, elem);
906 raw_spin_unlock_bh(&bucket->lock);
909 raw_spin_unlock_bh(&bucket->lock);
910 sk_psock_put(sk, psock);
912 sk_psock_free_link(link);
916 static int sock_hash_get_next_key(struct bpf_map *map, void *key,
919 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
920 struct bpf_shtab_elem *elem, *elem_next;
921 u32 hash, key_size = map->key_size;
922 struct hlist_head *head;
926 goto find_first_elem;
927 hash = sock_hash_bucket_hash(key, key_size);
928 head = &sock_hash_select_bucket(htab, hash)->head;
929 elem = sock_hash_lookup_elem_raw(head, hash, key, key_size);
931 goto find_first_elem;
933 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_next_rcu(&elem->node)),
934 struct bpf_shtab_elem, node);
936 memcpy(key_next, elem_next->key, key_size);
940 i = hash & (htab->buckets_num - 1);
943 for (; i < htab->buckets_num; i++) {
944 head = &sock_hash_select_bucket(htab, i)->head;
945 elem_next = hlist_entry_safe(rcu_dereference_raw(hlist_first_rcu(head)),
946 struct bpf_shtab_elem, node);
948 memcpy(key_next, elem_next->key, key_size);
956 static struct bpf_map *sock_hash_alloc(union bpf_attr *attr)
958 struct bpf_shtab *htab;
962 if (!capable(CAP_NET_ADMIN))
963 return ERR_PTR(-EPERM);
964 if (attr->max_entries == 0 ||
965 attr->key_size == 0 ||
966 (attr->value_size != sizeof(u32) &&
967 attr->value_size != sizeof(u64)) ||
968 attr->map_flags & ~SOCK_CREATE_FLAG_MASK)
969 return ERR_PTR(-EINVAL);
970 if (attr->key_size > MAX_BPF_STACK)
971 return ERR_PTR(-E2BIG);
973 htab = kzalloc(sizeof(*htab), GFP_USER);
975 return ERR_PTR(-ENOMEM);
977 bpf_map_init_from_attr(&htab->map, attr);
979 htab->buckets_num = roundup_pow_of_two(htab->map.max_entries);
980 htab->elem_size = sizeof(struct bpf_shtab_elem) +
981 round_up(htab->map.key_size, 8);
982 if (htab->buckets_num == 0 ||
983 htab->buckets_num > U32_MAX / sizeof(struct bpf_shtab_bucket)) {
988 cost = (u64) htab->buckets_num * sizeof(struct bpf_shtab_bucket) +
989 (u64) htab->elem_size * htab->map.max_entries;
990 if (cost >= U32_MAX - PAGE_SIZE) {
994 err = bpf_map_charge_init(&htab->map.memory, cost);
998 htab->buckets = bpf_map_area_alloc(htab->buckets_num *
999 sizeof(struct bpf_shtab_bucket),
1000 htab->map.numa_node);
1001 if (!htab->buckets) {
1002 bpf_map_charge_finish(&htab->map.memory);
1007 for (i = 0; i < htab->buckets_num; i++) {
1008 INIT_HLIST_HEAD(&htab->buckets[i].head);
1009 raw_spin_lock_init(&htab->buckets[i].lock);
1015 return ERR_PTR(err);
1018 static void sock_hash_free(struct bpf_map *map)
1020 struct bpf_shtab *htab = container_of(map, struct bpf_shtab, map);
1021 struct bpf_shtab_bucket *bucket;
1022 struct hlist_head unlink_list;
1023 struct bpf_shtab_elem *elem;
1024 struct hlist_node *node;
1027 /* After the sync no updates or deletes will be in-flight so it
1028 * is safe to walk map and remove entries without risking a race
1029 * in EEXIST update case.
1032 for (i = 0; i < htab->buckets_num; i++) {
1033 bucket = sock_hash_select_bucket(htab, i);
1035 /* We are racing with sock_hash_delete_from_link to
1036 * enter the spin-lock critical section. Every socket on
1037 * the list is still linked to sockhash. Since link
1038 * exists, psock exists and holds a ref to socket. That
1039 * lets us to grab a socket ref too.
1041 raw_spin_lock_bh(&bucket->lock);
1042 hlist_for_each_entry(elem, &bucket->head, node)
1043 sock_hold(elem->sk);
1044 hlist_move_list(&bucket->head, &unlink_list);
1045 raw_spin_unlock_bh(&bucket->lock);
1047 /* Process removed entries out of atomic context to
1048 * block for socket lock before deleting the psock's
1051 hlist_for_each_entry_safe(elem, node, &unlink_list, node) {
1052 hlist_del(&elem->node);
1053 lock_sock(elem->sk);
1055 sock_map_unref(elem->sk, elem);
1057 release_sock(elem->sk);
1059 sock_hash_free_elem(htab, elem);
1063 /* wait for psock readers accessing its map link */
1066 bpf_map_area_free(htab->buckets);
1070 static void *sock_hash_lookup_sys(struct bpf_map *map, void *key)
1074 if (map->value_size != sizeof(u64))
1075 return ERR_PTR(-ENOSPC);
1077 sk = __sock_hash_lookup_elem(map, key);
1079 return ERR_PTR(-ENOENT);
1081 sock_gen_cookie(sk);
1082 return &sk->sk_cookie;
1085 static void *sock_hash_lookup(struct bpf_map *map, void *key)
1089 sk = __sock_hash_lookup_elem(map, key);
1090 if (!sk || !sk_fullsock(sk))
1092 if (sk_is_refcounted(sk) && !refcount_inc_not_zero(&sk->sk_refcnt))
1097 static void sock_hash_release_progs(struct bpf_map *map)
1099 psock_progs_drop(&container_of(map, struct bpf_shtab, map)->progs);
1102 BPF_CALL_4(bpf_sock_hash_update, struct bpf_sock_ops_kern *, sops,
1103 struct bpf_map *, map, void *, key, u64, flags)
1105 WARN_ON_ONCE(!rcu_read_lock_held());
1107 if (likely(sock_map_sk_is_suitable(sops->sk) &&
1108 sock_map_op_okay(sops)))
1109 return sock_hash_update_common(map, key, sops->sk, flags);
1113 const struct bpf_func_proto bpf_sock_hash_update_proto = {
1114 .func = bpf_sock_hash_update,
1117 .ret_type = RET_INTEGER,
1118 .arg1_type = ARG_PTR_TO_CTX,
1119 .arg2_type = ARG_CONST_MAP_PTR,
1120 .arg3_type = ARG_PTR_TO_MAP_KEY,
1121 .arg4_type = ARG_ANYTHING,
1124 BPF_CALL_4(bpf_sk_redirect_hash, struct sk_buff *, skb,
1125 struct bpf_map *, map, void *, key, u64, flags)
1127 struct tcp_skb_cb *tcb = TCP_SKB_CB(skb);
1130 if (unlikely(flags & ~(BPF_F_INGRESS)))
1133 sk = __sock_hash_lookup_elem(map, key);
1134 if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1137 tcb->bpf.flags = flags;
1138 tcb->bpf.sk_redir = sk;
1142 const struct bpf_func_proto bpf_sk_redirect_hash_proto = {
1143 .func = bpf_sk_redirect_hash,
1145 .ret_type = RET_INTEGER,
1146 .arg1_type = ARG_PTR_TO_CTX,
1147 .arg2_type = ARG_CONST_MAP_PTR,
1148 .arg3_type = ARG_PTR_TO_MAP_KEY,
1149 .arg4_type = ARG_ANYTHING,
1152 BPF_CALL_4(bpf_msg_redirect_hash, struct sk_msg *, msg,
1153 struct bpf_map *, map, void *, key, u64, flags)
1157 if (unlikely(flags & ~(BPF_F_INGRESS)))
1160 sk = __sock_hash_lookup_elem(map, key);
1161 if (unlikely(!sk || !sock_map_redirect_allowed(sk)))
1169 const struct bpf_func_proto bpf_msg_redirect_hash_proto = {
1170 .func = bpf_msg_redirect_hash,
1172 .ret_type = RET_INTEGER,
1173 .arg1_type = ARG_PTR_TO_CTX,
1174 .arg2_type = ARG_CONST_MAP_PTR,
1175 .arg3_type = ARG_PTR_TO_MAP_KEY,
1176 .arg4_type = ARG_ANYTHING,
1179 static int sock_hash_map_btf_id;
1180 const struct bpf_map_ops sock_hash_ops = {
1181 .map_alloc = sock_hash_alloc,
1182 .map_free = sock_hash_free,
1183 .map_get_next_key = sock_hash_get_next_key,
1184 .map_update_elem = sock_map_update_elem,
1185 .map_delete_elem = sock_hash_delete_elem,
1186 .map_lookup_elem = sock_hash_lookup,
1187 .map_lookup_elem_sys_only = sock_hash_lookup_sys,
1188 .map_release_uref = sock_hash_release_progs,
1189 .map_check_btf = map_check_no_btf,
1190 .map_btf_name = "bpf_shtab",
1191 .map_btf_id = &sock_hash_map_btf_id,
1194 static struct sk_psock_progs *sock_map_progs(struct bpf_map *map)
1196 switch (map->map_type) {
1197 case BPF_MAP_TYPE_SOCKMAP:
1198 return &container_of(map, struct bpf_stab, map)->progs;
1199 case BPF_MAP_TYPE_SOCKHASH:
1200 return &container_of(map, struct bpf_shtab, map)->progs;
1208 int sock_map_prog_update(struct bpf_map *map, struct bpf_prog *prog,
1209 struct bpf_prog *old, u32 which)
1211 struct sk_psock_progs *progs = sock_map_progs(map);
1212 struct bpf_prog **pprog;
1218 case BPF_SK_MSG_VERDICT:
1219 pprog = &progs->msg_parser;
1221 case BPF_SK_SKB_STREAM_PARSER:
1222 pprog = &progs->skb_parser;
1224 case BPF_SK_SKB_STREAM_VERDICT:
1225 pprog = &progs->skb_verdict;
1232 return psock_replace_prog(pprog, prog, old);
1234 psock_set_prog(pprog, prog);
1238 static void sock_map_unlink(struct sock *sk, struct sk_psock_link *link)
1240 switch (link->map->map_type) {
1241 case BPF_MAP_TYPE_SOCKMAP:
1242 return sock_map_delete_from_link(link->map, sk,
1244 case BPF_MAP_TYPE_SOCKHASH:
1245 return sock_hash_delete_from_link(link->map, sk,
1252 static void sock_map_remove_links(struct sock *sk, struct sk_psock *psock)
1254 struct sk_psock_link *link;
1256 while ((link = sk_psock_link_pop(psock))) {
1257 sock_map_unlink(sk, link);
1258 sk_psock_free_link(link);
1262 void sock_map_unhash(struct sock *sk)
1264 void (*saved_unhash)(struct sock *sk);
1265 struct sk_psock *psock;
1268 psock = sk_psock(sk);
1269 if (unlikely(!psock)) {
1271 if (sk->sk_prot->unhash)
1272 sk->sk_prot->unhash(sk);
1276 saved_unhash = psock->saved_unhash;
1277 sock_map_remove_links(sk, psock);
1282 void sock_map_close(struct sock *sk, long timeout)
1284 void (*saved_close)(struct sock *sk, long timeout);
1285 struct sk_psock *psock;
1289 psock = sk_psock(sk);
1290 if (unlikely(!psock)) {
1293 return sk->sk_prot->close(sk, timeout);
1296 saved_close = psock->saved_close;
1297 sock_map_remove_links(sk, psock);
1300 saved_close(sk, timeout);