diff --git a/docs/marshal.rst b/docs/marshal.rst index 1d2dcccc..ee61cb2e 100644 --- a/docs/marshal.rst +++ b/docs/marshal.rst @@ -71,8 +71,8 @@ Protocol buffer type Python type Nullable msg_two = MyMessage(msg_dict) assert msg == msg_pb == msg_two - - + + Wrapper types ------------- diff --git a/proto/marshal/marshal.py b/proto/marshal/marshal.py index d0dc2ead..baac7adc 100644 --- a/proto/marshal/marshal.py +++ b/proto/marshal/marshal.py @@ -26,6 +26,7 @@ from proto.marshal.collections import Repeated from proto.marshal.collections import RepeatedComposite from proto.marshal.rules import bytes as pb_bytes +from proto.marshal.rules import stringy_numbers from proto.marshal.rules import dates from proto.marshal.rules import struct from proto.marshal.rules import wrappers @@ -147,6 +148,11 @@ def reset(self): # Special case for bytes to allow base64 encode/decode self.register(ProtoType.BYTES, pb_bytes.BytesRule()) + # Special case for int64 from strings because of dict round trip. + # See https://github.com/protocolbuffers/protobuf/issues/2679 + for rule_class in stringy_numbers.STRINGY_NUMBER_RULES: + self.register(rule_class._proto_type, rule_class()) + def to_python(self, proto_type, value, *, absent: bool = None): # Internal protobuf has its own special type for lists of values. # Return a view around it that implements MutableSequence. diff --git a/proto/marshal/rules/message.py b/proto/marshal/rules/message.py index e5ecf17b..c865b99d 100644 --- a/proto/marshal/rules/message.py +++ b/proto/marshal/rules/message.py @@ -29,7 +29,16 @@ def to_proto(self, value): if isinstance(value, self._wrapper): return self._wrapper.pb(value) if isinstance(value, dict) and not self.is_map: - return self._descriptor(**value) + # We need to use the wrapper's marshaling to handle + # potentially problematic nested messages. + try: + # Try the fast path first. + return self._descriptor(**value) + except TypeError as ex: + # If we have a type error, + # try the slow path in case the error + # was an int64/string issue + return self._wrapper(value)._pb return value @property diff --git a/proto/marshal/rules/stringy_numbers.py b/proto/marshal/rules/stringy_numbers.py new file mode 100644 index 00000000..0d808cc2 --- /dev/null +++ b/proto/marshal/rules/stringy_numbers.py @@ -0,0 +1,68 @@ +# Copyright (C) 2021 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from proto.primitives import ProtoType + + +class StringyNumberRule: + """A marshal between certain numeric types and strings + + This is a necessary hack to allow round trip conversion + from messages to dicts back to messages. + + See https://github.com/protocolbuffers/protobuf/issues/2679 + and + https://developers.google.com/protocol-buffers/docs/proto3#json + for more details. + """ + + def to_python(self, value, *, absent: bool = None): + return value + + def to_proto(self, value): + return self._python_type(value) + + +class Int64Rule(StringyNumberRule): + _python_type = int + _proto_type = ProtoType.INT64 + + +class UInt64Rule(StringyNumberRule): + _python_type = int + _proto_type = ProtoType.UINT64 + + +class SInt64Rule(StringyNumberRule): + _python_type = int + _proto_type = ProtoType.SINT64 + + +class Fixed64Rule(StringyNumberRule): + _python_type = int + _proto_type = ProtoType.FIXED64 + + +class SFixed64Rule(StringyNumberRule): + _python_type = int + _proto_type = ProtoType.SFIXED64 + + +STRINGY_NUMBER_RULES = [ + Int64Rule, + UInt64Rule, + SInt64Rule, + Fixed64Rule, + SFixed64Rule, +] diff --git a/proto/message.py b/proto/message.py index 00ec4cc7..3dd22414 100644 --- a/proto/message.py +++ b/proto/message.py @@ -394,7 +394,7 @@ def to_dict( determines whether field name representations preserve proto case (snake_case) or use lowerCamelCase. Default is True. including_default_value_fields (Optional(bool)): An option that - determines whether the default field values should be included in the results. + determines whether the default field values should be included in the results. Default is True. Returns: @@ -453,7 +453,9 @@ class Message(metaclass=MessageMeta): message. """ - def __init__(self, mapping=None, *, ignore_unknown_fields=False, **kwargs): + def __init__( + self, mapping=None, *, ignore_unknown_fields=False, **kwargs, + ): # We accept several things for `mapping`: # * An instance of this class. # * An instance of the underlying protobuf descriptor class. diff --git a/tests/test_fields_int.py b/tests/test_fields_int.py index c3a979c0..40f5aa38 100644 --- a/tests/test_fields_int.py +++ b/tests/test_fields_int.py @@ -93,3 +93,46 @@ class Foo(proto.Message): bar_field = Foo.meta.fields["bar"] assert bar_field.descriptor is bar_field.descriptor + + +def test_int64_dict_round_trip(): + # When converting a message to other types, protobuf turns int64 fields + # into decimal coded strings. + # This is not a problem for round trip JSON, but it is a problem + # when doing a round trip conversion from a message to a dict to a message. + # See https://github.com/protocolbuffers/protobuf/issues/2679 + # and + # https://developers.google.com/protocol-buffers/docs/proto3#json + # for more details. + class Squid(proto.Message): + mass_kg = proto.Field(proto.INT64, number=1) + length_cm = proto.Field(proto.UINT64, number=2) + age_s = proto.Field(proto.FIXED64, number=3) + depth_m = proto.Field(proto.SFIXED64, number=4) + serial_num = proto.Field(proto.SINT64, number=5) + + s = Squid(mass_kg=10, length_cm=20, age_s=30, depth_m=40, serial_num=50) + + s_dict = Squid.to_dict(s) + + s2 = Squid(s_dict) + + assert s == s2 + + # Double check that the conversion works with deeply nested messages. + class Clam(proto.Message): + class Shell(proto.Message): + class Pearl(proto.Message): + mass_kg = proto.Field(proto.INT64, number=1) + + pearl = proto.Field(Pearl, number=1) + + shell = proto.Field(Shell, number=1) + + c = Clam(shell=Clam.Shell(pearl=Clam.Shell.Pearl(mass_kg=10))) + + c_dict = Clam.to_dict(c) + + c2 = Clam(c_dict) + + assert c == c2