Skip to content

Commit

Permalink
Merge pull request #120 from dcSpark/plain-groups-in-arrays
Browse files Browse the repository at this point in the history
Basic groups inside arrays support.
  • Loading branch information
rooooooooob authored Dec 20, 2022
2 parents 8815b0d + 1773d17 commit 3f8acf9
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 67 deletions.
110 changes: 71 additions & 39 deletions src/generation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ struct DeserializeConfig<'a> {
final_exprs: Vec<String>,
/// Overload for the deserializer's name. Defaults to "raw"
deserializer_name_overload: Option<&'a str>,
/// Overload for read_len. This would be a local e.g. for arrays
read_len_overload: Option<String>,
}

impl<'a> DeserializeConfig<'a> {
Expand All @@ -173,6 +175,7 @@ impl<'a> DeserializeConfig<'a> {
optional_field: false,
final_exprs: Vec::new(),
deserializer_name_overload: None,
read_len_overload: None,
}
}

Expand All @@ -194,6 +197,22 @@ impl<'a> DeserializeConfig<'a> {
fn deserializer_name(&self) -> &'a str {
self.deserializer_name_overload.unwrap_or("raw")
}

fn overload_read_len(mut self, overload: String) -> Self {
self.read_len_overload = Some(overload);
self
}

fn pass_read_len(&self) -> String {
if let Some(overload) = &self.read_len_overload {
// the ONLY way to have a name overload is if we have a local variable (e.g. arrays)
format!("&mut {}", overload)
} else if self.in_embedded {
"read_len".to_owned()
} else {
"&mut read_len".to_owned()
}
}
}

