From 31b3fb3bdf36149c48897958647d708e216bf5ef Mon Sep 17 00:00:00 2001 From: Jim Crist-Harif Date: Tue, 17 Oct 2023 10:04:45 -0500 Subject: [PATCH] `msgspec.structs.fields` works on parametrized generics Previously `msgspec.structs.fields` only worked on struct instances or types. We now also support parametrized generic struct types, so `msgspec.structs.fields(MyStruct[int])` also works. --- msgspec/structs.py | 13 +++++++------ tests/test_struct.py | 32 +++++++++++++++++++++++++++++--- 2 files changed, 36 insertions(+), 9 deletions(-) diff --git a/msgspec/structs.py b/msgspec/structs.py index cf9f207a..57d5fb24 100644 --- a/msgspec/structs.py +++ b/msgspec/structs.py @@ -10,7 +10,7 @@ astuple, replace, ) -from ._utils import get_type_hints as _get_type_hints +from ._utils import get_class_annotations as _get_class_annotations __all__ = ( "FieldInfo", @@ -71,13 +71,14 @@ def fields(type_or_instance: Struct | type[Struct]) -> tuple[FieldInfo]: tuple[FieldInfo] """ if isinstance(type_or_instance, Struct): - cls = type(type_or_instance) - elif isinstance(type_or_instance, type) and issubclass(type_or_instance, Struct): - cls = type_or_instance + annotated_cls = cls = type(type_or_instance) else: - raise TypeError("Must be called with a struct type or instance") + annotated_cls = type_or_instance + cls = getattr(type_or_instance, "__origin__", type_or_instance) + if not (isinstance(cls, type) and issubclass(cls, Struct)): + raise TypeError("Must be called with a struct type or instance") - hints = _get_type_hints(cls) + hints = _get_class_annotations(annotated_cls) npos = len(cls.__struct_fields__) - len(cls.__struct_defaults__) fields = [] for name, encode_name, default_obj in zip( diff --git a/tests/test_struct.py b/tests/test_struct.py index 79bb308a..80255f52 100644 --- a/tests/test_struct.py +++ b/tests/test_struct.py @@ -8,7 +8,7 @@ import weakref from contextlib import contextmanager from inspect import Parameter, Signature -from typing import Any, List, Optional +from typing import Any, List, Optional, Generic, TypeVar import pytest from utils import temp_module @@ -2230,8 +2230,14 @@ def test_errors(self, func): class TestInspectFields: def test_fields_bad_arg(self): - with pytest.raises(TypeError, match="struct type or instance"): - msgspec.structs.fields(1) + T = TypeVar("T") + + class Bad(Generic[T]): + x: T + + for val in [1, int, Bad, Bad[int]]: + with pytest.raises(TypeError, match="struct type or instance"): + msgspec.structs.fields(val) def test_fields_no_fields(self): assert msgspec.structs.fields(msgspec.Struct) == () @@ -2289,6 +2295,26 @@ class Example(msgspec.Struct, rename="camel"): assert msgspec.structs.fields(Example) == sol + def test_fields_generic(self): + T = TypeVar("T") + + class Example(msgspec.Struct, Generic[T]): + x: T + y: int + + sol = ( + msgspec.structs.FieldInfo("x", "x", T), + msgspec.structs.FieldInfo("y", "y", int), + ) + assert msgspec.structs.fields(Example) == sol + assert msgspec.structs.fields(Example(1, 2)) == sol + + sol = ( + msgspec.structs.FieldInfo("x", "x", str), + msgspec.structs.FieldInfo("y", "y", int), + ) + assert msgspec.structs.fields(Example[str]) + class TestClassVar: def case1(self):