diff --git a/tests/serial/test_serialize.py b/tests/serial/test_serialize.py index b362559c..11f9162b 100644 --- a/tests/serial/test_serialize.py +++ b/tests/serial/test_serialize.py @@ -4,6 +4,7 @@ from dataclasses import dataclass from dataclasses import field +from dataclasses import fields from typing import ClassVar import pytest @@ -16,6 +17,7 @@ from kio.schema.types import TopicName from kio.serial import entity_writer from kio.serial import writers +from kio.serial._serialize import get_field_writer from kio.serial._serialize import get_writer from kio.serial._shared import NullableEntityMarker from kio.serial.readers import read_boolean @@ -29,6 +31,7 @@ from kio.serial.readers import read_uuid from kio.static.constants import EntityType from kio.static.constants import ErrorCode +from kio.static.primitive import i8 from kio.static.primitive import i16 from kio.static.primitive import i32 from kio.static.primitive import i32Timedelta @@ -136,6 +139,127 @@ def test_raises_not_implemented_error_for_invalid_combination( get_writer(kafka_type, flexible, optional) +class TestGetFieldWriter: + def test_special_cases_request_header_client_id(self) -> None: + @dataclass + class E: + client_id: int + + [field] = fields(E) + writer = get_field_writer( + field=field, flexible=True, is_request_header=True, is_tag=False + ) + assert writer == writers.write_nullable_legacy_string + + def test_returns_plain_writer_for_primitive_field(self) -> None: + @dataclass + class E: + a: i8 = field(metadata={"kafka_type": "int8"}) + + [extracted_field] = fields(E) + writer = get_field_writer( + field=extracted_field, flexible=True, is_request_header=True, is_tag=False + ) + assert writer == writers.write_int8 + + def test_returns_array_writer_for_primitive_tuple_field( + self, + buffer: io.BytesIO, + ) -> None: + @dataclass + class E: + a: tuple[i8, ...] = field(metadata={"kafka_type": "int8"}) + + [extracted_field] = fields(E) + writer = get_field_writer( + field=extracted_field, flexible=True, is_request_header=True, is_tag=False + ) + + # We test that the returned object is an int8 array writer, by making + # sure it behaves like one. + writer(buffer, [1, 2, 3]) + buffer.seek(0) + assert read_compact_array_length(buffer) == 3 + assert read_int8(buffer) == 1 + assert read_int8(buffer) == 2 + assert read_int8(buffer) == 3 + + def test_returns_entity_writer_for_entity_field( + self, + buffer: io.BytesIO, + ) -> None: + @dataclass + class A: + __flexible__: ClassVar = True + f: i8 = field(metadata={"kafka_type": "int8"}) + + @dataclass + class B: + a: A + + [extracted_field] = fields(B) + writer = get_field_writer( + field=extracted_field, flexible=True, is_request_header=True, is_tag=False + ) + + writer(buffer, A(f=i8(23))) + buffer.seek(0) + assert read_int8(buffer) == 23 + assert read_unsigned_varint(buffer) == 0 # tags + + def test_returns_entity_writer_for_nullable_entity_field( + self, + buffer: io.BytesIO, + ) -> None: + @dataclass + class A: + __flexible__: ClassVar = True + f: i8 = field(metadata={"kafka_type": "int8"}) + + @dataclass + class B: + a: A | None + + [extracted_field] = fields(B) + writer = get_field_writer( + field=extracted_field, flexible=True, is_request_header=True, is_tag=False + ) + + writer(buffer, A(f=i8(23))) + writer(buffer, None) + buffer.seek(0) + assert NullableEntityMarker(read_int8(buffer)) is NullableEntityMarker.not_null + assert read_int8(buffer) == 23 + assert read_unsigned_varint(buffer) == 0 # tags + assert NullableEntityMarker(read_int8(buffer)) is NullableEntityMarker.null + + def test_returns_entity_tuple_writer_for_entity_tuple_field( + self, + buffer: io.BytesIO, + ) -> None: + @dataclass + class A: + __flexible__: ClassVar = True + f: i8 = field(metadata={"kafka_type": "int8"}) + + @dataclass + class B: + a: tuple[A, ...] + + [extracted_field] = fields(B) + writer = get_field_writer( + field=extracted_field, flexible=True, is_request_header=True, is_tag=False + ) + + writer(buffer, [A(f=i8(23)), A(f=i8(17))]) + buffer.seek(0) + assert read_compact_array_length(buffer) == 2 + assert read_int8(buffer) == 23 + assert read_unsigned_varint(buffer) == 0 # tags + assert read_int8(buffer) == 17 + assert read_unsigned_varint(buffer) == 0 # tags + + @dataclass(frozen=True, slots=True, kw_only=True) class LegacyWithTag: __type__: ClassVar = EntityType.header