Skip to content

Commit

Permalink
Add test for ModelVisitor and fix XsdGroup.get_subgroups()
Browse files Browse the repository at this point in the history
  • Loading branch information
brunato committed Aug 19, 2024
1 parent fbe4622 commit 82ff2d4
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 14 deletions.
226 changes: 223 additions & 3 deletions tests/validators/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

from xmlschema import XMLSchema10, XMLSchema11
from xmlschema.exceptions import XMLSchemaValueError
from xmlschema.validators.exceptions import XMLSchemaValidationError
from xmlschema.validators.exceptions import XMLSchemaValidationError, XMLSchemaModelError
from xmlschema.validators.particles import ParticleMixin
from xmlschema.validators.models import distinguishable_paths, ModelVisitor, \
sort_content, iter_collapsed_content
Expand Down Expand Up @@ -68,7 +68,7 @@ def check_advance_true(self, model, expected=None):
def check_advance_false(self, model, expected=None):
"""
Advances a model with a no-match condition and checks the
expected error list or or exception.
expected error list or exception.
:param model: an ModelGroupVisitor instance.
:param expected: can be an exception class or a list. Leaving `None` means that \
Expand Down Expand Up @@ -1080,7 +1080,7 @@ def test_stoppable_property(self):
self.check_advance_true(model) # <b> matching
self.assertTrue(model.stoppable)

def test_occurs_check_methods_for_elements(self):
def test_particle_occurs_check_methods(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
Expand Down Expand Up @@ -1131,6 +1131,226 @@ def test_occurs_check_methods_for_elements(self):
self.assertRaises(ValueError, model.is_over)
self.assertRaises(ValueError, model.is_exceeded)

def test_get_model_particle(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:choice>
<xs:group ref="top"/>
<xs:element name="c" minOccurs="1"/>
</xs:choice>
</xs:complexType>
</xs:element>
<xs:element name="b"/>
<xs:group name="top">
<xs:sequence>
<xs:element name="a" minOccurs="0"/>
<xs:element ref="b" minOccurs="0" maxOccurs="2"/>
</xs:sequence>
</xs:group>
</xs:schema>
"""))

group = schema.elements['root'].type.content
top, c = group[:]
a, b = schema.groups['top']

model = ModelVisitor(group)
self.assertIs(model.get_model_particle(a), a)
self.assertIs(model.get_model_particle(b), b)
self.assertIs(model.get_model_particle(c), c)
self.assertIs(model.get_model_particle(top), top)

# Global model groups head declaration doesn't belong to any concrete model
with self.assertRaises(XMLSchemaModelError) as ctx:
model.get_model_particle(b.ref)
self.assertIn("not a particle of the model group", str(ctx.exception))

with self.assertRaises(XMLSchemaModelError) as ctx:
model.get_model_particle(top.ref)
self.assertIn("not a particle of the model group", str(ctx.exception))

self.assertIs(model.get_model_particle(), model.element)
self.assertListEqual(list(model.stop()), [])

with self.assertRaises(XMLSchemaValueError) as ctx:
model.get_model_particle()
self.assertIn("can't defaults to current element", str(ctx.exception))

