Skip to content

Commit

Permalink
Clean occurs checking and add helper methods to ModelVisitor
Browse files Browse the repository at this point in the history
  • Loading branch information
brunato committed Aug 18, 2024
1 parent 1049929 commit fbe4622
Show file tree
Hide file tree
Showing 6 changed files with 233 additions and 153 deletions.
8 changes: 4 additions & 4 deletions tests/validators/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -909,7 +909,7 @@ def test_model_visitor_copy(self):
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:sequence minOccurs="0" maxOccurs="unbounded">
<xs:sequence minOccurs="0" maxOccurs="unbounded">
<xs:group ref="group1" minOccurs="2" maxOccurs="unbounded"/>
<xs:group ref="group2" minOccurs="0" maxOccurs="unbounded"/>
<xs:group ref="group3" maxOccurs="unbounded"/>
Expand Down Expand Up @@ -1001,7 +1001,7 @@ def test_model_visitor_copy_nested(self):
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:sequence>
<xs:sequence>
<xs:element name="a1"/>
<xs:group ref="group1" maxOccurs="unbounded"/>
</xs:sequence>
Expand Down Expand Up @@ -1054,7 +1054,7 @@ def test_model_visitor_copy_nested(self):

def test_stoppable_property(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
Expand Down Expand Up @@ -1082,7 +1082,7 @@ def test_stoppable_property(self):

def test_occurs_check_methods_for_elements(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
Expand Down
39 changes: 30 additions & 9 deletions tests/validators/test_particles.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
#
import os
import unittest
from collections import Counter
from xml.etree import ElementTree

from xmlschema import XMLSchema10, XMLSchemaParseError
Expand Down Expand Up @@ -65,16 +66,36 @@ def test_is_univocal(self):
self.assertTrue(self.schema.elements['cars'].is_univocal())
self.assertFalse(self.schema.elements['cars'].type.content[0].is_univocal())

def test_is_missing(self):
self.assertTrue(self.schema.elements['cars'].is_missing(0))
self.assertFalse(self.schema.elements['cars'].is_missing(1))
self.assertFalse(self.schema.elements['cars'].is_missing(2))
self.assertFalse(self.schema.elements['cars'].type.content[0].is_missing(0))
def test_occurs_checkers(self):
xsd_element = self.schema.elements['cars']

def test_is_over(self):
self.assertFalse(self.schema.elements['cars'].is_over(0))
self.assertTrue(self.schema.elements['cars'].is_over(1))
self.assertFalse(self.schema.elements['cars'].type.content[0].is_over(1000))
occurs = Counter()
self.assertTrue(xsd_element.is_missing(occurs))
self.assertFalse(xsd_element.is_over(occurs))
self.assertFalse(xsd_element.is_exceeded(occurs))

occurs[xsd_element] += 1
self.assertFalse(xsd_element.is_missing(occurs))
self.assertTrue(xsd_element.is_over(occurs))
self.assertFalse(xsd_element.is_exceeded(occurs))

occurs[xsd_element] += 1
self.assertFalse(xsd_element.is_missing(occurs))
self.assertTrue(xsd_element.is_over(occurs))
self.assertTrue(xsd_element.is_exceeded(occurs))

xsd_element = self.schema.elements['cars'].type.content[0] # car
self.assertTrue(xsd_element.min_occurs == 0)
self.assertTrue(xsd_element.max_occurs is None)

self.assertFalse(xsd_element.is_missing(occurs))
self.assertFalse(xsd_element.is_over(occurs))
self.assertFalse(xsd_element.is_exceeded(occurs))

occurs[xsd_element] += 1000
self.assertFalse(xsd_element.is_missing(occurs))
self.assertFalse(xsd_element.is_over(occurs))
self.assertFalse(xsd_element.is_exceeded(occurs))

def test_has_occurs_restriction(self):
schema = XMLSchema10("""<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
Expand Down
126 changes: 33 additions & 93 deletions xmlschema/validators/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,63 +368,18 @@ def iter_elements(self) -> Iterator[SchemaElementType]:
except IndexError:
return

def is_optional(self, particle: ModelParticleType) -> bool:
"""
Returns `True` if a particle can be optional in the model. This a raw check,
because the optionality can depend on the presence and the position of other
elements, that can be checked only using a `ModelVisitor` instance. Raises an
`XMLSchemaValueError` if the particle is not part of the model group.
"""
if self.max_occurs == 0:
raise XMLSchemaValueError("the model group is empty")

groups = [self]
iterators: List[Iterator[ModelParticleType]] = []
particles = iter(self)

while True:
for item in particles:
if item is particle:
if item.min_occurs == 0:
return True

for group in reversed(groups):
if group.min_occurs == 0:
return True
elif group.model == 'choice' and len(group) > 1:
return True
else:
return False

if isinstance(item, XsdGroup):
if item.max_occurs == 0:
continue

groups.append(item)
iterators.append(particles)
particles = iter(item.content)
if len(iterators) > limits.MAX_MODEL_DEPTH:
raise XMLSchemaModelDepthError(self)
break
else:
try:
groups.pop()
particles = iterators.pop()
except IndexError:
msg = "The provided particle is not part of the model group"
raise XMLSchemaValueError(msg)

def get_subgroups(self, item: ModelParticleType) -> List['XsdGroup']:
def get_subgroups(self, particle: ModelParticleType) -> List['XsdGroup']:
"""
Returns a list of the groups that represent the path to the enclosed particle.
Raises an `XMLSchemaModelError` if *item* is not a particle of the model group.
Raises an `XMLSchemaModelError` if the argument is not a particle of the model
group.
"""
subgroups: List[Tuple[XsdGroup, Iterator[ModelParticleType]]] = []
group, children = self, iter(self)

while True:
for child in children:
if child is item:
if child is particle:
_subgroups = [x[0] for x in subgroups]
_subgroups.append(group)
return _subgroups
Expand All @@ -439,41 +394,38 @@ def get_subgroups(self, item: ModelParticleType) -> List['XsdGroup']:
group, children = subgroups.pop()
except IndexError:
msg = _('{!r} is not a particle of the model group')
raise XMLSchemaModelError(self, msg.format(item)) from None
raise XMLSchemaModelError(self, msg.format(particle)) from None

def overall_min_occurs(self, item: ModelParticleType) -> int:
"""Returns the overall min occurs of a particle in the model."""
min_occurs = item.min_occurs

for group in self.get_subgroups(item):
if group.model == 'choice' and len(group) > 1:
return 0
min_occurs *= group.min_occurs

return min_occurs

def overall_max_occurs(self, item: ModelParticleType) -> Optional[int]:
"""Returns the overall max occurs of a particle in the model."""
max_occurs = item.max_occurs
def get_model_visitor(self) -> ModelVisitor:
if self.open_content is None or self.open_content.mode == 'none':
return ModelVisitor(self)
elif self.open_content.mode == 'interleave':
return InterleavedModelVisitor(self, self.open_content.any_element)
else:
return SuffixedModelVisitor(self, self.open_content.any_element)

for group in self.get_subgroups(item):
if max_occurs == 0:
return 0
elif max_occurs is None:
continue
elif group.max_occurs is None:
max_occurs = None
else:
max_occurs *= group.max_occurs
def overall_min_occurs(self, particle: ModelParticleType) -> int:
"""
Returns the overall min occurs of a particle in the model group.
"""
model = self.get_model_visitor()
return model.overall_min_occurs(particle)

return max_occurs
def overall_max_occurs(self, particle: ModelParticleType) -> Optional[int]:
"""
Returns the overall max occurs of a particle in the model group.
"""
model = self.get_model_visitor()
return model.overall_max_occurs(particle)

def is_missing(self, occurs: Union[OccursCounterType, int]) -> bool:
try:
value = occurs[self.oid] or occurs[self] # type: ignore[index]
except TypeError:
value = occurs
def is_optional(self, particle: ModelParticleType) -> bool:
"""
Returns `True` if the provided particle can be optional in the model group.
"""
return self.overall_min_occurs(particle) == 0

def is_missing(self, occurs: OccursCounterType) -> bool:
value = occurs[self.oid] or occurs[self]
return not self.is_emptiable() if value == 0 else self.min_occurs > value

def get_expected(self, occurs: OccursCounterType) -> List[SchemaElementType]:
Expand Down Expand Up @@ -1042,16 +994,10 @@ def iter_decode(self, obj: ElementType, validation: str = 'lax', **kwargs: Any)
xsd_element: Optional[SchemaElementType]
expected: Optional[List[SchemaElementType]]

if self.open_content is None or self.open_content.mode == 'none':
model = ModelVisitor(self)
elif self.open_content.mode == 'interleave':
model = InterleavedModelVisitor(self, self.open_content.any_element)
else:
model = SuffixedModelVisitor(self, self.open_content.any_element)

errors = []
broken_model = False
namespaces = converter.namespaces
model = self.get_model_visitor()

for index, child in enumerate(obj):
if callable(child.tag):
Expand Down Expand Up @@ -1168,16 +1114,10 @@ def iter_encode(self, obj: ElementData, validation: str = 'lax', **kwargs: Any)
padding = '\n' + ' ' * converter.indent * level
default_namespace = converter.get('')

if self.open_content is None or self.open_content.mode == 'none':
model = ModelVisitor(self)
elif self.open_content.mode == 'interleave':
model = InterleavedModelVisitor(self, self.open_content.any_element)
else:
model = SuffixedModelVisitor(self, self.open_content.any_element)

index = cdata_index = 0
wrong_content_type = False
over_max_depth = 'max_depth' in kwargs and kwargs['max_depth'] <= level
model = self.get_model_visitor()

content: Iterable[Any]
if not obj.content:
Expand Down
Loading

0 comments on commit fbe4622

Please sign in to comment.