From 379b84b4e396d1bc9ffcb368d69851d821f652a3 Mon Sep 17 00:00:00 2001 From: Zach Johnson Date: Mon, 1 Feb 2021 15:44:12 -0800 Subject: [PATCH] rusty-gd: move size field to packet gen this way we have full context on the targeted field Bug: 171749953 Tag: #gd-refactor Test: gd/cert/run --rhost Change-Id: I2e4658cac68aaadacf35c31cc58c87ae997ae0f1 --- gd/packet/parser/fields/scalar_field.cc | 7 +++++-- gd/packet/parser/fields/size_field.cc | 35 ------------------------------- gd/packet/parser/fields/size_field.h | 2 -- gd/packet/parser/packet_def.cc | 37 +++++++++++++++++++++++++++++++++ 4 files changed, 42 insertions(+), 39 deletions(-) diff --git a/gd/packet/parser/fields/scalar_field.cc b/gd/packet/parser/fields/scalar_field.cc index c7c330485..1c3c5a618 100644 --- a/gd/packet/parser/fields/scalar_field.cc +++ b/gd/packet/parser/fields/scalar_field.cc @@ -16,6 +16,7 @@ #include "fields/scalar_field.h" +#include "fields/size_field.h" #include "util.h" const std::string ScalarField::kFieldType = "ScalarField"; @@ -194,8 +195,10 @@ void ScalarField::GenRustWriter(std::ostream& s, Size start_offset, Size end_off Size size = GetSize(); int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize()); - // needs casting to primitive - if (GetRustParseDataType() != GetRustDataType()) { + if (GetFieldType() == SizeField::kFieldType) { + // Do nothing, the field access has already happened in packet_def + } else if (GetRustParseDataType() != GetRustDataType()) { + // needs casting to primitive s << "let " << GetName() << " = self." << GetName() << ".to_" << GetRustParseDataType() << "().unwrap();"; } else { s << "let " << GetName() << " = self." << GetName() << ";"; diff --git a/gd/packet/parser/fields/size_field.cc b/gd/packet/parser/fields/size_field.cc index 45d0978dc..ce2c899d9 100644 --- a/gd/packet/parser/fields/size_field.cc +++ b/gd/packet/parser/fields/size_field.cc @@ -65,38 +65,3 @@ std::string SizeField::GetSizedFieldName() const { void SizeField::GenStringRepresentation(std::ostream& s, std::string accessor) const { s << accessor; } - -void SizeField::GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const { - Size size = GetSize(); - int num_leading_bits = GetRustBitOffset(s, start_offset, end_offset, GetSize()); - - s << "let mut " << GetName() << ": " << GetRustDataType() << " = self.get_total_size() as "; - s << GetRustDataType() << ";"; - s << GetName() << " -= self.get_size() as " << GetRustDataType() << ";"; - if (util::RoundSizeUp(size.bits()) != size.bits()) { - uint64_t mask = 0; - for (int i = 0; i < size.bits(); i++) { - mask <<= 1; - mask |= 1; - } - s << "let " << GetName() << " = "; - s << GetName() << " & 0x" << std::hex << mask << std::dec << ";"; - } - - int access_offset = 0; - if (num_leading_bits != 0) { - access_offset = -1; - uint64_t mask = 0; - for (int i = 0; i < num_leading_bits; i++) { - mask <<= 1; - mask |= 1; - } - s << "let " << GetName() << " = (" << GetName() << " << " << num_leading_bits << ") | (" - << "(buffer[" << start_offset.bytes() << "] as " << GetRustParseDataType() << ") & 0x" << std::hex << mask - << std::dec << ");"; - } - - s << "buffer[" << start_offset.bytes() + access_offset << ".." - << start_offset.bytes() + GetSize().bytes() + access_offset << "].copy_from_slice(&" << GetName() - << ".to_le_bytes());"; -} diff --git a/gd/packet/parser/fields/size_field.h b/gd/packet/parser/fields/size_field.h index 64e827a11..2d40c42eb 100644 --- a/gd/packet/parser/fields/size_field.h +++ b/gd/packet/parser/fields/size_field.h @@ -48,8 +48,6 @@ class SizeField : public ScalarField { virtual void GenStringRepresentation(std::ostream& s, std::string accessor) const override; - virtual void GenRustWriter(std::ostream& s, Size start_offset, Size end_offset) const override; - private: int size_; std::string sized_field_name_; diff --git a/gd/packet/parser/packet_def.cc b/gd/packet/parser/packet_def.cc index 264496f8b..181ebfd99 100644 --- a/gd/packet/parser/packet_def.cc +++ b/gd/packet/parser/packet_def.cc @@ -1017,6 +1017,43 @@ void PacketDef::GenRustStructImpls(std::ostream& s) const { << "no method exists to determine field location from begin() or end().\n"; } + if (field->GetFieldType() == SizeField::kFieldType) { + const auto& field_name = ((SizeField*)field)->GetSizedFieldName(); + const auto& sized_field = fields_.GetField(field_name); + if (sized_field == nullptr) { + ERROR(field) << __func__ << ": Can't find sized field named " << field_name; + } + if (sized_field->GetFieldType() == PayloadField::kFieldType) { + std::string modifier = ((PayloadField*)sized_field)->size_modifier_; + if (modifier != "") { + ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name; + } + + s << "let " << field->GetName() << " = " << field->GetRustDataType() + << "::try_from(self.child.get_total_size()).expect(\"payload size did not fit\");"; + } else if (sized_field->GetFieldType() == BodyField::kFieldType) { + s << "let " << field->GetName() << " = " << field->GetRustDataType() + << "::try_from(self.get_total_size() - self.get_size()).expect(\"payload size did not fit\");"; + } else if (sized_field->GetFieldType() == VectorField::kFieldType) { + const auto& vector_name = field_name + "_bytes"; + const VectorField* vector = (VectorField*)sized_field; + if (vector->element_size_.empty() || vector->element_size_.has_dynamic()) { + s << "let " << vector_name + " = self." << field_name << ".iter().fold(0, |acc, x| acc + x.get_size());"; + } else { + s << "let " << vector_name + " = self." << field_name << ".len() * ((" << vector->element_size_ << ") / 8);"; + } + std::string modifier = vector->GetSizeModifier(); + if (modifier != "") { + ERROR(field) << __func__ << ": size modifiers not implemented yet for " << field_name; + } + + s << "let " << field->GetName() << " = " << field->GetRustDataType() << "::try_from(" << vector_name + << ").expect(\"payload size did not fit\");"; + } else { + ERROR(field) << __func__ << ": Unhandled sized field type for " << field_name; + } + } + field->GenRustWriter(s, start_field_offset, end_field_offset); } -- 2.11.0