OSDN Git Service

Add IEventStream, so other subjects can use emit, etc
authorZach Johnson <zachoverflow@google.com>
Sat, 29 Feb 2020 06:34:08 +0000 (22:34 -0800)
committerZach Johnson <zachoverflow@google.com>
Sat, 29 Feb 2020 06:38:30 +0000 (22:38 -0800)
And use in PyHci & PyAclManager. Can't quite turn on filtering
on the ACL messages, due to the handles being wrong.

Test: cert/run --host --test_filter=AclManagerTest
Change-Id: I30ef332b06ae553553757337cd66da91f8debe3a

gd/cert/event_stream.py
gd/cert/truth.py
gd/hci/cert/acl_manager_test.py
gd/hci/cert/py_acl_manager.py
gd/hci/cert/py_hci.py

index a793299..4bbc5ea 100644 (file)
@@ -24,14 +24,44 @@ from google.protobuf import text_format
 from concurrent.futures import ThreadPoolExecutor
 from grpc import RpcError
 
+from abc import ABC, abstractmethod
 
-class EventStream(object):
+
+class IEventStream(ABC):
+
+    @abstractmethod
+    def get_event_queue(self):
+        pass
+
+
+class FilteringEventStream(IEventStream):
+
+    def __init__(self, stream, filter_fn):
+        self.filter_fn = filter_fn
+        self.event_queue = SimpleQueue()
+        self.stream = stream
+
+        self.stream.register_callback(self.__event_callback, self.filter_fn)
+
+    def __event_callback(self, event):
+        self.event_queue.put(event)
+
+    def get_event_queue(self):
+        return self.event_queue
+
+    def unregister(self):
+        self.stream.unregister(self.__event_callback)
+
+
+DEFAULT_TIMEOUT_SECONDS = 3
+
+
+class EventStream(IEventStream):
     """
     A class that streams events from a gRPC stream, which you can assert on.
 
     Don't use these asserts directly, use the ones from cert.truth.
     """
-    DEFAULT_TIMEOUT_SECONDS = 3
 
     def __init__(self, server_stream_call):
         if server_stream_call is None:
@@ -53,11 +83,8 @@ class EventStream(object):
     def __del__(self):
         self.shutdown()
 
-    def remaining_time_delta(self, end_time):
-        remaining = end_time - datetime.now()
-        if remaining < timedelta(milliseconds=0):
-            remaining = timedelta(milliseconds=0)
-        return remaining
+    def get_event_queue(self):
+        return self.event_queue
 
     def shutdown(self):
         """
@@ -170,7 +197,7 @@ class EventStream(object):
         event = None
         end_time = datetime.now() + timeout
         while event is None and datetime.now() < end_time:
-            remaining = self.remaining_time_delta(end_time)
+            remaining = static_remaining_time_delta(end_time)
             logging.debug("Waiting for event (%fs remaining)" %
                           (remaining.total_seconds()))
             try:
@@ -202,26 +229,7 @@ class EventStream(object):
                                happen
         :return:
         """
