# See the License for the specific language governing permissions and
# limitations under the License.
-from errno import *
+from errno import * # pylint: disable=wildcard-import
import os
import random
-from socket import *
+from socket import * # pylint: disable=wildcard-import
import time
import unittest
import net_test
import packets
import sock_diag
+import tcp_test
import threading
-NUM_SOCKETS = 100
+NUM_SOCKETS = 30
NO_BYTECODE = ""
# TODO: Backport SOCK_DESTROY and delete this.
# Dict mapping (addr, sport, dport) tuples to socketpairs.
socketpairs = {}
for i in xrange(NUM_SOCKETS):
- family, addr = random.choice([(AF_INET, "127.0.0.1"), (AF_INET6, "::1")])
+ family, addr = random.choice([
+ (AF_INET, "127.0.0.1"),
+ (AF_INET6, "::1"),
+ (AF_INET6, "::ffff:127.0.0.1")])
socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
sport, dport = (socketpair[0].getsockname()[1],
socketpair[1].getsockname()[1])
for sock in socketpair:
self.assertSocketClosed(sock)
-
-class SockDiagTest(SockDiagBaseTest):
-
def setUp(self):
- super(SockDiagTest, self).setUp()
+ super(SockDiagBaseTest, self).setUp()
self.sock_diag = sock_diag.SockDiag()
self.socketpairs = {}
def tearDown(self):
[s.close() for socketpair in self.socketpairs.values() for s in socketpair]
- super(SockDiagTest, self).tearDown()
+ super(SockDiagBaseTest, self).tearDown()
+
+
+class SockDiagTest(SockDiagBaseTest):
def assertSockDiagMatchesSocket(self, s, diag_msg):
family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
else:
assertRaisesErrno(ENOTCONN, s.getpeername)
+ def testFindsMappedSockets(self):
+ """Tests that inet_diag_find_one_icsk can find mapped sockets.
+
+ Relevant kernel commits:
+ android-3.10:
+ f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
+ """
+ socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
+ "::ffff:127.0.0.1")
+ for sock in socketpair:
+ diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
+ diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
+ self.sock_diag.GetSockDiag(diag_req)
+ # No errors? Good.
+
def testFindsAllMySockets(self):
"""Tests that basic socket dumping works.
EINVAL,
self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+ def testNonSockDiagCommand(self):
+ def DiagDump(code):
+ sock_id = self.sock_diag._EmptyInetDiagSockId()
+ req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
+ sock_id))
+ self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
+
+ op = sock_diag.SOCK_DIAG_BY_FAMILY
+ DiagDump(op) # No errors? Good.
+ self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
+
+
+@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+class SockDestroyTest(SockDiagBaseTest):
+
def testClosesSockets(self):
self.socketpairs = self._CreateLotsOfSockets()
for (addr, _, _), socketpair in self.socketpairs.iteritems():
# Check that both sockets in the pair are closed.
self.assertSocketsClosed(socketpair)
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testNonTcpSockets(self):
s = socket(AF_INET6, SOCK_DGRAM, 0)
s.connect(("::1", 53))
diag_msg = self.sock_diag.FindSockDiagFromFd(s)
self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
- def testNonSockDiagCommand(self):
- def DiagDump(code):
- sock_id = self.sock_diag._EmptyInetDiagSockId()
- req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
- sock_id))
- self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
-
- op = sock_diag.SOCK_DIAG_BY_FAMILY
- DiagDump(op) # No errors? Good.
- self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
-
# TODO:
# Test that killing unix sockets returns EOPNOTSUPP.
self.exception = e
-# TODO: Take a tun fd as input, make this a utility class, and reuse at least
-# in forwarding_test.
-class TcpTest(SockDiagBaseTest):
-
- NOT_YET_ACCEPTED = -1
-
- def setUp(self):
- super(TcpTest, self).setUp()
- self.sock_diag = sock_diag.SockDiag()
- self.netid = random.choice(self.tuns.keys())
-
- def OpenListenSocket(self, version):
- self.port = packets.RandomPort()
- family = {4: AF_INET, 6: AF_INET6}[version]
- address = {4: "0.0.0.0", 6: "::"}[version]
- s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
- s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
- s.bind((address, self.port))
- # We haven't configured inbound iptables marking, so bind explicitly.
- self.SelectInterface(s, self.netid, "mark")
- s.listen(100)
- return s
-
- def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
- pkt = super(TcpTest, self)._ReceiveAndExpectResponse(netid, packet,
- reply, msg)
- self.last_packet = pkt
- return pkt
-
- def ReceivePacketOn(self, netid, packet):
- super(TcpTest, self).ReceivePacketOn(netid, packet)
- self.last_packet = packet
-
- def RstPacket(self):
- return packets.RST(self.version, self.myaddr, self.remoteaddr,
- self.last_packet)
-
- def IncomingConnection(self, version, end_state, netid):
- if version == 5:
- mapped = True
- socket_version = 6
- version = 4
- else:
- socket_version = version
- mapped = False
-
- self.version = version
- self.s = self.OpenListenSocket(socket_version)
- self.end_state = end_state
-
- def MaybeMappedAddress(addr):
- return "::ffff:%s" % addr if mapped else addr
-
- remoteaddr = self.remoteaddr = MaybeMappedAddress(
- self.GetRemoteAddress(version))
- myaddr = self.myaddr = MaybeMappedAddress(
- self.MyAddress(version, netid))
+class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
- if end_state == sock_diag.TCP_LISTEN:
- return
-
- desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
- synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
- msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
- reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
- if end_state == sock_diag.TCP_SYN_RECV:
- return
-
- establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
- self.ReceivePacketOn(netid, establishing_ack)
-
- if end_state == self.NOT_YET_ACCEPTED:
- return
-
- self.accepted, _ = self.s.accept()
- net_test.DisableLinger(self.accepted)
+ def testIpv4MappedSynRecvSocket(self):
+ """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
- if end_state == sock_diag.TCP_ESTABLISHED:
- return
+ Relevant kernel commits:
+ android-3.4:
+ 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
+ """
+ netid = random.choice(self.tuns.keys())
+ self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
+ sock_id = self.sock_diag._EmptyInetDiagSockId()
+ sock_id.sport = self.port
+ states = 1 << tcp_test.TCP_SYN_RECV
+ req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
+ children = self.sock_diag.Dump(req, NO_BYTECODE)
- desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
- payload=net_test.UDP_PAYLOAD)
- self.accepted.send(net_test.UDP_PAYLOAD)
- self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+ self.assertTrue(children)
+ for child, unused_args in children:
+ self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
+ self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
+ child.id.dst)
+ self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
+ child.id.src)
- desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
- fin = packets._GetIpLayer(version)(str(fin))
- ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
- msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
- # TODO: Why can't we use this?
- # self._ReceiveAndExpectResponse(netid, fin, ack, msg)
- self.ReceivePacketOn(netid, fin)
- time.sleep(0.1)
- self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
- if end_state == sock_diag.TCP_CLOSE_WAIT:
- return
+@unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
+class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
- raise ValueError("Invalid TCP state %d specified" % end_state)
+ def setUp(self):
+ super(SockDestroyTcpTest, self).setUp()
+ self.netid = random.choice(self.tuns.keys())
def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
"""Closes the socket and checks whether a RST is sent or not."""
sock.close()
def CheckTcpReset(self, state, statename):
- for version in [4, 6]:
+ for version in [4, 5, 6]:
msg = "Closing incoming IPv%d %s socket" % (version, statename)
self.IncomingConnection(version, state, self.netid)
self.CheckRstOnClose(self.s, None, False, msg)
- if state != sock_diag.TCP_LISTEN:
+ if state != tcp_test.TCP_LISTEN:
msg = "Closing accepted IPv%d %s socket" % (version, statename)
self.CheckRstOnClose(self.accepted, None, True, msg)
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testTcpResets(self):
"""Checks that closing sockets in appropriate states sends a RST."""
- self.CheckTcpReset(sock_diag.TCP_LISTEN, "TCP_LISTEN")
- self.CheckTcpReset(sock_diag.TCP_ESTABLISHED, "TCP_ESTABLISHED")
- self.CheckTcpReset(sock_diag.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
+ self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
+ self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
+ self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
def FindChildSockets(self, s):
"""Finds the SYN_RECV child sockets of a given listening socket."""
d = self.sock_diag.FindSockDiagFromFd(self.s)
req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
- req.states = 1 << sock_diag.TCP_SYN_RECV | 1 << sock_diag.TCP_ESTABLISHED
+ req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
req.id.cookie = "\x00" * 8
children = self.sock_diag.Dump(req, NO_BYTECODE)
return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
for d, _ in children]
def CheckChildSocket(self, state, statename, parent_first):
- for version in [4, 6]:
+ for version in [4, 5, 6]:
self.IncomingConnection(version, state, self.netid)
d = self.sock_diag.FindSockDiagFromFd(self.s)
children = self.FindChildSockets(self.s)
self.assertEquals(1, len(children))
- is_established = (state == self.NOT_YET_ACCEPTED)
+ is_established = (state == tcp_test.NOT_YET_ACCEPTED)
# The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
# regular TCP hash tables, and inet_diag_find_one_icsk can find them.
CloseParent(False)
self.s.close()
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testChildSockets(self):
- self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", False)
- self.CheckChildSocket(sock_diag.TCP_SYN_RECV, "TCP_SYN_RECV", True)
- self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", False)
- self.CheckChildSocket(self.NOT_YET_ACCEPTED, "not yet accepted", True)
+ self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", False)
+ self.CheckChildSocket(tcp_test.TCP_SYN_RECV, "TCP_SYN_RECV", True)
+ self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", False)
+ self.CheckChildSocket(tcp_test.NOT_YET_ACCEPTED, "not yet accepted", True)
def CloseDuringBlockingCall(self, sock, call, expected_errno):
thread = SocketExceptionThread(sock, call)
self.assertEqual(expected_errno, thread.exception.errno)
self.assertSocketClosed(sock)
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testAcceptInterrupted(self):
"""Tests that accept() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
- self.IncomingConnection(version, sock_diag.TCP_LISTEN, self.netid)
+ self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
self.assertRaisesErrno(EINVAL, self.s.accept)
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testReadInterrupted(self):
"""Tests that read() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
- self.IncomingConnection(version, sock_diag.TCP_ESTABLISHED, self.netid)
+ self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
ECONNABORTED)
self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
- @unittest.skipUnless(HAVE_SOCK_DESTROY, "SOCK_DESTROY not supported")
def testConnectInterrupted(self):
"""Tests that connect() is interrupted by SOCK_DESTROY."""
for version in [4, 5, 6]:
msg = "SOCK_DESTROY of socket in connect, expected no RST"
self.ExpectNoPacketsOn(self.netid, msg)
- def testIpv4MappedSynRecvSocket(self):
- """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
-
- Relevant kernel commits:
- android-3.4:
- 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
- """
- self.IncomingConnection(5, sock_diag.TCP_SYN_RECV, self.netid)
- sock_id = self.sock_diag._EmptyInetDiagSockId()
- sock_id.sport = self.port
- states = 1 << sock_diag.TCP_SYN_RECV
- req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
- children = self.sock_diag.Dump(req, NO_BYTECODE)
-
- self.assertTrue(children)
- for child, unused_args in children:
- self.assertEqual(sock_diag.TCP_SYN_RECV, child.state)
- self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
- child.id.dst)
- self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
- child.id.src)
-
if __name__ == "__main__":
unittest.main()
--- /dev/null
+#!/usr/bin/python
+#
+# Copyright 2015 The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import time
+from socket import * # pylint: disable=wildcard-import
+
+import net_test
+import multinetwork_base
+import packets
+
+# TCP states. See include/net/tcp_states.h.
+TCP_ESTABLISHED = 1
+TCP_SYN_SENT = 2
+TCP_SYN_RECV = 3
+TCP_FIN_WAIT1 = 4
+TCP_FIN_WAIT2 = 5
+TCP_TIME_WAIT = 6
+TCP_CLOSE = 7
+TCP_CLOSE_WAIT = 8
+TCP_LAST_ACK = 9
+TCP_LISTEN = 10
+TCP_CLOSING = 11
+TCP_NEW_SYN_RECV = 12
+
+NOT_YET_ACCEPTED = -1
+
+
+class TcpBaseTest(multinetwork_base.MultiNetworkBaseTest):
+
+ def tearDown(self):
+ if hasattr(self, "s"):
+ self.s.close()
+ super(TcpBaseTest, self).tearDown()
+
+ def OpenListenSocket(self, version, netid):
+ self.port = packets.RandomPort()
+ family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
+ address = {4: "0.0.0.0", 5: "::", 6: "::"}[version]
+ s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
+ s.setsockopt(SOL_SOCKET, SO_REUSEADDR, 1)
+ s.bind((address, self.port))
+ # We haven't configured inbound iptables marking, so bind explicitly.
+ self.SelectInterface(s, netid, "mark")
+ s.listen(100)
+ return s
+
+ def _ReceiveAndExpectResponse(self, netid, packet, reply, msg):
+ pkt = super(TcpBaseTest, self)._ReceiveAndExpectResponse(netid, packet,
+ reply, msg)
+ self.last_packet = pkt
+ return pkt
+
+ def ReceivePacketOn(self, netid, packet):
+ super(TcpBaseTest, self).ReceivePacketOn(netid, packet)
+ self.last_packet = packet
+
+ def RstPacket(self):
+ return packets.RST(self.version, self.myaddr, self.remoteaddr,
+ self.last_packet)
+
+ def IncomingConnection(self, version, end_state, netid):
+ self.s = self.OpenListenSocket(version, netid)
+ self.end_state = end_state
+
+ remoteaddr = self.remoteaddr = self.GetRemoteAddress(version)
+ myaddr = self.myaddr = self.MyAddress(version, netid)
+
+ if version == 5: version = 4
+ self.version = version
+
+ if end_state == TCP_LISTEN:
+ return
+
+ desc, syn = packets.SYN(self.port, version, remoteaddr, myaddr)
+ synack_desc, synack = packets.SYNACK(version, myaddr, remoteaddr, syn)
+ msg = "Received %s, expected to see reply %s" % (desc, synack_desc)
+ reply = self._ReceiveAndExpectResponse(netid, syn, synack, msg)
+ if end_state == TCP_SYN_RECV:
+ return
+
+ establishing_ack = packets.ACK(version, remoteaddr, myaddr, reply)[1]
+ self.ReceivePacketOn(netid, establishing_ack)
+
+ if end_state == NOT_YET_ACCEPTED:
+ return
+
+ self.accepted, _ = self.s.accept()
+ net_test.DisableLinger(self.accepted)
+
+ if end_state == TCP_ESTABLISHED:
+ return
+
+ desc, data = packets.ACK(version, myaddr, remoteaddr, establishing_ack,
+ payload=net_test.UDP_PAYLOAD)
+ self.accepted.send(net_test.UDP_PAYLOAD)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % desc, data)
+
+ desc, fin = packets.FIN(version, remoteaddr, myaddr, data)
+ fin = packets._GetIpLayer(version)(str(fin))
+ ack_desc, ack = packets.ACK(version, myaddr, remoteaddr, fin)
+ msg = "Received %s, expected to see reply %s" % (desc, ack_desc)
+
+ # TODO: Why can't we use this?
+ # self._ReceiveAndExpectResponse(netid, fin, ack, msg)
+ self.ReceivePacketOn(netid, fin)
+ time.sleep(0.1)
+ self.ExpectPacketOn(netid, msg + ": expecting %s" % ack_desc, ack)
+ if end_state == TCP_CLOSE_WAIT:
+ return
+
+ raise ValueError("Invalid TCP state %d specified" % end_state)