OSDN Git Service

Refactor TCP test code into its own file.
authorLorenzo Colitti <lorenzo@google.com>
Mon, 25 Jan 2016 15:32:06 +0000 (00:32 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Mon, 25 Jan 2016 15:34:46 +0000 (00:34 +0900)
Change-Id: Ib1510f3679f9d4eb651e21307b71873190e610fc

tests/net_test/sock_diag.py
tests/net_test/sock_diag_test.py
tests/net_test/tcp_test.py [new file with mode: 0644]

index b3ef72b..b4d9cf6 100755 (executable)
@@ -94,20 +94,7 @@ TcpInfo = cstruct.Struct(
     "rcv_rtt rcv_space "
     "total_retrans")  # As of linux 3.13, at least.
 
-# TCP states. See include/net/tcp_states.h.
-TCP_ESTABLISHED = 1
-TCP_SYN_SENT = 2
-TCP_SYN_RECV = 3
-TCP_FIN_WAIT1 = 4
-TCP_FIN_WAIT2 = 5
 TCP_TIME_WAIT = 6
-TCP_CLOSE = 7
-TCP_CLOSE_WAIT = 8
-TCP_LAST_ACK = 9
-TCP_LISTEN = 10
-TCP_CLOSING = 11
-TCP_NEW_SYN_RECV = 12
-
 ALL_NON_TIME_WAIT = 0xffffffff & ~(1 << TCP_TIME_WAIT)
 
 
index 41644d6..f2cccfe 100755 (executable)
 # See the License for the specific language governing permissions and
 # limitations under the License.
 
-from errno import *
+from errno import *  # pylint: disable=wildcard-import
 import os
 import random
-from socket import *
+from socket import *  # pylint: disable=wildcard-import
 import time
 import unittest
 
@@ -27,6 +27,7 @@ import multinetwork_base
 import net_test
 import packets
 import sock_diag
+import tcp_test
 import threading
 
 
@@ -318,94 +319,38 @@ class SocketExceptionThread(threading.Thread):
       self.exception = e
 
 
-# TODO: Take a tun fd as input, make this a utility class, and reuse at least
-# in forwarding_test.
-class TcpTest(SockDiagBaseTest):
+class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
 
-  NOT_YET_ACCEPTED = -1
+  def testIpv4MappedSynRecvSocket(self):
+    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
 
-  def setUp(self):
-    super(TcpTest, self).setUp()
-    self.sock_diag = sock_diag.SockDiag()
-    self.netid = random.choice(self.tuns.keys())
+    Relevant kernel commits:
+         android-3.4:
+           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
+    """
+    netid = random.choice(self.tuns.keys())
+    self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
+    sock_id = self.sock_diag._EmptyInetDiagSockId()
+    sock_id.sport = self.port
+    states = 1 << tcp_test.TCP_SYN_RECV
+    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
+    children = self.sock_diag.Dump(req, NO_BYTECODE)
 
-  def OpenListenSocket(self, version):
-    self.port = packets.RandomPort()
-    family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
-    address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
-    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
-    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
-    s.bind((address, self.port))
-    # We haven't configured inbound iptables marking, so bind explicitly.
-    self.SelectInterface(s, self.netid, "mark")
-    s.listen(100)
-    return s
-
-  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
-    pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
-                                                         reply, msg)
-    self.last_packet = pkt
-    return pkt
-
-  def ReceivePacketOn(self, netid, packet):
-    super(TcpTest, self).ReceivePacketOn(netid, packet)
-    self.last_packet = packet
-
-  def RstPacket(self):
-    return packets.RST(self.version, self.myaddr, self.remoteaddr,
-                       self.last_packet)
-
-  def IncomingConnection(self, version, end_state, netid):
-    self.s = self.OpenListenSocket(version)
-    self.end_state = end_state
-
-    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
-    myaddr = self.myaddr = self.MyAddress(version, netid)
-
-    if version == 5: version = 4
-    self.version = version
-
-    if end_state == sock_diag.TCP_LISTEN:
-      return
-
-    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
-    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
-    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
-    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
-    if end_state == sock_diag.TCP_SYN_RECV:
-      return
-
-    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
-    self.ReceivePacketOn(netid, establishing_ack)
-
-    if end_state == self.NOT_YET_ACCEPTED:
-      return
-
-    self.accepted, _ = self.s.accept()
-    net_test.DisableLinger(self.accepted)
-
-    if end_state == sock_diag.TCP_ESTABLISHED:
-      return
-
-    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
-                             payload=net_test.UDP_PAYLOAD)
-    self.accepted.send(net_test.UDP_PAYLOAD)
-    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
-
-    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
-    fin = packets._GetIpLayer(version)(str(fin))
-    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
-    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
-
-    # TODO: Why can't we use this?
-    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
-    self.ReceivePacketOn(netid, fin)
-    time.sleep(0.1)
-    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
-    if end_state == sock_diag.TCP_CLOSE_WAIT:
-      return
+    self.assertTrue(children)
+    for child, unused_args in children:
+      self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
+                       child.id.dst)
+      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
+                       child.id.src)
+
+
+@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
 
-    raise ValueError("Invalid TCP state %d specified" % end_state)
+  def setUp(self):
+    super(SockDestroyTcpTest, self).setUp()
+    self.netid = random.choice(self.tuns.keys())
 
   def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
     """Closes the socket and checks whether a RST is sent or not."""
@@ -433,22 +378,21 @@ class TcpTest(SockDiagBaseTest):
       msg = "Closing incoming IPv%d %s socket" % (version, statename)
       self.IncomingConnection(version, state, self.netid)
       self.CheckRstOnClose(self.s, None, False, msg)
-      if state != sock_diag.TCP_LISTEN:
+      if state != tcp_test.TCP_LISTEN:
         msg = "Closing accepted IPv%d %s socket" % (version, statename)
         self.CheckRstOnClose(self.accepted, None, True, msg)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testTcpResets(self):
     """Checks that closing sockets in appropriate states sends a RST."""
-    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
-    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
-    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+    self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
+    self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+    self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
 
   def FindChildSockets(self, s):
     """Finds the SYN_RECV child sockets of a given listening socket."""
     d = self.sock_diag.FindSockDiagFromFd(self.s)
     req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
-    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+    req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
     req.id.cookie = "\x00" * 8
     children = self.sock_diag.Dump(req, NO_BYTECODE)
     return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
@@ -463,7 +407,7 @@ class TcpTest(SockDiagBaseTest):
       children = self.FindChildSockets(self.s)
       self.assertEquals(1, len(children))
 
-      is_established = (state == self.NOT_YET_ACCEPTED)
+      is_established = (state == tcp_test.NOT_YET_ACCEPTED)
 
       # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
       # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
@@ -512,12 +456,11 @@ class TcpTest(SockDiagBaseTest):
         CloseParent(False)
         self.s.close()
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testChildSockets(self):
-    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
-    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
-    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
-    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+    self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+    self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+    self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", False)
+    self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", True)
 
   def CloseDuringBlockingCall(self, sock, call, expected_errno):
     thread = SocketExceptionThread(sock, call)
@@ -532,25 +475,22 @@ class TcpTest(SockDiagBaseTest):
     self.assertEqual(expected_errno, thread.exception.errno)
     self.assertSocketClosed(sock)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testAcceptInterrupted(self):
     """Tests that accept() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
-      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+      self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
       self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
       self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
       self.assertRaisesErrno(EINVAL, self.s.accept)
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testReadInterrupted(self):
     """Tests that read() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
-      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+      self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
       self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
                                    ECONNABORTED)
       self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
   def testConnectInterrupted(self):
     """Tests that connect() is interrupted by SOCK_DESTROY."""
     for version in [4, 5, 6]:
@@ -572,28 +512,6 @@ class TcpTest(SockDiagBaseTest):
       msg = "SOCK_DESTROY of socket in connect, expected no RST"
       self.ExpectNoPacketsOn(self.netid, msg)
 
-  def testIpv4MappedSynRecvSocket(self):
-    """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
-
-    Relevant kernel commits:
-         android-3.4:
-           457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
-    """
-    self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
-    sock_id = self.sock_diag._EmptyInetDiagSockId()
-    sock_id.sport = self.port
-    states = 1 << sock_diag.TCP_SYN_RECV
-    req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
-    children = self.sock_diag.Dump(req, NO_BYTECODE)
-
-    self.assertTrue(children)
-    for child, unused_args in children:
-      self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
-      self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
-                       child.id.dst)
-      self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
-                       child.id.src)
-
 
 if __name__ == "__main__":
   unittest.main()
diff --git a/tests/net_test/tcp_test.py b/tests/net_test/tcp_test.py
new file mode 100644 (file)
index 0000000..2c97baf
--- /dev/null
@@ -0,0 +1,124 @@
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from socket import *  # pylint: disable=wildcard-import
+
+import net_test
+import multinetwork_base
+import packets
+
+# TCP states. See include/net/tcp_states.h.
+TCP_ESTABLISHED = 1
+TCP_SYN_SENT = 2
+TCP_SYN_RECV = 3
+TCP_FIN_WAIT1 = 4
+TCP_FIN_WAIT2 = 5
+TCP_TIME_WAIT = 6
+TCP_CLOSE = 7
+TCP_CLOSE_WAIT = 8
+TCP_LAST_ACK = 9
+TCP_LISTEN = 10
+TCP_CLOSING = 11
+TCP_NEW_SYN_RECV = 12
+
+NOT_YET_ACCEPTED = -1
+
+
+class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
+
+  def tearDown(self):
+    if hasattr(self, "s"):
+      self.s.close()
+    super(TcpBaseTest, self).tearDown()
+
+  def OpenListenSocket(self, version, netid):
+    self.port = packets.RandomPort()
+    family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+    address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
+    s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+    s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+    s.bind((address, self.port))
+    # We haven't configured inbound iptables marking, so bind explicitly.
+    self.SelectInterface(s, netid, "mark")
+    s.listen(100)
+    return s
+
+  def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+    pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
+                                                             reply, msg)
+    self.last_packet = pkt
+    return pkt
+
+  def ReceivePacketOn(self, netid, packet):
+    super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
+    self.last_packet = packet
+
+  def RstPacket(self):
+    return packets.RST(self.version, self.myaddr, self.remoteaddr,
+                       self.last_packet)
+
+  def IncomingConnection(self, version, end_state, netid):
+    self.s = self.OpenListenSocket(version, netid)
+    self.end_state = end_state
+
+    remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+    myaddr = self.myaddr = self.MyAddress(version, netid)
+
+    if version == 5: version = 4
+    self.version = version
+
+    if end_state == TCP_LISTEN:
+      return
+
+    desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+    synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+    msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+    reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+    if end_state == TCP_SYN_RECV:
+      return
+
+    establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+    self.ReceivePacketOn(netid, establishing_ack)
+
+    if end_state == NOT_YET_ACCEPTED:
+      return
+
+    self.accepted, _ = self.s.accept()
+    net_test.DisableLinger(self.accepted)
+
+    if end_state == TCP_ESTABLISHED:
+      return
+
+    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+                             payload=net_test.UDP_PAYLOAD)
+    self.accepted.send(net_test.UDP_PAYLOAD)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+    desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+    fin = packets._GetIpLayer(version)(str(fin))
+    ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+    msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+    # TODO: Why can't we use this?
+    #   self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+    self.ReceivePacketOn(netid, fin)
+    time.sleep(0.1)
+    self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+    if end_state == TCP_CLOSE_WAIT:
+      return
+
+    raise ValueError("Invalid TCP state %d specified" % end_state)