Skip to content

Commit

Permalink
feat(codecs): add support for protobuf decoding (#18019)
Browse files Browse the repository at this point in the history
* feat(codecs): add support for protobuf decoding, WIP still have some TODO to resolve

* feat(codecs): add support for protobuf.
code-review fixes: handle unwraps and fix support for empty buffer as a message, allowed in protobuf.

* feat(codecs): add support for protobuf.
code-review fixes: use `kind::any()` instead of `kind::json()`.
use `unimplemented!()` instead of `todo!()`.
in tests, add checks for List and Map.
in `ProtobufDeserializer::new`, refactor out creation of MessageDescriptor.
run `cargo fmt`.

* feat(codecs): add support for protobuf.
code-review fixes: apply suggested refactor to `to_vrl`, it's slightly slower, might improve in following PR.

* clippy fixes and minor refactoring

* address Bruce's comments

* update test code to use new log schema interface

* generate docs

---------

Co-authored-by: Pavlos Rontidis <pavlos.rontidis@gmail.com>
  • Loading branch information
Daniel599 and pront authored Jul 28, 2023
1 parent 8a2f8f6 commit a06c711
Show file tree
Hide file tree
Showing 32 changed files with 742 additions and 1 deletion.
4 changes: 4 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions lib/codecs/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ memchr = { version = "2", default-features = false }
once_cell = { version = "1.18", default-features = false }
ordered-float = { version = "3.7.0", default-features = false }
prost = { version = "0.11.8", default-features = false, features = ["std"] }
prost-reflect = { version = "0.11", default-features = false, features = ["serde"] }
regex = { version = "1.9.1", default-features = false, features = ["std", "perf"] }
serde = { version = "1", default-features = false, features = ["derive"] }
serde_json = { version = "1", default-features = false }
Expand Down
2 changes: 2 additions & 0 deletions lib/codecs/src/decoding/format/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ mod gelf;
mod json;
mod native;
mod native_json;
mod protobuf;
#[cfg(feature = "syslog")]
mod syslog;

Expand All @@ -19,6 +20,7 @@ pub use native::{NativeDeserializer, NativeDeserializerConfig};
pub use native_json::{
NativeJsonDeserializer, NativeJsonDeserializerConfig, NativeJsonDeserializerOptions,
};
pub use protobuf::{ProtobufDeserializer, ProtobufDeserializerConfig};
use smallvec::SmallVec;
#[cfg(feature = "syslog")]
pub use syslog::{SyslogDeserializer, SyslogDeserializerConfig, SyslogDeserializerOptions};
Expand Down
353 changes: 353 additions & 0 deletions lib/codecs/src/decoding/format/protobuf.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,353 @@
use std::collections::BTreeMap;
use std::fs;
use std::path::PathBuf;

use bytes::Bytes;
use chrono::Utc;
use ordered_float::NotNan;
use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor, ReflectMessage};
use smallvec::{smallvec, SmallVec};
use vector_config::configurable_component;
use vector_core::event::LogEvent;
use vector_core::{
config::{log_schema, DataType, LogNamespace},
event::Event,
schema,
};
use vrl::value::Kind;

use super::Deserializer;

/// Config used to build a `ProtobufDeserializer`.
#[configurable_component]
#[derive(Debug, Clone, Default)]
pub struct ProtobufDeserializerConfig {
/// Path to desc file
desc_file: PathBuf,

/// message type. e.g package.message
message_type: String,
}

impl ProtobufDeserializerConfig {
/// Build the `ProtobufDeserializer` from this configuration.
pub fn build(&self) -> ProtobufDeserializer {
// TODO return a Result instead.
ProtobufDeserializer::try_from(self).unwrap()
}

/// Return the type of event build by this deserializer.
pub fn output_type(&self) -> DataType {
DataType::Log
}

/// The schema produced by the deserializer.
pub fn schema_definition(&self, log_namespace: LogNamespace) -> schema::Definition {
match log_namespace {
LogNamespace::Legacy => {
let mut definition =
schema::Definition::empty_legacy_namespace().unknown_fields(Kind::any());

if let Some(timestamp_key) = log_schema().timestamp_key() {
definition = definition.try_with_field(
timestamp_key,
// The protobuf decoder will try to insert a new `timestamp`-type value into the
// "timestamp_key" field, but only if that field doesn't already exist.
Kind::any().or_timestamp(),
Some("timestamp"),
);
}
definition
}
LogNamespace::Vector => {
schema::Definition::new_with_default_metadata(Kind::any(), [log_namespace])
}
}
}
}

/// Deserializer that builds `Event`s from a byte frame containing protobuf.
#[derive(Debug, Clone)]
pub struct ProtobufDeserializer {
message_descriptor: MessageDescriptor,
}

impl ProtobufDeserializer {
/// Creates a new `ProtobufDeserializer`.
pub fn new(message_descriptor: MessageDescriptor) -> Self {
Self { message_descriptor }
}

fn get_message_descriptor(
desc_file: &PathBuf,
message_type: String,
) -> vector_common::Result<MessageDescriptor> {
let b = fs::read(desc_file)
.map_err(|e| format!("Failed to open protobuf desc file '{desc_file:?}': {e}",))?;
let pool = DescriptorPool::decode(b.as_slice())
.map_err(|e| format!("Failed to parse protobuf desc file '{desc_file:?}': {e}"))?;
Ok(pool.get_message_by_name(&message_type).unwrap_or_else(|| {
panic!("The message type '{message_type}' could not be found in '{desc_file:?}'")
}))
}
}

impl Deserializer for ProtobufDeserializer {
fn parse(
&self,
bytes: Bytes,
log_namespace: LogNamespace,
) -> vector_common::Result<SmallVec<[Event; 1]>> {
let dynamic_message = DynamicMessage::decode(self.message_descriptor.clone(), bytes)
.map_err(|error| format!("Error parsing protobuf: {:?}", error))?;

let proto_vrl = to_vrl(&prost_reflect::Value::Message(dynamic_message), None)?;
let mut event = Event::Log(LogEvent::from(proto_vrl));
let event = match log_namespace {
LogNamespace::Vector => event,
LogNamespace::Legacy => {
let timestamp = Utc::now();
if let Some(timestamp_key) = log_schema().timestamp_key_target_path() {
let log = event.as_mut_log();
if !log.contains(timestamp_key) {
log.insert(timestamp_key, timestamp);
}
}
event
}
};

Ok(smallvec![event])
}
}

impl TryFrom<&ProtobufDeserializerConfig> for ProtobufDeserializer {
type Error = vector_common::Error;
fn try_from(config: &ProtobufDeserializerConfig) -> vector_common::Result<Self> {
let message_descriptor = ProtobufDeserializer::get_message_descriptor(
&config.desc_file,
config.message_type.clone(),
)?;
Ok(Self::new(message_descriptor))
}
}

fn to_vrl(
prost_reflect_value: &prost_reflect::Value,
field_descriptor: Option<&prost_reflect::FieldDescriptor>,
) -> vector_common::Result<vrl::value::Value> {
let vrl_value = match prost_reflect_value {
prost_reflect::Value::Bool(v) => vrl::value::Value::from(*v),
prost_reflect::Value::I32(v) => vrl::value::Value::from(*v),
prost_reflect::Value::I64(v) => vrl::value::Value::from(*v),
prost_reflect::Value::U32(v) => vrl::value::Value::from(*v),
prost_reflect::Value::U64(v) => vrl::value::Value::from(*v),
prost_reflect::Value::F32(v) => vrl::value::Value::Float(
NotNan::new(f64::from(*v)).map_err(|_e| "Float number cannot be Nan")?,
),
prost_reflect::Value::F64(v) => {
vrl::value::Value::Float(NotNan::new(*v).map_err(|_e| "F64 number cannot be Nan")?)
}
prost_reflect::Value::String(v) => vrl::value::Value::from(v.as_str()),
prost_reflect::Value::Bytes(v) => vrl::value::Value::from(v.clone()),
prost_reflect::Value::EnumNumber(v) => {
if let Some(field_descriptor) = field_descriptor {
let kind = field_descriptor.kind();
let enum_desc = kind.as_enum().ok_or_else(|| {
format!(
"Internal error while parsing protobuf enum. Field descriptor: {:?}",
field_descriptor
)
})?;
vrl::value::Value::from(
enum_desc
.get_value(*v)
.ok_or_else(|| {
format!("The number {} cannot be in '{}'", v, enum_desc.name())
})?
.name(),
)
} else {
Err("Expected valid field descriptor")?
}
}
prost_reflect::Value::Message(v) => {
let mut obj_map = BTreeMap::new();
for field_desc in v.descriptor().fields() {
let field_value = v.get_field(&field_desc);
let out = to_vrl(field_value.as_ref(), Some(&field_desc))?;
obj_map.insert(field_desc.name().to_string(), out);
}
vrl::value::Value::from(obj_map)
}
prost_reflect::Value::List(v) => {
let vec = v
.iter()
.map(|o| to_vrl(o, field_descriptor))
.collect::<Result<Vec<_>, vector_common::Error>>()?;
vrl::value::Value::from(vec)
}
prost_reflect::Value::Map(v) => {
if let Some(field_descriptor) = field_descriptor {
let kind = field_descriptor.kind();
let message_desc = kind.as_message().ok_or_else(|| {
format!(
"Internal error while parsing protobuf field descriptor: {:?}",
field_descriptor
)
})?;
vrl::value::Value::from(
v.iter()
.map(|kv| {
Ok((
kv.0.as_str()
.ok_or_else(|| {
format!(
"Internal error while parsing protobuf map. Field descriptor: {:?}",
field_descriptor
)
})?
.to_string(),
to_vrl(kv.1, Some(&message_desc.map_entry_value_field()))?,
))
})
.collect::<vector_common::Result<BTreeMap<String, _>>>()?,
)
} else {
Err("Expected valid field descriptor")?
}
}
};
Ok(vrl_value)
}

