diff --git a/docs/api/xml-nodes.rst b/docs/api/xml-nodes.rst index 8becdea46..45629000c 100644 --- a/docs/api/xml-nodes.rst +++ b/docs/api/xml-nodes.rst @@ -17,3 +17,4 @@ for models and their fields. UnionNode PrimitiveNode StandardNode + WrapperNode diff --git a/docs/examples.rst b/docs/examples.rst index c03d3b0bc..9b59ae606 100644 --- a/docs/examples.rst +++ b/docs/examples.rst @@ -34,6 +34,7 @@ Advance Topics examples/custom-property-names examples/custom-class-factory + examples/wrapped-list Test Suites diff --git a/docs/examples/wrapped-list.rst b/docs/examples/wrapped-list.rst new file mode 100644 index 000000000..3a40f4bd3 --- /dev/null +++ b/docs/examples/wrapped-list.rst @@ -0,0 +1,46 @@ +============ +Wrapped List +============ + +XML data structures commonly wrap element and primitive collections. +For instance, a library may have several books and and other stuff as well. +In terms of `OpenAPI 3 `_, +these data structures are `wrapped`. Hence, xsdata has the field parameter `wrapper`, +which wraps any element/primitive collection into a custom xml element without the +need of a dedicated wrapper class. + +.. doctest:: + + >>> from dataclasses import dataclass, field + >>> from typing import List + >>> from xsdata.formats.dataclass.serializers import XmlSerializer + >>> from xsdata.formats.dataclass.serializers.config import SerializerConfig + >>> + >>> config = SerializerConfig(pretty_print=True, xml_declaration=False) + >>> serializer = XmlSerializer(config=config) + >>> + >>> @dataclass + ... class Library: + ... books: List[str] = field( + ... metadata={ + ... "wrapper": "Books", + ... "name": "Title", + ... "type": "Element", + ... } + ... ) + ... + >>> obj = Library( + ... books = [ + ... "python for beginners", + ... "beautiful xml", + ... ] + ... ) + >>> + >>> print(serializer.render(obj)) + + + python for beginners + beautiful xml + + + diff --git a/docs/models.rst b/docs/models.rst index c34d0f6c5..570cae1d5 100644 --- a/docs/models.rst +++ b/docs/models.rst @@ -304,6 +304,10 @@ marshalling. * - default_factory - Any - Default value factory + * - wrapper + - str + - The element name to wrap a collection of elements or primitives + .. warning:: diff --git a/tests/formats/dataclass/models/test_builders.py b/tests/formats/dataclass/models/test_builders.py index 9c3efc009..9787d6e91 100644 --- a/tests/formats/dataclass/models/test_builders.py +++ b/tests/formats/dataclass/models/test_builders.py @@ -8,6 +8,7 @@ from typing import get_type_hints from typing import Iterator from typing import List +from typing import Tuple from typing import Union from unittest import mock from unittest import TestCase @@ -27,6 +28,7 @@ from xsdata.formats.dataclass.compat import class_types from xsdata.formats.dataclass.models.builders import XmlMetaBuilder from xsdata.formats.dataclass.models.builders import XmlVarBuilder +from xsdata.formats.dataclass.models.elements import XmlMeta from xsdata.formats.dataclass.models.elements import XmlType from xsdata.models.datatype import XmlDate from xsdata.utils import text @@ -103,6 +105,43 @@ class Meta: result = self.builder.build(Thug, None) self.assertEqual("thug", result.qname) + def test_wrapper(self): + @dataclass + class PrimitiveType: + attr: str = field(metadata={"wrapper": "Items"}) + + @dataclass + class UnionType: + attr: Union[str, int] = field(metadata={"wrapper": "Items"}) + + @dataclass + class UnionCollection: + union_collection: List[Union[str, int]] = field( + metadata={"wrapper": "Items"} + ) + + @dataclass + class ListType: + attr: List[str] = field(metadata={"wrapper": "Items"}) + + @dataclass + class TupleType: + attr: Tuple[str, ...] = field(metadata={"wrapper": "Items"}) + + # @dataclass + # class SetType: + # attr: Set[str] = field(metadata={"wrapper": "Items"}) + + with self.assertRaises(XmlContextError): + self.builder.build(PrimitiveType, None) + with self.assertRaises(XmlContextError): + self.builder.build(UnionType, None) + + self.assertIsInstance(self.builder.build(ListType, None), XmlMeta) + self.assertIsInstance(self.builder.build(TupleType, None), XmlMeta) + # not supported by analyze_types + # self.assertIsInstance(self.builder.build(SetType, None), XmlMeta) + def test_build_with_no_dataclass_raises_exception(self, *args): with self.assertRaises(XmlContextError) as cm: self.builder.build(int, None) diff --git a/tests/formats/dataclass/parsers/nodes/test_wrapper.py b/tests/formats/dataclass/parsers/nodes/test_wrapper.py new file mode 100644 index 000000000..65dec8b73 --- /dev/null +++ b/tests/formats/dataclass/parsers/nodes/test_wrapper.py @@ -0,0 +1,71 @@ +from dataclasses import dataclass +from dataclasses import field +from typing import List +from unittest import TestCase + +from xsdata.formats.dataclass.parsers import XmlParser + + +class WrapperTests(TestCase): + def setUp(self) -> None: + self.parser = XmlParser() + + def test_namespace(self): + @dataclass + class NamespaceWrapper: + items: List[str] = field( + metadata={ + "wrapper": "Items", + "type": "Element", + "name": "item", + "namespace": "ns", + } + ) + + xml = 'ab' + obj = self.parser.from_string(xml, clazz=NamespaceWrapper) + self.assertIsInstance(obj, NamespaceWrapper) + self.assertTrue(hasattr(obj, "items")) + self.assertEqual(len(obj.items), 2) + self.assertEqual(obj.items[0], "a") + self.assertEqual(obj.items[1], "b") + + def test_primitive(self): + @dataclass + class PrimitiveWrapper: + primitive_list: List[str] = field( + metadata={ + "wrapper": "PrimitiveList", + "type": "Element", + "name": "Value", + } + ) + + xml = r"Value 1Value 2" + obj = self.parser.from_string(xml, clazz=PrimitiveWrapper) + self.assertTrue(hasattr(obj, "primitive_list")) + self.assertIsInstance(obj.primitive_list, list) + self.assertEqual(len(obj.primitive_list), 2) + self.assertEqual(obj.primitive_list[0], "Value 1") + self.assertEqual(obj.primitive_list[1], "Value 2") + + def test_element(self): + @dataclass + class ElementObject: + content: str = field(metadata={"type": "Element"}) + + @dataclass + class ElementWrapper: + elements: List[ElementObject] = field( + metadata={"wrapper": "Elements", "type": "Element", "name": "Object"} + ) + + xml = "HelloWorld" + obj = self.parser.from_string(xml, clazz=ElementWrapper) + self.assertTrue(hasattr(obj, "elements")) + self.assertIsInstance(obj.elements, list) + self.assertEqual(len(obj.elements), 2) + self.assertIsInstance(obj.elements[0], ElementObject) + self.assertIsInstance(obj.elements[1], ElementObject) + self.assertEqual(obj.elements[0].content, "Hello") + self.assertEqual(obj.elements[1].content, "World") diff --git a/tests/formats/dataclass/serializers/test_xml.py b/tests/formats/dataclass/serializers/test_xml.py index 30d01c585..a17f05ef6 100644 --- a/tests/formats/dataclass/serializers/test_xml.py +++ b/tests/formats/dataclass/serializers/test_xml.py @@ -1,5 +1,9 @@ +import re +from dataclasses import dataclass +from dataclasses import field from dataclasses import make_dataclass from typing import Generator +from typing import List from unittest import TestCase from xml.etree.ElementTree import QName @@ -25,6 +29,58 @@ class XmlSerializerTests(TestCase): def setUp(self) -> None: self.serializer = XmlSerializer() + def test_wrapper_primitive(self): + @dataclass + class PrimitiveWrapper: + primitive_list: List[str] = field( + metadata={ + "wrapper": "PrimitiveList", + "type": "Element", + "name": "Value", + } + ) + + obj = PrimitiveWrapper(primitive_list=["Value 1", "Value 2"]) + xml = self.serializer.render(obj) + expected = r"Value 1Value 2" + self.assertIsNotNone(re.search(expected, xml)) + + def test_wrapper_element(self): + @dataclass + class ElementObject: + content: str = field(metadata={"type": "Element"}) + + @dataclass + class ElementWrapper: + elements: List[ElementObject] = field( + metadata={"wrapper": "Elements", "type": "Element", "name": "Object"} + ) + + obj = ElementWrapper( + elements=[ElementObject(content="Hello"), ElementObject(content="World")] + ) + xml = self.serializer.render(obj) + expected = "HelloWorld" + self.assertIsNotNone(re.search(expected, xml)) + + def test_wrapper_namespace(self): + @dataclass + class NamespaceWrapper: + items: List[str] = field( + metadata={ + "wrapper": "Items", + "type": "Element", + "name": "item", + "namespace": "ns", + } + ) + + ns_map = {"foo": "ns"} + obj = NamespaceWrapper(items=["a", "b"]) + xml = self.serializer.render(obj, ns_map=ns_map) + expected = 'ab' + self.assertIsNotNone(re.search(expected, xml)) + def test_write_object_with_derived_element(self): book = BookForm(id="123") obj = DerivedElement(qname="item", value=book) diff --git a/tests/formats/dataclass/test_elements.py b/tests/formats/dataclass/test_elements.py index 3b20255e5..4441a4996 100644 --- a/tests/formats/dataclass/test_elements.py +++ b/tests/formats/dataclass/test_elements.py @@ -169,6 +169,7 @@ def test__repr__(self): "wildcards=[], " "attributes={}, " "any_attributes=[], " + "wrappers={}, " "namespace=None, " "mixed_content=False)" ) diff --git a/xsdata/formats/dataclass/context.py b/xsdata/formats/dataclass/context.py index 16fb4319e..5ace31174 100644 --- a/xsdata/formats/dataclass/context.py +++ b/xsdata/formats/dataclass/context.py @@ -133,13 +133,21 @@ def find_type_by_fields(self, field_names: Set[str]) -> Optional[Type[T]]: :param field_names: A unique list of field names """ - self.build_xsi_cache() - for types in self.xsi_cache.values(): - for clazz in types: - if self.local_names_match(field_names, clazz): - return clazz + def get_field_diff(clazz: Type) -> int: + meta = self.cache[clazz] + local_names = {var.local_name for var in meta.get_all_vars()} + return len(local_names - field_names) - return None + self.build_xsi_cache() + choices = [ + (clazz, get_field_diff(clazz)) + for types in self.xsi_cache.values() + for clazz in types + if self.local_names_match(field_names, clazz) + ] + + choices.sort(key=lambda x: (x[1], x[0].__name__)) + return choices[0][0] if len(choices) > 0 else None def find_subclass(self, clazz: Type, qname: str) -> Optional[Type]: """ diff --git a/xsdata/formats/dataclass/models/builders.py b/xsdata/formats/dataclass/models/builders.py index 166b0e5fb..d97d86ef0 100644 --- a/xsdata/formats/dataclass/models/builders.py +++ b/xsdata/formats/dataclass/models/builders.py @@ -71,9 +71,12 @@ def build(self, clazz: Type, parent_namespace: Optional[str]) -> XmlMeta: choices = [] any_attributes = [] wildcards = [] + wrappers: Dict[str, List[XmlVar]] = defaultdict(list) text = None for var in class_vars: + if var.wrapper is not None: + wrappers[var.wrapper].append(var) if var.is_attribute: attributes[var.qname] = var elif var.is_element: @@ -98,6 +101,7 @@ def build(self, clazz: Type, parent_namespace: Optional[str]) -> XmlMeta: choices=choices, any_attributes=any_attributes, wildcards=wildcards, + wrappers=wrappers, ) def build_vars( @@ -269,6 +273,7 @@ def build( nillable = metadata.get("nillable", False) format_str = metadata.get("format", None) sequential = metadata.get("sequential", False) + wrapper = metadata.get("wrapper", None) origin, sub_origin, types = self.analyze_types(type_hint, globalns) @@ -277,6 +282,14 @@ def build( f"Xml type '{xml_type}' does not support typing: {type_hint}" ) + if wrapper is not None: + if not isinstance(origin, type) or not issubclass( + origin, (list, set, tuple) + ): + raise XmlContextError( + f"a wrapper requires a collection type on attribute {name}" + ) + local_name = self.build_local_name(xml_type, local_name, name) if tokens and sub_origin is None: @@ -291,6 +304,8 @@ def build( namespaces = self.resolve_namespaces(xml_type, namespace, parent_namespace) default_namespace = self.default_namespace(namespaces) qname = build_qname(default_namespace, local_name) + if wrapper is not None: + wrapper = build_qname(default_namespace, wrapper) elements = {} wildcards = [] @@ -323,6 +338,7 @@ def build( namespaces=namespaces, xml_type=xml_type, derived=False, + wrapper=wrapper, ) def build_choices( diff --git a/xsdata/formats/dataclass/models/elements.py b/xsdata/formats/dataclass/models/elements.py index 492b5110a..c1e72fa03 100644 --- a/xsdata/formats/dataclass/models/elements.py +++ b/xsdata/formats/dataclass/models/elements.py @@ -74,6 +74,7 @@ class XmlVar(MetaMixin): :param namespaces: List of the supported namespaces :param elements: Mapping of qname-repeatable elements :param wildcards: List of repeatable wildcards + :param wrapper: A name for the wrapper. Applies for list types only. """ __slots__ = ( @@ -96,6 +97,7 @@ class XmlVar(MetaMixin): "namespaces", "elements", "wildcards", + "wrapper", # Calculated "tokens", "list_element", @@ -132,6 +134,7 @@ def __init__( namespaces: Sequence[str], elements: Mapping[str, "XmlVar"], wildcards: Sequence["XmlVar"], + wrapper: Optional[str] = None, **kwargs: Any, ): self.index = index @@ -153,6 +156,7 @@ def __init__( self.namespaces = namespaces self.elements = elements self.wildcards = wildcards + self.wrapper = wrapper self.factory = factory self.tokens_factory = tokens_factory @@ -316,6 +320,7 @@ class XmlMeta(MetaMixin): "wildcards", "attributes", "any_attributes", + "wrappers", # Calculated "namespace", "mixed_content", @@ -333,6 +338,7 @@ def __init__( wildcards: Sequence[XmlVar], attributes: Mapping[str, XmlVar], any_attributes: Sequence[XmlVar], + wrappers: Mapping[str, Sequence[XmlVar]], **kwargs: Any, ): self.clazz = clazz @@ -347,6 +353,7 @@ def __init__( self.attributes = attributes self.any_attributes = any_attributes self.mixed_content = any(wildcard.mixed for wildcard in self.wildcards) + self.wrappers = wrappers @property def element_types(self) -> Set[Type]: diff --git a/xsdata/formats/dataclass/parsers/bases.py b/xsdata/formats/dataclass/parsers/bases.py index 31ad1c964..c835a6805 100644 --- a/xsdata/formats/dataclass/parsers/bases.py +++ b/xsdata/formats/dataclass/parsers/bases.py @@ -3,6 +3,7 @@ from dataclasses import dataclass from dataclasses import field from typing import Any +from typing import cast from typing import Dict from typing import List from typing import Optional @@ -82,9 +83,14 @@ def start( :param attrs: Attribute key-value map :param ns_map: Namespace prefix-URI map """ + from xsdata.formats.dataclass.parsers.nodes import ElementNode, WrapperNode + try: item = queue[-1] - child = item.child(qname, attrs, ns_map, len(objects)) + if isinstance(item, ElementNode) and qname in item.meta.wrappers: + child = cast(XmlNode, WrapperNode(parent=item)) + else: + child = item.child(qname, attrs, ns_map, len(objects)) except IndexError: xsi_type = ParserUtils.xsi_type(attrs, ns_map) @@ -108,8 +114,6 @@ def start( xsi_nil = ParserUtils.xsi_nil(attrs) - from xsdata.formats.dataclass.parsers.nodes import ElementNode - child = ElementNode( position=0, meta=meta, diff --git a/xsdata/formats/dataclass/parsers/nodes/__init__.py b/xsdata/formats/dataclass/parsers/nodes/__init__.py index 07c3aad91..9b53ea373 100644 --- a/xsdata/formats/dataclass/parsers/nodes/__init__.py +++ b/xsdata/formats/dataclass/parsers/nodes/__init__.py @@ -4,6 +4,7 @@ from xsdata.formats.dataclass.parsers.nodes.standard import StandardNode from xsdata.formats.dataclass.parsers.nodes.union import UnionNode from xsdata.formats.dataclass.parsers.nodes.wildcard import WildcardNode +from xsdata.formats.dataclass.parsers.nodes.wrapper import WrapperNode __all__ = [ "ElementNode", @@ -12,4 +13,5 @@ "StandardNode", "UnionNode", "WildcardNode", + "WrapperNode", ] diff --git a/xsdata/formats/dataclass/parsers/nodes/wrapper.py b/xsdata/formats/dataclass/parsers/nodes/wrapper.py new file mode 100644 index 000000000..a72f9da0d --- /dev/null +++ b/xsdata/formats/dataclass/parsers/nodes/wrapper.py @@ -0,0 +1,26 @@ +from typing import Dict +from typing import List +from typing import Optional + +from xsdata.formats.dataclass.parsers.mixins import XmlNode +from xsdata.formats.dataclass.parsers.nodes.element import ElementNode + + +class WrapperNode(XmlNode): + """ + XmlNode to wrap an element or primitive list. + + :param parent: The parent node + """ + + def __init__(self, parent: ElementNode): + self.parent = parent + self.ns_map = parent.ns_map + + def bind( + self, qname: str, text: Optional[str], tail: Optional[str], objects: List + ) -> bool: + return False + + def child(self, qname: str, attrs: Dict, ns_map: Dict, position: int) -> XmlNode: + return self.parent.child(qname, attrs, ns_map, position) diff --git a/xsdata/formats/dataclass/serializers/xml.py b/xsdata/formats/dataclass/serializers/xml.py index 0ae7eb0df..a13c24607 100644 --- a/xsdata/formats/dataclass/serializers/xml.py +++ b/xsdata/formats/dataclass/serializers/xml.py @@ -158,8 +158,14 @@ def write_list( self, values: Iterable, var: XmlVar, namespace: NoneStr ) -> Generator: """Produce an events stream for the given list of values.""" - for value in values: - yield from self.write_value(value, var, namespace) + if var.wrapper is not None: + yield XmlWriterEvent.START, var.wrapper + for value in values: + yield from self.write_value(value, var, namespace) + yield XmlWriterEvent.END, var.wrapper + else: + for value in values: + yield from self.write_value(value, var, namespace) def write_tokens(self, value: Any, var: XmlVar, namespace: NoneStr) -> Generator: """Produce an events stream for the given tokens list or list of tokens diff --git a/xsdata/utils/testing.py b/xsdata/utils/testing.py index 4fbde1b99..1605375b2 100644 --- a/xsdata/utils/testing.py +++ b/xsdata/utils/testing.py @@ -422,6 +422,7 @@ def create( namespaces=namespaces, elements=elements, wildcards=wildcards, + wrapper=None, ) @@ -476,6 +477,7 @@ def create( # type: ignore wildcards=wildcards, attributes=attributes, any_attributes=any_attributes, + wrappers={}, )