From: Zach Johnson Date: Sat, 29 Feb 2020 06:34:08 +0000 (-0800) Subject: Add IEventStream, so other subjects can use emit, etc X-Git-Url: http://git.osdn.net/view?a=commitdiff_plain;h=2d949caf09870a6d6e1364e2d8223638ba5fa387;p=android-x86%2Fsystem-bt.git Add IEventStream, so other subjects can use emit, etc And use in PyHci & PyAclManager. Can't quite turn on filtering on the ACL messages, due to the handles being wrong. Test: cert/run --host --test_filter=AclManagerTest Change-Id: I30ef332b06ae553553757337cd66da91f8debe3a --- diff --git a/gd/cert/event_stream.py b/gd/cert/event_stream.py index a79329958..4bbc5ea3b 100644 --- a/gd/cert/event_stream.py +++ b/gd/cert/event_stream.py @@ -24,14 +24,44 @@ from google.protobuf import text_format from concurrent.futures import ThreadPoolExecutor from grpc import RpcError +from abc import ABC, abstractmethod -class EventStream(object): + +class IEventStream(ABC): + + @abstractmethod + def get_event_queue(self): + pass + + +class FilteringEventStream(IEventStream): + + def __init__(self, stream, filter_fn): + self.filter_fn = filter_fn + self.event_queue = SimpleQueue() + self.stream = stream + + self.stream.register_callback(self.__event_callback, self.filter_fn) + + def __event_callback(self, event): + self.event_queue.put(event) + + def get_event_queue(self): + return self.event_queue + + def unregister(self): + self.stream.unregister(self.__event_callback) + + +DEFAULT_TIMEOUT_SECONDS = 3 + + +class EventStream(IEventStream): """ A class that streams events from a gRPC stream, which you can assert on. Don't use these asserts directly, use the ones from cert.truth. """ - DEFAULT_TIMEOUT_SECONDS = 3 def __init__(self, server_stream_call): if server_stream_call is None: @@ -53,11 +83,8 @@ class EventStream(object): def __del__(self): self.shutdown() - def remaining_time_delta(self, end_time): - remaining = end_time - datetime.now() - if remaining < timedelta(milliseconds=0): - remaining = timedelta(milliseconds=0) - return remaining + def get_event_queue(self): + return self.event_queue def shutdown(self): """ @@ -170,7 +197,7 @@ class EventStream(object): event = None end_time = datetime.now() + timeout while event is None and datetime.now() < end_time: - remaining = self.remaining_time_delta(end_time) + remaining = static_remaining_time_delta(end_time) logging.debug("Waiting for event (%fs remaining)" % (remaining.total_seconds())) try: @@ -202,26 +229,7 @@ class EventStream(object): happen :return: """ - logging.debug("assert_event_occurs %d %fs" % (at_least_times, - timeout.total_seconds())) - event_list = [] - end_time = datetime.now() + timeout - while len(event_list) < at_least_times and datetime.now() < end_time: - remaining = self.remaining_time_delta(end_time) - logging.debug("Waiting for event (%fs remaining)" % - (remaining.total_seconds())) - try: - current_event = self.event_queue.get( - timeout=remaining.total_seconds()) - if match_fn(current_event): - event_list.append(current_event) - except Empty: - continue - logging.debug("Done waiting for event") - asserts.assert_true( - len(event_list) >= at_least_times, - msg=("Expected at least %d events, but got %d" % (at_least_times, - len(event_list)))) + NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout) def assert_event_occurs_at_most( self, @@ -242,7 +250,7 @@ class EventStream(object): event_list = [] end_time = datetime.now() + timeout while len(event_list) <= at_most_times and datetime.now() < end_time: - remaining = self.remaining_time_delta(end_time) + remaining = static_remaining_time_delta(end_time) logging.debug("Waiting for event iteration (%fs remaining)" % (remaining.total_seconds())) try: @@ -263,36 +271,79 @@ class EventStream(object): match_fns, order_matters, timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): - logging.debug("assert_all_events_occur %fs" % timeout.total_seconds()) - pending_matches = list(match_fns) - matched_order = [] - end_time = datetime.now() + timeout - while len(pending_matches) > 0 and datetime.now() < end_time: - remaining = self.remaining_time_delta(end_time) - logging.debug("Waiting for event (%fs remaining)" % - (remaining.total_seconds())) - try: - current_event = self.event_queue.get( - timeout=remaining.total_seconds()) - for match_fn in pending_matches: - if match_fn(current_event): - pending_matches.remove(match_fn) - matched_order.append(match_fn) - except Empty: - continue - logging.debug("Done waiting for event") + NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters, + timeout) + + +def static_remaining_time_delta(end_time): + remaining = end_time - datetime.now() + if remaining < timedelta(milliseconds=0): + remaining = timedelta(milliseconds=0) + return remaining + + +def NOT_FOR_YOU_assert_event_occurs( + istream, + match_fn, + at_least_times=1, + timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): + logging.debug("assert_event_occurs %d %fs" % (at_least_times, + timeout.total_seconds())) + event_list = [] + end_time = datetime.now() + timeout + while len(event_list) < at_least_times and datetime.now() < end_time: + remaining = static_remaining_time_delta(end_time) + logging.debug( + "Waiting for event (%fs remaining)" % (remaining.total_seconds())) + try: + current_event = istream.get_event_queue().get( + timeout=remaining.total_seconds()) + if match_fn(current_event): + event_list.append(current_event) + except Empty: + continue + logging.debug("Done waiting for event") + asserts.assert_true( + len(event_list) >= at_least_times, + msg=("Expected at least %d events, but got %d" % (at_least_times, + len(event_list)))) + + +def NOT_FOR_YOU_assert_all_events_occur( + istream, + match_fns, + order_matters, + timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)): + logging.debug("assert_all_events_occur %fs" % timeout.total_seconds()) + pending_matches = list(match_fns) + matched_order = [] + end_time = datetime.now() + timeout + while len(pending_matches) > 0 and datetime.now() < end_time: + remaining = static_remaining_time_delta(end_time) + logging.debug( + "Waiting for event (%fs remaining)" % (remaining.total_seconds())) + try: + current_event = istream.get_event_queue().get( + timeout=remaining.total_seconds()) + for match_fn in pending_matches: + if match_fn(current_event): + pending_matches.remove(match_fn) + matched_order.append(match_fn) + except Empty: + continue + logging.debug("Done waiting for event") + asserts.assert_true( + len(matched_order) == len(match_fns), + msg=("Expected at least %d events, but got %d" % (len(match_fns), + len(matched_order)))) + if order_matters: + correct_order = True + i = 0 + while i < len(match_fns): + if match_fns[i] is not matched_order[i]: + correct_order = False + break + i += 1 asserts.assert_true( - len(matched_order) == len(match_fns), - msg=("Expected at least %d events, but got %d" % - (len(match_fns), len(matched_order)))) - if order_matters: - correct_order = True - i = 0 - while i < len(match_fns): - if match_fns[i] is not matched_order[i]: - correct_order = False - break - i += 1 - asserts.assert_true( - correct_order, "Events not received in correct order %s %s" % - (match_fns, matched_order)) + correct_order, "Events not received in correct order %s %s" % + (match_fns, matched_order)) diff --git a/gd/cert/truth.py b/gd/cert/truth.py index 047bffe29..50a14c028 100644 --- a/gd/cert/truth.py +++ b/gd/cert/truth.py @@ -20,7 +20,9 @@ from mobly.asserts import assert_true from mobly.asserts import assert_false from mobly import signals -from cert.event_stream import EventStream +from cert.event_stream import IEventStream +from cert.event_stream import NOT_FOR_YOU_assert_event_occurs +from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur import sys, traceback @@ -63,7 +65,7 @@ class EventStreamSubject(ObjectSubject): if len(match_fns) == 0: raise signals.TestFailure("Must specify a match function") elif len(match_fns) == 1: - self._value.assert_event_occurs(match_fns[0]) + NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0]) return EventStreamContinuationSubject(self._value) else: return MultiMatchStreamSubject(self._value, match_fns) @@ -76,13 +78,13 @@ class MultiMatchStreamSubject(object): self._match_fns = match_fns def inAnyOrder(self): - self._stream.assert_all_events_occur( - self._match_fns, order_matters=False) + NOT_FOR_YOU_assert_all_events_occur( + self._stream, self._match_fns, order_matters=False) return EventStreamContinuationSubject(self._stream) def inOrder(self): - self._stream.assert_all_events_occur( - self._match_fns, order_matters=True) + NOT_FOR_YOU_assert_all_events_occur( + self._stream, self._match_fns, order_matters=True) return EventStreamContinuationSubject(self._stream) @@ -95,7 +97,7 @@ class EventStreamContinuationSubject(ObjectSubject): if len(match_fns) == 0: raise signals.TestFailure("Must specify a match function") elif len(match_fns) == 1: - self._value.assert_event_occurs(match_fns[0]) + NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0]) return EventStreamContinuationSubject(self._value) else: return MultiMatchStreamSubject(self._value, match_fns) @@ -116,7 +118,7 @@ class BooleanSubject(ObjectSubject): def assertThat(subject): if type(subject) is bool: return BooleanSubject(subject) - elif isinstance(subject, EventStream): + elif isinstance(subject, IEventStream): return EventStreamSubject(subject) else: return ObjectSubject(subject) diff --git a/gd/hci/cert/acl_manager_test.py b/gd/hci/cert/acl_manager_test.py index 8d4dd335b..bff5962bb 100644 --- a/gd/hci/cert/acl_manager_test.py +++ b/gd/hci/cert/acl_manager_test.py @@ -58,9 +58,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass): b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT' ) - assertThat(cert_hci.get_acl_stream()).emits( + assertThat(cert_acl).emits( lambda packet: b'SomeMoreAclData' in packet.data) - assertThat(dut_acl_manager.get_acl_stream()).emits( + assertThat(dut_acl).emits( lambda packet: b'SomeAclData' in packet.payload) def test_cert_connects(self): @@ -85,9 +85,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass): cert_acl.send_first( b'\x26\x00\x07\x00This is just SomeAclData from the Cert') - assertThat(cert_hci.get_acl_stream()).emits( + assertThat(cert_acl).emits( lambda packet: b'SomeMoreAclData' in packet.data) - assertThat(dut_acl_manager.get_acl_stream()).emits( + assertThat(dut_acl).emits( lambda packet: b'SomeAclData' in packet.payload) def test_recombination_l2cap_packet(self): @@ -105,6 +105,6 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass): dut_acl.wait_for_connection_complete() - assertThat(dut_acl_manager.get_acl_stream()).emits( + assertThat(dut_acl).emits( lambda packet: b'Hello!' in packet.payload, lambda packet: b'Hello' * 200 in packet.payload).inOrder() diff --git a/gd/hci/cert/py_acl_manager.py b/gd/hci/cert/py_acl_manager.py index 7766433ea..b7a4028d7 100644 --- a/gd/hci/cert/py_acl_manager.py +++ b/gd/hci/cert/py_acl_manager.py @@ -16,6 +16,8 @@ from google.protobuf import empty_pb2 as empty_proto from cert.event_stream import EventStream +from cert.event_stream import FilteringEventStream +from cert.event_stream import IEventStream from captures import ReadBdAddrCompleteCapture from captures import ConnectionCompleteCapture from captures import ConnectionRequestCapture @@ -25,11 +27,13 @@ from hci.facade import facade_pb2 as hci_facade from hci.facade import acl_manager_facade_pb2 as acl_manager_facade -class PyAclManagerAclConnection(object): +class PyAclManagerAclConnection(IEventStream): - def __init__(self, device, remote_addr, handle): + def __init__(self, device, acl_stream, remote_addr, handle): self.device = device self.handle = handle + # todo enable filtering after sorting out handles + self.our_acl_stream = FilteringEventStream(acl_stream, None) if remote_addr: self.connection_event_stream = EventStream( @@ -64,6 +68,9 @@ class PyAclManagerAclConnection(object): self.device.hci_acl_manager.SendAclData( acl_manager_facade.AclData(handle=self.handle, payload=bytes(data))) + def get_event_queue(self): + return self.our_acl_stream.get_event_queue() + class PyAclManager(object): @@ -94,14 +101,13 @@ class PyAclManager(object): self.device.hci_acl_manager.FetchIncomingConnection( empty_proto.Empty())) - def get_acl_stream(self): - return self.acl_stream - def initiate_connection(self, remote_addr): - return PyAclManagerAclConnection(self.device, remote_addr, None) + return PyAclManagerAclConnection(self.device, self.acl_stream, + remote_addr, None) def accept_connection(self): connection_complete = ConnectionCompleteCapture() assertThat(self.incoming_connection_stream).emits(connection_complete) handle = connection_complete.get().GetConnectionHandle() - return PyAclManagerAclConnection(self.device, None, handle) + return PyAclManagerAclConnection(self.device, self.acl_stream, None, + handle) diff --git a/gd/hci/cert/py_hci.py b/gd/hci/cert/py_hci.py index 548a12aa8..47b84a055 100644 --- a/gd/hci/cert/py_hci.py +++ b/gd/hci/cert/py_hci.py @@ -16,6 +16,8 @@ from google.protobuf import empty_pb2 as empty_proto from cert.event_stream import EventStream +from cert.event_stream import FilteringEventStream +from cert.event_stream import IEventStream from captures import ReadBdAddrCompleteCapture from captures import ConnectionCompleteCapture from captures import ConnectionRequestCapture @@ -24,16 +26,17 @@ from cert.truth import assertThat from hci.facade import facade_pb2 as hci_facade -class PyHciAclConnection(object): +class PyHciAclConnection(IEventStream): def __init__(self, handle, acl_stream, device): - self.handle = handle - self.acl_stream = acl_stream + self.handle = int(handle) self.device = device + # todo, handle we got is 0, so doesn't match - fix before enabling filtering + self.our_acl_stream = FilteringEventStream(acl_stream, None) def send(self, pb_flag, b_flag, data): acl_msg = hci_facade.AclMsg( - handle=int(self.handle), + handle=self.handle, packet_boundary_flag=int(pb_flag), broadcast_flag=int(b_flag), data=data) @@ -47,6 +50,9 @@ class PyHciAclConnection(object): self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT, hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data)) + def get_event_queue(self): + return self.our_acl_stream.get_event_queue() + class PyHci(object): @@ -80,9 +86,6 @@ class PyHci(object): def get_event_stream(self): return self.event_stream - def get_acl_stream(self): - return self.acl_stream - def send_command_with_complete(self, command): self.device.hci.send_command_with_complete(command)