diff --git a/msgspec/__init__.pyi b/msgspec/__init__.pyi index 86a1051a..1f5873b9 100644 --- a/msgspec/__init__.pyi +++ b/msgspec/__init__.pyi @@ -164,6 +164,7 @@ def convert( dec_hook: Optional[Callable[[type, Any], Any]] = None, builtin_types: Union[Iterable[type], None] = None, str_keys: bool = False, + allow_tagged_struct_subtypes: bool = False, ) -> T: ... @overload def convert( @@ -175,6 +176,7 @@ def convert( dec_hook: Optional[Callable[[type, Any], Any]] = None, builtin_types: Union[Iterable[type], None] = None, str_keys: bool = False, + allow_tagged_struct_subtypes: bool = False, ) -> Any: ... # TODO: deprecated diff --git a/msgspec/_core.c b/msgspec/_core.c index 1bf21643..53a51a21 100644 --- a/msgspec/_core.c +++ b/msgspec/_core.c @@ -2920,11 +2920,11 @@ static PyTypeObject NamedTupleInfo_Type; static PyTypeObject StructInfo_Type; static PyTypeObject StructMetaType; static PyTypeObject Ext_Type; -static TypeNode* TypeNode_Convert(PyObject *type); -static PyObject* StructInfo_Convert(PyObject*); -static PyObject* TypedDictInfo_Convert(PyObject*); -static PyObject* DataclassInfo_Convert(PyObject*); -static PyObject* NamedTupleInfo_Convert(PyObject*); +static TypeNode *TypeNode_Convert(PyObject *type, bool allow_tagged_struct_subtypes); +static PyObject* StructInfo_Convert(PyObject*, bool allow_tagged_struct_subtypes); +static PyObject* TypedDictInfo_Convert(PyObject*, bool allow_tagged_struct_subtypes); +static PyObject* DataclassInfo_Convert(PyObject*, bool allow_tagged_struct_subtypes); +static PyObject* NamedTupleInfo_Convert(PyObject*, bool allow_tagged_struct_subtypes); #define StructMeta_GET_FIELDS(s) (((StructMetaObject *)(s))->struct_fields) #define StructMeta_GET_NFIELDS(s) (PyTuple_GET_SIZE((((StructMetaObject *)(s))->struct_fields))) @@ -3374,6 +3374,7 @@ typedef struct { PyObject *literal_str_values; PyObject *literal_str_lookup; bool literal_none; + bool allow_tagged_struct_subtypes; /* Constraints */ int64_t c_int_min; int64_t c_int_max; @@ -3816,17 +3817,17 @@ typenode_from_collect_state(TypeNodeCollectState *state) { out->details[e_ind++].pointer = state->literal_str_lookup; } if (state->typeddict_obj != NULL) { - PyObject *info = TypedDictInfo_Convert(state->typeddict_obj); + PyObject *info = TypedDictInfo_Convert(state->typeddict_obj, state->allow_tagged_struct_subtypes); if (info == NULL) goto error; out->details[e_ind++].pointer = info; } if (state->dataclass_obj != NULL) { - PyObject *info = DataclassInfo_Convert(state->dataclass_obj); + PyObject *info = DataclassInfo_Convert(state->dataclass_obj, state->allow_tagged_struct_subtypes); if (info == NULL) goto error; out->details[e_ind++].pointer = info; } if (state->namedtuple_obj != NULL) { - PyObject *info = NamedTupleInfo_Convert(state->namedtuple_obj); + PyObject *info = NamedTupleInfo_Convert(state->namedtuple_obj, state->allow_tagged_struct_subtypes); if (info == NULL) goto error; out->details[e_ind++].pointer = info; } @@ -3835,10 +3836,10 @@ typenode_from_collect_state(TypeNodeCollectState *state) { out->details[e_ind++].pointer = state->c_str_regex; } if (state->dict_key_obj != NULL) { - TypeNode *temp = TypeNode_Convert(state->dict_key_obj); + TypeNode *temp = TypeNode_Convert(state->dict_key_obj, state->allow_tagged_struct_subtypes); if (temp == NULL) goto error; out->details[e_ind++].pointer = temp; - temp = TypeNode_Convert(state->dict_val_obj); + temp = TypeNode_Convert(state->dict_val_obj, state->allow_tagged_struct_subtypes); if (temp == NULL) goto error; out->details[e_ind++].pointer = temp; } @@ -3848,14 +3849,14 @@ typenode_from_collect_state(TypeNodeCollectState *state) { for (Py_ssize_t i = 0; i < fixtuple_size; i++) { TypeNode *temp = TypeNode_Convert( - PyTuple_GET_ITEM(state->array_el_obj, i) + PyTuple_GET_ITEM(state->array_el_obj, i), state->allow_tagged_struct_subtypes ); if (temp == NULL) goto error; out->details[e_ind++].pointer = temp; } } else { - TypeNode *temp = TypeNode_Convert(state->array_el_obj); + TypeNode *temp = TypeNode_Convert(state->array_el_obj, state->allow_tagged_struct_subtypes); if (temp == NULL) goto error; out->details[e_ind++].pointer = temp; } @@ -4387,22 +4388,108 @@ typenode_collect_convert_literals(TypeNodeCollectState *state) { return 0; } +static PyObject *get_all_subclasses(PyObject *cls, PyObject *visited) +{ + bool top_level = false; + + if (visited == NULL) + { + visited = PyFrozenSet_New(NULL); + if (visited == NULL) + return NULL; + + top_level = true; + } + + PyObject *subclasses = PyObject_CallMethod(cls, "__subclasses__", NULL); + if (subclasses == NULL) goto error; + + Py_ssize_t len = PyList_Size(subclasses); + for (Py_ssize_t i = 0; i < len; i++) + { + PyObject *subclass = PyList_GetItem(subclasses, i); + if (subclass == NULL) goto error; + + int is_in = PySet_Contains(visited, subclass); + if (is_in == -1) goto error; + + if (!is_in) + { + int result = PySet_Add(visited, subclass); + if (result == -1) goto error; + + PyObject *recursive_result = get_all_subclasses(subclass, visited); + if (recursive_result == NULL) goto error; + } + } + + Py_DECREF(subclasses); + return visited; + +error: + Py_XDECREF(subclasses); + + if (top_level) + { + Py_DECREF(visited); + } + return NULL; +} + +int set_update(PyObject *set1, PyObject *set2) +{ + PyObject *set_item = NULL; + Py_ssize_t set_pos = 0; + Py_hash_t set_hash; + + while (_PySet_NextEntry(set2, &set_pos, &set_item, &set_hash)) + { + int result = PySet_Add(set1, set_item); + if (result == -1) return -1; + } + + return 0; +} + static int typenode_collect_convert_structs(TypeNodeCollectState *state) { if (state->struct_obj == NULL && state->structs_set == NULL) { return 0; } - else if (state->struct_obj != NULL) { - /* Single struct */ - state->struct_info = StructInfo_Convert(state->struct_obj); - if (state->struct_info == NULL) return -1; - if (((StructInfo *)state->struct_info)->class->array_like == OPT_TRUE) { - state->types |= MS_TYPE_STRUCT_ARRAY; - } - else { - state->types |= MS_TYPE_STRUCT; + else if (state->struct_obj != NULL) + { + bool allow_tagged_struct_subtypes = ((StructMetaObject *)state->struct_obj)->struct_tag_field != NULL && state->allow_tagged_struct_subtypes; + + if (!allow_tagged_struct_subtypes) + { + /* Single struct and no tagged subclasses support */ + state->struct_info = StructInfo_Convert(state->struct_obj, state->allow_tagged_struct_subtypes); + if (state->struct_info == NULL) + return -1; + if (((StructInfo *)state->struct_info)->class->array_like == OPT_TRUE) + { + state->types |= MS_TYPE_STRUCT_ARRAY; + } + else + { + state->types |= MS_TYPE_STRUCT; + } + return 0; } - return 0; + } + + /* Either an actual upstream set of struct types (union) or a tagged struct + * and subclasses decoding requested + */ + PyObject *structs_set = state->structs_set; + + if (structs_set == NULL) { + structs_set = PyFrozenSet_New(NULL); + } + + if (state->struct_obj != NULL) { + int result = PySet_Add(structs_set, state->struct_obj); + if (result == -1) return -1; } /* Multiple structs. @@ -4411,7 +4498,7 @@ typenode_collect_convert_structs(TypeNodeCollectState *state) { * new one below. */ PyObject *lookup = PyDict_GetItem( - state->mod->struct_lookup_cache, state->structs_set + state->mod->struct_lookup_cache, structs_set ); if (lookup != NULL) { /* Lookup was in the cache, update the state and return */ @@ -4448,9 +4535,32 @@ typenode_collect_convert_structs(TypeNodeCollectState *state) { tag_mapping = PyDict_New(); if (tag_mapping == NULL) goto cleanup; + + PyObject *lookup_structs_set = PySet_New(structs_set); + + if (state->allow_tagged_struct_subtypes) { + /* Extend the lookup set with all subclasses of all classes in the original set */ + while (_PySet_NextEntry(structs_set, &set_pos, &set_item, &set_hash)) + { + /* Single struct, assume this struct and all of it's subclasses */ + PyObject *subclasses_set = get_all_subclasses(set_item, NULL); + + if (subclasses_set == NULL) goto cleanup; + + int result = set_update(lookup_structs_set, subclasses_set); + Py_DECREF(subclasses_set); + + if (result == -1) goto cleanup; + } + } - while (_PySet_NextEntry(state->structs_set, &set_pos, &set_item, &set_hash)) { - struct_info = StructInfo_Convert(set_item); + set_pos = 0; + set_item = NULL; + set_hash = 0; + + while (_PySet_NextEntry(lookup_structs_set, &set_pos, &set_item, &set_hash)) + { + struct_info = StructInfo_Convert(set_item, state->allow_tagged_struct_subtypes); if (struct_info == NULL) goto cleanup; StructMetaObject *struct_type = ((StructInfo *)struct_info)->class; @@ -4542,7 +4652,7 @@ typenode_collect_convert_structs(TypeNodeCollectState *state) { } /* Add the new lookup to the cache */ - if (PyDict_SetItem(state->mod->struct_lookup_cache, state->structs_set, lookup) < 0) { + if (PyDict_SetItem(state->mod->struct_lookup_cache, structs_set, lookup) < 0) { goto cleanup; } @@ -4992,11 +5102,13 @@ typenode_collect_type(TypeNodeCollectState *state, PyObject *obj) { } static TypeNode * -TypeNode_Convert(PyObject *obj) { +TypeNode_Convert(PyObject *obj, bool allow_tagged_struct_subtypes) +{ TypeNode *out = NULL; TypeNodeCollectState state = {0}; state.mod = msgspec_get_global_state(); state.context = obj; + state.allow_tagged_struct_subtypes = allow_tagged_struct_subtypes; if (Py_EnterRecursiveCall(" while analyzing a type")) return NULL; @@ -6670,7 +6782,8 @@ static PyTypeObject StructInfo_Type = { }; static PyObject * -StructInfo_Convert(PyObject *obj) { +StructInfo_Convert(PyObject *obj, bool allow_tagged_struct_subtypes) +{ MsgspecState *mod = msgspec_get_global_state(); StructMetaObject *class; PyObject *annotations = NULL; @@ -6750,7 +6863,7 @@ StructInfo_Convert(PyObject *obj) { PyObject *field = PyTuple_GET_ITEM(class->struct_fields, i); PyObject *field_type = PyDict_GetItem(annotations, field); if (field_type == NULL) goto error; - TypeNode *type = TypeNode_Convert(field_type); + TypeNode *type = TypeNode_Convert(field_type, allow_tagged_struct_subtypes); if (type == NULL) goto error; info->types[i] = type; } @@ -8200,7 +8313,7 @@ static PyTypeObject LiteralInfo_Type = { }; static PyObject * -TypedDictInfo_Convert(PyObject *obj) { +TypedDictInfo_Convert(PyObject *obj, bool allow_tagged_struct_subtypes) { PyObject *annotations = NULL, *required = NULL; TypedDictInfo *info = NULL; MsgspecState *mod = msgspec_get_global_state(); @@ -8243,7 +8356,7 @@ TypedDictInfo_Convert(PyObject *obj) { Py_ssize_t pos = 0, i = 0; PyObject *key, *val; while (PyDict_Next(annotations, &pos, &key, &val)) { - TypeNode *type = TypeNode_Convert(val); + TypeNode *type = TypeNode_Convert(val, allow_tagged_struct_subtypes); if (type == NULL) goto cleanup; Py_INCREF(key); info->fields[i].key = key; @@ -8364,7 +8477,8 @@ static PyTypeObject TypedDictInfo_Type = { }; static PyObject * -DataclassInfo_Convert(PyObject *obj) { +DataclassInfo_Convert(PyObject *obj, bool allow_tagged_struct_subtypes) +{ PyObject *cls = NULL, *fields = NULL, *field_defaults = NULL; PyObject *pre_init = NULL, *post_init = NULL; DataclassInfo *info = NULL; @@ -8428,7 +8542,7 @@ DataclassInfo_Convert(PyObject *obj) { /* Traverse fields and initialize DataclassInfo */ for (Py_ssize_t i = 0; i < nfields; i++) { PyObject *field = PyTuple_GET_ITEM(fields, i); - TypeNode *type = TypeNode_Convert(PyTuple_GET_ITEM(field, 1)); + TypeNode *type = TypeNode_Convert(PyTuple_GET_ITEM(field, 1), allow_tagged_struct_subtypes); if (type == NULL) goto cleanup; /* If field has a default factory, set extra flag bit */ if (PyObject_IsTrue(PyTuple_GET_ITEM(field, 2))) { @@ -8584,7 +8698,7 @@ static PyTypeObject DataclassInfo_Type = { }; static PyObject * -NamedTupleInfo_Convert(PyObject *obj) { +NamedTupleInfo_Convert(PyObject *obj, bool allow_tagged_struct_subtypes) { MsgspecState *mod = msgspec_get_global_state(); NamedTupleInfo *info = NULL; PyObject *annotations = NULL, *cls = NULL, *fields = NULL; @@ -8642,7 +8756,7 @@ NamedTupleInfo_Convert(PyObject *obj) { type_obj = mod->typing_any; } /* Convert the type to a TypeNode */ - TypeNode *type = TypeNode_Convert(type_obj); + TypeNode *type = TypeNode_Convert(type_obj, allow_tagged_struct_subtypes); if (type == NULL) goto cleanup; info->types[i] = type; /* Get the field default (if any), and append it to the list */ @@ -14497,7 +14611,7 @@ Decoder_init(Decoder *self, PyObject *args, PyObject *kwds) self->ext_hook = ext_hook; /* Handle type */ - self->type = TypeNode_Convert(type); + self->type = TypeNode_Convert(type, false); if (self->type == NULL) { return -1; } @@ -16258,7 +16372,7 @@ msgspec_msgpack_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, state.type = &typenode_any; } else if (Py_TYPE(type) == &StructMetaType) { - PyObject *info = StructInfo_Convert(type); + PyObject *info = StructInfo_Convert(type, false); if (info == NULL) return NULL; bool array_like = ((StructMetaObject *)type)->array_like == OPT_TRUE; typenode_struct.types = array_like ? MS_TYPE_STRUCT_ARRAY : MS_TYPE_STRUCT; @@ -16266,7 +16380,7 @@ msgspec_msgpack_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, state.type = (TypeNode *)(&typenode_struct); } else { - state.type = TypeNode_Convert(type); + state.type = TypeNode_Convert(type, false); if (state.type == NULL) return NULL; } @@ -16405,7 +16519,7 @@ JSONDecoder_init(JSONDecoder *self, PyObject *args, PyObject *kwds) self->strict = strict; /* Handle type */ - self->type = TypeNode_Convert(type); + self->type = TypeNode_Convert(type, false); if (self->type == NULL) return -1; Py_INCREF(type); self->orig_type = type; @@ -19291,7 +19405,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO state.type = &typenode_any; } else if (Py_TYPE(type) == &StructMetaType) { - PyObject *info = StructInfo_Convert(type); + PyObject *info = StructInfo_Convert(type, false); if (info == NULL) return NULL; bool array_like = ((StructMetaObject *)type)->array_like == OPT_TRUE; typenode_struct.types = array_like ? MS_TYPE_STRUCT_ARRAY : MS_TYPE_STRUCT; @@ -19299,7 +19413,7 @@ msgspec_json_decode(PyObject *self, PyObject *const *args, Py_ssize_t nargs, PyO state.type = (TypeNode *)(&typenode_struct); } else { - state.type = TypeNode_Convert(type); + state.type = TypeNode_Convert(type, false); if (state.type == NULL) return NULL; } @@ -21619,7 +21733,8 @@ convert( } PyDoc_STRVAR(msgspec_convert__doc__, -"convert(obj, type, *, strict=True, from_attributes=False, dec_hook=None, str_keys=False, builtin_types=None)\n" +"convert(obj, type, *, strict=True, from_attributes=False, dec_hook=None, str_keys=False, builtin_types=None" +", allow_tagged_struct_subtypes=False)\n" "--\n" "\n" "Convert the input object to the specified type, or error accordingly.\n" @@ -21660,6 +21775,11 @@ PyDoc_STRVAR(msgspec_convert__doc__, " Useful for wrapping other serialization protocols. Indicates whether the\n" " wrapped protocol only supports string keys. Setting to True enables a wider\n" " set of coercion rules from string to non-string types for dict keys.\n" +"allow_tagged_struct_subtypes: bool, optional\n" +" If you'd like to allow converting a dictionary corresponding to a subtype of\n" +" a tagged struct where type annotation specifies a supertype, set this to True.\n" +" Useful for nested structure, such as trees consisting of various nodes all\n" +" inheriting from a single base type.\n" " Default is False.\n" "\n" "Returns\n" @@ -21688,18 +21808,18 @@ static PyObject* msgspec_convert(PyObject *self, PyObject *args, PyObject *kwargs) { PyObject *obj = NULL, *pytype = NULL, *builtin_types = NULL, *dec_hook = NULL; - int str_keys = false, strict = true, from_attributes = false; + int str_keys = false, strict = true, from_attributes = false, allow_tagged_struct_subtypes = false; ConvertState state; char *kwlist[] = { "obj", "type", "strict", "from_attributes", "dec_hook", "builtin_types", - "str_keys", NULL + "str_keys", "allow_tagged_struct_subtypes", NULL }; /* Parse arguments */ if (!PyArg_ParseTupleAndKeywords( - args, kwargs, "OO|$ppOOp", kwlist, - &obj, &pytype, &strict, &from_attributes, &dec_hook, &builtin_types, &str_keys + args, kwargs, "OO|$ppOOpp", kwlist, + &obj, &pytype, &strict, &from_attributes, &dec_hook, &builtin_types, &str_keys, &allow_tagged_struct_subtypes )) { return NULL; } @@ -21729,7 +21849,7 @@ msgspec_convert(PyObject *self, PyObject *args, PyObject *kwargs) /* Avoid allocating a new TypeNode for struct types */ if (Py_TYPE(pytype) == &StructMetaType) { - PyObject *info = StructInfo_Convert(pytype); + PyObject *info = StructInfo_Convert(pytype, allow_tagged_struct_subtypes); if (info == NULL) return NULL; bool array_like = ((StructMetaObject *)pytype)->array_like == OPT_TRUE; TypeNodeSimple type; @@ -21740,7 +21860,7 @@ msgspec_convert(PyObject *self, PyObject *args, PyObject *kwargs) return out; } - TypeNode *type = TypeNode_Convert(pytype); + TypeNode *type = TypeNode_Convert(pytype, allow_tagged_struct_subtypes); if (type == NULL) return NULL; PyObject *out = convert(&state, obj, type, NULL); TypeNode_Free(type); diff --git a/tests/test_convert.py b/tests/test_convert.py index 95d496e3..4a745fde 100644 --- a/tests/test_convert.py +++ b/tests/test_convert.py @@ -1867,6 +1867,22 @@ class Test(Struct, tag=tag): assert f"Invalid value {bad!r}" in str(rec.value) assert "`$.type`" in str(rec.value) + def test_tagged_struct_subclass(self): + class Base(Struct, tag=True): + pass + + class A(Base): + pass + + class C(Base): + other: Base + + assert convert( + {"type": "C", "other": {"type": "C", "other": {"type": "A"}}}, + type=C, + allow_tagged_struct_subtypes=True, + ) == C(other=C(other=A())) + @pytest.mark.parametrize("tag_val", [2**64 - 1, 2**64, -(2**63) - 1]) @mapcls_and_from_attributes def test_tagged_struct_int_tag_not_int64_always_invalid(