From 1e57845af63f89de0e17a1c52b0187c8f6e2c366 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Tue, 23 Mar 2021 10:41:23 -0700 Subject: [PATCH 01/30] conformance + initial impl --- conformance/src/main.rs | 39 +++++++++++++++++++++++++------ prost-build/src/code_generator.rs | 12 ++++++++++ prost-build/src/lib.rs | 17 ++++++++++++++ protobuf/Cargo.toml | 1 + protobuf/build.rs | 1 + tests/Cargo.toml | 2 ++ tests/src/lib.rs | 37 +++++++++++++++++++++++++++++ 7 files changed, 102 insertions(+), 7 deletions(-) diff --git a/conformance/src/main.rs b/conformance/src/main.rs index db404c322..6b2c823b0 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -8,7 +8,7 @@ use protobuf::conformance::{ }; use protobuf::test_messages::proto2::TestAllTypesProto2; use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::{roundtrip, RoundtripResult}; +use tests::{roundtrip, roundtrip_json, RoundtripResult}; fn main() -> io::Result<()> { env_logger::init(); @@ -55,11 +55,6 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { "output format unspecified".to_string(), ); } - WireFormat::Json => { - return conformance_response::Result::Skipped( - "JSON output is not supported".to_string(), - ); - } WireFormat::Jspb => { return conformance_response::Result::Skipped( "JSPB output is not supported".to_string(), @@ -70,9 +65,39 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { "TEXT_FORMAT output is not supported".to_string(), ); } - WireFormat::Protobuf => (), + WireFormat::Protobuf | WireFormat::Json => (), }; + if let WireFormat::Json = request.requested_output_format() { + if let Some(conformance_request::Payload::JsonPayload(json_str)) = request.payload { + let roundtrip = match &*request.message_type { + "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip_json::(json_str), + "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip_json::(json_str), + _ => { + return conformance_response::Result::ParseError(format!( + "unknown message type: {}", + request.message_type + )); + } + }; + + return match roundtrip { + RoundtripResult::Ok(buf) => conformance_response::Result::JsonPayload(match std::str::from_utf8(&buf) { + Ok(str) => str.to_string(), + Err(error) => return conformance_response::Result::ParseError(error.to_string()) + }), + RoundtripResult::DecodeError(error) => { + conformance_response::Result::ParseError(error.to_string()) + } + RoundtripResult::Error(error) => { + conformance_response::Result::RuntimeError(error.to_string()) + } + } + + } + unreachable!() + } + let buf = match request.payload { None => return conformance_response::Result::ParseError("no payload".to_string()), Some(conformance_request::Payload::JsonPayload(_)) => { diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index af50c28c4..56c24c36f 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -257,6 +257,18 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(&attributes); self.buf.push('\n'); } + if let Some(_) = self + .config + .json_mapping + .get(fq_message_name) { + self.push_indent(); + self.buf.push_str("#[derive(serde::Deserialize, serde::Serialize)]"); + self.buf.push('\n'); + self.push_indent(); + self.buf.push_str(r#"#[serde(rename_all = "camelCase")]"#); + self.buf.push('\n'); + } + } fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 6353b7d0d..d1feb1372 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -226,6 +226,7 @@ pub struct Config { bytes_type: PathMap, type_attributes: PathMap, field_attributes: PathMap, + json_mapping: PathMap<()>, prost_types: bool, strip_enum_prefix: bool, out_dir: Option, @@ -444,6 +445,21 @@ impl Config { self } + /// Generates serde attributes in order to conform to the proto to json spec. + // TODO MORE COMMENTS + pub fn json_mapping(&mut self, paths: I) -> &mut Self + where + I: IntoIterator, + S: AsRef, + { + self.map_type.clear(); + for matcher in paths { + self.json_mapping + .insert(matcher.as_ref().to_string(), ()); + } + self + } + /// Configures the code generator to use the provided service generator. pub fn service_generator(&mut self, service_generator: Box) -> &mut Self { self.service_generator = Some(service_generator); @@ -841,6 +857,7 @@ impl default::Default for Config { bytes_type: PathMap::default(), type_attributes: PathMap::default(), field_attributes: PathMap::default(), + json_mapping: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, diff --git a/protobuf/Cargo.toml b/protobuf/Cargo.toml index 0885b6d8b..b58977121 100644 --- a/protobuf/Cargo.toml +++ b/protobuf/Cargo.toml @@ -9,6 +9,7 @@ edition = "2018" bytes = { version = "1", default-features = false } prost = { path = ".." } prost-types = { path = "../prost-types" } +serde = { version = "1", features = ["derive"] } [build-dependencies] anyhow = "1" diff --git a/protobuf/build.rs b/protobuf/build.rs index b0c5694d2..153e13d23 100644 --- a/protobuf/build.rs +++ b/protobuf/build.rs @@ -85,6 +85,7 @@ fn main() -> Result<()> { // values. prost_build::Config::new() .btree_map(&["."]) + .json_mapping(&["."]) .compile_protos( &[ test_includes.join("test_messages_proto2.proto"), diff --git a/tests/Cargo.toml b/tests/Cargo.toml index e369e9966..ed45ad64c 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -18,6 +18,8 @@ cfg-if = "0.1" prost = { path = ".." } prost-types = { path = "../prost-types" } protobuf = { path = "../protobuf" } +serde = { version="1.0", features=["derive"] } +serde_json = { version="1.0" } [dev-dependencies] diff = "0.1" diff --git a/tests/src/lib.rs b/tests/src/lib.rs index b20bc49bc..aae5eda97 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -109,6 +109,8 @@ use bytes::Buf; use prost::Message; +use serde::{Serialize, Deserialize}; + pub enum RoundtripResult { /// The roundtrip succeeded. Ok(Vec), @@ -195,6 +197,41 @@ where RoundtripResult::Ok(buf1) } +/// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, +/// otherwise the comparison may fail due to inconsistent `HashMap` entry encoding ordering. +pub fn roundtrip_json<'de, M>(data: &'de str) -> RoundtripResult +where + M: Message + Default + Serialize + Deserialize<'de>, +{ + // Try to decode a message from the data. If decoding fails, continue. + let all_types: M = match serde_json::from_str(data) { + Ok(all_types) => all_types, + Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + }; + + let str1 = match serde_json::to_string(&all_types) { + Ok(str) => str, + Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + }; + + let roundtrip = match serde_json::from_str(&str1) { + Ok(roundtrip) => roundtrip, + Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + }; + + let str2 = match serde_json::to_string(&roundtrip) { + Ok(str) => str, + Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + }; + + if str1 != str2 { + return RoundtripResult::Error(anyhow!("roundtripped JSON encoded strings do not match")); + } + + RoundtripResult::Ok(str1.into_bytes()) +} + + /// Generic rountrip serialization check for messages. pub fn check_message(msg: &M) where From ca91e1a1e71f934390f5e4875c28c0eb2f27bdf6 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Wed, 24 Mar 2021 17:55:24 -0700 Subject: [PATCH 02/30] initial try, need to get to WK --- conformance/src/main.rs | 10 ++++++---- prost-build/src/code_generator.rs | 13 +++++++++++++ prost-types/Cargo.toml | 1 + prost-types/src/protobuf.rs | 20 ++++++++++++-------- 4 files changed, 32 insertions(+), 12 deletions(-) diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 6b2c823b0..9d9985cdb 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -71,8 +71,8 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { if let WireFormat::Json = request.requested_output_format() { if let Some(conformance_request::Payload::JsonPayload(json_str)) = request.payload { let roundtrip = match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip_json::(json_str), - "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip_json::(json_str), + "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip_json::(&json_str), + "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip_json::(&json_str), _ => { return conformance_response::Result::ParseError(format!( "unknown message type: {}", @@ -93,9 +93,11 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { conformance_response::Result::RuntimeError(error.to_string()) } } - } - unreachable!() + // TODO(konradjniemiec) support proto -> json and json -> proto conformance + return conformance_response::Result::Skipped( + "only json <-> json is supported".to_string(), + ); } let buf = match request.payload { diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 56c24c36f..8aebde1b0 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -271,6 +271,18 @@ impl<'a> CodeGenerator<'a> { } + fn append_json_oneof_field_attributes(&mut self, fq_message_name: &str) { + assert_eq!(b'.', fq_message_name.as_bytes()[0]); + if let Some(_) = self + .config + .json_mapping + .get(fq_message_name) { + self.push_indent(); + self.buf.push_str("#[serde(flatten)]"); + self.buf.push('\n'); + } + } + fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); // TODO: this clone is dirty, but expedious. @@ -484,6 +496,7 @@ impl<'a> CodeGenerator<'a> { .join(", ") )); self.append_field_attributes(fq_message_name, oneof.name()); + self.append_json_oneof_field_attributes(fq_message_name); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 22d6b0642..970a6f35f 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -19,6 +19,7 @@ std = ["prost/std"] [dependencies] bytes = { version = "1", default-features = false } prost = { version = "0.7.0", path = "..", default-features = false, features = ["prost-derive"] } +serde = { version = "1", features = ["derive"] } [dev-dependencies] proptest = "0.9" diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 7530b7827..e307dfdb9 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -996,7 +996,8 @@ pub mod generated_code_info { /// "value": "1.212s" /// } /// -#[derive(Clone, PartialEq, ::prost::Message)] +// TODO(konradjniemiec) proper serialization +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct Any { /// A URL/resource name that uniquely identifies the type of the serialized /// protocol buffer message. This string must contain at least @@ -1444,7 +1445,8 @@ pub struct Mixin { /// microsecond should be expressed in JSON format as "3.000001s". /// /// -#[derive(Clone, PartialEq, ::prost::Message)] +// TODO(konradjniemiec) proper serialization +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -1659,7 +1661,8 @@ pub struct Duration { /// The implementation of any API method which has a FieldMask type field in the /// request should verify the included field paths, and return an /// `INVALID_ARGUMENT` error if any path is unmappable. -#[derive(Clone, PartialEq, ::prost::Message)] +// TODO(konradjniemiec) proper serialization +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct FieldMask { /// The set of field mask paths. #[prost(string, repeated, tag="1")] @@ -1673,7 +1676,7 @@ pub struct FieldMask { /// with the proto support for the language. /// /// The JSON representation for `Struct` is JSON object. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct Struct { /// Unordered map of dynamically typed values. #[prost(btree_map="string, message", tag="1")] @@ -1685,7 +1688,7 @@ pub struct Struct { /// variants, absence of any variant indicates an error. /// /// The JSON representation for `Value` is JSON value. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct Value { /// The kind of value. #[prost(oneof="value::Kind", tags="1, 2, 3, 4, 5, 6")] @@ -1694,7 +1697,7 @@ pub struct Value { /// Nested message and enum types in `Value`. pub mod value { /// The kind of value. - #[derive(Clone, PartialEq, ::prost::Oneof)] + #[derive(Clone, PartialEq, ::prost::Oneof, serde::Serialize, serde::Deserialize)] pub enum Kind { /// Represents a null value. #[prost(enumeration="super::NullValue", tag="1")] @@ -1719,7 +1722,7 @@ pub mod value { /// `ListValue` is a wrapper around a repeated field of values. /// /// The JSON representation for `ListValue` is JSON array. -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] pub struct ListValue { /// Repeated field of dynamically typed values. #[prost(message, repeated, tag="1")] @@ -1828,7 +1831,8 @@ pub enum NullValue { /// ) to obtain a formatter capable of generating timestamps in this format. /// /// -#[derive(Clone, PartialEq, ::prost::Message)] +#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +// TODO(konradjniemiec) proper serialization pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch /// 1970-01-01T00:00:00Z. Must be from 0001-01-01T00:00:00Z to From b9408673332a596f7c86baf044ace3477f8e046e Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 18 Oct 2021 16:41:30 -0700 Subject: [PATCH 03/30] saving old work --- conformance/failing_tests.txt | 1 - conformance/src/main.rs | 72 ++++++++++++++++--------------- prost-build/src/code_generator.rs | 59 +++++++++++++++---------- prost-build/src/lib.rs | 11 +++-- prost-types/Cargo.toml | 1 + prost-types/src/lib.rs | 40 +++++++++++++++++ prost-types/src/protobuf.rs | 2 +- tests/src/lib.rs | 60 +++++++++++++++----------- 8 files changed, 156 insertions(+), 90 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 73f00caa6..1bb832783 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,3 +1,2 @@ -# TODO(danburkert/prost#2): prost doesn't preserve unknown fields. Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 9d9985cdb..7e435886d 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -69,37 +69,43 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { }; if let WireFormat::Json = request.requested_output_format() { - if let Some(conformance_request::Payload::JsonPayload(json_str)) = request.payload { - let roundtrip = match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip_json::(&json_str), - "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip_json::(&json_str), - _ => { - return conformance_response::Result::ParseError(format!( - "unknown message type: {}", - request.message_type - )); - } - }; - - return match roundtrip { - RoundtripResult::Ok(buf) => conformance_response::Result::JsonPayload(match std::str::from_utf8(&buf) { - Ok(str) => str.to_string(), - Err(error) => return conformance_response::Result::ParseError(error.to_string()) - }), - RoundtripResult::DecodeError(error) => { - conformance_response::Result::ParseError(error.to_string()) - } - RoundtripResult::Error(error) => { - conformance_response::Result::RuntimeError(error.to_string()) - } - } - } - // TODO(konradjniemiec) support proto -> json and json -> proto conformance - return conformance_response::Result::Skipped( - "only json <-> json is supported".to_string(), + if let Some(conformance_request::Payload::JsonPayload(json_str)) = request.payload { + let roundtrip = match &*request.message_type { + "protobuf_test_messages.proto2.TestAllTypesProto2" => { + roundtrip_json::(&json_str) + } + "protobuf_test_messages.proto3.TestAllTypesProto3" => { + roundtrip_json::(&json_str) + } + _ => { + return conformance_response::Result::ParseError(format!( + "unknown message type: {}", + request.message_type + )); + } + }; + + return match roundtrip { + RoundtripResult::Ok(buf) => { + conformance_response::Result::JsonPayload(match std::str::from_utf8(&buf) { + Ok(str) => str.to_string(), + Err(error) => { + return conformance_response::Result::ParseError(error.to_string()) + } + }) + } + RoundtripResult::DecodeError(error) => { + conformance_response::Result::ParseError(error) + } + RoundtripResult::Error(error) => conformance_response::Result::RuntimeError(error), + }; + } + // TODO(konradjniemiec) support proto -> json and json -> proto conformance + return conformance_response::Result::Skipped( + "only json <-> json is supported".to_string(), ); } - + let buf = match request.payload { None => return conformance_response::Result::ParseError("no payload".to_string()), Some(conformance_request::Payload::JsonPayload(_)) => { @@ -133,11 +139,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { match roundtrip { RoundtripResult::Ok(buf) => conformance_response::Result::ProtobufPayload(buf), - RoundtripResult::DecodeError(error) => { - conformance_response::Result::ParseError(error.to_string()) - } - RoundtripResult::Error(error) => { - conformance_response::Result::RuntimeError(error.to_string()) - } + RoundtripResult::DecodeError(error) => conformance_response::Result::ParseError(error), + RoundtripResult::Error(error) => conformance_response::Result::RuntimeError(error), } } diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 8aebde1b0..232ee0fbf 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -172,6 +172,7 @@ impl<'a> CodeGenerator<'a> { self.append_doc(&fq_message_name, None); self.append_type_attributes(&fq_message_name); + self.append_json_message_attributes(&fq_message_name); self.push_indent(); self.buf .push_str("#[derive(Clone, PartialEq, ::prost::Message)]\n"); @@ -257,32 +258,44 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(&attributes); self.buf.push('\n'); } - if let Some(_) = self - .config - .json_mapping - .get(fq_message_name) { - self.push_indent(); - self.buf.push_str("#[derive(serde::Deserialize, serde::Serialize)]"); - self.buf.push('\n'); - self.push_indent(); - self.buf.push_str(r#"#[serde(rename_all = "camelCase")]"#); - self.buf.push('\n'); - } + } + + fn append_json_message_attributes(&mut self, fq_message_name: &str) { + if let Some(_) = self.config.json_mapping.get(fq_message_name) { + self.push_indent(); + self.buf + .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); + self.buf.push('\n'); + self.push_indent(); + self.buf.push_str(r#"#[serde(rename_all = "camelCase")]"#); + self.buf.push('\n'); + self.push_indent(); + self.buf.push_str("#[serde(default)]"); + self.buf.push('\n'); + } + } + fn append_json_oneof_enum_attributes(&mut self, fq_message_name: &str) { + if let Some(_) = self.config.json_mapping.get(fq_message_name) { + self.push_indent(); + self.buf + .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); + self.buf.push('\n'); + self.push_indent(); + self.buf.push_str(r#"#[serde(rename_all = "camelCase")]"#); + self.buf.push('\n'); + } } fn append_json_oneof_field_attributes(&mut self, fq_message_name: &str) { - assert_eq!(b'.', fq_message_name.as_bytes()[0]); - if let Some(_) = self - .config - .json_mapping - .get(fq_message_name) { - self.push_indent(); - self.buf.push_str("#[serde(flatten)]"); - self.buf.push('\n'); - } + assert_eq!(b'.', fq_message_name.as_bytes()[0]); + if let Some(_) = self.config.json_mapping.get(fq_message_name) { + self.push_indent(); + self.buf.push_str("#[serde(flatten)]"); + self.buf.push('\n'); + } } - + fn append_field_attributes(&mut self, fq_message_name: &str, field_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); // TODO: this clone is dirty, but expedious. @@ -496,7 +509,7 @@ impl<'a> CodeGenerator<'a> { .join(", ") )); self.append_field_attributes(fq_message_name, oneof.name()); - self.append_json_oneof_field_attributes(fq_message_name); + self.append_json_oneof_field_attributes(fq_message_name); self.push_indent(); self.buf.push_str(&format!( "pub {}: ::core::option::Option<{}>,\n", @@ -520,6 +533,7 @@ impl<'a> CodeGenerator<'a> { let oneof_name = format!("{}.{}", fq_message_name, oneof.name()); self.append_type_attributes(&oneof_name); + self.append_json_oneof_enum_attributes(&oneof_name); self.push_indent(); self.buf .push_str("#[derive(Clone, PartialEq, ::prost::Oneof)]\n"); @@ -616,6 +630,7 @@ impl<'a> CodeGenerator<'a> { self.append_doc(&fq_enum_name, None); self.append_type_attributes(&fq_enum_name); + self.append_json_oneof_enum_attributes(&fq_enum_name); self.push_indent(); self.buf.push_str( "#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)]\n", diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index d1feb1372..97e78aaef 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -449,14 +449,13 @@ impl Config { // TODO MORE COMMENTS pub fn json_mapping(&mut self, paths: I) -> &mut Self where - I: IntoIterator, + I: IntoIterator, S: AsRef, { - self.map_type.clear(); + self.map_type.clear(); for matcher in paths { - self.json_mapping - .insert(matcher.as_ref().to_string(), ()); - } + self.json_mapping.insert(matcher.as_ref().to_string(), ()); + } self } @@ -857,7 +856,7 @@ impl default::Default for Config { bytes_type: PathMap::default(), type_attributes: PathMap::default(), field_attributes: PathMap::default(), - json_mapping: PathMap::default(), + json_mapping: PathMap::default(), prost_types: true, strip_enum_prefix: true, out_dir: None, diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 970a6f35f..c00ad0513 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -20,6 +20,7 @@ std = ["prost/std"] bytes = { version = "1", default-features = false } prost = { version = "0.7.0", path = "..", default-features = false, features = ["prost-derive"] } serde = { version = "1", features = ["derive"] } +humantime = { version = "2.1" } [dev-dependencies] proptest = "0.9" diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 89ba1eb42..eb1ca45ec 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -157,6 +157,46 @@ impl From for std::time::SystemTime { } } +impl serde::Serialize for Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str( + &humantime::format_rfc3339(std::time::SystemTime::from(self.clone())).to_string(), + ) + } +} + +struct TimestampVisitor; + +#[cfg(feature = "std")] +impl<'de> serde::de::Visitor<'de> for TimestampVisitor { + type Value = Timestamp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid RFC 3339 timestamp string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + Ok(Timestamp::from( + humantime::parse_rfc3339(value).map_err(serde::de::Error::custom)?, + )) + } +} + +impl<'de> serde::Deserialize<'de> for Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(TimestampVisitor) + } +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index e307dfdb9..f0d2b7ea5 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -1831,7 +1831,7 @@ pub enum NullValue { /// ) to obtain a formatter capable of generating timestamps in this format. /// /// -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message)] // TODO(konradjniemiec) proper serialization pub struct Timestamp { /// Represents seconds of UTC time since Unix epoch diff --git a/tests/src/lib.rs b/tests/src/lib.rs index aae5eda97..c2b9537ad 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -109,16 +109,16 @@ use bytes::Buf; use prost::Message; -use serde::{Serialize, Deserialize}; +use serde::{Deserialize, Serialize}; pub enum RoundtripResult { /// The roundtrip succeeded. Ok(Vec), /// The data could not be decoded. This could indicate a bug in prost, /// or it could indicate that the input was bogus. - DecodeError(prost::DecodeError), + DecodeError(String), /// Re-encoding or validating the data failed. This indicates a bug in `prost`. - Error(anyhow::Error), + Error(String), } impl RoundtripResult { @@ -132,15 +132,16 @@ impl RoundtripResult { RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), } } - - /// Unwrap the roundtrip result. Panics if the result was a validation or re-encoding error. - pub fn unwrap_error(self) -> Result, prost::DecodeError> { - match self { - RoundtripResult::Ok(buf) => Ok(buf), - RoundtripResult::DecodeError(error) => Err(error), - RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), + /* + /// Unwrap the roundtrip result. Panics if the result was a validation or re-encoding error. + pub fn unwrap_error(self) -> Result, prost::DecodeError> { + match self { + RoundtripResult::Ok(buf) => Ok(buf), + RoundtripResult::DecodeError(error) => Err(DecodeError(error.to_string())), + RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), + } } - } + */ } /// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, @@ -152,7 +153,7 @@ where // Try to decode a message from the data. If decoding fails, continue. let all_types = match M::decode(data) { Ok(all_types) => all_types, - Err(error) => return RoundtripResult::DecodeError(error), + Err(error) => return RoundtripResult::DecodeError(error.to_string()), }; let encoded_len = all_types.encoded_len(); @@ -163,10 +164,10 @@ where let mut buf1 = Vec::new(); if let Err(error) = all_types.encode(&mut buf1) { - return RoundtripResult::Error(error.into()); + return RoundtripResult::Error(error.to_string()); } if encoded_len != buf1.len() { - return RoundtripResult::Error(anyhow!( + return RoundtripResult::Error(format!( "expected encoded len ({}) did not match actual encoded len ({})", encoded_len, buf1.len() @@ -175,12 +176,12 @@ where let roundtrip = match M::decode(&*buf1) { Ok(roundtrip) => roundtrip, - Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + Err(error) => return RoundtripResult::Error(error.to_string()), }; let mut buf2 = Vec::new(); if let Err(error) = roundtrip.encode(&mut buf2) { - return RoundtripResult::Error(error.into()); + return RoundtripResult::Error(error.to_string()); } /* @@ -191,7 +192,7 @@ where */ if buf1 != buf2 { - return RoundtripResult::Error(anyhow!("roundtripped encoded buffers do not match")); + return RoundtripResult::Error("roundtripped encoded buffers do not match".to_string()); } RoundtripResult::Ok(buf1) @@ -206,32 +207,41 @@ where // Try to decode a message from the data. If decoding fails, continue. let all_types: M = match serde_json::from_str(data) { Ok(all_types) => all_types, - Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + Err(error) => return RoundtripResult::DecodeError(format!("step 1 {}", error.to_string())), }; let str1 = match serde_json::to_string(&all_types) { - Ok(str) => str, - Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + Ok(str) => str, + Err(error) => return RoundtripResult::Error(format!("step 2 {}", error.to_string())), }; + if str1 != data { + return RoundtripResult::Error(format!( + "halftripped JSON encoded strings do not match {} {}", + str1, data + )); + } + let roundtrip = match serde_json::from_str(&str1) { Ok(roundtrip) => roundtrip, - Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + Err(error) => return RoundtripResult::Error(format!("step 3 {}", error.to_string())), }; let str2 = match serde_json::to_string(&roundtrip) { - Ok(str) => str, - Err(error) => return RoundtripResult::Error(anyhow::Error::new(error)), + Ok(str) => str, + Err(error) => return RoundtripResult::Error(format!("step 4 {}", error.to_string())), }; if str1 != str2 { - return RoundtripResult::Error(anyhow!("roundtripped JSON encoded strings do not match")); + return RoundtripResult::Error(format!( + "roundtripped JSON encoded strings do not match {} {}", + str1, str2 + )); } RoundtripResult::Ok(str1.into_bytes()) } - /// Generic rountrip serialization check for messages. pub fn check_message(msg: &M) where From 4a13f3c5dadfa53b750e4c09c63a7b2e26bebb9e Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 18 Oct 2021 17:45:50 -0700 Subject: [PATCH 04/30] get back in working state --- conformance/failing_tests.txt | 167 +++++++++++++++++++++++++++++- prost-build/src/code_generator.rs | 6 +- prost-types/src/lib.rs | 3 +- tests/src/lib.rs | 5 +- 4 files changed, 170 insertions(+), 11 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index b41904761..c1535b62d 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,3 +1,164 @@ -# TODO(tokio-rs/prost#2): prost doesn't preserve unknown fields. -Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput +Recommended.Proto2.JsonInput.FieldNameExtension.Validator +Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput +Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator +Recommended.Proto3.JsonInput.FieldNameWithDoubleUnderscores.JsonOutput +Recommended.Proto3.JsonInput.FieldNameWithDoubleUnderscores.Validator +Recommended.Proto3.JsonInput.Int64FieldBeString.Validator +Recommended.Proto3.JsonInput.MultilineNoSpaces.JsonOutput +Recommended.Proto3.JsonInput.MultilineWithSpaces.JsonOutput +Recommended.Proto3.JsonInput.NullValueInNormalMessage.Validator +Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator +Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator +Recommended.Proto3.JsonInput.OneLineNoSpaces.JsonOutput +Recommended.Proto3.JsonInput.OneLineWithSpaces.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroBool.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroDouble.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroFloat.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroMessage.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroString.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroUint32.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroUint64.JsonOutput +Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator +Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator +Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator +Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator +Required.Proto3.JsonInput.AllFieldAcceptNull.JsonOutput +Required.Proto3.JsonInput.Any.JsonOutput +Required.Proto3.JsonInput.AnyNested.JsonOutput +Required.Proto3.JsonInput.AnyUnorderedTypeTag.JsonOutput +Required.Proto3.JsonInput.AnyWithDuration.JsonOutput +Required.Proto3.JsonInput.AnyWithFieldMask.JsonOutput +Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.JsonOutput +Required.Proto3.JsonInput.AnyWithStruct.JsonOutput +Required.Proto3.JsonInput.AnyWithTimestamp.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput +Required.Proto3.JsonInput.BoolFieldFalse.JsonOutput +Required.Proto3.JsonInput.BoolFieldTrue.JsonOutput +Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput +Required.Proto3.JsonInput.BoolMapField.JsonOutput +Required.Proto3.JsonInput.BytesField.JsonOutput +Required.Proto3.JsonInput.BytesRepeatedField.JsonOutput +Required.Proto3.JsonInput.DoubleFieldInfinity.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMaxPositiveValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMinNegativeValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldNan.JsonOutput +Required.Proto3.JsonInput.DoubleFieldNegativeInfinity.JsonOutput +Required.Proto3.JsonInput.DoubleFieldQuotedValue.JsonOutput +Required.Proto3.JsonInput.DurationMaxValue.JsonOutput +Required.Proto3.JsonInput.DurationMinValue.JsonOutput +Required.Proto3.JsonInput.DurationNull.JsonOutput +Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput +Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput +Required.Proto3.JsonInput.EnumField.JsonOutput +Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput +Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput +Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator +Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.JsonOutput +Required.Proto3.JsonInput.EnumRepeatedField.JsonOutput +Required.Proto3.JsonInput.FieldMask.JsonOutput +Required.Proto3.JsonInput.FieldNameEscaped.JsonOutput +Required.Proto3.JsonInput.FieldNameInLowerCamelCase.Validator +Required.Proto3.JsonInput.FieldNameInSnakeCase.JsonOutput +Required.Proto3.JsonInput.FieldNameWithMixedCases.JsonOutput +Required.Proto3.JsonInput.FieldNameWithMixedCases.Validator +Required.Proto3.JsonInput.FieldNameWithNumbers.JsonOutput +Required.Proto3.JsonInput.FieldNameWithNumbers.Validator +Required.Proto3.JsonInput.FloatFieldInfinity.JsonOutput +Required.Proto3.JsonInput.FloatFieldMaxNegativeValue.JsonOutput +Required.Proto3.JsonInput.FloatFieldMaxPositiveValue.JsonOutput +Required.Proto3.JsonInput.FloatFieldMinNegativeValue.JsonOutput +Required.Proto3.JsonInput.FloatFieldMinPositiveValue.JsonOutput +Required.Proto3.JsonInput.FloatFieldNan.JsonOutput +Required.Proto3.JsonInput.FloatFieldNegativeInfinity.JsonOutput +Required.Proto3.JsonInput.FloatFieldQuotedValue.JsonOutput +Required.Proto3.JsonInput.FloatFieldTooLarge +Required.Proto3.JsonInput.FloatFieldTooSmall +Required.Proto3.JsonInput.HelloWorld.JsonOutput +Required.Proto3.JsonInput.Int32FieldExponentialFormat.JsonOutput +Required.Proto3.JsonInput.Int32FieldFloatTrailingZero.JsonOutput +Required.Proto3.JsonInput.Int32FieldMaxFloatValue.JsonOutput +Required.Proto3.JsonInput.Int32FieldMaxValue.JsonOutput +Required.Proto3.JsonInput.Int32FieldMinFloatValue.JsonOutput +Required.Proto3.JsonInput.Int32FieldMinValue.JsonOutput +Required.Proto3.JsonInput.Int32FieldStringValue.JsonOutput +Required.Proto3.JsonInput.Int32FieldStringValueEscaped.JsonOutput +Required.Proto3.JsonInput.Int32MapEscapedKey.JsonOutput +Required.Proto3.JsonInput.Int32MapField.JsonOutput +Required.Proto3.JsonInput.Int64FieldMaxValue.JsonOutput +Required.Proto3.JsonInput.Int64FieldMaxValueNotQuoted.JsonOutput +Required.Proto3.JsonInput.Int64FieldMinValue.JsonOutput +Required.Proto3.JsonInput.Int64FieldMinValueNotQuoted.JsonOutput +Required.Proto3.JsonInput.Int64MapEscapedKey.JsonOutput +Required.Proto3.JsonInput.Int64MapField.JsonOutput +Required.Proto3.JsonInput.MessageField.JsonOutput +Required.Proto3.JsonInput.MessageMapField.JsonOutput +Required.Proto3.JsonInput.MessageRepeatedField.JsonOutput +Required.Proto3.JsonInput.OneofFieldDuplicate +Required.Proto3.JsonInput.OptionalBoolWrapper.JsonOutput +Required.Proto3.JsonInput.OptionalBytesWrapper.JsonOutput +Required.Proto3.JsonInput.OptionalDoubleWrapper.JsonOutput +Required.Proto3.JsonInput.OptionalFloatWrapper.JsonOutput +Required.Proto3.JsonInput.OptionalInt32Wrapper.JsonOutput +Required.Proto3.JsonInput.OptionalInt64Wrapper.JsonOutput +Required.Proto3.JsonInput.OptionalStringWrapper.JsonOutput +Required.Proto3.JsonInput.OptionalUint32Wrapper.JsonOutput +Required.Proto3.JsonInput.OptionalUint64Wrapper.JsonOutput +Required.Proto3.JsonInput.OptionalWrapperTypesWithNonDefaultValue.JsonOutput +Required.Proto3.JsonInput.OriginalProtoFieldName.JsonOutput +Required.Proto3.JsonInput.PrimitiveRepeatedField.JsonOutput +Required.Proto3.JsonInput.RepeatedBoolWrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedBytesWrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedDoubleWrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedFloatWrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedInt32Wrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedInt64Wrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedListValue.JsonOutput +Required.Proto3.JsonInput.RepeatedStringWrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedUint32Wrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedUint64Wrapper.JsonOutput +Required.Proto3.JsonInput.RepeatedValue.JsonOutput +Required.Proto3.JsonInput.SkipsDefaultPrimitive.Validator +Required.Proto3.JsonInput.StringField.JsonOutput +Required.Proto3.JsonInput.StringFieldEscape.JsonOutput +Required.Proto3.JsonInput.StringFieldSurrogatePair.JsonOutput +Required.Proto3.JsonInput.StringFieldUnicode.JsonOutput +Required.Proto3.JsonInput.StringFieldUnicodeEscape.JsonOutput +Required.Proto3.JsonInput.StringFieldUnicodeEscapeWithLowercaseHexLetters.JsonOutput +Required.Proto3.JsonInput.StringRepeatedField.JsonOutput +Required.Proto3.JsonInput.Struct.JsonOutput +Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput +Required.Proto3.JsonInput.TimestampLeap.JsonOutput +Required.Proto3.JsonInput.TimestampMaxValue.JsonOutput +Required.Proto3.JsonInput.TimestampMinValue.JsonOutput +Required.Proto3.JsonInput.TimestampNull.JsonOutput +Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput +Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput +Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput +Required.Proto3.JsonInput.Uint32FieldMaxFloatValue.JsonOutput +Required.Proto3.JsonInput.Uint32FieldMaxValue.JsonOutput +Required.Proto3.JsonInput.Uint32MapField.JsonOutput +Required.Proto3.JsonInput.Uint64FieldMaxValue.JsonOutput +Required.Proto3.JsonInput.Uint64FieldMaxValueNotQuoted.JsonOutput +Required.Proto3.JsonInput.Uint64MapField.JsonOutput +Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput +Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput +Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput +Required.Proto3.JsonInput.ValueAcceptList.JsonOutput +Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput +Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput +Required.Proto3.JsonInput.ValueAcceptString.JsonOutput +Required.Proto3.JsonInput.WrapperTypesWithNullValue.JsonOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index a6b3dfe05..32feacb13 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -271,7 +271,7 @@ impl<'a> CodeGenerator<'a> { } fn append_json_message_attributes(&mut self, fq_message_name: &str) { - if let Some(_) = self.config.json_mapping.get(fq_message_name) { + if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { self.push_indent(); self.buf .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); @@ -286,7 +286,7 @@ impl<'a> CodeGenerator<'a> { } fn append_json_oneof_enum_attributes(&mut self, fq_message_name: &str) { - if let Some(_) = self.config.json_mapping.get(fq_message_name) { + if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { self.push_indent(); self.buf .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); @@ -299,7 +299,7 @@ impl<'a> CodeGenerator<'a> { fn append_json_oneof_field_attributes(&mut self, fq_message_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); - if let Some(_) = self.config.json_mapping.get(fq_message_name) { + if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { self.push_indent(); self.buf.push_str("#[serde(flatten)]"); self.buf.push('\n'); diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 4d9ed9639..1b4e05855 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -255,13 +255,14 @@ impl TryFrom for std::time::SystemTime { } } +#[cfg(feature = "std")] impl serde::Serialize for Timestamp { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer, { serializer.serialize_str( - &humantime::format_rfc3339(std::time::SystemTime::from(self.clone())).to_string(), + &humantime::format_rfc3339(std::time::SystemTime::try_from(self.clone()).unwrap()).to_string(), ) } } diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 8e17e8572..c013786ee 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -109,7 +109,6 @@ pub mod default_string_escape { use alloc::format; use alloc::vec::Vec; -use anyhow::anyhow; use bytes::Buf; use prost::Message; @@ -202,9 +201,7 @@ where } if buf1 != buf3 { - return RoundtripResult::Error(anyhow!( - "roundtripped encoded buffers do not match with `encode_to_vec`" - )); + return RoundtripResult::Error("roundtripped encoded buffers do not match with `encode_to_vec`".to_string()); } RoundtripResult::Ok(buf1) From de960ffedd9b5293ec3a15d880a3ed2272d32165 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 18 Oct 2021 17:46:58 -0700 Subject: [PATCH 05/30] uh? --- conformance/failing_tests.txt | 166 +--------------------------------- 1 file changed, 2 insertions(+), 164 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index c1535b62d..1bb832783 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,164 +1,2 @@ -Recommended.Proto2.JsonInput.FieldNameExtension.Validator -Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput -Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator -Recommended.Proto3.JsonInput.FieldNameWithDoubleUnderscores.JsonOutput -Recommended.Proto3.JsonInput.FieldNameWithDoubleUnderscores.Validator -Recommended.Proto3.JsonInput.Int64FieldBeString.Validator -Recommended.Proto3.JsonInput.MultilineNoSpaces.JsonOutput -Recommended.Proto3.JsonInput.MultilineWithSpaces.JsonOutput -Recommended.Proto3.JsonInput.NullValueInNormalMessage.Validator -Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator -Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator -Recommended.Proto3.JsonInput.OneLineNoSpaces.JsonOutput -Recommended.Proto3.JsonInput.OneLineWithSpaces.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroBool.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroDouble.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroFloat.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroMessage.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroString.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroUint32.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroUint64.JsonOutput -Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator -Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator -Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator -Required.Proto2.JsonInput.StoresDefaultPrimitive.Validator -Required.Proto3.JsonInput.AllFieldAcceptNull.JsonOutput -Required.Proto3.JsonInput.Any.JsonOutput -Required.Proto3.JsonInput.AnyNested.JsonOutput -Required.Proto3.JsonInput.AnyUnorderedTypeTag.JsonOutput -Required.Proto3.JsonInput.AnyWithDuration.JsonOutput -Required.Proto3.JsonInput.AnyWithFieldMask.JsonOutput -Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.JsonOutput -Required.Proto3.JsonInput.AnyWithStruct.JsonOutput -Required.Proto3.JsonInput.AnyWithTimestamp.JsonOutput -Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput -Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput -Required.Proto3.JsonInput.BoolFieldFalse.JsonOutput -Required.Proto3.JsonInput.BoolFieldTrue.JsonOutput -Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput -Required.Proto3.JsonInput.BoolMapField.JsonOutput -Required.Proto3.JsonInput.BytesField.JsonOutput -Required.Proto3.JsonInput.BytesRepeatedField.JsonOutput -Required.Proto3.JsonInput.DoubleFieldInfinity.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMaxPositiveValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMinNegativeValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldNan.JsonOutput -Required.Proto3.JsonInput.DoubleFieldNegativeInfinity.JsonOutput -Required.Proto3.JsonInput.DoubleFieldQuotedValue.JsonOutput -Required.Proto3.JsonInput.DurationMaxValue.JsonOutput -Required.Proto3.JsonInput.DurationMinValue.JsonOutput -Required.Proto3.JsonInput.DurationNull.JsonOutput -Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput -Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput -Required.Proto3.JsonInput.EnumField.JsonOutput -Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput -Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput -Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator -Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.JsonOutput -Required.Proto3.JsonInput.EnumRepeatedField.JsonOutput -Required.Proto3.JsonInput.FieldMask.JsonOutput -Required.Proto3.JsonInput.FieldNameEscaped.JsonOutput -Required.Proto3.JsonInput.FieldNameInLowerCamelCase.Validator -Required.Proto3.JsonInput.FieldNameInSnakeCase.JsonOutput -Required.Proto3.JsonInput.FieldNameWithMixedCases.JsonOutput -Required.Proto3.JsonInput.FieldNameWithMixedCases.Validator -Required.Proto3.JsonInput.FieldNameWithNumbers.JsonOutput -Required.Proto3.JsonInput.FieldNameWithNumbers.Validator -Required.Proto3.JsonInput.FloatFieldInfinity.JsonOutput -Required.Proto3.JsonInput.FloatFieldMaxNegativeValue.JsonOutput -Required.Proto3.JsonInput.FloatFieldMaxPositiveValue.JsonOutput -Required.Proto3.JsonInput.FloatFieldMinNegativeValue.JsonOutput -Required.Proto3.JsonInput.FloatFieldMinPositiveValue.JsonOutput -Required.Proto3.JsonInput.FloatFieldNan.JsonOutput -Required.Proto3.JsonInput.FloatFieldNegativeInfinity.JsonOutput -Required.Proto3.JsonInput.FloatFieldQuotedValue.JsonOutput -Required.Proto3.JsonInput.FloatFieldTooLarge -Required.Proto3.JsonInput.FloatFieldTooSmall -Required.Proto3.JsonInput.HelloWorld.JsonOutput -Required.Proto3.JsonInput.Int32FieldExponentialFormat.JsonOutput -Required.Proto3.JsonInput.Int32FieldFloatTrailingZero.JsonOutput -Required.Proto3.JsonInput.Int32FieldMaxFloatValue.JsonOutput -Required.Proto3.JsonInput.Int32FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Int32FieldMinFloatValue.JsonOutput -Required.Proto3.JsonInput.Int32FieldMinValue.JsonOutput -Required.Proto3.JsonInput.Int32FieldStringValue.JsonOutput -Required.Proto3.JsonInput.Int32FieldStringValueEscaped.JsonOutput -Required.Proto3.JsonInput.Int32MapEscapedKey.JsonOutput -Required.Proto3.JsonInput.Int32MapField.JsonOutput -Required.Proto3.JsonInput.Int64FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Int64FieldMaxValueNotQuoted.JsonOutput -Required.Proto3.JsonInput.Int64FieldMinValue.JsonOutput -Required.Proto3.JsonInput.Int64FieldMinValueNotQuoted.JsonOutput -Required.Proto3.JsonInput.Int64MapEscapedKey.JsonOutput -Required.Proto3.JsonInput.Int64MapField.JsonOutput -Required.Proto3.JsonInput.MessageField.JsonOutput -Required.Proto3.JsonInput.MessageMapField.JsonOutput -Required.Proto3.JsonInput.MessageRepeatedField.JsonOutput -Required.Proto3.JsonInput.OneofFieldDuplicate -Required.Proto3.JsonInput.OptionalBoolWrapper.JsonOutput -Required.Proto3.JsonInput.OptionalBytesWrapper.JsonOutput -Required.Proto3.JsonInput.OptionalDoubleWrapper.JsonOutput -Required.Proto3.JsonInput.OptionalFloatWrapper.JsonOutput -Required.Proto3.JsonInput.OptionalInt32Wrapper.JsonOutput -Required.Proto3.JsonInput.OptionalInt64Wrapper.JsonOutput -Required.Proto3.JsonInput.OptionalStringWrapper.JsonOutput -Required.Proto3.JsonInput.OptionalUint32Wrapper.JsonOutput -Required.Proto3.JsonInput.OptionalUint64Wrapper.JsonOutput -Required.Proto3.JsonInput.OptionalWrapperTypesWithNonDefaultValue.JsonOutput -Required.Proto3.JsonInput.OriginalProtoFieldName.JsonOutput -Required.Proto3.JsonInput.PrimitiveRepeatedField.JsonOutput -Required.Proto3.JsonInput.RepeatedBoolWrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedBytesWrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedDoubleWrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedFloatWrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedInt32Wrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedInt64Wrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedListValue.JsonOutput -Required.Proto3.JsonInput.RepeatedStringWrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedUint32Wrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedUint64Wrapper.JsonOutput -Required.Proto3.JsonInput.RepeatedValue.JsonOutput -Required.Proto3.JsonInput.SkipsDefaultPrimitive.Validator -Required.Proto3.JsonInput.StringField.JsonOutput -Required.Proto3.JsonInput.StringFieldEscape.JsonOutput -Required.Proto3.JsonInput.StringFieldSurrogatePair.JsonOutput -Required.Proto3.JsonInput.StringFieldUnicode.JsonOutput -Required.Proto3.JsonInput.StringFieldUnicodeEscape.JsonOutput -Required.Proto3.JsonInput.StringFieldUnicodeEscapeWithLowercaseHexLetters.JsonOutput -Required.Proto3.JsonInput.StringRepeatedField.JsonOutput -Required.Proto3.JsonInput.Struct.JsonOutput -Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput -Required.Proto3.JsonInput.TimestampLeap.JsonOutput -Required.Proto3.JsonInput.TimestampMaxValue.JsonOutput -Required.Proto3.JsonInput.TimestampMinValue.JsonOutput -Required.Proto3.JsonInput.TimestampNull.JsonOutput -Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput -Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput -Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput -Required.Proto3.JsonInput.Uint32FieldMaxFloatValue.JsonOutput -Required.Proto3.JsonInput.Uint32FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Uint32MapField.JsonOutput -Required.Proto3.JsonInput.Uint64FieldMaxValue.JsonOutput -Required.Proto3.JsonInput.Uint64FieldMaxValueNotQuoted.JsonOutput -Required.Proto3.JsonInput.Uint64MapField.JsonOutput -Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput -Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput -Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput -Required.Proto3.JsonInput.ValueAcceptList.JsonOutput -Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput -Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput -Required.Proto3.JsonInput.ValueAcceptString.JsonOutput -Required.Proto3.JsonInput.WrapperTypesWithNullValue.JsonOutput +Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput +Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput From 6270fff8cd79f59a9d03bff01d497ca5c4a6970b Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Tue, 19 Oct 2021 16:55:10 -0700 Subject: [PATCH 06/30] lifetime error --- prost-build/src/code_generator.rs | 20 +++++++++++++++++++- prost-types/src/lib.rs | 7 ++++++- tests/src/lib.rs | 11 ++++++----- 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 32feacb13..f01d5cfa5 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -280,7 +280,7 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str(r#"#[serde(rename_all = "camelCase")]"#); self.buf.push('\n'); self.push_indent(); - self.buf.push_str("#[serde(default)]"); + self.buf.push_str(r#"#[serde(default)]"#); self.buf.push('\n'); } } @@ -319,6 +319,22 @@ impl<'a> CodeGenerator<'a> { } } + fn append_json_field_attributes(&mut self, fq_message_name: &str, map_type: Option<&str>) { + if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + push_indent(&mut self.buf, self.depth); + if let Some(map_type) = map_type { + self.buf.push_str(&format!( + r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, + map_type + )); + } else { + self.buf + .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); + } + self.buf.push('\n'); + } + } + fn append_field(&mut self, fq_message_name: &str, field: FieldDescriptorProto) { let type_ = field.r#type(); let repeated = field.label == Some(Label::Repeated as i32); @@ -422,6 +438,7 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str("\")]\n"); self.append_field_attributes(fq_message_name, field.name()); + self.append_json_field_attributes(fq_message_name, None); self.push_indent(); self.buf.push_str("pub "); self.buf.push_str(&to_snake(field.name())); @@ -481,6 +498,7 @@ impl<'a> CodeGenerator<'a> { field.number() )); self.append_field_attributes(fq_message_name, field.name()); + self.append_json_field_attributes(fq_message_name, Some(map_type.rust_type())); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 1b4e05855..3173656a7 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -262,7 +262,8 @@ impl serde::Serialize for Timestamp { S: serde::Serializer, { serializer.serialize_str( - &humantime::format_rfc3339(std::time::SystemTime::try_from(self.clone()).unwrap()).to_string(), + &humantime::format_rfc3339(std::time::SystemTime::try_from(self.clone()).unwrap()) + .to_string(), ) } } @@ -296,6 +297,10 @@ impl<'de> serde::Deserialize<'de> for Timestamp { } } +pub fn is_default(t: &T) -> bool { + t == &T::default() +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/tests/src/lib.rs b/tests/src/lib.rs index c013786ee..4dfcac9b6 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -201,7 +201,9 @@ where } if buf1 != buf3 { - return RoundtripResult::Error("roundtripped encoded buffers do not match with `encode_to_vec`".to_string()); + return RoundtripResult::Error( + "roundtripped encoded buffers do not match with `encode_to_vec`".to_string(), + ); } RoundtripResult::Ok(buf1) @@ -226,12 +228,12 @@ where if str1 != data { return RoundtripResult::Error(format!( - "halftripped JSON encoded strings do not match {} {}", + "halftripped JSON encoded strings do not match\nstring: {}\noriginal provided data: {}", str1, data )); } - let roundtrip = match serde_json::from_str(&str1) { + let roundtrip = match serde_json::from_str::<'de, M>(&str1) { Ok(roundtrip) => roundtrip, Err(error) => return RoundtripResult::Error(format!("step 3 {}", error.to_string())), }; @@ -240,14 +242,13 @@ where Ok(str) => str, Err(error) => return RoundtripResult::Error(format!("step 4 {}", error.to_string())), }; - + if str1 != str2 { return RoundtripResult::Error(format!( "roundtripped JSON encoded strings do not match {} {}", str1, str2 )); } - RoundtripResult::Ok(str1.into_bytes()) } From 6dd713c7bb241b675e369062529ddce975277bf7 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 25 Oct 2021 15:52:27 -0700 Subject: [PATCH 07/30] stopping point --- prost-build/src/code_generator.rs | 120 ++++- prost-types/src/lib.rs | 854 ++++++++++++++++++++++++++++++ tests/src/lib.rs | 8 +- 3 files changed, 975 insertions(+), 7 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index f01d5cfa5..b17fa9dc2 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -319,8 +319,28 @@ impl<'a> CodeGenerator<'a> { } } - fn append_json_field_attributes(&mut self, fq_message_name: &str, map_type: Option<&str>) { + fn append_json_field_attributes( + &mut self, + fq_message_name: &str, + ty: &str, + field_name: &str, + optional: bool, + repeated: bool, + map_type: Option<&str>, + ) { if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); + self.buf.push('\n'); + /* if field_name.starts_with('_') { + push_indent(&mut self.buf, self.depth); + self.buf.push_str(&format!( + r#"#[serde(alias = "{}")]"#, + field_name, + )); + self.buf.push('\n'); + }*/ push_indent(&mut self.buf, self.depth); if let Some(map_type) = map_type { self.buf.push_str(&format!( @@ -332,6 +352,86 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); } self.buf.push('\n'); + + match (ty, optional, repeated) { + ("i32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("i32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::i32_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("i64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("i64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::i64_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("u32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::u32_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("u64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::u64_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("f64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f64_visitor")]"#); + self.buf.push('\n'); + } + ("f64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f64_opt_visitor")]"#); + self.buf.push('\n'); + } + ("f32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f32_visitor")]"#); + self.buf.push('\n'); + } + ("f32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f32_opt_visitor")]"#); + self.buf.push('\n'); + } + _ => {} + } } } @@ -438,7 +538,14 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str("\")]\n"); self.append_field_attributes(fq_message_name, field.name()); - self.append_json_field_attributes(fq_message_name, None); + self.append_json_field_attributes( + fq_message_name, + &ty, + field.name(), + optional, + repeated, + None, + ); self.push_indent(); self.buf.push_str("pub "); self.buf.push_str(&to_snake(field.name())); @@ -498,7 +605,14 @@ impl<'a> CodeGenerator<'a> { field.number() )); self.append_field_attributes(fq_message_name, field.name()); - self.append_json_field_attributes(fq_message_name, Some(map_type.rust_type())); + self.append_json_field_attributes( + fq_message_name, + map_type.rust_type(), + field.name(), + false, + false, + Some(map_type.rust_type()), + ); self.push_indent(); self.buf.push_str(&format!( "pub {}: {}<{}, {}>,\n", diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 3173656a7..60ec75ff5 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -301,6 +301,860 @@ pub fn is_default(t: &T) -> bool { t == &T::default() } +pub mod i32_visitor { + struct I32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I32Visitor { + type Value = i32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid I32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(value as i32) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value.parse::().map(|x| x as i32).map_err(E::custom) + } else { + value.parse::().map_err(E::custom) + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I32Visitor) + } +} + +pub mod i32_opt_visitor { + struct I32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid I32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(Some(value as i32)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map(|x| Some(x as i32)) + .map_err(E::custom) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I32Visitor) + } +} + +pub mod i64_visitor { + struct I64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I64Visitor { + type Value = i64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid I64 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as i64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(value as i64) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i64::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value.parse::().map(|x| x as i64).map_err(E::custom) + } else { + value.parse::().map_err(E::custom) + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I64Visitor) + } +} + +pub mod i64_opt_visitor { + struct I64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid I64 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as i64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(Some(value as i64)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i64::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map(|x| Some(x as i64)) + .map_err(E::custom) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I64Visitor) + } +} + +pub mod u32_visitor { + struct U32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U32Visitor { + type Value = u32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid U32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(value as u32) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value.parse::().map(|x| x as u32).map_err(E::custom) + } else { + value.parse::().map_err(E::custom) + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U32Visitor) + } +} + +pub mod u32_opt_visitor { + struct U32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid U32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(Some(value as u32)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map(|x| Some(x as u32)) + .map_err(E::custom) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U32Visitor) + } +} + +pub mod u64_visitor { + struct U64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U64Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid U64 string or integer") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as u64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(value as u64) + } + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value.parse::().map(|x| x as u64).map_err(E::custom) + } else { + value.parse::().map_err(E::custom) + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U64Visitor) + } +} + +pub mod u64_opt_visitor { + struct U64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid U64 string or integer") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as u64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON { + return Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )); + } else { + // This is a round number, we can cast just fine. + Ok(Some(value as u64)) + } + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map(|x| Some(x as u64)) + .map_err(E::custom) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U64Visitor) + } +} + +pub mod f64_visitor { + struct F64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F64Visitor { + type Value = f64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid F64 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(f64::NAN), + "Infinity" => Ok(f64::INFINITY), + "-Infinity" => Ok(f64::NEG_INFINITY), + _ => value.parse::().map_err(E::custom), + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F64Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &f64, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f64(*value) + } + } +} + +pub mod f64_opt_visitor { + struct F64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid F64 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(Some(f64::NAN)), + "Infinity" => Ok(Some(f64::INFINITY)), + "-Infinity" => Ok(Some(f64::NEG_INFINITY)), + _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F64Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + match value { + None => serializer.serialize_none(), + Some(double) => crate::f64_visitor::serialize(double, serializer), + } + } +} + +pub mod f32_visitor { + struct F32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F32Visitor { + type Value = f32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid F32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as f32) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + // TODO figure out min/max bug. + Ok(value as f32) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as f32) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(f32::NAN), + "Infinity" => Ok(f32::INFINITY), + "-Infinity" => Ok(f32::NEG_INFINITY), + _ => value.parse::().map_err(E::custom), + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F32Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &f32, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f32(*value) + } + } +} + +pub mod f32_opt_visitor { + struct F32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid F32 string or integer") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f32)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + // TODO figure out min/max bug. + Ok(Some(value as f32)) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f32)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(Some(f32::NAN)), + "Infinity" => Ok(Some(f32::INFINITY)), + "-Infinity" => Ok(Some(f32::NEG_INFINITY)), + _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F32Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + match value { + None => serializer.serialize_none(), + Some(float) => crate::f32_visitor::serialize(float, serializer), + } + } +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 4dfcac9b6..b6dfbd33b 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -226,14 +226,14 @@ where Err(error) => return RoundtripResult::Error(format!("step 2 {}", error.to_string())), }; - if str1 != data { + /* if str1 != data { return RoundtripResult::Error(format!( "halftripped JSON encoded strings do not match\nstring: {}\noriginal provided data: {}", str1, data )); - } + }*/ - let roundtrip = match serde_json::from_str::<'de, M>(&str1) { + let roundtrip: M = match serde_json::from_str::(data) { Ok(roundtrip) => roundtrip, Err(error) => return RoundtripResult::Error(format!("step 3 {}", error.to_string())), }; @@ -242,7 +242,7 @@ where Ok(str) => str, Err(error) => return RoundtripResult::Error(format!("step 4 {}", error.to_string())), }; - + if str1 != str2 { return RoundtripResult::Error(format!( "roundtripped JSON encoded strings do not match {} {}", From 680f8b78e8bba512473d8f59e2fdd5e737ae2724 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 1 Nov 2021 15:28:05 -0700 Subject: [PATCH 08/30] working --- prost-build/src/code_generator.rs | 83 +++- prost-types/Cargo.toml | 1 + prost-types/src/lib.rs | 649 +++++++++++++++++++++++++++--- tests/Cargo.toml | 1 + tests/src/lib.rs | 5 +- 5 files changed, 678 insertions(+), 61 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index b17fa9dc2..569efc887 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -326,27 +326,42 @@ impl<'a> CodeGenerator<'a> { field_name: &str, optional: bool, repeated: bool, + json_name: &str, map_type: Option<&str>, ) { if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + if json_name.len() > 0 { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(&format!(r#"#[serde(rename = "{}")]"#, json_name,)); + self.buf.push('\n'); + } push_indent(&mut self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); self.buf.push('\n'); - /* if field_name.starts_with('_') { - push_indent(&mut self.buf, self.depth); - self.buf.push_str(&format!( - r#"#[serde(alias = "{}")]"#, - field_name, - )); - self.buf.push('\n'); - }*/ push_indent(&mut self.buf, self.depth); if let Some(map_type) = map_type { self.buf.push_str(&format!( r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, map_type )); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + match map_type { + "::std::collections::HashMap" => + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::map_visitor::deserialize")]"#, + ), + "::prost::alloc::collections::BTreeMap" => + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::btree_map_visitor::deserialize")]"#, + ), + + _ => (), + } + self.buf.push('\n'); + return; } else { self.buf .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); @@ -361,12 +376,32 @@ impl<'a> CodeGenerator<'a> { ); self.buf.push('\n'); } + ("i32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i32_visitor::I32Visitor>")]"#, + ); + self.buf.push('\n'); + } ("i32", true, false) => { push_indent(&mut self.buf, self.depth); self.buf .push_str(r#"#[serde(deserialize_with = "::prost_types::i32_opt_visitor::deserialize")]"#); self.buf.push('\n'); } + ("bool", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::bool_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("bool", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::bool_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } ("i64", false, false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( @@ -430,6 +465,36 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(with = "::prost_types::f32_opt_visitor")]"#); self.buf.push('\n'); } + ("::prost::alloc::string::String", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::string_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("::prost::alloc::string::String", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::string_opt_visitor::deserialize")]"#); + self.buf.push('\n'); + } + ("::prost::alloc::vec::Vec", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::vec_u8_visitor")]"#); + self.buf.push('\n'); + } + ("::prost::alloc::vec::Vec", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::vec_u8_opt_visitor")]"#); + self.buf.push('\n'); + } + (_, _, true) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#); + self.buf.push('\n'); + } _ => {} } } @@ -544,6 +609,7 @@ impl<'a> CodeGenerator<'a> { field.name(), optional, repeated, + field.json_name(), None, ); self.push_indent(); @@ -611,6 +677,7 @@ impl<'a> CodeGenerator<'a> { field.name(), false, false, + field.json_name(), Some(map_type.rust_type()), ); self.push_indent(); diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 371a76287..4f7903b6c 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -20,6 +20,7 @@ default = ["std"] std = ["prost/std"] [dependencies] +base64 = "0.13" bytes = { version = "1", default-features = false } serde = { version = "1", features = ["derive"] } humantime = { version = "2.1" } diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 60ec75ff5..4b4e79c98 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -16,6 +16,7 @@ use core::i32; use core::i64; use core::time; + include!("protobuf.rs"); pub mod compiler { include!("compiler.rs"); @@ -297,13 +298,407 @@ impl<'de> serde::Deserialize<'de> for Timestamp { } } +pub trait HasConstructor { + fn new() -> Self; +} + +pub struct MyType<'de, T: serde::de::Visitor<'de> + HasConstructor>(>::Value); + +impl<'de, T> serde::Deserialize<'de> for MyType<'de, T> where T: serde::de::Visitor<'de> + HasConstructor { + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(T::new()).map(|x| MyType{0: x}) + } +} + pub fn is_default(t: &T) -> bool { t == &T::default() } +pub mod vec_visitor { + struct VecVisitor<'de, T> where T: serde::Deserialize<'de> { + _vec_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T: serde::Deserialize<'de>> serde::de::Visitor<'de> for VecVisitor<'de, T> { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de> { + let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); + loop { + match seq.next_element()? { + Some(el) => res.push(el), + None => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecVisitor::<'de, T>{_vec_type: &std::marker::PhantomData}) + } +} + +pub mod repeated_visitor { + struct VecVisitor<'de, T> where T: serde::de::Visitor<'de> + crate::HasConstructor { + _vec_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for VecVisitor<'de, T> where + T: serde::de::Visitor<'de> + crate::HasConstructor, + { + type Value = Vec<>::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de> { + let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option> = seq.next_element()?; + match response { + Some(el) => res.push(el.0), + None => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T: 'de + serde::de::Visitor<'de> + crate::HasConstructor>(deserializer: D) -> Result>::Value>, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecVisitor::<'de, T>{_vec_type: &std::marker::PhantomData}) + } + + pub fn serialize(value: Vec<::Value>, serializer: S) -> Result + where + S: serde::Serializer, + F: crate::SerializeMethod, + { + use serde::ser::SerializeSeq; + let mut seq = serializer.serialize_seq(Some(value.len()))?; + for e in value { + seq.serialize_element(&crate::MySeType::{val: e})?; + } + seq.end() + } +} + +pub trait SerializeMethod { + type Value; + fn serialize(value: &Value, serializer: S) -> Result where S: serde::Serializer; +} + +pub struct MySeType where T: SerializeMethod { + val: T::Value, +} + +impl serde::Serialize for MySeType { + fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { + T::serialize(&self.val, serializer) + } +} + +pub mod map_visitor { + struct MapVisitor<'de, K, V> + where K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::Deserialize<'de> + { + _key_type: &'de std::marker::PhantomData, + _value_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, V: serde::Deserialize<'de>> serde::de::Visitor<'de> for MapVisitor<'de, K, V> { + type Value = std::collections::HashMap; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de> { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + match map.next_entry()? { + Some((k, v)) => {res.insert(k,v);}, + None => return Ok(res), + } + } + + } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, V: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(MapVisitor::<'de, K, V>{ + _key_type: &std::marker::PhantomData, + _value_type: &std::marker::PhantomData, + }) + } +} + +pub mod btree_map_visitor { + struct MapVisitor<'de, K, V> + where K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::Deserialize<'de> + { + _key_type: &'de std::marker::PhantomData, + _value_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, V: serde::Deserialize<'de>> serde::de::Visitor<'de> for MapVisitor<'de, K, V> { + type Value = std::collections::BTreeMap; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de> { + let mut res = Self::Value::new(); + loop { + match map.next_entry()? { + Some((k, v)) => {res.insert(k,v);}, + None => return Ok(res), + } + } + + } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, V: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(MapVisitor::<'de, K, V>{ + _key_type: &std::marker::PhantomData, + _value_type: &std::marker::PhantomData, + }) + } +} + +pub mod string_visitor { + struct StringVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = std::string::String; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + return Ok(value.to_string()) + } + + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(StringVisitor) + } +} + +pub mod string_opt_visitor { + struct StringVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + return Ok(Some(value.to_string())) + } + + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(StringVisitor) + } +} + + +pub mod bool_visitor { + struct BoolVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for BoolVisitor { + type Value = bool; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid Bool string or integer") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + return Ok(value) + } + + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(bool::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(BoolVisitor) + } +} + +pub mod bool_opt_visitor { + struct BoolVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for BoolVisitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid Bool string or integer") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + return Ok(Some(value)) + } + + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(BoolVisitor) + } +} + pub mod i32_visitor { - struct I32Visitor; + pub struct I32Visitor; + impl crate::HasConstructor for I32Visitor { + fn new() -> I32Visitor { + return I32Visitor{}; + } + } + #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for I32Visitor { type Value = i32; @@ -324,13 +719,14 @@ pub mod i32_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value > i32::MAX as f64 || value < i32::MIN as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(value as i32) } } @@ -349,11 +745,18 @@ pub mod i32_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map(|x| x as i32).map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(i32::default()) + } } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -387,13 +790,14 @@ pub mod i32_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value > i32::MAX as f64 || value < i32::MIN as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(Some(value as i32)) } } @@ -412,10 +816,7 @@ pub mod i32_opt_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map(|x| Some(x as i32)) - .map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -467,13 +868,14 @@ pub mod i64_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value > i64::MAX as f64 || value < i64::MIN as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(value as i64) } } @@ -492,11 +894,19 @@ pub mod i64_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map(|x| x as i64).map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(i64::default()) + } + } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -529,13 +939,14 @@ pub mod i64_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value > i64::MAX as f64 || value < i64::MIN as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(Some(value as i64)) } } @@ -554,10 +965,7 @@ pub mod i64_opt_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map(|x| Some(x as i64)) - .map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -610,13 +1018,14 @@ pub mod u32_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value < 0.0 || value > u32::MAX as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(value as u32) } } @@ -635,11 +1044,18 @@ pub mod u32_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map(|x| x as u32).map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(u32::default()) + } } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -673,13 +1089,14 @@ pub mod u32_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value < 0.0 || value > u32::MAX as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(Some(value as u32)) } } @@ -700,8 +1117,8 @@ pub mod u32_opt_visitor { if value.contains('e') || value.contains('E') || value.ends_with(".0") { value .parse::() - .map(|x| Some(x as u32)) .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -720,6 +1137,7 @@ pub mod u32_opt_visitor { { Ok(None) } + } #[cfg(feature = "std")] @@ -753,13 +1171,14 @@ pub mod u64_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value < 0.0 || value > u64::MAX as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { - // This is a round number, we can cast just fine. + // This is a round number in the proper range, we can cast just fine. Ok(value as u64) } } @@ -770,11 +1189,18 @@ pub mod u64_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map(|x| x as u64).map_err(E::custom) + value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(u64::default()) + } } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -807,11 +1233,12 @@ pub mod u64_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON { - return Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON || + value < 0.0 || value > u64::MAX as f64 { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, - )); + )) } else { // This is a round number, we can cast just fine. Ok(Some(value as u64)) @@ -826,8 +1253,7 @@ pub mod u64_opt_visitor { if value.contains('e') || value.contains('E') || value.ends_with(".0") { value .parse::() - .map(|x| Some(x as u64)) - .map_err(E::custom) + .map_err(E::custom).and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -900,6 +1326,13 @@ pub mod f64_visitor { _ => value.parse::().map_err(E::custom), } } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(f64::default()) + } } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -1027,8 +1460,14 @@ pub mod f32_visitor { where E: serde::de::Error, { - // TODO figure out min/max bug. - Ok(value as f32) + if value < f32::MIN as f64 || value > f32::MAX as f64 { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + Ok(value as f32) + } } fn visit_u64(self, value: u64) -> Result @@ -1049,6 +1488,12 @@ pub mod f32_visitor { _ => value.parse::().map_err(E::custom), } } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(f32::default()) + } } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -1097,8 +1542,14 @@ pub mod f32_opt_visitor { where E: serde::de::Error, { - // TODO figure out min/max bug. - Ok(Some(value as f32)) + if value < f32::MIN as f64 || value > f32::MAX as f64 { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + Ok(Some(value as f32)) + } } fn visit_u64(self, value: u64) -> Result @@ -1155,6 +1606,102 @@ pub mod f32_opt_visitor { } } +pub mod vec_u8_visitor { + struct VecU8Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for VecU8Visitor { + type Value = ::prost::alloc::vec::Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid Base64 encoded string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + base64::decode(value).map_err(E::custom) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result<::prost::alloc::vec::Vec, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecU8Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &::prost::alloc::vec::Vec, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&base64::encode(value)) + } +} + +pub mod vec_u8_opt_visitor { + struct VecU8Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for VecU8Visitor { + type Value = std::option::Option<::prost::alloc::vec::Vec>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid Base64 encoded string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + base64::decode(value).map(|str| Some(str)).map_err(E::custom) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecU8Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option<::prost::alloc::vec::Vec>, serializer: S) -> Result + where + S: serde::Serializer, + { + match value { + None => serializer.serialize_none(), + Some(value) => crate::vec_u8_visitor::serialize(value, serializer), + } + } +} + #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/tests/Cargo.toml b/tests/Cargo.toml index 631d632aa..d750ddf20 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -23,6 +23,7 @@ prost-types = { path = "../prost-types" } protobuf = { path = "../protobuf" } serde = { version="1.0", features=["derive"] } serde_json = { version="1.0" } +serde_path_to_error = "0.1" [dev-dependencies] diff = "0.1" diff --git a/tests/src/lib.rs b/tests/src/lib.rs index b6dfbd33b..38e2a245b 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -216,9 +216,10 @@ where M: Message + Default + Serialize + Deserialize<'de>, { // Try to decode a message from the data. If decoding fails, continue. - let all_types: M = match serde_json::from_str(data) { + let jd = &mut serde_json::Deserializer::from_str(data); + let all_types: M = match serde_path_to_error::deserialize(jd) { Ok(all_types) => all_types, - Err(error) => return RoundtripResult::DecodeError(format!("step 1 {}", error.to_string())), + Err(error) => return RoundtripResult::DecodeError(format!("step 1 {} at {}", error.to_string(), error.path().to_string())), }; let str1 = match serde_json::to_string(&all_types) { From 6f542b64753f1c931e388540d164d8ae39599d03 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 1 Nov 2021 17:38:38 -0700 Subject: [PATCH 09/30] work --- prost-build/src/code_generator.rs | 45 +++- prost-types/src/lib.rs | 354 ++++++++++++++++++++---------- tests/src/lib.rs | 8 +- 3 files changed, 289 insertions(+), 118 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 569efc887..6d202546f 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -444,7 +444,25 @@ impl<'a> CodeGenerator<'a> { ("f64", false, false) => { push_indent(&mut self.buf, self.depth); self.buf - .push_str(r#"#[serde(with = "::prost_types::f64_visitor")]"#); + .push_str(r#"#[serde(serialize_with = "<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize")]"#); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::f64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("f64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f64_visitor::F64Visitor>")]"#, + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f64_visitor::F64Serializer>")]"# + ); self.buf.push('\n'); } ("f64", true, false) => { @@ -456,7 +474,25 @@ impl<'a> CodeGenerator<'a> { ("f32", false, false) => { push_indent(&mut self.buf, self.depth); self.buf - .push_str(r#"#[serde(with = "::prost_types::f32_visitor")]"#); + .push_str(r#"#[serde(serialize_with = "<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize")]"#); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::f32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("f32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f32_visitor::F32Visitor>")]"#, + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f32_visitor::F32Serializer>")]"# + ); self.buf.push('\n'); } ("f32", true, false) => { @@ -491,8 +527,9 @@ impl<'a> CodeGenerator<'a> { } (_, _, true) => { push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#, + ); self.buf.push('\n'); } _ => {} diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 4b4e79c98..13b06c4bf 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -16,7 +16,6 @@ use core::i32; use core::i64; use core::time; - include!("protobuf.rs"); pub mod compiler { include!("compiler.rs"); @@ -302,14 +301,21 @@ pub trait HasConstructor { fn new() -> Self; } -pub struct MyType<'de, T: serde::de::Visitor<'de> + HasConstructor>(>::Value); +pub struct MyType<'de, T: serde::de::Visitor<'de> + HasConstructor>( + >::Value, +); -impl<'de, T> serde::Deserialize<'de> for MyType<'de, T> where T: serde::de::Visitor<'de> + HasConstructor { +impl<'de, T> serde::Deserialize<'de> for MyType<'de, T> +where + T: serde::de::Visitor<'de> + HasConstructor, +{ fn deserialize(deserializer: D) -> Result, D::Error> where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(T::new()).map(|x| MyType{0: x}) + deserializer + .deserialize_any(T::new()) + .map(|x| MyType { 0: x }) } } @@ -318,7 +324,10 @@ pub fn is_default(t: &T) -> bool { } pub mod vec_visitor { - struct VecVisitor<'de, T> where T: serde::Deserialize<'de> { + struct VecVisitor<'de, T> + where + T: serde::Deserialize<'de>, + { _vec_type: &'de std::marker::PhantomData, } @@ -332,7 +341,8 @@ pub mod vec_visitor { fn visit_seq(self, mut seq: A) -> Result where - A: serde::de::SeqAccess<'de> { + A: serde::de::SeqAccess<'de>, + { let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); loop { match seq.next_element()? { @@ -341,7 +351,7 @@ pub mod vec_visitor { } } } - + fn visit_unit(self) -> Result where E: serde::de::Error, @@ -351,21 +361,29 @@ pub mod vec_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D, T: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + pub fn deserialize<'de, D, T: 'de + serde::Deserialize<'de>>( + deserializer: D, + ) -> Result, D::Error> where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(VecVisitor::<'de, T>{_vec_type: &std::marker::PhantomData}) + deserializer.deserialize_any(VecVisitor::<'de, T> { + _vec_type: &std::marker::PhantomData, + }) } } pub mod repeated_visitor { - struct VecVisitor<'de, T> where T: serde::de::Visitor<'de> + crate::HasConstructor { + struct VecVisitor<'de, T> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + { _vec_type: &'de std::marker::PhantomData, } #[cfg(feature = "std")] - impl<'de, T> serde::de::Visitor<'de> for VecVisitor<'de, T> where + impl<'de, T> serde::de::Visitor<'de> for VecVisitor<'de, T> + where T: serde::de::Visitor<'de> + crate::HasConstructor, { type Value = Vec<>::Value>; @@ -376,7 +394,8 @@ pub mod repeated_visitor { fn visit_seq(self, mut seq: A) -> Result where - A: serde::de::SeqAccess<'de> { + A: serde::de::SeqAccess<'de>, + { let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); loop { let response: std::option::Option> = seq.next_element()?; @@ -386,7 +405,7 @@ pub mod repeated_visitor { } } } - + fn visit_unit(self) -> Result where E: serde::de::Error, @@ -396,22 +415,30 @@ pub mod repeated_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D, T: 'de + serde::de::Visitor<'de> + crate::HasConstructor>(deserializer: D) -> Result>::Value>, D::Error> + pub fn deserialize<'de, D, T: 'de + serde::de::Visitor<'de> + crate::HasConstructor>( + deserializer: D, + ) -> Result>::Value>, D::Error> where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(VecVisitor::<'de, T>{_vec_type: &std::marker::PhantomData}) + deserializer.deserialize_any(VecVisitor::<'de, T> { + _vec_type: &std::marker::PhantomData, + }) } - pub fn serialize(value: Vec<::Value>, serializer: S) -> Result + pub fn serialize( + value: &Vec<::Value>, + serializer: S, + ) -> Result where S: serde::Serializer, F: crate::SerializeMethod, +// ::Value: Copy, { use serde::ser::SerializeSeq; let mut seq = serializer.serialize_seq(Some(value.len()))?; for e in value { - seq.serialize_element(&crate::MySeType::{val: e})?; + seq.serialize_element(&crate::MySeType:: { val: e })?; } seq.end() } @@ -419,30 +446,44 @@ pub mod repeated_visitor { pub trait SerializeMethod { type Value; - fn serialize(value: &Value, serializer: S) -> Result where S: serde::Serializer; + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer; } -pub struct MySeType where T: SerializeMethod { - val: T::Value, +pub struct MySeType<'a, T> +where + T: SerializeMethod, +{ + val: &'a ::Value, } -impl serde::Serialize for MySeType { - fn serialize(&self, serializer: S) -> Result where S: serde::Serializer { - T::serialize(&self.val, serializer) +impl<'a, T: SerializeMethod> serde::Serialize for MySeType<'a, T> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + T::serialize(self.val, serializer) } } pub mod map_visitor { struct MapVisitor<'de, K, V> - where K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, - V: serde::Deserialize<'de> + where + K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::Deserialize<'de>, { _key_type: &'de std::marker::PhantomData, _value_type: &'de std::marker::PhantomData, } #[cfg(feature = "std")] - impl<'de, K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, V: serde::Deserialize<'de>> serde::de::Visitor<'de> for MapVisitor<'de, K, V> { + impl< + 'de, + K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::Deserialize<'de>, + > serde::de::Visitor<'de> for MapVisitor<'de, K, V> + { type Value = std::collections::HashMap; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -451,15 +492,17 @@ pub mod map_visitor { fn visit_map(self, mut map: A) -> Result where - A: serde::de::MapAccess<'de> { + A: serde::de::MapAccess<'de>, + { let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); loop { match map.next_entry()? { - Some((k, v)) => {res.insert(k,v);}, + Some((k, v)) => { + res.insert(k, v); + } None => return Ok(res), } } - } fn visit_unit(self) -> Result where @@ -470,11 +513,18 @@ pub mod map_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D, K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, V: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + pub fn deserialize< + 'de, + D, + K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: 'de + serde::Deserialize<'de>, + >( + deserializer: D, + ) -> Result, D::Error> where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(MapVisitor::<'de, K, V>{ + deserializer.deserialize_any(MapVisitor::<'de, K, V> { _key_type: &std::marker::PhantomData, _value_type: &std::marker::PhantomData, }) @@ -483,15 +533,21 @@ pub mod map_visitor { pub mod btree_map_visitor { struct MapVisitor<'de, K, V> - where K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, - V: serde::Deserialize<'de> + where + K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::Deserialize<'de>, { _key_type: &'de std::marker::PhantomData, _value_type: &'de std::marker::PhantomData, } #[cfg(feature = "std")] - impl<'de, K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, V: serde::Deserialize<'de>> serde::de::Visitor<'de> for MapVisitor<'de, K, V> { + impl< + 'de, + K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::Deserialize<'de>, + > serde::de::Visitor<'de> for MapVisitor<'de, K, V> + { type Value = std::collections::BTreeMap; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { @@ -500,15 +556,17 @@ pub mod btree_map_visitor { fn visit_map(self, mut map: A) -> Result where - A: serde::de::MapAccess<'de> { + A: serde::de::MapAccess<'de>, + { let mut res = Self::Value::new(); loop { match map.next_entry()? { - Some((k, v)) => {res.insert(k,v);}, + Some((k, v)) => { + res.insert(k, v); + } None => return Ok(res), } } - } fn visit_unit(self) -> Result where @@ -519,11 +577,18 @@ pub mod btree_map_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D, K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, V: 'de + serde::Deserialize<'de>>(deserializer: D) -> Result, D::Error> + pub fn deserialize< + 'de, + D, + K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: 'de + serde::Deserialize<'de>, + >( + deserializer: D, + ) -> Result, D::Error> where D: serde::Deserializer<'de>, { - deserializer.deserialize_any(MapVisitor::<'de, K, V>{ + deserializer.deserialize_any(MapVisitor::<'de, K, V> { _key_type: &std::marker::PhantomData, _value_type: &std::marker::PhantomData, }) @@ -545,10 +610,9 @@ pub mod string_visitor { where E: serde::de::Error, { - return Ok(value.to_string()) + return Ok(value.to_string()); } - fn visit_unit(self) -> Result where E: serde::de::Error, @@ -581,10 +645,9 @@ pub mod string_opt_visitor { where E: serde::de::Error, { - return Ok(Some(value.to_string())) + return Ok(Some(value.to_string())); } - fn visit_unit(self) -> Result where E: serde::de::Error, @@ -598,11 +661,12 @@ pub mod string_opt_visitor { { Ok(None) } - } #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result, D::Error> where D: serde::Deserializer<'de>, { @@ -610,7 +674,6 @@ pub mod string_opt_visitor { } } - pub mod bool_visitor { struct BoolVisitor; @@ -626,10 +689,9 @@ pub mod bool_visitor { where E: serde::de::Error, { - return Ok(value) + return Ok(value); } - fn visit_unit(self) -> Result where E: serde::de::Error, @@ -661,10 +723,9 @@ pub mod bool_opt_visitor { where E: serde::de::Error, { - return Ok(Some(value)) + return Ok(Some(value)); } - fn visit_unit(self) -> Result where E: serde::de::Error, @@ -678,7 +739,6 @@ pub mod bool_opt_visitor { { Ok(None) } - } #[cfg(feature = "std")] @@ -695,10 +755,10 @@ pub mod i32_visitor { impl crate::HasConstructor for I32Visitor { fn new() -> I32Visitor { - return I32Visitor{}; + return I32Visitor {}; } } - + #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for I32Visitor { type Value = i32; @@ -719,8 +779,10 @@ pub mod i32_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value > i32::MAX as f64 || value < i32::MIN as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i32::MAX as f64 + || value < i32::MIN as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -745,7 +807,10 @@ pub mod i32_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } @@ -790,8 +855,10 @@ pub mod i32_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value > i32::MAX as f64 || value < i32::MIN as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i32::MAX as f64 + || value < i32::MIN as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -816,7 +883,10 @@ pub mod i32_opt_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -868,8 +938,10 @@ pub mod i64_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value > i64::MAX as f64 || value < i64::MIN as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i64::MAX as f64 + || value < i64::MIN as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -894,7 +966,10 @@ pub mod i64_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } @@ -906,7 +981,6 @@ pub mod i64_visitor { { Ok(i64::default()) } - } pub fn deserialize<'de, D>(deserializer: D) -> Result @@ -939,8 +1013,10 @@ pub mod i64_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value > i64::MAX as f64 || value < i64::MIN as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i64::MAX as f64 + || value < i64::MIN as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -965,7 +1041,10 @@ pub mod i64_opt_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -1018,8 +1097,10 @@ pub mod u32_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value < 0.0 || value > u32::MAX as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u32::MAX as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -1044,7 +1125,10 @@ pub mod u32_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } @@ -1089,8 +1173,10 @@ pub mod u32_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value < 0.0 || value > u32::MAX as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u32::MAX as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -1137,7 +1223,6 @@ pub mod u32_opt_visitor { { Ok(None) } - } #[cfg(feature = "std")] @@ -1171,8 +1256,10 @@ pub mod u64_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value < 0.0 || value > u64::MAX as f64 { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u64::MAX as f64 + { Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, @@ -1189,7 +1276,10 @@ pub mod u64_visitor { { // If we have scientific notation or a decimal, parse float first. if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value.parse::().map_err(E::custom).and_then(|x| self.visit_f64(x)) + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map_err(E::custom) } @@ -1233,9 +1323,11 @@ pub mod u64_opt_visitor { where E: serde::de::Error, { - if (value.trunc() - value).abs() > f64::EPSILON || - value < 0.0 || value > u64::MAX as f64 { - Err(serde::de::Error::invalid_type( + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u64::MAX as f64 + { + Err(serde::de::Error::invalid_type( serde::de::Unexpected::Float(value), &self, )) @@ -1253,7 +1345,8 @@ pub mod u64_opt_visitor { if value.contains('e') || value.contains('E') || value.ends_with(".0") { value .parse::() - .map_err(E::custom).and_then(|x| self.visit_f64(x)) + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) } else { value.parse::().map(|x| Some(x)).map_err(E::custom) } @@ -1284,8 +1377,14 @@ pub mod u64_opt_visitor { } pub mod f64_visitor { - struct F64Visitor; + pub struct F64Visitor; + impl crate::HasConstructor for F64Visitor { + fn new() -> F64Visitor { + return F64Visitor {}; + } + } +, #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for F64Visitor { type Value = f64; @@ -1342,19 +1441,24 @@ pub mod f64_visitor { deserializer.deserialize_any(F64Visitor) } - #[cfg(feature = "std")] - pub fn serialize(value: &f64, serializer: S) -> Result - where - S: serde::Serializer, - { - if value.is_nan() { - serializer.serialize_str("NaN") - } else if value.is_infinite() && value.is_sign_negative() { - serializer.serialize_str("-Infinity") - } else if value.is_infinite() { - serializer.serialize_str("Infinity") - } else { - serializer.serialize_f64(*value) + pub struct F64Serializer; + + impl crate::SerializeMethod for F64Serializer { + type Value = f64; + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f64(*value) + } } } } @@ -1431,15 +1535,22 @@ pub mod f64_opt_visitor { where S: serde::Serializer, { + use crate::SerializeMethod; match value { None => serializer.serialize_none(), - Some(double) => crate::f64_visitor::serialize(double, serializer), + Some(double) => crate::f64_visitor::F64Serializer::serialize(double, serializer), } } } pub mod f32_visitor { - struct F32Visitor; + pub struct F32Visitor; + + impl crate::HasConstructor for F32Visitor { + fn new() -> F32Visitor { + return F32Visitor {}; + } + } #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for F32Visitor { @@ -1503,19 +1614,25 @@ pub mod f32_visitor { deserializer.deserialize_any(F32Visitor) } - #[cfg(feature = "std")] - pub fn serialize(value: &f32, serializer: S) -> Result - where - S: serde::Serializer, - { - if value.is_nan() { - serializer.serialize_str("NaN") - } else if value.is_infinite() && value.is_sign_negative() { - serializer.serialize_str("-Infinity") - } else if value.is_infinite() { - serializer.serialize_str("Infinity") - } else { - serializer.serialize_f32(*value) + pub struct F32Serializer; + + impl crate::SerializeMethod for F32Serializer { + type Value = f32; + + #[cfg(feature = "std")] + fn serialize(value: &f32, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f32(*value) + } } } } @@ -1599,9 +1716,10 @@ pub mod f32_opt_visitor { where S: serde::Serializer, { + use crate::SerializeMethod; match value { None => serializer.serialize_none(), - Some(float) => crate::f32_visitor::serialize(float, serializer), + Some(float) => crate::f32_visitor::F32Serializer::serialize(float, serializer), } } } @@ -1641,7 +1759,10 @@ pub mod vec_u8_visitor { } #[cfg(feature = "std")] - pub fn serialize(value: &::prost::alloc::vec::Vec, serializer: S) -> Result + pub fn serialize( + value: &::prost::alloc::vec::Vec, + serializer: S, + ) -> Result where S: serde::Serializer, { @@ -1664,7 +1785,9 @@ pub mod vec_u8_opt_visitor { where E: serde::de::Error, { - base64::decode(value).map(|str| Some(str)).map_err(E::custom) + base64::decode(value) + .map(|str| Some(str)) + .map_err(E::custom) } fn visit_unit(self) -> Result @@ -1683,7 +1806,9 @@ pub mod vec_u8_opt_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result>, D::Error> + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result>, D::Error> where D: serde::Deserializer<'de>, { @@ -1691,7 +1816,10 @@ pub mod vec_u8_opt_visitor { } #[cfg(feature = "std")] - pub fn serialize(value: &std::option::Option<::prost::alloc::vec::Vec>, serializer: S) -> Result + pub fn serialize( + value: &std::option::Option<::prost::alloc::vec::Vec>, + serializer: S, + ) -> Result where S: serde::Serializer, { diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 38e2a245b..0c5adab2c 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -219,7 +219,13 @@ where let jd = &mut serde_json::Deserializer::from_str(data); let all_types: M = match serde_path_to_error::deserialize(jd) { Ok(all_types) => all_types, - Err(error) => return RoundtripResult::DecodeError(format!("step 1 {} at {}", error.to_string(), error.path().to_string())), + Err(error) => { + return RoundtripResult::DecodeError(format!( + "step 1 {} at {}", + error.to_string(), + error.path().to_string() + )) + } }; let str1 = match serde_json::to_string(&all_types) { From 146906289d028edab7daade493535b21edd06811 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 1 Nov 2021 23:24:10 -0700 Subject: [PATCH 10/30] final check in before putting up for 1 level of review --- conformance/Cargo.toml | 2 + conformance/src/main.rs | 82 +++++++++++++++++++++++++++++-- prost-build/src/code_generator.rs | 48 +++++++++++++++++- prost-types/src/lib.rs | 68 +++++++++++++++++++------ tests/src/lib.rs | 29 ++--------- 5 files changed, 181 insertions(+), 48 deletions(-) diff --git a/conformance/Cargo.toml b/conformance/Cargo.toml index 83249f173..aabbd14c4 100644 --- a/conformance/Cargo.toml +++ b/conformance/Cargo.toml @@ -15,3 +15,5 @@ log = "0.4" prost = { path = ".." } protobuf = { path = "../protobuf" } tests = { path = "../tests" } +serde_json = { version="1.0" } +serde_path_to_error = "0.1" diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 7e435886d..73dd8d402 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -100,7 +100,45 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { RoundtripResult::Error(error) => conformance_response::Result::RuntimeError(error), }; } - // TODO(konradjniemiec) support proto -> json and json -> proto conformance + if let Some(conformance_request::Payload::ProtobufPayload(buf)) = request.payload { + // proto -> json + return match &*request.message_type { + "protobuf_test_messages.proto2.TestAllTypesProto2" => { + let m = match TestAllTypesProto2::decode(&*buf) { + Ok(m) => m, + Err(error) => { + return conformance_response::Result::ParseError(error.to_string()) + }, + }; + match serde_json::to_string(&m) { + Ok(str) => conformance_response::Result::JsonPayload(str), + Err(error) => { + return conformance_response::Result::ParseError(error.to_string()) + } + } + } + "protobuf_test_messages.proto3.TestAllTypesProto3" => { + let m = match TestAllTypesProto3::decode(&*buf) { + Ok(m) => m, + Err(error) => { + return conformance_response::Result::ParseError(error.to_string()) + }, + }; + match serde_json::to_string(&m) { + Ok(str) => conformance_response::Result::JsonPayload(str), + Err(error) => { + return conformance_response::Result::ParseError(error.to_string()) + } + } + } + _ => { + return conformance_response::Result::ParseError(format!( + "unknown message type: {}", + request.message_type + )); + } + } + } return conformance_response::Result::Skipped( "only json <-> json is supported".to_string(), ); @@ -108,10 +146,44 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { let buf = match request.payload { None => return conformance_response::Result::ParseError("no payload".to_string()), - Some(conformance_request::Payload::JsonPayload(_)) => { - return conformance_response::Result::Skipped( - "JSON input is not supported".to_string(), - ); + Some(conformance_request::Payload::JsonPayload(str)) => { + // json -> proto + match &*request.message_type { + "protobuf_test_messages.proto2.TestAllTypesProto2" => { + let jd = &mut serde_json::Deserializer::from_str(&str); + let all_types: TestAllTypesProto2 = match serde_path_to_error::deserialize(jd) { + Ok(all_types) => all_types, + Err(error) => { + return conformance_response::Result::ParseError(format!( + "error deserializing json: {} at {}", + error.to_string(), + error.path().to_string() + )) + } + }; + return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()) + } + "protobuf_test_messages.proto3.TestAllTypesProto3" => { + let jd = &mut serde_json::Deserializer::from_str(&str); + let all_types: TestAllTypesProto3 = match serde_path_to_error::deserialize(jd) { + Ok(all_types) => all_types, + Err(error) => { + return conformance_response::Result::ParseError(format!( + "error deserializing json: {} at {}", + error.to_string(), + error.path().to_string() + )) + } + }; + return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()) + } + _ => { + return conformance_response::Result::ParseError(format!( + "unknown message type: {}", + request.message_type + )); + } + } } Some(conformance_request::Payload::JspbPayload(_)) => { return conformance_response::Result::Skipped( diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 6d202546f..38f384b3d 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -402,6 +402,13 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(deserialize_with = "::prost_types::bool_opt_visitor::deserialize")]"#); self.buf.push('\n'); } + ("bool", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::bool_visitor::BoolVisitor>")]"#, + ); + self.buf.push('\n'); + } ("i64", false, false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( @@ -415,6 +422,13 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(deserialize_with = "::prost_types::i64_opt_visitor::deserialize")]"#); self.buf.push('\n'); } + ("i64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i64_visitor::I64Visitor>")]"#, + ); + self.buf.push('\n'); + } ("u32", false, false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( @@ -428,6 +442,13 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(deserialize_with = "::prost_types::u32_opt_visitor::deserialize")]"#); self.buf.push('\n'); } + ("u32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u32_visitor::U32Visitor>")]"#, + ); + self.buf.push('\n'); + } ("u64", false, false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( @@ -441,6 +462,13 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(deserialize_with = "::prost_types::u64_opt_visitor::deserialize")]"#); self.buf.push('\n'); } + ("u64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u64_visitor::U64Visitor>")]"#, + ); + self.buf.push('\n'); + } ("f64", false, false) => { push_indent(&mut self.buf, self.depth); self.buf @@ -516,7 +544,12 @@ impl<'a> CodeGenerator<'a> { ("::prost::alloc::vec::Vec", false, false) => { push_indent(&mut self.buf, self.depth); self.buf - .push_str(r#"#[serde(with = "::prost_types::vec_u8_visitor")]"#); + .push_str(r#"#[serde(serialize_with = "<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize")]"#); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(deserialize_with = "::prost_types::vec_u8_visitor::deserialize")]"# + ); self.buf.push('\n'); } ("::prost::alloc::vec::Vec", true, false) => { @@ -525,6 +558,19 @@ impl<'a> CodeGenerator<'a> { .push_str(r#"#[serde(with = "::prost_types::vec_u8_opt_visitor")]"#); self.buf.push('\n'); } + ("::prost::alloc::vec::Vec", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str( + r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::vec_u8_visitor::VecU8Visitor>")]"#, + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::vec_u8_visitor::VecU8Serializer>")]"# + ); + self.buf.push('\n'); + } (_, _, true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 13b06c4bf..50d3cd696 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -675,7 +675,13 @@ pub mod string_opt_visitor { } pub mod bool_visitor { - struct BoolVisitor; + pub struct BoolVisitor; + + impl crate::HasConstructor for BoolVisitor { + fn new() -> Self { + return Self{}; + } + } #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for BoolVisitor { @@ -917,7 +923,13 @@ pub mod i32_opt_visitor { } pub mod i64_visitor { - struct I64Visitor; + pub struct I64Visitor; + + impl crate::HasConstructor for I64Visitor { + fn new() -> Self { + return Self {}; + } + } #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for I64Visitor { @@ -1075,7 +1087,14 @@ pub mod i64_opt_visitor { } pub mod u32_visitor { - struct U32Visitor; + pub struct U32Visitor; + + impl crate::HasConstructor for U32Visitor { + fn new() -> Self { + return Self {}; + } + } + #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for U32Visitor { @@ -1235,7 +1254,13 @@ pub mod u32_opt_visitor { } pub mod u64_visitor { - struct U64Visitor; + pub struct U64Visitor; + + impl crate::HasConstructor for U64Visitor { + fn new() -> Self { + return Self {}; + } + } #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for U64Visitor { @@ -1384,7 +1409,7 @@ pub mod f64_visitor { return F64Visitor {}; } } -, + #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for F64Visitor { type Value = f64; @@ -1725,7 +1750,13 @@ pub mod f32_opt_visitor { } pub mod vec_u8_visitor { - struct VecU8Visitor; + pub struct VecU8Visitor; + + impl crate::HasConstructor for VecU8Visitor { + fn new() -> Self { + return Self {}; + } + } #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for VecU8Visitor { @@ -1758,15 +1789,19 @@ pub mod vec_u8_visitor { deserializer.deserialize_any(VecU8Visitor) } - #[cfg(feature = "std")] - pub fn serialize( - value: &::prost::alloc::vec::Vec, - serializer: S, - ) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&base64::encode(value)) + + pub struct VecU8Serializer; + + impl crate::SerializeMethod for VecU8Serializer { + type Value = ::prost::alloc::vec::Vec; + + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&base64::encode(value)) + } } } @@ -1823,9 +1858,10 @@ pub mod vec_u8_opt_visitor { where S: serde::Serializer, { + use crate::SerializeMethod; match value { None => serializer.serialize_none(), - Some(value) => crate::vec_u8_visitor::serialize(value, serializer), + Some(value) => crate::vec_u8_visitor::VecU8Serializer::serialize(value, serializer), } } } diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 0c5adab2c..a21acab4a 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -215,13 +215,13 @@ pub fn roundtrip_json<'de, M>(data: &'de str) -> RoundtripResult where M: Message + Default + Serialize + Deserialize<'de>, { - // Try to decode a message from the data. If decoding fails, continue. + let jd = &mut serde_json::Deserializer::from_str(data); let all_types: M = match serde_path_to_error::deserialize(jd) { Ok(all_types) => all_types, Err(error) => { return RoundtripResult::DecodeError(format!( - "step 1 {} at {}", + "error deserializing json: {} at {}", error.to_string(), error.path().to_string() )) @@ -230,32 +230,9 @@ where let str1 = match serde_json::to_string(&all_types) { Ok(str) => str, - Err(error) => return RoundtripResult::Error(format!("step 2 {}", error.to_string())), - }; - - /* if str1 != data { - return RoundtripResult::Error(format!( - "halftripped JSON encoded strings do not match\nstring: {}\noriginal provided data: {}", - str1, data - )); - }*/ - - let roundtrip: M = match serde_json::from_str::(data) { - Ok(roundtrip) => roundtrip, - Err(error) => return RoundtripResult::Error(format!("step 3 {}", error.to_string())), - }; - - let str2 = match serde_json::to_string(&roundtrip) { - Ok(str) => str, - Err(error) => return RoundtripResult::Error(format!("step 4 {}", error.to_string())), + Err(error) => return RoundtripResult::Error(format!("error encoding json {}", error.to_string())), }; - if str1 != str2 { - return RoundtripResult::Error(format!( - "roundtripped JSON encoded strings do not match {} {}", - str1, str2 - )); - } RoundtripResult::Ok(str1.into_bytes()) } From 06c0983201d87b34e5c4da655e1e1c51da0110fe Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 1 Nov 2021 23:29:53 -0700 Subject: [PATCH 11/30] fmt --- conformance/src/main.rs | 10 +++++----- prost-types/src/lib.rs | 6 ++---- tests/src/lib.rs | 5 +++-- 3 files changed, 10 insertions(+), 11 deletions(-) diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 73dd8d402..553adaf0a 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -108,7 +108,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { Ok(m) => m, Err(error) => { return conformance_response::Result::ParseError(error.to_string()) - }, + } }; match serde_json::to_string(&m) { Ok(str) => conformance_response::Result::JsonPayload(str), @@ -122,7 +122,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { Ok(m) => m, Err(error) => { return conformance_response::Result::ParseError(error.to_string()) - }, + } }; match serde_json::to_string(&m) { Ok(str) => conformance_response::Result::JsonPayload(str), @@ -137,7 +137,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { request.message_type )); } - } + }; } return conformance_response::Result::Skipped( "only json <-> json is supported".to_string(), @@ -161,7 +161,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { )) } }; - return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()) + return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()); } "protobuf_test_messages.proto3.TestAllTypesProto3" => { let jd = &mut serde_json::Deserializer::from_str(&str); @@ -175,7 +175,7 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { )) } }; - return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()) + return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()); } _ => { return conformance_response::Result::ParseError(format!( diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 50d3cd696..c6a7d5b43 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -433,7 +433,7 @@ pub mod repeated_visitor { where S: serde::Serializer, F: crate::SerializeMethod, -// ::Value: Copy, + // ::Value: Copy, { use serde::ser::SerializeSeq; let mut seq = serializer.serialize_seq(Some(value.len()))?; @@ -679,7 +679,7 @@ pub mod bool_visitor { impl crate::HasConstructor for BoolVisitor { fn new() -> Self { - return Self{}; + return Self {}; } } @@ -1095,7 +1095,6 @@ pub mod u32_visitor { } } - #[cfg(feature = "std")] impl<'de> serde::de::Visitor<'de> for U32Visitor { type Value = u32; @@ -1789,7 +1788,6 @@ pub mod vec_u8_visitor { deserializer.deserialize_any(VecU8Visitor) } - pub struct VecU8Serializer; impl crate::SerializeMethod for VecU8Serializer { diff --git a/tests/src/lib.rs b/tests/src/lib.rs index a21acab4a..4718f4c2b 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -215,7 +215,6 @@ pub fn roundtrip_json<'de, M>(data: &'de str) -> RoundtripResult where M: Message + Default + Serialize + Deserialize<'de>, { - let jd = &mut serde_json::Deserializer::from_str(data); let all_types: M = match serde_path_to_error::deserialize(jd) { Ok(all_types) => all_types, @@ -230,7 +229,9 @@ where let str1 = match serde_json::to_string(&all_types) { Ok(str) => str, - Err(error) => return RoundtripResult::Error(format!("error encoding json {}", error.to_string())), + Err(error) => { + return RoundtripResult::Error(format!("error encoding json {}", error.to_string())) + } }; RoundtripResult::Ok(str1.into_bytes()) From a002a23f8d586c4f43b7f3fc95d7428f901fc788 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 1 Nov 2021 23:33:35 -0700 Subject: [PATCH 12/30] remove comment + add failed tests --- conformance/failing_tests.txt | 128 +++++++++++++++++++++++++++++++++- tests/src/lib.rs | 10 --- 2 files changed, 126 insertions(+), 12 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 1bb832783..2d2acf2d4 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,2 +1,126 @@ -Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput +Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput +Recommended.FieldMaskPathsDontRoundTrip.JsonOutput +Recommended.FieldMaskTooManyUnderscore.JsonOutput +Recommended.Proto2.JsonInput.FieldNameExtension.Validator +Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput +Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput +Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator +Recommended.Proto3.JsonInput.Int64FieldBeString.Validator +Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator +Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator +Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroBytes.ProtobufOutput +Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroEnum.ProtobufOutput +Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull +Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator +Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator +Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput +Required.DurationProtoInputTooLarge.JsonOutput +Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto3.JsonInput.Any.JsonOutput +Required.Proto3.JsonInput.Any.ProtobufOutput +Required.Proto3.JsonInput.AnyNested.JsonOutput +Required.Proto3.JsonInput.AnyNested.ProtobufOutput +Required.Proto3.JsonInput.AnyUnorderedTypeTag.JsonOutput +Required.Proto3.JsonInput.AnyUnorderedTypeTag.ProtobufOutput +Required.Proto3.JsonInput.AnyWithDuration.JsonOutput +Required.Proto3.JsonInput.AnyWithDuration.ProtobufOutput +Required.Proto3.JsonInput.AnyWithFieldMask.JsonOutput +Required.Proto3.JsonInput.AnyWithFieldMask.ProtobufOutput +Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.JsonOutput +Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.ProtobufOutput +Required.Proto3.JsonInput.AnyWithStruct.JsonOutput +Required.Proto3.JsonInput.AnyWithStruct.ProtobufOutput +Required.Proto3.JsonInput.AnyWithTimestamp.JsonOutput +Required.Proto3.JsonInput.AnyWithTimestamp.ProtobufOutput +Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForInteger.ProtobufOutput +Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForJsonObject.ProtobufOutput +Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput +Required.Proto3.JsonInput.BoolMapEscapedKey.ProtobufOutput +Required.Proto3.JsonInput.BoolMapField.JsonOutput +Required.Proto3.JsonInput.BoolMapField.ProtobufOutput +Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput +Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.ProtobufOutput +Required.Proto3.JsonInput.DurationMaxValue.JsonOutput +Required.Proto3.JsonInput.DurationMaxValue.ProtobufOutput +Required.Proto3.JsonInput.DurationMinValue.JsonOutput +Required.Proto3.JsonInput.DurationMinValue.ProtobufOutput +Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput +Required.Proto3.JsonInput.DurationRepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput +Required.Proto3.JsonInput.EmptyFieldMask.ProtobufOutput +Required.Proto3.JsonInput.EnumField.JsonOutput +Required.Proto3.JsonInput.EnumField.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAlias.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.ProtobufOutput +Required.Proto3.JsonInput.EnumRepeatedField.JsonOutput +Required.Proto3.JsonInput.EnumRepeatedField.ProtobufOutput +Required.Proto3.JsonInput.FieldMask.JsonOutput +Required.Proto3.JsonInput.FieldMask.ProtobufOutput +Required.Proto3.JsonInput.OneofFieldDuplicate +Required.Proto3.JsonInput.RepeatedListValue.JsonOutput +Required.Proto3.JsonInput.RepeatedListValue.ProtobufOutput +Required.Proto3.JsonInput.RepeatedValue.JsonOutput +Required.Proto3.JsonInput.RepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.Struct.JsonOutput +Required.Proto3.JsonInput.Struct.ProtobufOutput +Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput +Required.Proto3.JsonInput.StructWithEmptyListValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampMinValue.JsonOutput +Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput +Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput +Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput +Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput +Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput +Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput +Required.Proto3.JsonInput.ValueAcceptFloat.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput +Required.Proto3.JsonInput.ValueAcceptInteger.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptList.JsonOutput +Required.Proto3.JsonInput.ValueAcceptList.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput +Required.Proto3.JsonInput.ValueAcceptNull.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput +Required.Proto3.JsonInput.ValueAcceptObject.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptString.JsonOutput +Required.Proto3.JsonInput.ValueAcceptString.ProtobufOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Default.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKey.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKeyInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateValueInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.MissingDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.NonDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Unordered.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Default.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKey.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKeyInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateValueInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.MissingDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.NonDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Unordered.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.DefaultValue.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForDifferentField.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForSameField.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.NonDefaultValue.JsonOutput +Required.TimestampProtoInputTooLarge.JsonOutput +Required.TimestampProtoInputTooSmall.JsonOutput diff --git a/tests/src/lib.rs b/tests/src/lib.rs index 4718f4c2b..c15f1a751 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -136,16 +136,6 @@ impl RoundtripResult { RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), } } - /* - /// Unwrap the roundtrip result. Panics if the result was a validation or re-encoding error. - pub fn unwrap_error(self) -> Result, prost::DecodeError> { - match self { - RoundtripResult::Ok(buf) => Ok(buf), - RoundtripResult::DecodeError(error) => Err(DecodeError(error.to_string())), - RoundtripResult::Error(error) => panic!("failed roundtrip: {}", error), - } - } - */ } /// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, From 7ac0d015e3287927c7e8e573a3b40291303bf542 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Wed, 3 Nov 2021 11:54:48 -0700 Subject: [PATCH 13/30] almost ready for review --- conformance/failing_tests.txt | 2 + conformance/src/main.rs | 159 ++---------- prost-build/src/code_generator.rs | 412 +++++++++++++++--------------- prost-build/src/lib.rs | 12 +- tests/src/lib.rs | 102 +++++--- 5 files changed, 305 insertions(+), 382 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 2d2acf2d4..7cf33f7d2 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,3 +1,5 @@ +Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput +Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput Recommended.FieldMaskPathsDontRoundTrip.JsonOutput Recommended.FieldMaskTooManyUnderscore.JsonOutput diff --git a/conformance/src/main.rs b/conformance/src/main.rs index 553adaf0a..3aafa154e 100644 --- a/conformance/src/main.rs +++ b/conformance/src/main.rs @@ -8,7 +8,7 @@ use protobuf::conformance::{ }; use protobuf::test_messages::proto2::TestAllTypesProto2; use protobuf::test_messages::proto3::TestAllTypesProto3; -use tests::{roundtrip, roundtrip_json, RoundtripResult}; +use tests::{roundtrip, RoundtripResult}; fn main() -> io::Result<()> { env_logger::init(); @@ -49,158 +49,33 @@ fn main() -> io::Result<()> { } fn handle_request(request: ConformanceRequest) -> conformance_response::Result { - match request.requested_output_format() { - WireFormat::Unspecified => { + let rof = request.requested_output_format(); + match (rof, request.payload.as_ref()) { + (WireFormat::Unspecified, _) | (_, None) => { return conformance_response::Result::ParseError( - "output format unspecified".to_string(), + "input/output format unspecified".to_string(), ); } - WireFormat::Jspb => { + (WireFormat::Jspb, _) | (_, Some(conformance_request::Payload::JspbPayload(_))) => { return conformance_response::Result::Skipped( - "JSPB output is not supported".to_string(), + "JSPB input/output is not supported".to_string(), ); } - WireFormat::TextFormat => { + (WireFormat::TextFormat, _) | (_, Some(conformance_request::Payload::TextPayload(_))) => { return conformance_response::Result::Skipped( - "TEXT_FORMAT output is not supported".to_string(), + "TEXT_FORMAT input/output is not supported".to_string(), ); } - WireFormat::Protobuf | WireFormat::Json => (), + (WireFormat::Protobuf, _) | (WireFormat::Json, _) => (), }; - if let WireFormat::Json = request.requested_output_format() { - if let Some(conformance_request::Payload::JsonPayload(json_str)) = request.payload { - let roundtrip = match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => { - roundtrip_json::(&json_str) - } - "protobuf_test_messages.proto3.TestAllTypesProto3" => { - roundtrip_json::(&json_str) - } - _ => { - return conformance_response::Result::ParseError(format!( - "unknown message type: {}", - request.message_type - )); - } - }; - - return match roundtrip { - RoundtripResult::Ok(buf) => { - conformance_response::Result::JsonPayload(match std::str::from_utf8(&buf) { - Ok(str) => str.to_string(), - Err(error) => { - return conformance_response::Result::ParseError(error.to_string()) - } - }) - } - RoundtripResult::DecodeError(error) => { - conformance_response::Result::ParseError(error) - } - RoundtripResult::Error(error) => conformance_response::Result::RuntimeError(error), - }; - } - if let Some(conformance_request::Payload::ProtobufPayload(buf)) = request.payload { - // proto -> json - return match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => { - let m = match TestAllTypesProto2::decode(&*buf) { - Ok(m) => m, - Err(error) => { - return conformance_response::Result::ParseError(error.to_string()) - } - }; - match serde_json::to_string(&m) { - Ok(str) => conformance_response::Result::JsonPayload(str), - Err(error) => { - return conformance_response::Result::ParseError(error.to_string()) - } - } - } - "protobuf_test_messages.proto3.TestAllTypesProto3" => { - let m = match TestAllTypesProto3::decode(&*buf) { - Ok(m) => m, - Err(error) => { - return conformance_response::Result::ParseError(error.to_string()) - } - }; - match serde_json::to_string(&m) { - Ok(str) => conformance_response::Result::JsonPayload(str), - Err(error) => { - return conformance_response::Result::ParseError(error.to_string()) - } - } - } - _ => { - return conformance_response::Result::ParseError(format!( - "unknown message type: {}", - request.message_type - )); - } - }; - } - return conformance_response::Result::Skipped( - "only json <-> json is supported".to_string(), - ); - } - - let buf = match request.payload { - None => return conformance_response::Result::ParseError("no payload".to_string()), - Some(conformance_request::Payload::JsonPayload(str)) => { - // json -> proto - match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => { - let jd = &mut serde_json::Deserializer::from_str(&str); - let all_types: TestAllTypesProto2 = match serde_path_to_error::deserialize(jd) { - Ok(all_types) => all_types, - Err(error) => { - return conformance_response::Result::ParseError(format!( - "error deserializing json: {} at {}", - error.to_string(), - error.path().to_string() - )) - } - }; - return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()); - } - "protobuf_test_messages.proto3.TestAllTypesProto3" => { - let jd = &mut serde_json::Deserializer::from_str(&str); - let all_types: TestAllTypesProto3 = match serde_path_to_error::deserialize(jd) { - Ok(all_types) => all_types, - Err(error) => { - return conformance_response::Result::ParseError(format!( - "error deserializing json: {} at {}", - error.to_string(), - error.path().to_string() - )) - } - }; - return conformance_response::Result::ProtobufPayload(all_types.encode_to_vec()); - } - _ => { - return conformance_response::Result::ParseError(format!( - "unknown message type: {}", - request.message_type - )); - } - } - } - Some(conformance_request::Payload::JspbPayload(_)) => { - return conformance_response::Result::Skipped( - "JSON input is not supported".to_string(), - ); + let result = match (&*request.message_type, request.payload) { + ("protobuf_test_messages.proto2.TestAllTypesProto2", Some(payload)) => { + roundtrip::(payload, rof) } - Some(conformance_request::Payload::TextPayload(_)) => { - return conformance_response::Result::Skipped( - "JSON input is not supported".to_string(), - ); + ("protobuf_test_messages.proto3.TestAllTypesProto3", Some(payload)) => { + roundtrip::(payload, rof) } - Some(conformance_request::Payload::ProtobufPayload(buf)) => buf, - }; - - let roundtrip = match &*request.message_type { - "protobuf_test_messages.proto2.TestAllTypesProto2" => roundtrip::(&buf), - "protobuf_test_messages.proto3.TestAllTypesProto3" => roundtrip::(&buf), _ => { return conformance_response::Result::ParseError(format!( "unknown message type: {}", @@ -209,8 +84,8 @@ fn handle_request(request: ConformanceRequest) -> conformance_response::Result { } }; - match roundtrip { - RoundtripResult::Ok(buf) => conformance_response::Result::ProtobufPayload(buf), + match result { + RoundtripResult::Ok(result) => result, RoundtripResult::DecodeError(error) => conformance_response::Result::ParseError(error), RoundtripResult::Error(error) => conformance_response::Result::RuntimeError(error), } diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 38f384b3d..668a59d02 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -329,26 +329,28 @@ impl<'a> CodeGenerator<'a> { json_name: &str, map_type: Option<&str>, ) { - if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { - if json_name.len() > 0 { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(&format!(r#"#[serde(rename = "{}")]"#, json_name,)); - self.buf.push('\n'); - } + if let None = self.config.json_mapping.get_first(fq_message_name) { + return; + } + if json_name.len() > 0 { push_indent(&mut self.buf, self.depth); self.buf - .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); + .push_str(&format!(r#"#[serde(rename = "{}")]"#, json_name,)); + self.buf.push('\n'); + } + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + if let Some(map_type) = map_type { + self.buf.push_str(&format!( + r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, + map_type + )); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); - if let Some(map_type) = map_type { - self.buf.push_str(&format!( - r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, - map_type - )); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - match map_type { + match map_type { "::std::collections::HashMap" => self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::map_visitor::deserialize")]"#, @@ -360,226 +362,230 @@ impl<'a> CodeGenerator<'a> { _ => (), } - self.buf.push('\n'); - return; - } else { - self.buf - .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); - } self.buf.push('\n'); + return; + } else { + self.buf + .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); + } + self.buf.push('\n'); - match (ty, optional, repeated) { - ("i32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i32_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("i32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( + match (ty, optional, repeated) { + ("i32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("i32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i32_visitor::I32Visitor>")]"#, ); - self.buf.push('\n'); - } - ("i32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::i32_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("bool", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::bool_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("bool", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + } + ("i32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i32_opt_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("bool", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::bool_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("bool", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(deserialize_with = "::prost_types::bool_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("bool", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( + self.buf.push('\n'); + } + ("bool", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::bool_visitor::BoolVisitor>")]"#, ); - self.buf.push('\n'); - } - ("i64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i64_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("i64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::i64_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("i64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( + self.buf.push('\n'); + } + ("i64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("i64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::i64_opt_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("i64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i64_visitor::I64Visitor>")]"#, ); - self.buf.push('\n'); - } - ("u32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u32_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::u32_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("u32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( + self.buf.push('\n'); + } + ("u32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u32_opt_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u32_visitor::U32Visitor>")]"#, ); - self.buf.push('\n'); - } - ("u64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u64_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::u64_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("u64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( + self.buf.push('\n'); + } + ("u64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::u64_opt_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("u64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u64_visitor::U64Visitor>")]"#, ); - self.buf.push('\n'); - } - ("f64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + } + ("f64", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::f64_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("f64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::f64_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("f64", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f64_visitor::F64Visitor>")]"#, ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f64_visitor::F64Serializer>")]"# ); - self.buf.push('\n'); - } - ("f64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::f64_opt_visitor")]"#); - self.buf.push('\n'); - } - ("f32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + } + ("f64", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f64_opt_visitor")]"#); + self.buf.push('\n'); + } + ("f32", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::f32_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("f32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::f32_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("f32", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f32_visitor::F32Visitor>")]"#, ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f32_visitor::F32Serializer>")]"# ); - self.buf.push('\n'); - } - ("f32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::f32_opt_visitor")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::string::String", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::string_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::string::String", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + } + ("f32", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::f32_opt_visitor")]"#); + self.buf.push('\n'); + } + ("::prost::alloc::string::String", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::string_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("::prost::alloc::string::String", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(deserialize_with = "::prost_types::string_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::vec::Vec", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + } + ("::prost::alloc::vec::Vec", false, false) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::vec_u8_visitor::deserialize")]"# - ); - self.buf.push('\n'); - } - ("::prost::alloc::vec::Vec", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::vec_u8_opt_visitor")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::vec::Vec", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::vec_u8_visitor::deserialize")]"#, + ); + self.buf.push('\n'); + } + ("::prost::alloc::vec::Vec", true, false) => { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(with = "::prost_types::vec_u8_opt_visitor")]"#); + self.buf.push('\n'); + } + ("::prost::alloc::vec::Vec", false, true) => { + push_indent(&mut self.buf, self.depth); + self.buf .push_str( r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::vec_u8_visitor::VecU8Visitor>")]"#, ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::vec_u8_visitor::VecU8Serializer>")]"# ); - self.buf.push('\n'); - } - (_, _, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - _ => {} + self.buf.push('\n'); + } + (_, _, true) => { + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#, + ); + self.buf.push('\n'); } + _ => {} } } diff --git a/prost-build/src/lib.rs b/prost-build/src/lib.rs index 1dc6de50f..b18677dde 100644 --- a/prost-build/src/lib.rs +++ b/prost-build/src/lib.rs @@ -449,13 +449,21 @@ impl Config { } /// Generates serde attributes in order to conform to the proto to json spec. - // TODO MORE COMMENTS + /// Once applied, all messages will implement Serialize and Deserialize, and + /// serde_json can be used to go to/from json. + /// + /// Verification of the implementation is done in the `conformance` crate. See + /// the failed list for any limitations in the current implementation. + /// + /// More on the proto/json spec can be found [here](https://developers.google.com/protocol-buffers/docs/proto3#json). + /// + /// There are additional options that Google suggests, however none are currently + /// implemented. pub fn json_mapping(&mut self, paths: I) -> &mut Self where I: IntoIterator, S: AsRef, { - self.map_type.clear(); for matcher in paths { self.json_mapping.insert(matcher.as_ref().to_string(), ()); } diff --git a/tests/src/lib.rs b/tests/src/lib.rs index c15f1a751..b3eab3369 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -113,11 +113,11 @@ use bytes::Buf; use prost::Message; -use serde::{Deserialize, Serialize}; - +use protobuf::conformance::{conformance_request, conformance_response, WireFormat}; +use serde::{de::DeserializeOwned, Serialize}; pub enum RoundtripResult { /// The roundtrip succeeded. - Ok(Vec), + Ok(conformance_response::Result), /// The data could not be decoded. This could indicate a bug in prost, /// or it could indicate that the input was bogus. DecodeError(String), @@ -127,7 +127,7 @@ pub enum RoundtripResult { impl RoundtripResult { /// Unwrap the roundtrip result. - pub fn unwrap(self) -> Vec { + pub fn unwrap(self) -> conformance_response::Result { match self { RoundtripResult::Ok(buf) => buf, RoundtripResult::DecodeError(error) => { @@ -138,25 +138,60 @@ impl RoundtripResult { } } -/// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, -/// otherwise the comparison may fail due to inconsistent `HashMap` entry encoding ordering. -pub fn roundtrip(data: &[u8]) -> RoundtripResult +fn decode(payload: conformance_request::Payload) -> Result where - M: Message + Default, + M: Message + Default + DeserializeOwned, { - // Try to decode a message from the data. If decoding fails, continue. - let all_types = match M::decode(data) { - Ok(all_types) => all_types, - Err(error) => return RoundtripResult::DecodeError(error.to_string()), - }; + match payload { + conformance_request::Payload::JsonPayload(str) => { + let jd = &mut serde_json::Deserializer::from_str(&str); + match serde_path_to_error::deserialize(jd) { + Ok(all_types) => Ok(all_types), + Err(error) => Err(format!( + "error deserializing json: {} at {}", + error.to_string(), + error.path().to_string() + )), + } + } + conformance_request::Payload::ProtobufPayload(buf) => match M::decode(&*buf) { + Ok(m) => Ok(m), + Err(error) => Err(error.to_string()), + }, + _ => panic!("only proto and json are supported"), + } +} - let encoded_len = all_types.encoded_len(); +fn encode( + message: M, + requested_output_format: WireFormat, +) -> Result +where + M: Message + Default + Serialize, +{ + match requested_output_format { + WireFormat::Json => match serde_json::to_string(&message) { + Ok(str) => Ok(conformance_response::Result::JsonPayload(str)), + Err(error) => Err(error.to_string()), + }, + WireFormat::Protobuf => Ok(conformance_response::Result::ProtobufPayload( + message.encode_to_vec(), + )), + _ => panic!("only proto and json are supported"), + } +} - // TODO: Reenable this once sign-extension in negative int32s is figured out. +fn proto_to_proto_checks() +where + M: Message + Default, +{ + /* + // TODO: Reenable this once sign-extension in negative int32s is figured out. // assert!(encoded_len <= data.len(), "encoded_len: {}, len: {}, all_types: {:?}", // encoded_len, data.len(), all_types); - let mut buf1 = Vec::new(); + + let mut buf1 = Vec::new(); if let Err(error) = all_types.encode(&mut buf1) { return RoundtripResult::Error(error.to_string()); } @@ -195,36 +230,33 @@ where "roundtripped encoded buffers do not match with `encode_to_vec`".to_string(), ); } - - RoundtripResult::Ok(buf1) + */ } /// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, /// otherwise the comparison may fail due to inconsistent `HashMap` entry encoding ordering. -pub fn roundtrip_json<'de, M>(data: &'de str) -> RoundtripResult +pub fn roundtrip( + payload: conformance_request::Payload, + requested_output_format: WireFormat, +) -> RoundtripResult where - M: Message + Default + Serialize + Deserialize<'de>, + M: Message + Default + DeserializeOwned + Serialize, { - let jd = &mut serde_json::Deserializer::from_str(data); - let all_types: M = match serde_path_to_error::deserialize(jd) { + let all_types: M = match decode(payload.clone()) { Ok(all_types) => all_types, - Err(error) => { - return RoundtripResult::DecodeError(format!( - "error deserializing json: {} at {}", - error.to_string(), - error.path().to_string() - )) - } + Err(error) => return RoundtripResult::DecodeError(error), }; - let str1 = match serde_json::to_string(&all_types) { - Ok(str) => str, - Err(error) => { - return RoundtripResult::Error(format!("error encoding json {}", error.to_string())) + if let conformance_request::Payload::ProtobufPayload(_) = payload { + if requested_output_format == WireFormat::Protobuf { + proto_to_proto_checks::(); } - }; + } - RoundtripResult::Ok(str1.into_bytes()) + match encode(all_types, requested_output_format) { + Ok(result) => RoundtripResult::Ok(result), + Err(error) => RoundtripResult::Error(error), + } } /// Generic rountrip serialization check for messages. From 7766e33bdca378377ea48d495da1cbb31b2b0cab Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Wed, 3 Nov 2021 11:58:18 -0700 Subject: [PATCH 14/30] no-op --- conformance/failing_tests.txt | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 7cf33f7d2..40b53f347 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,5 +1,3 @@ -Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput Recommended.FieldMaskPathsDontRoundTrip.JsonOutput Recommended.FieldMaskTooManyUnderscore.JsonOutput @@ -25,6 +23,7 @@ Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput Required.Proto3.JsonInput.Any.JsonOutput Required.Proto3.JsonInput.Any.ProtobufOutput Required.Proto3.JsonInput.AnyNested.JsonOutput @@ -106,6 +105,7 @@ Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput Required.Proto3.JsonInput.ValueAcceptObject.ProtobufOutput Required.Proto3.JsonInput.ValueAcceptString.JsonOutput Required.Proto3.JsonInput.ValueAcceptString.ProtobufOutput +Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Default.JsonOutput Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKey.JsonOutput Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKeyInMapEntry.JsonOutput From 0d62c8cf2de5f1b0d94dcd209a08cb47a42db199 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Wed, 3 Nov 2021 14:19:32 -0700 Subject: [PATCH 15/30] final changes --- conformance/Cargo.toml | 2 -- prost-build/src/code_generator.rs | 7 +++++++ tests/src/lib.rs | 35 ++++++++++++++----------------- 3 files changed, 23 insertions(+), 21 deletions(-) diff --git a/conformance/Cargo.toml b/conformance/Cargo.toml index aabbd14c4..83249f173 100644 --- a/conformance/Cargo.toml +++ b/conformance/Cargo.toml @@ -15,5 +15,3 @@ log = "0.4" prost = { path = ".." } protobuf = { path = "../protobuf" } tests = { path = "../tests" } -serde_json = { version="1.0" } -serde_path_to_error = "0.1" diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 668a59d02..8f514a8fe 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -332,18 +332,23 @@ impl<'a> CodeGenerator<'a> { if let None = self.config.json_mapping.get_first(fq_message_name) { return; } + // If there is a json name specified, add it. if json_name.len() > 0 { push_indent(&mut self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(rename = "{}")]"#, json_name,)); self.buf.push('\n'); } + // Always alias to the field name for deserializing. push_indent(&mut self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); + + // Special case maps. if let Some(map_type) = map_type { + // Use is_empty instead of is_default to avoid allocations. self.buf.push_str(&format!( r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, map_type @@ -370,6 +375,8 @@ impl<'a> CodeGenerator<'a> { } self.buf.push('\n'); + // Add custom deserializers and optionally serializers for most primitive types + // and their optional and repeated counterparts. match (ty, optional, repeated) { ("i32", false, false) => { push_indent(&mut self.buf, self.depth); diff --git a/tests/src/lib.rs b/tests/src/lib.rs index b3eab3369..d972c45bd 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -181,22 +181,21 @@ where } } -fn proto_to_proto_checks() +/// Does additional checks on the binary output of the protobuf messages. +fn proto_checks(message: &M) -> Result<(), String> where M: Message + Default, { - /* - // TODO: Reenable this once sign-extension in negative int32s is figured out. + // TODO: Reenable this once sign-extension in negative int32s is figured out. // assert!(encoded_len <= data.len(), "encoded_len: {}, len: {}, all_types: {:?}", // encoded_len, data.len(), all_types); - - - let mut buf1 = Vec::new(); - if let Err(error) = all_types.encode(&mut buf1) { - return RoundtripResult::Error(error.to_string()); + let mut buf1 = Vec::new(); + if let Err(error) = message.encode(&mut buf1) { + return Err(error.to_string()); } + let encoded_len = message.encoded_len(); if encoded_len != buf1.len() { - return RoundtripResult::Error(format!( + return Err(format!( "expected encoded len ({}) did not match actual encoded len ({})", encoded_len, buf1.len() @@ -205,12 +204,12 @@ where let roundtrip = match M::decode(&*buf1) { Ok(roundtrip) => roundtrip, - Err(error) => return RoundtripResult::Error(error.to_string()), + Err(error) => return Err(error.to_string()), }; let mut buf2 = Vec::new(); if let Err(error) = roundtrip.encode(&mut buf2) { - return RoundtripResult::Error(error.to_string()); + return Err(error.to_string()); } let buf3 = roundtrip.encode_to_vec(); @@ -222,15 +221,13 @@ where */ if buf1 != buf2 { - return RoundtripResult::Error("roundtripped encoded buffers do not match".to_string()); + return Err("roundtripped encoded buffers do not match".to_string()); } if buf1 != buf3 { - return RoundtripResult::Error( - "roundtripped encoded buffers do not match with `encode_to_vec`".to_string(), - ); + return Err("roundtripped encoded buffers do not match with `encode_to_vec`".to_string()); } - */ + Ok(()) } /// Tests round-tripping a message type. The message should be compiled with `BTreeMap` fields, @@ -247,9 +244,9 @@ where Err(error) => return RoundtripResult::DecodeError(error), }; - if let conformance_request::Payload::ProtobufPayload(_) = payload { - if requested_output_format == WireFormat::Protobuf { - proto_to_proto_checks::(); + if requested_output_format == WireFormat::Protobuf { + if let Err(error) = proto_checks::(&all_types) { + return RoundtripResult::Error(error); } } From 0dd0652bc6d90fee2257a0f0bbce64786c6aa70c Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Thu, 11 Nov 2021 00:55:01 -0800 Subject: [PATCH 16/30] add maps, still haven't cleaned up some code --- prost-build/src/code_generator.rs | 237 +++++++++++++++--- prost-types/src/lib.rs | 403 +++++++++++++++++++++++++++++- 2 files changed, 599 insertions(+), 41 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 8f514a8fe..a73aa32a4 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -44,6 +44,48 @@ fn push_indent(buf: &mut String, depth: u8) { } } +/// Returns (serializer, deserializer) function names to use in serde +/// serialize_with and deserialize_with macros respectively. If none are +/// specified, the default works fine. +/// If collection is true, the return is no longer the function, but instead, +/// the Visitor type that will be used for either the repeated helper or +/// custom map helper. +fn get_custom_json_type_mappers( + ty: &str, + optional: bool, + collection: bool, +) -> (Option<&str>, Option<&str>) { + match (ty, optional, collection) { + ("bool", false, false) => (None, Some("::prost_types::bool_visitor::deserialize")), + ("bool", true, false) => (None, Some("::prost_types::bool_opt_visitor::deserialize")), + ("bool", false, true) => (None, Some("::prost_types::bool_visitor::BoolVisitor")), + ("i32", false, false) => (None, Some("::prost_types::i32_visitor::deserialize")), + ("i32", true, false) => (None, Some("::prost_types::i32_opt_visitor::deserialize")), + ("i32", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor")), + ("i64", false, false) => (None, Some("::prost_types::i64_visitor::deserialize")), + ("i64", true, false) => (None, Some("::prost_types::i64_opt_visitor::deserialize")), + ("i64", false, true) => (None, Some("::prost_types::i64_visitor::I64Visitor")), + ("u32", false, false) => (None, Some("::prost_types::u32_visitor::deserialize")), + ("u32", true, false) => (None, Some("::prost_types::u32_opt_visitor::deserialize")), + ("u32", false, true) => (None, Some("::prost_types::u32_visitor::U32Visitor")), + ("u64", false, false) => (None, Some("::prost_types::u64_visitor::deserialize")), + ("u64", true, false) => (None, Some("::prost_types::u64_opt_visitor::deserialize")), + ("u64", false, true) => (None, Some("::prost_types::u64_visitor::U64Visitor")), + ("f64", false, false) => (Some("<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::f64_visitor::deserialize")), + ("f64", true, false) => (Some("::prost_types::f64_opt_visitor::serialize"), Some("::prost_types::f64_opt_visitor::deserialize")), + ("f64", false, true) => (Some("::prost_types::f64_visitor::F64Serializer"), Some("::prost_types::f64_visitor::F64Visitor")), + ("f32", false, false) => (Some("<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::f32_visitor::deserialize")), + ("f32", true, false) => (Some("::prost_types::f32_opt_visitor::serialize"), Some("::prost_types::f32_opt_visitor::deserialize")), + ("f32", false, true) => (Some("::prost_types::f32_visitor::F32Serializer"), Some("::prost_types::f32_visitor::F32Visitor")), + ("::prost::alloc::string::String", false, false) => (None, Some("::prost_types::string_visitor::deserialize")), + ("::prost::alloc::string::String", true, false) => (None, Some("::prost_types::string_opt_visitor::deserialize")), + ("::prost::alloc::vec::Vec", false, false) => (Some("<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::vec_u8_visitor::deserialize")), + ("::prost::alloc::vec::Vec", true, false) => (Some("::prost_types::vec_u8_opt_visitor::serialize"), Some("::prost_types::vec_u8_opt_visitor::deserialize")), + ("::prost::alloc::vec::Vec", false, true) => (Some("::prost_types::vec_u8_visitor::VecU8Serializer"), Some("::prost_types::vec_u8_visitor::VecU8Visitor")), + (_,_, _) => (None, None) + } +} + impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, @@ -319,19 +361,8 @@ impl<'a> CodeGenerator<'a> { } } - fn append_json_field_attributes( - &mut self, - fq_message_name: &str, - ty: &str, - field_name: &str, - optional: bool, - repeated: bool, - json_name: &str, - map_type: Option<&str>, - ) { - if let None = self.config.json_mapping.get_first(fq_message_name) { - return; - } + // Shared fields between field and map fields. + fn append_shared_json_field_attributes(&mut self, field_name: &str, json_name: &str) { // If there is a json name specified, add it. if json_name.len() > 0 { push_indent(&mut self.buf, self.depth); @@ -345,34 +376,162 @@ impl<'a> CodeGenerator<'a> { .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); + } - // Special case maps. - if let Some(map_type) = map_type { - // Use is_empty instead of is_default to avoid allocations. - self.buf.push_str(&format!( - r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, - map_type - )); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - match map_type { - "::std::collections::HashMap" => - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::map_visitor::deserialize")]"#, - ), - "::prost::alloc::collections::BTreeMap" => + fn append_json_map_field_attributes( + &mut self, + fq_message_name: &str, + field_name: &str, + key_ty: &str, + value_ty: &str, + map_type: &str, + json_name: &str, + ) { + if let None = self.config.json_mapping.get_first(fq_message_name) { + return; + } + self.append_shared_json_field_attributes(field_name, json_name); + + // Use is_empty instead of is_default to avoid allocations. + push_indent(&mut self.buf, self.depth); + self.buf.push_str(&format!( + r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, + map_type + )); + self.buf.push('\n'); + + let (key_se_opt, key_de_opt) = get_custom_json_type_mappers(key_ty, false, true); + let (value_se_opt, value_de_opt) = get_custom_json_type_mappers(value_ty, false, true); + + push_indent(&mut self.buf, self.depth); + match (key_se_opt, key_de_opt, value_se_opt, value_de_opt, map_type) { + (Some(key_se), Some(key_de), Some(value_se), Some(value_de), "::std::collections::HashMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_to_custom_visitor::serialize::<_, {}, {}>")]"#, key_se, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + } + (None, Some(key_de), None, Some(value_de), "::std::collections::HashMap") => + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ), + (Some(key_se), Some(key_de), None, Some(value_de), "::std::collections::HashMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + }, + (Some(key_se), Some(key_de), None, None, "::std::collections::HashMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + ); + }, + (None, Some(key_de), None, None, "::std::collections::HashMap") => + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + ), + (None, Some(key_de), Some(value_se), Some(value_de), "::std::collections::HashMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_serializer::serialize::<_, _, {}>")]"#, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + }, + (Some(key_se), Some(key_de), Some(value_se), Some(value_de), "::prost::alloc::collections::BTreeMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_to_custom_visitor::serialize::<_, {}, {}>")]"#, key_se, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + } + (None, Some(key_de), None, Some(value_de), "::prost::alloc::collections::BTreeMap") => + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ), + (Some(key_se), Some(key_de), None, Some(value_de), "::prost::alloc::collections::BTreeMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + }, + (Some(key_se), Some(key_de), None, None, "::prost::alloc::collections::BTreeMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + ); + }, + (None, Some(key_de), None, None, "::prost::alloc::collections::BTreeMap") => + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + ), + (None, Some(key_de), Some(value_se), Some(value_de), "::prost::alloc::collections::BTreeMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_serializer::serialize::<_, _, {}>")]"#, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + ); + + }, + (_, _, _, _, "::std::collections::HashMap") => + self.buf.push_str( + r#"#[serde(deserialize_with = "::prost_types::map_visitor::deserialize")]"#, + ), + (_, _, _, _, "::prost::alloc::collections::BTreeMap") => self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::btree_map_visitor::deserialize")]"#, ), + _ => (), + } + self.buf.push('\n'); + } - _ => (), - } - self.buf.push('\n'); + fn append_json_field_attributes( + &mut self, + fq_message_name: &str, + ty: &str, + field_name: &str, + optional: bool, + repeated: bool, + json_name: &str, + ) { + if let None = self.config.json_mapping.get_first(fq_message_name) { return; - } else { - self.buf - .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); } + self.append_shared_json_field_attributes(field_name, json_name); + + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); self.buf.push('\n'); // Add custom deserializers and optionally serializers for most primitive types @@ -706,7 +865,6 @@ impl<'a> CodeGenerator<'a> { optional, repeated, field.json_name(), - None, ); self.push_indent(); self.buf.push_str("pub "); @@ -767,14 +925,13 @@ impl<'a> CodeGenerator<'a> { field.number() )); self.append_field_attributes(fq_message_name, field.name()); - self.append_json_field_attributes( + self.append_json_map_field_attributes( fq_message_name, - map_type.rust_type(), field.name(), - false, - false, + &key_ty, + &value_ty, + map_type.rust_type(), field.json_name(), - Some(map_type.rust_type()), ); self.push_indent(); self.buf.push_str(&format!( diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index c6a7d5b43..85af2eb19 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -433,7 +433,6 @@ pub mod repeated_visitor { where S: serde::Serializer, F: crate::SerializeMethod, - // ::Value: Copy, { use serde::ser::SerializeSeq; let mut seq = serializer.serialize_seq(Some(value.len()))?; @@ -444,6 +443,408 @@ pub mod repeated_visitor { } } +pub mod map_custom_serializer { + pub fn serialize( + value: &std::collections::HashMap::Value>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + K: serde::Serialize + std::cmp::Eq + std::hash::Hash, + G: crate::SerializeMethod, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&key, &crate::MySeType:: { val: value })?; + } + map.end() + } +} + +pub mod btree_map_custom_serializer { + pub fn serialize( + value: &std::collections::BTreeMap::Value>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + K: serde::Serialize + std::cmp::Eq + std::cmp::Ord, + G: crate::SerializeMethod, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&key, &crate::MySeType:: { val: value })?; + } + map.end() + } +} + +pub mod map_custom_visitor { + struct MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + V: serde::Deserialize<'de>, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + V: serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::hash::Hash, + { + type Value = std::collections::HashMap<>::Value, V>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option<(crate::MyType<'de, T>, V)> = map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value, V>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + V: 'de + serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::hash::Hash, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::HashMap<::Value, V>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::SerializeMethod, + V: serde::Serialize, + ::Value: std::cmp::Eq + std::hash::Hash, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&crate::MySeType:: { val: key }, &value)?; + } + map.end() + } +} + +pub mod map_custom_to_custom_visitor { + struct MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + S: serde::de::Visitor<'de> + crate::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de S>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + S: serde::de::Visitor<'de> + crate::HasConstructor, + >::Value: std::cmp::Eq + std::hash::Hash, + { + type Value = std::collections::HashMap< + >::Value, + >::Value, + >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option<(crate::MyType<'de, T>, crate::MyType<'de, S>)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, S>( + deserializer: D, + ) -> Result< + std::collections::HashMap< + >::Value, + >::Value, + >, + D::Error, + > + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + S: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + >::Value: std::cmp::Eq + std::hash::Hash, + { + deserializer.deserialize_any(MapVisitor::<'de, T, S> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::HashMap< + ::Value, + ::Value, + >, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::SerializeMethod, + G: crate::SerializeMethod, + ::Value: std::cmp::Eq + std::hash::Hash, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry( + &crate::MySeType:: { val: key }, + &crate::MySeType:: { val: value }, + )?; + } + map.end() + } +} + +pub mod btree_map_custom_visitor { + struct MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + V: serde::Deserialize<'de>, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + V: serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + type Value = std::collections::BTreeMap<>::Value, V>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + let response: std::option::Option<(crate::MyType<'de, T>, V)> = map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value, V>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + V: 'de + serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::BTreeMap<::Value, V>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::SerializeMethod, + V: serde::Serialize, + ::Value: std::cmp::Eq + std::cmp::Ord, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&crate::MySeType:: { val: key }, &value)?; + } + map.end() + } +} + +pub mod btree_map_custom_to_custom_visitor { + struct MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + S: serde::de::Visitor<'de> + crate::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de S>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::HasConstructor, + S: serde::de::Visitor<'de> + crate::HasConstructor, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + type Value = std::collections::BTreeMap< + >::Value, + >::Value, + >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + let response: std::option::Option<(crate::MyType<'de, T>, crate::MyType<'de, S>)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, S>( + deserializer: D, + ) -> Result< + std::collections::BTreeMap< + >::Value, + >::Value, + >, + D::Error, + > + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + S: 'de + serde::de::Visitor<'de> + crate::HasConstructor, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + deserializer.deserialize_any(MapVisitor::<'de, T, S> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::BTreeMap< + ::Value, + ::Value, + >, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::SerializeMethod, + G: crate::SerializeMethod, + ::Value: std::cmp::Eq + std::cmp::Ord, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry( + &crate::MySeType:: { val: key }, + &crate::MySeType:: { val: value }, + )?; + } + map.end() + } +} + pub trait SerializeMethod { type Value; fn serialize(value: &Self::Value, serializer: S) -> Result From 5f6c093be653f5cf9e4915edcf8a7000961edb67 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Sat, 13 Nov 2021 22:48:58 -0800 Subject: [PATCH 17/30] maps working correctly --- prost-build/src/code_generator.rs | 198 +++--------------------------- prost-derive/src/lib.rs | 94 ++++++++++++++ 2 files changed, 108 insertions(+), 184 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index a73aa32a4..3a73f4753 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -533,218 +533,48 @@ impl<'a> CodeGenerator<'a> { self.buf .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); self.buf.push('\n'); - // Add custom deserializers and optionally serializers for most primitive types // and their optional and repeated counterparts. - match (ty, optional, repeated) { - ("i32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i32_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("i32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i32_visitor::I32Visitor>")]"#, - ); - self.buf.push('\n'); - } - ("i32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i32_opt_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("bool", false, false) => { + match (get_custom_json_type_mappers(ty, optional, repeated), repeated) { + ((Some(se), Some(de)), false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::bool_visitor::deserialize")]"#, + &format!(r#"#[serde(serialize_with = "{}")]"#, se), ); self.buf.push('\n'); - } - ("bool", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::bool_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("bool", false, true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::bool_visitor::BoolVisitor>")]"#, - ); - self.buf.push('\n'); - } - ("i64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i64_visitor::deserialize")]"#, + &format!(r#"#[serde(deserialize_with = "{}")]"#, de), ); self.buf.push('\n'); - } - ("i64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::i64_opt_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("i64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::i64_visitor::I64Visitor>")]"#, - ); - self.buf.push('\n'); - } - ("u32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u32_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u32_opt_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u32_visitor::U32Visitor>")]"#, - ); - self.buf.push('\n'); - } - ("u64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u64_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::u64_opt_visitor::deserialize")]"#, - ); - self.buf.push('\n'); - } - ("u64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::u64_visitor::U64Visitor>")]"#, - ); - self.buf.push('\n'); - } - ("f64", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); + }, + ((None, Some(de)), false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::f64_visitor::deserialize")]"#, + &format!(r#"#[serde(deserialize_with = "{}")]"#, de), ); self.buf.push('\n'); - } - ("f64", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f64_visitor::F64Visitor>")]"#, - ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f64_visitor::F64Serializer>")]"# - ); - self.buf.push('\n'); - } - ("f64", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::f64_opt_visitor")]"#); - self.buf.push('\n'); - } - ("f32", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); + }, + ((Some(se), Some(de)), true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::f32_visitor::deserialize")]"#, + &format!(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, {}>")]"#, se), ); self.buf.push('\n'); - } - ("f32", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::f32_visitor::F32Visitor>")]"#, - ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::f32_visitor::F32Serializer>")]"# - ); - self.buf.push('\n'); - } - ("f32", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::f32_opt_visitor")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::string::String", false, false) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::string_visitor::deserialize")]"#, + &format!(r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } - ("::prost::alloc::string::String", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(deserialize_with = "::prost_types::string_opt_visitor::deserialize")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::vec::Vec", false, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize")]"#); - self.buf.push('\n'); + ((None, Some(de)), true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::vec_u8_visitor::deserialize")]"#, + &format!(r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } - ("::prost::alloc::vec::Vec", true, false) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(with = "::prost_types::vec_u8_opt_visitor")]"#); - self.buf.push('\n'); - } - ("::prost::alloc::vec::Vec", false, true) => { - push_indent(&mut self.buf, self.depth); - self.buf - .push_str( - r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, ::prost_types::vec_u8_visitor::VecU8Visitor>")]"#, - ); - self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, ::prost_types::vec_u8_visitor::VecU8Serializer>")]"# - ); - self.buf.push('\n'); - } - (_, _, true) => { + (_, true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#, diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 74608d56d..a6e37adf1 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -325,6 +325,100 @@ pub fn enumeration(input: TokenStream) -> TokenStream { try_enumeration(input).unwrap() } +fn try_json_enumeration(input: TokenStream) -> Result { + let input: DeriveInput = syn::parse(input)?; + let ident = input.ident; + + let generics = &input.generics; + let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); + + let punctuated_variants = match input.data { + Data::Enum(DataEnum { variants, .. }) => variants, + Data::Struct(_) => bail!("Enumeration can not be derived for a struct"), + Data::Union(..) => bail!("Enumeration can not be derived for a union"), + }; + + // Map the variants into 'fields'. + let mut variants: Vec<(Ident, Expr)> = Vec::new(); + for Variant { + ident, + fields, + discriminant, + .. + } in punctuated_variants + { + match fields { + Fields::Unit => (), + Fields::Named(_) | Fields::Unnamed(_) => { + bail!("Enumeration variants may not have fields") + } + } + + match discriminant { + Some((_, expr)) => variants.push((ident, expr)), + None => bail!("Enumeration variants must have a disriminant"), + } + } + + if variants.is_empty() { + panic!("Enumeration must have at least one variant"); + } + + let default = variants[0].0.clone(); + + let is_valid = variants + .iter() + .map(|&(_, ref value)| quote!(#value => true)); + let from = variants.iter().map( + |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), + ); + + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); + let from_i32_doc = format!( + "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", + ident + ); + + let expanded = quote! { + impl #impl_generics #ident #ty_generics #where_clause { + #[doc=#is_valid_doc] + pub fn is_valid(value: i32) -> bool { + match value { + #(#is_valid,)* + _ => false, + } + } + + #[doc=#from_i32_doc] + pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { + match value { + #(#from,)* + _ => ::core::option::Option::None, + } + } + } + + impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { + fn default() -> #ident { + #ident::#default + } + } + + impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause { + fn from(value: #ident) -> i32 { + value as i32 + } + } + }; + + Ok(expanded.into()) +} + +#[proc_macro_derive(JsonEnumeration, attributes(prost))] +pub fn json_enumeration(input: TokenStream) -> TokenStream { + try_json_enumeration(input).unwrap() +} + fn try_oneof(input: TokenStream) -> Result { let input: DeriveInput = syn::parse(input)?; From 372aa431822cd9f9d08fdd334e253ed9149ad4db Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Tue, 15 Feb 2022 10:09:44 -0800 Subject: [PATCH 18/30] updated w partial enums --- conformance/failing_tests.txt | 1 + prost-build/src/code_generator.rs | 4 +- prost-derive/src/field/mod.rs | 2 +- prost-derive/src/lib.rs | 133 ++++++++++-------------------- prost-types/src/compiler.rs | 2 + prost-types/src/lib.rs | 65 +++++++++++++++ prost-types/src/protobuf.rs | 59 +++++++++++++ 7 files changed, 173 insertions(+), 93 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 40b53f347..e45f5ac96 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -9,6 +9,7 @@ Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator Recommended.Proto3.JsonInput.Int64FieldBeString.Validator +Recommended.Proto3.JsonInput.MapFieldValueIsNull Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 3a73f4753..e9d822690 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -375,7 +375,6 @@ impl<'a> CodeGenerator<'a> { self.buf .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); } fn append_json_map_field_attributes( @@ -970,6 +969,9 @@ impl<'a> CodeGenerator<'a> { self.append_doc(fq_enum_name, Some(value.name())); self.append_field_attributes(fq_enum_name, &value.name()); self.push_indent(); + self.buf.push_str(&format!(r#"#[prost(enum_field_name="{}")]"#, value.name())); + self.buf.push_str("\n"); + self.push_indent(); let name = to_upper_camel(value.name()); let name_unprefixed = match prefix_to_strip { Some(prefix) => strip_enum_prefix(&prefix, &name), diff --git a/prost-derive/src/field/mod.rs b/prost-derive/src/field/mod.rs index 09fef830e..537da18bf 100644 --- a/prost-derive/src/field/mod.rs +++ b/prost-derive/src/field/mod.rs @@ -224,7 +224,7 @@ impl fmt::Display for Label { } /// Get the items belonging to the 'prost' list attribute, e.g. `#[prost(foo, bar="baz")]`. -fn prost_attrs(attrs: Vec) -> Vec { +pub fn prost_attrs(attrs: Vec) -> Vec { attrs .iter() .flat_map(Attribute::parse_meta) diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index a6e37adf1..6d642be0a 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -246,10 +246,12 @@ fn try_enumeration(input: TokenStream) -> Result { // Map the variants into 'fields'. let mut variants: Vec<(Ident, Expr)> = Vec::new(); + let mut proto_names: Vec = Vec::new(); for Variant { ident, fields, discriminant, + attrs, .. } in punctuated_variants { @@ -262,11 +264,25 @@ fn try_enumeration(input: TokenStream) -> Result { match discriminant { Some((_, expr)) => variants.push((ident, expr)), - None => bail!("Enumeration variants must have a disriminant"), + None => bail!("Enumeration variants must have a discriminant"), + } + + let metas = crate::field::prost_attrs(attrs); + for meta in metas { + if let syn::Meta::NameValue(syn::MetaNameValue { path, lit, .. } ) = meta { + if path.is_ident("enum_field_name") { + if let syn::Lit::Str(lit_str) = lit { + proto_names.push(lit_str.value()); + break; + } + } + } } } - if variants.is_empty() { + // TODO(konradjniemiec): we need to default to not failing here, + // and instead deriving the proto names to avoid breaking changes. + if variants.is_empty() || variants.len() != proto_names.len() { panic!("Enumeration must have at least one variant"); } @@ -279,6 +295,11 @@ fn try_enumeration(input: TokenStream) -> Result { |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), ); + let to_string = variants.iter().zip(proto_names.iter()).map(|(&(ref variant, _), proto_name)| quote!(#ident::#variant => #proto_name)); + assert!(to_string.len() > 0); + let from_string = variants.iter().zip(proto_names.iter()).map(|(&(ref variant, _), proto_name)| quote!(#proto_name => #ident::#variant)); + assert!(from_string.len() > 0); + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); let from_i32_doc = format!( "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", @@ -315,98 +336,28 @@ fn try_enumeration(input: TokenStream) -> Result { value as i32 } } - }; - - Ok(expanded.into()) -} - -#[proc_macro_derive(Enumeration, attributes(prost))] -pub fn enumeration(input: TokenStream) -> TokenStream { - try_enumeration(input).unwrap() -} - -fn try_json_enumeration(input: TokenStream) -> Result { - let input: DeriveInput = syn::parse(input)?; - let ident = input.ident; - - let generics = &input.generics; - let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); - - let punctuated_variants = match input.data { - Data::Enum(DataEnum { variants, .. }) => variants, - Data::Struct(_) => bail!("Enumeration can not be derived for a struct"), - Data::Union(..) => bail!("Enumeration can not be derived for a union"), - }; - // Map the variants into 'fields'. - let mut variants: Vec<(Ident, Expr)> = Vec::new(); - for Variant { - ident, - fields, - discriminant, - .. - } in punctuated_variants - { - match fields { - Fields::Unit => (), - Fields::Named(_) | Fields::Unnamed(_) => { - bail!("Enumeration variants may not have fields") + impl #impl_generics ::core::convert::TryFrom:: for #ident #ty_generics #where_clause { + type Error = &'static str; + fn try_from(value: i32) -> Result { + Self::from_i32(value).ok_or_else(|| "invalid i32 value for enum") } } - match discriminant { - Some((_, expr)) => variants.push((ident, expr)), - None => bail!("Enumeration variants must have a disriminant"), - } - } - - if variants.is_empty() { - panic!("Enumeration must have at least one variant"); - } - - let default = variants[0].0.clone(); - - let is_valid = variants - .iter() - .map(|&(_, ref value)| quote!(#value => true)); - let from = variants.iter().map( - |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), - ); - - let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); - let from_i32_doc = format!( - "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", - ident - ); - - let expanded = quote! { - impl #impl_generics #ident #ty_generics #where_clause { - #[doc=#is_valid_doc] - pub fn is_valid(value: i32) -> bool { - match value { - #(#is_valid,)* - _ => false, - } - } - - #[doc=#from_i32_doc] - pub fn from_i32(value: i32) -> ::core::option::Option<#ident> { - match value { - #(#from,)* - _ => ::core::option::Option::None, - } - } - } - - impl #impl_generics ::core::default::Default for #ident #ty_generics #where_clause { - fn default() -> #ident { - #ident::#default + impl #impl_generics ToString for #ident #ty_generics #where_clause { + fn to_string(&self) -> String { + match self { + #(#to_string,)* + }.to_string() } } - - impl #impl_generics ::core::convert::From::<#ident> for i32 #ty_generics #where_clause { - fn from(value: #ident) -> i32 { - value as i32 + impl #impl_generics ::core::str::FromStr for #ident #ty_generics #where_clause { + type Err = &'static str; + fn from_str(value: &str) -> Result { + Ok(match value { + #(#from_string,)* + _ => Self::default(), + }) } } }; @@ -414,9 +365,9 @@ fn try_json_enumeration(input: TokenStream) -> Result { Ok(expanded.into()) } -#[proc_macro_derive(JsonEnumeration, attributes(prost))] -pub fn json_enumeration(input: TokenStream) -> TokenStream { - try_json_enumeration(input).unwrap() +#[proc_macro_derive(Enumeration, attributes(prost))] +pub fn enumeration(input: TokenStream) -> TokenStream { + try_enumeration(input).unwrap() } fn try_oneof(input: TokenStream) -> Result { diff --git a/prost-types/src/compiler.rs b/prost-types/src/compiler.rs index da30df775..627c52261 100644 --- a/prost-types/src/compiler.rs +++ b/prost-types/src/compiler.rs @@ -133,7 +133,9 @@ pub mod code_generator_response { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Feature { + #[prost(enum_field_name="FEATURE_NONE")] None = 0, + #[prost(enum_field_name="FEATURE_PROTO3_OPTIONAL")] Proto3Optional = 1, } } diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 85af2eb19..7729be44d 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -443,6 +443,71 @@ pub mod repeated_visitor { } } +pub mod enum_visitor { + struct EnumVisitor<'de, T> + where + T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + { + _type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> + where + T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + { + type Value = i32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match T::from_str(value) { + Ok(en) => Ok(en.into()), + Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &self)), + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T>( + deserializer: D, + ) -> Result + where + D: serde::Deserializer<'de>, + T: 'de + ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + { + deserializer.deserialize_any(EnumVisitor::<'de, T> { + _type: &std::marker::PhantomData, + }) + } + + pub fn serialize( + value: &i32, + serializer: S, + ) -> Result + where + S: serde::Serializer, + T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + { + match T::try_from(*value) { + Err(_) => Err(serde::ser::Error::custom("invalid enum value")), + Ok(t) => serializer.serialize_str(&t.to_string()) + } + } +} + pub mod map_custom_serializer { pub fn serialize( value: &std::collections::HashMap::Value>, diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 5db743068..2afaeda0c 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -178,43 +178,64 @@ pub mod field_descriptor_proto { pub enum Type { /// 0 is reserved for errors. /// Order is weird for historical reasons. + #[prost(enum_field_name="TYPE_DOUBLE")] Double = 1, + #[prost(enum_field_name="TYPE_FLOAT")] Float = 2, /// Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT64 if /// negative values are likely. + #[prost(enum_field_name="TYPE_INT64")] Int64 = 3, + #[prost(enum_field_name="TYPE_UINT64")] Uint64 = 4, /// Not ZigZag encoded. Negative numbers take 10 bytes. Use TYPE_SINT32 if /// negative values are likely. + #[prost(enum_field_name="TYPE_INT32")] Int32 = 5, + #[prost(enum_field_name="TYPE_FIXED64")] Fixed64 = 6, + #[prost(enum_field_name="TYPE_FIXED32")] Fixed32 = 7, + #[prost(enum_field_name="TYPE_BOOL")] Bool = 8, + #[prost(enum_field_name="TYPE_STRING")] String = 9, /// Tag-delimited aggregate. /// Group type is deprecated and not supported in proto3. However, Proto3 /// implementations should still be able to parse the group wire format and /// treat group fields as unknown fields. + #[prost(enum_field_name="TYPE_GROUP")] Group = 10, /// Length-delimited aggregate. + #[prost(enum_field_name="TYPE_MESSAGE")] Message = 11, /// New in version 2. + #[prost(enum_field_name="TYPE_BYTES")] Bytes = 12, + #[prost(enum_field_name="TYPE_UINT32")] Uint32 = 13, + #[prost(enum_field_name="TYPE_ENUM")] Enum = 14, + #[prost(enum_field_name="TYPE_SFIXED32")] Sfixed32 = 15, + #[prost(enum_field_name="TYPE_SFIXED64")] Sfixed64 = 16, /// Uses ZigZag encoding. + #[prost(enum_field_name="TYPE_SINT32")] Sint32 = 17, /// Uses ZigZag encoding. + #[prost(enum_field_name="TYPE_SINT64")] Sint64 = 18, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum Label { /// 0 is reserved for errors + #[prost(enum_field_name="LABEL_OPTIONAL")] Optional = 1, + #[prost(enum_field_name="LABEL_REQUIRED")] Required = 2, + #[prost(enum_field_name="LABEL_REPEATED")] Repeated = 3, } } @@ -451,12 +472,15 @@ pub mod file_options { #[repr(i32)] pub enum OptimizeMode { /// Generate complete code for parsing, serialization, + #[prost(enum_field_name="SPEED")] Speed = 1, /// etc. /// /// Use ReflectionOps to implement these methods. + #[prost(enum_field_name="CODE_SIZE")] CodeSize = 2, /// Generate code using MessageLite and the lite runtime. + #[prost(enum_field_name="LITE_RUNTIME")] LiteRuntime = 3, } } @@ -597,18 +621,24 @@ pub mod field_options { #[repr(i32)] pub enum CType { /// Default mode. + #[prost(enum_field_name="STRING")] String = 0, + #[prost(enum_field_name="CORD")] Cord = 1, + #[prost(enum_field_name="STRING_PIECE")] StringPiece = 2, } #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum JsType { /// Use the default type. + #[prost(enum_field_name="JS_NORMAL")] JsNormal = 0, /// Use JavaScript strings. + #[prost(enum_field_name="JS_STRING")] JsString = 1, /// Use JavaScript numbers. + #[prost(enum_field_name="JS_NUMBER")] JsNumber = 2, } } @@ -690,10 +720,13 @@ pub mod method_options { #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, PartialOrd, Ord, ::prost::Enumeration)] #[repr(i32)] pub enum IdempotencyLevel { + #[prost(enum_field_name="IDEMPOTENCY_UNKNOWN")] IdempotencyUnknown = 0, /// implies idempotent + #[prost(enum_field_name="NO_SIDE_EFFECTS")] NoSideEffects = 1, /// idempotent, but may have side effects + #[prost(enum_field_name="IDEMPOTENT")] Idempotent = 2, } } @@ -1107,42 +1140,61 @@ pub mod field { #[repr(i32)] pub enum Kind { /// Field type unknown. + #[prost(enum_field_name="TYPE_UNKNOWN")] TypeUnknown = 0, /// Field type double. + #[prost(enum_field_name="TYPE_DOUBLE")] TypeDouble = 1, /// Field type float. + #[prost(enum_field_name="TYPE_FLOAT")] TypeFloat = 2, /// Field type int64. + #[prost(enum_field_name="TYPE_INT64")] TypeInt64 = 3, /// Field type uint64. + #[prost(enum_field_name="TYPE_UINT64")] TypeUint64 = 4, /// Field type int32. + #[prost(enum_field_name="TYPE_INT32")] TypeInt32 = 5, /// Field type fixed64. + #[prost(enum_field_name="TYPE_FIXED64")] TypeFixed64 = 6, /// Field type fixed32. + #[prost(enum_field_name="TYPE_FIXED32")] TypeFixed32 = 7, /// Field type bool. + #[prost(enum_field_name="TYPE_BOOL")] TypeBool = 8, /// Field type string. + #[prost(enum_field_name="TYPE_STRING")] TypeString = 9, /// Field type group. Proto2 syntax only, and deprecated. + #[prost(enum_field_name="TYPE_GROUP")] TypeGroup = 10, /// Field type message. + #[prost(enum_field_name="TYPE_MESSAGE")] TypeMessage = 11, /// Field type bytes. + #[prost(enum_field_name="TYPE_BYTES")] TypeBytes = 12, /// Field type uint32. + #[prost(enum_field_name="TYPE_UINT32")] TypeUint32 = 13, /// Field type enum. + #[prost(enum_field_name="TYPE_ENUM")] TypeEnum = 14, /// Field type sfixed32. + #[prost(enum_field_name="TYPE_SFIXED32")] TypeSfixed32 = 15, /// Field type sfixed64. + #[prost(enum_field_name="TYPE_SFIXED64")] TypeSfixed64 = 16, /// Field type sint32. + #[prost(enum_field_name="TYPE_SINT32")] TypeSint32 = 17, /// Field type sint64. + #[prost(enum_field_name="TYPE_SINT64")] TypeSint64 = 18, } /// Whether a field is optional, required, or repeated. @@ -1150,12 +1202,16 @@ pub mod field { #[repr(i32)] pub enum Cardinality { /// For fields with unknown cardinality. + #[prost(enum_field_name="CARDINALITY_UNKNOWN")] Unknown = 0, /// For optional fields. + #[prost(enum_field_name="CARDINALITY_OPTIONAL")] Optional = 1, /// For required fields. Proto2 syntax only. + #[prost(enum_field_name="CARDINALITY_REQUIRED")] Required = 2, /// For repeated fields. + #[prost(enum_field_name="CARDINALITY_REPEATED")] Repeated = 3, } } @@ -1213,8 +1269,10 @@ pub struct Option { #[repr(i32)] pub enum Syntax { /// Syntax `proto2`. + #[prost(enum_field_name="SYNTAX_PROTO2")] Proto2 = 0, /// Syntax `proto3`. + #[prost(enum_field_name="SYNTAX_PROTO3")] Proto3 = 1, } /// Api is a light-weight descriptor for an API Interface. @@ -1736,6 +1794,7 @@ pub struct ListValue { #[repr(i32)] pub enum NullValue { /// Null value. + #[prost(enum_field_name="NULL_VALUE")] NullValue = 0, } /// A Timestamp represents a point in time independent of any time zone or local From ac9adfd7f7db67b97aed3f64da11ebe0c1a3f40a Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Tue, 15 Feb 2022 10:12:42 -0800 Subject: [PATCH 19/30] fmt --- prost-build/src/code_generator.rs | 27 ++++++++++---------- prost-derive/src/lib.rs | 16 ++++++++---- prost-types/src/lib.rs | 41 +++++++++++++++++++++---------- 3 files changed, 53 insertions(+), 31 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index e9d822690..6735b75e1 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -534,26 +534,26 @@ impl<'a> CodeGenerator<'a> { self.buf.push('\n'); // Add custom deserializers and optionally serializers for most primitive types // and their optional and repeated counterparts. - match (get_custom_json_type_mappers(ty, optional, repeated), repeated) { + match ( + get_custom_json_type_mappers(ty, optional, repeated), + repeated, + ) { ((Some(se), Some(de)), false) => { push_indent(&mut self.buf, self.depth); - self.buf.push_str( - &format!(r#"#[serde(serialize_with = "{}")]"#, se), - ); + self.buf + .push_str(&format!(r#"#[serde(serialize_with = "{}")]"#, se)); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); - self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "{}")]"#, de), - ); + self.buf + .push_str(&format!(r#"#[serde(deserialize_with = "{}")]"#, de)); self.buf.push('\n'); - }, + } ((None, Some(de)), false) => { push_indent(&mut self.buf, self.depth); - self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "{}")]"#, de), - ); + self.buf + .push_str(&format!(r#"#[serde(deserialize_with = "{}")]"#, de)); self.buf.push('\n'); - }, + } ((Some(se), Some(de)), true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( @@ -969,7 +969,8 @@ impl<'a> CodeGenerator<'a> { self.append_doc(fq_enum_name, Some(value.name())); self.append_field_attributes(fq_enum_name, &value.name()); self.push_indent(); - self.buf.push_str(&format!(r#"#[prost(enum_field_name="{}")]"#, value.name())); + self.buf + .push_str(&format!(r#"#[prost(enum_field_name="{}")]"#, value.name())); self.buf.push_str("\n"); self.push_indent(); let name = to_upper_camel(value.name()); diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 6d642be0a..8844ffa82 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -266,10 +266,10 @@ fn try_enumeration(input: TokenStream) -> Result { Some((_, expr)) => variants.push((ident, expr)), None => bail!("Enumeration variants must have a discriminant"), } - + let metas = crate::field::prost_attrs(attrs); for meta in metas { - if let syn::Meta::NameValue(syn::MetaNameValue { path, lit, .. } ) = meta { + if let syn::Meta::NameValue(syn::MetaNameValue { path, lit, .. }) = meta { if path.is_ident("enum_field_name") { if let syn::Lit::Str(lit_str) = lit { proto_names.push(lit_str.value()); @@ -295,11 +295,17 @@ fn try_enumeration(input: TokenStream) -> Result { |&(ref variant, ref value)| quote!(#value => ::core::option::Option::Some(#ident::#variant)), ); - let to_string = variants.iter().zip(proto_names.iter()).map(|(&(ref variant, _), proto_name)| quote!(#ident::#variant => #proto_name)); + let to_string = variants + .iter() + .zip(proto_names.iter()) + .map(|(&(ref variant, _), proto_name)| quote!(#ident::#variant => #proto_name)); assert!(to_string.len() > 0); - let from_string = variants.iter().zip(proto_names.iter()).map(|(&(ref variant, _), proto_name)| quote!(#proto_name => #ident::#variant)); + let from_string = variants + .iter() + .zip(proto_names.iter()) + .map(|(&(ref variant, _), proto_name)| quote!(#proto_name => #ident::#variant)); assert!(from_string.len() > 0); - + let is_valid_doc = format!("Returns `true` if `value` is a variant of `{}`.", ident); let from_i32_doc = format!( "Converts an `i32` to a `{}`, or `None` if `value` is not a valid variant.", diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 7729be44d..754d337f3 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -446,7 +446,11 @@ pub mod repeated_visitor { pub mod enum_visitor { struct EnumVisitor<'de, T> where - T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, { _type: &'de std::marker::PhantomData, } @@ -454,7 +458,11 @@ pub mod enum_visitor { #[cfg(feature = "std")] impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> where - T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, { type Value = i32; @@ -468,7 +476,10 @@ pub mod enum_visitor { { match T::from_str(value) { Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &self)), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &self, + )), } } @@ -481,29 +492,33 @@ pub mod enum_visitor { } #[cfg(feature = "std")] - pub fn deserialize<'de, D, T>( - deserializer: D, - ) -> Result + pub fn deserialize<'de, D, T>(deserializer: D) -> Result where D: serde::Deserializer<'de>, - T: 'de + ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + T: 'de + + ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, { deserializer.deserialize_any(EnumVisitor::<'de, T> { _type: &std::marker::PhantomData, }) } - pub fn serialize( - value: &i32, - serializer: S, - ) -> Result + pub fn serialize(value: &i32, serializer: S) -> Result where S: serde::Serializer, - T: ToString + std::str::FromStr + std::convert::Into + std::convert::TryFrom + Default, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, { match T::try_from(*value) { Err(_) => Err(serde::ser::Error::custom("invalid enum value")), - Ok(t) => serializer.serialize_str(&t.to_string()) + Ok(t) => serializer.serialize_str(&t.to_string()), } } } From 9ef2b423a9484d77b848ff6dddb2e59044a754ec Mon Sep 17 00:00:00 2001 From: Mohamed Yassin Date: Wed, 16 Feb 2022 19:14:12 -0500 Subject: [PATCH 20/30] Added enum visitor and enum opt visitor --- .vscode/launch.json | 291 ++++++++++++++++++++++++++++++ conformance/failing_tests.txt | 134 +------------- conformance/succeeding_tests.txt | 3 + output.log | 55 ++++++ prost-build/src/code_generator.rs | 118 +++++++----- prost-derive/src/lib.rs | 5 + prost-types/src/lib.rs | 138 ++++++++++++++ 7 files changed, 569 insertions(+), 175 deletions(-) create mode 100644 .vscode/launch.json create mode 100644 conformance/succeeding_tests.txt create mode 100644 output.log diff --git a/.vscode/launch.json b/.vscode/launch.json new file mode 100644 index 000000000..e481d0956 --- /dev/null +++ b/.vscode/launch.json @@ -0,0 +1,291 @@ +{ + // Use IntelliSense to learn about possible attributes. + // Hover to view descriptions of existing attributes. + // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 + "version": "0.2.0", + "configurations": [ + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'conformance'", + "cargo": { + "args": [ + "build", + "--bin=conformance", + "--package=conformance" + ], + "filter": { + "name": "conformance", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in executable 'conformance'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bin=conformance", + "--package=conformance" + ], + "filter": { + "name": "conformance", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug integration test 'conformance'", + "cargo": { + "args": [ + "test", + "--no-run", + "--test=conformance", + "--package=conformance" + ], + "filter": { + "name": "conformance", + "kind": "test" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'prost'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=prost" + ], + "filter": { + "name": "prost", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'varint'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=varint", + "--package=prost" + ], + "filter": { + "name": "varint", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'protobuf'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=protobuf" + ], + "filter": { + "name": "protobuf", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug benchmark 'dataset'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bench=dataset", + "--package=protobuf" + ], + "filter": { + "name": "dataset", + "kind": "bench" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'prost-types'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=prost-types" + ], + "filter": { + "name": "prost-types", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'prost-build'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=prost-build" + ], + "filter": { + "name": "prost-build", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'tests'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=tests" + ], + "filter": { + "name": "tests", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'tests-2015'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=tests-2015" + ], + "filter": { + "name": "tests-2015", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'tests-no-std'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=tests-no-std" + ], + "filter": { + "name": "tests-no-std", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in library 'single_include'", + "cargo": { + "args": [ + "test", + "--no-run", + "--lib", + "--package=single_include" + ], + "filter": { + "name": "single_include", + "kind": "lib" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug executable 'single_include'", + "cargo": { + "args": [ + "build", + "--bin=single_include", + "--package=single_include" + ], + "filter": { + "name": "single_include", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + }, + { + "type": "lldb", + "request": "launch", + "name": "Debug unit tests in executable 'single_include'", + "cargo": { + "args": [ + "test", + "--no-run", + "--bin=single_include", + "--package=single_include" + ], + "filter": { + "name": "single_include", + "kind": "bin" + } + }, + "args": [], + "cwd": "${workspaceFolder}" + } + ] +} \ No newline at end of file diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index e45f5ac96..bca91240e 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,129 +1,5 @@ -Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput -Recommended.FieldMaskPathsDontRoundTrip.JsonOutput -Recommended.FieldMaskTooManyUnderscore.JsonOutput -Recommended.Proto2.JsonInput.FieldNameExtension.Validator -Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput -Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput -Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator -Recommended.Proto3.JsonInput.Int64FieldBeString.Validator -Recommended.Proto3.JsonInput.MapFieldValueIsNull -Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator -Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator -Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroBytes.ProtobufOutput -Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroEnum.ProtobufOutput -Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull -Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator -Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator -Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput -Required.DurationProtoInputTooLarge.JsonOutput -Required.DurationProtoInputTooSmall.JsonOutput -Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.JsonInput.Any.JsonOutput -Required.Proto3.JsonInput.Any.ProtobufOutput -Required.Proto3.JsonInput.AnyNested.JsonOutput -Required.Proto3.JsonInput.AnyNested.ProtobufOutput -Required.Proto3.JsonInput.AnyUnorderedTypeTag.JsonOutput -Required.Proto3.JsonInput.AnyUnorderedTypeTag.ProtobufOutput -Required.Proto3.JsonInput.AnyWithDuration.JsonOutput -Required.Proto3.JsonInput.AnyWithDuration.ProtobufOutput -Required.Proto3.JsonInput.AnyWithFieldMask.JsonOutput -Required.Proto3.JsonInput.AnyWithFieldMask.ProtobufOutput -Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.JsonOutput -Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.ProtobufOutput -Required.Proto3.JsonInput.AnyWithStruct.JsonOutput -Required.Proto3.JsonInput.AnyWithStruct.ProtobufOutput -Required.Proto3.JsonInput.AnyWithTimestamp.JsonOutput -Required.Proto3.JsonInput.AnyWithTimestamp.ProtobufOutput -Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput -Required.Proto3.JsonInput.AnyWithValueForInteger.ProtobufOutput -Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput -Required.Proto3.JsonInput.AnyWithValueForJsonObject.ProtobufOutput -Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput -Required.Proto3.JsonInput.BoolMapEscapedKey.ProtobufOutput -Required.Proto3.JsonInput.BoolMapField.JsonOutput -Required.Proto3.JsonInput.BoolMapField.ProtobufOutput -Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput -Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.ProtobufOutput -Required.Proto3.JsonInput.DurationMaxValue.JsonOutput -Required.Proto3.JsonInput.DurationMaxValue.ProtobufOutput -Required.Proto3.JsonInput.DurationMinValue.JsonOutput -Required.Proto3.JsonInput.DurationMinValue.ProtobufOutput -Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput -Required.Proto3.JsonInput.DurationRepeatedValue.ProtobufOutput -Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput -Required.Proto3.JsonInput.EmptyFieldMask.ProtobufOutput -Required.Proto3.JsonInput.EnumField.JsonOutput -Required.Proto3.JsonInput.EnumField.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAlias.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.ProtobufOutput -Required.Proto3.JsonInput.EnumRepeatedField.JsonOutput -Required.Proto3.JsonInput.EnumRepeatedField.ProtobufOutput -Required.Proto3.JsonInput.FieldMask.JsonOutput -Required.Proto3.JsonInput.FieldMask.ProtobufOutput -Required.Proto3.JsonInput.OneofFieldDuplicate -Required.Proto3.JsonInput.RepeatedListValue.JsonOutput -Required.Proto3.JsonInput.RepeatedListValue.ProtobufOutput -Required.Proto3.JsonInput.RepeatedValue.JsonOutput -Required.Proto3.JsonInput.RepeatedValue.ProtobufOutput -Required.Proto3.JsonInput.Struct.JsonOutput -Required.Proto3.JsonInput.Struct.ProtobufOutput -Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput -Required.Proto3.JsonInput.StructWithEmptyListValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampMinValue.JsonOutput -Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput -Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput -Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput -Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput -Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput -Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput -Required.Proto3.JsonInput.ValueAcceptFloat.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput -Required.Proto3.JsonInput.ValueAcceptInteger.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptList.JsonOutput -Required.Proto3.JsonInput.ValueAcceptList.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput -Required.Proto3.JsonInput.ValueAcceptNull.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput -Required.Proto3.JsonInput.ValueAcceptObject.ProtobufOutput -Required.Proto3.JsonInput.ValueAcceptString.JsonOutput -Required.Proto3.JsonInput.ValueAcceptString.ProtobufOutput -Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Default.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKey.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKeyInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateValueInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.MissingDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.NonDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Unordered.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Default.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKey.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKeyInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateValueInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.MissingDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.NonDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Unordered.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.DefaultValue.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForDifferentField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForSameField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.NonDefaultValue.JsonOutput -Required.TimestampProtoInputTooLarge.JsonOutput -Required.TimestampProtoInputTooSmall.JsonOutput +Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput +Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput +Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator diff --git a/conformance/succeeding_tests.txt b/conformance/succeeding_tests.txt new file mode 100644 index 000000000..a9fc4cc37 --- /dev/null +++ b/conformance/succeeding_tests.txt @@ -0,0 +1,3 @@ +Required.Proto3.JsonInput.EnumField.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAlias.ProtobufOutput diff --git a/output.log b/output.log new file mode 100644 index 000000000..8567e88b3 --- /dev/null +++ b/output.log @@ -0,0 +1,55 @@ + Finished test [unoptimized + debuginfo] target(s) in 0.06s + Running unittests (target/debug/deps/conformance-96effac175dc0f5f) + +running 0 tests + +test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s + + Running tests/conformance.rs (target/debug/deps/conformance-9fdcb6d434b04b3a) + +running 1 test +thread 'main' panicked at 'all times should be after the epoch: SystemTimeError(62135596801s)', /Users/myassin/.cargo/registry/src/github.com-1ecc6299db9ec823/humantime-2.1.0/src/date.rs:255:14 +note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace +[libprotobuf ERROR /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:322] Required.TimestampProtoInputTooSmall.JsonOutput: unexpected EOF from test program +[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:163] Trying to reap child, pid=72362 +[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:176] child killed by signal 1 +thread 'main' panicked at 'a Display implementation returned an error unexpectedly: Error', /rustc/db9d1b20bba1968c1ec1fc49616d4742c1725b4b/library/alloc/src/string.rs:2401:14 +note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace +[libprotobuf ERROR /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:322] Required.TimestampProtoInputTooLarge.JsonOutput: unexpected EOF from test program +[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:163] Trying to reap child, pid=72365 +[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:176] child killed by signal 1 + +CONFORMANCE TEST BEGIN ==================================== + +ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 0}" requested_output_format: PROTOBUF message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `0`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" +ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 0}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `0`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" +ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 1}" requested_output_format: PROTOBUF message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `1`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" +ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 1}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `1`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" +ERROR, test=Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator: Expected JSON payload but got type 1. request=json_payload: "{\"optionalNestedEnum\": 123}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `123`, expected a valid String string or integer at line 1 column 26 at optionalNestedEnum" + +These tests failed. If they can't be fixed right now, you can add them to the failure list so the overall suite can succeed. Add them to the failure list by running: + ./update_failure_list.py failing_tests.txt --add failing_tests.txt + + Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput + Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput + Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput + Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput + Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator + +CONFORMANCE SUITE FAILED: 1883 successes, 0 skipped, 125 expected failures, 5 unexpected failures. + +test test_conformance ... FAILED + +failures: + +---- test_conformance stdout ---- +thread 'test_conformance' panicked at 'proto conformance test failed', conformance/tests/conformance.rs:32:5 +note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace + + +failures: + test_conformance + +test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.29s + +error: test failed, to rerun pass '-p conformance --test conformance' diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 6735b75e1..aec2e0b81 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -44,49 +44,54 @@ fn push_indent(buf: &mut String, depth: u8) { } } -/// Returns (serializer, deserializer) function names to use in serde -/// serialize_with and deserialize_with macros respectively. If none are -/// specified, the default works fine. -/// If collection is true, the return is no longer the function, but instead, -/// the Visitor type that will be used for either the repeated helper or -/// custom map helper. -fn get_custom_json_type_mappers( - ty: &str, - optional: bool, - collection: bool, -) -> (Option<&str>, Option<&str>) { - match (ty, optional, collection) { - ("bool", false, false) => (None, Some("::prost_types::bool_visitor::deserialize")), - ("bool", true, false) => (None, Some("::prost_types::bool_opt_visitor::deserialize")), - ("bool", false, true) => (None, Some("::prost_types::bool_visitor::BoolVisitor")), - ("i32", false, false) => (None, Some("::prost_types::i32_visitor::deserialize")), - ("i32", true, false) => (None, Some("::prost_types::i32_opt_visitor::deserialize")), - ("i32", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor")), - ("i64", false, false) => (None, Some("::prost_types::i64_visitor::deserialize")), - ("i64", true, false) => (None, Some("::prost_types::i64_opt_visitor::deserialize")), - ("i64", false, true) => (None, Some("::prost_types::i64_visitor::I64Visitor")), - ("u32", false, false) => (None, Some("::prost_types::u32_visitor::deserialize")), - ("u32", true, false) => (None, Some("::prost_types::u32_opt_visitor::deserialize")), - ("u32", false, true) => (None, Some("::prost_types::u32_visitor::U32Visitor")), - ("u64", false, false) => (None, Some("::prost_types::u64_visitor::deserialize")), - ("u64", true, false) => (None, Some("::prost_types::u64_opt_visitor::deserialize")), - ("u64", false, true) => (None, Some("::prost_types::u64_visitor::U64Visitor")), - ("f64", false, false) => (Some("<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::f64_visitor::deserialize")), - ("f64", true, false) => (Some("::prost_types::f64_opt_visitor::serialize"), Some("::prost_types::f64_opt_visitor::deserialize")), - ("f64", false, true) => (Some("::prost_types::f64_visitor::F64Serializer"), Some("::prost_types::f64_visitor::F64Visitor")), - ("f32", false, false) => (Some("<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::f32_visitor::deserialize")), - ("f32", true, false) => (Some("::prost_types::f32_opt_visitor::serialize"), Some("::prost_types::f32_opt_visitor::deserialize")), - ("f32", false, true) => (Some("::prost_types::f32_visitor::F32Serializer"), Some("::prost_types::f32_visitor::F32Visitor")), - ("::prost::alloc::string::String", false, false) => (None, Some("::prost_types::string_visitor::deserialize")), - ("::prost::alloc::string::String", true, false) => (None, Some("::prost_types::string_opt_visitor::deserialize")), - ("::prost::alloc::vec::Vec", false, false) => (Some("<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize"), Some("::prost_types::vec_u8_visitor::deserialize")), - ("::prost::alloc::vec::Vec", true, false) => (Some("::prost_types::vec_u8_opt_visitor::serialize"), Some("::prost_types::vec_u8_opt_visitor::deserialize")), - ("::prost::alloc::vec::Vec", false, true) => (Some("::prost_types::vec_u8_visitor::VecU8Serializer"), Some("::prost_types::vec_u8_visitor::VecU8Visitor")), +impl<'a> CodeGenerator<'a> { + /// Returns (serializer, deserializer) function names to use in serde + /// serialize_with and deserialize_with macros respectively. If none are + /// specified, the default works fine. + /// If collection is true, the return is no longer the function, but instead, + /// the Visitor type that will be used for either the repeated helper or + /// custom map helper. + fn get_custom_json_type_mappers( + &self, + ty: &str, + type_name: String, + optional: bool, + collection: bool, + ) -> (Option, Option) { + match (ty, optional, collection) { + ("bool", false, false) => (None, Some("::prost_types::bool_visitor::deserialize".to_string())), + ("bool", true, false) => (None, Some("::prost_types::bool_opt_visitor::deserialize".to_string())), + ("bool", false, true) => (None, Some("::prost_types::bool_visitor::BoolVisitor".to_string())), + ("i32", false, false) => (None, Some("::prost_types::i32_visitor::deserialize".to_string())), + ("i32", true, false) => (None, Some("::prost_types::i32_opt_visitor::deserialize".to_string())), + ("i32", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor".to_string())), + ("enum", false, false) => (Some(format!("::prost_types::enum_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), + ("enum", true, false) => (Some(format!("::prost_types::enum_opt_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_opt_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), + ("enum", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor".to_string())), + ("i64", false, false) => (None, Some("::prost_types::i64_visitor::deserialize".to_string())), + ("i64", true, false) => (None, Some("::prost_types::i64_opt_visitor::deserialize".to_string())), + ("i64", false, true) => (None, Some("::prost_types::i64_visitor::I64Visitor".to_string())), + ("u32", false, false) => (None, Some("::prost_types::u32_visitor::deserialize".to_string())), + ("u32", true, false) => (None, Some("::prost_types::u32_opt_visitor::deserialize".to_string())), + ("u32", false, true) => (None, Some("::prost_types::u32_visitor::U32Visitor".to_string())), + ("u64", false, false) => (None, Some("::prost_types::u64_visitor::deserialize".to_string())), + ("u64", true, false) => (None, Some("::prost_types::u64_opt_visitor::deserialize".to_string())), + ("u64", false, true) => (None, Some("::prost_types::u64_visitor::U64Visitor".to_string())), + ("f64", false, false) => (Some("<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::f64_visitor::deserialize".to_string())), + ("f64", true, false) => (Some("::prost_types::f64_opt_visitor::serialize".to_string()), Some("::prost_types::f64_opt_visitor::deserialize".to_string())), + ("f64", false, true) => (Some("::prost_types::f64_visitor::F64Serializer".to_string()), Some("::prost_types::f64_visitor::F64Visitor".to_string())), + ("f32", false, false) => (Some("<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::f32_visitor::deserialize".to_string())), + ("f32", true, false) => (Some("::prost_types::f32_opt_visitor::serialize".to_string()), Some("::prost_types::f32_opt_visitor::deserialize".to_string())), + ("f32", false, true) => (Some("::prost_types::f32_visitor::F32Serializer".to_string()), Some("::prost_types::f32_visitor::F32Visitor".to_string())), + ("::prost::alloc::string::String", false, false) => (None, Some("::prost_types::string_visitor::deserialize".to_string())), + ("::prost::alloc::string::String", true, false) => (None, Some("::prost_types::string_opt_visitor::deserialize".to_string())), + ("::prost::alloc::vec::Vec", false, false) => (Some("<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::vec_u8_visitor::deserialize".to_string())), + ("::prost::alloc::vec::Vec", true, false) => (Some("::prost_types::vec_u8_opt_visitor::serialize".to_string()), Some("::prost_types::vec_u8_opt_visitor::deserialize".to_string())), + ("::prost::alloc::vec::Vec", false, true) => (Some("::prost_types::vec_u8_visitor::VecU8Serializer".to_string()), Some("::prost_types::vec_u8_visitor::VecU8Visitor".to_string())), (_,_, _) => (None, None) } -} + } -impl<'a> CodeGenerator<'a> { pub fn generate( config: &mut Config, message_graph: &MessageGraph, @@ -383,6 +388,8 @@ impl<'a> CodeGenerator<'a> { field_name: &str, key_ty: &str, value_ty: &str, + key_type_name: String, + value_type_name: String, map_type: &str, json_name: &str, ) { @@ -399,8 +406,10 @@ impl<'a> CodeGenerator<'a> { )); self.buf.push('\n'); - let (key_se_opt, key_de_opt) = get_custom_json_type_mappers(key_ty, false, true); - let (value_se_opt, value_de_opt) = get_custom_json_type_mappers(value_ty, false, true); + let (key_se_opt, key_de_opt) = + self.get_custom_json_type_mappers(key_ty, key_type_name, false, true); + let (value_se_opt, value_de_opt) = + self.get_custom_json_type_mappers(value_ty, value_type_name, false, true); push_indent(&mut self.buf, self.depth); match (key_se_opt, key_de_opt, value_se_opt, value_de_opt, map_type) { @@ -518,6 +527,7 @@ impl<'a> CodeGenerator<'a> { &mut self, fq_message_name: &str, ty: &str, + type_name: String, field_name: &str, optional: bool, repeated: bool, @@ -535,7 +545,7 @@ impl<'a> CodeGenerator<'a> { // Add custom deserializers and optionally serializers for most primitive types // and their optional and repeated counterparts. match ( - get_custom_json_type_mappers(ty, optional, repeated), + self.get_custom_json_type_mappers(ty, type_name, optional, repeated), repeated, ) { ((Some(se), Some(de)), false) => { @@ -687,9 +697,15 @@ impl<'a> CodeGenerator<'a> { self.buf.push_str("\")]\n"); self.append_field_attributes(fq_message_name, field.name()); + let ty_or_enum = match type_ { + Type::Enum => "enum".to_string(), + _ => ty.clone(), + }; + self.append_json_field_attributes( fq_message_name, - &ty, + &ty_or_enum, + field.type_name().to_string(), field.name(), optional, repeated, @@ -754,11 +770,21 @@ impl<'a> CodeGenerator<'a> { field.number() )); self.append_field_attributes(fq_message_name, field.name()); + let key_ty_or_enum = match key.r#type() { + Type::Enum => "enum".to_string(), + _ => key_ty.clone(), + }; + let value_ty_or_enum = match value.r#type() { + Type::Enum => "enum".to_string(), + _ => value_ty.clone(), + }; self.append_json_map_field_attributes( fq_message_name, field.name(), - &key_ty, - &value_ty, + &key_ty_or_enum, + &value_ty_or_enum, + key.type_name().to_string(), + value.type_name().to_string(), map_type.rust_type(), field.json_name(), ); diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 8844ffa82..bd8c8b08b 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -286,6 +286,10 @@ fn try_enumeration(input: TokenStream) -> Result { panic!("Enumeration must have at least one variant"); } + if variants.len() != proto_names.len() { + panic!("Number of annotated protonames was unexpected. You probably want to upgrade prost-build."); + } + let default = variants[0].0.clone(); let is_valid = variants @@ -357,6 +361,7 @@ fn try_enumeration(input: TokenStream) -> Result { }.to_string() } } + impl #impl_generics ::core::str::FromStr for #ident #ty_generics #where_clause { type Err = &'static str; fn from_str(value: &str) -> Result { diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 754d337f3..737073eac 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -483,6 +483,19 @@ pub mod enum_visitor { } } + fn visit_i32(self, value: i32) -> Result + where + E: serde::de::Error, + { + match T::try_from(value) { + Ok(en) => Ok(en.into()), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Signed(value as i64), + &self, + )), + } + } + fn visit_unit(self) -> Result where E: serde::de::Error, @@ -521,6 +534,131 @@ pub mod enum_visitor { Ok(t) => serializer.serialize_str(&t.to_string()), } } + + pub struct EnumSerializer + where + T: std::convert::TryFrom + ToString, + { + _type: std::marker::PhantomData, + } + + impl crate::SerializeMethod for EnumSerializer + where + T: std::convert::TryFrom + ToString, + { + type Value = i32; + + fn serialize(value: &i32, serializer: S) -> Result + where + S: serde::Serializer, + { + match T::try_from(*value) { + Err(_) => Err(serde::ser::Error::custom("invalid enum value")), + Ok(t) => serializer.serialize_str(&t.to_string()), + } + } + } +} + +pub mod enum_opt_visitor { + struct EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + _type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid String string or integer") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match T::from_str(value) { + Ok(en) => Ok(Some(en.into())), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &self, + )), + } + } + + fn visit_i32(self, value: i32) -> Result + where + E: serde::de::Error, + { + match T::try_from(value) { + Ok(en) => Ok(Some(en.into())), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Signed(value as i64), + &self, + )), + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + + ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + deserializer.deserialize_any(EnumVisitor::<'de, T> { + _type: &std::marker::PhantomData, + }) + } + + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + use crate::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(enum_int) => crate::enum_visitor::EnumSerializer::::serialize(enum_int, serializer), + } + } } pub mod map_custom_serializer { From 1ab98b7a3b0b7011be2b60034c493e6a533a2f0e Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Thu, 17 Feb 2022 13:45:54 -0800 Subject: [PATCH 21/30] cleanup + make enums work mostly --- .gitignore | 1 + .vscode/launch.json | 291 ------------------------------ conformance/failing_tests.txt | 127 ++++++++++++- conformance/succeeding_tests.txt | 3 - output.log | 55 ------ prost-build/src/code_generator.rs | 2 +- prost-derive/src/lib.rs | 2 +- prost-types/src/lib.rs | 63 ++++++- 8 files changed, 180 insertions(+), 364 deletions(-) delete mode 100644 .vscode/launch.json delete mode 100644 conformance/succeeding_tests.txt delete mode 100644 output.log diff --git a/.gitignore b/.gitignore index a9d37c560..110d01053 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ target Cargo.lock +.vscode \ No newline at end of file diff --git a/.vscode/launch.json b/.vscode/launch.json deleted file mode 100644 index e481d0956..000000000 --- a/.vscode/launch.json +++ /dev/null @@ -1,291 +0,0 @@ -{ - // Use IntelliSense to learn about possible attributes. - // Hover to view descriptions of existing attributes. - // For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387 - "version": "0.2.0", - "configurations": [ - { - "type": "lldb", - "request": "launch", - "name": "Debug executable 'conformance'", - "cargo": { - "args": [ - "build", - "--bin=conformance", - "--package=conformance" - ], - "filter": { - "name": "conformance", - "kind": "bin" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in executable 'conformance'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bin=conformance", - "--package=conformance" - ], - "filter": { - "name": "conformance", - "kind": "bin" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug integration test 'conformance'", - "cargo": { - "args": [ - "test", - "--no-run", - "--test=conformance", - "--package=conformance" - ], - "filter": { - "name": "conformance", - "kind": "test" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'prost'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=prost" - ], - "filter": { - "name": "prost", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug benchmark 'varint'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bench=varint", - "--package=prost" - ], - "filter": { - "name": "varint", - "kind": "bench" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'protobuf'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=protobuf" - ], - "filter": { - "name": "protobuf", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug benchmark 'dataset'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bench=dataset", - "--package=protobuf" - ], - "filter": { - "name": "dataset", - "kind": "bench" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'prost-types'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=prost-types" - ], - "filter": { - "name": "prost-types", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'prost-build'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=prost-build" - ], - "filter": { - "name": "prost-build", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'tests'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=tests" - ], - "filter": { - "name": "tests", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'tests-2015'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=tests-2015" - ], - "filter": { - "name": "tests-2015", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'tests-no-std'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=tests-no-std" - ], - "filter": { - "name": "tests-no-std", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in library 'single_include'", - "cargo": { - "args": [ - "test", - "--no-run", - "--lib", - "--package=single_include" - ], - "filter": { - "name": "single_include", - "kind": "lib" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug executable 'single_include'", - "cargo": { - "args": [ - "build", - "--bin=single_include", - "--package=single_include" - ], - "filter": { - "name": "single_include", - "kind": "bin" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - }, - { - "type": "lldb", - "request": "launch", - "name": "Debug unit tests in executable 'single_include'", - "cargo": { - "args": [ - "test", - "--no-run", - "--bin=single_include", - "--package=single_include" - ], - "filter": { - "name": "single_include", - "kind": "bin" - } - }, - "args": [], - "cwd": "${workspaceFolder}" - } - ] -} \ No newline at end of file diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index bca91240e..fafabaabf 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -1,5 +1,124 @@ -Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput -Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput -Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput +Recommended.FieldMaskNumbersDontRoundTrip.JsonOutput +Recommended.FieldMaskPathsDontRoundTrip.JsonOutput +Recommended.FieldMaskTooManyUnderscore.JsonOutput +Recommended.Proto2.JsonInput.FieldNameExtension.Validator +Recommended.Proto3.JsonInput.BytesFieldBase64Url.JsonOutput +Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput +Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator +Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator +Recommended.Proto3.JsonInput.Int64FieldBeString.Validator +Recommended.Proto3.JsonInput.MapFieldValueIsNull +Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator +Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator +Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroBytes.ProtobufOutput +Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput +Recommended.Proto3.JsonInput.OneofZeroEnum.ProtobufOutput +Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull +Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator +Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator +Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput +Required.DurationProtoInputTooLarge.JsonOutput +Required.DurationProtoInputTooSmall.JsonOutput +Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput +Required.Proto3.JsonInput.Any.JsonOutput +Required.Proto3.JsonInput.Any.ProtobufOutput +Required.Proto3.JsonInput.AnyNested.JsonOutput +Required.Proto3.JsonInput.AnyNested.ProtobufOutput +Required.Proto3.JsonInput.AnyUnorderedTypeTag.JsonOutput +Required.Proto3.JsonInput.AnyUnorderedTypeTag.ProtobufOutput +Required.Proto3.JsonInput.AnyWithDuration.JsonOutput +Required.Proto3.JsonInput.AnyWithDuration.ProtobufOutput +Required.Proto3.JsonInput.AnyWithFieldMask.JsonOutput +Required.Proto3.JsonInput.AnyWithFieldMask.ProtobufOutput +Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.JsonOutput +Required.Proto3.JsonInput.AnyWithInt32ValueWrapper.ProtobufOutput +Required.Proto3.JsonInput.AnyWithStruct.JsonOutput +Required.Proto3.JsonInput.AnyWithStruct.ProtobufOutput +Required.Proto3.JsonInput.AnyWithTimestamp.JsonOutput +Required.Proto3.JsonInput.AnyWithTimestamp.ProtobufOutput +Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForInteger.ProtobufOutput +Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput +Required.Proto3.JsonInput.AnyWithValueForJsonObject.ProtobufOutput +Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput +Required.Proto3.JsonInput.BoolMapEscapedKey.ProtobufOutput +Required.Proto3.JsonInput.BoolMapField.JsonOutput +Required.Proto3.JsonInput.BoolMapField.ProtobufOutput +Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput +Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput +Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.ProtobufOutput +Required.Proto3.JsonInput.DurationMaxValue.JsonOutput +Required.Proto3.JsonInput.DurationMaxValue.ProtobufOutput +Required.Proto3.JsonInput.DurationMinValue.JsonOutput +Required.Proto3.JsonInput.DurationMinValue.ProtobufOutput +Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput +Required.Proto3.JsonInput.DurationRepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput +Required.Proto3.JsonInput.EmptyFieldMask.ProtobufOutput Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator +Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasDifferentCase.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasLowerCase.ProtobufOutput +Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.JsonOutput +Required.Proto3.JsonInput.EnumFieldWithAliasUseAlias.ProtobufOutput +Required.Proto3.JsonInput.FieldMask.JsonOutput +Required.Proto3.JsonInput.FieldMask.ProtobufOutput +Required.Proto3.JsonInput.OneofFieldDuplicate +Required.Proto3.JsonInput.RepeatedListValue.JsonOutput +Required.Proto3.JsonInput.RepeatedListValue.ProtobufOutput +Required.Proto3.JsonInput.RepeatedValue.JsonOutput +Required.Proto3.JsonInput.RepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.Struct.JsonOutput +Required.Proto3.JsonInput.Struct.ProtobufOutput +Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput +Required.Proto3.JsonInput.StructWithEmptyListValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampMinValue.JsonOutput +Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput +Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput +Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput +Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput +Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput +Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput +Required.Proto3.JsonInput.ValueAcceptFloat.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptInteger.JsonOutput +Required.Proto3.JsonInput.ValueAcceptInteger.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptList.JsonOutput +Required.Proto3.JsonInput.ValueAcceptList.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptNull.JsonOutput +Required.Proto3.JsonInput.ValueAcceptNull.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptObject.JsonOutput +Required.Proto3.JsonInput.ValueAcceptObject.ProtobufOutput +Required.Proto3.JsonInput.ValueAcceptString.JsonOutput +Required.Proto3.JsonInput.ValueAcceptString.ProtobufOutput +Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Default.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKey.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKeyInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateValueInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.MissingDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.NonDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Unordered.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Default.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKey.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKeyInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateValueInMapEntry.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.MissingDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.NonDefault.JsonOutput +Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Unordered.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.DefaultValue.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForDifferentField.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForSameField.JsonOutput +Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.NonDefaultValue.JsonOutput +Required.TimestampProtoInputTooLarge.JsonOutput +Required.TimestampProtoInputTooSmall.JsonOutput diff --git a/conformance/succeeding_tests.txt b/conformance/succeeding_tests.txt deleted file mode 100644 index a9fc4cc37..000000000 --- a/conformance/succeeding_tests.txt +++ /dev/null @@ -1,3 +0,0 @@ -Required.Proto3.JsonInput.EnumField.ProtobufOutput -Required.Proto3.JsonInput.EnumFieldWithAlias.JsonOutput -Required.Proto3.JsonInput.EnumFieldWithAlias.ProtobufOutput diff --git a/output.log b/output.log deleted file mode 100644 index 8567e88b3..000000000 --- a/output.log +++ /dev/null @@ -1,55 +0,0 @@ - Finished test [unoptimized + debuginfo] target(s) in 0.06s - Running unittests (target/debug/deps/conformance-96effac175dc0f5f) - -running 0 tests - -test result: ok. 0 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s - - Running tests/conformance.rs (target/debug/deps/conformance-9fdcb6d434b04b3a) - -running 1 test -thread 'main' panicked at 'all times should be after the epoch: SystemTimeError(62135596801s)', /Users/myassin/.cargo/registry/src/github.com-1ecc6299db9ec823/humantime-2.1.0/src/date.rs:255:14 -note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace -[libprotobuf ERROR /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:322] Required.TimestampProtoInputTooSmall.JsonOutput: unexpected EOF from test program -[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:163] Trying to reap child, pid=72362 -[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:176] child killed by signal 1 -thread 'main' panicked at 'a Display implementation returned an error unexpectedly: Error', /rustc/db9d1b20bba1968c1ec1fc49616d4742c1725b4b/library/alloc/src/string.rs:2401:14 -note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace -[libprotobuf ERROR /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:322] Required.TimestampProtoInputTooLarge.JsonOutput: unexpected EOF from test program -[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:163] Trying to reap child, pid=72365 -[libprotobuf INFO /Users/myassin/Documents/prost/target/debug/build/protobuf-f93ae7e45fc7e3e4/out/protobufxBxnmy/protobuf-3.14.0/conformance/conformance_test_runner.cc:176] child killed by signal 1 - -CONFORMANCE TEST BEGIN ==================================== - -ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 0}" requested_output_format: PROTOBUF message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `0`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" -ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 0}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `0`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" -ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 1}" requested_output_format: PROTOBUF message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `1`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" -ERROR, test=Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput: Failed to parse input or produce output. request=json_payload: "{\"optionalNestedEnum\": 1}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `1`, expected a valid String string or integer at line 1 column 24 at optionalNestedEnum" -ERROR, test=Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator: Expected JSON payload but got type 1. request=json_payload: "{\"optionalNestedEnum\": 123}" requested_output_format: JSON message_type: "protobuf_test_messages.proto3.TestAllTypesProto3" test_category: JSON_TEST, response=parse_error: "error deserializing json: optionalNestedEnum: invalid type: integer `123`, expected a valid String string or integer at line 1 column 26 at optionalNestedEnum" - -These tests failed. If they can't be fixed right now, you can add them to the failure list so the overall suite can succeed. Add them to the failure list by running: - ./update_failure_list.py failing_tests.txt --add failing_tests.txt - - Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.JsonOutput - Required.Proto3.JsonInput.EnumFieldNumericValueNonZero.ProtobufOutput - Required.Proto3.JsonInput.EnumFieldNumericValueZero.JsonOutput - Required.Proto3.JsonInput.EnumFieldNumericValueZero.ProtobufOutput - Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator - -CONFORMANCE SUITE FAILED: 1883 successes, 0 skipped, 125 expected failures, 5 unexpected failures. - -test test_conformance ... FAILED - -failures: - ----- test_conformance stdout ---- -thread 'test_conformance' panicked at 'proto conformance test failed', conformance/tests/conformance.rs:32:5 -note: run with `RUST_BACKTRACE=1` environment variable to display a backtrace - - -failures: - test_conformance - -test result: FAILED. 0 passed; 1 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.29s - -error: test failed, to rerun pass '-p conformance --test conformance' diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index aec2e0b81..47eb22140 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -67,7 +67,7 @@ impl<'a> CodeGenerator<'a> { ("i32", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor".to_string())), ("enum", false, false) => (Some(format!("::prost_types::enum_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), ("enum", true, false) => (Some(format!("::prost_types::enum_opt_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_opt_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), - ("enum", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor".to_string())), + ("enum", false, true) => (Some(format!("::prost_types::enum_visitor::EnumSerializer<{}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_visitor::EnumVisitor::<{}>", self.resolve_ident(&type_name)))), ("i64", false, false) => (None, Some("::prost_types::i64_visitor::deserialize".to_string())), ("i64", true, false) => (None, Some("::prost_types::i64_opt_visitor::deserialize".to_string())), ("i64", false, true) => (None, Some("::prost_types::i64_visitor::I64Visitor".to_string())), diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index bd8c8b08b..3143ffe34 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -286,7 +286,7 @@ fn try_enumeration(input: TokenStream) -> Result { panic!("Enumeration must have at least one variant"); } - if variants.len() != proto_names.len() { + if variants.len() != proto_names.len() { panic!("Number of annotated protonames was unexpected. You probably want to upgrade prost-build."); } diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 737073eac..9560f504c 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -444,7 +444,7 @@ pub mod repeated_visitor { } pub mod enum_visitor { - struct EnumVisitor<'de, T> + pub struct EnumVisitor<'de, T> where T: ToString + std::str::FromStr @@ -455,6 +455,19 @@ pub mod enum_visitor { _type: &'de std::marker::PhantomData, } + impl crate::HasConstructor for EnumVisitor<'_, T> + where T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, +{ + fn new() -> Self { + return Self {_type: &std::marker::PhantomData}; + } + } + + #[cfg(feature = "std")] impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> where @@ -482,12 +495,11 @@ pub mod enum_visitor { )), } } - - fn visit_i32(self, value: i32) -> Result + fn visit_i64(self, value: i64) -> Result where E: serde::de::Error, { - match T::try_from(value) { + match T::try_from(value as i32) { Ok(en) => Ok(en.into()), Err(_) => Err(serde::de::Error::invalid_value( serde::de::Unexpected::Signed(value as i64), @@ -496,6 +508,20 @@ pub mod enum_visitor { } } + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + fn visit_unit(self) -> Result where E: serde::de::Error, @@ -584,7 +610,7 @@ pub mod enum_opt_visitor { type Value = std::option::Option; fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") + formatter.write_str("a valid string or integer representation of an enum") } fn visit_str(self, value: &str) -> Result @@ -600,11 +626,11 @@ pub mod enum_opt_visitor { } } - fn visit_i32(self, value: i32) -> Result + fn visit_i64(self, value: i64) -> Result where E: serde::de::Error, { - match T::try_from(value) { + match T::try_from(value as i32) { Ok(en) => Ok(Some(en.into())), Err(_) => Err(serde::de::Error::invalid_value( serde::de::Unexpected::Signed(value as i64), @@ -613,6 +639,20 @@ pub mod enum_opt_visitor { } } + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + fn visit_unit(self) -> Result where E: serde::de::Error, @@ -644,7 +684,10 @@ pub mod enum_opt_visitor { }) } - pub fn serialize(value: &std::option::Option, serializer: S) -> Result + pub fn serialize( + value: &std::option::Option, + serializer: S, + ) -> Result where S: serde::Serializer, T: ToString @@ -656,7 +699,9 @@ pub mod enum_opt_visitor { use crate::SerializeMethod; match value { None => serializer.serialize_none(), - Some(enum_int) => crate::enum_visitor::EnumSerializer::::serialize(enum_int, serializer), + Some(enum_int) => { + crate::enum_visitor::EnumSerializer::::serialize(enum_int, serializer) + } } } } From 9a547889ca04ae1c72575c80cc520ea570094f01 Mon Sep 17 00:00:00 2001 From: Mohamed Yassin Date: Thu, 17 Feb 2022 20:08:52 -0500 Subject: [PATCH 22/30] Fixed oneof and enum unknown handling --- conformance/failing_tests.txt | 7 ------- prost-build/src/code_generator.rs | 25 +++++++++++++++++++------ prost-types/src/lib.rs | 22 +++++----------------- 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index fafabaabf..03dffb0f4 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -12,8 +12,6 @@ Recommended.Proto3.JsonInput.Int64FieldBeString.Validator Recommended.Proto3.JsonInput.MapFieldValueIsNull Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator -Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroBytes.ProtobufOutput Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput Recommended.Proto3.JsonInput.OneofZeroEnum.ProtobufOutput Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull @@ -21,7 +19,6 @@ Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator -Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput @@ -116,9 +113,5 @@ Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateValueInMapEntry Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.MissingDefault.JsonOutput Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.NonDefault.JsonOutput Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Unordered.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.DefaultValue.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForDifferentField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForSameField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.NonDefaultValue.JsonOutput Required.TimestampProtoInputTooLarge.JsonOutput Required.TimestampProtoInputTooSmall.JsonOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 47eb22140..eb4a6d5e6 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -532,16 +532,19 @@ impl<'a> CodeGenerator<'a> { optional: bool, repeated: bool, json_name: &str, + oneof: bool, ) { if let None = self.config.json_mapping.get_first(fq_message_name) { return; } self.append_shared_json_field_attributes(field_name, json_name); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); - self.buf.push('\n'); + if !oneof { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); + self.buf.push('\n'); + } // Add custom deserializers and optionally serializers for most primitive types // and their optional and repeated counterparts. match ( @@ -710,6 +713,7 @@ impl<'a> CodeGenerator<'a> { optional, repeated, field.json_name(), + false, ); self.push_indent(); self.buf.push_str("pub "); @@ -871,9 +875,18 @@ impl<'a> CodeGenerator<'a> { field.number() )); self.append_field_attributes(&oneof_name, field.name()); - - self.push_indent(); let ty = self.resolve_type(&field, fq_message_name); + self.append_json_field_attributes( + &oneof_name, + &ty, + field.type_name().to_string(), + field.name(), + false, + false, + field.json_name(), + true, + ); + self.push_indent(); let boxed = (type_ == Type::Message || type_ == Type::Group) && self diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 9560f504c..6fb16a425 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -489,10 +489,7 @@ pub mod enum_visitor { { match T::from_str(value) { Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Str(value), - &self, - )), + Err(_) => Ok(T::default().into()), } } fn visit_i64(self, value: i64) -> Result @@ -501,10 +498,7 @@ pub mod enum_visitor { { match T::try_from(value as i32) { Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), + Err(_) => Ok(T::default().into()), } } @@ -526,7 +520,7 @@ pub mod enum_visitor { where E: serde::de::Error, { - Ok(Self::Value::default()) + Ok(T::default().into()) } } @@ -619,10 +613,7 @@ pub mod enum_opt_visitor { { match T::from_str(value) { Ok(en) => Ok(Some(en.into())), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Str(value), - &self, - )), + Err(_) => Ok(Some(T::default().into())), } } @@ -632,10 +623,7 @@ pub mod enum_opt_visitor { { match T::try_from(value as i32) { Ok(en) => Ok(Some(en.into())), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), + Err(_) => Ok(Some(T::default().into())), } } From 34c21f419246cd2f5b2ebd03d03f87d705059151 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Tue, 15 Mar 2022 21:42:22 -0700 Subject: [PATCH 23/30] review ready --- Cargo.toml | 1 + conformance/failing_tests.txt | 38 +- prost-build/src/code_generator.rs | 173 +- prost-derive/src/lib.rs | 2 +- prost-types/Cargo.toml | 3 +- prost-types/src/lib.rs | 2276 +------------------------- prost-types/src/protobuf.rs | 15 +- prost-types/src/serde.rs | 2546 +++++++++++++++++++++++++++++ protobuf/Cargo.toml | 2 +- 9 files changed, 2673 insertions(+), 2383 deletions(-) create mode 100644 prost-types/src/serde.rs diff --git a/Cargo.toml b/Cargo.toml index a1225ae6b..bbb7be207 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -43,6 +43,7 @@ bench = false default = ["prost-derive", "std"] no-recursion-limit = [] std = [] +json = [] [dependencies] bytes = { version = "1", default-features = false } diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index fafabaabf..e4c9be0ee 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -7,21 +7,17 @@ Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.DurationHasZeroFractionalDigit.Validator Recommended.Proto3.JsonInput.Int64FieldBeString.Validator Recommended.Proto3.JsonInput.MapFieldValueIsNull Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator -Recommended.Proto3.JsonInput.OneofZeroBytes.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroBytes.ProtobufOutput -Recommended.Proto3.JsonInput.OneofZeroEnum.JsonOutput -Recommended.Proto3.JsonInput.OneofZeroEnum.ProtobufOutput Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator +Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator -Recommended.Proto3.ProtobufInput.OneofZeroBytes.JsonOutput Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput @@ -45,20 +41,16 @@ Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput Required.Proto3.JsonInput.AnyWithValueForInteger.ProtobufOutput Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput Required.Proto3.JsonInput.AnyWithValueForJsonObject.ProtobufOutput -Required.Proto3.JsonInput.BoolMapEscapedKey.JsonOutput -Required.Proto3.JsonInput.BoolMapEscapedKey.ProtobufOutput -Required.Proto3.JsonInput.BoolMapField.JsonOutput -Required.Proto3.JsonInput.BoolMapField.ProtobufOutput Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.ProtobufOutput +Required.Proto3.JsonInput.DurationJsonInputTooLarge +Required.Proto3.JsonInput.DurationJsonInputTooSmall Required.Proto3.JsonInput.DurationMaxValue.JsonOutput Required.Proto3.JsonInput.DurationMaxValue.ProtobufOutput Required.Proto3.JsonInput.DurationMinValue.JsonOutput Required.Proto3.JsonInput.DurationMinValue.ProtobufOutput -Required.Proto3.JsonInput.DurationRepeatedValue.JsonOutput -Required.Proto3.JsonInput.DurationRepeatedValue.ProtobufOutput Required.Proto3.JsonInput.EmptyFieldMask.JsonOutput Required.Proto3.JsonInput.EmptyFieldMask.ProtobufOutput Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator @@ -83,10 +75,6 @@ Required.Proto3.JsonInput.TimestampMinValue.JsonOutput Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampWithNegativeOffset.JsonOutput -Required.Proto3.JsonInput.TimestampWithNegativeOffset.ProtobufOutput -Required.Proto3.JsonInput.TimestampWithPositiveOffset.JsonOutput -Required.Proto3.JsonInput.TimestampWithPositiveOffset.ProtobufOutput Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput @@ -102,23 +90,5 @@ Required.Proto3.JsonInput.ValueAcceptObject.ProtobufOutput Required.Proto3.JsonInput.ValueAcceptString.JsonOutput Required.Proto3.JsonInput.ValueAcceptString.ProtobufOutput Required.Proto3.ProtobufInput.UnknownVarint.ProtobufOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Default.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKey.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateKeyInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.DuplicateValueInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.MissingDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.NonDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.BOOL.BOOL.Unordered.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Default.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKey.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateKeyInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.DuplicateValueInMapEntry.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.MissingDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.NonDefault.JsonOutput -Required.Proto3.ProtobufInput.ValidDataMap.STRING.BYTES.Unordered.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.DefaultValue.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForDifferentField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.MultipleValuesForSameField.JsonOutput -Required.Proto3.ProtobufInput.ValidDataOneof.BYTES.NonDefaultValue.JsonOutput Required.TimestampProtoInputTooLarge.JsonOutput Required.TimestampProtoInputTooSmall.JsonOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 47eb22140..a5dd084d9 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -57,38 +57,40 @@ impl<'a> CodeGenerator<'a> { type_name: String, optional: bool, collection: bool, + map_key: bool, ) -> (Option, Option) { - match (ty, optional, collection) { - ("bool", false, false) => (None, Some("::prost_types::bool_visitor::deserialize".to_string())), - ("bool", true, false) => (None, Some("::prost_types::bool_opt_visitor::deserialize".to_string())), - ("bool", false, true) => (None, Some("::prost_types::bool_visitor::BoolVisitor".to_string())), - ("i32", false, false) => (None, Some("::prost_types::i32_visitor::deserialize".to_string())), - ("i32", true, false) => (None, Some("::prost_types::i32_opt_visitor::deserialize".to_string())), - ("i32", false, true) => (None, Some("::prost_types::i32_visitor::I32Visitor".to_string())), - ("enum", false, false) => (Some(format!("::prost_types::enum_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), - ("enum", true, false) => (Some(format!("::prost_types::enum_opt_visitor::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_opt_visitor::deserialize::<_, {}>", self.resolve_ident(&type_name)))), - ("enum", false, true) => (Some(format!("::prost_types::enum_visitor::EnumSerializer<{}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::enum_visitor::EnumVisitor::<{}>", self.resolve_ident(&type_name)))), - ("i64", false, false) => (None, Some("::prost_types::i64_visitor::deserialize".to_string())), - ("i64", true, false) => (None, Some("::prost_types::i64_opt_visitor::deserialize".to_string())), - ("i64", false, true) => (None, Some("::prost_types::i64_visitor::I64Visitor".to_string())), - ("u32", false, false) => (None, Some("::prost_types::u32_visitor::deserialize".to_string())), - ("u32", true, false) => (None, Some("::prost_types::u32_opt_visitor::deserialize".to_string())), - ("u32", false, true) => (None, Some("::prost_types::u32_visitor::U32Visitor".to_string())), - ("u64", false, false) => (None, Some("::prost_types::u64_visitor::deserialize".to_string())), - ("u64", true, false) => (None, Some("::prost_types::u64_opt_visitor::deserialize".to_string())), - ("u64", false, true) => (None, Some("::prost_types::u64_visitor::U64Visitor".to_string())), - ("f64", false, false) => (Some("<::prost_types::f64_visitor::F64Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::f64_visitor::deserialize".to_string())), - ("f64", true, false) => (Some("::prost_types::f64_opt_visitor::serialize".to_string()), Some("::prost_types::f64_opt_visitor::deserialize".to_string())), - ("f64", false, true) => (Some("::prost_types::f64_visitor::F64Serializer".to_string()), Some("::prost_types::f64_visitor::F64Visitor".to_string())), - ("f32", false, false) => (Some("<::prost_types::f32_visitor::F32Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::f32_visitor::deserialize".to_string())), - ("f32", true, false) => (Some("::prost_types::f32_opt_visitor::serialize".to_string()), Some("::prost_types::f32_opt_visitor::deserialize".to_string())), - ("f32", false, true) => (Some("::prost_types::f32_visitor::F32Serializer".to_string()), Some("::prost_types::f32_visitor::F32Visitor".to_string())), - ("::prost::alloc::string::String", false, false) => (None, Some("::prost_types::string_visitor::deserialize".to_string())), - ("::prost::alloc::string::String", true, false) => (None, Some("::prost_types::string_opt_visitor::deserialize".to_string())), - ("::prost::alloc::vec::Vec", false, false) => (Some("<::prost_types::vec_u8_visitor::VecU8Serializer as ::prost_types::SerializeMethod>::serialize".to_string()), Some("::prost_types::vec_u8_visitor::deserialize".to_string())), - ("::prost::alloc::vec::Vec", true, false) => (Some("::prost_types::vec_u8_opt_visitor::serialize".to_string()), Some("::prost_types::vec_u8_opt_visitor::deserialize".to_string())), - ("::prost::alloc::vec::Vec", false, true) => (Some("::prost_types::vec_u8_visitor::VecU8Serializer".to_string()), Some("::prost_types::vec_u8_visitor::VecU8Visitor".to_string())), - (_,_, _) => (None, None) + match (ty, optional, collection, map_key) { + ("bool", false, false, _) => (None, Some("::prost_types::serde::bool::deserialize".to_string())), + ("bool", true, false, _) => (None, Some("::prost_types::serde::bool_opt::deserialize".to_string())), + ("bool", _, true, false) => (None, Some("::prost_types::serde::bool::BoolVisitor".to_string())), + ("bool", _, true, true) => (Some("::prost_types::serde::bool_map_key::BoolKeySerializer".to_string()), Some("::prost_types::serde::bool_map_key::BoolVisitor".to_string())), + ("i32", false, false, _) => (None, Some("::prost_types::serde::i32::deserialize".to_string())), + ("i32", true, false, _) => (None, Some("::prost_types::serde::i32_opt::deserialize".to_string())), + ("i32", _, true, _) => (None, Some("::prost_types::serde::i32::I32Visitor".to_string())), + ("enum", false, false, _) => (Some(format!("::prost_types::serde::enum_serde::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_serde::deserialize::<_, {}>", self.resolve_ident(&type_name)))), + ("enum", true, false, _) => (Some(format!("::prost_types::serde::enum_opt::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_opt::deserialize::<_, {}>", self.resolve_ident(&type_name)))), + ("enum", _, true, _) => (Some(format!("::prost_types::serde::enum_serde::EnumSerializer<{}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_serde::EnumVisitor::<{}>", self.resolve_ident(&type_name)))), + ("i64", false, false, _) => (None, Some("::prost_types::serde::i64::deserialize".to_string())), + ("i64", true, false, _) => (None, Some("::prost_types::serde::i64_opt::deserialize".to_string())), + ("i64", _, true, _) => (None, Some("::prost_types::serde::i64::I64Visitor".to_string())), + ("u32", false, false, _) => (None, Some("::prost_types::serde::u32::deserialize".to_string())), + ("u32", true, false, _) => (None, Some("::prost_types::serde::u32_opt::deserialize".to_string())), + ("u32", _, true, _) => (None, Some("::prost_types::serde::u32::U32Visitor".to_string())), + ("u64", false, false, _) => (None, Some("::prost_types::serde::u64::deserialize".to_string())), + ("u64", true, false, _) => (None, Some("::prost_types::serde::u64_opt::deserialize".to_string())), + ("u64", _, true, _) => (None, Some("::prost_types::serde::u64::U64Visitor".to_string())), + ("f64", false, false, _) => (Some("<::prost_types::serde::f64::F64Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::f64::deserialize".to_string())), + ("f64", true, false, _) => (Some("::prost_types::serde::f64_opt::serialize".to_string()), Some("::prost_types::serde::f64_opt::deserialize".to_string())), + ("f64", _, true, _) => (Some("::prost_types::serde::f64::F64Serializer".to_string()), Some("::prost_types::serde::f64::F64Visitor".to_string())), + ("f32", false, false, _) => (Some("<::prost_types::serde::f32::F32Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::f32::deserialize".to_string())), + ("f32", true, false, _) => (Some("::prost_types::serde::f32_opt::serialize".to_string()), Some("::prost_types::serde::f32_opt::deserialize".to_string())), + ("f32", _, true, _) => (Some("::prost_types::serde::f32::F32Serializer".to_string()), Some("::prost_types::serde::f32::F32Visitor".to_string())), + ("::prost::alloc::string::String", false, false, _) => (None, Some("::prost_types::serde::string::deserialize".to_string())), + ("::prost::alloc::string::String", true, false, _) => (None, Some("::prost_types::serde::string_opt::deserialize".to_string())), + ("::prost::alloc::vec::Vec", false, false, _) => (Some("<::prost_types::serde::vec_u8::VecU8Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::vec_u8::deserialize".to_string())), + ("::prost::alloc::vec::Vec", true, false, _) => (Some("::prost_types::serde::vec_u8_opt::serialize".to_string()), Some("::prost_types::serde::vec_u8_opt::deserialize".to_string())), + ("::prost::alloc::vec::Vec", _, true, _) => (Some("::prost_types::serde::vec_u8::VecU8Serializer".to_string()), Some("::prost_types::serde::vec_u8::VecU8Visitor".to_string())), + (_,_, _, _) => (None, None) } } @@ -407,116 +409,137 @@ impl<'a> CodeGenerator<'a> { self.buf.push('\n'); let (key_se_opt, key_de_opt) = - self.get_custom_json_type_mappers(key_ty, key_type_name, false, true); + self.get_custom_json_type_mappers(key_ty, key_type_name, false, true, true); let (value_se_opt, value_de_opt) = - self.get_custom_json_type_mappers(value_ty, value_type_name, false, true); + self.get_custom_json_type_mappers(value_ty, value_type_name, false, true, false); push_indent(&mut self.buf, self.depth); match (key_se_opt, key_de_opt, value_se_opt, value_de_opt, map_type) { (Some(key_se), Some(key_de), Some(value_se), Some(value_de), "::std::collections::HashMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_to_custom_visitor::serialize::<_, {}, {}>")]"#, key_se, value_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_to_custom::serialize::<_, {}, {}>")]"#, key_se, value_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); } (None, Some(key_de), None, Some(value_de), "::std::collections::HashMap") => self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ), (Some(key_se), Some(key_de), None, Some(value_de), "::std::collections::HashMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); }, (Some(key_se), Some(key_de), None, None, "::std::collections::HashMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom::deserialize::<_, {}, _>")]"#, key_de) ); }, (None, Some(key_de), None, None, "::std::collections::HashMap") => self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom::deserialize::<_, {}, _>")]"#, key_de) ), (None, Some(key_de), Some(value_se), Some(value_de), "::std::collections::HashMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::map_custom_serializer::serialize::<_, _, {}>")]"#, value_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); }, + (None, None, Some(value_se), Some(value_de), "::std::collections::HashMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_value::serialize::<_, _, {}>")]"#, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_value::deserialize::<_, _, {}>")]"#, value_de) + ); + }, + (Some(key_se), Some(key_de), Some(value_se), Some(value_de), "::prost::alloc::collections::BTreeMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_to_custom_visitor::serialize::<_, {}, {}>")]"#, key_se, value_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_to_custom::serialize::<_, {}, {}>")]"#, key_se, value_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); } (None, Some(key_de), None, Some(value_de), "::prost::alloc::collections::BTreeMap") => self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ), (Some(key_se), Some(key_de), None, Some(value_de), "::prost::alloc::collections::BTreeMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); }, (Some(key_se), Some(key_de), None, None, "::prost::alloc::collections::BTreeMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_visitor::serialize::<_, {}, _>")]"#, key_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom::deserialize::<_, {}, _>")]"#, key_de) ); }, (None, Some(key_de), None, None, "::prost::alloc::collections::BTreeMap") => self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_visitor::deserialize::<_, {}, _>")]"#, key_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom::deserialize::<_, {}, _>")]"#, key_de) ), (None, Some(key_de), Some(value_se), Some(value_de), "::prost::alloc::collections::BTreeMap") => { self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::btree_map_custom_serializer::serialize::<_, _, {}>")]"#, value_se) + &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::btree_map_custom_to_custom_visitor::deserialize::<_, {}, {}>")]"#, key_de, value_de) + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); }, + (None, None, Some(value_se), Some(value_de), "::prost::alloc::collections::BTreeMap") => { + self.buf.push_str( + &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_value::serialize::<_, _, {}>")]"#, value_se) + ); + self.buf.push('\n'); + push_indent(&mut self.buf, self.depth); + self.buf.push_str( + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_value::deserialize::<_, _, {}>")]"#, value_de) + ); + }, (_, _, _, _, "::std::collections::HashMap") => self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::map_visitor::deserialize")]"#, + r#"#[serde(deserialize_with = "::prost_types::serde::map::deserialize")]"#, ), (_, _, _, _, "::prost::alloc::collections::BTreeMap") => self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::btree_map_visitor::deserialize")]"#, + r#"#[serde(deserialize_with = "::prost_types::serde::btree_map::deserialize")]"#, ), _ => (), } @@ -532,20 +555,23 @@ impl<'a> CodeGenerator<'a> { optional: bool, repeated: bool, json_name: &str, + is_oneof_field: bool, ) { if let None = self.config.json_mapping.get_first(fq_message_name) { return; } self.append_shared_json_field_attributes(field_name, json_name); - push_indent(&mut self.buf, self.depth); - self.buf - .push_str(r#"#[serde(skip_serializing_if = "::prost_types::is_default")]"#); - self.buf.push('\n'); + if !is_oneof_field { + push_indent(&mut self.buf, self.depth); + self.buf + .push_str(r#"#[serde(skip_serializing_if = "::prost_types::serde::is_default")]"#); + self.buf.push('\n'); + } // Add custom deserializers and optionally serializers for most primitive types // and their optional and repeated counterparts. match ( - self.get_custom_json_type_mappers(ty, type_name, optional, repeated), + self.get_custom_json_type_mappers(ty, type_name, optional, repeated, false), repeated, ) { ((Some(se), Some(de)), false) => { @@ -567,26 +593,26 @@ impl<'a> CodeGenerator<'a> { ((Some(se), Some(de)), true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(serialize_with = "::prost_types::repeated_visitor::serialize::<_, {}>")]"#, se), + &format!(r#"#[serde(serialize_with = "::prost_types::serde::repeated::serialize::<_, {}>")]"#, se), ); self.buf.push('\n'); push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, {}>")]"#, de), + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::repeated::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } ((None, Some(de)), true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - &format!(r#"#[serde(deserialize_with = "::prost_types::repeated_visitor::deserialize::<_, {}>")]"#, de), + &format!(r#"#[serde(deserialize_with = "::prost_types::serde::repeated::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } (_, true) => { push_indent(&mut self.buf, self.depth); self.buf.push_str( - r#"#[serde(deserialize_with = "::prost_types::vec_visitor::deserialize")]"#, + r#"#[serde(deserialize_with = "::prost_types::serde::vec::deserialize")]"#, ); self.buf.push('\n'); } @@ -710,6 +736,7 @@ impl<'a> CodeGenerator<'a> { optional, repeated, field.json_name(), + false, ); self.push_indent(); self.buf.push_str("pub "); @@ -872,8 +899,24 @@ impl<'a> CodeGenerator<'a> { )); self.append_field_attributes(&oneof_name, field.name()); - self.push_indent(); let ty = self.resolve_type(&field, fq_message_name); + let ty_or_enum = match type_ { + Type::Enum => "enum".to_string(), + _ => ty.clone(), + }; + + self.append_json_field_attributes( + &oneof_name, + &ty_or_enum, + field.type_name().to_string(), + field.name(), + false, + false, + field.json_name(), + true, + ); + + self.push_indent(); let boxed = (type_ == Type::Message || type_ == Type::Group) && self diff --git a/prost-derive/src/lib.rs b/prost-derive/src/lib.rs index 3143ffe34..bebdf9c11 100644 --- a/prost-derive/src/lib.rs +++ b/prost-derive/src/lib.rs @@ -282,7 +282,7 @@ fn try_enumeration(input: TokenStream) -> Result { // TODO(konradjniemiec): we need to default to not failing here, // and instead deriving the proto names to avoid breaking changes. - if variants.is_empty() || variants.len() != proto_names.len() { + if variants.is_empty() { panic!("Enumeration must have at least one variant"); } diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 4f7903b6c..3ba14233f 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -18,13 +18,14 @@ doctest = false [features] default = ["std"] std = ["prost/std"] +json = ["prost/json", "chrono"] [dependencies] base64 = "0.13" bytes = { version = "1", default-features = false } serde = { version = "1", features = ["derive"] } -humantime = { version = "2.1" } prost = { version = "0.9.0", path = "..", default-features = false, features = ["prost-derive"] } +chrono = { version = "0.4", optional = true } [dev-dependencies] proptest = "1" diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 9560f504c..aa09d6eb8 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -21,6 +21,9 @@ pub mod compiler { include!("compiler.rs"); } +#[cfg(feature = "json")] +pub mod serde; + // The Protobuf `Duration` and `Timestamp` types can't delegate to the standard library equivalents // because the Protobuf versions are signed. To make them easier to work with, `From` conversions // are defined in both directions. @@ -255,2279 +258,6 @@ impl TryFrom for std::time::SystemTime { } } -#[cfg(feature = "std")] -impl serde::Serialize for Timestamp { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str( - &humantime::format_rfc3339(std::time::SystemTime::try_from(self.clone()).unwrap()) - .to_string(), - ) - } -} - -struct TimestampVisitor; - -#[cfg(feature = "std")] -impl<'de> serde::de::Visitor<'de> for TimestampVisitor { - type Value = Timestamp; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid RFC 3339 timestamp string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - Ok(Timestamp::from( - humantime::parse_rfc3339(value).map_err(serde::de::Error::custom)?, - )) - } -} - -impl<'de> serde::Deserialize<'de> for Timestamp { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(TimestampVisitor) - } -} - -pub trait HasConstructor { - fn new() -> Self; -} - -pub struct MyType<'de, T: serde::de::Visitor<'de> + HasConstructor>( - >::Value, -); - -impl<'de, T> serde::Deserialize<'de> for MyType<'de, T> -where - T: serde::de::Visitor<'de> + HasConstructor, -{ - fn deserialize(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer - .deserialize_any(T::new()) - .map(|x| MyType { 0: x }) - } -} - -pub fn is_default(t: &T) -> bool { - t == &T::default() -} - -pub mod vec_visitor { - struct VecVisitor<'de, T> - where - T: serde::Deserialize<'de>, - { - _vec_type: &'de std::marker::PhantomData, - } - - #[cfg(feature = "std")] - impl<'de, T: serde::Deserialize<'de>> serde::de::Visitor<'de> for VecVisitor<'de, T> { - type Value = Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); - loop { - match seq.next_element()? { - Some(el) => res.push(el), - None => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T: 'de + serde::Deserialize<'de>>( - deserializer: D, - ) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(VecVisitor::<'de, T> { - _vec_type: &std::marker::PhantomData, - }) - } -} - -pub mod repeated_visitor { - struct VecVisitor<'de, T> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - { - _vec_type: &'de std::marker::PhantomData, - } - - #[cfg(feature = "std")] - impl<'de, T> serde::de::Visitor<'de> for VecVisitor<'de, T> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - { - type Value = Vec<>::Value>; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_seq(self, mut seq: A) -> Result - where - A: serde::de::SeqAccess<'de>, - { - let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); - loop { - let response: std::option::Option> = seq.next_element()?; - match response { - Some(el) => res.push(el.0), - None => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T: 'de + serde::de::Visitor<'de> + crate::HasConstructor>( - deserializer: D, - ) -> Result>::Value>, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(VecVisitor::<'de, T> { - _vec_type: &std::marker::PhantomData, - }) - } - - pub fn serialize( - value: &Vec<::Value>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - F: crate::SerializeMethod, - { - use serde::ser::SerializeSeq; - let mut seq = serializer.serialize_seq(Some(value.len()))?; - for e in value { - seq.serialize_element(&crate::MySeType:: { val: e })?; - } - seq.end() - } -} - -pub mod enum_visitor { - pub struct EnumVisitor<'de, T> - where - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - _type: &'de std::marker::PhantomData, - } - - impl crate::HasConstructor for EnumVisitor<'_, T> - where T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, -{ - fn new() -> Self { - return Self {_type: &std::marker::PhantomData}; - } - } - - - #[cfg(feature = "std")] - impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> - where - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - type Value = i32; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match T::from_str(value) { - Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Str(value), - &self, - )), - } - } - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - match T::try_from(value as i32) { - Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), - } - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - self.visit_i64(value as i64) - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - self.visit_i64(value as i64) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - T: 'de - + ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - deserializer.deserialize_any(EnumVisitor::<'de, T> { - _type: &std::marker::PhantomData, - }) - } - - pub fn serialize(value: &i32, serializer: S) -> Result - where - S: serde::Serializer, - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - match T::try_from(*value) { - Err(_) => Err(serde::ser::Error::custom("invalid enum value")), - Ok(t) => serializer.serialize_str(&t.to_string()), - } - } - - pub struct EnumSerializer - where - T: std::convert::TryFrom + ToString, - { - _type: std::marker::PhantomData, - } - - impl crate::SerializeMethod for EnumSerializer - where - T: std::convert::TryFrom + ToString, - { - type Value = i32; - - fn serialize(value: &i32, serializer: S) -> Result - where - S: serde::Serializer, - { - match T::try_from(*value) { - Err(_) => Err(serde::ser::Error::custom("invalid enum value")), - Ok(t) => serializer.serialize_str(&t.to_string()), - } - } - } -} - -pub mod enum_opt_visitor { - struct EnumVisitor<'de, T> - where - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - _type: &'de std::marker::PhantomData, - } - - #[cfg(feature = "std")] - impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> - where - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid string or integer representation of an enum") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match T::from_str(value) { - Ok(en) => Ok(Some(en.into())), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Str(value), - &self, - )), - } - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - match T::try_from(value as i32) { - Ok(en) => Ok(Some(en.into())), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), - } - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - self.visit_i64(value as i64) - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - self.visit_i64(value as i64) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - T: 'de - + ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - deserializer.deserialize_any(EnumVisitor::<'de, T> { - _type: &std::marker::PhantomData, - }) - } - - pub fn serialize( - value: &std::option::Option, - serializer: S, - ) -> Result - where - S: serde::Serializer, - T: ToString - + std::str::FromStr - + std::convert::Into - + std::convert::TryFrom - + Default, - { - use crate::SerializeMethod; - match value { - None => serializer.serialize_none(), - Some(enum_int) => { - crate::enum_visitor::EnumSerializer::::serialize(enum_int, serializer) - } - } - } -} - -pub mod map_custom_serializer { - pub fn serialize( - value: &std::collections::HashMap::Value>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - K: serde::Serialize + std::cmp::Eq + std::hash::Hash, - G: crate::SerializeMethod, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry(&key, &crate::MySeType:: { val: value })?; - } - map.end() - } -} - -pub mod btree_map_custom_serializer { - pub fn serialize( - value: &std::collections::BTreeMap::Value>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - K: serde::Serialize + std::cmp::Eq + std::cmp::Ord, - G: crate::SerializeMethod, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry(&key, &crate::MySeType:: { val: value })?; - } - map.end() - } -} - -pub mod map_custom_visitor { - struct MapVisitor<'de, T, V> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - V: serde::Deserialize<'de>, - { - _map_type: fn() -> ( - std::marker::PhantomData<&'de T>, - std::marker::PhantomData<&'de V>, - ), - } - - #[cfg(feature = "std")] - impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - V: serde::Deserialize<'de>, - >::Value: std::cmp::Eq + std::hash::Hash, - { - type Value = std::collections::HashMap<>::Value, V>; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); - loop { - let response: std::option::Option<(crate::MyType<'de, T>, V)> = map.next_entry()?; - match response { - Some((key, val)) => { - res.insert(key.0, val); - } - _ => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T, V>( - deserializer: D, - ) -> Result>::Value, V>, D::Error> - where - D: serde::Deserializer<'de>, - T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - V: 'de + serde::Deserialize<'de>, - >::Value: std::cmp::Eq + std::hash::Hash, - { - deserializer.deserialize_any(MapVisitor::<'de, T, V> { - _map_type: || (std::marker::PhantomData, std::marker::PhantomData), - }) - } - - pub fn serialize( - value: &std::collections::HashMap<::Value, V>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - F: crate::SerializeMethod, - V: serde::Serialize, - ::Value: std::cmp::Eq + std::hash::Hash, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry(&crate::MySeType:: { val: key }, &value)?; - } - map.end() - } -} - -pub mod map_custom_to_custom_visitor { - struct MapVisitor<'de, T, S> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - S: serde::de::Visitor<'de> + crate::HasConstructor, - { - _map_type: fn() -> ( - std::marker::PhantomData<&'de T>, - std::marker::PhantomData<&'de S>, - ), - } - - #[cfg(feature = "std")] - impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - S: serde::de::Visitor<'de> + crate::HasConstructor, - >::Value: std::cmp::Eq + std::hash::Hash, - { - type Value = std::collections::HashMap< - >::Value, - >::Value, - >; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); - loop { - let response: std::option::Option<(crate::MyType<'de, T>, crate::MyType<'de, S>)> = - map.next_entry()?; - match response { - Some((key, val)) => { - res.insert(key.0, val.0); - } - _ => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T, S>( - deserializer: D, - ) -> Result< - std::collections::HashMap< - >::Value, - >::Value, - >, - D::Error, - > - where - D: serde::Deserializer<'de>, - T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - S: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - >::Value: std::cmp::Eq + std::hash::Hash, - { - deserializer.deserialize_any(MapVisitor::<'de, T, S> { - _map_type: || (std::marker::PhantomData, std::marker::PhantomData), - }) - } - - pub fn serialize( - value: &std::collections::HashMap< - ::Value, - ::Value, - >, - serializer: S, - ) -> Result - where - S: serde::Serializer, - F: crate::SerializeMethod, - G: crate::SerializeMethod, - ::Value: std::cmp::Eq + std::hash::Hash, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry( - &crate::MySeType:: { val: key }, - &crate::MySeType:: { val: value }, - )?; - } - map.end() - } -} - -pub mod btree_map_custom_visitor { - struct MapVisitor<'de, T, V> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - V: serde::Deserialize<'de>, - { - _map_type: fn() -> ( - std::marker::PhantomData<&'de T>, - std::marker::PhantomData<&'de V>, - ), - } - - #[cfg(feature = "std")] - impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - V: serde::Deserialize<'de>, - >::Value: std::cmp::Eq + std::cmp::Ord, - { - type Value = std::collections::BTreeMap<>::Value, V>; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::new(); - loop { - let response: std::option::Option<(crate::MyType<'de, T>, V)> = map.next_entry()?; - match response { - Some((key, val)) => { - res.insert(key.0, val); - } - _ => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T, V>( - deserializer: D, - ) -> Result>::Value, V>, D::Error> - where - D: serde::Deserializer<'de>, - T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - V: 'de + serde::Deserialize<'de>, - >::Value: std::cmp::Eq + std::cmp::Ord, - { - deserializer.deserialize_any(MapVisitor::<'de, T, V> { - _map_type: || (std::marker::PhantomData, std::marker::PhantomData), - }) - } - - pub fn serialize( - value: &std::collections::BTreeMap<::Value, V>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - F: crate::SerializeMethod, - V: serde::Serialize, - ::Value: std::cmp::Eq + std::cmp::Ord, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry(&crate::MySeType:: { val: key }, &value)?; - } - map.end() - } -} - -pub mod btree_map_custom_to_custom_visitor { - struct MapVisitor<'de, T, S> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - S: serde::de::Visitor<'de> + crate::HasConstructor, - { - _map_type: fn() -> ( - std::marker::PhantomData<&'de T>, - std::marker::PhantomData<&'de S>, - ), - } - - #[cfg(feature = "std")] - impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> - where - T: serde::de::Visitor<'de> + crate::HasConstructor, - S: serde::de::Visitor<'de> + crate::HasConstructor, - >::Value: std::cmp::Eq + std::cmp::Ord, - { - type Value = std::collections::BTreeMap< - >::Value, - >::Value, - >; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::new(); - loop { - let response: std::option::Option<(crate::MyType<'de, T>, crate::MyType<'de, S>)> = - map.next_entry()?; - match response { - Some((key, val)) => { - res.insert(key.0, val.0); - } - _ => return Ok(res), - } - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D, T, S>( - deserializer: D, - ) -> Result< - std::collections::BTreeMap< - >::Value, - >::Value, - >, - D::Error, - > - where - D: serde::Deserializer<'de>, - T: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - S: 'de + serde::de::Visitor<'de> + crate::HasConstructor, - >::Value: std::cmp::Eq + std::cmp::Ord, - { - deserializer.deserialize_any(MapVisitor::<'de, T, S> { - _map_type: || (std::marker::PhantomData, std::marker::PhantomData), - }) - } - - pub fn serialize( - value: &std::collections::BTreeMap< - ::Value, - ::Value, - >, - serializer: S, - ) -> Result - where - S: serde::Serializer, - F: crate::SerializeMethod, - G: crate::SerializeMethod, - ::Value: std::cmp::Eq + std::cmp::Ord, - { - use serde::ser::SerializeMap; - let mut map = serializer.serialize_map(Some(value.len()))?; - for (key, value) in value { - map.serialize_entry( - &crate::MySeType:: { val: key }, - &crate::MySeType:: { val: value }, - )?; - } - map.end() - } -} - -pub trait SerializeMethod { - type Value; - fn serialize(value: &Self::Value, serializer: S) -> Result - where - S: serde::Serializer; -} - -pub struct MySeType<'a, T> -where - T: SerializeMethod, -{ - val: &'a ::Value, -} - -impl<'a, T: SerializeMethod> serde::Serialize for MySeType<'a, T> { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - T::serialize(self.val, serializer) - } -} - -pub mod map_visitor { - struct MapVisitor<'de, K, V> - where - K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, - V: serde::Deserialize<'de>, - { - _key_type: &'de std::marker::PhantomData, - _value_type: &'de std::marker::PhantomData, - } - - #[cfg(feature = "std")] - impl< - 'de, - K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, - V: serde::Deserialize<'de>, - > serde::de::Visitor<'de> for MapVisitor<'de, K, V> - { - type Value = std::collections::HashMap; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); - loop { - match map.next_entry()? { - Some((k, v)) => { - res.insert(k, v); - } - None => return Ok(res), - } - } - } - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize< - 'de, - D, - K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, - V: 'de + serde::Deserialize<'de>, - >( - deserializer: D, - ) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(MapVisitor::<'de, K, V> { - _key_type: &std::marker::PhantomData, - _value_type: &std::marker::PhantomData, - }) - } -} - -pub mod btree_map_visitor { - struct MapVisitor<'de, K, V> - where - K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, - V: serde::Deserialize<'de>, - { - _key_type: &'de std::marker::PhantomData, - _value_type: &'de std::marker::PhantomData, - } - - #[cfg(feature = "std")] - impl< - 'de, - K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, - V: serde::Deserialize<'de>, - > serde::de::Visitor<'de> for MapVisitor<'de, K, V> - { - type Value = std::collections::BTreeMap; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_map(self, mut map: A) -> Result - where - A: serde::de::MapAccess<'de>, - { - let mut res = Self::Value::new(); - loop { - match map.next_entry()? { - Some((k, v)) => { - res.insert(k, v); - } - None => return Ok(res), - } - } - } - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize< - 'de, - D, - K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, - V: 'de + serde::Deserialize<'de>, - >( - deserializer: D, - ) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(MapVisitor::<'de, K, V> { - _key_type: &std::marker::PhantomData, - _value_type: &std::marker::PhantomData, - }) - } -} - -pub mod string_visitor { - struct StringVisitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for StringVisitor { - type Value = std::string::String; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - return Ok(value.to_string()); - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(StringVisitor) - } -} - -pub mod string_opt_visitor { - struct StringVisitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for StringVisitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid String string or integer") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - return Ok(Some(value.to_string())); - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>( - deserializer: D, - ) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(StringVisitor) - } -} - -pub mod bool_visitor { - pub struct BoolVisitor; - - impl crate::HasConstructor for BoolVisitor { - fn new() -> Self { - return Self {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for BoolVisitor { - type Value = bool; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid Bool string or integer") - } - - fn visit_bool(self, value: bool) -> Result - where - E: serde::de::Error, - { - return Ok(value); - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(bool::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(BoolVisitor) - } -} - -pub mod bool_opt_visitor { - struct BoolVisitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for BoolVisitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid Bool string or integer") - } - - fn visit_bool(self, value: bool) -> Result - where - E: serde::de::Error, - { - return Ok(Some(value)); - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(BoolVisitor) - } -} - -pub mod i32_visitor { - pub struct I32Visitor; - - impl crate::HasConstructor for I32Visitor { - fn new() -> I32Visitor { - return I32Visitor {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for I32Visitor { - type Value = i32; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid I32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i32::try_from(value).map_err(E::custom) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value > i32::MAX as f64 - || value < i32::MIN as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(value as i32) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i32::try_from(value).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map_err(E::custom) - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(i32::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(I32Visitor) - } -} - -pub mod i32_opt_visitor { - struct I32Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for I32Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid I32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i32::try_from(value).map(|x| Some(x)).map_err(E::custom) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value > i32::MAX as f64 - || value < i32::MIN as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(Some(value as i32)) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i32::try_from(value).map(|x| Some(x)).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(I32Visitor) - } -} - -pub mod i64_visitor { - pub struct I64Visitor; - - impl crate::HasConstructor for I64Visitor { - fn new() -> Self { - return Self {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for I64Visitor { - type Value = i64; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid I64 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(value as i64) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value > i64::MAX as f64 - || value < i64::MIN as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(value as i64) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i64::try_from(value).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map_err(E::custom) - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(i64::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(I64Visitor) - } -} - -pub mod i64_opt_visitor { - struct I64Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for I64Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid I64 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as i64)) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value > i64::MAX as f64 - || value < i64::MIN as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(Some(value as i64)) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - i64::try_from(value).map(|x| Some(x)).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(I64Visitor) - } -} - -pub mod u32_visitor { - pub struct U32Visitor; - - impl crate::HasConstructor for U32Visitor { - fn new() -> Self { - return Self {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for U32Visitor { - type Value = u32; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid U32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - u32::try_from(value).map_err(E::custom) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value < 0.0 - || value > u32::MAX as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(value as u32) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - u32::try_from(value).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map_err(E::custom) - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(u32::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(U32Visitor) - } -} - -pub mod u32_opt_visitor { - struct U32Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for U32Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid U32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - u32::try_from(value).map(|x| Some(x)).map_err(E::custom) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value < 0.0 - || value > u32::MAX as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(Some(value as u32)) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - use std::convert::TryFrom; - u32::try_from(value).map(|x| Some(x)).map_err(E::custom) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(U32Visitor) - } -} - -pub mod u64_visitor { - pub struct U64Visitor; - - impl crate::HasConstructor for U64Visitor { - fn new() -> Self { - return Self {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for U64Visitor { - type Value = u64; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid U64 string or integer") - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(value as u64) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value < 0.0 - || value > u64::MAX as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number in the proper range, we can cast just fine. - Ok(value as u64) - } - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map_err(E::custom) - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(u64::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(U64Visitor) - } -} - -pub mod u64_opt_visitor { - struct U64Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for U64Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid U64 string or integer") - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as u64)) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if (value.trunc() - value).abs() > f64::EPSILON - || value < 0.0 - || value > u64::MAX as f64 - { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - // This is a round number, we can cast just fine. - Ok(Some(value as u64)) - } - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - // If we have scientific notation or a decimal, parse float first. - if value.contains('e') || value.contains('E') || value.ends_with(".0") { - value - .parse::() - .map_err(E::custom) - .and_then(|x| self.visit_f64(x)) - } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(U64Visitor) - } -} - -pub mod f64_visitor { - pub struct F64Visitor; - - impl crate::HasConstructor for F64Visitor { - fn new() -> F64Visitor { - return F64Visitor {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for F64Visitor { - type Value = f64; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid F64 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(value as f64) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - Ok(value as f64) - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(value as f64) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match value { - "NaN" => Ok(f64::NAN), - "Infinity" => Ok(f64::INFINITY), - "-Infinity" => Ok(f64::NEG_INFINITY), - _ => value.parse::().map_err(E::custom), - } - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(f64::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(F64Visitor) - } - - pub struct F64Serializer; - - impl crate::SerializeMethod for F64Serializer { - type Value = f64; - #[cfg(feature = "std")] - fn serialize(value: &Self::Value, serializer: S) -> Result - where - S: serde::Serializer, - { - if value.is_nan() { - serializer.serialize_str("NaN") - } else if value.is_infinite() && value.is_sign_negative() { - serializer.serialize_str("-Infinity") - } else if value.is_infinite() { - serializer.serialize_str("Infinity") - } else { - serializer.serialize_f64(*value) - } - } - } -} - -pub mod f64_opt_visitor { - struct F64Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for F64Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid F64 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as f64)) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as f64)) - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as f64)) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match value { - "NaN" => Ok(Some(f64::NAN)), - "Infinity" => Ok(Some(f64::INFINITY)), - "-Infinity" => Ok(Some(f64::NEG_INFINITY)), - _ => value.parse::().map(|x| Some(x)).map_err(E::custom), - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(F64Visitor) - } - - #[cfg(feature = "std")] - pub fn serialize(value: &std::option::Option, serializer: S) -> Result - where - S: serde::Serializer, - { - use crate::SerializeMethod; - match value { - None => serializer.serialize_none(), - Some(double) => crate::f64_visitor::F64Serializer::serialize(double, serializer), - } - } -} - -pub mod f32_visitor { - pub struct F32Visitor; - - impl crate::HasConstructor for F32Visitor { - fn new() -> F32Visitor { - return F32Visitor {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for F32Visitor { - type Value = f32; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid F32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(value as f32) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if value < f32::MIN as f64 || value > f32::MAX as f64 { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - Ok(value as f32) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(value as f32) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match value { - "NaN" => Ok(f32::NAN), - "Infinity" => Ok(f32::INFINITY), - "-Infinity" => Ok(f32::NEG_INFINITY), - _ => value.parse::().map_err(E::custom), - } - } - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(f32::default()) - } - } - - pub fn deserialize<'de, D>(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(F32Visitor) - } - - pub struct F32Serializer; - - impl crate::SerializeMethod for F32Serializer { - type Value = f32; - - #[cfg(feature = "std")] - fn serialize(value: &f32, serializer: S) -> Result - where - S: serde::Serializer, - { - if value.is_nan() { - serializer.serialize_str("NaN") - } else if value.is_infinite() && value.is_sign_negative() { - serializer.serialize_str("-Infinity") - } else if value.is_infinite() { - serializer.serialize_str("Infinity") - } else { - serializer.serialize_f32(*value) - } - } - } -} - -pub mod f32_opt_visitor { - struct F32Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for F32Visitor { - type Value = std::option::Option; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid F32 string or integer") - } - - fn visit_i64(self, value: i64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as f32)) - } - - fn visit_f64(self, value: f64) -> Result - where - E: serde::de::Error, - { - if value < f32::MIN as f64 || value > f32::MAX as f64 { - Err(serde::de::Error::invalid_type( - serde::de::Unexpected::Float(value), - &self, - )) - } else { - Ok(Some(value as f32)) - } - } - - fn visit_u64(self, value: u64) -> Result - where - E: serde::de::Error, - { - Ok(Some(value as f32)) - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - match value { - "NaN" => Ok(Some(f32::NAN)), - "Infinity" => Ok(Some(f32::INFINITY)), - "-Infinity" => Ok(Some(f32::NEG_INFINITY)), - _ => value.parse::().map(|x| Some(x)).map_err(E::custom), - } - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(F32Visitor) - } - - #[cfg(feature = "std")] - pub fn serialize(value: &std::option::Option, serializer: S) -> Result - where - S: serde::Serializer, - { - use crate::SerializeMethod; - match value { - None => serializer.serialize_none(), - Some(float) => crate::f32_visitor::F32Serializer::serialize(float, serializer), - } - } -} - -pub mod vec_u8_visitor { - pub struct VecU8Visitor; - - impl crate::HasConstructor for VecU8Visitor { - fn new() -> Self { - return Self {}; - } - } - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for VecU8Visitor { - type Value = ::prost::alloc::vec::Vec; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid Base64 encoded string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - base64::decode(value).map_err(E::custom) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(Self::Value::default()) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>(deserializer: D) -> Result<::prost::alloc::vec::Vec, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(VecU8Visitor) - } - - pub struct VecU8Serializer; - - impl crate::SerializeMethod for VecU8Serializer { - type Value = ::prost::alloc::vec::Vec; - - #[cfg(feature = "std")] - fn serialize(value: &Self::Value, serializer: S) -> Result - where - S: serde::Serializer, - { - serializer.serialize_str(&base64::encode(value)) - } - } -} - -pub mod vec_u8_opt_visitor { - struct VecU8Visitor; - - #[cfg(feature = "std")] - impl<'de> serde::de::Visitor<'de> for VecU8Visitor { - type Value = std::option::Option<::prost::alloc::vec::Vec>; - - fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { - formatter.write_str("a valid Base64 encoded string") - } - - fn visit_str(self, value: &str) -> Result - where - E: serde::de::Error, - { - base64::decode(value) - .map(|str| Some(str)) - .map_err(E::custom) - } - - fn visit_unit(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - - fn visit_none(self) -> Result - where - E: serde::de::Error, - { - Ok(None) - } - } - - #[cfg(feature = "std")] - pub fn deserialize<'de, D>( - deserializer: D, - ) -> Result>, D::Error> - where - D: serde::Deserializer<'de>, - { - deserializer.deserialize_any(VecU8Visitor) - } - - #[cfg(feature = "std")] - pub fn serialize( - value: &std::option::Option<::prost::alloc::vec::Vec>, - serializer: S, - ) -> Result - where - S: serde::Serializer, - { - use crate::SerializeMethod; - match value { - None => serializer.serialize_none(), - Some(value) => crate::vec_u8_visitor::VecU8Serializer::serialize(value, serializer), - } - } -} - #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/prost-types/src/protobuf.rs b/prost-types/src/protobuf.rs index 2afaeda0c..01e9bee69 100644 --- a/prost-types/src/protobuf.rs +++ b/prost-types/src/protobuf.rs @@ -1030,7 +1030,7 @@ pub mod generated_code_info { /// } /// // TODO(konradjniemiec) proper serialization -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, ::serde::Serialize, ::serde::Deserialize)] pub struct Any { /// A URL/resource name that uniquely identifies the type of the serialized /// protocol buffer message. This string must contain at least @@ -1503,8 +1503,7 @@ pub struct Mixin { /// microsecond should be expressed in JSON format as "3.000001s". /// /// -// TODO(konradjniemiec) proper serialization -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message)] pub struct Duration { /// Signed seconds of the span of time. Must be from -315,576,000,000 /// to +315,576,000,000 inclusive. Note: these bounds are computed from: @@ -1720,7 +1719,7 @@ pub struct Duration { /// request should verify the included field paths, and return an /// `INVALID_ARGUMENT` error if any path is unmappable. // TODO(konradjniemiec) proper serialization -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, ::serde::Serialize, ::serde::Deserialize)] pub struct FieldMask { /// The set of field mask paths. #[prost(string, repeated, tag="1")] @@ -1734,7 +1733,7 @@ pub struct FieldMask { /// with the proto support for the language. /// /// The JSON representation for `Struct` is JSON object. -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, ::serde::Serialize, ::serde::Deserialize)] pub struct Struct { /// Unordered map of dynamically typed values. #[prost(btree_map="string, message", tag="1")] @@ -1746,7 +1745,7 @@ pub struct Struct { /// variants, absence of any variant indicates an error. /// /// The JSON representation for `Value` is JSON value. -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, ::serde::Serialize, ::serde::Deserialize)] pub struct Value { /// The kind of value. #[prost(oneof="value::Kind", tags="1, 2, 3, 4, 5, 6")] @@ -1755,7 +1754,7 @@ pub struct Value { /// Nested message and enum types in `Value`. pub mod value { /// The kind of value. - #[derive(Clone, PartialEq, ::prost::Oneof, serde::Serialize, serde::Deserialize)] + #[derive(Clone, PartialEq, ::prost::Oneof, ::serde::Serialize, ::serde::Deserialize)] pub enum Kind { /// Represents a null value. #[prost(enumeration="super::NullValue", tag="1")] @@ -1780,7 +1779,7 @@ pub mod value { /// `ListValue` is a wrapper around a repeated field of values. /// /// The JSON representation for `ListValue` is JSON array. -#[derive(Clone, PartialEq, ::prost::Message, serde::Serialize, serde::Deserialize)] +#[derive(Clone, PartialEq, ::prost::Message, ::serde::Serialize, ::serde::Deserialize)] pub struct ListValue { /// Repeated field of dynamically typed values. #[prost(message, repeated, tag="1")] diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs new file mode 100644 index 000000000..d6c6e7ef8 --- /dev/null +++ b/prost-types/src/serde.rs @@ -0,0 +1,2546 @@ +#[cfg(feature = "std")] +impl ::serde::Serialize for crate::Timestamp { + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + use std::convert::TryInto; + serializer.serialize_str( + &chrono::DateTime::::from_utc( + chrono::NaiveDateTime::from_timestamp(self.seconds, self.nanos.try_into().unwrap()), + chrono::Utc, + ) + .to_rfc3339(), + ) + } +} + +struct TimestampVisitor; + +#[cfg(feature = "std")] +impl<'de> ::serde::de::Visitor<'de> for TimestampVisitor { + type Value = crate::Timestamp; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid RFC 3339 timestamp string") + } + + fn visit_str(self, value: &str) -> Result + where + E: ::serde::de::Error, + { + use std::convert::TryInto; + let dt = chrono::DateTime::parse_from_rfc3339(value) + .map_err(::serde::de::Error::custom)? + .naive_utc(); + Ok(crate::Timestamp::from( + std::time::UNIX_EPOCH + + std::time::Duration::new( + dt.timestamp() + .try_into() + .map_err(::serde::de::Error::custom)?, + dt.timestamp_subsec_nanos(), + ), + )) + } +} + +impl<'de> ::serde::Deserialize<'de> for crate::Timestamp { + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + deserializer.deserialize_any(TimestampVisitor) + } +} + +impl ::serde::Serialize for crate::Duration { + fn serialize(&self, serializer: S) -> Result + where + S: ::serde::Serializer, + { + let mut nanos = self.nanos; + if nanos < 0 { + nanos = -nanos; + } + + while nanos > 0 && nanos % 1_000 == 0 { + nanos /= 1_000; + } + + if nanos == 0 { + serializer.serialize_str(&format!("{}s", self.seconds)) + } else { + serializer.serialize_str(&format!("{}.{}s", self.seconds, nanos)) + } + } +} + +struct DurationVisitor; + +#[cfg(feature = "std")] +impl<'de> ::serde::de::Visitor<'de> for DurationVisitor { + type Value = crate::Duration; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid duration string") + } + + fn visit_str(self, value: &str) -> Result + where + E: ::serde::de::Error, + { + let value = match value.strip_suffix('s') { + Some(value) => value, + None => { + return Err(::serde::de::Error::custom(format!( + "invalid duration: {}", + value + ))) + } + }; + let seconds = value.parse::().map_err(::serde::de::Error::custom)?; + + if seconds.is_sign_negative() { + let crate::Duration { seconds, nanos } = + std::time::Duration::from_secs_f64(-seconds).into(); + + Ok(crate::Duration { + seconds: -seconds, + nanos: -nanos, + }) + } else { + Ok(std::time::Duration::from_secs_f64(seconds).into()) + } + } +} + +impl<'de> ::serde::Deserialize<'de> for crate::Duration { + fn deserialize(deserializer: D) -> Result + where + D: ::serde::Deserializer<'de>, + { + deserializer.deserialize_any(DurationVisitor) + } +} + +pub trait HasConstructor { + fn new() -> Self; +} + +pub struct MyType<'de, T: serde::de::Visitor<'de> + HasConstructor>( + >::Value, +); + +impl<'de, T> serde::Deserialize<'de> for MyType<'de, T> +where + T: serde::de::Visitor<'de> + HasConstructor, +{ + fn deserialize(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer + .deserialize_any(T::new()) + .map(|x| MyType { 0: x }) + } +} + +pub fn is_default(t: &T) -> bool { + t == &T::default() +} + +pub mod vec { + struct VecVisitor<'de, T> + where + T: serde::Deserialize<'de>, + { + _vec_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T: serde::Deserialize<'de>> serde::de::Visitor<'de> for VecVisitor<'de, T> { + type Value = Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid list") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); + loop { + match seq.next_element()? { + Some(el) => res.push(el), + None => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T: 'de + serde::Deserialize<'de>>( + deserializer: D, + ) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecVisitor::<'de, T> { + _vec_type: &std::marker::PhantomData, + }) + } +} + +pub mod repeated { + struct VecVisitor<'de, T> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + _vec_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for VecVisitor<'de, T> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + type Value = Vec<>::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid repeated field") + } + + fn visit_seq(self, mut seq: A) -> Result + where + A: serde::de::SeqAccess<'de>, + { + let mut res = Self::Value::with_capacity(seq.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option> = + seq.next_element()?; + match response { + Some(el) => res.push(el.0), + None => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor>( + deserializer: D, + ) -> Result>::Value>, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecVisitor::<'de, T> { + _vec_type: &std::marker::PhantomData, + }) + } + + pub fn serialize( + value: &Vec<::Value>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::serde::SerializeMethod, + { + use serde::ser::SerializeSeq; + let mut seq = serializer.serialize_seq(Some(value.len()))?; + for e in value { + seq.serialize_element(&crate::serde::MySeType:: { val: e })?; + } + seq.end() + } +} + +pub mod enum_serde { + pub struct EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + _type: &'de std::marker::PhantomData, + } + + impl crate::serde::HasConstructor for EnumVisitor<'_, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + fn new() -> Self { + return Self { + _type: &std::marker::PhantomData, + }; + } + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + type Value = i32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid string or integer representation of an enum") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match T::from_str(value) { + Ok(en) => Ok(en.into()), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &self, + )), + } + } + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + match T::try_from(value as i32) { + Ok(en) => Ok(en.into()), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Signed(value as i64), + &self, + )), + } + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + T: 'de + + ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + deserializer.deserialize_any(EnumVisitor::<'de, T> { + _type: &std::marker::PhantomData, + }) + } + + pub fn serialize(value: &i32, serializer: S) -> Result + where + S: serde::Serializer, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + match T::try_from(*value) { + Err(_) => Err(serde::ser::Error::custom("invalid enum value")), + Ok(t) => serializer.serialize_str(&t.to_string()), + } + } + + pub struct EnumSerializer + where + T: std::convert::TryFrom + ToString, + { + _type: std::marker::PhantomData, + } + + impl crate::serde::SerializeMethod for EnumSerializer + where + T: std::convert::TryFrom + ToString, + { + type Value = i32; + + fn serialize(value: &i32, serializer: S) -> Result + where + S: serde::Serializer, + { + match T::try_from(*value) { + Err(_) => Err(serde::ser::Error::custom("invalid enum value")), + Ok(t) => serializer.serialize_str(&t.to_string()), + } + } + } +} + +pub mod enum_opt { + struct EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + _type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl<'de, T> serde::de::Visitor<'de> for EnumVisitor<'de, T> + where + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid string or integer representation of an enum") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match T::from_str(value) { + Ok(en) => Ok(Some(en.into())), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &self, + )), + } + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + match T::try_from(value as i32) { + Ok(en) => Ok(Some(en.into())), + Err(_) => Err(serde::de::Error::invalid_value( + serde::de::Unexpected::Signed(value as i64), + &self, + )), + } + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + self.visit_i64(value as i64) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + + ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + deserializer.deserialize_any(EnumVisitor::<'de, T> { + _type: &std::marker::PhantomData, + }) + } + + pub fn serialize( + value: &std::option::Option, + serializer: S, + ) -> Result + where + S: serde::Serializer, + T: ToString + + std::str::FromStr + + std::convert::Into + + std::convert::TryFrom + + Default, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(enum_int) => { + crate::serde::enum_serde::EnumSerializer::::serialize(enum_int, serializer) + } + } + } +} + +pub mod btree_map_custom_value { + struct MapVisitor<'de, T, V> + where + T: serde::Deserialize<'de>, + V: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + type Value = std::collections::BTreeMap>::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + let response: std::option::Option<(T, crate::serde::MyType<'de, V>)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::BTreeMap::Value>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + T: serde::Serialize + std::cmp::Eq + std::cmp::Ord, + F: crate::serde::SerializeMethod, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&key, &crate::serde::MySeType:: { val: value })?; + } + map.end() + } +} + +pub mod map_custom_value { + struct MapVisitor<'de, T, V> + where + T: serde::Deserialize<'de>, + V: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + type Value = std::collections::HashMap>::Value>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option<(T, crate::serde::MyType<'de, V>)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::HashMap::Value>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + T: serde::Serialize + std::cmp::Eq + std::hash::Hash, + F: crate::serde::SerializeMethod, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&key, &crate::serde::MySeType:: { val: value })?; + } + map.end() + } +} + +pub mod map_custom { + struct MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: serde::Deserialize<'de>, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::hash::Hash, + { + type Value = std::collections::HashMap<>::Value, V>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option<(crate::serde::MyType<'de, T>, V)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value, V>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: 'de + serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::hash::Hash, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::HashMap<::Value, V>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::serde::SerializeMethod, + V: serde::Serialize, + ::Value: std::cmp::Eq + std::hash::Hash, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&crate::serde::MySeType:: { val: key }, &value)?; + } + map.end() + } +} + +pub mod map_custom_to_custom { + struct MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de S>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: serde::de::Visitor<'de> + crate::serde::HasConstructor, + >::Value: std::cmp::Eq + std::hash::Hash, + { + type Value = std::collections::HashMap< + >::Value, + >::Value, + >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + let response: std::option::Option<( + crate::serde::MyType<'de, T>, + crate::serde::MyType<'de, S>, + )> = map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, S>( + deserializer: D, + ) -> Result< + std::collections::HashMap< + >::Value, + >::Value, + >, + D::Error, + > + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + >::Value: std::cmp::Eq + std::hash::Hash, + { + deserializer.deserialize_any(MapVisitor::<'de, T, S> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::HashMap< + ::Value, + ::Value, + >, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::serde::SerializeMethod, + G: crate::serde::SerializeMethod, + ::Value: std::cmp::Eq + std::hash::Hash, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry( + &crate::serde::MySeType:: { val: key }, + &crate::serde::MySeType:: { val: value }, + )?; + } + map.end() + } +} + +pub mod btree_map_custom { + struct MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: serde::Deserialize<'de>, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de V>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, V> serde::de::Visitor<'de> for MapVisitor<'de, T, V> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + type Value = std::collections::BTreeMap<>::Value, V>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + let response: std::option::Option<(crate::serde::MyType<'de, T>, V)> = + map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, V>( + deserializer: D, + ) -> Result>::Value, V>, D::Error> + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + V: 'de + serde::Deserialize<'de>, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + deserializer.deserialize_any(MapVisitor::<'de, T, V> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::BTreeMap<::Value, V>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::serde::SerializeMethod, + V: serde::Serialize, + ::Value: std::cmp::Eq + std::cmp::Ord, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry(&crate::serde::MySeType:: { val: key }, &value)?; + } + map.end() + } +} + +pub mod btree_map_custom_to_custom { + struct MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: serde::de::Visitor<'de> + crate::serde::HasConstructor, + { + _map_type: fn() -> ( + std::marker::PhantomData<&'de T>, + std::marker::PhantomData<&'de S>, + ), + } + + #[cfg(feature = "std")] + impl<'de, T, S> serde::de::Visitor<'de> for MapVisitor<'de, T, S> + where + T: serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: serde::de::Visitor<'de> + crate::serde::HasConstructor, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + type Value = std::collections::BTreeMap< + >::Value, + >::Value, + >; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + let response: std::option::Option<( + crate::serde::MyType<'de, T>, + crate::serde::MyType<'de, S>, + )> = map.next_entry()?; + match response { + Some((key, val)) => { + res.insert(key.0, val.0); + } + _ => return Ok(res), + } + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D, T, S>( + deserializer: D, + ) -> Result< + std::collections::BTreeMap< + >::Value, + >::Value, + >, + D::Error, + > + where + D: serde::Deserializer<'de>, + T: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + S: 'de + serde::de::Visitor<'de> + crate::serde::HasConstructor, + >::Value: std::cmp::Eq + std::cmp::Ord, + { + deserializer.deserialize_any(MapVisitor::<'de, T, S> { + _map_type: || (std::marker::PhantomData, std::marker::PhantomData), + }) + } + + pub fn serialize( + value: &std::collections::BTreeMap< + ::Value, + ::Value, + >, + serializer: S, + ) -> Result + where + S: serde::Serializer, + F: crate::serde::SerializeMethod, + G: crate::serde::SerializeMethod, + ::Value: std::cmp::Eq + std::cmp::Ord, + { + use serde::ser::SerializeMap; + let mut map = serializer.serialize_map(Some(value.len()))?; + for (key, value) in value { + map.serialize_entry( + &crate::serde::MySeType:: { val: key }, + &crate::serde::MySeType:: { val: value }, + )?; + } + map.end() + } +} + +pub trait SerializeMethod { + type Value; + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer; +} + +pub struct MySeType<'a, T> +where + T: SerializeMethod, +{ + val: &'a ::Value, +} + +impl<'a, T: SerializeMethod> serde::Serialize for MySeType<'a, T> { + fn serialize(&self, serializer: S) -> Result + where + S: serde::Serializer, + { + T::serialize(self.val, serializer) + } +} + +pub mod map { + struct MapVisitor<'de, K, V> + where + K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::Deserialize<'de>, + { + _key_type: &'de std::marker::PhantomData, + _value_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl< + 'de, + K: serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: serde::Deserialize<'de>, + > serde::de::Visitor<'de> for MapVisitor<'de, K, V> + { + type Value = std::collections::HashMap; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::with_capacity(map.size_hint().unwrap_or(0)); + loop { + match map.next_entry()? { + Some((k, v)) => { + res.insert(k, v); + } + None => return Ok(res), + } + } + } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize< + 'de, + D, + K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::hash::Hash, + V: 'de + serde::Deserialize<'de>, + >( + deserializer: D, + ) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(MapVisitor::<'de, K, V> { + _key_type: &std::marker::PhantomData, + _value_type: &std::marker::PhantomData, + }) + } +} + +pub mod btree_map { + struct MapVisitor<'de, K, V> + where + K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::Deserialize<'de>, + { + _key_type: &'de std::marker::PhantomData, + _value_type: &'de std::marker::PhantomData, + } + + #[cfg(feature = "std")] + impl< + 'de, + K: serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: serde::Deserialize<'de>, + > serde::de::Visitor<'de> for MapVisitor<'de, K, V> + { + type Value = std::collections::BTreeMap; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid map") + } + + fn visit_map(self, mut map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let mut res = Self::Value::new(); + loop { + match map.next_entry()? { + Some((k, v)) => { + res.insert(k, v); + } + None => return Ok(res), + } + } + } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize< + 'de, + D, + K: 'de + serde::Deserialize<'de> + std::cmp::Eq + std::cmp::Ord, + V: 'de + serde::Deserialize<'de>, + >( + deserializer: D, + ) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(MapVisitor::<'de, K, V> { + _key_type: &std::marker::PhantomData, + _value_type: &std::marker::PhantomData, + }) + } +} + +pub mod string { + struct StringVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = std::string::String; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + return Ok(value.to_string()); + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(StringVisitor) + } +} + +pub mod string_opt { + struct StringVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for StringVisitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + return Ok(Some(value.to_string())); + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(StringVisitor) + } +} + +pub mod bool { + pub struct BoolVisitor; + + impl crate::serde::HasConstructor for BoolVisitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for BoolVisitor { + type Value = bool; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid boolean") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + return Ok(value); + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(bool::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(BoolVisitor) + } +} + +pub mod bool_map_key { + pub struct BoolVisitor; + + impl crate::serde::HasConstructor for BoolVisitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for BoolVisitor { + type Value = bool; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid boolean") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "true" => Ok(true), + "false" => Ok(false), + _ => Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Str(value), + &self, + )), + } + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(BoolVisitor) + } + + pub struct BoolKeySerializer; + + impl crate::serde::SerializeMethod for BoolKeySerializer { + type Value = bool; + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + if *value { + serializer.serialize_str("true") + } else { + serializer.serialize_str("false") + } + } + } +} + +pub mod bool_opt { + struct BoolVisitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for BoolVisitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid boolean") + } + + fn visit_bool(self, value: bool) -> Result + where + E: serde::de::Error, + { + return Ok(Some(value)); + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(BoolVisitor) + } +} + +pub mod i32 { + pub struct I32Visitor; + + impl crate::serde::HasConstructor for I32Visitor { + fn new() -> I32Visitor { + return I32Visitor {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I32Visitor { + type Value = i32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid i32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i32::MAX as f64 + || value < i32::MIN as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(value as i32) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map_err(E::custom) + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(i32::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I32Visitor) + } +} + +pub mod i32_opt { + struct I32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid i32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i32::MAX as f64 + || value < i32::MIN as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(Some(value as i32)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I32Visitor) + } +} + +pub mod i64 { + pub struct I64Visitor; + + impl crate::serde::HasConstructor for I64Visitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I64Visitor { + type Value = i64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid i64") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as i64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i64::MAX as f64 + || value < i64::MIN as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(value as i64) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i64::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map_err(E::custom) + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(i64::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I64Visitor) + } +} + +pub mod i64_opt { + struct I64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for I64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid i64") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as i64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value > i64::MAX as f64 + || value < i64::MIN as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(Some(value as i64)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + i64::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(I64Visitor) + } +} + +pub mod u32 { + pub struct U32Visitor; + + impl crate::serde::HasConstructor for U32Visitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U32Visitor { + type Value = u32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid u32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u32::MAX as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(value as u32) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map_err(E::custom) + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(u32::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U32Visitor) + } +} + +pub mod u32_opt { + struct U32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid u32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u32::MAX as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(Some(value as u32)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + use std::convert::TryFrom; + u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U32Visitor) + } +} + +pub mod u64 { + pub struct U64Visitor; + + impl crate::serde::HasConstructor for U64Visitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U64Visitor { + type Value = u64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid u64") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as u64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u64::MAX as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number in the proper range, we can cast just fine. + Ok(value as u64) + } + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map_err(E::custom) + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(u64::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U64Visitor) + } +} + +pub mod u64_opt { + struct U64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for U64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid u64") + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as u64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if (value.trunc() - value).abs() > f64::EPSILON + || value < 0.0 + || value > u64::MAX as f64 + { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + // This is a round number, we can cast just fine. + Ok(Some(value as u64)) + } + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + // If we have scientific notation or a decimal, parse float first. + if value.contains('e') || value.contains('E') || value.ends_with(".0") { + value + .parse::() + .map_err(E::custom) + .and_then(|x| self.visit_f64(x)) + } else { + value.parse::().map(|x| Some(x)).map_err(E::custom) + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(U64Visitor) + } +} + +pub mod f64 { + pub struct F64Visitor; + + impl crate::serde::HasConstructor for F64Visitor { + fn new() -> F64Visitor { + return F64Visitor {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F64Visitor { + type Value = f64; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid f64") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as f64) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(f64::NAN), + "Infinity" => Ok(f64::INFINITY), + "-Infinity" => Ok(f64::NEG_INFINITY), + _ => value.parse::().map_err(E::custom), + } + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(f64::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F64Visitor) + } + + pub struct F64Serializer; + + impl crate::serde::SerializeMethod for F64Serializer { + type Value = f64; + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f64(*value) + } + } + } +} + +pub mod f64_opt { + struct F64Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F64Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid f64") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f64)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(Some(f64::NAN)), + "Infinity" => Ok(Some(f64::INFINITY)), + "-Infinity" => Ok(Some(f64::NEG_INFINITY)), + _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F64Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(double) => crate::serde::f64::F64Serializer::serialize(double, serializer), + } + } +} + +pub mod f32 { + pub struct F32Visitor; + + impl crate::serde::HasConstructor for F32Visitor { + fn new() -> F32Visitor { + return F32Visitor {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F32Visitor { + type Value = f32; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid f32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(value as f32) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if value < f32::MIN as f64 || value > f32::MAX as f64 { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + Ok(value as f32) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(value as f32) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(f32::NAN), + "Infinity" => Ok(f32::INFINITY), + "-Infinity" => Ok(f32::NEG_INFINITY), + _ => value.parse::().map_err(E::custom), + } + } + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(f32::default()) + } + } + + pub fn deserialize<'de, D>(deserializer: D) -> Result + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F32Visitor) + } + + pub struct F32Serializer; + + impl crate::serde::SerializeMethod for F32Serializer { + type Value = f32; + + #[cfg(feature = "std")] + fn serialize(value: &f32, serializer: S) -> Result + where + S: serde::Serializer, + { + if value.is_nan() { + serializer.serialize_str("NaN") + } else if value.is_infinite() && value.is_sign_negative() { + serializer.serialize_str("-Infinity") + } else if value.is_infinite() { + serializer.serialize_str("Infinity") + } else { + serializer.serialize_f32(*value) + } + } + } +} + +pub mod f32_opt { + struct F32Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for F32Visitor { + type Value = std::option::Option; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid f32") + } + + fn visit_i64(self, value: i64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f32)) + } + + fn visit_f64(self, value: f64) -> Result + where + E: serde::de::Error, + { + if value < f32::MIN as f64 || value > f32::MAX as f64 { + Err(serde::de::Error::invalid_type( + serde::de::Unexpected::Float(value), + &self, + )) + } else { + Ok(Some(value as f32)) + } + } + + fn visit_u64(self, value: u64) -> Result + where + E: serde::de::Error, + { + Ok(Some(value as f32)) + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + match value { + "NaN" => Ok(Some(f32::NAN)), + "Infinity" => Ok(Some(f32::INFINITY)), + "-Infinity" => Ok(Some(f32::NEG_INFINITY)), + _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + } + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(F32Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(float) => crate::serde::f32::F32Serializer::serialize(float, serializer), + } + } +} + +pub mod vec_u8 { + pub struct VecU8Visitor; + + impl crate::serde::HasConstructor for VecU8Visitor { + fn new() -> Self { + return Self {}; + } + } + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for VecU8Visitor { + type Value = ::prost::alloc::vec::Vec; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid base64 encoded string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + base64::decode(value).map_err(E::custom) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Self::Value::default()) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result<::prost::alloc::vec::Vec, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecU8Visitor) + } + + pub struct VecU8Serializer; + + impl crate::serde::SerializeMethod for VecU8Serializer { + type Value = ::prost::alloc::vec::Vec; + + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&base64::encode(value)) + } + } +} + +pub mod vec_u8_opt { + struct VecU8Visitor; + + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for VecU8Visitor { + type Value = std::option::Option<::prost::alloc::vec::Vec>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid base64 encoded string") + } + + fn visit_str(self, value: &str) -> Result + where + E: serde::de::Error, + { + base64::decode(value) + .map(|str| Some(str)) + .map_err(E::custom) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>( + deserializer: D, + ) -> Result>, D::Error> + where + D: serde::Deserializer<'de>, + { + deserializer.deserialize_any(VecU8Visitor) + } + + #[cfg(feature = "std")] + pub fn serialize( + value: &std::option::Option<::prost::alloc::vec::Vec>, + serializer: S, + ) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(value) => crate::serde::vec_u8::VecU8Serializer::serialize(value, serializer), + } + } +} diff --git a/protobuf/Cargo.toml b/protobuf/Cargo.toml index 3515aadb7..34305dc6b 100644 --- a/protobuf/Cargo.toml +++ b/protobuf/Cargo.toml @@ -11,7 +11,7 @@ edition = "2018" [dependencies] bytes = { version = "1", default-features = false } prost = { path = ".." } -prost-types = { path = "../prost-types" } +prost-types = { path = "../prost-types", features = ["json"] } serde = { version = "1", features = ["derive"] } [build-dependencies] From 96755ddd9e807906f5a1d9e920386c7736be8205 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Wed, 16 Mar 2022 12:11:41 -0700 Subject: [PATCH 24/30] i64/u64 serialize as string, add empty visitor --- conformance/failing_tests.txt | 2 - prost-build/src/code_generator.rs | 12 ++--- prost-types/src/serde.rs | 84 ++++++++++++++++++++++++++++++- 3 files changed, 88 insertions(+), 10 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index e4c9be0ee..2292c32f6 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -7,7 +7,6 @@ Recommended.Proto3.JsonInput.BytesFieldBase64Url.ProtobufOutput Recommended.Proto3.JsonInput.DurationHas3FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHas6FractionalDigits.Validator Recommended.Proto3.JsonInput.DurationHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.Int64FieldBeString.Validator Recommended.Proto3.JsonInput.MapFieldValueIsNull Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator @@ -17,7 +16,6 @@ Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator -Recommended.Proto3.JsonInput.Uint64FieldBeString.Validator Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 8c6d36e61..3d72d8f37 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -70,15 +70,15 @@ impl<'a> CodeGenerator<'a> { ("enum", false, false, _) => (Some(format!("::prost_types::serde::enum_serde::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_serde::deserialize::<_, {}>", self.resolve_ident(&type_name)))), ("enum", true, false, _) => (Some(format!("::prost_types::serde::enum_opt::serialize::<_, {}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_opt::deserialize::<_, {}>", self.resolve_ident(&type_name)))), ("enum", _, true, _) => (Some(format!("::prost_types::serde::enum_serde::EnumSerializer<{}>", self.resolve_ident(&type_name))), Some(format!("::prost_types::serde::enum_serde::EnumVisitor::<{}>", self.resolve_ident(&type_name)))), - ("i64", false, false, _) => (None, Some("::prost_types::serde::i64::deserialize".to_string())), - ("i64", true, false, _) => (None, Some("::prost_types::serde::i64_opt::deserialize".to_string())), - ("i64", _, true, _) => (None, Some("::prost_types::serde::i64::I64Visitor".to_string())), + ("i64", false, false, _) => (Some("<::prost_types::serde::i64::I64Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::i64::deserialize".to_string())), + ("i64", true, false, _) => (Some("::prost_types::serde::i64_opt::serialize".to_string()), Some("::prost_types::serde::i64_opt::deserialize".to_string())), + ("i64", _, true, _) => (Some("::prost_types::serde::i64::I64Serializer".to_string()), Some("::prost_types::serde::i64::I64Visitor".to_string())), ("u32", false, false, _) => (None, Some("::prost_types::serde::u32::deserialize".to_string())), ("u32", true, false, _) => (None, Some("::prost_types::serde::u32_opt::deserialize".to_string())), ("u32", _, true, _) => (None, Some("::prost_types::serde::u32::U32Visitor".to_string())), - ("u64", false, false, _) => (None, Some("::prost_types::serde::u64::deserialize".to_string())), - ("u64", true, false, _) => (None, Some("::prost_types::serde::u64_opt::deserialize".to_string())), - ("u64", _, true, _) => (None, Some("::prost_types::serde::u64::U64Visitor".to_string())), + ("u64", false, false, _) => (Some("<::prost_types::serde::u64::U64Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::u64::deserialize".to_string())), + ("u64", true, false, _) => (Some("::prost_types::serde::u64_opt::serialize".to_string()), Some("::prost_types::serde::u64_opt::deserialize".to_string())), + ("u64", _, true, _) => (Some("::prost_types::serde::u64::U64Serializer".to_string()), Some("::prost_types::serde::u64::U64Visitor".to_string())), ("f64", false, false, _) => (Some("<::prost_types::serde::f64::F64Serializer as ::prost_types::serde::SerializeMethod>::serialize".to_string()), Some("::prost_types::serde::f64::deserialize".to_string())), ("f64", true, false, _) => (Some("::prost_types::serde::f64_opt::serialize".to_string()), Some("::prost_types::serde::f64_opt::deserialize".to_string())), ("f64", _, true, _) => (Some("::prost_types::serde::f64::F64Serializer".to_string()), Some("::prost_types::serde::f64::F64Visitor".to_string())), diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index d6c6e7ef8..42a57a167 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -150,6 +150,35 @@ pub fn is_default(t: &T) -> bool { t == &T::default() } +pub mod empty { + struct EmptyVisitor; + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for EmptyVisitor { + type Value = (); + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid empty object") + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let _ = map; + Ok(()) + } + } + + pub fn serialize(_: &(), serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeMap; + let map = serializer.serialize_map(Some(0))?; + map.end() + } +} + pub mod vec { struct VecVisitor<'de, T> where @@ -1682,6 +1711,19 @@ pub mod i64 { { deserializer.deserialize_any(I64Visitor) } + + pub struct I64Serializer; + + impl crate::serde::SerializeMethod for I64Serializer { + type Value = i64; + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&value.to_string()) + } + } } pub mod i64_opt { @@ -1765,6 +1807,18 @@ pub mod i64_opt { { deserializer.deserialize_any(I64Visitor) } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(double) => crate::serde::i64::I64Serializer::serialize(double, serializer), + } + } } pub mod u32 { @@ -2004,6 +2058,19 @@ pub mod u64 { { deserializer.deserialize_any(U64Visitor) } + + pub struct U64Serializer; + + impl crate::serde::SerializeMethod for U64Serializer { + type Value = u64; + #[cfg(feature = "std")] + fn serialize(value: &Self::Value, serializer: S) -> Result + where + S: serde::Serializer, + { + serializer.serialize_str(&value.to_string()) + } + } } pub mod u64_opt { @@ -2079,6 +2146,19 @@ pub mod u64_opt { { deserializer.deserialize_any(U64Visitor) } + + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(double) => crate::serde::u64::U64Serializer::serialize(double, serializer), + } + } + } pub mod f64 { @@ -2109,7 +2189,7 @@ pub mod f64 { where E: serde::de::Error, { - Ok(value as f64) + Ok(value) } fn visit_u64(self, value: u64) -> Result @@ -2190,7 +2270,7 @@ pub mod f64_opt { where E: serde::de::Error, { - Ok(Some(value as f64)) + Ok(Some(value)) } fn visit_u64(self, value: u64) -> Result From 4b5a456bb2c2879c20a3d881f7c98eaeaeb4a20e Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Thu, 17 Mar 2022 13:31:06 -0700 Subject: [PATCH 25/30] empty + fmt --- prost-build/src/code_generator.rs | 2 + prost-types/src/lib.rs | 1 - prost-types/src/serde.rs | 84 +++++++++++++++++++++++++++---- 3 files changed, 75 insertions(+), 12 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index 3d72d8f37..e96c2d2a0 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -60,6 +60,8 @@ impl<'a> CodeGenerator<'a> { map_key: bool, ) -> (Option, Option) { match (ty, optional, collection, map_key) { + ("()", false, false, _) => (Some("::prost_types::serde::empty::serialize".to_string()), Some("::prost_types::serde::empty::deserialize".to_string())), + ("()", true, false, _) => (Some("::prost_types::serde::empty_opt::serialize".to_string()), Some("::prost_types::serde::empty_opt::deserialize".to_string())), ("bool", false, false, _) => (None, Some("::prost_types::serde::bool::deserialize".to_string())), ("bool", true, false, _) => (None, Some("::prost_types::serde::bool_opt::deserialize".to_string())), ("bool", _, true, false) => (None, Some("::prost_types::serde::bool::BoolVisitor".to_string())), diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index adfcf5cf9..1e5450020 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -258,7 +258,6 @@ impl TryFrom for std::time::SystemTime { } } - #[cfg(test)] mod tests { use std::time::{Duration, SystemTime, UNIX_EPOCH}; diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index 42a57a167..8fa7bb011 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -169,6 +169,13 @@ pub mod empty { } } + pub fn deserialize<'de, D>(deserializer: D) -> Result<(), D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_any(EmptyVisitor) + } + pub fn serialize(_: &(), serializer: S) -> Result where S: serde::Serializer, @@ -179,6 +186,62 @@ pub mod empty { } } +pub mod empty_opt { + struct EmptyVisitor; + #[cfg(feature = "std")] + impl<'de> serde::de::Visitor<'de> for EmptyVisitor { + type Value = std::option::Option<()>; + + fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result { + formatter.write_str("a valid empty object") + } + + fn visit_map(self, map: A) -> Result + where + A: serde::de::MapAccess<'de>, + { + let _ = map; + Ok(Some(())) + } + + fn visit_unit(self) -> Result + where + E: serde::de::Error, + { + Ok(Some(())) + } + + fn visit_none(self) -> Result + where + E: serde::de::Error, + { + Ok(None) + } + } + + #[cfg(feature = "std")] + pub fn deserialize<'de, D>(deserializer: D) -> Result, D::Error> + where + D: serde::de::Deserializer<'de>, + { + deserializer.deserialize_any(EmptyVisitor) + } + + #[cfg(feature = "std")] + pub fn serialize(opt: &std::option::Option<()>, serializer: S) -> Result + where + S: serde::Serializer, + { + use serde::ser::SerializeMap; + if opt.is_some() { + let map = serializer.serialize_map(Some(0))?; + map.end() + } else { + serializer.serialize_none() + } + } +} + pub mod vec { struct VecVisitor<'de, T> where @@ -2147,18 +2210,17 @@ pub mod u64_opt { deserializer.deserialize_any(U64Visitor) } - #[cfg(feature = "std")] - pub fn serialize(value: &std::option::Option, serializer: S) -> Result - where - S: serde::Serializer, - { - use crate::serde::SerializeMethod; - match value { - None => serializer.serialize_none(), - Some(double) => crate::serde::u64::U64Serializer::serialize(double, serializer), - } + #[cfg(feature = "std")] + pub fn serialize(value: &std::option::Option, serializer: S) -> Result + where + S: serde::Serializer, + { + use crate::serde::SerializeMethod; + match value { + None => serializer.serialize_none(), + Some(double) => crate::serde::u64::U64Serializer::serialize(double, serializer), } - + } } pub mod f64 { From db0d9a8ebe911d50769a672212c6e5688234499e Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Thu, 17 Mar 2022 15:43:17 -0700 Subject: [PATCH 26/30] add better empty impl --- prost-types/src/serde.rs | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index 8fa7bb011..49daead12 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -160,12 +160,16 @@ pub mod empty { formatter.write_str("a valid empty object") } - fn visit_map(self, map: A) -> Result + fn visit_map(self, mut map: A) -> Result where A: serde::de::MapAccess<'de>, { - let _ = map; - Ok(()) + let tmp: std::option::Option<((), ())> = map.next_entry()?; + if tmp.is_some() { + Err(::serde::de::Error::custom("this is a message, not empty")) + } else { + Ok(()) + } } } @@ -196,12 +200,16 @@ pub mod empty_opt { formatter.write_str("a valid empty object") } - fn visit_map(self, map: A) -> Result + fn visit_map(self, mut map: A) -> Result where A: serde::de::MapAccess<'de>, { - let _ = map; - Ok(Some(())) + let tmp: std::option::Option<((), ())> = map.next_entry()?; + if tmp.is_some() { + Err(::serde::de::Error::custom("this is a message, not empty")) + } else { + Ok(Some(())) + } } fn visit_unit(self) -> Result From ac23da01a952441f54b7d264d5f261067d19eacc Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Thu, 17 Mar 2022 16:00:22 -0700 Subject: [PATCH 27/30] make a change for enum validations --- prost-types/src/serde.rs | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index 49daead12..af57ee22f 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -425,16 +425,20 @@ pub mod enum_serde { )), } } + fn visit_i64(self, value: i64) -> Result where E: serde::de::Error, { match T::try_from(value as i32) { Ok(en) => Ok(en.into()), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), + // There is a test in the conformance tests: + // Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator + // That implies this should return the default value, so we + // will. This also helps when parsing a oneof, since this means + // we won't fail to deserialize when we have an out of bounds + // enum value. + Err(_) => Ok(T::default().into()), } } @@ -562,10 +566,13 @@ pub mod enum_opt { { match T::try_from(value as i32) { Ok(en) => Ok(Some(en.into())), - Err(_) => Err(serde::de::Error::invalid_value( - serde::de::Unexpected::Signed(value as i64), - &self, - )), + // There is a test in the conformance tests: + // Required.Proto3.JsonInput.EnumFieldUnknownValue.Validator + // That implies this should return the default value, so we + // will. This also helps when parsing a oneof, since this means + // we won't fail to deserialize when we have an out of bounds + // enum value. + Err(_) => Ok(Some(T::default().into())), } } From 9374f797e2e739420f1b21d93bc507f718da1ebe Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 21 Mar 2022 00:31:47 -0700 Subject: [PATCH 28/30] timestamp implementation --- conformance/failing_tests.txt | 14 +- prost-types/Cargo.toml | 4 +- prost-types/src/datetime.rs | 534 ++++++++++++++++++++++++++++++++++ prost-types/src/lib.rs | 3 + prost-types/src/serde.rs | 28 +- 5 files changed, 551 insertions(+), 32 deletions(-) create mode 100644 prost-types/src/datetime.rs diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 2292c32f6..87c529588 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -11,11 +11,6 @@ Recommended.Proto3.JsonInput.MapFieldValueIsNull Recommended.Proto3.JsonInput.NullValueInOtherOneofNewFormat.Validator Recommended.Proto3.JsonInput.NullValueInOtherOneofOldFormat.Validator Recommended.Proto3.JsonInput.RepeatedFieldPrimitiveElementIsNull -Recommended.Proto3.JsonInput.TimestampHas3FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHas6FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHas9FractionalDigits.Validator -Recommended.Proto3.JsonInput.TimestampHasZeroFractionalDigit.Validator -Recommended.Proto3.JsonInput.TimestampZeroNormalized.Validator Required.DurationProtoInputTooLarge.JsonOutput Required.DurationProtoInputTooSmall.JsonOutput Required.Proto2.ProtobufInput.UnknownVarint.ProtobufOutput @@ -69,10 +64,11 @@ Required.Proto3.JsonInput.Struct.JsonOutput Required.Proto3.JsonInput.Struct.ProtobufOutput Required.Proto3.JsonInput.StructWithEmptyListValue.JsonOutput Required.Proto3.JsonInput.StructWithEmptyListValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampMinValue.JsonOutput -Required.Proto3.JsonInput.TimestampMinValue.ProtobufOutput -Required.Proto3.JsonInput.TimestampRepeatedValue.JsonOutput -Required.Proto3.JsonInput.TimestampRepeatedValue.ProtobufOutput +Required.Proto3.JsonInput.TimestampJsonInputLowercaseT +Required.Proto3.JsonInput.TimestampJsonInputLowercaseZ +Required.Proto3.JsonInput.TimestampJsonInputMissingT +Required.Proto3.JsonInput.TimestampJsonInputMissingZ +Required.Proto3.JsonInput.TimestampJsonInputTooSmall Required.Proto3.JsonInput.ValueAcceptBool.JsonOutput Required.Proto3.JsonInput.ValueAcceptBool.ProtobufOutput Required.Proto3.JsonInput.ValueAcceptFloat.JsonOutput diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 3ba14233f..3ef4df3df 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -18,14 +18,14 @@ doctest = false [features] default = ["std"] std = ["prost/std"] -json = ["prost/json", "chrono"] +json = ["prost/json", "lexical"] [dependencies] base64 = "0.13" bytes = { version = "1", default-features = false } serde = { version = "1", features = ["derive"] } prost = { version = "0.9.0", path = "..", default-features = false, features = ["prost-derive"] } -chrono = { version = "0.4", optional = true } +lexical = { version = "6", default-features = false, features = ["parse-integers", "parse-floats"], optional = true } [dev-dependencies] proptest = "1" diff --git a/prost-types/src/datetime.rs b/prost-types/src/datetime.rs new file mode 100644 index 000000000..93712fe7c --- /dev/null +++ b/prost-types/src/datetime.rs @@ -0,0 +1,534 @@ +//! A date/time type which exists primarily to convert [`Timestamp`]s into an RFC 3339 formatted +//! string. + +/// A point in time, represented as a date and time in the UTC timezone. +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] +pub(crate) struct DateTime { + /// The year. + pub(crate) year: i64, + /// The month of the year, from 1 to 12, inclusive. + pub(crate) month: u8, + /// The day of the month, from 1 to 31, inclusive. + pub(crate) day: u8, + /// The hour of the day, from 0 to 23, inclusive. + pub(crate) hour: u8, + /// The minute of the hour, from 0 to 59, inclusive. + pub(crate) minute: u8, + /// The second of the minute, from 0 to 59, inclusive. + pub(crate) second: u8, + /// The nanoseconds, from 0 to 999_999_999, inclusive. + pub(crate) nanos: u32, +} + +impl DateTime { + /// The minimum representable [`Timestamp`] as a `DateTime`. + pub(crate) const MIN: DateTime = DateTime { + year: -292_277_022_657, + month: 1, + day: 27, + hour: 8, + minute: 29, + second: 52, + nanos: 0, + }; + + /// The maximum representable [`Timestamp`] as a `DateTime`. + pub(crate) const MAX: DateTime = DateTime { + year: 292_277_026_596, + month: 12, + day: 4, + hour: 15, + minute: 30, + second: 7, + nanos: 999_999_999, + }; + + /// Returns `true` if the `DateTime` is a valid calendar date. + pub(crate) fn is_valid(&self) -> bool { + self >= &DateTime::MIN + && self <= &DateTime::MAX + && self.month > 0 + && self.month <= 12 + && self.day > 0 + && self.day <= days_in_month(self.year, self.month) + && self.hour < 24 + && self.minute < 60 + && self.second < 60 + && self.nanos < 1_000_000_000 + } + + /// Returns a `Display`-able type which formats only the time portion of the datetime, e.g. `12:34:56.123456`. + pub(crate) fn time(self) -> impl std::fmt::Display { + struct Time { + inner: DateTime, + } + + impl std::fmt::Display for Time { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // Format subseconds to either nothing, millis, micros, or nanos. + let nanos = self.inner.nanos; + let subsec = if nanos == 0 { + format!("") + } else if nanos % 1_000_000 == 0 { + format!(".{:03}", nanos / 1_000_000) + } else if nanos % 1_000 == 0 { + format!(".{:06}", nanos / 1_000) + } else { + format!(".{:09}", nanos) + }; + + write!( + f, + "{:02}:{:02}:{:02}{}", + self.inner.hour, self.inner.minute, self.inner.second, subsec, + ) + } + } + + Time { inner: self } + } +} + +impl std::fmt::Display for DateTime { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + // Pad years to at least 4 digits. + let year = if self.year > 9999 { + format!("+{}", self.year) + } else if self.year < 0 { + format!("{:05}", self.year) + } else { + format!("{:04}", self.year) + }; + + write!( + f, + "{}-{:02}-{:02}T{}Z", + year, + self.month, + self.day, + self.time() + ) + } +} + +impl From for DateTime { + /// musl's [`__secs_to_tm`][1] converted to Rust via [c2rust][2] and then cleaned up by hand. + /// + /// All existing `strftime`-like APIs in Rust are unable to handle the full range of timestamps + /// representable by `Timestamp`, including `strftime` itself, since tm.tm_year is an int. + /// + /// [1]: http://git.musl-libc.org/cgit/musl/tree/src/time/__secs_to_tm.c + /// [2]: https://c2rust.com/ + fn from(timestamp: crate::Timestamp) -> DateTime { + let t = timestamp.seconds; + let nanos = timestamp.nanos; + + // 2000-03-01 (mod 400 year, immediately after feb29 + const LEAPOCH: i64 = 946_684_800 + 86400 * (31 + 29); + const DAYS_PER_400Y: i32 = 365 * 400 + 97; + const DAYS_PER_100Y: i32 = 365 * 100 + 24; + const DAYS_PER_4Y: i32 = 365 * 4 + 1; + const DAYS_IN_MONTH: [u8; 12] = [31, 30, 31, 30, 31, 31, 30, 31, 30, 31, 31, 29]; + + // Note(dcb): this bit is rearranged slightly to avoid integer overflow. + let mut days: i64 = (t / 86_400) - (LEAPOCH / 86_400); + let mut remsecs: i32 = (t % 86_400) as i32; + if remsecs < 0i32 { + remsecs += 86_400; + days -= 1 + } + + let mut qc_cycles: i32 = (days / i64::from(DAYS_PER_400Y)) as i32; + let mut remdays: i32 = (days % i64::from(DAYS_PER_400Y)) as i32; + if remdays < 0 { + remdays += DAYS_PER_400Y; + qc_cycles -= 1; + } + + let mut c_cycles: i32 = remdays / DAYS_PER_100Y; + if c_cycles == 4 { + c_cycles -= 1; + } + remdays -= c_cycles * DAYS_PER_100Y; + + let mut q_cycles: i32 = remdays / DAYS_PER_4Y; + if q_cycles == 25 { + q_cycles -= 1; + } + remdays -= q_cycles * DAYS_PER_4Y; + + let mut remyears: i32 = remdays / 365; + if remyears == 4 { + remyears -= 1; + } + remdays -= remyears * 365; + + let mut years: i64 = i64::from(remyears) + + 4 * i64::from(q_cycles) + + 100 * i64::from(c_cycles) + + 400 * i64::from(qc_cycles); + + let mut months: i32 = 0; + while i32::from(DAYS_IN_MONTH[months as usize]) <= remdays { + remdays -= i32::from(DAYS_IN_MONTH[months as usize]); + months += 1 + } + + if months >= 10 { + months -= 12; + years += 1; + } + + let date_time = DateTime { + year: years + 2000, + month: (months + 3) as u8, + day: (remdays + 1) as u8, + hour: (remsecs / 3600) as u8, + minute: (remsecs / 60 % 60) as u8, + second: (remsecs % 60) as u8, + nanos: nanos as u32, + }; + debug_assert!(date_time.is_valid()); + date_time + } +} + +/// Returns the number of days in the month. +fn days_in_month(year: i64, month: u8) -> u8 { + const DAYS_IN_MONTH: [u8; 12] = [31, 28, 31, 30, 31, 30, 31, 31, 30, 31, 30, 31]; + let (_, is_leap) = year_to_seconds(year); + DAYS_IN_MONTH[usize::from(month - 1)] + u8::from(is_leap && month == 2) +} + + +macro_rules! ensure { + ($expr:expr) => {{ + if !$expr { + return None; + } + }}; +} + +/// Parses a date in RFC 3339 format from `b`, returning the year, month, day, and remaining input. +/// +/// The date is not validated according to a calendar. +fn parse_date(b: &[u8]) -> Option<(i64, u8, u8, &[u8])> { + // Smallest valid date is YYYY-MM-DD. + ensure!(b.len() >= 10); + + // Parse the year in one of three formats: + // * +YYYY[Y]+ + // * -[Y]+ + // * YYYY + let (year, b) = if b[0] == b'+' { + let (digits, b) = parse_digits(&b[1..]); + ensure!(digits.len() >= 5); + let date: i64 = lexical::parse(digits).ok()?; + (date, b) + } else if b[0] == b'-' { + let (digits, b) = parse_digits(&b[1..]); + ensure!(digits.len() >= 4); + let date: i64 = lexical::parse(digits).ok()?; + (-date, b) + } else { + // Parse a 4 digit numeric. + let (n1, b) = parse_two_digit_numeric(b)?; + let (n2, b) = parse_two_digit_numeric(b)?; + (i64::from(n1) * 100 + i64::from(n2), b) + }; + + let b = parse_char(b, b'-')?; + let (month, b) = parse_two_digit_numeric(b)?; + let b = parse_char(b, b'-')?; + let (day, b) = parse_two_digit_numeric(b)?; + Some((year, month, day, b)) +} + + +/// Parses a time in RFC 3339 format from `b`, returning the hour, minute, second, and nanos. +/// +/// The date is not validated according to a calendar. +fn parse_time(b: &[u8]) -> Option<(u8, u8, u8, u32, &[u8])> { + let (hour, b) = parse_two_digit_numeric(b)?; + let b = parse_char(b, b':')?; + let (minute, b) = parse_two_digit_numeric(b)?; + let b = parse_char(b, b':')?; + let (second, b) = parse_two_digit_numeric(b)?; + + // Parse the nanoseconds, if present. + let (nanos, b) = if let Some(b) = parse_char(b, b'.') { + let (digits, b) = parse_digits(b); + ensure!(digits.len() <= 9); + let nanos = 10u32.pow(9 - digits.len() as u32) * lexical::parse::(digits).ok()?; + (nanos, b) + } else { + (0, b) + }; + + Some((hour, minute, second, nanos, b)) +} + +/// Parses a timezone offset in RFC 3339 format from `b`, returning the offset hour, offset minute, +/// and remaining input. +fn parse_offset(b: &[u8]) -> Option<(i8, i8, &[u8])> { + if b.is_empty() { + // If no timezone specified, assume UTC. + return Some((0, 0, b)); + } + + // Snowflake's timestamp format contains a space seperator before the offset. + let b = parse_char(b, b' ').unwrap_or(b); + + if let Some(b) = parse_char_ignore_case(b, b'Z') { + Some((0, 0, b)) + } else { + let (is_positive, b) = if let Some(b) = parse_char(b, b'+') { + (true, b) + } else if let Some(b) = parse_char(b, b'-') { + (false, b) + } else { + return None; + }; + + let (hour, b) = parse_two_digit_numeric(b)?; + + let (minute, b) = if b.is_empty() { + // No offset minutes are sepcified, e.g. +00 or +07. + (0, b) + } else { + // Optional colon seperator between the hour and minute digits. + let b = parse_char(b, b':').unwrap_or(b); + let (minute, b) = parse_two_digit_numeric(b)?; + (minute, b) + }; + + // '-00:00' indicates an unknown local offset. + ensure!(is_positive || hour > 0 || minute > 0); + + ensure!(hour < 24 && minute < 60); + + let hour = hour as i8; + let minute = minute as i8; + + if is_positive { + Some((hour, minute, b)) + } else { + Some((-hour, -minute, b)) + } + } +} + +/// Parses a two-digit base-10 number from `b`, returning the number and the remaining bytes. +fn parse_two_digit_numeric(b: &[u8]) -> Option<(u8, &[u8])> { + ensure!(b.len() >= 2); + let (digits, b) = b.split_at(2); + Some((lexical::parse(digits).ok()?, b)) +} + +/// Splits `b` at the first occurance of a non-digit character. +fn parse_digits(b: &[u8]) -> (&[u8], &[u8]) { + let idx = b + .iter() + .position(|c| !c.is_ascii_digit()) + .unwrap_or_else(|| b.len()); + b.split_at(idx) +} + +/// Attempts to parse `c` from `b`, returning the remaining bytes. If the character can not be +/// parsed, returns `None`. +fn parse_char(b: &[u8], c: u8) -> Option<&[u8]> { + let (&first, rest) = b.split_first()?; + ensure!(first == c); + Some(rest) +} + +/// Attempts to parse `c` from `b`, ignoring ASCII case, returning the remaining bytes. If the +/// character can not be parsed, returns `None`. +fn parse_char_ignore_case(b: &[u8], c: u8) -> Option<&[u8]> { + let (first, rest) = b.split_first()?; + ensure!(first.eq_ignore_ascii_case(&c)); + Some(rest) +} + +/// Returns the offset in seconds from the Unix epoch of the date time. +/// +/// This is musl's [`__tm_to_secs`][1] converted to Rust via [c2rust[2] and then cleaned up by +/// hand. +/// +/// [1]: https://git.musl-libc.org/cgit/musl/tree/src/time/__tm_to_secs.c +/// [2]: https://c2rust.com/ +fn date_time_to_seconds(tm: &DateTime) -> i64 { + let (start_of_year, is_leap) = year_to_seconds(tm.year); + + let seconds_within_year = month_to_seconds(tm.month, is_leap) + + 86400 * u32::from(tm.day - 1) + + 3600 * u32::from(tm.hour) + + 60 * u32::from(tm.minute) + + u32::from(tm.second); + + (start_of_year + i128::from(seconds_within_year)) as i64 +} + +/// Returns the number of seconds in the year prior to the start of the provided month. +/// +/// This is musl's [`__month_to_secs`][1] converted to Rust via c2rust and then cleaned up by hand. +/// +/// [1]: https://git.musl-libc.org/cgit/musl/tree/src/time/__month_to_secs.c +fn month_to_seconds(month: u8, is_leap: bool) -> u32 { + const SECS_THROUGH_MONTH: [u32; 12] = [ + 0, + 31 * 86400, + 59 * 86400, + 90 * 86400, + 120 * 86400, + 151 * 86400, + 181 * 86400, + 212 * 86400, + 243 * 86400, + 273 * 86400, + 304 * 86400, + 334 * 86400, + ]; + let t = SECS_THROUGH_MONTH[usize::from(month - 1)]; + if is_leap && month > 2 { + t + 86400 + } else { + t + } +} + +/// Returns the offset in seconds from the Unix epoch of the start of a year. +/// +/// musl's [`__year_to_secs`][1] converted to Rust via c2rust and then cleaned up by hand. +/// +/// Returns an i128 because the start of the earliest supported year underflows i64. +/// +/// [1]: https://git.musl-libc.org/cgit/musl/tree/src/time/__year_to_secs.c +pub(crate) fn year_to_seconds(year: i64) -> (i128, bool) { + let is_leap; + let year = year - 1900; + + // Fast path for years 1900 - 2038. + if year as u64 <= 138 { + let mut leaps: i64 = (year - 68) >> 2; + if (year - 68).trailing_zeros() >= 2 { + leaps -= 1; + is_leap = true; + } else { + is_leap = false; + } + return ( + i128::from(31_536_000 * (year - 70) + 86400 * leaps), + is_leap, + ); + } + + let centuries: i64; + let mut leaps: i64; + + let mut cycles: i64 = (year - 100) / 400; + let mut rem: i64 = (year - 100) % 400; + + if rem < 0 { + cycles -= 1; + rem += 400 + } + if rem == 0 { + is_leap = true; + centuries = 0; + leaps = 0; + } else { + if rem >= 200 { + if rem >= 300 { + centuries = 3; + rem -= 300; + } else { + centuries = 2; + rem -= 200; + } + } else if rem >= 100 { + centuries = 1; + rem -= 100; + } else { + centuries = 0; + } + if rem == 0 { + is_leap = false; + leaps = 0; + } else { + leaps = rem / 4; + rem %= 4; + is_leap = rem == 0; + } + } + leaps += 97 * cycles + 24 * centuries - i64::from(is_leap); + + ( + i128::from((year - 100) * 31_536_000) + i128::from(leaps * 86400 + 946_684_800 + 86400), + is_leap, + ) +} + +/// Parses a timestamp in RFC 3339 format from `b`. +pub(crate) fn parse_timestamp(b: &[u8]) -> Option { + let (year, month, day, b) = parse_date(b)?; + + if b.is_empty() { + // The string only contained a date. + let date_time = DateTime { + year, + month, + day, + ..DateTime::default() + }; + + ensure!(date_time.is_valid()); + + return Some(crate::Timestamp::from(date_time)); + } + + // Accept either 'T' or ' ' as delimeter between date and time. + let b = parse_char_ignore_case(b, b'T').or_else(|| parse_char(b, b' '))?; + let (hour, minute, mut second, nanos, b) = parse_time(b)?; + let (offset_hour, offset_minute, b) = parse_offset(b)?; + + ensure!(b.is_empty()); + + // Detect whether the timestamp falls in a leap second. If this is the case, roll it back + // to the previous second. To be maximally conservative, this should be checking that the + // timestamp is the last second in the UTC day (23:59:60), and even potentially checking + // that it's the final day of the UTC month, however these checks are non-trivial because + // at this point we have, in effect, a local date time, since the offset has not been + // applied. + if second == 60 { + second = 59; + } + + let date_time = DateTime { + year, + month, + day, + hour, + minute, + second, + nanos, + }; + + ensure!(date_time.is_valid()); + + let crate::Timestamp { seconds, nanos } = crate::Timestamp::from(date_time); + + let seconds = + seconds.checked_sub(i64::from(offset_hour) * 3600 + i64::from(offset_minute) * 60)?; + + Some(crate::Timestamp { seconds, nanos }) +} + + +impl From for crate::Timestamp { + fn from(date_time: DateTime) -> crate::Timestamp { + let seconds = date_time_to_seconds(&date_time); + let nanos = date_time.nanos; + crate::Timestamp { seconds, nanos: nanos as i32 } + } +} diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 1e5450020..44d5bd6b4 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -24,6 +24,9 @@ pub mod compiler { #[cfg(feature = "json")] pub mod serde; +#[cfg(feature = "json")] +pub mod datetime; + // The Protobuf `Duration` and `Timestamp` types can't delegate to the standard library equivalents // because the Protobuf versions are signed. To make them easier to work with, `From` conversions // are defined in both directions. diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index af57ee22f..26a01afe3 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -4,14 +4,7 @@ impl ::serde::Serialize for crate::Timestamp { where S: ::serde::Serializer, { - use std::convert::TryInto; - serializer.serialize_str( - &chrono::DateTime::::from_utc( - chrono::NaiveDateTime::from_timestamp(self.seconds, self.nanos.try_into().unwrap()), - chrono::Utc, - ) - .to_rfc3339(), - ) + serializer.serialize_str(&crate::datetime::DateTime::from(self.clone()).to_string()) } } @@ -29,19 +22,12 @@ impl<'de> ::serde::de::Visitor<'de> for TimestampVisitor { where E: ::serde::de::Error, { - use std::convert::TryInto; - let dt = chrono::DateTime::parse_from_rfc3339(value) - .map_err(::serde::de::Error::custom)? - .naive_utc(); - Ok(crate::Timestamp::from( - std::time::UNIX_EPOCH - + std::time::Duration::new( - dt.timestamp() - .try_into() - .map_err(::serde::de::Error::custom)?, - dt.timestamp_subsec_nanos(), - ), - )) + crate::datetime::parse_timestamp(value.as_bytes()).ok_or_else( || { + serde::de::Error::invalid_value( + serde::de::Unexpected::Str(value), + &self, + ) + }) } } From 327c852da674a9af78a1b718b393da0809546e67 Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Mon, 21 Mar 2022 21:52:35 -0700 Subject: [PATCH 29/30] double precision --- conformance/failing_tests.txt | 4 ---- tests/Cargo.toml | 2 +- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/conformance/failing_tests.txt b/conformance/failing_tests.txt index 87c529588..bb578ef6e 100644 --- a/conformance/failing_tests.txt +++ b/conformance/failing_tests.txt @@ -34,10 +34,6 @@ Required.Proto3.JsonInput.AnyWithValueForInteger.JsonOutput Required.Proto3.JsonInput.AnyWithValueForInteger.ProtobufOutput Required.Proto3.JsonInput.AnyWithValueForJsonObject.JsonOutput Required.Proto3.JsonInput.AnyWithValueForJsonObject.ProtobufOutput -Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMaxNegativeValue.ProtobufOutput -Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.JsonOutput -Required.Proto3.JsonInput.DoubleFieldMinPositiveValue.ProtobufOutput Required.Proto3.JsonInput.DurationJsonInputTooLarge Required.Proto3.JsonInput.DurationJsonInputTooSmall Required.Proto3.JsonInput.DurationMaxValue.JsonOutput diff --git a/tests/Cargo.toml b/tests/Cargo.toml index d750ddf20..f67f93827 100644 --- a/tests/Cargo.toml +++ b/tests/Cargo.toml @@ -22,7 +22,7 @@ prost = { path = ".." } prost-types = { path = "../prost-types" } protobuf = { path = "../protobuf" } serde = { version="1.0", features=["derive"] } -serde_json = { version="1.0" } +serde_json = { version="1.0", features=["float_roundtrip"]} serde_path_to_error = "0.1" [dev-dependencies] From a6a07779762a0d6fb150ba47e06bfaf686ade6ea Mon Sep 17 00:00:00 2001 From: Konrad Niemiec Date: Fri, 25 Mar 2022 12:00:03 -0500 Subject: [PATCH 30/30] most of the stuff --- prost-build/src/code_generator.rs | 85 ++++++++----- prost-types/Cargo.toml | 1 + prost-types/src/datetime.rs | 193 +++++++++++++++++++++++++++++- prost-types/src/lib.rs | 39 ++++++ prost-types/src/serde.rs | 65 +++++----- tests/src/lib.rs | 6 +- 6 files changed, 317 insertions(+), 72 deletions(-) diff --git a/prost-build/src/code_generator.rs b/prost-build/src/code_generator.rs index e96c2d2a0..50828d07f 100644 --- a/prost-build/src/code_generator.rs +++ b/prost-build/src/code_generator.rs @@ -322,7 +322,12 @@ impl<'a> CodeGenerator<'a> { } fn append_json_message_attributes(&mut self, fq_message_name: &str) { - if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + if self + .config + .json_mapping + .get_first(fq_message_name) + .is_some() + { self.push_indent(); self.buf .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); @@ -337,7 +342,12 @@ impl<'a> CodeGenerator<'a> { } fn append_json_oneof_enum_attributes(&mut self, fq_message_name: &str) { - if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + if self + .config + .json_mapping + .get_first(fq_message_name) + .is_some() + { self.push_indent(); self.buf .push_str("#[derive(serde::Deserialize, serde::Serialize)]"); @@ -350,7 +360,12 @@ impl<'a> CodeGenerator<'a> { fn append_json_oneof_field_attributes(&mut self, fq_message_name: &str) { assert_eq!(b'.', fq_message_name.as_bytes()[0]); - if let Some(_) = self.config.json_mapping.get_first(fq_message_name) { + if self + .config + .json_mapping + .get_first(fq_message_name) + .is_some() + { self.push_indent(); self.buf.push_str("#[serde(flatten)]"); self.buf.push('\n'); @@ -373,19 +388,20 @@ impl<'a> CodeGenerator<'a> { // Shared fields between field and map fields. fn append_shared_json_field_attributes(&mut self, field_name: &str, json_name: &str) { // If there is a json name specified, add it. - if json_name.len() > 0 { - push_indent(&mut self.buf, self.depth); + if !json_name.is_empty() { + push_indent(self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(rename = "{}")]"#, json_name,)); self.buf.push('\n'); } // Always alias to the field name for deserializing. - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(alias = "{}")]"#, field_name,)); self.buf.push('\n'); } + #[allow(clippy::too_many_arguments)] fn append_json_map_field_attributes( &mut self, fq_message_name: &str, @@ -397,13 +413,18 @@ impl<'a> CodeGenerator<'a> { map_type: &str, json_name: &str, ) { - if let None = self.config.json_mapping.get_first(fq_message_name) { + if self + .config + .json_mapping + .get_first(fq_message_name) + .is_none() + { return; } self.append_shared_json_field_attributes(field_name, json_name); // Use is_empty instead of is_default to avoid allocations. - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str(&format!( r#"#[serde(skip_serializing_if = "{}::is_empty")]"#, map_type @@ -415,14 +436,14 @@ impl<'a> CodeGenerator<'a> { let (value_se_opt, value_de_opt) = self.get_custom_json_type_mappers(value_ty, value_type_name, false, true, false); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); match (key_se_opt, key_de_opt, value_se_opt, value_de_opt, map_type) { (Some(key_se), Some(key_de), Some(value_se), Some(value_de), "::std::collections::HashMap") => { self.buf.push_str( &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_to_custom::serialize::<_, {}, {}>")]"#, key_se, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -436,7 +457,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -446,7 +467,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom::deserialize::<_, {}, _>")]"#, key_de) ); @@ -460,7 +481,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -470,7 +491,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::map_custom_value::deserialize::<_, _, {}>")]"#, value_de) ); @@ -481,7 +502,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_to_custom::serialize::<_, {}, {}>")]"#, key_se, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -495,7 +516,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -505,7 +526,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom::serialize::<_, {}, _>")]"#, key_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom::deserialize::<_, {}, _>")]"#, key_de) ); @@ -519,7 +540,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_to_custom::deserialize::<_, {}, {}>")]"#, key_de, value_de) ); @@ -530,7 +551,7 @@ impl<'a> CodeGenerator<'a> { &format!(r#"#[serde(serialize_with = "::prost_types::serde::btree_map_custom_value::serialize::<_, _, {}>")]"#, value_se) ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::btree_map_custom_value::deserialize::<_, _, {}>")]"#, value_de) ); @@ -548,6 +569,7 @@ impl<'a> CodeGenerator<'a> { self.buf.push('\n'); } + #[allow(clippy::too_many_arguments)] fn append_json_field_attributes( &mut self, fq_message_name: &str, @@ -559,13 +581,18 @@ impl<'a> CodeGenerator<'a> { json_name: &str, oneof: bool, ) { - if let None = self.config.json_mapping.get_first(fq_message_name) { + if self + .config + .json_mapping + .get_first(fq_message_name) + .is_none() + { return; } self.append_shared_json_field_attributes(field_name, json_name); if !oneof { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf .push_str(r#"#[serde(skip_serializing_if = "::prost_types::serde::is_default")]"#); self.buf.push('\n'); @@ -577,42 +604,42 @@ impl<'a> CodeGenerator<'a> { repeated, ) { ((Some(se), Some(de)), false) => { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(serialize_with = "{}")]"#, se)); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(deserialize_with = "{}")]"#, de)); self.buf.push('\n'); } ((None, Some(de)), false) => { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf .push_str(&format!(r#"#[serde(deserialize_with = "{}")]"#, de)); self.buf.push('\n'); } ((Some(se), Some(de)), true) => { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(serialize_with = "::prost_types::serde::repeated::serialize::<_, {}>")]"#, se), ); self.buf.push('\n'); - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::repeated::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } ((None, Some(de)), true) => { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( &format!(r#"#[serde(deserialize_with = "::prost_types::serde::repeated::deserialize::<_, {}>")]"#, de), ); self.buf.push('\n'); } (_, true) => { - push_indent(&mut self.buf, self.depth); + push_indent(self.buf, self.depth); self.buf.push_str( r#"#[serde(deserialize_with = "::prost_types::serde::vec::deserialize")]"#, ); @@ -1040,7 +1067,7 @@ impl<'a> CodeGenerator<'a> { self.push_indent(); self.buf .push_str(&format!(r#"#[prost(enum_field_name="{}")]"#, value.name())); - self.buf.push_str("\n"); + self.buf.push('\n'); self.push_indent(); let name = to_upper_camel(value.name()); let name_unprefixed = match prefix_to_strip { diff --git a/prost-types/Cargo.toml b/prost-types/Cargo.toml index 3ef4df3df..b1efdea27 100644 --- a/prost-types/Cargo.toml +++ b/prost-types/Cargo.toml @@ -29,3 +29,4 @@ lexical = { version = "6", default-features = false, features = ["parse-integers [dev-dependencies] proptest = "1" +serde_json = "1" \ No newline at end of file diff --git a/prost-types/src/datetime.rs b/prost-types/src/datetime.rs index 93712fe7c..335847b7f 100644 --- a/prost-types/src/datetime.rs +++ b/prost-types/src/datetime.rs @@ -68,7 +68,7 @@ impl DateTime { // Format subseconds to either nothing, millis, micros, or nanos. let nanos = self.inner.nanos; let subsec = if nanos == 0 { - format!("") + String::new() } else if nanos % 1_000_000 == 0 { format!(".{:03}", nanos / 1_000_000) } else if nanos % 1_000 == 0 { @@ -200,7 +200,6 @@ fn days_in_month(year: i64, month: u8) -> u8 { DAYS_IN_MONTH[usize::from(month - 1)] + u8::from(is_leap && month == 2) } - macro_rules! ensure { ($expr:expr) => {{ if !$expr { @@ -244,7 +243,6 @@ fn parse_date(b: &[u8]) -> Option<(i64, u8, u8, &[u8])> { Some((year, month, day, b)) } - /// Parses a time in RFC 3339 format from `b`, returning the hour, minute, second, and nanos. /// /// The date is not validated according to a calendar. @@ -524,11 +522,196 @@ pub(crate) fn parse_timestamp(b: &[u8]) -> Option { Some(crate::Timestamp { seconds, nanos }) } - impl From for crate::Timestamp { fn from(date_time: DateTime) -> crate::Timestamp { let seconds = date_time_to_seconds(&date_time); let nanos = date_time.nanos; - crate::Timestamp { seconds, nanos: nanos as i32 } + crate::Timestamp { + seconds, + nanos: nanos as i32, + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::Timestamp; + + #[test] + fn test_min_max() { + assert_eq!( + DateTime::MIN, + DateTime::from(Timestamp { + seconds: i64::MIN, + nanos: 0 + }), + ); + assert_eq!( + DateTime::MAX, + DateTime::from(Timestamp { + seconds: i64::MAX, + nanos: 999_999_999 + }), + ); + } + + #[test] + fn test_datetime_from_timestamp() { + let case = |expected: &str, secs: i64, nanos: i32| { + let timestamp = Timestamp { + seconds: secs, + nanos, + }; + assert_eq!( + expected, + format!("{}", DateTime::from(timestamp.clone())), + "timestamp: {:?}", + timestamp + ); + }; + + // Mostly generated with: + // - date -jur +"%Y-%m-%dT%H:%M:%S.000000000Z" + // - http://unixtimestamp.50x.eu/ + + case("1970-01-01T00:00:00Z", 0, 0); + + case("1970-01-01T00:00:00.000000001Z", 0, 1); + case("1970-01-01T00:00:00.123450Z", 0, 123_450_000); + case("1970-01-01T00:00:00.050Z", 0, 50_000_000); + case("1970-01-01T00:00:01.000000001Z", 1, 1); + case("1970-01-01T00:01:01.000000001Z", 60 + 1, 1); + case("1970-01-01T01:01:01.000000001Z", 60 * 60 + 60 + 1, 1); + case( + "1970-01-02T01:01:01.000000001Z", + 24 * 60 * 60 + 60 * 60 + 60 + 1, + 1, + ); + + case("1969-12-31T23:59:59Z", -1, 0); + case("1969-12-31T23:59:59.000001Z", -1, 1_000); + case("1969-12-31T23:59:59.500Z", -1, 500_000_000); + case("1969-12-31T23:58:59.000001Z", -60 - 1, 1_000); + case("1969-12-31T22:58:59.000001Z", -60 * 60 - 60 - 1, 1_000); + case( + "1969-12-30T22:58:59.000000001Z", + -24 * 60 * 60 - 60 * 60 - 60 - 1, + 1, + ); + + case("2038-01-19T03:14:07Z", i32::MAX as i64, 0); + case("2038-01-19T03:14:08Z", i32::MAX as i64 + 1, 0); + case("1901-12-13T20:45:52Z", i32::MIN as i64, 0); + case("1901-12-13T20:45:51Z", i32::MIN as i64 - 1, 0); + case("+292277026596-12-04T15:30:07Z", i64::MAX, 0); + case("+292277026596-12-04T15:30:06Z", i64::MAX - 1, 0); + case("-292277022657-01-27T08:29:53Z", i64::MIN + 1, 0); + + case("1900-01-01T00:00:00Z", -2_208_988_800, 0); + case("1899-12-31T23:59:59Z", -2_208_988_801, 0); + case("0000-01-01T00:00:00Z", -62_167_219_200, 0); + case("-0001-12-31T23:59:59Z", -62_167_219_201, 0); + + case("1234-05-06T07:08:09Z", -23_215_049_511, 0); + case("-1234-05-06T07:08:09Z", -101_097_651_111, 0); + case("2345-06-07T08:09:01Z", 11_847_456_541, 0); + case("-2345-06-07T08:09:01Z", -136_154_620_259, 0); + } + + #[test] + fn test_parse_timestamp() { + // RFC 3339 Section 5.8 Examples + assert_eq!( + serde_json::from_str::("\"1985-04-12T23:20:50.52Z\"").unwrap(), + Timestamp::date_time_nanos(1985, 4, 12, 23, 20, 50, 520_000_000), + ); + assert_eq!( + serde_json::from_str::("\"1996-12-19T16:39:57-08:00\"").unwrap(), + Timestamp::date_time(1996, 12, 20, 0, 39, 57), + ); + assert_eq!( + serde_json::from_str::("\"1996-12-19T16:39:57-08:00\"").unwrap(), + Timestamp::date_time(1996, 12, 20, 0, 39, 57), + ); + assert_eq!( + serde_json::from_str::("\"1990-12-31T23:59:60Z\"").unwrap(), + Timestamp::date_time(1990, 12, 31, 23, 59, 59), + ); + assert_eq!( + serde_json::from_str::("\"1990-12-31T15:59:60-08:00\"").unwrap(), + Timestamp::date_time(1990, 12, 31, 23, 59, 59), + ); + assert_eq!( + serde_json::from_str::("\"1937-01-01T12:00:27.87+00:20\"").unwrap(), + Timestamp::date_time_nanos(1937, 1, 1, 11, 40, 27, 870_000_000), + ); + + // Date + assert_eq!( + serde_json::from_str::("\"1937-01-01\"").unwrap(), + Timestamp::date(1937, 1, 1), + ); + + // Negative year + assert_eq!( + serde_json::from_str::("\"-0008-01-01\"").unwrap(), + Timestamp::date(-8, 1, 1), + ); + + // Plus year + assert_eq!( + serde_json::from_str::("\"+19370-01-01\"").unwrap(), + Timestamp::date(19370, 1, 1), + ); + + // Full nanos + assert_eq!( + serde_json::from_str::("\"2020-02-03T01:02:03.123456789Z\"").unwrap(), + Timestamp::date_time_nanos(2020, 2, 3, 1, 2, 3, 123_456_789), + ); + + // Leap day + assert_eq!( + serde_json::from_str::("\"2020-02-29T01:02:03.00Z\"").unwrap(), + Timestamp::from(DateTime { + year: 2020, + month: 2, + day: 29, + hour: 1, + minute: 2, + second: 3, + nanos: 0, + }), + ); + + // Test extensions to RFC 3339. + // ' ' instead of 'T' as date/time separator. + assert_eq!( + serde_json::from_str::("\"1985-04-12 23:20:50.52Z\"").unwrap(), + Timestamp::date_time_nanos(1985, 4, 12, 23, 20, 50, 520_000_000), + ); + + // No time zone specified. + assert_eq!( + serde_json::from_str::("\"1985-04-12T23:20:50.52\"").unwrap(), + Timestamp::date_time_nanos(1985, 4, 12, 23, 20, 50, 520_000_000), + ); + + // Offset without minutes specified. + assert_eq!( + serde_json::from_str::("\"1996-12-19T16:39:57-08\"").unwrap(), + Timestamp::date_time(1996, 12, 20, 0, 39, 57), + ); + + // Snowflake stage style. + assert_eq!( + serde_json::from_str::("\"2015-09-12 00:47:19.591 Z\"").unwrap(), + Timestamp::date_time_nanos(2015, 9, 12, 0, 47, 19, 591_000_000), + ); + assert_eq!( + serde_json::from_str::("\"2020-06-15 00:01:02.123 +0800\"").unwrap(), + Timestamp::date_time_nanos(2020, 6, 14, 16, 1, 2, 123_000_000), + ); } } diff --git a/prost-types/src/lib.rs b/prost-types/src/lib.rs index 44d5bd6b4..cebc56db1 100644 --- a/prost-types/src/lib.rs +++ b/prost-types/src/lib.rs @@ -169,6 +169,45 @@ impl Timestamp { // debug_assert!(self.seconds >= -62_135_596_800 && self.seconds <= 253_402_300_799, // "invalid timestamp: {:?}", self); } + + /// Creates a new `Timestamp` at the start of the provided UTC date. + /// + /// This is primarily meant for creating timestamp literals in tests. + pub fn date(year: i64, month: u8, day: u8) -> Timestamp { + Timestamp::date_time_nanos(year, month, day, 0, 0, 0, 0) + } + + /// Creates a new `Timestamp` instance with the provided UTC date and time. + /// + /// This is primarily useful for creating timestamp literals in tests. + pub fn date_time(year: i64, month: u8, day: u8, hour: u8, minute: u8, second: u8) -> Timestamp { + Timestamp::date_time_nanos(year, month, day, hour, minute, second, 0) + } + + /// Creates a new `Timestamp` instance with the provided UTC date and time. + /// + /// This is primarily useful for creating timestamp literals in tests. + pub fn date_time_nanos( + year: i64, + month: u8, + day: u8, + hour: u8, + minute: u8, + second: u8, + nanos: u32, + ) -> Timestamp { + let date_time = crate::datetime::DateTime { + year, + month, + day, + hour, + minute, + second, + nanos, + }; + assert!(date_time.is_valid(), "invalid date time: {}", date_time); + Timestamp::from(date_time) + } } /// Implements the unstable/naive version of `Eq`: a basic equality check on the internal fields of the `Timestamp`. diff --git a/prost-types/src/serde.rs b/prost-types/src/serde.rs index 26a01afe3..97488ae43 100644 --- a/prost-types/src/serde.rs +++ b/prost-types/src/serde.rs @@ -22,11 +22,8 @@ impl<'de> ::serde::de::Visitor<'de> for TimestampVisitor { where E: ::serde::de::Error, { - crate::datetime::parse_timestamp(value.as_bytes()).ok_or_else( || { - serde::de::Error::invalid_value( - serde::de::Unexpected::Str(value), - &self, - ) + crate::datetime::parse_timestamp(value.as_bytes()).ok_or_else(|| { + serde::de::Error::invalid_value(serde::de::Unexpected::Str(value), &self) }) } } @@ -341,7 +338,7 @@ pub mod repeated { } pub fn serialize( - value: &Vec<::Value>, + value: &[::Value], serializer: S, ) -> Result where @@ -378,9 +375,9 @@ pub mod enum_serde { + Default, { fn new() -> Self { - return Self { + Self { _type: &std::marker::PhantomData, - }; + } } } @@ -1327,7 +1324,7 @@ pub mod string { where E: serde::de::Error, { - return Ok(value.to_string()); + Ok(value.to_string()) } fn visit_unit(self) -> Result @@ -1362,7 +1359,7 @@ pub mod string_opt { where E: serde::de::Error, { - return Ok(Some(value.to_string())); + Ok(Some(value.to_string())) } fn visit_unit(self) -> Result @@ -1396,7 +1393,7 @@ pub mod bool { impl crate::serde::HasConstructor for BoolVisitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -1412,7 +1409,7 @@ pub mod bool { where E: serde::de::Error, { - return Ok(value); + Ok(value) } fn visit_unit(self) -> Result @@ -1436,7 +1433,7 @@ pub mod bool_map_key { impl crate::serde::HasConstructor for BoolVisitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -1503,7 +1500,7 @@ pub mod bool_opt { where E: serde::de::Error, { - return Ok(Some(value)); + Ok(Some(value)) } fn visit_unit(self) -> Result @@ -1535,7 +1532,7 @@ pub mod i32 { impl crate::serde::HasConstructor for I32Visitor { fn new() -> I32Visitor { - return I32Visitor {}; + I32Visitor {} } } @@ -1628,7 +1625,7 @@ pub mod i32_opt { E: serde::de::Error, { use std::convert::TryFrom; - i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + i32::try_from(value).map(Some).map_err(E::custom) } fn visit_f64(self, value: f64) -> Result @@ -1654,7 +1651,7 @@ pub mod i32_opt { E: serde::de::Error, { use std::convert::TryFrom; - i32::try_from(value).map(|x| Some(x)).map_err(E::custom) + i32::try_from(value).map(Some).map_err(E::custom) } fn visit_str(self, value: &str) -> Result @@ -1668,7 +1665,7 @@ pub mod i32_opt { .map_err(E::custom) .and_then(|x| self.visit_f64(x)) } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) + value.parse::().map(Some).map_err(E::custom) } } @@ -1701,7 +1698,7 @@ pub mod i64 { impl crate::serde::HasConstructor for I64Visitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -1831,7 +1828,7 @@ pub mod i64_opt { E: serde::de::Error, { use std::convert::TryFrom; - i64::try_from(value).map(|x| Some(x)).map_err(E::custom) + i64::try_from(value).map(Some).map_err(E::custom) } fn visit_str(self, value: &str) -> Result @@ -1845,7 +1842,7 @@ pub mod i64_opt { .map_err(E::custom) .and_then(|x| self.visit_f64(x)) } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) + value.parse::().map(Some).map_err(E::custom) } } @@ -1890,7 +1887,7 @@ pub mod u32 { impl crate::serde::HasConstructor for U32Visitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -1983,7 +1980,7 @@ pub mod u32_opt { E: serde::de::Error, { use std::convert::TryFrom; - u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + u32::try_from(value).map(Some).map_err(E::custom) } fn visit_f64(self, value: f64) -> Result @@ -2009,7 +2006,7 @@ pub mod u32_opt { E: serde::de::Error, { use std::convert::TryFrom; - u32::try_from(value).map(|x| Some(x)).map_err(E::custom) + u32::try_from(value).map(Some).map_err(E::custom) } fn visit_str(self, value: &str) -> Result @@ -2023,7 +2020,7 @@ pub mod u32_opt { .map_err(E::custom) .and_then(|x| self.visit_f64(x)) } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) + value.parse::().map(Some).map_err(E::custom) } } @@ -2056,7 +2053,7 @@ pub mod u64 { impl crate::serde::HasConstructor for U64Visitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -2184,7 +2181,7 @@ pub mod u64_opt { .map_err(E::custom) .and_then(|x| self.visit_f64(x)) } else { - value.parse::().map(|x| Some(x)).map_err(E::custom) + value.parse::().map(Some).map_err(E::custom) } } @@ -2229,7 +2226,7 @@ pub mod f64 { impl crate::serde::HasConstructor for F64Visitor { fn new() -> F64Visitor { - return F64Visitor {}; + F64Visitor {} } } @@ -2351,7 +2348,7 @@ pub mod f64_opt { "NaN" => Ok(Some(f64::NAN)), "Infinity" => Ok(Some(f64::INFINITY)), "-Infinity" => Ok(Some(f64::NEG_INFINITY)), - _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + _ => value.parse::().map(Some).map_err(E::custom), } } @@ -2396,7 +2393,7 @@ pub mod f32 { impl crate::serde::HasConstructor for F32Visitor { fn new() -> F32Visitor { - return F32Visitor {}; + F32Visitor {} } } @@ -2532,7 +2529,7 @@ pub mod f32_opt { "NaN" => Ok(Some(f32::NAN)), "Infinity" => Ok(Some(f32::INFINITY)), "-Infinity" => Ok(Some(f32::NEG_INFINITY)), - _ => value.parse::().map(|x| Some(x)).map_err(E::custom), + _ => value.parse::().map(Some).map_err(E::custom), } } @@ -2577,7 +2574,7 @@ pub mod vec_u8 { impl crate::serde::HasConstructor for VecU8Visitor { fn new() -> Self { - return Self {}; + Self {} } } @@ -2642,9 +2639,7 @@ pub mod vec_u8_opt { where E: serde::de::Error, { - base64::decode(value) - .map(|str| Some(str)) - .map_err(E::custom) + base64::decode(value).map(Some).map_err(E::custom) } fn visit_unit(self) -> Result diff --git a/tests/src/lib.rs b/tests/src/lib.rs index f585edbdf..58e0f4b09 100644 --- a/tests/src/lib.rs +++ b/tests/src/lib.rs @@ -148,8 +148,8 @@ where Ok(all_types) => Ok(all_types), Err(error) => Err(format!( "error deserializing json: {} at {}", - error.to_string(), - error.path().to_string() + error, + error.path(), )), } } @@ -238,7 +238,7 @@ pub fn roundtrip( where M: Message + Default + DeserializeOwned + Serialize, { - let all_types: M = match decode(payload.clone()) { + let all_types: M = match decode(payload) { Ok(all_types) => all_types, Err(error) => return RoundtripResult::DecodeError(error), };