OSDN Git Service

Merge tag 'for-upstream' of https://gitlab.com/bonzini/qemu into staging
[qmiga/qemu.git] / python / tests / protocol.py
1 import asyncio
2 from contextlib import contextmanager
3 import os
4 import socket
5 from tempfile import TemporaryDirectory
6
7 import avocado
8
9 from qemu.qmp import ConnectError, Runstate
10 from qemu.qmp.protocol import AsyncProtocol, StateError
11 from qemu.qmp.util import asyncio_run, create_task
12
13
14 class NullProtocol(AsyncProtocol[None]):
15     """
16     NullProtocol is a test mockup of an AsyncProtocol implementation.
17
18     It adds a fake_session instance variable that enables a code path
19     that bypasses the actual connection logic, but still allows the
20     reader/writers to start.
21
22     Because the message type is defined as None, an asyncio.Event named
23     'trigger_input' is created that prohibits the reader from
24     incessantly being able to yield None; this event can be poked to
25     simulate an incoming message.
26
27     For testing symmetry with do_recv, an interface is added to "send" a
28     Null message.
29
30     For testing purposes, a "simulate_disconnection" method is also
31     added which allows us to trigger a bottom half disconnect without
32     injecting any real errors into the reader/writer loops; in essence
33     it performs exactly half of what disconnect() normally does.
34     """
35     def __init__(self, name=None):
36         self.fake_session = False
37         self.trigger_input: asyncio.Event
38         super().__init__(name)
39
40     async def _establish_session(self):
41         self.trigger_input = asyncio.Event()
42         await super()._establish_session()
43
44     async def _do_start_server(self, address, ssl=None):
45         if self.fake_session:
46             self._accepted = asyncio.Event()
47             self._set_state(Runstate.CONNECTING)
48             await asyncio.sleep(0)
49         else:
50             await super()._do_start_server(address, ssl)
51
52     async def _do_accept(self):
53         if self.fake_session:
54             self._accepted = None
55         else:
56             await super()._do_accept()
57
58     async def _do_connect(self, address, ssl=None):
59         if self.fake_session:
60             self._set_state(Runstate.CONNECTING)
61             await asyncio.sleep(0)
62         else:
63             await super()._do_connect(address, ssl)
64
65     async def _do_recv(self) -> None:
66         await self.trigger_input.wait()
67         self.trigger_input.clear()
68
69     def _do_send(self, msg: None) -> None:
70         pass
71
72     async def send_msg(self) -> None:
73         await self._outgoing.put(None)
74
75     async def simulate_disconnect(self) -> None:
76         """
77         Simulates a bottom-half disconnect.
78
79         This method schedules a disconnection but does not wait for it
80         to complete. This is used to put the loop into the DISCONNECTING
81         state without fully quiescing it back to IDLE. This is normally
82         something you cannot coax AsyncProtocol to do on purpose, but it
83         will be similar to what happens with an unhandled Exception in
84         the reader/writer.
85
86         Under normal circumstances, the library design requires you to
87         await on disconnect(), which awaits the disconnect task and
88         returns bottom half errors as a pre-condition to allowing the
89         loop to return back to IDLE.
90         """
91         self._schedule_disconnect()
92
93
94 class LineProtocol(AsyncProtocol[str]):
95     def __init__(self, name=None):
96         super().__init__(name)
97         self.rx_history = []
98
99     async def _do_recv(self) -> str:
100         raw = await self._readline()
101         msg = raw.decode()
102         self.rx_history.append(msg)
103         return msg
104
105     def _do_send(self, msg: str) -> None:
106         assert self._writer is not None
107         self._writer.write(msg.encode() + b'\n')
108
109     async def send_msg(self, msg: str) -> None:
110         await self._outgoing.put(msg)
111
112
113 def run_as_task(coro, allow_cancellation=False):
114     """
115     Run a given coroutine as a task.
116
117     Optionally, wrap it in a try..except block that allows this
118     coroutine to be canceled gracefully.
119     """
120     async def _runner():
121         try:
122             await coro
123         except asyncio.CancelledError:
124             if allow_cancellation:
125                 return
126             raise
127     return create_task(_runner())
128
129
130 @contextmanager
131 def jammed_socket():
132     """
133     Opens up a random unused TCP port on localhost, then jams it.
134     """
135     socks = []
136
137     try:
138         sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
139         sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
140         sock.bind(('127.0.0.1', 0))
141         sock.listen(1)
142         address = sock.getsockname()
143
144         socks.append(sock)
145
146         # I don't *fully* understand why, but it takes *two* un-accepted
147         # connections to start jamming the socket.
148         for _ in range(2):
149             sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
150             sock.connect(address)
151             socks.append(sock)
152
153         yield address
154
155     finally:
156         for sock in socks:
157             sock.close()
158
159
160 class Smoke(avocado.Test):
161
162     def setUp(self):
163         self.proto = NullProtocol()
164
165     def test__repr__(self):
166         self.assertEqual(
167             repr(self.proto),
168             "<NullProtocol runstate=IDLE>"
169         )
170
171     def testRunstate(self):
172         self.assertEqual(
173             self.proto.runstate,
174             Runstate.IDLE
175         )
176
177     def testDefaultName(self):
178         self.assertEqual(
179             self.proto.name,
180             None
181         )
182
183     def testLogger(self):
184         self.assertEqual(
185             self.proto.logger.name,
186             'qemu.qmp.protocol'
187         )
188
189     def testName(self):
190         self.proto = NullProtocol('Steve')
191
192         self.assertEqual(
193             self.proto.name,
194             'Steve'
195         )
196
197         self.assertEqual(
198             self.proto.logger.name,
199             'qemu.qmp.protocol.Steve'
200         )
201
202         self.assertEqual(
203             repr(self.proto),
204             "<NullProtocol name='Steve' runstate=IDLE>"
205         )
206
207
208 class TestBase(avocado.Test):
209
210     def setUp(self):
211         self.proto = NullProtocol(type(self).__name__)
212         self.assertEqual(self.proto.runstate, Runstate.IDLE)
213         self.runstate_watcher = None
214
215     def tearDown(self):
216         self.assertEqual(self.proto.runstate, Runstate.IDLE)
217
218     async def _asyncSetUp(self):
219         pass
220
221     async def _asyncTearDown(self):
222         if self.runstate_watcher:
223             await self.runstate_watcher
224
225     @staticmethod
226     def async_test(async_test_method):
227         """
228         Decorator; adds SetUp and TearDown to async tests.
229         """
230         async def _wrapper(self, *args, **kwargs):
231             loop = asyncio.get_event_loop()
232             loop.set_debug(True)
233
234             await self._asyncSetUp()
235             await async_test_method(self, *args, **kwargs)
236             await self._asyncTearDown()
237
238         return _wrapper
239
240     # Definitions
241
242     # The states we expect a "bad" connect/accept attempt to transition through
243     BAD_CONNECTION_STATES = (
244         Runstate.CONNECTING,
245         Runstate.DISCONNECTING,
246         Runstate.IDLE,
247     )
248
249     # The states we expect a "good" session to transition through
250     GOOD_CONNECTION_STATES = (
251         Runstate.CONNECTING,
252         Runstate.RUNNING,
253         Runstate.DISCONNECTING,
254         Runstate.IDLE,
255     )
256
257     # Helpers
258
259     async def _watch_runstates(self, *states):
260         """
261         This launches a task alongside (most) tests below to confirm that
262         the sequence of runstate changes that occur is exactly as
263         anticipated.
264         """
265         async def _watcher():
266             for state in states:
267                 new_state = await self.proto.runstate_changed()
268                 self.assertEqual(
269                     new_state,
270                     state,
271                     msg=f"Expected state '{state.name}'",
272                 )
273
274         self.runstate_watcher = create_task(_watcher())
275         # Kick the loop and force the task to block on the event.
276         await asyncio.sleep(0)
277
278
279 class State(TestBase):
280
281     @TestBase.async_test
282     async def testSuperfluousDisconnect(self):
283         """
284         Test calling disconnect() while already disconnected.
285         """
286         await self._watch_runstates(
287             Runstate.DISCONNECTING,
288             Runstate.IDLE,
289         )
290         await self.proto.disconnect()
291
292
293 class Connect(TestBase):
294     """
295     Tests primarily related to calling Connect().
296     """
297     async def _bad_connection(self, family: str):
298         assert family in ('INET', 'UNIX')
299
300         if family == 'INET':
301             await self.proto.connect(('127.0.0.1', 0))
302         elif family == 'UNIX':
303             await self.proto.connect('/dev/null')
304
305     async def _hanging_connection(self):
306         with jammed_socket() as addr:
307             await self.proto.connect(addr)
308
309     async def _bad_connection_test(self, family: str):
310         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
311
312         with self.assertRaises(ConnectError) as context:
313             await self._bad_connection(family)
314
315         self.assertIsInstance(context.exception.exc, OSError)
316         self.assertEqual(
317             context.exception.error_message,
318             "Failed to establish connection"
319         )
320
321     @TestBase.async_test
322     async def testBadINET(self):
323         """
324         Test an immediately rejected call to an IP target.
325         """
326         await self._bad_connection_test('INET')
327
328     @TestBase.async_test
329     async def testBadUNIX(self):
330         """
331         Test an immediately rejected call to a UNIX socket target.
332         """
333         await self._bad_connection_test('UNIX')
334
335     @TestBase.async_test
336     async def testCancellation(self):
337         """
338         Test what happens when a connection attempt is aborted.
339         """
340         # Note that accept() cannot be cancelled outright, as it isn't a task.
341         # However, we can wrap it in a task and cancel *that*.
342         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
343         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
344
345         state = await self.proto.runstate_changed()
346         self.assertEqual(state, Runstate.CONNECTING)
347
348         # This is insider baseball, but the connection attempt has
349         # yielded *just* before the actual connection attempt, so kick
350         # the loop to make sure it's truly wedged.
351         await asyncio.sleep(0)
352
353         task.cancel()
354         await task
355
356     @TestBase.async_test
357     async def testTimeout(self):
358         """
359         Test what happens when a connection attempt times out.
360         """
361         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
362         task = run_as_task(self._hanging_connection())
363
364         # More insider baseball: to improve the speed of this test while
365         # guaranteeing that the connection even gets a chance to start,
366         # verify that the connection hangs *first*, then await the
367         # result of the task with a nearly-zero timeout.
368
369         state = await self.proto.runstate_changed()
370         self.assertEqual(state, Runstate.CONNECTING)
371         await asyncio.sleep(0)
372
373         with self.assertRaises(asyncio.TimeoutError):
374             await asyncio.wait_for(task, timeout=0)
375
376     @TestBase.async_test
377     async def testRequire(self):
378         """
379         Test what happens when a connection attempt is made while CONNECTING.
380         """
381         await self._watch_runstates(*self.BAD_CONNECTION_STATES)
382         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
383
384         state = await self.proto.runstate_changed()
385         self.assertEqual(state, Runstate.CONNECTING)
386
387         with self.assertRaises(StateError) as context:
388             await self._bad_connection('UNIX')
389
390         self.assertEqual(
391             context.exception.error_message,
392             "NullProtocol is currently connecting."
393         )
394         self.assertEqual(context.exception.state, Runstate.CONNECTING)
395         self.assertEqual(context.exception.required, Runstate.IDLE)
396
397         task.cancel()
398         await task
399
400     @TestBase.async_test
401     async def testImplicitRunstateInit(self):
402         """
403         Test what happens if we do not wait on the runstate event until
404         AFTER a connection is made, i.e., connect()/accept() themselves
405         initialize the runstate event. All of the above tests force the
406         initialization by waiting on the runstate *first*.
407         """
408         task = run_as_task(self._hanging_connection(), allow_cancellation=True)
409
410         # Kick the loop to coerce the state change
411         await asyncio.sleep(0)
412         assert self.proto.runstate == Runstate.CONNECTING
413
414         # We already missed the transition to CONNECTING
415         await self._watch_runstates(Runstate.DISCONNECTING, Runstate.IDLE)
416
417         task.cancel()
418         await task
419
420
421 class Accept(Connect):
422     """
423     All of the same tests as Connect, but using the accept() interface.
424     """
425     async def _bad_connection(self, family: str):
426         assert family in ('INET', 'UNIX')
427
428         if family == 'INET':
429             await self.proto.start_server_and_accept(('example.com', 1))
430         elif family == 'UNIX':
431             await self.proto.start_server_and_accept('/dev/null')
432
433     async def _hanging_connection(self):
434         with TemporaryDirectory(suffix='.qmp') as tmpdir:
435             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
436             await self.proto.start_server_and_accept(sock)
437
438
439 class FakeSession(TestBase):
440
441     def setUp(self):
442         super().setUp()
443         self.proto.fake_session = True
444
445     async def _asyncSetUp(self):
446         await super()._asyncSetUp()
447         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
448
449     async def _asyncTearDown(self):
450         await self.proto.disconnect()
451         await super()._asyncTearDown()
452
453     ####
454
455     @TestBase.async_test
456     async def testFakeConnect(self):
457
458         """Test the full state lifecycle (via connect) with a no-op session."""
459         await self.proto.connect('/not/a/real/path')
460         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
461
462     @TestBase.async_test
463     async def testFakeAccept(self):
464         """Test the full state lifecycle (via accept) with a no-op session."""
465         await self.proto.start_server_and_accept('/not/a/real/path')
466         self.assertEqual(self.proto.runstate, Runstate.RUNNING)
467
468     @TestBase.async_test
469     async def testFakeRecv(self):
470         """Test receiving a fake/null message."""
471         await self.proto.start_server_and_accept('/not/a/real/path')
472
473         logname = self.proto.logger.name
474         with self.assertLogs(logname, level='DEBUG') as context:
475             self.proto.trigger_input.set()
476             self.proto.trigger_input.clear()
477             await asyncio.sleep(0)  # Kick reader.
478
479         self.assertEqual(
480             context.output,
481             [f"DEBUG:{logname}:<-- None"],
482         )
483
484     @TestBase.async_test
485     async def testFakeSend(self):
486         """Test sending a fake/null message."""
487         await self.proto.start_server_and_accept('/not/a/real/path')
488
489         logname = self.proto.logger.name
490         with self.assertLogs(logname, level='DEBUG') as context:
491             # Cheat: Send a Null message to nobody.
492             await self.proto.send_msg()
493             # Kick writer; awaiting on a queue.put isn't sufficient to yield.
494             await asyncio.sleep(0)
495
496         self.assertEqual(
497             context.output,
498             [f"DEBUG:{logname}:--> None"],
499         )
500
501     async def _prod_session_api(
502             self,
503             current_state: Runstate,
504             error_message: str,
505             accept: bool = True
506     ):
507         with self.assertRaises(StateError) as context:
508             if accept:
509                 await self.proto.start_server_and_accept('/not/a/real/path')
510             else:
511                 await self.proto.connect('/not/a/real/path')
512
513         self.assertEqual(context.exception.error_message, error_message)
514         self.assertEqual(context.exception.state, current_state)
515         self.assertEqual(context.exception.required, Runstate.IDLE)
516
517     @TestBase.async_test
518     async def testAcceptRequireRunning(self):
519         """Test that accept() cannot be called when Runstate=RUNNING"""
520         await self.proto.start_server_and_accept('/not/a/real/path')
521
522         await self._prod_session_api(
523             Runstate.RUNNING,
524             "NullProtocol is already connected and running.",
525             accept=True,
526         )
527
528     @TestBase.async_test
529     async def testConnectRequireRunning(self):
530         """Test that connect() cannot be called when Runstate=RUNNING"""
531         await self.proto.start_server_and_accept('/not/a/real/path')
532
533         await self._prod_session_api(
534             Runstate.RUNNING,
535             "NullProtocol is already connected and running.",
536             accept=False,
537         )
538
539     @TestBase.async_test
540     async def testAcceptRequireDisconnecting(self):
541         """Test that accept() cannot be called when Runstate=DISCONNECTING"""
542         await self.proto.start_server_and_accept('/not/a/real/path')
543
544         # Cheat: force a disconnect.
545         await self.proto.simulate_disconnect()
546
547         await self._prod_session_api(
548             Runstate.DISCONNECTING,
549             ("NullProtocol is disconnecting."
550              " Call disconnect() to return to IDLE state."),
551             accept=True,
552         )
553
554     @TestBase.async_test
555     async def testConnectRequireDisconnecting(self):
556         """Test that connect() cannot be called when Runstate=DISCONNECTING"""
557         await self.proto.start_server_and_accept('/not/a/real/path')
558
559         # Cheat: force a disconnect.
560         await self.proto.simulate_disconnect()
561
562         await self._prod_session_api(
563             Runstate.DISCONNECTING,
564             ("NullProtocol is disconnecting."
565              " Call disconnect() to return to IDLE state."),
566             accept=False,
567         )
568
569
570 class SimpleSession(TestBase):
571
572     def setUp(self):
573         super().setUp()
574         self.server = LineProtocol(type(self).__name__ + '-server')
575
576     async def _asyncSetUp(self):
577         await super()._asyncSetUp()
578         await self._watch_runstates(*self.GOOD_CONNECTION_STATES)
579
580     async def _asyncTearDown(self):
581         await self.proto.disconnect()
582         try:
583             await self.server.disconnect()
584         except EOFError:
585             pass
586         await super()._asyncTearDown()
587
588     @TestBase.async_test
589     async def testSmoke(self):
590         with TemporaryDirectory(suffix='.qmp') as tmpdir:
591             sock = os.path.join(tmpdir, type(self.proto).__name__ + ".sock")
592             server_task = create_task(self.server.start_server_and_accept(sock))
593
594             # give the server a chance to start listening [...]
595             await asyncio.sleep(0)
596             await self.proto.connect(sock)