_backends.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719
  1. """
  2. Backends in `einops` are organized to meet the following requirements
  3. - backends are not imported unless those are actually needed, because
  4. - backends may not be installed
  5. - importing all available backends will drive to significant memory footprint
  6. - backends may be present but installed with errors (but never used),
  7. importing may drive to crashes
  8. - backend should be either symbolic or imperative
  9. - this determines which methods (from_numpy/to_numpy or create_symbol/eval_symbol) should be defined
  10. - if backend can't provide symbols for shape dimensions, UnknownSize objects are used
  11. """
  12. import sys
  13. __author__ = "Alex Rogozhnikov"
  14. _loaded_backends: dict = {}
  15. _type2backend: dict = {}
  16. _debug_importing = False
  17. def get_backend(tensor) -> "AbstractBackend":
  18. """
  19. Takes a correct backend (e.g. numpy backend if tensor is numpy.ndarray) for a tensor.
  20. If needed, imports package and creates backend
  21. """
  22. _type = type(tensor)
  23. _result = _type2backend.get(_type, None)
  24. if _result is not None:
  25. return _result
  26. for framework_name, backend in list(_loaded_backends.items()):
  27. if backend.is_appropriate_type(tensor):
  28. _type2backend[_type] = backend
  29. return backend
  30. # Find backend subclasses recursively
  31. backend_subclasses = []
  32. backends = AbstractBackend.__subclasses__()
  33. while backends:
  34. backend = backends.pop()
  35. backends += backend.__subclasses__()
  36. backend_subclasses.append(backend)
  37. for BackendSubclass in backend_subclasses:
  38. if _debug_importing:
  39. print("Testing for subclass of ", BackendSubclass)
  40. if BackendSubclass.framework_name not in _loaded_backends:
  41. # check that module was already imported. Otherwise it can't be imported
  42. if BackendSubclass.framework_name in sys.modules:
  43. if _debug_importing:
  44. print("Imported backend for ", BackendSubclass.framework_name)
  45. backend = BackendSubclass()
  46. _loaded_backends[backend.framework_name] = backend
  47. if backend.is_appropriate_type(tensor):
  48. _type2backend[_type] = backend
  49. return backend
  50. raise RuntimeError("Tensor type unknown to einops {}".format(type(tensor)))
  51. class AbstractBackend:
  52. """Base backend class, major part of methods are only for debugging purposes."""
  53. framework_name: str
  54. def is_appropriate_type(self, tensor):
  55. """helper method should recognize tensors it can handle"""
  56. raise NotImplementedError()
  57. def from_numpy(self, x):
  58. raise NotImplementedError("framework doesn't support imperative execution")
  59. def to_numpy(self, x):
  60. raise NotImplementedError("framework doesn't support imperative execution")
  61. def create_symbol(self, shape):
  62. raise NotImplementedError("framework doesn't support symbolic computations")
  63. def eval_symbol(self, symbol, symbol_value_pairs):
  64. # symbol-value pairs is list[tuple[symbol, value-tensor]]
  65. raise NotImplementedError("framework doesn't support symbolic computations")
  66. def arange(self, start, stop):
  67. # supplementary method used only in testing, so should implement CPU version
  68. raise NotImplementedError("framework doesn't implement arange")
  69. def shape(self, x):
  70. """shape should return a tuple with integers or "shape symbols" (which will evaluate to actual size)"""
  71. return x.shape
  72. def reshape(self, x, shape):
  73. return x.reshape(shape)
  74. def transpose(self, x, axes):
  75. return x.transpose(axes)
  76. def reduce(self, x, operation, axes):
  77. return getattr(x, operation)(axis=axes)
  78. def stack_on_zeroth_dimension(self, tensors: list):
  79. raise NotImplementedError()
  80. def add_axis(self, x, new_position):
  81. raise NotImplementedError()
  82. def add_axes(self, x, n_axes, pos2len):
  83. repeats = [1] * n_axes
  84. for axis_position, axis_length in pos2len.items():
  85. x = self.add_axis(x, axis_position)
  86. repeats[axis_position] = axis_length
  87. return self.tile(x, tuple(repeats))
  88. def tile(self, x, repeats):
  89. """repeats - same lengths as x.shape"""
  90. raise NotImplementedError()
  91. def concat(self, tensors, axis: int):
  92. """concatenates tensors along axis.
  93. Assume identical across tensors: devices, dtypes and shapes except selected axis."""
  94. raise NotImplementedError()
  95. def is_float_type(self, x):
  96. # some backends (torch) can't compute average for non-floating types.
  97. # Decided to drop average for all backends if type is not floating
  98. raise NotImplementedError()
  99. def layers(self):
  100. raise NotImplementedError("backend does not provide layers")
  101. def __repr__(self):
  102. return "<einops backend for {}>".format(self.framework_name)
  103. def einsum(self, pattern, *x):
  104. raise NotImplementedError("backend does not support einsum")
  105. class UnknownSize:
  106. """pseudo-symbol for symbolic frameworks which do not provide symbols for shape elements"""
  107. def __floordiv__(self, other):
  108. return self
  109. def __eq__(self, other):
  110. return True # we don't know actual size
  111. def __mul__(self, other):
  112. return self
  113. def __rmul__(self, other):
  114. return self
  115. def __hash__(self):
  116. return hash(None)
  117. class NumpyBackend(AbstractBackend):
  118. framework_name = "numpy"
  119. def __init__(self):
  120. import numpy
  121. self.np = numpy
  122. def is_appropriate_type(self, tensor):
  123. return isinstance(tensor, self.np.ndarray)
  124. def from_numpy(self, x):
  125. return x
  126. def to_numpy(self, x):
  127. return x
  128. def arange(self, start, stop):
  129. return self.np.arange(start, stop)
  130. def stack_on_zeroth_dimension(self, tensors: list):
  131. return self.np.stack(tensors)
  132. def tile(self, x, repeats):
  133. return self.np.tile(x, repeats)
  134. def concat(self, tensors, axis: int):
  135. return self.np.concatenate(tensors, axis=axis)
  136. def is_float_type(self, x):
  137. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  138. def add_axis(self, x, new_position):
  139. return self.np.expand_dims(x, new_position)
  140. def einsum(self, pattern, *x):
  141. return self.np.einsum(pattern, *x)
  142. class JaxBackend(NumpyBackend):
  143. framework_name = "jax"
  144. def __init__(self):
  145. super(JaxBackend, self).__init__()
  146. self.onp = self.np
  147. import jax.numpy
  148. self.np = jax.numpy
  149. def from_numpy(self, x):
  150. return self.np.asarray(x)
  151. def to_numpy(self, x):
  152. return self.onp.asarray(x)
  153. class TorchBackend(AbstractBackend):
  154. framework_name = "torch"
  155. def __init__(self):
  156. import torch
  157. self.torch = torch
  158. # importing would register operations in torch._dynamo for torch.compile
  159. from . import _torch_specific # noqa
  160. def is_appropriate_type(self, tensor):
  161. return isinstance(tensor, self.torch.Tensor)
  162. def from_numpy(self, x):
  163. variable = self.torch.from_numpy(x)
  164. if self.is_float_type(variable):
  165. # attach grad only to floating types
  166. variable.requires_grad = True
  167. return variable
  168. def to_numpy(self, x):
  169. return x.detach().cpu().numpy()
  170. def arange(self, start, stop):
  171. return self.torch.arange(start, stop, dtype=self.torch.int64)
  172. def reduce(self, x, operation, reduced_axes):
  173. if operation == "min":
  174. return x.amin(dim=reduced_axes)
  175. elif operation == "max":
  176. return x.amax(dim=reduced_axes)
  177. elif operation == "sum":
  178. return x.sum(dim=reduced_axes)
  179. elif operation == "mean":
  180. return x.mean(dim=reduced_axes)
  181. elif operation in ("any", "all", "prod"):
  182. # pytorch supports reducing only one operation at a time
  183. for i in list(sorted(reduced_axes))[::-1]:
  184. x = getattr(x, operation)(dim=i)
  185. return x
  186. else:
  187. raise NotImplementedError("Unknown reduction ", operation)
  188. def transpose(self, x, axes):
  189. return x.permute(axes)
  190. def stack_on_zeroth_dimension(self, tensors: list):
  191. return self.torch.stack(tensors)
  192. def add_axes(self, x, n_axes, pos2len):
  193. repeats = [-1] * n_axes
  194. for axis_position, axis_length in pos2len.items():
  195. x = self.add_axis(x, axis_position)
  196. repeats[axis_position] = axis_length
  197. return x.expand(repeats)
  198. def tile(self, x, repeats):
  199. return x.repeat(repeats)
  200. def concat(self, tensors, axis: int):
  201. return self.torch.cat(tensors, dim=axis)
  202. def add_axis(self, x, new_position):
  203. return self.torch.unsqueeze(x, new_position)
  204. def is_float_type(self, x):
  205. return x.dtype in [self.torch.float16, self.torch.float32, self.torch.float64, self.torch.bfloat16]
  206. def layers(self):
  207. from .layers import torch
  208. return torch
  209. def einsum(self, pattern, *x):
  210. return self.torch.einsum(pattern, *x)
  211. class CupyBackend(AbstractBackend):
  212. framework_name = "cupy"
  213. def __init__(self):
  214. import cupy
  215. self.cupy = cupy
  216. def is_appropriate_type(self, tensor):
  217. return isinstance(tensor, self.cupy.ndarray)
  218. def from_numpy(self, x):
  219. return self.cupy.asarray(x)
  220. def to_numpy(self, x):
  221. return self.cupy.asnumpy(x)
  222. def arange(self, start, stop):
  223. return self.cupy.arange(start, stop)
  224. def stack_on_zeroth_dimension(self, tensors: list):
  225. return self.cupy.stack(tensors)
  226. def tile(self, x, repeats):
  227. return self.cupy.tile(x, repeats)
  228. def concat(self, tensors, axis: int):
  229. return self.cupy.concatenate(tensors, axis=axis)
  230. def add_axis(self, x, new_position):
  231. return self.cupy.expand_dims(x, new_position)
  232. def is_float_type(self, x):
  233. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  234. def einsum(self, pattern, *x):
  235. return self.cupy.einsum(pattern, *x)
  236. class HashableTuple:
  237. """Overcomes non-hashability of symbolic elements"""
  238. def __init__(self, elements: tuple):
  239. self.elements = elements
  240. def __iter__(self):
  241. for x in self.elements:
  242. yield x
  243. def __len__(self):
  244. return len(self.elements)
  245. def __getitem__(self, item):
  246. return self.elements[item]
  247. # default equality and hash is used (True only with itself, hash taken of id)
  248. class TensorflowBackend(AbstractBackend):
  249. framework_name = "tensorflow"
  250. def __init__(self):
  251. import tensorflow
  252. self.tf = tensorflow
  253. def is_appropriate_type(self, tensor):
  254. return isinstance(tensor, (self.tf.Tensor, self.tf.Variable))
  255. def from_numpy(self, x):
  256. assert self.tf.executing_eagerly()
  257. return self.tf.convert_to_tensor(x)
  258. def to_numpy(self, x):
  259. assert self.tf.executing_eagerly()
  260. return x.numpy()
  261. def arange(self, start, stop):
  262. return self.tf.range(start, stop)
  263. def shape(self, x):
  264. if self.tf.executing_eagerly():
  265. return tuple(UnknownSize() if d is None else int(d) for d in x.shape)
  266. else:
  267. static_shape = x.shape.as_list()
  268. tf_shape = self.tf.shape(x)
  269. # use the static shape where known, otherwise use the TF shape components
  270. shape = tuple([s or tf_shape[dim] for dim, s in enumerate(static_shape)])
  271. try:
  272. hash(shape)
  273. return shape
  274. except BaseException:
  275. # unhashable symbols in shape. Wrap tuple to be hashable.
  276. return HashableTuple(shape)
  277. def reduce(self, x, operation, axes):
  278. return getattr(self.tf, "reduce_" + operation)(x, axis=axes)
  279. def reshape(self, x, shape):
  280. return self.tf.reshape(x, shape)
  281. def transpose(self, x, axes):
  282. return self.tf.transpose(x, axes)
  283. def stack_on_zeroth_dimension(self, tensors: list):
  284. return self.tf.stack(tensors)
  285. def tile(self, x, repeats):
  286. return self.tf.tile(x, repeats)
  287. def concat(self, tensors, axis: int):
  288. return self.tf.concat(tensors, axis=axis)
  289. def add_axis(self, x, new_position):
  290. return self.tf.expand_dims(x, new_position)
  291. def is_float_type(self, x):
  292. return x.dtype in ("float16", "float32", "float64", "float128", "bfloat16")
  293. def layers(self):
  294. from .layers import tensorflow
  295. return tensorflow
  296. def einsum(self, pattern, *x):
  297. return self.tf.einsum(pattern, *x)
  298. class TFKerasBackend(AbstractBackend):
  299. framework_name = "tensorflow.keras"
  300. def __init__(self):
  301. import tensorflow as tf
  302. self.tf = tf
  303. self.keras = tf.keras
  304. self.K = tf.keras.backend
  305. def is_appropriate_type(self, tensor):
  306. return self.tf.is_tensor(tensor) and self.K.is_keras_tensor(tensor)
  307. def create_symbol(self, shape):
  308. return self.keras.Input(batch_shape=shape)
  309. def eval_symbol(self, symbol, symbol_value_pairs):
  310. model = self.keras.models.Model([var for (var, _) in symbol_value_pairs], symbol)
  311. return model.predict_on_batch([val for (_, val) in symbol_value_pairs])
  312. def arange(self, start, stop):
  313. return self.K.arange(start, stop)
  314. def shape(self, x):
  315. shape = self.K.shape(x) # tf tensor
  316. return HashableTuple(tuple(shape))
  317. def reduce(self, x, operation, axes):
  318. return getattr(self.K, operation)(x, axis=axes)
  319. def reshape(self, x, shape):
  320. return self.K.reshape(x, shape)
  321. def transpose(self, x, axes):
  322. return self.K.permute_dimensions(x, axes)
  323. def stack_on_zeroth_dimension(self, tensors: list):
  324. return self.K.stack(tensors)
  325. def tile(self, x, repeats):
  326. return self.K.tile(x, repeats)
  327. def concat(self, tensors, axis: int):
  328. return self.K.concatenate(tensors, axis=axis)
  329. def add_axis(self, x, new_position):
  330. return self.K.expand_dims(x, new_position)
  331. def is_float_type(self, x):
  332. return "float" in self.K.dtype(x)
  333. def layers(self):
  334. from .layers import keras
  335. return keras
  336. class OneFlowBackend(AbstractBackend):
  337. framework_name = "oneflow"
  338. def __init__(self):
  339. import oneflow as flow
  340. self.flow = flow
  341. def is_appropriate_type(self, tensor):
  342. return isinstance(tensor, self.flow.Tensor)
  343. def from_numpy(self, x):
  344. variable = self.flow.from_numpy(x)
  345. if self.is_float_type(variable):
  346. # attach grad only to floating types
  347. variable.requires_grad = True
  348. return variable
  349. def to_numpy(self, x):
  350. return x.detach().cpu().numpy()
  351. def arange(self, start, stop):
  352. return self.flow.arange(start, stop, dtype=self.flow.int64)
  353. def reduce(self, x, operation, reduced_axes):
  354. for axis in sorted(reduced_axes, reverse=True):
  355. if operation == "min":
  356. x, _ = x.min(dim=axis)
  357. elif operation == "max":
  358. x, _ = x.max(dim=axis)
  359. elif operation in ["sum", "mean", "prod", "any", "all"]:
  360. x = getattr(x, operation)(dim=axis)
  361. else:
  362. raise NotImplementedError("Unknown reduction ", operation)
  363. return x
  364. def transpose(self, x, axes):
  365. return x.permute(axes)
  366. def stack_on_zeroth_dimension(self, tensors: list):
  367. return self.flow.stack(tensors)
  368. def add_axes(self, x, n_axes, pos2len):
  369. repeats = [-1] * n_axes
  370. for axis_position, axis_length in pos2len.items():
  371. x = self.add_axis(x, axis_position)
  372. repeats[axis_position] = axis_length
  373. return x.expand(*repeats)
  374. def tile(self, x, repeats):
  375. return x.repeat(repeats)
  376. def concat(self, tensors, axis: int):
  377. return self.flow.concat(tensors, dim=axis)
  378. def add_axis(self, x, new_position):
  379. return self.flow.unsqueeze(x, new_position)
  380. def is_float_type(self, x):
  381. return x.dtype in [self.flow.float16, self.flow.float32, self.flow.float64]
  382. def layers(self):
  383. from .layers import oneflow
  384. return oneflow
  385. def einsum(self, pattern, *x):
  386. return self.flow.einsum(pattern, *x)
  387. class PaddleBackend(AbstractBackend):
  388. framework_name = "paddle"
  389. def __init__(self):
  390. import paddle
  391. self.paddle = paddle
  392. def is_appropriate_type(self, tensor):
  393. return self.paddle.is_tensor(tensor)
  394. def from_numpy(self, x):
  395. tensor = self.paddle.to_tensor(x)
  396. tensor.stop_gradient = False
  397. return tensor
  398. def to_numpy(self, x):
  399. return x.detach().numpy()
  400. def arange(self, start, stop):
  401. return self.paddle.arange(start, stop, dtype=self.paddle.int64)
  402. def reduce(self, x, operation, axes):
  403. if len(axes) == x.ndim:
  404. # currently paddle returns 1d tensor instead of 0d
  405. return super().reduce(x, operation, axes).squeeze(0)
  406. else:
  407. return super().reduce(x, operation, axes)
  408. def transpose(self, x, axes):
  409. return x.transpose(axes)
  410. def add_axes(self, x, n_axes, pos2len):
  411. repeats = [-1] * n_axes
  412. for axis_position, axis_length in pos2len.items():
  413. x = self.add_axis(x, axis_position)
  414. repeats[axis_position] = axis_length
  415. return x.expand(repeats)
  416. def stack_on_zeroth_dimension(self, tensors: list):
  417. return self.paddle.stack(tensors)
  418. def reshape(self, x, shape):
  419. return x.reshape(shape)
  420. def tile(self, x, repeats):
  421. return x.tile(repeats)
  422. def concat(self, tensors, axis: int):
  423. return self.paddle.concat(tensors, axis=axis)
  424. def add_axis(self, x, new_position):
  425. return x.unsqueeze(new_position)
  426. def is_float_type(self, x):
  427. return x.dtype in [self.paddle.float16, self.paddle.float32, self.paddle.float64]
  428. def layers(self):
  429. from .layers import paddle
  430. return paddle
  431. def einsum(self, pattern, *x):
  432. return self.paddle.einsum(pattern, *x)
  433. def shape(self, x):
  434. return tuple(x.shape)
  435. class TinygradBackend(AbstractBackend):
  436. framework_name = "tinygrad"
  437. def __init__(self):
  438. import tinygrad
  439. self.tinygrad = tinygrad
  440. def is_appropriate_type(self, tensor):
  441. return isinstance(tensor, self.tinygrad.Tensor)
  442. def from_numpy(self, x):
  443. return self.tinygrad.Tensor(x)
  444. def to_numpy(self, x):
  445. return x.numpy()
  446. def arange(self, start, stop):
  447. return self.tinygrad.Tensor.arange(start, stop)
  448. def shape(self, x):
  449. return x.shape
  450. def reshape(self, x, shape):
  451. return x.reshape(shape)
  452. def transpose(self, x, axes):
  453. return x.permute(axes)
  454. def reduce(self, x, operation, axes):
  455. for axis in sorted(axes, reverse=True):
  456. x = getattr(x, operation)(axis=axis)
  457. return x
  458. def stack_on_zeroth_dimension(self, tensors: list):
  459. return self.tinygrad.Tensor.stack(tensors)
  460. def add_axis(self, x, new_position):
  461. return x.unsqueeze(new_position)
  462. def tile(self, x, repeats):
  463. return x.repeat(repeats)
  464. def concat(self, tensors, axis: int):
  465. return tensors[0].cat(*tensors[1:], dim=axis) if len(tensors) > 1 else tensors[0]
  466. def is_float_type(self, x):
  467. return self.tinygrad.dtypes.is_float(x.dtype)
  468. def einsum(self, pattern, *x):
  469. return self.tinygrad.Tensor.einsum(pattern, *x)
  470. class PyTensorBackend(AbstractBackend):
  471. framework_name = "pytensor"
  472. def __init__(self):
  473. from pytensor import tensor
  474. self.pt = tensor
  475. def is_appropriate_type(self, tensor):
  476. return isinstance(tensor, self.pt.TensorVariable)
  477. def is_float_type(self, x):
  478. return x.dtype in self.pt.type.float_dtypes
  479. def from_numpy(self, x):
  480. return self.pt.as_tensor(x)
  481. def to_numpy(self, x):
  482. return x.eval() # Will only work if there are no symbolic inputs
  483. def create_symbol(self, shape):
  484. if not isinstance(shape, tuple | list):
  485. shape = (shape,)
  486. return self.pt.tensor(shape=shape)
  487. def eval_symbol(self, symbol, symbol_value_pairs):
  488. return symbol.eval(dict(symbol_value_pairs))
  489. def arange(self, start, stop):
  490. return self.pt.arange(start, stop)
  491. def shape(self, x):
  492. # use the static shape dimensions where known
  493. return tuple(
  494. static_dim if static_dim is not None else symbolic_dim
  495. for static_dim, symbolic_dim in zip(x.type.shape, x.shape)
  496. )
  497. def stack_on_zeroth_dimension(self, tensors: list):
  498. return self.pt.stack(tensors)
  499. def tile(self, x, repeats):
  500. return self.pt.tile(x, repeats)
  501. def concat(self, tensors, axis: int):
  502. return self.pt.concatenate(tensors, axis=axis)
  503. def add_axis(self, x, new_position):
  504. return self.pt.expand_dims(x, new_position)
  505. def einsum(self, pattern, *x):
  506. return self.pt.einsum(pattern, *x)