| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356 |
- """
- Some utility functions that are only used for unittests.
- Placing them in test/ directory seems to be against convention, so they are part of the library.
- """
- from __future__ import print_function, division, absolute_import
- import random
- import copy
- import warnings
- import tempfile
- import shutil
- import re
- import sys
- import numpy as np
- import six.moves as sm
- # unittest.mock is not available in 2.7 (though unittest2 might contain it?)
- try:
- import unittest.mock as mock
- except ImportError:
- import mock
- try:
- import cPickle as pickle
- except ImportError:
- import pickle
- import imgaug as ia
- import imgaug.random as iarandom
- from imgaug.augmentables.kps import KeypointsOnImage
- class ArgCopyingMagicMock(mock.MagicMock):
- """A MagicMock that copies its call args/kwargs before storing the call.
- This is useful for imgaug as many augmentation methods change data
- in-place.
- Taken from https://stackoverflow.com/a/23264042/3760780
- """
- def _mock_call(self, *args, **kwargs):
- args_copy = copy.deepcopy(args)
- kwargs_copy = copy.deepcopy(kwargs)
- return super(ArgCopyingMagicMock, self)._mock_call(
- *args_copy, **kwargs_copy)
- # Added in 0.4.0.
- def assert_cbaois_equal(observed, expected, max_distance=1e-4):
- # pylint: disable=unidiomatic-typecheck
- if isinstance(observed, list) or isinstance(expected, list):
- assert isinstance(observed, list)
- assert isinstance(expected, list)
- assert len(observed) == len(expected)
- for observed_i, expected_i in zip(observed, expected):
- assert_cbaois_equal(observed_i, expected_i,
- max_distance=max_distance)
- else:
- assert type(observed) == type(expected)
- assert len(observed.items) == len(expected.items)
- assert observed.shape == expected.shape
- for item_a, item_b in zip(observed.items, expected.items):
- assert item_a.coords_almost_equals(item_b,
- max_distance=max_distance)
- if isinstance(expected, ia.PolygonsOnImage):
- for item_obs, item_exp in zip(observed.items, expected.items):
- if item_exp.is_valid:
- assert item_obs.is_valid
- def create_random_images(size):
- return np.random.uniform(0, 255, size).astype(np.uint8)
- def create_random_keypoints(size_images, nb_keypoints_per_img):
- result = []
- for _ in sm.xrange(size_images[0]):
- kps = []
- height, width = size_images[1], size_images[2]
- for _ in sm.xrange(nb_keypoints_per_img):
- x = np.random.randint(0, width-1)
- y = np.random.randint(0, height-1)
- kps.append(ia.Keypoint(x=x, y=y))
- result.append(ia.KeypointsOnImage(kps, shape=size_images[1:]))
- return result
- def array_equal_lists(list1, list2):
- assert isinstance(list1, list), (
- "Expected list1 to be a list, got type %s." % (type(list1),))
- assert isinstance(list2, list), (
- "Expected list2 to be a list, got type %s." % (type(list2),))
- if len(list1) != len(list2):
- return False
- for arr1, arr2 in zip(list1, list2):
- if not np.array_equal(arr1, arr2):
- return False
- return True
- def keypoints_equal(kpsois1, kpsois2, eps=0.001):
- if isinstance(kpsois1, KeypointsOnImage):
- assert isinstance(kpsois2, KeypointsOnImage)
- kpsois1 = [kpsois1]
- kpsois2 = [kpsois2]
- if len(kpsois1) != len(kpsois2):
- return False
- for kpsoi1, kpsoi2 in zip(kpsois1, kpsois2):
- kps1 = kpsoi1.keypoints
- kps2 = kpsoi2.keypoints
- if len(kps1) != len(kps2):
- return False
- for kp1, kp2 in zip(kps1, kps2):
- x_equal = (float(kp2.x) - eps
- <= float(kp1.x)
- <= float(kp2.x) + eps)
- y_equal = (float(kp2.y) - eps
- <= float(kp1.y)
- <= float(kp2.y) + eps)
- if not x_equal or not y_equal:
- return False
- return True
- def reseed(seed=0):
- iarandom.seed(seed)
- np.random.seed(seed)
- random.seed(seed)
- # Added in 0.4.0.
- def runtest_pickleable_uint8_img(augmenter, shape=(15, 15, 3), iterations=3):
- image = np.mod(np.arange(int(np.prod(shape))), 256).astype(np.uint8)
- image = image.reshape(shape)
- augmenter_pkl = pickle.loads(pickle.dumps(augmenter, protocol=-1))
- for _ in np.arange(iterations):
- image_aug = augmenter(image=image)
- image_aug_pkl = augmenter_pkl(image=image)
- assert np.array_equal(image_aug, image_aug_pkl)
- def wrap_shift_deprecation(func, *args, **kwargs):
- """Helper for tests of CBA shift() functions.
- Added in 0.4.0.
- """
- # No deprecated arguments? Just call the functions directly.
- deprecated_kwargs = ["top", "right", "bottom", "left"]
- if not any([kwname in kwargs for kwname in deprecated_kwargs]):
- return func()
- # Deprecated arguments? Log warnings and assume that there was a
- # deprecation warning with expected message.
- with warnings.catch_warnings(record=True) as caught_warnings:
- warnings.simplefilter("always")
- result = func()
- assert (
- "These are deprecated. Use `x` and `y` instead."
- in str(caught_warnings[-1].message)
- )
- return result
- class TemporaryDirectory(object):
- """Create a context for a temporary directory.
- The directory is automatically removed at the end of the context.
- This context is available in ``tmpfile.TemporaryDirectory``, but only
- from 3.2+.
- Added in 0.4.0.
- """
- def __init__(self, suffix="", prefix="tmp", dir=None):
- # pylint: disable=redefined-builtin
- self.name = tempfile.mkdtemp(suffix, prefix, dir)
- def __enter__(self):
- return self.name
- def __exit__(self, exc_type, exc_val, exc_tb):
- shutil.rmtree(self.name)
- # Copied from
- # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
- # at commit 293dd23 (Nov 19, 2019).
- # Required at least to enable assertWarns() in python <3.2.
- # Added in 0.4.0.
- def _is_subtype(expected, basetype):
- if isinstance(expected, tuple):
- return all(_is_subtype(e, basetype) for e in expected)
- return isinstance(expected, type) and issubclass(expected, basetype)
- # Copied from
- # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
- # at commit 293dd23 (Nov 19, 2019).
- # Required at least to enable assertWarns() in python <3.2.
- # Added in 0.4.0.
- class _BaseTestCaseContext:
- # Added in 0.4.0.
- def __init__(self, test_case):
- self.test_case = test_case
- # Added in 0.4.0.
- def _raiseFailure(self, standardMsg):
- # pylint: disable=invalid-name, protected-access, no-member
- msg = self.test_case._formatMessage(self.msg, standardMsg)
- raise self.test_case.failureException(msg)
- # Copied from
- # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
- # at commit 293dd23 (Nov 19, 2019).
- # Required at least to enable assertWarns() in python <3.2.
- # Added in 0.4.0.
- class _AssertRaisesBaseContext(_BaseTestCaseContext):
- # Added in 0.4.0.
- def __init__(self, expected, test_case, expected_regex=None):
- _BaseTestCaseContext.__init__(self, test_case)
- self.expected = expected
- self.test_case = test_case
- if expected_regex is not None:
- expected_regex = re.compile(expected_regex)
- self.expected_regex = expected_regex
- self.obj_name = None
- self.msg = None
- # Added in 0.4.0.
- # pylint: disable=inconsistent-return-statements
- def handle(self, name, args, kwargs):
- """
- If args is empty, assertRaises/Warns is being used as a
- context manager, so check for a 'msg' kwarg and return self.
- If args is not empty, call a callable passing positional and keyword
- arguments.
- """
- # pylint: disable=no-member, self-cls-assignment, not-context-manager
- try:
- if not _is_subtype(self.expected, self._base_type):
- raise TypeError('%s() arg 1 must be %s' %
- (name, self._base_type_str))
- if not args:
- self.msg = kwargs.pop('msg', None)
- if kwargs:
- raise TypeError('%r is an invalid keyword argument for '
- 'this function' % (next(iter(kwargs)),))
- return self
- callable_obj = args[0]
- args = args[1:]
- try:
- self.obj_name = callable_obj.__name__
- except AttributeError:
- self.obj_name = str(callable_obj)
- with self:
- callable_obj(*args, **kwargs)
- finally:
- # bpo-23890: manually break a reference cycle
- self = None
- # pylint: enable=inconsistent-return-statements
- # Copied from
- # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
- # at commit 293dd23 (Nov 19, 2019).
- # Required at least to enable assertWarns() in python <3.2.
- # Added in 0.4.0.
- class _AssertWarnsContext(_AssertRaisesBaseContext):
- """A context manager used to implement TestCase.assertWarns* methods."""
- _base_type = Warning
- _base_type_str = 'a warning type or tuple of warning types'
- # Added in 0.4.0.
- def __enter__(self):
- # The __warningregistry__'s need to be in a pristine state for tests
- # to work properly.
- # pylint: disable=invalid-name, attribute-defined-outside-init
- for v in sys.modules.values():
- if getattr(v, '__warningregistry__', None):
- v.__warningregistry__ = {}
- self.warnings_manager = warnings.catch_warnings(record=True)
- self.warnings = self.warnings_manager.__enter__()
- warnings.simplefilter("always", self.expected)
- return self
- # Added in 0.4.0.
- def __exit__(self, exc_type, exc_value, tb):
- # pylint: disable=invalid-name, attribute-defined-outside-init
- self.warnings_manager.__exit__(exc_type, exc_value, tb)
- if exc_type is not None:
- # let unexpected exceptions pass through
- return
- try:
- exc_name = self.expected.__name__
- except AttributeError:
- exc_name = str(self.expected)
- first_matching = None
- for m in self.warnings:
- w = m.message
- if not isinstance(w, self.expected):
- continue
- if first_matching is None:
- first_matching = w
- if (self.expected_regex is not None and
- not self.expected_regex.search(str(w))):
- continue
- # store warning for later retrieval
- self.warning = w
- self.filename = m.filename
- self.lineno = m.lineno
- return
- # Now we simply try to choose a helpful failure message
- if first_matching is not None:
- self._raiseFailure('"{}" does not match "{}"'.format(
- self.expected_regex.pattern, str(first_matching)))
- if self.obj_name:
- self._raiseFailure("{} not triggered by {}".format(exc_name,
- self.obj_name))
- else:
- self._raiseFailure("{} not triggered".format(exc_name))
- # Partially copied from
- # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
- # at commit 293dd23 (Nov 19, 2019).
- # Required at least to enable assertWarns() in python <3.2.
- def assertWarns(testcase, expected_warning, *args, **kwargs):
- """Context with same functionality as ``assertWarns`` in ``unittest``.
- Note that ``assertWarns`` is only available in python 3.2+.
- Added in 0.4.0.
- """
- # pylint: disable=invalid-name
- context = _AssertWarnsContext(expected_warning, testcase)
- return context.handle('assertWarns', args, kwargs)
|