diff --git a/README.md b/README.md index 97e9c05..4ee87b9 100644 --- a/README.md +++ b/README.md @@ -158,6 +158,9 @@ All other types fall into the "complex" category. They currently consist of: - `unions`: Unions of serializable types are supported as well. - `Structured`-derived types: You can use any of your `Structured`-derived classes as a type-hint, and the variable will be serialized as well. +- `typing.Self`: This type-hint denotes that the attribute should be unpacked as an instance of + the containing class itself. Note that due to the recursive posibilities this allows, care + must be taken to avoid hitting the recursion limit of Python. ### Tuples diff --git a/structured/serializers/__init__.py b/structured/serializers/__init__.py index 374730c..7265306 100644 --- a/structured/serializers/__init__.py +++ b/structured/serializers/__init__.py @@ -5,3 +5,4 @@ from .structured import * from .tuples import * from .unions import * +from .self import * diff --git a/structured/serializers/api.py b/structured/serializers/api.py index 1cf3238..e77100e 100644 --- a/structured/serializers/api.py +++ b/structured/serializers/api.py @@ -10,6 +10,17 @@ - Modified packing method `pack` - All unpacking methods may return an iterable of values instead of a tuple. For more details, check the docstrings on each method or attribute. + +A note on "container" serializers (for example, CompoundSerializer and +ArraySerializer): Due to the posibility of recursive nesting via the +`typing.Self` type-hint as a serializable type, care must be taken with +delegating to sub-serializers. In particular, only updating `self.size` at the +*end* of a pack/unpack operation ensures that nested usages of the same +serializer won't overwrite intermediate values. + +Similarly (although this is true regardless of nesting), you almost always want +a custom `prepack` and `preunpack` method, to pass that information along to +the nested serializers. """ from __future__ import annotations diff --git a/structured/serializers/arrays.py b/structured/serializers/arrays.py index cf5ca97..f113cbc 100644 --- a/structured/serializers/arrays.py +++ b/structured/serializers/arrays.py @@ -96,26 +96,38 @@ def _check_data_size(self, expected: int, actual: int) -> None: raise ValueError( f'Array data size {actual} does not match expected size {expected}' ) + + def prepack(self, partial_object) -> Self: + self._partial_object = partial_object + return self + + def preunpack(self, partial_object) -> Self: + self._partial_object = partial_object + return self def pack(self, *values: Unpack[tuple[list[T]]]) -> bytes: data = [b''] - self.size = header_size = self.header_serializer.size + size = header_size = self.header_serializer.size + item_serializer = self.item_serializer.prepack(self._partial_object) for item in values[0]: - data.append(self.item_serializer.pack(item)) - self.size += self.item_serializer.size - header_values = self._header_pack_values(values[0], self.size - header_size) + data.append(item_serializer.pack(item)) + size += item_serializer.size + header_values = self._header_pack_values(values[0], size - header_size) data[0] = self.header_serializer.pack(*header_values) + self.size = size return b''.join(data) def pack_into( self, buffer: WritableBuffer, offset: int, *values: Unpack[tuple[list[T]]] ) -> None: items = values[0] - self.size = header_size = self.header_serializer.size + size = header_size = self.header_serializer.size + item_serializer = self.item_serializer.prepack(self._partial_object) for item in items: - self.item_serializer.pack_into(buffer, offset + self.size, item) - self.size += self.item_serializer.size - header_values = self._header_pack_values(items, self.size - header_size) + item_serializer.pack_into(buffer, offset + size, item) + size += item_serializer.size + header_values = self._header_pack_values(items, size - header_size) + self.size = size self.header_serializer.pack_into(buffer, offset, *header_values) def pack_write(self, writable: BinaryIO, *values: Unpack[tuple[list[T]]]) -> None: @@ -125,34 +137,40 @@ def pack_write(self, writable: BinaryIO, *values: Unpack[tuple[list[T]]]) -> Non def unpack(self, buffer: ReadableBuffer) -> tuple[list[T]]: header = self.header_serializer.unpack(buffer) count, data_size = self._header_unpack_values(*header) - self.size = header_size = self.header_serializer.size + size = header_size = self.header_serializer.size + item_serializer = self.item_serializer.preunpack(self._partial_object) items = [] for _ in range(count): - items.extend(self.item_serializer.unpack(buffer[self.size :])) - self.size += self.item_serializer.size - self._check_data_size(data_size, self.size - header_size) + items.extend(item_serializer.unpack(buffer[size :])) + size += item_serializer.size + self._check_data_size(data_size, size - header_size) + self.size = size return (items,) def unpack_from(self, buffer: ReadableBuffer, offset: int) -> tuple[list[T]]: header = self.header_serializer.unpack_from(buffer, offset) count, data_size = self._header_unpack_values(*header) - self.size = header_size = self.header_serializer.size + size = header_size = self.header_serializer.size + item_serializer = self.item_serializer.preunpack(self._partial_object) items = [] for _ in range(count): - items.extend(self.item_serializer.unpack_from(buffer, offset + self.size)) - self.size += self.item_serializer.size - self._check_data_size(data_size, self.size - header_size) + items.extend(item_serializer.unpack_from(buffer, offset + size)) + size += item_serializer.size + self._check_data_size(data_size, size - header_size) + self.size = size return (items,) def unpack_read(self, readable: BinaryIO) -> tuple[list[T]]: header = self.header_serializer.unpack_read(readable) count, data_size = self._header_unpack_values(*header) - self.size = header_size = self.header_serializer.size + size = header_size = self.header_serializer.size + item_serializer = self.item_serializer.preunpack(self._partial_object) items = [] for _ in range(count): - items.extend(self.item_serializer.unpack_read(readable)) - self.size += self.item_serializer.size - self._check_data_size(data_size, self.size - header_size) + items.extend(item_serializer.unpack_read(readable)) + size += item_serializer.size + self._check_data_size(data_size, size - header_size) + self.size = size return (items,) diff --git a/structured/serializers/self.py b/structured/serializers/self.py new file mode 100644 index 0000000..8a09259 --- /dev/null +++ b/structured/serializers/self.py @@ -0,0 +1,42 @@ +""" +Serializer for special handling of the typing.Self typehint. +""" + +__all__ = [ + 'SelfSerializer', +] + + +from ..type_checking import ( + Any, + ClassVar, + TYPE_CHECKING, + annotated, + Self, +) +from .api import Serializer +from .structured import StructuredSerializer + + +if TYPE_CHECKING: + from ..structured import Structured, _Proxy +else: + Structured = 'Structured' + _Proxy = '_Proxy' + +class SelfSerializer(Serializer[Structured]): + num_values: ClassVar[int] = 1 + + def prepack(self, partial_object: Structured) -> Serializer: + return StructuredSerializer(type(partial_object)) + + def preunpack(self, partial_object: _Proxy) -> Serializer: + return StructuredSerializer(partial_object.cls) + + @classmethod + def _transform(cls, unwrapped: Any, actual: Any) -> Any: + if unwrapped is Self: + return cls() + + +annotated.register_transform(SelfSerializer._transform) \ No newline at end of file diff --git a/structured/serializers/structured.py b/structured/serializers/structured.py index 127e59c..e9774f4 100644 --- a/structured/serializers/structured.py +++ b/structured/serializers/structured.py @@ -42,32 +42,39 @@ class StructuredSerializer(Generic[TStructured], Serializer[TStructured]): def __init__(self, obj_type: type[TStructured]) -> None: self.obj_type = obj_type - - @property - def size(self) -> int: - return self.obj_type.serializer.size + self.size = 0 def pack(self, values: TStructured) -> bytes: - return values.pack() + data = values.pack() + self.size = values.serializer.size + return data def pack_into( self, buffer: WritableBuffer, offset: int, values: TStructured ) -> None: values.pack_into(buffer, offset) + self.size = values.serializer.size def pack_write(self, writable: BinaryIO, values: TStructured) -> None: values.pack_write(writable) + self.size = values.serializer.size def unpack(self, buffer: ReadableBuffer) -> tuple[TStructured]: - return (self.obj_type.create_unpack(buffer),) + value = self.obj_type.create_unpack(buffer) + self.size = self.obj_type.serializer.size + return (value, ) def unpack_from( self, buffer: ReadableBuffer, offset: int = 0 ) -> tuple[TStructured]: - return (self.obj_type.create_unpack_from(buffer, offset),) + value = self.obj_type.create_unpack_from(buffer, offset) + self.size = self.obj_type.serializer.size + return (value, ) def unpack_read(self, readable: BinaryIO) -> tuple[TStructured]: - return (self.obj_type.create_unpack_read(readable),) + value = self.obj_type.create_unpack_read(readable) + self.size = self.obj_type.serializer.size + return (value, ) @classmethod def _transform(cls, unwrapped: Any, actual: Any) -> Any: diff --git a/structured/serializers/unions.py b/structured/serializers/unions.py index a4766ff..f3545d8 100644 --- a/structured/serializers/unions.py +++ b/structured/serializers/unions.py @@ -20,6 +20,7 @@ ClassVar, Iterable, ReadableBuffer, + WritableBuffer, annotated, get_union_args, ) @@ -48,7 +49,7 @@ def __init__(self, result_map: dict[Any, Any], default: Any = None) -> None: key: self.validate_serializer(serializer) for key, serializer in result_map.items() } - self._last_serializer = self.default + self.size = 0 @staticmethod def validate_serializer(hint) -> Serializer: @@ -59,16 +60,15 @@ def validate_serializer(hint) -> Serializer: raise ValueError('Union results must serializer a single item.') return serializer - @property - def size(self) -> int: - if self._last_serializer: - return self._last_serializer.size - else: - return 0 + def prepack(self, partial_object) -> Serializer: + self._partial_object = partial_object + return self + + def preunpack(self, partial_object) -> Serializer: + self._partial_object = partial_object + return self - def get_serializer( - self, decider_result: Any, partial_object: Any, packing: bool - ) -> Serializer: + def get_serializer(self, decider_result: Any, packing: bool) -> Serializer: """Given a target used to decide, return a serializer used to unpack.""" if self.default is None: try: @@ -80,11 +80,9 @@ def get_serializer( else: serializer = self.result_map.get(decider_result, self.default) if packing: - serializer = serializer.prepack(partial_object) + return serializer.prepack(self._partial_object) else: - serializer = serializer.preunpack(partial_object) - self._last_serializer = serializer - return self._last_serializer + return serializer.preunpack(self._partial_object) @staticmethod def _transform(unwrapped: Any, actual: Any) -> Any: @@ -120,13 +118,43 @@ def __init__( super().__init__(result_map, default) self.decider = decider - def prepack(self, partial_object: Any) -> Serializer: - result = self.decider(partial_object) - return self.get_serializer(result, partial_object, True) + def decide(self, packing: bool) -> Serializer: + result = self.decider(self._partial_object) + return self.get_serializer(result, packing) + + def pack(self, *values: Any) -> bytes: + serializer = self.decide(True) + data = serializer.pack(*values) + self.size = serializer.size + return data + + def pack_into(self, buffer: WritableBuffer, offset: int, *values: Any) -> None: + serializer = self.decide(True) + serializer.pack_into(buffer, offset, *values) + self.size = serializer.size + + def pack_write(self, writable: BinaryIO, *values: Any) -> None: + serializer = self.decide(True) + serializer.pack_write(writable, *values) + self.size = serializer.size - def preunpack(self, partial_object: Any) -> Serializer: - result = self.decider(partial_object) - return self.get_serializer(result, partial_object, False) + def unpack(self, buffer: ReadableBuffer) -> Iterable: + serializer = self.decide(False) + value = serializer.unpack(buffer) + self.size = serializer.size + return value + + def unpack_from(self, buffer: ReadableBuffer, offset: int = 0) -> Iterable: + serializer = self.decide(False) + value = serializer.unpack_from(buffer, offset) + self.size = serializer.size + return value + + def unpack_read(self, readable: BinaryIO) -> Iterable: + serializer = self.decide(False) + value = serializer.unpack_read(readable) + self.size = serializer.size + return value class LookaheadDecider(AUnion): @@ -153,19 +181,43 @@ def __init__( ) self.read_ahead_serializer = serializer - def prepack(self, partial_object: Any) -> Serializer: - result = self.decider(partial_object) - return self.get_serializer(result, partial_object, True) + def pack(self, *values: Any) -> bytes: + result = self.decider(self._partial_object) + serializer = self.get_serializer(result, True) + data = serializer.pack(*values) + self.size = serializer.size + return data + + def pack_into(self, buffer: WritableBuffer, offset: int, *values: Any) -> None: + result = self.decider(self._partial_object) + serializer = self.get_serializer(result, True) + serializer.pack_into(buffer, offset, *values) + self.size = serializer.size + + def pack_write(self, writable: BinaryIO, *values: Any) -> None: + result = self.decider(self._partial_object) + serializer = self.get_serializer(result, True) + serializer.pack_write(writable, *values) + self.size = serializer.size def unpack(self, buffer: ReadableBuffer) -> Iterable: result = tuple(self.read_ahead_serializer.unpack(buffer))[0] - return self.get_serializer(result, None, False).unpack(buffer) + serializer = self.get_serializer(result, False) + values = serializer.unpack(buffer) + self.size = serializer.size + return values def unpack_from(self, buffer: ReadableBuffer, offset: int = 0) -> Iterable: result = tuple(self.read_ahead_serializer.unpack_from(buffer, offset))[0] - return self.get_serializer(result, None, False).unpack_from(buffer, offset) + serializer = self.get_serializer(result, False) + values = serializer.unpack_from(buffer, offset) + self.size = serializer.size + return values def unpack_read(self, readable: BinaryIO) -> Iterable: result = tuple(self.read_ahead_serializer.unpack_read(readable))[0] readable.seek(-self.read_ahead_serializer.size, os.SEEK_CUR) - return self.get_serializer(result, None, False).unpack_read(readable) + serializer = self.get_serializer(result, False) + values = serializer.unpack_read(readable) + self.size = serializer.size + return values diff --git a/structured/structured.py b/structured/structured.py index aa4afb6..f46184d 100644 --- a/structured/structured.py +++ b/structured/structured.py @@ -285,7 +285,7 @@ def _create_proxy(cls) -> tuple[_Proxy, Serializer]: """Create a proxy object for this class, which can be used to create new instances of this class. """ - proxy = _Proxy(cls.attrs) + proxy = _Proxy(cls) return proxy, cls.serializer.preunpack(proxy) @classmethod @@ -464,8 +464,9 @@ class _Proxy: # NOTE: Only using __dunder__ methods, so any attributes on the class this # is a proxy for won't be shadowed. - def __init__(self, attrs: tuple[str, ...]) -> None: - self.__attrs = attrs + def __init__(self, cls: type[Structured]) -> None: + self.__attrs = cls.attrs + self.cls = cls def __call__(self, values: Iterable[Any]) -> None: for attr, value in zips(self.__attrs, values, strict=True): diff --git a/tests/__init__.py b/tests/__init__.py index d4e842d..f8ac28c 100644 --- a/tests/__init__.py +++ b/tests/__init__.py @@ -7,7 +7,7 @@ def standard_tests(target_obj: Structured, target_data: bytes): target_size = len(target_data) assert target_obj.pack() == target_data assert type(target_obj).create_unpack(target_data) == target_obj - assert target_obj.serializer.size == target_size + assert target_obj.serializer.size == target_size, f'{target_obj.serializer.size} != {target_size}' buffer = bytearray(len(target_data)) target_obj.pack_into(buffer) diff --git a/tests/test_self.py b/tests/test_self.py new file mode 100644 index 0000000..0943a36 --- /dev/null +++ b/tests/test_self.py @@ -0,0 +1,91 @@ +import struct +from operator import attrgetter + +import pytest + +from structured import * +from structured.type_checking import Self, Generic, TypeVar, Annotated + +from . import standard_tests + + +class TestSelf: + def test_detection(self) -> None: + class Base(Structured): + a: Self + + assert isinstance(Base.serializer, SelfSerializer) + + with pytest.raises(RecursionError): + Base.create_unpack(b'') + + class Derived(Base): + b: int8 + + assert isinstance(Derived.serializer, CompoundSerializer) + assert isinstance(Derived.serializer.serializers[0], SelfSerializer) + + + T = TypeVar('T') + class BaseGeneric(Generic[T], Structured): + a: T + b: Self + + class DerivedGeneric(BaseGeneric[int8]): + pass + + assert isinstance(DerivedGeneric.serializer, CompoundSerializer) + assert isinstance(DerivedGeneric.serializer.serializers[1], SelfSerializer) + assert isinstance(DerivedGeneric.serializer.serializers[0], struct.Struct) + assert DerivedGeneric.serializer.serializers[0].format == 'b' + + + def test_arrays(self) -> None: + # Test nesting to at least 2 levels + class Base(Structured): + a: array[Header[uint32], Self] + b: uint8 + + level2_items = [ + Base([], 42), + ] + level1_items = [ + Base([], 1), + Base([], 2), + Base(level2_items, 3), + ] + level0_item = Base(level1_items, 0) + + # Level 2 data + item_data = struct.pack('IB', 0, 42) + # Level 1 data + item1_data = struct.pack('IB', 0, 1) + item2_data = struct.pack('IB', 0, 2) + item3_data = struct.pack('I', 1) + item_data + struct.pack('B', 3) + # Level 0 data + container_data = struct.pack('I', 3) + item1_data + item2_data + item3_data + struct.pack('B', 0) + + standard_tests(level0_item, container_data) + + unpacked_obj = Base.create_unpack(container_data) + assert isinstance(unpacked_obj.a[0], Base) + + def test_unions(self) -> None: + decider = LookbackDecider( + attrgetter('type_flag'), + { + 0: Self, + 1: uint64, + } + ) + class Base(Structured): + type_flag: uint8 + data: Annotated[Self | uint64, decider] + + nested_obj = Base(1, 42) + # Note: not the same as pack('BQ', ...), because of padding inserted + nested_data = struct.pack('B', 1) + struct.pack('Q', 42) + container_obj = Base(0, nested_obj) + container_data = struct.pack('B', 0) + nested_data + assert container_obj.pack() == container_data + standard_tests(container_obj, container_data) diff --git a/tests/test_unions.py b/tests/test_unions.py index 3329dc8..7c89938 100644 --- a/tests/test_unions.py +++ b/tests/test_unions.py @@ -33,7 +33,8 @@ class Proxy: # No default specified, and decider returned an invalid value serializer = LookbackDecider(attrgetter('a'), {1: int32}, None) with pytest.raises(ValueError): - serializer.prepack(a) + serializer = serializer.prepack(a) + serializer.pack(1) def test_lookback() -> None: