test_einsum.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352
  1. from typing import Any, Callable
  2. from einops.tests import collect_test_backends
  3. from einops.einops import _compactify_pattern_for_einsum, einsum, EinopsError
  4. import numpy as np
  5. import pytest
  6. import string
  7. class Arguments:
  8. def __init__(self, *args: Any, **kargs: Any):
  9. self.args = args
  10. self.kwargs = kargs
  11. def __call__(self, function: Callable):
  12. return function(*self.args, **self.kwargs)
  13. test_layer_cases = [
  14. (
  15. Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape=None, c_out=13, c_in=12),
  16. (2, 12, 3, 4),
  17. (4, 13, 3, 2),
  18. ),
  19. (
  20. Arguments("b c_in h w -> w c_out h b", "c_in c_out", bias_shape="c_out", c_out=13, c_in=12),
  21. (2, 12, 3, 4),
  22. (4, 13, 3, 2),
  23. ),
  24. (
  25. Arguments("b c_in h w -> w c_in h b", "", bias_shape=None, c_in=12),
  26. (2, 12, 3, 4),
  27. (4, 12, 3, 2),
  28. ),
  29. (
  30. Arguments("b c_in h w -> b c_out", "c_in h w c_out", bias_shape=None, c_in=12, h=3, w=4, c_out=5),
  31. (2, 12, 3, 4),
  32. (2, 5),
  33. ),
  34. (
  35. Arguments("b t head c_in -> b t head c_out", "head c_in c_out", bias_shape=None, head=4, c_in=5, c_out=6),
  36. (2, 3, 4, 5),
  37. (2, 3, 4, 6),
  38. ),
  39. ]
  40. # Each of the form:
  41. # (Arguments, true_einsum_pattern, in_shapes, out_shape)
  42. test_functional_cases = [
  43. (
  44. # Basic:
  45. "b c h w, b w -> b h",
  46. "abcd,ad->ac",
  47. ((2, 3, 4, 5), (2, 5)),
  48. (2, 4),
  49. ),
  50. (
  51. # Three tensors:
  52. "b c h w, b w, b c -> b h",
  53. "abcd,ad,ab->ac",
  54. ((2, 3, 40, 5), (2, 5), (2, 3)),
  55. (2, 40),
  56. ),
  57. (
  58. # Ellipsis, and full names:
  59. "... one two three, three four five -> ... two five",
  60. "...abc,cde->...be",
  61. ((32, 5, 2, 3, 4), (4, 5, 6)),
  62. (32, 5, 3, 6),
  63. ),
  64. (
  65. # Ellipsis at the end:
  66. "one two three ..., three four five -> two five ...",
  67. "abc...,cde->be...",
  68. ((2, 3, 4, 32, 5), (4, 5, 6)),
  69. (3, 6, 32, 5),
  70. ),
  71. (
  72. # Ellipsis on multiple tensors:
  73. "... one two three, ... three four five -> ... two five",
  74. "...abc,...cde->...be",
  75. ((32, 5, 2, 3, 4), (32, 5, 4, 5, 6)),
  76. (32, 5, 3, 6),
  77. ),
  78. (
  79. # One tensor, and underscores:
  80. "first_tensor second_tensor -> first_tensor",
  81. "ab->a",
  82. ((5, 4),),
  83. (5,),
  84. ),
  85. (
  86. # Trace (repeated index)
  87. "i i -> ",
  88. "aa->",
  89. ((5, 5),),
  90. (),
  91. ),
  92. (
  93. # Too many spaces in string:
  94. " one two , three four->two four ",
  95. "ab,cd->bd",
  96. ((2, 3), (4, 5)),
  97. (3, 5),
  98. ),
  99. # The following tests were inspired by numpy's einsum tests
  100. # https://github.com/numpy/numpy/blob/v1.23.0/numpy/core/tests/test_einsum.py
  101. (
  102. # Trace with other indices
  103. "i middle i -> middle",
  104. "aba->b",
  105. ((5, 10, 5),),
  106. (10,),
  107. ),
  108. (
  109. # Ellipsis in the middle:
  110. "i ... i -> ...",
  111. "a...a->...",
  112. ((5, 3, 2, 1, 4, 5),),
  113. (3, 2, 1, 4),
  114. ),
  115. (
  116. # Product of first and last axes:
  117. "i ... i -> i ...",
  118. "a...a->a...",
  119. ((5, 3, 2, 1, 4, 5),),
  120. (5, 3, 2, 1, 4),
  121. ),
  122. (
  123. # Triple diagonal
  124. "one one one -> one",
  125. "aaa->a",
  126. ((5, 5, 5),),
  127. (5,),
  128. ),
  129. (
  130. # Axis swap:
  131. "i j k -> j i k",
  132. "abc->bac",
  133. ((1, 2, 3),),
  134. (2, 1, 3),
  135. ),
  136. (
  137. # Identity:
  138. "... -> ...",
  139. "...->...",
  140. ((5, 4, 3, 2, 1),),
  141. (5, 4, 3, 2, 1),
  142. ),
  143. (
  144. # Elementwise product of three tensors
  145. "..., ..., ... -> ...",
  146. "...,...,...->...",
  147. ((3, 2), (3, 2), (3, 2)),
  148. (3, 2),
  149. ),
  150. (
  151. # Basic summation:
  152. "index ->",
  153. "a->",
  154. ((10,)),
  155. (()),
  156. ),
  157. ]
  158. def test_layer():
  159. for backend in collect_test_backends(layers=True, symbolic=False):
  160. if backend.framework_name in ["tensorflow", "torch", "oneflow", "paddle"]:
  161. layer_type = backend.layers().EinMix
  162. for args, in_shape, out_shape in test_layer_cases:
  163. layer = args(layer_type)
  164. print("Running", layer.einsum_pattern, "for", backend.framework_name)
  165. input = np.random.uniform(size=in_shape).astype("float32")
  166. input_framework = backend.from_numpy(input)
  167. output_framework = layer(input_framework)
  168. output = backend.to_numpy(output_framework)
  169. assert output.shape == out_shape
  170. valid_backends_functional = [
  171. "tensorflow",
  172. "torch",
  173. "jax",
  174. "numpy",
  175. "oneflow",
  176. "cupy",
  177. "tensorflow.keras",
  178. "paddle",
  179. "pytensor",
  180. ]
  181. def test_functional():
  182. # Functional tests:
  183. backends = filter(lambda x: x.framework_name in valid_backends_functional, collect_test_backends())
  184. for backend in backends:
  185. for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
  186. print(f"Running '{einops_pattern}' for {backend.framework_name}")
  187. # Create pattern:
  188. predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
  189. assert predicted_pattern == true_pattern
  190. # Generate example data:
  191. rstate = np.random.RandomState(0)
  192. in_arrays = [rstate.uniform(size=shape).astype("float32") for shape in in_shapes]
  193. in_arrays_framework = [backend.from_numpy(array) for array in in_arrays]
  194. # Loop over whether we call it manually with the backend,
  195. # or whether we use `einops.einsum`.
  196. for do_manual_call in [True, False]:
  197. # Actually run einsum:
  198. if do_manual_call:
  199. out_array = backend.einsum(predicted_pattern, *in_arrays_framework)
  200. else:
  201. out_array = einsum(*in_arrays_framework, einops_pattern)
  202. # Check shape:
  203. if tuple(out_array.shape) != out_shape:
  204. raise ValueError(f"Expected output shape {out_shape} but got {out_array.shape}")
  205. # Check values:
  206. true_out_array = np.einsum(true_pattern, *in_arrays)
  207. predicted_out_array = backend.to_numpy(out_array)
  208. np.testing.assert_array_almost_equal(predicted_out_array, true_out_array, decimal=5)
  209. def test_functional_symbolic():
  210. backends = filter(
  211. lambda x: x.framework_name in valid_backends_functional, collect_test_backends(symbolic=True, layers=False)
  212. )
  213. for backend in backends:
  214. for einops_pattern, true_pattern, in_shapes, out_shape in test_functional_cases:
  215. print(f"Running '{einops_pattern}' for symbolic {backend.framework_name}")
  216. # Create pattern:
  217. predicted_pattern = _compactify_pattern_for_einsum(einops_pattern)
  218. assert predicted_pattern == true_pattern
  219. rstate = np.random.RandomState(0)
  220. in_syms = [backend.create_symbol(in_shape) for in_shape in in_shapes]
  221. in_data = [rstate.uniform(size=in_shape).astype("float32") for in_shape in in_shapes]
  222. expected_out_data = np.einsum(true_pattern, *in_data)
  223. for do_manual_call in [True, False]:
  224. if do_manual_call:
  225. predicted_out_symbol = backend.einsum(predicted_pattern, *in_syms)
  226. else:
  227. predicted_out_symbol = einsum(*in_syms, einops_pattern)
  228. predicted_out_data = backend.eval_symbol(
  229. predicted_out_symbol,
  230. list(zip(in_syms, in_data)),
  231. )
  232. if predicted_out_data.shape != out_shape:
  233. raise ValueError(f"Expected output shape {out_shape} but got {predicted_out_data.shape}")
  234. np.testing.assert_array_almost_equal(predicted_out_data, expected_out_data, decimal=5)
  235. def test_functional_errors():
  236. # Specific backend does not matter, as errors are raised
  237. # during the pattern creation.
  238. rstate = np.random.RandomState(0)
  239. def create_tensor(*shape):
  240. return rstate.uniform(size=shape).astype("float32")
  241. # raise NotImplementedError("Singleton () axes are not yet supported in einsum.")
  242. with pytest.raises(NotImplementedError, match="^Singleton"):
  243. einsum(
  244. create_tensor(5, 1),
  245. "i () -> i",
  246. )
  247. # raise NotImplementedError("Shape rearrangement is not yet supported in einsum.")
  248. with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
  249. einsum(
  250. create_tensor(5, 1),
  251. "a b -> (a b)",
  252. )
  253. with pytest.raises(NotImplementedError, match="^Shape rearrangement"):
  254. einsum(
  255. create_tensor(10, 1),
  256. "(a b) -> a b",
  257. )
  258. # raise RuntimeError("Encountered empty axis name in einsum.")
  259. # raise RuntimeError("Axis name in einsum must be a string.")
  260. # ^ Not tested, these are just a failsafe in case an unexpected error occurs.
  261. # raise NotImplementedError("Anonymous axes are not yet supported in einsum.")
  262. with pytest.raises(NotImplementedError, match="^Anonymous axes"):
  263. einsum(
  264. create_tensor(5, 1),
  265. "i 2 -> i",
  266. )
  267. # ParsedExpression error:
  268. with pytest.raises(EinopsError, match="^Invalid axis identifier"):
  269. einsum(
  270. create_tensor(5, 1),
  271. "i 2j -> i",
  272. )
  273. # raise ValueError("Einsum pattern must contain '->'.")
  274. with pytest.raises(ValueError, match="^Einsum pattern"):
  275. einsum(
  276. create_tensor(5, 3, 2),
  277. "i j k",
  278. )
  279. # raise RuntimeError("Too many axes in einsum.")
  280. with pytest.raises(RuntimeError, match="^Too many axes"):
  281. einsum(
  282. create_tensor(1),
  283. " ".join(string.ascii_letters) + " extra ->",
  284. )
  285. # raise RuntimeError("Unknown axis on right side of einsum.")
  286. with pytest.raises(RuntimeError, match="^Unknown axis"):
  287. einsum(
  288. create_tensor(5, 1),
  289. "i j -> k",
  290. )
  291. # raise ValueError(
  292. # "The last argument passed to `einops.einsum` must be a string,"
  293. # " representing the einsum pattern."
  294. # )
  295. with pytest.raises(ValueError, match="^The last argument"):
  296. einsum(
  297. "i j k -> i",
  298. create_tensor(5, 4, 3),
  299. )
  300. # raise ValueError(
  301. # "`einops.einsum` takes at minimum two arguments: the tensors,"
  302. # " followed by the pattern."
  303. # )
  304. with pytest.raises(ValueError, match="^`einops.einsum` takes"):
  305. einsum(
  306. "i j k -> i",
  307. )
  308. with pytest.raises(ValueError, match="^`einops.einsum` takes"):
  309. einsum(
  310. create_tensor(5, 1),
  311. )
  312. # TODO: Include check for giving normal einsum pattern rather than einops.