OSDN Git Service

sctp: add dif and sdif check in asoc and ep lookup
authorXin Long <lucien.xin@gmail.com>
Wed, 16 Nov 2022 20:01:20 +0000 (15:01 -0500)
committerDavid S. Miller <davem@davemloft.net>
Fri, 18 Nov 2022 11:42:54 +0000 (11:42 +0000)
This patch at first adds a pernet global l3mdev_accept to decide if it
accepts the packets from a l3mdev when a SCTP socket doesn't bind to
any interface. It's set to 1 to avoid any possible incompatible issue,
and in next patch, a sysctl will be introduced to allow to change it.

Then similar to inet/udp_sk_bound_dev_eq(), sctp_sk_bound_dev_eq() is
added to check either dif or sdif is equal to sk_bound_dev_if, and to
check sid is 0 or l3mdev_accept is 1 if sk_bound_dev_if is not set.
This function is used to match a association or a endpoint, namely
called by sctp_addrs_lookup_transport() and sctp_endpoint_is_match().
All functions that needs updating are:

sctp_rcv():
  asoc:
  __sctp_rcv_lookup()
    __sctp_lookup_association() -> sctp_addrs_lookup_transport()
    __sctp_rcv_lookup_harder()
      __sctp_rcv_init_lookup()
         __sctp_lookup_association() -> sctp_addrs_lookup_transport()
      __sctp_rcv_walk_lookup()
         __sctp_rcv_asconf_lookup()
           __sctp_lookup_association() -> sctp_addrs_lookup_transport()

  ep:
  __sctp_rcv_lookup_endpoint() -> sctp_endpoint_is_match()

sctp_connect():
  sctp_endpoint_is_peeled_off()
    __sctp_lookup_association()
      sctp_has_association()
        sctp_lookup_association()
          __sctp_lookup_association() -> sctp_addrs_lookup_transport()

sctp_diag_dump_one():
  sctp_transport_lookup_process() -> sctp_addrs_lookup_transport()

Signed-off-by: Xin Long <lucien.xin@gmail.com>
Signed-off-by: David S. Miller <davem@davemloft.net>
include/net/netns/sctp.h
include/net/sctp/sctp.h
include/net/sctp/structs.h
net/sctp/diag.c
net/sctp/endpointola.c
net/sctp/input.c
net/sctp/protocol.c
net/sctp/socket.c

index a681147..7eff3d9 100644 (file)
@@ -175,6 +175,10 @@ struct netns_sctp {
 
        /* Threshold for autoclose timeout, in seconds. */
        unsigned long max_autoclose;
+
+#ifdef CONFIG_NET_L3_MASTER_DEV
+       int l3mdev_accept;
+#endif
 };
 
 #endif /* __NETNS_SCTP_H__ */
index 01d904b..c335dd0 100644 (file)
@@ -109,7 +109,7 @@ struct sctp_transport *sctp_transport_get_idx(struct net *net,
                        struct rhashtable_iter *iter, int pos);
 int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
                                  const union sctp_addr *laddr,
-                                 const union sctp_addr *paddr, void *p);
+                                 const union sctp_addr *paddr, void *p, int dif);
 int sctp_transport_traverse_process(sctp_callback_t cb, sctp_callback_t cb_done,
                                    struct net *net, int *pos, void *p);
 int sctp_for_each_endpoint(int (*cb)(struct sctp_endpoint *, void *), void *p);
@@ -157,10 +157,12 @@ void sctp_unhash_transport(struct sctp_transport *t);
 struct sctp_transport *sctp_addrs_lookup_transport(
                                struct net *net,
                                const union sctp_addr *laddr,
-                               const union sctp_addr *paddr);
+                               const union sctp_addr *paddr,
+                               int dif, int sdif);
 struct sctp_transport *sctp_epaddr_lookup_transport(
                                const struct sctp_endpoint *ep,
                                const union sctp_addr *paddr);
