Skip to content

Commit

Permalink
feat: support decoding packed repeated fields
Browse files Browse the repository at this point in the history
  • Loading branch information
uglyog committed Aug 23, 2022
1 parent c65deb4 commit 7773c99
Show file tree
Hide file tree
Showing 3 changed files with 186 additions and 45 deletions.
13 changes: 2 additions & 11 deletions src/message_builder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ use prost_types::{DescriptorProto, FieldDescriptorProto, FileDescriptorProto};
use prost_types::field_descriptor_proto::Type;
use tracing::trace;

use crate::utils::last_name;
use crate::utils::{last_name, should_be_packed_type};

/// Enum to set what type of field the value is for
#[derive(Clone, Copy, Debug, PartialEq)]
Expand Down Expand Up @@ -293,7 +293,7 @@ impl MessageBuilder {
fn encode_repeated_field(&self, buffer: &mut BytesMut, field_value: &FieldValueInner) -> anyhow::Result<()> {
trace!(">> encode_repeated_field({:?})", field_value);
if !field_value.values.is_empty() {
if should_be_packed(field_value) {
if should_be_packed_type(field_value.proto_type) {
self.encode_packed_field(buffer, field_value)?;
} else {
for value in &field_value.values {
Expand Down Expand Up @@ -425,15 +425,6 @@ impl MessageBuilder {
}
}

fn should_be_packed(field: &FieldValueInner) -> bool {
match field.proto_type {
Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 |
Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 |
Type::Sint64 => true,
_ => false
}
}

fn field_type_name(field: &FieldValueInner) -> anyhow::Result<String> {
Ok(match field.proto_type {
Type::Double => "double".to_string(),
Expand Down
190 changes: 159 additions & 31 deletions src/message_decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,14 @@ use std::fmt::{Display, Formatter};
use std::str::from_utf8;

use anyhow::anyhow;
use bytes::{Buf, BytesMut};
use bytes::{Buf, Bytes, BytesMut};
use itertools::Itertools;
use prost::encoding::{decode_key, decode_varint, encode_varint, WireType};
use prost_types::{DescriptorProto, EnumDescriptorProto, FieldDescriptorProto, FileDescriptorSet};
use prost_types::field_descriptor_proto::Type;
use tracing::{error, trace, warn};
use tracing::{debug, error, trace, warn};

use crate::utils::{as_hex, find_message_type_by_name, last_name};
use crate::utils::{as_hex, find_message_type_by_name, is_repeated_field, last_name, should_be_packed_type};

/// Decoded Protobuf field
#[derive(Clone, Debug, PartialEq)]
Expand Down Expand Up @@ -211,92 +211,97 @@ pub fn decode_message<B>(
let varint = decode_varint(buffer)?;
let t: Type = field_descriptor.r#type();
match t {
Type::Int64 => ProtobufFieldData::Integer64(varint as i64),
Type::Uint64 => ProtobufFieldData::UInteger64(varint),
Type::Int32 => ProtobufFieldData::Integer32(varint as i32),
Type::Bool => ProtobufFieldData::Boolean(varint > 0),
Type::Uint32 => ProtobufFieldData::UInteger32(varint as u32),
Type::Int64 => vec![ (ProtobufFieldData::Integer64(varint as i64), wire_type) ],
Type::Uint64 => vec![ (ProtobufFieldData::UInteger64(varint), wire_type) ],
Type::Int32 => vec![ (ProtobufFieldData::Integer32(varint as i32), wire_type) ],
Type::Bool => vec![ (ProtobufFieldData::Boolean(varint > 0), wire_type) ],
Type::Uint32 => vec![ (ProtobufFieldData::UInteger32(varint as u32), wire_type) ],
Type::Enum => {
let enum_proto = descriptor.enum_type.iter()
.find(|enum_type| enum_type.name.clone().unwrap_or_default() == last_name(field_descriptor.type_name.clone().unwrap_or_default().as_str()))
.ok_or_else(|| anyhow!("Did not find the enum {:?} for the field {} in the Protobuf descriptor", field_descriptor.type_name, field_num))?;
ProtobufFieldData::Enum(varint as i32, enum_proto.clone())
vec![ (ProtobufFieldData::Enum(varint as i32, enum_proto.clone()), wire_type) ]
},
Type::Sint32 => {
let value = varint as u32;
ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32)))
vec![ (ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32))), wire_type) ]
},
Type::Sint64 => ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))),
Type::Sint64 => vec![ (ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))), wire_type) ],
_ => {
error!("Was expecting {:?} but received an unknown varint type", t);
ProtobufFieldData::Unknown(varint.to_le_bytes().to_vec())
vec![ (ProtobufFieldData::Unknown(varint.to_le_bytes().to_vec()), wire_type) ]
}
}
}
WireType::SixtyFourBit => {
let t: Type = field_descriptor.r#type();
match t {
Type::Double => ProtobufFieldData::Double(buffer.get_f64_le()),
Type::Fixed64 => ProtobufFieldData::UInteger64(buffer.get_u64_le()),
Type::Sfixed64 => ProtobufFieldData::Integer64(buffer.get_i64_le()),
Type::Double => vec![ (ProtobufFieldData::Double(buffer.get_f64_le()), wire_type) ],
Type::Fixed64 => vec![ (ProtobufFieldData::UInteger64(buffer.get_u64_le()), wire_type) ],
Type::Sfixed64 => vec![ (ProtobufFieldData::Integer64(buffer.get_i64_le()), wire_type) ],
_ => {
error!("Was expecting {:?} but received an unknown 64 bit type", t);
let value = buffer.get_u64_le();
ProtobufFieldData::Unknown(value.to_le_bytes().to_vec())
vec![ (ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()), wire_type) ]
}
}
}
WireType::LengthDelimited => {
let data_length = decode_varint(buffer)?;
let data_buffer = if buffer.remaining() >= data_length as usize {
let mut data_buffer = if buffer.remaining() >= data_length as usize {
buffer.copy_to_bytes(data_length as usize)
} else {
return Err(anyhow!("Insufficient data remaining ({} bytes) to read {} bytes for field {}", buffer.remaining(), data_length, field_num));
};
let t: Type = field_descriptor.r#type();
match t {
Type::String => ProtobufFieldData::String(from_utf8(&data_buffer)?.to_string()),
Type::String => vec![ (ProtobufFieldData::String(from_utf8(&data_buffer)?.to_string()), wire_type) ],
Type::Message => {
let type_name = field_descriptor.type_name.as_ref().map(|v| last_name(v.as_str()).to_string());
let message_proto = descriptor.nested_type.iter()
.find(|message_descriptor| message_descriptor.name == type_name)
.cloned()
.or_else(|| find_message_type_by_name(&type_name.unwrap_or_default(), descriptors).ok())
.ok_or_else(|| anyhow!("Did not find the embedded message {:?} for the field {} in the Protobuf descriptor", field_descriptor.type_name, field_num))?;
ProtobufFieldData::Message(data_buffer.to_vec(), message_proto)
vec![ (ProtobufFieldData::Message(data_buffer.to_vec(), message_proto), wire_type) ]
}
Type::Bytes => ProtobufFieldData::Bytes(data_buffer.to_vec()),
_ => {
Type::Bytes => vec![ (ProtobufFieldData::Bytes(data_buffer.to_vec()), wire_type) ],
_ => if should_be_packed_type(t) && is_repeated_field(&field_descriptor) {
debug!("Reading length delimited field as a packed repeated field");
decode_packed_field(field_descriptor, &mut data_buffer)?
} else {
error!("Was expecting {:?} but received an unknown length-delimited type", t);
let mut buf = BytesMut::with_capacity((data_length + 8) as usize);
encode_varint(data_length, &mut buf);
buf.extend_from_slice(&*data_buffer);
ProtobufFieldData::Unknown(buf.freeze().to_vec())
vec![ (ProtobufFieldData::Unknown(buf.freeze().to_vec()), wire_type) ]
}
}
}
WireType::ThirtyTwoBit => {
let t: Type = field_descriptor.r#type();
match t {
Type::Float => ProtobufFieldData::Float(buffer.get_f32_le()),
Type::Fixed32 => ProtobufFieldData::UInteger32(buffer.get_u32_le()),
Type::Sfixed32 => ProtobufFieldData::Integer32(buffer.get_i32_le()),
Type::Float => vec![ (ProtobufFieldData::Float(buffer.get_f32_le()), wire_type) ],
Type::Fixed32 => vec![ (ProtobufFieldData::UInteger32(buffer.get_u32_le()), wire_type) ],
Type::Sfixed32 => vec![ (ProtobufFieldData::Integer32(buffer.get_i32_le()), wire_type) ],
_ => {
error!("Was expecting {:?} but received an unknown fixed 32 bit type", t);
let value = buffer.get_u32_le();
ProtobufFieldData::Unknown(value.to_le_bytes().to_vec())
vec![ (ProtobufFieldData::Unknown(value.to_le_bytes().to_vec()), wire_type) ]
}
}
}
_ => return Err(anyhow!("Messages with {:?} wire type fields are not supported", wire_type))
};

trace!(field_num, ?wire_type, ?data, "read field, bytes remaining = {}", buffer.remaining());
fields.push(ProtobufField {
field_num,
wire_type,
data
});
for (data, wire_type) in data {
fields.push(ProtobufField {
field_num,
wire_type,
data
});
}
}
Err(err) => {
warn!("Was not able to decode field: {}", err);
Expand Down Expand Up @@ -325,6 +330,82 @@ pub fn decode_message<B>(
Ok(fields.iter().sorted_by(|a, b| Ord::cmp(&a.field_num, &b.field_num)).cloned().collect())
}

fn decode_packed_field(field: FieldDescriptorProto, data: &mut Bytes) -> anyhow::Result<Vec<(ProtobufFieldData, WireType)>> {
let mut values = vec![];
let t: Type = field.r#type();
match t {
Type::Double => {
while data.has_remaining() {
values.push((ProtobufFieldData::Double(data.get_f64_le()), WireType::SixtyFourBit));
}
}
Type::Float => {
while data.has_remaining() {
values.push((ProtobufFieldData::Float(data.get_f32_le()), WireType::ThirtyTwoBit));
}
}
Type::Int64 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
values.push((ProtobufFieldData::Integer64(varint as i64), WireType::Varint));
}
}
Type::Uint64 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
values.push((ProtobufFieldData::UInteger64(varint), WireType::Varint));
}
}
Type::Int32 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
values.push((ProtobufFieldData::Integer32(varint as i32), WireType::Varint));
}
}
Type::Fixed64 => {
while data.has_remaining() {
values.push((ProtobufFieldData::UInteger64(data.get_u64_le()), WireType::SixtyFourBit));
}
}
Type::Fixed32 => {
while data.has_remaining() {
values.push((ProtobufFieldData::UInteger32(data.get_u32_le()), WireType::ThirtyTwoBit));
}
}
Type::Uint32 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
values.push((ProtobufFieldData::UInteger32(varint as u32), WireType::Varint));
}
}
Type::Sfixed32 => {
while data.has_remaining() {
values.push((ProtobufFieldData::Integer32(data.get_i32_le()), WireType::ThirtyTwoBit));
}
}
Type::Sfixed64 => {
while data.has_remaining() {
values.push((ProtobufFieldData::Integer64(data.get_i64_le()), WireType::SixtyFourBit));
}
}
Type::Sint32 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
let value = varint as u32;
values.push((ProtobufFieldData::Integer32(((value >> 1) as i32) ^ (-((value & 1) as i32))), WireType::Varint));
}
}
Type::Sint64 => {
while data.has_remaining() {
let varint = decode_varint(data)?;
values.push((ProtobufFieldData::Integer64(((varint >> 1) as i64) ^ (-((varint & 1) as i64))), WireType::Varint));
}
}
_ => return Err(anyhow!("Field type {:?} can not be packed", t))
};
Ok(values)
}

