OSDN Git Service

rusty-gd: remaining fixes to get DirectHciTest running consistently
authorZach Johnson <zachoverflow@google.com>
Mon, 8 Feb 2021 06:17:28 +0000 (22:17 -0800)
committerZach Johnson <zachoverflow@google.com>
Mon, 8 Feb 2021 06:22:44 +0000 (22:22 -0800)
Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost DirectHciTest
Change-Id: I05c6fca13d9af2705b4ae597f6a458963c0fd53f

15 files changed:
gd/Android.bp
gd/cert/logging_client_interceptor.py
gd/packet/parser/fields/fixed_scalar_field.cc
gd/packet/parser/fields/fixed_scalar_field.h
gd/packet/parser/fields/scalar_field.cc
gd/packet/parser/fields/vector_field.cc
gd/packet/parser/packet_def.cc
gd/packet/parser/parent_def.cc
gd/packet/parser/parent_def.h
gd/packet/parser/struct_def.cc
gd/rust/facade/helpers/lib.rs
gd/rust/hal/src/rootcanal_hal.rs
gd/rust/hal/src/snoop.rs
gd/rust/hci/src/facade.rs
gd/rust/hci/src/lib.rs

index d2daa7d..7459909 100644 (file)
@@ -517,6 +517,7 @@ rust_library {
         "libnum_traits",
         "libthiserror",
         "libbt_hci_custom_types",
+        "liblog_rust",
     ],
 }
 
index b74cdfc..5ab7149 100644 (file)
@@ -23,7 +23,7 @@ from google.protobuf import text_format
 
 def custom_message_formatter(m, ident, as_one_line):
     if m.DESCRIPTOR == common.Data.DESCRIPTOR:
-        return 'payload: (hex) "{}"'.format(m.payload.hex())
+        return 'payload: (hex) "{}"'.format(m.payload.hex(" "))
     return None
 
 
