OSDN Git Service

HCI: Use futures to wait for packets in tests
authorMyles Watson <mylesgw@google.com>
Thu, 25 Jul 2019 23:28:35 +0000 (16:28 -0700)
committerMyles Watson <mylesgw@google.com>
Tue, 30 Jul 2019 21:15:03 +0000 (21:15 +0000)
Test: bluetooth_gd_test on a device
Change-Id: I3e61b83df47204380398646350b90c3f63734bca

gd/hci/hci_layer_test.cc

index a5b8f51..51cf3bf 100644 (file)
@@ -48,6 +48,9 @@ const size_t count_size = 0x8;
 namespace bluetooth {
 namespace hci {
 
+constexpr std::chrono::milliseconds kTimeout = HciLayer::kHciTimeoutMs / 2;
+constexpr std::chrono::milliseconds kAclTimeout = std::chrono::milliseconds(1000);
+
 class TestHciHal : public hal::HciHal {
  public:
   TestHciHal() : hal::HciHal() {}
@@ -66,10 +69,20 @@ class TestHciHal : public hal::HciHal {
 
   void sendHciCommand(hal::HciPacket command) override {
     outgoing_commands_.push_back(std::move(command));
+    if (sent_command_promise_ != nullptr) {
+      auto promise = std::move(sent_command_promise_);
+      sent_command_promise_.reset();
+      promise->set_value();
+    }
   }
 
   void sendAclData(hal::HciPacket data) override {
     outgoing_acl_.push_front(std::move(data));
+    if (sent_acl_promise_ != nullptr) {
+      auto promise = std::move(sent_acl_promise_);
+      sent_acl_promise_.reset();
+      promise->set_value();
+    }
   }
 
   void sendScoData(hal::HciPacket data) override {
@@ -87,17 +100,25 @@ class TestHciHal : public hal::HciHal {
     return outgoing_commands_.size();
   }
 
+  std::future<void> GetSentCommandFuture() {
+    ASSERT_LOG(sent_command_promise_ == nullptr, "Promises promises ... Only one at a time");
+    sent_command_promise_ = std::make_unique<std::promise<void>>();
+    return sent_command_promise_->get_future();
+  }
+
   CommandPacketView GetSentCommand() {
-    while (outgoing_commands_.size() == 0)
-      ;
     auto packetview = GetPacketView(std::move(outgoing_commands_.front()));
     outgoing_commands_.pop_front();
     return CommandPacketView::Create(packetview);
   }
 
+  std::future<void> GetSentAclFuture() {
+    ASSERT_LOG(sent_acl_promise_ == nullptr, "Promises promises ... Only one at a time");
+    sent_acl_promise_ = std::make_unique<std::promise<void>>();
+    return sent_acl_promise_->get_future();
+  }
+
   PacketView<kLittleEndian> GetSentAcl() {
-    while (outgoing_acl_.size() == 0)
-      ;
     auto packetview = GetPacketView(std::move(outgoing_acl_.front()));
     outgoing_acl_.pop_front();
     return packetview;
@@ -115,6 +136,8 @@ class TestHciHal : public hal::HciHal {
   std::list<hal::HciPacket> outgoing_commands_;
   std::list<hal::HciPacket> outgoing_acl_;
   std::list<hal::HciPacket> outgoing_sco_;
+  std::unique_ptr<std::promise<void>> sent_command_promise_;
+  std::unique_ptr<std::promise<void>> sent_acl_promise_;
 };
 
 const ModuleFactory TestHciHal::Factory = ModuleFactory([]() { return new TestHciHal(); });
@@ -141,21 +164,31 @@ class DependsOnHci : public Module {
     queue_end->RegisterEnqueue(GetHandler(), common::Bind(&DependsOnHci::handle_enqueue, common::Unretained(this)));
   }
 
+  std::future<void> GetReceivedEventFuture() {
+    ASSERT_LOG(event_promise_ == nullptr, "Promises promises ... Only one at a time");
+    event_promise_ = std::make_unique<std::promise<void>>();
+    return event_promise_->get_future();
+  }
+
   EventPacketView GetReceivedEvent() {
-    while (incoming_events_.size() == 0)
-      ;
     EventPacketView packetview = incoming_events_.front();
     incoming_events_.pop_front();
     return packetview;
   }
 
+  std::future<void> GetReceivedAclFuture() {
+    ASSERT_LOG(event_promise_ == nullptr, "Promises promises ... Only one at a time");
+    acl_promise_ = std::make_unique<std::promise<void>>();
+    return acl_promise_->get_future();
+  }
+
+  size_t GetNumReceivedAclPackets() {
+    return incoming_acl_packets_.size();
+  }
+
   AclPacketView GetReceivedAcl() {
-    auto queue_end = hci_->GetAclQueueEnd();
-    std::unique_ptr<AclPacketView> incoming_acl_ptr;
-    while (incoming_acl_ptr == nullptr) {
-      incoming_acl_ptr = queue_end->TryDequeue();
-    }
-    AclPacketView packetview = *incoming_acl_ptr;
+    AclPacketView packetview = incoming_acl_packets_.front();
+    incoming_acl_packets_.pop_front();
     return packetview;
   }
 
@@ -164,9 +197,13 @@ class DependsOnHci : public Module {
     hci_->RegisterEventHandler(EventCode::CONNECTION_COMPLETE,
                                common::Bind(&DependsOnHci::handle_event<EventPacketView>, common::Unretained(this)),
                                GetHandler());
+    hci_->GetAclQueueEnd()->RegisterDequeue(GetHandler(),
+                                            common::Bind(&DependsOnHci::handle_acl, common::Unretained(this)));
   }
 
-  void Stop() {}
+  void Stop() {
+    hci_->GetAclQueueEnd()->UnregisterDequeue();
+  }
 
   void ListDependencies(ModuleList* list) {
     list->add<HciLayer>();
@@ -177,10 +214,28 @@ class DependsOnHci : public Module {
  private:
   HciLayer* hci_ = nullptr;
   std::list<EventPacketView> incoming_events_;
+  std::list<AclPacketView> incoming_acl_packets_;
+  std::unique_ptr<std::promise<void>> event_promise_;
+  std::unique_ptr<std::promise<void>> acl_promise_;
+
+  void handle_acl() {
+    auto acl_ptr = hci_->GetAclQueueEnd()->TryDequeue();
+    incoming_acl_packets_.push_back(*acl_ptr);
+    if (acl_promise_ != nullptr) {
+      auto promise = std::move(acl_promise_);
+      acl_promise_.reset();
+      promise->set_value();
+    }
+  }
 
   template <typename T>
   void handle_event(T event) {
     incoming_events_.push_back(event);
+    if (event_promise_ != nullptr) {
+      auto promise = std::move(event_promise_);
+      event_promise_.reset();
+      promise->set_value();
+    }
   }
 
   std::queue<std::unique_ptr<AclPacketBuilder>> outgoing_acl_;
@@ -205,14 +260,18 @@ class HciTest : public ::testing::Test {
       counting_down_bytes.push_back(~i);
     }
     hal = new TestHciHal();
+
+    auto command_future = hal->GetSentCommandFuture();
+
     fake_registry_.InjectTestModule(&hal::HciHal::Factory, hal);
     fake_registry_.Start<DependsOnHci>(&fake_registry_.GetTestThread());
     hci = static_cast<HciLayer*>(fake_registry_.GetModuleUnderTest(&HciLayer::Factory));
     upper = static_cast<DependsOnHci*>(fake_registry_.GetModuleUnderTest(&DependsOnHci::Factory));
     ASSERT(fake_registry_.IsStarted<HciLayer>());
-    // Wait for the reset
-    while (hal->GetNumSentCommands() == 0)
-      ;
+
+    auto reset_sent_status = command_future.wait_for(kTimeout);
+    ASSERT_EQ(reset_sent_status, std::future_status::ready);
+
     // Verify that reset was received
     ASSERT_EQ(1, hal->GetNumSentCommands());
 
@@ -256,6 +315,7 @@ TEST_F(HciTest, noOpCredits) {
   uint8_t num_packets = 0;
   hal->callbacks->hciEventReceived(GetPacketBytes(NoCommandCompleteBuilder::Create(num_packets)));
 
+  auto command_future = hal->GetSentCommandFuture();
   upper->SendHciCommandExpectingComplete(ReadLocalVersionInformationBuilder::Create());
 
   // Verify that nothing was sent
@@ -263,11 +323,15 @@ TEST_F(HciTest, noOpCredits) {
 
   num_packets = 1;
   hal->callbacks->hciEventReceived(GetPacketBytes(NoCommandCompleteBuilder::Create(num_packets)));
+
+  auto command_sent_status = command_future.wait_for(kTimeout);
+  ASSERT_EQ(command_sent_status, std::future_status::ready);
+
   // Verify that one was sent
-  while (hal->GetNumSentCommands() == 0)
-    ;
   ASSERT_EQ(1, hal->GetNumSentCommands());
 
+  auto event_future = upper->GetReceivedEventFuture();
+
   // Send the response event
   ErrorCode error_code = ErrorCode::SUCCESS;
   HciVersion hci_version = HciVersion::V_5_0;
@@ -277,6 +341,11 @@ TEST_F(HciTest, noOpCredits) {
   uint16_t lmp_subversion = 0x5678;
   hal->callbacks->hciEventReceived(GetPacketBytes(ReadLocalVersionInformationCompleteBuilder::Create(
       num_packets, error_code, hci_version, hci_subversion, lmp_version, manufacturer_name, lmp_subversion)));
+
+  // Wait for the event
+  auto event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
+
   auto event = upper->GetReceivedEvent();
   ASSERT(ReadLocalVersionInformationCompleteView::Create(CommandCompleteView::Create(EventPacketView::Create(event)))
              .IsValid());
@@ -285,13 +354,15 @@ TEST_F(HciTest, noOpCredits) {
 TEST_F(HciTest, creditsTest) {
   ASSERT_EQ(0, hal->GetNumSentCommands());
 
+  auto command_future = hal->GetSentCommandFuture();
+
   // Send all three commands
   upper->SendHciCommandExpectingComplete(ReadLocalVersionInformationBuilder::Create());
   upper->SendHciCommandExpectingComplete(ReadLocalSupportedCommandsBuilder::Create());
   upper->SendHciCommandExpectingComplete(ReadLocalSupportedFeaturesBuilder::Create());
 
-  while (hal->GetNumSentCommands() == 0)
-    ;
+  auto command_sent_status = command_future.wait_for(kTimeout);
+  ASSERT_EQ(command_sent_status, std::future_status::ready);
 
   // Verify that the first one is sent
   ASSERT_EQ(1, hal->GetNumSentCommands());
@@ -303,6 +374,9 @@ TEST_F(HciTest, creditsTest) {
   // Verify that only one was sent
   ASSERT_EQ(0, hal->GetNumSentCommands());
 
+  // Get a new future
+  auto event_future = upper->GetReceivedEventFuture();
+
   // Send the response event
   uint8_t num_packets = 1;
   ErrorCode error_code = ErrorCode::SUCCESS;
@@ -313,13 +387,18 @@ TEST_F(HciTest, creditsTest) {
   uint16_t lmp_subversion = 0x5678;
   hal->callbacks->hciEventReceived(GetPacketBytes(ReadLocalVersionInformationCompleteBuilder::Create(
       num_packets, error_code, hci_version, hci_subversion, lmp_version, manufacturer_name, lmp_subversion)));
+
+  // Wait for the event
+  auto event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
+
   auto event = upper->GetReceivedEvent();
   ASSERT(ReadLocalVersionInformationCompleteView::Create(CommandCompleteView::Create(EventPacketView::Create(event)))
              .IsValid());
 
   // Verify that the second one is sent
-  while (hal->GetNumSentCommands() == 0)
-    ;
+  command_sent_status = command_future.wait_for(kTimeout);
+  ASSERT_EQ(command_sent_status, std::future_status::ready);
   ASSERT_EQ(1, hal->GetNumSentCommands());
 
   sent_command = hal->GetSentCommand();
@@ -328,6 +407,8 @@ TEST_F(HciTest, creditsTest) {
 
   // Verify that only one was sent
   ASSERT_EQ(0, hal->GetNumSentCommands());
+  event_future = upper->GetReceivedEventFuture();
+  command_future = hal->GetSentCommandFuture();
 
   // Send the response event
   std::array<uint8_t, 64> supported_commands;
@@ -336,13 +417,16 @@ TEST_F(HciTest, creditsTest) {
   }
   hal->callbacks->hciEventReceived(
       GetPacketBytes(ReadLocalSupportedCommandsCompleteBuilder::Create(num_packets, error_code, supported_commands)));
+  // Wait for the event
+  event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
+
   event = upper->GetReceivedEvent();
   ASSERT(ReadLocalSupportedCommandsCompleteView::Create(CommandCompleteView::Create(EventPacketView::Create(event)))
              .IsValid());
-
   // Verify that the third one is sent
-  while (hal->GetNumSentCommands() == 0)
-    ;
+  command_sent_status = command_future.wait_for(kTimeout);
+  ASSERT_EQ(command_sent_status, std::future_status::ready);
   ASSERT_EQ(1, hal->GetNumSentCommands());
 
   sent_command = hal->GetSentCommand();
@@ -351,11 +435,16 @@ TEST_F(HciTest, creditsTest) {
 
   // Verify that only one was sent
   ASSERT_EQ(0, hal->GetNumSentCommands());
+  event_future = upper->GetReceivedEventFuture();
 
   // Send the response event
   uint64_t lmp_features = 0x012345678abcdef;
   hal->callbacks->hciEventReceived(
       GetPacketBytes(ReadLocalSupportedFeaturesCompleteBuilder::Create(num_packets, error_code, lmp_features)));
+
+  // Wait for the event
+  event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
   event = upper->GetReceivedEvent();
   ASSERT(ReadLocalSupportedFeaturesCompleteView::Create(CommandCompleteView::Create(EventPacketView::Create(event)))
              .IsValid());
@@ -363,6 +452,7 @@ TEST_F(HciTest, creditsTest) {
 
 TEST_F(HciTest, createConnectionTest) {
   // Send CreateConnection to the controller
+  auto command_future = hal->GetSentCommandFuture();
   common::Address bd_addr;
   ASSERT_TRUE(common::Address::FromString("A1:A2:A3:A4:A5:A6", bd_addr));
   uint16_t packet_type = 0x1234;
@@ -373,6 +463,9 @@ TEST_F(HciTest, createConnectionTest) {
   upper->SendHciCommandExpectingStatus(CreateConnectionBuilder::Create(
       bd_addr, packet_type, page_scan_repetition_mode, clock_offset, clock_offset_valid, allow_role_switch));
 
+  auto command_sent_status = command_future.wait_for(kTimeout);
+  ASSERT_EQ(command_sent_status, std::future_status::ready);
+
   // Check the command
   auto sent_command = hal->GetSentCommand();
   ASSERT_LT(0, sent_command.size());
@@ -387,6 +480,7 @@ TEST_F(HciTest, createConnectionTest) {
   ASSERT_EQ(allow_role_switch, view.GetAllowRoleSwitch());
 
   // Send a Command Status to the host
+  auto event_future = upper->GetReceivedEventFuture();
   ErrorCode status = ErrorCode::SUCCESS;
   uint16_t handle = 0x123;
   LinkType link_type = LinkType::ACL;
@@ -394,15 +488,20 @@ TEST_F(HciTest, createConnectionTest) {
   hal->callbacks->hciEventReceived(GetPacketBytes(CreateConnectionStatusBuilder::Create(ErrorCode::SUCCESS, 1)));
 
   // Verify the event
+  auto event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
   auto event = upper->GetReceivedEvent();
   ASSERT_TRUE(event.IsValid());
   ASSERT_EQ(EventCode::COMMAND_STATUS, event.GetEventCode());
 
   // Send a ConnectionComplete to the host
+  event_future = upper->GetReceivedEventFuture();
   hal->callbacks->hciEventReceived(
       GetPacketBytes(ConnectionCompleteBuilder::Create(status, handle, bd_addr, link_type, encryption_enabled)));
 
   // Verify the event
+  event_status = event_future.wait_for(kTimeout);
+  ASSERT_EQ(event_status, std::future_status::ready);
   event = upper->GetReceivedEvent();
   ASSERT_TRUE(event.IsValid());
   ASSERT_EQ(EventCode::CONNECTION_COMPLETE, event.GetEventCode());
@@ -419,10 +518,13 @@ TEST_F(HciTest, createConnectionTest) {
   auto acl_payload = std::make_unique<RawBuilder>();
   acl_payload->AddAddress(bd_addr);
   acl_payload->AddOctets2(handle);
+  auto incoming_acl_future = upper->GetReceivedAclFuture();
   hal->callbacks->aclDataReceived(
       GetPacketBytes(AclPacketBuilder::Create(handle, packet_boundary_flag, broadcast_flag, std::move(acl_payload))));
 
   // Verify the ACL packet
+  auto incoming_acl_status = incoming_acl_future.wait_for(kAclTimeout);
+  ASSERT_EQ(incoming_acl_status, std::future_status::ready);
   auto acl_view = upper->GetReceivedAcl();
   ASSERT_TRUE(acl_view.IsValid());
   ASSERT_EQ(sizeof(bd_addr) + sizeof(handle), acl_view.GetPayload().size());
@@ -436,9 +538,12 @@ TEST_F(HciTest, createConnectionTest) {
   auto acl_payload2 = std::make_unique<RawBuilder>();
   acl_payload2->AddOctets2(handle);
   acl_payload2->AddAddress(bd_addr);
+  auto sent_acl_future = hal->GetSentAclFuture();
   upper->SendAclData(AclPacketBuilder::Create(handle, packet_boundary_flag2, broadcast_flag2, std::move(acl_payload2)));
 
   // Verify the ACL packet
+  auto sent_acl_status = sent_acl_future.wait_for(kAclTimeout);
+  ASSERT_EQ(sent_acl_status, std::future_status::ready);
   auto sent_acl = hal->GetSentAcl();
   ASSERT_LT(0, sent_acl.size());
   AclPacketView sent_acl_view = AclPacketView::Create(sent_acl);
@@ -449,19 +554,64 @@ TEST_F(HciTest, createConnectionTest) {
   ASSERT_EQ(bd_addr, sent_itr.extract<Address>());
 }
 
-TEST_F(HciTest, receiveMultipleAclPacket) {
+TEST_F(HciTest, receiveMultipleAclPackets) {
   common::Address bd_addr;
   ASSERT_TRUE(common::Address::FromString("A1:A2:A3:A4:A5:A6", bd_addr));
   uint16_t handle = 0x0001;
+  uint16_t num_packets = 100;
   PacketBoundaryFlag packet_boundary_flag = PacketBoundaryFlag::COMPLETE_PDU;
   BroadcastFlag broadcast_flag = BroadcastFlag::POINT_TO_POINT;
-  for (int i = 0; i < 100; i++) {
+  for (uint16_t i = 0; i < num_packets; i++) {
     auto acl_payload = std::make_unique<RawBuilder>();
     acl_payload->AddAddress(bd_addr);
     acl_payload->AddOctets2(handle);
+    acl_payload->AddOctets2(i);
     hal->callbacks->aclDataReceived(
         GetPacketBytes(AclPacketBuilder::Create(handle, packet_boundary_flag, broadcast_flag, std::move(acl_payload))));
   }
+  auto incoming_acl_future = upper->GetReceivedAclFuture();
+  uint16_t received_packets = 0;
+  while (received_packets < num_packets - 1) {
+    auto incoming_acl_status = incoming_acl_future.wait_for(kAclTimeout);
+    // Get the next future.
+    incoming_acl_future = upper->GetReceivedAclFuture();
+    ASSERT_EQ(incoming_acl_status, std::future_status::ready);
+    size_t num_packets = upper->GetNumReceivedAclPackets();
+    for (size_t i = 0; i < num_packets; i++) {
+      auto acl_view = upper->GetReceivedAcl();
+      ASSERT_TRUE(acl_view.IsValid());
+      ASSERT_EQ(sizeof(bd_addr) + sizeof(handle) + sizeof(received_packets), acl_view.GetPayload().size());
+      auto itr = acl_view.GetPayload().begin();
+      ASSERT_EQ(bd_addr, itr.extract<Address>());
+      ASSERT_EQ(handle, itr.extract<uint16_t>());
+      ASSERT_EQ(received_packets, itr.extract<uint16_t>());
+      received_packets += 1;
+    }
+  }
+
+  // Check to see if this future was already fulfilled.
+  auto acl_race_status = incoming_acl_future.wait_for(std::chrono::milliseconds(1));
+  if (acl_race_status == std::future_status::ready) {
+    // Get the next future.
+    incoming_acl_future = upper->GetReceivedAclFuture();
+  }
+
+  // One last packet to make sure they were all sent.  Already got the future.
+  auto acl_payload = std::make_unique<RawBuilder>();
+  acl_payload->AddAddress(bd_addr);
+  acl_payload->AddOctets2(handle);
+  acl_payload->AddOctets2(num_packets);
+  hal->callbacks->aclDataReceived(
+      GetPacketBytes(AclPacketBuilder::Create(handle, packet_boundary_flag, broadcast_flag, std::move(acl_payload))));
+  auto incoming_acl_status = incoming_acl_future.wait_for(kAclTimeout);
+  ASSERT_EQ(incoming_acl_status, std::future_status::ready);
+  auto acl_view = upper->GetReceivedAcl();
+  ASSERT_TRUE(acl_view.IsValid());
+  ASSERT_EQ(sizeof(bd_addr) + sizeof(handle) + sizeof(received_packets), acl_view.GetPayload().size());
+  auto itr = acl_view.GetPayload().begin();
+  ASSERT_EQ(bd_addr, itr.extract<Address>());
+  ASSERT_EQ(handle, itr.extract<uint16_t>());
+  ASSERT_EQ(received_packets, itr.extract<uint16_t>());
 }
 }  // namespace hci
 }  // namespace bluetooth