146 lines
5.0 KiB
Python
146 lines
5.0 KiB
Python
from __future__ import absolute_import
|
|
|
|
import re
|
|
import sys
|
|
import warnings
|
|
|
|
try:
|
|
import unittest # noqa
|
|
unittest.skip
|
|
from unittest.util import safe_repr, unorderable_list_difference
|
|
except AttributeError:
|
|
import unittest2 as unittest # noqa
|
|
from unittest2.util import safe_repr, unorderable_list_difference # noqa
|
|
|
|
from billiard.five import string_t, items, values
|
|
|
|
from .compat import catch_warnings
|
|
|
|
# -- adds assertWarns from recent unittest2, not in Python 2.7.
|
|
|
|
|
|
class _AssertRaisesBaseContext(object):
|
|
|
|
def __init__(self, expected, test_case, callable_obj=None,
|
|
expected_regex=None):
|
|
self.expected = expected
|
|
self.failureException = test_case.failureException
|
|
self.obj_name = None
|
|
if isinstance(expected_regex, string_t):
|
|
expected_regex = re.compile(expected_regex)
|
|
self.expected_regex = expected_regex
|
|
|
|
|
|
class _AssertWarnsContext(_AssertRaisesBaseContext):
|
|
"""A context manager used to implement TestCase.assertWarns* methods."""
|
|
|
|
def __enter__(self):
|
|
# The __warningregistry__'s need to be in a pristine state for tests
|
|
# to work properly.
|
|
warnings.resetwarnings()
|
|
for v in values(sys.modules):
|
|
if getattr(v, '__warningregistry__', None):
|
|
v.__warningregistry__ = {}
|
|
self.warnings_manager = catch_warnings(record=True)
|
|
self.warnings = self.warnings_manager.__enter__()
|
|
warnings.simplefilter('always', self.expected)
|
|
return self
|
|
|
|
def __exit__(self, exc_type, exc_value, tb):
|
|
self.warnings_manager.__exit__(exc_type, exc_value, tb)
|
|
if exc_type is not None:
|
|
# let unexpected exceptions pass through
|
|
return
|
|
try:
|
|
exc_name = self.expected.__name__
|
|
except AttributeError:
|
|
exc_name = str(self.expected)
|
|
first_matching = None
|
|
for m in self.warnings:
|
|
w = m.message
|
|
if not isinstance(w, self.expected):
|
|
continue
|
|
if first_matching is None:
|
|
first_matching = w
|
|
if (self.expected_regex is not None and
|
|
not self.expected_regex.search(str(w))):
|
|
continue
|
|
# store warning for later retrieval
|
|
self.warning = w
|
|
self.filename = m.filename
|
|
self.lineno = m.lineno
|
|
return
|
|
# Now we simply try to choose a helpful failure message
|
|
if first_matching is not None:
|
|
raise self.failureException(
|
|
'%r does not match %r' % (
|
|
self.expected_regex.pattern, str(first_matching)))
|
|
if self.obj_name:
|
|
raise self.failureException(
|
|
'%s not triggered by %s' % (exc_name, self.obj_name))
|
|
else:
|
|
raise self.failureException('%s not triggered' % exc_name)
|
|
|
|
|
|
class Case(unittest.TestCase):
|
|
|
|
def assertWarns(self, expected_warning):
|
|
return _AssertWarnsContext(expected_warning, self, None)
|
|
|
|
def assertWarnsRegex(self, expected_warning, expected_regex):
|
|
return _AssertWarnsContext(expected_warning, self,
|
|
None, expected_regex)
|
|
|
|
def assertDictContainsSubset(self, expected, actual, msg=None):
|
|
missing, mismatched = [], []
|
|
|
|
for key, value in items(expected):
|
|
if key not in actual:
|
|
missing.append(key)
|
|
elif value != actual[key]:
|
|
mismatched.append('%s, expected: %s, actual: %s' % (
|
|
safe_repr(key), safe_repr(value),
|
|
safe_repr(actual[key])))
|
|
|
|
if not (missing or mismatched):
|
|
return
|
|
|
|
standard_msg = ''
|
|
if missing:
|
|
standard_msg = 'Missing: %s' % ','.join(map(safe_repr, missing))
|
|
|
|
if mismatched:
|
|
if standard_msg:
|
|
standard_msg += '; '
|
|
standard_msg += 'Mismatched values: %s' % (
|
|
','.join(mismatched))
|
|
|
|
self.fail(self._formatMessage(msg, standard_msg))
|
|
|
|
def assertItemsEqual(self, expected_seq, actual_seq, msg=None):
|
|
missing = unexpected = None
|
|
try:
|
|
expected = sorted(expected_seq)
|
|
actual = sorted(actual_seq)
|
|
except TypeError:
|
|
# Unsortable items (example: set(), complex(), ...)
|
|
expected = list(expected_seq)
|
|
actual = list(actual_seq)
|
|
missing, unexpected = unorderable_list_difference(
|
|
expected, actual)
|
|
else:
|
|
return self.assertSequenceEqual(expected, actual, msg=msg)
|
|
|
|
errors = []
|
|
if missing:
|
|
errors.append(
|
|
'Expected, but missing:\n %s' % (safe_repr(missing), ),
|
|
)
|
|
if unexpected:
|
|
errors.append(
|
|
'Unexpected, but present:\n %s' % (safe_repr(unexpected), ),
|
|
)
|
|
if errors:
|
|
standardMsg = '\n'.join(errors)
|
|
self.fail(self._formatMessage(msg, standardMsg))
|