index e0ad2dd..c745b12 100644 (file)
@@ -37,3 +37,12 @@ void FixedScalarField::GenValue(std::ostream& s) const {
 void FixedScalarField::GenStringRepresentation(std::ostream& s, std::string) const {
   s << "+" << value_;
 }
+
+void FixedScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const {
+  s << "let " << GetName() << ": " << GetRustDataType() << " = " << value_ << ";";
+  FixedField::GenRustWriter(s, start_offset, end_offset);
+}
+
+void FixedScalarField::GenRustGetter(std::ostream& s, Size start_offset, Size end_offset) const {
+  FixedField::GenRustGetter(s, start_offset, end_offset);
+}
index 44a0fa5..2304ba4 100644 (file)
@@ -39,6 +39,10 @@ class FixedScalarField : public FixedField {
 
   static const std::string field_type;
 
+  void GenRustGetter(std::ostream& s, Size start_offset, Size end_offset) const override;
+
+  void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
+
  private:
 
   const int64_t value_;
index 1c3c5a6..de5e31c 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "fields/scalar_field.h"
 
+#include "fields/fixed_scalar_field.h"
 #include "fields/size_field.h"
 #include "util.h"
 
@@ -165,12 +166,31 @@ void ScalarField::GenRustGetter(std::ostream& s, Size start_offset, Size end_off
   int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());
 
   s << "let " << GetName() << " = ";
-  if (num_leading_bits == 0) {
-    s << GetRustParseDataType() << "::from_le_bytes(bytes[" << start_offset.bytes() << "..";
-    s << start_offset.bytes() + size.bytes() << "].try_into().unwrap());";
+  auto offset = num_leading_bits == 0 ? 0 : -1;
+  s << GetRustParseDataType() << "::from_le_bytes([";
+  int total_bytes;
+  if (size_ <= 8) {
+    total_bytes = 1;
+  } else if (size_ <= 16) {
+    total_bytes = 2;
+  } else if (size_ <= 32) {
+    total_bytes = 4;
   } else {
-    s << GetRustParseDataType() << "::from_le_bytes(bytes[" << start_offset.bytes() - 1 << "..";
-    s << start_offset.bytes() + size.bytes() - 1 << "].try_into().unwrap());";
+    total_bytes = 8;
+  }
+  for (int i = 0; i < total_bytes; i++) {
+    if (i > 0) {
+      s << ",";
+    }
+    if (i < size.bytes()) {
+      s << "bytes[" << start_offset.bytes() + i + offset << "]";
+    } else {
+      s << 0;
+    }
+  }
+  s << "]);";
+
+  if (num_leading_bits != 0) {
     s << "let " << GetName() << " = " << GetName() << " >> " << num_leading_bits << ";";
   }
 
@@ -195,8 +215,8 @@ void ScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_off
   Size size = GetSize();
   int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize());
 
-  if (GetFieldType() == SizeField::kFieldType) {
-    // Do nothing, the field access has already happened in packet_def
+  if (GetFieldType() == SizeField::kFieldType || GetFieldType() == FixedScalarField::kFieldType) {
+    // Do nothing, the field access has already happened
   } else if (GetRustParseDataType() != GetRustDataType()) {
     // needs casting to primitive
     s << "let " << GetName() << " = self." << GetName() << ".to_" << GetRustParseDataType() << "().unwrap();";
@@ -222,11 +242,11 @@ void ScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_off
       mask |= 1;
     }
     s << "let " << GetName() << " = (" << GetName() << " << " << num_leading_bits << ") | ("
-      << "(buffer[" << start_offset.bytes() << "] as " << GetRustParseDataType() << ") & 0x" << std::hex << mask
-      << std::dec << ");";
+      << "(buffer[" << start_offset.bytes() + access_offset << "] as " << GetRustParseDataType() << ") & 0x" << std::hex
+      << mask << std::dec << ");";
   }
 
   s << "buffer[" << start_offset.bytes() + access_offset << ".."
     << start_offset.bytes() + GetSize().bytes() + access_offset << "].copy_from_slice(&" << GetName()
-    << ".to_le_bytes());";
+    << ".to_le_bytes()[0.." << size.bytes() << "]);";
 }
index 9870eae..ffa0f57 100644 (file)
@@ -262,14 +262,21 @@ void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const
   if (element_field_type == ScalarField::kFieldType) {
     s << "let " << GetName() << ": " << GetRustDataType() << " = ";
     if (size_field_ == nullptr) {
-      s << "bytes[" << start_offset.bytes() << "..].to_vec().chunks_exact(";
+      s << "bytes[" << start_offset.bytes() << "..]";
+    } else if (size_field_->GetFieldType() == CountField::kFieldType) {
+      s << "bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ((";
+      s << size_field_->GetName() << " as usize) * " << element_size << ")]";
     } else {
       s << "bytes[" << start_offset.bytes() << "..(";
       s << start_offset.bytes() << " + " << size_field_->GetName();
-      s << " as usize)].to_vec().chunks_exact(";
+      s << " as usize)";
+      if (GetSizeModifier() != "") {
+        s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
+      }
+      s << "]";
     }
 
-    s << element_size << ").into_iter().map(|i| ";
+    s << ".to_vec().chunks_exact(" << element_size << ").into_iter().map(|i| ";
     s << element_field->GetRustDataType() << "::from_le_bytes([";
 
     for (int j=0; j < element_size; j++) {
@@ -280,37 +287,42 @@ void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const
     }
     s << "])).collect();";
   } else {
-        s << "let " << GetName() << ": " << GetRustDataType() << " = ";
-        if (size_field_ == nullptr) {
-         s << "bytes[" << start_offset.bytes() << "..].to_vec().chunks_exact(";
-       } else {
-          s << "bytes[" << start_offset.bytes() << "..(";
-          s << start_offset.bytes() << " + (" << size_field_->GetName() << " as usize * ";
-          s << GetElementField()->GetSize().bytes() << "))].to_vec().chunks_exact(";
-       }
-
-        s << element_size << ").into_iter().map(|i| ";
-        s << element_field->GetRustDataType() << "::parse(&[";
-
-        for (int j=0; j < element_size; j++) {
-          s << "i[" << j << "]";
-         if (j != element_size - 1) {
-            s << ", ";
-         }
-        }
-        s << "]).unwrap()).collect();";
+    s << "let mut " << GetName() << ": " << GetRustDataType() << " = Vec::new();";
+    if (size_field_ == nullptr) {
+      s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
+      s << "while parsable_.len() > 0 {";
+    } else if (size_field_->GetFieldType() == CountField::kFieldType) {
+      s << "let mut parsable_ = &bytes[" << start_offset.bytes() << "..];";
+      s << "let count_ = " << size_field_->GetName() << " as usize;";
+      s << "for _ in 0..count_ {";
+    } else {
+      s << "let mut parsable_ = &bytes[" << start_offset.bytes() << ".." << start_offset.bytes() << " + ("
+        << size_field_->GetName() << " as usize)";
+      if (GetSizeModifier() != "") {
+        s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
+      }
+      s << "];";
+      s << "while parsable_.len() > 0 {";
+    }
+    s << " let parsed = " << element_field->GetRustDataType() << "::parse(&parsable_)?;";
+    s << " parsable_ = &parsable_[parsed.get_total_size()..];";
+    s << GetName() << ".push(parsed);";
+    s << "}";
   }
 }
 
 void VectorField::GenRustWriter(std::ostream& s, Size start_offset, Size) const {
-  s << "for (i, e) in self." << GetName() << ".iter().enumerate() {";
   if (GetElementField()->GetFieldType() == ScalarField::kFieldType) {
+    s << "for (i, e) in self." << GetName() << ".iter().enumerate() {";
     s << "buffer[" << start_offset.bytes() << "+i..";
     s << start_offset.bytes() << "+i+" << GetElementField()->GetSize().bytes() << "]";
     s << ".copy_from_slice(&e.to_le_bytes())";
+    s << "}";
   } else {
-    s << "self." << GetName() << "[i].write_to(&mut buffer[" << start_offset.bytes() << "+i..";
-    s << start_offset.bytes() << "+i+" << GetElementField()->GetSize().bytes() << "]);";
+    s << "let mut vec_buffer_ = &mut buffer[" << start_offset.bytes() << "..];";
+    s << "for e_ in &self." << GetName() << " {";
+    s << " e_.write_to(&mut vec_buffer_[0..e_.get_total_size()]);";
+    s << " vec_buffer_ = &mut vec_buffer_[e_.get_total_size()..];";
+    s << "}";
   }
-  s << "}";
 }
