Skip to content

Commit

Permalink
Guard against multiple substitution group runs
Browse files Browse the repository at this point in the history
  • Loading branch information
tefra committed Jun 23, 2021
1 parent e83e202 commit 5361381
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 1 deletion.
4 changes: 4 additions & 0 deletions tests/codegen/handlers/test_attribute_substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,10 @@ def test_process_attribute(self, mock_find):

self.assertEqual(4, len(target.attrs))

# Guard against multiple runs in case of xs:groups
self.processor.process_attribute(target, first_attr)
self.assertEqual(4, len(target.attrs))

self.assertEqual(reference_attrs[0], target.attrs[0])
self.assertIsNot(reference_attrs[0], target.attrs[0])
self.assertEqual(reference_attrs[1], target.attrs[3])
Expand Down
8 changes: 7 additions & 1 deletion xsdata/codegen/handlers/attribute_substitution.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
class AttributeSubstitutionHandler(RelativeHandlerInterface):
"""Apply substitution attributes to the given class recursively."""

__slots__ = ("substitutions",)
__slots__ = "substitutions"

def __init__(self, container: ContainerInterface):
super().__init__(container)
Expand Down Expand Up @@ -43,11 +43,17 @@ def process_attribute(self, target: Class, attr: Attr):
The cloned attributes are placed below the attribute the are
supposed to substitute.
Guard against multiple substitutions in case of xs:groups.
"""
index = target.attrs.index(attr)
assert self.substitutions is not None

for attr_type in attr.types:
if attr_type.substituted:
continue

attr_type.substituted = True
for substitution in self.substitutions.get(attr_type.qname, []):
clone = ClassUtils.clone_attribute(substitution, attr.restrictions)

Expand Down
2 changes: 2 additions & 0 deletions xsdata/codegen/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,6 +190,7 @@ class AttrType:
:param native:
:param forward:
:param circular:
:param substituted:
"""

qname: str
Expand All @@ -198,6 +199,7 @@ class AttrType:
native: bool = field(default=False)
forward: bool = field(default=False)
circular: bool = field(default=False)
substituted: bool = field(default=False, compare=False)

@property
def datatype(self) -> Optional[DataType]:
Expand Down

0 comments on commit 5361381

Please sign in to comment.