test_packing.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309
  1. import dataclasses
  2. import typing
  3. import numpy as np
  4. import pytest
  5. from einops import EinopsError, asnumpy, pack, unpack
  6. from einops.tests import collect_test_backends
  7. def pack_unpack(xs, pattern):
  8. x, ps = pack(xs, pattern)
  9. unpacked = unpack(xs, ps, pattern)
  10. assert len(unpacked) == len(xs)
  11. for a, b in zip(unpacked, xs):
  12. assert np.allclose(asnumpy(a), asnumpy(b))
  13. def unpack_and_pack(x, ps, pattern: str):
  14. unpacked = unpack(x, ps, pattern)
  15. packed, ps2 = pack(unpacked, pattern=pattern)
  16. assert np.allclose(asnumpy(packed), asnumpy(x))
  17. return unpacked
  18. def unpack_and_pack_against_numpy(x, ps, pattern: str):
  19. capturer_backend = CaptureException()
  20. capturer_numpy = CaptureException()
  21. with capturer_backend:
  22. unpacked = unpack(x, ps, pattern)
  23. packed, ps2 = pack(unpacked, pattern=pattern)
  24. with capturer_numpy:
  25. x_np = asnumpy(x)
  26. unpacked_np = unpack(x_np, ps, pattern)
  27. packed_np, ps3 = pack(unpacked_np, pattern=pattern)
  28. assert type(capturer_numpy.exception) == type(capturer_backend.exception) # noqa E721
  29. if capturer_numpy.exception is not None:
  30. # both failed
  31. return
  32. else:
  33. # neither failed, check results are identical
  34. assert np.allclose(asnumpy(packed), asnumpy(x))
  35. assert np.allclose(asnumpy(packed_np), asnumpy(x))
  36. assert len(unpacked) == len(unpacked_np)
  37. for a, b in zip(unpacked, unpacked_np):
  38. assert np.allclose(asnumpy(a), b)
  39. class CaptureException:
  40. def __enter__(self):
  41. self.exception = None
  42. def __exit__(self, exc_type, exc_val, exc_tb):
  43. self.exception = exc_val
  44. return True
  45. def test_numpy_trivial(H=13, W=17):
  46. def rand(*shape):
  47. return np.random.random(shape)
  48. def check(a, b):
  49. assert a.dtype == b.dtype
  50. assert a.shape == b.shape
  51. assert np.all(a == b)
  52. r, g, b = rand(3, H, W)
  53. embeddings = rand(H, W, 32)
  54. check(
  55. np.stack([r, g, b], axis=2),
  56. pack([r, g, b], "h w *")[0],
  57. )
  58. check(
  59. np.stack([r, g, b], axis=1),
  60. pack([r, g, b], "h * w")[0],
  61. )
  62. check(
  63. np.stack([r, g, b], axis=0),
  64. pack([r, g, b], "* h w")[0],
  65. )
  66. check(
  67. np.concatenate([r, g, b], axis=1),
  68. pack([r, g, b], "h *")[0],
  69. )
  70. check(
  71. np.concatenate([r, g, b], axis=0),
  72. pack([r, g, b], "* w")[0],
  73. )
  74. i = np.index_exp[:, :, None]
  75. check(
  76. np.concatenate([r[i], g[i], b[i], embeddings], axis=2),
  77. pack([r, g, b, embeddings], "h w *")[0],
  78. )
  79. with pytest.raises(EinopsError):
  80. pack([r, g, b, embeddings], "h w nonexisting_axis *")
  81. pack([r, g, b], "some_name_for_H some_name_for_w1 *")
  82. with pytest.raises(EinopsError):
  83. pack([r, g, b, embeddings], "h _w *") # no leading underscore
  84. with pytest.raises(EinopsError):
  85. pack([r, g, b, embeddings], "h_ w *") # no trailing underscore
  86. with pytest.raises(EinopsError):
  87. pack([r, g, b, embeddings], "1h_ w *")
  88. with pytest.raises(EinopsError):
  89. pack([r, g, b, embeddings], "1 w *")
  90. with pytest.raises(EinopsError):
  91. pack([r, g, b, embeddings], "h h *")
  92. # capital and non-capital are different
  93. pack([r, g, b, embeddings], "h H *")
  94. @dataclasses.dataclass
  95. class UnpackTestCase:
  96. shape: typing.Tuple[int, ...]
  97. pattern: str
  98. def dim(self):
  99. return self.pattern.split().index("*")
  100. def selfcheck(self):
  101. assert self.shape[self.dim()] == 5
  102. cases = [
  103. # NB: in all cases unpacked axis is of length 5.
  104. # that's actively used in tests below
  105. UnpackTestCase((5,), "*"),
  106. UnpackTestCase((5, 7), "* seven"),
  107. UnpackTestCase((7, 5), "seven *"),
  108. UnpackTestCase((5, 3, 4), "* three four"),
  109. UnpackTestCase((4, 5, 3), "four * three"),
  110. UnpackTestCase((3, 4, 5), "three four *"),
  111. ]
  112. def test_pack_unpack_with_numpy():
  113. case: UnpackTestCase
  114. for case in cases:
  115. shape = case.shape
  116. pattern = case.pattern
  117. x = np.random.random(shape)
  118. # all correct, no minus 1
  119. unpack_and_pack(x, [[2], [1], [2]], pattern)
  120. # no -1, asking for wrong shapes
  121. with pytest.raises(BaseException):
  122. unpack_and_pack(x, [[2], [1], [2]], pattern + " non_existent_axis")
  123. with pytest.raises(BaseException):
  124. unpack_and_pack(x, [[2], [1], [1]], pattern)
  125. with pytest.raises(BaseException):
  126. unpack_and_pack(x, [[4], [1], [1]], pattern)
  127. # all correct, with -1
  128. unpack_and_pack(x, [[2], [1], [-1]], pattern)
  129. unpack_and_pack(x, [[2], [-1], [2]], pattern)
  130. unpack_and_pack(x, [[-1], [1], [2]], pattern)
  131. _, _, last = unpack_and_pack(x, [[2], [3], [-1]], pattern)
  132. assert last.shape[case.dim()] == 0
  133. # asking for more elements than available
  134. with pytest.raises(BaseException):
  135. unpack(x, [[2], [4], [-1]], pattern)
  136. # this one does not raise, because indexing x[2:1] just returns zero elements
  137. # with pytest.raises(BaseException):
  138. # unpack(x, [[2], [-1], [4]], pattern)
  139. with pytest.raises(BaseException):
  140. unpack(x, [[-1], [1], [5]], pattern)
  141. # all correct, -1 nested
  142. rs = unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
  143. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  144. rs = unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
  145. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  146. rs = unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
  147. assert all(len(r.shape) == len(x.shape) + 1 for r in rs)
  148. # asking for more elements, -1 nested
  149. with pytest.raises(BaseException):
  150. unpack(x, [[-1, 2], [1], [5]], pattern)
  151. with pytest.raises(BaseException):
  152. unpack(x, [[2, 2], [2], [5, -1]], pattern)
  153. # asking for non-divisible number of elements
  154. with pytest.raises(BaseException):
  155. unpack(x, [[2, 1], [1], [3, -1]], pattern)
  156. with pytest.raises(BaseException):
  157. unpack(x, [[2, 1], [3, -1], [1]], pattern)
  158. with pytest.raises(BaseException):
  159. unpack(x, [[3, -1], [2, 1], [1]], pattern)
  160. # -1 takes zero
  161. unpack_and_pack(x, [[0], [5], [-1]], pattern)
  162. unpack_and_pack(x, [[0], [-1], [5]], pattern)
  163. unpack_and_pack(x, [[-1], [5], [0]], pattern)
  164. # -1 takes zero, -1
  165. unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
  166. def test_pack_unpack_against_numpy():
  167. for backend in collect_test_backends(symbolic=False, layers=False):
  168. print(f"test packing against numpy for {backend.framework_name}")
  169. check_zero_len = True
  170. for case in cases:
  171. unpack_and_pack = unpack_and_pack_against_numpy
  172. shape = case.shape
  173. pattern = case.pattern
  174. x = np.random.random(shape)
  175. x = backend.from_numpy(x)
  176. # all correct, no minus 1
  177. unpack_and_pack(x, [[2], [1], [2]], pattern)
  178. # no -1, asking for wrong shapes
  179. with pytest.raises(BaseException):
  180. unpack(x, [[2], [1], [1]], pattern)
  181. with pytest.raises(BaseException):
  182. unpack(x, [[4], [1], [1]], pattern)
  183. # all correct, with -1
  184. unpack_and_pack(x, [[2], [1], [-1]], pattern)
  185. unpack_and_pack(x, [[2], [-1], [2]], pattern)
  186. unpack_and_pack(x, [[-1], [1], [2]], pattern)
  187. # asking for more elements than available
  188. with pytest.raises(BaseException):
  189. unpack(x, [[2], [4], [-1]], pattern)
  190. # this one does not raise, because indexing x[2:1] just returns zero elements
  191. # with pytest.raises(BaseException):
  192. # unpack(x, [[2], [-1], [4]], pattern)
  193. with pytest.raises(BaseException):
  194. unpack(x, [[-1], [1], [5]], pattern)
  195. # all correct, -1 nested
  196. unpack_and_pack(x, [[1, 2], [1, 1], [-1, 1]], pattern)
  197. unpack_and_pack(x, [[1, 2], [1, -1], [1, 1]], pattern)
  198. unpack_and_pack(x, [[2, -1], [1, 2], [1, 1]], pattern)
  199. # asking for more elements, -1 nested
  200. with pytest.raises(BaseException):
  201. unpack(x, [[-1, 2], [1], [5]], pattern)
  202. with pytest.raises(BaseException):
  203. unpack(x, [[2, 2], [2], [5, -1]], pattern)
  204. # asking for non-divisible number of elements
  205. with pytest.raises(BaseException):
  206. unpack(x, [[2, 1], [1], [3, -1]], pattern)
  207. with pytest.raises(BaseException):
  208. unpack(x, [[2, 1], [3, -1], [1]], pattern)
  209. with pytest.raises(BaseException):
  210. unpack(x, [[3, -1], [2, 1], [1]], pattern)
  211. if check_zero_len:
  212. # -1 takes zero
  213. unpack_and_pack(x, [[2], [3], [-1]], pattern)
  214. unpack_and_pack(x, [[0], [5], [-1]], pattern)
  215. unpack_and_pack(x, [[0], [-1], [5]], pattern)
  216. unpack_and_pack(x, [[-1], [5], [0]], pattern)
  217. # -1 takes zero, -1
  218. unpack_and_pack(x, [[2, -1], [1, 5]], pattern)
  219. def test_pack_unpack_array_api():
  220. from einops import array_api as AA
  221. import numpy as xp
  222. if xp.__version__ < "2.0.0":
  223. pytest.skip()
  224. for case in cases:
  225. shape = case.shape
  226. pattern = case.pattern
  227. x_np = np.random.random(shape)
  228. x_xp = xp.from_dlpack(x_np)
  229. for ps in [
  230. [[2], [1], [2]],
  231. [[1], [1], [-1]],
  232. [[1], [1], [-1, 3]],
  233. [[2, 1], [1, 1, 1], [-1]],
  234. ]:
  235. x_np_split = unpack(x_np, ps, pattern)
  236. x_xp_split = AA.unpack(x_xp, ps, pattern)
  237. for a, b in zip(x_np_split, x_xp_split):
  238. assert np.allclose(a, AA.asnumpy(b + 0))
  239. x_agg_np, ps1 = pack(x_np_split, pattern)
  240. x_agg_xp, ps2 = AA.pack(x_xp_split, pattern)
  241. assert ps1 == ps2
  242. assert np.allclose(x_agg_np, AA.asnumpy(x_agg_xp))
  243. for ps in [
  244. [[2, 3]],
  245. [[1], [5]],
  246. [[1], [5], [-1]],
  247. [[1], [2, 3]],
  248. [[1], [5], [-1, 2]],
  249. ]:
  250. with pytest.raises(BaseException):
  251. unpack(x_np, ps, pattern)