diff --git a/elementpath/xpath1_parser.py b/elementpath/xpath1_parser.py index 61da0d4..95a77aa 100644 --- a/elementpath/xpath1_parser.py +++ b/elementpath/xpath1_parser.py @@ -280,14 +280,13 @@ def select(self, context=None): for item in context.iter_children_or_self(): if is_attribute_node(item, name) or is_element_node(item, tag): path = context.get_path(item) + xsd_component = self.parser.schema.find(path, self.parser.namespaces) - - # print(path, xsd_component) - if xsd_component is not None: self.xsd_type = xsd_component.type else: self.xsd_type = self.parser.schema + yield self.get_typed_node(context, item) else: # XSD typed selection diff --git a/elementpath/xpath_context.py b/elementpath/xpath_context.py index 6e3fffe..ede8eea 100644 --- a/elementpath/xpath_context.py +++ b/elementpath/xpath_context.py @@ -91,6 +91,8 @@ class XPathContext(object): that are not included in the tree. Uses a LRU cache to minimize parent map rebuilding for trees processed with an incremental parser. """ + if isinstance(elem, TypedElement): + elem = elem[0] if elem is self.root: return diff --git a/elementpath/xpath_nodes.py b/elementpath/xpath_nodes.py index a588e2c..310ed6c 100644 --- a/elementpath/xpath_nodes.py +++ b/elementpath/xpath_nodes.py @@ -131,11 +131,11 @@ def is_schema_node(obj): def is_comment_node(obj): - return is_etree_element(obj) and callable(obj.tag) and obj.tag.__name__ == 'Comment' + return hasattr(obj, 'tag') and callable(obj.tag) and obj.tag.__name__ == 'Comment' def is_processing_instruction_node(obj): - return is_etree_element(obj) and callable(obj.tag) and obj.tag.__name__ == 'ProcessingInstruction' + return hasattr(obj, 'tag') and callable(obj.tag) and obj.tag.__name__ == 'ProcessingInstruction' def is_document_node(obj): @@ -155,8 +155,8 @@ else: def is_xpath_node(obj): - return isinstance(obj, tuple) or is_etree_element(obj) or \ - is_document_node(obj) or is_text_node(obj) or is_schema_node(obj) + return isinstance(obj, tuple) or is_etree_element(obj) or is_schema_node(obj) or \ + is_document_node(obj) or is_text_node(obj) ### diff --git a/elementpath/xpath_token.py b/elementpath/xpath_token.py index c2b2e37..beb18df 100644 --- a/elementpath/xpath_token.py +++ b/elementpath/xpath_token.py @@ -27,7 +27,7 @@ from decimal import Decimal from .compat import string_base_type, unicode_type from .exceptions import xpath_error from .namespaces import XQT_ERRORS_NAMESPACE -from .xpath_nodes import AttributeNode, NamespaceNode, TypedAttribute, TypedElement, \ +from .xpath_nodes import AttributeNode, TypedAttribute, TypedElement, \ is_etree_element, is_attribute_node, elem_iter_strings, is_text_node, \ is_namespace_node, is_comment_node, is_processing_instruction_node, \ is_element_node, is_document_node, is_xpath_node, is_schema_node @@ -291,8 +291,10 @@ class XPathToken(Token): for result in self.select(context): if isinstance(result, TypedElement): yield result[0] - elif isinstance(result, (AttributeNode, TypedAttribute)): + elif isinstance(result, AttributeNode): yield result[1] + elif isinstance(result, TypedAttribute): + yield result[0][1] if hasattr(result[0][1], 'type') else result[1] else: yield result diff --git a/tests/test_elementpath.py b/tests/test_elementpath.py index eb48490..8080449 100644 --- a/tests/test_elementpath.py +++ b/tests/test_elementpath.py @@ -38,7 +38,7 @@ if __name__ == '__main__': except ImportError: # Python 2 fallback from test_exceptions import ExceptionsTest - from test_namespaces import NamespacessTest + from test_namespaces import NamespacesTest from test_datatypes import UntypedAtomicTest, DateTimeTypesTest, DurationTypesTest, TimezoneTypeTest from test_xpath_nodes import XPathNodesTest from test_xpath_token import XPathTokenTest diff --git a/tests/test_xpath2_parser.py b/tests/test_xpath2_parser.py index 6917c6c..3cad794 100644 --- a/tests/test_xpath2_parser.py +++ b/tests/test_xpath2_parser.py @@ -658,12 +658,18 @@ class XPath2ParserTest(test_xpath1_parser.XPath1ParserTest): ' ' ' ' '') + + namespaces = {'p0': 'ns0', 'p2': 'ns2'} + prefixes = select(root, "fn:in-scope-prefixes(.)", namespaces, parser=self.parser.__class__) + if self.etree is lxml_etree: - prefixes = {'p0', 'p1'} + self.assertIn('p0', prefixes) + self.assertIn('p1', prefixes) + self.assertNotIn('p2', prefixes) else: - prefixes = {'p0', 'p2', 'fn', 'xlink', 'err', 'vc', 'xslt', '', 'hfp'} - prefixes |= {x for x in self.etree._namespace_map.values()} - self.check_selector("fn:in-scope-prefixes(.)", root, prefixes, namespaces={'p0': 'ns0', 'p2': 'ns2'}) + self.assertIn('p0', prefixes) + self.assertNotIn('p1', prefixes) + self.assertIn('p2', prefixes) def test_string_constructors(self): self.check_value("xs:string(5.0)", '5.0') diff --git a/tests/test_xpath_context.py b/tests/test_xpath_context.py index 74d37de..c6e2d44 100644 --- a/tests/test_xpath_context.py +++ b/tests/test_xpath_context.py @@ -13,6 +13,7 @@ import unittest import xml.etree.ElementTree as ElementTree from elementpath import * +from elementpath.compat import PY3 class XPathContextTest(unittest.TestCase): @@ -48,7 +49,35 @@ class XPathContextTest(unittest.TestCase): root[2]: root, root[2][0]: root[2], root[2][1]: root[2] }) - def test_path(self): + def test_get_parent(self): + root = ElementTree.XML('') + + context = XPathContext(root) + + self.assertIsNone(context._parent_map) + self.assertIsNone(context.get_parent(root)) + + self.assertIsNone(context._parent_map) + self.assertEqual(context.get_parent(root[0]), root) + self.assertIsInstance(context._parent_map, dict) + parent_map_id = id(context._parent_map) + + self.assertEqual(context.get_parent(root[1]), root) + self.assertEqual(context.get_parent(root[2]), root) + self.assertEqual(context.get_parent(root[2][1]), root[2]) + + self.assertEqual(context.get_parent(TypedElement(root[2][1], None)), root[2]) + self.assertEqual(id(context._parent_map), parent_map_id) + + self.assertIsNone(context.get_parent(AttributeNode('max', '10'))) + self.assertNotEqual(id(context._parent_map), parent_map_id) + + parent_map_id = id(context._parent_map) + self.assertIsNone(context.get_parent(AttributeNode('max', '10'))) + if PY3: + self.assertEqual(id(context._parent_map), parent_map_id) # LRU cache prevents parent map rebuild + + def test_get_path(self): root = ElementTree.XML('') context = XPathContext(root) diff --git a/tests/test_xpath_token.py b/tests/test_xpath_token.py index 99e414e..3d63804 100644 --- a/tests/test_xpath_token.py +++ b/tests/test_xpath_token.py @@ -14,6 +14,7 @@ import unittest import io import math import xml.etree.ElementTree as ElementTree +from collections import namedtuple from elementpath.namespaces import XSD_NAMESPACE from elementpath.xpath_nodes import AttributeNode, TypedAttribute, TypedElement, NamespaceNode @@ -61,6 +62,10 @@ class XPathTokenTest(unittest.TestCase): context = XPathContext(elem, item=TypedAttribute(AttributeNode('max', '30'), 30)) self.assertListEqual(list(token.select_results(context)), [30]) + attribute = namedtuple('XsdAttribute', 'name type')('max', 'xs:string') + context = XPathContext(elem, item=TypedAttribute(AttributeNode('max', attribute), 30)) + self.assertListEqual(list(token.select_results(context)), [attribute]) + context = XPathContext(elem, item=10) self.assertListEqual(list(token.select_results(context)), [10])