OSDN Git Service

Test IPv4 PMTU as well.
[android-x86/system-extras.git] / tests / net_test / mark_test.py
1 #!/usr/bin/python
2
3 import errno
4 import fcntl
5 import os
6 import posix
7 import random
8 import re
9 from socket import *  # pylint: disable=wildcard-import
10 import struct
11 import time           # pylint: disable=unused-import
12 import unittest
13
14 from scapy import all as scapy
15
16 import iproute
17 import net_test
18
19 IFF_TUN = 1
20 IFF_TAP = 2
21 IFF_NO_PI = 0x1000
22 TUNSETIFF = 0x400454ca
23
24 PING_IDENT = 0xff19
25 PING_PAYLOAD = "foobarbaz"
26 PING_SEQ = 3
27 PING_TOS = 0x83
28
29 SO_BINDTODEVICE = 25
30
31 IP_UNICAST_IF = 50
32 IPV6_UNICAST_IF = 76
33
34 UDP_PAYLOAD = "hello"
35
36
37 # Check to see if the kernel supports UID routing.
38 def HaveUidRouting():
39   # Create a rule with the UID range selector. If the kernel doesn't understand
40   # the selector, it will create a rule with no selectors.
41   iproute.IPRoute().UidRangeRule(6, True, 1000, 2000, 100)
42
43   # Dump all the rules. If we find a rule using the UID range selector, then the
44   # kernel supports UID range routing.
45   rules = iproute.IPRoute().DumpRules(6)
46   result = any(iproute.EXPERIMENTAL_FRA_UID_START in attrs
47                for rule, attrs in rules)
48
49   # Delete the rule.
50   iproute.IPRoute().UidRangeRule(6, False, 1000, 2000, 100)
51   return result
52
53
54 AUTOCONF_TABLE_SYSCTL = "/proc/sys/net/ipv6/conf/default/accept_ra_rt_table"
55 IPV4_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv4/fwmark_reflect"
56 IPV6_MARK_REFLECT_SYSCTL = "/proc/sys/net/ipv6/fwmark_reflect"
57 SYNCOOKIES_SYSCTL = "/proc/sys/net/ipv4/tcp_syncookies"
58 TCP_MARK_ACCEPT_SYSCTL = "/proc/sys/net/ipv4/tcp_fwmark_accept"
59
60 HAVE_AUTOCONF_TABLE = os.path.isfile(AUTOCONF_TABLE_SYSCTL)
61 HAVE_MARK_REFLECT = os.path.isfile(IPV4_MARK_REFLECT_SYSCTL)
62 HAVE_TCP_MARK_ACCEPT = os.path.isfile(TCP_MARK_ACCEPT_SYSCTL)
63
64 HAVE_EXPERIMENTAL_UID_ROUTING = HaveUidRouting()
65
66
67 class ConfigurationError(AssertionError):
68   pass
69
70
71 class UnexpectedPacketError(AssertionError):
72   pass
73
74
75 class Packets(object):
76
77   TCP_FIN = 1
78   TCP_SYN = 2
79   TCP_RST = 4
80   TCP_ACK = 16
81
82   TCP_SEQ = 1692871236
83   TCP_WINDOW = 14400
84
85   @staticmethod
86   def RandomPort():
87     return random.randint(1025, 65535)
88
89   @staticmethod
90   def _GetIpLayer(version):
91     return {4: scapy.IP, 6: scapy.IPv6}[version]
92
93   @staticmethod
94   def _SetPacketTos(packet, tos):
95     if isinstance(packet, scapy.IPv6):
96       packet.tc = tos
97     elif isinstance(packet, scapy.IP):
98       packet.tos = tos
99     else:
100       raise ValueError("Can't find ToS Field")
101
102   @classmethod
103   def UDP(cls, version, srcaddr, dstaddr, sport=0):
104     ip = cls._GetIpLayer(version)
105     # Can't just use "if sport" because None has meaning (it means unspecified).
106     if sport == 0:
107       sport = cls.RandomPort()
108     return ("UDPv%d packet" % version,
109             ip(src=srcaddr, dst=dstaddr) /
110             scapy.UDP(sport=sport, dport=53) / UDP_PAYLOAD)
111
112   @classmethod
113   def SYN(cls, dport, version, srcaddr, dstaddr, sport=0, seq=TCP_SEQ):
114     ip = cls._GetIpLayer(version)
115     if sport == 0:
116       sport = cls.RandomPort()
117     return ("TCP SYN",
118             ip(src=srcaddr, dst=dstaddr) /
119             scapy.TCP(sport=sport, dport=dport,
120                       seq=seq, ack=0,
121                       flags=cls.TCP_SYN, window=cls.TCP_WINDOW))
122
123   @classmethod
124   def RST(cls, version, srcaddr, dstaddr, packet):
125     ip = cls._GetIpLayer(version)
126     original = packet.getlayer("TCP")
127     return ("TCP RST",
128             ip(src=srcaddr, dst=dstaddr) /
129             scapy.TCP(sport=original.dport, dport=original.sport,
130                       ack=original.seq + 1, seq=None,
131                       flags=cls.TCP_RST | cls.TCP_ACK, window=cls.TCP_WINDOW))
132
133   @classmethod
134   def SYNACK(cls, version, srcaddr, dstaddr, packet):
135     ip = cls._GetIpLayer(version)
136     original = packet.getlayer("TCP")
137     return ("TCP SYN+ACK",
138             ip(src=srcaddr, dst=dstaddr) /
139             scapy.TCP(sport=original.dport, dport=original.sport,
140                       ack=original.seq + 1, seq=None,
141                       flags=cls.TCP_SYN | cls.TCP_ACK, window=None))
142
143   @classmethod
144   def ACK(cls, version, srcaddr, dstaddr, packet):
145     ip = cls._GetIpLayer(version)
146     original = packet.getlayer("TCP")
147     was_syn_or_fin = (original.flags & (cls.TCP_SYN | cls.TCP_FIN)) != 0
148     return ("TCP ACK",
149             ip(src=srcaddr, dst=dstaddr) /
150             scapy.TCP(sport=original.dport, dport=original.sport,
151                       ack=original.seq + was_syn_or_fin, seq=original.ack,
152                       flags=cls.TCP_ACK, window=cls.TCP_WINDOW))
153
154   @classmethod
155   def FIN(cls, version, srcaddr, dstaddr, packet):
156     ip = cls._GetIpLayer(version)
157     original = packet.getlayer("TCP")
158     was_fin = (original.flags & cls.TCP_FIN) != 0
159     return ("TCP FIN",
160             ip(src=srcaddr, dst=dstaddr) /
161             scapy.TCP(sport=original.dport, dport=original.sport,
162                       ack=original.seq + was_fin, seq=original.ack,
163                       flags=cls.TCP_ACK | cls.TCP_FIN, window=cls.TCP_WINDOW))
164
165   @classmethod
166   def ICMPPortUnreachable(cls, version, srcaddr, dstaddr, packet):
167     if version == 4:
168       # Linux hardcodes the ToS on ICMP errors to 0xc0 or greater because of
169       # RFC 1812 4.3.2.5 (!).
170       return ("ICMPv4 port unreachable",
171               scapy.IP(src=srcaddr, dst=dstaddr, proto=1, tos=0xc0) /
172               scapy.ICMPerror(type=3, code=3) / packet)
173     else:
174       return ("ICMPv6 port unreachable",
175               scapy.IPv6(src=srcaddr, dst=dstaddr) /
176               scapy.ICMPv6DestUnreach(code=4) / packet)
177
178   @classmethod
179   def ICMPPacketTooBig(cls, version, srcaddr, dstaddr, packet):
180     if version == 4:
181       return ("ICMPv4 fragmentation needed",
182               scapy.IP(src=srcaddr, dst=dstaddr, proto=1) /
183               scapy.ICMPerror(type=3, code=4, unused=1280) / str(packet)[:64])
184     else:
185       udp = packet.getlayer("UDP")
186       udp.payload = str(udp.payload)[:1280-40-8]
187       return ("ICMPv6 Packet Too Big",
188               scapy.IPv6(src=srcaddr, dst=dstaddr) /
189               scapy.ICMPv6PacketTooBig() / str(packet)[:1232])
190
191   @classmethod
192   def ICMPEcho(cls, version, srcaddr, dstaddr):
193     ip = cls._GetIpLayer(version)
194     icmp = {4: scapy.ICMP, 6: scapy.ICMPv6EchoRequest}[version]
195     packet = (ip(src=srcaddr, dst=dstaddr) /
196               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
197     cls._SetPacketTos(packet, PING_TOS)
198     return ("ICMPv%d echo" % version, packet)
199
200   @classmethod
201   def ICMPReply(cls, version, srcaddr, dstaddr, packet):
202     ip = cls._GetIpLayer(version)
203     # Scapy doesn't provide an ICMP echo reply constructor.
204     icmpv4_reply = lambda **kwargs: scapy.ICMP(type=0, **kwargs)
205     icmp = {4: icmpv4_reply, 6: scapy.ICMPv6EchoReply}[version]
206     packet = (ip(src=srcaddr, dst=dstaddr) /
207               icmp(id=PING_IDENT, seq=PING_SEQ) / PING_PAYLOAD)
208     cls._SetPacketTos(packet, PING_TOS)
209     return ("ICMPv%d echo reply" % version, packet)
210
211   @classmethod
212   def NS(cls, srcaddr, tgtaddr, srcmac):
213     solicited = inet_pton(AF_INET6, tgtaddr)
214     last3bytes = tuple([ord(b) for b in solicited[-3:]])
215     solicited = "ff02::1:ff%02x:%02x%02x" % last3bytes
216     packet = (scapy.IPv6(src=srcaddr, dst=solicited) /
217               scapy.ICMPv6ND_NS(tgt=tgtaddr) /
218               scapy.ICMPv6NDOptSrcLLAddr(lladdr=srcmac))
219     return ("ICMPv6 NS", packet)
220
221   @classmethod
222   def NA(cls, srcaddr, dstaddr, srcmac):
223     packet = (scapy.IPv6(src=srcaddr, dst=dstaddr) /
224               scapy.ICMPv6ND_NA(tgt=srcaddr, R=0, S=1, O=1) /
225               scapy.ICMPv6NDOptDstLLAddr(lladdr=srcmac))
226     return ("ICMPv6 NA", packet)
227
228
229 class RunAsUid(object):
230
231   """Context guard to run a code block as a given UID."""
232
233   def __init__(self, uid):
234     self.uid = uid
235
236   def __enter__(self):
237     if self.uid:
238       self.saved_uid = os.geteuid()
239       if self.uid:
240         os.seteuid(self.uid)
241
242   def __exit__(self, unused_type, unused_value, unused_traceback):
243     if self.uid:
244       os.seteuid(self.saved_uid)
245
246
247 class MultiNetworkTest(net_test.NetworkTest):
248
249   # Must be between 1 and 256, since we put them in MAC addresses and IIDs.
250   NETIDS = [100, 150, 200, 250]
251
252   # Stores sysctl values to write back when the test completes.
253   saved_sysctls = {}
254
255   # Wether to output setup commands.
256   DEBUG = False
257
258   # The size of our UID ranges.
259   UID_RANGE_SIZE = 1000
260
261   @classmethod
262   def UidRangeForNetid(cls, netid):
263     return (
264         cls.UID_RANGE_SIZE * netid,
265         cls.UID_RANGE_SIZE * (netid + 1) - 1
266     )
267
268   @classmethod
269   def UidForNetid(cls, netid):
270     return random.randint(*cls.UidRangeForNetid(netid))
271
272   @classmethod
273   def _TableForNetid(cls, netid):
274     if cls.AUTOCONF_TABLE_OFFSET and netid in cls.ifindices:
275       return cls.ifindices[netid] + (-cls.AUTOCONF_TABLE_OFFSET)
276     else:
277       return netid
278
279   @staticmethod
280   def GetInterfaceName(netid):
281     return "nettest%d" % netid
282
283   @staticmethod
284   def RouterMacAddress(netid):
285     return "02:00:00:00:%02x:00" % netid
286
287   @staticmethod
288   def MyMacAddress(netid):
289     return "02:00:00:00:%02x:01" % netid
290
291   @staticmethod
292   def _RouterAddress(netid, version):
293     if version == 6:
294       return "fe80::%02x00" % netid
295     elif version == 4:
296       return "10.0.%d.1" % netid
297     else:
298       raise ValueError("Don't support IPv%s" % version)
299
300   @classmethod
301   def _MyIPv4Address(cls, netid):
302     return "10.0.%d.2" % netid
303
304   @classmethod
305   def _MyIPv6Address(cls, netid):
306     return net_test.GetLinkAddress(cls.GetInterfaceName(netid), False)
307
308   @classmethod
309   def MyAddress(cls, version, netid):
310     return {4: cls._MyIPv4Address(netid),
311             6: cls._MyIPv6Address(netid)}[version]
312
313   @staticmethod
314   def IPv6Prefix(netid):
315     return "2001:db8:%02x::" % netid
316
317   @staticmethod
318   def GetRandomDestination(prefix):
319     if "." in prefix:
320       return prefix + "%d.%d" % (random.randint(0, 31), random.randint(0, 255))
321     else:
322       return prefix + "%x:%x" % (random.randint(0, 65535),
323                                  random.randint(0, 65535))
324
325   def GetProtocolFamily(self, version):
326     return {4: AF_INET, 6: AF_INET6}[version]
327
328   @classmethod
329   def CreateTunInterface(cls, netid):
330     iface = cls.GetInterfaceName(netid)
331     f = open("/dev/net/tun", "r+b")
332     ifr = struct.pack("16sH", iface, IFF_TAP | IFF_NO_PI)
333     ifr += "\x00" * (40 - len(ifr))
334     fcntl.ioctl(f, TUNSETIFF, ifr)
335     # Give ourselves a predictable MAC address.
336     net_test.SetInterfaceHWAddr(iface, cls.MyMacAddress(netid))
337     # Disable DAD so we don't have to wait for it.
338     cls.SetSysctl("/proc/sys/net/ipv6/conf/%s/accept_dad" % iface, 0)
339     net_test.SetInterfaceUp(iface)
340     net_test.SetNonBlocking(f)
341     return f
342
343   @classmethod
344   def SendRA(cls, netid):
345     validity = 300                 # seconds
346     validity_ms = validity * 1000  # milliseconds
347     macaddr = cls.RouterMacAddress(netid)
348     lladdr = cls._RouterAddress(netid, 6)
349
350     # We don't want any routes in the main table. If the kernel doesn't support
351     # putting RA routes into per-interface tables, configure routing manually.
352     routerlifetime = validity if HAVE_AUTOCONF_TABLE else 0
353
354     ra = (scapy.Ether(src=macaddr, dst="33:33:00:00:00:01") /
355           scapy.IPv6(src=lladdr, hlim=255) /
356           scapy.ICMPv6ND_RA(retranstimer=validity_ms,
357                             routerlifetime=routerlifetime) /
358           scapy.ICMPv6NDOptSrcLLAddr(lladdr=macaddr) /
359           scapy.ICMPv6NDOptPrefixInfo(prefix=cls.IPv6Prefix(netid),
360                                       prefixlen=64,
361                                       L=1, A=1,
362                                       validlifetime=validity,
363                                       preferredlifetime=validity))
364     posix.write(cls.tuns[netid].fileno(), str(ra))
365
366   @classmethod
367   def _RunSetupCommands(cls, netid, is_add):
368     for version in [4, 6]:
369       # Find out how to configure things.
370       iface = cls.GetInterfaceName(netid)
371       ifindex = cls.ifindices[netid]
372       macaddr = cls.RouterMacAddress(netid)
373       router = cls._RouterAddress(netid, version)
374       table = cls._TableForNetid(netid)
375
376       # Run iptables to set up incoming packet marking.
377       add_del = "-A" if is_add else "-D"
378       iptables = {4: "iptables", 6: "ip6tables"}[version]
379       args = "%s %s INPUT -t mangle -i %s -j MARK --set-mark %d" % (
380           iptables, add_del, iface, netid)
381       iptables = "/sbin/" + iptables
382       ret = os.spawnvp(os.P_WAIT, iptables, args.split(" "))
383       if ret:
384         raise ConfigurationError("Setup command failed: %s" % args)
385
386       # Set up routing rules.
387       if HAVE_EXPERIMENTAL_UID_ROUTING:
388         start, end = cls.UidRangeForNetid(netid)
389         cls.iproute.UidRangeRule(version, is_add, start, end, table,
390                                  priority=100)
391       cls.iproute.OifRule(version, is_add, iface, table, priority=200)
392       cls.iproute.FwmarkRule(version, is_add, netid, table, priority=300)
393
394       # Configure routing and addressing.
395       #
396       # IPv6 uses autoconf for everything, except if per-device autoconf routing
397       # tables are not supported, in which case the default route (only) is
398       # configured manually. For IPv4 we have to manualy configure addresses,
399       # routes, and neighbour cache entries (since we don't reply to ARP or ND).
400       #
401       # Since deleting addresses also causes routes to be deleted, we need to
402       # be careful with ordering or the delete commands will fail with ENOENT.
403       do_routing = (version == 4 or cls.AUTOCONF_TABLE_OFFSET is None)
404       if is_add:
405         if version == 4:
406           cls.iproute.AddAddress(cls._MyIPv4Address(netid), 24, ifindex)
407           cls.iproute.AddNeighbour(version, router, macaddr, ifindex)
408         if do_routing:
409           cls.iproute.AddRoute(version, table, "default", 0, router, ifindex)
410       else:
411         if do_routing:
412           cls.iproute.DelRoute(version, table, "default", 0, router, ifindex)
413         if version == 4:
414           cls.iproute.DelNeighbour(version, router, macaddr, ifindex)
415           cls.iproute.DelAddress(cls._MyIPv4Address(netid), 24, ifindex)
416
417   @classmethod
418   def GetSysctl(cls, sysctl):
419     return open(sysctl, "r").read()
420
421   @classmethod
422   def SetSysctl(cls, sysctl, value):
423     # Only save each sysctl value the first time we set it. This is so we can
424     # set it to arbitrary values multiple times and still write it back
425     # correctly at the end.
426     if sysctl not in cls.saved_sysctls:
427       cls.saved_sysctls[sysctl] = cls.GetSysctl(sysctl)
428     open(sysctl, "w").write(str(value) + "\n")
429
430   @classmethod
431   def _RestoreSysctls(cls):
432     for sysctl, value in cls.saved_sysctls.iteritems():
433       try:
434         open(sysctl, "w").write(value)
435       except IOError:
436         pass
437
438   @classmethod
439   def _ICMPRatelimitFilename(cls, version):
440     return "/proc/sys/net/" + {4: "ipv4/icmp_ratelimit",
441                                6: "ipv6/icmp/ratelimit"}[version]
442
443   @classmethod
444   def _SetICMPRatelimit(cls, version, limit):
445     cls.SetSysctl(cls._ICMPRatelimitFilename(version), limit)
446
447   @classmethod
448   def setUpClass(cls):
449     # This is per-class setup instead of per-testcase setup because shelling out
450     # to ip and iptables is slow, and because routing configuration doesn't
451     # change during the test.
452     cls.iproute = iproute.IPRoute()
453     cls.tuns = {}
454     cls.ifindices = {}
455     if HAVE_AUTOCONF_TABLE:
456       cls.SetSysctl(AUTOCONF_TABLE_SYSCTL, -1000)
457       cls.AUTOCONF_TABLE_OFFSET = -1000
458     else:
459       cls.AUTOCONF_TABLE_OFFSET = None
460
461     # Disable ICMP rate limits. These will be restored by _RestoreSysctls.
462     for version in [4, 6]:
463       cls._SetICMPRatelimit(version, 0)
464
465     for netid in cls.NETIDS:
466       cls.tuns[netid] = cls.CreateTunInterface(netid)
467       iface = cls.GetInterfaceName(netid)
468       cls.ifindices[netid] = net_test.GetInterfaceIndex(iface)
469
470       cls.SendRA(netid)
471       cls._RunSetupCommands(netid, True)
472
473     # Uncomment to look around at interface and rule configuration while
474     # running in the background. (Once the test finishes running, all the
475     # interfaces and rules are gone.)
476     # time.sleep(30)
477
478   @classmethod
479   def tearDownClass(cls):
480     for netid in cls.tuns:
481       cls._RunSetupCommands(netid, False)
482       cls.tuns[netid].close()
483     cls._RestoreSysctls()
484
485   def SetSocketMark(self, s, netid):
486     s.setsockopt(SOL_SOCKET, net_test.SO_MARK, netid)
487
488   def GetSocketMark(self, s):
489     return s.getsockopt(SOL_SOCKET, net_test.SO_MARK)
490
491   def ClearSocketMark(self, s):
492     self.SetSocketMark(s, 0)
493
494   def BindToDevice(self, s, iface):
495     if not iface:
496       iface = ""
497     s.setsockopt(SOL_SOCKET, SO_BINDTODEVICE, iface)
498
499   def SetUnicastInterface(self, version, s, iface):
500     if iface:
501       ifindex = net_test.GetInterfaceIndex(iface)
502     else:
503       ifindex = 0
504     # Otherwise, Python apparently thinks it's a 1-byte option.
505     ifindex = struct.pack("!I", ifindex)
506
507     layer, opt = {
508         4: (net_test.SOL_IP, IP_UNICAST_IF),
509         6: (net_test.SOL_IPV6, IPV6_UNICAST_IF),
510     }[version]
511     s.setsockopt(layer, opt, ifindex)
512
513   def ReceiveEtherPacketOn(self, netid, packet):
514     posix.write(self.tuns[netid].fileno(), str(packet))
515
516   def ReceivePacketOn(self, netid, ip_packet):
517     routermac = self.RouterMacAddress(netid)
518     mymac = self.MyMacAddress(netid)
519     packet = scapy.Ether(src=routermac, dst=mymac) / ip_packet
520     self.ReceiveEtherPacketOn(netid, packet)
521
522   def ReadAllPacketsOn(self, netid, include_multicast=False):
523     packets = []
524     while True:
525       try:
526         packet = posix.read(self.tuns[netid].fileno(), 4096)
527         ether = scapy.Ether(packet)
528         # Multicast frames are frames where the first byte of the destination
529         # MAC address has 1 in the least-significant bit.
530         if include_multicast or not int(ether.dst.split(":")[0], 16) & 0x1:
531           packets.append(ether.payload)
532       except OSError, e:
533         # EAGAIN means there are no more packets waiting.
534         if re.match(e.message, os.strerror(errno.EAGAIN)):
535           break
536         # Anything else is unexpected.
537         else:
538           raise e
539     return packets
540
541   def ClearTunQueues(self):
542     # Keep reading packets on all netids until we get no packets on any of them.
543     waiting = None
544     while waiting != 0:
545       waiting = sum(len(self.ReadAllPacketsOn(netid)) for netid in self.NETIDS)
546
547   def assertPacketMatches(self, expected, actual):
548     # The expected packet is just a rough sketch of the packet we expect to
549     # receive. For example, it doesn't contain fields we can't predict, such as
550     # initial TCP sequence numbers, or that depend on the host implementation
551     # and settings, such as TCP options. To check whether the packet matches
552     # what we expect, instead of just checking all the known fields one by one,
553     # we blank out fields in the actual packet and then compare the whole
554     # packets to each other as strings. Because we modify the actual packet,
555     # make a copy here.
556     actual = actual.copy()
557
558     # Blank out IPv4 fields that we can't predict, like ID and the DF bit.
559     actualip = actual.getlayer("IP")
560     expectedip = expected.getlayer("IP")
561     if actualip and expectedip:
562       actualip.id = expectedip.id
563       actualip.flags &= 5
564       actualip.chksum = None  # Change the header, recalculate the checksum.
565
566     # Blank out UDP fields that we can't predict (e.g., the source port for
567     # kernel-originated packets).
568     actualudp = actual.getlayer("UDP")
569     expectedudp = expected.getlayer("UDP")
570     if actualudp and expectedudp:
571       if expectedudp.sport is None:
572         actualudp.sport = None
573         actualudp.chksum = None
574
575     # Since the TCP code below messes with options, recalculate the length.
576     if actualip:
577       actualip.len = None
578     actualipv6 = actual.getlayer("IPv6")
579     if actualipv6:
580       actualipv6.plen = None
581
582     # Blank out TCP fields that we can't predict.
583     actualtcp = actual.getlayer("TCP")
584     expectedtcp = expected.getlayer("TCP")
585     if actualtcp and expectedtcp:
586       actualtcp.dataofs = expectedtcp.dataofs
587       actualtcp.options = expectedtcp.options
588       actualtcp.window = expectedtcp.window
589       if expectedtcp.sport is None:
590         actualtcp.sport = None
591       if expectedtcp.seq is None:
592         actualtcp.seq = None
593       if expectedtcp.ack is None:
594         actualtcp.ack = None
595       actualtcp.chksum = None
596
597     # Serialize the packet so that expected packet fields that are only set when
598     # a packet is serialized e.g., the checksum) are filled in.
599     expected_real = expected.__class__(str(expected))
600     actual_real = actual.__class__(str(actual))
601     # repr() can be expensive. Call it only if the test is going to fail and we
602     # want to see the error.
603     if expected_real != actual_real:
604       self.assertEquals(repr(expected_real), repr(actual_real))
605
606   def PacketMatches(self, expected, actual):
607     try:
608       self.assertPacketMatches(expected, actual)
609       return True
610     except AssertionError:
611       return False
612
613   def ExpectNoPacketsOn(self, netid, msg):
614     packets = self.ReadAllPacketsOn(netid)
615     if packets:
616       firstpacket = str(packets[0]).encode("hex")
617     else:
618       firstpacket = ""
619     self.assertFalse(packets, msg + ": unexpected packet: " + firstpacket)
620
621   def ExpectPacketOn(self, netid, msg, expected):
622     # To avoid confusion due to lots of ICMPv6 ND going on all the time, drop
623     # multicast packets unless the packet we expect to see is a multicast
624     # packet. For now the only tests that use this are IPv6.
625     ipv6 = expected.getlayer("IPv6")
626     if ipv6 and ipv6.dst.startswith("ff"):
627       include_multicast = True
628     else:
629       include_multicast = False
630
631     packets = self.ReadAllPacketsOn(netid, include_multicast=include_multicast)
632     self.assertTrue(packets, msg + ": received no packets")
633
634     # If we receive a packet that matches what we expected, return it.
635     for packet in packets:
636       if self.PacketMatches(expected, packet):
637         return packet
638
639     # None of the packets matched. Call assertPacketMatches to output a diff
640     # between the expected packet and the last packet we received. In theory,
641     # we'd output a diff to the packet that's the best match for what we
642     # expected, but this is good enough for now.
643     try:
644       self.assertPacketMatches(expected, packets[-1])
645     except Exception, e:
646       raise UnexpectedPacketError(
647           "%s: diff with last packet:\n%s" % (msg, e.message))
648
649
650 class MarkTest(MultiNetworkTest):
651
652   # How many times to run packet reflection tests.
653   ITERATIONS = 5
654
655   # For convenience.
656   IPV4_ADDR = net_test.IPV4_ADDR
657   IPV6_ADDR = net_test.IPV6_ADDR
658   IPV4_PING = net_test.IPV4_PING
659   IPV6_PING = net_test.IPV6_PING
660
661   @classmethod
662   def setUpClass(cls):
663     super(MarkTest, cls).setUpClass()
664
665     # Open a port so we can observe SYN+ACKs. Since it's a dual-stack socket it
666     # will accept both IPv4 and IPv6 connections. We do this here instead of in
667     # each test so we can use the same socket every time. That way, if a kernel
668     # bug causes incoming packets to mark the listening socket instead of the
669     # accepted socket, the test will fail as soon as the next address/interface
670     # combination is tried.
671     cls.listenport = 1234
672     cls.listensocket = net_test.IPv6TCPSocket()
673     cls.listensocket.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
674     cls.listensocket.bind(("::", cls.listenport))
675     cls.listensocket.listen(100)
676
677   @classmethod
678   def _SetMarkReflectSysctls(cls, value):
679     cls.SetSysctl(IPV4_MARK_REFLECT_SYSCTL, value)
680     try:
681       cls.SetSysctl(IPV6_MARK_REFLECT_SYSCTL, value)
682     except IOError:
683       # This does not exist if we use the version of the patch that uses a
684       # common sysctl for IPv4 and IPv6.
685       pass
686
687   @classmethod
688   def _SetTCPMarkAcceptSysctl(cls, value):
689     cls.SetSysctl(TCP_MARK_ACCEPT_SYSCTL, value)
690
691   def setUp(self):
692     self.ClearTunQueues()
693
694   def tearDown(self):
695     # In case there was an exception in one of the tests and we didn't clean up.
696     self.BindToDevice(self.listensocket, None)
697
698   def _GetRemoteAddress(self, version):
699     return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
700
701   def BuildSocket(self, version, constructor, mark, uid, oif, ucast_oif):
702     with RunAsUid(uid):
703       family = self.GetProtocolFamily(version)
704       s = constructor(family)
705     if mark:
706       self.SetSocketMark(s, mark)
707     if oif:
708       self.BindToDevice(s, oif)
709     if ucast_oif:
710       self.SetUnicastInterface(version, s, ucast_oif)
711     return s
712
713   def CheckPingPacket(self, version, mark, uid, oif, ucast_oif, dstaddr, packet,
714                       expected_netid):
715     s = self.BuildSocket(version, net_test.PingSocket, mark, uid, oif,
716                          ucast_oif)
717
718     myaddr = self.MyAddress(version, expected_netid)
719     s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
720     s.bind((myaddr, PING_IDENT))
721     net_test.SetSocketTos(s, PING_TOS)
722
723     desc, expected = Packets.ICMPEcho(version, myaddr, dstaddr)
724     msg = "IPv%d ping: expected %s on %s" % (
725         version, desc, self.GetInterfaceName(expected_netid))
726     s.sendto(packet + PING_PAYLOAD, (dstaddr, 19321))
727     self.ExpectPacketOn(expected_netid, msg, expected)
728
729   def CheckTCPSYNPacket(self, version, mark, uid, oif, ucast_oif, dstaddr,
730                         expected_netid):
731     s = self.BuildSocket(version, net_test.TCPSocket, mark, uid, oif, ucast_oif)
732
733     if version == 6 and dstaddr.startswith("::ffff"):
734       version = 4
735     myaddr = self.MyAddress(version, expected_netid)
736     desc, expected = Packets.SYN(53, version, myaddr, dstaddr,
737                                  sport=None, seq=None)
738
739     # Non-blocking TCP connects always return EINPROGRESS.
740     self.assertRaisesErrno(errno.EINPROGRESS, s.connect, (dstaddr, 53))
741     msg = "IPv%s TCP connect: expected %s on %s" % (
742         version, desc, self.GetInterfaceName(expected_netid))
743     self.ExpectPacketOn(expected_netid, msg, expected)
744     s.close()
745
746   def CheckUDPPacket(self, version, mark, uid, oif, ucast_oif,
747                      dstaddr, expected_netid):
748     s = self.BuildSocket(version, net_test.UDPSocket, mark, uid, oif, ucast_oif)
749
750     if version == 6 and dstaddr.startswith("::ffff"):
751       version = 4
752     myaddr = self.MyAddress(version, expected_netid)
753     desc, expected = Packets.UDP(version, myaddr, dstaddr, sport=None)
754     msg = "IPv%s UDP %%s: expected %s on %s" % (
755         version, desc, self.GetInterfaceName(expected_netid))
756
757     s.sendto(UDP_PAYLOAD, (dstaddr, 53))
758     self.ExpectPacketOn(expected_netid, msg % "sendto", expected)
759
760     s.connect((dstaddr, 53))
761     s.send(UDP_PAYLOAD)
762     self.ExpectPacketOn(expected_netid, msg % "connect/send", expected)
763     s.close()
764
765   def CheckOutgoingPackets(self, mode):
766     v4addr = self.IPV4_ADDR
767     v6addr = self.IPV6_ADDR
768     v4mapped = "::ffff:" + v4addr
769
770     for _ in xrange(self.ITERATIONS):
771       for netid in self.tuns:
772
773         mark = uid = oif = ucast_oif = None
774         if mode == "mark":
775           mark = netid
776         elif mode == "uid":
777           uid = self.UidForNetid(netid)
778         elif mode == "oif":
779           oif = self.GetInterfaceName(netid)
780         elif mode == "ucast_oif":
781           ucast_oif = self.GetInterfaceName(netid)
782         else:
783           raise ValueError("Unkown routing mode %s" % mode)
784
785         self.CheckPingPacket(4, mark, uid, oif, ucast_oif, v4addr,
786                              self.IPV4_PING, netid)
787         self.CheckPingPacket(6, mark, uid, oif, ucast_oif, v6addr,
788                              self.IPV6_PING, netid)
789
790         # TCP doesn't seem to honour IP_UNICAST_IF.
791         if mode != "ucast_oif":
792           self.CheckTCPSYNPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
793           self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
794           self.CheckTCPSYNPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
795
796         if mode != "ucast_oif":
797           # This doesn't work.
798           self.CheckUDPPacket(4, mark, uid, oif, ucast_oif, v4addr, netid)
799           # These work, but the source addresses are incorrect.
800           self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v6addr, netid)
801           self.CheckUDPPacket(6, mark, uid, oif, ucast_oif, v4mapped, netid)
802
803   def testMarkRouting(self):
804     """Checks that socket marking selects the right outgoing interface."""
805     self.CheckOutgoingPackets("mark")
806
807   @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")
808   def testUidRouting(self):
809     """Checks that UID routing selects the right outgoing interface."""
810     self.CheckOutgoingPackets("uid")
811
812   def testOifRouting(self):
813     """Checks that oif routing selects the right outgoing interface."""
814     self.CheckOutgoingPackets("oif")
815
816   def testUcastOifRouting(self):
817     """Checks that ucast oif routing selects the right outgoing interface."""
818     self.CheckOutgoingPackets("ucast_oif")
819
820   def CheckRemarking(self, version):
821     s = net_test.UDPSocket(self.GetProtocolFamily(version))
822
823     # Figure out what packets to expect.
824     unspec = {4: "0.0.0.0", 6: "::"}[version]
825     sport = Packets.RandomPort()
826     s.bind((unspec, sport))
827     dstaddr = {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
828     desc, expected = Packets.UDP(version, unspec, dstaddr, sport)
829
830     # For each netid, set that netid's mark on the socket without closing it,
831     # and check that the packets sent on that socket go out on the right
832     # network.
833     for netid in self.tuns:
834       self.SetSocketMark(s, netid)
835       expected.src = self.MyAddress(version, netid)
836       s.sendto("hello", (dstaddr, 53))
837       msg = "Remarked UDPv%d socket: expecting %s on %s" % (
838           version, desc, self.GetInterfaceName(netid))
839       self.ExpectPacketOn(netid, msg, expected)
840
841   def testIPv4Remarking(self):
842     """Checks that updating the mark on an IPv4 socket changes routing."""
843     self.CheckRemarking(4)
844
845   def testIPv6Remarking(self):
846     """Checks that updating the mark on an IPv6 socket changes routing."""
847     self.CheckRemarking(6)
848
849   def CheckReflection(self, version, packet_generator, reply_generator,
850                       mark_behaviour, callback=None):
851     """Checks that replies go out on the same interface as the original.
852
853     Iterates through all the combinations of the interfaces in self.tuns and the
854     IP addresses assigned to them. For each combination:
855      - Calls packet_generator to generate a packet to that IP address.
856      - Writes the packet generated by packet_generator on the given tun
857        interface, causing the kernel to receive it.
858      - Checks that the kernel's reply matches the packet generated by
859        reply_generator.
860      - Calls the given callback function.
861
862     Args:
863       version: An integer, 4 or 6.
864       packet_generator: A function taking an IP version (an integer), a source
865         address and a destination address (strings), and returning a scapy
866         packet.
867       reply_generator: A function taking the same arguments as packet_generator,
868         plus a scapy packet, and returning a scapy packet.
869       mark_behaviour: A string describing the mark behaviour to test. Tests are
870         performed with the corresponding sysctl set to both 0 and 1.
871       callback: A function to call to perform extra checks if the packet
872         matches. Takes netid, version, local address, remote address, original
873         packet, kernel reply, and a message.
874     """
875     # What are we testing?
876     sysctl_function = {"accept": self._SetTCPMarkAcceptSysctl,
877                        "reflect": self._SetMarkReflectSysctls}[mark_behaviour]
878
879     # Check packets addressed to the IP addresses of all our interfaces...
880     for dest_ip_netid in self.tuns:
881       dest_ip_iface = self.GetInterfaceName(dest_ip_netid)
882
883       myaddr = self.MyAddress(version, dest_ip_netid)
884       remote_addr = self._GetRemoteAddress(version)
885
886       # ... coming in on all our interfaces...
887       for iif_netid in self.tuns:
888         iif = self.GetInterfaceName(iif_netid)
889
890         # ... with inbound mark sysctl enabled and disabled.
891         for sysctl_value in [0, 1]:
892
893           # If we're testing accepting TCP connections, also check that
894           # SO_BINDTODEVICE correctly sets the interface the SYN+ACK is sent on.
895           # Since SO_BINDTODEVICE and the sysctl do the same thing, it doesn't
896           # really make sense to test with sysctl_value=1 and SO_BINDTODEVICE
897           # turned on at the same time.
898           if mark_behaviour == "accept" and not sysctl_value:
899             bind_devices = [None, iif]
900           else:
901             bind_devices = [None]
902
903           for bound_dev in bind_devices:
904             # The socket is unbound in tearDown.
905             self.BindToDevice(self.listensocket, bound_dev)
906
907             # Generate the packet here instead of in the outer loop, so
908             # subsequent TCP connections use different source ports and
909             # retransmissions from old connections don't confuse subsequent
910             # tests.
911             desc, packet = packet_generator(version, remote_addr, myaddr)
912             reply_desc, reply = reply_generator(version, myaddr, remote_addr,
913                                                 packet)
914
915             msg = "Receiving %s on %s to %s IP, %s=%d, bound_dev=%s" % (
916                 desc, iif, dest_ip_iface, mark_behaviour, sysctl_value,
917                 bound_dev)
918             sysctl_function(sysctl_value)
919
920             # Cause the kernel to receive packet on iif_netid.
921             self.ReceivePacketOn(iif_netid, packet)
922
923             # Expect the kernel to send out reply on the same interface.
924             #
925             # HACK: IPv6 ping replies always do a routing lookup with the
926             # interface the ping came in on. So even if mark reflection is not
927             # working, IPv6 ping replies will be properly reflected. Don't
928             # fail when that happens.
929             if bound_dev or sysctl_value or reply_desc == "ICMPv6 echo reply":
930               msg += ": Expecting %s on %s" % (reply_desc, iif)
931               reply = self.ExpectPacketOn(iif_netid, msg, reply)
932               # If a callback was set, call it.
933               if callback:
934                 callback(sysctl_value, iif_netid, version, myaddr, remote_addr,
935                          packet, reply, msg)
936             else:
937               msg += ": Expecting no packets on %s" % iif
938               self.ExpectNoPacketsOn(iif_netid, msg)
939
940   def SYNToClosedPort(self, *args):
941     return Packets.SYN(999, *args)
942
943   def SYNToOpenPort(self, *args):
944     return Packets.SYN(self.listenport, *args)
945
946   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
947   def testIPv4ICMPErrorsReflectMark(self):
948     self.CheckReflection(4, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
949
950   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
951   def testIPv6ICMPErrorsReflectMark(self):
952     self.CheckReflection(6, Packets.UDP, Packets.ICMPPortUnreachable, "reflect")
953
954   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
955   def testIPv4PingRepliesReflectMarkAndTos(self):
956     self.CheckReflection(4, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
957
958   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
959   def testIPv6PingRepliesReflectMarkAndTos(self):
960     self.CheckReflection(6, Packets.ICMPEcho, Packets.ICMPReply, "reflect")
961
962   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
963   def testIPv4RSTsReflectMark(self):
964     self.CheckReflection(4, self.SYNToClosedPort, Packets.RST, "reflect")
965
966   @unittest.skipUnless(HAVE_MARK_REFLECT, "no mark reflection")
967   def testIPv6RSTsReflectMark(self):
968     self.CheckReflection(6, self.SYNToClosedPort, Packets.RST, "reflect")
969
970   def CheckTCPConnection(self, sysctl_value, netid, version,
971                          myaddr, remote_addr, packet, reply, msg):
972     establishing_ack = Packets.ACK(version, remote_addr, myaddr, reply)[1]
973     self.ReceivePacketOn(netid, establishing_ack)
974     s, unused_peer = self.listensocket.accept()
975     try:
976       mark = self.GetSocketMark(s)
977     finally:
978       s.close()
979     if sysctl_value:
980       self.assertEquals(netid, mark,
981                         msg + ": Accepted socket: Expected mark %d, got %d" % (
982                             netid, mark))
983
984     # Check the FIN was sent on the right interface, and ack it. We don't expect
985     # this to fail because by the time the connection is established things are
986     # likely working, but a) extra tests are always good and b) extra packets
987     # like the FIN (and retransmitted FINs) could cause later tests that expect
988     # no packets to fail.
989     desc, fin = Packets.FIN(version, myaddr, remote_addr, establishing_ack)
990     self.ExpectPacketOn(netid, msg + ": expecting %s after close" % desc, fin)
991
992     desc, finack = Packets.FIN(version, remote_addr, myaddr, fin)
993     self.ReceivePacketOn(netid, finack)
994
995     desc, finackack = Packets.ACK(version, myaddr, remote_addr, finack)
996     self.ExpectPacketOn(netid, msg + ": expecting final ack", finackack)
997
998   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
999   def testIPv4TCPConnections(self):
1000     self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
1001                          self.CheckTCPConnection)
1002
1003   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
1004   def testIPv6TCPConnections(self):
1005     self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
1006                          self.CheckTCPConnection)
1007
1008   @unittest.skipUnless(HAVE_TCP_MARK_ACCEPT, "fwmark writeback not supported")
1009   def testTCPConnectionsWithSynCookies(self):
1010     # Force SYN cookies on all connections.
1011     self.SetSysctl(SYNCOOKIES_SYSCTL, 2)
1012     try:
1013       self.CheckReflection(4, self.SYNToOpenPort, Packets.SYNACK, "accept",
1014                            self.CheckTCPConnection)
1015       self.CheckReflection(6, self.SYNToOpenPort, Packets.SYNACK, "accept",
1016                            self.CheckTCPConnection)
1017     finally:
1018       # Stop forcing SYN cookies on all connections.
1019       self.SetSysctl(SYNCOOKIES_SYSCTL, 1)
1020
1021
1022 class RATest(MultiNetworkTest):
1023
1024   def testDoesNotHaveObsoleteSysctl(self):
1025     self.assertFalse(os.path.isfile(
1026         "/proc/sys/net/ipv6/route/autoconf_table_offset"))
1027
1028   @unittest.skipUnless(HAVE_AUTOCONF_TABLE, "no support for per-table autoconf")
1029   def testPurgeDefaultRouters(self):
1030
1031     def CheckIPv6Connectivity(expect_connectivity):
1032       for netid in self.NETIDS:
1033         s = net_test.UDPSocket(AF_INET6)
1034         self.SetSocketMark(s, netid)
1035         if expect_connectivity:
1036           self.assertEquals(5, s.sendto("hello", (net_test.IPV6_ADDR, 1234)))
1037         else:
1038           self.assertRaisesErrno(errno.ENETUNREACH,
1039                                  s.sendto, "hello", (net_test.IPV6_ADDR, 1234))
1040
1041     try:
1042       CheckIPv6Connectivity(True)
1043       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 1)
1044       CheckIPv6Connectivity(False)
1045     finally:
1046       self.SetSysctl("/proc/sys/net/ipv6/conf/all/forwarding", 0)
1047       for netid in self.NETIDS:
1048         self.SendRA(netid)
1049       CheckIPv6Connectivity(True)
1050
1051   @unittest.skipUnless(HAVE_AUTOCONF_TABLE, "our manual routing doesn't do PIO")
1052   def testOnlinkCommunication(self):
1053     """Checks that on-link communication goes direct and not through routers."""
1054     for netid in self.tuns:
1055       # Send a UDP packet to a random on-link destination.
1056       s = net_test.UDPSocket(AF_INET6)
1057       iface = self.GetInterfaceName(netid)
1058       self.BindToDevice(s, iface)
1059       # dstaddr can never be our address because GetRandomDestination only fills
1060       # in the lower 32 bits, but our address has 0xff in the byte before that
1061       # (since it's constructed from the EUI-64 and so has ff:fe in the middle).
1062       dstaddr = self.GetRandomDestination(self.IPv6Prefix(netid))
1063       s.sendto("hello", (dstaddr, 53))
1064
1065       # Expect an NS for that destination on the interface.
1066       myaddr = self.MyAddress(6, netid)
1067       mymac = self.MyMacAddress(netid)
1068       desc, expected = Packets.NS(myaddr, dstaddr, mymac)
1069       msg = "Sending UDP packet to on-link destination: expecting %s" % desc
1070       self.ExpectPacketOn(netid, msg, expected)
1071
1072       # Send an NA.
1073       tgtmac = "02:00:00:00:%02x:99" % netid
1074       _, reply = Packets.NA(dstaddr, myaddr, tgtmac)
1075       # Don't use ReceivePacketOn, since that uses the router's MAC address as
1076       # the source. Instead, construct our own Ethernet header with source
1077       # MAC of tgtmac.
1078       reply = scapy.Ether(src=tgtmac, dst=mymac) / reply
1079       self.ReceiveEtherPacketOn(netid, reply)
1080
1081       # Expect the kernel to send the original UDP packet now that the ND cache
1082       # entry has been populated.
1083       sport = s.getsockname()[1]
1084       desc, expected = Packets.UDP(6, myaddr, dstaddr, sport=sport)
1085       msg = "After NA response, expecting %s" % desc
1086       self.ExpectPacketOn(netid, msg, expected)
1087
1088   @unittest.skipUnless(False, "Known bug: routing tables are never deleted")
1089   def testNoLeftoverRoutes(self):
1090     def GetNumRoutes():
1091       return len(open("/proc/net/ipv6_route").readlines())
1092
1093     num_routes = GetNumRoutes()
1094     for i in xrange(10, 20):
1095       try:
1096         self.tuns[i] = self.CreateTunInterface(i)
1097         self.SendRA(i)
1098         self.tuns[i].close()
1099       finally:
1100         del self.tuns[i]
1101     self.assertEquals(num_routes, GetNumRoutes())
1102
1103
1104 class PMTUTest(MultiNetworkTest):
1105
1106   PAYLOAD_SIZE = 1400
1107   IP_MTU_DISCOVER = 10
1108   IP_MTU = 14
1109   IP_PMTUDISC_DO = 1
1110   IPV6_PATHMTU = 61
1111   IPV6_DONTFRAG = 62
1112
1113   def GetSocketMTU(self, version, s):
1114     if version == 6:
1115       ip6_mtuinfo = s.getsockopt(net_test.SOL_IPV6, self.IPV6_PATHMTU, 32)
1116       mtu = struct.unpack("=28sI", ip6_mtuinfo)
1117       return mtu[1]
1118     else:
1119       return s.getsockopt(net_test.SOL_IP, self.IP_MTU)
1120
1121   def SetDontFragment(self, version, s):
1122     if version == 6:
1123       s.setsockopt(net_test.SOL_IPV6, self.IPV6_DONTFRAG, 1)
1124       s.setsockopt(net_test.SOL_IPV6, net_test.IPV6_RECVERR, 1)
1125     else:
1126       s.setsockopt(net_test.SOL_IP, self.IP_MTU_DISCOVER, self.IP_PMTUDISC_DO)
1127       s.setsockopt(net_test.SOL_IP, net_test.IP_RECVERR, 1)
1128
1129   def CheckPMTU(self, version):
1130     for netid in self.tuns:
1131       s = net_test.UDPSocket(self.GetProtocolFamily(version))
1132       self.SetDontFragment(version, s)
1133
1134       srcaddr = self.MyAddress(version, netid)
1135       dst_prefix, intermediate = {
1136           4: ("172.19.", "172.16.9.12"),
1137           6: ("2001:db8::", "2001:db8::1")
1138       }[version]
1139       dstaddr = self.GetRandomDestination(dst_prefix)
1140
1141       # So the packet has somewhere to go.
1142       self.SetSocketMark(s, netid)
1143       s.connect((dstaddr, 1234))
1144       self.assertEquals(1500, self.GetSocketMTU(version, s))
1145
1146       s.send(self.PAYLOAD_SIZE * "a")
1147       packets = self.ReadAllPacketsOn(netid)
1148       self.assertEquals(1, len(packets))
1149       _, toobig = Packets.ICMPPacketTooBig(version, intermediate, srcaddr,
1150                                            packets[0])
1151       self.ReceivePacketOn(netid, toobig)
1152       self.assertEquals(1280, self.GetSocketMTU(version, s))
1153       s.close()
1154
1155       # Open another socket to ensure the path MTU is cached.
1156       s2 = net_test.UDPSocket(self.GetProtocolFamily(version))
1157       self.BindToDevice(s2, self.GetInterfaceName(netid))
1158       s2.connect((dstaddr, 1234))
1159       self.assertEquals(1280, self.GetSocketMTU(version, s2))
1160
1161   def testIPv4PMTU(self):
1162     self.CheckPMTU(4)
1163
1164   def testIPv6PMTU(self):
1165     self.CheckPMTU(6)
1166
1167
1168 @unittest.skipUnless(HAVE_EXPERIMENTAL_UID_ROUTING, "no UID routing")
1169 class UidRoutingTest(MultiNetworkTest):
1170
1171   def setUp(self):
1172     self.iproute = iproute.IPRoute()
1173
1174   def GetRulesAtPriority(self, version, priority):
1175     rules = self.iproute.DumpRules(version)
1176     out = [(rule, attributes) for rule, attributes in rules
1177            if attributes.get(iproute.FRA_PRIORITY, 0) == priority]
1178     return out
1179
1180   def CheckInitialTablesHaveNoUIDs(self, version):
1181     rules = []
1182     for priority in [0, 32766, 32767]:
1183       rules.extend(self.GetRulesAtPriority(version, priority))
1184     for _, attributes in rules:
1185       self.assertNotIn(iproute.EXPERIMENTAL_FRA_UID_START, attributes)
1186       self.assertNotIn(iproute.EXPERIMENTAL_FRA_UID_END, attributes)
1187
1188   def testIPv4InitialTablesHaveNoUIDs(self):
1189     self.CheckInitialTablesHaveNoUIDs(4)
1190
1191   def testIPv6InitialTablesHaveNoUIDs(self):
1192     self.CheckInitialTablesHaveNoUIDs(6)
1193
1194   def CheckGetAndSetRules(self, version):
1195     def Random():
1196       return random.randint(1000000, 2000000)
1197
1198     start, end = tuple(sorted([Random(), Random()]))
1199     table = Random()
1200     priority = Random()
1201
1202     try:
1203       self.iproute.UidRangeRule(version, True, start, end, table,
1204                                 priority=priority)
1205
1206       rules = self.GetRulesAtPriority(version, priority)
1207       self.assertTrue(rules)
1208       _, attributes = rules[-1]
1209       self.assertEquals(priority, attributes[iproute.FRA_PRIORITY])
1210       self.assertEquals(start, attributes[iproute.EXPERIMENTAL_FRA_UID_START])
1211       self.assertEquals(end, attributes[iproute.EXPERIMENTAL_FRA_UID_END])
1212       self.assertEquals(table, attributes[iproute.FRA_TABLE])
1213     finally:
1214       self.iproute.UidRangeRule(version, False, start, end, table,
1215                                 priority=priority)
1216
1217   def testIPv4GetAndSetRules(self):
1218     self.CheckGetAndSetRules(4)
1219
1220   def testIPv6GetAndSetRules(self):
1221     self.CheckGetAndSetRules(6)
1222
1223   def ExpectNoRoute(self, addr, oif, mark, uid):
1224     # The lack of a route may be either an error, or an unreachable route.
1225     try:
1226       routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1227       rtmsg, _ = routes[0]
1228       self.assertEquals(iproute.RTN_UNREACHABLE, rtmsg.type)
1229     except IOError, e:
1230       if int(e.errno) != -int(errno.ENETUNREACH):
1231         raise e
1232
1233   def ExpectRoute(self, addr, oif, mark, uid):
1234     routes = self.iproute.GetRoutes(addr, oif, mark, uid)
1235     rtmsg, _ = routes[0]
1236     self.assertEquals(iproute.RTN_UNICAST, rtmsg.type)
1237
1238   def CheckGetRoute(self, version, addr):
1239     self.ExpectNoRoute(addr, 0, 0, 0)
1240     for netid in self.NETIDS:
1241       uid = self.UidForNetid(netid)
1242       self.ExpectRoute(addr, 0, 0, uid)
1243     self.ExpectNoRoute(addr, 0, 0, 0)
1244
1245   def testIPv4RouteGet(self):
1246     self.CheckGetRoute(4, net_test.IPV4_ADDR)
1247
1248   def testIPv6RouteGet(self):
1249     self.CheckGetRoute(6, net_test.IPV6_ADDR)
1250
1251
1252 if __name__ == "__main__":
1253   unittest.main()