Update XML resource iterfind() to fix issues #102 and #112

- Speed up admitting simple paths and checking only elements
    that match path level
  - Avoid selection for * paths (about 35% faster)
  - Add close() method to XmlResource
This commit is contained in:
Davide Brunato 2019-10-19 19:31:43 +02:00
parent 43322b6bc0
commit 8dd5d193ba
2 changed files with 200 additions and 18 deletions

View File

@ -11,7 +11,7 @@
import os.path
import re
import codecs
from elementpath import iter_select, Selector
from elementpath import iter_select, Selector, XPath1Parser
from .compat import (
PY3, StringIO, BytesIO, string_base_type, urlopen, urlsplit, urljoin, urlunsplit,
@ -26,8 +26,23 @@ from .etree import ElementTree, PyElementTree, SafeXMLParser, etree_tostring
DEFUSE_MODES = ('always', 'remote', 'never')
XML_RESOURCE_XPATH_SYMBOLS = {
'position', 'last', 'not', 'and', 'or', '!=', '<=', '>=', '(', ')', 'text',
'[', ']', '.', ',', '/', '|', '*', '=', '<', '>', ':', '(end)', '(name)',
'(string)', '(float)', '(decimal)', '(integer)'
}
class XmlResourceXPathParser(XPath1Parser):
symbol_table = {k: v for k, v in XPath1Parser.symbol_table.items() if k in XML_RESOURCE_XPATH_SYMBOLS}
SYMBOLS = XML_RESOURCE_XPATH_SYMBOLS
XmlResourceXPathParser.build_tokenizer()
def is_remote_url(url):
return url is not None and urlsplit(url).scheme not in ('', 'file')
return isinstance(url, string_base_type) and urlsplit(url).scheme not in ('', 'file')
def url_path_is_directory(url):
@ -424,14 +439,23 @@ class XMLResource(object):
def parse(self, source):
"""
An equivalent of *ElementTree.parse()* that can protect from XML entities attacks. When
protection is applied XML data are loaded and defused before building the ElementTree instance.
An equivalent of *ElementTree.parse()* that can protect from XML entities attacks.
When protection is applied XML data are loaded and defused before building the
ElementTree instance. The protection applied is based on value of *defuse*
attribute and *base_url* property.
:param source: a filename or file object containing XML data.
:returns: an ElementTree instance.
"""
if self.defuse == 'always' or self.defuse == 'remote' and is_remote_url(self._url):
text = source.read()
if self.defuse == 'always' or self.defuse == 'remote' and \
hasattr(source, 'read') and is_remote_url(self.base_url):
if hasattr(source, 'read'):
text = source.read()
else:
with open(source) as f:
text = f.read()
if isinstance(text, bytes):
self.defusing(BytesIO(text))
return ElementTree.parse(BytesIO(text))
@ -445,11 +469,14 @@ class XMLResource(object):
"""
An equivalent of *ElementTree.iterparse()* that can protect from XML entities attacks.
When protection is applied the iterator yields pure-Python Element instances.
The protection applied is based on resource *defuse* attribute and *base_url* property.
:param source: a filename or file object containing XML data.
:param events: a list of events to report back. If omitted, only end events are reported.
"""
if self.defuse == 'always' or self.defuse == 'remote' and is_remote_url(self._url):
if self.defuse == 'always' or self.defuse == 'remote' and \
hasattr(source, 'read') and is_remote_url(self.base_url):
parser = SafeXMLParser(target=PyElementTree.TreeBuilder())
try:
return PyElementTree.iterparse(source, events, parser)
@ -461,17 +488,20 @@ class XMLResource(object):
def fromstring(self, text):
"""
An equivalent of *ElementTree.fromstring()* that can protect from XML entities attacks.
The protection applied is based on resource *defuse* attribute and *base_url* property.
:param text: a string containing XML data.
:returns: the root Element instance.
"""
if self.defuse == 'always' or self.defuse == 'remote' and is_remote_url(self._url):
if self.defuse == 'always' or self.defuse == 'remote' and is_remote_url(self.base_url):
self.defusing(StringIO(text))
return ElementTree.fromstring(text)
def tostring(self, indent='', max_lines=None, spaces_for_tab=4, xml_declaration=False):
"""Generates a string representation of the XML resource."""
return etree_tostring(self._root, self.get_namespaces(), indent, max_lines, spaces_for_tab, xml_declaration)
elem = self._root
namespaces = self.get_namespaces()
return etree_tostring(elem, namespaces, indent, max_lines, spaces_for_tab, xml_declaration)
def copy(self, **kwargs):
"""Resource copy method. Change init parameters with keyword arguments."""
@ -502,6 +532,10 @@ class XMLResource(object):
raise XMLSchemaURLError(reason="cannot access to resource %r: %s" % (self._url, err.reason))
def seek(self, position):
"""
Change stream position if the XML resource was created with a seekable
file-like object. In the other cases this method has no effect.
"""
if not hasattr(self.source, 'read'):
return
@ -523,6 +557,16 @@ class XMLResource(object):
except AttributeError:
pass
def close(self):
"""
Close the XML resource if it's created with a file-like object.
In other cases this method has no effect.
"""
try:
self.source.close()
except (AttributeError, TypeError):
pass
def load(self):
"""
Loads the XML text from the data source. If the data source is an Element
@ -619,7 +663,11 @@ class XMLResource(object):
yield elem
elem.clear()
else:
selector = Selector(path, namespaces, strict=False)
selector = Selector(path, namespaces, strict=False, parser=XmlResourceXPathParser)
path.replace(' ', '').replace('./', '')
path_level = path.count('/') + 1
select_all = '*' in path and set(path).issubset({'*', '/'})
level = 0
for event, elem in self.iterparse(resource, events=('start', 'end')):
if event == "start":
@ -629,7 +677,8 @@ class XMLResource(object):
level += 1
else:
level -= 1
if elem in selector.select(self._root):
if level == path_level and \
(select_all or elem in selector.select(self._root)):
yield elem
elem.clear()
elif level == 0:

View File

@ -13,12 +13,14 @@
This module runs tests concerning resources.
"""
import unittest
import time
import os
import platform
try:
from pathlib import PureWindowsPath, PurePath
except ImportError:
# noinspection PyPackageRequirements
from pathlib2 import PureWindowsPath, PurePath
from xmlschema import (
@ -29,6 +31,7 @@ from xmlschema.tests import SKIP_REMOTE_TESTS, casepath
from xmlschema.compat import urlopen, urlsplit, uses_relative, StringIO
from xmlschema.etree import ElementTree, PyElementTree, lxml_etree, \
etree_element, py_etree_element
from xmlschema.namespaces import XSD_NAMESPACE
from xmlschema.helpers import is_etree_element
@ -344,14 +347,36 @@ class TestResources(unittest.TestCase):
resource.load()
self.assertTrue(resource.is_loaded())
def test_xml_resource_open(self):
def test_xml_resource_parse(self):
resource = XMLResource(self.vh_xml_file)
xml_file = resource.open()
data = xml_file.read().decode('utf-8')
self.assertTrue(data.startswith('<?xml '))
xml_file.close()
resource = XMLResource('<A/>')
self.assertRaises(ValueError, resource.open)
self.assertEqual(resource.defuse, 'remote')
xml_document = resource.parse(self.col_xml_file)
self.assertTrue(is_etree_element(xml_document.getroot()))
resource.defuse = 'always'
xml_document = resource.parse(self.col_xml_file)
self.assertTrue(is_etree_element(xml_document.getroot()))
def test_xml_resource_iterparse(self):
resource = XMLResource(self.vh_xml_file)
self.assertEqual(resource.defuse, 'remote')
for _, elem in resource.iterparse(self.col_xml_file, events=('end',)):
self.assertTrue(is_etree_element(elem))
resource.defuse = 'always'
for _, elem in resource.iterparse(self.col_xml_file, events=('end',)):
self.assertTrue(is_etree_element(elem))
def test_xml_resource_fromstring(self):
resource = XMLResource(self.vh_xml_file)
self.assertEqual(resource.defuse, 'remote')
self.assertEqual(resource.fromstring('<root/>').tag, 'root')
resource.defuse = 'always'
self.assertEqual(resource.fromstring('<root/>').tag, 'root')
def test_xml_resource_tostring(self):
resource = XMLResource(self.vh_xml_file)
@ -373,6 +398,114 @@ class TestResources(unittest.TestCase):
resource2 = resource.copy()
self.assertEqual(resource.text, resource2.text)
def test_xml_resource_open(self):
resource = XMLResource(self.vh_xml_file)
xml_file = resource.open()
self.assertIsNot(xml_file, resource.source)
data = xml_file.read().decode('utf-8')
self.assertTrue(data.startswith('<?xml '))
xml_file.close()
resource = XMLResource('<A/>')
self.assertRaises(ValueError, resource.open)
resource = XMLResource(source=open(self.vh_xml_file))
xml_file = resource.open()
self.assertIs(xml_file, resource.source)
xml_file.close()
def test_xml_resource_seek(self):
resource = XMLResource(self.vh_xml_file)
self.assertIsNone(resource.seek(0))
self.assertIsNone(resource.seek(1))
xml_file = open(self.vh_xml_file)
resource = XMLResource(source=xml_file)
self.assertEqual(resource.seek(0), 0)
self.assertEqual(resource.seek(1), 1)
xml_file.close()
def test_xml_resource_close(self):
resource = XMLResource(self.vh_xml_file)
resource.close()
xml_file = resource.open()
self.assertTrue(callable(xml_file.read))
xml_file = open(self.vh_xml_file)
resource = XMLResource(source=xml_file)
resource.close()
with self.assertRaises(ValueError):
resource.open()
def test_xml_resource_iter(self):
for lazy in (False, True):
resource = XMLResource(self.schema_class.meta_schema.source.url, lazy=lazy)
k = 0
for k, _ in enumerate(resource.iter()):
pass
self.assertEqual(k, 1389)
k = 0
for k, _ in enumerate(resource.iter('{%s}complexType' % XSD_NAMESPACE)):
pass
self.assertEqual(k, 55)
def test_xml_resource_iterfind(self):
resource = XMLResource(self.schema_class.meta_schema.source.url, lazy=False)
self.assertFalse(resource.is_lazy())
start_time = time.time()
for _ in range(10):
for _ in resource.iterfind():
pass
t1 = time.time() - start_time
start_time = time.time()
for _ in range(10):
for _ in resource.iterfind(path='.'):
pass
t2 = time.time() - start_time
self.assertLessEqual(t1, t2 / 30.0)
self.assertGreaterEqual(t1, t2 / 100.0)
start_time = time.time()
counter = 0
for _ in resource.iterfind(path='*'):
counter += 1
t3 = time.time() - start_time
self.assertGreaterEqual(t2, t3 / counter * 10)
resource = XMLResource(self.schema_class.meta_schema.source.url)
self.assertTrue(resource.is_lazy())
start_time = time.time()
for _ in range(10):
for _ in resource.iterfind():
pass
tl1 = time.time() - start_time
self.assertLessEqual(t1, tl1 / 1000.0)
self.assertGreaterEqual(t1, tl1 / 10000.0)
start_time = time.time()
for _ in range(10):
for _ in resource.iterfind(path='.'):
pass
tl2 = time.time() - start_time
self.assertLessEqual(t2, tl2 / 80.0)
self.assertGreaterEqual(t2, tl2 / 1000.0)
start_time = time.time()
counter3 = 0
for _ in resource.iterfind(path='*'):
counter3 += 1
tl3 = time.time() - start_time
self.assertGreaterEqual(tl2, tl3 / counter3 * 10)
start_time = time.time()
for _ in resource.iterfind(path='. /. / xs:complexType', namespaces={'xs': XSD_NAMESPACE}):
pass
tl4 = time.time() - start_time
self.assertTrue(0.7 < (tl3 / tl4) < 1)
def test_xml_resource_get_namespaces(self):
with open(self.vh_xml_file) as schema_file:
resource = XMLResource(schema_file)