OSDN Git Service

rusty-gd: generate bounds checks
authorZach Johnson <zachoverflow@google.com>
Wed, 10 Feb 2021 04:42:38 +0000 (20:42 -0800)
committerZach Johnson <zachoverflow@google.com>
Wed, 10 Feb 2021 05:25:47 +0000 (21:25 -0800)
Bug: 171749953
Tag: #gd-refactor
Test: gd/cert/run --rhost DirectHciTest
Change-Id: I2d349518e466853a65fbbb296a0ea69e34a4a34d

13 files changed:
gd/packet/parser/fields/body_field.h
gd/packet/parser/fields/packet_field.cc
gd/packet/parser/fields/packet_field.h
gd/packet/parser/fields/padding_field.h
gd/packet/parser/fields/payload_field.h
gd/packet/parser/fields/reserved_field.h
gd/packet/parser/fields/struct_field.cc
gd/packet/parser/fields/struct_field.h
gd/packet/parser/fields/vector_field.cc
gd/packet/parser/fields/vector_field.h
gd/packet/parser/gen_rust.cc
gd/packet/parser/packet_def.cc
gd/packet/parser/struct_def.cc

index cd0e1c7..c31af74 100644 (file)
@@ -58,6 +58,7 @@ class BodyField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream&, Size, Size, std::string) const override{};
   // Body fields can only be dynamically sized.
   const SizeField* size_field_{nullptr};
 };
index 8b1bbbb..f42320a 100644 (file)
@@ -130,3 +130,14 @@ bool PacketField::GenRustNameAndType(std::ostream& s) const {
   s << GetName() << ": " << param_type;
   return true;
 }
+
+void PacketField::GenBoundsCheck(std::ostream& s, Size start_offset, Size, std::string context) const {
+  Size size = GetSize();
+  s << "if bytes.len() < " << start_offset.bytes() + size.bytes() << " {";
+  s << " return Err(Error::InvalidLengthError{";
+  s << "    obj: \"" << context << "\".to_string(),";
+  s << "    field: \"" << GetName() << "\".to_string(),";
+  s << "    wanted: " << start_offset.bytes() + size.bytes() << ",";
+  s << "    got: bytes.len()});";
+  s << "}";
+}
index 75caa76..1285362 100644 (file)
@@ -128,6 +128,8 @@ class PacketField : public Loggable {
 
   virtual void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const = 0;
 
+  virtual void GenBoundsCheck(std::ostream& s, Size start_offset, Size, std::string) const;
+
   virtual bool GetterIsByRef() const {
     return true;
   }
index 81903f5..ea40ccd 100644 (file)
@@ -57,6 +57,8 @@ class PaddingField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream&, Size, Size, std::string) const override{};
+
  private:
   Size size_;
 };
index 770708f..f86f275 100644 (file)
@@ -62,6 +62,8 @@ class PayloadField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream&, Size, Size, std::string) const override{};
+
   // Payload fields can only be dynamically sized.
   const SizeField* size_field_;
   // Only used if the size of the payload is based on another field.
index 19b5041..7fabe49 100644 (file)
@@ -53,6 +53,8 @@ class ReservedField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream&, Size, Size, std::string) const override{};
+
  private:
   std::string name_;
   int size_;
index 47b457a..75ad1ab 100644 (file)
@@ -92,6 +92,10 @@ std::string StructField::GetRustDataType() const {
   return GetDataType();
 }
 
+void StructField::GenBoundsCheck(std::ostream&, Size, Size, std::string) const {
+  // implicitly checked by the struct parser
+}
+
 void StructField::GenRustGetter(std::ostream& s, Size start_offset, Size) const {
   s << "let " << GetName() << " = ";
   s << GetRustDataType() << "::parse(&bytes[" << start_offset.bytes() << "..";
index cfa13ef..acc5fb4 100644 (file)
@@ -57,6 +57,8 @@ class StructField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream& s, Size start_offset, Size end_offset, std::string) const override;
+
  private:
   std::string type_name_;
 
index ffa0f57..9c8ca07 100644 (file)
@@ -254,6 +254,53 @@ std::string VectorField::GetRustDataType() const {
   return "Vec::<" + element_field_->GetRustDataType() + ">";
 }
 