def test_model_occurs_check_methods(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:sequence maxOccurs="25">
<xs:element name="a" minOccurs="0"/>
<xs:element name="b" maxOccurs="2"/>
<xs:element name="c" minOccurs="4" maxOccurs="unbounded"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>
"""))

group = schema.elements['root'].type.content
a, b, c = group[:]

model = ModelVisitor(group)
self.assertEqual(model.overall_min_occurs(a), 0)
self.assertEqual(model.overall_min_occurs(b), 1)
self.assertEqual(model.overall_min_occurs(c), 4)

self.assertEqual(model.overall_max_occurs(a), 25)
self.assertEqual(model.overall_max_occurs(b), 50)
self.assertIsNone(model.overall_max_occurs(c))

self.assertTrue(model.is_optional(a))
self.assertFalse(model.is_optional(b))
self.assertFalse(model.is_optional(c))

self.assertIs(model.element, a)
self.assertListEqual(list(model.advance(True)), [])
self.assertIs(model.element, b)
self.assertListEqual(list(model.advance(True)), [])
self.assertIs(model.element, b)
self.assertListEqual(list(model.advance(False)), [])
self.assertIs(model.element, c)
self.assertListEqual(list(model.advance(True)), [])

self.assertEqual(model.overall_min_occurs(a), 0)
self.assertEqual(model.overall_min_occurs(b), 0)
self.assertEqual(model.overall_min_occurs(c), 3)

self.assertEqual(model.overall_max_occurs(a), 24)
self.assertEqual(model.overall_max_occurs(b), 49)
self.assertIsNone(model.overall_max_occurs(c))

self.assertTrue(model.is_optional(a))
self.assertTrue(model.is_optional(b))
self.assertFalse(model.is_optional(c))

schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:choice maxOccurs="10">
<xs:group ref="top" maxOccurs="25"/>
<xs:element name="d" minOccurs="1"/>
</xs:choice>
</xs:complexType>
</xs:element>
<xs:group name="top">
<xs:sequence>
<xs:element name="a" minOccurs="0"/>
<xs:element name="b" maxOccurs="2"/>
<xs:element name="c" minOccurs="4" maxOccurs="unbounded"/>
</xs:sequence>
</xs:group>
</xs:schema>
"""))

group = schema.elements['root'].type.content
top, d = group[:]
a, b, c = schema.groups['top']

model = ModelVisitor(group)
self.assertEqual(model.overall_min_occurs(a), 0)
self.assertEqual(model.overall_min_occurs(b), 0)
self.assertEqual(model.overall_min_occurs(c), 0)
self.assertEqual(model.overall_min_occurs(top), 0)
self.assertEqual(model.overall_min_occurs(d), 0)

self.assertEqual(model.overall_max_occurs(a), 250)
self.assertEqual(model.overall_max_occurs(b), 500)
self.assertIsNone(model.overall_max_occurs(c))
self.assertEqual(model.overall_max_occurs(top), 250)
self.assertEqual(model.overall_max_occurs(d), 10)

self.assertIs(model.element, a)
self.assertListEqual(list(model.advance(False)), [])
self.assertIs(model.element, b)
self.assertListEqual(list(model.advance_until('d')), [])
self.assertIs(model.element, a)

self.assertEqual(model.overall_min_occurs(a), 0)
self.assertEqual(model.overall_min_occurs(b), 0)
self.assertEqual(model.overall_min_occurs(c), 0)
self.assertEqual(model.overall_min_occurs(top), 0)
self.assertEqual(model.overall_min_occurs(d), 0)

self.assertEqual(model.overall_max_occurs(a), 225)
self.assertEqual(model.overall_max_occurs(b), 450)
self.assertIsNone(model.overall_max_occurs(c))
self.assertEqual(model.overall_max_occurs(top), 225)
self.assertEqual(model.overall_max_occurs(d), 9)

def test_check_following(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:sequence>
<xs:element name="a" minOccurs="0"/>
<xs:element name="b" minOccurs="3" maxOccurs="8"/>
<xs:element name="c" minOccurs="2" maxOccurs="unbounded"/>
<xs:element name="d"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>
"""))

group = schema.elements['root'].type.content
a, b, c, d = group

model = ModelVisitor(group)
self.assertTrue(model.check_following(a.name))
self.assertTrue(model.check_following(b.name))
self.assertTrue(model.check_following(a.name, b.name))
self.assertFalse(model.check_following(c.name))
self.assertFalse(model.check_following(d.name))

def test_advance_smart_methods(self):
schema = self.schema_class(dedent(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:sequence>
<xs:element name="a" minOccurs="0"/>
<xs:element name="b" minOccurs="3" maxOccurs="8"/>
<xs:element name="c" minOccurs="2" maxOccurs="unbounded"/>
<xs:element name="d"/>
</xs:sequence>
</xs:complexType>
</xs:element>
</xs:schema>
"""))

group = schema.elements['root'].type.content
a, b, c, d = group

model = group.get_model_visitor()
self.assertIs(model.element, a)
self.assertFalse(model.advance_safe(c.name))
self.assertIs(model.element, a)
self.assertTrue(model.advance_safe(a.name, b.name, b.name, b.name, c.name))
self.assertIs(model.element, c)

