Add select_results() to XPathToken

This commit is contained in:
Davide Brunato 2019-09-30 17:17:05 +02:00
parent 4aeee2bb34
commit 22d9f4258d
6 changed files with 38 additions and 43 deletions

View File

@ -198,7 +198,6 @@ class XPath1Parser(Parser):
def parse(self, source):
root_token = super(XPath1Parser, self).parse(source)
root_token.is_root = True
try:
root_token.evaluate() # Static context evaluation
except MissingContextError:
@ -257,10 +256,6 @@ def select(self, context=None):
for item in context.iter_children_or_self():
xsd_type = self.match_xsd_type(item, name)
if xsd_type is not None:
if self.is_root:
yield item
continue
primitive_type = self.parser.schema.get_primitive_type(xsd_type)
value = XSD_BUILTIN_TYPES[primitive_type.local_name].value
if isinstance(item, AttributeNode):
@ -280,18 +275,12 @@ def select(self, context=None):
for item in context.iter_children_or_self():
try:
if is_attribute_node(item, name):
if self.is_root:
yield self.xsd_type.decode(item[1])
else:
yield TypedAttribute(item, self.xsd_type.decode(item[1]))
yield TypedAttribute(item, self.xsd_type.decode(item[1]))
elif is_element_node(item, tag):
if isinstance(item, TypedElement):
yield item[1] if self.is_root else item
yield item
elif self.xsd_type.is_simple() or self.xsd_type.has_simple_content():
if self.is_root:
yield self.xsd_type.decode(item.text)
else:
yield TypedElement(item, self.xsd_type.decode(item.text))
yield TypedElement(item, self.xsd_type.decode(item.text))
else:
yield item
except (TypeError, ValueError):
@ -644,7 +633,7 @@ def select(self, context=None):
yield item
else:
results = {item for k in range(2) for item in self[k].select(context.copy())}
for item in context.iter_results(results, self.is_root):
for item in context.iter_results(results):
yield item
@ -684,10 +673,7 @@ def select(self, context=None):
elif len(self) == 1:
context.item = None
for result in self[0].select(context):
if isinstance(result, (AttributeNode, TypedAttribute, TypedElement)):
yield result[1] if self.is_root else result
else:
yield result
yield result
else:
items = []
context2 = context.copy()
@ -704,10 +690,10 @@ def select(self, context=None):
elif isinstance(result, (TypedAttribute, TypedElement)):
if result[0] not in items:
items.append(result)
yield result[1] if self.is_root else result
yield result
elif isinstance(result, AttributeNode):
items.append(result)
yield result[1] if self.is_root else result
yield result
else:
items.append(result)
yield result
@ -880,7 +866,7 @@ def select(self, context=None):
for _ in context.iter_attributes():
for result in self[0].select(context):
yield result[1] if self.is_root else result
yield result
@method(axis('namespace'))

View File

@ -339,8 +339,6 @@ class XPath2Parser(XPath1Parser):
def parse(self, source):
root_token = super(XPath1Parser, self).parse(source)
root_token.is_root = True
if self.schema is None:
try:
root_token.evaluate() # Static context evaluation
@ -391,7 +389,7 @@ XPath2Parser.duplicate('|', 'union')
def select(self, context=None):
if context is not None:
results = set(self[0].select(context.copy())) & set(self[1].select(context.copy()))
for item in context.iter_results(results, self.is_root):
for item in context.iter_results(results):
yield item
@ -399,7 +397,7 @@ def select(self, context=None):
def select(self, context=None):
if context is not None:
results = set(self[0].select(context.copy())) - set(self[1].select(context.copy()))
for item in context.iter_results(results, self.is_root):
for item in context.iter_results(results):
yield item

View File

