Add select_results() to XPathToken
This commit is contained in:
parent
4aeee2bb34
commit
22d9f4258d
|
@ -198,7 +198,6 @@ class XPath1Parser(Parser):
|
|||
|
||||
def parse(self, source):
|
||||
root_token = super(XPath1Parser, self).parse(source)
|
||||
root_token.is_root = True
|
||||
try:
|
||||
root_token.evaluate() # Static context evaluation
|
||||
except MissingContextError:
|
||||
|
@ -257,10 +256,6 @@ def select(self, context=None):
|
|||
for item in context.iter_children_or_self():
|
||||
xsd_type = self.match_xsd_type(item, name)
|
||||
if xsd_type is not None:
|
||||
if self.is_root:
|
||||
yield item
|
||||
continue
|
||||
|
||||
primitive_type = self.parser.schema.get_primitive_type(xsd_type)
|
||||
value = XSD_BUILTIN_TYPES[primitive_type.local_name].value
|
||||
if isinstance(item, AttributeNode):
|
||||
|
@ -280,18 +275,12 @@ def select(self, context=None):
|
|||
for item in context.iter_children_or_self():
|
||||
try:
|
||||
if is_attribute_node(item, name):
|
||||
if self.is_root:
|
||||
yield self.xsd_type.decode(item[1])
|
||||
else:
|
||||
yield TypedAttribute(item, self.xsd_type.decode(item[1]))
|
||||
yield TypedAttribute(item, self.xsd_type.decode(item[1]))
|
||||
elif is_element_node(item, tag):
|
||||
if isinstance(item, TypedElement):
|
||||
yield item[1] if self.is_root else item
|
||||
yield item
|
||||
elif self.xsd_type.is_simple() or self.xsd_type.has_simple_content():
|
||||
if self.is_root:
|
||||
yield self.xsd_type.decode(item.text)
|
||||
else:
|
||||
yield TypedElement(item, self.xsd_type.decode(item.text))
|
||||
yield TypedElement(item, self.xsd_type.decode(item.text))
|
||||
else:
|
||||
yield item
|
||||
except (TypeError, ValueError):
|
||||
|
@ -644,7 +633,7 @@ def select(self, context=None):
|
|||
yield item
|
||||
else:
|
||||
results = {item for k in range(2) for item in self[k].select(context.copy())}
|
||||
for item in context.iter_results(results, self.is_root):
|
||||
for item in context.iter_results(results):
|
||||
yield item
|
||||
|
||||
|
||||
|
@ -684,10 +673,7 @@ def select(self, context=None):
|
|||
elif len(self) == 1:
|
||||
context.item = None
|
||||
for result in self[0].select(context):
|
||||
if isinstance(result, (AttributeNode, TypedAttribute, TypedElement)):
|
||||
yield result[1] if self.is_root else result
|
||||
else:
|
||||
yield result
|
||||
yield result
|
||||
else:
|
||||
items = []
|
||||
context2 = context.copy()
|
||||
|
@ -704,10 +690,10 @@ def select(self, context=None):
|
|||
elif isinstance(result, (TypedAttribute, TypedElement)):
|
||||
if result[0] not in items:
|
||||
items.append(result)
|
||||
yield result[1] if self.is_root else result
|
||||
yield result
|
||||
elif isinstance(result, AttributeNode):
|
||||
items.append(result)
|
||||
yield result[1] if self.is_root else result
|
||||
yield result
|
||||
else:
|
||||
items.append(result)
|
||||
yield result
|
||||
|
@ -880,7 +866,7 @@ def select(self, context=None):
|
|||
|
||||
for _ in context.iter_attributes():
|
||||
for result in self[0].select(context):
|
||||
yield result[1] if self.is_root else result
|
||||
yield result
|
||||
|
||||
|
||||
@method(axis('namespace'))
|
||||
|
|
|
@ -339,8 +339,6 @@ class XPath2Parser(XPath1Parser):
|
|||
|
||||
def parse(self, source):
|
||||
root_token = super(XPath1Parser, self).parse(source)
|
||||
root_token.is_root = True
|
||||
|
||||
if self.schema is None:
|
||||
try:
|
||||
root_token.evaluate() # Static context evaluation
|
||||
|
@ -391,7 +389,7 @@ XPath2Parser.duplicate('|', 'union')
|
|||
def select(self, context=None):
|
||||
if context is not None:
|
||||
results = set(self[0].select(context.copy())) & set(self[1].select(context.copy()))
|
||||
for item in context.iter_results(results, self.is_root):
|
||||
for item in context.iter_results(results):
|
||||
yield item
|
||||
|
||||
|
||||
|
@ -399,7 +397,7 @@ def select(self, context=None):
|
|||
def select(self, context=None):
|
||||
if context is not None:
|
||||
results = set(self[0].select(context.copy())) - set(self[1].select(context.copy()))
|
||||
for item in context.iter_results(results, self.is_root):
|
||||
for item in context.iter_results(results):
|
||||
yield item
|
||||
|
||||
|
||||
|
|
|
@ -258,27 +258,24 @@ class XPathContext(object):
|
|||
|
||||
self.item, self.size, self.position, self.axis = status
|
||||
|
||||
def iter_results(self, results, is_root=False):
|
||||
def iter_results(self, results):
|
||||
"""Iterates results in document order."""
|
||||
status = self.item, self.size, self.position
|
||||
self.item = self.root
|
||||
|
||||
for item in self._iter_context():
|
||||
if item in results:
|
||||
if is_attribute_node(item):
|
||||
yield item[1] if is_root else item
|
||||
else:
|
||||
yield item
|
||||
yield item
|
||||
elif isinstance(item, AttributeNode):
|
||||
# Match XSD decoded attributes
|
||||
for attr in filter(lambda x: isinstance(x, TypedAttribute), results):
|
||||
if attr[0] == item:
|
||||
yield attr[1] if is_root else attr
|
||||
yield attr
|
||||
elif is_etree_element(item):
|
||||
# Match XSD decoded elements
|
||||
for elem in filter(lambda x: isinstance(x, TypedElement), results):
|
||||
if elem[0] is item:
|
||||
yield elem[1] if is_root else elem
|
||||
yield elem
|
||||
|
||||
self.item, self.size, self.position = status
|
||||
|
||||
|
|
|
@ -45,7 +45,7 @@ def iter_select(root, path, namespaces=None, parser=None, **kwargs):
|
|||
parser = (parser or XPath2Parser)(namespaces, **kwargs)
|
||||
root_token = parser.parse(path)
|
||||
context = XPathContext(root)
|
||||
return root_token.select(context)
|
||||
return root_token.select_results(context)
|
||||
|
||||
|
||||
class Selector(object):
|
||||
|
@ -98,4 +98,4 @@ class Selector(object):
|
|||
:return: A generator of the XPath expression results.
|
||||
"""
|
||||
context = XPathContext(root)
|
||||
return self.root_token.select(context)
|
||||
return self.root_token.select_results(context)
|
||||
|
|
|
@ -27,10 +27,10 @@ from decimal import Decimal
|
|||
from .compat import string_base_type, unicode_type
|
||||
from .exceptions import xpath_error
|
||||
from .namespaces import XQT_ERRORS_NAMESPACE
|
||||
from .xpath_nodes import AttributeNode, TypedElement, is_etree_element, is_attribute_node, \
|
||||
elem_iter_strings, is_text_node, is_namespace_node, is_comment_node, \
|
||||
is_processing_instruction_node, is_element_node, is_document_node, \
|
||||
is_xpath_node, is_schema_node
|
||||
from .xpath_nodes import AttributeNode, NamespaceNode, TypedElement, is_etree_element, \
|
||||
is_attribute_node, elem_iter_strings, is_text_node, is_namespace_node, \
|
||||
is_comment_node, is_processing_instruction_node, is_element_node, \
|
||||
is_document_node, is_xpath_node, is_schema_node
|
||||
from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, XSD_BUILTIN_TYPES
|
||||
from .tdop_parser import Token
|
||||
|
||||
|
@ -52,8 +52,6 @@ def ordinal(n):
|
|||
|
||||
class XPathToken(Token):
|
||||
"""Base class for XPath tokens."""
|
||||
|
||||
is_root = False # Flag that is set to True for root token instances
|
||||
comment = None # for XPath 2.0+ comments
|
||||
xsd_type = None # fox XPath 2.0+ schema types labeling
|
||||
|
||||
|
@ -279,6 +277,22 @@ class XPathToken(Token):
|
|||
return [(self.data_value(value1), self.data_value(value2))
|
||||
for value1 in operand1 for value2 in operand2]
|
||||
|
||||
def select_results(self, context):
|
||||
"""
|
||||
Generates formatted XPath results.
|
||||
|
||||
:param context: the XPath dynamic context.
|
||||
"""
|
||||
for result in self.select(context):
|
||||
if not isinstance(result, tuple):
|
||||
yield result # not a namedtuple-wrapped result
|
||||
elif hasattr(result[0], 'type'):
|
||||
yield result[0] # an XSD schema attribute/element
|
||||
elif not isinstance(result, NamespaceNode):
|
||||
yield result[1]
|
||||
else:
|
||||
yield result
|
||||
|
||||
def get_results(self, context):
|
||||
"""
|
||||
Returns formatted XPath results.
|
||||
|
@ -287,7 +301,7 @@ class XPathToken(Token):
|
|||
:return: a list or a simple datatype when the result is a single simple type \
|
||||
generated by a literal or function token.
|
||||
"""
|
||||
results = list(self.select(context))
|
||||
results = [x for x in self.select_results(context)]
|
||||
if len(results) == 1:
|
||||
res = results[0]
|
||||
if isinstance(res, (bool, int, float, Decimal)):
|
||||
|
|
|
@ -259,7 +259,7 @@ class XPath2ParserTest(test_xpath1_parser.XPath1ParserTest):
|
|||
self.check_value('($a, $b) = ($c, 2.0)', True, context=context)
|
||||
|
||||
root = self.etree.XML('<root min="10" max="7"/>')
|
||||
self.check_value('@min', ['10'], context=XPathContext(root=root))
|
||||
self.check_value('@min', [AttributeNode('min', '10')], context=XPathContext(root=root))
|
||||
self.check_value('@min le @max', True, context=XPathContext(root=root))
|
||||
root = self.etree.XML('<root min="80" max="7"/>')
|
||||
self.check_value('@min le @max', False, context=XPathContext(root=root))
|
||||
|
|
Loading…
Reference in New Issue