Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve (de-)serialization performance for scalar arrays #517

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 56 additions & 22 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,26 @@ def map_field(
)


def _is_sequence(x: Any) -> bool:
return (
not isinstance(x, str)
and not isinstance(x, bytes)
and isinstance(x, typing.Sequence)
)


def _is_empty_sequence(x: Any) -> bool:
return _is_sequence(x) and len(x) == 0


def _is_nonempty_sequence(x: Any) -> bool:
return _is_sequence(x) and len(x) != 0


def _is_sequence_type(t: Any) -> bool:
return getattr(t, "_name", None) in ["List", "Sequence"]


def _pack_fmt(proto_type: str) -> str:
"""Returns a little-endian format string for reading/writing binary."""
return {
Expand Down Expand Up @@ -919,7 +939,10 @@ def dump(self, stream: BinaryIO) -> None:
field_name=field_name, meta=meta
)

if value == self._get_field_default(field_name) and not (
if (
_is_empty_sequence(value)
or value == self._get_field_default(field_name)
) and not (
selected_in_group or serialize_empty or include_default_value_for_oneof
):
# Default (zero) values are not serialized. Two exceptions are
Expand All @@ -928,7 +951,11 @@ def dump(self, stream: BinaryIO) -> None:
# set by the user).
continue

if isinstance(value, list):
if isinstance(value, ScalarArray) and meta.proto_type in FIXED_TYPES:
if value._ScalarArray__proto_type != meta.proto_type:
raise ValueError("Scalar array has incompatible type")
output += _serialize_single(meta.number, TYPE_BYTES, bytes(value))
elif _is_sequence(value):
if meta.proto_type in PACKED_TYPES:
# Packed lists look like a length-delimited field. First,
# preprocess/encode each value into a buffer and then
Expand Down Expand Up @@ -1125,6 +1152,8 @@ def _get_field_default(self, field_name: str) -> Any:
with warnings.catch_warnings():
# ignore warnings when initialising deprecated field defaults
warnings.filterwarnings("ignore", category=DeprecationWarning)
if _is_sequence_type(self._betterproto.default_gen[field_name]):
return []
return self._betterproto.default_gen[field_name]()

@classmethod
Expand All @@ -1135,7 +1164,7 @@ def _get_field_default_gen(cls, field: dataclasses.Field) -> Any:
if t.__origin__ is dict:
# This is some kind of map (dict in Python).
return dict
elif t.__origin__ is list:
elif _is_sequence_type(t.__origin__):
# This is some kind of list (repeated) field.
return list
elif t.__origin__ is Union and t.__args__[1] is type(None):
Expand Down Expand Up @@ -1240,22 +1269,18 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
value: Any
if parsed.wire_type == WIRE_LEN_DELIM and meta.proto_type in PACKED_TYPES:
# This is a packed repeated field.
pos = 0
value = []
while pos < len(parsed.value):
if meta.proto_type in (TYPE_FLOAT, TYPE_FIXED32, TYPE_SFIXED32):
decoded, pos = parsed.value[pos : pos + 4], pos + 4
wire_type = WIRE_FIXED_32
elif meta.proto_type in (TYPE_DOUBLE, TYPE_FIXED64, TYPE_SFIXED64):
decoded, pos = parsed.value[pos : pos + 8], pos + 8
wire_type = WIRE_FIXED_64
else:
if meta.proto_type in FIXED_TYPES:
value = ScalarArray(parsed.value, meta.proto_type)
else:
pos = 0
value = []
while pos < len(parsed.value):
decoded, pos = decode_varint(parsed.value, pos)
wire_type = WIRE_VARINT
decoded = self._postprocess_single(
wire_type, meta, field_name, decoded
)
value.append(decoded)
decoded = self._postprocess_single(
wire_type, meta, field_name, decoded
)
value.append(decoded)
else:
value = self._postprocess_single(
parsed.wire_type, meta, field_name, parsed.value
Expand All @@ -1270,7 +1295,7 @@ def load(self: T, stream: BinaryIO, size: Optional[int] = None) -> T:
if meta.proto_type == TYPE_MAP:
# Value represents a single key/value pair entry in the map.
current[value.key] = value.value
elif isinstance(current, list) and not isinstance(value, list):
elif _is_sequence(current) and not _is_sequence(value):
current.append(value)
else:
setattr(self, field_name, value)
Expand Down Expand Up @@ -1364,7 +1389,7 @@ def to_dict(
field_types = self._type_hints()
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
field_is_repeated = _is_sequence_type(defaults[field_name])
try:
value = getattr(self, field_name)
except AttributeError:
Expand Down Expand Up @@ -1425,7 +1450,10 @@ def to_dict(
if value or include_default_values:
output[cased_name] = output_map
elif (
value != self._get_field_default(field_name)
(
_is_nonempty_sequence(value)
or value != self._get_field_default(field_name)
)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
Expand Down Expand Up @@ -1594,6 +1622,7 @@ def to_json(
return json.dumps(
self.to_dict(include_default_values=include_default_values, casing=casing),
indent=indent,
default=lambda x: x.__json__(),
)

def from_json(self: T, value: Union[str, bytes]) -> T:
Expand Down Expand Up @@ -1641,7 +1670,7 @@ def to_pydict(
output: Dict[str, Any] = {}
defaults = self._betterproto.default_gen
for field_name, meta in self._betterproto.meta_by_field_name.items():
field_is_repeated = defaults[field_name] is list
field_is_repeated = _is_sequence_type(defaults[field_name])
value = getattr(self, field_name)
cased_name = casing(field_name).rstrip("_") # type: ignore
if meta.proto_type == TYPE_MESSAGE:
Expand Down Expand Up @@ -1690,7 +1719,10 @@ def to_pydict(
if value or include_default_values:
output[cased_name] = value
elif (
value != self._get_field_default(field_name)
(
_is_nonempty_sequence(value)
or value != self._get_field_default(field_name)
)
or include_default_values
or self._include_default_value_for_oneof(
field_name=field_name, meta=meta
Expand Down Expand Up @@ -1846,6 +1878,8 @@ def which_one_of(message: Message, group_name: str) -> Tuple[str, Optional[Any]]
UInt64Value,
)

from .scalar_array import ScalarArray


class _Duration(Duration):
@classmethod
Expand Down
8 changes: 4 additions & 4 deletions src/betterproto/plugin/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,7 @@ def py_name(self) -> str:
@property
def annotation(self) -> str:
if self.repeated:
return f"List[{self.py_name}]"
return f"Sequence[{self.py_name}]"
return self.py_name

@property
Expand Down Expand Up @@ -440,8 +440,8 @@ def typing_imports(self) -> Set[str]:
annotation = self.annotation
if "Optional[" in annotation:
imports.add("Optional")
if "List[" in annotation:
imports.add("List")
if "Sequence[" in annotation:
imports.add("Sequence")
if "Dict[" in annotation:
imports.add("Dict")
return imports
Expand Down Expand Up @@ -572,7 +572,7 @@ def annotation(self) -> str:
if self.use_builtins:
py_type = f"builtins.{py_type}"
if self.repeated:
return f"List[{py_type}]"
return f"Sequence[{py_type}]"
if self.optional:
return f"Optional[{py_type}]"
return py_type
Expand Down
104 changes: 104 additions & 0 deletions src/betterproto/scalar_array.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
import struct
from collections.abc import Sequence
from . import (
TYPE_DOUBLE,
TYPE_FIXED32,
TYPE_FIXED64,
TYPE_FLOAT,
TYPE_SFIXED32,
TYPE_SFIXED64,
_pack_fmt,
)

NP_DOUBLE = "float64"
NP_FLOAT = "float32"
NP_SFIXED32 = "int32"
NP_FIXED32 = "uint32"
NP_SFIXED64 = "int64"
NP_FIXED64 = "uint64"


def _convert_types_np2proto(np_type: str) -> str:
return {
NP_DOUBLE: TYPE_DOUBLE,
NP_FLOAT: TYPE_FLOAT,
NP_SFIXED32: TYPE_SFIXED32,
NP_FIXED32: TYPE_FIXED32,
NP_SFIXED64: TYPE_SFIXED64,
NP_FIXED64: TYPE_FIXED64,
}[np_type]


def _convert_types_proto2np(proto_type: str) -> str:
return {
TYPE_DOUBLE: NP_DOUBLE,
TYPE_FLOAT: NP_FLOAT,
TYPE_SFIXED32: NP_SFIXED32,
TYPE_FIXED32: NP_FIXED32,
TYPE_SFIXED64: NP_SFIXED64,
TYPE_FIXED64: NP_FIXED64,
}[proto_type]


def _item_size(proto_type: str) -> int:
return {
TYPE_DOUBLE: 8,
TYPE_FLOAT: 4,
TYPE_SFIXED32: 4,
TYPE_FIXED32: 4,
TYPE_SFIXED64: 8,
TYPE_FIXED64: 8,
}[proto_type]


class ScalarArray(Sequence):
__data: bytes
__item_size: int
__proto_type: str

def __init__(self, data: bytes, proto_type: str) -> None:
self.__data = data
self.__item_size = _item_size(proto_type)
self.__proto_type = proto_type

def __len__(self) -> int:
return len(self.__data) // self.__item_size

def __getitem__(self, i: int):
if i < 0:
i += len(self)
if i < 0 or i >= len(self):
raise IndexError

value = self.__data[i * self.__item_size : (i + 1) * self.__item_size]
value = struct.unpack(_pack_fmt(self.__proto_type), value)[0]
return value

def __bytes__(self) -> bytes:
return self.__data

def __repr__(self) -> str:
return str(list(self))

def __array__(self):
import numpy as np

return np.frombuffer(
self.__data, dtype=_convert_types_proto2np(self.__proto_type)
)

def __json__(self):
return list(self)

def __eq__(self, other):
if isinstance(other, ScalarArray):
return (
self.__data == other.__data
and self.__item_size == other.__item_size
and self.__proto_type == other.__proto_type
)
return isinstance(other, Sequence) and list(self) == list(other)

@staticmethod
def from_numpy(ar) -> "ScalarArray":
return ScalarArray(bytes(ar), _convert_types_np2proto(str(ar.dtype)))