OSDN Git Service

SecurityRecordDatabase
authorJakub Pawlowski <jpawlowski@google.com>
Tue, 11 Feb 2020 17:42:08 +0000 (18:42 +0100)
committerJakub Pawlowski <jpawlowski@google.com>
Tue, 11 Feb 2020 19:54:10 +0000 (20:54 +0100)
Store SecurityRecords directly in vector, rather than shared_ptr
Move management of SecurityRecord into separate unit -
SecurityRecordDatabase.

Bug: 142341141
Change-Id: I0cc2dd8a7ddcf5a01117f0ebf7bd68111a93a2c5

gd/security/internal/security_manager_impl.cc
gd/security/internal/security_manager_impl.h
gd/security/record/security_record.h
gd/security/security_record_database.h [new file with mode: 0644]

index a951911..9b8d458 100644 (file)
@@ -28,46 +28,29 @@ namespace bluetooth {
 namespace security {
 namespace internal {
 
-std::shared_ptr<bluetooth::security::record::SecurityRecord> SecurityManagerImpl::CreateSecurityRecord(
-    hci::Address address) {
-  hci::AddressWithType device(address, hci::AddressType::PUBLIC_DEVICE_ADDRESS);
-  // Security record check
-  auto entry = security_record_map_.find(device.GetAddress());
-  if (entry == security_record_map_.end()) {
-    LOG_INFO("No security record for device: %s ", device.ToString().c_str());
-    // Create one
-    std::shared_ptr<security::record::SecurityRecord> record =
-        std::make_shared<security::record::SecurityRecord>(device);
-    auto new_entry = std::pair<hci::Address, std::shared_ptr<security::record::SecurityRecord>>(
-        record->GetPseudoAddress().GetAddress(), record);
-    // Keep track of it
-    security_record_map_.insert(new_entry);
-    return record;
-  }
-  return entry->second;
-}
-
-void SecurityManagerImpl::DispatchPairingHandler(std::shared_ptr<security::record::SecurityRecord> record,
-                                                 bool locally_initiated) {
+void SecurityManagerImpl::DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated) {
   common::OnceCallback<void(hci::Address, PairingResultOrFailure)> callback =
       common::BindOnce(&SecurityManagerImpl::OnPairingHandlerComplete, common::Unretained(this));
-  auto entry = pairing_handler_map_.find(record->GetPseudoAddress().GetAddress());
+  auto entry = pairing_handler_map_.find(record.GetPseudoAddress().GetAddress());
   if (entry != pairing_handler_map_.end()) {
     LOG_WARN("Device already has a pairing handler, and is in the middle of pairing!");
     return;
   }
   std::shared_ptr<pairing::PairingHandler> pairing_handler = nullptr;
-  switch (record->GetPseudoAddress().GetAddressType()) {
-    case hci::AddressType::PUBLIC_DEVICE_ADDRESS:
+  switch (record.GetPseudoAddress().GetAddressType()) {
+    case hci::AddressType::PUBLIC_DEVICE_ADDRESS: {
+      std::shared_ptr<record::SecurityRecord> record_copy =
+          std::make_shared<record::SecurityRecord>(record.GetPseudoAddress());
       pairing_handler = std::make_shared<security::pairing::ClassicPairingHandler>(
-          l2cap_classic_module_->GetFixedChannelManager(), security_manager_channel_, record, security_handler_,
+          l2cap_classic_module_->GetFixedChannelManager(), security_manager_channel_, record_copy, security_handler_,
           std::move(callback), listeners_);
       break;
+    }
     default:
-      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record->GetPseudoAddress().GetAddressType());
+      ASSERT_LOG(false, "Pairing type %hhu not implemented!", record.GetPseudoAddress().GetAddressType());
   }
   auto new_entry = std::pair<hci::Address, std::shared_ptr<pairing::PairingHandler>>(
-      record->GetPseudoAddress().GetAddress(), pairing_handler);
+      record.GetPseudoAddress().GetAddress(), pairing_handler);
   pairing_handler_map_.insert(std::move(new_entry));
   pairing_handler->Initiate(locally_initiated, pairing::kDefaultIoCapability, pairing::kDefaultOobDataPresent,
                             pairing::kDefaultAuthenticationRequirements);
@@ -81,8 +64,8 @@ void SecurityManagerImpl::Init() {
 }
 
 void SecurityManagerImpl::CreateBond(hci::AddressWithType device) {
-  auto record = CreateSecurityRecord(device.GetAddress());
-  if (record->IsBonded()) {
+  record::SecurityRecord& record = security_database_.FindOrCreate(device);
+  if (record.IsBonded()) {
     NotifyDeviceBonded(device);
   } else {
     // Dispatch pairing handler, if we are calling create we are the initiator
@@ -106,10 +89,7 @@ void SecurityManagerImpl::CancelBond(hci::AddressWithType device) {
 
 void SecurityManagerImpl::RemoveBond(hci::AddressWithType device) {
   CancelBond(device);
-  auto entry = security_record_map_.find(device.GetAddress());
-  if (entry != security_record_map_.end()) {
-    security_record_map_.erase(entry);
-  }
+  security_database_.Remove(device);
   // Signal disconnect
   // Remove security record
   // Signal Remove from database
@@ -168,7 +148,8 @@ void SecurityManagerImpl::HandleEvent(T packet) {
     auto event = hci::EventPacketView::Create(std::move(packet));
     ASSERT_LOG(event.IsValid(), "Received invalid packet");
     const hci::EventCode code = event.GetEventCode();
-    auto record = CreateSecurityRecord(bd_addr);
+    auto record =
+        security_database_.FindOrCreate(hci::AddressWithType{bd_addr, hci::AddressType::PUBLIC_DEVICE_ADDRESS});
     switch (code) {
       case hci::EventCode::LINK_KEY_REQUEST:
         DispatchPairingHandler(record, true);
index 1139e44..0385644 100644 (file)
@@ -26,6 +26,7 @@
 #include "security/channel/security_manager_channel.h"
 #include "security/pairing/classic_pairing_handler.h"
 #include "security/record/security_record.h"
+#include "security/security_record_database.h"
 
 namespace bluetooth {
 namespace security {
@@ -122,8 +123,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener {
   template <class T>
   void HandleEvent(T packet);
 
-  std::shared_ptr<record::SecurityRecord> CreateSecurityRecord(hci::Address address);
-  void DispatchPairingHandler(std::shared_ptr<record::SecurityRecord> record, bool locally_initiated);
+  void DispatchPairingHandler(record::SecurityRecord& record, bool locally_initiated);
   void OnL2capRegistrationCompleteLe(l2cap::le::FixedChannelManager::RegistrationResult result,
                                      std::unique_ptr<l2cap::le::FixedChannelService> le_smp_service);
   void OnConnectionOpenLe(std::unique_ptr<l2cap::le::FixedChannel> channel);
@@ -137,7 +137,7 @@ class SecurityManagerImpl : public channel::ISecurityManagerChannelListener {
   std::unique_ptr<l2cap::le::FixedChannelManager> l2cap_manager_le_;
   hci::LeSecurityInterface* hci_security_interface_le_ __attribute__((unused));
   channel::SecurityManagerChannel* security_manager_channel_;
-  std::unordered_map<hci::Address, std::shared_ptr<record::SecurityRecord>> security_record_map_;
+  SecurityRecordDatabase security_database_;
   std::unordered_map<hci::Address, std::shared_ptr<pairing::PairingHandler>> pairing_handler_map_;
 };
 }  // namespace internal
index 137c6d7..efe55f8 100644 (file)
@@ -41,6 +41,8 @@ class SecurityRecord {
  public:
   explicit SecurityRecord(hci::AddressWithType address) : pseudo_address_(address), state_(PAIRING) {}
 
+  SecurityRecord& operator=(const SecurityRecord& other) = default;
+
   /**
    * Returns true if Link Keys are stored persistently
    */
@@ -72,15 +74,16 @@ class SecurityRecord {
 
  private:
   /* First address we have ever seen this device with, that we used to create bond */
-  const hci::AddressWithType pseudo_address_;
-
-  /* Identity Address */
-  std::optional<hci::AddressWithType> identity_address_;
+  hci::AddressWithType pseudo_address_;
 
   BondState state_;
   std::array<uint8_t, 16> link_key_ = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0};
   hci::KeyType key_type_ = hci::KeyType::DEBUG_COMBINATION;
 
+ public:
+  /* Identity Address */
+  std::optional<hci::AddressWithType> identity_address_;
+
   std::optional<crypto_toolbox::Octet16> ltk;
   std::optional<uint16_t> ediv;
   std::optional<std::array<uint8_t, 8>> rand;
diff --git a/gd/security/security_record_database.h b/gd/security/security_record_database.h
new file mode 100644 (file)
index 0000000..2360475
--- /dev/null
@@ -0,0 +1,63 @@
+/*
+ * Copyright 2019 The Android Open Source Project
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ *      http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#pragma once
+
+#include "security/record/security_record.h"
+
+namespace bluetooth {
+namespace security {
+
+class SecurityRecordDatabase {
+ public:
+  using iterator = std::vector<record::SecurityRecord>::iterator;
+
+  record::SecurityRecord& FindOrCreate(hci::AddressWithType address) {
+    auto it = Find(address);
+    // Security record check
+    if (it != records_.end()) return *it;
+
+    // No security record, create one
+    records_.emplace_back(address);
+    return records_.back();
+  }
+
+  void Remove(const hci::AddressWithType& address) {
+    auto it = Find(address);
+
+    // No record exists
+    if (it == records_.end()) return;
+
+    record::SecurityRecord& last = records_.back();
+    *it = std::move(last);
+    records_.pop_back();
+  }
+
+  iterator Find(hci::AddressWithType address) {
+    for (auto it = records_.begin(); it != records_.end(); ++it) {
+      record::SecurityRecord& record = *it;
+      if (record.identity_address_.has_value() && record.identity_address_.value() == address) return it;
+      if (record.GetPseudoAddress() == address) return it;
+      if (record.irk.has_value() && address.IsRpaThatMatchesIrk(record.irk.value())) return it;
+    }
+    return records_.end();
+  }
+
+  std::vector<record::SecurityRecord> records_;
+};
+
+}  // namespace security
+}  // namespace bluetooth
\ No newline at end of file