+bool sctp_sk_bound_dev_eq(struct net *net, int bound_dev_if, int dif, int sdif);
 
 /*
  * sctp/proc.c
index 7b4884c..afa3781 100644 (file)
@@ -1379,10 +1379,12 @@ struct sctp_association *sctp_endpoint_lookup_assoc(
        struct sctp_transport **);
 bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep,
                                 const union sctp_addr *paddr);
-struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *,
-                                       struct net *, const union sctp_addr *);
+struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep,
+                                            struct net *net,
+                                            const union sctp_addr *laddr,
+                                            int dif, int sdif);
 bool sctp_has_association(struct net *net, const union sctp_addr *laddr,
-                         const union sctp_addr *paddr);
+                         const union sctp_addr *paddr, int dif, int sdif);
 
 int sctp_verify_init(struct net *net, const struct sctp_endpoint *ep,
                     const struct sctp_association *asoc,
index d9c6d8f..a557009 100644 (file)
@@ -426,6 +426,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb,
        struct net *net = sock_net(skb->sk);
        const struct nlmsghdr *nlh = cb->nlh;
        union sctp_addr laddr, paddr;
+       int dif = req->id.idiag_if;
        struct sctp_comm_param commp = {
                .skb = skb,
                .r = req,
@@ -454,7 +455,7 @@ static int sctp_diag_dump_one(struct netlink_callback *cb,
        }
 
        return sctp_transport_lookup_process(sctp_sock_dump_one,
-                                            net, &laddr, &paddr, &commp);
+                                            net, &laddr, &paddr, &commp, dif);
 }
 
 static void sctp_diag_dump(struct sk_buff *skb, struct netlink_callback *cb,
index efffde7..7e77b45 100644 (file)
@@ -246,12 +246,15 @@ void sctp_endpoint_put(struct sctp_endpoint *ep)
 /* Is this the endpoint we are looking for?  */
 struct sctp_endpoint *sctp_endpoint_is_match(struct sctp_endpoint *ep,
                                               struct net *net,
-                                              const union sctp_addr *laddr)
+                                              const union sctp_addr *laddr,
+                                              int dif, int sdif)
 {
+       int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if);
        struct sctp_endpoint *retval = NULL;
 
-       if ((htons(ep->base.bind_addr.port) == laddr->v4.sin_port) &&
-           net_eq(ep->base.net, net)) {
+       if (net_eq(ep->base.net, net) &&
+           sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) &&
+           (htons(ep->base.bind_addr.port) == laddr->v4.sin_port)) {
                if (sctp_bind_addr_match(&ep->base.bind_addr, laddr,
                                         sctp_sk(ep->base.sk)))
                        retval = ep;
@@ -298,6 +301,7 @@ out:
 bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep,
                                 const union sctp_addr *paddr)
 {
+       int bound_dev_if = READ_ONCE(ep->base.sk->sk_bound_dev_if);
        struct sctp_sockaddr_entry *addr;
        struct net *net = ep->base.net;
        struct sctp_bind_addr *bp;
@@ -307,7 +311,8 @@ bool sctp_endpoint_is_peeled_off(struct sctp_endpoint *ep,
         * so the address_list can not change.
         */
        list_for_each_entry(addr, &bp->address_list, list) {
-               if (sctp_has_association(net, &addr->a, paddr))
+               if (sctp_has_association(net, &addr->a, paddr,
+                                        bound_dev_if, bound_dev_if))
                        return true;
        }
 
index 4f43afa..bf70371 100644 (file)
@@ -50,16 +50,19 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net,
                                      struct sk_buff *skb,
                                      const union sctp_addr *paddr,
                                      const union sctp_addr *laddr,
-                                     struct sctp_transport **transportp);
+                                     struct sctp_transport **transportp,
+                                     int dif, int sdif);
 static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
                                        struct net *net, struct sk_buff *skb,
                                        const union sctp_addr *laddr,
-                                       const union sctp_addr *daddr);
+                                       const union sctp_addr *daddr,
+                                       int dif, int sdif);
 static struct sctp_association *__sctp_lookup_association(
                                        struct net *net,
                                        const union sctp_addr *local,
                                        const union sctp_addr *peer,
-                                       struct sctp_transport **pt);
+                                       struct sctp_transport **pt,
+                                       int dif, int sdif);
 
 static int sctp_add_backlog(struct sock *sk, struct sk_buff *skb);
 
@@ -92,11 +95,11 @@ int sctp_rcv(struct sk_buff *skb)
        struct sctp_chunk *chunk;
        union sctp_addr src;
        union sctp_addr dest;
-       int bound_dev_if;
        int family;
        struct sctp_af *af;
        struct net *net = dev_net(skb->dev);
        bool is_gso = skb_is_gso(skb) && skb_is_gso_sctp(skb);
