OSDN Git Service

Test outgoing oif routing.
[android-x86/system-extras.git] / tests / net_test / mark_test.py
1 #!/usr/bin/python
2
3 import errno
4 import fcntl
5 import os
6 import posix
7 import random
8 import re
9 from socket import *  # pylint: disable=wildcard-import
10 import struct
11 import time
12 import unittest
13
14 from scapy import all as scapy
15
16 import iproute
17 import net_test
18
19 IFF_TUN = 1
20 IFF_TAP = 2
21 IFF_NO_PI = 0x1000
22 TUNSETIFF = 0x400454ca
23
24 PING_IDENT = 0xff19
25 PING_PAYLOAD = "foobarbaz"
26 PING_SEQ = 3
27 PING_TOS = 0x83
28
29 SO_BINDTODEVICE = 25
30
31 UDP_PAYLOAD = "hello"
32
33
34 # Check to see if the kernel supports UID routing.
35 def HaveUidRouting():
36   result = False
37
38   # Create a rule with the UID selector. If the kernel doesn't understand the
39   # UID selector, it will create a rule with no selectors.
40   iproute.IPRoute().UidRule(6, True, 100, 100)
41
42   # Dump dump all the rules. If we find a rule using the UID selector, then the
43   # kernel supports UID routing.
44   rules = iproute.IPRoute().DumpRules(6)
45   for unused_rtmsg, attributes in rules:
46     for (nla, unused_nla_data) in attributes:
47       if nla.nla_type == iproute.EXPERIMENTAL_FRA_UID:
48         result = True
49         break
50
51   # Delete the rule.
52   iproute.IPRoute().UidRule(6, False, 100, 100)
53   return result
54
55
56 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
57 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
58 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
59 SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
60 TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
61
62 HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
63 HAVE_MARK_REFLECT = os.path.isfile(IPV4_MARK_REFLECT_SYSCTL)
64 HAVE_TCP_MARK_ACCEPT = os.path.isfile(TCP_MARK_ACCEPT_SYSCTL)
65
66 HAVE_EXPERIMENTAL_UID_ROUTING = HaveUidRouting()
67
68
69 class ConfigurationError(AssertionError):
70   pass
71
72
73 class UnexpectedPacketError(AssertionError):
74   pass
75
76
77 class Packets(object):
78
79   TCP_FIN = 1
80   TCP_SYN = 2
81   TCP_RST = 4
82   TCP_ACK = 16
83
84   TCP_SEQ = 1692871236
85   TCP_WINDOW = 14400
86
87   @staticmethod
88   def RandomPort():
89     return random.randint(1025, 65535)
90
91   @staticmethod
92   def _GetIpLayer(version):
93     return {4: scapy.IP, 6: scapy.IPv6}[version]
94
95   @staticmethod
96   def _SetPacketTos(packet, tos):
97     if isinstance(packet, scapy.IPv6):
98       packet.tc = tos
99     elif isinstance(packet, scapy.IP):
100       packet.tos = tos
101     else:
102       raise ValueError("Can't find ToS Field")
103
104   @classmethod
105   def UDP(cls, version, srcaddr, dstaddr, sport=0):
106     ip = cls._GetIpLayer(version)
107     # Can't just use "if sport" because None has meaning (it means unspecified).
108     if sport == 0:
109       sport = cls.RandomPort()
110     return ("UDPv%d packet" % version,
111             ip(src=srcaddr, dst=dstaddr) /
112             scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
113
114   @classmethod
115   def SYN(cls, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
116     ip = cls._GetIpLayer(version)
117     if sport == 0:
118       sport = cls.RandomPort()
119     return ("TCP SYN",
120             ip(src=srcaddr, dst=dstaddr) /
121             scapy.TCP(sport=sport, dport=dport,
122                       seq=seq, ack=0,
123                       flags=cls.TCP_SYN, window=cls.TCP_WINDOW))
124
125   @classmethod
126   def RST(cls, version, srcaddr, dstaddr, packet):
127     ip = cls._GetIpLayer(version)
128     original = packet.getlayer("TCP")
129     return ("TCP RST",
130             ip(src=srcaddr, dst=dstaddr) /
131             scapy.TCP(sport=original.dport, dport=original.sport,
132                       ack=original.seq + 1, seq=None,
133                       flags=cls.TCP_RST | cls.TCP_ACK, window=cls.TCP_WINDOW))
134
135   @classmethod
136   def SYNACK(cls, version, srcaddr, dstaddr, packet):
137     ip = cls._GetIpLayer(version)
138     original = packet.getlayer("TCP")
139     return ("TCP SYN+ACK",
140             ip(src=srcaddr, dst=dstaddr) /
141             scapy.TCP(sport=original.dport, dport=original.sport,
142                       ack=original.seq + 1, seq=None,
143                       flags=cls.TCP_SYN | cls.TCP_ACK, window=None))
144
145   @classmethod
146   def ACK(cls, version, srcaddr, dstaddr, packet):
147     ip = cls._GetIpLayer(version)
148     original = packet.getlayer("TCP")
149     was_syn_or_fin = (original.flags & (cls.TCP_SYN | cls.TCP_FIN)) != 0
150     return ("TCP ACK",
151             ip(src=srcaddr, dst=dstaddr) /
152             scapy.TCP(sport=original.dport, dport=original.sport,
153                       ack=original.seq + was_syn_or_fin, seq=original.ack,
154                       flags=cls.TCP_ACK, window=cls.TCP_WINDOW))
155
156   @classmethod
157   def FIN(cls, version, srcaddr, dstaddr, packet):
158     ip = cls._GetIpLayer(version)
159     original = packet.getlayer("TCP")
160     was_fin = (original.flags & cls.TCP_FIN) != 0
161     return ("TCP FIN",
162             ip(src=srcaddr, dst=dstaddr) /
163             scapy.TCP(sport=original.dport, dport=original.sport,
164                       ack=original.seq + was_fin, seq=original.ack,
165                       flags=cls.TCP_ACK | cls.TCP_FIN, window=cls.TCP_WINDOW))
166
167   @classmethod
168   def ICMPPortUnreachable(cls, version, srcaddr, dstaddr, packet):
169     if version == 4:
170       # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
171       # RFC 1812 4.3.2.5 (!).
172       return ("ICMPv4 port unreachable",
173               scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
174               scapy.ICMPerror(type=3, code=3) / packet)
175     else:
176       return ("ICMPv6 port unreachable",
177               scapy.IPv6(src=srcaddr, dst=dstaddr) /
178               scapy.ICMPv6DestUnreach(code=4) / packet)
179
180   @classmethod
181   def ICMPPacketTooBig(cls, version, srcaddr, dstaddr, packet):
182     if version == 4:
183       # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
184       # RFC 1812 4.3.2.5 (!).
185       raise NotImplementedError
186     else:
187       udp = packet.getlayer("UDP")
188       udp.payload = str(udp.payload)[:1280-40-8]
189       return ("ICMPv6 Packet Too Big",
190               scapy.IPv6(src=srcaddr, dst=dstaddr) /
191               scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
192
193   @classmethod
194   def ICMPEcho(cls, version, srcaddr, dstaddr):
195     ip = cls._GetIpLayer(version)
196     icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
197     packet = (ip(src=srcaddr, dst=dstaddr) /
198               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
199     cls._SetPacketTos(packet, PING_TOS)
200     return ("ICMPv%d echo" % version, packet)
201
202   @classmethod
203   def ICMPReply(cls, version, srcaddr, dstaddr, packet):
204     ip = cls._GetIpLayer(version)
205     # Scapy doesn't provide an ICMP echo reply constructor.
206     icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
207     icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
208     packet = (ip(src=srcaddr, dst=dstaddr) /
209               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
210     cls._SetPacketTos(packet, PING_TOS)
211     return ("ICMPv%d echo reply" % version, packet)
212
213
214 class RunAsUid(object):
215
216   """Context guard to run a code block as a given UID."""
217
218   def __init__(self, uid):
219     self.uid = uid
220
221   def __enter__(self):
222     if self.uid:
223       self.saved_uid = os.geteuid()
224       if self.uid:
225         os.seteuid(self.uid)
226
227   def __exit__(self, unused_type, unused_value, unused_traceback):
228     if self.uid:
229       os.seteuid(self.saved_uid)
230
231
232 class MultiNetworkTest(net_test.NetworkTest):
233
234   # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
235   NETIDS = [100, 150, 200, 250]
236
237   # Stores sysctl values to write back when the test completes.
238   saved_sysctls = {}
239
240   # Wether to output setup commands.
241   DEBUG = False
242
243   @staticmethod
244   def UidForNetid(netid):
245     return 2000 + netid
246
247   @classmethod
248   def _TableForNetid(cls, netid):
249     if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
250       return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
251     else:
252       return netid
253
254   @staticmethod
255   def GetInterfaceName(netid):
256     return "nettest%d" % netid
257
258   @staticmethod
259   def RouterMacAddress(netid):
260     return "02:00:00:00:%02x:00" % netid
261
262   @staticmethod
263   def MyMacAddress(netid):
264     return "02:00:00:00:%02x:01" % netid
265
266   @staticmethod
267   def _RouterAddress(netid, version):
268     if version == 6:
269       return "fe80::%02x00" % netid
270     elif version == 4:
271       return "10.0.%d.1" % netid
272     else:
273       raise ValueError("Don't support IPv%s" % version)
274
275   @classmethod
276   def _MyIPv4Address(cls, netid):
277     return "10.0.%d.2" % netid
278
279   @classmethod
280   def _MyIPv6Address(cls, netid):
281     return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
282
283   @classmethod
284   def MyAddress(cls, version, netid):
285     return {4: cls._MyIPv4Address(netid),
286             6: cls._MyIPv6Address(netid)}[version]
287
288   @classmethod
289   def CreateTunInterface(cls, netid):
290     iface = cls.GetInterfaceName(netid)
291     f = open("/dev/net/tun", "r+b")
292     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
293     ifr += "\x00" * (40 - len(ifr))
294     fcntl.ioctl(f, TUNSETIFF, ifr)
295     # Give ourselves a predictable MAC address.
296     net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
297     # Disable DAD so we don't have to wait for it.
298     cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
299     net_test.SetInterfaceUp(iface)
300     net_test.SetNonBlocking(f)
301     return f
302
303   @classmethod
304   def SendRA(cls, netid):
305     validity = 300                 # seconds
306     validity_ms = validity * 1000  # milliseconds
307     macaddr = cls.RouterMacAddress(netid)
308     lladdr = cls._RouterAddress(netid, 6)
309
310     # We don't want any routes in the main table. If the kernel doesn't support
311     # putting RA routes into per-interface tables, configure routing manually.
312     routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
313
314     ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
315           scapy.IPv6(src=lladdr, hlim=255) /
316           scapy.ICMPv6ND_RA(retranstimer=validity_ms,
317                             routerlifetime=routerlifetime) /
318           scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
319           scapy.ICMPv6NDOptPrefixInfo(prefix="2001:db8:%d::" % netid,
320                                       prefixlen=64,
321                                       L=1, A=1,
322                                       validlifetime=validity,
323                                       preferredlifetime=validity))
324     posix.write(cls.tuns[netid].fileno(), str(ra))
325
326   @classmethod
327   def _RunSetupCommands(cls, netid, is_add):
328     iptables_commands = [
329         "/sbin/%(iptables)s %(append_delete)s INPUT -t mangle -i %(iface)s"
330         " -j MARK --set-mark %(mark)d",
331     ]
332     route_commands = [
333         "ip -%(version)d route %(add_del)s table %(table)s"
334         " default dev %(iface)s via %(router)s",
335     ]
336     ipv4_commands = [
337         "ip -4 nei %(add_del)s %(router)s dev %(iface)s"
338         " lladdr %(macaddr)s nud permanent",
339         "ip -4 addr %(add_del)s %(ipv4addr)s/24 dev %(iface)s",
340     ]
341
342     for version, iptables in zip([4, 6], ["iptables", "ip6tables"]):
343       table = cls._TableForNetid(netid)
344       uid = cls.UidForNetid(netid)
345       iface = cls.GetInterfaceName(netid)
346       if HAVE_EXPERIMENTAL_UID_ROUTING:
347         cls.iproute.UidRule(version, is_add, uid, table, priority=100)
348       cls.iproute.OifRule(version, is_add, iface, table, priority=200)
349       cls.iproute.FwmarkRule(version, is_add, netid, table, priority=300)
350
351       if cls.DEBUG:
352         os.spawnvp(os.P_WAIT, "/sbin/ip", ["ip", "-6", "rule", "list"])
353
354       if version == 6:
355         if cls.AUTOCONF_TABLE_OFFSET is None:
356           # Set up routing manually.
357           cmds = iptables_commands + route_commands
358         else:
359           cmds = iptables_commands
360
361       if version == 4:
362         # Deleting addresses also causes routes to be deleted, so watch the
363         # order or the test will output lots of ENOENT errors.
364         if is_add:
365           cmds = iptables_commands + ipv4_commands + route_commands
366         else:
367           cmds = iptables_commands + route_commands + ipv4_commands
368
369       cmds = str("\n".join(cmds) % {
370           "add_del": "add" if is_add else "del",
371           "append_delete": "-A" if is_add else "-D",
372           "iface": iface,
373           "iptables": iptables,
374           "ipv4addr": cls._MyIPv4Address(netid),
375           "macaddr": cls.RouterMacAddress(netid),
376           "mark": netid,
377           "router": cls._RouterAddress(netid, version),
378           "table": table,
379           "version": version,
380       }).split("\n")
381       for cmd in cmds:
382         cmd = cmd.split(" ")
383         if cls.DEBUG: print " ".join(cmd)
384         ret = os.spawnvp(os.P_WAIT, cmd[0], cmd)
385         if ret:
386           raise ConfigurationError("Setup command failed: %s" % " ".join(cmd))
387
388   @classmethod
389   def GetSysctl(cls, sysctl):
390     return open(sysctl, "r").read()
391
392   @classmethod
393   def SetSysctl(cls, sysctl, value):
394     # Only save each sysctl value the first time we set it. This is so we can
395     # set it to arbitrary values multiple times and still write it back
396     # correctly at the end.
397     if sysctl not in cls.saved_sysctls:
398       cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
399     open(sysctl, "w").write(str(value) + "\n")
400
401   @classmethod
402   def _RestoreSysctls(cls):
403     for sysctl, value in cls.saved_sysctls.iteritems():
404       try:
405         open(sysctl, "w").write(value)
406       except IOError:
407         pass
408
409   @classmethod
410   def _ICMPRatelimitFilename(cls, version):
411     return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
412                                6: "ipv6/icmp/ratelimit"}[version]
413
414   @classmethod
415   def _SetICMPRatelimit(cls, version, limit):
416     cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
417
418   @classmethod
419   def setUpClass(cls):
420     # This is per-class setup instead of per-testcase setup because shelling out
421     # to ip and iptables is slow, and because routing configuration doesn't
422     # change during the test.
423     cls.iproute = iproute.IPRoute()
424     cls.tuns = {}
425     cls.ifindices = {}
426     if HAVE_AUTOCONF_TABLE:
427       cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
428       cls.AUTOCONF_TABLE_OFFSET = -1000
429     else:
430       cls.AUTOCONF_TABLE_OFFSET = None
431
432     # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
433     for version in [4, 6]:
434       cls._SetICMPRatelimit(version, 0)
435
436     for netid in cls.NETIDS:
437       cls.tuns[netid] = cls.CreateTunInterface(netid)
438       iface = cls.GetInterfaceName(netid)
439       cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
440
441       cls.SendRA(netid)
442       cls._RunSetupCommands(netid, True)
443
444     # Uncomment to look around at interface and rule configuration while
445     # running in the background. (Once the test finishes running, all the
446     # interfaces and rules are gone.)
447     # time.sleep(30)
448
449   @classmethod
450   def tearDownClass(cls):
451     for netid in cls.tuns:
452       cls._RunSetupCommands(netid, False)
453       cls.tuns[netid].close()
454     cls._RestoreSysctls()
455
456   def SetSocketMark(self, s, netid):
457     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
458
459   def GetSocketMark(self, s):
460     return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
461
462   def ClearSocketMark(self, s):
463     self.SetSocketMark(s, 0)
464
465   def BindToDevice(self, s, iface):
466     if not iface:
467       iface = ""
468     s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
469
470   def ReceivePacketOn(self, netid, ip_packet):
471     routermac = self.RouterMacAddress(netid)
472     mymac = self.MyMacAddress(netid)
473     packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
474     posix.write(self.tuns[netid].fileno(), str(packet))
475
476   def ReadAllPacketsOn(self, netid):
477     packets = []
478     while True:
479       try:
480         packet = posix.read(self.tuns[netid].fileno(), 4096)
481         ether = scapy.Ether(packet)
482         # Skip multicast frames, i.e., frames where the first byte of the
483         # destination MAC address has 1 in the least-significant bit.
484         if not int(ether.dst.split(":")[0], 16) & 0x1:
485           packets.append(ether.payload)
486       except OSError, e:
487         # EAGAIN means there are no more packets waiting.
488         if re.match(e.message, os.strerror(errno.EAGAIN)):
489           break
490         # Anything else is unexpected.
491         else:
492           raise e
493     return packets
494
495   def ClearTunQueues(self):
496     # Keep reading packets on all netids until we get no packets on any of them.
497     waiting = None
498     while waiting != 0:
499       waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
500
501
502 class MarkTest(MultiNetworkTest):
503
504   # How many times to run packet reflection tests.
505   ITERATIONS = 5
506
507   # For convenience.
508   IPV4_ADDR = net_test.IPV4_ADDR
509   IPV6_ADDR = net_test.IPV6_ADDR
510   IPV4_PING = net_test.IPV4_PING
511   IPV6_PING = net_test.IPV6_PING
512
513   @classmethod
514   def setUpClass(cls):
515     super(MarkTest, cls).setUpClass()
516
517     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
518     # will accept both IPv4 and IPv6 connections. We do this here instead of in
519     # each test so we can use the same socket every time. That way, if a kernel
520     # bug causes incoming packets to mark the listening socket instead of the
521     # accepted socket, the test will fail as soon as the next address/interface
522     # combination is tried.
523     cls.listenport = 1234
524     cls.listensocket = net_test.IPv6TCPSocket()
525     cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
526     cls.listensocket.bind(("::", cls.listenport))
527     cls.listensocket.listen(100)
528
529   @classmethod
530   def _SetMarkReflectSysctls(cls, value):
531     cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
532     try:
533       cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
534     except IOError:
535       # This does not exist if we use the version of the patch that uses a
536       # common sysctl for IPv4 and IPv6.
537       pass
538
539   @classmethod
540   def _SetTCPMarkAcceptSysctl(cls, value):
541     cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
542
543   def assertPacketMatches(self, expected, actual):
544     # The expected packet is just a rough sketch of the packet we expect to
545     # receive. For example, it doesn't contain fields we can't predict, such as
546     # initial TCP sequence numbers, or that depend on the host implementation
547     # and settings, such as TCP options. To check whether the packet matches
548     # what we expect, instead of just checking all the known fields one by one,
549     # we blank out fields in the actual packet and then compare the whole
550     # packets to each other as strings. Because we modify the actual packet,
551     # make a copy here.
552     actual = actual.copy()
553
554     # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
555     actualip = actual.getlayer("IP")
556     expectedip = expected.getlayer("IP")
557     if actualip and expectedip:
558       actualip.id = expectedip.id
559       actualip.flags &= 5
560       actualip.chksum = None  # Change the header, recalculate the checksum.
561
562     # Blank out UDP fields that we can't predict (e.g., the source port for
563     # kernel-originated packets).
564     actualudp = actual.getlayer("UDP")
565     expectedudp = expected.getlayer("UDP")
566     if actualudp and expectedudp:
567       if expectedudp.sport is None:
568         actualudp.sport = None
569         actualudp.chksum = None
570
571     # Since the TCP code below messes with options, recalculate the length.
572     if actualip:
573       actualip.len = None
574     actualipv6 = actual.getlayer("IPv6")
575     if actualipv6:
576       actualipv6.plen = None
577
578     # Blank out TCP fields that we can't predict.
579     actualtcp = actual.getlayer("TCP")
580     expectedtcp = expected.getlayer("TCP")
581     if actualtcp and expectedtcp:
582       actualtcp.dataofs = expectedtcp.dataofs
583       actualtcp.options = expectedtcp.options
584       actualtcp.window = expectedtcp.window
585       if expectedtcp.sport is None:
586         actualtcp.sport = None
587       if expectedtcp.seq is None:
588         actualtcp.seq = None
589       if expectedtcp.ack is None:
590         actualtcp.ack = None
591       actualtcp.chksum = None
592
593     # Serialize the packet so that expected packet fields that are only set when
594     # a packet is serialized e.g., the checksum) are filled in.
595     expected_real = expected.__class__(str(expected))
596     actual_real = actual.__class__(str(actual))
597     # repr() can be expensive. Call it only if the test is going to fail and we
598     # want to see the error.
599     if expected_real != actual_real:
600       self.assertEquals(repr(expected_real), repr(actual_real))
601
602   def PacketMatches(self, expected, actual):
603     try:
604       self.assertPacketMatches(expected, actual)
605       return True
606     except AssertionError:
607       return False
608
609   def ExpectNoPacketsOn(self, netid, msg, expected):
610     packets = self.ReadAllPacketsOn(netid)
611     if packets:
612       firstpacket = str(packets[0]).encode("hex")
613     else:
614       firstpacket = ""
615     self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
616
617   def ExpectPacketOn(self, netid, msg, expected):
618     packets = self.ReadAllPacketsOn(netid)
619     self.assertTrue(packets, msg + ": received no packets")
620
621     # If we receive a packet that matches what we expected, return it.
622     for packet in packets:
623       if self.PacketMatches(expected, packet):
624         return packet
625
626     # None of the packets matched. Call assertPacketMatches to output a diff
627     # between the expected packet and the last packet we received. In theory,
628     # we'd output a diff to the packet that's the best match for what we
629     # expected, but this is good enough for now.
630     try:
631       self.assertPacketMatches(expected, packets[-1])
632     except Exception, e:
633       raise UnexpectedPacketError(
634           "%s: diff with last packet:\n%s" % (msg, e.message))
635
636   def setUp(self):
637     self.ClearTunQueues()
638
639   def tearDown(self):
640     # In case there was an exception in one of the tests and we didn't clean up.
641     self.BindToDevice(self.listensocket, None)
642
643   def _GetRemoteAddress(self, version):
644     return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
645
646   def _GetProtocolFamily(self, version):
647     return {4: AF_INET, 6: AF_INET6}[version]
648
649   def BuildSocket(self, version, constructor, mark, uid, oif):
650     with RunAsUid(uid):
651       family = self._GetProtocolFamily(version)
652       s = constructor(family)
653     if mark:
654       self.SetSocketMark(s, mark)
655     if oif:
656       self.BindToDevice(s, oif)
657     return s
658
659   def CheckPingPacket(self, version, mark, uid, oif, dstaddr, packet,
660                       expected_netid):
661     s = self.BuildSocket(version, net_test.PingSocket, mark, uid, oif)
662
663     myaddr = self.MyAddress(version, expected_netid)
664     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
665     s.bind((myaddr, PING_IDENT))
666     net_test.SetSocketTos(s, PING_TOS)
667
668     desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
669
670     s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
671     msg = "IPv%d ping: expected %s on %s" % (
672         version, desc, self.GetInterfaceName(expected_netid))
673     self.ExpectPacketOn(expected_netid, msg, expected)
674
675
676   def CheckTCPSYNPacket(self, version, mark, uid, oif, dstaddr, expected_netid):
677     s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif)
678
679     if version == 6 and dstaddr.startswith("::ffff"):
680       version = 4
681     myaddr = self.MyAddress(version, expected_netid)
682     desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
683                                  sport=None, seq=None)
684
685     # Non-blocking TCP connects always return EINPROGRESS.
686     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
687     msg = "IPv%s TCP connect: expected %s on %s" % (
688         version, desc, self.GetInterfaceName(expected_netid))
689     self.ExpectPacketOn(expected_netid, msg, expected)
690     s.close()
691
692   def CheckUDPPacket(self, version, mark, uid, oif, dstaddr, expected_netid):
693     s = self.BuildSocket(version, net_test.UDPSocket, mark, uid, oif)
694
695     if version == 6 and dstaddr.startswith("::ffff"):
696       version = 4
697     myaddr = self.MyAddress(version, expected_netid)
698     desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
699     msg = "IPv%s UDP %%s: expected %s on %s" % (
700         version, desc, self.GetInterfaceName(expected_netid))
701
702     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
703     self.ExpectPacketOn(expected_netid, msg % "sendto", expected)
704
705     s.connect((dstaddr, 53))
706     s.send(UDP_PAYLOAD)
707     self.ExpectPacketOn(expected_netid, msg % "connect/send", expected)
708     s.close()
709
710   def CheckOutgoingPackets(self, mode):
711     v4addr = self.IPV4_ADDR
712     v6addr = self.IPV6_ADDR
713
714     for _ in xrange(self.ITERATIONS):
715       for netid in self.tuns:
716
717         if mode == "mark":
718           mark, uid, oif = (netid, 0, 0)
719         elif mode == "uid":
720           mark, uid, oif = (0, self.UidForNetid(netid), 0)
721         elif mode == "oif":
722           mark, uid, oif = (0, 0, self.GetInterfaceName(netid))
723         else:
724           raise ValueError("Unkown routing mode %s" % mode)
725
726         self.CheckPingPacket(4, mark, uid, oif, v4addr, self.IPV4_PING, netid)
727         # Kernel bug.
728         if mode != "oif":
729           self.CheckPingPacket(6, mark, uid, oif, v6addr, self.IPV6_PING, netid)
730
731         self.CheckTCPSYNPacket(4, mark, uid, oif, v4addr, netid)
732         self.CheckTCPSYNPacket(6, mark, uid, oif, v6addr, netid)
733         self.CheckTCPSYNPacket(6, mark, uid, oif, "::ffff:" + v4addr, netid)
734
735         self.CheckUDPPacket(4, mark, uid, oif, v4addr, netid)
736         self.CheckUDPPacket(6, mark, uid, oif, v6addr, netid)
737         self.CheckUDPPacket(6, mark, uid, oif, "::ffff:" + v4addr, netid)
738
739   def testMarkRouting(self):
740     """Checks that socket marking selects the right outgoing interface."""
741     self.CheckOutgoingPackets("mark")
742
743   @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")
744   def testUidRouting(self):
745     """Checks that UID routing selects the right outgoing interface."""
746     self.CheckOutgoingPackets("uid")
747
748   def testOifRouting(self):
749     """Checks that oif routing selects the right outgoing interface."""
750     self.CheckOutgoingPackets("oif")
751
752   def CheckReflection(self, version, packet_generator, reply_generator,
753                       mark_behaviour, callback=None):
754     """Checks that replies go out on the same interface as the original.
755
756     Iterates through all the combinations of the interfaces in self.tuns and the
757     IP addresses assigned to them. For each combination:
758      - Calls packet_generator to generate a packet to that IP address.
759      - Writes the packet generated by packet_generator on the given tun
760        interface, causing the kernel to receive it.
761      - Checks that the kernel's reply matches the packet generated by
762        reply_generator.
763      - Calls the given callback function.
764
765     Args:
766       version: An integer, 4 or 6.
767       packet_generator: A function taking an IP version (an integer), a source
768         address and a destination address (strings), and returning a scapy
769         packet.
770       reply_generator: A function taking the same arguments as packet_generator,
771         plus a scapy packet, and returning a scapy packet.
772       mark_behaviour: A string describing the mark behaviour to test. Tests are
773         performed with the corresponding sysctl set to both 0 and 1.
774       callback: A function to call to perform extra checks if the packet
775         matches. Takes netid, version, local address, remote address, original
776         packet, kernel reply, and a message.
777     """
778     # What are we testing?
779     sysctl_function = {"accept": self._SetTCPMarkAcceptSysctl,
780                        "reflect": self._SetMarkReflectSysctls}[mark_behaviour]
781
782     # Check packets addressed to the IP addresses of all our interfaces...
783     for dest_ip_netid in self.tuns:
784       dest_ip_iface = self.GetInterfaceName(dest_ip_netid)
785
786       myaddr = self.MyAddress(version, dest_ip_netid)
787       remote_addr = self._GetRemoteAddress(version)
788
789       # ... coming in on all our interfaces...
790       for iif_netid in self.tuns:
791         iif = self.GetInterfaceName(iif_netid)
792
793         # ... with inbound mark sysctl enabled and disabled.
794         for sysctl_value in [0, 1]:
795
796           # If we're testing accepting TCP connections, also check that
797           # SO_BINDTODEVICE correctly sets the interface the SYN+ACK is sent on.
798           # Since SO_BINDTODEVICE and the sysctl do the same thing, it doesn't
799           # really make sense to test with sysctl_value=1 and SO_BINDTODEVICE
800           # turned on at the same time.
801           if mark_behaviour == "accept" and not sysctl_value:
802             bind_devices = [None, iif]
803           else:
804             bind_devices = [None]
805
806           for bound_dev in bind_devices:
807             # The socket is unbound in tearDown.
808             self.BindToDevice(self.listensocket, bound_dev)
809
810             # Generate the packet here instead of in the outer loop, so
811             # subsequent TCP connections use different source ports and
812             # retransmissions from old connections don't confuse subsequent
813             # tests.
814             desc, packet = packet_generator(version, remote_addr, myaddr)
815             reply_desc, reply = reply_generator(version, myaddr, remote_addr,
816                                                 packet)
817
818             msg = "Receiving %s on %s to %s IP, %s=%d, bound_dev=%s" % (
819                 desc, iif, dest_ip_iface, mark_behaviour, sysctl_value,
820                 bound_dev)
821             sysctl_function(sysctl_value)
822
823             # Cause the kernel to receive packet on iif_netid.
824             self.ReceivePacketOn(iif_netid, packet)
825
826             # Expect the kernel to send out reply on the same interface.
827             #
828             # HACK: IPv6 ping replies always do a routing lookup with the
829             # interface the ping came in on. So even if mark reflection is not
830             # working, IPv6 ping replies will be properly reflected. Don't
831             # fail when that happens.
832             if bound_dev or sysctl_value or reply_desc == "ICMPv6 echo reply":
833               msg += ": Expecting %s on %s" % (reply_desc, iif)
834               reply = self.ExpectPacketOn(iif_netid, msg, reply)
835               # If a callback was set, call it.
836               if callback:
837                 callback(sysctl_value, iif_netid, version, myaddr, remote_addr,
838                          packet, reply, msg)
839             else:
840               msg += ": Expecting no packets on %s" % iif
841               self.ExpectNoPacketsOn(iif_netid, msg, reply)
842
843   def SYNToClosedPort(self, *args):
844     return Packets.SYN(999, *args)
845
846   def SYNToOpenPort(self, *args):
847     return Packets.SYN(self.listenport, *args)
848
849   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
850   def testIPv4ICMPErrorsReflectMark(self):
851     self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
852
853   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
854   def testIPv6ICMPErrorsReflectMark(self):
855     self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
856
857   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
858   def testIPv4PingRepliesReflectMarkAndTos(self):
859     self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
860
861   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
862   def testIPv6PingRepliesReflectMarkAndTos(self):
863     self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
864
865   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
866   def testIPv4RSTsReflectMark(self):
867     self.CheckReflection(4, self.SYNToClosedPort, Packets.RST, "reflect")
868
869   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
870   def testIPv6RSTsReflectMark(self):
871     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
872
873   def CheckTCPConnection(self, sysctl_value, netid, version,
874                          myaddr, remote_addr, packet, reply, msg):
875     establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
876     self.ReceivePacketOn(netid, establishing_ack)
877     s, unused_peer = self.listensocket.accept()
878     try:
879       mark = self.GetSocketMark(s)
880     finally:
881       s.close()
882     if sysctl_value:
883       self.assertEquals(netid, mark,
884                         msg + ": Accepted socket: Expected mark %d, got %d" % (
885                             netid, mark))
886
887     # Check the FIN was sent on the right interface, and ack it. We don't expect
888     # this to fail because by the time the connection is established things are
889     # likely working, but a) extra tests are always good and b) extra packets
890     # like the FIN (and retransmitted FINs) could cause later tests that expect
891     # no packets to fail.
892     desc, fin = Packets.FIN(version, myaddr, remote_addr, establishing_ack)
893     self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
894
895     desc, finack = Packets.FIN(version, remote_addr, myaddr, fin)
896     self.ReceivePacketOn(netid, finack)
897
898     desc, finackack = Packets.ACK(version, myaddr, remote_addr, finack)
899     self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
900
901   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
902   def testIPv4TCPConnections(self):
903     self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
904                          self.CheckTCPConnection)
905
906   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
907   def testIPv6TCPConnections(self):
908     self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
909                          self.CheckTCPConnection)
910
911   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
912   def testTCPConnectionsWithSynCookies(self):
913     # Force SYN cookies on all connections.
914     self.SetSysctl(SYNCOOKIES_SYSCTL, 2)
915     try:
916       self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
917                            self.CheckTCPConnection)
918       self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
919                            self.CheckTCPConnection)
920     finally:
921       # Stop forcing SYN cookies on all connections.
922       self.SetSysctl(SYNCOOKIES_SYSCTL, 1)
923
924
925 class RATest(MultiNetworkTest):
926
927   def testDoesNotHaveObsoleteSysctl(self):
928     self.assertFalse(os.path.isfile(
929         "/proc/sys/net/ipv6/route/autoconf_table_offset"))
930
931   @unittest.skipUnless(HAVE_AUTOCONF_TABLE, "no support for per-table autoconf")
932   def testPurgeDefaultRouters(self):
933
934     def CheckIPv6Connectivity(expect_connectivity):
935       for netid in self.NETIDS:
936         s = net_test.UDPSocket(AF_INET6)
937         self.SetSocketMark(s, netid)
938         if expect_connectivity:
939           self.assertEquals(5, s.sendto("hello", (net_test.IPV6_ADDR, 1234)))
940         else:
941           self.assertRaisesErrno(errno.ENETUNREACH,
942                                  s.sendto, "hello", (net_test.IPV6_ADDR, 1234))
943
944     try:
945       CheckIPv6Connectivity(True)
946       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
947       CheckIPv6Connectivity(False)
948     finally:
949       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
950       for netid in self.NETIDS:
951         self.SendRA(netid)
952       CheckIPv6Connectivity(True)
953
954   @unittest.skipUnless(False, "Known bug: routing tables are never deleted")
955   def testNoLeftoverRoutes(self):
956     def GetNumRoutes():
957       return len(open("/proc/net/ipv6_route").readlines())
958
959     num_routes = GetNumRoutes()
960     for i in xrange(10, 20):
961       try:
962         self.tuns[i] = self.CreateTunInterface(i)
963         self.SendRA(i)
964         self.tuns[i].close()
965       finally:
966         del self.tuns[i]
967     self.assertEquals(num_routes, GetNumRoutes())
968
969
970 class PMTUTest(MultiNetworkTest):
971
972   IPV6_PATHMTU = 61
973   IPV6_DONTFRAG = 62
974
975   def GetRandomDestination(self, version):
976     if version == 4:
977       return "172.16.%d.%d" % (random.randint(0, 31), random.randint(0, 255))
978     else:
979       return "2001:db8::%x:%x" % (random.randint(0, 65535),
980                                   random.randint(0, 65535))
981
982   def GetSocketMTU(self, s):
983     ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
984     mtu = struct.unpack("=28sI", ip6_mtuinfo)
985     return mtu[1]
986
987   def testIPv6PMTU(self):
988     s = net_test.Socket(AF_INET6, SOCK_DGRAM, IPPROTO_UDP)
989     s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
990     s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
991     netid = self.NETIDS[2]  # Just pick an arbitrary one.
992
993     srcaddr = self.MyAddress(6, netid)
994     dstaddr = self.GetRandomDestination(6)
995     intermediate = "2001:db8::1"
996
997     self.SetSocketMark(s, netid)  # So the packet has somewhere to go.
998     s.connect((dstaddr, 1234))
999     self.assertEquals(1500, self.GetSocketMTU(s))
1000
1001     s.send(1400 * "a")
1002     packets = self.ReadAllPacketsOn(netid)
1003     self.assertEquals(1, len(packets))
1004     toobig = Packets.ICMPPacketTooBig(6, intermediate, srcaddr, packets[0])[1]
1005     self.ReceivePacketOn(netid, toobig)
1006     self.assertEquals(1280, self.GetSocketMTU(s))
1007
1008
1009 if __name__ == "__main__":
1010   unittest.main()