Skip to content

Commit

Permalink
refactor: Improve typing of classify_field() (#132)
Browse files Browse the repository at this point in the history
This addresses a part of the code base that had a lot of type ignores.
This was due to how the return type of the `classify_field()` function
did not maintain the relationship of its two values, and not having
precise enough information for the case when the inner type is a subtype
of `Entity`.

By introducing specialized types for each four cases in conjunction with
a union return type, mypy is able to much better keep track of the
relationship that a field classified as an "entity field" must carry a
type that is a subtype of `Entity`.

As can be seen in the diff, test_introspect.py is the only changed test
module, this only has fallout on the internal introspection API. The
improvement and intent of this commit is to reduce the amount of type
ignores in `get_field_reader()` and `get_field_writer()`.

Co-authored-by: Anton Agestam <anton.agestam@aiven.io>
  • Loading branch information
antonagestam and aiven-anton authored Oct 2, 2024
1 parent 5262430 commit c30f54b
Show file tree
Hide file tree
Showing 4 changed files with 112 additions and 80 deletions.
56 changes: 43 additions & 13 deletions src/kio/serial/_introspect.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
import enum

from dataclasses import Field
from dataclasses import dataclass
from dataclasses import is_dataclass
from types import EllipsisType
from types import NoneType
from types import UnionType
from typing import ClassVar
from typing import TypeAlias
from typing import TypeVar
from typing import Union
from typing import final
from typing import get_args
from typing import get_origin

from kio.static.protocol import Entity

from .errors import SchemaError


Expand Down Expand Up @@ -43,21 +47,47 @@ def is_optional(field: Field) -> bool:
return NoneType in get_args(inner_type)


class FieldKind(enum.Enum):
primitive = enum.auto()
primitive_tuple = enum.auto()
entity = enum.auto()
entity_tuple = enum.auto()
@final
@dataclass(frozen=True, slots=True)
class PrimitiveField:
is_array: ClassVar = False
type_: type


@final
@dataclass(frozen=True, slots=True)
class PrimitiveTupleField:
is_array: ClassVar = True
type_: type


@final
@dataclass(frozen=True, slots=True)
class EntityField:
is_array: ClassVar = False
type_: type[Entity]


@final
@dataclass(frozen=True, slots=True)
class EntityTupleField:
is_array: ClassVar = True
type_: type[Entity]


FieldClass: TypeAlias = (
PrimitiveField | PrimitiveTupleField | EntityField | EntityTupleField
)


T = TypeVar("T")


def classify_field(field: Field[T]) -> tuple[FieldKind, type[T]]:
def classify_field(field: Field[T]) -> FieldClass:
return _classify_field(field.type, field.name)


def _classify_field(field_type: type[T], field_name: str) -> tuple[FieldKind, type[T]]:
def _classify_field(field_type: type[T], field_name: str) -> FieldClass:
type_origin = get_origin(field_type)

if type_origin is UnionType:
Expand All @@ -79,18 +109,18 @@ def _classify_field(field_type: type[T], field_name: str) -> tuple[FieldKind, ty

if type_origin is not tuple:
return (
(FieldKind.entity, field_type) # type: ignore[return-value]
EntityField(field_type) # type: ignore[arg-type]
if is_dataclass(field_type)
else (FieldKind.primitive, field_type)
else PrimitiveField(field_type)
)

type_args = get_args(field_type)

match type_args:
case (inner_type, EllipsisType()) if is_dataclass(inner_type):
return FieldKind.entity_tuple, inner_type
return EntityTupleField(inner_type)
case (inner_type, EllipsisType()):
return FieldKind.primitive_tuple, inner_type
return PrimitiveTupleField(inner_type)

raise SchemaError(f"Field {field_name} has invalid tuple type args: {type_args}")

Expand Down
51 changes: 28 additions & 23 deletions src/kio/serial/_parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,10 @@
from kio.static.protocol import Entity

from . import readers
from ._introspect import FieldKind
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
from ._introspect import PrimitiveTupleField
from ._introspect import classify_field
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
Expand Down Expand Up @@ -94,40 +97,42 @@ def get_field_reader(
if is_request_header and field.name == "client_id":
return readers.read_nullable_legacy_string # type: ignore[return-value]

field_kind, field_type = classify_field(field)
flexible = entity_type.__flexible__
array_reader = (
readers.compact_array_reader if flexible else readers.legacy_array_reader
)
field_class = classify_field(field)

match field_kind:
case FieldKind.primitive:
return get_reader(
match field_class:
case PrimitiveField():
inner_type_reader = get_reader(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=is_optional(field) and not is_tagged_field,
)
case FieldKind.primitive_tuple:
return array_reader( # type: ignore[return-value]
get_reader(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=is_optional(field),
)
case PrimitiveTupleField():
inner_type_reader = get_reader(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=is_optional(field),
)
case FieldKind.entity:
return ( # type: ignore[no-any-return]
entity_reader(field_type, nullable=True) # type: ignore[call-overload]
case EntityField(field_type):
inner_type_reader = (
entity_reader(field_type, nullable=True)
if is_optional(field)
else entity_reader(field_type, nullable=False) # type: ignore[call-overload]
)
case FieldKind.entity_tuple:
return array_reader( # type: ignore[return-value]
entity_reader(field_type) # type: ignore[type-var]
else entity_reader(field_type, nullable=False)
)
case EntityTupleField(field_type):
inner_type_reader = entity_reader(field_type)
case no_match:
assert_never(no_match)

if field_class.is_array:
array_reader = (
readers.compact_array_reader if flexible else readers.legacy_array_reader
)
# mypy fails to bind T to Sequence[object] here.
return array_reader(inner_type_reader) # type: ignore[return-value]

return inner_type_reader


E = TypeVar("E", bound=Entity)

Expand Down
44 changes: 21 additions & 23 deletions src/kio/serial/_serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,10 @@
from kio.static.protocol import Entity

from . import writers
from ._introspect import FieldKind
from ._introspect import EntityField
from ._introspect import EntityTupleField
from ._introspect import PrimitiveField
from ._introspect import PrimitiveTupleField
from ._introspect import classify_field
from ._introspect import get_field_tag
from ._introspect import get_schema_field_type
Expand Down Expand Up @@ -104,45 +107,40 @@ def get_field_writer(
if is_request_header and field.name == "client_id":
return writers.write_nullable_legacy_string # type: ignore[return-value]

field_kind, field_type = classify_field(field)
array_writer = compact_array_writer if flexible else legacy_array_writer

# Optionality needs to be special cased for tagged fields, because they are optional
# by definition. This optionality is implemented in a different way from normal
# fields, it's implemented by the presence or absence by the tag itself. Hence, we
# can have optional fields with in-transit value types that cannot represent None.
# To be able to match an optional tagged field to a writer that cannot accept None,
# we hard-code all tagged fields as not optional here.
optional = False if is_tag else is_optional(field)
field_class = classify_field(field)

match field_kind:
case FieldKind.primitive:
return get_writer(
match field_class:
case PrimitiveField() | PrimitiveTupleField():
inner_type_writer = get_writer(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=optional,
)
case FieldKind.primitive_tuple:
return array_writer( # type: ignore[return-value]
get_writer(
kafka_type=get_schema_field_type(field),
flexible=flexible,
optional=optional,
)
)
case FieldKind.entity:
return ( # type: ignore[no-any-return]
entity_writer(field_type, nullable=True) # type: ignore[call-overload]
case EntityField(field_type):
inner_type_writer = (
entity_writer(field_type, nullable=True)
if optional
else entity_writer(field_type, nullable=False) # type: ignore[call-overload]
)
case FieldKind.entity_tuple:
return array_writer( # type: ignore[return-value]
entity_writer(field_type) # type: ignore[type-var]
else entity_writer(field_type, nullable=False)
)
case EntityTupleField(field_type):
inner_type_writer = entity_writer(field_type)
case no_match:
assert_never(no_match)

if field_class.is_array:
array_writer = compact_array_writer if flexible else legacy_array_writer
# mypy fails to bind T to Sequence[object] here.
return array_writer(inner_type_writer) # type: ignore[return-value]

return inner_type_writer


E = TypeVar("E", bound=Entity)

Expand Down
41 changes: 20 additions & 21 deletions tests/serial/test_introspect.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,28 @@
from dataclasses import dataclass
from dataclasses import field
from dataclasses import fields
from typing import ClassVar
from uuid import UUID

import pytest

from kio.serial._introspect import FieldKind
from kio.serial._introspect import EntityField
from kio.serial._introspect import EntityTupleField
from kio.serial._introspect import PrimitiveField
from kio.serial._introspect import PrimitiveTupleField
from kio.serial._introspect import classify_field
from kio.serial._introspect import get_schema_field_type
from kio.serial._introspect import is_optional
from kio.serial.errors import SchemaError
from kio.static.constants import EntityType
from kio.static.primitive import i16


@dataclass
class Nested: ...
class Nested:
__type__: ClassVar = EntityType.data
__version__: ClassVar = i16(0)
__flexible__: ClassVar = True


@dataclass
Expand Down Expand Up @@ -109,28 +118,21 @@ def test_raises_schema_error_for_non_none_union(self) -> None:
classify_field(model_fields["verbose_union_without_none"])

def test_can_classify_primitive_field(self) -> None:
assert classify_field(model_fields["primitive"]) == (FieldKind.primitive, int)
assert classify_field(model_fields["primitive"]) == PrimitiveField(int)

def test_can_classify_primitive_tuple_field(self) -> None:
assert classify_field(model_fields["primitive_tuple"]) == (
FieldKind.primitive_tuple,
int,
)
expected = PrimitiveTupleField(int)
assert classify_field(model_fields["primitive_tuple"]) == expected

def test_can_classify_entity_tuple_field(self) -> None:
assert classify_field(model_fields["entity_tuple"]) == (
FieldKind.entity_tuple,
Nested,
)
assert classify_field(model_fields["entity_tuple"]) == EntityTupleField(Nested)

def test_can_classify_nullable_nested_entity_tuple(self) -> None:
assert classify_field(model_fields["nullable_entity_tuple"]) == (
FieldKind.entity_tuple,
Nested,
)
expected = EntityTupleField(Nested)
assert classify_field(model_fields["nullable_entity_tuple"]) == expected

def test_can_classify_simple_nested_entity(self) -> None:
assert classify_field(model_fields["entity"]) == (FieldKind.entity, Nested)
assert classify_field(model_fields["entity"]) == EntityField(Nested)

# See KIP-893.
@pytest.mark.parametrize(
Expand All @@ -141,10 +143,7 @@ def test_can_classify_simple_nested_entity(self) -> None:
),
)
def test_can_classify_nullable_nested_entity(self, field: Field) -> None:
assert classify_field(field) == (FieldKind.entity, Nested)
assert classify_field(field) == EntityField(Nested)

def test_can_classify_uuid_or_none(self) -> None:
assert classify_field(model_fields["uuid_or_none"]) == (
FieldKind.primitive,
UUID,
)
assert classify_field(model_fields["uuid_or_none"]) == PrimitiveField(UUID)

0 comments on commit c30f54b

Please sign in to comment.