index 32b7be6..3d97961 100644 (file)
@@ -869,6 +869,25 @@ void PacketDef::GenRustStructFieldNames(std::ostream& s) const {
 void PacketDef::GenRustStructImpls(std::ostream& s) const {
   s << "impl " << name_ << "Data {";
 
+  // conforms function
+  s << "fn conforms(bytes: &[u8]) -> bool {";
+  GenRustConformanceCheck(s);
+
+  auto fields = fields_.GetFieldsWithTypes({
+      StructField::kFieldType,
+  });
+
+  for (auto const& field : fields) {
+    auto start_offset = GetOffsetForField(field->GetName(), false);
+    auto end_offset = GetOffsetForField(field->GetName(), true);
+
+    s << "if !" << field->GetRustDataType() << "::conforms(&bytes[" << start_offset.bytes();
+    s << ".." << start_offset.bytes() + field->GetSize().bytes() << "]) { return false; }";
+  }
+
+  s << " true";
+  s << "}";
+
   // parse function
   if (parent_constraints_.empty() && !children_.empty() && parent_ != nullptr) {
       auto constraint = FindConstraintField();
@@ -879,9 +898,8 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
   } else {
     s << "fn parse(bytes: &[u8]) -> Result<Self> {";
   }
-  auto fields = fields_.GetFieldsWithoutTypes({
+  fields = fields_.GetFieldsWithoutTypes({
       BodyField::kFieldType,
-      FixedScalarField::kFieldType,
   });
 
   for (auto const& field : fields) {
@@ -928,6 +946,7 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
         auto enum_variant = enum_type + "::"
             + util::UnderscoreToCamelCase(util::ToLowerCase(variant_name));
         s << enum_variant;
+        s << " if " << desc_path[0]->name_ << "Data::conforms(&bytes[..])";
         s << " => {";
         s << name_ << "DataChild::";
         s << desc_path[0]->name_ << "(Arc::new(";
index 682ea4b..a628e03 100644 (file)
@@ -570,12 +570,28 @@ bool ParentDef::HasChildEnums() const {
   return !children_.empty() || fields_.HasPayload();
 }
 
+void ParentDef::GenRustConformanceCheck(std::ostream& s) const {
+  auto fields = fields_.GetFieldsWithTypes({
+      FixedScalarField::kFieldType,
+  });
+
+  for (auto const& field : fields) {
+    auto start_offset = GetOffsetForField(field->GetName(), false);
+    auto end_offset = GetOffsetForField(field->GetName(), true);
+
+    auto f = (FixedScalarField*)field;
+    f->GenRustGetter(s, start_offset, end_offset);
+    s << "if " << f->GetName() << " != ";
+    f->GenValue(s);
+    s << " { return false; } ";
+  }
+}
+
 void ParentDef::GenRustWriteToFields(std::ostream& s) const {
   auto fields = fields_.GetFieldsWithoutTypes({
       BodyField::kFieldType,
       PaddingField::kFieldType,
       ReservedField::kFieldType,
-      FixedScalarField::kFieldType,
   });
 
   for (auto const& field : fields) {
index 2e923a4..1acca00 100644 (file)
@@ -91,4 +91,6 @@ class ParentDef : public TypeDef {
   void GenRustWriteToFields(std::ostream& s) const;
 
   void GenSizeRetVal(std::ostream& s) const;
+
+  void GenRustConformanceCheck(std::ostream& s) const;
 };
index 929f3a6..3b93a7b 100644 (file)
@@ -364,6 +364,12 @@ void StructDef::GenRustDeclarations(std::ostream& s) const {
 
 void StructDef::GenRustImpls(std::ostream& s) const {
   s << "impl " << name_ << "{";
+
+  s << "fn conforms(bytes: &[u8]) -> bool {";
+  GenRustConformanceCheck(s);
+  s << " true";
+  s << "}";
+
   s << "pub fn parse(bytes: &[u8]) -> Result<Self> {";
   auto fields = fields_.GetFieldsWithoutTypes({
       BodyField::kFieldType,
index fb8b8c8..2d98776 100644 (file)
@@ -43,7 +43,9 @@ impl<T: 'static + Packet + Send> RxAdapter<T> {
             while let Some(payload) = clone_rx.lock().await.recv().await {
                 let mut data = Data::default();
                 data.set_payload(payload.to_vec());
-                sink.send((data, WriteFlags::default())).await.unwrap();
+                if let Err(e) = sink.send((data, WriteFlags::default())).await {
+                    log::error!("failure sending data: {:?}", e);
+                }
             }
         });
     }
index f6cddfc..19c5835 100644 (file)
@@ -93,16 +93,24 @@ where
             payload.resize(len, 0);
             reader.read_exact(&mut payload).await?;
             buffer.unsplit(payload);
-            evt_tx.send(EventPacket::parse(&buffer.freeze()).unwrap()).unwrap();
+            let frozen = buffer.freeze();
+            match EventPacket::parse(&frozen) {
+                Ok(p) => evt_tx.send(p).unwrap(),
+                Err(e) => log::error!("dropping invalid event packet: {}: {:02x}", e, frozen),
+            }
         } else if buffer[0] == HciPacketType::Acl as u8 {
             buffer.resize(HciPacketHeaderSize::Acl as usize, 0);
             reader.read_exact(&mut buffer).await?;
             let len: usize = (buffer[2] as u16 + ((buffer[3] as u16) << 8)).into();
-            let mut payload = buffer.split_off(HciPacketHeaderSize::Event as usize);
+            let mut payload = buffer.split_off(HciPacketHeaderSize::Acl as usize);
             payload.resize(len, 0);
             reader.read_exact(&mut payload).await?;
             buffer.unsplit(payload);
-            acl_tx.send(AclPacket::parse(&buffer.freeze()).unwrap()).unwrap();
+            let frozen = buffer.freeze();
+            match AclPacket::parse(&frozen) {
+                Ok(p) => acl_tx.send(p).unwrap(),
+                Err(e) => log::error!("dropping invalid ACL packet: {}: {:02x}", e, frozen),
+            }
         }
     }
 }
index 88b2d54..7fd89b3 100644 (file)
@@ -131,19 +131,31 @@ async fn provide_snooped_hal(config: SnoopConfig, raw_hal: RawHal, rt: Arc<Runti
         loop {
             select! {
                 Some(evt) = consume(&raw_hal.evt_rx) => {
-                    evt_up_tx.send(evt.clone()).await.unwrap();
+                    if let Err(e) = evt_up_tx.send(evt.clone()).await {
+                        error!("evt channel closed {:?}", e);
+                        break;
+                    }
                     logger.log(Type::Evt, Direction::Up, evt.to_bytes()).await;
                 },
                 Some(cmd) = cmd_down_rx.recv() => {
-                    raw_hal.cmd_tx.send(cmd.clone()).unwrap();
+                    if let Err(e) = raw_hal.cmd_tx.send(cmd.clone())  {
+                        error!("cmd channel closed {:?}", e);
+                        break;
+                    }
                     logger.log(Type::Cmd, Direction::Down, cmd.to_bytes()).await;
                 },
                 Some(acl) = acl_down_rx.recv() => {
-                    raw_hal.acl_tx.send(acl.clone()).unwrap();
+                    if let Err(e) = raw_hal.acl_tx.send(acl.clone()) {
+                        error!("acl down channel closed {:?}", e);
+                        break;
+                    }
                     logger.log(Type::Acl, Direction::Down, acl.to_bytes()).await;
                 },
                 Some(acl) = consume(&raw_hal.acl_rx) => {
-                    acl_up_tx.send(acl.clone()).await.unwrap();
+                    if let Err(e) = acl_up_tx.send(acl.clone()).await {
+                        error!("acl up channel closed {:?}", e);
+                        break;
+                    }
                     logger.log(Type::Acl, Direction::Up, acl.to_bytes()).await;
                 },
                 else => break,
index 141e2d1..6d1ae44 100644 (file)
@@ -81,9 +81,11 @@ impl HciFacade for HciFacadeService {
     fn send_command(&mut self, ctx: RpcContext<'_>, mut data: Data, sink: UnarySink<Empty>) {
         let packet = CommandPacket::parse(&data.take_payload()).unwrap();
         let mut commands = self.commands.clone();
+        let evt_tx = self.evt_tx.clone();
         ctx.spawn(async move {
             sink.success(Empty::default()).await.unwrap();
-            commands.send(packet).await.unwrap();
+            let response = commands.send(packet).await.unwrap();
+            evt_tx.send(response).await.unwrap();
         });
     }
 
index 98d7d36..a8af09e 100644 (file)
@@ -22,6 +22,7 @@ use bt_packets::hci::{
 };
 use error::Result;
 use gddi::{module, part_out, provides, Stoppable};
+use log::error;
 use std::collections::HashMap;
 use std::sync::Arc;
 use std::time::Duration;
@@ -178,7 +179,11 @@ async fn dispatch(
                         hci_timeout.cancel();
                         let this_opcode = evt.get_command_op_code();
                         match pending.take() {
-                            Some(QueuedCommand{cmd, fut}) if cmd.get_op_code() == this_opcode  => fut.send(evt.into()).unwrap(),
+                            Some(QueuedCommand{cmd, fut}) if cmd.get_op_code() == this_opcode => {
+                                if let Err(e) = fut.send(evt.into()) {
+                                    error!("failure dispatching command status {:?}", e);
+                                }
+                            },
                             Some(QueuedCommand{cmd, ..}) => panic!("Waiting for {:?}, got {:?}", cmd.get_op_code(), this_opcode),
                             None => panic!("Unexpected status event with opcode {:?}", this_opcode),
                         }
@@ -187,7 +192,11 @@ async fn dispatch(
                         hci_timeout.cancel();
                         let this_opcode = evt.get_command_op_code();
                         match pending.take() {
-                            Some(QueuedCommand{cmd, fut}) if cmd.get_op_code() == this_opcode  => fut.send(evt.into()).unwrap(),
+                            Some(QueuedCommand{cmd, fut}) if cmd.get_op_code() == this_opcode => {
+                                if let Err(e) = fut.send(evt.into()) {
+                                    error!("failure dispatching command complete {:?}", e);
+                                }
+                            },
                             Some(QueuedCommand{cmd, ..}) => panic!("Waiting for {:?}, got {:?}", cmd.get_op_code(), this_opcode),
                             None => panic!("Unexpected complete event with opcode {:?}", this_opcode),
                         }
@@ -195,7 +204,11 @@ async fn dispatch(
                     LeMetaEvent(evt) => {
                         let code = evt.get_subevent_code();
                         match le_evt_handlers.lock().await.get(&code) {
-                            Some(sender) => sender.send(evt).await.unwrap(),
+                            Some(sender) => {
+                                if let Err(e) = sender.send(evt).await {
+                                    error!("le meta event channel closed {:?}", e);
+                                }
+                            },
                             None => panic!("Unhandled le subevent {:?}", code),
                         }
                     },
@@ -205,14 +218,21 @@ async fn dispatch(
                     _ => {
                         let code = evt.get_event_code();
                         match evt_handlers.lock().await.get(&code) {
-                            Some(sender) => sender.send(evt).await.unwrap(),
-                            None => panic!("Unhandled le subevent {:?}", code),
+                            Some(sender) => {
+                                if let Err(e) = sender.send(evt).await {
+                                    error!("hci event channel closed {:?}", e);
+                                }
+                            },
+                            None if code == EventCode::NumberOfCompletedPackets =>{},
+                            None => panic!("Unhandled subevent {:?}", code),
                         }
                     },
                 }
             },
             Some(queued) = cmd_rx.recv(), if pending.is_none() => {
-                cmd_tx.send(queued.cmd.clone()).await.unwrap();
+                if let Err(e) = cmd_tx.send(queued.cmd.clone()).await {
+                    error!("command queue closed: {:?}", e);
+                }
                 hci_timeout.reset(Duration::from_secs(2));
                 pending = Some(queued);
             },