einsum.py 37 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import collections
  15. import itertools
  16. import re
  17. import string
  18. import numpy as np
  19. import opt_einsum
  20. from paddle import _C_ops
  21. from ..base.data_feeder import check_type, check_variable_and_dtype
  22. from ..base.framework import in_dynamic_or_pir_mode
  23. from ..base.layer_helper import LayerHelper
  24. from .linalg import matmul, transpose
  25. from .manipulation import reshape, squeeze, unsqueeze
  26. from .math import (
  27. multiply,
  28. sum as paddle_sum,
  29. )
  30. __all__ = []
  31. def parse_op_labels(labelstr, operand):
  32. '''
  33. Parse labels for an input operand.
  34. Parameters
  35. ----------
  36. labelstr:
  37. the input label string
  38. operand:
  39. the input operand
  40. Returns
  41. -------
  42. the input operand's full label string in which all anonymous dimensions are
  43. labeled in dots.
  44. '''
  45. # Sanity checks
  46. for c in labelstr.replace('.', ''):
  47. assert (
  48. c.isalpha()
  49. ), f"Invalid equation: {c} is not a valid label, which should be letters."
  50. assert (
  51. labelstr.replace('...', '', 1).find('.') == -1
  52. ), "Invalid equation: `.` is found outside of an ellipsis."
  53. ndims = len(operand.shape)
  54. full_labelstr = labelstr.replace('...', '.' * (ndims - len(labelstr) + 3))
  55. assert (
  56. len(full_labelstr) == ndims
  57. ), f"Invalid equation: the label string '{labelstr}' misses dimensions."
  58. return full_labelstr
  59. def parse_labels(labelstr, operands):
  60. '''
  61. Parse label strings for all input operands.
  62. Parameters
  63. ----------
  64. labelstr:
  65. The equation's label string
  66. operands:
  67. The input operands
  68. Returns
  69. -------
  70. list of full label strings for all input operands
  71. '''
  72. nop_labels = labelstr.split(',')
  73. assert len(nop_labels) == len(operands), (
  74. f"Invalid equation: the number of operands is {len(operands)}, "
  75. f"but found {len(nop_labels)} segments in the label equation."
  76. )
  77. return list(map(parse_op_labels, nop_labels, operands))
  78. def validate_rhs(rhs, input_labels, n_bcast_dims):
  79. '''
  80. Check whether the equation's right hand side is valid
  81. '''
  82. # Sanity check.
  83. if n_bcast_dims > 0:
  84. assert (
  85. '...' in rhs
  86. ), "Invalid equation: missing ellipsis in output labels."
  87. rhs = rhs.replace('...', '')
  88. rhs_set = set(rhs)
  89. # Hidden assumption: available labels don't include '.'
  90. assert '.' not in input_labels
  91. # Verify that output labels all come from the set of input labels
  92. non_input_labels = rhs_set.difference(input_labels)
  93. assert not non_input_labels, (
  94. f"Invalid equation: "
  95. f"output label {sorted(non_input_labels)} not used by any input."
  96. )
  97. # Verify that output labels are not duplicate
  98. assert len(rhs) == len(
  99. rhs_set
  100. ), "Invalid equation: duplicate output labels are found."
  101. def build_view(in_labels, out_labels):
  102. '''
  103. Build an inverse map of dimension indices. Three conditions must hold for
  104. the result to be meaningful.
  105. First, no duplicate letter labels in each label string.
  106. Second, the number of dots in dimout_labels >= that in in_labels.
  107. Third, dots are contiguous in each label string.
  108. Parameters
  109. ----------
  110. in_labels:
  111. The dimension labels to map to
  112. out_labels:
  113. The dimension labels to map from
  114. Returns
  115. -------
  116. The inverse map from out_labels to in_labels. The length of the inverse map equals that of
  117. out_labels. -1 is filled if there's no matching input dimension for a specific label.
  118. Examples
  119. --------
  120. in_labels = 'ij..', out_labels = '..ji'
  121. inv_map = [2, 3, 1, 0]
  122. in_labels = 'ij..', out_labels = '..kji'
  123. inv_map = [2, 3, -1, 1, 0]
  124. '''
  125. inv_map = [-1] * len(out_labels)
  126. # First build the broadcast dimension mapping
  127. # Find the broadcast index range in out_labels
  128. r = re.search(r'\.+', out_labels)
  129. if r:
  130. start, end = r.start(), r.end()
  131. s = re.search(r'\.+', in_labels)
  132. # fill the broadcast dimension indices from right to left.
  133. if s:
  134. for ax, dim in zip(
  135. range(start, end)[::-1], range(s.start(), s.end())[::-1]
  136. ):
  137. inv_map[ax] = dim
  138. # Now work on non-broadcast dimensions
  139. if r:
  140. it = itertools.chain(range(start), range(end, len(out_labels)))
  141. else:
  142. it = iter(range(len(out_labels)))
  143. for i in it:
  144. inv_map[i] = in_labels.find(out_labels[i])
  145. return inv_map
  146. def build_global_view(nop_labels, rhs, n_bcast_dims):
  147. '''
  148. Build the global view, which is a layout of all dimension labels
  149. plus an index table that maps from the layout to the dimensions
  150. in each operand. In the global view, the dimensions are arranged
  151. such that output ones are put on the left and contraction ones
  152. are put on the right.
  153. Parameters
  154. ----------
  155. nop_labels:
  156. The input full label strings of all input operands
  157. rhs:
  158. The equation right hand side
  159. n_bcast_dims:
  160. The maximum number of broadcast dimensions
  161. Returns
  162. -------
  163. A tuple of g_labels, g_view, g_nout, g_count
  164. g_labels:
  165. the layout of all labels in a string
  166. g_view:
  167. the index table
  168. g_nout:
  169. the number of output dimensions
  170. g_count:
  171. the counter array for dimension contractions
  172. '''
  173. # Put all labels in alphabetical order
  174. concat = sorted(''.join(nop_labels).replace('.', ''))
  175. labels, count = [], []
  176. for a, b in zip(['.'] + concat, concat):
  177. if a != b:
  178. labels.append(b)
  179. count.append(1)
  180. else:
  181. count[-1] += 1
  182. if rhs is not None:
  183. validate_rhs(rhs, labels, n_bcast_dims)
  184. g_labels_out = rhs.replace('...', '.' * n_bcast_dims)
  185. else:
  186. g_labels_out = '.' * n_bcast_dims + ''.join(
  187. l for l, c in zip(labels, count) if c == 1
  188. )
  189. for i in range(len(count))[::-1]:
  190. if labels[i] in g_labels_out:
  191. labels.pop(i)
  192. count.pop(i)
  193. g_labels_sum = ''.join(labels)
  194. g_labels = g_labels_out + g_labels_sum
  195. g_view = [build_view(i, g_labels) for i in nop_labels]
  196. g_nout = len(g_labels_out)
  197. g_count = count
  198. return g_labels, g_view, g_nout, g_count
  199. def build_global_shape(g_view, g_labels, op_shapes):
  200. '''
  201. The global shape is the shape of all dimensions rearranged and broadcasting
  202. to the global view. It's a reference data structure for einsum planning.
  203. Parameters
  204. ----------
  205. g_view:
  206. the global view
  207. op_shapes:
  208. the shapes of the all operands
  209. Returns
  210. -------
  211. g_shape:
  212. the global shape vector
  213. g_masks:
  214. list of shape masks for each operand. A dimension's shape mask is a boolean
  215. indicating whether its size > 1, in other words, it's not squeezable
  216. '''
  217. view_shapes = []
  218. g_masks = []
  219. for view, op_shape in zip(g_view, op_shapes):
  220. view_shapes.append([op_shape[dim] if dim > -1 else 1 for dim in view])
  221. g_shape = [set(sizes_per_ax) - {1} for sizes_per_ax in zip(*view_shapes)]
  222. non_bcastable = [ax for ax, sizes in enumerate(g_shape) if len(sizes) > 1]
  223. assert not non_bcastable, (
  224. f"Invalid operands: label {g_labels[non_bcastable[0]]} "
  225. f"corresponds to non-broadcastable dimensions."
  226. )
  227. g_shape = [sizes.pop() if len(sizes) > 0 else 1 for sizes in g_shape]
  228. g_masks = [
  229. [s > 1 or s == -1 for s in view_shape] for view_shape in view_shapes
  230. ]
  231. return g_shape, g_masks
  232. def has_duplicated_labels(labels):
  233. '''
  234. Returns True if there is any duplicate label.
  235. '''
  236. labels = labels.replace('.', '')
  237. return len(labels) > len(set(labels))
  238. def diagonalize(labels, operand):
  239. '''
  240. Merges dimensions with duplicate labels.
  241. For those dimensions with duplicate labels, merge them into one dimension
  242. which represents the diagonal elements. This requires the dimensions with
  243. duplicate labels are equal sized.
  244. Examples
  245. --------
  246. 'ijj...i' would be merged into 'ij...'
  247. '''
  248. assert not has_duplicated_labels(
  249. labels
  250. ), 'Duplicate labels are not supported.'
  251. return labels, operand
  252. def plan_reduce(plan, op, reduce_dims, keepdim):
  253. '''
  254. Add reduce to the plan
  255. '''
  256. varname = f'op{op}'
  257. f = lambda var, dims: paddle_sum(var, dims, keepdim=keepdim)
  258. step = f, [varname], varname, reduce_dims
  259. plan.add_step(step)
  260. def plan_scalar_prod(plan, op1, op2):
  261. varnames = [f'op{op1}', f'op{op2}']
  262. f = lambda var1, var2: paddle_sum(var1) * var2
  263. # f = lambda var1, var2: var1 * var2
  264. step = f, varnames, varnames[1]
  265. plan.add_step(step)
  266. def plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K):
  267. '''
  268. plan matmul
  269. '''
  270. # Transpose and re-shape op1 and op2 in I, J1, K and I, J2, K
  271. # Then apply matmul(x, y, transpose_x=False, transpose_y=True)
  272. var1, var2 = f'op{op1}', f'op{op2}'
  273. op1_view, op2_view = (g_view[op] for op in (op1, op2))
  274. I1 = [idx for idx in I if op1_view[idx] >= 0]
  275. I2 = [idx for idx in I if op2_view[idx] >= 0]
  276. op1_view = np.array(op1_view)
  277. op1_dims = op1_view[I1 + J1 + K]
  278. op2_view = np.array(op2_view)
  279. op2_dims = op2_view[I2 + J2 + K]
  280. op1_mask, op2_mask = (g_supports[op] for op in (op1, op2))
  281. op1_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op1_mask)])
  282. op2_vshape = np.array([s if m else 1 for s, m in zip(g_shape, op2_mask)])
  283. vshape = np.maximum(op1_vshape, op2_vshape)
  284. i1, i2, j1, j2, k = map(len, (I1, I2, J1, J2, K))
  285. if any(op1_dims != np.arange(len(op1_dims))):
  286. # print(f'perm1: {perm1}')
  287. step = transpose, [var1], var1, list(op1_dims)
  288. plan.add_step(step)
  289. if any(op2_dims != np.arange(len(op2_dims))):
  290. # print(f'perm2: {perm2}')
  291. step = transpose, [var2], var2, list(op2_dims)
  292. plan.add_step(step)
  293. # Check if conditions hold for turning the operation into a matmul
  294. if (
  295. j1 + j2 > 0
  296. and k > 0
  297. and -1 not in np.concatenate((op1_vshape, op2_vshape))
  298. ):
  299. op1_shape = (
  300. list(op1_vshape[I])
  301. + [np.prod(op1_vshape[J1])]
  302. + [np.prod(op1_vshape[K])]
  303. )
  304. op2_shape = (
  305. list(op2_vshape[I])
  306. + [np.prod(op2_vshape[J2])]
  307. + [np.prod(op2_vshape[K])]
  308. )
  309. # Merge J dims and K dims by reshaping
  310. step = reshape, [var1], var1, op1_shape
  311. plan.add_step(step)
  312. step = reshape, [var2], var2, op2_shape
  313. plan.add_step(step)
  314. # Matmul
  315. step = matmul, [var1, var2], var2, False, True
  316. plan.add_step(step)
  317. # Reshape back
  318. shape = list(vshape[I + J1 + J2])
  319. step = reshape, [var2], var2, shape
  320. plan.add_step(step)
  321. elif j1 == j2 == k == 1:
  322. # Can still do matmul even unknown shapes are present
  323. step = matmul, [var1, var2], var2, False, True
  324. plan.add_step(step)
  325. # In the rest cases we opt for ops other than matmul
  326. else:
  327. # unsqueeze operands include J1...J2... dimensions
  328. if j2:
  329. fill = list(range(i1 + j1, i1 + j1 + j2))
  330. step = unsqueeze, [var1], var1, fill
  331. plan.add_step(step)
  332. if j1:
  333. fill = list(range(i2, i2 + j1))
  334. step = unsqueeze, [var2], var2, fill
  335. plan.add_step(step)
  336. # In case of no dimensions to contract, do an elementwise multiply
  337. if k == 0:
  338. # make broadcast
  339. step = multiply, [var1, var2], var2
  340. plan.add_step(step)
  341. # Contract and no join, turn into a dot
  342. elif j1 + j2 == 0 and k == 1:
  343. step = unsqueeze, [var1], var1, [-2]
  344. plan.add_step(step)
  345. step = unsqueeze, [var2], var2, [-1]
  346. plan.add_step(step)
  347. step = matmul, [var1, var2], var2
  348. plan.add_step(step)
  349. step = squeeze, [var2], var2, [-1, -2]
  350. plan.add_step(step)
  351. elif j1 + j2 == 0 and -1 not in np.concatenate(
  352. (op1_vshape[K], op2_vshape[K])
  353. ):
  354. assert all(op1_vshape[K] == op2_vshape[K])
  355. step = (
  356. reshape,
  357. [var1],
  358. var1,
  359. list(op1_vshape[I]) + [1] + [np.prod(op1_vshape[K])],
  360. )
  361. plan.add_step(step)
  362. step = (
  363. reshape,
  364. [var2],
  365. var2,
  366. list(op2_vshape[I]) + [1] + [np.prod(op2_vshape[K])],
  367. )
  368. plan.add_step(step)
  369. step = matmul, [var1, var2], var2, False, True
  370. plan.add_step(step)
  371. step = squeeze, [var2], var2, [-1, -2]
  372. plan.add_step(step)
  373. else:
  374. step = multiply, [var1, var2], var2
  375. plan.add_step(step)
  376. reduce_dims = list(range(-k, 0))
  377. plan_reduce(plan, op2, reduce_dims, keepdim=False)
  378. # Wrap up, updating auxiliary data
  379. # Updating g_mask for I and J axes
  380. for ax in I + J1 + J2:
  381. op2_mask[ax] = vshape[ax] > 1 or vshape[ax] == -1
  382. for ax in K:
  383. op2_mask[ax] = False
  384. for ax in range(len(op2_view)):
  385. op2_view[ax] = -1
  386. dim = 0
  387. for ax in I + J1 + J2:
  388. op2_view[ax], dim = dim, dim + 1
  389. g_view[op2] = list(op2_view)
  390. def plan_summation(
  391. plan, g_view, op1, op2, g_supports, g_shape, g_count, n_bcast
  392. ):
  393. '''
  394. Plan various kinds of summation
  395. '''
  396. op1_view, op2_view = g_view[op1], g_view[op2]
  397. op1_mask, op2_mask = g_supports[op1], g_supports[op2]
  398. ndim = len(op1_view)
  399. nout = ndim - len(g_count)
  400. count = [0] * nout + g_count
  401. I, K, J1, J2 = list(range(n_bcast)), [], [], []
  402. for ax, dim1, dim2 in zip(
  403. range(n_bcast, ndim), op1_view[n_bcast:], op2_view[n_bcast:]
  404. ):
  405. if (dim1 != -1) != (dim2 != -1):
  406. if dim1 != -1:
  407. J1.append(ax)
  408. else:
  409. J2.append(ax)
  410. elif dim1 != -1:
  411. fold = int(op1_mask[ax]) + int(op2_mask[ax])
  412. if ax >= nout and fold == count[ax]:
  413. # Ready to fold the dimensions
  414. K.append(ax)
  415. count[ax] -= fold
  416. else:
  417. I.append(ax)
  418. count[ax] -= max(fold - 1, 0)
  419. # Update g_count
  420. g_count[:] = count[nout:]
  421. # Now it's OK to merge the K dims as the same shape holds
  422. # print(f'I: {I} J1: {J1} J2: {J2} K: {K}')
  423. plan_matmul(plan, g_view, op1, op2, g_supports, g_shape, I, J1, J2, K)
  424. def rearrange(axes):
  425. perm, fill = [], []
  426. for ax, dim in enumerate(axes):
  427. if dim < 0:
  428. fill.append(ax)
  429. else:
  430. perm.append(dim)
  431. # Trivial permutation returns []
  432. if all(i == dim for i, dim in enumerate(perm)):
  433. perm = []
  434. return perm, fill
  435. def plan_broadcast(plan, operands, nop_axes):
  436. '''
  437. Plan broadcast across
  438. '''
  439. nop = len(operands)
  440. varnames = [f'op{i}' for i in range(nop)]
  441. for i, op_axes in zip(range(nop), nop_axes):
  442. # Re-arrange the dimensions according to the global layout
  443. perm, fill = rearrange(op_axes)
  444. var = varnames[i]
  445. if perm:
  446. step = transpose, [var], var, perm
  447. plan.add_step(step)
  448. if fill:
  449. step = unsqueeze, [var], var, fill
  450. plan.add_step(step)
  451. def f(*args):
  452. expr = ' * '.join(varnames)
  453. return eval(expr, dict(zip(varnames, args)))
  454. step = f, varnames, None
  455. plan.add_step(step)
  456. class Plan:
  457. def __init__(self):
  458. self.env = {}
  459. self.steps = []
  460. def add_step(self, step):
  461. self.steps.append(step)
  462. def get_var(self, varname):
  463. return self.env[varname] if varname in self.env else None
  464. def set_var(self, varname, var):
  465. self.env[varname] = var
  466. def show(self):
  467. res = None
  468. for f, in_varnames, out_varname, *args in self.steps:
  469. print(repr((out_varname, f, *in_varnames, *args)))
  470. return res
  471. def execute(self):
  472. res = None
  473. for f, in_varnames, out_varname, *args in self.steps:
  474. res = f(*map(self.get_var, in_varnames), *args)
  475. if out_varname:
  476. self.set_var(out_varname, res)
  477. return res
  478. def plan_einsum(operands, g_view, g_shape, g_supports, g_count, n_bcast):
  479. '''
  480. Plans the actual execution steps.
  481. Results
  482. -------
  483. the execution plan
  484. '''
  485. nop = len(operands)
  486. ndim = len(g_view[0])
  487. nout = ndim - len(g_count)
  488. # Initialize a plan with an environment
  489. plan = Plan()
  490. op_names = [f'op{i}' for i in range(nop)]
  491. list(map(plan.set_var, op_names, operands))
  492. # In case no dimensions to combine, do broadcast straight across
  493. if not g_count:
  494. plan_broadcast(plan, operands, g_view)
  495. return plan
  496. # Down count degenerate contraction dimensions.
  497. for view, support in zip(g_view, g_supports):
  498. # To collect the down count number, we use a type casting trick
  499. down_count = [
  500. int((d + 1) and (not s))
  501. for d, s in zip(view[nout:], support[nout:])
  502. ]
  503. for i, count in enumerate(down_count):
  504. g_count[i] -= count
  505. # Reduce any dimension for which g_support is set and g_count == 1
  506. for i, view, mask in zip(range(nop), g_view, g_supports):
  507. to_reduce = []
  508. for dim, masked, count in zip(view[nout:], mask[nout:], g_count):
  509. to_reduce.append(dim if (masked and count == 1) else -1)
  510. reduce_dims = list(filter(lambda x: x > -1, to_reduce))
  511. if reduce_dims:
  512. plan_reduce(plan, i, reduce_dims, keepdim=True)
  513. # Unset mask and decrease g_count for the reduced dimensions
  514. for i, d in enumerate(to_reduce):
  515. ax = i + nout
  516. mask[ax] = mask[ax] and (d == -1)
  517. g_count[i] -= 0 if d == -1 else 1
  518. # Plan the summations over the operand sequence
  519. for i in range(nop):
  520. # plan a single step
  521. if i == 0:
  522. continue
  523. # We'd like to arrange the dimensions in the following way:
  524. # [I... J... K...]
  525. # [I... J... K...]
  526. # where
  527. # I... are aligned and not to be combined immediately
  528. # J... are not aligned and not to be combined immediately
  529. # K... are aligned and should be immediately combined
  530. # At this point the non-trivial broadcast dimensions in K are already reduced
  531. # and removed. That means all K dimensions are aligned and their sizes are not 1.
  532. # We then inspect the layout of I,J,K plus the above observation to make
  533. # specialization decisions. The current strategy is set as follows:
  534. # (1) if I... J... K... are all empty, it's multiplying a scalar
  535. # (2) if K... are empty, better use a broadcast
  536. # (3) if I... J... empty and K... not empty, a vector-vector multiply (or a dot)
  537. # (4) Elsewise, either I... or J... not empty, and K... not empty, use a general matmul
  538. # Resolve the summation kind: dot, matmul or *
  539. if not any(g_supports[i - 1]):
  540. # op1 is a one element tensor.
  541. plan_scalar_prod(plan, i - 1, i)
  542. else:
  543. plan_summation(
  544. plan, g_view, i - 1, i, g_supports, g_shape, g_count, n_bcast
  545. )
  546. # for ax, dim in enumerate(g_view[nop-1][:nout]):
  547. # assert dim == ax
  548. assert all(not masked for masked in g_supports[nop - 1][nout:])
  549. view = g_view[-1]
  550. if any(ax != dim for ax, dim in enumerate(view[:nout])):
  551. perm = [dim for dim in view if dim >= 0]
  552. if sorted(perm) != perm:
  553. varname = f'op{nop - 1}'
  554. step = transpose, [varname], varname, perm
  555. plan.add_step(step)
  556. dim = 0
  557. unsqueeze_dims = []
  558. for ax, d in enumerate(view):
  559. if d != -1:
  560. view[ax], dim = dim, dim + 1
  561. for ax, d in enumerate(view[:nout]):
  562. if d == -1:
  563. unsqueeze_dims.append(ax)
  564. if unsqueeze_dims:
  565. varname = f'op{nop - 1}'
  566. step = unsqueeze, [varname], varname, unsqueeze_dims
  567. plan.add_step(step)
  568. squeeze_dims = [dim for dim in view[nout:] if dim != -1]
  569. if squeeze_dims:
  570. # plan_reduce(plan, nop-1, reduce_dims, keepdim=False)
  571. varname = f'op{nop - 1}'
  572. step = squeeze, [varname], varname, squeeze_dims
  573. plan.add_step(step)
  574. return plan
  575. def replace_ellipsis(left_equation, rhs, *operands):
  576. """
  577. we replace ... as unused variables to simplify the EinsumOp implementation.
  578. """
  579. ellipsis_strings = None
  580. max_ndim = 0
  581. new_operands = []
  582. unused_variables = {chr(c) for c in range(ord('a'), ord('z'))}
  583. for equ, operand in zip(left_equation.split(','), operands):
  584. ndims = len(operand.shape) - len(equ.replace("...", ""))
  585. max_ndim = max(max_ndim, ndims)
  586. for c in equ:
  587. unused_variables.discard(c)
  588. for equ, operand in zip(left_equation.split(','), operands):
  589. if '...' in equ:
  590. start_unsqueeze_idx = equ.index('...')
  591. to_squeeze_num = max_ndim - (
  592. len(operand.shape) - len(equ.replace("...", ""))
  593. )
  594. operand = unsqueeze(
  595. operand,
  596. axis=[i + start_unsqueeze_idx for i in range(to_squeeze_num)],
  597. )
  598. new_operands.append(operand)
  599. operands = new_operands
  600. ellipsis_strings = ''.join(unused_variables.pop() for _ in range(max_ndim))
  601. if ellipsis_strings is not None:
  602. left_equation = left_equation.replace('...', ellipsis_strings)
  603. rhs = rhs.replace('...', ellipsis_strings)
  604. return left_equation, rhs, operands
  605. def preprocess(equation, *operands):
  606. """
  607. check equation / raise error, default right labels generation
  608. """
  609. equation = equation.replace(" ", "")
  610. nop = len(operands)
  611. assert nop > 0, (
  612. "Required at least one operand in Einsum API, but received %s " % nop
  613. )
  614. # Part the equation to left hand side and right hand side
  615. lhs, *rhs = equation.lower().split('->')
  616. assert len(rhs) < 2, "Invalid equation: multiple `->` were found."
  617. labels = parse_labels(lhs, operands)
  618. # Note, we distinguish between 'ij->' and 'ij' by setting rhs to '' and None
  619. rhs = rhs[0] if rhs else None
  620. if rhs is None:
  621. rhs = rhs_inference(lhs)
  622. assert len(lhs.split(',')) == len(operands), (
  623. f"Invalid equation: the number of operands is {len(operands)}, "
  624. f"but found {len(lhs.split(','))} segments in the label equation."
  625. )
  626. assert not (
  627. '...' in lhs and '...' not in rhs
  628. ), 'Invalid equation: missing ellipsis in output labels.'
  629. lhs, rhs, operands = replace_ellipsis(lhs, rhs, *operands)
  630. return lhs, rhs, labels, operands
  631. def parse_fake_shape(equation, operands, labels):
  632. """
  633. this shape is just used for operands planning. may differ with the original shape.
  634. for example:
  635. ... is replaced by 1
  636. -1 is replaced by 1
  637. Results
  638. -------
  639. list of shape
  640. """
  641. origin_labels = (x.strip() for x in equation.split(','))
  642. shaped = collections.namedtuple('shaped', ['shape'])
  643. def fake_shape(ori_label, label, op):
  644. """
  645. 1. ori_label is the original labels, not aligned by '....'
  646. 2. if the '...' is evaluated to empty list, there is no '.' in label
  647. """
  648. assert len(op.shape) == len(label), (
  649. "length of shape and length of label must be the same, but received %d != %d"
  650. % (len(op.shape), len(label))
  651. )
  652. fakes = [s for i, (l, s) in enumerate(zip(label, op.shape)) if l != '.']
  653. fakes = list(map(abs, fakes)) # make -1 -> 1
  654. if '.' in ori_label:
  655. fakes.insert(ori_label.index('.'), 1)
  656. return shaped(fakes)
  657. out = list(map(fake_shape, origin_labels, labels, operands))
  658. return out
  659. def rhs_inference(lhs):
  660. def is_free(key):
  661. return cnt.get(key) == 1 and key not in ['.', ',']
  662. cnt = collections.Counter(lhs)
  663. rhs = "..." if '...' in lhs else ""
  664. rhs = rhs + "".join(filter(is_free, sorted(cnt.elements())))
  665. return rhs
  666. def gen_equation_for_opteinsum(lhs, rhs):
  667. """
  668. 1. gen rhs if rhs is None
  669. 2. '...' -> 'A'
  670. """
  671. def get_used_label(counter):
  672. used = set(counter.elements())
  673. for c in string.ascii_lowercase:
  674. if c not in used:
  675. return c
  676. raise ValueError(
  677. "You have used all `a` - `z`, there can't find a unused char for einsum optimization"
  678. )
  679. cnt = collections.Counter(lhs)
  680. broadcast_label = get_used_label(cnt)
  681. if rhs is None:
  682. rhs = rhs_inference(lhs)
  683. lhs = lhs.replace("...", broadcast_label)
  684. rhs = rhs.replace("...", broadcast_label)
  685. return lhs + "->" + rhs, broadcast_label
  686. def einsum_v2(equation, *operands):
  687. """
  688. einsum v2 implementation.
  689. 1. Implement C++ EinsumOp.
  690. 2. V2 create the EinsumOp to calculate, so just a little verify work in python.
  691. 3. V2 use opt_einsum.contract_path to optimize the multivariable einsum.
  692. """
  693. n_op = len(operands)
  694. lhs, rhs, labels, operands = preprocess(equation, *operands)
  695. if n_op <= 2:
  696. return gen_einsum_op(lhs + '->' + rhs, *operands)
  697. shapes = parse_fake_shape(lhs, operands, labels)
  698. opt_equation, broadcast_label = gen_equation_for_opteinsum(lhs, rhs)
  699. _, cons = opt_einsum.contract_path(opt_equation, *shapes, einsum_call=True)
  700. var_list = list(operands)
  701. for path in cons:
  702. (a, b), _, eq, *__ = path
  703. assert (
  704. a > b
  705. ), "Assume the first var_idx is smaller than the second_idx. opt_einsum can guarantee it."
  706. var_s = [var_list.pop(a), var_list.pop(b)]
  707. eq = eq.replace(broadcast_label, "...")
  708. var_list.append(gen_einsum_op(eq, *var_s))
  709. assert (
  710. len(var_list) == 1
  711. ), "There must be one elements in list, but received %d." % len(var_list)
  712. return var_list[0]
  713. def gen_einsum_op(equation, *operands):
  714. """
  715. EinsumOp Python Interface:
  716. """
  717. if in_dynamic_or_pir_mode():
  718. return _C_ops.einsum(operands, equation)[0]
  719. else:
  720. assert len(operands) <= 2, "Only support two operands in EinsumOp."
  721. for inp in operands:
  722. check_variable_and_dtype(
  723. inp, 'dtype', ['float32', 'float64'], 'einsum'
  724. )
  725. check_type(equation, 'equation', str, 'einsum')
  726. helper = LayerHelper('einsum', **locals())
  727. out = helper.create_variable_for_type_inference(dtype=operands[0].dtype)
  728. attrs = {}
  729. attrs['equation'] = equation
  730. caches = [
  731. helper.create_variable_for_type_inference(dtype=operands[0].dtype)
  732. for i in range(len(operands))
  733. ]
  734. xshape = [
  735. helper.create_variable_for_type_inference(dtype=operands[0].dtype)
  736. for i in range(len(operands))
  737. ]
  738. helper.append_op(
  739. type='einsum',
  740. inputs={'Operands': operands},
  741. outputs={'Out': out, "InnerCache": caches, "XShape": xshape},
  742. attrs=attrs,
  743. )
  744. return out
  745. def einsum(equation, *operands):
  746. r"""
  747. einsum(equation, *operands)
  748. The current version of this API should be used in dynamic graph only mode.
  749. Einsum offers a tensor operation API which allows using the Einstein summation
  750. convention or Einstain notation. It takes as input one or multiple tensors and
  751. produces as output one tensor.
  752. Einsum is able to perform a variety of tensor operations. Following lists a few:
  753. - for single operand
  754. - trace
  755. - diagonal
  756. - transpose
  757. - sum
  758. - for double operands
  759. - dot
  760. - outer
  761. - broadcasting and elementwise multiply
  762. - matrix multiply
  763. - batched matrix multiply
  764. - for many operads
  765. - broadcasting multiply
  766. - chained matrix multiply
  767. **The summation notation**
  768. - The tensor dimensions are labeled using uncased English letters. E.g., `ijk`
  769. relates to a three dimensional tensor whose dimensions are labeled i, j, and k.
  770. - The equation is `,` separated into terms, each being a distinct input's
  771. dimension label string.
  772. - Ellipsis `...` enables broadcasting by automatically converting the unlabeled
  773. dimensions into broadcasting dimensions.
  774. - Singular labels are called free labels, duplicate are dummy labels. Dummy labeled
  775. dimensions will be reduced and removed in the output.
  776. - Output labels can be explicitly specified on the right hand side of `->` or omitted.
  777. In the latter case, the output labels will be inferred from the input labels.
  778. - Inference of output labels
  779. - Broadcasting label `...`, if present, is put on the leftmost position.
  780. - Free labels are reordered alphabetically and put after `...`.
  781. - On explicit output labels
  782. - If broadcasting is enabled, then `...` must be present.
  783. - The output labels can be an empty, an indication to output as a scalar
  784. the sum over the original output.
  785. - Non-input labels are invalid.
  786. - Duplicate labels are invalid.
  787. - For any dummy label which is present for the output, it's promoted to
  788. a free label.
  789. - For any free label which is not present for the output, it's lowered to
  790. a dummy label.
  791. - Examples
  792. - '...ij, ...jk', where i and k are free labels, j is dummy. The output label
  793. string is '...ik'
  794. - 'ij -> i', where i is a free label and j is a dummy label.
  795. - '...ij, ...jk -> ...ijk', where i, j and k are all free labels.
  796. - '...ij, ...jk -> ij', an invalid equation since `...` is not present for
  797. the output.
  798. **The summation rule**
  799. The summation procedure can be outlined as follows, although the actual steps taken
  800. may vary significantly due to implementation specific optimization.
  801. - Step 1: preparation for broadcasting, that is, transposing and unsqueezing
  802. the input operands to have each resulting dimension identically labeled across
  803. all the input operands.
  804. - Step 2: broadcasting multiply all the resulting operands from step 1.
  805. - Step 3: reducing dummy labeled dimensions.
  806. - Step 4: transposing the result tensor to match the output labels.
  807. **On trace and diagonal**
  808. The trace and diagonal are planned yet unimplemented features.
  809. Args:
  810. equation (`str`):
  811. The summation terms using the Einstein summation notation.
  812. operands (`list|Tensor`):
  813. The input tensors over which to compute the Einstein summation. The number of
  814. operands should equal the number of input terms in the equation.
  815. Returns:
  816. result (`Tensor`), the result tensor.
  817. Examples:
  818. .. code-block:: python
  819. >>> import paddle
  820. >>> paddle.seed(102)
  821. >>> x = paddle.rand([4])
  822. >>> y = paddle.rand([5])
  823. >>> # sum
  824. >>> print(paddle.einsum('i->', x))
  825. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  826. 1.81225157)
  827. >>> # dot
  828. >>> print(paddle.einsum('i,i->', x, x))
  829. Tensor(shape=[], dtype=float32, place=Place(cpu), stop_gradient=True,
  830. 1.13530672)
  831. >>> # outer
  832. >>> print(paddle.einsum("i,j->ij", x, y))
  833. Tensor(shape=[4, 5], dtype=float32, place=Place(cpu), stop_gradient=True,
  834. [[0.26443148, 0.05962684, 0.25360870, 0.21900642, 0.56994802],
  835. [0.20955276, 0.04725220, 0.20097610, 0.17355499, 0.45166403],
  836. [0.35836059, 0.08080698, 0.34369346, 0.29680005, 0.77240014],
  837. [0.00484230, 0.00109189, 0.00464411, 0.00401047, 0.01043695]])
  838. >>> A = paddle.rand([2, 3, 2])
  839. >>> B = paddle.rand([2, 2, 3])
  840. >>> # transpose
  841. >>> print(paddle.einsum('ijk->kji', A))
  842. Tensor(shape=[2, 3, 2], dtype=float32, place=Place(cpu), stop_gradient=True,
  843. [[[0.50882483, 0.56067896],
  844. [0.84598064, 0.36310029],
  845. [0.55289471, 0.33273944]],
  846. [[0.04836850, 0.73811269],
  847. [0.29769155, 0.28137168],
  848. [0.84636718, 0.67521429]]])
  849. >>> # batch matrix multiplication
  850. >>> print(paddle.einsum('ijk, ikl->ijl', A,B))
  851. Tensor(shape=[2, 3, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
  852. [[[0.36321065, 0.42009076, 0.40849245],
  853. [0.74353045, 0.79189068, 0.81345987],
  854. [0.90488225, 0.79786193, 0.93451476]],
  855. [[0.12680580, 1.06945944, 0.79821426],
  856. [0.07774551, 0.55068684, 0.44512171],
  857. [0.08053084, 0.80583858, 0.56031936]]])
  858. >>> # Ellipsis transpose
  859. >>> print(paddle.einsum('...jk->...kj', A))
  860. Tensor(shape=[2, 2, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
  861. [[[0.50882483, 0.84598064, 0.55289471],
  862. [0.04836850, 0.29769155, 0.84636718]],
  863. [[0.56067896, 0.36310029, 0.33273944],
  864. [0.73811269, 0.28137168, 0.67521429]]])
  865. >>> # Ellipsis batch matrix multiplication
  866. >>> print(paddle.einsum('...jk, ...kl->...jl', A,B))
  867. Tensor(shape=[2, 3, 3], dtype=float32, place=Place(cpu), stop_gradient=True,
  868. [[[0.36321065, 0.42009076, 0.40849245],
  869. [0.74353045, 0.79189068, 0.81345987],
  870. [0.90488225, 0.79786193, 0.93451476]],
  871. [[0.12680580, 1.06945944, 0.79821426],
  872. [0.07774551, 0.55068684, 0.44512171],
  873. [0.08053084, 0.80583858, 0.56031936]]])
  874. """
  875. import os
  876. if int(os.environ.get('FLAGS_new_einsum', "1")):
  877. return einsum_v2(equation, *operands)
  878. nop = len(operands)
  879. assert nop > 0, "At least one operand is expected."
  880. # Part the equation to left hand side and right hand side
  881. lhs, *rhs = equation.lower().replace(' ', '').split('->')
  882. assert len(rhs) < 2, "Invalid equation: multiple `->` were found."
  883. # Note, we distinguish between 'ij->' and 'ij' by setting rhs to '' and None
  884. rhs = rhs[0] if rhs else None
  885. # Parse labels for each operand and count the number of occurrences for each alphabet label
  886. nop_labels = parse_labels(lhs, operands)
  887. # Diagonalize the operands which have duplicate labels
  888. nop_labels, operands = list(zip(*map(diagonalize, nop_labels, operands)))
  889. # To handle broadcasting, we should first know how many dimensions are there
  890. # We need to use that number to generate output labels
  891. # e.g. 1 for ['ij', 'i.', '.k']
  892. n_bcast_dims = max(s.count('.') for s in nop_labels)
  893. # Build the data structures for planning. It's helpful to think of all the operands
  894. # broadcasting together from a global view. In this view, dimensions from multiple
  895. # operands are mapped to the same position if they are labeled uniquely. Broadcasting
  896. # dimensions are mapped to adjacent positions with the right bound fixed. Subject to
  897. # each operand, the map is injective but for all operands the map is on-to.
  898. # g_labels:
  899. # The labels of the global view
  900. # g_view:
  901. # Includes a list of maps from each operand's dimensions to the global view's dimensions
  902. # which we refer to as ax or axes in the code to distinguish from operand's dims
  903. # g_shape:
  904. # The shape of the global view. The size of each dimension is what the aligned dimensions
  905. # should broadcast to
  906. # g_nout:
  907. # Number of output axes
  908. # g_supports
  909. # Booleans indicating each operand's non-trivial dimensions
  910. # g_count
  911. # Counting how many non-trivial dimensions remain for each ax
  912. g_labels, g_view, g_nout, g_count = build_global_view(
  913. nop_labels, rhs, n_bcast_dims
  914. )
  915. g_shape, g_supports = build_global_shape(
  916. g_view, g_labels, [op.shape for op in operands]
  917. )
  918. # Now we're ready to build up an execution plan
  919. args = operands, g_view, g_shape, g_supports, g_count, n_bcast_dims
  920. plan = plan_einsum(*args)
  921. result = plan.execute()
  922. return result