OSDN Git Service

Test SIOCKILLADDR on non-empty addresses too.
authorLorenzo Colitti <lorenzo@google.com>
Tue, 27 Oct 2015 08:37:00 +0000 (17:37 +0900)
committerLorenzo Colitti <lorenzo@google.com>
Wed, 28 Oct 2015 04:43:19 +0000 (13:43 +0900)
Change-Id: Ifddc4e16c4cd59f96048dd517186d0a1f440f02c

tests/net_test/tcp_nuke_addr_test.py

index b4c1882..fd3cded 100755 (executable)
@@ -96,25 +96,27 @@ def ExchangeMessage(addr_family, ip_addr):
         client_socket.send(DEFAULT_TEST_MESSAGE)
 
 
-def KillAddrIoctl(addr_family):
-  """Calls the SIOCKILLADDR on IPv6 address family.
+def KillAddrIoctl(addr):
+  """Calls the SIOCKILLADDR ioctl on the provided IP address.
 
   Args:
-    addr_family: The address family (e.g. AF_INET6).
+    addr The IP address to pass to the ioctl.
 
   Raises:
-    ValueError: If the address family is invalid for the ioctl.
+    ValueError: If addr is of an unsupported address family.
   """
-  if addr_family == AF_INET6:
-    addr = inet_pton(AF_INET6, IPV6_LOOPBACK_ADDR)
+  family, _, _, _, _ = getaddrinfo(addr, None, AF_UNSPEC, SOCK_DGRAM, 0,
+                                   AI_NUMERICHOST)[0]
+  if family == AF_INET6:
+    addr = inet_pton(AF_INET6, addr)
     ifreq = In6Ifreq((addr, 128, LOOPBACK_IFINDEX)).Pack()
-  elif addr_family == AF_INET:
-    addr = inet_pton(AF_INET, IPV4_LOOPBACK_ADDR)
+  elif family == AF_INET:
+    addr = inet_pton(AF_INET, addr)
     sockaddr = csocket.SockaddrIn((AF_INET, 0, addr)).Pack()
     ifreq = Ifreq((LOOPBACK_DEV, sockaddr)).Pack()
   else:
-    raise ValueError('Address family %r not supported.' % addr_family)
-  datagram_socket = socket(addr_family, SOCK_DGRAM)
+    raise ValueError('Address family %r not supported.' % family)
+  datagram_socket = socket(family, SOCK_DGRAM)
   fcntl.ioctl(datagram_socket.fileno(), SIOCKILLADDR, ifreq)
   datagram_socket.close()
 
@@ -155,25 +157,25 @@ class TcpNukeAddrTest(net_test.NetworkTest):
     """
     for i in xrange(DEFAULT_TEST_RUNS):
       ExchangeMessage(AF_INET6, IPV6_LOOPBACK_ADDR)
-      KillAddrIoctl(AF_INET6)
+      KillAddrIoctl(IPV6_LOOPBACK_ADDR)
       ExchangeMessage(AF_INET, IPV4_LOOPBACK_ADDR)
-      KillAddrIoctl(AF_INET)
+      KillAddrIoctl(IPV4_LOOPBACK_ADDR)
       # Test passes if kernel does not crash.
 
-  def testClosesSockets(self):
-    """Tests that SIOCKILLADDR closes IPv6 sockets."""
+  def testClosesIPv6Sockets(self):
+    """Tests that SIOCKILLADDR closes IPv6 sockets and unblocks threads."""
 
     threadpairs = []
 
     for i in xrange(DEFAULT_TEST_RUNS):
-      clientsock, acceptedsock = CreateSocketPair(AF_INET6, "::1")
+      clientsock, acceptedsock = CreateSocketPair(AF_INET6, IPV6_LOOPBACK_ADDR)
       clientthread = ExceptionalReadThread(clientsock)
       clientthread.start()
       serverthread = ExceptionalReadThread(acceptedsock)
       serverthread.start()
       threadpairs.append((clientthread, serverthread))
 
-    KillAddrIoctl(AF_INET6)
+    KillAddrIoctl(IPV6_LOOPBACK_ADDR)
 
     def CheckThreadException(thread):
       thread.join(100)
@@ -182,13 +184,47 @@ class TcpNukeAddrTest(net_test.NetworkTest):
       self.assertTrue(isinstance(thread.exception, IOError))
       self.assertEquals(errno.ETIMEDOUT, thread.exception.errno)
       self.assertRaisesErrno(errno.ENOTCONN, thread.sock.getpeername)
-      self.assertRaisesErrno(errno.EISCONN, thread.sock.connect, ("::1", 53))
+      self.assertRaisesErrno(errno.EISCONN, thread.sock.connect,
+                             (IPV6_LOOPBACK_ADDR, 53))
       self.assertRaisesErrno(errno.EPIPE, thread.sock.send, "foo")
 
     for clientthread, serverthread in threadpairs:
       CheckThreadException(clientthread)
       CheckThreadException(serverthread)
 
+  def assertSocketsClosed(self, socketpair):
+    for sock in socketpair:
+      self.assertRaisesErrno(errno.ENOTCONN, sock.getpeername)
+
+  def assertSocketsNotClosed(self, socketpair):
+    for sock in socketpair:
+      self.assertTrue(sock.getpeername())
+
+  def testAddresses(self):
+    socketpair = CreateSocketPair(AF_INET, IPV4_LOOPBACK_ADDR)
+    KillAddrIoctl("::")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("::1")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("127.0.0.3")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("0.0.0.0")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("127.0.0.1")
+    self.assertSocketsClosed(socketpair)
+
+    socketpair = CreateSocketPair(AF_INET6, IPV6_LOOPBACK_ADDR)
+    KillAddrIoctl("0.0.0.0")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("127.0.0.1")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("::2")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("::")
+    self.assertSocketsNotClosed(socketpair)
+    KillAddrIoctl("::1")
+    self.assertSocketsClosed(socketpair)
+
 
 class TcpNukeAddrHashTest(net_test.NetworkTest):
 
@@ -206,8 +242,8 @@ class TcpNukeAddrHashTest(net_test.NetworkTest):
       socketpairs.append(CreateSocketPair(AF_INET, IPV4_LOOPBACK_ADDR))
       socketpairs.append(CreateSocketPair(AF_INET6, IPV6_LOOPBACK_ADDR))
 
-    KillAddrIoctl(AF_INET)
-    KillAddrIoctl(AF_INET6)
+    KillAddrIoctl(IPV4_LOOPBACK_ADDR)
+    KillAddrIoctl(IPV6_LOOPBACK_ADDR)
 
     for socketpair in socketpairs:
       for sock in socketpair: