From 90d3fc65dc3160a269a4d31312db90b0c33b5705 Mon Sep 17 00:00:00 2001 From: Lorenzo Colitti Date: Wed, 23 Apr 2014 17:36:05 +0900 Subject: [PATCH] Simplify putting sockets onto networks. Change-Id: Ibc82cdf3c8dd80f8bcab84b5a76f1e4d36069c89 --- tests/net_test/mark_test.py | 192 +++++++++++++++++++++++--------------------- 1 file changed, 101 insertions(+), 91 deletions(-) diff --git a/tests/net_test/mark_test.py b/tests/net_test/mark_test.py index 594ff049..18baf348 100755 --- a/tests/net_test/mark_test.py +++ b/tests/net_test/mark_test.py @@ -483,6 +483,8 @@ class MultiNetworkTest(net_test.NetworkTest): cls._RestoreSysctls() def SetSocketMark(self, s, netid): + if netid is None: + netid = 0 s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) def GetSocketMark(self, s): @@ -496,19 +498,15 @@ class MultiNetworkTest(net_test.NetworkTest): iface = "" s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface) - def SetUnicastInterface(self, version, s, iface): - if iface: - ifindex = net_test.GetInterfaceIndex(iface) - else: - ifindex = 0 - # Otherwise, Python apparently thinks it's a 1-byte option. + def SetUnicastInterface(self, s, ifindex): + # Otherwise, Python thinks it's a 1-byte option. ifindex = struct.pack("!I", ifindex) - layer, opt = { - 4: (net_test.SOL_IP, IP_UNICAST_IF), - 6: (net_test.SOL_IPV6, IPV6_UNICAST_IF), - }[version] - s.setsockopt(layer, opt, ifindex) + # Always set the IPv4 interface, because it will be used even on IPv6 + # sockets if the destination address is a mapped address. + s.setsockopt(net_test.SOL_IP, IP_UNICAST_IF, ifindex) + if s.family == AF_INET6: + s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex) def ReceiveEtherPacketOn(self, netid, packet): posix.write(self.tuns[netid].fileno(), str(packet)) @@ -698,71 +696,81 @@ class MarkTest(MultiNetworkTest): def _GetRemoteAddress(self, version): return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version] - def BuildSocket(self, version, constructor, mark, uid, oif, ucast_oif): + def SelectInterface(self, s, netid, mode): + if mode == "uid": + raise ValueError("Can't change UID on an existing socket") + elif mode == "mark": + self.SetSocketMark(s, netid) + elif mode == "oif": + iface = self.GetInterfaceName(netid) if netid else "" + self.BindToDevice(s, iface) + elif mode == "ucast_oif": + self.SetUnicastInterface(s, self.ifindices.get(netid, 0)) + else: + raise ValueError("Unkown interface selection mode %s" % mode) + + def BuildSocket(self, version, constructor, netid, routing_mode): + uid = self.UidForNetid(netid) if routing_mode == "uid" else None with RunAsUid(uid): family = self.GetProtocolFamily(version) s = constructor(family) - if mark: - self.SetSocketMark(s, mark) - if oif: - self.BindToDevice(s, oif) - if ucast_oif: - self.SetUnicastInterface(version, s, ucast_oif) + + if routing_mode not in [None, "uid"]: + self.SelectInterface(s, netid, routing_mode) + return s - def CheckPingPacket(self, version, mark, uid, oif, ucast_oif, dstaddr, packet, - expected_netid): - s = self.BuildSocket(version, net_test.PingSocket, mark, uid, oif, - ucast_oif) + def CheckPingPacket(self, version, netid, routing_mode, dstaddr, packet): + s = self.BuildSocket(version, net_test.PingSocket, netid, routing_mode) - myaddr = self.MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, netid) s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) s.bind((myaddr, PING_IDENT)) net_test.SetSocketTos(s, PING_TOS) desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr) msg = "IPv%d ping: expected %s on %s" % ( - version, desc, self.GetInterfaceName(expected_netid)) + version, desc, self.GetInterfaceName(netid)) s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321)) - self.ExpectPacketOn(expected_netid, msg, expected) + self.ExpectPacketOn(netid, msg, expected) - def CheckTCPSYNPacket(self, version, mark, uid, oif, ucast_oif, dstaddr, - expected_netid): - s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif, ucast_oif) + def CheckTCPSYNPacket(self, version, netid, routing_mode, dstaddr): + s = self.BuildSocket(version, net_test.TCPSocket, netid, routing_mode) if version == 6 and dstaddr.startswith("::ffff"): version = 4 - myaddr = self.MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, netid) desc, expected = Packets.SYN(53, version, myaddr, dstaddr, sport=None, seq=None) # Non-blocking TCP connects always return EINPROGRESS. self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53)) msg = "IPv%s TCP connect: expected %s on %s" % ( - version, desc, self.GetInterfaceName(expected_netid)) - self.ExpectPacketOn(expected_netid, msg, expected) + version, desc, self.GetInterfaceName(netid)) + self.ExpectPacketOn(netid, msg, expected) s.close() - def CheckUDPPacket(self, version, mark, uid, oif, ucast_oif, - dstaddr, expected_netid): - s = self.BuildSocket(version, net_test.UDPSocket, mark, uid, oif, ucast_oif) + def CheckUDPPacket(self, version, netid, routing_mode, dstaddr): + s = self.BuildSocket(version, net_test.UDPSocket, netid, routing_mode) if version == 6 and dstaddr.startswith("::ffff"): version = 4 - myaddr = self.MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, netid) desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None) msg = "IPv%s UDP %%s: expected %s on %s" % ( - version, desc, self.GetInterfaceName(expected_netid)) + version, desc, self.GetInterfaceName(netid)) s.sendto(UDP_PAYLOAD, (dstaddr, 53)) - self.ExpectPacketOn(expected_netid, msg % "sendto", expected) + self.ExpectPacketOn(netid, msg % "sendto", expected) - s.connect((dstaddr, 53)) - s.send(UDP_PAYLOAD) - self.ExpectPacketOn(expected_netid, msg % "connect/send", expected) - s.close() + # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP. + if routing_mode != "ucast_oif": + s.connect((dstaddr, 53)) + s.send(UDP_PAYLOAD) + self.ExpectPacketOn(netid, msg % "connect/send", expected) + s.close() - def CheckOutgoingPackets(self, mode): + def CheckOutgoingPackets(self, routing_mode): v4addr = self.IPV4_ADDR v6addr = self.IPV6_ADDR v4mapped = "::ffff:" + v4addr @@ -770,35 +778,18 @@ class MarkTest(MultiNetworkTest): for _ in xrange(self.ITERATIONS): for netid in self.tuns: - mark = uid = oif = ucast_oif = None - if mode == "mark": - mark = netid - elif mode == "uid": - uid = self.UidForNetid(netid) - elif mode == "oif": - oif = self.GetInterfaceName(netid) - elif mode == "ucast_oif": - ucast_oif = self.GetInterfaceName(netid) - else: - raise ValueError("Unkown routing mode %s" % mode) - - self.CheckPingPacket(4, mark, uid, oif, ucast_oif, v4addr, - self.IPV4_PING, netid) - self.CheckPingPacket(6, mark, uid, oif, ucast_oif, v6addr, - self.IPV6_PING, netid) - - # TCP doesn't seem to honour IP_UNICAST_IF. - if mode != "ucast_oif": - self.CheckTCPSYNPacket(4, mark, uid, oif, ucast_oif, v4addr, netid) - self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v6addr, netid) - self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid) - - if mode != "ucast_oif": - # This doesn't work. - self.CheckUDPPacket(4, mark, uid, oif, ucast_oif, v4addr, netid) - # These work, but the source addresses are incorrect. - self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v6addr, netid) - self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid) + self.CheckPingPacket(4, netid, routing_mode, v4addr, self.IPV4_PING) + self.CheckPingPacket(6, netid, routing_mode, v6addr, self.IPV6_PING) + + # IP_UNICAST_IF doesn't seem to work on connected sockets, so no TCP. + if routing_mode != "ucast_oif": + self.CheckTCPSYNPacket(4, netid, routing_mode, v4addr) + self.CheckTCPSYNPacket(6, netid, routing_mode, v6addr) + self.CheckTCPSYNPacket(6, netid, routing_mode, v4mapped) + + self.CheckUDPPacket(4, netid, routing_mode, v4addr) + self.CheckUDPPacket(6, netid, routing_mode, v6addr) + self.CheckUDPPacket(6, netid, routing_mode, v4mapped) def testMarkRouting(self): """Checks that socket marking selects the right outgoing interface.""" @@ -817,34 +808,53 @@ class MarkTest(MultiNetworkTest): """Checks that ucast oif routing selects the right outgoing interface.""" self.CheckOutgoingPackets("ucast_oif") - def CheckRemarking(self, version): - s = net_test.UDPSocket(self.GetProtocolFamily(version)) + def CheckRemarking(self, version, use_connect): + # Remarking or resetting UNICAST_IF on connected sockets does not work. + if use_connect: + modes = ["oif"] + else: + modes = ["mark", "oif", "ucast_oif"] - # Figure out what packets to expect. - unspec = {4: "0.0.0.0", 6: "::"}[version] - sport = Packets.RandomPort() - s.bind((unspec, sport)) - dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version] - desc, expected = Packets.UDP(version, unspec, dstaddr, sport) + for mode in modes: + s = net_test.UDPSocket(self.GetProtocolFamily(version)) - # For each netid, set that netid's mark on the socket without closing it, - # and check that the packets sent on that socket go out on the right - # network. - for netid in self.tuns: - self.SetSocketMark(s, netid) - expected.src = self.MyAddress(version, netid) - s.sendto("hello", (dstaddr, 53)) - msg = "Remarked UDPv%d socket: expecting %s on %s" % ( - version, desc, self.GetInterfaceName(netid)) - self.ExpectPacketOn(netid, msg, expected) + # Figure out what packets to expect. + unspec = {4: "0.0.0.0", 6: "::"}[version] + sport = Packets.RandomPort() + s.bind((unspec, sport)) + dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version] + desc, expected = Packets.UDP(version, unspec, dstaddr, sport) + + # If we're testing connected sockets, connect the socket on the first + # netid now. + if use_connect: + netid = self.tuns.keys()[0] + self.SelectInterface(s, netid, mode) + s.connect((dstaddr, 53)) + expected.src = self.MyAddress(version, netid) + + # For each netid, select that network without closing the socket, and + # check that the packets sent on that socket go out on the right network. + for netid in self.tuns: + self.SelectInterface(s, netid, mode) + if not use_connect: + expected.src = self.MyAddress(version, netid) + s.sendto("hello", (dstaddr, 53)) + connected_str = "Connected" if use_connect else "Unconnected" + msg = "%s UDPv%d socket remarked using %s: expecting %s on %s" % ( + connected_str, version, mode, desc, self.GetInterfaceName(netid)) + self.ExpectPacketOn(netid, msg, expected) + self.SelectInterface(s, None, mode) def testIPv4Remarking(self): """Checks that updating the mark on an IPv4 socket changes routing.""" - self.CheckRemarking(4) + self.CheckRemarking(4, False) + self.CheckRemarking(4, True) def testIPv6Remarking(self): """Checks that updating the mark on an IPv6 socket changes routing.""" - self.CheckRemarking(6) + self.CheckRemarking(6, False) + self.CheckRemarking(6, True) def CheckReflection(self, version, packet_generator, reply_generator, mark_behaviour, callback=None): -- 2.11.0