OSDN Git Service

Merge changes Ib1510f36,I8a2f50da,Ife74d135,Idfb19903
authorLorenzo Colitti <lorenzo@google.com>
Tue, 26 Jan 2016 07:01:17 +0000 (07:01 +0000)
committerandroid-build-merger <android-build-merger@google.com>
Tue, 26 Jan 2016 07:01:17 +0000 (07:01 +0000)
am: 9427861d27

* commit '9427861d2746735378aa3dc628c917fac2acec83':
  Refactor TCP test code into its own file.
  Move some sock_diag tests around.
  Teach more sock_diag code and tests about mapped sockets.
  Delete the hack that finds mapped sockets.

tests/net_test/multinetwork_base.py
tests/net_test/sock_diag.py
tests/net_test/sock_diag_test.py
tests/net_test/tcp_test.py [new file with mode: 0644]

index c178bdd..31fcc4c 100644 (file)
@@ -188,6 +188,7 @@ class MultiNetworkBaseTest(net_test.NetworkTest):
   @classmethod
   def MyAddress(cls, version, netid):
     return {4: cls._MyIPv4Address(netid),
+            5: "::ffff:" + cls._MyIPv4Address(netid),
             6: cls._MyIPv6Address(netid)}[version]
 
   @classmethod
@@ -428,7 +429,9 @@ class MultiNetworkBaseTest(net_test.NetworkTest):
       s.setsockopt(net_test.SOL_IPV6, IPV6_UNICAST_IF, ifindex)
 
   def GetRemoteAddress(self, version):
-    return {4: self.IPV4_ADDR, 6: self.IPV6_ADDR}[version]
+    return {4: self.IPV4_ADDR,
+            5: "::ffff:" + self.IPV4_ADDR,
+            6: self.IPV6_ADDR}[version]
 
   def SelectInterface(self, s, netid, mode):
     if mode == "uid":
index 6979877..b4d9cf6 100755 (executable)
@@ -94,20 +94,7 @@ TcpInfo = cstruct.Struct(
     "rcv_rtt rcv_space "
     "total_retrans")  # As of linux 3.13, at least.
 
-# TCP states. See include/net/tcp_states.h.
-TCP_ESTABLISHED = 1
-TCP_SYN_SENT = 2
-TCP_SYN_RECV = 3
-TCP_FIN_WAIT1 = 4
-TCP_FIN_WAIT2 = 5
 TCP_TIME_WAIT = 6
-TCP_CLOSE = 7
-TCP_CLOSE_WAIT = 8
-TCP_LAST_ACK = 9
-TCP_LISTEN = 10
-TCP_CLOSING = 11
-TCP_NEW_SYN_RECV = 12
-
 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
 
 
@@ -329,27 +316,7 @@ class SockDiag(netlink.NetlinkSocket):
   @staticmethod
   def DiagReqFromDiagMsg(d, protocol):
     """Constructs a diag_req from a diag_msg the kernel has given us."""
