Skip to content

Commit

Permalink
Add support for _elementsize_ to rust generator
Browse files Browse the repository at this point in the history
  • Loading branch information
DeltaEvo committed May 15, 2024
1 parent 362e2d2 commit 7b805e1
Show file tree
Hide file tree
Showing 24 changed files with 2,236 additions and 9 deletions.
4 changes: 4 additions & 0 deletions pdl-compiler/scripts/generate_cxx_backend_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,6 +328,10 @@ def run(input: argparse.FileType, output: argparse.FileType, test_vectors: argpa
'Struct_Checksum_Field_FromEnd',
'PartialParent5',
'PartialParent12',
'Packet_Array_Field_VariableElementSize_ConstantSize',
'Packet_Array_Field_VariableElementSize_VariableSize',
'Packet_Array_Field_VariableElementSize_VariableCount',
'Packet_Array_Field_VariableElementSize_UnknownSize',
]

output.write(
Expand Down
8 changes: 8 additions & 0 deletions pdl-compiler/scripts/pdl/ast.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,12 @@ class SizeField(Field):
width: int


@node('elementsize_field')
class ElementSize(Field):
field_id: str
width: int


@node('count_field')
class CountField(Field):
field_id: str
Expand Down Expand Up @@ -276,6 +282,8 @@ def convert_(obj: object) -> object:
loc = obj['loc']
loc = SourceRange(loc['file'], SourceLocation(**loc['start']), SourceLocation(**loc['end']))
constructor = constructors_.get(kind)
if not constructor:
raise Exception(f'Unknown kind {kind}')
members = {'loc': loc, 'kind': kind}
for name, value in obj.items():
if name != 'kind' and name != 'loc':
Expand Down
70 changes: 70 additions & 0 deletions pdl-compiler/src/backends/rust/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1289,6 +1289,76 @@ mod tests {
"
);

test_pdl!(
packet_decl_array_dynamic_element_size,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[]
}
"
);

test_pdl!(
packet_decl_array_dynamic_element_size_dynamic_size,
"
struct Foo {
inner: 8[]
}
packet Bar {
_size_(x): 4,
_elementsize_(x): 4,
x: Foo[]
}
"
);

test_pdl!(
packet_decl_array_dynamic_element_size_dynamic_count,
"
struct Foo {
inner: 8[]
}
packet Bar {
_count_(x): 4,
_elementsize_(x): 4,
x: Foo[]
}
"
);

test_pdl!(
packet_decl_array_dynamic_element_size_static_count,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[4]
}
"
);

test_pdl!(
packet_decl_array_dynamic_element_size_static_count_1,
"
struct Foo {
inner: 8[]
}
packet Bar {
_elementsize_(x): 5,
padding: 3,
x: Foo[1]
}
"
);

