diff --git a/src/betterproto/__init__.py b/src/betterproto/__init__.py index da6e57fa8..26311f7af 100644 --- a/src/betterproto/__init__.py +++ b/src/betterproto/__init__.py @@ -310,6 +310,26 @@ def map_field( ) +def _is_sequence(x: Any) -> bool: + return ( + not isinstance(x, str) + and not isinstance(x, bytes) + and isinstance(x, typing.Sequence) + ) + + +def _is_empty_sequence(x: Any) -> bool: + return _is_sequence(x) and len(x) == 0 + + +def _is_nonempty_sequence(x: Any) -> bool: + return _is_sequence(x) and len(x) != 0 + + +def _is_sequence_type(t: Any) -> bool: + return getattr(t, "_name", None) in ["List", "Sequence"] + + def _pack_fmt(proto_type: str) -> str: """Returns a little-endian format string for reading/writing binary.""" return { @@ -919,7 +939,10 @@ def dump(self, stream: BinaryIO) -> None: field_name=field_name, meta=meta ) - if value == self._get_field_default(field_name) and not ( + if ( + _is_empty_sequence(value) + or value == self._get_field_default(field_name) + ) and not ( selected_in_group or serialize_empty or include_default_value_for_oneof ): # Default (zero) values are not serialized. Two exceptions are @@ -928,7 +951,11 @@ def dump(self, stream: BinaryIO) -> None: # set by the user). continue - if isinstance(value, list): + if isinstance(value, ScalarArray) and meta.proto_type in FIXED_TYPES: + if value._ScalarArray__proto_type != meta.proto_type: + raise ValueError("Scalar array has incompatible type") + output += _serialize_single(meta.number, TYPE_BYTES, bytes(value)) + elif _is_sequence(value): if meta.proto_type in PACKED_TYPES: # Packed lists look like a length-delimited field. First, # preprocess/encode each value into a buffer and then @@ -1125,6 +1152,8 @@ def _get_field_default(self, field_name: str) -> Any: with warnings.catch_warnings(): # ignore warnings when initialising deprecated field defaults warnings.filterwarnings("ignore", category=DeprecationWarning) + if _is_sequence_type(self._betterproto.default_gen[field_name]): + return [] return self._betterproto.default_gen[field_name]() @classmethod @@ -1135,7 +1164,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any: if t.__origin__ is dict: # This is some kind of map (dict in Python). return dict - elif t.__origin__ is list: + elif _is_sequence_type(t.__origin__): # This is some kind of list (repeated) field. return list elif t.__origin__ is Union and t.__args__[1] is type(None): @@ -1240,22 +1269,18 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: value: Any if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES: # This is a packed repeated field. - pos = 0 - value = [] - while pos < len(parsed.value): - if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32): - decoded, pos = parsed.value[pos : pos + 4], pos + 4 - wire_type = WIRE_FIXED_32 - elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64): - decoded, pos = parsed.value[pos : pos + 8], pos + 8 - wire_type = WIRE_FIXED_64 - else: + if meta.proto_type in FIXED_TYPES: + value = ScalarArray(parsed.value, meta.proto_type) + else: + pos = 0 + value = [] + while pos < len(parsed.value): decoded, pos = decode_varint(parsed.value, pos) wire_type = WIRE_VARINT - decoded = self._postprocess_single( - wire_type, meta, field_name, decoded - ) - value.append(decoded) + decoded = self._postprocess_single( + wire_type, meta, field_name, decoded + ) + value.append(decoded) else: value = self._postprocess_single( parsed.wire_type, meta, field_name, parsed.value @@ -1270,7 +1295,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T: if meta.proto_type == TYPE_MAP: # Value represents a single key/value pair entry in the map. current[value.key] = value.value - elif isinstance(current, list) and not isinstance(value, list): + elif _is_sequence(current) and not _is_sequence(value): current.append(value) else: setattr(self, field_name, value) @@ -1364,7 +1389,7 @@ def to_dict( field_types = self._type_hints() defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): - field_is_repeated = defaults[field_name] is list + field_is_repeated = _is_sequence_type(defaults[field_name]) try: value = getattr(self, field_name) except AttributeError: @@ -1425,7 +1450,10 @@ def to_dict( if value or include_default_values: output[cased_name] = output_map elif ( - value != self._get_field_default(field_name) + ( + _is_nonempty_sequence(value) + or value != self._get_field_default(field_name) + ) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1594,6 +1622,7 @@ def to_json( return json.dumps( self.to_dict(include_default_values=include_default_values, casing=casing), indent=indent, + default=lambda x: x.__json__(), ) def from_json(self: T, value: Union[str, bytes]) -> T: @@ -1641,7 +1670,7 @@ def to_pydict( output: Dict[str, Any] = {} defaults = self._betterproto.default_gen for field_name, meta in self._betterproto.meta_by_field_name.items(): - field_is_repeated = defaults[field_name] is list + field_is_repeated = _is_sequence_type(defaults[field_name]) value = getattr(self, field_name) cased_name = casing(field_name).rstrip("_") # type: ignore if meta.proto_type == TYPE_MESSAGE: @@ -1690,7 +1719,10 @@ def to_pydict( if value or include_default_values: output[cased_name] = value elif ( - value != self._get_field_default(field_name) + ( + _is_nonempty_sequence(value) + or value != self._get_field_default(field_name) + ) or include_default_values or self._include_default_value_for_oneof( field_name=field_name, meta=meta @@ -1846,6 +1878,8 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]] UInt64Value, ) +from .scalar_array import ScalarArray + class _Duration(Duration): @classmethod diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index ea819d44d..c7c2cb854 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -322,7 +322,7 @@ def py_name(self) -> str: @property def annotation(self) -> str: if self.repeated: - return f"List[{self.py_name}]" + return f"Sequence[{self.py_name}]" return self.py_name @property @@ -440,8 +440,8 @@ def typing_imports(self) -> Set[str]: annotation = self.annotation if "Optional[" in annotation: imports.add("Optional") - if "List[" in annotation: - imports.add("List") + if "Sequence[" in annotation: + imports.add("Sequence") if "Dict[" in annotation: imports.add("Dict") return imports @@ -572,7 +572,7 @@ def annotation(self) -> str: if self.use_builtins: py_type = f"builtins.{py_type}" if self.repeated: - return f"List[{py_type}]" + return f"Sequence[{py_type}]" if self.optional: return f"Optional[{py_type}]" return py_type diff --git a/src/betterproto/scalar_array.py b/src/betterproto/scalar_array.py new file mode 100644 index 000000000..18553f9ba --- /dev/null +++ b/src/betterproto/scalar_array.py @@ -0,0 +1,104 @@ +import struct +from collections.abc import Sequence +from . import ( + TYPE_DOUBLE, + TYPE_FIXED32, + TYPE_FIXED64, + TYPE_FLOAT, + TYPE_SFIXED32, + TYPE_SFIXED64, + _pack_fmt, +) + +NP_DOUBLE = "float64" +NP_FLOAT = "float32" +NP_SFIXED32 = "int32" +NP_FIXED32 = "uint32" +NP_SFIXED64 = "int64" +NP_FIXED64 = "uint64" + + +def _convert_types_np2proto(np_type: str) -> str: + return { + NP_DOUBLE: TYPE_DOUBLE, + NP_FLOAT: TYPE_FLOAT, + NP_SFIXED32: TYPE_SFIXED32, + NP_FIXED32: TYPE_FIXED32, + NP_SFIXED64: TYPE_SFIXED64, + NP_FIXED64: TYPE_FIXED64, + }[np_type] + + +def _convert_types_proto2np(proto_type: str) -> str: + return { + TYPE_DOUBLE: NP_DOUBLE, + TYPE_FLOAT: NP_FLOAT, + TYPE_SFIXED32: NP_SFIXED32, + TYPE_FIXED32: NP_FIXED32, + TYPE_SFIXED64: NP_SFIXED64, + TYPE_FIXED64: NP_FIXED64, + }[proto_type] + + +def _item_size(proto_type: str) -> int: + return { + TYPE_DOUBLE: 8, + TYPE_FLOAT: 4, + TYPE_SFIXED32: 4, + TYPE_FIXED32: 4, + TYPE_SFIXED64: 8, + TYPE_FIXED64: 8, + }[proto_type] + + +class ScalarArray(Sequence): + __data: bytes + __item_size: int + __proto_type: str + + def __init__(self, data: bytes, proto_type: str) -> None: + self.__data = data + self.__item_size = _item_size(proto_type) + self.__proto_type = proto_type + + def __len__(self) -> int: + return len(self.__data) // self.__item_size + + def __getitem__(self, i: int): + if i < 0: + i += len(self) + if i < 0 or i >= len(self): + raise IndexError + + value = self.__data[i * self.__item_size : (i + 1) * self.__item_size] + value = struct.unpack(_pack_fmt(self.__proto_type), value)[0] + return value + + def __bytes__(self) -> bytes: + return self.__data + + def __repr__(self) -> str: + return str(list(self)) + + def __array__(self): + import numpy as np + + return np.frombuffer( + self.__data, dtype=_convert_types_proto2np(self.__proto_type) + ) + + def __json__(self): + return list(self) + + def __eq__(self, other): + if isinstance(other, ScalarArray): + return ( + self.__data == other.__data + and self.__item_size == other.__item_size + and self.__proto_type == other.__proto_type + ) + return isinstance(other, Sequence) and list(self) == list(other) + + @staticmethod + def from_numpy(ar) -> "ScalarArray": + return ScalarArray(bytes(ar), _convert_types_np2proto(str(ar.dtype)))