contract.py 34 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882
  1. """
  2. Contains the primary optimization and contraction routines.
  3. """
  4. from collections import namedtuple
  5. from decimal import Decimal
  6. import numpy as np
  7. from . import backends, blas, helpers, parser, paths, sharing
  8. __all__ = ["contract_path", "contract", "format_const_einsum_str", "ContractExpression", "shape_only"]
  9. class PathInfo(object):
  10. """A printable object to contain information about a contraction path.
  11. Attributes
  12. ----------
  13. naive_cost : int
  14. The estimate FLOP cost of a naive einsum contraction.
  15. opt_cost : int
  16. The estimate FLOP cost of this optimized contraction path.
  17. largest_intermediate : int
  18. The number of elements in the largest intermediate array that will be
  19. produced during the contraction.
  20. """
  21. def __init__(self, contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost,
  22. opt_cost, size_list, size_dict):
  23. self.contraction_list = contraction_list
  24. self.input_subscripts = input_subscripts
  25. self.output_subscript = output_subscript
  26. self.path = path
  27. self.indices = indices
  28. self.scale_list = scale_list
  29. self.naive_cost = Decimal(naive_cost)
  30. self.opt_cost = Decimal(opt_cost)
  31. self.speedup = self.naive_cost / self.opt_cost
  32. self.size_list = size_list
  33. self.size_dict = size_dict
  34. self.shapes = [tuple(size_dict[k] for k in ks) for ks in input_subscripts.split(',')]
  35. self.eq = "{}->{}".format(input_subscripts, output_subscript)
  36. self.largest_intermediate = Decimal(max(size_list))
  37. def __repr__(self):
  38. # Return the path along with a nice string representation
  39. header = ("scaling", "BLAS", "current", "remaining")
  40. path_print = [
  41. " Complete contraction: {}\n".format(self.eq), " Naive scaling: {}\n".format(len(self.indices)),
  42. " Optimized scaling: {}\n".format(max(self.scale_list)), " Naive FLOP count: {:.3e}\n".format(
  43. self.naive_cost), " Optimized FLOP count: {:.3e}\n".format(self.opt_cost),
  44. " Theoretical speedup: {:.3e}\n".format(self.speedup),
  45. " Largest intermediate: {:.3e} elements\n".format(self.largest_intermediate), "-" * 80 + "\n",
  46. "{:>6} {:>11} {:>22} {:>37}\n".format(*header), "-" * 80
  47. ]
  48. for n, contraction in enumerate(self.contraction_list):
  49. inds, idx_rm, einsum_str, remaining, do_blas = contraction
  50. if remaining is not None:
  51. remaining_str = ",".join(remaining) + "->" + self.output_subscript
  52. else:
  53. remaining_str = "..."
  54. size_remaining = max(0, 56 - max(22, len(einsum_str)))
  55. path_run = (self.scale_list[n], do_blas, einsum_str, remaining_str, size_remaining)
  56. path_print.append("\n{:>4} {:>14} {:>22} {:>{}}".format(*path_run))
  57. return "".join(path_print)
  58. def _choose_memory_arg(memory_limit, size_list):
  59. if memory_limit == 'max_input':
  60. return max(size_list)
  61. if memory_limit is None:
  62. return None
  63. if memory_limit < 1:
  64. if memory_limit == -1:
  65. return None
  66. else:
  67. raise ValueError("Memory limit must be larger than 0, or -1")
  68. return int(memory_limit)
  69. _VALID_CONTRACT_KWARGS = {'optimize', 'path', 'memory_limit', 'einsum_call', 'use_blas', 'shapes'}
  70. def contract_path(*operands, **kwargs):
  71. """
  72. Find a contraction order 'path', without performing the contraction.
  73. Parameters
  74. ----------
  75. subscripts : str
  76. Specifies the subscripts for summation.
  77. *operands : list of array_like
  78. These are the arrays for the operation.
  79. optimize : str, list or bool, optional (default: ``auto``)
  80. Choose the type of path.
  81. - if a list is given uses this as the path.
  82. - ``'optimal'`` An algorithm that explores all possible ways of
  83. contracting the listed tensors. Scales factorially with the number of
  84. terms in the contraction.
  85. - ``'branch-all'`` An algorithm like optimal but that restricts itself
  86. to searching 'likely' paths. Still scales factorially.
  87. - ``'branch-2'`` An even more restricted version of 'branch-all' that
  88. only searches the best two options at each step. Scales exponentially
  89. with the number of terms in the contraction.
  90. - ``'greedy'`` An algorithm that heuristically chooses the best pair
  91. contraction at each step.
  92. - ``'auto'`` Choose the best of the above algorithms whilst aiming to
  93. keep the path finding time below 1ms.
  94. use_blas : bool
  95. Use BLAS functions or not
  96. memory_limit : int, optional (default: None)
  97. Maximum number of elements allowed in intermediate arrays.
  98. shapes : bool, optional
  99. Whether ``contract_path`` should assume arrays (the default) or array
  100. shapes have been supplied.
  101. Returns
  102. -------
  103. path : list of tuples
  104. The einsum path
  105. PathInfo : str
  106. A printable object containing various information about the path found.
  107. Notes
  108. -----
  109. The resulting path indicates which terms of the input contraction should be
  110. contracted first, the result of this contraction is then appended to the end of
  111. the contraction list.
  112. Examples
  113. --------
  114. We can begin with a chain dot example. In this case, it is optimal to
  115. contract the b and c tensors represented by the first element of the path (1,
  116. 2). The resulting tensor is added to the end of the contraction and the
  117. remaining contraction, ``(0, 1)``, is then executed.
  118. >>> a = np.random.rand(2, 2)
  119. >>> b = np.random.rand(2, 5)
  120. >>> c = np.random.rand(5, 2)
  121. >>> path_info = opt_einsum.contract_path('ij,jk,kl->il', a, b, c)
  122. >>> print(path_info[0])
  123. [(1, 2), (0, 1)]
  124. >>> print(path_info[1])
  125. Complete contraction: ij,jk,kl->il
  126. Naive scaling: 4
  127. Optimized scaling: 3
  128. Naive FLOP count: 1.600e+02
  129. Optimized FLOP count: 5.600e+01
  130. Theoretical speedup: 2.857
  131. Largest intermediate: 4.000e+00 elements
  132. -------------------------------------------------------------------------
  133. scaling current remaining
  134. -------------------------------------------------------------------------
  135. 3 kl,jk->jl ij,jl->il
  136. 3 jl,ij->il il->il
  137. A more complex index transformation example.
  138. >>> I = np.random.rand(10, 10, 10, 10)
  139. >>> C = np.random.rand(10, 10)
  140. >>> path_info = oe.contract_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C)
  141. >>> print(path_info[0])
  142. [(0, 2), (0, 3), (0, 2), (0, 1)]
  143. >>> print(path_info[1])
  144. Complete contraction: ea,fb,abcd,gc,hd->efgh
  145. Naive scaling: 8
  146. Optimized scaling: 5
  147. Naive FLOP count: 8.000e+08
  148. Optimized FLOP count: 8.000e+05
  149. Theoretical speedup: 1000.000
  150. Largest intermediate: 1.000e+04 elements
  151. --------------------------------------------------------------------------
  152. scaling current remaining
  153. --------------------------------------------------------------------------
  154. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  155. 5 bcde,fb->cdef gc,hd,cdef->efgh
  156. 5 cdef,gc->defg hd,defg->efgh
  157. 5 defg,hd->efgh efgh->efgh
  158. """
  159. # Make sure all keywords are valid
  160. unknown_kwargs = set(kwargs) - _VALID_CONTRACT_KWARGS
  161. if len(unknown_kwargs):
  162. raise TypeError("einsum_path: Did not understand the following kwargs: {}".format(unknown_kwargs))
  163. path_type = kwargs.pop('optimize', 'auto')
  164. memory_limit = kwargs.pop('memory_limit', None)
  165. shapes = kwargs.pop('shapes', False)
  166. # Hidden option, only einsum should call this
  167. einsum_call_arg = kwargs.pop("einsum_call", False)
  168. use_blas = kwargs.pop('use_blas', True)
  169. # Python side parsing
  170. input_subscripts, output_subscript, operands = parser.parse_einsum_input(operands)
  171. # Build a few useful list and sets
  172. input_list = input_subscripts.split(',')
  173. input_sets = [set(x) for x in input_list]
  174. if shapes:
  175. input_shps = operands
  176. else:
  177. input_shps = [x.shape for x in operands]
  178. output_set = set(output_subscript)
  179. indices = set(input_subscripts.replace(',', ''))
  180. # Get length of each unique dimension and ensure all dimensions are correct
  181. size_dict = {}
  182. for tnum, term in enumerate(input_list):
  183. sh = input_shps[tnum]
  184. if len(sh) != len(term):
  185. raise ValueError("Einstein sum subscript '{}' does not contain the "
  186. "correct number of indices for operand {}.".format(input_list[tnum], tnum))
  187. for cnum, char in enumerate(term):
  188. dim = int(sh[cnum])
  189. if char in size_dict:
  190. # For broadcasting cases we always want the largest dim size
  191. if size_dict[char] == 1:
  192. size_dict[char] = dim
  193. elif dim not in (1, size_dict[char]):
  194. raise ValueError("Size of label '{}' for operand {} ({}) does not match previous "
  195. "terms ({}).".format(char, tnum, size_dict[char], dim))
  196. else:
  197. size_dict[char] = dim
  198. # Compute size of each input array plus the output array
  199. size_list = [helpers.compute_size_by_dict(term, size_dict) for term in input_list + [output_subscript]]
  200. memory_arg = _choose_memory_arg(memory_limit, size_list)
  201. num_ops = len(input_list)
  202. # Compute naive cost
  203. # This isnt quite right, need to look into exactly how einsum does this
  204. # indices_in_input = input_subscripts.replace(',', '')
  205. inner_product = (sum(len(x) for x in input_sets) - len(indices)) > 0
  206. naive_cost = helpers.flop_count(indices, inner_product, num_ops, size_dict)
  207. # Compute the path
  208. if not isinstance(path_type, (str, paths.PathOptimizer)):
  209. # Custom path supplied
  210. path = path_type
  211. elif num_ops <= 2:
  212. # Nothing to be optimized
  213. path = [tuple(range(num_ops))]
  214. elif isinstance(path_type, paths.PathOptimizer):
  215. # Custom path optimizer supplied
  216. path = path_type(input_sets, output_set, size_dict, memory_arg)
  217. else:
  218. path_optimizer = paths.get_path_fn(path_type)
  219. path = path_optimizer(input_sets, output_set, size_dict, memory_arg)
  220. cost_list = []
  221. scale_list = []
  222. size_list = []
  223. contraction_list = []
  224. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  225. for cnum, contract_inds in enumerate(path):
  226. # Make sure we remove inds from right to left
  227. contract_inds = tuple(sorted(list(contract_inds), reverse=True))
  228. contract_tuple = helpers.find_contraction(contract_inds, input_sets, output_set)
  229. out_inds, input_sets, idx_removed, idx_contract = contract_tuple
  230. # Compute cost, scale, and size
  231. cost = helpers.flop_count(idx_contract, idx_removed, len(contract_inds), size_dict)
  232. cost_list.append(cost)
  233. scale_list.append(len(idx_contract))
  234. size_list.append(helpers.compute_size_by_dict(out_inds, size_dict))
  235. tmp_inputs = [input_list.pop(x) for x in contract_inds]
  236. tmp_shapes = [input_shps.pop(x) for x in contract_inds]
  237. if use_blas:
  238. do_blas = blas.can_blas(tmp_inputs, out_inds, idx_removed, tmp_shapes)
  239. else:
  240. do_blas = False
  241. # Last contraction
  242. if (cnum - len(path)) == -1:
  243. idx_result = output_subscript
  244. else:
  245. # use tensordot order to minimize transpositions
  246. all_input_inds = "".join(tmp_inputs)
  247. idx_result = "".join(sorted(out_inds, key=all_input_inds.find))
  248. shp_result = parser.find_output_shape(tmp_inputs, tmp_shapes, idx_result)
  249. input_list.append(idx_result)
  250. input_shps.append(shp_result)
  251. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  252. # for large expressions saving the remaining terms at each step can
  253. # incur a large memory footprint - and also be messy to print
  254. if len(input_list) <= 20:
  255. remaining = tuple(input_list)
  256. else:
  257. remaining = None
  258. contraction = (contract_inds, idx_removed, einsum_str, remaining, do_blas)
  259. contraction_list.append(contraction)
  260. opt_cost = sum(cost_list)
  261. if einsum_call_arg:
  262. return operands, contraction_list
  263. path_print = PathInfo(contraction_list, input_subscripts, output_subscript, indices, path, scale_list, naive_cost,
  264. opt_cost, size_list, size_dict)
  265. return path, path_print
  266. @sharing.einsum_cache_wrap
  267. def _einsum(*operands, **kwargs):
  268. """Base einsum, but with pre-parse for valid characters if a string is given.
  269. """
  270. fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
  271. if not isinstance(operands[0], str):
  272. return fn(*operands, **kwargs)
  273. einsum_str, operands = operands[0], operands[1:]
  274. # Do we need to temporarily map indices into [a-z,A-Z] range?
  275. if not parser.has_valid_einsum_chars_only(einsum_str):
  276. # Explicitly find output str first so as to maintain order
  277. if '->' not in einsum_str:
  278. einsum_str += '->' + parser.find_output_str(einsum_str)
  279. einsum_str = parser.convert_to_valid_einsum_chars(einsum_str)
  280. return fn(einsum_str, *operands, **kwargs)
  281. def _default_transpose(x, axes):
  282. # most libraries implement a method version
  283. return x.transpose(axes)
  284. @sharing.transpose_cache_wrap
  285. def _transpose(x, axes, backend='numpy'):
  286. """Base transpose.
  287. """
  288. fn = backends.get_func('transpose', backend, _default_transpose)
  289. return fn(x, axes)
  290. @sharing.tensordot_cache_wrap
  291. def _tensordot(x, y, axes, backend='numpy'):
  292. """Base tensordot.
  293. """
  294. fn = backends.get_func('tensordot', backend)
  295. return fn(x, y, axes=axes)
  296. # Rewrite einsum to handle different cases
  297. def contract(*operands, **kwargs):
  298. """
  299. contract(subscripts, *operands, out=None, dtype=None, order='K', casting='safe', use_blas=True, optimize=True, memory_limit=None, backend='numpy')
  300. Evaluates the Einstein summation convention on the operands. A drop in
  301. replacement for NumPy's einsum function that optimizes the order of contraction
  302. to reduce overall scaling at the cost of several intermediate arrays.
  303. Parameters
  304. ----------
  305. subscripts : str
  306. Specifies the subscripts for summation.
  307. *operands : list of array_like
  308. These are the arrays for the operation.
  309. out : array_like
  310. A output array in which set the resulting output.
  311. dtype : str
  312. The dtype of the given contraction, see np.einsum.
  313. order : str
  314. The order of the resulting contraction, see np.einsum.
  315. casting : str
  316. The casting procedure for operations of different dtype, see np.einsum.
  317. use_blas : bool
  318. Do you use BLAS for valid operations, may use extra memory for more intermediates.
  319. optimize : str, list or bool, optional (default: ``auto``)
  320. Choose the type of path.
  321. - if a list is given uses this as the path.
  322. - ``'optimal'`` An algorithm that explores all possible ways of
  323. contracting the listed tensors. Scales factorially with the number of
  324. terms in the contraction.
  325. - ``'dp'`` A faster (but essentially optimal) algorithm that uses
  326. dynamic programming to exhaustively search all contraction paths
  327. without outer-products.
  328. - ``'greedy'`` An cheap algorithm that heuristically chooses the best
  329. pairwise contraction at each step. Scales linearly in the number of
  330. terms in the contraction.
  331. - ``'random-greedy'`` Run a randomized version of the greedy algorithm
  332. 32 times and pick the best path.
  333. - ``'random-greedy-128'`` Run a randomized version of the greedy
  334. algorithm 128 times and pick the best path.
  335. - ``'branch-all'`` An algorithm like optimal but that restricts itself
  336. to searching 'likely' paths. Still scales factorially.
  337. - ``'branch-2'`` An even more restricted version of 'branch-all' that
  338. only searches the best two options at each step. Scales exponentially
  339. with the number of terms in the contraction.
  340. - ``'auto'`` Choose the best of the above algorithms whilst aiming to
  341. keep the path finding time below 1ms.
  342. - ``'auto-hq'`` Aim for a high quality contraction, choosing the best
  343. of the above algorithms whilst aiming to keep the path finding time
  344. below 1sec.
  345. memory_limit : {None, int, 'max_input'} (default: None)
  346. Give the upper bound of the largest intermediate tensor contract will build.
  347. - None or -1 means there is no limit
  348. - 'max_input' means the limit is set as largest input tensor
  349. - a positive integer is taken as an explicit limit on the number of elements
  350. The default is None. Note that imposing a limit can make contractions
  351. exponentially slower to perform.
  352. backend : str, optional (default: ``auto``)
  353. Which library to use to perform the required ``tensordot``, ``transpose``
  354. and ``einsum`` calls. Should match the types of arrays supplied, See
  355. :func:`contract_expression` for generating expressions which convert
  356. numpy arrays to and from the backend library automatically.
  357. Returns
  358. -------
  359. out : array_like
  360. The result of the einsum expression.
  361. Notes
  362. -----
  363. This function should produce a result identical to that of NumPy's einsum
  364. function. The primary difference is ``contract`` will attempt to form
  365. intermediates which reduce the overall scaling of the given einsum contraction.
  366. By default the worst intermediate formed will be equal to that of the largest
  367. input array. For large einsum expressions with many input arrays this can
  368. provide arbitrarily large (1000 fold+) speed improvements.
  369. For contractions with just two tensors this function will attempt to use
  370. NumPy's built-in BLAS functionality to ensure that the given operation is
  371. preformed optimally. When NumPy is linked to a threaded BLAS, potential
  372. speedups are on the order of 20-100 for a six core machine.
  373. Examples
  374. --------
  375. See :func:`opt_einsum.contract_path` or :func:`numpy.einsum`
  376. """
  377. optimize_arg = kwargs.pop('optimize', True)
  378. if optimize_arg is True:
  379. optimize_arg = 'auto'
  380. valid_einsum_kwargs = ['out', 'dtype', 'order', 'casting']
  381. einsum_kwargs = {k: v for (k, v) in kwargs.items() if k in valid_einsum_kwargs}
  382. # If no optimization, run pure einsum
  383. if optimize_arg is False:
  384. return _einsum(*operands, **einsum_kwargs)
  385. # Grab non-einsum kwargs
  386. use_blas = kwargs.pop('use_blas', True)
  387. memory_limit = kwargs.pop('memory_limit', None)
  388. backend = kwargs.pop('backend', 'auto')
  389. gen_expression = kwargs.pop('_gen_expression', False)
  390. constants_dict = kwargs.pop('_constants_dict', {})
  391. # Make sure remaining keywords are valid for einsum
  392. unknown_kwargs = [k for (k, v) in kwargs.items() if k not in valid_einsum_kwargs]
  393. if len(unknown_kwargs):
  394. raise TypeError("Did not understand the following kwargs: {}".format(unknown_kwargs))
  395. if gen_expression:
  396. full_str = operands[0]
  397. # Build the contraction list and operand
  398. operands, contraction_list = contract_path(*operands,
  399. optimize=optimize_arg,
  400. memory_limit=memory_limit,
  401. einsum_call=True,
  402. use_blas=use_blas)
  403. # check if performing contraction or just building expression
  404. if gen_expression:
  405. return ContractExpression(full_str, contraction_list, constants_dict, **einsum_kwargs)
  406. return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
  407. def infer_backend(x):
  408. return x.__class__.__module__.split('.')[0]
  409. def parse_backend(arrays, backend):
  410. """Find out what backend we should use, dipatching based on the first
  411. array if ``backend='auto'`` is specified.
  412. """
  413. if backend != 'auto':
  414. return backend
  415. backend = infer_backend(arrays[0])
  416. # some arrays will be defined in modules that don't implement tensordot
  417. # etc. so instead default to numpy
  418. if not backends.has_tensordot(backend):
  419. return 'numpy'
  420. return backend
  421. def _core_contract(operands, contraction_list, backend='auto', evaluate_constants=False, **einsum_kwargs):
  422. """Inner loop used to perform an actual contraction given the output
  423. from a ``contract_path(..., einsum_call=True)`` call.
  424. """
  425. # Special handling if out is specified
  426. out_array = einsum_kwargs.pop('out', None)
  427. specified_out = out_array is not None
  428. backend = parse_backend(operands, backend)
  429. # try and do as much as possible without einsum if not available
  430. no_einsum = not backends.has_einsum(backend)
  431. # Start contraction loop
  432. for num, contraction in enumerate(contraction_list):
  433. inds, idx_rm, einsum_str, _, blas_flag = contraction
  434. # check if we are performing the pre-pass of an expression with constants,
  435. # if so, break out upon finding first non-constant (None) operand
  436. if evaluate_constants and any(operands[x] is None for x in inds):
  437. return operands, contraction_list[num:]
  438. tmp_operands = [operands.pop(x) for x in inds]
  439. # Do we need to deal with the output?
  440. handle_out = specified_out and ((num + 1) == len(contraction_list))
  441. # Call tensordot (check if should prefer einsum, but only if available)
  442. if blas_flag and ('EINSUM' not in blas_flag or no_einsum):
  443. # Checks have already been handled
  444. input_str, results_index = einsum_str.split('->')
  445. input_left, input_right = input_str.split(',')
  446. tensor_result = "".join(s for s in input_left + input_right if s not in idx_rm)
  447. # Find indices to contract over
  448. left_pos, right_pos = [], []
  449. for s in idx_rm:
  450. left_pos.append(input_left.find(s))
  451. right_pos.append(input_right.find(s))
  452. # Contract!
  453. new_view = _tensordot(*tmp_operands, axes=(tuple(left_pos), tuple(right_pos)), backend=backend)
  454. # Build a new view if needed
  455. if (tensor_result != results_index) or handle_out:
  456. transpose = tuple(map(tensor_result.index, results_index))
  457. new_view = _transpose(new_view, axes=transpose, backend=backend)
  458. if handle_out:
  459. out_array[:] = new_view
  460. # Call einsum
  461. else:
  462. # If out was specified
  463. if handle_out:
  464. einsum_kwargs["out"] = out_array
  465. # Do the contraction
  466. new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
  467. # Append new items and dereference what we can
  468. operands.append(new_view)
  469. del tmp_operands, new_view
  470. if specified_out:
  471. return out_array
  472. else:
  473. return operands[0]
  474. def format_const_einsum_str(einsum_str, constants):
  475. """Add brackets to the constant terms in ``einsum_str``. For example:
  476. >>> format_const_einsum_str('ab,bc,cd->ad', [0, 2])
  477. 'bc,[ab,cd]->ad'
  478. No-op if there are no constants.
  479. """
  480. if not constants:
  481. return einsum_str
  482. if "->" in einsum_str:
  483. lhs, rhs = einsum_str.split('->')
  484. arrow = "->"
  485. else:
  486. lhs, rhs, arrow = einsum_str, "", ""
  487. wrapped_terms = ["[{}]".format(t) if i in constants else t for i, t in enumerate(lhs.split(','))]
  488. formatted_einsum_str = "{}{}{}".format(','.join(wrapped_terms), arrow, rhs)
  489. # merge adjacent constants
  490. formatted_einsum_str = formatted_einsum_str.replace("],[", ',')
  491. return formatted_einsum_str
  492. class ContractExpression:
  493. """Helper class for storing an explicit ``contraction_list`` which can
  494. then be repeatedly called solely with the array arguments.
  495. """
  496. def __init__(self, contraction, contraction_list, constants_dict, **einsum_kwargs):
  497. self.contraction_list = contraction_list
  498. self.einsum_kwargs = einsum_kwargs
  499. self.contraction = format_const_einsum_str(contraction, constants_dict.keys())
  500. # need to know _full_num_args to parse constants with, and num_args to call with
  501. self._full_num_args = contraction.count(',') + 1
  502. self.num_args = self._full_num_args - len(constants_dict)
  503. # likewise need to know full contraction list
  504. self._full_contraction_list = contraction_list
  505. self._constants_dict = constants_dict
  506. self._evaluated_constants = {}
  507. self._backend_expressions = {}
  508. def evaluate_constants(self, backend='auto'):
  509. """Convert any constant operands to the correct backend form, and
  510. perform as many contractions as possible to create a new list of
  511. operands, stored in ``self._evaluated_constants[backend]``. This also
  512. makes sure ``self.contraction_list`` only contains the remaining,
  513. non-const operations.
  514. """
  515. # prepare a list of operands, with `None` for non-consts
  516. tmp_const_ops = [self._constants_dict.get(i, None) for i in range(self._full_num_args)]
  517. backend = parse_backend(tmp_const_ops, backend)
  518. # get the new list of operands with constant operations performed, and remaining contractions
  519. try:
  520. new_ops, new_contraction_list = backends.evaluate_constants(backend, tmp_const_ops, self)
  521. except KeyError:
  522. new_ops, new_contraction_list = self(*tmp_const_ops, backend=backend, evaluate_constants=True)
  523. self._evaluated_constants[backend] = new_ops
  524. self.contraction_list = new_contraction_list
  525. def _get_evaluated_constants(self, backend):
  526. """Retrieve or generate the cached list of constant operators (mixed
  527. in with None representing non-consts) and the remaining contraction
  528. list.
  529. """
  530. try:
  531. return self._evaluated_constants[backend]
  532. except KeyError:
  533. self.evaluate_constants(backend)
  534. return self._evaluated_constants[backend]
  535. def _get_backend_expression(self, arrays, backend):
  536. try:
  537. return self._backend_expressions[backend]
  538. except KeyError:
  539. fn = backends.build_expression(backend, arrays, self)
  540. self._backend_expressions[backend] = fn
  541. return fn
  542. def _contract(self, arrays, out=None, backend='auto', evaluate_constants=False):
  543. """The normal, core contraction.
  544. """
  545. contraction_list = self._full_contraction_list if evaluate_constants else self.contraction_list
  546. return _core_contract(list(arrays),
  547. contraction_list,
  548. out=out,
  549. backend=backend,
  550. evaluate_constants=evaluate_constants,
  551. **self.einsum_kwargs)
  552. def _contract_with_conversion(self, arrays, out, backend, evaluate_constants=False):
  553. """Special contraction, i.e., contraction with a different backend
  554. but converting to and from that backend. Retrieves or generates a
  555. cached expression using ``arrays`` as templates, then calls it
  556. with ``arrays``.
  557. If ``evaluate_constants=True``, perform a partial contraction that
  558. prepares the constant tensors and operations with the right backend.
  559. """
  560. # convert consts to correct type & find reduced contraction list
  561. if evaluate_constants:
  562. return backends.evaluate_constants(backend, arrays, self)
  563. result = self._get_backend_expression(arrays, backend)(*arrays)
  564. if out is not None:
  565. out[()] = result
  566. return out
  567. return result
  568. def __call__(self, *arrays, **kwargs):
  569. """Evaluate this expression with a set of arrays.
  570. Parameters
  571. ----------
  572. arrays : seq of array
  573. The arrays to supply as input to the expression.
  574. out : array, optional (default: ``None``)
  575. If specified, output the result into this array.
  576. backend : str, optional (default: ``numpy``)
  577. Perform the contraction with this backend library. If numpy arrays
  578. are supplied then try to convert them to and from the correct
  579. backend array type.
  580. """
  581. out = kwargs.pop('out', None)
  582. backend = kwargs.pop('backend', 'auto')
  583. backend = parse_backend(arrays, backend)
  584. evaluate_constants = kwargs.pop('evaluate_constants', False)
  585. if kwargs:
  586. raise ValueError("The only valid keyword arguments to a `ContractExpression` "
  587. "call are `out=` or `backend=`. Got: {}.".format(kwargs))
  588. correct_num_args = self._full_num_args if evaluate_constants else self.num_args
  589. if len(arrays) != correct_num_args:
  590. raise ValueError("This `ContractExpression` takes exactly {} array arguments "
  591. "but received {}.".format(self.num_args, len(arrays)))
  592. if self._constants_dict and not evaluate_constants:
  593. # fill in the missing non-constant terms with newly supplied arrays
  594. ops_var, ops_const = iter(arrays), self._get_evaluated_constants(backend)
  595. ops = [next(ops_var) if op is None else op for op in ops_const]
  596. else:
  597. ops = arrays
  598. try:
  599. # Check if the backend requires special preparation / calling
  600. # but also ignore non-numpy arrays -> assume user wants same type back
  601. if backends.has_backend(backend) and all(isinstance(x, np.ndarray) for x in arrays):
  602. return self._contract_with_conversion(ops, out, backend, evaluate_constants=evaluate_constants)
  603. return self._contract(ops, out, backend, evaluate_constants=evaluate_constants)
  604. except ValueError as err:
  605. original_msg = str(err.args) if err.args else ""
  606. msg = ("Internal error while evaluating `ContractExpression`. Note that few checks are performed"
  607. " - the number and rank of the array arguments must match the original expression. "
  608. "The internal error was: '{}'".format(original_msg), )
  609. err.args = msg
  610. raise
  611. def __repr__(self):
  612. if self._constants_dict:
  613. constants_repr = ", constants={}".format(sorted(self._constants_dict))
  614. else:
  615. constants_repr = ""
  616. return "<ContractExpression('{}'{})>".format(self.contraction, constants_repr)
  617. def __str__(self):
  618. s = [self.__repr__()]
  619. for i, c in enumerate(self.contraction_list):
  620. s.append("\n {}. ".format(i + 1))
  621. s.append("'{}'".format(c[2]) + (" [{}]".format(c[-1]) if c[-1] else ""))
  622. if self.einsum_kwargs:
  623. s.append("\neinsum_kwargs={}".format(self.einsum_kwargs))
  624. return "".join(s)
  625. Shaped = namedtuple('Shaped', ['shape'])
  626. def shape_only(shape):
  627. """Dummy ``numpy.ndarray`` which has a shape only - for generating
  628. contract expressions.
  629. """
  630. return Shaped(shape)
  631. def contract_expression(subscripts, *shapes, **kwargs):
  632. """Generate a reusable expression for a given contraction with
  633. specific shapes, which can, for example, be cached.
  634. Parameters
  635. ----------
  636. subscripts : str
  637. Specifies the subscripts for summation.
  638. shapes : sequence of integer tuples
  639. Shapes of the arrays to optimize the contraction for.
  640. constants : sequence of int, optional
  641. The indices of any constant arguments in ``shapes``, in which case the
  642. actual array should be supplied at that position rather than just a
  643. shape. If these are specified, then constant parts of the contraction
  644. between calls will be reused. Additionally, if a GPU-enabled backend is
  645. used for example, then the constant tensors will be kept on the GPU,
  646. minimizing transfers.
  647. kwargs :
  648. Passed on to ``contract_path`` or ``einsum``. See ``contract``.
  649. Returns
  650. -------
  651. expr : ContractExpression
  652. Callable with signature ``expr(*arrays, out=None, backend='numpy')``
  653. where the array's shapes should match ``shapes``.
  654. Notes
  655. -----
  656. - The `out` keyword argument should be supplied to the generated expression
  657. rather than this function.
  658. - The `backend` keyword argument should also be supplied to the generated
  659. expression. If numpy arrays are supplied, if possible they will be
  660. converted to and back from the correct backend array type.
  661. - The generated expression will work with any arrays which have
  662. the same rank (number of dimensions) as the original shapes, however, if
  663. the actual sizes are different, the expression may no longer be optimal.
  664. - Constant operations will be computed upon the first call with a particular
  665. backend, then subsequently reused.
  666. Examples
  667. --------
  668. Basic usage:
  669. >>> expr = contract_expression("ab,bc->ac", (3, 4), (4, 5))
  670. >>> a, b = np.random.rand(3, 4), np.random.rand(4, 5)
  671. >>> c = expr(a, b)
  672. >>> np.allclose(c, a @ b)
  673. True
  674. Supply ``a`` as a constant:
  675. >>> expr = contract_expression("ab,bc->ac", a, (4, 5), constants=[0])
  676. >>> expr
  677. <ContractExpression('[ab],bc->ac', constants=[0])>
  678. >>> c = expr(b)
  679. >>> np.allclose(c, a @ b)
  680. True
  681. """
  682. if not kwargs.get('optimize', True):
  683. raise ValueError("Can only generate expressions for optimized contractions.")
  684. for arg in ('out', 'backend'):
  685. if kwargs.get(arg, None) is not None:
  686. raise ValueError("'{}' should only be specified when calling a "
  687. "`ContractExpression`, not when building it.".format(arg))
  688. if not isinstance(subscripts, str):
  689. subscripts, shapes = parser.convert_interleaved_input((subscripts, ) + shapes)
  690. kwargs['_gen_expression'] = True
  691. # build dict of constant indices mapped to arrays
  692. constants = kwargs.pop('constants', ())
  693. constants_dict = {i: shapes[i] for i in constants}
  694. kwargs['_constants_dict'] = constants_dict
  695. # apart from constant arguments, make dummy arrays
  696. dummy_arrays = [s if i in constants else shape_only(s) for i, s in enumerate(shapes)]
  697. return contract(subscripts, *dummy_arrays, **kwargs)