From 1773d176f3512b8d546ede6f53ae7531c582ec56 Mon Sep 17 00:00:00 2001 From: rooooooooob Date: Tue, 13 Dec 2022 13:35:27 -0800 Subject: [PATCH] Basic groups inside arrays support. e.g.: ``` foo = (uint, text) bar = [ foos: [* foo], ] ``` which if you had `n` foos would be a `2 * n` long CBOR array as there's no wrapping map/array around the basic group `foo`. This can be done in some CDDL specs as a space optimization. --- src/generation.rs | 110 +++++++++++++++++++++++++++--------------- src/intermediate.rs | 48 +++++++++--------- src/parsing.rs | 5 +- tests/core/input.cddl | 5 ++ tests/core/tests.rs | 13 ++++- 5 files changed, 114 insertions(+), 67 deletions(-) diff --git a/src/generation.rs b/src/generation.rs index ce7bb78..1fc7cb6 100644 --- a/src/generation.rs +++ b/src/generation.rs @@ -163,6 +163,8 @@ struct DeserializeConfig<'a> { final_exprs: Vec, /// Overload for the deserializer's name. Defaults to "raw" deserializer_name_overload: Option<&'a str>, + /// Overload for read_len. This would be a local e.g. for arrays + read_len_overload: Option, } impl<'a> DeserializeConfig<'a> { @@ -173,6 +175,7 @@ impl<'a> DeserializeConfig<'a> { optional_field: false, final_exprs: Vec::new(), deserializer_name_overload: None, + read_len_overload: None, } } @@ -194,6 +197,22 @@ impl<'a> DeserializeConfig<'a> { fn deserializer_name(&self) -> &'a str { self.deserializer_name_overload.unwrap_or("raw") } + + fn overload_read_len(mut self, overload: String) -> Self { + self.read_len_overload = Some(overload); + self + } + + fn pass_read_len(&self) -> String { + if let Some(overload) = &self.read_len_overload { + // the ONLY way to have a name overload is if we have a local variable (e.g. arrays) + format!("&mut {}", overload) + } else if self.in_embedded { + "read_len".to_owned() + } else { + "&mut read_len".to_owned() + } + } } fn concat_files(paths: Vec<&str>) -> std::io::Result { @@ -944,7 +963,23 @@ impl GenerationScope { } }, SerializingRustType::Root(ConceptualRustType::Array(ty)) => { - start_len(body, Representation::Array, serializer_use, &encoding_var, &format!("{}.len() as u64", config.expr)); + let len_expr = match &ty.conceptual_type { + ConceptualRustType::Rust(elem_ident) if types.is_plain_group(elem_ident) => { + // you should not be able to indiscriminately encode a plain group like this as it + // could be multiple elements. This would require special handling if it's even permitted in CDDL. + assert!(ty.encodings.is_empty()); + if let Some(fixed_elem_size) = ty.conceptual_type.expanded_field_count(types) { + format!("{} * {}.len() as u64", fixed_elem_size, config.expr) + } else { + format!( + "{}.iter().map(|e| {}).sum()", + config.expr, + ty.conceptual_type.definite_info("e", types)) + } + }, + _ => format!("{}.len() as u64", config.expr) + }; + start_len(body, Representation::Array, serializer_use, &encoding_var, &len_expr); let elem_var_name = format!("{}_elem", config.var_name); let elem_encs = if CLI_ARGS.preserve_encodings { encoding_fields(&elem_var_name, &ty.clone().resolve_aliases(), false) @@ -1305,17 +1340,12 @@ impl GenerationScope { // a parameter whether it was an optional field, and if so, read_len.read_elems(embedded mandatory fields)?; // since otherwise it'd only length check the optional fields within the type. assert!(!config.optional_field); - let pass_read_len = if config.in_embedded { - "read_len" - } else { - "&mut read_len" - }; deser_code.read_len_used = true; let final_expr_value = format!( "{}::deserialize_as_embedded_group({}, {}, len)", ident, deserializer_name, - pass_read_len); + config.pass_read_len()); deser_code.content.line(&final_result_expr_complete(&mut deser_code.throws, config.final_exprs, &final_expr_value)); } else { @@ -1451,20 +1481,39 @@ impl GenerationScope { if CLI_ARGS.preserve_encodings { deser_code.content .line(&format!("let len = {}.array_sz()?;", deserializer_name)) - .line(&format!("let {}_encoding = len.into();", config.var_name)); + .line(&format!("let mut {}_encoding = len.into();", config.var_name)); if !elem_encs.is_empty() { deser_code.content.line(&format!("let mut {}_elem_encodings = Vec::new();", config.var_name)); } } else { deser_code.content.line(&format!("let len = {}.array()?;", deserializer_name)); } - let mut deser_loop = make_deser_loop("len", &format!("{}.len()", arr_var_name)); + let mut elem_config = DeserializeConfig::new(&elem_var_name); + let (mut deser_loop, plain_len_check) = match &ty.conceptual_type { + ConceptualRustType::Rust(ty_ident) if types.is_plain_group(&*ty_ident) => { + // two things that must be done differently for embedded plain groups: + // 1) We can't directly read the CBOR len's number of items since it could be >1 + // 2) We need a different cbor read len var to pass into embedded deserialize + let read_len_overload = format!("{}_read_len", config.var_name); + deser_code.content.line(&format!("let mut {} = CBORReadLen::new(len);", read_len_overload)); + // inside of deserialize_as_embedded_group we only modify read_len for things we couldn't + // statically know beforehand. This was done for other areas that use plain groups in order + // to be able to do static length checks for statically sized groups that contain plain groups + // at the start of deserialization instead of many checks for every single field. + let plain_len_check = match ty.conceptual_type.expanded_mandatory_field_count(types) { + 0 => None, + n => Some(format!("{}.read_elems({})?;", read_len_overload, n)), + }; + elem_config = elem_config.overload_read_len(read_len_overload); + let deser_loop = make_deser_loop("len", &format!("{}_read_len.read", config.var_name)); + (deser_loop, plain_len_check) + }, + _ => (make_deser_loop("len", &format!("({}.len() as u64)", arr_var_name)), None) + }; deser_loop.push_block(make_deser_loop_break_check()); - if let ConceptualRustType::Rust(ty_ident) = &ty.conceptual_type { - // TODO: properly handle which read_len would be checked here. - assert!(!types.is_plain_group(&*ty_ident)); + if let Some(plain_len_check) = plain_len_check { + deser_loop.line(plain_len_check); } - let mut elem_config = DeserializeConfig::new(&elem_var_name); elem_config.deserializer_name_overload = config.deserializer_name_overload; if !elem_encs.is_empty() { let elem_var_names_str = encoding_var_names_str(&elem_var_name, ty); @@ -1543,7 +1592,7 @@ impl GenerationScope { } else { deser_code.content.line(&format!("let {} = {}.map()?;", len_var, deserializer_name)); } - let mut deser_loop = make_deser_loop(&len_var, &format!("{}.len()", table_var)); + let mut deser_loop = make_deser_loop(&len_var, &format!("({}.len() as u64)", table_var)); deser_loop.push_block(make_deser_loop_break_check()); let mut key_config = DeserializeConfig::new(&key_var_name); key_config.deserializer_name_overload = config.deserializer_name_overload; @@ -2145,15 +2194,13 @@ fn create_base_wasm_wrapper<'a>(gen_scope: &GenerationScope, ident: &'a RustIden // Alway creates directly just Serialize impl. Shortcut for create_serialize_impls when // we know we won't need the SerializeEmbeddedGroup impl. // See comments for create_serialize_impls for usage. -fn create_serialize_impl(ident: &RustIdent, rep: Option, tag: Option, definite_len: Option, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) { +fn create_serialize_impl(ident: &RustIdent, rep: Option, tag: Option, definite_len: &str, use_this_encoding: Option<&str>) -> (codegen::Function, codegen::Impl) { match create_serialize_impls(ident, rep, tag, definite_len, use_this_encoding, false) { (ser_func, ser_impl, None) => (ser_func, ser_impl), (_ser_func, _ser_impl, Some(_embedded_impl)) => unreachable!(), } } -// If definite_len is provided, it will use that expression as the definite length. -// Otherwise indefinite will be used and the user should remember to write a Special::Break at the end. // Returns (serialize, Serialize, Some(SerializeEmbeddedGroup)) impls for structs that require embedded, in which case // the serialize calls the embedded serialize and you implement the embedded serialize // Otherwise returns (serialize Serialize, None) impls and you implement the serialize. @@ -2164,15 +2211,12 @@ fn create_serialize_impl(ident: &RustIdent, rep: Option, tag: Op // In the second case (no embedded), only the array/map tag + length are written and the user will // want to write the rest of serialize() after that. // * `use_this_encoding` - If present, references a variable (must be bool and in this scope) to toggle definite vs indefinite (e.g. for PRESERVE_ENCODING) -fn create_serialize_impls(ident: &RustIdent, rep: Option, tag: Option, definite_len: Option, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option) { +fn create_serialize_impls(ident: &RustIdent, rep: Option, tag: Option, definite_len: &str, use_this_encoding: Option<&str>, generate_serialize_embedded: bool) -> (codegen::Function, codegen::Impl, Option) { if generate_serialize_embedded { // This is not necessarily a problem but we should investigate this case to ensure we're not calling // (de)serialize_as_embedded without (de)serializing the tag assert_eq!(tag, None); } - if use_this_encoding.is_some() && definite_len.is_none() { - panic!("definite_len is required for use_this_encoding or else we'd only be able to serialize indefinite no matter what"); - } let name = &ident.to_string(); let ser_impl = make_serialization_impl(name); let mut ser_func = make_serialization_function("serialize"); @@ -2183,28 +2227,16 @@ fn create_serialize_impls(ident: &RustIdent, rep: Option, tag: O // TODO: do definite length encoding for optional fields too if let Some (rep) = rep { if let Some(definite) = use_this_encoding { - start_len(&mut ser_func, rep, "serializer", definite, definite_len.as_ref().unwrap()); + start_len(&mut ser_func, rep, "serializer", definite, definite_len); } else { - let len = match &definite_len { - Some(fixed_field_count) => cbor_event_len_n(fixed_field_count), - None => { - assert!(!CLI_ARGS.canonical_form); - cbor_event_len_indef().to_owned() - }, - }; + let len = cbor_event_len_n(definite_len); match rep { Representation::Array => ser_func.line(format!("serializer.write_array({})?;", len)), Representation::Map => ser_func.line(format!("serializer.write_map({})?;", len)), }; } if generate_serialize_embedded { - match definite_len { - Some(_) => ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param())), - None => { - ser_func.line(format!("self.serialize_as_embedded_group(serializer{})?;", canonical_param())); - ser_func.line("serializer.write_special(CBORSpecial::Break)") - }, - }; + ser_func.line(format!("self.serialize_as_embedded_group(serializer{})", canonical_param())); } } else { // not array or map, generate serialize directly @@ -2387,7 +2419,7 @@ fn make_err_annotate_block(annotation: &str, before: &str, after: &str) -> Block fn make_deser_loop(len_var: &str, len_expr: &str) -> Block { Block::new( &format!( - "while match {} {{ {} => {} < n as usize, {} => true, }}", + "while match {} {{ {} => {} < n, {} => true, }}", len_var, cbor_event_len_n("n"), len_expr, @@ -2857,7 +2889,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na name, Some(record.rep), tag, - record.definite_info(types), + &record.definite_info(types), len_encoding_var.map(|var| format!("self.encodings.as_ref().map(|encs| encs.{}).unwrap_or_default()", var)).as_deref(), types.is_plain_group(name)); let mut ser_func = match ser_embedded_impl { @@ -3210,7 +3242,7 @@ fn codegen_struct(gen_scope: &mut GenerationScope, types: &IntermediateTypes, na ser_func.line(format!( "let deser_order = self.encodings.as_ref().filter(|encs| {}encs.orig_deser_order.len() == {}).map(|encs| encs.orig_deser_order.clone()).unwrap_or_else(|| {});", check_canonical, - record.definite_info(types).expect("cannot fail for maps"), + record.definite_info(types), serialization_order)); let mut ser_loop = codegen::Block::new("for field_index in deser_order"); let mut ser_loop_match = codegen::Block::new("match field_index"); diff --git a/src/intermediate.rs b/src/intermediate.rs index 2681e6e..c567e62 100644 --- a/src/intermediate.rs +++ b/src/intermediate.rs @@ -179,9 +179,7 @@ impl<'a> IntermediateTypes<'a> { match &rust_struct.variant { RustStructType::Table { domain, range } => { // we must provide the keys type to return - if CLI_ARGS.wasm { - self.create_and_register_array_type(parent_visitor, domain.clone(), &domain.conceptual_type.name_as_wasm_array()); - } + self.create_and_register_array_type(parent_visitor, domain.clone(), &domain.conceptual_type.name_as_wasm_array()); let mut map_type: RustType = ConceptualRustType::Map(Box::new(domain.clone()), Box::new(range.clone())).into(); if let Some(tag) = rust_struct.tag { map_type = map_type.tag(tag); @@ -226,9 +224,11 @@ impl<'a> IntermediateTypes<'a> { if let ConceptualRustType::Rust(_) = &element_type.conceptual_type { self.set_rep_if_plain_group(parent_visitor, &array_type_ident, Representation::Array); } - // we don't pass in tags here. If a tag-wrapped array is done I think it generates - // 2 separate types (array wrapper -> tag wrapper struct) - self.register_rust_struct(parent_visitor, RustStruct::new_array(array_type_ident, None, element_type.clone())); + if CLI_ARGS.wasm { + // we don't pass in tags here. If a tag-wrapped array is done I think it generates + // 2 separate types (array wrapper -> tag wrapper struct) + self.register_rust_struct(parent_visitor, RustStruct::new_array(array_type_ident, None, element_type.clone())); + } ConceptualRustType::Array(Box::new(element_type)).into() } @@ -1213,25 +1213,24 @@ impl ConceptualRustType { } // See comment in RustStruct::definite_info(), this is the same, returns a string expression - // which evaluates to the length when possible, or None if not. + // which evaluates to the length. // self_expr is an expresison that evaluates to this RustType (e.g. member, etc) at the point where // the return of this function will be used. - pub fn definite_info(&self, self_expr: &str, types: &IntermediateTypes) -> Option { + pub fn definite_info(&self, self_expr: &str, types: &IntermediateTypes) -> String { match self.expanded_field_count(types) { - Some(count) => Some(count.to_string()), + Some(count) => count.to_string(), None => match self { - Self::Optional(ty) => Some(format!("match {} {{ Some(x) => {}, None => 1 }}", self_expr, ty.conceptual_type.definite_info("x", types)?)), + Self::Optional(ty) => format!("match {} {{ Some(x) => {}, None => 1 }}", self_expr, ty.conceptual_type.definite_info("x", types)), Self::Rust(ident) => if types.is_plain_group(ident) { match types.rust_structs.get(&ident) { Some(rs) => rs.definite_info(types), - // when we split up parsing from codegen instead of multi-passing this should be an error - None => None, + None => panic!("rust struct {} not found but referenced by {:?}", ident, self), } } else { - Some(String::from("1")) + String::from("1") }, Self::Alias(_ident, ty) => ty.definite_info(self_expr, types), - _ => Some(String::from("1")), + _ => String::from("1"), } } } @@ -1564,15 +1563,14 @@ impl RustStruct { } } - // Even if fixed_field_count() == None, this will try and return an expression for + // Even if fixed_field_count() == None, this will return an expression for // a definite length, e.g. with optional field checks in the expression // This is useful for definite-length serialization - pub fn definite_info(&self, types: &IntermediateTypes) -> Option { + pub fn definite_info(&self, types: &IntermediateTypes) -> String { match &self.variant { RustStructType::Record(record) => record.definite_info(types), - RustStructType::Table{ .. } => Some(String::from("self.0.len() as u64")), - RustStructType::Array{ .. } => Some(String::from("self.0.len() as u64")), - //RustStructType::TypeChoice{ .. } => None, + RustStructType::Table{ .. } => String::from("self.0.len() as u64"), + RustStructType::Array{ .. } => String::from("self.0.len() as u64"), RustStructType::TypeChoice{ .. } => unreachable!("I don't think type choices should be using length?"), RustStructType::GroupChoice{ .. } => unreachable!("I don't think group choices should be using length?"), RustStructType::Wrapper{ .. } => unreachable!("wrapper types don't use length"), @@ -1651,9 +1649,9 @@ impl RustRecord { } // This is guaranteed - pub fn definite_info(&self, types: &IntermediateTypes) -> Option { + pub fn definite_info(&self, types: &IntermediateTypes) -> String { match self.fixed_field_count(types) { - Some(count) => Some(count.to_string()), + Some(count) => count.to_string(), None => { let mut fixed_field_count = 0; let mut conditional_field_expr = String::new(); @@ -1663,7 +1661,7 @@ impl RustRecord { conditional_field_expr.push_str(" + "); } let (field_expr, field_contribution) = match self.rep { - Representation::Array => ("x", field.rust_type.conceptual_type.definite_info("x", types)?), + Representation::Array => ("x", field.rust_type.conceptual_type.definite_info("x", types)), // maps are defined by their keys instead (although they shouldn't have multi-length values either...) Representation::Map => ("_", String::from("1")), }; @@ -1693,7 +1691,7 @@ impl RustRecord { if !conditional_field_expr.is_empty() { conditional_field_expr.push_str(" + "); } - let field_len_expr = field.rust_type.conceptual_type.definite_info(&format!("self.{}", field.name), types)?; + let field_len_expr = field.rust_type.conceptual_type.definite_info(&format!("self.{}", field.name), types); conditional_field_expr.push_str(&field_len_expr); }, }, @@ -1704,9 +1702,9 @@ impl RustRecord { } } if conditional_field_expr.is_empty() || fixed_field_count != 0 { - Some(format!("{} + {}", fixed_field_count.to_string(), conditional_field_expr)) + format!("{} + {}", fixed_field_count.to_string(), conditional_field_expr) } else { - Some(conditional_field_expr) + conditional_field_expr } } } diff --git a/src/parsing.rs b/src/parsing.rs index 1e151aa..be21d94 100644 --- a/src/parsing.rs +++ b/src/parsing.rs @@ -695,9 +695,12 @@ fn rust_type_from_type2(types: &mut IntermediateTypes, parent_visitor: &ParentVi // array of elements with choices: enums? _ => unimplemented!("group choices in array type not supported"), }; - + //let array_wrapper_name = element_type.name_as_wasm_array(); //types.create_and_register_array_type(element_type, &array_wrapper_name) + if let ConceptualRustType::Rust(element_ident) = &element_type.conceptual_type { + types.set_rep_if_plain_group(parent_visitor, element_ident, Representation::Array); + } ConceptualRustType::Array(Box::new(element_type)).into() }, Type2::Map { group, .. } => { diff --git a/tests/core/input.cddl b/tests/core/input.cddl index 3ab454e..6e1a852 100644 --- a/tests/core/input.cddl +++ b/tests/core/input.cddl @@ -18,6 +18,11 @@ bar = { plain = (d: #6.23(uint), e: tagged_text) outer = [a: uint, b: plain, c: "some text"] +plain_arrays = [ +; this is not supported right now. When single-element arrays are supported remove this. +; single: [plain], + multi: [*plain], +] table = { * uint => text } diff --git a/tests/core/tests.rs b/tests/core/tests.rs index a058042..5d70996 100644 --- a/tests/core/tests.rs +++ b/tests/core/tests.rs @@ -3,10 +3,13 @@ mod tests { use super::*; fn deser_test(orig: &T) { - print_cbor_types("orig", &orig.to_bytes()); - let deser = T::deserialize(&mut Deserializer::from(std::io::Cursor::new(orig.to_bytes()))).unwrap(); + let orig_bytes = orig.to_bytes(); + print_cbor_types("orig", &orig_bytes); + let mut deserializer = Deserializer::from(std::io::Cursor::new(orig_bytes.clone())); + let deser = T::deserialize(&mut deserializer).unwrap(); print_cbor_types("deser", &deser.to_bytes()); assert_eq!(orig.to_bytes(), deser.to_bytes()); + assert_eq!(deserializer.as_ref().position(), orig_bytes.len() as u64); } #[test] @@ -41,6 +44,12 @@ mod tests { deser_test(&Plain::new(7576, String::from("wiorurri34h").into())); } + #[test] + fn plain_arrays() { + let plain = Plain::new(7576, String::from("wiorurri34h").into()); + deser_test(&PlainArrays::new(vec![plain.clone(), plain.clone()])); + } + #[test] fn outer() { deser_test(&Outer::new(2143254, Plain::new(7576, String::from("wiorurri34h").into())));