From 852d50c0878bb883200bba088dfa76fb36295d43 Mon Sep 17 00:00:00 2001 From: Davide Brunato Date: Thu, 8 Aug 2019 09:23:09 +0200 Subject: [PATCH] Add get_operands() instance method to XPathToken class --- elementpath/tdop_parser.py | 2 + elementpath/xpath1_parser.py | 83 +++++++++++++++++++----------------- elementpath/xpath_token.py | 36 +++++++++++++--- tests/test_xpath1_parser.py | 13 +++--- 4 files changed, 80 insertions(+), 54 deletions(-) diff --git a/elementpath/tdop_parser.py b/elementpath/tdop_parser.py index ba1c366..530efd6 100644 --- a/elementpath/tdop_parser.py +++ b/elementpath/tdop_parser.py @@ -241,6 +241,8 @@ class Token(MutableSequence): symbol = self.symbol if symbol == '(name)': return self.value + elif symbol == '(decimal)': + return str(self.value) elif SPECIAL_SYMBOL_PATTERN.match(symbol) is not None: return repr(self.value) else: diff --git a/elementpath/xpath1_parser.py b/elementpath/xpath1_parser.py index 79bba69..7dad23d 100644 --- a/elementpath/xpath1_parser.py +++ b/elementpath/xpath1_parser.py @@ -15,7 +15,7 @@ import decimal from .compat import PY3, string_base_type from .exceptions import ElementPathSyntaxError, ElementPathNameError, MissingContextError from .datatypes import UntypedAtomic, DayTimeDuration, YearMonthDuration, \ - NumericTypeProxy, XSD_BUILTIN_TYPES + NumericTypeProxy, ArithmeticTypeProxy, XSD_BUILTIN_TYPES from .xpath_context import XPathSchemaContext from .tdop_parser import Parser, MultiLabel from .namespaces import XML_ID, XML_LANG, XPATH_1_DEFAULT_NAMESPACES, \ @@ -561,55 +561,59 @@ def evaluate(self, context=None): return elif len(self) == 1: arg = self.get_argument(context, cls=NumericTypeProxy) - if arg is None: - return - try: - return +arg - except TypeError: - raise self.wrong_type("numeric value is required: %r" % arg) + if arg is not None: + try: + return +arg + except TypeError: + raise self.wrong_type("numeric value is required: %r" % arg) else: - arg1 = self.get_argument(context) - arg2 = self.get_argument(context, index=1) - if arg1 is None or arg2 is None: - return - elif isinstance(arg1, string_base_type): - if isinstance(arg2, string_base_type): - raise self.wrong_type("unsupported operands %r and %r" % (arg1, arg2)) - elif isinstance(arg2, NumericTypeProxy): - arg1 = float(arg1) - - try: - return self[0].evaluate(context) + self[1].evaluate(context) - except TypeError as err: - raise self.wrong_type(str(err)) + op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy) + if op1 is not None: + try: + return op1 + op2 + except TypeError as err: + raise self.wrong_type(str(err)) @method(infix('-', bp=40)) def evaluate(self, context=None): - try: - try: - return self[0].evaluate(context) - self[1].evaluate(context) - except TypeError: - self.wrong_type("values must be numeric: %r" % [tk.evaluate(context) for tk in self]) - except IndexError: - try: - return -self[0].evaluate(context) - except TypeError: - self.wrong_type("value must be numeric: %r" % self[0].evaluate(context)) + if len(self) == 1: + arg = self.get_argument(context, cls=NumericTypeProxy) + if arg is not None: + try: + return -arg + except TypeError: + raise self.wrong_type("numeric value is required: %r" % arg) + else: + op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy) + if op1 is not None: + try: + return op1 - op2 + except TypeError as err: + raise self.wrong_type(str(err)) @method(infix('*', bp=45)) def evaluate(self, context=None): if self: - return self[0].evaluate(context) * self[1].evaluate(context) + op1, op2 = self.get_operands(context, cls=ArithmeticTypeProxy) + if op1 is not None: + try: + return op1 * op2 + except TypeError as err: + raise self.wrong_type(str(err)) @method(infix('div', bp=45)) def evaluate(self, context=None): - dividend = self[0].evaluate(context) - divisor = self[1].evaluate(context) - if divisor != 0: - return dividend / divisor + dividend, divisor = self.get_operands(context, cls=ArithmeticTypeProxy) + if dividend is None: + return + elif divisor != 0: + try: + return dividend / divisor + except TypeError as err: + raise self.wrong_type(str(err)) elif dividend == 0: return float('nan') elif dividend > 0: @@ -620,11 +624,10 @@ def evaluate(self, context=None): @method(infix('mod', bp=45)) def evaluate(self, context=None): - arg1 = self.get_argument(context, cls=NumericTypeProxy) - arg2 = self.get_argument(context, index=1, cls=NumericTypeProxy) - if arg1 is not None and arg2 is not None: + op1, op2 = self.get_operands(context, cls=NumericTypeProxy) + if op1 is not None: try: - return arg1 % arg2 + return op1 % op2 except TypeError as err: raise self.wrong_type(str(err)) diff --git a/elementpath/xpath_token.py b/elementpath/xpath_token.py index 26ee019..aebb2b0 100644 --- a/elementpath/xpath_token.py +++ b/elementpath/xpath_token.py @@ -21,13 +21,14 @@ for documents. Generic tuples are used for representing attributes and named-tup """ import locale import contextlib +from decimal import Decimal from .compat import string_base_type from .exceptions import xpath_error from .namespaces import XQT_ERRORS_NAMESPACE from .xpath_nodes import AttributeNode, is_etree_element, \ is_element_node, is_document_node, is_xpath_node, node_string_value -from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, NumericTypeProxy, XSD_BUILTIN_TYPES +from .datatypes import UntypedAtomic, Timezone, DayTimeDuration, XSD_BUILTIN_TYPES from .tdop_parser import Token @@ -166,7 +167,7 @@ class XPathToken(Token): if self.parser.compatibility_mode: if issubclass(cls, string_base_type): return self.string_value(item) - elif issubclass(cls, float) or cls is NumericTypeProxy: + elif issubclass(cls, float) or issubclass(float, cls): return self.number_value(item) if self.parser.version > '1.0': @@ -175,16 +176,12 @@ class XPathToken(Token): return value elif isinstance(value, UntypedAtomic): try: - if cls is NumericTypeProxy: - return float(value) - elif issubclass(cls, string_base_type): + if issubclass(cls, string_base_type): return str(value) else: return cls(value) except (TypeError, ValueError): pass - elif issubclass(cls, float) and isinstance(value, NumericTypeProxy): - return self.number_value(value) code = 'XPTY0004' if self.label == 'function' else 'FORG0006' message = "the %s argument %r is not an instance of %r" @@ -289,6 +286,31 @@ class XPathToken(Token): else: return results + def get_operands(self, context, cls=None): + """ + Returns the operands for a binary operator. Float arguments are converted + to decimal if the other argument is a `Decimal` instance. + + :param context: the XPath dynamic context. + :param cls: if a type is provided performs a type checking on item. + :return: a couple of values representing the operands. If any operand \ + is not available returns a `(None, None)` couple. + """ + arg1 = self.get_argument(context, cls=cls) + if arg1 is None: + return None, None + + arg2 = self.get_argument(context, index=1, cls=cls) + if arg2 is None: + return None, None + + if isinstance(arg1, Decimal) and isinstance(arg2, float): + return arg1, Decimal(arg2) + elif isinstance(arg2, Decimal) and isinstance(arg1, float): + return Decimal(arg1), arg2 + + return arg1, arg2 + def adjust_datetime(self, context, cls): """ XSD datetime adjust function helper. diff --git a/tests/test_xpath1_parser.py b/tests/test_xpath1_parser.py index b0980d1..9f7e4f9 100644 --- a/tests/test_xpath1_parser.py +++ b/tests/test_xpath1_parser.py @@ -818,15 +818,14 @@ class XPath1ParserTest(unittest.TestCase): root = self.etree.XML(XML_DATA_TEST) if self.parser.version == '1.0': - self.check_value("'9' - 5.0", 4) - - self.check_selector("/values/a mod 2", root, [1.4]) - self.check_value("/values/b mod 2", float('nan'), context=XPathContext(root)) + self.check_value("'9' + 5.0", 14) + self.check_selector("/values/a + 2", root, [5.4]) + self.check_value("/values/b + 2", float('nan'), context=XPathContext(root)) else: - self.check_selector("/values/a mod 2", root, TypeError) - self.check_value("/values/b mod 2", TypeError, context=XPathContext(root)) + self.check_selector("/values/a + 2", root, TypeError) + self.check_value("/values/b + 2", TypeError, context=XPathContext(root)) - self.check_selector("/values/d mod 3", root, [2]) + self.check_selector("/values/d + 3", root, [47]) def test_numerical_mod_operator(self): self.check_value("11 mod 3", 2)