-    # For a dual-stack socket connected to a mapped address, the diag_msg
-    # returned by the kernel has family AF_INET6 and mapped addresses. But if
-    # we ask the kernel to find a socket based on that data, we'll get ENOENT.
-    # This is because inet_diag_find_one_icsk sees diag_req.family == AF_INET6
-    # and looks in the IPv6 TCP hash table, but mapped sockets are in the IPv4
-    # hash tables. So fix up the diag_req to specify AF_INET.
-    #
-    # TODO: Should the kernel do this for us in inet_diag_find_one_icsk?
-    mapped_prefix = inet_pton(AF_INET6, "::ffff:0.0.0.0")[:12]
-    if (d.family == AF_INET6 and
-        (d.id.src.startswith(mapped_prefix) or
-         d.id.dst.startswith(mapped_prefix))):
-      family = AF_INET
-      sock_id = InetDiagSockId((d.id.sport, d.id.dport,
-                                d.id.src[12:16] + "\x00" * 12,
-                                d.id.dst[12:16] + "\x00" * 12,
-                                d.id.iface, d.id.cookie))
-    else:
-      family = d.family
-      sock_id = d.id
-    return InetDiagReqV2((family, protocol, 0, 1 << d.state, sock_id))
+    return InetDiagReqV2((d.family, protocol, 0, 1 << d.state, d.id))
 
   def CloseSocket(self, req):
     self._SendNlRequest(SOCK_DESTROY, req.Pack(),
index 35301a7..f2cccfe 100755 (executable)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from errno import *
+from errno import *  # pylint: disable=wildcard-import
 import os
 import random
-from socket import *
+from socket import *  # pylint: disable=wildcard-import
 import time
 import unittest
 
@@ -27,10 +27,11 @@ import multinetwork_base
 import net_test
 import packets
 import sock_diag
+import tcp_test
 import threading
 
 
-NUM_SOCKETS = 100
+NUM_SOCKETS = 30
 NO_BYTECODE = ""
 
 # TODO: Backport SOCK_DESTROY and delete this.
@@ -44,7 +45,10 @@ class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
     # Dict mapping (addr, sport, dport) tuples to socketpairs.
     socketpairs = {}
     for i in xrange(NUM_SOCKETS):
-      family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
+      family, addr = random.choice([
+          (AF_INET, "127.0.0.1"),
+          (AF_INET6, "::1"),
+          (AF_INET6, "::ffff:127.0.0.1")])
       socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
       sport, dport = (socketpair[0].getsockname()[1],
                       socketpair[1].getsockname()[1])
@@ -61,17 +65,17 @@ class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
     for sock in socketpair:
       self.assertSocketClosed(sock)
 
-
-class SockDiagTest(SockDiagBaseTest):
-
   def setUp(self):
-    super(SockDiagTest, self).setUp()
+    super(SockDiagBaseTest, self).setUp()
     self.sock_diag = sock_diag.SockDiag()
     self.socketpairs = {}
 
   def tearDown(self):
     [s.close() for socketpair in self.socketpairs.values() for s in socketpair]
-    super(SockDiagTest, self).tearDown()
+    super(SockDiagBaseTest, self).tearDown()
+
+
+class SockDiagTest(SockDiagBaseTest):
 
   def assertSockDiagMatchesSocket(self, s, diag_msg):
     family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
@@ -88,6 +92,21 @@ class SockDiagTest(SockDiagBaseTest):
     else:
       assertRaisesErrno(ENOTCONN, s.getpeername)
 
+  def testFindsMappedSockets(self):
+    """Tests that inet_diag_find_one_icsk can find mapped sockets.
+
+    Relevant kernel commits:
+      android-3.10:
+        f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
+    """
+    socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
+                                           "::ffff:127.0.0.1")
+    for sock in socketpair:
+      diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
+      diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
+      self.sock_diag.GetSockDiag(diag_req)
+      # No errors? Good.
+
   def testFindsAllMySockets(self):
     """Tests that basic socket dumping works.
 
@@ -232,7 +251,21 @@ class SockDiagTest(SockDiagBaseTest):
         EINVAL,
         self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testNonSockDiagCommand(self):
+    def DiagDump(code):
+      sock_id = self.sock_diag._EmptyInetDiagSockId()
+      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
+                                     sock_id))
+      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
+
+    op = sock_diag.SOCK_DIAG_BY_FAMILY
+    DiagDump(op)  # No errors? Good.
+    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+
+
+@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+class SockDestroyTest(SockDiagBaseTest):
+
   def testClosesSockets(self):
     self.socketpairs = self._CreateLotsOfSockets()
     for (addr, _, _), socketpair in self.socketpairs.iteritems():
@@ -260,24 +293,12 @@ class SockDiagTest(SockDiagBaseTest):
       # Check that both sockets in the pair are closed.
       self.assertSocketsClosed(socketpair)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testNonTcpSockets(self):
     s = socket(AF_INET6, SOCK_DGRAM, 0)
     s.connect(("::1", 53))
     diag_msg = self.sock_diag.FindSockDiagFromFd(s)
     self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
 
-  def testNonSockDiagCommand(self):
-    def DiagDump(code):
-      sock_id = self.sock_diag._EmptyInetDiagSockId()
-      req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
-                                     sock_id))
-      self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
-
-    op = sock_diag.SOCK_DIAG_BY_FAMILY
-    DiagDump(op)  # No errors? Good.
-    self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
-
   # TODO:
   # Test that killing unix sockets returns EOPNOTSUPP.
 
@@ -298,105 +319,38 @@ class SocketExceptionThread(threading.Thread):
       self.exception = e
 
 
-# TODO: Take a tun fd as input, make this a utility class, and reuse at least
-# in forwarding_test.
-class TcpTest(SockDiagBaseTest):
-
-  NOT_YET_ACCEPTED = -1
-
-  def setUp(self):
-    super(TcpTest, self).setUp()
-    self.sock_diag = sock_diag.SockDiag()
-    self.netid = random.choice(self.tuns.keys())
-
-  def OpenListenSocket(self, version):
-    self.port = packets.RandomPort()
-    family = {4: AF_INET, 6: AF_INET6}[version]
-    address = {4: "0.0.0.0", 6: "::"}[version]
-    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
-    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
-    s.bind((address, self.port))
-    # We haven't configured inbound iptables marking, so bind explicitly.
-    self.SelectInterface(s, self.netid, "mark")
-    s.listen(100)
-    return s
-
-  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
-    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
-                                                         reply, msg)
-    self.last_packet = pkt
-    return pkt
-
-  def ReceivePacketOn(self, netid, packet):
-    super(TcpTest, self).ReceivePacketOn(netid, packet)
-    self.last_packet = packet
-
-  def RstPacket(self):
-    return packets.RST(self.version, self.myaddr, self.remoteaddr,
-                       self.last_packet)
-
-  def IncomingConnection(self, version, end_state, netid):
-    if version == 5:
-      mapped = True
-      socket_version = 6
-      version = 4
-    else:
-      socket_version = version
-      mapped = False
-
-    self.version = version
-    self.s = self.OpenListenSocket(socket_version)
-    self.end_state = end_state
-
-    def MaybeMappedAddress(addr):
-      return "::ffff:%s" % addr if mapped else addr
-
-    remoteaddr = self.remoteaddr = MaybeMappedAddress(
-        self.GetRemoteAddress(version))
-    myaddr = self.myaddr = MaybeMappedAddress(
-        self.MyAddress(version, netid))
+class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
 
-    if end_state == sock_diag.TCP_LISTEN:
-      return
-
-    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
-    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
-    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
-    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
-    if end_state == sock_diag.TCP_SYN_RECV:
-      return
-
-    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
-    self.ReceivePacketOn(netid, establishing_ack)
-
-    if end_state == self.NOT_YET_ACCEPTED:
-      return
-
-    self.accepted, _ = self.s.accept()
-    net_test.DisableLinger(self.accepted)
+  def testIpv4MappedSynRecvSocket(self):
+    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
 
-    if end_state == sock_diag.TCP_ESTABLISHED:
-      return
+    Relevant kernel commits:
+         android-3.4:
+           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
+    """
+    netid = random.choice(self.tuns.keys())
+    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
+    sock_id = self.sock_diag._EmptyInetDiagSockId()
+    sock_id.sport = self.port
+    states = 1 << tcp_test.TCP_SYN_RECV
+    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
+    children = self.sock_diag.Dump(req, NO_BYTECODE)
 
