diff --git a/tests/codegen/handlers/test_attribute_default_value.py b/tests/codegen/handlers/test_attribute_default_value.py new file mode 100644 index 000000000..fd30a28d9 --- /dev/null +++ b/tests/codegen/handlers/test_attribute_default_value.py @@ -0,0 +1,147 @@ +from unittest import mock + +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers import AttributeDefaultValueHandler +from xsdata.models.enums import Namespace +from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import AttrTypeFactory +from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import FactoryTestCase + + +class AttributeDefaultValueHandlerTests(FactoryTestCase): + def setUp(self): + super().setUp() + + container = ClassContainer() + self.processor = AttributeDefaultValueHandler(container=container) + + def test_process_attribute_with_enumeration(self): + target = ClassFactory.create() + attr = AttrFactory.enumeration() + attr.restrictions.max_occurs = 2 + attr.fixed = True + + self.processor.process_attribute(target, attr) + self.assertTrue(attr.fixed) + + def test_process_attribute_with_optional_field(self): + target = ClassFactory.create() + attr = AttrFactory.create(fixed=True, default=2) + attr.restrictions.min_occurs = 0 + self.processor.process_attribute(target, attr) + self.assertFalse(attr.fixed) + self.assertIsNone(attr.default) + + def test_process_attribute_with_xsi_type(self): + target = ClassFactory.create() + attr = AttrFactory.create( + fixed=True, default=2, name="type", namespace=Namespace.XSI.uri + ) + self.processor.process_attribute(target, attr) + self.assertFalse(attr.fixed) + self.assertIsNone(attr.default) + + def test_process_attribute_with_valid_case(self): + target = ClassFactory.create() + attr = AttrFactory.create(fixed=True, default=2) + self.processor.process_attribute(target, attr) + self.assertTrue(attr.fixed) + self.assertEqual(2, attr.default) + + @mock.patch("xsdata.codegen.handlers.attribute_default_value.logger.warning") + @mock.patch.object(AttributeDefaultValueHandler, "find_enum") + def test_process_attribute_enum(self, mock_find_enum, mock_logger_warning): + enum_one = ClassFactory.enumeration(1, qname="{a}root") + enum_one.attrs[0].default = "1" + enum_one.attrs[0].name = "one" + enum_two = ClassFactory.enumeration(1, qname="inner") + enum_two.attrs[0].default = "2" + enum_two.attrs[0].name = "two" + enum_three = ClassFactory.enumeration(2, qname="missing_member") + enum_three.attrs[0].default = "4" + enum_three.attrs[0].name = "four" + enum_three.attrs[1].default = "5" + enum_three.attrs[1].name = "five" + + mock_find_enum.side_effect = [ + None, + enum_one, + None, + enum_two, + enum_three, + enum_three, + ] + + target = ClassFactory.create( + qname="target", + attrs=[ + AttrFactory.create( + types=[ + AttrTypeFactory.create(), + AttrTypeFactory.create(qname="foo"), + ], + default="1", + ), + AttrFactory.create( + types=[ + AttrTypeFactory.create(), + AttrTypeFactory.create(qname="bar", forward=True), + ], + default="2", + ), + AttrFactory.create(default="3"), + AttrFactory.create(default=" 4 5"), + ], + ) + + actual = [] + for attr in target.attrs: + self.processor.process_attribute(target, attr) + actual.append(attr.default) + + self.assertEqual( + [ + "@enum@{a}root::one", + "@enum@inner::two", + None, + "@enum@missing_member::four@five", + ], + actual, + ) + mock_logger_warning.assert_called_once_with( + "No enumeration member matched %s.%s default value `%s`", + target.name, + target.attrs[2].local_name, + "3", + ) + + def test_find_enum(self): + native_type = AttrTypeFactory.create() + matching_external = AttrTypeFactory.create("foo") + missing_external = AttrTypeFactory.create("bar") + enumeration = ClassFactory.enumeration(1, qname="foo") + inner = ClassFactory.enumeration(1, qname="foobar") + + target = ClassFactory.create( + attrs=[ + AttrFactory.create( + types=[ + native_type, + matching_external, + missing_external, + ] + ) + ], + inner=[inner], + ) + self.processor.container.extend([target, enumeration]) + + actual = self.processor.find_enum(native_type) + self.assertIsNone(actual) + + actual = self.processor.find_enum(matching_external) + self.assertEqual(enumeration, actual) + + actual = self.processor.find_enum(missing_external) + self.assertIsNone(actual) diff --git a/tests/codegen/handlers/test_attribute_name_conflict.py b/tests/codegen/handlers/test_attribute_name_conflict.py new file mode 100644 index 000000000..031e20428 --- /dev/null +++ b/tests/codegen/handlers/test_attribute_name_conflict.py @@ -0,0 +1,52 @@ +from xsdata.codegen.handlers import AttributeNameConflictHandler +from xsdata.models.enums import Tag +from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import FactoryTestCase + + +class AttributeNameConflictHandlerTests(FactoryTestCase): + def setUp(self): + super().setUp() + + self.processor = AttributeNameConflictHandler() + + def test_process(self): + attrs = [ + AttrFactory.create(name="a", tag=Tag.ELEMENT), + AttrFactory.create(name="a", tag=Tag.ATTRIBUTE), + AttrFactory.create(name="b", tag=Tag.ATTRIBUTE), + AttrFactory.create(name="c", tag=Tag.ATTRIBUTE), + AttrFactory.create(name="c", tag=Tag.ELEMENT), + AttrFactory.create(name="d", tag=Tag.ELEMENT), + AttrFactory.create(name="d", tag=Tag.ELEMENT), + AttrFactory.create(name="e", tag=Tag.ELEMENT, namespace="b"), + AttrFactory.create(name="e", tag=Tag.ELEMENT), + AttrFactory.create(name="f", tag=Tag.ELEMENT), + AttrFactory.create(name="f", tag=Tag.ELEMENT, namespace="a"), + AttrFactory.create(name="gA", tag=Tag.ENUMERATION), + AttrFactory.create(name="g[A]", tag=Tag.ENUMERATION), + AttrFactory.create(name="g_a", tag=Tag.ENUMERATION), + AttrFactory.create(name="g_a_1", tag=Tag.ENUMERATION), + ] + target = ClassFactory.create(attrs=attrs) + + self.processor.process(target) + expected = [ + "a", + "a_Attribute", + "b", + "c_Attribute", + "c", + "d_Element", + "d", + "b_e", + "e", + "f", + "a_f", + "gA", + "g[A]_2", + "g_a_3", + "g_a_1", + ] + self.assertEqual(expected, [x.name for x in attrs]) diff --git a/tests/codegen/handlers/test_attribute_restrictions.py b/tests/codegen/handlers/test_attribute_restrictions.py new file mode 100644 index 000000000..ea201de31 --- /dev/null +++ b/tests/codegen/handlers/test_attribute_restrictions.py @@ -0,0 +1,112 @@ +from xsdata.codegen.handlers import AttributeRestrictionsHandler +from xsdata.codegen.models import Class +from xsdata.codegen.models import Restrictions +from xsdata.models.enums import Tag +from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import FactoryTestCase + + +class AttributeRestrictionsHandlerTests(FactoryTestCase): + def setUp(self): + super().setUp() + + self.processor = AttributeRestrictionsHandler() + + def test_reset_occurrences(self): + required = Restrictions(min_occurs=1, max_occurs=1) + attr = AttrFactory.attribute(restrictions=required.clone()) + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + tokens = Restrictions(required=True, tokens=True, min_occurs=1, max_occurs=1) + attr = AttrFactory.element(restrictions=tokens.clone()) + self.processor.reset_occurrences(attr) + self.assertFalse(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + attr = AttrFactory.element(restrictions=tokens.clone()) + attr.restrictions.max_occurs = 2 + self.processor.reset_occurrences(attr) + self.assertFalse(attr.restrictions.required) + self.assertIsNotNone(attr.restrictions.min_occurs) + self.assertIsNotNone(attr.restrictions.max_occurs) + + multiple = Restrictions(min_occurs=0, max_occurs=2) + attr = AttrFactory.create(tag=Tag.EXTENSION, restrictions=multiple) + self.processor.reset_occurrences(attr) + self.assertTrue(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + multiple = Restrictions(max_occurs=2, required=True) + attr = AttrFactory.element(restrictions=multiple, fixed=True) + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.required) + self.assertEqual(0, attr.restrictions.min_occurs) + self.assertFalse(attr.fixed) + + attr = AttrFactory.element(restrictions=required.clone()) + self.processor.reset_occurrences(attr) + self.assertTrue(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + + restrictions = Restrictions(required=True, min_occurs=0, max_occurs=1) + attr = AttrFactory.element(restrictions=restrictions, default="A", fixed=True) + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.required) + self.assertIsNone(attr.restrictions.min_occurs) + self.assertIsNone(attr.restrictions.max_occurs) + self.assertIsNone(attr.default) + self.assertFalse(attr.fixed) + + attr = AttrFactory.element(restrictions=required.clone(), default="A") + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.required) + + attr = AttrFactory.element(restrictions=required.clone(), fixed=True) + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.required) + + attr = AttrFactory.element(restrictions=required.clone()) + attr.restrictions.nillable = True + self.processor.reset_occurrences(attr) + self.assertIsNone(attr.restrictions.required) + + def test_reset_sequential(self): + def len_sequential(target: Class): + return len([attr for attr in target.attrs if attr.restrictions.sequential]) + + restrictions = Restrictions(max_occurs=2, sequential=True) + target = ClassFactory.create( + attrs=[ + AttrFactory.create(restrictions=restrictions.clone()), + AttrFactory.create(restrictions=restrictions.clone()), + ] + ) + + attrs_clone = [attr.clone() for attr in target.attrs] + + self.processor.reset_sequential(target, 0) + self.assertEqual(2, len_sequential(target)) + + target.attrs[0].restrictions.sequential = False + self.processor.reset_sequential(target, 0) + self.assertEqual(1, len_sequential(target)) + + self.processor.reset_sequential(target, 1) + self.assertEqual(0, len_sequential(target)) + + target.attrs = attrs_clone + target.attrs[1].restrictions.sequential = False + self.processor.reset_sequential(target, 0) + self.assertEqual(0, len_sequential(target)) + + target.attrs[0].restrictions.sequential = True + target.attrs[0].restrictions.max_occurs = 0 + target.attrs[1].restrictions.sequential = True + self.processor.reset_sequential(target, 0) + self.assertEqual(1, len_sequential(target)) diff --git a/tests/codegen/handlers/test_attribute_sanitizer.py b/tests/codegen/handlers/test_attribute_sanitizer.py index a901a531d..7a6652145 100644 --- a/tests/codegen/handlers/test_attribute_sanitizer.py +++ b/tests/codegen/handlers/test_attribute_sanitizer.py @@ -4,7 +4,6 @@ from xsdata.codegen.models import Status from xsdata.models.enums import DataType from xsdata.models.enums import Tag -from xsdata.utils import collections from xsdata.utils.testing import AttrFactory from xsdata.utils.testing import AttrTypeFactory from xsdata.utils.testing import ClassFactory diff --git a/tests/codegen/handlers/test_choice_group.py b/tests/codegen/handlers/test_choice_group.py new file mode 100644 index 000000000..e3ab501e3 --- /dev/null +++ b/tests/codegen/handlers/test_choice_group.py @@ -0,0 +1,154 @@ +from unittest import mock + +from xsdata.codegen.handlers import ChoiceGroupHandler +from xsdata.codegen.models import Restrictions +from xsdata.models.enums import DataType +from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import AttrTypeFactory +from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import FactoryTestCase + + +class ChoiceGroupHandlerTests(FactoryTestCase): + def setUp(self): + super().setUp() + + self.processor = ChoiceGroupHandler() + + @mock.patch.object(ChoiceGroupHandler, "group_fields") + def test_process(self, mock_group_fields): + target = ClassFactory.elements(8) + # First group repeating + target.attrs[0].restrictions.choice = "1" + target.attrs[1].restrictions.choice = "1" + target.attrs[1].restrictions.max_occurs = 2 + # Second group repeating + target.attrs[2].restrictions.choice = "2" + target.attrs[3].restrictions.choice = "2" + target.attrs[3].restrictions.max_occurs = 2 + # Third group optional + target.attrs[4].restrictions.choice = "3" + target.attrs[5].restrictions.choice = "3" + + self.processor.process(target) + mock_group_fields.assert_has_calls( + [ + mock.call(target, target.attrs[0:2]), + mock.call(target, target.attrs[2:4]), + ] + ) + + def test_group_fields(self): + target = ClassFactory.create(attrs=AttrFactory.list(2)) + target.attrs[0].restrictions.choice = "1" + target.attrs[1].restrictions.choice = "1" + target.attrs[0].restrictions.min_occurs = 10 + target.attrs[0].restrictions.max_occurs = 15 + target.attrs[1].restrictions.min_occurs = 5 + target.attrs[1].restrictions.max_occurs = 20 + + expected = AttrFactory.create( + name="attr_B_Or_attr_C", + tag="Choice", + index=0, + types=[AttrTypeFactory.native(DataType.ANY_TYPE)], + choices=[ + AttrFactory.create( + tag=target.attrs[0].tag, + name="attr_B", + types=target.attrs[0].types, + ), + AttrFactory.create( + tag=target.attrs[1].tag, + name="attr_C", + types=target.attrs[1].types, + ), + ], + ) + expected_res = Restrictions(min_occurs=5, max_occurs=20) + + self.processor.group_fields(target, list(target.attrs)) + self.assertEqual(1, len(target.attrs)) + self.assertEqual(expected, target.attrs[0]) + self.assertEqual(expected_res, target.attrs[0].restrictions) + + def test_group_fields_with_effective_choices_sums_occurs(self): + target = ClassFactory.create(attrs=AttrFactory.list(2)) + target.attrs[0].restrictions.choice = "effective_1" + target.attrs[1].restrictions.choice = "effective_1" + target.attrs[0].restrictions.min_occurs = 1 + target.attrs[0].restrictions.max_occurs = 2 + target.attrs[1].restrictions.min_occurs = 3 + target.attrs[1].restrictions.max_occurs = 4 + + expected_res = Restrictions(min_occurs=4, max_occurs=6) + + self.processor.group_fields(target, list(target.attrs)) + self.assertEqual(1, len(target.attrs)) + self.assertEqual(expected_res, target.attrs[0].restrictions) + + def test_group_fields_limit_name(self): + target = ClassFactory.create(attrs=AttrFactory.list(3)) + for attr in target.attrs: + attr.restrictions.choice = "1" + + self.processor.group_fields(target, list(target.attrs)) + + self.assertEqual(1, len(target.attrs)) + self.assertEqual("attr_B_Or_attr_C_Or_attr_D", target.attrs[0].name) + + target = ClassFactory.create(attrs=AttrFactory.list(4)) + for attr in target.attrs: + attr.restrictions.choice = "1" + + self.processor.group_fields(target, list(target.attrs)) + self.assertEqual("choice", target.attrs[0].name) + + target = ClassFactory.create() + attr = AttrFactory.element(restrictions=Restrictions(choice="1")) + target.attrs.append(attr) + target.attrs.append(attr.clone()) + self.processor.group_fields(target, list(target.attrs)) + self.assertEqual("choice", target.attrs[0].name) + + def test_build_attr_choice(self): + attr = AttrFactory.create( + name="a", namespace="xsdata", default="123", help="help", fixed=True + ) + attr.local_name = "aaa" + attr.restrictions = Restrictions( + required=True, + prohibited=None, + min_occurs=1, + max_occurs=1, + min_exclusive="1.1", + min_inclusive="1", + min_length=1, + max_exclusive="1", + max_inclusive="1.1", + max_length=10, + total_digits=333, + fraction_digits=2, + length=5, + white_space="collapse", + pattern=r"[A-Z]", + explicit_timezone="+1", + nillable=True, + choice="abc", + sequential=True, + ) + expected_res = attr.restrictions.clone() + expected_res.min_occurs = None + expected_res.max_occurs = None + expected_res.sequential = None + + actual = self.processor.build_attr_choice(attr) + + self.assertEqual(attr.local_name, actual.name) + self.assertEqual(attr.namespace, actual.namespace) + self.assertEqual(attr.default, actual.default) + self.assertEqual(attr.tag, actual.tag) + self.assertEqual(attr.types, actual.types) + self.assertEqual(expected_res, actual.restrictions) + self.assertEqual(attr.help, actual.help) + self.assertFalse(actual.fixed) diff --git a/tests/codegen/handlers/test_class_name_conflict.py b/tests/codegen/handlers/test_class_name_conflict.py new file mode 100644 index 000000000..bebcb3b62 --- /dev/null +++ b/tests/codegen/handlers/test_class_name_conflict.py @@ -0,0 +1,187 @@ +from unittest import mock + +from xsdata.codegen.container import ClassContainer +from xsdata.codegen.handlers import ClassNameConflictHandler +from xsdata.models.config import StructureStyle +from xsdata.models.enums import Tag +from xsdata.utils.testing import AttrFactory +from xsdata.utils.testing import AttrTypeFactory +from xsdata.utils.testing import ClassFactory +from xsdata.utils.testing import ExtensionFactory +from xsdata.utils.testing import FactoryTestCase + + +class ClassNameConflictHandlerTests(FactoryTestCase): + def setUp(self): + super().setUp() + + self.container = ClassContainer() + self.processor = ClassNameConflictHandler(container=self.container) + + @mock.patch.object(ClassNameConflictHandler, "rename_classes") + def test_process(self, mock_rename_classes): + classes = [ + ClassFactory.create(qname="{foo}A"), + ClassFactory.create(qname="{foo}a"), + ClassFactory.create(qname="_a"), + ClassFactory.create(qname="_b"), + ClassFactory.create(qname="b"), + ] + self.container.extend(classes) + self.processor.process() + + mock_rename_classes.assert_has_calls( + [ + mock.call(classes[:2], False), + mock.call(classes[3:], False), + ] + ) + + @mock.patch.object(ClassNameConflictHandler, "rename_classes") + def test_process_with_single_package_structure(self, mock_rename_classes): + classes = [ + ClassFactory.create(qname="{foo}a"), + ClassFactory.create(qname="{bar}a"), + ClassFactory.create(qname="a"), + ] + + self.container.config.output.structure = StructureStyle.SINGLE_PACKAGE + self.container.extend(classes) + self.processor.process() + + mock_rename_classes.assert_called_once_with(classes, True) + + @mock.patch.object(ClassNameConflictHandler, "rename_class") + def test_rename_classes(self, mock_rename_class): + classes = [ + ClassFactory.create(qname="_a", tag=Tag.ELEMENT), + ClassFactory.create(qname="_A", tag=Tag.ELEMENT), + ClassFactory.create(qname="a", tag=Tag.COMPLEX_TYPE), + ] + self.processor.rename_classes(classes, False) + self.processor.rename_classes(classes, True) + + mock_rename_class.assert_has_calls( + [ + mock.call(classes[0], False), + mock.call(classes[1], False), + mock.call(classes[2], False), + mock.call(classes[0], True), + mock.call(classes[1], True), + mock.call(classes[2], True), + ] + ) + + @mock.patch.object(ClassNameConflictHandler, "rename_class") + def test_rename_classes_protects_single_element(self, mock_rename_class): + classes = [ + ClassFactory.create(qname="_a", tag=Tag.ELEMENT), + ClassFactory.create(qname="a", tag=Tag.COMPLEX_TYPE), + ] + self.processor.rename_classes(classes, False) + + mock_rename_class.assert_called_once_with(classes[1], False) + + @mock.patch.object(ClassNameConflictHandler, "rename_class_dependencies") + def test_rename_class(self, mock_rename_class_dependencies): + target = ClassFactory.create(qname="{foo}_a") + self.processor.container.add(target) + self.processor.container.add(ClassFactory.create(qname="{foo}a_1")) + self.processor.container.add(ClassFactory.create(qname="{foo}A_2")) + self.processor.container.add(ClassFactory.create(qname="{bar}a_3")) + self.processor.rename_class(target, False) + + self.assertEqual("{foo}_a_3", target.qname) + self.assertEqual("_a", target.meta_name) + + mock_rename_class_dependencies.assert_has_calls( + mock.call(item, id(target), "{foo}_a_3") + for item in self.processor.container.iterate() + ) + + self.assertEqual([target], self.container.data["{foo}_a_3"]) + self.assertEqual([], self.container.data["{foo}_a"]) + + @mock.patch.object(ClassNameConflictHandler, "rename_class_dependencies") + def test_rename_class_by_name(self, mock_rename_class_dependencies): + target = ClassFactory.create(qname="{foo}_a") + self.processor.container.add(target) + self.processor.container.add(ClassFactory.create(qname="{bar}a_1")) + self.processor.container.add(ClassFactory.create(qname="{thug}A_2")) + self.processor.container.add(ClassFactory.create(qname="{bar}a_3")) + self.processor.rename_class(target, True) + + self.assertEqual("{foo}_a_4", target.qname) + self.assertEqual("_a", target.meta_name) + + mock_rename_class_dependencies.assert_has_calls( + mock.call(item, id(target), "{foo}_a_4") + for item in self.processor.container.iterate() + ) + + self.assertEqual([target], self.container.data["{foo}_a_4"]) + self.assertEqual([], self.container.data["{foo}_a"]) + + def test_rename_class_dependencies(self): + attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1) + + target = ClassFactory.create( + extensions=[ + ExtensionFactory.create(), + ExtensionFactory.create(attr_type.clone()), + ], + attrs=[ + AttrFactory.create(), + AttrFactory.create(types=[AttrTypeFactory.create(), attr_type.clone()]), + ], + inner=[ + ClassFactory.create( + extensions=[ExtensionFactory.create(attr_type.clone())], + attrs=[ + AttrFactory.create(), + AttrFactory.create( + types=[AttrTypeFactory.create(), attr_type.clone()] + ), + ], + ) + ], + ) + + self.processor.rename_class_dependencies(target, 1, "thug") + dependencies = set(target.dependencies()) + self.assertNotIn("{foo}bar", dependencies) + self.assertIn("thug", dependencies) + + def test_rename_attr_dependencies_with_default_enum(self): + attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1) + target = ClassFactory.create( + attrs=[ + AttrFactory.create( + types=[attr_type], + default=f"@enum@{attr_type.qname}::member", + ), + ] + ) + + self.processor.rename_class_dependencies(target, 1, "thug") + dependencies = set(target.dependencies()) + self.assertEqual("@enum@thug::member", target.attrs[0].default) + self.assertNotIn("{foo}bar", dependencies) + self.assertIn("thug", dependencies) + + def test_rename_attr_dependencies_with_choices(self): + attr_type = AttrTypeFactory.create(qname="foo", reference=1) + target = ClassFactory.create( + attrs=[ + AttrFactory.create( + choices=[ + AttrFactory.create(types=[attr_type.clone()]), + ] + ) + ] + ) + + self.processor.rename_class_dependencies(target, 1, "bar") + dependencies = set(target.dependencies()) + self.assertNotIn("foo", dependencies) + self.assertIn("bar", dependencies) diff --git a/tests/codegen/test_analyzer.py b/tests/codegen/test_analyzer.py index bafdc2db2..bfcad1ec5 100644 --- a/tests/codegen/test_analyzer.py +++ b/tests/codegen/test_analyzer.py @@ -2,7 +2,6 @@ from xsdata.codegen.analyzer import ClassAnalyzer from xsdata.codegen.container import ClassContainer -from xsdata.codegen.sanitizer import ClassSanitizer from xsdata.codegen.validator import ClassValidator from xsdata.exceptions import AnalyzerValueError from xsdata.models.config import GeneratorConfig @@ -14,16 +13,12 @@ class ClassAnalyzerTests(FactoryTestCase): @mock.patch.object(ClassAnalyzer, "validate_references") - @mock.patch.object(ClassSanitizer, "process") - @mock.patch.object(ClassContainer, "filter_classes") @mock.patch.object(ClassContainer, "process") @mock.patch.object(ClassValidator, "process") def test_process( self, mock_validator_process, mock_container_process, - mock_container_filter_classes, - mock_sanitizer_process, mock_validate_references, ): config = GeneratorConfig() @@ -37,8 +32,6 @@ def test_process( mock_validator_process.assert_called_once_with() mock_container_process.assert_called_once_with() - mock_container_filter_classes.assert_called_once_with() - mock_sanitizer_process.assert_called_once_with() mock_validate_references.assert_called_once_with(classes) def test_class_references(self): diff --git a/tests/codegen/test_container.py b/tests/codegen/test_container.py index 837ba9c27..2f16619a6 100644 --- a/tests/codegen/test_container.py +++ b/tests/codegen/test_container.py @@ -41,18 +41,18 @@ def test_initialize(self): "AttributeMixedContentHandler", "AttributeSanitizerHandler", ], - [x.__class__.__name__ for x in container.processors], + [x.__class__.__name__ for x in container.pre_processors], ) - @mock.patch.object(ClassContainer, "process_class") - def test_find(self, mock_process_class): - def process_class(x: Class): + @mock.patch.object(ClassContainer, "pre_process_class") + def test_find(self, mock_pre_process_class): + def pre_process_class(x: Class): x.status = Status.PROCESSED class_a = ClassFactory.create(qname="a") class_b = ClassFactory.create(qname="b", status=Status.PROCESSED) class_c = ClassFactory.enumeration(2, qname="b", status=Status.PROCESSING) - mock_process_class.side_effect = process_class + mock_pre_process_class.side_effect = pre_process_class self.container.extend([class_a, class_b, class_c]) self.assertIsNone(self.container.find("nope")) @@ -61,29 +61,29 @@ def process_class(x: Class): self.assertEqual( class_c, self.container.find(class_b.qname, lambda x: x.is_enumeration) ) - mock_process_class.assert_called_once_with(class_a) + mock_pre_process_class.assert_called_once_with(class_a) - @mock.patch.object(ClassContainer, "process_class") - def test_find_inner(self, mock_process_class): + @mock.patch.object(ClassContainer, "pre_process_class") + def test_find_inner(self, mock_pre_process_class): obj = ClassFactory.create() first = ClassFactory.create(qname="{a}a") second = ClassFactory.create(qname="{a}b", status=Status.PROCESSED) obj.inner.extend((first, second)) - def process_class(x: Class): + def pre_process_class(x: Class): x.status = Status.PROCESSED - mock_process_class.side_effect = process_class + mock_pre_process_class.side_effect = pre_process_class self.assertEqual(first, self.container.find_inner(obj, "{a}a")) self.assertEqual(second, self.container.find_inner(obj, "{a}b")) - mock_process_class.assert_called_once_with(first) + mock_pre_process_class.assert_called_once_with(first) def test_process(self): target = ClassFactory.create(inner=ClassFactory.list(2)) self.container.add(target) - self.container.process_class(target) + self.container.process() self.assertEqual(Status.PROCESSED, target.status) self.assertEqual(Status.PROCESSED, target.inner[0].status) self.assertEqual(Status.PROCESSED, target.inner[1].status) diff --git a/tests/codegen/test_sanitizer.py b/tests/codegen/test_sanitizer.py deleted file mode 100644 index 8f6997747..000000000 --- a/tests/codegen/test_sanitizer.py +++ /dev/null @@ -1,659 +0,0 @@ -from unittest import mock - -from xsdata.codegen.container import ClassContainer -from xsdata.codegen.models import Class -from xsdata.codegen.models import Restrictions -from xsdata.codegen.sanitizer import ClassSanitizer -from xsdata.models.config import GeneratorConfig -from xsdata.models.config import StructureStyle -from xsdata.models.enums import DataType -from xsdata.models.enums import Namespace -from xsdata.models.enums import Tag -from xsdata.utils.testing import AttrFactory -from xsdata.utils.testing import AttrTypeFactory -from xsdata.utils.testing import ClassFactory -from xsdata.utils.testing import ExtensionFactory -from xsdata.utils.testing import FactoryTestCase - - -class ClassSanitizerTest(FactoryTestCase): - def setUp(self): - super().setUp() - - self.container = ClassContainer(config=GeneratorConfig()) - self.sanitizer = ClassSanitizer(container=self.container) - - @mock.patch.object(ClassSanitizer, "resolve_conflicts") - @mock.patch.object(ClassSanitizer, "process_class") - def test_process(self, mock_process_class, mock_resolve_conflicts): - classes = ClassFactory.list(2) - self.container.extend(classes) - self.sanitizer.process() - - mock_process_class.assert_has_calls(list(map(mock.call, classes))) - mock_resolve_conflicts.assert_called_once_with() - - @mock.patch.object(ClassSanitizer, "process_duplicate_attribute_names") - @mock.patch.object(ClassSanitizer, "process_attribute_sequence") - @mock.patch.object(ClassSanitizer, "process_attribute_restrictions") - @mock.patch.object(ClassSanitizer, "process_attribute_default") - def test_process_class( - self, - mock_process_attribute_default, - mock_process_attribute_restrictions, - mock_process_attribute_sequence, - mock_process_duplicate_attribute_names, - ): - target = ClassFactory.elements(2) - inner = ClassFactory.elements(1) - target.inner.append(inner) - - self.sanitizer.process_class(target) - - calls_with_target = [ - mock.call(target.inner[0], target.inner[0].attrs[0]), - mock.call(target, target.attrs[0]), - mock.call(target, target.attrs[1]), - ] - - calls_without_target = [ - mock.call(target.inner[0].attrs[0]), - mock.call(target.attrs[0]), - mock.call(target.attrs[1]), - ] - - mock_process_attribute_default.assert_has_calls(calls_with_target) - mock_process_attribute_restrictions.assert_has_calls(calls_without_target) - mock_process_attribute_sequence.assert_has_calls(calls_with_target) - mock_process_duplicate_attribute_names.assert_has_calls( - [mock.call(target.inner[0].attrs), mock.call(target.attrs)] - ) - - @mock.patch.object(ClassSanitizer, "group_compound_fields") - def test_process_class_group_compound_fields(self, mock_group_compound_fields): - target = ClassFactory.create() - inner = ClassFactory.create() - target.inner.append(inner) - - self.container.config.output.compound_fields = True - self.sanitizer.process_class(target) - - mock_group_compound_fields.assert_has_calls( - [ - mock.call(inner), - mock.call(target), - ] - ) - - def test_process_attribute_default_with_enumeration(self): - target = ClassFactory.create() - attr = AttrFactory.enumeration() - attr.restrictions.max_occurs = 2 - attr.fixed = True - - self.sanitizer.process_attribute_default(target, attr) - self.assertTrue(attr.fixed) - - def test_process_attribute_default_with_optional_field(self): - target = ClassFactory.create() - attr = AttrFactory.create(fixed=True, default=2) - attr.restrictions.min_occurs = 0 - self.sanitizer.process_attribute_default(target, attr) - self.assertFalse(attr.fixed) - self.assertIsNone(attr.default) - - def test_process_attribute_default_with_xsi_type(self): - target = ClassFactory.create() - attr = AttrFactory.create( - fixed=True, default=2, name="type", namespace=Namespace.XSI.uri - ) - self.sanitizer.process_attribute_default(target, attr) - self.assertFalse(attr.fixed) - self.assertIsNone(attr.default) - - def test_process_attribute_default_with_valid_case(self): - target = ClassFactory.create() - attr = AttrFactory.create(fixed=True, default=2) - self.sanitizer.process_attribute_default(target, attr) - self.assertTrue(attr.fixed) - self.assertEqual(2, attr.default) - - @mock.patch("xsdata.codegen.sanitizer.logger.warning") - @mock.patch.object(ClassSanitizer, "find_enum") - def test_process_attribute_default_enum(self, mock_find_enum, mock_logger_warning): - enum_one = ClassFactory.enumeration(1, qname="{a}root") - enum_one.attrs[0].default = "1" - enum_one.attrs[0].name = "one" - enum_two = ClassFactory.enumeration(1, qname="inner") - enum_two.attrs[0].default = "2" - enum_two.attrs[0].name = "two" - enum_three = ClassFactory.enumeration(2, qname="missing_member") - enum_three.attrs[0].default = "4" - enum_three.attrs[0].name = "four" - enum_three.attrs[1].default = "5" - enum_three.attrs[1].name = "five" - - mock_find_enum.side_effect = [ - None, - enum_one, - None, - enum_two, - enum_three, - enum_three, - ] - - target = ClassFactory.create( - qname="target", - attrs=[ - AttrFactory.create( - types=[ - AttrTypeFactory.create(), - AttrTypeFactory.create(qname="foo"), - ], - default="1", - ), - AttrFactory.create( - types=[ - AttrTypeFactory.create(), - AttrTypeFactory.create(qname="bar", forward=True), - ], - default="2", - ), - AttrFactory.create(default="3"), - AttrFactory.create(default=" 4 5"), - ], - ) - - actual = [] - for attr in target.attrs: - self.sanitizer.process_attribute_default(target, attr) - actual.append(attr.default) - - self.assertEqual( - [ - "@enum@{a}root::one", - "@enum@inner::two", - None, - "@enum@missing_member::four@five", - ], - actual, - ) - mock_logger_warning.assert_called_once_with( - "No enumeration member matched %s.%s default value `%s`", - target.name, - target.attrs[2].local_name, - "3", - ) - - def test_find_enum(self): - native_type = AttrTypeFactory.create() - matching_external = AttrTypeFactory.create("foo") - missing_external = AttrTypeFactory.create("bar") - enumeration = ClassFactory.enumeration(1, qname="foo") - inner = ClassFactory.enumeration(1, qname="foobar") - - target = ClassFactory.create( - attrs=[ - AttrFactory.create( - types=[ - native_type, - matching_external, - missing_external, - ] - ) - ], - inner=[inner], - ) - self.sanitizer.container.extend([target, enumeration]) - - actual = self.sanitizer.find_enum(native_type) - self.assertIsNone(actual) - - actual = self.sanitizer.find_enum(matching_external) - self.assertEqual(enumeration, actual) - - actual = self.sanitizer.find_enum(missing_external) - self.assertIsNone(actual) - - def test_process_attribute_restrictions(self): - required = Restrictions(min_occurs=1, max_occurs=1) - attr = AttrFactory.attribute(restrictions=required.clone()) - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - tokens = Restrictions(required=True, tokens=True, min_occurs=1, max_occurs=1) - attr = AttrFactory.element(restrictions=tokens.clone()) - self.sanitizer.process_attribute_restrictions(attr) - self.assertFalse(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - attr = AttrFactory.element(restrictions=tokens.clone()) - attr.restrictions.max_occurs = 2 - self.sanitizer.process_attribute_restrictions(attr) - self.assertFalse(attr.restrictions.required) - self.assertIsNotNone(attr.restrictions.min_occurs) - self.assertIsNotNone(attr.restrictions.max_occurs) - - multiple = Restrictions(min_occurs=0, max_occurs=2) - attr = AttrFactory.create(tag=Tag.EXTENSION, restrictions=multiple) - self.sanitizer.process_attribute_restrictions(attr) - self.assertTrue(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - multiple = Restrictions(max_occurs=2, required=True) - attr = AttrFactory.element(restrictions=multiple, fixed=True) - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.required) - self.assertEqual(0, attr.restrictions.min_occurs) - self.assertFalse(attr.fixed) - - attr = AttrFactory.element(restrictions=required.clone()) - self.sanitizer.process_attribute_restrictions(attr) - self.assertTrue(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - - restrictions = Restrictions(required=True, min_occurs=0, max_occurs=1) - attr = AttrFactory.element(restrictions=restrictions, default="A", fixed=True) - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.required) - self.assertIsNone(attr.restrictions.min_occurs) - self.assertIsNone(attr.restrictions.max_occurs) - self.assertIsNone(attr.default) - self.assertFalse(attr.fixed) - - attr = AttrFactory.element(restrictions=required.clone(), default="A") - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.required) - - attr = AttrFactory.element(restrictions=required.clone(), fixed=True) - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.required) - - attr = AttrFactory.element(restrictions=required.clone()) - attr.restrictions.nillable = True - self.sanitizer.process_attribute_restrictions(attr) - self.assertIsNone(attr.restrictions.required) - - def test_sanitize_duplicate_attribute_names(self): - attrs = [ - AttrFactory.create(name="a", tag=Tag.ELEMENT), - AttrFactory.create(name="a", tag=Tag.ATTRIBUTE), - AttrFactory.create(name="b", tag=Tag.ATTRIBUTE), - AttrFactory.create(name="c", tag=Tag.ATTRIBUTE), - AttrFactory.create(name="c", tag=Tag.ELEMENT), - AttrFactory.create(name="d", tag=Tag.ELEMENT), - AttrFactory.create(name="d", tag=Tag.ELEMENT), - AttrFactory.create(name="e", tag=Tag.ELEMENT, namespace="b"), - AttrFactory.create(name="e", tag=Tag.ELEMENT), - AttrFactory.create(name="f", tag=Tag.ELEMENT), - AttrFactory.create(name="f", tag=Tag.ELEMENT, namespace="a"), - AttrFactory.create(name="gA", tag=Tag.ENUMERATION), - AttrFactory.create(name="g[A]", tag=Tag.ENUMERATION), - AttrFactory.create(name="g_a", tag=Tag.ENUMERATION), - AttrFactory.create(name="g_a_1", tag=Tag.ENUMERATION), - ] - - self.sanitizer.process_duplicate_attribute_names(attrs) - expected = [ - "a", - "a_Attribute", - "b", - "c_Attribute", - "c", - "d_Element", - "d", - "b_e", - "e", - "f", - "a_f", - "gA", - "g[A]_2", - "g_a_3", - "g_a_1", - ] - self.assertEqual(expected, [x.name for x in attrs]) - - def test_sanitize_attribute_sequence(self): - def len_sequential(target: Class): - return len([attr for attr in target.attrs if attr.restrictions.sequential]) - - restrictions = Restrictions(max_occurs=2, sequential=True) - target = ClassFactory.create( - attrs=[ - AttrFactory.create(restrictions=restrictions.clone()), - AttrFactory.create(restrictions=restrictions.clone()), - ] - ) - - attrs_clone = [attr.clone() for attr in target.attrs] - - self.sanitizer.process_attribute_sequence(target, target.attrs[0]) - self.assertEqual(2, len_sequential(target)) - - target.attrs[0].restrictions.sequential = False - self.sanitizer.process_attribute_sequence(target, target.attrs[0]) - self.assertEqual(1, len_sequential(target)) - - self.sanitizer.process_attribute_sequence(target, target.attrs[1]) - self.assertEqual(0, len_sequential(target)) - - target.attrs = attrs_clone - target.attrs[1].restrictions.sequential = False - self.sanitizer.process_attribute_sequence(target, target.attrs[0]) - self.assertEqual(0, len_sequential(target)) - - target.attrs[0].restrictions.sequential = True - target.attrs[0].restrictions.max_occurs = 0 - target.attrs[1].restrictions.sequential = True - self.sanitizer.process_attribute_sequence(target, target.attrs[0]) - self.assertEqual(1, len_sequential(target)) - - @mock.patch.object(ClassSanitizer, "rename_classes") - def test_resolve_conflicts(self, mock_rename_classes): - classes = [ - ClassFactory.create(qname="{foo}A"), - ClassFactory.create(qname="{foo}a"), - ClassFactory.create(qname="_a"), - ClassFactory.create(qname="_b"), - ClassFactory.create(qname="b"), - ] - self.sanitizer.container.extend(classes) - self.sanitizer.resolve_conflicts() - - mock_rename_classes.assert_has_calls( - [ - mock.call(classes[:2], False), - mock.call(classes[3:], False), - ] - ) - - @mock.patch.object(ClassSanitizer, "rename_classes") - def test_resolve_conflicts_with_single_package_structure(self, mock_rename_classes): - classes = [ - ClassFactory.create(qname="{foo}a"), - ClassFactory.create(qname="{bar}a"), - ClassFactory.create(qname="a"), - ] - - self.sanitizer.container.config.output.structure = StructureStyle.SINGLE_PACKAGE - self.sanitizer.container.extend(classes) - self.sanitizer.resolve_conflicts() - - mock_rename_classes.assert_called_once_with(classes, True) - - @mock.patch.object(ClassSanitizer, "rename_class") - def test_rename_classes(self, mock_rename_class): - classes = [ - ClassFactory.create(qname="_a", tag=Tag.ELEMENT), - ClassFactory.create(qname="_A", tag=Tag.ELEMENT), - ClassFactory.create(qname="a", tag=Tag.COMPLEX_TYPE), - ] - self.sanitizer.rename_classes(classes, False) - self.sanitizer.rename_classes(classes, True) - - mock_rename_class.assert_has_calls( - [ - mock.call(classes[0], False), - mock.call(classes[1], False), - mock.call(classes[2], False), - mock.call(classes[0], True), - mock.call(classes[1], True), - mock.call(classes[2], True), - ] - ) - - @mock.patch.object(ClassSanitizer, "rename_class") - def test_rename_classes_protects_single_element(self, mock_rename_class): - classes = [ - ClassFactory.create(qname="_a", tag=Tag.ELEMENT), - ClassFactory.create(qname="a", tag=Tag.COMPLEX_TYPE), - ] - self.sanitizer.rename_classes(classes, False) - - mock_rename_class.assert_called_once_with(classes[1], False) - - @mock.patch.object(ClassSanitizer, "rename_class_dependencies") - def test_rename_class(self, mock_rename_class_dependencies): - target = ClassFactory.create(qname="{foo}_a") - self.sanitizer.container.add(target) - self.sanitizer.container.add(ClassFactory.create(qname="{foo}a_1")) - self.sanitizer.container.add(ClassFactory.create(qname="{foo}A_2")) - self.sanitizer.container.add(ClassFactory.create(qname="{bar}a_3")) - self.sanitizer.rename_class(target, False) - - self.assertEqual("{foo}_a_3", target.qname) - self.assertEqual("_a", target.meta_name) - - mock_rename_class_dependencies.assert_has_calls( - mock.call(item, id(target), "{foo}_a_3") - for item in self.sanitizer.container.iterate() - ) - - self.assertEqual([target], self.container.data["{foo}_a_3"]) - self.assertEqual([], self.container.data["{foo}_a"]) - - @mock.patch.object(ClassSanitizer, "rename_class_dependencies") - def test_rename_class_by_name(self, mock_rename_class_dependencies): - target = ClassFactory.create(qname="{foo}_a") - self.sanitizer.container.add(target) - self.sanitizer.container.add(ClassFactory.create(qname="{bar}a_1")) - self.sanitizer.container.add(ClassFactory.create(qname="{thug}A_2")) - self.sanitizer.container.add(ClassFactory.create(qname="{bar}a_3")) - self.sanitizer.rename_class(target, True) - - self.assertEqual("{foo}_a_4", target.qname) - self.assertEqual("_a", target.meta_name) - - mock_rename_class_dependencies.assert_has_calls( - mock.call(item, id(target), "{foo}_a_4") - for item in self.sanitizer.container.iterate() - ) - - self.assertEqual([target], self.container.data["{foo}_a_4"]) - self.assertEqual([], self.container.data["{foo}_a"]) - - def test_rename_class_dependencies(self): - attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1) - - target = ClassFactory.create( - extensions=[ - ExtensionFactory.create(), - ExtensionFactory.create(attr_type.clone()), - ], - attrs=[ - AttrFactory.create(), - AttrFactory.create(types=[AttrTypeFactory.create(), attr_type.clone()]), - ], - inner=[ - ClassFactory.create( - extensions=[ExtensionFactory.create(attr_type.clone())], - attrs=[ - AttrFactory.create(), - AttrFactory.create( - types=[AttrTypeFactory.create(), attr_type.clone()] - ), - ], - ) - ], - ) - - self.sanitizer.rename_class_dependencies(target, 1, "thug") - dependencies = set(target.dependencies()) - self.assertNotIn("{foo}bar", dependencies) - self.assertIn("thug", dependencies) - - def test_rename_attr_dependencies_with_default_enum(self): - attr_type = AttrTypeFactory.create(qname="{foo}bar", reference=1) - target = ClassFactory.create( - attrs=[ - AttrFactory.create( - types=[attr_type], - default=f"@enum@{attr_type.qname}::member", - ), - ] - ) - - self.sanitizer.rename_class_dependencies(target, 1, "thug") - dependencies = set(target.dependencies()) - self.assertEqual("@enum@thug::member", target.attrs[0].default) - self.assertNotIn("{foo}bar", dependencies) - self.assertIn("thug", dependencies) - - def test_rename_attr_dependencies_with_choices(self): - attr_type = AttrTypeFactory.create(qname="foo", reference=1) - target = ClassFactory.create( - attrs=[ - AttrFactory.create( - choices=[ - AttrFactory.create(types=[attr_type.clone()]), - ] - ) - ] - ) - - self.sanitizer.rename_class_dependencies(target, 1, "bar") - dependencies = set(target.dependencies()) - self.assertNotIn("foo", dependencies) - self.assertIn("bar", dependencies) - - @mock.patch.object(ClassSanitizer, "group_fields") - def test_group_compound_fields(self, mock_group_fields): - target = ClassFactory.elements(8) - # First group repeating - target.attrs[0].restrictions.choice = "1" - target.attrs[1].restrictions.choice = "1" - target.attrs[1].restrictions.max_occurs = 2 - # Second group repeating - target.attrs[2].restrictions.choice = "2" - target.attrs[3].restrictions.choice = "2" - target.attrs[3].restrictions.max_occurs = 2 - # Third group optional - target.attrs[4].restrictions.choice = "3" - target.attrs[5].restrictions.choice = "3" - - self.sanitizer.group_compound_fields(target) - mock_group_fields.assert_has_calls( - [ - mock.call(target, target.attrs[0:2]), - mock.call(target, target.attrs[2:4]), - ] - ) - - def test_group_fields(self): - target = ClassFactory.create(attrs=AttrFactory.list(2)) - target.attrs[0].restrictions.choice = "1" - target.attrs[1].restrictions.choice = "1" - target.attrs[0].restrictions.min_occurs = 10 - target.attrs[0].restrictions.max_occurs = 15 - target.attrs[1].restrictions.min_occurs = 5 - target.attrs[1].restrictions.max_occurs = 20 - - expected = AttrFactory.create( - name="attr_B_Or_attr_C", - tag="Choice", - index=0, - types=[AttrTypeFactory.native(DataType.ANY_TYPE)], - choices=[ - AttrFactory.create( - tag=target.attrs[0].tag, - name="attr_B", - types=target.attrs[0].types, - ), - AttrFactory.create( - tag=target.attrs[1].tag, - name="attr_C", - types=target.attrs[1].types, - ), - ], - ) - expected_res = Restrictions(min_occurs=5, max_occurs=20) - - self.sanitizer.group_fields(target, list(target.attrs)) - self.assertEqual(1, len(target.attrs)) - self.assertEqual(expected, target.attrs[0]) - self.assertEqual(expected_res, target.attrs[0].restrictions) - - def test_group_fields_with_effective_choices_sums_occurs(self): - target = ClassFactory.create(attrs=AttrFactory.list(2)) - target.attrs[0].restrictions.choice = "effective_1" - target.attrs[1].restrictions.choice = "effective_1" - target.attrs[0].restrictions.min_occurs = 1 - target.attrs[0].restrictions.max_occurs = 2 - target.attrs[1].restrictions.min_occurs = 3 - target.attrs[1].restrictions.max_occurs = 4 - - expected_res = Restrictions(min_occurs=4, max_occurs=6) - - self.sanitizer.group_fields(target, list(target.attrs)) - self.assertEqual(1, len(target.attrs)) - self.assertEqual(expected_res, target.attrs[0].restrictions) - - def test_group_fields_limit_name(self): - target = ClassFactory.create(attrs=AttrFactory.list(3)) - for attr in target.attrs: - attr.restrictions.choice = "1" - - self.sanitizer.group_fields(target, list(target.attrs)) - - self.assertEqual(1, len(target.attrs)) - self.assertEqual("attr_B_Or_attr_C_Or_attr_D", target.attrs[0].name) - - target = ClassFactory.create(attrs=AttrFactory.list(4)) - for attr in target.attrs: - attr.restrictions.choice = "1" - - self.sanitizer.group_fields(target, list(target.attrs)) - self.assertEqual("choice", target.attrs[0].name) - - target = ClassFactory.create() - attr = AttrFactory.element(restrictions=Restrictions(choice="1")) - target.attrs.append(attr) - target.attrs.append(attr.clone()) - self.sanitizer.group_fields(target, list(target.attrs)) - self.assertEqual("choice", target.attrs[0].name) - - def test_build_attr_choice(self): - attr = AttrFactory.create( - name="a", namespace="xsdata", default="123", help="help", fixed=True - ) - attr.local_name = "aaa" - attr.restrictions = Restrictions( - required=True, - prohibited=None, - min_occurs=1, - max_occurs=1, - min_exclusive="1.1", - min_inclusive="1", - min_length=1, - max_exclusive="1", - max_inclusive="1.1", - max_length=10, - total_digits=333, - fraction_digits=2, - length=5, - white_space="collapse", - pattern=r"[A-Z]", - explicit_timezone="+1", - nillable=True, - choice="abc", - sequential=True, - ) - expected_res = attr.restrictions.clone() - expected_res.min_occurs = None - expected_res.max_occurs = None - expected_res.sequential = None - - actual = self.sanitizer.build_attr_choice(attr) - - self.assertEqual(attr.local_name, actual.name) - self.assertEqual(attr.namespace, actual.namespace) - self.assertEqual(attr.default, actual.default) - self.assertEqual(attr.tag, actual.tag) - self.assertEqual(attr.types, actual.types) - self.assertEqual(expected_res, actual.restrictions) - self.assertEqual(attr.help, actual.help) - self.assertFalse(actual.fixed) diff --git a/xsdata/codegen/analyzer.py b/xsdata/codegen/analyzer.py index 6d44c22e3..6e6f4e2ed 100644 --- a/xsdata/codegen/analyzer.py +++ b/xsdata/codegen/analyzer.py @@ -2,7 +2,6 @@ from xsdata.codegen.container import ClassContainer from xsdata.codegen.models import Class -from xsdata.codegen.sanitizer import ClassSanitizer from xsdata.codegen.validator import ClassValidator from xsdata.exceptions import AnalyzerValueError @@ -21,12 +20,6 @@ def process(cls, container: ClassContainer) -> List[Class]: # Run analyzer handlers container.process() - # Filter classes that should be generated. - container.filter_classes() - - # Sanitize class attributes after merging and flattening types and extensions. - ClassSanitizer(container).process() - classes = container.class_list cls.validate_references(classes) diff --git a/xsdata/codegen/container.py b/xsdata/codegen/container.py index da886ef88..341b1f9f8 100644 --- a/xsdata/codegen/container.py +++ b/xsdata/codegen/container.py @@ -7,14 +7,18 @@ from typing import List from typing import Optional +from xsdata.codegen.handlers import AttributeDefaultValueHandler from xsdata.codegen.handlers import AttributeGroupHandler from xsdata.codegen.handlers import AttributeMergeHandler from xsdata.codegen.handlers import AttributeMixedContentHandler +from xsdata.codegen.handlers import AttributeRestrictionsHandler from xsdata.codegen.handlers import AttributeSanitizerHandler from xsdata.codegen.handlers import AttributeSubstitutionHandler from xsdata.codegen.handlers import AttributeTypeHandler +from xsdata.codegen.handlers import ChoiceGroupHandler from xsdata.codegen.handlers import ClassEnumerationHandler from xsdata.codegen.handlers import ClassExtensionHandler +from xsdata.codegen.handlers import ClassNameConflictHandler from xsdata.codegen.mixins import ContainerInterface from xsdata.codegen.mixins import HandlerInterface from xsdata.codegen.models import Class @@ -31,10 +35,11 @@ class ClassContainer(ContainerInterface): data: Dict = field(default_factory=dict) config: GeneratorConfig = field(default_factory=GeneratorConfig) - processors: List[HandlerInterface] = field(init=False) + pre_processors: List[HandlerInterface] = field(init=False) + post_processors: List[HandlerInterface] = field(init=False) def __post_init__(self): - self.processors: List[HandlerInterface] = [ + self.pre_processors: List[HandlerInterface] = [ AttributeGroupHandler(self), ClassExtensionHandler(self), ClassEnumerationHandler(self), @@ -45,6 +50,13 @@ def __post_init__(self): AttributeSanitizerHandler(self), ] + self.post_processors: List[HandlerInterface] = [ + AttributeDefaultValueHandler(self), + AttributeRestrictionsHandler(), + ] + if self.config.output.compound_fields: + self.post_processors.insert(0, ChoiceGroupHandler()) + @property def class_list(self) -> List[Class]: return list(self.iterate()) @@ -60,7 +72,7 @@ def find(self, qname: str, condition: Callable = return_true) -> Optional[Class] for row in self.data.get(qname, []): if condition(row): if row.status == Status.RAW: - self.process_class(row) + self.pre_process_class(row) return self.find(qname, condition) return row @@ -69,7 +81,7 @@ def find(self, qname: str, condition: Callable = return_true) -> Optional[Class] def find_inner(self, source: Class, qname: str) -> Class: inner = ClassUtils.find_inner(source, qname) if inner.status == Status.RAW: - self.process_class(inner) + self.pre_process_class(inner) return inner @@ -77,23 +89,38 @@ def process(self): """Run the process handlers for ever non processed class.""" for obj in self.iterate(): if obj.status == Status.RAW: - self.process_class(obj) + self.pre_process_class(obj) + + self.filter_classes() + + for obj in self.iterate(): + self.post_process_class(obj) - def process_class(self, target: Class): + conflict_resolver = ClassNameConflictHandler(self) + conflict_resolver.process() + + def pre_process_class(self, target: Class): """Run the process handlers for the target class.""" target.status = Status.PROCESSING - for processor in self.processors: + for processor in self.pre_processors: processor.process(target) # We go top to bottom because it's easier to handle circular # references. for inner in target.inner: if inner.status == Status.RAW: - self.process_class(inner) + self.pre_process_class(inner) target.status = Status.PROCESSED + def post_process_class(self, target: Class): + for processor in self.post_processors: + processor.process(target) + + for inner in target.inner: + self.post_process_class(inner) + def filter_classes(self): """If there is any class derived from complexType or element then filter classes that should be generated, otherwise leave the container diff --git a/xsdata/codegen/handlers/__init__.py b/xsdata/codegen/handlers/__init__.py index 9d13f6a72..3ff1d3eb7 100644 --- a/xsdata/codegen/handlers/__init__.py +++ b/xsdata/codegen/handlers/__init__.py @@ -1,11 +1,17 @@ +from xsdata.codegen.handlers.attribute_default_value import AttributeDefaultValueHandler from xsdata.codegen.handlers.attribute_group import AttributeGroupHandler from xsdata.codegen.handlers.attribute_merge import AttributeMergeHandler from xsdata.codegen.handlers.attribute_mixed_content import AttributeMixedContentHandler +from xsdata.codegen.handlers.attribute_name_conflict import AttributeNameConflictHandler +from xsdata.codegen.handlers.attribute_restrictions import AttributeRestrictionsHandler from xsdata.codegen.handlers.attribute_sanitizer import AttributeSanitizerHandler from xsdata.codegen.handlers.attribute_substitution import AttributeSubstitutionHandler from xsdata.codegen.handlers.attribute_type import AttributeTypeHandler +from xsdata.codegen.handlers.choice_group import ChoiceGroupHandler from xsdata.codegen.handlers.class_enumeration import ClassEnumerationHandler from xsdata.codegen.handlers.class_extension import ClassExtensionHandler +from xsdata.codegen.handlers.class_name_conflict import ClassNameConflictHandler + __all__ = [ "ClassEnumerationHandler", @@ -16,4 +22,10 @@ "AttributeSubstitutionHandler", "AttributeTypeHandler", "ClassExtensionHandler", + "ChoiceGroupHandler", + "AttributeRestrictionsHandler", + "AttributeDefaultValueHandler", + "AttributeRestrictionsHandler", + "AttributeNameConflictHandler", + "ClassNameConflictHandler", ] diff --git a/xsdata/codegen/handlers/attribute_default_value.py b/xsdata/codegen/handlers/attribute_default_value.py new file mode 100644 index 000000000..a2673f171 --- /dev/null +++ b/xsdata/codegen/handlers/attribute_default_value.py @@ -0,0 +1,92 @@ +from dataclasses import dataclass +from typing import Optional + +from xsdata.codegen.mixins import ContainerInterface +from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.models import Attr +from xsdata.codegen.models import AttrType +from xsdata.codegen.models import Class +from xsdata.logger import logger + + +@dataclass +class AttributeDefaultValueHandler(HandlerInterface): + container: ContainerInterface + + def process(self, target: Class): + for attr in target.attrs: + self.process_attribute(target, attr) + + def process_attribute(self, target: Class, attr: Attr): + """ + Sanitize attribute default value. + + Cases: + 1. Ignore enumerations. + 2. List fields can not have a fixed value. + 3. Optional fields or xsi:type can not have a default or fixed value. + 4. Convert string literal default value for enum fields. + """ + + if attr.is_enumeration: + return + + if attr.is_optional or attr.is_xsi_type: + attr.fixed = False + attr.default = None + + if attr.default: + self.process_attribute_default_enum(target, attr) + + def process_attribute_default_enum(self, target: Class, attr: Attr): + """ + Convert string literal default value for enum fields. + + Loop through all attributes types and search for enum sources. + If an enum source exist map the default string literal value to + a qualified name. If the source class in inner promote it to + root classes. + """ + + source_found = False + + assert attr.default is not None + + for attr_type in attr.types: + source = self.find_enum(attr_type) + if not source: + continue + + source_found = True + value_members = {x.default: x.name for x in source.attrs} + name = value_members.get(attr.default) + if name: + attr.default = f"@enum@{source.qname}::{name}" + return + + names = [ + value_members[token] + for token in attr.default.split() + if token in value_members + ] + if names: + attr.default = f"@enum@{source.qname}::{'@'.join(names)}" + return + + if source_found: + logger.warning( + "No enumeration member matched %s.%s default value `%s`", + target.name, + attr.local_name, + attr.default, + ) + attr.default = None + + def find_enum(self, attr_type: AttrType) -> Optional[Class]: + """Find an enumeration class byte the attribute type.""" + if attr_type.native: + return None + + return self.container.find( + attr_type.qname, condition=lambda x: x.is_enumeration + ) diff --git a/xsdata/codegen/handlers/attribute_name_conflict.py b/xsdata/codegen/handlers/attribute_name_conflict.py new file mode 100644 index 000000000..a803cf650 --- /dev/null +++ b/xsdata/codegen/handlers/attribute_name_conflict.py @@ -0,0 +1,20 @@ +from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.models import Class +from xsdata.codegen.utils import ClassUtils +from xsdata.utils import text +from xsdata.utils.collections import group_by + + +class AttributeNameConflictHandler(HandlerInterface): + """Enumeration class processor.""" + + def process(self, target: Class): + """Sanitize duplicate attribute names that might exist by applying + rename strategies.""" + grouped = group_by(target.attrs, lambda attr: text.alnum(attr.name)) + for items in grouped.values(): + total = len(items) + if total == 2 and not items[0].is_enumeration: + ClassUtils.rename_attribute_by_preference(*items) + elif total > 1: + ClassUtils.rename_attributes_by_index(target.attrs, items) diff --git a/xsdata/codegen/handlers/attribute_restrictions.py b/xsdata/codegen/handlers/attribute_restrictions.py new file mode 100644 index 000000000..75c3ff605 --- /dev/null +++ b/xsdata/codegen/handlers/attribute_restrictions.py @@ -0,0 +1,68 @@ +from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.models import Attr +from xsdata.codegen.models import Class + + +class AttributeRestrictionsHandler(HandlerInterface): + """Enumeration class processor.""" + + def process(self, target: Class): + + for index, attr in enumerate(target.attrs): + self.reset_occurrences(attr) + self.reset_sequential(target, index) + + @classmethod + def reset_occurrences(cls, attr: Attr): + """Sanitize attribute required flag by comparing the min/max + occurrences restrictions.""" + restrictions = attr.restrictions + min_occurs = restrictions.min_occurs or 0 + max_occurs = restrictions.max_occurs or 0 + + if attr.is_attribute: + restrictions.min_occurs = None + restrictions.max_occurs = None + elif attr.is_tokens: + restrictions.required = None + if max_occurs <= 1: + restrictions.min_occurs = None + restrictions.max_occurs = None + elif attr.xml_type is None or min_occurs == max_occurs == 1: + restrictions.required = True + restrictions.min_occurs = None + restrictions.max_occurs = None + elif min_occurs == 0 and max_occurs < 2: + restrictions.required = None + restrictions.min_occurs = None + restrictions.max_occurs = None + attr.default = None + attr.fixed = False + else: # max_occurs > 1 + restrictions.min_occurs = min_occurs + restrictions.required = None + attr.fixed = False + + if attr.default or attr.fixed or attr.restrictions.nillable: + restrictions.required = None + + @classmethod + def reset_sequential(cls, target: Class, index: int): + """Reset the attribute at the given index if it has no siblings with + the sequential restriction.""" + + attr = target.attrs[index] + before = target.attrs[index - 1] if index - 1 >= 0 else None + after = target.attrs[index + 1] if index + 1 < len(target.attrs) else None + + if not attr.is_list: + attr.restrictions.sequential = False + + if ( + not attr.restrictions.sequential + or (before and before.restrictions.sequential) + or (after and after.restrictions.sequential and after.is_list) + ): + return + + attr.restrictions.sequential = False diff --git a/xsdata/codegen/handlers/choice_group.py b/xsdata/codegen/handlers/choice_group.py new file mode 100644 index 000000000..440db5f32 --- /dev/null +++ b/xsdata/codegen/handlers/choice_group.py @@ -0,0 +1,81 @@ +from typing import List + +from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.models import Attr +from xsdata.codegen.models import AttrType +from xsdata.codegen.models import Class +from xsdata.codegen.models import Restrictions +from xsdata.models.enums import DataType +from xsdata.models.enums import Tag +from xsdata.utils.collections import group_by + + +class ChoiceGroupHandler(HandlerInterface): + """Enumeration class processor.""" + + def process(self, target: Class): + groups = group_by(target.attrs, lambda x: x.restrictions.choice) + for choice, attrs in groups.items(): + if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs): + self.group_fields(target, attrs) + + @classmethod + def group_fields(cls, target: Class, attrs: List[Attr]): + """Group attributes into a new compound field.""" + + pos = target.attrs.index(attrs[0]) + choice = attrs[0].restrictions.choice + sum_occurs = choice and choice.startswith("effective_") + names = [] + choices = [] + min_occurs = [] + max_occurs = [] + for attr in attrs: + target.attrs.remove(attr) + names.append(attr.local_name) + min_occurs.append(attr.restrictions.min_occurs or 0) + max_occurs.append(attr.restrictions.max_occurs or 0) + choices.append(cls.build_attr_choice(attr)) + + if len(names) > 3 or len(names) != len(set(names)): + name = "choice" + else: + name = "_Or_".join(names) + + target.attrs.insert( + pos, + Attr( + name=name, + index=0, + types=[AttrType(qname=str(DataType.ANY_TYPE), native=True)], + tag=Tag.CHOICE, + restrictions=Restrictions( + min_occurs=sum(min_occurs) if sum_occurs else min(min_occurs), + max_occurs=sum(max_occurs) if sum_occurs else max(max_occurs), + ), + choices=choices, + ), + ) + + @classmethod + def build_attr_choice(cls, attr: Attr) -> Attr: + """ + Converts the given attr to a choice. + + The most important part is the reset of certain restrictions + that don't make sense as choice metadata like occurrences. + """ + restrictions = attr.restrictions.clone() + restrictions.min_occurs = None + restrictions.max_occurs = None + restrictions.sequential = None + + return Attr( + name=attr.local_name, + namespace=attr.namespace, + default=attr.default, + types=attr.types, + tag=attr.tag, + help=attr.help, + restrictions=restrictions, + ) diff --git a/xsdata/codegen/handlers/class_name_conflict.py b/xsdata/codegen/handlers/class_name_conflict.py new file mode 100644 index 000000000..826f6a6e5 --- /dev/null +++ b/xsdata/codegen/handlers/class_name_conflict.py @@ -0,0 +1,102 @@ +import operator +from dataclasses import dataclass +from typing import List + +from xsdata.codegen.mixins import ContainerInterface +from xsdata.codegen.mixins import HandlerInterface +from xsdata.codegen.models import Attr +from xsdata.codegen.models import Class +from xsdata.models.config import StructureStyle +from xsdata.utils import text +from xsdata.utils.collections import group_by +from xsdata.utils.namespaces import build_qname +from xsdata.utils.namespaces import split_qname + + +@dataclass +class ClassNameConflictHandler(HandlerInterface): + container: ContainerInterface + + def process(self): + use_name = ( + self.container.config.output.structure == StructureStyle.SINGLE_PACKAGE + ) + getter = operator.attrgetter("name" if use_name else "qname") + groups = group_by(self.container.iterate(), lambda x: text.alnum(getter(x))) + + for classes in groups.values(): + if len(classes) > 1: + self.rename_classes(classes, use_name) + + def rename_classes(self, classes: List[Class], use_name: bool): + """ + Rename all the classes in the list. + + Protect classes derived from xs:element if there is only one in + the list. + """ + total_elements = sum(x.is_element for x in classes) + for target in classes: + if not target.is_element or total_elements > 1: + self.rename_class(target, use_name) + + def rename_class(self, target: Class, use_name: bool): + """Find the next available class identifier, save the original name in + the class metadata and update the class qualified name and all classes + that depend on the target class.""" + + qname = target.qname + namespace, name = split_qname(target.qname) + target.qname = self.next_qname(namespace, name, use_name) + target.meta_name = name + self.container.reset(target, qname) + + for item in self.container.iterate(): + self.rename_class_dependencies(item, id(target), target.qname) + + def next_qname(self, namespace: str, name: str, use_name: bool) -> str: + """Append the next available index number for the given namespace and + local name.""" + index = 0 + + if use_name: + reserved = {text.alnum(obj.name) for obj in self.container.iterate()} + else: + reserved = {text.alnum(obj.qname) for obj in self.container.iterate()} + + while True: + index += 1 + new_name = f"{name}_{index}" + qname = build_qname(namespace, new_name) + cmp = text.alnum(new_name if use_name else qname) + + if cmp not in reserved: + return qname + + def rename_class_dependencies(self, target: Class, reference: int, replace: str): + """Search and replace the old qualified attribute type name with the + new one if it exists in the target class attributes, extensions and + inner classes.""" + for attr in target.attrs: + self.rename_attr_dependencies(attr, reference, replace) + + for ext in target.extensions: + if ext.type.reference == reference: + ext.type.qname = replace + + for inner in target.inner: + self.rename_class_dependencies(inner, reference, replace) + + def rename_attr_dependencies(self, attr: Attr, reference: int, replace: str): + """Search and replace the old qualified attribute type name with the + new one in the attr types, choices and default value.""" + for attr_type in attr.types: + if attr_type.reference == reference: + attr_type.qname = replace + + if isinstance(attr.default, str) and attr.default.startswith("@enum@"): + members = text.suffix(attr.default, "::") + attr.default = f"@enum@{replace}::{members}" + + for choice in attr.choices: + self.rename_attr_dependencies(choice, reference, replace) diff --git a/xsdata/codegen/mixins.py b/xsdata/codegen/mixins.py index f74ff481e..126f14923 100644 --- a/xsdata/codegen/mixins.py +++ b/xsdata/codegen/mixins.py @@ -36,6 +36,10 @@ def add(self, item: Class): def extend(self, items: List[Class]): """Add a list of classes the container.""" + @abc.abstractmethod + def reset(self, item: Class, qname: str): + """Update the given class qualified name.""" + class HandlerInterface(metaclass=abc.ABCMeta): """Class handler interface.""" diff --git a/xsdata/codegen/models.py b/xsdata/codegen/models.py index 26272f993..f6d8dd098 100644 --- a/xsdata/codegen/models.py +++ b/xsdata/codegen/models.py @@ -380,6 +380,7 @@ class Status(IntEnum): RAW = 0 PROCESSING = 1 PROCESSED = 2 + SANITIZED = 3 @dataclass diff --git a/xsdata/codegen/sanitizer.py b/xsdata/codegen/sanitizer.py deleted file mode 100644 index cc846e0f7..000000000 --- a/xsdata/codegen/sanitizer.py +++ /dev/null @@ -1,370 +0,0 @@ -import operator -from dataclasses import dataclass -from typing import List -from typing import Optional - -from xsdata.codegen.container import ClassContainer -from xsdata.codegen.models import Attr -from xsdata.codegen.models import AttrType -from xsdata.codegen.models import Class -from xsdata.codegen.models import Restrictions -from xsdata.codegen.utils import ClassUtils -from xsdata.logger import logger -from xsdata.models.config import StructureStyle -from xsdata.models.enums import DataType -from xsdata.models.enums import Tag -from xsdata.utils import collections -from xsdata.utils import text -from xsdata.utils.collections import group_by -from xsdata.utils.namespaces import build_qname -from xsdata.utils.namespaces import split_qname -from xsdata.utils.text import alnum - - -@dataclass -class ClassSanitizer: - """Prepare all the classes attributes for code generation and cleanup after - the analyzer processors.""" - - container: ClassContainer - - def process(self): - """Iterate through all classes and run the sanitizer procedure.""" - - for target in self.container.iterate(): - self.process_class(target) - - self.resolve_conflicts() - - def process_class(self, target: Class): - """ - Sanitize the attributes of the given class. After applying all the - flattening handlers the attributes need to be further sanitized to - squash common issues like duplicate attribute names. - - Steps: - 1. Sanitize inner classes - 2. Sanitize attributes default value - 3. Sanitize attributes name - 4. Sanitize attributes sequential flag - 5. Sanitize duplicate attribute names - """ - collections.apply(target.inner, self.process_class) - - if self.container.config.output.compound_fields: - self.group_compound_fields(target) - - for attr in target.attrs: - self.process_attribute_restrictions(attr) - self.process_attribute_default(target, attr) - self.process_attribute_sequence(target, attr) - - self.process_duplicate_attribute_names(target.attrs) - - def group_compound_fields(self, target: Class): - """Group and process target attributes by the choice group.""" - - groups = group_by(target.attrs, lambda x: x.restrictions.choice) - for choice, attrs in groups.items(): - if choice and len(attrs) > 1 and any(attr.is_list for attr in attrs): - self.group_fields(target, attrs) - - def group_fields(self, target: Class, attrs: List[Attr]): - """Group attributes into a new compound field.""" - - pos = target.attrs.index(attrs[0]) - choice = attrs[0].restrictions.choice - sum_occurs = choice and choice.startswith("effective_") - names = [] - choices = [] - min_occurs = [] - max_occurs = [] - for attr in attrs: - target.attrs.remove(attr) - names.append(attr.local_name) - min_occurs.append(attr.restrictions.min_occurs or 0) - max_occurs.append(attr.restrictions.max_occurs or 0) - choices.append(self.build_attr_choice(attr)) - - if len(names) > 3 or len(names) != len(set(names)): - name = "choice" - else: - name = "_Or_".join(names) - - target.attrs.insert( - pos, - Attr( - name=name, - index=0, - types=[AttrType(qname=str(DataType.ANY_TYPE), native=True)], - tag=Tag.CHOICE, - restrictions=Restrictions( - min_occurs=sum(min_occurs) if sum_occurs else min(min_occurs), - max_occurs=sum(max_occurs) if sum_occurs else max(max_occurs), - ), - choices=choices, - ), - ) - - def process_attribute_default(self, target: Class, attr: Attr): - """ - Sanitize attribute default value. - - Cases: - 1. Ignore enumerations. - 2. List fields can not have a fixed value. - 3. Optional fields or xsi:type can not have a default or fixed value. - 4. Convert string literal default value for enum fields. - """ - - if attr.is_enumeration: - return - - if attr.is_optional or attr.is_xsi_type: - attr.fixed = False - attr.default = None - - if attr.default: - self.process_attribute_default_enum(target, attr) - - def process_attribute_default_enum(self, target: Class, attr: Attr): - """ - Convert string literal default value for enum fields. - - Loop through all attributes types and search for enum sources. - If an enum source exist map the default string literal value to - a qualified name. If the source class in inner promote it to - root classes. - """ - - source_found = False - - assert attr.default is not None - - for attr_type in attr.types: - source = self.find_enum(attr_type) - if not source: - continue - - source_found = True - value_members = {x.default: x.name for x in source.attrs} - name = value_members.get(attr.default) - if name: - attr.default = f"@enum@{source.qname}::{name}" - return - - names = [ - value_members[token] - for token in attr.default.split() - if token in value_members - ] - if names: - attr.default = f"@enum@{source.qname}::{'@'.join(names)}" - return - - if source_found: - logger.warning( - "No enumeration member matched %s.%s default value `%s`", - target.name, - attr.local_name, - attr.default, - ) - attr.default = None - - def resolve_conflicts(self): - """Find classes with the same case insensitive qualified name and - rename them.""" - use_name = ( - self.container.config.output.structure == StructureStyle.SINGLE_PACKAGE - ) - getter = operator.attrgetter("name" if use_name else "qname") - groups = group_by(self.container.iterate(), lambda x: alnum(getter(x))) - - for classes in groups.values(): - if len(classes) > 1: - self.rename_classes(classes, use_name) - - def rename_classes(self, classes: List[Class], use_name: bool): - """ - Rename all the classes in the list. - - Protect classes derived from xs:element if there is only one in - the list. - """ - total_elements = sum(x.is_element for x in classes) - for target in classes: - if not target.is_element or total_elements > 1: - self.rename_class(target, use_name) - - def rename_class(self, target: Class, use_name: bool): - """Find the next available class identifier, save the original name in - the class metadata and update the class qualified name and all classes - that depend on the target class.""" - - qname = target.qname - namespace, name = split_qname(target.qname) - target.qname = self.next_qname(namespace, name, use_name) - target.meta_name = name - self.container.reset(target, qname) - - for item in self.container.iterate(): - self.rename_class_dependencies(item, id(target), target.qname) - - def next_qname(self, namespace: str, name: str, use_name: bool) -> str: - """Append the next available index number for the given namespace and - local name.""" - index = 0 - - if use_name: - reserved = {alnum(obj.name) for obj in self.container.iterate()} - else: - reserved = set(map(alnum, self.container.data.keys())) - - while True: - index += 1 - new_name = f"{name}_{index}" - qname = build_qname(namespace, new_name) - cmp = alnum(new_name if use_name else qname) - - if cmp not in reserved: - return qname - - def rename_class_dependencies(self, target: Class, reference: int, replace: str): - """Search and replace the old qualified attribute type name with the - new one if it exists in the target class attributes, extensions and - inner classes.""" - for attr in target.attrs: - self.rename_attr_dependencies(attr, reference, replace) - - for ext in target.extensions: - if ext.type.reference == reference: - ext.type.qname = replace - - for inner in target.inner: - self.rename_class_dependencies(inner, reference, replace) - - def rename_attr_dependencies(self, attr: Attr, reference: int, replace: str): - """Search and replace the old qualified attribute type name with the - new one in the attr types, choices and default value.""" - for attr_type in attr.types: - if attr_type.reference == reference: - attr_type.qname = replace - - if isinstance(attr.default, str) and attr.default.startswith("@enum@"): - members = text.suffix(attr.default, "::") - attr.default = f"@enum@{replace}::{members}" - - for choice in attr.choices: - self.rename_attr_dependencies(choice, reference, replace) - - def find_enum(self, attr_type: AttrType) -> Optional[Class]: - """Find an enumeration class byte the attribute type.""" - if attr_type.native: - return None - - return self.container.find( - attr_type.qname, condition=lambda x: x.is_enumeration - ) - - @classmethod - def process_attribute_restrictions(cls, attr: Attr): - """Sanitize attribute required flag by comparing the min/max - occurrences restrictions.""" - restrictions = attr.restrictions - min_occurs = restrictions.min_occurs or 0 - max_occurs = restrictions.max_occurs or 0 - - if attr.is_attribute: - restrictions.min_occurs = None - restrictions.max_occurs = None - elif attr.is_tokens: - restrictions.required = None - if max_occurs <= 1: - restrictions.min_occurs = None - restrictions.max_occurs = None - elif attr.xml_type is None or min_occurs == max_occurs == 1: - restrictions.required = True - restrictions.min_occurs = None - restrictions.max_occurs = None - elif min_occurs == 0 and max_occurs < 2: - restrictions.required = None - restrictions.min_occurs = None - restrictions.max_occurs = None - attr.default = None - attr.fixed = False - else: # max_occurs > 1 - restrictions.min_occurs = min_occurs - restrictions.required = None - attr.fixed = False - - if attr.default or attr.fixed or attr.restrictions.nillable: - restrictions.required = None - - @classmethod - def process_attribute_sequence(cls, target: Class, attr: Attr): - """Reset the attribute at the given index if it has no siblings with - the sequential restriction.""" - - index = target.attrs.index(attr) - before = target.attrs[index - 1] if index - 1 >= 0 else None - after = target.attrs[index + 1] if index + 1 < len(target.attrs) else None - - if not attr.is_list: - attr.restrictions.sequential = False - - if ( - not attr.restrictions.sequential - or (before and before.restrictions.sequential) - or (after and after.restrictions.sequential and after.is_list) - ): - return - - attr.restrictions.sequential = False - - @classmethod - def process_duplicate_attribute_names(cls, attrs: List[Attr]) -> None: - """Sanitize duplicate attribute names that might exist by applying - rename strategies.""" - grouped = group_by(attrs, lambda attr: alnum(attr.name)) - for items in grouped.values(): - total = len(items) - if total == 2 and not items[0].is_enumeration: - ClassUtils.rename_attribute_by_preference(*items) - elif total > 1: - cls.rename_attributes_with_index(attrs, items) - - @classmethod - def rename_attributes_with_index(cls, attrs: List[Attr], rename: List[Attr]): - """Append the next available index number to all the rename attributes - names.""" - for index in range(1, len(rename)): - num = 1 - name = rename[index].name - - while any(alnum(attr.name) == alnum(f"{name}_{num}") for attr in attrs): - num += 1 - - rename[index].name = f"{name}_{num}" - - @classmethod - def build_attr_choice(cls, attr: Attr) -> Attr: - """ - Converts the given attr to a choice. - - The most important part is the reset of certain restrictions - that don't make sense as choice metadata like occurrences. - """ - restrictions = attr.restrictions.clone() - restrictions.min_occurs = None - restrictions.max_occurs = None - restrictions.sequential = None - - return Attr( - name=attr.local_name, - namespace=attr.namespace, - default=attr.default, - types=attr.types, - tag=attr.tag, - help=attr.help, - restrictions=restrictions, - ) diff --git a/xsdata/codegen/utils.py b/xsdata/codegen/utils.py index 295973381..adbafa1fa 100644 --- a/xsdata/codegen/utils.py +++ b/xsdata/codegen/utils.py @@ -10,6 +10,7 @@ from xsdata.codegen.models import Restrictions from xsdata.exceptions import CodeGenerationError from xsdata.utils import collections +from xsdata.utils import text from xsdata.utils.namespaces import build_qname from xsdata.utils.namespaces import clean_uri from xsdata.utils.namespaces import split_qname @@ -208,3 +209,17 @@ def rename_attribute_by_preference(cls, a: Attr, b: Attr): else: change = b if b.is_attribute else a change.name = f"{change.name}_{change.tag}" + + @classmethod + def rename_attributes_by_index(cls, attrs: List[Attr], rename: List[Attr]): + """Append the next available index number to all the rename attributes + names.""" + for index in range(1, len(rename)): + num = 1 + name = rename[index].name + + reserved = {text.alnum(attr.name) for attr in attrs} + while text.alnum(f"{name}_{num}") in reserved: + num += 1 + + rename[index].name = f"{name}_{num}"