@ -258,27 +258,24 @@ class XPathContext(object):
self.item, self.size, self.position, self.axis = status
def iter_results(self, results, is_root=False):
def iter_results(self, results):
"""Iterates results in document order."""
status = self.item, self.size, self.position
self.item = self.root
for item in self._iter_context():
if item in results:
if is_attribute_node(item):
yield item[1] if is_root else item
else:
yield item
yield item
elif isinstance(item, AttributeNode):
# Match XSD decoded attributes
for attr in filter(lambda x: isinstance(x, TypedAttribute), results):
if attr[0] == item:
yield attr[1] if is_root else attr
yield attr
elif is_etree_element(item):
# Match XSD decoded elements
for elem in filter(lambda x: isinstance(x, TypedElement), results):
if elem[0] is item:
yield elem[1] if is_root else elem
yield elem
self.item, self.size, self.position = status

View File

@ -45,7 +45,7 @@ def iter_select(root, path, namespaces=None, parser=None, **kwargs):
parser = (parser or XPath2Parser)(namespaces, **kwargs)
root_token = parser.parse(path)
context = XPathContext(root)
return root_token.select(context)
return root_token.select_results(context)
class Selector(object):
@ -98,4 +98,4 @@ class Selector(object):
:return: A generator of the XPath expression results.
"""
context = XPathContext(root)
return self.root_token.select(context)
return self.root_token.select_results(context)

View File

@ -27,10 +27,10 @@ from decimal import Decimal
from .compat import string_base_type, unicode_type
from .exceptions import xpath_error
from .namespaces import XQT_ERRORS_NAMESPACE
from .xpath_nodes import AttributeNode, TypedElement, is_etree_element, is_attribute_node, \
elem_iter_strings, is_text_node, is_namespace_node, is_comment_node, \
is_processing_instruction_node, is_element_node, is_document_node, \
is_xpath_node, is_schema_node
from .xpath_nodes import AttributeNode, NamespaceNode, TypedElement, is_etree_element, \
is_attribute_node, elem_iter_strings, is_text_node, is_namespace_node, \
is_comment_node, is_processing_instruction_node, is_element_node, \
is_document_node, is_xpath_node, is_schema_node
from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, XSD_BUILTIN_TYPES
from .tdop_parser import Token
@ -52,8 +52,6 @@ def ordinal(n):
class XPathToken(Token):
"""Base class for XPath tokens."""
is_root = False # Flag that is set to True for root token instances
comment = None # for XPath 2.0+ comments
xsd_type = None # fox XPath 2.0+ schema types labeling
@ -279,6 +277,22 @@ class XPathToken(Token):
return [(self.data_value(value1), self.data_value(value2))
for value1 in operand1 for value2 in operand2]
def select_results(self, context):
"""
Generates formatted XPath results.
:param context: the XPath dynamic context.
"""
for result in self.select(context):
if not isinstance(result, tuple):
yield result # not a namedtuple-wrapped result
elif hasattr(result[0], 'type'):
yield result[0] # an XSD schema attribute/element
elif not isinstance(result, NamespaceNode):
yield result[1]
else:
yield result
def get_results(self, context):
"""
Returns formatted XPath results.
@ -287,7 +301,7 @@ class XPathToken(Token):
:return: a list or a simple datatype when the result is a single simple type \
generated by a literal or function token.
"""
results = list(self.select(context))
results = [x for x in self.select_results(context)]
if len(results) == 1:
res = results[0]
if isinstance(res, (bool, int, float, Decimal)):

View File

@ -259,7 +259,7 @@ class XPath2ParserTest(test_xpath1_parser.XPath1ParserTest):
self.check_value('($a, $b) = ($c, 2.0)', True, context=context)
root = self.etree.XML('<root min="10" max="7"/>')
self.check_value('@min', ['10'], context=XPathContext(root=root))
self.check_value('@min', [AttributeNode('min', '10')], context=XPathContext(root=root))
self.check_value('@min le @max', True, context=XPathContext(root=root))
root = self.etree.XML('<root min="80" max="7"/>')
self.check_value('@min le @max', False, context=XPathContext(root=root))