Add get_operands() instance method to XPathToken class

This commit is contained in:
Davide Brunato 2019-08-08 09:23:09 +02:00
parent 5ae6d0f0ff
commit 852d50c087
4 changed files with 80 additions and 54 deletions

View File

@ -241,6 +241,8 @@ class Token(MutableSequence):
symbol = self.symbol
if symbol == '(name)':
return self.value
elif symbol == '(decimal)':
return str(self.value)
elif SPECIAL_SYMBOL_PATTERN.match(symbol) is not None:
return repr(self.value)
else:

View File

@ -15,7 +15,7 @@ import decimal
from .compat import PY3, string_base_type
from .exceptions import ElementPathSyntaxError, ElementPathNameError, MissingContextError
from .datatypes import UntypedAtomic, DayTimeDuration, YearMonthDuration, \
NumericTypeProxy, XSD_BUILTIN_TYPES
NumericTypeProxy, ArithmeticTypeProxy, XSD_BUILTIN_TYPES
from .xpath_context import XPathSchemaContext
from .tdop_parser import Parser, MultiLabel
from .namespaces import XML_ID, XML_LANG, XPATH_1_DEFAULT_NAMESPACES, \
@ -561,55 +561,59 @@ def evaluate(self, context=None):
return
elif len(self) == 1:
arg = self.get_argument(context, cls=NumericTypeProxy)
if arg is None:
return
try:
return +arg
except TypeError:
raise self.wrong_type("numeric value is required: %r" % arg)
if arg is not None:
try:
return +arg
except TypeError:
raise self.wrong_type("numeric value is required: %r" % arg)
else:
arg1 = self.get_argument(context)
arg2 = self.get_argument(context, index=1)
if arg1 is None or arg2 is None:
return
elif isinstance(arg1, string_base_type):
if isinstance(arg2, string_base_type):
raise self.wrong_type("unsupported operands %r and %r" % (arg1, arg2))
elif isinstance(arg2, NumericTypeProxy):
arg1 = float(arg1)
try:
return self[0].evaluate(context) + self[1].evaluate(context)
except TypeError as err:
raise self.wrong_type(str(err))
op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy)
if op1 is not None:
try:
return op1 + op2
except TypeError as err:
raise self.wrong_type(str(err))
@method(infix('-', bp=40))
def evaluate(self, context=None):
try:
try:
return self[0].evaluate(context) - self[1].evaluate(context)
except TypeError:
self.wrong_type("values must be numeric: %r" % [tk.evaluate(context) for tk in self])
except IndexError:
try:
return -self[0].evaluate(context)
except TypeError:
self.wrong_type("value must be numeric: %r" % self[0].evaluate(context))
if len(self) == 1:
arg = self.get_argument(context, cls=NumericTypeProxy)
if arg is not None:
try:
return -arg
except TypeError:
raise self.wrong_type("numeric value is required: %r" % arg)
else:
op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy)
if op1 is not None:
try:
return op1 - op2
except TypeError as err:
raise self.wrong_type(str(err))
@method(infix('*', bp=45))
def evaluate(self, context=None):
if self:
return self[0].evaluate(context) * self[1].evaluate(context)
op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy)
if op1 is not None:
try:
return op1 * op2
except TypeError as err:
raise self.wrong_type(str(err))
@method(infix('div', bp=45))
def evaluate(self, context=None):
dividend = self[0].evaluate(context)
divisor = self[1].evaluate(context)
if divisor != 0:
return dividend / divisor
dividend, divisor = self.get_operands(context, cls=ArithmeticTypeProxy)
if dividend is None:
return
elif divisor != 0:
try:
return dividend / divisor
except TypeError as err:
raise self.wrong_type(str(err))
elif dividend == 0:
return float('nan')
elif dividend > 0:
@ -620,11 +624,10 @@ def evaluate(self, context=None):
@method(infix('mod', bp=45))
def evaluate(self, context=None):
arg1 = self.get_argument(context, cls=NumericTypeProxy)
arg2 = self.get_argument(context, index=1, cls=NumericTypeProxy)
if arg1 is not None and arg2 is not None:
op1, op2 = self.get_operands(context, cls=NumericTypeProxy)
if op1 is not None:
try:
return arg1 % arg2
return op1 % op2
except TypeError as err:
raise self.wrong_type(str(err))