-    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
-                             payload=net_test.UDP_PAYLOAD)
-    self.accepted.send(net_test.UDP_PAYLOAD)
-    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+    self.assertTrue(children)
+    for child, unused_args in children:
+      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
+                       child.id.dst)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
+                       child.id.src)
 
-    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
-    fin = packets._GetIpLayer(version)(str(fin))
-    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
-    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
 
-    # TODO: Why can't we use this?
-    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
-    self.ReceivePacketOn(netid, fin)
-    time.sleep(0.1)
-    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
-    if end_state == sock_diag.TCP_CLOSE_WAIT:
-      return
+@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
 
-    raise ValueError("Invalid TCP state %d specified" % end_state)
+  def setUp(self):
+    super(SockDestroyTcpTest, self).setUp()
+    self.netid = random.choice(self.tuns.keys())
 
   def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
     """Closes the socket and checks whether a RST is sent or not."""
@@ -420,33 +374,32 @@ class TcpTest(SockDiagBaseTest):
       sock.close()
 
   def CheckTcpReset(self, state, statename):
-    for version in [4, 6]:
+    for version in [4, 5, 6]:
       msg = "Closing incoming IPv%d %s socket" % (version, statename)
       self.IncomingConnection(version, state, self.netid)
       self.CheckRstOnClose(self.s, None, False, msg)
-      if state != sock_diag.TCP_LISTEN:
+      if state != tcp_test.TCP_LISTEN:
         msg = "Closing accepted IPv%d %s socket" % (version, statename)
         self.CheckRstOnClose(self.accepted, None, True, msg)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testTcpResets(self):
     """Checks that closing sockets in appropriate states sends a RST."""
-    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
-    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
-    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
+    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
 
   def FindChildSockets(self, s):
     """Finds the SYN_RECV child sockets of a given listening socket."""
     d = self.sock_diag.FindSockDiagFromFd(self.s)
     req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
-    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
     req.id.cookie = "\x00" * 8
     children = self.sock_diag.Dump(req, NO_BYTECODE)
     return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
             for d, _ in children]
 
   def CheckChildSocket(self, state, statename, parent_first):
-    for version in [4, 6]:
+    for version in [4, 5, 6]:
       self.IncomingConnection(version, state, self.netid)
 
       d = self.sock_diag.FindSockDiagFromFd(self.s)
@@ -454,7 +407,7 @@ class TcpTest(SockDiagBaseTest):
       children = self.FindChildSockets(self.s)
       self.assertEquals(1, len(children))
 
-      is_established = (state == self.NOT_YET_ACCEPTED)
+      is_established = (state == tcp_test.NOT_YET_ACCEPTED)
 
       # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
       # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
