testutils.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. """
  2. Some utility functions that are only used for unittests.
  3. Placing them in test/ directory seems to be against convention, so they are part of the library.
  4. """
  5. from __future__ import print_function, division, absolute_import
  6. import random
  7. import copy
  8. import warnings
  9. import tempfile
  10. import shutil
  11. import re
  12. import sys
  13. import numpy as np
  14. import six.moves as sm
  15. # unittest.mock is not available in 2.7 (though unittest2 might contain it?)
  16. try:
  17. import unittest.mock as mock
  18. except ImportError:
  19. import mock
  20. try:
  21. import cPickle as pickle
  22. except ImportError:
  23. import pickle
  24. import imgaug as ia
  25. import imgaug.random as iarandom
  26. from imgaug.augmentables.kps import KeypointsOnImage
  27. class ArgCopyingMagicMock(mock.MagicMock):
  28. """A MagicMock that copies its call args/kwargs before storing the call.
  29. This is useful for imgaug as many augmentation methods change data
  30. in-place.
  31. Taken from https://stackoverflow.com/a/23264042/3760780
  32. """
  33. def _mock_call(self, *args, **kwargs):
  34. args_copy = copy.deepcopy(args)
  35. kwargs_copy = copy.deepcopy(kwargs)
  36. return super(ArgCopyingMagicMock, self)._mock_call(
  37. *args_copy, **kwargs_copy)
  38. # Added in 0.4.0.
  39. def assert_cbaois_equal(observed, expected, max_distance=1e-4):
  40. # pylint: disable=unidiomatic-typecheck
  41. if isinstance(observed, list) or isinstance(expected, list):
  42. assert isinstance(observed, list)
  43. assert isinstance(expected, list)
  44. assert len(observed) == len(expected)
  45. for observed_i, expected_i in zip(observed, expected):
  46. assert_cbaois_equal(observed_i, expected_i,
  47. max_distance=max_distance)
  48. else:
  49. assert type(observed) == type(expected)
  50. assert len(observed.items) == len(expected.items)
  51. assert observed.shape == expected.shape
  52. for item_a, item_b in zip(observed.items, expected.items):
  53. assert item_a.coords_almost_equals(item_b,
  54. max_distance=max_distance)
  55. if isinstance(expected, ia.PolygonsOnImage):
  56. for item_obs, item_exp in zip(observed.items, expected.items):
  57. if item_exp.is_valid:
  58. assert item_obs.is_valid
  59. def create_random_images(size):
  60. return np.random.uniform(0, 255, size).astype(np.uint8)
  61. def create_random_keypoints(size_images, nb_keypoints_per_img):
  62. result = []
  63. for _ in sm.xrange(size_images[0]):
  64. kps = []
  65. height, width = size_images[1], size_images[2]
  66. for _ in sm.xrange(nb_keypoints_per_img):
  67. x = np.random.randint(0, width-1)
  68. y = np.random.randint(0, height-1)
  69. kps.append(ia.Keypoint(x=x, y=y))
  70. result.append(ia.KeypointsOnImage(kps, shape=size_images[1:]))
  71. return result
  72. def array_equal_lists(list1, list2):
  73. assert isinstance(list1, list), (
  74. "Expected list1 to be a list, got type %s." % (type(list1),))
  75. assert isinstance(list2, list), (
  76. "Expected list2 to be a list, got type %s." % (type(list2),))
  77. if len(list1) != len(list2):
  78. return False
  79. for arr1, arr2 in zip(list1, list2):
  80. if not np.array_equal(arr1, arr2):
  81. return False
  82. return True
  83. def keypoints_equal(kpsois1, kpsois2, eps=0.001):
  84. if isinstance(kpsois1, KeypointsOnImage):
  85. assert isinstance(kpsois2, KeypointsOnImage)
  86. kpsois1 = [kpsois1]
  87. kpsois2 = [kpsois2]
  88. if len(kpsois1) != len(kpsois2):
  89. return False
  90. for kpsoi1, kpsoi2 in zip(kpsois1, kpsois2):
  91. kps1 = kpsoi1.keypoints
  92. kps2 = kpsoi2.keypoints
  93. if len(kps1) != len(kps2):
  94. return False
  95. for kp1, kp2 in zip(kps1, kps2):
  96. x_equal = (float(kp2.x) - eps
  97. <= float(kp1.x)
  98. <= float(kp2.x) + eps)
  99. y_equal = (float(kp2.y) - eps
  100. <= float(kp1.y)
  101. <= float(kp2.y) + eps)
  102. if not x_equal or not y_equal:
  103. return False
  104. return True
  105. def reseed(seed=0):
  106. iarandom.seed(seed)
  107. np.random.seed(seed)
  108. random.seed(seed)
  109. # Added in 0.4.0.
  110. def runtest_pickleable_uint8_img(augmenter, shape=(15, 15, 3), iterations=3):
  111. image = np.mod(np.arange(int(np.prod(shape))), 256).astype(np.uint8)
  112. image = image.reshape(shape)
  113. augmenter_pkl = pickle.loads(pickle.dumps(augmenter, protocol=-1))
  114. for _ in np.arange(iterations):
  115. image_aug = augmenter(image=image)
  116. image_aug_pkl = augmenter_pkl(image=image)
  117. assert np.array_equal(image_aug, image_aug_pkl)
  118. def wrap_shift_deprecation(func, *args, **kwargs):
  119. """Helper for tests of CBA shift() functions.
  120. Added in 0.4.0.
  121. """
  122. # No deprecated arguments? Just call the functions directly.
  123. deprecated_kwargs = ["top", "right", "bottom", "left"]
  124. if not any([kwname in kwargs for kwname in deprecated_kwargs]):
  125. return func()
  126. # Deprecated arguments? Log warnings and assume that there was a
  127. # deprecation warning with expected message.
  128. with warnings.catch_warnings(record=True) as caught_warnings:
  129. warnings.simplefilter("always")
  130. result = func()
  131. assert (
  132. "These are deprecated. Use `x` and `y` instead."
  133. in str(caught_warnings[-1].message)
  134. )
  135. return result
  136. class TemporaryDirectory(object):
  137. """Create a context for a temporary directory.
  138. The directory is automatically removed at the end of the context.
  139. This context is available in ``tmpfile.TemporaryDirectory``, but only
  140. from 3.2+.
  141. Added in 0.4.0.
  142. """
  143. def __init__(self, suffix="", prefix="tmp", dir=None):
  144. # pylint: disable=redefined-builtin
  145. self.name = tempfile.mkdtemp(suffix, prefix, dir)
  146. def __enter__(self):
  147. return self.name
  148. def __exit__(self, exc_type, exc_val, exc_tb):
  149. shutil.rmtree(self.name)
  150. # Copied from
  151. # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
  152. # at commit 293dd23 (Nov 19, 2019).
  153. # Required at least to enable assertWarns() in python <3.2.
  154. # Added in 0.4.0.
  155. def _is_subtype(expected, basetype):
  156. if isinstance(expected, tuple):
  157. return all(_is_subtype(e, basetype) for e in expected)
  158. return isinstance(expected, type) and issubclass(expected, basetype)
  159. # Copied from
  160. # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
  161. # at commit 293dd23 (Nov 19, 2019).
  162. # Required at least to enable assertWarns() in python <3.2.
  163. # Added in 0.4.0.
  164. class _BaseTestCaseContext:
  165. # Added in 0.4.0.
  166. def __init__(self, test_case):
  167. self.test_case = test_case
  168. # Added in 0.4.0.
  169. def _raiseFailure(self, standardMsg):
  170. # pylint: disable=invalid-name, protected-access, no-member
  171. msg = self.test_case._formatMessage(self.msg, standardMsg)
  172. raise self.test_case.failureException(msg)
  173. # Copied from
  174. # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
  175. # at commit 293dd23 (Nov 19, 2019).
  176. # Required at least to enable assertWarns() in python <3.2.
  177. # Added in 0.4.0.
  178. class _AssertRaisesBaseContext(_BaseTestCaseContext):
  179. # Added in 0.4.0.
  180. def __init__(self, expected, test_case, expected_regex=None):
  181. _BaseTestCaseContext.__init__(self, test_case)
  182. self.expected = expected
  183. self.test_case = test_case
  184. if expected_regex is not None:
  185. expected_regex = re.compile(expected_regex)
  186. self.expected_regex = expected_regex
  187. self.obj_name = None
  188. self.msg = None
  189. # Added in 0.4.0.
  190. # pylint: disable=inconsistent-return-statements
  191. def handle(self, name, args, kwargs):
  192. """
  193. If args is empty, assertRaises/Warns is being used as a
  194. context manager, so check for a 'msg' kwarg and return self.
  195. If args is not empty, call a callable passing positional and keyword
  196. arguments.
  197. """
  198. # pylint: disable=no-member, self-cls-assignment, not-context-manager
  199. try:
  200. if not _is_subtype(self.expected, self._base_type):
  201. raise TypeError('%s() arg 1 must be %s' %
  202. (name, self._base_type_str))
  203. if not args:
  204. self.msg = kwargs.pop('msg', None)
  205. if kwargs:
  206. raise TypeError('%r is an invalid keyword argument for '
  207. 'this function' % (next(iter(kwargs)),))
  208. return self
  209. callable_obj = args[0]
  210. args = args[1:]
  211. try:
  212. self.obj_name = callable_obj.__name__
  213. except AttributeError:
  214. self.obj_name = str(callable_obj)
  215. with self:
  216. callable_obj(*args, **kwargs)
  217. finally:
  218. # bpo-23890: manually break a reference cycle
  219. self = None
  220. # pylint: enable=inconsistent-return-statements
  221. # Copied from
  222. # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
  223. # at commit 293dd23 (Nov 19, 2019).
  224. # Required at least to enable assertWarns() in python <3.2.
  225. # Added in 0.4.0.
  226. class _AssertWarnsContext(_AssertRaisesBaseContext):
  227. """A context manager used to implement TestCase.assertWarns* methods."""
  228. _base_type = Warning
  229. _base_type_str = 'a warning type or tuple of warning types'
  230. # Added in 0.4.0.
  231. def __enter__(self):
  232. # The __warningregistry__'s need to be in a pristine state for tests
  233. # to work properly.
  234. # pylint: disable=invalid-name, attribute-defined-outside-init
  235. for v in sys.modules.values():
  236. if getattr(v, '__warningregistry__', None):
  237. v.__warningregistry__ = {}
  238. self.warnings_manager = warnings.catch_warnings(record=True)
  239. self.warnings = self.warnings_manager.__enter__()
  240. warnings.simplefilter("always", self.expected)
  241. return self
  242. # Added in 0.4.0.
  243. def __exit__(self, exc_type, exc_value, tb):
  244. # pylint: disable=invalid-name, attribute-defined-outside-init
  245. self.warnings_manager.__exit__(exc_type, exc_value, tb)
  246. if exc_type is not None:
  247. # let unexpected exceptions pass through
  248. return
  249. try:
  250. exc_name = self.expected.__name__
  251. except AttributeError:
  252. exc_name = str(self.expected)
  253. first_matching = None
  254. for m in self.warnings:
  255. w = m.message
  256. if not isinstance(w, self.expected):
  257. continue
  258. if first_matching is None:
  259. first_matching = w
  260. if (self.expected_regex is not None and
  261. not self.expected_regex.search(str(w))):
  262. continue
  263. # store warning for later retrieval
  264. self.warning = w
  265. self.filename = m.filename
  266. self.lineno = m.lineno
  267. return
  268. # Now we simply try to choose a helpful failure message
  269. if first_matching is not None:
  270. self._raiseFailure('"{}" does not match "{}"'.format(
  271. self.expected_regex.pattern, str(first_matching)))
  272. if self.obj_name:
  273. self._raiseFailure("{} not triggered by {}".format(exc_name,
  274. self.obj_name))
  275. else:
  276. self._raiseFailure("{} not triggered".format(exc_name))
  277. # Partially copied from
  278. # https://github.com/python/cpython/blob/master/Lib/unittest/case.py
  279. # at commit 293dd23 (Nov 19, 2019).
  280. # Required at least to enable assertWarns() in python <3.2.
  281. def assertWarns(testcase, expected_warning, *args, **kwargs):
  282. """Context with same functionality as ``assertWarns`` in ``unittest``.
  283. Note that ``assertWarns`` is only available in python 3.2+.
  284. Added in 0.4.0.
  285. """
  286. # pylint: disable=invalid-name
  287. context = _AssertWarnsContext(expected_warning, testcase)
  288. return context.handle('assertWarns', args, kwargs)