#[cfg(test)]
mod tests {
// TODO: add test for bad file path & invalid message_type

use std::path::PathBuf;
use std::{env, fs};
use vector_core::config::log_schema;

use super::*;

fn test_data_dir() -> PathBuf {
PathBuf::from(env::var_os("CARGO_MANIFEST_DIR").unwrap())
.join("tests/data/decoding/protobuf")
}

fn parse_and_validate(
protobuf_bin_message: String,
protobuf_desc_path: PathBuf,
message_type: &str,
validate_log: fn(&LogEvent),
) {
let input = Bytes::from(protobuf_bin_message);
let message_descriptor = ProtobufDeserializer::get_message_descriptor(
&protobuf_desc_path,
message_type.to_string(),
)
.unwrap();
let deserializer = ProtobufDeserializer::new(message_descriptor);

for namespace in [LogNamespace::Legacy, LogNamespace::Vector] {
let events = deserializer.parse(input.clone(), namespace).unwrap();
let mut events = events.into_iter();

{
let event = events.next().unwrap();
let log = event.as_log();
validate_log(log);
assert_eq!(
log.get(log_schema().timestamp_key_target_path().unwrap())
.is_some(),
namespace == LogNamespace::Legacy
);
}

assert_eq!(events.next(), None);
}
}

#[test]
fn deserialize_protobuf() {
let protobuf_bin_message_path = test_data_dir().join("person_someone.pb");
let protobuf_desc_path = test_data_dir().join("test_protobuf.desc");
let message_type = "test_protobuf.Person";
let validate_log = |log: &LogEvent| {
assert_eq!(log["name"], "someone".into());
assert_eq!(
log["phones"].as_array().unwrap()[0].as_object().unwrap()["number"]
.as_str()
.unwrap(),
"123456"
);
};

parse_and_validate(
fs::read_to_string(protobuf_bin_message_path).unwrap(),
protobuf_desc_path,
message_type,
validate_log,
);
}

#[test]
fn deserialize_protobuf3() {
let protobuf_bin_message_path = test_data_dir().join("person_someone3.pb");
let protobuf_desc_path = test_data_dir().join("test_protobuf3.desc");
let message_type = "test_protobuf3.Person";
let validate_log = |log: &LogEvent| {
assert_eq!(log["name"], "someone".into());
assert_eq!(
log["phones"].as_array().unwrap()[0].as_object().unwrap()["number"]
.as_str()
.unwrap(),
"1234"
);
assert_eq!(
log["data"].as_object().unwrap()["data_phone"],
"HOME".into()
);
};

parse_and_validate(
fs::read_to_string(protobuf_bin_message_path).unwrap(),
protobuf_desc_path,
message_type,
validate_log,
);
}

#[test]
fn deserialize_empty_buffer() {
let protobuf_bin_message = "".to_string();
let protobuf_desc_path = test_data_dir().join("test_protobuf.desc");
let message_type = "test_protobuf.Person";
let validate_log = |log: &LogEvent| {
assert_eq!(log["name"], "".into());
};

parse_and_validate(
protobuf_bin_message,
protobuf_desc_path,
message_type,
validate_log,
);
}

#[test]
fn deserialize_error_invalid_protobuf() {
let input = Bytes::from("{ foo");
let message_descriptor = ProtobufDeserializer::get_message_descriptor(
&test_data_dir().join("test_protobuf.desc"),
"test_protobuf.Person".to_string(),
)
.unwrap();
let deserializer = ProtobufDeserializer::new(message_descriptor);

for namespace in [LogNamespace::Legacy, LogNamespace::Vector] {
assert!(deserializer.parse(input.clone(), namespace).is_err());
}
}
}
Loading

0 comments on commit a06c711

Please sign in to comment.