OSDN Git Service

More SOCK_DESTROY test work
authorLorenzo Colitti <lorenzo@google.com>
Sat, 19 Dec 2015 17:36:07 +0000 (02:36 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Wed, 6 Jan 2016 08:53:16 +0000 (17:53 +0900)
- Test that killing a socket kills established but not accepted
  children.
- Make tests a bit more readable.

Change-Id: I1133480233baf09d3f7bf73db612a82917ab0ca9

tests/net_test/sock_diag_test.py

index e122c9c..5975931 100755 (executable)
@@ -218,6 +218,8 @@ class SocketExceptionThread(threading.Thread):
 # in forwarding_test.
 class TcpTest(SockDiagTest):
 
+  NOT_YET_ACCEPTED = -1
+
   def setUp(self):
     super(TcpTest, self).setUp()
     self.sock_diag = sock_diag.SockDiag()
@@ -270,13 +272,15 @@ class TcpTest(SockDiagTest):
     establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
     self.ReceivePacketOn(netid, establishing_ack)
 
-    self.accepted, _ = self.s.accept()
-    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
-                             payload=net_test.UDP_PAYLOAD)
+    if end_state == self.NOT_YET_ACCEPTED:
+      return
 
+    self.accepted, _ = self.s.accept()
     if end_state == sock_diag.TCP_ESTABLISHED:
       return
 
+    desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+                             payload=net_test.UDP_PAYLOAD)
     self.accepted.send(net_test.UDP_PAYLOAD)
     self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
 
@@ -313,56 +317,99 @@ class TcpTest(SockDiagTest):
       msg = "%s: " % msg
       self.ExpectNoPacketsOn(self.netid, msg)
 
-    if sock is not None:
+    if sock is not None and do_close:
       sock.close()
 
+  def CheckTcpReset(self, state, statename):
+    for version in [4, 6]:
+      msg = "Closing incoming IPv%d %s socket" % (version, statename)
+      self.IncomingConnection(version, state, self.netid)
+      self.CheckRstOnClose(self.s, None, False, msg)
+      if state != sock_diag.TCP_LISTEN:
+        msg = "Closing accepted IPv%d %s socket" % (version, statename)
+        self.CheckRstOnClose(self.accepted, None, True, msg)
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testTcpResets(self):
+    """Checks that closing sockets in appropriate states sends a RST."""
+    self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
+    self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+    self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+
   def FindChildSockets(self, s):
     """Finds the SYN_RECV child sockets of a given listening socket."""
     d = self.sock_diag.FindSockDiagFromFd(self.s)
     req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
-    req.states = 1 << sock_diag.TCP_SYN_RECV
+    req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
     req.id.cookie = "\x00" * 8
-    sockets = self.sock_diag.Dump(req)
-    sockets = [diag_msg for diag_msg, attrs in sockets]
-    return sockets
+    children = self.sock_diag.Dump(req)
+    return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
+            for d, _ in children]
 
-  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
-  def testTcpResets(self):
-    """Checks that closing sockets in appropriate states sends a RST."""
+  def CheckChildSocket(self, state, statename, parent_first):
     for version in [4, 6]:
-      msg = "Closing incoming IPv%d TCP_LISTEN socket" % version
-      self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
-      self.CheckRstOnClose(self.s, None, False, msg)
+      self.IncomingConnection(version, state, self.netid)
 
-      msg = "Closing incoming IPv%d TCP_SYN_RECV socket" % version
-      self.IncomingConnection(version, sock_diag.TCP_SYN_RECV, self.netid)
+      d = self.sock_diag.FindSockDiagFromFd(self.s)
+      parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
       children = self.FindChildSockets(self.s)
       self.assertEquals(1, len(children))
-      for child in children:
-        req = self.sock_diag.DiagReqFromDiagMsg(child, IPPROTO_TCP)
-        if net_test.LINUX_VERSION >= (4, 4):
-          # The new TCP listener code in 4.4 makes request sockets live in the
-          # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
-          self.sock_diag.GetSockDiag(req)  # No errors? Good, child found.
-          self.CheckRstOnClose(None, req, False, msg + " child")
-          self.assertFalse(self.FindChildSockets(self.s))
-        else:
-          # Before 4.4, we can't see or kill SYN_RECV sockets.
-          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, req)
 
-      self.CheckRstOnClose(self.s, None, False, msg)
+      is_established = (state == self.NOT_YET_ACCEPTED)
 
-      msg = "Closing incoming IPv%d TCP_ESTABLISHED socket" % version
-      self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
-      self.CheckRstOnClose(self.s, None, False, msg)
-      msg = "Closing accepted IPv%d TCP_ESTABLISHED socket" % version
-      self.CheckRstOnClose(self.accepted, None, True, msg)
+      # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
+      # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
+      # Before 4.4, we can see those sockets in dumps, but we can't fetch
+      # or close them.
+      can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
 
-      msg = "Closing incoming IPv%d TCP_CLOSE_WAIT socket" % version
-      self.IncomingConnection(version, sock_diag.TCP_CLOSE_WAIT, self.netid)
-      self.CheckRstOnClose(self.s, None, False, msg)
-      msg = "Closing accepted IPv%d TCP_ESTABLISHED socket" % version
-      self.CheckRstOnClose(self.accepted, None, True, msg)
+      for child in children:
+        if can_close_children:
+          self.sock_diag.GetSockDiag(child)  # No errors? Good, child found.
+        else:
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+      def CloseParent(expect_reset):
+        msg = "Closing parent IPv%d %s socket %s child" % (
+            version, statename, "before" if parent_first else "after")
+        self.CheckRstOnClose(self.s, None, expect_reset, msg)
+        self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
+
+      def CheckChildrenClosed():
+        for child in children:
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+
+      def CloseChildren():
+        for child in children:
+          msg = "Closing child IPv%d %s socket %s parent" % (
+              version, statename, "after" if parent_first else "before")
+          self.sock_diag.GetSockDiag(child)
+          self.CheckRstOnClose(None, child, is_established, msg)
+          self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
+        CheckChildrenClosed()
+
+      if parent_first:
+        # Closing the parent will close child sockets, which will send a RST,
+        # iff they are already established.
+        CloseParent(is_established)
+        if is_established:
+          CheckChildrenClosed()
+        elif can_close_children:
+          CloseChildren()
+          CheckChildrenClosed()
+        self.s.close()
+      else:
+        if can_close_children:
+          CloseChildren()
+        CloseParent(False)
+        self.s.close()
+
+  @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+  def testChildSockets(self):
+    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+    self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
+    self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
 
   def CloseDuringBlockingCall(self, sock, call, expected_errno):
     thread = SocketExceptionThread(sock, call)