einsumfunc.py 57 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650
  1. """
  2. Implementation of optimized einsum.
  3. """
  4. import functools
  5. import itertools
  6. import operator
  7. from numpy._core.multiarray import c_einsum, matmul
  8. from numpy._core.numeric import asanyarray, reshape
  9. from numpy._core.overrides import array_function_dispatch
  10. from numpy._core.umath import multiply
  11. __all__ = ['einsum', 'einsum_path']
  12. # importing string for string.ascii_letters would be too slow
  13. # the first import before caching has been measured to take 800 µs (#23777)
  14. # imports begin with uppercase to mimic ASCII values to avoid sorting issues
  15. einsum_symbols = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
  16. einsum_symbols_set = set(einsum_symbols)
  17. def _flop_count(idx_contraction, inner, num_terms, size_dictionary):
  18. """
  19. Computes the number of FLOPS in the contraction.
  20. Parameters
  21. ----------
  22. idx_contraction : iterable
  23. The indices involved in the contraction
  24. inner : bool
  25. Does this contraction require an inner product?
  26. num_terms : int
  27. The number of terms in a contraction
  28. size_dictionary : dict
  29. The size of each of the indices in idx_contraction
  30. Returns
  31. -------
  32. flop_count : int
  33. The total number of FLOPS required for the contraction.
  34. Examples
  35. --------
  36. >>> _flop_count('abc', False, 1, {'a': 2, 'b':3, 'c':5})
  37. 30
  38. >>> _flop_count('abc', True, 2, {'a': 2, 'b':3, 'c':5})
  39. 60
  40. """
  41. overall_size = _compute_size_by_dict(idx_contraction, size_dictionary)
  42. op_factor = max(1, num_terms - 1)
  43. if inner:
  44. op_factor += 1
  45. return overall_size * op_factor
  46. def _compute_size_by_dict(indices, idx_dict):
  47. """
  48. Computes the product of the elements in indices based on the dictionary
  49. idx_dict.
  50. Parameters
  51. ----------
  52. indices : iterable
  53. Indices to base the product on.
  54. idx_dict : dictionary
  55. Dictionary of index sizes
  56. Returns
  57. -------
  58. ret : int
  59. The resulting product.
  60. Examples
  61. --------
  62. >>> _compute_size_by_dict('abbc', {'a': 2, 'b':3, 'c':5})
  63. 90
  64. """
  65. ret = 1
  66. for i in indices:
  67. ret *= idx_dict[i]
  68. return ret
  69. def _find_contraction(positions, input_sets, output_set):
  70. """
  71. Finds the contraction for a given set of input and output sets.
  72. Parameters
  73. ----------
  74. positions : iterable
  75. Integer positions of terms used in the contraction.
  76. input_sets : list
  77. List of sets that represent the lhs side of the einsum subscript
  78. output_set : set
  79. Set that represents the rhs side of the overall einsum subscript
  80. Returns
  81. -------
  82. new_result : set
  83. The indices of the resulting contraction
  84. remaining : list
  85. List of sets that have not been contracted, the new set is appended to
  86. the end of this list
  87. idx_removed : set
  88. Indices removed from the entire contraction
  89. idx_contraction : set
  90. The indices used in the current contraction
  91. Examples
  92. --------
  93. # A simple dot product test case
  94. >>> pos = (0, 1)
  95. >>> isets = [set('ab'), set('bc')]
  96. >>> oset = set('ac')
  97. >>> _find_contraction(pos, isets, oset)
  98. ({'a', 'c'}, [{'a', 'c'}], {'b'}, {'a', 'b', 'c'})
  99. # A more complex case with additional terms in the contraction
  100. >>> pos = (0, 2)
  101. >>> isets = [set('abd'), set('ac'), set('bdc')]
  102. >>> oset = set('ac')
  103. >>> _find_contraction(pos, isets, oset)
  104. ({'a', 'c'}, [{'a', 'c'}, {'a', 'c'}], {'b', 'd'}, {'a', 'b', 'c', 'd'})
  105. """
  106. idx_contract = set()
  107. idx_remain = output_set.copy()
  108. remaining = []
  109. for ind, value in enumerate(input_sets):
  110. if ind in positions:
  111. idx_contract |= value
  112. else:
  113. remaining.append(value)
  114. idx_remain |= value
  115. new_result = idx_remain & idx_contract
  116. idx_removed = (idx_contract - new_result)
  117. remaining.append(new_result)
  118. return (new_result, remaining, idx_removed, idx_contract)
  119. def _optimal_path(input_sets, output_set, idx_dict, memory_limit):
  120. """
  121. Computes all possible pair contractions, sieves the results based
  122. on ``memory_limit`` and returns the lowest cost path. This algorithm
  123. scales factorial with respect to the elements in the list ``input_sets``.
  124. Parameters
  125. ----------
  126. input_sets : list
  127. List of sets that represent the lhs side of the einsum subscript
  128. output_set : set
  129. Set that represents the rhs side of the overall einsum subscript
  130. idx_dict : dictionary
  131. Dictionary of index sizes
  132. memory_limit : int
  133. The maximum number of elements in a temporary array
  134. Returns
  135. -------
  136. path : list
  137. The optimal contraction order within the memory limit constraint.
  138. Examples
  139. --------
  140. >>> isets = [set('abd'), set('ac'), set('bdc')]
  141. >>> oset = set()
  142. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  143. >>> _optimal_path(isets, oset, idx_sizes, 5000)
  144. [(0, 2), (0, 1)]
  145. """
  146. full_results = [(0, [], input_sets)]
  147. for iteration in range(len(input_sets) - 1):
  148. iter_results = []
  149. # Compute all unique pairs
  150. for curr in full_results:
  151. cost, positions, remaining = curr
  152. for con in itertools.combinations(
  153. range(len(input_sets) - iteration), 2
  154. ):
  155. # Find the contraction
  156. cont = _find_contraction(con, remaining, output_set)
  157. new_result, new_input_sets, idx_removed, idx_contract = cont
  158. # Sieve the results based on memory_limit
  159. new_size = _compute_size_by_dict(new_result, idx_dict)
  160. if new_size > memory_limit:
  161. continue
  162. # Build (total_cost, positions, indices_remaining)
  163. total_cost = cost + _flop_count(
  164. idx_contract, idx_removed, len(con), idx_dict
  165. )
  166. new_pos = positions + [con]
  167. iter_results.append((total_cost, new_pos, new_input_sets))
  168. # Update combinatorial list, if we did not find anything return best
  169. # path + remaining contractions
  170. if iter_results:
  171. full_results = iter_results
  172. else:
  173. path = min(full_results, key=lambda x: x[0])[1]
  174. path += [tuple(range(len(input_sets) - iteration))]
  175. return path
  176. # If we have not found anything return single einsum contraction
  177. if len(full_results) == 0:
  178. return [tuple(range(len(input_sets)))]
  179. path = min(full_results, key=lambda x: x[0])[1]
  180. return path
  181. def _parse_possible_contraction(
  182. positions, input_sets, output_set, idx_dict,
  183. memory_limit, path_cost, naive_cost
  184. ):
  185. """Compute the cost (removed size + flops) and resultant indices for
  186. performing the contraction specified by ``positions``.
  187. Parameters
  188. ----------
  189. positions : tuple of int
  190. The locations of the proposed tensors to contract.
  191. input_sets : list of sets
  192. The indices found on each tensors.
  193. output_set : set
  194. The output indices of the expression.
  195. idx_dict : dict
  196. Mapping of each index to its size.
  197. memory_limit : int
  198. The total allowed size for an intermediary tensor.
  199. path_cost : int
  200. The contraction cost so far.
  201. naive_cost : int
  202. The cost of the unoptimized expression.
  203. Returns
  204. -------
  205. cost : (int, int)
  206. A tuple containing the size of any indices removed, and the flop cost.
  207. positions : tuple of int
  208. The locations of the proposed tensors to contract.
  209. new_input_sets : list of sets
  210. The resulting new list of indices if this proposed contraction
  211. is performed.
  212. """
  213. # Find the contraction
  214. contract = _find_contraction(positions, input_sets, output_set)
  215. idx_result, new_input_sets, idx_removed, idx_contract = contract
  216. # Sieve the results based on memory_limit
  217. new_size = _compute_size_by_dict(idx_result, idx_dict)
  218. if new_size > memory_limit:
  219. return None
  220. # Build sort tuple
  221. old_sizes = (
  222. _compute_size_by_dict(input_sets[p], idx_dict) for p in positions
  223. )
  224. removed_size = sum(old_sizes) - new_size
  225. # NB: removed_size used to be just the size of any removed indices i.e.:
  226. # helpers.compute_size_by_dict(idx_removed, idx_dict)
  227. cost = _flop_count(idx_contract, idx_removed, len(positions), idx_dict)
  228. sort = (-removed_size, cost)
  229. # Sieve based on total cost as well
  230. if (path_cost + cost) > naive_cost:
  231. return None
  232. # Add contraction to possible choices
  233. return [sort, positions, new_input_sets]
  234. def _update_other_results(results, best):
  235. """Update the positions and provisional input_sets of ``results``
  236. based on performing the contraction result ``best``. Remove any
  237. involving the tensors contracted.
  238. Parameters
  239. ----------
  240. results : list
  241. List of contraction results produced by
  242. ``_parse_possible_contraction``.
  243. best : list
  244. The best contraction of ``results`` i.e. the one that
  245. will be performed.
  246. Returns
  247. -------
  248. mod_results : list
  249. The list of modified results, updated with outcome of
  250. ``best`` contraction.
  251. """
  252. best_con = best[1]
  253. bx, by = best_con
  254. mod_results = []
  255. for cost, (x, y), con_sets in results:
  256. # Ignore results involving tensors just contracted
  257. if x in best_con or y in best_con:
  258. continue
  259. # Update the input_sets
  260. del con_sets[by - int(by > x) - int(by > y)]
  261. del con_sets[bx - int(bx > x) - int(bx > y)]
  262. con_sets.insert(-1, best[2][-1])
  263. # Update the position indices
  264. mod_con = x - int(x > bx) - int(x > by), y - int(y > bx) - int(y > by)
  265. mod_results.append((cost, mod_con, con_sets))
  266. return mod_results
  267. def _greedy_path(input_sets, output_set, idx_dict, memory_limit):
  268. """
  269. Finds the path by contracting the best pair until the input list is
  270. exhausted. The best pair is found by minimizing the tuple
  271. ``(-prod(indices_removed), cost)``. What this amounts to is prioritizing
  272. matrix multiplication or inner product operations, then Hadamard like
  273. operations, and finally outer operations. Outer products are limited by
  274. ``memory_limit``. This algorithm scales cubically with respect to the
  275. number of elements in the list ``input_sets``.
  276. Parameters
  277. ----------
  278. input_sets : list
  279. List of sets that represent the lhs side of the einsum subscript
  280. output_set : set
  281. Set that represents the rhs side of the overall einsum subscript
  282. idx_dict : dictionary
  283. Dictionary of index sizes
  284. memory_limit : int
  285. The maximum number of elements in a temporary array
  286. Returns
  287. -------
  288. path : list
  289. The greedy contraction order within the memory limit constraint.
  290. Examples
  291. --------
  292. >>> isets = [set('abd'), set('ac'), set('bdc')]
  293. >>> oset = set()
  294. >>> idx_sizes = {'a': 1, 'b':2, 'c':3, 'd':4}
  295. >>> _greedy_path(isets, oset, idx_sizes, 5000)
  296. [(0, 2), (0, 1)]
  297. """
  298. # Handle trivial cases that leaked through
  299. if len(input_sets) == 1:
  300. return [(0,)]
  301. elif len(input_sets) == 2:
  302. return [(0, 1)]
  303. # Build up a naive cost
  304. contract = _find_contraction(
  305. range(len(input_sets)), input_sets, output_set
  306. )
  307. idx_result, new_input_sets, idx_removed, idx_contract = contract
  308. naive_cost = _flop_count(
  309. idx_contract, idx_removed, len(input_sets), idx_dict
  310. )
  311. # Initially iterate over all pairs
  312. comb_iter = itertools.combinations(range(len(input_sets)), 2)
  313. known_contractions = []
  314. path_cost = 0
  315. path = []
  316. for iteration in range(len(input_sets) - 1):
  317. # Iterate over all pairs on the first step, only previously
  318. # found pairs on subsequent steps
  319. for positions in comb_iter:
  320. # Always initially ignore outer products
  321. if input_sets[positions[0]].isdisjoint(input_sets[positions[1]]):
  322. continue
  323. result = _parse_possible_contraction(
  324. positions, input_sets, output_set, idx_dict,
  325. memory_limit, path_cost, naive_cost
  326. )
  327. if result is not None:
  328. known_contractions.append(result)
  329. # If we do not have a inner contraction, rescan pairs
  330. # including outer products
  331. if len(known_contractions) == 0:
  332. # Then check the outer products
  333. for positions in itertools.combinations(
  334. range(len(input_sets)), 2
  335. ):
  336. result = _parse_possible_contraction(
  337. positions, input_sets, output_set, idx_dict,
  338. memory_limit, path_cost, naive_cost
  339. )
  340. if result is not None:
  341. known_contractions.append(result)
  342. # If we still did not find any remaining contractions,
  343. # default back to einsum like behavior
  344. if len(known_contractions) == 0:
  345. path.append(tuple(range(len(input_sets))))
  346. break
  347. # Sort based on first index
  348. best = min(known_contractions, key=lambda x: x[0])
  349. # Now propagate as many unused contractions as possible
  350. # to the next iteration
  351. known_contractions = _update_other_results(known_contractions, best)
  352. # Next iteration only compute contractions with the new tensor
  353. # All other contractions have been accounted for
  354. input_sets = best[2]
  355. new_tensor_pos = len(input_sets) - 1
  356. comb_iter = ((i, new_tensor_pos) for i in range(new_tensor_pos))
  357. # Update path and total cost
  358. path.append(best[1])
  359. path_cost += best[0][1]
  360. return path
  361. def _parse_einsum_input(operands):
  362. """
  363. A reproduction of einsum c side einsum parsing in python.
  364. Returns
  365. -------
  366. input_strings : str
  367. Parsed input strings
  368. output_string : str
  369. Parsed output string
  370. operands : list of array_like
  371. The operands to use in the numpy contraction
  372. Examples
  373. --------
  374. The operand list is simplified to reduce printing:
  375. >>> np.random.seed(123)
  376. >>> a = np.random.rand(4, 4)
  377. >>> b = np.random.rand(4, 4, 4)
  378. >>> _parse_einsum_input(('...a,...a->...', a, b))
  379. ('za,xza', 'xz', [a, b]) # may vary
  380. >>> _parse_einsum_input((a, [Ellipsis, 0], b, [Ellipsis, 0]))
  381. ('za,xza', 'xz', [a, b]) # may vary
  382. """
  383. if len(operands) == 0:
  384. raise ValueError("No input operands")
  385. if isinstance(operands[0], str):
  386. subscripts = operands[0].replace(" ", "")
  387. operands = [asanyarray(v) for v in operands[1:]]
  388. # Ensure all characters are valid
  389. for s in subscripts:
  390. if s in '.,->':
  391. continue
  392. if s not in einsum_symbols:
  393. raise ValueError(f"Character {s} is not a valid symbol.")
  394. else:
  395. tmp_operands = list(operands)
  396. operand_list = []
  397. subscript_list = []
  398. for p in range(len(operands) // 2):
  399. operand_list.append(tmp_operands.pop(0))
  400. subscript_list.append(tmp_operands.pop(0))
  401. output_list = tmp_operands[-1] if len(tmp_operands) else None
  402. operands = [asanyarray(v) for v in operand_list]
  403. subscripts = ""
  404. last = len(subscript_list) - 1
  405. for num, sub in enumerate(subscript_list):
  406. for s in sub:
  407. if s is Ellipsis:
  408. subscripts += "..."
  409. else:
  410. try:
  411. s = operator.index(s)
  412. except TypeError as e:
  413. raise TypeError(
  414. "For this input type lists must contain "
  415. "either int or Ellipsis"
  416. ) from e
  417. subscripts += einsum_symbols[s]
  418. if num != last:
  419. subscripts += ","
  420. if output_list is not None:
  421. subscripts += "->"
  422. for s in output_list:
  423. if s is Ellipsis:
  424. subscripts += "..."
  425. else:
  426. try:
  427. s = operator.index(s)
  428. except TypeError as e:
  429. raise TypeError(
  430. "For this input type lists must contain "
  431. "either int or Ellipsis"
  432. ) from e
  433. subscripts += einsum_symbols[s]
  434. # Check for proper "->"
  435. if ("-" in subscripts) or (">" in subscripts):
  436. invalid = (subscripts.count("-") > 1) or (subscripts.count(">") > 1)
  437. if invalid or (subscripts.count("->") != 1):
  438. raise ValueError("Subscripts can only contain one '->'.")
  439. # Parse ellipses
  440. if "." in subscripts:
  441. used = subscripts.replace(".", "").replace(",", "").replace("->", "")
  442. unused = list(einsum_symbols_set - set(used))
  443. ellipse_inds = "".join(unused)
  444. longest = 0
  445. if "->" in subscripts:
  446. input_tmp, output_sub = subscripts.split("->")
  447. split_subscripts = input_tmp.split(",")
  448. out_sub = True
  449. else:
  450. split_subscripts = subscripts.split(',')
  451. out_sub = False
  452. for num, sub in enumerate(split_subscripts):
  453. if "." in sub:
  454. if (sub.count(".") != 3) or (sub.count("...") != 1):
  455. raise ValueError("Invalid Ellipses.")
  456. # Take into account numerical values
  457. if operands[num].shape == ():
  458. ellipse_count = 0
  459. else:
  460. ellipse_count = max(operands[num].ndim, 1)
  461. ellipse_count -= (len(sub) - 3)
  462. if ellipse_count > longest:
  463. longest = ellipse_count
  464. if ellipse_count < 0:
  465. raise ValueError("Ellipses lengths do not match.")
  466. elif ellipse_count == 0:
  467. split_subscripts[num] = sub.replace('...', '')
  468. else:
  469. rep_inds = ellipse_inds[-ellipse_count:]
  470. split_subscripts[num] = sub.replace('...', rep_inds)
  471. subscripts = ",".join(split_subscripts)
  472. if longest == 0:
  473. out_ellipse = ""
  474. else:
  475. out_ellipse = ellipse_inds[-longest:]
  476. if out_sub:
  477. subscripts += "->" + output_sub.replace("...", out_ellipse)
  478. else:
  479. # Special care for outputless ellipses
  480. output_subscript = ""
  481. tmp_subscripts = subscripts.replace(",", "")
  482. for s in sorted(set(tmp_subscripts)):
  483. if s not in (einsum_symbols):
  484. raise ValueError(f"Character {s} is not a valid symbol.")
  485. if tmp_subscripts.count(s) == 1:
  486. output_subscript += s
  487. normal_inds = ''.join(sorted(set(output_subscript) -
  488. set(out_ellipse)))
  489. subscripts += "->" + out_ellipse + normal_inds
  490. # Build output string if does not exist
  491. if "->" in subscripts:
  492. input_subscripts, output_subscript = subscripts.split("->")
  493. else:
  494. input_subscripts = subscripts
  495. # Build output subscripts
  496. tmp_subscripts = subscripts.replace(",", "")
  497. output_subscript = ""
  498. for s in sorted(set(tmp_subscripts)):
  499. if s not in einsum_symbols:
  500. raise ValueError(f"Character {s} is not a valid symbol.")
  501. if tmp_subscripts.count(s) == 1:
  502. output_subscript += s
  503. # Make sure output subscripts are in the input
  504. for char in output_subscript:
  505. if output_subscript.count(char) != 1:
  506. raise ValueError("Output character %s appeared more than once in "
  507. "the output." % char)
  508. if char not in input_subscripts:
  509. raise ValueError(f"Output character {char} did not appear in the input")
  510. # Make sure number operands is equivalent to the number of terms
  511. if len(input_subscripts.split(',')) != len(operands):
  512. raise ValueError("Number of einsum subscripts must be equal to the "
  513. "number of operands.")
  514. return (input_subscripts, output_subscript, operands)
  515. def _einsum_path_dispatcher(*operands, optimize=None, einsum_call=None):
  516. # NOTE: technically, we should only dispatch on array-like arguments, not
  517. # subscripts (given as strings). But separating operands into
  518. # arrays/subscripts is a little tricky/slow (given einsum's two supported
  519. # signatures), so as a practical shortcut we dispatch on everything.
  520. # Strings will be ignored for dispatching since they don't define
  521. # __array_function__.
  522. return operands
  523. @array_function_dispatch(_einsum_path_dispatcher, module='numpy')
  524. def einsum_path(*operands, optimize='greedy', einsum_call=False):
  525. """
  526. einsum_path(subscripts, *operands, optimize='greedy')
  527. Evaluates the lowest cost contraction order for an einsum expression by
  528. considering the creation of intermediate arrays.
  529. Parameters
  530. ----------
  531. subscripts : str
  532. Specifies the subscripts for summation.
  533. *operands : list of array_like
  534. These are the arrays for the operation.
  535. optimize : {bool, list, tuple, 'greedy', 'optimal'}
  536. Choose the type of path. If a tuple is provided, the second argument is
  537. assumed to be the maximum intermediate size created. If only a single
  538. argument is provided the largest input or output array size is used
  539. as a maximum intermediate size.
  540. * if a list is given that starts with ``einsum_path``, uses this as the
  541. contraction path
  542. * if False no optimization is taken
  543. * if True defaults to the 'greedy' algorithm
  544. * 'optimal' An algorithm that combinatorially explores all possible
  545. ways of contracting the listed tensors and chooses the least costly
  546. path. Scales exponentially with the number of terms in the
  547. contraction.
  548. * 'greedy' An algorithm that chooses the best pair contraction
  549. at each step. Effectively, this algorithm searches the largest inner,
  550. Hadamard, and then outer products at each step. Scales cubically with
  551. the number of terms in the contraction. Equivalent to the 'optimal'
  552. path for most contractions.
  553. Default is 'greedy'.
  554. Returns
  555. -------
  556. path : list of tuples
  557. A list representation of the einsum path.
  558. string_repr : str
  559. A printable representation of the einsum path.
  560. Notes
  561. -----
  562. The resulting path indicates which terms of the input contraction should be
  563. contracted first, the result of this contraction is then appended to the
  564. end of the contraction list. This list can then be iterated over until all
  565. intermediate contractions are complete.
  566. See Also
  567. --------
  568. einsum, linalg.multi_dot
  569. Examples
  570. --------
  571. We can begin with a chain dot example. In this case, it is optimal to
  572. contract the ``b`` and ``c`` tensors first as represented by the first
  573. element of the path ``(1, 2)``. The resulting tensor is added to the end
  574. of the contraction and the remaining contraction ``(0, 1)`` is then
  575. completed.
  576. >>> np.random.seed(123)
  577. >>> a = np.random.rand(2, 2)
  578. >>> b = np.random.rand(2, 5)
  579. >>> c = np.random.rand(5, 2)
  580. >>> path_info = np.einsum_path('ij,jk,kl->il', a, b, c, optimize='greedy')
  581. >>> print(path_info[0])
  582. ['einsum_path', (1, 2), (0, 1)]
  583. >>> print(path_info[1])
  584. Complete contraction: ij,jk,kl->il # may vary
  585. Naive scaling: 4
  586. Optimized scaling: 3
  587. Naive FLOP count: 1.600e+02
  588. Optimized FLOP count: 5.600e+01
  589. Theoretical speedup: 2.857
  590. Largest intermediate: 4.000e+00 elements
  591. -------------------------------------------------------------------------
  592. scaling current remaining
  593. -------------------------------------------------------------------------
  594. 3 kl,jk->jl ij,jl->il
  595. 3 jl,ij->il il->il
  596. A more complex index transformation example.
  597. >>> I = np.random.rand(10, 10, 10, 10)
  598. >>> C = np.random.rand(10, 10)
  599. >>> path_info = np.einsum_path('ea,fb,abcd,gc,hd->efgh', C, C, I, C, C,
  600. ... optimize='greedy')
  601. >>> print(path_info[0])
  602. ['einsum_path', (0, 2), (0, 3), (0, 2), (0, 1)]
  603. >>> print(path_info[1])
  604. Complete contraction: ea,fb,abcd,gc,hd->efgh # may vary
  605. Naive scaling: 8
  606. Optimized scaling: 5
  607. Naive FLOP count: 8.000e+08
  608. Optimized FLOP count: 8.000e+05
  609. Theoretical speedup: 1000.000
  610. Largest intermediate: 1.000e+04 elements
  611. --------------------------------------------------------------------------
  612. scaling current remaining
  613. --------------------------------------------------------------------------
  614. 5 abcd,ea->bcde fb,gc,hd,bcde->efgh
  615. 5 bcde,fb->cdef gc,hd,cdef->efgh
  616. 5 cdef,gc->defg hd,defg->efgh
  617. 5 defg,hd->efgh efgh->efgh
  618. """
  619. # Figure out what the path really is
  620. path_type = optimize
  621. if path_type is True:
  622. path_type = 'greedy'
  623. if path_type is None:
  624. path_type = False
  625. explicit_einsum_path = False
  626. memory_limit = None
  627. # No optimization or a named path algorithm
  628. if (path_type is False) or isinstance(path_type, str):
  629. pass
  630. # Given an explicit path
  631. elif len(path_type) and (path_type[0] == 'einsum_path'):
  632. explicit_einsum_path = True
  633. # Path tuple with memory limit
  634. elif ((len(path_type) == 2) and isinstance(path_type[0], str) and
  635. isinstance(path_type[1], (int, float))):
  636. memory_limit = int(path_type[1])
  637. path_type = path_type[0]
  638. else:
  639. raise TypeError(f"Did not understand the path: {str(path_type)}")
  640. # Hidden option, only einsum should call this
  641. einsum_call_arg = einsum_call
  642. # Python side parsing
  643. input_subscripts, output_subscript, operands = (
  644. _parse_einsum_input(operands)
  645. )
  646. # Build a few useful list and sets
  647. input_list = input_subscripts.split(',')
  648. num_inputs = len(input_list)
  649. input_sets = [set(x) for x in input_list]
  650. output_set = set(output_subscript)
  651. indices = set(input_subscripts.replace(',', ''))
  652. num_indices = len(indices)
  653. # Get length of each unique dimension and ensure all dimensions are correct
  654. dimension_dict = {}
  655. for tnum, term in enumerate(input_list):
  656. sh = operands[tnum].shape
  657. if len(sh) != len(term):
  658. raise ValueError("Einstein sum subscript %s does not contain the "
  659. "correct number of indices for operand %d."
  660. % (input_subscripts[tnum], tnum))
  661. for cnum, char in enumerate(term):
  662. dim = sh[cnum]
  663. if char in dimension_dict.keys():
  664. # For broadcasting cases we always want the largest dim size
  665. if dimension_dict[char] == 1:
  666. dimension_dict[char] = dim
  667. elif dim not in (1, dimension_dict[char]):
  668. raise ValueError("Size of label '%s' for operand %d (%d) "
  669. "does not match previous terms (%d)."
  670. % (char, tnum, dimension_dict[char], dim))
  671. else:
  672. dimension_dict[char] = dim
  673. # Compute size of each input array plus the output array
  674. size_list = [_compute_size_by_dict(term, dimension_dict)
  675. for term in input_list + [output_subscript]]
  676. max_size = max(size_list)
  677. if memory_limit is None:
  678. memory_arg = max_size
  679. else:
  680. memory_arg = memory_limit
  681. # Compute the path
  682. if explicit_einsum_path:
  683. path = path_type[1:]
  684. elif (
  685. (path_type is False)
  686. or (num_inputs in [1, 2])
  687. or (indices == output_set)
  688. ):
  689. # Nothing to be optimized, leave it to einsum
  690. path = [tuple(range(num_inputs))]
  691. elif path_type == "greedy":
  692. path = _greedy_path(
  693. input_sets, output_set, dimension_dict, memory_arg
  694. )
  695. elif path_type == "optimal":
  696. path = _optimal_path(
  697. input_sets, output_set, dimension_dict, memory_arg
  698. )
  699. else:
  700. raise KeyError("Path name %s not found", path_type)
  701. cost_list, scale_list, size_list, contraction_list = [], [], [], []
  702. # Build contraction tuple (positions, gemm, einsum_str, remaining)
  703. for cnum, contract_inds in enumerate(path):
  704. # Make sure we remove inds from right to left
  705. contract_inds = tuple(sorted(contract_inds, reverse=True))
  706. contract = _find_contraction(contract_inds, input_sets, output_set)
  707. out_inds, input_sets, idx_removed, idx_contract = contract
  708. if not einsum_call_arg:
  709. # these are only needed for printing info
  710. cost = _flop_count(
  711. idx_contract, idx_removed, len(contract_inds), dimension_dict
  712. )
  713. cost_list.append(cost)
  714. scale_list.append(len(idx_contract))
  715. size_list.append(_compute_size_by_dict(out_inds, dimension_dict))
  716. tmp_inputs = []
  717. for x in contract_inds:
  718. tmp_inputs.append(input_list.pop(x))
  719. # Last contraction
  720. if (cnum - len(path)) == -1:
  721. idx_result = output_subscript
  722. else:
  723. sort_result = [(dimension_dict[ind], ind) for ind in out_inds]
  724. idx_result = "".join([x[1] for x in sorted(sort_result)])
  725. input_list.append(idx_result)
  726. einsum_str = ",".join(tmp_inputs) + "->" + idx_result
  727. contraction = (contract_inds, einsum_str, input_list[:])
  728. contraction_list.append(contraction)
  729. if len(input_list) != 1:
  730. # Explicit "einsum_path" is usually trusted, but we detect this kind of
  731. # mistake in order to prevent from returning an intermediate value.
  732. raise RuntimeError(
  733. f"Invalid einsum_path is specified: {len(input_list) - 1} more "
  734. "operands has to be contracted.")
  735. if einsum_call_arg:
  736. return (operands, contraction_list)
  737. # Return the path along with a nice string representation
  738. overall_contraction = input_subscripts + "->" + output_subscript
  739. header = ("scaling", "current", "remaining")
  740. # Compute naive cost
  741. # This isn't quite right, need to look into exactly how einsum does this
  742. inner_product = (
  743. sum(len(set(x)) for x in input_subscripts.split(',')) - num_indices
  744. ) > 0
  745. naive_cost = _flop_count(
  746. indices, inner_product, num_inputs, dimension_dict
  747. )
  748. opt_cost = sum(cost_list) + 1
  749. speedup = naive_cost / opt_cost
  750. max_i = max(size_list)
  751. path_print = f" Complete contraction: {overall_contraction}\n"
  752. path_print += f" Naive scaling: {num_indices}\n"
  753. path_print += " Optimized scaling: %d\n" % max(scale_list)
  754. path_print += f" Naive FLOP count: {naive_cost:.3e}\n"
  755. path_print += f" Optimized FLOP count: {opt_cost:.3e}\n"
  756. path_print += f" Theoretical speedup: {speedup:3.3f}\n"
  757. path_print += f" Largest intermediate: {max_i:.3e} elements\n"
  758. path_print += "-" * 74 + "\n"
  759. path_print += "%6s %24s %40s\n" % header
  760. path_print += "-" * 74
  761. for n, contraction in enumerate(contraction_list):
  762. _, einsum_str, remaining = contraction
  763. remaining_str = ",".join(remaining) + "->" + output_subscript
  764. path_run = (scale_list[n], einsum_str, remaining_str)
  765. path_print += "\n%4d %24s %40s" % path_run
  766. path = ['einsum_path'] + path
  767. return (path, path_print)
  768. def _parse_eq_to_pure_multiplication(a_term, shape_a, b_term, shape_b, out):
  769. """If there are no contracted indices, then we can directly transpose and
  770. insert singleton dimensions into ``a`` and ``b`` such that (broadcast)
  771. elementwise multiplication performs the einsum.
  772. No need to cache this as it is within the cached
  773. ``_parse_eq_to_batch_matmul``.
  774. """
  775. desired_a = ""
  776. desired_b = ""
  777. new_shape_a = []
  778. new_shape_b = []
  779. for ix in out:
  780. if ix in a_term:
  781. desired_a += ix
  782. new_shape_a.append(shape_a[a_term.index(ix)])
  783. else:
  784. new_shape_a.append(1)
  785. if ix in b_term:
  786. desired_b += ix
  787. new_shape_b.append(shape_b[b_term.index(ix)])
  788. else:
  789. new_shape_b.append(1)
  790. if desired_a != a_term:
  791. eq_a = f"{a_term}->{desired_a}"
  792. else:
  793. eq_a = None
  794. if desired_b != b_term:
  795. eq_b = f"{b_term}->{desired_b}"
  796. else:
  797. eq_b = None
  798. return (
  799. eq_a,
  800. eq_b,
  801. new_shape_a,
  802. new_shape_b,
  803. None, # new_shape_ab, not needed since not fusing
  804. None, # perm_ab, not needed as we transpose a and b first
  805. True, # pure_multiplication=True
  806. )
  807. @functools.lru_cache(2**12)
  808. def _parse_eq_to_batch_matmul(eq, shape_a, shape_b):
  809. """Cached parsing of a two term einsum equation into the necessary
  810. sequence of arguments for contracttion via batched matrix multiplication.
  811. The steps we need to specify are:
  812. 1. Remove repeated and trivial indices from the left and right terms,
  813. and transpose them, done as a single einsum.
  814. 2. Fuse the remaining indices so we have two 3D tensors.
  815. 3. Perform the batched matrix multiplication.
  816. 4. Unfuse the output to get the desired final index order.
  817. """
  818. lhs, out = eq.split("->")
  819. a_term, b_term = lhs.split(",")
  820. if len(a_term) != len(shape_a):
  821. raise ValueError(f"Term '{a_term}' does not match shape {shape_a}.")
  822. if len(b_term) != len(shape_b):
  823. raise ValueError(f"Term '{b_term}' does not match shape {shape_b}.")
  824. sizes = {}
  825. singletons = set()
  826. # parse left term to unique indices with size > 1
  827. left = {}
  828. for ix, d in zip(a_term, shape_a):
  829. if d == 1:
  830. # everything (including broadcasting) works nicely if simply ignore
  831. # such dimensions, but we do need to track if they appear in output
  832. # and thus should be reintroduced later
  833. singletons.add(ix)
  834. continue
  835. if sizes.setdefault(ix, d) != d:
  836. # set and check size
  837. raise ValueError(
  838. f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
  839. )
  840. left[ix] = True
  841. # parse right term to unique indices with size > 1
  842. right = {}
  843. for ix, d in zip(b_term, shape_b):
  844. # broadcast indices (size 1 on one input and size != 1
  845. # on the other) should not be treated as singletons
  846. if d == 1:
  847. if ix not in left:
  848. singletons.add(ix)
  849. continue
  850. singletons.discard(ix)
  851. if sizes.setdefault(ix, d) != d:
  852. # set and check size
  853. raise ValueError(
  854. f"Index {ix} has mismatched sizes {sizes[ix]} and {d}."
  855. )
  856. right[ix] = True
  857. # now we classify the unique size > 1 indices only
  858. bat_inds = [] # appears on A, B, O
  859. con_inds = [] # appears on A, B, .
  860. a_keep = [] # appears on A, ., O
  861. b_keep = [] # appears on ., B, O
  862. # other indices (appearing on A or B only) will
  863. # be summed or traced out prior to the matmul
  864. for ix in left:
  865. if right.pop(ix, False):
  866. if ix in out:
  867. bat_inds.append(ix)
  868. else:
  869. con_inds.append(ix)
  870. elif ix in out:
  871. a_keep.append(ix)
  872. # now only indices unique to right remain
  873. for ix in right:
  874. if ix in out:
  875. b_keep.append(ix)
  876. if not con_inds:
  877. # contraction is pure multiplication, prepare inputs differently
  878. return _parse_eq_to_pure_multiplication(
  879. a_term, shape_a, b_term, shape_b, out
  880. )
  881. # only need the size one indices that appear in the output
  882. singletons = [ix for ix in out if ix in singletons]
  883. # take diagonal, remove any trivial axes and transpose left
  884. desired_a = "".join((*bat_inds, *a_keep, *con_inds))
  885. if a_term != desired_a:
  886. eq_a = f"{a_term}->{desired_a}"
  887. else:
  888. eq_a = None
  889. # take diagonal, remove any trivial axes and transpose right
  890. desired_b = "".join((*bat_inds, *con_inds, *b_keep))
  891. if b_term != desired_b:
  892. eq_b = f"{b_term}->{desired_b}"
  893. else:
  894. eq_b = None
  895. # then we want to reshape
  896. if bat_inds:
  897. lgroups = (bat_inds, a_keep, con_inds)
  898. rgroups = (bat_inds, con_inds, b_keep)
  899. ogroups = (bat_inds, a_keep, b_keep)
  900. else:
  901. # avoid size 1 batch dimension if no batch indices
  902. lgroups = (a_keep, con_inds)
  903. rgroups = (con_inds, b_keep)
  904. ogroups = (a_keep, b_keep)
  905. if any(len(group) != 1 for group in lgroups):
  906. # need to fuse 'kept' and contracted indices
  907. # (though could allow batch indices to be broadcast)
  908. new_shape_a = tuple(
  909. functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
  910. for ix_group in lgroups
  911. )
  912. else:
  913. new_shape_a = None
  914. if any(len(group) != 1 for group in rgroups):
  915. # need to fuse 'kept' and contracted indices
  916. # (though could allow batch indices to be broadcast)
  917. new_shape_b = tuple(
  918. functools.reduce(operator.mul, (sizes[ix] for ix in ix_group), 1)
  919. for ix_group in rgroups
  920. )
  921. else:
  922. new_shape_b = None
  923. if any(len(group) != 1 for group in ogroups) or singletons:
  924. new_shape_ab = (1,) * len(singletons) + tuple(
  925. sizes[ix] for ix_group in ogroups for ix in ix_group
  926. )
  927. else:
  928. new_shape_ab = None
  929. # then we might need to permute the matmul produced output:
  930. out_produced = "".join((*singletons, *bat_inds, *a_keep, *b_keep))
  931. if out_produced != out:
  932. perm_ab = tuple(out_produced.index(ix) for ix in out)
  933. else:
  934. perm_ab = None
  935. return (
  936. eq_a,
  937. eq_b,
  938. new_shape_a,
  939. new_shape_b,
  940. new_shape_ab,
  941. perm_ab,
  942. False, # pure_multiplication=False
  943. )
  944. @functools.lru_cache(maxsize=64)
  945. def _parse_output_order(order, a_is_fcontig, b_is_fcontig):
  946. order = order.upper()
  947. if order == "K":
  948. return None
  949. elif order in "CF":
  950. return order
  951. elif order == "A":
  952. if a_is_fcontig and b_is_fcontig:
  953. return "F"
  954. else:
  955. return "C"
  956. else:
  957. raise ValueError(
  958. "ValueError: order must be one of "
  959. f"'C', 'F', 'A', or 'K' (got '{order}')"
  960. )
  961. def bmm_einsum(eq, a, b, out=None, **kwargs):
  962. """Perform arbitrary pairwise einsums using only ``matmul``, or
  963. ``multiply`` if no contracted indices are involved (plus maybe single term
  964. ``einsum`` to prepare the terms individually). The logic for each is cached
  965. based on the equation and array shape, and each step is only performed if
  966. necessary.
  967. Parameters
  968. ----------
  969. eq : str
  970. The einsum equation.
  971. a : array_like
  972. The first array to contract.
  973. b : array_like
  974. The second array to contract.
  975. Returns
  976. -------
  977. array_like
  978. Notes
  979. -----
  980. A fuller description of this algorithm, and original source for this
  981. implementation, can be found at https://github.com/jcmgray/einsum_bmm.
  982. """
  983. (
  984. eq_a,
  985. eq_b,
  986. new_shape_a,
  987. new_shape_b,
  988. new_shape_ab,
  989. perm_ab,
  990. pure_multiplication,
  991. ) = _parse_eq_to_batch_matmul(eq, a.shape, b.shape)
  992. # n.b. one could special case various cases to call c_einsum directly here
  993. # need to handle `order` a little manually, since we do transpose
  994. # operations before and potentially after the ufunc calls
  995. output_order = _parse_output_order(
  996. kwargs.pop("order", "K"), a.flags.f_contiguous, b.flags.f_contiguous
  997. )
  998. # prepare left
  999. if eq_a is not None:
  1000. # diagonals, sums, and tranpose
  1001. a = c_einsum(eq_a, a)
  1002. if new_shape_a is not None:
  1003. a = reshape(a, new_shape_a)
  1004. # prepare right
  1005. if eq_b is not None:
  1006. # diagonals, sums, and tranpose
  1007. b = c_einsum(eq_b, b)
  1008. if new_shape_b is not None:
  1009. b = reshape(b, new_shape_b)
  1010. if pure_multiplication:
  1011. # no contracted indices
  1012. if output_order is not None:
  1013. kwargs["order"] = output_order
  1014. # do the 'contraction' via multiplication!
  1015. return multiply(a, b, out=out, **kwargs)
  1016. # can only supply out here if no other reshaping / transposing
  1017. matmul_out_compatible = (new_shape_ab is None) and (perm_ab is None)
  1018. if matmul_out_compatible:
  1019. kwargs["out"] = out
  1020. # do the contraction!
  1021. ab = matmul(a, b, **kwargs)
  1022. # prepare the output
  1023. if new_shape_ab is not None:
  1024. ab = reshape(ab, new_shape_ab)
  1025. if perm_ab is not None:
  1026. ab = ab.transpose(perm_ab)
  1027. if (out is not None) and (not matmul_out_compatible):
  1028. # handle case where out is specified, but we also needed
  1029. # to reshape / transpose ``ab`` after the matmul
  1030. out[:] = ab
  1031. ab = out
  1032. elif output_order is not None:
  1033. ab = asanyarray(ab, order=output_order)
  1034. return ab
  1035. def _einsum_dispatcher(*operands, out=None, optimize=None, **kwargs):
  1036. # Arguably we dispatch on more arguments than we really should; see note in
  1037. # _einsum_path_dispatcher for why.
  1038. yield from operands
  1039. yield out
  1040. # Rewrite einsum to handle different cases
  1041. @array_function_dispatch(_einsum_dispatcher, module='numpy')
  1042. def einsum(*operands, out=None, optimize=False, **kwargs):
  1043. """
  1044. einsum(subscripts, *operands, out=None, dtype=None, order='K',
  1045. casting='safe', optimize=False)
  1046. Evaluates the Einstein summation convention on the operands.
  1047. Using the Einstein summation convention, many common multi-dimensional,
  1048. linear algebraic array operations can be represented in a simple fashion.
  1049. In *implicit* mode `einsum` computes these values.
  1050. In *explicit* mode, `einsum` provides further flexibility to compute
  1051. other array operations that might not be considered classical Einstein
  1052. summation operations, by disabling, or forcing summation over specified
  1053. subscript labels.
  1054. See the notes and examples for clarification.
  1055. Parameters
  1056. ----------
  1057. subscripts : str
  1058. Specifies the subscripts for summation as comma separated list of
  1059. subscript labels. An implicit (classical Einstein summation)
  1060. calculation is performed unless the explicit indicator '->' is
  1061. included as well as subscript labels of the precise output form.
  1062. operands : list of array_like
  1063. These are the arrays for the operation.
  1064. out : ndarray, optional
  1065. If provided, the calculation is done into this array.
  1066. dtype : {data-type, None}, optional
  1067. If provided, forces the calculation to use the data type specified.
  1068. Note that you may have to also give a more liberal `casting`
  1069. parameter to allow the conversions. Default is None.
  1070. order : {'C', 'F', 'A', 'K'}, optional
  1071. Controls the memory layout of the output. 'C' means it should
  1072. be C contiguous. 'F' means it should be Fortran contiguous,
  1073. 'A' means it should be 'F' if the inputs are all 'F', 'C' otherwise.
  1074. 'K' means it should be as close to the layout as the inputs as
  1075. is possible, including arbitrarily permuted axes.
  1076. Default is 'K'.
  1077. casting : {'no', 'equiv', 'safe', 'same_kind', 'unsafe'}, optional
  1078. Controls what kind of data casting may occur. Setting this to
  1079. 'unsafe' is not recommended, as it can adversely affect accumulations.
  1080. * 'no' means the data types should not be cast at all.
  1081. * 'equiv' means only byte-order changes are allowed.
  1082. * 'safe' means only casts which can preserve values are allowed.
  1083. * 'same_kind' means only safe casts or casts within a kind,
  1084. like float64 to float32, are allowed.
  1085. * 'unsafe' means any data conversions may be done.
  1086. Default is 'safe'.
  1087. optimize : {False, True, 'greedy', 'optimal'}, optional
  1088. Controls if intermediate optimization should occur. No optimization
  1089. will occur if False and True will default to the 'greedy' algorithm.
  1090. Also accepts an explicit contraction list from the ``np.einsum_path``
  1091. function. See ``np.einsum_path`` for more details. Defaults to False.
  1092. Returns
  1093. -------
  1094. output : ndarray
  1095. The calculation based on the Einstein summation convention.
  1096. See Also
  1097. --------
  1098. einsum_path, dot, inner, outer, tensordot, linalg.multi_dot
  1099. einsum:
  1100. Similar verbose interface is provided by the
  1101. `einops <https://github.com/arogozhnikov/einops>`_ package to cover
  1102. additional operations: transpose, reshape/flatten, repeat/tile,
  1103. squeeze/unsqueeze and reductions.
  1104. The `opt_einsum <https://optimized-einsum.readthedocs.io/en/stable/>`_
  1105. optimizes contraction order for einsum-like expressions
  1106. in backend-agnostic manner.
  1107. Notes
  1108. -----
  1109. The Einstein summation convention can be used to compute
  1110. many multi-dimensional, linear algebraic array operations. `einsum`
  1111. provides a succinct way of representing these.
  1112. A non-exhaustive list of these operations,
  1113. which can be computed by `einsum`, is shown below along with examples:
  1114. * Trace of an array, :py:func:`numpy.trace`.
  1115. * Return a diagonal, :py:func:`numpy.diag`.
  1116. * Array axis summations, :py:func:`numpy.sum`.
  1117. * Transpositions and permutations, :py:func:`numpy.transpose`.
  1118. * Matrix multiplication and dot product, :py:func:`numpy.matmul`
  1119. :py:func:`numpy.dot`.
  1120. * Vector inner and outer products, :py:func:`numpy.inner`
  1121. :py:func:`numpy.outer`.
  1122. * Broadcasting, element-wise and scalar multiplication,
  1123. :py:func:`numpy.multiply`.
  1124. * Tensor contractions, :py:func:`numpy.tensordot`.
  1125. * Chained array operations, in efficient calculation order,
  1126. :py:func:`numpy.einsum_path`.
  1127. The subscripts string is a comma-separated list of subscript labels,
  1128. where each label refers to a dimension of the corresponding operand.
  1129. Whenever a label is repeated it is summed, so ``np.einsum('i,i', a, b)``
  1130. is equivalent to :py:func:`np.inner(a,b) <numpy.inner>`. If a label
  1131. appears only once, it is not summed, so ``np.einsum('i', a)``
  1132. produces a view of ``a`` with no changes. A further example
  1133. ``np.einsum('ij,jk', a, b)`` describes traditional matrix multiplication
  1134. and is equivalent to :py:func:`np.matmul(a,b) <numpy.matmul>`.
  1135. Repeated subscript labels in one operand take the diagonal.
  1136. For example, ``np.einsum('ii', a)`` is equivalent to
  1137. :py:func:`np.trace(a) <numpy.trace>`.
  1138. In *implicit mode*, the chosen subscripts are important
  1139. since the axes of the output are reordered alphabetically. This
  1140. means that ``np.einsum('ij', a)`` doesn't affect a 2D array, while
  1141. ``np.einsum('ji', a)`` takes its transpose. Additionally,
  1142. ``np.einsum('ij,jk', a, b)`` returns a matrix multiplication, while,
  1143. ``np.einsum('ij,jh', a, b)`` returns the transpose of the
  1144. multiplication since subscript 'h' precedes subscript 'i'.
  1145. In *explicit mode* the output can be directly controlled by
  1146. specifying output subscript labels. This requires the
  1147. identifier '->' as well as the list of output subscript labels.
  1148. This feature increases the flexibility of the function since
  1149. summing can be disabled or forced when required. The call
  1150. ``np.einsum('i->', a)`` is like :py:func:`np.sum(a) <numpy.sum>`
  1151. if ``a`` is a 1-D array, and ``np.einsum('ii->i', a)``
  1152. is like :py:func:`np.diag(a) <numpy.diag>` if ``a`` is a square 2-D array.
  1153. The difference is that `einsum` does not allow broadcasting by default.
  1154. Additionally ``np.einsum('ij,jh->ih', a, b)`` directly specifies the
  1155. order of the output subscript labels and therefore returns matrix
  1156. multiplication, unlike the example above in implicit mode.
  1157. To enable and control broadcasting, use an ellipsis. Default
  1158. NumPy-style broadcasting is done by adding an ellipsis
  1159. to the left of each term, like ``np.einsum('...ii->...i', a)``.
  1160. ``np.einsum('...i->...', a)`` is like
  1161. :py:func:`np.sum(a, axis=-1) <numpy.sum>` for array ``a`` of any shape.
  1162. To take the trace along the first and last axes,
  1163. you can do ``np.einsum('i...i', a)``, or to do a matrix-matrix
  1164. product with the left-most indices instead of rightmost, one can do
  1165. ``np.einsum('ij...,jk...->ik...', a, b)``.
  1166. When there is only one operand, no axes are summed, and no output
  1167. parameter is provided, a view into the operand is returned instead
  1168. of a new array. Thus, taking the diagonal as ``np.einsum('ii->i', a)``
  1169. produces a view (changed in version 1.10.0).
  1170. `einsum` also provides an alternative way to provide the subscripts and
  1171. operands as ``einsum(op0, sublist0, op1, sublist1, ..., [sublistout])``.
  1172. If the output shape is not provided in this format `einsum` will be
  1173. calculated in implicit mode, otherwise it will be performed explicitly.
  1174. The examples below have corresponding `einsum` calls with the two
  1175. parameter methods.
  1176. Views returned from einsum are now writeable whenever the input array
  1177. is writeable. For example, ``np.einsum('ijk...->kji...', a)`` will now
  1178. have the same effect as :py:func:`np.swapaxes(a, 0, 2) <numpy.swapaxes>`
  1179. and ``np.einsum('ii->i', a)`` will return a writeable view of the diagonal
  1180. of a 2D array.
  1181. Added the ``optimize`` argument which will optimize the contraction order
  1182. of an einsum expression. For a contraction with three or more operands
  1183. this can greatly increase the computational efficiency at the cost of
  1184. a larger memory footprint during computation.
  1185. Typically a 'greedy' algorithm is applied which empirical tests have shown
  1186. returns the optimal path in the majority of cases. In some cases 'optimal'
  1187. will return the superlative path through a more expensive, exhaustive
  1188. search. For iterative calculations it may be advisable to calculate
  1189. the optimal path once and reuse that path by supplying it as an argument.
  1190. An example is given below.
  1191. See :py:func:`numpy.einsum_path` for more details.
  1192. Examples
  1193. --------
  1194. >>> a = np.arange(25).reshape(5,5)
  1195. >>> b = np.arange(5)
  1196. >>> c = np.arange(6).reshape(2,3)
  1197. Trace of a matrix:
  1198. >>> np.einsum('ii', a)
  1199. 60
  1200. >>> np.einsum(a, [0,0])
  1201. 60
  1202. >>> np.trace(a)
  1203. 60
  1204. Extract the diagonal (requires explicit form):
  1205. >>> np.einsum('ii->i', a)
  1206. array([ 0, 6, 12, 18, 24])
  1207. >>> np.einsum(a, [0,0], [0])
  1208. array([ 0, 6, 12, 18, 24])
  1209. >>> np.diag(a)
  1210. array([ 0, 6, 12, 18, 24])
  1211. Sum over an axis (requires explicit form):
  1212. >>> np.einsum('ij->i', a)
  1213. array([ 10, 35, 60, 85, 110])
  1214. >>> np.einsum(a, [0,1], [0])
  1215. array([ 10, 35, 60, 85, 110])
  1216. >>> np.sum(a, axis=1)
  1217. array([ 10, 35, 60, 85, 110])
  1218. For higher dimensional arrays summing a single axis can be done
  1219. with ellipsis:
  1220. >>> np.einsum('...j->...', a)
  1221. array([ 10, 35, 60, 85, 110])
  1222. >>> np.einsum(a, [Ellipsis,1], [Ellipsis])
  1223. array([ 10, 35, 60, 85, 110])
  1224. Compute a matrix transpose, or reorder any number of axes:
  1225. >>> np.einsum('ji', c)
  1226. array([[0, 3],
  1227. [1, 4],
  1228. [2, 5]])
  1229. >>> np.einsum('ij->ji', c)
  1230. array([[0, 3],
  1231. [1, 4],
  1232. [2, 5]])
  1233. >>> np.einsum(c, [1,0])
  1234. array([[0, 3],
  1235. [1, 4],
  1236. [2, 5]])
  1237. >>> np.transpose(c)
  1238. array([[0, 3],
  1239. [1, 4],
  1240. [2, 5]])
  1241. Vector inner products:
  1242. >>> np.einsum('i,i', b, b)
  1243. 30
  1244. >>> np.einsum(b, [0], b, [0])
  1245. 30
  1246. >>> np.inner(b,b)
  1247. 30
  1248. Matrix vector multiplication:
  1249. >>> np.einsum('ij,j', a, b)
  1250. array([ 30, 80, 130, 180, 230])
  1251. >>> np.einsum(a, [0,1], b, [1])
  1252. array([ 30, 80, 130, 180, 230])
  1253. >>> np.dot(a, b)
  1254. array([ 30, 80, 130, 180, 230])
  1255. >>> np.einsum('...j,j', a, b)
  1256. array([ 30, 80, 130, 180, 230])
  1257. Broadcasting and scalar multiplication:
  1258. >>> np.einsum('..., ...', 3, c)
  1259. array([[ 0, 3, 6],
  1260. [ 9, 12, 15]])
  1261. >>> np.einsum(',ij', 3, c)
  1262. array([[ 0, 3, 6],
  1263. [ 9, 12, 15]])
  1264. >>> np.einsum(3, [Ellipsis], c, [Ellipsis])
  1265. array([[ 0, 3, 6],
  1266. [ 9, 12, 15]])
  1267. >>> np.multiply(3, c)
  1268. array([[ 0, 3, 6],
  1269. [ 9, 12, 15]])
  1270. Vector outer product:
  1271. >>> np.einsum('i,j', np.arange(2)+1, b)
  1272. array([[0, 1, 2, 3, 4],
  1273. [0, 2, 4, 6, 8]])
  1274. >>> np.einsum(np.arange(2)+1, [0], b, [1])
  1275. array([[0, 1, 2, 3, 4],
  1276. [0, 2, 4, 6, 8]])
  1277. >>> np.outer(np.arange(2)+1, b)
  1278. array([[0, 1, 2, 3, 4],
  1279. [0, 2, 4, 6, 8]])
  1280. Tensor contraction:
  1281. >>> a = np.arange(60.).reshape(3,4,5)
  1282. >>> b = np.arange(24.).reshape(4,3,2)
  1283. >>> np.einsum('ijk,jil->kl', a, b)
  1284. array([[4400., 4730.],
  1285. [4532., 4874.],
  1286. [4664., 5018.],
  1287. [4796., 5162.],
  1288. [4928., 5306.]])
  1289. >>> np.einsum(a, [0,1,2], b, [1,0,3], [2,3])
  1290. array([[4400., 4730.],
  1291. [4532., 4874.],
  1292. [4664., 5018.],
  1293. [4796., 5162.],
  1294. [4928., 5306.]])
  1295. >>> np.tensordot(a,b, axes=([1,0],[0,1]))
  1296. array([[4400., 4730.],
  1297. [4532., 4874.],
  1298. [4664., 5018.],
  1299. [4796., 5162.],
  1300. [4928., 5306.]])
  1301. Writeable returned arrays (since version 1.10.0):
  1302. >>> a = np.zeros((3, 3))
  1303. >>> np.einsum('ii->i', a)[:] = 1
  1304. >>> a
  1305. array([[1., 0., 0.],
  1306. [0., 1., 0.],
  1307. [0., 0., 1.]])
  1308. Example of ellipsis use:
  1309. >>> a = np.arange(6).reshape((3,2))
  1310. >>> b = np.arange(12).reshape((4,3))
  1311. >>> np.einsum('ki,jk->ij', a, b)
  1312. array([[10, 28, 46, 64],
  1313. [13, 40, 67, 94]])
  1314. >>> np.einsum('ki,...k->i...', a, b)
  1315. array([[10, 28, 46, 64],
  1316. [13, 40, 67, 94]])
  1317. >>> np.einsum('k...,jk', a, b)
  1318. array([[10, 28, 46, 64],
  1319. [13, 40, 67, 94]])
  1320. Chained array operations. For more complicated contractions, speed ups
  1321. might be achieved by repeatedly computing a 'greedy' path or pre-computing
  1322. the 'optimal' path and repeatedly applying it, using an `einsum_path`
  1323. insertion (since version 1.12.0). Performance improvements can be
  1324. particularly significant with larger arrays:
  1325. >>> a = np.ones(64).reshape(2,4,8)
  1326. Basic `einsum`: ~1520ms (benchmarked on 3.1GHz Intel i5.)
  1327. >>> for iteration in range(500):
  1328. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a)
  1329. Sub-optimal `einsum` (due to repeated path calculation time): ~330ms
  1330. >>> for iteration in range(500):
  1331. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
  1332. ... optimize='optimal')
  1333. Greedy `einsum` (faster optimal path approximation): ~160ms
  1334. >>> for iteration in range(500):
  1335. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize='greedy')
  1336. Optimal `einsum` (best usage pattern in some use cases): ~110ms
  1337. >>> path = np.einsum_path('ijk,ilm,njm,nlk,abc->',a,a,a,a,a,
  1338. ... optimize='optimal')[0]
  1339. >>> for iteration in range(500):
  1340. ... _ = np.einsum('ijk,ilm,njm,nlk,abc->',a,a,a,a,a, optimize=path)
  1341. """
  1342. # Special handling if out is specified
  1343. specified_out = out is not None
  1344. # If no optimization, run pure einsum
  1345. if optimize is False:
  1346. if specified_out:
  1347. kwargs['out'] = out
  1348. return c_einsum(*operands, **kwargs)
  1349. # Check the kwargs to avoid a more cryptic error later, without having to
  1350. # repeat default values here
  1351. valid_einsum_kwargs = ['dtype', 'order', 'casting']
  1352. unknown_kwargs = [k for (k, v) in kwargs.items() if
  1353. k not in valid_einsum_kwargs]
  1354. if len(unknown_kwargs):
  1355. raise TypeError(f"Did not understand the following kwargs: {unknown_kwargs}")
  1356. # Build the contraction list and operand
  1357. operands, contraction_list = einsum_path(*operands, optimize=optimize,
  1358. einsum_call=True)
  1359. # Start contraction loop
  1360. for num, contraction in enumerate(contraction_list):
  1361. inds, einsum_str, _ = contraction
  1362. tmp_operands = [operands.pop(x) for x in inds]
  1363. # Do we need to deal with the output?
  1364. handle_out = specified_out and ((num + 1) == len(contraction_list))
  1365. # If out was specified
  1366. if handle_out:
  1367. kwargs["out"] = out
  1368. if len(tmp_operands) == 2:
  1369. # Call (batched) matrix multiplication if possible
  1370. new_view = bmm_einsum(einsum_str, *tmp_operands, **kwargs)
  1371. else:
  1372. # Call einsum
  1373. new_view = c_einsum(einsum_str, *tmp_operands, **kwargs)
  1374. # Append new items and dereference what we can
  1375. operands.append(new_view)
  1376. del tmp_operands, new_view
  1377. if specified_out:
  1378. return out
  1379. else:
  1380. return operands[0]