From 460fefc3d5298acaec30377aef26b9f06472d53d Mon Sep 17 00:00:00 2001 From: Lorenzo Colitti Date: Tue, 8 Apr 2014 15:45:08 +0900 Subject: [PATCH] Refactor setup and common code into a superclass. Change-Id: Iee489954175de6eec12b711d6c3ebb9a64cfd6c3 --- tests/net_test/mark_test.py | 198 +++++++++++++++++++++++--------------------- 1 file changed, 102 insertions(+), 96 deletions(-) diff --git a/tests/net_test/mark_test.py b/tests/net_test/mark_test.py index 48c9e983..a5cd120d 100755 --- a/tests/net_test/mark_test.py +++ b/tests/net_test/mark_test.py @@ -213,10 +213,7 @@ class RunAsUid(object): os.seteuid(self.saved_uid) -class MarkTest(net_test.NetworkTest): - - # How many times to run packet reflection tests. - ITERATIONS = 5 +class MultiNetworkTest(net_test.NetworkTest): # Must be between 1 and 256, since we put them in MAC addresses and IIDs. NETIDS = [100, 150, 200, 250] @@ -224,18 +221,12 @@ class MarkTest(net_test.NetworkTest): # Stores sysctl values to write back when the test completes. saved_sysctls = {} - # For convenience. - IPV4_ADDR = net_test.IPV4_ADDR - IPV6_ADDR = net_test.IPV6_ADDR - IPV4_PING = net_test.IPV4_PING - IPV6_PING = net_test.IPV6_PING - # Wether to output setup commands. DEBUG = False @staticmethod - def _GetInterfaceName(netid): - return "nettest%d" % netid + def UidForNetid(netid): + return 2000 + netid @classmethod def _TableForNetid(cls, netid): @@ -245,15 +236,15 @@ class MarkTest(net_test.NetworkTest): return netid @staticmethod - def _UidForNetid(netid): - return 2000 + netid + def GetInterfaceName(netid): + return "nettest%d" % netid @staticmethod - def _RouterMacAddress(netid): + def RouterMacAddress(netid): return "02:00:00:00:%02x:00" % netid @staticmethod - def _MyMacAddress(netid): + def MyMacAddress(netid): return "02:00:00:00:%02x:01" % netid @staticmethod @@ -271,22 +262,22 @@ class MarkTest(net_test.NetworkTest): @classmethod def _MyIPv6Address(cls, netid): - return net_test.GetLinkAddress(cls._GetInterfaceName(netid), False) + return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False) @classmethod - def _MyAddress(cls, version, netid): + def MyAddress(cls, version, netid): return {4: cls._MyIPv4Address(netid), 6: cls._MyIPv6Address(netid)}[version] @classmethod def _CreateTunInterface(cls, netid): - iface = cls._GetInterfaceName(netid) + iface = cls.GetInterfaceName(netid) f = open("/dev/net/tun", "r+b") ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI) ifr += "\x00" * (40 - len(ifr)) fcntl.ioctl(f, TUNSETIFF, ifr) # Give ourselves a predictable MAC address. - net_test.SetInterfaceHWAddr(iface, cls._MyMacAddress(netid)) + net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid)) # Disable DAD so we don't have to wait for it. open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0") net_test.SetInterfaceUp(iface) @@ -297,7 +288,7 @@ class MarkTest(net_test.NetworkTest): def _SendRA(cls, netid): validity = 300 # seconds validity_ms = validity * 1000 # milliseconds - macaddr = cls._RouterMacAddress(netid) + macaddr = cls.RouterMacAddress(netid) lladdr = cls._RouterAddress(netid, 6) # We don't want any routes in the main table. If the kernel doesn't support @@ -335,7 +326,7 @@ class MarkTest(net_test.NetworkTest): for version, iptables in zip([4, 6], ["iptables", "ip6tables"]): table = cls._TableForNetid(netid) - uid = cls._UidForNetid(netid) + uid = cls.UidForNetid(netid) if HAVE_EXPERIMENTAL_UID_ROUTING: cls.iproute.UidRule(version, is_add, uid, table, priority=100) cls.iproute.FwmarkRule(version, is_add, netid, table, priority=200) @@ -358,10 +349,10 @@ class MarkTest(net_test.NetworkTest): cmds = str("\n".join(cmds) % { "add_del": "add" if is_add else "del", "append_delete": "-A" if is_add else "-D", - "iface": cls._GetInterfaceName(netid), + "iface": cls.GetInterfaceName(netid), "iptables": iptables, "ipv4addr": cls._MyIPv4Address(netid), - "macaddr": cls._RouterMacAddress(netid), + "macaddr": cls.RouterMacAddress(netid), "mark": netid, "router": cls._RouterAddress(netid, version), "table": table, @@ -375,16 +366,16 @@ class MarkTest(net_test.NetworkTest): raise ConfigurationError("Setup command failed: %s" % " ".join(cmd)) @classmethod - def _GetSysctl(cls, sysctl): + def GetSysctl(cls, sysctl): return open(sysctl, "r").read() @classmethod - def _SetSysctl(cls, sysctl, value): + def SetSysctl(cls, sysctl, value): # Only save each sysctl value the first time we set it. This is so we can # set it to arbitrary values multiple times and still write it back # correctly at the end. if sysctl not in cls.saved_sysctls: - cls.saved_sysctls[sysctl] = cls._GetSysctl(sysctl) + cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl) open(sysctl, "w").write(str(value) + "\n") @classmethod @@ -399,21 +390,7 @@ class MarkTest(net_test.NetworkTest): @classmethod def _SetICMPRatelimit(cls, version, limit): - cls._SetSysctl(cls._ICMPRatelimitFilename(version), limit) - - @classmethod - def _SetMarkReflectSysctls(cls, value): - cls._SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value) - try: - cls._SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value) - except IOError: - # This does not exist if we use the version of the patch that uses a - # common sysctl for IPv4 and IPv6. - pass - - @classmethod - def _SetTCPMarkAcceptSysctl(cls, value): - cls._SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value) + cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit) @classmethod def setUpClass(cls): @@ -424,7 +401,7 @@ class MarkTest(net_test.NetworkTest): cls.tuns = {} cls.ifindices = {} if HAVE_AUTOCONF_TABLE: - cls._SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) + cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000) cls.AUTOCONF_TABLE_OFFSET = -1000 else: cls.AUTOCONF_TABLE_OFFSET = None @@ -436,12 +413,58 @@ class MarkTest(net_test.NetworkTest): for netid in cls.NETIDS: cls.tuns[netid] = cls._CreateTunInterface(netid) - iface = cls._GetInterfaceName(netid) + iface = cls.GetInterfaceName(netid) cls.ifindices[netid] = net_test.GetInterfaceIndex(iface) cls._SendRA(netid) cls._RunSetupCommands(netid, True) + # Uncomment to look around at interface and rule configuration while + # running in the background. (Once the test finishes running, all the + # interfaces and rules are gone.) + # time.sleep(30) + + @classmethod + def tearDownClass(cls): + for netid in cls.tuns: + cls._RunSetupCommands(netid, False) + cls.tuns[netid].close() + cls._RestoreSysctls() + + def SetSocketMark(self, s, netid): + s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) + + def GetSocketMark(self, s): + return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) + + def ReceivePacketOn(self, netid, ip_packet): + routermac = self.RouterMacAddress(netid) + mymac = self.MyMacAddress(netid) + packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet + posix.write(self.tuns[netid].fileno(), str(packet)) + + def ClearTunQueues(self): + # Keep reading packets on all netids until we get no packets on any of them. + waiting = None + while waiting != 0: + waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) + + +class MarkTest(MultiNetworkTest): + + # How many times to run packet reflection tests. + ITERATIONS = 5 + + # For convenience. + IPV4_ADDR = net_test.IPV4_ADDR + IPV6_ADDR = net_test.IPV6_ADDR + IPV4_PING = net_test.IPV4_PING + IPV6_PING = net_test.IPV6_PING + + @classmethod + def setUpClass(cls): + super(MarkTest, cls).setUpClass() + # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it # will accept both IPv4 and IPv6 connections. We do this here instead of in # each test so we can use the same socket every time. That way, if a kernel @@ -454,17 +477,19 @@ class MarkTest(net_test.NetworkTest): cls.listensocket.bind(("::", cls.listenport)) cls.listensocket.listen(100) - # Uncomment to look around at interface and rule configuration while - # running in the background. (Once the test finishes running, all the - # interfaces and rules are gone.) - # time.sleep(30) + @classmethod + def _SetMarkReflectSysctls(cls, value): + cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value) + try: + cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value) + except IOError: + # This does not exist if we use the version of the patch that uses a + # common sysctl for IPv4 and IPv6. + pass @classmethod - def tearDownClass(cls): - for netid in cls.tuns: - cls._RunSetupCommands(netid, False) - cls.tuns[netid].close() - cls._RestoreSysctls() + def _SetTCPMarkAcceptSysctl(cls, value): + cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value) def assertPacketMatches(self, expected, actual): # The expected packet is just a rough sketch of the packet we expect to @@ -578,37 +603,18 @@ class MarkTest(net_test.NetworkTest): raise UnexpectedPacketError( "%s: diff with last packet:\n%s" % (msg, e.message)) - def ReceivePacketOn(self, netid, ip_packet): - routermac = self._RouterMacAddress(netid) - mymac = self._MyMacAddress(netid) - packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet - posix.write(self.tuns[netid].fileno(), str(packet)) - - def ClearTunQueues(self): - # Keep reading packets on all netids until we get no packets on any of them. - waiting = None - while waiting != 0: - waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS) - def setUp(self): self.ClearTunQueues() - @classmethod - def _GetRemoteAddress(cls, version): - return {4: cls.IPV4_ADDR, 6: cls.IPV6_ADDR}[version] - - def SetSocketMark(self, s, netid): - s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid) - - def GetSocketMark(self, s): - return s.getsockopt(SOL_SOCKET, net_test.SO_MARK) + def _GetRemoteAddress(self, version): + return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version] - def GetProtocolFamily(self, version): + def _GetProtocolFamily(self, version): return {4: AF_INET, 6: AF_INET6}[version] def BuildSocket(self, version, constructor, mark, uid): with RunAsUid(uid): - family = self.GetProtocolFamily(version) + family = self._GetProtocolFamily(version) s = constructor(family) if mark: self.SetSocketMark(s, mark) @@ -618,7 +624,7 @@ class MarkTest(net_test.NetworkTest): expected_netid): s = self.BuildSocket(version, net_test.PingSocket, mark, uid) - myaddr = self._MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, expected_netid) s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1) s.bind((myaddr, PING_IDENT)) net_test.SetSocketTos(s, PING_TOS) @@ -628,7 +634,7 @@ class MarkTest(net_test.NetworkTest): self.ClearTunQueues() s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321)) msg = "IPv%d ping: expected %s on %s" % ( - version, desc, self._GetInterfaceName(expected_netid)) + version, desc, self.GetInterfaceName(expected_netid)) self.ExpectPacketOn(expected_netid, msg, expected) def CheckTCPSYNPacket(self, version, mark, uid, dstaddr, expected_netid): @@ -636,7 +642,7 @@ class MarkTest(net_test.NetworkTest): if version == 6 and dstaddr.startswith("::ffff"): version = 4 - myaddr = self._MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, expected_netid) desc, expected = Packets.SYN(53, version, myaddr, dstaddr, sport=None, seq=None) @@ -644,7 +650,7 @@ class MarkTest(net_test.NetworkTest): # Non-blocking TCP connects always return EINPROGRESS. self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53)) msg = "IPv%s TCP connect: expected %s on %s" % ( - version, desc, self._GetInterfaceName(expected_netid)) + version, desc, self.GetInterfaceName(expected_netid)) self.ExpectPacketOn(expected_netid, msg, expected) s.close() @@ -653,10 +659,10 @@ class MarkTest(net_test.NetworkTest): if version == 6 and dstaddr.startswith("::ffff"): version = 4 - myaddr = self._MyAddress(version, expected_netid) + myaddr = self.MyAddress(version, expected_netid) desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None) msg = "IPv%s UDP %%s: expected %s on %s" % ( - version, desc, self._GetInterfaceName(expected_netid)) + version, desc, self.GetInterfaceName(expected_netid)) self.ClearTunQueues() s.sendto(UDP_PAYLOAD, (dstaddr, 53)) @@ -690,18 +696,18 @@ class MarkTest(net_test.NetworkTest): """Checks that UID routing selects the right outgoing interface.""" for _ in xrange(self.ITERATIONS): for netid in self.tuns: - uid = self._UidForNetid(netid) + uid = self.UidForNetid(netid) self.CheckPingPacket(4, 0, uid, self.IPV4_ADDR, self.IPV4_PING, netid) self.CheckPingPacket(6, 0, uid, self.IPV6_ADDR, self.IPV6_PING, netid) for netid in self.tuns: - uid = self._UidForNetid(netid) + uid = self.UidForNetid(netid) self.CheckTCPSYNPacket(4, 0, uid, self.IPV4_ADDR, netid) self.CheckTCPSYNPacket(6, 0, uid, self.IPV6_ADDR, netid) self.CheckTCPSYNPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid) for netid in self.tuns: - uid = self._UidForNetid(netid) + uid = self.UidForNetid(netid) self.CheckUDPPacket(4, 0, uid, self.IPV4_ADDR, netid) self.CheckUDPPacket(6, 0, uid, self.IPV6_ADDR, netid) self.CheckUDPPacket(6, 0, uid, "::ffff:" + self.IPV4_ADDR, netid) @@ -738,14 +744,14 @@ class MarkTest(net_test.NetworkTest): # Check packets addressed to the IP addresses of all our interfaces... for dest_ip_netid in self.tuns: - dest_ip_iface = self._GetInterfaceName(dest_ip_netid) + dest_ip_iface = self.GetInterfaceName(dest_ip_netid) - myaddr = self._MyAddress(version, dest_ip_netid) + myaddr = self.MyAddress(version, dest_ip_netid) remote_addr = self._GetRemoteAddress(version) # ... coming in on all our interfaces... for iif_netid in self.tuns: - iif = self._GetInterfaceName(iif_netid) + iif = self.GetInterfaceName(iif_netid) desc, packet = packet_generator(version, remote_addr, myaddr) reply_desc, reply = reply_generator(version, myaddr, remote_addr, packet) @@ -799,8 +805,8 @@ class MarkTest(net_test.NetworkTest): def testIPv6RSTsReflectMark(self): self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect") - def CheckAcceptedSocketMarkCallback(self, netid, version, myaddr, - remote_addr, packet, reply, msg): + def CheckTCPConnection(self, netid, version, myaddr, remote_addr, + packet, reply, msg): establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1] self.ReceivePacketOn(netid, establishing_ack) s, unused_peer = self.listensocket.accept() @@ -829,25 +835,25 @@ class MarkTest(net_test.NetworkTest): @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported") def testIPv4TCPConnections(self): self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept", - self.CheckAcceptedSocketMarkCallback) + self.CheckTCPConnection) @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported") def testIPv6TCPConnections(self): self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept", - self.CheckAcceptedSocketMarkCallback) + self.CheckTCPConnection) @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported") def testTCPConnectionsWithSynCookies(self): # Force SYN cookies on all connections. - self._SetSysctl(SYNCOOKIES_SYSCTL, 2) + self.SetSysctl(SYNCOOKIES_SYSCTL, 2) try: self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept", - self.CheckAcceptedSocketMarkCallback) + self.CheckTCPConnection) self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept", - self.CheckAcceptedSocketMarkCallback) + self.CheckTCPConnection) finally: # Stop forcing SYN cookies on all connections. - self._SetSysctl(SYNCOOKIES_SYSCTL, 1) + self.SetSysctl(SYNCOOKIES_SYSCTL, 1) if __name__ == "__main__": -- 2.11.0