Add iter_group() to ModelVisitor

This commit is contained in:
Davide Brunato 2019-11-05 11:09:34 +01:00
parent b95d890f51
commit dd2ab72654
4 changed files with 72 additions and 14 deletions

View File

@ -651,6 +651,51 @@ class TestModelValidation11(TestModelValidation):
self.assertIsNone(schema.validate(xml_data))
def test_all_model_with_relaxed_occurs(self):
schema = self.schema_class(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:all>
<xs:element name="a" minOccurs="0" maxOccurs="5"/>
<xs:element name="b" maxOccurs="5"/>
<xs:element name="c" minOccurs="2" maxOccurs="unbounded"/>
<xs:element name="d" />
</xs:all>
</xs:complexType>
</xs:element>
</xs:schema>
""")
xml_data = '<root><a/><b/><d/><c/><a/><c/><c/><a/><a/><b/></root>'
self.assertIsNone(schema.validate(xml_data))
schema = self.schema_class(
"""<?xml version="1.0" encoding="UTF-8"?>
<xs:schema xmlns:xs="http://www.w3.org/2001/XMLSchema">
<xs:element name="root">
<xs:complexType>
<xs:all>
<xs:element name="a" minOccurs="0" maxOccurs="5"/>
<xs:group ref="group1"/>
</xs:all>
</xs:complexType>
</xs:element>
<xs:group name="group1">
<xs:all>
<xs:element name="b" maxOccurs="5"/>
<xs:element name="c" minOccurs="2" maxOccurs="unbounded"/>
<xs:element name="d" />
</xs:all>
</xs:group>
</xs:schema>
""")
self.assertIsNone(schema.validate(xml_data))
class TestModelBasedSorting(XsdValidatorTestCase):

View File

@ -235,7 +235,8 @@ class XsdAttribute(XsdComponent, ValidationMixin):
elif text == self.fixed or validation == 'skip':
pass
elif self.type.text_decode(text) != self.type.text_decode(self.fixed):
yield self.validation_error(validation, "value differs from fixed value", text, **kwargs)
msg = "attribute {!r} has a fixed value {!r}".format(self.name, self.fixed)
yield self.validation_error(validation, msg, text, **kwargs)
for result in self.type.iter_decode(text, validation, **kwargs):
if isinstance(result, XMLSchemaValidationError):

View File

@ -821,11 +821,6 @@ class Xsd11Group(XsdGroup):
Content: (annotation?, (element | any | group)*)
</all>
"""
def __iter__(self):
if self.model == 'sequence':
return iter(self._group)
return iter(sorted(self._group, key=lambda x: isinstance(x, XsdAnyElement)))
def _parse_content_model(self, content_model):
self.model = local_name(content_model.tag)
if self.model == 'all':
@ -855,7 +850,7 @@ class Xsd11Group(XsdGroup):
if ref != self.name:
self.append(Xsd11Group(child, self.schema, self))
if (self.model != 'all') ^ (self[-1].model != 'all'):
msg = "an xs:%s group cannot reference to an x:%s group"
msg = "an xs:%s group cannot include a reference to an x:%s group"
self.parse_error(msg % (self.model, self[-1].model))
self.pop()

View File

@ -338,7 +338,9 @@ class ModelVisitor(MutableSequence):
self.occurs = Counter()
self._subgroups = []
self.element = None
self.group, self.items, self.match = root, iter(root), False
self.group = root
self.items = self.iter_group()
self.match = False
self._start()
def __str__(self):
@ -374,7 +376,9 @@ class ModelVisitor(MutableSequence):
del self._subgroups[:]
self.occurs.clear()
self.element = None
self.group, self.items, self.match = self.root, iter(self.root), False
self.group = self.root
self.items = self.iter_group()
self.match = False
def _start(self):
while True:
@ -421,6 +425,18 @@ class ModelVisitor(MutableSequence):
for e in self.advance():
yield e
def iter_group(self):
if self.group.model != 'all':
for item in self.group:
yield item
elif not self.occurs:
for e in self.group.iter_elements():
yield e
else:
for e in self.group.iter_elements():
if not e.is_over(self.occurs[e]):
yield e
def advance(self, match=False):
"""
Generator function for advance to the next element. Yields tuples with
@ -448,7 +464,7 @@ class ModelVisitor(MutableSequence):
if model == 'choice':
occurs[item] = 0
occurs[self.group] += 1
self.items, self.match = iter(self.group), False
self.items, self.match = self.iter_group(), False
elif model == 'sequence' and item is self.group[-1]:
self.occurs[self.group] += 1
return item.is_missing(item_occurs)
@ -473,7 +489,7 @@ class ModelVisitor(MutableSequence):
occurs[element] += 1
self.match = True
if self.group.model == 'all':
self.items = (e for e in self.group if not e.is_over(occurs[e]))
self.items = (e for e in self.group.iter_elements() if not e.is_over(occurs[e]))
elif not element.is_over(occurs[element]):
return
@ -490,15 +506,16 @@ class ModelVisitor(MutableSequence):
if obj is None:
if not self.match:
if self.group.model == 'all':
if all(e.min_occurs <= occurs[e] for e in self.group):
if all(e.min_occurs <= occurs[e] for e in self.group.iter_elements()):
occurs[self.group] = 1
group, expected = self.group, self.expected
if stop_item(group) and expected:
yield group, occurs[group], expected
elif self.group.model != 'all':
self.items, self.match = iter(self.group), False
self.items, self.match = self.iter_group(), False
elif any(not e.is_over(occurs[e]) for e in self.group):
self.items, self.match = (e for e in self.group if not e.is_over(occurs[e])), False
self.items = self.iter_group()
self.match = False
else:
occurs[self.group] = 1