-        logging.debug("assert_event_occurs %d %fs" % (at_least_times,
-                                                      timeout.total_seconds()))
-        event_list = []
-        end_time = datetime.now() + timeout
-        while len(event_list) < at_least_times and datetime.now() < end_time:
-            remaining = self.remaining_time_delta(end_time)
-            logging.debug("Waiting for event (%fs remaining)" %
-                          (remaining.total_seconds()))
-            try:
-                current_event = self.event_queue.get(
-                    timeout=remaining.total_seconds())
-                if match_fn(current_event):
-                    event_list.append(current_event)
-            except Empty:
-                continue
-        logging.debug("Done waiting for event")
-        asserts.assert_true(
-            len(event_list) >= at_least_times,
-            msg=("Expected at least %d events, but got %d" % (at_least_times,
-                                                              len(event_list))))
+        NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)
 
     def assert_event_occurs_at_most(
             self,
@@ -242,7 +250,7 @@ class EventStream(object):
         event_list = []
         end_time = datetime.now() + timeout
         while len(event_list) <= at_most_times and datetime.now() < end_time:
-            remaining = self.remaining_time_delta(end_time)
+            remaining = static_remaining_time_delta(end_time)
             logging.debug("Waiting for event iteration (%fs remaining)" %
                           (remaining.total_seconds()))
             try:
@@ -263,36 +271,79 @@ class EventStream(object):
             match_fns,
             order_matters,
             timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
-        logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
-        pending_matches = list(match_fns)
-        matched_order = []
-        end_time = datetime.now() + timeout
-        while len(pending_matches) > 0 and datetime.now() < end_time:
-            remaining = self.remaining_time_delta(end_time)
-            logging.debug("Waiting for event (%fs remaining)" %
-                          (remaining.total_seconds()))
-            try:
-                current_event = self.event_queue.get(
-                    timeout=remaining.total_seconds())
-                for match_fn in pending_matches:
-                    if match_fn(current_event):
-                        pending_matches.remove(match_fn)
-                        matched_order.append(match_fn)
-            except Empty:
-                continue
-        logging.debug("Done waiting for event")
+        NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters,
+                                            timeout)
+
+
+def static_remaining_time_delta(end_time):
+    remaining = end_time - datetime.now()
+    if remaining < timedelta(milliseconds=0):
+        remaining = timedelta(milliseconds=0)
+    return remaining
+
+
+def NOT_FOR_YOU_assert_event_occurs(
+        istream,
+        match_fn,
+        at_least_times=1,
+        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
+    logging.debug("assert_event_occurs %d %fs" % (at_least_times,
+                                                  timeout.total_seconds()))
+    event_list = []
+    end_time = datetime.now() + timeout
+    while len(event_list) < at_least_times and datetime.now() < end_time:
+        remaining = static_remaining_time_delta(end_time)
+        logging.debug(
+            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
+        try:
+            current_event = istream.get_event_queue().get(
+                timeout=remaining.total_seconds())
+            if match_fn(current_event):
+                event_list.append(current_event)
+        except Empty:
+            continue
+    logging.debug("Done waiting for event")
+    asserts.assert_true(
+        len(event_list) >= at_least_times,
+        msg=("Expected at least %d events, but got %d" % (at_least_times,
+                                                          len(event_list))))
+
+
+def NOT_FOR_YOU_assert_all_events_occur(
+        istream,
+        match_fns,
+        order_matters,
+        timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
+    logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
+    pending_matches = list(match_fns)
+    matched_order = []
+    end_time = datetime.now() + timeout
+    while len(pending_matches) > 0 and datetime.now() < end_time:
+        remaining = static_remaining_time_delta(end_time)
+        logging.debug(
+            "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
+        try:
+            current_event = istream.get_event_queue().get(
+                timeout=remaining.total_seconds())
+            for match_fn in pending_matches:
+                if match_fn(current_event):
+                    pending_matches.remove(match_fn)
+                    matched_order.append(match_fn)
+        except Empty:
+            continue
+    logging.debug("Done waiting for event")
+    asserts.assert_true(
+        len(matched_order) == len(match_fns),
+        msg=("Expected at least %d events, but got %d" % (len(match_fns),
+                                                          len(matched_order))))
+    if order_matters:
+        correct_order = True
+        i = 0
+        while i < len(match_fns):
+            if match_fns[i] is not matched_order[i]:
+                correct_order = False
+                break
+            i += 1
         asserts.assert_true(
-            len(matched_order) == len(match_fns),
-            msg=("Expected at least %d events, but got %d" %
-                 (len(match_fns), len(matched_order))))
-        if order_matters:
-            correct_order = True
-            i = 0
-            while i < len(match_fns):
-                if match_fns[i] is not matched_order[i]:
-                    correct_order = False
-                    break
-                i += 1
-            asserts.assert_true(
-                correct_order, "Events not received in correct order %s %s" %
-                (match_fns, matched_order))
+            correct_order, "Events not received in correct order %s %s" %
+            (match_fns, matched_order))
index 047bffe..50a14c0 100644 (file)
@@ -20,7 +20,9 @@ from mobly.asserts import assert_true
 from mobly.asserts import assert_false
 
 from mobly import signals
-from cert.event_stream import EventStream
+from cert.event_stream import IEventStream
+from cert.event_stream import NOT_FOR_YOU_assert_event_occurs
+from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur
 
 import sys, traceback
 
@@ -63,7 +65,7 @@ class EventStreamSubject(ObjectSubject):
         if len(match_fns) == 0:
             raise signals.TestFailure("Must specify a match function")
         elif len(match_fns) == 1:
-            self._value.assert_event_occurs(match_fns[0])
+            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
             return EventStreamContinuationSubject(self._value)
         else:
             return MultiMatchStreamSubject(self._value, match_fns)
@@ -76,13 +78,13 @@ class MultiMatchStreamSubject(object):
         self._match_fns = match_fns
 
     def inAnyOrder(self):
-        self._stream.assert_all_events_occur(
-            self._match_fns, order_matters=False)
+        NOT_FOR_YOU_assert_all_events_occur(
+            self._stream, self._match_fns, order_matters=False)
         return EventStreamContinuationSubject(self._stream)
 
     def inOrder(self):
-        self._stream.assert_all_events_occur(
-            self._match_fns, order_matters=True)
+        NOT_FOR_YOU_assert_all_events_occur(
+            self._stream, self._match_fns, order_matters=True)
         return EventStreamContinuationSubject(self._stream)
 
 
@@ -95,7 +97,7 @@ class EventStreamContinuationSubject(ObjectSubject):
         if len(match_fns) == 0:
             raise signals.TestFailure("Must specify a match function")
         elif len(match_fns) == 1:
-            self._value.assert_event_occurs(match_fns[0])
+            NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
             return EventStreamContinuationSubject(self._value)
         else:
             return MultiMatchStreamSubject(self._value, match_fns)
@@ -116,7 +118,7 @@ class BooleanSubject(ObjectSubject):
 def assertThat(subject):
     if type(subject) is bool:
         return BooleanSubject(subject)
-    elif isinstance(subject, EventStream):
+    elif isinstance(subject, IEventStream):
         return EventStreamSubject(subject)
     else:
         return ObjectSubject(subject)
index 8d4dd33..bff5962 100644 (file)
@@ -58,9 +58,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
                     b'\x29\x00\x07\x00This is just SomeMoreAclData from the DUT'
                 )
 
-                assertThat(cert_hci.get_acl_stream()).emits(
+                assertThat(cert_acl).emits(
                     lambda packet: b'SomeMoreAclData' in packet.data)
-                assertThat(dut_acl_manager.get_acl_stream()).emits(
+                assertThat(dut_acl).emits(
                     lambda packet: b'SomeAclData' in packet.payload)
 
     def test_cert_connects(self):
@@ -85,9 +85,9 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
             cert_acl.send_first(
                 b'\x26\x00\x07\x00This is just SomeAclData from the Cert')
 
-            assertThat(cert_hci.get_acl_stream()).emits(
+            assertThat(cert_acl).emits(
                 lambda packet: b'SomeMoreAclData' in packet.data)
-            assertThat(dut_acl_manager.get_acl_stream()).emits(
+            assertThat(dut_acl).emits(
                 lambda packet: b'SomeAclData' in packet.payload)
 
     def test_recombination_l2cap_packet(self):
@@ -105,6 +105,6 @@ class AclManagerTest(GdFacadeOnlyBaseTestClass):
 
                 dut_acl.wait_for_connection_complete()
 
-                assertThat(dut_acl_manager.get_acl_stream()).emits(
+                assertThat(dut_acl).emits(
                     lambda packet: b'Hello!' in packet.payload,
                     lambda packet: b'Hello' * 200 in packet.payload).inOrder()
index 7766433..b7a4028 100644 (file)
@@ -16,6 +16,8 @@
 
 from google.protobuf import empty_pb2 as empty_proto
 from cert.event_stream import EventStream
+from cert.event_stream import FilteringEventStream
+from cert.event_stream import IEventStream
 from captures import ReadBdAddrCompleteCapture
 from captures import ConnectionCompleteCapture
 from captures import ConnectionRequestCapture
@@ -25,11 +27,13 @@ from hci.facade import facade_pb2 as hci_facade
 from hci.facade import acl_manager_facade_pb2 as acl_manager_facade
 
 
-class PyAclManagerAclConnection(object):
+class PyAclManagerAclConnection(IEventStream):
 
-    def __init__(self, device, remote_addr, handle):
+    def __init__(self, device, acl_stream, remote_addr, handle):
         self.device = device
         self.handle = handle
+        # todo enable filtering after sorting out handles
+        self.our_acl_stream = FilteringEventStream(acl_stream, None)
 
         if remote_addr:
             self.connection_event_stream = EventStream(
@@ -64,6 +68,9 @@ class PyAclManagerAclConnection(object):
         self.device.hci_acl_manager.SendAclData(
             acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))
 
+    def get_event_queue(self):
+        return self.our_acl_stream.get_event_queue()
+
 
 class PyAclManager(object):
 
@@ -94,14 +101,13 @@ class PyAclManager(object):
             self.device.hci_acl_manager.FetchIncomingConnection(
                 empty_proto.Empty()))
 
-    def get_acl_stream(self):
-        return self.acl_stream
-
     def initiate_connection(self, remote_addr):
-        return PyAclManagerAclConnection(self.device, remote_addr, None)
+        return PyAclManagerAclConnection(self.device, self.acl_stream,
+                                         remote_addr, None)
 
     def accept_connection(self):
         connection_complete = ConnectionCompleteCapture()
         assertThat(self.incoming_connection_stream).emits(connection_complete)
         handle = connection_complete.get().GetConnectionHandle()
-        return PyAclManagerAclConnection(self.device, None, handle)
+        return PyAclManagerAclConnection(self.device, self.acl_stream, None,
+                                         handle)
index 548a12a..47b84a0 100644 (file)
@@ -16,6 +16,8 @@
 
 from google.protobuf import empty_pb2 as empty_proto
 from cert.event_stream import EventStream
+from cert.event_stream import FilteringEventStream
+from cert.event_stream import IEventStream
 from captures import ReadBdAddrCompleteCapture
 from captures import ConnectionCompleteCapture
 from captures import ConnectionRequestCapture
@@ -24,16 +26,17 @@ from cert.truth import assertThat
 from hci.facade import facade_pb2 as hci_facade
 
 
-class PyHciAclConnection(object):
+class PyHciAclConnection(IEventStream):
 
     def __init__(self, handle, acl_stream, device):
-        self.handle = handle
-        self.acl_stream = acl_stream
+        self.handle = int(handle)
         self.device = device
+        # todo, handle we got is 0, so doesn't match - fix before enabling filtering
+        self.our_acl_stream = FilteringEventStream(acl_stream, None)
 
     def send(self, pb_flag, b_flag, data):
         acl_msg = hci_facade.AclMsg(
-            handle=int(self.handle),
+            handle=self.handle,
             packet_boundary_flag=int(pb_flag),
             broadcast_flag=int(b_flag),
             data=data)
@@ -47,6 +50,9 @@ class PyHciAclConnection(object):
         self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
                   hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data))
 
+    def get_event_queue(self):
+        return self.our_acl_stream.get_event_queue()
+
 
 class PyHci(object):
 
@@ -80,9 +86,6 @@ class PyHci(object):
     def get_event_stream(self):
         return self.event_stream
 
-    def get_acl_stream(self):
-        return self.acl_stream
-
     def send_command_with_complete(self, command):
         self.device.hci.send_command_with_complete(command)