+       int dif, sdif;
 
        if (skb->pkt_type != PACKET_HOST)
                goto discard_it;
@@ -141,6 +144,8 @@ int sctp_rcv(struct sk_buff *skb)
        /* Initialize local addresses for lookups. */
        af->from_skb(&src, skb, 1);
        af->from_skb(&dest, skb, 0);
+       dif = af->skb_iif(skb);
+       sdif = af->skb_sdif(skb);
 
        /* If the packet is to or from a non-unicast address,
         * silently discard the packet.
@@ -157,36 +162,16 @@ int sctp_rcv(struct sk_buff *skb)
            !af->addr_valid(&dest, NULL, skb))
                goto discard_it;
 
-       asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport);
+       asoc = __sctp_rcv_lookup(net, skb, &src, &dest, &transport, dif, sdif);
 
        if (!asoc)
-               ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src);
+               ep = __sctp_rcv_lookup_endpoint(net, skb, &dest, &src, dif, sdif);
 
        /* Retrieve the common input handling substructure. */
        rcvr = asoc ? &asoc->base : &ep->base;
        sk = rcvr->sk;
 
        /*
-        * If a frame arrives on an interface and the receiving socket is
-        * bound to another interface, via SO_BINDTODEVICE, treat it as OOTB
-        */
-       bound_dev_if = READ_ONCE(sk->sk_bound_dev_if);
-       if (bound_dev_if && (bound_dev_if != af->skb_iif(skb))) {
-               if (transport) {
-                       sctp_transport_put(transport);
-                       asoc = NULL;
-                       transport = NULL;
-               } else {
-                       sctp_endpoint_put(ep);
-                       ep = NULL;
-               }
-               sk = net->sctp.ctl_sock;
-               ep = sctp_sk(sk)->ep;
-               sctp_endpoint_hold(ep);
-               rcvr = &ep->base;
-       }
-
-       /*
         * RFC 2960, 8.4 - Handle "Out of the blue" Packets.
         * An SCTP packet is called an "out of the blue" (OOTB)
         * packet if it is correctly formed, i.e., passed the
@@ -485,6 +470,8 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb,
        struct sctp_association *asoc;
        struct sctp_transport *transport = NULL;
        __u32 vtag = ntohl(sctphdr->vtag);
+       int sdif = inet_sdif(skb);
+       int dif = inet_iif(skb);
 
        *app = NULL; *tpp = NULL;
 
@@ -500,7 +487,7 @@ struct sock *sctp_err_lookup(struct net *net, int family, struct sk_buff *skb,
        /* Look for an association that matches the incoming ICMP error
         * packet.
         */
-       asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport);
+       asoc = __sctp_lookup_association(net, &saddr, &daddr, &transport, dif, sdif);
        if (!asoc)
                return NULL;
 
@@ -850,7 +837,8 @@ static inline __u32 sctp_hashfn(const struct net *net, __be16 lport,
 static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
                                        struct net *net, struct sk_buff *skb,
                                        const union sctp_addr *laddr,
-                                       const union sctp_addr *paddr)
+                                       const union sctp_addr *paddr,
+                                       int dif, int sdif)
 {
        struct sctp_hashbucket *head;
        struct sctp_endpoint *ep;
@@ -863,7 +851,7 @@ static struct sctp_endpoint *__sctp_rcv_lookup_endpoint(
        head = &sctp_ep_hashtable[hash];
        read_lock(&head->lock);
        sctp_for_each_hentry(ep, &head->chain) {
-               if (sctp_endpoint_is_match(ep, net, laddr))
+               if (sctp_endpoint_is_match(ep, net, laddr, dif, sdif))
                        goto hit;
        }
 
@@ -990,14 +978,26 @@ void sctp_unhash_transport(struct sctp_transport *t)
                        sctp_hash_params);
 }
 