@@ -503,12 +456,11 @@ class TcpTest(SockDiagBaseTest):
         CloseParent(False)
         self.s.close()
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testChildSockets(self):
-    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
-    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
-    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
-    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+    self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+    self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+    self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", False)
+    self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", True)
 
   def CloseDuringBlockingCall(self, sock, call, expected_errno):
     thread = SocketExceptionThread(sock, call)
@@ -523,25 +475,22 @@ class TcpTest(SockDiagBaseTest):
     self.assertEqual(expected_errno, thread.exception.errno)
     self.assertSocketClosed(sock)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testAcceptInterrupted(self):
     """Tests that accept() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
-      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
       self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
       self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
       self.assertRaisesErrno(EINVAL, self.s.accept)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testReadInterrupted(self):
     """Tests that read() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
-      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
       self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
                                    ECONNABORTED)
       self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testConnectInterrupted(self):
     """Tests that connect() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
@@ -563,28 +512,6 @@ class TcpTest(SockDiagBaseTest):
       msg = "SOCK_DESTROY of socket in connect, expected no RST"
       self.ExpectNoPacketsOn(self.netid, msg)
 
-  def testIpv4MappedSynRecvSocket(self):
-    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
-
-    Relevant kernel commits:
-         android-3.4:
-           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
-    """
-    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
-    sock_id = self.sock_diag._EmptyInetDiagSockId()
-    sock_id.sport = self.port
-    states = 1 << sock_diag.TCP_SYN_RECV
-    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
-    children = self.sock_diag.Dump(req, NO_BYTECODE)
-
-    self.assertTrue(children)
-    for child, unused_args in children:
-      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
-      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
-                       child.id.dst)
-      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
-                       child.id.src)
-
 
 if __name__ == "__main__":
   unittest.main()
diff --git a/tests/net_test/tcp_test.py b/tests/net_test/tcp_test.py
new file mode 100644 (file)
index 0000000..2c97baf
--- /dev/null
@@ -0,0 +1,124 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from socket import *  # pylint: disable=wildcard-import
+
+import net_test
+import multinetwork_base
+import packets
+
+# TCP states. See include/net/tcp_states.h.
+TCP_ESTABLISHED = 1
+TCP_SYN_SENT = 2
+TCP_SYN_RECV = 3
+TCP_FIN_WAIT1 = 4
+TCP_FIN_WAIT2 = 5
+TCP_TIME_WAIT = 6
+TCP_CLOSE = 7
+TCP_CLOSE_WAIT = 8
+TCP_LAST_ACK = 9
+TCP_LISTEN = 10
+TCP_CLOSING = 11
+TCP_NEW_SYN_RECV = 12
+
+NOT_YET_ACCEPTED = -1
+
+
+class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
+
+  def tearDown(self):
+    if hasattr(self, "s"):
+      self.s.close()
+    super(TcpBaseTest, self).tearDown()
+
+  def OpenListenSocket(self, version, netid):
+    self.port = packets.RandomPort()
+    family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+    address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
+    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+    s.bind((address, self.port))
+    # We haven't configured inbound iptables marking, so bind explicitly.
+    self.SelectInterface(s, netid, "mark")
+    s.listen(100)
+    return s
+
+  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+    pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
+                                                             reply, msg)
+    self.last_packet = pkt
+    return pkt
+
+  def ReceivePacketOn(self, netid, packet):
+    super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
+    self.last_packet = packet
+
+  def RstPacket(self):
+    return packets.RST(self.version, self.myaddr, self.remoteaddr,
+                       self.last_packet)
+
+  def IncomingConnection(self, version, end_state, netid):
+    self.s = self.OpenListenSocket(version, netid)
+    self.end_state = end_state
+
+    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+    myaddr = self.myaddr = self.MyAddress(version, netid)
+
+    if version == 5: version = 4
+    self.version = version
+
+    if end_state == TCP_LISTEN:
+      return
+
+    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+    if end_state == TCP_SYN_RECV:
+      return
+
+    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+    self.ReceivePacketOn(netid, establishing_ack)
+
+    if end_state == NOT_YET_ACCEPTED:
+      return
+
+    self.accepted, _ = self.s.accept()
+    net_test.DisableLinger(self.accepted)
+
+    if end_state == TCP_ESTABLISHED:
+      return
+
+    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+                             payload=net_test.UDP_PAYLOAD)
+    self.accepted.send(net_test.UDP_PAYLOAD)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+    fin = packets._GetIpLayer(version)(str(fin))
+    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+    # TODO: Why can't we use this?
+    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+    self.ReceivePacketOn(netid, fin)
+    time.sleep(0.1)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+    if end_state == TCP_CLOSE_WAIT:
+      return
+
+    raise ValueError("Invalid TCP state %d specified" % end_state)