test_other.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291
  1. from doctest import testmod
  2. import numpy
  3. import pytest
  4. import einops
  5. import einops.layers
  6. import einops.parsing
  7. from einops._backends import AbstractBackend
  8. from einops.einops import rearrange, parse_shape, _optimize_transformation
  9. from einops.tests import collect_test_backends, is_backend_tested
  10. __author__ = "Alex Rogozhnikov"
  11. def test_doctests_examples():
  12. # tests docstrings, additionally
  13. testmod(einops.layers, raise_on_error=True, extraglobs=dict(np=numpy))
  14. testmod(einops.einops, raise_on_error=True, extraglobs=dict(np=numpy))
  15. def test_backends_installed():
  16. """
  17. This test will fail if some of backends are not installed or can't be imported
  18. Other tests will just work and only test installed backends.
  19. """
  20. from . import parse_backends_to_test
  21. backends_to_test = parse_backends_to_test()
  22. errors = []
  23. for backend_type in AbstractBackend.__subclasses__():
  24. if backend_type.framework_name not in backends_to_test:
  25. continue
  26. try:
  27. # instantiate
  28. backend_type()
  29. except Exception as e:
  30. errors.append((backend_type.framework_name, e))
  31. assert len(errors) == 0, errors
  32. def test_optimize_transformations_numpy():
  33. print("Testing optimizations")
  34. shapes = [[2] * n_dimensions for n_dimensions in range(14)]
  35. shapes += [[3] * n_dimensions for n_dimensions in range(6)]
  36. shapes += [[2, 3, 5, 7]]
  37. shapes += [[2, 3, 5, 7, 11, 17]]
  38. for shape in shapes:
  39. for attempt in range(5):
  40. n_dimensions = len(shape)
  41. x = numpy.random.randint(0, 2**12, size=shape).reshape([-1])
  42. init_shape = shape[:]
  43. n_reduced = numpy.random.randint(0, n_dimensions + 1)
  44. reduced_axes = tuple(numpy.random.permutation(n_dimensions)[:n_reduced])
  45. axes_reordering = numpy.random.permutation(n_dimensions - n_reduced)
  46. final_shape = numpy.random.randint(0, 1024, size=333) # just random
  47. init_shape2, reduced_axes2, axes_reordering2, final_shape2 = combination2 = _optimize_transformation(
  48. init_shape, reduced_axes, axes_reordering, final_shape
  49. )
  50. assert numpy.array_equal(final_shape, final_shape2)
  51. result1 = x.reshape(init_shape).sum(axis=reduced_axes).transpose(axes_reordering).reshape([-1])
  52. result2 = x.reshape(init_shape2).sum(axis=reduced_axes2).transpose(axes_reordering2).reshape([-1])
  53. assert numpy.array_equal(result1, result2)
  54. # testing we can't optimize this formula again
  55. combination3 = _optimize_transformation(*combination2)
  56. for a, b in zip(combination2, combination3):
  57. assert numpy.array_equal(a, b)
  58. _IMPERATIVE_BACKENDS = collect_test_backends(symbolic=False, layers=False)
  59. x_np = numpy.zeros([10, 20, 30, 40])
  60. def test_parse_shape_imperative():
  61. for backend in _IMPERATIVE_BACKENDS:
  62. print("Shape parsing for ", backend.framework_name)
  63. parsed1 = parse_shape(x_np, "a b c d")
  64. parsed2 = parse_shape(backend.from_numpy(x_np), "a b c d")
  65. assert parsed1 == parsed2 == dict(a=10, b=20, c=30, d=40)
  66. assert parsed1 != dict(a=1, b=20, c=30, d=40) != parsed2
  67. def test_underscore():
  68. for backend in _IMPERATIVE_BACKENDS:
  69. parsed1 = parse_shape(x_np, "_ _ _ _")
  70. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ _ _")
  71. assert parsed1 == parsed2 == dict()
  72. def test_underscore_one():
  73. for backend in _IMPERATIVE_BACKENDS:
  74. parsed1 = parse_shape(x_np, "_ _ _ hello")
  75. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ _ hello")
  76. assert parsed1 == parsed2 == dict(hello=40)
  77. def test_underscore_several():
  78. for backend in _IMPERATIVE_BACKENDS:
  79. parsed1 = parse_shape(x_np, "_ _ a1 a1a111a")
  80. parsed2 = parse_shape(backend.from_numpy(x_np), "_ _ a1 a1a111a")
  81. assert parsed1 == parsed2 == dict(a1=30, a1a111a=40)
  82. def test_repeating():
  83. with pytest.raises(einops.EinopsError):
  84. parse_shape(x_np, "a a b b")
  85. for backend in _IMPERATIVE_BACKENDS:
  86. with pytest.raises(einops.EinopsError):
  87. parse_shape(backend.from_numpy(x_np), "a a b b")
  88. def test_ellipsis():
  89. for backend in _IMPERATIVE_BACKENDS:
  90. for shape, pattern, expected in [
  91. ([10, 20], "...", dict()),
  92. ([10], "... a", dict(a=10)),
  93. ([10, 20], "... a", dict(a=20)),
  94. ([10, 20, 30], "... a", dict(a=30)),
  95. ([10, 20, 30, 40], "... a", dict(a=40)),
  96. ([10], "a ...", dict(a=10)),
  97. ([10, 20], "a ...", dict(a=10)),
  98. ([10, 20, 30], "a ...", dict(a=10)),
  99. ([10, 20, 30, 40], "a ...", dict(a=10)),
  100. ([10, 20, 30, 40], " a ... b", dict(a=10, b=40)),
  101. ([10, 40], " a ... b", dict(a=10, b=40)),
  102. ]:
  103. x = numpy.ones(shape)
  104. parsed1 = parse_shape(x, pattern)
  105. parsed2 = parse_shape(backend.from_numpy(x), pattern)
  106. assert parsed1 == parsed2 == expected
  107. def test_parse_with_anonymous_axes():
  108. for backend in _IMPERATIVE_BACKENDS:
  109. for shape, pattern, expected in [
  110. ([1, 2, 3, 4], "1 2 3 a", dict(a=4)),
  111. ([10, 1, 2], "a 1 2", dict(a=10)),
  112. ([10, 1, 2], "a () 2", dict(a=10)),
  113. ]:
  114. x = numpy.ones(shape)
  115. parsed1 = parse_shape(x, pattern)
  116. parsed2 = parse_shape(backend.from_numpy(x), pattern)
  117. assert parsed1 == parsed2 == expected
  118. def test_failures():
  119. for backend in _IMPERATIVE_BACKENDS:
  120. # every test should fail
  121. for shape, pattern in [
  122. ([1, 2, 3, 4], "a b c"),
  123. ([1, 2, 3, 4], "2 a b c"),
  124. ([1, 2, 3, 4], "a b c ()"),
  125. ([1, 2, 3, 4], "a b c d e"),
  126. ([1, 2, 3, 4], "a b c d e ..."),
  127. ([1, 2, 3, 4], "a b c ()"),
  128. ]:
  129. with pytest.raises(RuntimeError):
  130. x = numpy.ones(shape)
  131. parse_shape(backend.from_numpy(x), pattern)
  132. _SYMBOLIC_BACKENDS = [
  133. *collect_test_backends(symbolic=True, layers=False),
  134. *collect_test_backends(symbolic=True, layers=True),
  135. ]
  136. # tensorflow.keras needs special way to compile,
  137. # shape vars can be used only inside layers but not as outputs
  138. _SYMBOLIC_BACKENDS = [backend for backend in _SYMBOLIC_BACKENDS if backend.framework_name != "tensorflow.keras"]
  139. @pytest.mark.parametrize("backend", _SYMBOLIC_BACKENDS)
  140. def test_parse_shape_symbolic(backend):
  141. for shape in [
  142. [10, 20, 30, 40],
  143. [10, 20, None, None],
  144. [None, None, None, None],
  145. ]:
  146. print(
  147. f"special shape parsing {backend.framework_name=} {shape=}",
  148. )
  149. input_symbol = backend.create_symbol(shape)
  150. shape_placeholder = parse_shape(input_symbol, "a b c d")
  151. shape = {}
  152. for name, symbol in shape_placeholder.items():
  153. shape[name] = (
  154. symbol
  155. if isinstance(symbol, int)
  156. else backend.eval_symbol(symbol, [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
  157. )
  158. print(shape)
  159. result_placeholder = rearrange(
  160. input_symbol, "a b (c1 c2) (d1 d2) -> (a b d1) c1 (c2 d2)", **parse_shape(input_symbol, "a b c1 _"), d2=2
  161. )
  162. result = backend.eval_symbol(result_placeholder, [(input_symbol, numpy.zeros([10, 20, 30, 40]))])
  163. print(result.shape)
  164. assert result.shape == (10 * 20 * 20, 30, 1 * 2)
  165. assert numpy.allclose(result, 0)
  166. @pytest.mark.parametrize("backend", _SYMBOLIC_BACKENDS)
  167. def test_parse_shape_symbolic_ellipsis(backend):
  168. for static_shape, shape, pattern, expected in [
  169. ([10, 20], [None, None], "...", dict()),
  170. ([10], [None], "... a", dict(a=10)),
  171. ([10, 20], [None, None], "... a", dict(a=20)),
  172. ([10, 20, 30], [None, None, None], "... a", dict(a=30)),
  173. ([10, 20, 30, 40], [None, None, None, None], "... a", dict(a=40)),
  174. ([10], [None], "a ...", dict(a=10)),
  175. ([10, 20], [None, None], "a ...", dict(a=10)),
  176. ([10, 20, 30], [None, None, None], "a ...", dict(a=10)),
  177. ([10, 20, 30, 40], [None, None, None, None], "a ...", dict(a=10)),
  178. ([10, 20, 30, 40], [None, None, None, None], " a ... b", dict(a=10, b=40)),
  179. ([10, 40], [None, None], " a ... b ", dict(a=10, b=40)),
  180. ]:
  181. input_symbol = backend.create_symbol(shape)
  182. shape_placeholder = parse_shape(input_symbol, pattern)
  183. out_shape = {}
  184. for name, symbol in shape_placeholder.items():
  185. if isinstance(symbol, int):
  186. out_shape[name] = symbol
  187. else:
  188. out_shape[name] = backend.eval_symbol(symbol, [(input_symbol, numpy.zeros(static_shape))])
  189. assert out_shape == expected
  190. def test_is_float_type():
  191. backends = collect_test_backends(symbolic=False, layers=False)
  192. backends += collect_test_backends(symbolic=False, layers=True)
  193. for backend in backends:
  194. for dtype in ["int32", "int64", "float32", "float64"]:
  195. is_float = "float" in dtype
  196. input = numpy.zeros([3, 4, 5], dtype=dtype)
  197. input = backend.from_numpy(input)
  198. assert backend.is_float_type(input) == is_float, (dtype, backend, input.dtype)
  199. def test_torch_compile():
  200. """
  201. Test ensures that allow_ops_in_compiled_graph allows compiling in a single graph
  202. Additionally we ensure that after compilation cache works properly
  203. (by changing shapes and patterns)
  204. We additionally check that pack/unpack still can be handled
  205. despite variable number of inputs/outputs
  206. """
  207. if not is_backend_tested("torch"):
  208. pytest.skip()
  209. import torch
  210. from torch import nn
  211. from einops import repeat, reduce, pack, unpack, einsum
  212. from einops._torch_specific import allow_ops_in_compiled_graph
  213. allow_ops_in_compiled_graph()
  214. class TorchModuleWithOperations(nn.Module):
  215. def __init__(self) -> None:
  216. super().__init__()
  217. def forward(self, x_abc, suffix=""):
  218. a, b, c = x_abc.shape
  219. def suf(pattern):
  220. parts = pattern.split()
  221. return " ".join([p if p[-1] not in "acd" else p + suffix for p in parts])
  222. # patterns look a bit strange because names a, c, d will be modified on every run
  223. # by suf function
  224. x_abcd = repeat(x_abc, suf("a b c -> a b c 4"))
  225. x_abc = reduce(x_abcd, suf("a b c d -> a b c"), "min")
  226. x_abdc, ps = pack([x_abc] * (2 + len(suffix)), suf("a b * c"))
  227. x_array = unpack(rearrange(x_abdc, suf("a b d c -> (a b ) 1 c d")), ps, "ab one1 c *")
  228. x1 = x_array[0] + len(x_array)
  229. x1 = rearrange(x1, suf("(a b ) 1 c -> a b c"), b=b)
  230. addition = einsum(x_abc, x_abcd, suf("a b c , a b c d -> d"))[0]
  231. return x1 + addition
  232. original = TorchModuleWithOperations()
  233. compiled = torch.compile(original, fullgraph=True, backend="aot_eager")
  234. for size in [10, 20, 40]:
  235. x = torch.rand([size, size + 1, size + 2])
  236. for suffix in ["", "suf1", "other_suffix"]:
  237. result1 = compiled(x, suffix)
  238. result2 = original(x, suffix)
  239. assert torch.allclose(result1, result2)