View File

@ -21,13 +21,14 @@ for documents. Generic tuples are used for representing attributes and named-tup
"""
import locale
import contextlib
from decimal import Decimal
from .compat import string_base_type
from .exceptions import xpath_error
from .namespaces import XQT_ERRORS_NAMESPACE
from .xpath_nodes import AttributeNode, is_etree_element, \
is_element_node, is_document_node, is_xpath_node, node_string_value
from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, NumericTypeProxy, XSD_BUILTIN_TYPES
from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, XSD_BUILTIN_TYPES
from .tdop_parser import Token
@ -166,7 +167,7 @@ class XPathToken(Token):
if self.parser.compatibility_mode:
if issubclass(cls, string_base_type):
return self.string_value(item)
elif issubclass(cls, float) or cls is NumericTypeProxy:
elif issubclass(cls, float) or issubclass(float, cls):
return self.number_value(item)
if self.parser.version > '1.0':
@ -175,16 +176,12 @@ class XPathToken(Token):
return value
elif isinstance(value, UntypedAtomic):
try:
if cls is NumericTypeProxy:
return float(value)
elif issubclass(cls, string_base_type):
if issubclass(cls, string_base_type):
return str(value)
else:
return cls(value)
except (TypeError, ValueError):
pass
elif issubclass(cls, float) and isinstance(value, NumericTypeProxy):
return self.number_value(value)
code = 'XPTY0004' if self.label == 'function' else 'FORG0006'
message = "the %s argument %r is not an instance of %r"
@ -289,6 +286,31 @@ class XPathToken(Token):
else:
return results
def get_operands(self, context, cls=None):
"""
Returns the operands for a binary operator. Float arguments are converted
to decimal if the other argument is a `Decimal` instance.
:param context: the XPath dynamic context.
:param cls: if a type is provided performs a type checking on item.
:return: a couple of values representing the operands. If any operand \
is not available returns a `(None, None)` couple.
"""
arg1 = self.get_argument(context, cls=cls)
if arg1 is None:
return None, None
arg2 = self.get_argument(context, index=1, cls=cls)
if arg2 is None:
return None, None
if isinstance(arg1, Decimal) and isinstance(arg2, float):
return arg1, Decimal(arg2)
elif isinstance(arg2, Decimal) and isinstance(arg1, float):
return Decimal(arg1), arg2
return arg1, arg2
def adjust_datetime(self, context, cls):
"""
XSD datetime adjust function helper.

View File

@ -818,15 +818,14 @@ class XPath1ParserTest(unittest.TestCase):
root = self.etree.XML(XML_DATA_TEST)
if self.parser.version == '1.0':
self.check_value("'9' - 5.0", 4)
self.check_selector("/values/a mod 2", root, [1.4])
self.check_value("/values/b mod 2", float('nan'), context=XPathContext(root))
self.check_value("'9' + 5.0", 14)
self.check_selector("/values/a + 2", root, [5.4])
self.check_value("/values/b + 2", float('nan'), context=XPathContext(root))
else:
self.check_selector("/values/a mod 2", root, TypeError)
self.check_value("/values/b mod 2", TypeError, context=XPathContext(root))
self.check_selector("/values/a + 2", root, TypeError)
self.check_value("/values/b + 2", TypeError, context=XPathContext(root))
self.check_selector("/values/d mod 3", root, [2])
self.check_selector("/values/d + 3", root, [47])
def test_numerical_mod_operator(self):
self.check_value("11 mod 3", 2)