Skip to content

Commit

Permalink
msgspec.structs.fields works on parametrized generics
Browse files Browse the repository at this point in the history
Previously `msgspec.structs.fields` only worked on struct instances or
types. We now also support parametrized generic struct types, so
`msgspec.structs.fields(MyStruct[int])` also works.
  • Loading branch information
jcrist committed Oct 17, 2023
1 parent 9f5f50b commit 31b3fb3
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 9 deletions.
13 changes: 7 additions & 6 deletions msgspec/structs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
astuple,
replace,
)
from ._utils import get_type_hints as _get_type_hints
from ._utils import get_class_annotations as _get_class_annotations

__all__ = (
"FieldInfo",
Expand Down Expand Up @@ -71,13 +71,14 @@ def fields(type_or_instance: Struct | type[Struct]) -> tuple[FieldInfo]:
tuple[FieldInfo]
"""
if isinstance(type_or_instance, Struct):
cls = type(type_or_instance)
elif isinstance(type_or_instance, type) and issubclass(type_or_instance, Struct):
cls = type_or_instance
annotated_cls = cls = type(type_or_instance)
else:
raise TypeError("Must be called with a struct type or instance")
annotated_cls = type_or_instance
cls = getattr(type_or_instance, "__origin__", type_or_instance)
if not (isinstance(cls, type) and issubclass(cls, Struct)):
raise TypeError("Must be called with a struct type or instance")

hints = _get_type_hints(cls)
hints = _get_class_annotations(annotated_cls)
npos = len(cls.__struct_fields__) - len(cls.__struct_defaults__)
fields = []
for name, encode_name, default_obj in zip(
Expand Down
32 changes: 29 additions & 3 deletions tests/test_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import weakref
from contextlib import contextmanager
from inspect import Parameter, Signature
from typing import Any, List, Optional
from typing import Any, List, Optional, Generic, TypeVar

import pytest
from utils import temp_module
Expand Down Expand Up @@ -2230,8 +2230,14 @@ def test_errors(self, func):

class TestInspectFields:
def test_fields_bad_arg(self):
with pytest.raises(TypeError, match="struct type or instance"):
msgspec.structs.fields(1)
T = TypeVar("T")

class Bad(Generic[T]):
x: T

for val in [1, int, Bad, Bad[int]]:
with pytest.raises(TypeError, match="struct type or instance"):
msgspec.structs.fields(val)

def test_fields_no_fields(self):
assert msgspec.structs.fields(msgspec.Struct) == ()
Expand Down Expand Up @@ -2289,6 +2295,26 @@ class Example(msgspec.Struct, rename="camel"):

assert msgspec.structs.fields(Example) == sol

def test_fields_generic(self):
T = TypeVar("T")

class Example(msgspec.Struct, Generic[T]):
x: T
y: int

sol = (
msgspec.structs.FieldInfo("x", "x", T),
msgspec.structs.FieldInfo("y", "y", int),
)
assert msgspec.structs.fields(Example) == sol
assert msgspec.structs.fields(Example(1, 2)) == sol

sol = (
msgspec.structs.FieldInfo("x", "x", str),
msgspec.structs.FieldInfo("y", "y", int),
)
assert msgspec.structs.fields(Example[str])


class TestClassVar:
def case1(self):
Expand Down

0 comments on commit 31b3fb3

Please sign in to comment.