OSDN Git Service

Test mark-based routing for outgoing packets.
[android-x86/system-extras.git] / tests / net_test / mark_test.py
1 #!/usr/bin/python
2
3 import fcntl
4 import errno
5 import os
6 import posix
7 import struct
8 import time
9 import unittest
10 from scapy import all as scapy
11 from socket import *
12
13 import net_test
14
15 DEBUG = False
16
17 IFF_TUN = 1
18 IFF_TAP = 2
19 IFF_NO_PI = 0x1000
20 TUNSETIFF = 0x400454ca
21
22 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/route/autoconf_table_offset"
23
24 class ConfigurationError(AssertionError):
25   pass
26
27
28 class UnexpectedPacketError(AssertionError):
29   pass
30
31
32 class Packets(object):
33
34   @staticmethod
35   def _GetIpLayer(version):
36     return {4: scapy.IP, 6: scapy.IPv6}[version]
37
38   @staticmethod
39   def _SetPacketTos(packet, tos):
40     if isinstance(packet, scapy.IPv6):
41       packet.tc = tos
42     elif isinstance(packet, scapy.IP):
43       packet.tos = tos
44     else:
45       raise ValueError("Can't find ToS Field")
46
47   @classmethod
48   def UdpPacket(self, version, srcaddr, dstaddr):
49     ip = self._GetIpLayer(version)
50     return ("UDPv%d packet" % version,
51             ip(src=srcaddr, dst=dstaddr) /
52             scapy.UDP(sport=999, dport=1234) / "hello")
53
54   @classmethod
55   def SYN(self, port, version, srcaddr, dstaddr):
56     ip = self._GetIpLayer(version)
57     return ("TCP SYN",
58             ip(src=srcaddr, dst=dstaddr) /
59             scapy.TCP(sport=50999, dport=port, seq=1692871236, ack=0,
60                       flags=2, window=14400))
61
62   @classmethod
63   def RST(self, version, srcaddr, dstaddr, packet):
64     ip = self._GetIpLayer(version)
65     original = packet.getlayer("TCP")
66     return ("TCP RST",
67             ip(src=srcaddr, dst=dstaddr) /
68             scapy.TCP(sport=original.dport, dport=original.sport,
69                       ack=original.seq + 1, seq=None,
70                       flags=20, window=None))
71
72   @classmethod
73   def SYNACK(self, version, srcaddr, dstaddr, packet):
74     ip = self._GetIpLayer(version)
75     original = packet.getlayer("TCP")
76     return ("TCP SYN+ACK",
77             ip(src=srcaddr, dst=dstaddr) /
78             scapy.TCP(sport=original.dport, dport=original.sport,
79                       ack=original.seq + 1, seq=None,
80                       flags=18, window=None))
81
82   @classmethod
83   def ICMPPortUnreachable(self, version, srcaddr, dstaddr, packet):
84     if version == 4:
85       # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
86       # RFC 1812 4.3.2.5 (!).
87       return ("ICMPv4 port unreachable",
88               scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
89               scapy.ICMPerror(type=3, code=3) / packet)
90     else:
91       return ("ICMPv6 port unreachable",
92               scapy.IPv6(src=srcaddr, dst=dstaddr) /
93               scapy.ICMPv6DestUnreach(code=4) / packet)
94
95   @classmethod
96   def ICMPEcho(self, version, srcaddr, dstaddr):
97     ip = self._GetIpLayer(version)
98     icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
99     packet = (ip(src=srcaddr, dst=dstaddr) /
100               icmp(id=0xff19, seq=3) / "foobarbaz")
101     self._SetPacketTos(packet, 0x83)
102     return ("ICMPv%d echo" % version, packet)
103
104   @classmethod
105   def ICMPReply(self, version, srcaddr, dstaddr, packet, tos=None):
106     ip = self._GetIpLayer(version)
107
108     # Scapy doesn't provide an ICMP echo reply constructor.
109     icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
110     icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
111     packet = (ip(src=srcaddr, dst=dstaddr) /
112               icmp(id=0xff19, seq=3) / "foobarbaz")
113     self._SetPacketTos(packet, 0x83)
114     return ("ICMPv%d echo" % version, packet)
115
116
117 class MarkTest(net_test.NetworkTest):
118
119   NETIDS = [100, 200]
120
121   @staticmethod
122   def _RouterMacAddress(netid):
123     return "02:00:00:00:%02x:00" % netid
124
125   @staticmethod
126   def _MyMacAddress(netid):
127     return "02:00:00:00:%02x:01" % netid
128
129   @staticmethod
130   def _RouterAddress(netid, version):
131     if version == 6:
132       return "fe80::%02x00" % netid
133     elif version == 4:
134       return "10.0.%d.1" % netid
135     else:
136       raise ValueError("Don't support IPv%s" % version)
137
138   @staticmethod
139   def _MyIPv4Address(netid):
140     return "10.0.%d.2" % netid
141
142   @classmethod
143   def _CreateTunInterface(self, netid):
144     iface = self._GetInterfaceName(netid)
145     f = open("/dev/net/tun", "r+b")
146     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
147     ifr = ifr + "\x00" * (40 - len(ifr))
148     fcntl.ioctl(f, TUNSETIFF, ifr)
149     # Give ourselves a predictable MAC address.
150     macaddr = self._MyMacAddress(netid)
151     net_test.SetInterfaceHWAddr(iface, self._MyMacAddress(netid))
152     # Disable DAD so we don't have to wait for it.
153     open("/proc/sys/net/ipv6/conf/%s/dad_transmits" % iface, "w").write("0")
154     net_test.SetInterfaceUp(iface)
155     net_test.SetNonBlocking(f)
156     return f
157
158   @staticmethod
159   def _GetInterfaceName(netid):
160     return "nettest%d" % netid
161
162   @classmethod
163   def _SendRA(self, netid):
164     validity = 300                 # seconds
165     validity_ms = validity * 1000  # milliseconds
166     macaddr = self._RouterMacAddress(netid)
167     lladdr = self._RouterAddress(netid, 6)
168     ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
169           scapy.IPv6(src=lladdr, hlim=255) /
170           scapy.ICMPv6ND_RA(retranstimer=validity_ms,
171                             routerlifetime=validity) /
172           scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
173           scapy.ICMPv6NDOptPrefixInfo(prefix="2001:db8:%d::" % netid,
174                                       prefixlen=64,
175                                       L=1, A=1,
176                                       validlifetime=validity,
177                                       preferredlifetime=validity))
178     posix.write(self.tuns[netid].fileno(), str(ra))
179
180   COMMANDS = [
181       "/sbin/%(iptables)s %(append_delete)s INPUT -t mangle -i %(iface)s"
182       " -j MARK --set-mark %(netid)d",
183       "ip -%(version)d rule %(add_del)s fwmark %(netid)s lookup %(table)s",
184   ]
185   ROUTE_COMMANDS = [
186       "ip -%(version)d route %(add_del)s table %(table)s"
187       " default dev %(iface)s via %(router)s",
188   ]
189   IPV4_COMMANDS = [
190       "ip -4 nei %(add_del)s %(router)s dev %(iface)s"
191       " lladdr %(macaddr)s nud permanent",
192       "ip -4 addr %(add_del)s 10.0.%(netid)d.2/24 dev %(iface)s",
193   ]
194
195   @classmethod
196   def _RunSetupCommands(self, netid, is_add):
197     iface = self._GetInterfaceName(netid)
198     for version, iptables in zip([4, 6], ["iptables", "ip6tables"]):
199
200       if version == 6:
201         cmds = self.COMMANDS
202         if self.AUTOCONF_TABLE_OFFSET < 0:
203           # Set up routing manually.
204           # Don't do cmds += self.ROUTE_COMMANDS as this modifies self.COMMANDS.
205           cmds = self.COMMANDS + self.ROUTE_COMMANDS
206
207       if version == 4:
208         # Deleting addresses also causes routes to be deleted, so watch the
209         # order or the test will output lots of ENOENT errors.
210         if is_add:
211           cmds = self.COMMANDS + self.IPV4_COMMANDS + self.ROUTE_COMMANDS
212         else:
213           cmds = self.COMMANDS + self.ROUTE_COMMANDS + self.IPV4_COMMANDS
214
215       cmds = str("\n".join(cmds) % {
216           "add_del": "add" if is_add else "del",
217           "append_delete": "-A" if is_add else "-D",
218           "iface": iface,
219           "iptables": iptables,
220           "ipv4addr": self._MyIPv4Address(netid),
221           "macaddr": self._RouterMacAddress(netid),
222           "netid": netid,
223           "router": self._RouterAddress(netid, version),
224           "table": self._TableForNetid(netid),
225           "version": version,
226       }).split("\n")
227       for cmd in cmds:
228         cmd = cmd.split(" ")
229         if DEBUG: print " ".join(cmd)
230         ret = os.spawnvp(os.P_WAIT, cmd[0], cmd)
231         if ret:
232           raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
233
234   @classmethod
235   def _SetAutoconfTableSysctl(self, offset):
236     try:
237       open(AUTOCONF_TABLE_SYSCTL, "w").write(str(offset))
238       self.AUTOCONF_TABLE_OFFSET = offset
239     except IOError:
240       self.AUTOCONF_TABLE_OFFSET = -1
241
242   @classmethod
243   def _TableForNetid(self, netid):
244     if self.AUTOCONF_TABLE_OFFSET >= 0:
245       return self.ifindices[netid] + self.AUTOCONF_TABLE_OFFSET
246     else:
247       return netid      
248
249   @classmethod
250   def setUpClass(self):
251     self.tuns = {}
252     self.ifindices = {}
253     self._SetAutoconfTableSysctl(1000)
254     for netid in self.NETIDS:
255       self.tuns[netid] = self._CreateTunInterface(netid)
256
257       iface = self._GetInterfaceName(netid)
258       self.ifindices[netid] = net_test.GetInterfaceIndex(iface)
259
260       self._SendRA(netid)
261       self._RunSetupCommands(netid, True)
262
263     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
264     # will accept both IPv4 and IPv6 connections. We do this here instead of in
265     # each test so we can use the same socket every time. That way, if a kernel
266     # bug causes incoming packets to mark the listening socket instead of the
267     # accepted socket, the test will fail as soon as the next address/interface
268     # combination is tried.
269     self.listenport = 1234
270     self.listensocket = net_test.IPv6TCPSocket()
271     self.listensocket.bind(("::", self.listenport))
272     self.listensocket.listen(100)
273
274     # Give time for unknown things to settle down.
275     time.sleep(0.5)
276     # Uncomment to look around at interface and rule configuration while
277     # running in the background. (Once the test finishes running, all the
278     # interfaces and rules are gone.)
279     #time.sleep(30)
280
281   @classmethod
282   def tearDownClass(self):
283     for netid in self.tuns:
284       self._RunSetupCommands(netid, False)
285       self.tuns[netid].close()
286
287   def CheckExpectedPacket(self, expected, actual, msg):
288       # Remove the Ethernet header from the incoming packet.
289       actual = scapy.Ether(actual).payload
290
291       # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
292       actualip = actual.getlayer("IP")
293       expectedip = expected.getlayer("IP")
294       if actualip and expectedip:
295         actualip.id = expectedip.id
296         actualip.flags &= 5
297         actualip.chksum = None  # Change the header, recalculate the checksum.
298
299       # Blank out TCP fields that we can't predict.
300       actualtcp = actual.getlayer("TCP")
301       expectedtcp = expected.getlayer("TCP")
302       if actualtcp and expectedtcp:
303         actualtcp.dataofs = expectedtcp.dataofs
304         actualtcp.options = expectedtcp.options
305         actualtcp.window = expectedtcp.window
306         if expectedtcp.seq is None:
307           actualtcp.seq = None
308         if expectedtcp.ack is None:
309           actualtcp.ack = None
310         actualtcp.chksum = None
311
312       # Serialize the packet so:
313       # - Expected packet fields that are only set when a packet is serialized
314       #   (e.g., the checksum) are filled in.
315       # - The packet is readable. Scapy has detailed dissection capabilities,
316       #   but they only seem to be usable to print the packet, not return its
317       #   dissection as a string.
318       #   TODO: Check if this is true.
319       self.assertMultiLineEqual(str(expected).encode("hex"),
320                                 str(actual).encode("hex"))
321     
322   def assertNoPacketsOn(self, netids, msg):
323     for netid in netids:
324       try:
325         self.assertRaisesErrno(errno.EAGAIN, self.tuns[netid].read, 4096)
326       except AssertionError, e:
327         raise UnexpectedPacketError("%s: Unexpected packet on %s" % (
328             msg, self._GetInterfaceName(netid)))
329
330   def assertNoOtherPackets(self, msg):
331     self.assertNoPacketsOn([netid for netid in self.tuns], msg)
332
333   def assertNoPacketsExceptOn(self, netid, msg):
334     self.assertNoPacketsOn([n for n in self.tuns if n != netid], msg)
335
336   def ExpectPacketOn(self, netid, msg, expected=None):
337     # Check no packets were sent on any other netid.
338     self.assertNoPacketsExceptOn(netid, msg)
339
340     # Check that a packet was sent on netid.
341     try:
342       actual = self.tuns[netid].read(4096)
343     except IOError, e:
344       raise AssertionError(msg + ": " + str(e))
345     self.assertTrue(actual)
346
347     # If we know what sort of packet we expect, check that here.
348     if expected:
349       self.CheckExpectedPacket(expected, actual, msg)
350
351   def ReceivePacketOn(self, netid, ip_packet):
352     routermac = self._RouterMacAddress(netid)
353     mymac = self._MyMacAddress(netid)
354     packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
355     posix.write(self.tuns[netid].fileno(), str(packet))
356
357   def ClearTunQueues(self):
358     for f in self.tuns.values():
359       try:
360         f.read(4096)
361       except IOError:
362         continue
363     self.assertNoOtherPackets("Unexpected packets after clearing queues")
364
365   def setUp(self):
366     self.ClearTunQueues()
367
368   @staticmethod
369   def _GetRemoteAddress(version):
370     return {4: net_test.IPV4_ADDR, 6: net_test.IPV6_ADDR}[version]
371
372   def MarkSocket(self, s, netid):
373     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
374
375   def GetProtocolFamily(self, version):
376     return {4: AF_INET, 6: AF_INET6}[version]
377
378   def testOutgoingPackets(self):
379     """Checks that socket marking selects the right outgoing interface."""
380
381     def CheckPingPacket(version, netid, packet):
382       s = net_test.PingSocket(self.GetProtocolFamily(version))
383       dstaddr = self._GetRemoteAddress(version)
384       self.MarkSocket(s, netid)
385       s.sendto(packet, (dstaddr, 19321))
386       self.ExpectPacketOn(netid, "IPv%d ping: mark %d" % (version, netid))
387
388     for netid in self.tuns:
389       CheckPingPacket(4, netid, net_test.IPV4_PING)
390       CheckPingPacket(6, netid, net_test.IPV6_PING)
391
392     def CheckTCPSYNPacket(version, netid, dstaddr):
393       s = net_test.TCPSocket(self.GetProtocolFamily(version))
394       self.MarkSocket(s, netid)
395       # Non-blocking TCP connects always return EINPROGRESS.
396       self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
397       self.ExpectPacketOn(netid, "IPv%d TCP connect: mark %d" % (version,
398                                                                  netid))
399       s.close()
400
401     for netid in self.tuns:
402       CheckTCPSYNPacket(4, netid, net_test.IPV4_ADDR)
403       CheckTCPSYNPacket(6, netid, net_test.IPV6_ADDR)
404       CheckTCPSYNPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
405
406     def CheckUDPPacket(version, netid, dstaddr):
407       s = net_test.UDPSocket(self.GetProtocolFamily(version))
408       self.MarkSocket(s, netid)
409       s.sendto("hello", (dstaddr, 53))
410       self.ExpectPacketOn(netid, "IPv%d UDP sendto: mark %d" % (version, netid))
411       s.connect((dstaddr, 53))
412       s.send("hello")
413       self.ExpectPacketOn(netid, "IPv%d UDP connect/send: mark %d" % (version,
414                                                                       netid))
415       s.close()
416
417     for netid in self.tuns:
418       CheckUDPPacket(4, netid, net_test.IPV4_ADDR)
419       CheckUDPPacket(6, netid, net_test.IPV6_ADDR)
420       CheckUDPPacket(6, netid, "::ffff:" + net_test.IPV4_ADDR)
421
422   def CheckReflection(self, version, packet_generator, reply_generator):
423     """Checks that replies go out on the same interface as the original."""
424
425     # Check packets addressed to the IP addresses of all our interfaces...
426     for dest_ip_netid in self.tuns:
427       dest_ip_iface = self._GetInterfaceName(dest_ip_netid)
428
429       if version == 4:
430         myaddr = self._MyIPv4Address(dest_ip_netid)
431       else:
432         myaddr = net_test.GetLinkAddress(self._GetInterfaceName(dest_ip_netid),
433                                                                 False)
434       remote_addr = self._GetRemoteAddress(version)
435
436       # ... coming in on all our interfaces.
437       for iif_netid in self.tuns:
438         iif = self._GetInterfaceName(iif_netid)
439         desc, packet = packet_generator(version, remote_addr, myaddr)
440         if reply_generator:
441           # We know what we want a reply to.
442           reply_desc, reply = reply_generator(version, myaddr, remote_addr,
443                                               packet)
444         else:
445           # Expect any reply.
446           reply_desc, reply = "any packet", None
447         msg = "Receiving %s on %s to %s IP: Expecting %s on %s" % (
448             desc, iif, dest_ip_iface, reply_desc, iif)
449
450         # Expect a reply on the interface the original packet came in on.
451         self.ClearTunQueues()
452         self.ReceivePacketOn(iif_netid, packet)
453         self.ExpectPacketOn(iif_netid, msg, reply)
454
455   def SYNToClosedPort(self, *args):
456     return Packets.SYN(999, *args)
457
458   def SYNToOpenPort(self, *args):
459     return Packets.SYN(self.listenport, *args)
460
461   def testIPv4ICMPErrorsReflectMark(self):
462     self.CheckReflection(4, Packets.UdpPacket, Packets.ICMPPortUnreachable)
463
464   def testIPv6ICMPErrorsReflectMark(self):
465     self.CheckReflection(6, Packets.UdpPacket, Packets.ICMPPortUnreachable)
466
467   def testIPv4PingRepliesReflectMarkAndTos(self):
468     self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply)
469
470   def testIPv6PingRepliesReflectMarkAndTos(self):
471     self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply)
472
473   def testIPv4RSTsReflectMark(self):
474     self.CheckReflection(4, self.SYNToClosedPort, Packets.RST)
475
476   def testIPv6RSTsReflectMark(self):
477     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST)
478
479   @unittest.skipUnless(False, "skipping: doesn't work yet")
480   def testIPv4SYNACKsReflectMark(self):
481     self.CheckReflection(4, Packets.SYNToOpenPort, Packets.SYNACK)
482
483   @unittest.skipUnless(False, "skipping: doesn't work yet")
484   def testIPv6SYNACKsReflectMark(self):
485     self.CheckReflection(6, Packets.SYNToOpenPort, Packets.SYNACK)
486
487
488 if __name__ == "__main__":
489   unittest.main()