Skip to content

Commit

Permalink
dataclass: refactor evaluating string fields
Browse files Browse the repository at this point in the history
  • Loading branch information
alexfikl authored and inducer committed Dec 17, 2024
1 parent c4f00b8 commit 9ac8bc8
Showing 1 changed file with 44 additions and 20 deletions.
64 changes: 44 additions & 20 deletions arraycontext/container/dataclass.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,22 @@
"""

from collections.abc import Mapping, Sequence
from dataclasses import Field, fields, is_dataclass
from typing import Union, get_args, get_origin
from dataclasses import fields, is_dataclass
from typing import NamedTuple, Union, get_args, get_origin

from arraycontext.container import is_array_container_type


# {{{ dataclass containers

class _Field(NamedTuple):
"""Small lookalike for :class:`dataclasses.Field`."""

init: bool
name: str
type: type


def is_array_type(tp: type) -> bool:
from arraycontext import Array
return tp is Array or is_array_container_type(tp)
Expand Down Expand Up @@ -73,7 +81,9 @@ def dataclass_array_container(cls: type) -> type:

assert is_dataclass(cls)

def is_array_field(f: Field, field_type: type) -> bool:
def is_array_field(f: _Field) -> bool:
field_type = f.type

# NOTE: unions of array containers are treated separately to handle
# unions of only array containers, e.g. `Union[np.ndarray, Array]`, as
# they can work seamlessly with arithmetic and traversal.
Expand All @@ -96,10 +106,8 @@ def is_array_field(f: Field, field_type: type) -> bool:
f"Field '{f.name}' union contains non-array container "
"arguments. All arguments must be array containers.")

if isinstance(field_type, str):
raise TypeError(
f"String annotation on field '{f.name}' not supported. "
"(this may be due to 'from __future__ import annotations')")
# NOTE: this should never happen due to using `inspect.get_annotations`
assert not isinstance(field_type, str)

if __debug__:
if not f.init:
Expand Down Expand Up @@ -127,36 +135,52 @@ def is_array_field(f: Field, field_type: type) -> bool:

return is_array_type(field_type)

from pytools import partition

array_fields = _get_annotated_fields(cls)
array_fields, non_array_fields = partition(is_array_field, array_fields)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")

return _inject_dataclass_serialization(cls, array_fields, non_array_fields)


def _get_annotated_fields(cls: type) -> Sequence[_Field]:
"""Get a list of fields in the class *cls* with evaluated types.
If any of the fields in *cls* have type annotations that are strings, e.g.
from using ``from __future__ import annotations``, this function evaluates
them using :func:`inspect.get_annotations`. Note that this requires the class
to live in a module that is importable.
:return: a list of fields.
"""

from inspect import get_annotations

array_fields: list[Field] = []
non_array_fields: list[Field] = []
result = []
cls_ann: Mapping[str, type] | None = None
for field in fields(cls):
field_type_or_str = field.type
if isinstance(field_type_or_str, str):
if cls_ann is None:
cls_ann = get_annotations(cls, eval_str=True)

field_type = cls_ann[field.name]
else:
field_type = field_type_or_str

if is_array_field(field, field_type):
array_fields.append(field)
else:
non_array_fields.append(field)

if not array_fields:
raise ValueError(f"'{cls}' must have fields with array container type "
"in order to use the 'dataclass_array_container' decorator")
result.append(_Field(init=field.init, name=field.name, type=field_type))

return _inject_dataclass_serialization(cls, array_fields, non_array_fields)
return result


def _inject_dataclass_serialization(
cls: type,
array_fields: Sequence[Field],
non_array_fields: Sequence[Field]) -> type:
array_fields: Sequence[_Field],
non_array_fields: Sequence[_Field]) -> type:
"""Implements :func:`~arraycontext.serialize_container` and
:func:`~arraycontext.deserialize_container` for the given dataclass *cls*.
Expand Down

0 comments on commit 9ac8bc8

Please sign in to comment.