diff --git a/msgspec/_core.c b/msgspec/_core.c index 00e6abdc..f2510352 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -3706,7 +3706,7 @@ typenode_from_collect_state(TypeNodeCollectState *state) { static bool get_msgspec_cache(MsgspecState *mod, PyObject *obj, PyTypeObject *type, PyObject **out) { - PyObject *cached = PyObject_GetAttr(obj, mod->str___msgspec_cache__); + PyObject *cached = PyObject_GenericGetAttr(obj, mod->str___msgspec_cache__); if (cached != NULL) { if (Py_TYPE(cached) != type) { Py_DECREF(cached); diff --git a/tests/test_common.py b/tests/test_common.py index d422ad44..9886da0d 100644 --- a/tests/test_common.py +++ b/tests/test_common.py @@ -2582,6 +2582,20 @@ class Ex: assert dec.decode(proto.encode(msg)) == msg assert dec2.decode(proto.encode(msg)) == msg + def test_decode_dataclass_subclasses(self, proto): + @dataclass + class Base: + x: int + + @dataclass + class Sub(Base): + y: int + + msg = proto.encode({"x": 1, "y": 2}) + + assert proto.decode(msg, type=Base) == Base(1) + assert proto.decode(msg, type=Sub) == Sub(1, 2) + def test_multiple_dataclasses_errors(self, proto): @dataclass class Ex1: