# -*- coding: utf-8 -*- # # Copyright (c), 2018-2019, SISSA (International School for Advanced Studies). # All rights reserved. # This file is distributed under the terms of the MIT License. # See the file 'LICENSE' in the root directory of the present # distribution, or http://opensource.org/licenses/MIT. # # @author Davide Brunato # """ XPath 2.0 implementation - part 2 (functions) """ import sys import decimal import math import datetime import time import re import locale import unicodedata from .compat import PY3, string_base_type, unicode_chr, urlparse, urljoin, urllib_quote, unicode_type from .datatypes import QNAME_PATTERN, DateTime10, Date10, Time, Timezone, Duration, DayTimeDuration from .namespaces import prefixed_to_qname, get_namespace from .xpath_context import XPathSchemaContext from .xpath_nodes import is_document_node, is_xpath_node, is_element_node, \ is_attribute_node, node_name, node_nilled, node_base_uri, node_document_uri from .xpath2_parser import XPath2Parser method = XPath2Parser.method function = XPath2Parser.function WRONG_REPLACEMENT_PATTERN = re.compile(r'(? 1: with self.use_locale(collation=self.get_argument(context, 1)): return func(arg) return func(arg) except TypeError as err: self.wrong_type(str(err)) except ValueError: return [] ### # General functions for sequences @method(function('empty', nargs=1)) @method(function('exists', nargs=1)) def evaluate(self, context=None): return next(iter(self.select(context))) @method('empty') def select(self, context=None): try: next(iter(self[0].select(context))) except StopIteration: yield True else: yield False @method('exists') def select(self, context=None): try: next(iter(self[0].select(context))) except StopIteration: yield False else: yield True @method(function('distinct-values', nargs=(1, 2))) def select(self, context=None): nan = False results = [] for item in self[0].select(context): value = self.data_value(item) if context is not None: context.item = value if not nan and isinstance(value, float) and math.isnan(value): yield value nan = True elif value not in results: yield value results.append(value) @method(function('insert-before', nargs=3)) def select(self, context=None): insert_at_pos = max(0, self[1].value - 1) inserted = False for pos, result in enumerate(self[0].select(context)): if not inserted and pos == insert_at_pos: for item in self[2].select(context): yield item inserted = True yield result if not inserted: for item in self[2].select(context): yield item @method(function('index-of', nargs=(1, 3))) def select(self, context=None): value = self[1].evaluate(context) for pos, result in enumerate(self[0].select(context)): if result == value: yield pos + 1 @method(function('remove', nargs=2)) def select(self, context=None): target = self[1].evaluate(context) - 1 for pos, result in enumerate(self[0].select(context)): if pos != target: yield result @method(function('reverse', nargs=1)) def select(self, context=None): for result in reversed(list(self[0].select(context))): yield result @method(function('subsequence', nargs=(2, 3))) def select(self, context=None): starting_loc = self[1].evaluate(context) - 1 length = self[2].evaluate(context) if len(self) >= 3 else 0 for pos, result in enumerate(self[0].select(context)): if starting_loc <= pos and (not length or pos < starting_loc + length): yield result @method(function('unordered', nargs=1)) def select(self, context=None): for result in sorted(list(self[0].select(context)), key=lambda x: self.string_value(x)): yield result ### # Cardinality functions for sequences @method(function('zero-or-one', nargs=1)) def select(self, context=None): results = iter(self[0].select(context)) try: item = next(results) except StopIteration: return try: next(results) except StopIteration: yield item else: raise self.error('FORG0003') @method(function('one-or-more', nargs=1)) def select(self, context=None): results = iter(self[0].select(context)) try: item = next(results) except StopIteration: raise self.error('FORG0004') else: yield item while True: try: yield next(results) except StopIteration: break @method(function('exactly-one', nargs=1)) def select(self, context=None): results = iter(self[0].select(context)) try: item = next(results) except StopIteration: raise self.error('FORG0005') else: try: next(results) except StopIteration: yield item else: raise self.error('FORG0005') ### # Regex @method(function('matches', nargs=(2, 3))) def evaluate(self, context=None): input_string = self.get_argument(context, default='', cls=string_base_type) pattern = self.get_argument(context, 1, required=True, cls=string_base_type) flags = 0 if len(self) > 2: for c in self.get_argument(context, 2, required=True, cls=string_base_type): if c in 'smix': flags |= getattr(re, c.upper()) else: raise self.error('FORX0001', "Invalid regular expression flag %r" % c) try: return re.search(pattern, input_string, flags=flags) is not None except re.error: raise self.error('FORX0002', "Invalid regular expression %r" % pattern) # TODO: full XML regex syntax @method(function('replace', nargs=(3, 4))) def evaluate(self, context=None): input_string = self.get_argument(context, default='', cls=string_base_type) pattern = self.get_argument(context, 1, required=True, cls=string_base_type) replacement = self.get_argument(context, 2, required=True, cls=string_base_type) flags = 0 if len(self) > 3: for c in self.get_argument(context, 3, required=True, cls=string_base_type): if c in 'smix': flags |= getattr(re, c.upper()) else: raise self.error('FORX0001', "Invalid regular expression flag %r" % c) try: pattern = re.compile(pattern, flags=flags) except re.error: raise self.error('FORX0002', "Invalid regular expression %r" % pattern) # TODO: full XML regex syntax else: if pattern.search(''): raise self.error('FORX0003', "Regular expression %r matches zero-length string" % pattern.pattern) elif WRONG_REPLACEMENT_PATTERN.search(replacement): raise self.error('FORX0004', "Invalid replacement string %r" % replacement) else: if sys.version_info >= (3, 5): for g in range(pattern.groups + 1): if '$%d' % g in replacement: replacement = re.sub(r'(?' % g, replacement) else: match = pattern.search(input_string) for g in range(pattern.groups + 1): if '$%d' % g in replacement: if match and match.group(g) is not None: replacement = re.sub(r'(?' % g, replacement) else: replacement = re.sub(r'(? 2: for c in self.get_argument(context, 2, required=True, cls=string_base_type): if c in 'smix': flags |= getattr(re, c.upper()) else: raise self.error('FORX0001', "Invalid regular expression flag %r" % c) try: pattern = re.compile(pattern, flags=flags) except re.error: raise self.error('FORX0002', "Invalid regular expression %r" % pattern) else: if pattern.search(''): raise self.error('FORX0003', "Regular expression %r matches zero-length string" % pattern.pattern) if input_string: for value in pattern.split(input_string): if value is not None and pattern.search(value) is None: yield value ### # Functions on anyURI @method(function('resolve-uri', nargs=(1, 2))) def evaluate(self, context=None): relative = self.get_argument(context, cls=string_base_type) if len(self) == 2: base_uri = self.get_argument(context, index=1, required=True, cls=string_base_type) base_uri = urlparse(base_uri).geturl() elif self.parser.base_uri is None: raise self.error('FONS0005') else: base_uri = self.parser.base_uri if relative is not None: url_parts = urlparse(relative) if url_parts.path.startswith('/'): return relative elif url_parts.scheme: return urljoin(base_uri, relative.split(':')[1]) else: return urljoin(base_uri, relative) ### # String functions @method(function('codepoints-to-string', nargs=1)) def evaluate(self, context=None): return ''.join(unicode_chr(cp) for cp in self[0].select(context)) @method(function('string-to-codepoints', nargs=1)) def select(self, context=None): for char in self[0].evaluate(context): yield ord(char) @method(function('compare', nargs=(2, 3))) def evaluate(self, context=None): comp1 = self.get_argument(context, 0, cls=string_base_type) comp2 = self.get_argument(context, 1, cls=string_base_type) if comp1 is None or comp2 is None: return [] if len(self) < 3: locale.setlocale(locale.LC_ALL, '') value = locale.strcoll(comp1, comp2) else: with self.use_locale(collation=self.get_argument(context, 2)): value = locale.strcoll(comp1, comp2) return 1 if value > 0 else -1 if value < 0 else 0 @method(function('codepoint-equal', nargs=2)) def evaluate(self, context=None): comp1 = self.get_argument(context, 0, cls=string_base_type) comp2 = self.get_argument(context, 1, cls=string_base_type) if comp1 is None or comp2 is None: return [] elif len(comp1) != len(comp2): return False else: return all(ord(c1) == ord(c2) for c1, c2 in zip(comp1, comp2)) @method(function('string-join', nargs=2)) def evaluate(self, context=None): items = [self.string_value(s) if is_element_node(s) else s for s in self[0].select(context)] try: return self.get_argument(context, 1, cls=string_base_type).join(items) except AttributeError as err: self.wrong_type("the separator must be a string: %s" % err) except TypeError as err: self.wrong_type("the values must be strings: %s" % err) @method(function('normalize-unicode', nargs=(1, 2))) def evaluate(self, context=None): arg = self.get_argument(context, default='', cls=string_base_type) if len(self) > 1: normalization_form = self.get_argument(context, 1, cls=string_base_type) if normalization_form is None: self.wrong_type("2nd argument can't be an empty sequence") else: normalization_form = normalization_form.strip().upper() else: normalization_form = 'NFC' if normalization_form == 'FULLY-NORMALIZED': raise NotImplementedError("%r normalization form not supported" % normalization_form) if arg is None: return '' elif not isinstance(arg, unicode_type): arg = arg.decode('utf-8') try: return unicodedata.normalize(normalization_form, arg) except ValueError: raise self.error('FOCH0003', "unsupported normalization form %r" % normalization_form) @method(function('upper-case', nargs=1)) def evaluate(self, context=None): arg = self.get_argument(context, cls=string_base_type) try: return '' if arg is None else arg.upper() except AttributeError: self.wrong_type("the argument must be a string: %r" % arg) @method(function('lower-case', nargs=1)) def evaluate(self, context=None): arg = self.get_argument(context, cls=string_base_type) try: return '' if arg is None else arg.lower() except AttributeError: self.wrong_type("the argument must be a string: %r" % arg) @method(function('encode-for-uri', nargs=1)) def evaluate(self, context=None): uri_part = self.get_argument(context, cls=string_base_type) try: return '' if uri_part is None else urllib_quote(uri_part, safe='~') except TypeError: self.wrong_type("the argument must be a string: %r" % uri_part) @method(function('iri-to-uri', nargs=1)) def evaluate(self, context=None): iri = self.get_argument(context, cls=string_base_type) try: return '' if iri is None else urllib_quote(iri, safe='-_.!~*\'()#;/?:@&=+$,[]%') except TypeError: self.wrong_type("the argument must be a string: %r" % iri) @method(function('escape-html-uri', nargs=1)) def evaluate(self, context=None): uri = self.get_argument(context, cls=string_base_type) try: return '' if uri is None else urllib_quote(uri, safe=''.join(chr(cp) for cp in range(32, 127))) except TypeError: self.wrong_type("the argument must be a string: %r" % uri) @method(function('starts-with', nargs=(2, 3))) def evaluate(self, context=None): arg1 = self.get_argument(context, default='', cls=string_base_type) arg2 = self.get_argument(context, index=1, default='', cls=string_base_type) return arg1.startswith(arg2) @method(function('ends-with', nargs=(2, 3))) def evaluate(self, context=None): arg1 = self.get_argument(context, default='', cls=string_base_type) arg2 = self.get_argument(context, index=1, default='', cls=string_base_type) return arg1.endswith(arg2) ### # Functions on durations, dates and times @method(function('years-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.months // 12 if item.months >= 0 else -(abs(item.months) // 12) @method(function('months-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.months % 12 if item.months >= 0 else -(abs(item.months) % 12) @method(function('days-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.seconds // 86400 if item.seconds >= 0 else -(abs(item.seconds) // 86400) @method(function('hours-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.seconds // 3600 % 24 if item.seconds >= 0 else -(abs(item.seconds) // 3600 % 24) @method(function('minutes-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.seconds // 60 % 60 if item.seconds >= 0 else -(abs(item.seconds) // 60 % 60) @method(function('seconds-from-duration', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Duration) if item is None: return [] else: return item.seconds % 60 if item.seconds >= 0 else -(abs(item.seconds) % 60) @method(function('year-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else -(item.year + 1) if item.bce else item.year @method(function('month-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else item.month @method(function('day-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else item.day @method(function('hours-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else item.hour @method(function('minutes-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else item.minute @method(function('seconds-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else item.second @method(function('timezone-from-dateTime', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=DateTime10) return [] if item is None else DayTimeDuration(seconds=item.tzinfo.offset.total_seconds()) @method(function('year-from-date', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Date10) return [] if item is None else item.year @method(function('month-from-date', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Date10) return [] if item is None else item.month @method(function('day-from-date', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Date10) return [] if item is None else item.day @method(function('timezone-from-date', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Date10) return [] if item is None else DayTimeDuration(seconds=item.tzinfo.offset.total_seconds()) @method(function('hours-from-time', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Time) return [] if item is None else item.hour @method(function('minutes-from-time', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Time) return [] if item is None else item.minute @method(function('seconds-from-time', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Time) return [] if item is None else item.second + item.microsecond / 1000000.0 @method(function('timezone-from-time', nargs=1)) def evaluate(self, context=None): item = self.get_argument(context, cls=Time) return [] if item is None else DayTimeDuration(seconds=item.tzinfo.offset.total_seconds()) ### # Timezone adjustment functions @method(function('adjust-dateTime-to-timezone', nargs=(1, 2))) def evaluate(self, context=None): return self.adjust_datetime(context, DateTime10) @method(function('adjust-date-to-timezone', nargs=(1, 2))) def evaluate(self, context=None): return self.adjust_datetime(context, Date10) @method(function('adjust-time-to-timezone', nargs=(1, 2))) def evaluate(self, context=None): return self.adjust_datetime(context, Time) ### # Context functions @method(function('current-dateTime', nargs=0)) def evaluate(self, context=None): dt = datetime.datetime.now() if context is None else context.current_dt return DateTime10(dt.year, dt.month, dt.day, dt.hour, dt.minute, dt.second, dt.microsecond, dt.tzinfo) @method(function('current-date', nargs=0)) def evaluate(self, context=None): dt = datetime.datetime.now() if context is None else context.current_dt return Date10(dt.year, dt.month, dt.day, tzinfo=dt.tzinfo) @method(function('current-time', nargs=0)) def evaluate(self, context=None): dt = datetime.datetime.now() if context is None else context.current_dt return Time(dt.hour, dt.minute, dt.second, dt.microsecond, dt.tzinfo) @method(function('implicit-timezone', nargs=0)) def evaluate(self, context=None): if context is not None and context.timezone is not None: return context.timezone else: return Timezone(datetime.timedelta(seconds=time.timezone)) @method(function('static-base-uri', nargs=0)) def evaluate(self, context=None): if self.parser.base_uri is not None: return self.parser.base_uri ### # The root function (Ref: https://www.w3.org/TR/2010/REC-xpath-functions-20101214/#func-root) @method(function('root', nargs=(0, 1))) def evaluate(self, context=None): if self: item = self.get_argument(context) elif context is None: raise self.error('XPDY0002') else: item = context.item if item is None: return [] elif is_xpath_node(item): return item else: raise self.error('XPTY0004') ### # The error function (Ref: https://www.w3.org/TR/xpath20/#func-error) @method(function('error', nargs=(0, 3))) def evaluate(self, context=None): if not self: raise self.error('FOER0000') elif len(self) == 1: item = self.get_argument(context) raise self.error(item or 'FOER0000') # XPath 2.0 definitions continue into module xpath2_constructors