+void VectorField::GenBoundsCheck(std::ostream& s, Size start_offset, Size, std::string context) const {
+  auto element_field_type = GetElementField()->GetFieldType();
+  auto element_field = GetElementField();
+  auto element_size = element_field->GetSize().bytes();
+
+  if (element_field_type == ScalarField::kFieldType) {
+    if (size_field_ == nullptr) {
+      s << "let rem_ = (bytes.len() - " << start_offset.bytes() << ") % " << element_size << ";";
+      s << "if rem_ != 0 {";
+      s << " return Err(Error::InvalidLengthError{";
+      s << "    obj: \"" << context << "\".to_string(),";
+      s << "    field: \"" << GetName() << "\".to_string(),";
+      s << "    wanted: bytes.len() + rem_,";
+      s << "    got: bytes.len()});";
+      s << "}";
+    } else if (size_field_->GetFieldType() == CountField::kFieldType) {
+      s << "let want_ = " << start_offset.bytes() << " + ((" << size_field_->GetName() << " as usize) * "
+        << element_size << ");";
+      s << "if bytes.len() < want_ {";
+      s << " return Err(Error::InvalidLengthError{";
+      s << "    obj: \"" << context << "\".to_string(),";
+      s << "    field: \"" << GetName() << "\".to_string(),";
+      s << "    wanted: want_,";
+      s << "    got: bytes.len()});";
+      s << "}";
+    } else {
+      s << "let want_ = " << start_offset.bytes() << " + (" << size_field_->GetName() << " as usize)";
+      if (GetSizeModifier() != "") {
+        s << " - ((" << GetSizeModifier().substr(1) << ") / 8)";
+      }
+      s << ";";
+      s << "if bytes.len() < want_ {";
+      s << " return Err(Error::InvalidLengthError{";
+      s << "    obj: \"" << context << "\".to_string(),";
+      s << "    field: \"" << GetName() << "\".to_string(),";
+      s << "    wanted: want_,";
+      s << "    got: bytes.len()});";
+      s << "}";
+      if (GetSizeModifier() != "") {
+        s << "if ((" << size_field_->GetName() << " as usize) < ((" << GetSizeModifier().substr(1) << ") / 8)) {";
+        s << " return Err(Error::ImpossibleStructError);";
+        s << "}";
+      }
+    }
+  }
+}
+
 void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const {
   auto element_field_type = GetElementField()->GetFieldType();
   auto element_field = GetElementField();
@@ -304,9 +351,14 @@ void VectorField::GenRustGetter(std::ostream& s, Size start_offset, Size) const
       s << "];";
       s << "while parsable_.len() > 0 {";
     }
-    s << " let parsed = " << element_field->GetRustDataType() << "::parse(&parsable_)?;";
-    s << " parsable_ = &parsable_[parsed.get_total_size()..];";
+    s << " match " << element_field->GetRustDataType() << "::parse(&parsable_) {";
+    s << "  Ok(parsed) => {";
+    s << "   parsable_ = &parsable_[parsed.get_total_size()..];";
     s << GetName() << ".push(parsed);";
+    s << "  },";
+    s << "  Err(Error::ImpossibleStructError) => break,";
+    s << "  Err(e) => return Err(e),";
+    s << " }";
     s << "}";
   }
 }
index a97b2e1..e6b5330 100644 (file)
@@ -75,6 +75,8 @@ class VectorField : public PacketField {
 
   void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override;
 
+  void GenBoundsCheck(std::ostream& s, Size start_offset, Size end_offset, std::string context) const override;
+
   const std::string name_;
 
   const PacketField* element_field_{nullptr};
index 2773e9b..047adb0 100644 (file)
@@ -40,6 +40,15 @@ pub enum Error {
     field: String,
     value: u64,
   },
+  #[error("when parsing {obj}.{field} needed length of {wanted} but got {got}")]
+  InvalidLengthError {
+    obj: String,
+    field: String,
+    wanted: usize,
+    got: usize,
+  },
+  #[error("Due to size restrictions a struct could not be parsed.")]
+  ImpossibleStructError,
 }
 
 pub trait Packet {
index b51c2db..6959233 100644 (file)
@@ -911,6 +911,7 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const {
                    << "no method exists to determine field location from begin() or end().\n";
     }
 
+    field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
     field->GenRustGetter(s, start_field_offset, end_field_offset);
   }
 
index 3b93a7b..a3337c4 100644 (file)
@@ -384,6 +384,7 @@ void StructDef::GenRustImpls(std::ostream& s) const {
                    << "no method exists to determine field location from begin() or end().\n";
     }
 
+    field->GenBoundsCheck(s, start_field_offset, end_field_offset, name_);
     field->GenRustGetter(s, start_field_offset, end_field_offset);
   }