from concurrent.futures import ThreadPoolExecutor
from grpc import RpcError
+from abc import ABC, abstractmethod
-class EventStream(object):
+
+class IEventStream(ABC):
+
+ @abstractmethod
+ def get_event_queue(self):
+ pass
+
+
+class FilteringEventStream(IEventStream):
+
+ def __init__(self, stream, filter_fn):
+ self.filter_fn = filter_fn
+ self.event_queue = SimpleQueue()
+ self.stream = stream
+
+ self.stream.register_callback(self.__event_callback, self.filter_fn)
+
+ def __event_callback(self, event):
+ self.event_queue.put(event)
+
+ def get_event_queue(self):
+ return self.event_queue
+
+ def unregister(self):
+ self.stream.unregister(self.__event_callback)
+
+
+DEFAULT_TIMEOUT_SECONDS = 3
+
+
+class EventStream(IEventStream):
"""
A class that streams events from a gRPC stream, which you can assert on.
Don't use these asserts directly, use the ones from cert.truth.
"""
- DEFAULT_TIMEOUT_SECONDS = 3
def __init__(self, server_stream_call):
if server_stream_call is None:
def __del__(self):
self.shutdown()
- def remaining_time_delta(self, end_time):
- remaining = end_time - datetime.now()
- if remaining < timedelta(milliseconds=0):
- remaining = timedelta(milliseconds=0)
- return remaining
+ def get_event_queue(self):
+ return self.event_queue
def shutdown(self):
"""
event = None
end_time = datetime.now() + timeout
while event is None and datetime.now() < end_time:
- remaining = self.remaining_time_delta(end_time)
+ remaining = static_remaining_time_delta(end_time)
logging.debug("Waiting for event (%fs remaining)" %
(remaining.total_seconds()))
try:
happen
:return:
"""
- logging.debug("assert_event_occurs %d %fs" % (at_least_times,
- timeout.total_seconds()))
- event_list = []
- end_time = datetime.now() + timeout
- while len(event_list) < at_least_times and datetime.now() < end_time:
- remaining = self.remaining_time_delta(end_time)
- logging.debug("Waiting for event (%fs remaining)" %
- (remaining.total_seconds()))
- try:
- current_event = self.event_queue.get(
- timeout=remaining.total_seconds())
- if match_fn(current_event):
- event_list.append(current_event)
- except Empty:
- continue
- logging.debug("Done waiting for event")
- asserts.assert_true(
- len(event_list) >= at_least_times,
- msg=("Expected at least %d events, but got %d" % (at_least_times,
- len(event_list))))
+ NOT_FOR_YOU_assert_event_occurs(self, match_fn, at_least_times, timeout)
def assert_event_occurs_at_most(
self,
event_list = []
end_time = datetime.now() + timeout
while len(event_list) <= at_most_times and datetime.now() < end_time:
- remaining = self.remaining_time_delta(end_time)
+ remaining = static_remaining_time_delta(end_time)
logging.debug("Waiting for event iteration (%fs remaining)" %
(remaining.total_seconds()))
try:
match_fns,
order_matters,
timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
- logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
- pending_matches = list(match_fns)
- matched_order = []
- end_time = datetime.now() + timeout
- while len(pending_matches) > 0 and datetime.now() < end_time:
- remaining = self.remaining_time_delta(end_time)
- logging.debug("Waiting for event (%fs remaining)" %
- (remaining.total_seconds()))
- try:
- current_event = self.event_queue.get(
- timeout=remaining.total_seconds())
- for match_fn in pending_matches:
- if match_fn(current_event):
- pending_matches.remove(match_fn)
- matched_order.append(match_fn)
- except Empty:
- continue
- logging.debug("Done waiting for event")
+ NOT_FOR_YOU_assert_all_events_occur(self, match_fns, order_matters,
+ timeout)
+
+
+def static_remaining_time_delta(end_time):
+ remaining = end_time - datetime.now()
+ if remaining < timedelta(milliseconds=0):
+ remaining = timedelta(milliseconds=0)
+ return remaining
+
+
+def NOT_FOR_YOU_assert_event_occurs(
+ istream,
+ match_fn,
+ at_least_times=1,
+ timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
+ logging.debug("assert_event_occurs %d %fs" % (at_least_times,
+ timeout.total_seconds()))
+ event_list = []
+ end_time = datetime.now() + timeout
+ while len(event_list) < at_least_times and datetime.now() < end_time:
+ remaining = static_remaining_time_delta(end_time)
+ logging.debug(
+ "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
+ try:
+ current_event = istream.get_event_queue().get(
+ timeout=remaining.total_seconds())
+ if match_fn(current_event):
+ event_list.append(current_event)
+ except Empty:
+ continue
+ logging.debug("Done waiting for event")
+ asserts.assert_true(
+ len(event_list) >= at_least_times,
+ msg=("Expected at least %d events, but got %d" % (at_least_times,
+ len(event_list))))
+
+
+def NOT_FOR_YOU_assert_all_events_occur(
+ istream,
+ match_fns,
+ order_matters,
+ timeout=timedelta(seconds=DEFAULT_TIMEOUT_SECONDS)):
+ logging.debug("assert_all_events_occur %fs" % timeout.total_seconds())
+ pending_matches = list(match_fns)
+ matched_order = []
+ end_time = datetime.now() + timeout
+ while len(pending_matches) > 0 and datetime.now() < end_time:
+ remaining = static_remaining_time_delta(end_time)
+ logging.debug(
+ "Waiting for event (%fs remaining)" % (remaining.total_seconds()))
+ try:
+ current_event = istream.get_event_queue().get(
+ timeout=remaining.total_seconds())
+ for match_fn in pending_matches:
+ if match_fn(current_event):
+ pending_matches.remove(match_fn)
+ matched_order.append(match_fn)
+ except Empty:
+ continue
+ logging.debug("Done waiting for event")
+ asserts.assert_true(
+ len(matched_order) == len(match_fns),
+ msg=("Expected at least %d events, but got %d" % (len(match_fns),
+ len(matched_order))))
+ if order_matters:
+ correct_order = True
+ i = 0
+ while i < len(match_fns):
+ if match_fns[i] is not matched_order[i]:
+ correct_order = False
+ break
+ i += 1
asserts.assert_true(
- len(matched_order) == len(match_fns),
- msg=("Expected at least %d events, but got %d" %
- (len(match_fns), len(matched_order))))
- if order_matters:
- correct_order = True
- i = 0
- while i < len(match_fns):
- if match_fns[i] is not matched_order[i]:
- correct_order = False
- break
- i += 1
- asserts.assert_true(
- correct_order, "Events not received in correct order %s %s" %
- (match_fns, matched_order))
+ correct_order, "Events not received in correct order %s %s" %
+ (match_fns, matched_order))
from mobly.asserts import assert_false
from mobly import signals
-from cert.event_stream import EventStream
+from cert.event_stream import IEventStream
+from cert.event_stream import NOT_FOR_YOU_assert_event_occurs
+from cert.event_stream import NOT_FOR_YOU_assert_all_events_occur
import sys, traceback
if len(match_fns) == 0:
raise signals.TestFailure("Must specify a match function")
elif len(match_fns) == 1:
- self._value.assert_event_occurs(match_fns[0])
+ NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
return EventStreamContinuationSubject(self._value)
else:
return MultiMatchStreamSubject(self._value, match_fns)
self._match_fns = match_fns
def inAnyOrder(self):
- self._stream.assert_all_events_occur(
- self._match_fns, order_matters=False)
+ NOT_FOR_YOU_assert_all_events_occur(
+ self._stream, self._match_fns, order_matters=False)
return EventStreamContinuationSubject(self._stream)
def inOrder(self):
- self._stream.assert_all_events_occur(
- self._match_fns, order_matters=True)
+ NOT_FOR_YOU_assert_all_events_occur(
+ self._stream, self._match_fns, order_matters=True)
return EventStreamContinuationSubject(self._stream)
if len(match_fns) == 0:
raise signals.TestFailure("Must specify a match function")
elif len(match_fns) == 1:
- self._value.assert_event_occurs(match_fns[0])
+ NOT_FOR_YOU_assert_event_occurs(self._value, match_fns[0])
return EventStreamContinuationSubject(self._value)
else:
return MultiMatchStreamSubject(self._value, match_fns)
def assertThat(subject):
if type(subject) is bool:
return BooleanSubject(subject)
- elif isinstance(subject, EventStream):
+ elif isinstance(subject, IEventStream):
return EventStreamSubject(subject)
else:
return ObjectSubject(subject)
from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
+from cert.event_stream import FilteringEventStream
+from cert.event_stream import IEventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
from hci.facade import acl_manager_facade_pb2 as acl_manager_facade
-class PyAclManagerAclConnection(object):
+class PyAclManagerAclConnection(IEventStream):
- def __init__(self, device, remote_addr, handle):
+ def __init__(self, device, acl_stream, remote_addr, handle):
self.device = device
self.handle = handle
+ # todo enable filtering after sorting out handles
+ self.our_acl_stream = FilteringEventStream(acl_stream, None)
if remote_addr:
self.connection_event_stream = EventStream(
self.device.hci_acl_manager.SendAclData(
acl_manager_facade.AclData(handle=self.handle, payload=bytes(data)))
+ def get_event_queue(self):
+ return self.our_acl_stream.get_event_queue()
+
class PyAclManager(object):
self.device.hci_acl_manager.FetchIncomingConnection(
empty_proto.Empty()))
- def get_acl_stream(self):
- return self.acl_stream
-
def initiate_connection(self, remote_addr):
- return PyAclManagerAclConnection(self.device, remote_addr, None)
+ return PyAclManagerAclConnection(self.device, self.acl_stream,
+ remote_addr, None)
def accept_connection(self):
connection_complete = ConnectionCompleteCapture()
assertThat(self.incoming_connection_stream).emits(connection_complete)
handle = connection_complete.get().GetConnectionHandle()
- return PyAclManagerAclConnection(self.device, None, handle)
+ return PyAclManagerAclConnection(self.device, self.acl_stream, None,
+ handle)
from google.protobuf import empty_pb2 as empty_proto
from cert.event_stream import EventStream
+from cert.event_stream import FilteringEventStream
+from cert.event_stream import IEventStream
from captures import ReadBdAddrCompleteCapture
from captures import ConnectionCompleteCapture
from captures import ConnectionRequestCapture
from hci.facade import facade_pb2 as hci_facade
-class PyHciAclConnection(object):
+class PyHciAclConnection(IEventStream):
def __init__(self, handle, acl_stream, device):
- self.handle = handle
- self.acl_stream = acl_stream
+ self.handle = int(handle)
self.device = device
+ # todo, handle we got is 0, so doesn't match - fix before enabling filtering
+ self.our_acl_stream = FilteringEventStream(acl_stream, None)
def send(self, pb_flag, b_flag, data):
acl_msg = hci_facade.AclMsg(
- handle=int(self.handle),
+ handle=self.handle,
packet_boundary_flag=int(pb_flag),
broadcast_flag=int(b_flag),
data=data)
self.send(hci_packets.PacketBoundaryFlag.CONTINUING_FRAGMENT,
hci_packets.BroadcastFlag.POINT_TO_POINT, bytes(data))
+ def get_event_queue(self):
+ return self.our_acl_stream.get_event_queue()
+
class PyHci(object):
def get_event_stream(self):
return self.event_stream
- def get_acl_stream(self):
- return self.acl_stream
-
def send_command_with_complete(self, command):
self.device.hci.send_command_with_complete(command)