+bool sctp_sk_bound_dev_eq(struct net *net, int bound_dev_if, int dif, int sdif)
+{
+       bool l3mdev_accept = true;
+
+#if IS_ENABLED(CONFIG_NET_L3_MASTER_DEV)
+       l3mdev_accept = !!READ_ONCE(net->sctp.l3mdev_accept);
+#endif
+       return inet_bound_dev_eq(l3mdev_accept, bound_dev_if, dif, sdif);
+}
+
 /* return a transport with holding it */
 struct sctp_transport *sctp_addrs_lookup_transport(
                                struct net *net,
                                const union sctp_addr *laddr,
-                               const union sctp_addr *paddr)
+                               const union sctp_addr *paddr,
+                               int dif, int sdif)
 {
        struct rhlist_head *tmp, *list;
        struct sctp_transport *t;
+       int bound_dev_if;
        struct sctp_hash_cmp_arg arg = {
                .paddr = paddr,
                .net   = net,
@@ -1011,7 +1011,9 @@ struct sctp_transport *sctp_addrs_lookup_transport(
                if (!sctp_transport_hold(t))
                        continue;
 
-               if (sctp_bind_addr_match(&t->asoc->base.bind_addr,
+               bound_dev_if = READ_ONCE(t->asoc->base.sk->sk_bound_dev_if);
+               if (sctp_sk_bound_dev_eq(net, bound_dev_if, dif, sdif) &&
+                   sctp_bind_addr_match(&t->asoc->base.bind_addr,
                                         laddr, sctp_sk(t->asoc->base.sk)))
                        return t;
                sctp_transport_put(t);
@@ -1048,12 +1050,13 @@ static struct sctp_association *__sctp_lookup_association(
                                        struct net *net,
                                        const union sctp_addr *local,
                                        const union sctp_addr *peer,
-                                       struct sctp_transport **pt)
+                                       struct sctp_transport **pt,
+                                       int dif, int sdif)
 {
        struct sctp_transport *t;
        struct sctp_association *asoc = NULL;
 
-       t = sctp_addrs_lookup_transport(net, local, peer);
+       t = sctp_addrs_lookup_transport(net, local, peer, dif, sdif);
        if (!t)
                goto out;
 
@@ -1069,12 +1072,13 @@ static
 struct sctp_association *sctp_lookup_association(struct net *net,
                                                 const union sctp_addr *laddr,
                                                 const union sctp_addr *paddr,
-                                                struct sctp_transport **transportp)
+                                                struct sctp_transport **transportp,
+                                                int dif, int sdif)
 {
        struct sctp_association *asoc;
 
        rcu_read_lock();
-       asoc = __sctp_lookup_association(net, laddr, paddr, transportp);
+       asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif);
        rcu_read_unlock();
 
        return asoc;
@@ -1083,11 +1087,12 @@ struct sctp_association *sctp_lookup_association(struct net *net,
 /* Is there an association matching the given local and peer addresses? */
 bool sctp_has_association(struct net *net,
                          const union sctp_addr *laddr,
-                         const union sctp_addr *paddr)
+                         const union sctp_addr *paddr,
+                         int dif, int sdif)
 {
        struct sctp_transport *transport;
 
-       if (sctp_lookup_association(net, laddr, paddr, &transport)) {
+       if (sctp_lookup_association(net, laddr, paddr, &transport, dif, sdif)) {
                sctp_transport_put(transport);
                return true;
        }
@@ -1115,7 +1120,8 @@ bool sctp_has_association(struct net *net,
  */
 static struct sctp_association *__sctp_rcv_init_lookup(struct net *net,
        struct sk_buff *skb,
-       const union sctp_addr *laddr, struct sctp_transport **transportp)
+       const union sctp_addr *laddr, struct sctp_transport **transportp,
+       int dif, int sdif)
 {
        struct sctp_association *asoc;
        union sctp_addr addr;
@@ -1154,7 +1160,7 @@ static struct sctp_association *__sctp_rcv_init_lookup(struct net *net,
                if (!af->from_addr_param(paddr, params.addr, sh->source, 0))
                        continue;
 
-               asoc = __sctp_lookup_association(net, laddr, paddr, transportp);
+               asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif);
                if (asoc)
                        return asoc;
        }
@@ -1181,7 +1187,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup(
                                        struct sctp_chunkhdr *ch,
                                        const union sctp_addr *laddr,
                                        __be16 peer_port,
-                                       struct sctp_transport **transportp)
+                                       struct sctp_transport **transportp,
+                                       int dif, int sdif)
 {
        struct sctp_addip_chunk *asconf = (struct sctp_addip_chunk *)ch;
        struct sctp_af *af;
@@ -1201,7 +1208,7 @@ static struct sctp_association *__sctp_rcv_asconf_lookup(
        if (!af->from_addr_param(&paddr, param, peer_port, 0))
                return NULL;
 
-       return __sctp_lookup_association(net, laddr, &paddr, transportp);
+       return __sctp_lookup_association(net, laddr, &paddr, transportp, dif, sdif);
 }
 
 
@@ -1217,7 +1224,8 @@ static struct sctp_association *__sctp_rcv_asconf_lookup(
 static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net,
                                      struct sk_buff *skb,
                                      const union sctp_addr *laddr,
-                                     struct sctp_transport **transportp)
+                                     struct sctp_transport **transportp,
+                                     int dif, int sdif)
 {
        struct sctp_association *asoc = NULL;
        struct sctp_chunkhdr *ch;
@@ -1260,7 +1268,7 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net,
                                asoc = __sctp_rcv_asconf_lookup(
                                                net, ch, laddr,
                                                sctp_hdr(skb)->source,
-                                               transportp);
+                                               transportp, dif, sdif);
                        break;
                default:
                        break;
@@ -1285,7 +1293,8 @@ static struct sctp_association *__sctp_rcv_walk_lookup(struct net *net,
 static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net,
                                      struct sk_buff *skb,
                                      const union sctp_addr *laddr,
-                                     struct sctp_transport **transportp)
+                                     struct sctp_transport **transportp,
+                                     int dif, int sdif)
 {
        struct sctp_chunkhdr *ch;
 
@@ -1309,9 +1318,9 @@ static struct sctp_association *__sctp_rcv_lookup_harder(struct net *net,
 
        /* If this is INIT/INIT-ACK look inside the chunk too. */
        if (ch->type == SCTP_CID_INIT || ch->type == SCTP_CID_INIT_ACK)
-               return __sctp_rcv_init_lookup(net, skb, laddr, transportp);
+               return __sctp_rcv_init_lookup(net, skb, laddr, transportp, dif, sdif);
 
-       return __sctp_rcv_walk_lookup(net, skb, laddr, transportp);
+       return __sctp_rcv_walk_lookup(net, skb, laddr, transportp, dif, sdif);
 }
 
 /* Lookup an association for an inbound skb. */
@@ -1319,11 +1328,12 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net,
                                      struct sk_buff *skb,
                                      const union sctp_addr *paddr,
                                      const union sctp_addr *laddr,
-                                     struct sctp_transport **transportp)
+                                     struct sctp_transport **transportp,
+                                     int dif, int sdif)
 {
        struct sctp_association *asoc;
 
-       asoc = __sctp_lookup_association(net, laddr, paddr, transportp);
+       asoc = __sctp_lookup_association(net, laddr, paddr, transportp, dif, sdif);
        if (asoc)
                goto out;
 
@@ -1331,7 +1341,7 @@ static struct sctp_association *__sctp_rcv_lookup(struct net *net,
         * SCTP Implementors Guide, 2.18 Handling of address
         * parameters within the INIT or INIT-ACK.
         */
-       asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp);
+       asoc = __sctp_rcv_lookup_harder(net, skb, laddr, transportp, dif, sdif);
        if (asoc)
                goto out;
 
index a18cf04..909a89a 100644 (file)
@@ -1394,6 +1394,10 @@ static int __net_init sctp_defaults_init(struct net *net)
        /* Initialize maximum autoclose timeout. */
        net->sctp.max_autoclose         = INT_MAX / HZ;
 
+#ifdef CONFIG_NET_L3_MASTER_DEV
+       net->sctp.l3mdev_accept = 1;
+#endif
+
        status = sctp_sysctl_net_register(net);
        if (status)
                goto err_sysctl_register;
index 4306164..5acbdf0 100644 (file)
@@ -5315,14 +5315,14 @@ EXPORT_SYMBOL_GPL(sctp_for_each_endpoint);
 
 int sctp_transport_lookup_process(sctp_callback_t cb, struct net *net,
                                  const union sctp_addr *laddr,
-                                 const union sctp_addr *paddr, void *p)
+                                 const union sctp_addr *paddr, void *p, int dif)
 {
        struct sctp_transport *transport;
        struct sctp_endpoint *ep;
        int err = -ENOENT;
 
        rcu_read_lock();
-       transport = sctp_addrs_lookup_transport(net, laddr, paddr);
+       transport = sctp_addrs_lookup_transport(net, laddr, paddr, dif, dif);
        if (!transport) {
                rcu_read_unlock();
                return err;