test_pdl!(
packet_decl_reserved_field,
"
Expand Down
137 changes: 128 additions & 9 deletions pdl-compiler/src/backends/rust/parser.rs
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,12 @@ impl<'a> FieldParser<'a> {
let #id = #v as usize;
}
}
ast::FieldDesc::ElementSize { field_id, .. } => {
let id = format_ident!("{field_id}_element_size");
quote! {
let #id = #v as usize;
}
}
ast::FieldDesc::Count { field_id, .. } => {
let id = format_ident!("{field_id}_count");
quote! {
Expand Down Expand Up @@ -289,6 +295,15 @@ impl<'a> FieldParser<'a> {
}
}

fn find_element_size_field(&self, id: &str) -> Option<proc_macro2::Ident> {
self.decl.fields().find_map(|field| match &field.desc {
ast::FieldDesc::ElementSize { field_id, .. } if field_id == id => {
Some(format_ident!("{id}_element_size"))
}
_ => None,
})
}

fn payload_field_offset_from_end(&self) -> Option<usize> {
let decl = self.scope.typedef[self.packet_name];
let mut fields = decl.fields();
Expand Down Expand Up @@ -338,17 +353,20 @@ impl<'a> FieldParser<'a> {
decl: Option<&ast::Decl>,
) {
enum ElementWidth {
Static(usize), // Static size in bytes.
Static(usize), // Static size in bytes.
Dynamic(proc_macro2::Ident), // Dynamic size in bytes.
Unknown,
}
let element_width =
match width.or_else(|| self.schema.total_size(decl.unwrap().key).static_()) {
Some(w) => {
assert_eq!(w % 8, 0, "Array element size ({w}) is not a multiple of 8");
ElementWidth::Static(w / 8)
}
None => ElementWidth::Unknown,
};
let element_width = if let Some(w) =
width.or_else(|| self.schema.total_size(decl.unwrap().key).static_())
{
assert_eq!(w % 8, 0, "Array element size ({w}) is not a multiple of 8");
ElementWidth::Static(w / 8)
} else if let Some(element_size_field) = self.find_element_size_field(id) {
ElementWidth::Dynamic(element_size_field)
} else {
ElementWidth::Unknown
};

// The "shape" of the array, i.e., the number of elements
// given via a static count, a count field, a size field, or
Expand Down Expand Up @@ -385,6 +403,8 @@ impl<'a> FieldParser<'a> {
None => self.span.clone(),
};

let field_name = id;
let packet_name = self.packet_name;
let id = id.to_ident();

let parse_element = self.parse_array_element(&span, width, type_id, decl);
Expand Down Expand Up @@ -513,6 +533,105 @@ impl<'a> FieldParser<'a> {
}
});
}
(ElementWidth::Dynamic(element_size_field), ArrayShape::Static(count)) => {
// The element width is known, and the array element
// count is known statically.
let array_size = if *count == 1 {
quote!(#element_size_field)
} else {
quote!(#count * #element_size_field)
};

self.check_size(&span, &array_size);

let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
// TODO: use
// https://doc.rust-lang.org/std/array/fn.try_from_fn.html
// when stabilized.
let #id = #span.chunks(#element_size_field).take(#count)
.map(|mut chunk| #parse_element.and_then(|value| {
if chunk.is_empty() {
Ok(value)
} else {
Err(DecodeError::TrailingBytesInArray {
obj: #packet_name,
field: #field_name,
})
}
}))
.collect::<Result<Vec<_>, DecodeError>>()?;
#span = &#span[#array_size..];
let #id = #id
.try_into()
.map_err(|_| DecodeError::InvalidPacketError)?;
});
}
(ElementWidth::Dynamic(element_size_field), ArrayShape::CountField(count_field)) => {
// The element width is known, and the array element
// count is known dynamically by the count field.
self.check_size(&span, &quote!(#count_field * #element_size_field));

let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
let #id = #span.chunks(#element_size_field).take(#count_field)
.map(|mut chunk| #parse_element.and_then(|value| {
if chunk.is_empty() {
Ok(value)
} else {
Err(DecodeError::TrailingBytesInArray {
obj: #packet_name,
field: #field_name,
})
}
}))
.collect::<Result<Vec<_>, DecodeError>>()?;
#span = &#span[(#element_size_field * #count_field)..];
});
}
(ElementWidth::Dynamic(element_size_field), ArrayShape::SizeField(_))
| (ElementWidth::Dynamic(element_size_field), ArrayShape::Unknown) => {
// The element width is known, and the array full size
// is known by size field, or unknown (in which case
// it is the remaining span length).
let array_size = if let ArrayShape::SizeField(size_field) = &array_shape {
self.check_size(&span, &quote!(#size_field));
quote!(#size_field)
} else {
quote!(#span.remaining())
};
self.code.push(quote! {
if #array_size % #element_size_field != 0 {
return Err(DecodeError::InvalidArraySize {
array: #array_size,
element: #element_size_field,
});
}
});

let parse_element =
self.parse_array_element(&format_ident!("chunk"), width, type_id, decl);

self.code.push(quote! {
let #id = #span.chunks(#element_size_field).take(#array_size / #element_size_field)
.map(|mut chunk| #parse_element.and_then(|value| {
if chunk.is_empty() {
Ok(value)
} else {
Err(DecodeError::TrailingBytesInArray {
obj: #packet_name,
field: #field_name,
})
}
}))
.collect::<Result<Vec<_>, DecodeError>>()?;
#span = &#span[#array_size..];
});
}
}
}