fn find_field_descriptor(field_num: i32, descriptor: &DescriptorProto) -> anyhow::Result<FieldDescriptorProto> {
descriptor.field.iter().find(|field| {
if let Some(num) = field.number {
Expand Down Expand Up @@ -982,4 +1063,51 @@ mod tests {
};
expect!(ProtobufFieldData::Unknown(vec![1, 2, 3, 4]).default_field_value(&descriptor)).to(be_equal_to(ProtobufFieldData::Unknown(vec![])));
}

#[test]
fn decode_packed_field() {
let f_value: f32 = 12.0;
let f_value2: f32 = 9.0;
let mut buffer = BytesMut::new();
buffer.put_u8(10);
buffer.put_u8(8);
buffer.put_f32_le(f_value);
buffer.put_f32_le(f_value2);

let descriptor = DescriptorProto {
name: Some("PackedFieldMessage".to_string()),
field: vec![
prost_types::FieldDescriptorProto {
name: Some("field_1".to_string()),
number: Some(1),
label: Some(prost_types::field_descriptor_proto::Label::Repeated as i32),
r#type: Some(prost_types::field_descriptor_proto::Type::Float as i32),
type_name: Some("Float".to_string()),
extendee: None,
default_value: None,
oneof_index: None,
json_name: None,
options: None,
proto3_optional: None
}
],
extension: vec![],
nested_type: vec![],
enum_type: vec![],
extension_range: vec![],
oneof_decl: vec![],
options: None,
reserved_range: vec![],
reserved_name: vec![]
};

let result = decode_message(&mut buffer, &descriptor, &FileDescriptorSet{ file: vec![] }).unwrap();
expect!(result.len()).to(be_equal_to(2));

let field_result = result.first().unwrap();

expect!(field_result.field_num).to(be_equal_to(1));
expect!(field_result.wire_type).to(be_equal_to(WireType::ThirtyTwoBit));
expect!(&field_result.data).to(be_equal_to(&ProtobufFieldData::Float(12.0)));
}
}
28 changes: 25 additions & 3 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,24 @@ use std::fmt::Write;

use anyhow::anyhow;
use bytes::{Bytes, BytesMut};
use field_descriptor_proto::Type;
use maplit::hashmap;
use pact_models::json_utils::json_to_string;
use pact_models::pact::load_pact_from_json;
use pact_models::prelude::v4::V4Pact;
use pact_models::v4::interaction::V4Interaction;
use prost::Message;
use prost_types::{DescriptorProto, EnumDescriptorProto, field_descriptor_proto, FieldDescriptorProto, FileDescriptorProto, FileDescriptorSet, MethodDescriptorProto, ServiceDescriptorProto, Value};
use prost_types::{
DescriptorProto,
EnumDescriptorProto,
field_descriptor_proto,
FieldDescriptorProto,
FileDescriptorProto,
FileDescriptorSet,
MethodDescriptorProto,
ServiceDescriptorProto,
Value
};
use prost_types::field_descriptor_proto::Label;
use prost_types::value::Kind;
use serde_json::json;
Expand Down Expand Up @@ -50,7 +61,7 @@ pub fn find_message_type_in_file_descriptor(message_name: &str, descriptor: &Fil
/// If the field is a map field. A field will be a map field if it is a repeated field, the field
/// type is a message and the nested type has the map flag set on the message options.
pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> bool {
if field.label() == Label::Repeated && field.r#type() == field_descriptor_proto::Type::Message {
if field.label() == Label::Repeated && field.r#type() == Type::Message {
match find_nested_type(message_descriptor, field) {
Some(nested) => match nested.options {
None => false,
Expand All @@ -66,7 +77,7 @@ pub fn is_map_field(message_descriptor: &DescriptorProto, field: &FieldDescripto
/// Returns the nested descriptor for this field.
pub fn find_nested_type(message_descriptor: &DescriptorProto, field: &FieldDescriptorProto) -> Option<DescriptorProto> {
trace!(">> find_nested_type({:?}, {:?}, {:?}, {:?})", message_descriptor.name, field.name, field.r#type(), field.type_name);
if field.r#type() == field_descriptor_proto::Type::Message {
if field.r#type() == Type::Message {
let type_name = field.type_name.clone().unwrap_or_default();
let message_type = last_name(type_name.as_str());
trace!("find_nested_type: Looking for nested type '{}'", message_type);
Expand Down Expand Up @@ -343,6 +354,17 @@ pub(crate) fn find_service_descriptor<'a>(
.ok_or_else(|| anyhow!("Did not find a descriptor for service '{}'", service_name))
}

/// If a field type should be packed. These are repeated fields of primitive numeric types
/// (types which use the varint, 32-bit, or 64-bit wire types)
pub fn should_be_packed_type(field_type: Type) -> bool {
match field_type {
Type::Double | Type::Float | Type::Int64 | Type::Uint64 | Type::Int32 | Type::Fixed64 |
Type::Fixed32 | Type::Uint32 | Type::Sfixed32 | Type::Sfixed64 | Type::Sint32 |
Type::Sint64 => true,
_ => false
}
}

#[cfg(test)]
pub(crate) mod tests {
use bytes::Bytes;
Expand Down

0 comments on commit 7773c99

Please sign in to comment.