3 # Copyright 2015 The Android Open Source Project
5 # Licensed under the Apache License, Version 2.0 (the "License");
6 # you may not use this file except in compliance with the License.
7 # You may obtain a copy of the License at
9 # http://www.apache.org/licenses/LICENSE-2.0
11 # Unless required by applicable law or agreed to in writing, software
12 # distributed under the License is distributed on an "AS IS" BASIS,
13 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 # See the License for the specific language governing permissions and
15 # limitations under the License.
17 # pylint: disable=g-bad-todo,g-bad-file-header,wildcard-import
18 from errno import * # pylint: disable=wildcard-import
22 from socket import * # pylint: disable=wildcard-import
27 import multinetwork_base
38 class SockDiagBaseTest(multinetwork_base.MultiNetworkBaseTest):
41 def _CreateLotsOfSockets():
42 # Dict mapping (addr, sport, dport) tuples to socketpairs.
44 for _ in xrange(NUM_SOCKETS):
45 family, addr = random.choice([
46 (AF_INET, "127.0.0.1"),
48 (AF_INET6, "::ffff:127.0.0.1")])
49 socketpair = net_test.CreateSocketPair(family, SOCK_STREAM, addr)
50 sport, dport = (socketpair[0].getsockname()[1],
51 socketpair[1].getsockname()[1])
52 socketpairs[(addr, sport, dport)] = socketpair
55 def assertSocketClosed(self, sock):
56 self.assertRaisesErrno(ENOTCONN, sock.getpeername)
58 def assertSocketConnected(self, sock):
59 sock.getpeername() # No errors? Socket is alive and connected.
61 def assertSocketsClosed(self, socketpair):
62 for sock in socketpair:
63 self.assertSocketClosed(sock)
66 super(SockDiagBaseTest, self).setUp()
67 self.sock_diag = sock_diag.SockDiag()
71 for socketpair in self.socketpairs.values():
74 super(SockDiagBaseTest, self).tearDown()
77 class SockDiagTest(SockDiagBaseTest):
79 def assertSockDiagMatchesSocket(self, s, diag_msg):
80 family = s.getsockopt(net_test.SOL_SOCKET, net_test.SO_DOMAIN)
81 self.assertEqual(diag_msg.family, family)
83 src, sport = s.getsockname()[0:2]
84 self.assertEqual(diag_msg.id.src, self.sock_diag.PaddedAddress(src))
85 self.assertEqual(diag_msg.id.sport, sport)
87 if self.sock_diag.GetDestinationAddress(diag_msg) not in ["0.0.0.0", "::"]:
88 dst, dport = s.getpeername()[0:2]
89 self.assertEqual(diag_msg.id.dst, self.sock_diag.PaddedAddress(dst))
90 self.assertEqual(diag_msg.id.dport, dport)
92 self.assertRaisesErrno(ENOTCONN, s.getpeername)
94 def testFindsMappedSockets(self):
95 """Tests that inet_diag_find_one_icsk can find mapped sockets.
97 Relevant kernel commits:
99 f77e059 net: diag: support v4mapped sockets in inet_diag_find_one_icsk()
101 socketpair = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
103 for sock in socketpair:
104 diag_msg = self.sock_diag.FindSockDiagFromFd(sock)
105 diag_req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
106 self.sock_diag.GetSockDiag(diag_req)
109 def testFindsAllMySockets(self):
110 """Tests that basic socket dumping works.
114 ab4a727 net: inet_diag: zero out uninitialized idiag_{src,dst} fields
116 3eb409b net: inet_diag: zero out uninitialized idiag_{src,dst} fields
118 self.socketpairs = self._CreateLotsOfSockets()
119 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
120 self.assertGreaterEqual(len(sockets), NUM_SOCKETS)
122 # Find the cookies for all of our sockets.
124 for diag_msg, unused_attrs in sockets:
125 addr = self.sock_diag.GetSourceAddress(diag_msg)
126 sport = diag_msg.id.sport
127 dport = diag_msg.id.dport
128 if (addr, sport, dport) in self.socketpairs:
129 cookies[(addr, sport, dport)] = diag_msg.id.cookie
130 elif (addr, dport, sport) in self.socketpairs:
131 cookies[(addr, sport, dport)] = diag_msg.id.cookie
133 # Did we find all the cookies?
134 self.assertEquals(2 * NUM_SOCKETS, len(cookies))
136 socketpairs = self.socketpairs.values()
137 random.shuffle(socketpairs)
138 for socketpair in socketpairs:
139 for sock in socketpair:
140 # Check that we can find a diag_msg by scanning a dump.
141 self.assertSockDiagMatchesSocket(
143 self.sock_diag.FindSockDiagFromFd(sock))
144 cookie = self.sock_diag.FindSockDiagFromFd(sock).id.cookie
146 # Check that we can find a diag_msg once we know the cookie.
147 req = self.sock_diag.DiagReqFromSocket(sock)
148 req.id.cookie = cookie
149 diag_msg = self.sock_diag.GetSockDiag(req)
150 req.states = 1 << diag_msg.state
151 self.assertSockDiagMatchesSocket(sock, diag_msg)
153 def testBytecodeCompilation(self):
154 # pylint: disable=bad-whitespace
156 (sock_diag.INET_DIAG_BC_S_GE, 1, 8, 0), # 0
157 (sock_diag.INET_DIAG_BC_D_LE, 1, 7, 0xffff), # 8
158 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::1", 128, -1)), # 16
159 (sock_diag.INET_DIAG_BC_JMP, 1, 3, None), # 44
160 (sock_diag.INET_DIAG_BC_S_COND, 2, 4, ("127.0.0.1", 32, -1)), # 48
161 (sock_diag.INET_DIAG_BC_D_LE, 1, 3, 0x6665), # not used # 64
162 (sock_diag.INET_DIAG_BC_NOP, 1, 1, None), # 72
166 # pylint: enable=bad-whitespace
167 bytecode = self.sock_diag.PackBytecode(instructions)
171 "071c20000a800000ffffffff00000000000000000000000000000001"
173 "0718200002200000ffffffff7f000001"
177 self.assertMultiLineEqual(expected, bytecode.encode("hex"))
178 self.assertEquals(76, len(bytecode))
179 self.socketpairs = self._CreateLotsOfSockets()
180 filteredsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
181 allsockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, NO_BYTECODE)
182 self.assertItemsEqual(allsockets, filteredsockets)
184 # Pick a few sockets in hash table order, and check that the bytecode we
185 # compiled selects them properly.
186 for socketpair in self.socketpairs.values()[:20]:
188 diag_msg = self.sock_diag.FindSockDiagFromFd(s)
190 (sock_diag.INET_DIAG_BC_S_GE, 1, 5, diag_msg.id.sport),
191 (sock_diag.INET_DIAG_BC_S_LE, 1, 4, diag_msg.id.sport),
192 (sock_diag.INET_DIAG_BC_D_GE, 1, 3, diag_msg.id.dport),
193 (sock_diag.INET_DIAG_BC_D_LE, 1, 2, diag_msg.id.dport),
195 bytecode = self.sock_diag.PackBytecode(instructions)
196 self.assertEquals(32, len(bytecode))
197 sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode)
198 self.assertEquals(1, len(sockets))
200 # TODO: why doesn't comparing the cstructs work?
201 self.assertEquals(diag_msg.Pack(), sockets[0][0].Pack())
203 def testCrossFamilyBytecode(self):
204 """Checks for a cross-family bug in inet_diag_hostcond matching.
206 Relevant kernel commits:
208 f67caec inet_diag: avoid unsafe and nonsensical prefix matches in inet_diag_bc_run()
210 # TODO: this is only here because the test fails if there are any open
211 # sockets other than the ones it creates itself. Make the bytecode more
212 # specific and remove it.
213 self.assertFalse(self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, ""))
215 unused_pair4 = net_test.CreateSocketPair(AF_INET, SOCK_STREAM, "127.0.0.1")
216 unused_pair6 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM, "::1")
218 bytecode4 = self.sock_diag.PackBytecode([
219 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("0.0.0.0", 0, -1))])
220 bytecode6 = self.sock_diag.PackBytecode([
221 (sock_diag.INET_DIAG_BC_S_COND, 1, 2, ("::", 0, -1))])
223 # IPv4/v6 filters must never match IPv6/IPv4 sockets...
224 v4sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode4)
225 self.assertTrue(v4sockets)
226 self.assertTrue(all(d.family == AF_INET for d, _ in v4sockets))
228 v6sockets = self.sock_diag.DumpAllInetSockets(IPPROTO_TCP, bytecode6)
229 self.assertTrue(v6sockets)
230 self.assertTrue(all(d.family == AF_INET6 for d, _ in v6sockets))
232 # Except for mapped addresses, which match both IPv4 and IPv6.
233 pair5 = net_test.CreateSocketPair(AF_INET6, SOCK_STREAM,
235 diag_msgs = [self.sock_diag.FindSockDiagFromFd(s) for s in pair5]
236 v4sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
238 v6sockets = [d for d, _ in self.sock_diag.DumpAllInetSockets(IPPROTO_TCP,
240 self.assertTrue(all(d in v4sockets for d in diag_msgs))
241 self.assertTrue(all(d in v6sockets for d in diag_msgs))
243 def testPortComparisonValidation(self):
244 """Checks for a bug in validating port comparison bytecode.
246 Relevant kernel commits:
248 5e1f542 inet_diag: validate port comparison byte code to prevent unsafe reads
250 bytecode = sock_diag.InetDiagBcOp((sock_diag.INET_DIAG_BC_D_GE, 4, 8))
251 self.assertRaisesErrno(
253 self.sock_diag.DumpAllInetSockets, IPPROTO_TCP, bytecode.Pack())
255 def testNonSockDiagCommand(self):
257 sock_id = self.sock_diag._EmptyInetDiagSockId()
258 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, 0xffffffff,
260 self.sock_diag._Dump(code, req, sock_diag.InetDiagMsg, "")
262 op = sock_diag.SOCK_DIAG_BY_FAMILY
263 DiagDump(op) # No errors? Good.
264 self.assertRaisesErrno(EINVAL, DiagDump, op + 17)
267 class SockDestroyTest(SockDiagBaseTest):
268 """Tests that SOCK_DESTROY works correctly.
270 Relevant kernel commits:
272 b613f56 net: diag: split inet_diag_dump_one_icsk into two
273 64be0ae net: diag: Add the ability to destroy a socket.
274 6eb5d2e net: diag: Support SOCK_DESTROY for inet sockets.
275 c1e64e2 net: diag: Support destroying TCP sockets.
276 2010b93 net: tcp: deal with listen sockets properly in tcp_abort.
279 d48ec88 net: diag: split inet_diag_dump_one_icsk into two
280 2438189 net: diag: Add the ability to destroy a socket.
281 7a2ddbc net: diag: Support SOCK_DESTROY for inet sockets.
282 44047b2 net: diag: Support destroying TCP sockets.
283 200dae7 net: tcp: deal with listen sockets properly in tcp_abort.
286 9eaff90 net: diag: split inet_diag_dump_one_icsk into two
287 d60326c net: diag: Add the ability to destroy a socket.
288 3d4ce85 net: diag: Support SOCK_DESTROY for inet sockets.
289 529dfc6 net: diag: Support destroying TCP sockets.
290 9c712fe net: tcp: deal with listen sockets properly in tcp_abort.
293 100263d net: diag: split inet_diag_dump_one_icsk into two
294 194c5f3 net: diag: Add the ability to destroy a socket.
295 8387ea2 net: diag: Support SOCK_DESTROY for inet sockets.
296 b80585a net: diag: Support destroying TCP sockets.
297 476c6ce net: tcp: deal with listen sockets properly in tcp_abort.
300 def testClosesSockets(self):
301 self.socketpairs = self._CreateLotsOfSockets()
302 for _, socketpair in self.socketpairs.iteritems():
303 # Close one of the sockets.
304 # This will send a RST that will close the other side as well.
305 s = random.choice(socketpair)
306 if random.randrange(0, 2) == 1:
307 self.sock_diag.CloseSocketFromFd(s)
309 diag_msg = self.sock_diag.FindSockDiagFromFd(s)
311 # Get the cookie wrong and ensure that we get an error and the socket
313 real_cookie = diag_msg.id.cookie
314 diag_msg.id.cookie = os.urandom(len(real_cookie))
315 req = self.sock_diag.DiagReqFromDiagMsg(diag_msg, IPPROTO_TCP)
316 self.assertRaisesErrno(ENOENT, self.sock_diag.CloseSocket, req)
317 self.assertSocketConnected(s)
319 # Now close it with the correct cookie.
320 req.id.cookie = real_cookie
321 self.sock_diag.CloseSocket(req)
323 # Check that both sockets in the pair are closed.
324 self.assertSocketsClosed(socketpair)
326 def testNonTcpSockets(self):
327 s = socket(AF_INET6, SOCK_DGRAM, 0)
328 s.connect(("::1", 53))
329 self.sock_diag.FindSockDiagFromFd(s) # No exceptions? Good.
330 self.assertRaisesErrno(EOPNOTSUPP, self.sock_diag.CloseSocketFromFd, s)
333 # Test that killing unix sockets returns EOPNOTSUPP.
336 class SocketExceptionThread(threading.Thread):
338 def __init__(self, sock, operation):
339 self.exception = None
340 super(SocketExceptionThread, self).__init__()
343 self.operation = operation
347 self.operation(self.sock)
352 class SockDiagTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
354 def testIpv4MappedSynRecvSocket(self):
355 """Tests for the absence of a bug with AF_INET6 TCP SYN-RECV sockets.
357 Relevant kernel commits:
359 457a04b inet_diag: fix oops for IPv4 AF_INET6 TCP SYN-RECV state
361 netid = random.choice(self.tuns.keys())
362 self.IncomingConnection(5, tcp_test.TCP_SYN_RECV, netid)
363 sock_id = self.sock_diag._EmptyInetDiagSockId()
364 sock_id.sport = self.port
365 states = 1 << tcp_test.TCP_SYN_RECV
366 req = sock_diag.InetDiagReqV2((AF_INET6, IPPROTO_TCP, 0, states, sock_id))
367 children = self.sock_diag.Dump(req, NO_BYTECODE)
369 self.assertTrue(children)
370 for child, unused_args in children:
371 self.assertEqual(tcp_test.TCP_SYN_RECV, child.state)
372 self.assertEqual(self.sock_diag.PaddedAddress(self.remoteaddr),
374 self.assertEqual(self.sock_diag.PaddedAddress(self.myaddr),
378 class SockDestroyTcpTest(tcp_test.TcpBaseTest, SockDiagBaseTest):
381 super(SockDestroyTcpTest, self).setUp()
382 self.netid = random.choice(self.tuns.keys())
384 def CheckRstOnClose(self, sock, req, expect_reset, msg, do_close=True):
385 """Closes the socket and checks whether a RST is sent or not."""
387 self.assertIsNone(req, "Must specify sock or req, not both")
388 self.sock_diag.CloseSocketFromFd(sock)
389 self.assertRaisesErrno(EINVAL, sock.accept)
391 self.assertIsNone(sock, "Must specify sock or req, not both")
392 self.sock_diag.CloseSocket(req)
395 desc, rst = self.RstPacket()
396 msg = "%s: expecting %s: " % (msg, desc)
397 self.ExpectPacketOn(self.netid, msg, rst)
400 self.ExpectNoPacketsOn(self.netid, msg)
402 if sock is not None and do_close:
405 def CheckTcpReset(self, state, statename):
406 for version in [4, 5, 6]:
407 msg = "Closing incoming IPv%d %s socket" % (version, statename)
408 self.IncomingConnection(version, state, self.netid)
409 self.CheckRstOnClose(self.s, None, False, msg)
410 if state != tcp_test.TCP_LISTEN:
411 msg = "Closing accepted IPv%d %s socket" % (version, statename)
412 self.CheckRstOnClose(self.accepted, None, True, msg)
414 def testTcpResets(self):
415 """Checks that closing sockets in appropriate states sends a RST."""
416 self.CheckTcpReset(tcp_test.TCP_LISTEN, "TCP_LISTEN")
417 self.CheckTcpReset(tcp_test.TCP_ESTABLISHED, "TCP_ESTABLISHED")
418 self.CheckTcpReset(tcp_test.TCP_CLOSE_WAIT, "TCP_CLOSE_WAIT")
420 def FindChildSockets(self, s):
421 """Finds the SYN_RECV child sockets of a given listening socket."""
422 d = self.sock_diag.FindSockDiagFromFd(self.s)
423 req = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
424 req.states = 1 << tcp_test.TCP_SYN_RECV | 1 << tcp_test.TCP_ESTABLISHED
425 req.id.cookie = "\x00" * 8
426 children = self.sock_diag.Dump(req, NO_BYTECODE)
427 return [self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
428 for d, _ in children]
430 def CheckChildSocket(self, version, statename, parent_first):
431 state = getattr(tcp_test, statename)
433 self.IncomingConnection(version, state, self.netid)
435 d = self.sock_diag.FindSockDiagFromFd(self.s)
436 parent = self.sock_diag.DiagReqFromDiagMsg(d, IPPROTO_TCP)
437 children = self.FindChildSockets(self.s)
438 self.assertEquals(1, len(children))
440 is_established = (state == tcp_test.TCP_NOT_YET_ACCEPTED)
442 # The new TCP listener code in 4.4 makes SYN_RECV sockets live in the
443 # regular TCP hash tables, and inet_diag_find_one_icsk can find them.
444 # Before 4.4, we can see those sockets in dumps, but we can't fetch
446 can_close_children = is_established or net_test.LINUX_VERSION >= (4, 4)
448 for child in children:
449 if can_close_children:
450 self.sock_diag.GetSockDiag(child) # No errors? Good, child found.
452 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
454 def CloseParent(expect_reset):
455 msg = "Closing parent IPv%d %s socket %s child" % (
456 version, statename, "before" if parent_first else "after")
457 self.CheckRstOnClose(self.s, None, expect_reset, msg)
458 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, parent)
460 def CheckChildrenClosed():
461 for child in children:
462 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
465 for child in children:
466 msg = "Closing child IPv%d %s socket %s parent" % (
467 version, statename, "after" if parent_first else "before")
468 self.sock_diag.GetSockDiag(child)
469 self.CheckRstOnClose(None, child, is_established, msg)
470 self.assertRaisesErrno(ENOENT, self.sock_diag.GetSockDiag, child)
471 CheckChildrenClosed()
474 # Closing the parent will close child sockets, which will send a RST,
475 # iff they are already established.
476 CloseParent(is_established)
478 CheckChildrenClosed()
479 elif can_close_children:
481 CheckChildrenClosed()
484 if can_close_children:
489 def testChildSockets(self):
490 for version in [4, 5, 6]:
491 self.CheckChildSocket(version, "TCP_SYN_RECV", False)
492 self.CheckChildSocket(version, "TCP_SYN_RECV", True)
493 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", False)
494 self.CheckChildSocket(version, "TCP_NOT_YET_ACCEPTED", True)
496 def CloseDuringBlockingCall(self, sock, call, expected_errno):
497 thread = SocketExceptionThread(sock, call)
500 self.sock_diag.CloseSocketFromFd(sock)
502 self.assertFalse(thread.is_alive())
503 self.assertIsNotNone(thread.exception)
504 self.assertTrue(isinstance(thread.exception, IOError),
505 "Expected IOError, got %s" % thread.exception)
506 self.assertEqual(expected_errno, thread.exception.errno)
507 self.assertSocketClosed(sock)
509 def testAcceptInterrupted(self):
510 """Tests that accept() is interrupted by SOCK_DESTROY."""
511 for version in [4, 5, 6]:
512 self.IncomingConnection(version, tcp_test.TCP_LISTEN, self.netid)
513 self.CloseDuringBlockingCall(self.s, lambda sock: sock.accept(), EINVAL)
514 self.assertRaisesErrno(ECONNABORTED, self.s.send, "foo")
515 self.assertRaisesErrno(EINVAL, self.s.accept)
517 def testReadInterrupted(self):
518 """Tests that read() is interrupted by SOCK_DESTROY."""
519 for version in [4, 5, 6]:
520 self.IncomingConnection(version, tcp_test.TCP_ESTABLISHED, self.netid)
521 self.CloseDuringBlockingCall(self.accepted, lambda sock: sock.recv(4096),
523 self.assertRaisesErrno(EPIPE, self.accepted.send, "foo")
525 def testConnectInterrupted(self):
526 """Tests that connect() is interrupted by SOCK_DESTROY."""
527 for version in [4, 5, 6]:
528 family = {4: AF_INET, 5: AF_INET6, 6: AF_INET6}[version]
529 s = net_test.Socket(family, SOCK_STREAM, IPPROTO_TCP)
530 self.SelectInterface(s, self.netid, "mark")
532 remoteaddr = "::ffff:" + self.GetRemoteAddress(4)
535 remoteaddr = self.GetRemoteAddress(version)
537 _, sport = s.getsockname()[:2]
538 self.CloseDuringBlockingCall(
539 s, lambda sock: sock.connect((remoteaddr, 53)), ECONNABORTED)
540 desc, syn = packets.SYN(53, version, self.MyAddress(version, self.netid),
541 remoteaddr, sport=sport, seq=None)
542 self.ExpectPacketOn(self.netid, desc, syn)
543 msg = "SOCK_DESTROY of socket in connect, expected no RST"
544 self.ExpectNoPacketsOn(self.netid, msg)
547 if __name__ == "__main__":