Expand Down
44 changes: 44 additions & 0 deletions pdl-compiler/src/backends/rust/serializer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -377,6 +377,50 @@ impl Encoder {
shift,
});
}
ast::FieldDesc::ElementSize { field_id, width, .. } => {
let field_name = field_id.to_ident();
let field_type = types::Integer::new(*width);
let field_element_size_name = format_ident!("{field_id}_element_size");
let packet_name = &self.packet_name;
self.tokens.extend(quote! {
let #field_element_size_name = self.#field_name
.get(0)
.map_or(0, Packet::encoded_len);

for (element_index, element) in self.#field_name.iter().enumerate() {
if element.encoded_len() != #field_element_size_name {
return Err(EncodeError::InvalidArrayElementSize {
packet: #packet_name,
field: #field_id,
size: element.encoded_len(),
expected_size: #field_element_size_name,
element_index,
})
}
}
});
if field_type.width > *width {
let max_value = mask_bits(*width, "usize");
self.tokens.extend(quote! {
if #field_element_size_name > #max_value {
return Err(EncodeError::SizeOverflow {
packet: #packet_name,
field: #field_id,
size: #field_element_size_name,
maximum_size: #max_value,
})
}
});
}
self.tokens.extend(quote! {
let #field_element_size_name = #field_element_size_name as #field_type;
});
self.bit_fields.push(BitField {
value: quote!(#field_element_size_name),
field_type,
shift,
});
}
ast::FieldDesc::Count { field_id, width, .. } => {
let field_name = field_id.to_ident();
let field_type = types::Integer::new(*width);
Expand Down
4 changes: 4 additions & 0 deletions pdl-compiler/src/backends/rust/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -172,6 +172,10 @@ pub fn generate_tests(input_file: &str) -> Result<String, String> {
"Packet_Array_Field_UnsizedElement_VariableSize",
"Packet_Array_Field_SizedElement_VariableSize_Padded",
"Packet_Array_Field_UnsizedElement_VariableCount_Padded",
"Packet_Array_Field_VariableElementSize_ConstantSize",
"Packet_Array_Field_VariableElementSize_VariableSize",
"Packet_Array_Field_VariableElementSize_VariableCount",
"Packet_Array_Field_VariableElementSize_UnknownSize",
"Packet_Optional_Scalar_Field",
"Packet_Optional_Enum_Field",
"Packet_Optional_Struct_Field",
Expand Down
34 changes: 34 additions & 0 deletions pdl-compiler/tests/canonical/le_test_file.pdl
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ struct UnsizedStruct {
array: 8[],
}

struct UnknownSizeStruct {
array: 8[],
}

group ScalarGroup {
a: 16
}
Expand Down Expand Up @@ -362,6 +366,36 @@ packet Packet_Array_Field_UnsizedElement_VariableCount_Padded {
_padding_ [16],
}

packet Packet_Array_Field_VariableElementSize_ConstantSize {
_elementsize_(array): 4,
_reserved_: 4,
array: UnknownSizeStruct[4],
}

packet Packet_Array_Field_VariableElementSize_VariableSize {
_size_(array) : 4,
_reserved_: 4,
_elementsize_(array): 4,
_reserved_: 4,
array: UnknownSizeStruct[],
tail: UnknownSizeStruct[]
}

packet Packet_Array_Field_VariableElementSize_VariableCount {
_count_(array) : 4,
_reserved_: 4,
_elementsize_(array): 4,
_reserved_: 4,
array: UnknownSizeStruct[],
tail: UnknownSizeStruct[],
}

packet Packet_Array_Field_VariableElementSize_UnknownSize {
_elementsize_(array): 4,
_reserved_: 4,
array: UnknownSizeStruct[],
}

packet Packet_Optional_Scalar_Field {
c0: 1,
c1: 1,
Expand Down
Loading

0 comments on commit 7b805e1

Please sign in to comment.