model = group.get_model_visitor()
self.assertIs(model.element, a)
self.assertTrue(list(model.advance_until(c.name)))
self.assertIs(model.element, c)

model.restart()
self.assertIs(model.element, a)
self.assertListEqual(list(model.advance_until(b.name)), [])


class TestModelValidation11(TestModelValidation):
schema_class = XMLSchema11
Expand Down
4 changes: 2 additions & 2 deletions xmlschema/validators/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ def get_subgroups(self, particle: ModelParticleType) -> List['XsdGroup']:
group.
"""
subgroups: List[Tuple[XsdGroup, Iterator[ModelParticleType]]] = []
group, children = self, iter(self)
group, children = self, iter(self if self.ref is None else self.ref)

while True:
for child in children:
Expand All @@ -387,7 +387,7 @@ def get_subgroups(self, particle: ModelParticleType) -> List['XsdGroup']:
if len(subgroups) > limits.MAX_MODEL_DEPTH:
raise XMLSchemaModelDepthError(self)
subgroups.append((group, children))
group, children = child, iter(child)
group, children = child, iter(child if child.ref is None else child.ref)
break
else:
try:
Expand Down
18 changes: 9 additions & 9 deletions xmlschema/validators/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,7 @@ def get_model_particle(self, particle: Optional[ModelParticleType] = None) \
is ended.
"""
if particle is not None:
for _group in self.group.get_subgroups(particle):
for _group in self.root.get_subgroups(particle):
break
return particle
elif self.element is not None:
Expand All @@ -563,13 +563,13 @@ def overall_min_occurs(self, particle: Optional[ModelParticleType] = None) -> in
"""
particle = self.get_model_particle(particle)
min_occurs = 1
for group in self.group.get_subgroups(particle):
for group in self.root.get_subgroups(particle):
group_min_occurs = group.min_occurs - self.occurs[group]
if group_min_occurs <= 0 or group.model == 'choice' and len(group) > 1:
return 0
min_occurs *= group_min_occurs

return min_occurs * particle.min_occurs - self.occurs[particle]
return max(0, min_occurs * particle.min_occurs - self.occurs[particle])

def overall_max_occurs(self, particle: Optional[ModelParticleType] = None) -> Optional[int]:
"""
Expand All @@ -580,7 +580,7 @@ def overall_max_occurs(self, particle: Optional[ModelParticleType] = None) -> Op
particle = self.get_model_particle(particle)
max_occurs: Optional[int] = 1

for group in self.group.get_subgroups(particle):
for group in self.root.get_subgroups(particle):
group_max_occurs = group.max_occurs
if group_max_occurs == 0:
return 0
Expand Down Expand Up @@ -632,9 +632,9 @@ def is_exceeded(self, particle: Optional[ModelParticleType] = None) -> bool:

def advance_until(self, tag: str) -> Iterator[AdvanceYieldedType]:
"""
Advance until an element that matches `tag` is found. Stops after
an error in advancing. If the model ends before the tag is found
raise an `XMLSchemaValueError`.
Advances until an element matching `tag` is found. Stops after
an error in advancing. If the model ends before the tag is found,
it throws an `XMLSchemaValueError`.
"""
_err: Optional[AdvanceYieldedType] = None
while True:
Expand All @@ -649,7 +649,7 @@ def advance_until(self, tag: str) -> Iterator[AdvanceYieldedType]:
for _err in self.advance(False):
yield _err

def check_followings(self, *tags: str) -> bool:
def check_following(self, *tags: str) -> bool:
"""
Returns `True` if the model can be advanced without errors adding
the provided sequence of elements, represented by their tags.
Expand All @@ -673,7 +673,7 @@ def advance_safe(self, *tags: str) -> bool:
produce errors or the ending of the model. Returns `True` if the advance has
been done, `False` otherwise.
"""
if not self.check_followings(*tags):
if not self.check_following(*tags):
return False

for tag in tags:
Expand Down

0 comments on commit 82ff2d4

Please sign in to comment.