From 7773c9996453d6d43c306ee2026126ff6e067380 Mon Sep 17 00:00:00 2001 From: Ronald Holshausen Date: Tue, 23 Aug 2022 16:22:32 +1000 Subject: [PATCH] feat: support decoding packed repeated fields --- src/message_builder.rs | 13 +-- src/message_decoder.rs | 190 ++++++++++++++++++++++++++++++++++------- src/utils.rs | 28 +++++- 3 files changed, 186 insertions(+), 45 deletions(-) diff --git a/src/message_builder.rs b/src/message_builder.rs index 196fdb2..d31ca33 100644 --- a/src/message_builder.rs +++ b/src/message_builder.rs @@ -13,7 +13,7 @@ use prost_types::{DescriptorProto, FieldDescriptorProto, FileDescriptorProto}; use prost_types::field_descriptor_proto::Type; use tracing::trace; -use crate::utils::last_name; +use crate::utils::{last_name, should_be_packed_type}; /// Enum to set what type of field the value is for #[derive(Clone, Copy, Debug, PartialEq)] @@ -293,7 +293,7 @@ impl MessageBuilder { fn encode_repeated_field(&self, buffer: &mut BytesMut, field_value: &FieldValueInner) -> anyhow::Result<()> { trace!(">> encode_repeated_field({:?})", field_value); if !field_value.values.is_empty() { - if should_be_packed(field_value) { + if should_be_packed_type(field_value.proto_type) { self.encode_packed_field(buffer, field_value)?; } else { for value in &field_value.values { @@ -425,15 +425,6 @@ impl MessageBuilder { } } -fn should_be_packed(field: &FieldValueInner) -> bool { - match field.proto_type { - Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 | - Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 | - Type::Sint64 => true, - _ => false - } -} - fn field_type_name(field: &FieldValueInner) -> anyhow::Result { Ok(match field.proto_type { Type::Double => "double".to_string(), diff --git a/src/message_decoder.rs b/src/message_decoder.rs index 60f8ec8..84cff7a 100644 --- a/src/message_decoder.rs +++ b/src/message_decoder.rs @@ -4,14 +4,14 @@ use std::fmt::{Display, Formatter}; use std::str::from_utf8; use anyhow::anyhow; -use bytes::{Buf, BytesMut}; +use bytes::{Buf, Bytes, BytesMut}; use itertools::Itertools; use prost::encoding::{decode_key, decode_varint, encode_varint, WireType}; use prost_types::{DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorSet}; use prost_types::field_descriptor_proto::Type; -use tracing::{error, trace, warn}; +use tracing::{debug, error, trace, warn}; -use crate::utils::{as_hex, find_message_type_by_name, last_name}; +use crate::utils::{as_hex, find_message_type_by_name, is_repeated_field, last_name, should_be_packed_type}; /// Decoded Protobuf field #[derive(Clone, Debug, PartialEq)] @@ -211,51 +211,51 @@ pub fn decode_message( let varint = decode_varint(buffer)?; let t: Type = field_descriptor.r#type(); match t { - Type::Int64 => ProtobufFieldData::Integer64(varint as i64), - Type::Uint64 => ProtobufFieldData::UInteger64(varint), - Type::Int32 => ProtobufFieldData::Integer32(varint as i32), - Type::Bool => ProtobufFieldData::Boolean(varint > 0), - Type::Uint32 => ProtobufFieldData::UInteger32(varint as u32), + Type::Int64 => vec![ (ProtobufFieldData::Integer64(varint as i64), wire_type) ], + Type::Uint64 => vec![ (ProtobufFieldData::UInteger64(varint), wire_type) ], + Type::Int32 => vec![ (ProtobufFieldData::Integer32(varint as i32), wire_type) ], + Type::Bool => vec![ (ProtobufFieldData::Boolean(varint > 0), wire_type) ], + Type::Uint32 => vec![ (ProtobufFieldData::UInteger32(varint as u32), wire_type) ], Type::Enum => { let enum_proto = descriptor.enum_type.iter() .find(|enum_type| enum_type.name.clone().unwrap_or_default() == last_name(field_descriptor.type_name.clone().unwrap_or_default().as_str())) .ok_or_else(|| anyhow!("Did not find the enum {:?} for the field {} in the Protobuf descriptor", field_descriptor.type_name, field_num))?; - ProtobufFieldData::Enum(varint as i32, enum_proto.clone()) + vec![ (ProtobufFieldData::Enum(varint as i32, enum_proto.clone()), wire_type) ] }, Type::Sint32 => { let value = varint as u32; - ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32))) + vec![ (ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32))), wire_type) ] }, - Type::Sint64 => ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))), + Type::Sint64 => vec![ (ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))), wire_type) ], _ => { error!("Was expecting {:?} but received an unknown varint type", t); - ProtobufFieldData::Unknown(varint.to_le_bytes().to_vec()) + vec![ (ProtobufFieldData::Unknown(varint.to_le_bytes().to_vec()), wire_type) ] } } } WireType::SixtyFourBit => { let t: Type = field_descriptor.r#type(); match t { - Type::Double => ProtobufFieldData::Double(buffer.get_f64_le()), - Type::Fixed64 => ProtobufFieldData::UInteger64(buffer.get_u64_le()), - Type::Sfixed64 => ProtobufFieldData::Integer64(buffer.get_i64_le()), + Type::Double => vec![ (ProtobufFieldData::Double(buffer.get_f64_le()), wire_type) ], + Type::Fixed64 => vec![ (ProtobufFieldData::UInteger64(buffer.get_u64_le()), wire_type) ], + Type::Sfixed64 => vec![ (ProtobufFieldData::Integer64(buffer.get_i64_le()), wire_type) ], _ => { error!("Was expecting {:?} but received an unknown 64 bit type", t); let value = buffer.get_u64_le(); - ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()) + vec![ (ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()), wire_type) ] } } } WireType::LengthDelimited => { let data_length = decode_varint(buffer)?; - let data_buffer = if buffer.remaining() >= data_length as usize { + let mut data_buffer = if buffer.remaining() >= data_length as usize { buffer.copy_to_bytes(data_length as usize) } else { return Err(anyhow!("Insufficient data remaining ({} bytes) to read {} bytes for field {}", buffer.remaining(), data_length, field_num)); }; let t: Type = field_descriptor.r#type(); match t { - Type::String => ProtobufFieldData::String(from_utf8(&data_buffer)?.to_string()), + Type::String => vec![ (ProtobufFieldData::String(from_utf8(&data_buffer)?.to_string()), wire_type) ], Type::Message => { let type_name = field_descriptor.type_name.as_ref().map(|v| last_name(v.as_str()).to_string()); let message_proto = descriptor.nested_type.iter() @@ -263,28 +263,31 @@ pub fn decode_message( .cloned() .or_else(|| find_message_type_by_name(&type_name.unwrap_or_default(), descriptors).ok()) .ok_or_else(|| anyhow!("Did not find the embedded message {:?} for the field {} in the Protobuf descriptor", field_descriptor.type_name, field_num))?; - ProtobufFieldData::Message(data_buffer.to_vec(), message_proto) + vec![ (ProtobufFieldData::Message(data_buffer.to_vec(), message_proto), wire_type) ] } - Type::Bytes => ProtobufFieldData::Bytes(data_buffer.to_vec()), - _ => { + Type::Bytes => vec![ (ProtobufFieldData::Bytes(data_buffer.to_vec()), wire_type) ], + _ => if should_be_packed_type(t) && is_repeated_field(&field_descriptor) { + debug!("Reading length delimited field as a packed repeated field"); + decode_packed_field(field_descriptor, &mut data_buffer)? + } else { error!("Was expecting {:?} but received an unknown length-delimited type", t); let mut buf = BytesMut::with_capacity((data_length + 8) as usize); encode_varint(data_length, &mut buf); buf.extend_from_slice(&*data_buffer); - ProtobufFieldData::Unknown(buf.freeze().to_vec()) + vec![ (ProtobufFieldData::Unknown(buf.freeze().to_vec()), wire_type) ] } } } WireType::ThirtyTwoBit => { let t: Type = field_descriptor.r#type(); match t { - Type::Float => ProtobufFieldData::Float(buffer.get_f32_le()), - Type::Fixed32 => ProtobufFieldData::UInteger32(buffer.get_u32_le()), - Type::Sfixed32 => ProtobufFieldData::Integer32(buffer.get_i32_le()), + Type::Float => vec![ (ProtobufFieldData::Float(buffer.get_f32_le()), wire_type) ], + Type::Fixed32 => vec![ (ProtobufFieldData::UInteger32(buffer.get_u32_le()), wire_type) ], + Type::Sfixed32 => vec![ (ProtobufFieldData::Integer32(buffer.get_i32_le()), wire_type) ], _ => { error!("Was expecting {:?} but received an unknown fixed 32 bit type", t); let value = buffer.get_u32_le(); - ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()) + vec![ (ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()), wire_type) ] } } } @@ -292,11 +295,13 @@ pub fn decode_message( }; trace!(field_num, ?wire_type, ?data, "read field, bytes remaining = {}", buffer.remaining()); - fields.push(ProtobufField { - field_num, - wire_type, - data - }); + for (data, wire_type) in data { + fields.push(ProtobufField { + field_num, + wire_type, + data + }); + } } Err(err) => { warn!("Was not able to decode field: {}", err); @@ -325,6 +330,82 @@ pub fn decode_message( Ok(fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)).cloned().collect()) } +fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow::Result> { + let mut values = vec![]; + let t: Type = field.r#type(); + match t { + Type::Double => { + while data.has_remaining() { + values.push((ProtobufFieldData::Double(data.get_f64_le()), WireType::SixtyFourBit)); + } + } + Type::Float => { + while data.has_remaining() { + values.push((ProtobufFieldData::Float(data.get_f32_le()), WireType::ThirtyTwoBit)); + } + } + Type::Int64 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + values.push((ProtobufFieldData::Integer64(varint as i64), WireType::Varint)); + } + } + Type::Uint64 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + values.push((ProtobufFieldData::UInteger64(varint), WireType::Varint)); + } + } + Type::Int32 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + values.push((ProtobufFieldData::Integer32(varint as i32), WireType::Varint)); + } + } + Type::Fixed64 => { + while data.has_remaining() { + values.push((ProtobufFieldData::UInteger64(data.get_u64_le()), WireType::SixtyFourBit)); + } + } + Type::Fixed32 => { + while data.has_remaining() { + values.push((ProtobufFieldData::UInteger32(data.get_u32_le()), WireType::ThirtyTwoBit)); + } + } + Type::Uint32 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + values.push((ProtobufFieldData::UInteger32(varint as u32), WireType::Varint)); + } + } + Type::Sfixed32 => { + while data.has_remaining() { + values.push((ProtobufFieldData::Integer32(data.get_i32_le()), WireType::ThirtyTwoBit)); + } + } + Type::Sfixed64 => { + while data.has_remaining() { + values.push((ProtobufFieldData::Integer64(data.get_i64_le()), WireType::SixtyFourBit)); + } + } + Type::Sint32 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + let value = varint as u32; + values.push((ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32))), WireType::Varint)); + } + } + Type::Sint64 => { + while data.has_remaining() { + let varint = decode_varint(data)?; + values.push((ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))), WireType::Varint)); + } + } + _ => return Err(anyhow!("Field type {:?} can not be packed", t)) + }; + Ok(values) +} + fn find_field_descriptor(field_num: i32, descriptor: &DescriptorProto) -> anyhow::Result { descriptor.field.iter().find(|field| { if let Some(num) = field.number { @@ -982,4 +1063,51 @@ mod tests { }; expect!(ProtobufFieldData::Unknown(vec![1, 2, 3, 4]).default_field_value(&descriptor)).to(be_equal_to(ProtobufFieldData::Unknown(vec![]))); } + + #[test] + fn decode_packed_field() { + let f_value: f32 = 12.0; + let f_value2: f32 = 9.0; + let mut buffer = BytesMut::new(); + buffer.put_u8(10); + buffer.put_u8(8); + buffer.put_f32_le(f_value); + buffer.put_f32_le(f_value2); + + let descriptor = DescriptorProto { + name: Some("PackedFieldMessage".to_string()), + field: vec![ + prost_types::FieldDescriptorProto { + name: Some("field_1".to_string()), + number: Some(1), + label: Some(prost_types::field_descriptor_proto::Label::Repeated as i32), + r#type: Some(prost_types::field_descriptor_proto::Type::Float as i32), + type_name: Some("Float".to_string()), + extendee: None, + default_value: None, + oneof_index: None, + json_name: None, + options: None, + proto3_optional: None + } + ], + extension: vec![], + nested_type: vec![], + enum_type: vec![], + extension_range: vec![], + oneof_decl: vec![], + options: None, + reserved_range: vec![], + reserved_name: vec![] + }; + + let result = decode_message(&mut buffer, &descriptor, &FileDescriptorSet{ file: vec![] }).unwrap(); + expect!(result.len()).to(be_equal_to(2)); + + let field_result = result.first().unwrap(); + + expect!(field_result.field_num).to(be_equal_to(1)); + expect!(field_result.wire_type).to(be_equal_to(WireType::ThirtyTwoBit)); + expect!(&field_result.data).to(be_equal_to(&ProtobufFieldData::Float(12.0))); + } } diff --git a/src/utils.rs b/src/utils.rs index a3d24aa..fc7c494 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -5,13 +5,24 @@ use std::fmt::Write; use anyhow::anyhow; use bytes::{Bytes, BytesMut}; +use field_descriptor_proto::Type; use maplit::hashmap; use pact_models::json_utils::json_to_string; use pact_models::pact::load_pact_from_json; use pact_models::prelude::v4::V4Pact; use pact_models::v4::interaction::V4Interaction; use prost::Message; -use prost_types::{DescriptorProto, EnumDescriptorProto, field_descriptor_proto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet, MethodDescriptorProto, ServiceDescriptorProto, Value}; +use prost_types::{ + DescriptorProto, + EnumDescriptorProto, + field_descriptor_proto, + FieldDescriptorProto, + FileDescriptorProto, + FileDescriptorSet, + MethodDescriptorProto, + ServiceDescriptorProto, + Value +}; use prost_types::field_descriptor_proto::Label; use prost_types::value::Kind; use serde_json::json; @@ -50,7 +61,7 @@ pub fn find_message_type_in_file_descriptor(message_name: &str, descriptor: &Fil /// If the field is a map field. A field will be a map field if it is a repeated field, the field /// type is a message and the nested type has the map flag set on the message options. pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> bool { - if field.label() == Label::Repeated && field.r#type() == field_descriptor_proto::Type::Message { + if field.label() == Label::Repeated && field.r#type() == Type::Message { match find_nested_type(message_descriptor, field) { Some(nested) => match nested.options { None => false, @@ -66,7 +77,7 @@ pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescripto /// Returns the nested descriptor for this field. pub fn find_nested_type(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> Option { trace!(">> find_nested_type({:?}, {:?}, {:?}, {:?})", message_descriptor.name, field.name, field.r#type(), field.type_name); - if field.r#type() == field_descriptor_proto::Type::Message { + if field.r#type() == Type::Message { let type_name = field.type_name.clone().unwrap_or_default(); let message_type = last_name(type_name.as_str()); trace!("find_nested_type: Looking for nested type '{}'", message_type); @@ -343,6 +354,17 @@ pub(crate) fn find_service_descriptor<'a>( .ok_or_else(|| anyhow!("Did not find a descriptor for service '{}'", service_name)) } +/// If a field type should be packed. These are repeated fields of primitive numeric types +/// (types which use the varint, 32-bit, or 64-bit wire types) +pub fn should_be_packed_type(field_type: Type) -> bool { + match field_type { + Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 | + Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 | + Type::Sint64 => true, + _ => false + } +} + #[cfg(test)] pub(crate) mod tests { use bytes::Bytes;