test_layers.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469
  1. import pickle
  2. from collections import namedtuple
  3. import numpy
  4. import pytest
  5. from einops import rearrange, reduce, EinopsError
  6. from einops.tests import collect_test_backends, is_backend_tested, FLOAT_REDUCTIONS as REDUCTIONS
  7. __author__ = "Alex Rogozhnikov"
  8. testcase = namedtuple("testcase", ["pattern", "axes_lengths", "input_shape", "wrong_shapes"])
  9. rearrangement_patterns = [
  10. testcase(
  11. "b c h w -> b (c h w)",
  12. dict(c=20),
  13. (10, 20, 30, 40),
  14. [(), (10,), (10, 10, 10), (10, 21, 30, 40), [1, 20, 1, 1, 1]],
  15. ),
  16. testcase(
  17. "b c (h1 h2) (w1 w2) -> b (c h2 w2) h1 w1",
  18. dict(h2=2, w2=2),
  19. (10, 20, 30, 40),
  20. [(), (1, 1, 1, 1), (1, 10, 3), ()],
  21. ),
  22. testcase(
  23. "b ... c -> c b ...",
  24. dict(b=10),
  25. (10, 20, 30),
  26. [(), (10,), (5, 10)],
  27. ),
  28. ]
  29. def test_rearrange_imperative():
  30. for backend in collect_test_backends(symbolic=False, layers=True):
  31. print("Test layer for ", backend.framework_name)
  32. for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
  33. x = numpy.arange(numpy.prod(input_shape), dtype="float32").reshape(input_shape)
  34. result_numpy = rearrange(x, pattern, **axes_lengths)
  35. layer = backend.layers().Rearrange(pattern, **axes_lengths)
  36. for shape in wrong_shapes:
  37. try:
  38. layer(backend.from_numpy(numpy.zeros(shape, dtype="float32")))
  39. except BaseException:
  40. pass
  41. else:
  42. raise AssertionError("Failure expected")
  43. # simple pickling / unpickling
  44. layer2 = pickle.loads(pickle.dumps(layer))
  45. result1 = backend.to_numpy(layer(backend.from_numpy(x)))
  46. result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
  47. assert numpy.allclose(result_numpy, result1)
  48. assert numpy.allclose(result1, result2)
  49. just_sum = backend.layers().Reduce("...->", reduction="sum")
  50. variable = backend.from_numpy(x)
  51. result = just_sum(layer(variable))
  52. result.backward()
  53. assert numpy.allclose(backend.to_numpy(variable.grad), 1)
  54. def test_rearrange_symbolic():
  55. for backend in collect_test_backends(symbolic=True, layers=True):
  56. print("Test layer for ", backend.framework_name)
  57. for pattern, axes_lengths, input_shape, wrong_shapes in rearrangement_patterns:
  58. x = numpy.arange(numpy.prod(input_shape), dtype="float32").reshape(input_shape)
  59. result_numpy = rearrange(x, pattern, **axes_lengths)
  60. layer = backend.layers().Rearrange(pattern, **axes_lengths)
  61. input_shape_of_nones = [None] * len(input_shape)
  62. shapes = [input_shape, input_shape_of_nones]
  63. for shape in shapes:
  64. symbol = backend.create_symbol(shape)
  65. eval_inputs = [(symbol, x)]
  66. result_symbol1 = layer(symbol)
  67. result1 = backend.eval_symbol(result_symbol1, eval_inputs)
  68. assert numpy.allclose(result_numpy, result1)
  69. layer2 = pickle.loads(pickle.dumps(layer))
  70. result_symbol2 = layer2(symbol)
  71. result2 = backend.eval_symbol(result_symbol2, eval_inputs)
  72. assert numpy.allclose(result1, result2)
  73. # now testing back-propagation
  74. just_sum = backend.layers().Reduce("...->", reduction="sum")
  75. result_sum1 = backend.eval_symbol(just_sum(result_symbol1), eval_inputs)
  76. result_sum2 = numpy.sum(x)
  77. assert numpy.allclose(result_sum1, result_sum2)
  78. reduction_patterns = rearrangement_patterns + [
  79. testcase("b c h w -> b ()", dict(b=10), (10, 20, 30, 40), [(10,), (10, 20, 30)]),
  80. testcase("b c (h1 h2) (w1 w2) -> b c h1 w1", dict(h1=15, h2=2, w2=2), (10, 20, 30, 40), [(10, 20, 31, 40)]),
  81. testcase("b ... c -> b", dict(b=10), (10, 20, 30, 40), [(10,), (11, 10)]),
  82. ]
  83. def test_reduce_imperative():
  84. for backend in collect_test_backends(symbolic=False, layers=True):
  85. print("Test layer for ", backend.framework_name)
  86. for reduction in REDUCTIONS:
  87. for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
  88. print(backend, reduction, pattern, axes_lengths, input_shape, wrong_shapes)
  89. x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype="float32").reshape(input_shape)
  90. x /= x.mean()
  91. result_numpy = reduce(x, pattern, reduction, **axes_lengths)
  92. layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
  93. for shape in wrong_shapes:
  94. try:
  95. layer(backend.from_numpy(numpy.zeros(shape, dtype="float32")))
  96. except BaseException:
  97. pass
  98. else:
  99. raise AssertionError("Failure expected")
  100. # simple pickling / unpickling
  101. layer2 = pickle.loads(pickle.dumps(layer))
  102. result1 = backend.to_numpy(layer(backend.from_numpy(x)))
  103. result2 = backend.to_numpy(layer2(backend.from_numpy(x)))
  104. assert numpy.allclose(result_numpy, result1)
  105. assert numpy.allclose(result1, result2)
  106. just_sum = backend.layers().Reduce("...->", reduction="sum")
  107. variable = backend.from_numpy(x)
  108. result = just_sum(layer(variable))
  109. result.backward()
  110. grad = backend.to_numpy(variable.grad)
  111. if reduction == "sum":
  112. assert numpy.allclose(grad, 1)
  113. if reduction == "mean":
  114. assert numpy.allclose(grad, grad.min())
  115. if reduction in ["max", "min"]:
  116. assert numpy.all(numpy.in1d(grad, [0, 1]))
  117. assert numpy.sum(grad) > 0.5
  118. def test_reduce_symbolic():
  119. for backend in collect_test_backends(symbolic=True, layers=True):
  120. print("Test layer for ", backend.framework_name)
  121. for reduction in REDUCTIONS:
  122. for pattern, axes_lengths, input_shape, wrong_shapes in reduction_patterns:
  123. x = numpy.arange(1, 1 + numpy.prod(input_shape), dtype="float32").reshape(input_shape)
  124. x /= x.mean()
  125. result_numpy = reduce(x, pattern, reduction, **axes_lengths)
  126. layer = backend.layers().Reduce(pattern, reduction, **axes_lengths)
  127. input_shape_of_nones = [None] * len(input_shape)
  128. shapes = [input_shape, input_shape_of_nones]
  129. for shape in shapes:
  130. symbol = backend.create_symbol(shape)
  131. eval_inputs = [(symbol, x)]
  132. result_symbol1 = layer(symbol)
  133. result1 = backend.eval_symbol(result_symbol1, eval_inputs)
  134. assert numpy.allclose(result_numpy, result1)
  135. layer2 = pickle.loads(pickle.dumps(layer))
  136. result_symbol2 = layer2(symbol)
  137. result2 = backend.eval_symbol(result_symbol2, eval_inputs)
  138. assert numpy.allclose(result1, result2)
  139. def create_torch_model(use_reduce=False, add_scripted_layer=False):
  140. if not is_backend_tested("torch"):
  141. pytest.skip()
  142. else:
  143. from torch.nn import Sequential, Conv2d, MaxPool2d, Linear, ReLU
  144. from einops.layers.torch import Rearrange, Reduce, EinMix
  145. import torch.jit
  146. return Sequential(
  147. Conv2d(3, 6, kernel_size=(5, 5)),
  148. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2) if use_reduce else MaxPool2d(kernel_size=2),
  149. Conv2d(6, 16, kernel_size=(5, 5)),
  150. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  151. torch.jit.script(Rearrange("b c h w -> b (c h w)"))
  152. if add_scripted_layer
  153. else Rearrange("b c h w -> b (c h w)"),
  154. Linear(16 * 5 * 5, 120),
  155. ReLU(),
  156. Linear(120, 84),
  157. ReLU(),
  158. EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
  159. EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
  160. Linear(84, 10),
  161. )
  162. def test_torch_layer():
  163. if not is_backend_tested("torch"):
  164. pytest.skip()
  165. else:
  166. # checked that torch present
  167. import torch
  168. import torch.jit
  169. model1 = create_torch_model(use_reduce=True)
  170. model2 = create_torch_model(use_reduce=False)
  171. input = torch.randn([10, 3, 32, 32])
  172. # random models have different predictions
  173. assert not torch.allclose(model1(input), model2(input))
  174. model2.load_state_dict(pickle.loads(pickle.dumps(model1.state_dict())))
  175. assert torch.allclose(model1(input), model2(input))
  176. # tracing (freezing)
  177. model3 = torch.jit.trace(model2, example_inputs=input)
  178. torch.testing.assert_close(model1(input), model3(input), atol=1e-3, rtol=1e-3)
  179. torch.testing.assert_close(model1(input + 1), model3(input + 1), atol=1e-3, rtol=1e-3)
  180. model4 = torch.jit.trace(model2, example_inputs=input)
  181. torch.testing.assert_close(model1(input), model4(input), atol=1e-3, rtol=1e-3)
  182. torch.testing.assert_close(model1(input + 1), model4(input + 1), atol=1e-3, rtol=1e-3)
  183. def test_torch_layers_scripting():
  184. if not is_backend_tested("torch"):
  185. pytest.skip()
  186. else:
  187. import torch
  188. for script_layer in [False, True]:
  189. model1 = create_torch_model(use_reduce=True, add_scripted_layer=script_layer)
  190. model2 = torch.jit.script(model1)
  191. input = torch.randn([10, 3, 32, 32])
  192. torch.testing.assert_close(model1(input), model2(input), atol=1e-3, rtol=1e-3)
  193. def test_keras_layer():
  194. if not is_backend_tested("tensorflow"):
  195. pytest.skip()
  196. else:
  197. import tensorflow as tf
  198. if tf.__version__ < "2.16.":
  199. # current implementation of layers follows new TF interface
  200. pytest.skip()
  201. from tensorflow.keras.models import Sequential
  202. from tensorflow.keras.layers import Conv2D as Conv2d, Dense as Linear, ReLU
  203. from einops.layers.keras import Rearrange, Reduce, EinMix, keras_custom_objects
  204. def create_keras_model():
  205. return Sequential(
  206. [
  207. Conv2d(6, kernel_size=5, input_shape=[32, 32, 3]),
  208. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  209. Conv2d(16, kernel_size=5),
  210. Reduce("b c (h h2) (w w2) -> b c h w", "max", h2=2, w2=2),
  211. Rearrange("b c h w -> b (c h w)"),
  212. Linear(120),
  213. ReLU(),
  214. Linear(84),
  215. ReLU(),
  216. EinMix("b c1 -> (b c2)", weight_shape="c1 c2", bias_shape="c2", c1=84, c2=84),
  217. EinMix("(b c2) -> b c3", weight_shape="c2 c3", bias_shape="c3", c2=84, c3=84),
  218. Linear(10),
  219. ]
  220. )
  221. model1 = create_keras_model()
  222. model2 = create_keras_model()
  223. input = numpy.random.normal(size=[10, 32, 32, 3]).astype("float32")
  224. # two randomly init models should provide different outputs
  225. assert not numpy.allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
  226. # get some temp filename
  227. tmp_model_filename = "/tmp/einops_tf_model.h5"
  228. # save arch + weights
  229. print("temp_path_keras1", tmp_model_filename)
  230. tf.keras.models.save_model(model1, tmp_model_filename)
  231. model3 = tf.keras.models.load_model(tmp_model_filename, custom_objects=keras_custom_objects)
  232. numpy.testing.assert_allclose(model1.predict_on_batch(input), model3.predict_on_batch(input))
  233. weight_filename = "/tmp/einops_tf_model.weights.h5"
  234. # save arch as json
  235. model4 = tf.keras.models.model_from_json(model1.to_json(), custom_objects=keras_custom_objects)
  236. model1.save_weights(weight_filename)
  237. model4.load_weights(weight_filename)
  238. model2.load_weights(weight_filename)
  239. # check that differently-inialized model receives same weights
  240. numpy.testing.assert_allclose(model1.predict_on_batch(input), model2.predict_on_batch(input))
  241. # ulimate test
  242. # save-load architecture, and then load weights - should return same result
  243. numpy.testing.assert_allclose(model1.predict_on_batch(input), model4.predict_on_batch(input))
  244. def test_flax_layers():
  245. """
  246. One-off simple tests for Flax layers.
  247. Unfortunately, Flax layers have a different interface from other layers.
  248. """
  249. if not is_backend_tested("jax"):
  250. pytest.skip()
  251. else:
  252. import jax
  253. import jax.numpy as jnp
  254. import flax
  255. from flax import linen as nn
  256. from einops.layers.flax import EinMix, Reduce, Rearrange
  257. class NN(nn.Module):
  258. @nn.compact
  259. def __call__(self, x):
  260. x = EinMix(
  261. "b (h h2) (w w2) c -> b h w c_out", "h2 w2 c c_out", "c_out", sizes=dict(h2=2, w2=3, c=4, c_out=5)
  262. )(x)
  263. x = Rearrange("b h w c -> b (w h c)", sizes=dict(c=5))(x)
  264. x = Reduce("b hwc -> b", "mean", dict(hwc=2 * 3 * 5))(x)
  265. return x
  266. model = NN()
  267. fixed_input = jnp.ones([10, 2 * 2, 3 * 3, 4])
  268. params = model.init(jax.random.PRNGKey(0), fixed_input)
  269. def eval_at_point(params):
  270. return jnp.linalg.norm(model.apply(params, fixed_input))
  271. vandg = jax.value_and_grad(eval_at_point)
  272. value0 = eval_at_point(params)
  273. value1, grad1 = vandg(params)
  274. assert jnp.allclose(value0, value1)
  275. params2 = jax.tree_map(lambda x1, x2: x1 - x2 * 0.001, params, grad1)
  276. value2 = eval_at_point(params2)
  277. assert value0 >= value2, (value0, value2)
  278. # check serialization
  279. fbytes = flax.serialization.to_bytes(params)
  280. _loaded = flax.serialization.from_bytes(params, fbytes)
  281. def test_einmix_decomposition():
  282. """
  283. Testing that einmix correctly decomposes into smaller transformations.
  284. """
  285. from einops.layers._einmix import _EinmixDebugger
  286. mixin1 = _EinmixDebugger(
  287. "a b c d e -> e d c b a",
  288. weight_shape="d a b",
  289. d=2, a=3, b=5,
  290. ) # fmt: off
  291. assert mixin1.pre_reshape_pattern is None
  292. assert mixin1.post_reshape_pattern is None
  293. assert mixin1.einsum_pattern == "abcde,dab->edcba"
  294. assert mixin1.saved_weight_shape == [2, 3, 5]
  295. assert mixin1.saved_bias_shape is None
  296. mixin2 = _EinmixDebugger(
  297. "a b c d e -> e d c b a",
  298. weight_shape="d a b",
  299. bias_shape="a b c d e",
  300. a=1, b=2, c=3, d=4, e=5,
  301. ) # fmt: off
  302. assert mixin2.pre_reshape_pattern is None
  303. assert mixin2.post_reshape_pattern is None
  304. assert mixin2.einsum_pattern == "abcde,dab->edcba"
  305. assert mixin2.saved_weight_shape == [4, 1, 2]
  306. assert mixin2.saved_bias_shape == [5, 4, 3, 2, 1]
  307. mixin3 = _EinmixDebugger(
  308. "... -> ...",
  309. weight_shape="",
  310. bias_shape="",
  311. ) # fmt: off
  312. assert mixin3.pre_reshape_pattern is None
  313. assert mixin3.post_reshape_pattern is None
  314. assert mixin3.einsum_pattern == "...,->..."
  315. assert mixin3.saved_weight_shape == []
  316. assert mixin3.saved_bias_shape == []
  317. mixin4 = _EinmixDebugger(
  318. "b a ... -> b c ...",
  319. weight_shape="b a c",
  320. a=1, b=2, c=3,
  321. ) # fmt: off
  322. assert mixin4.pre_reshape_pattern is None
  323. assert mixin4.post_reshape_pattern is None
  324. assert mixin4.einsum_pattern == "ba...,bac->bc..."
  325. assert mixin4.saved_weight_shape == [2, 1, 3]
  326. assert mixin4.saved_bias_shape is None
  327. mixin5 = _EinmixDebugger(
  328. "(b a) ... -> b c (...)",
  329. weight_shape="b a c",
  330. a=1, b=2, c=3,
  331. ) # fmt: off
  332. assert mixin5.pre_reshape_pattern == "(b a) ... -> b a ..."
  333. assert mixin5.pre_reshape_lengths == dict(a=1, b=2)
  334. assert mixin5.post_reshape_pattern == "b c ... -> b c (...)"
  335. assert mixin5.einsum_pattern == "ba...,bac->bc..."
  336. assert mixin5.saved_weight_shape == [2, 1, 3]
  337. assert mixin5.saved_bias_shape is None
  338. mixin6 = _EinmixDebugger(
  339. "b ... (a c) -> b ... (a d)",
  340. weight_shape="c d",
  341. bias_shape="a d",
  342. a=1, c=3, d=4,
  343. ) # fmt: off
  344. assert mixin6.pre_reshape_pattern == "b ... (a c) -> b ... a c"
  345. assert mixin6.pre_reshape_lengths == dict(a=1, c=3)
  346. assert mixin6.post_reshape_pattern == "b ... a d -> b ... (a d)"
  347. assert mixin6.einsum_pattern == "b...ac,cd->b...ad"
  348. assert mixin6.saved_weight_shape == [3, 4]
  349. assert mixin6.saved_bias_shape == [1, 1, 4] # (b) a d, ellipsis does not participate
  350. mixin7 = _EinmixDebugger(
  351. "a ... (b c) -> a (... d b)",
  352. weight_shape="c d b",
  353. bias_shape="d b",
  354. b=2, c=3, d=4,
  355. ) # fmt: off
  356. assert mixin7.pre_reshape_pattern == "a ... (b c) -> a ... b c"
  357. assert mixin7.pre_reshape_lengths == dict(b=2, c=3)
  358. assert mixin7.post_reshape_pattern == "a ... d b -> a (... d b)"
  359. assert mixin7.einsum_pattern == "a...bc,cdb->a...db"
  360. assert mixin7.saved_weight_shape == [3, 4, 2]
  361. assert mixin7.saved_bias_shape == [1, 4, 2] # (a) d b, ellipsis does not participate
  362. def test_einmix_restrictions():
  363. """
  364. Testing different cases
  365. """
  366. from einops.layers._einmix import _EinmixDebugger
  367. with pytest.raises(EinopsError):
  368. _EinmixDebugger(
  369. "a b c d e -> e d c b a",
  370. weight_shape="d a b",
  371. d=2, a=3, # missing b
  372. ) # fmt: off
  373. with pytest.raises(EinopsError):
  374. _EinmixDebugger(
  375. "a b c d e -> e d c b a",
  376. weight_shape="w a b",
  377. d=2, a=3, b=1 # missing d
  378. ) # fmt: off
  379. with pytest.raises(EinopsError):
  380. _EinmixDebugger(
  381. "(...) a -> ... a",
  382. weight_shape="a", a=1, # ellipsis on the left
  383. ) # fmt: off
  384. with pytest.raises(EinopsError):
  385. _EinmixDebugger(
  386. "(...) a -> a ...",
  387. weight_shape="a", a=1, # ellipsis on the right side after bias axis
  388. bias_shape='a',
  389. ) # fmt: off