Fix assertions

- Add custom parser for assertion facet (without position() and last()).
  - Move parser initialization to XsdAssert.parse_xpath_test() because
    all the components must be defined.
This commit is contained in:
Davide Brunato 2019-10-01 06:44:31 +02:00
parent b6c6e2ac8f
commit 844ddec3ba
9 changed files with 96 additions and 62 deletions

View File

@ -2,7 +2,7 @@
setuptools
tox
coverage
elementpath~=1.2.0
elementpath~=1.3.0
lxml
memory_profiler
pathlib2 # For Py27 tests on resources

View File

@ -39,7 +39,7 @@ class InstallCommand(install):
setup(
name='xmlschema',
version='1.0.15',
install_requires=['elementpath~=1.2.0'],
install_requires=['elementpath~=1.3.0'],
packages=['xmlschema'],
include_package_data=True,
cmdclass={

View File

@ -11,7 +11,7 @@ toxworkdir = {homedir}/.tox/xmlschema
[testenv]
deps =
lxml
elementpath~=1.2.0
elementpath~=1.3.0
py27: pathlib2
memory: memory_profiler
docs: Sphinx
@ -25,7 +25,7 @@ whitelist_externals = make
[testenv:py38]
deps =
lxml==4.3.5
elementpath~=1.2.0
elementpath~=1.3.0
[testenv:package]
commands = python xmlschema/tests/test_package.py

View File

@ -45,43 +45,43 @@ class XsdXPathTest(unittest.TestCase):
self.assertTrue(self.xs1.findall('.'))
self.assertTrue(isinstance(self.xs1.find('.'), XMLSchema))
self.assertTrue(sorted(self.xs1.findall("*"), key=lambda x: x.name) == elements)
self.assertTrue(self.xs1.findall("*") == self.xs1.findall("./*"))
self.assertTrue(self.xs1.find("./vh:bikes") == self.xs1.elements['bikes'])
self.assertTrue(self.xs1.find("./vh:vehicles/vh:cars").name == self.xs1.elements['cars'].name)
self.assertFalse(self.xs1.find("./vh:vehicles/vh:cars") == self.xs1.elements['cars'])
self.assertFalse(self.xs1.find("/vh:vehicles/vh:cars") == self.xs1.elements['cars'])
self.assertTrue(self.xs1.find("vh:vehicles/vh:cars/..") == self.xs1.elements['vehicles'])
self.assertTrue(self.xs1.find("vh:vehicles/*/..") == self.xs1.elements['vehicles'])
self.assertTrue(self.xs1.find("vh:vehicles/vh:cars/../vh:cars") == self.xs1.find("vh:vehicles/vh:cars"))
self.assertListEqual(self.xs1.findall("*"), self.xs1.findall("./*"))
self.assertEqual(self.xs1.find("./vh:bikes"), self.xs1.elements['bikes'])
self.assertEqual(self.xs1.find("./vh:vehicles/vh:cars").name, self.xs1.elements['cars'].name)
self.assertNotEqual(self.xs1.find("./vh:vehicles/vh:cars"), self.xs1.elements['cars'])
self.assertNotEqual(self.xs1.find("/vh:vehicles/vh:cars"), self.xs1.elements['cars'])
self.assertEqual(self.xs1.find("vh:vehicles/vh:cars/.."), self.xs1.elements['vehicles'])
self.assertEqual(self.xs1.find("vh:vehicles/*/.."), self.xs1.elements['vehicles'])
self.assertEqual(self.xs1.find("vh:vehicles/vh:cars/../vh:cars"), self.xs1.find("vh:vehicles/vh:cars"))
def test_xpath_axis(self):
self.assertTrue(self.xs1.find("vh:vehicles/child::vh:cars/..") == self.xs1.elements['vehicles'])
self.assertEqual(self.xs1.find("vh:vehicles/child::vh:cars/.."), self.xs1.elements['vehicles'])
def test_xpath_subscription(self):
self.assertTrue(len(self.xs1.findall("./vh:vehicles/*")) == 2)
self.assertTrue(self.xs1.findall("./vh:vehicles/*[2]") == [self.bikes])
self.assertTrue(self.xs1.findall("./vh:vehicles/*[3]") == [])
self.assertTrue(self.xs1.findall("./vh:vehicles/*[last()-1]") == [self.cars])
self.assertTrue(self.xs1.findall("./vh:vehicles/*[position()=last()]") == [self.bikes])
self.assertEqual(len(self.xs1.findall("./vh:vehicles/*")), 2)
self.assertListEqual(self.xs1.findall("./vh:vehicles/*[2]"), [self.bikes])
self.assertListEqual(self.xs1.findall("./vh:vehicles/*[3]"), [])
self.assertListEqual(self.xs1.findall("./vh:vehicles/*[last()-1]"), [self.cars])
self.assertListEqual(self.xs1.findall("./vh:vehicles/*[position()=last()]"), [self.bikes])
def test_xpath_group(self):
self.assertTrue(self.xs1.findall("/(vh:vehicles/*/*)") == self.xs1.findall("/vh:vehicles/*/*"))
self.assertTrue(self.xs1.findall("/(vh:vehicles/*/*)[1]") == self.xs1.findall("/vh:vehicles/*/*[1]"))
self.assertEqual(self.xs1.findall("/(vh:vehicles/*/*)"), self.xs1.findall("/vh:vehicles/*/*"))
self.assertEqual(self.xs1.findall("/(vh:vehicles/*/*)[1]"), self.xs1.findall("/vh:vehicles/*/*[1]")[:1])
def test_xpath_predicate(self):
car = self.xs1.elements['cars'].type.content_type[0]
self.assertTrue(self.xs1.findall("./vh:vehicles/vh:cars/vh:car[@make]") == [car])
self.assertTrue(self.xs1.findall("./vh:vehicles/vh:cars/vh:car[@make]") == [car])
self.assertTrue(self.xs1.findall("./vh:vehicles/vh:cars['ciao']") == [self.cars])
self.assertTrue(self.xs1.findall("./vh:vehicles/*['']") == [])
self.assertListEqual(self.xs1.findall("./vh:vehicles/vh:cars/vh:car[@make]"), [car])
self.assertListEqual(self.xs1.findall("./vh:vehicles/vh:cars/vh:car[@make]"), [car])
self.assertListEqual(self.xs1.findall("./vh:vehicles/vh:cars['ciao']"), [self.cars])
self.assertListEqual(self.xs1.findall("./vh:vehicles/*['']"), [])
def test_xpath_descendants(self):
selector = Selector('.//xs:element', self.xs2.namespaces, parser=XPath1Parser)
elements = list(selector.iter_select(self.xs2.root))
self.assertTrue(len(elements) == 14)
self.assertEqual(len(elements), 14)
selector = Selector('.//xs:element|.//xs:attribute|.//xs:keyref', self.xs2.namespaces, parser=XPath1Parser)
elements = list(selector.iter_select(self.xs2.root))
self.assertTrue(len(elements) == 17)
self.assertEqual(len(elements), 17)
def test_xpath_issues(self):
namespaces = {'ps': "http://schemas.microsoft.com/powershell/2004/04"}

View File

@ -32,48 +32,54 @@ class XsdAssert(XsdComponent, ElementPathMixin):
"""
_ADMITTED_TAGS = {XSD_ASSERT}
token = None
parser = None
path = 'true()'
def __init__(self, elem, schema, parent, base_type):
self.base_type = base_type
super(XsdAssert, self).__init__(elem, schema, parent)
def __repr__(self):
return '%s(test=%r)' % (self.__class__.__name__, self.path)
def _parse(self):
super(XsdAssert, self)._parse()
if self.base_type.is_complex():
if self.base_type.is_simple():
self.parse_error("base_type=%r is not a complexType definition" % self.base_type)
else:
try:
self.path = self.elem.attrib['test']
except KeyError as err:
self.parse_error(str(err), elem=self.elem)
self.path = 'true()'
if not self.base_type.has_simple_content():
variables = {'value': datatypes.XSD_BUILTIN_TYPES['anyType'].value}
else:
try:
builtin_type_name = self.base_type.content_type.primitive_type.local_name
except AttributeError:
variables = {'value': datatypes.XSD_BUILTIN_TYPES['anySimpleType'].value}
else:
variables = {'value': datatypes.XSD_BUILTIN_TYPES[builtin_type_name].value}
else:
self.parse_error("base_type=%r is not a complexType definition" % self.base_type)
self.path = 'true()'
variables = None
if 'xpathDefaultNamespace' in self.elem.attrib:
self.xpath_default_namespace = self._parse_xpath_default_namespace(self.elem)
else:
self.xpath_default_namespace = self.schema.xpath_default_namespace
self.parser = XPath2Parser(self.namespaces, variables, False,
self.xpath_default_namespace, schema=self.xpath_proxy)
self.xpath_proxy = XMLSchemaProxy(self.schema, self)
@property
def built(self):
return self.token is not None and (self.base_type.parent is None or self.base_type.built)
def parse_xpath_test(self):
self.parser.schema = XMLSchemaProxy(self.schema, self)
if self.base_type.has_simple_content():
variables = {'value': datatypes.XSD_BUILTIN_TYPES['anyType'].value}
elif self.base_type.is_complex():
try:
builtin_type_name = self.base_type.content_type.primitive_type.local_name
except AttributeError:
variables = {'value': datatypes.XSD_BUILTIN_TYPES['anySimpleType'].value}
else:
variables = {'value': datatypes.XSD_BUILTIN_TYPES[builtin_type_name].value}
else:
variables = None
self.parser = XPath2Parser(
self.namespaces, variables, False, self.xpath_default_namespace, schema=self.xpath_proxy
)
try:
self.token = self.parser.parse(self.path)
except ElementPathError as err:
@ -81,10 +87,16 @@ class XsdAssert(XsdComponent, ElementPathMixin):
self.token = self.parser.parse('true()')
def __call__(self, elem, value=None, source=None, **kwargs):
self.parser.variables['value'] = value
root = elem if source is None else source.root
if value is not None:
self.parser.variables['value'] = self.base_type.text_decode(value)
if source is None:
context = XPathContext(root=elem)
else:
context = XPathContext(root=source.root, item=elem)
try:
if not self.token.evaluate(XPathContext(root=root, item=elem)):
if not self.token.evaluate(context.copy()):
msg = "expression is not true with test path %r."
yield XMLSchemaValidationError(self, obj=elem, reason=msg % self.path)
except ElementPathError as err:

View File

@ -546,7 +546,7 @@ class XsdComplexType(XsdType, ValidationMixin):
for obj in self.base_type.iter_components(xsd_classes):
yield obj
for obj in self.assertions:
for obj in filter(lambda x: x.base_type is self, self.assertions):
if xsd_classes is None or isinstance(obj, xsd_classes):
yield obj
@ -857,7 +857,7 @@ class Xsd11ComplexType(XsdComplexType):
def _parse_content_tail(self, elem, **kwargs):
self.attributes = self.schema.BUILDERS.attribute_group_class(elem, self.schema, self, **kwargs)
self.assertions = []
for child in filter(lambda x: x.tag != XSD_ANNOTATION, elem):
if child.tag == XSD_ASSERT:
self.assertions.append(XsdAssert(child, self.schema, self, self))
self.assertions = [XsdAssert(e, self.schema, self, self) for e in elem if e.tag == XSD_ASSERT]
if getattr(self.base_type, 'assertions', None):
self.assertions.extend(assertion for assertion in self.base_type.assertions)

View File

@ -643,6 +643,25 @@ class XsdPatternFacets(MutableSequence, XsdFacet):
return [e.get('value', '') for e in self._elements]
class XsdAssertionXPathParser(XPath2Parser):
"""Parser for XSD 1.1 assertion facets."""
XsdAssertionXPathParser.unregister('last')
XsdAssertionXPathParser.unregister('position')
@XsdAssertionXPathParser.method(XsdAssertionXPathParser.function('last', nargs=0))
def evaluate(self, context=None):
self.missing_context("Context item size is undefined")
@XsdAssertionXPathParser.method(XsdAssertionXPathParser.function('position', nargs=0))
def evaluate(self, context=None):
self.missing_context("Context item position is undefined")
XsdAssertionXPathParser.build_tokenizer()
class XsdAssertionFacet(XsdFacet):
"""
XSD 1.1 *assertion* facet for simpleType definitions.
@ -678,8 +697,8 @@ class XsdAssertionFacet(XsdFacet):
self.xpath_default_namespace = self._parse_xpath_default_namespace(self.elem)
else:
self.xpath_default_namespace = self.schema.xpath_default_namespace
self.parser = XPath2Parser(self.namespaces, strict=False, variables=variables,
default_namespace=self.xpath_default_namespace)
self.parser = XsdAssertionXPathParser(self.namespaces, strict=False, variables=variables,
default_namespace=self.xpath_default_namespace)
try:
self.token = self.parser.parse(self.path)

View File

@ -814,7 +814,7 @@ class XMLSchemaBase(XsdValidator, ValidationMixin, ElementPathMixin):
def get_element(self, tag, path=None, namespaces=None):
if not path:
return self.find(tag)
return self.find(tag, namespaces)
elif path[-1] == '*':
return self.find(path[:-1] + tag, namespaces)
else:
@ -1185,7 +1185,7 @@ class XMLSchemaBase(XsdValidator, ValidationMixin, ElementPathMixin):
schema_path = '/%s/*' % source.root.tag
for elem in source.iterfind(path, namespaces):
xsd_element = self.get_element(elem.tag, schema_path, namespaces)
xsd_element = self.get_element(elem.tag, schema_path, self.namespaces)
if xsd_element is None:
yield self.validation_error('lax', "%r is not an element of the schema" % elem, elem)

View File

@ -213,7 +213,7 @@ class ElementPathMixin(Sequence):
default_namespace=self.xpath_default_namespace)
root_token = parser.parse(path)
context = XMLSchemaContext(self)
return root_token.select(context)
return root_token.select_results(context)
def find(self, path, namespaces=None):
"""
@ -226,14 +226,17 @@ class ElementPathMixin(Sequence):
path = path.strip()
if path.startswith('/') and not path.startswith('//'):
path = ''.join(['/', XSD_SCHEMA, path])
if namespaces is None:
namespaces = {k: v for k, v in self.namespaces.items() if k}
namespaces[''] = self.xpath_default_namespace
elif '' not in namespaces:
namespaces[''] = self.xpath_default_namespace
parser = XPath2Parser(namespaces, strict=False, schema=self.xpath_proxy,
default_namespace=self.xpath_default_namespace)
parser = XPath2Parser(namespaces, strict=False, schema=self.xpath_proxy)
root_token = parser.parse(path)
context = XMLSchemaContext(self)
return next(root_token.select(context), None)
return next(root_token.select_results(context), None)
def findall(self, path, namespaces=None):
"""