fn concat_files(paths: Vec<&str>) -> std::io::Result<String> {
Expand Down Expand Up @@ -944,7 +963,23 @@ impl GenerationScope {
}
},
SerializingRustType::Root(ConceptualRustType::Array(ty)) => {
start_len(body, Representation::Array, serializer_use, &encoding_var, &format!("{}.len() as u64", config.expr));
let len_expr = match &ty.conceptual_type {
ConceptualRustType::Rust(elem_ident) if types.is_plain_group(elem_ident) => {
// you should not be able to indiscriminately encode a plain group like this as it
// could be multiple elements. This would require special handling if it's even permitted in CDDL.
assert!(ty.encodings.is_empty());
if let Some(fixed_elem_size) = ty.conceptual_type.expanded_field_count(types) {
format!("{} * {}.len() as u64", fixed_elem_size, config.expr)
} else {
format!(
"{}.iter().map(|e| {}).sum()",
config.expr,
ty.conceptual_type.definite_info("e", types))
}
},
_ => format!("{}.len() as u64", config.expr)
};
start_len(body, Representation::Array, serializer_use, &encoding_var, &len_expr);
let elem_var_name = format!("{}_elem", config.var_name);
let elem_encs = if CLI_ARGS.preserve_encodings {
encoding_fields(&elem_var_name, &ty.clone().resolve_aliases(), false)
Expand Down Expand Up @@ -1305,17 +1340,12 @@ impl GenerationScope {
// a parameter whether it was an optional field, and if so, read_len.read_elems(embedded mandatory fields)?;
// since otherwise it'd only length check the optional fields within the type.
assert!(!config.optional_field);
let pass_read_len = if config.in_embedded {
"read_len"
} else {
"&mut read_len"
};
deser_code.read_len_used = true;
let final_expr_value = format!(
"{}::deserialize_as_embedded_group({}, {}, len)",
ident,
deserializer_name,
pass_read_len);
config.pass_read_len());

deser_code.content.line(&final_result_expr_complete(&mut deser_code.throws, config.final_exprs, &final_expr_value));
} else {
Expand Down Expand Up @@ -1451,20 +1481,39 @@ impl GenerationScope {
if CLI_ARGS.preserve_encodings {
deser_code.content
.line(&format!("let len = {}.array_sz()?;", deserializer_name))
.line(&format!("let {}_encoding = len.into();", config.var_name));
.line(&format!("let mut {}_encoding = len.into();", config.var_name));
if !elem_encs.is_empty() {
deser_code.content.line(&format!("let mut {}_elem_encodings = Vec::new();", config.var_name));
}
} else {
deser_code.content.line(&format!("let len = {}.array()?;", deserializer_name));
}
let mut deser_loop = make_deser_loop("len", &format!("{}.len()", arr_var_name));
let mut elem_config = DeserializeConfig::new(&elem_var_name);
let (mut deser_loop, plain_len_check) = match &ty.conceptual_type {
ConceptualRustType::Rust(ty_ident) if types.is_plain_group(&*ty_ident) => {
// two things that must be done differently for embedded plain groups:
// 1) We can't directly read the CBOR len's number of items since it could be >1
// 2) We need a different cbor read len var to pass into embedded deserialize
let read_len_overload = format!("{}_read_len", config.var_name);
deser_code.content.line(&format!("let mut {} = CBORReadLen::new(len);", read_len_overload));
// inside of deserialize_as_embedded_group we only modify read_len for things we couldn't
// statically know beforehand. This was done for other areas that use plain groups in order
// to be able to do static length checks for statically sized groups that contain plain groups
// at the start of deserialization instead of many checks for every single field.
let plain_len_check = match ty.conceptual_type.expanded_mandatory_field_count(types) {
0 => None,
n => Some(format!("{}.read_elems({})?;", read_len_overload, n)),
};
elem_config = elem_config.overload_read_len(read_len_overload);
let deser_loop = make_deser_loop("len", &format!("{}_read_len.read", config.var_name));
(deser_loop, plain_len_check)
},
_ => (make_deser_loop("len", &format!("({}.len() as u64)", arr_var_name)), None)
};
deser_loop.push_block(make_deser_loop_break_check());
if let ConceptualRustType::Rust(ty_ident) = &ty.conceptual_type {
// TODO: properly handle which read_len would be checked here.
assert!(!types.is_plain_group(&*ty_ident));
if let Some(plain_len_check) = plain_len_check {
deser_loop.line(plain_len_check);
}
let mut elem_config = DeserializeConfig::new(&elem_var_name);
elem_config.deserializer_name_overload = config.deserializer_name_overload;
if !elem_encs.is_empty() {
let elem_var_names_str = encoding_var_names_str(&elem_var_name, ty);
Expand Down Expand Up @@ -1543,7 +1592,7 @@ impl GenerationScope {
} else {
deser_code.content.line(&format!("let {} = {}.map()?;", len_var, deserializer_name));
}
let mut deser_loop = make_deser_loop(&len_var, &format!("{}.len()", table_var));
let mut deser_loop = make_deser_loop(&len_var, &format!("({}.len() as u64)", table_var));
deser_loop.push_block(make_deser_loop_break_check());
let mut key_config = DeserializeConfig::new(&key_var_name);
key_config.deserializer_name_overload = config.deserializer_name_overload;
Expand Down Expand Up @@ -2145,15 +2194,13 @@ fn create_base_wasm_wrapper<'a>(gen_scope: &GenerationScope, ident: &'a RustIden
// Alway creates directly just Serialize impl. Shortcut for create_serialize_impls when
// we know we won't need the SerializeEmbeddedGroup impl.
// See comments for create_serialize_impls for usage.
fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: Option<String>, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) {
fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: &str, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) {
match create_serialize_impls(ident, rep, tag, definite_len, use_this_encoding, false) {
(ser_func, ser_impl, None) => (ser_func, ser_impl),
(_ser_func, _ser_impl, Some(_embedded_impl)) => unreachable!(),
}
}

// If definite_len is provided, it will use that expression as the definite length.
// Otherwise indefinite will be used and the user should remember to write a Special::Break at the end.
// Returns (serialize, Serialize, Some(SerializeEmbeddedGroup)) impls for structs that require embedded, in which case
// the serialize calls the embedded serialize and you implement the embedded serialize
// Otherwise returns (serialize Serialize, None) impls and you implement the serialize.
Expand All @@ -2164,15 +2211,12 @@ fn create_serialize_impl(ident: &RustIdent, rep: Option<Representation>, tag: Op
// In the second case (no embedded), only the array/map tag + length are written and the user will
// want to write the rest of serialize() after that.
// * `use_this_encoding` - If present, references a variable (must be bool and in this scope) to toggle definite vs indefinite (e.g. for PRESERVE_ENCODING)
fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: Option<String>, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option<codegen::Impl>) {
fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: Option<usize>, definite_len: &str, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option<codegen::Impl>) {
if generate_serialize_embedded {
// This is not necessarily a problem but we should investigate this case to ensure we're not calling
// (de)serialize_as_embedded without (de)serializing the tag
assert_eq!(tag, None);
}
if use_this_encoding.is_some() && definite_len.is_none() {
panic!("definite_len is required for use_this_encoding or else we'd only be able to serialize indefinite no matter what");
}
let name = &ident.to_string();
let ser_impl = make_serialization_impl(name);
let mut ser_func = make_serialization_function("serialize");
Expand All @@ -2183,28 +2227,16 @@ fn create_serialize_impls(ident: &RustIdent, rep: Option<Representation>, tag: O
// TODO: do definite length encoding for optional fields too
if let Some (rep) = rep {
if let Some(definite) = use_this_encoding {
start_len(&mut ser_func, rep, "serializer", definite, definite_len.as_ref().unwrap());
start_len(&mut ser_func, rep, "serializer", definite, definite_len);
} else {
let len = match &definite_len {
Some(fixed_field_count) => cbor_event_len_n(fixed_field_count),
None => {
assert!(!CLI_ARGS.canonical_form);
cbor_event_len_indef().to_owned()
},
};
let len = cbor_event_len_n(definite_len);
match rep {
Representation::Array => ser_func.line(format!("serializer.write_array({})?;", len)),
Representation::Map => ser_func.line(format!("serializer.write_map({})?;", len)),
};
}
if generate_serialize_embedded {
match definite_len {
Some(_) => ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param())),
None => {
ser_func.line(format!("self.serialize_as_embedded_group(serializer{})?;", canonical_param()));
ser_func.line("serializer.write_special(CBORSpecial::Break)")
},
};
ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param()));
}
} else {
// not array or map, generate serialize directly
Expand Down Expand Up @@ -2387,7 +2419,7 @@ fn make_err_annotate_block(annotation: &str, before: &str, after: &str) -> Block
fn make_deser_loop(len_var: &str, len_expr: &str) -> Block {
Block::new(
&format!(
"while match {} {{ {} => {} < n as usize, {} => true, }}",
"while match {} {{ {} => {} < n, {} => true, }}",
len_var,
cbor_event_len_n("n"),
len_expr,
Expand Down Expand Up @@ -2857,7 +2889,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na
name,
Some(record.rep),
tag,
record.definite_info(types),
&record.definite_info(types),
len_encoding_var.map(|var| format!("self.encodings.as_ref().map(|encs| encs.{}).unwrap_or_default()", var)).as_deref(),
types.is_plain_group(name));
let mut ser_func = match ser_embedded_impl {
Expand Down Expand Up @@ -3210,7 +3242,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na
ser_func.line(format!(
"let deser_order = self.encodings.as_ref().filter(|encs| {}encs.orig_deser_order.len() == {}).map(|encs| encs.orig_deser_order.clone()).unwrap_or_else(|| {});",
check_canonical,
record.definite_info(types).expect("cannot fail for maps"),
record.definite_info(types),
serialization_order));
let mut ser_loop = codegen::Block::new("for field_index in deser_order");
let mut ser_loop_match = codegen::Block::new("match field_index");
Expand Down
Loading

0 comments on commit 3f8acf9

Please sign in to comment.