functional.py 85 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195
  1. # mypy: allow-untyped-defs
  2. import itertools
  3. import operator
  4. from collections.abc import Sequence
  5. from typing import Any, Optional, TYPE_CHECKING, Union
  6. import torch
  7. import torch.nn.functional as F
  8. from torch import _VF, Tensor
  9. from torch._C import _add_docstr
  10. from torch._jit_internal import _overload as overload, boolean_dispatch
  11. from torch._lowrank import pca_lowrank, svd_lowrank
  12. from torch.overrides import (
  13. handle_torch_function,
  14. has_torch_function,
  15. has_torch_function_unary,
  16. has_torch_function_variadic,
  17. )
  18. __all__ = [
  19. "atleast_1d",
  20. "atleast_2d",
  21. "atleast_3d",
  22. "align_tensors",
  23. "broadcast_shapes",
  24. "broadcast_tensors",
  25. "cartesian_prod",
  26. "block_diag",
  27. "cdist",
  28. "chain_matmul",
  29. "einsum",
  30. "istft",
  31. "lu",
  32. "norm",
  33. "meshgrid",
  34. "pca_lowrank",
  35. "split",
  36. "stft",
  37. "svd_lowrank",
  38. "tensordot",
  39. "unique",
  40. "unique_consecutive",
  41. "unravel_index",
  42. ]
  43. def broadcast_tensors(*tensors):
  44. r"""broadcast_tensors(*tensors) -> List of Tensors
  45. Broadcasts the given tensors according to :ref:`broadcasting-semantics`.
  46. Args:
  47. *tensors: any number of tensors of the same type
  48. .. warning::
  49. More than one element of a broadcasted tensor may refer to a single
  50. memory location. As a result, in-place operations (especially ones that
  51. are vectorized) may result in incorrect behavior. If you need to write
  52. to the tensors, please clone them first.
  53. Example::
  54. >>> x = torch.arange(3).view(1, 3)
  55. >>> y = torch.arange(2).view(2, 1)
  56. >>> a, b = torch.broadcast_tensors(x, y)
  57. >>> a.size()
  58. torch.Size([2, 3])
  59. >>> a
  60. tensor([[0, 1, 2],
  61. [0, 1, 2]])
  62. """
  63. # This wrapper exists to support variadic args.
  64. if has_torch_function(tensors):
  65. return handle_torch_function(broadcast_tensors, tensors, *tensors)
  66. return _VF.broadcast_tensors(tensors) # type: ignore[attr-defined]
  67. def broadcast_shapes(*shapes):
  68. r"""broadcast_shapes(*shapes) -> Size
  69. Similar to :func:`broadcast_tensors` but for shapes.
  70. This is equivalent to
  71. ``torch.broadcast_tensors(*map(torch.empty, shapes))[0].shape``
  72. but avoids the need create to intermediate tensors. This is useful for
  73. broadcasting tensors of common batch shape but different rightmost shape,
  74. e.g. to broadcast mean vectors with covariance matrices.
  75. Example::
  76. >>> torch.broadcast_shapes((2,), (3, 1), (1, 1, 1))
  77. torch.Size([1, 3, 2])
  78. Args:
  79. \*shapes (torch.Size): Shapes of tensors.
  80. Returns:
  81. shape (torch.Size): A shape compatible with all input shapes.
  82. Raises:
  83. RuntimeError: If shapes are incompatible.
  84. """
  85. # This wrapper exists to support variadic args.
  86. # TODO Move this to C++ once the jit has better support for torch.Size.
  87. if not torch.jit.is_tracing():
  88. result = torch._refs._broadcast_shapes(*shapes)
  89. if result is None:
  90. return torch.Size([])
  91. return torch.Size(result)
  92. else:
  93. # with implementation above, torch.jit.trace hardcodes the sizes which makes subsequent replays fail
  94. with torch.no_grad():
  95. scalar = torch.zeros((), device="cpu")
  96. tensors = [scalar.expand(shape) for shape in shapes]
  97. tensors = broadcast_tensors(*tensors)
  98. return tensors[0].shape
  99. def split(
  100. tensor: Tensor,
  101. split_size_or_sections: Union[int, list[int]],
  102. dim: int = 0,
  103. ) -> tuple[Tensor, ...]:
  104. r"""Splits the tensor into chunks. Each chunk is a view of the original tensor.
  105. If :attr:`split_size_or_sections` is an integer type, then :attr:`tensor` will
  106. be split into equally sized chunks (if possible). Last chunk will be smaller if
  107. the tensor size along the given dimension :attr:`dim` is not divisible by
  108. :attr:`split_size`.
  109. If :attr:`split_size_or_sections` is a list, then :attr:`tensor` will be split
  110. into ``len(split_size_or_sections)`` chunks with sizes in :attr:`dim` according
  111. to :attr:`split_size_or_sections`.
  112. Args:
  113. tensor (Tensor): tensor to split.
  114. split_size_or_sections (int) or (list(int)): size of a single chunk or
  115. list of sizes for each chunk
  116. dim (int): dimension along which to split the tensor.
  117. Example::
  118. >>> a = torch.arange(10).reshape(5, 2)
  119. >>> a
  120. tensor([[0, 1],
  121. [2, 3],
  122. [4, 5],
  123. [6, 7],
  124. [8, 9]])
  125. >>> torch.split(a, 2)
  126. (tensor([[0, 1],
  127. [2, 3]]),
  128. tensor([[4, 5],
  129. [6, 7]]),
  130. tensor([[8, 9]]))
  131. >>> torch.split(a, [1, 4])
  132. (tensor([[0, 1]]),
  133. tensor([[2, 3],
  134. [4, 5],
  135. [6, 7],
  136. [8, 9]]))
  137. """
  138. if has_torch_function_unary(tensor):
  139. return handle_torch_function(
  140. split, (tensor,), tensor, split_size_or_sections, dim=dim
  141. )
  142. # Overwriting reason:
  143. # This dispatches to two ATen functions depending on the type of
  144. # split_size_or_sections. The branching code is in _tensor.py, which we
  145. # call here.
  146. return tensor.split(split_size_or_sections, dim)
  147. def einsum(*args: Any) -> Tensor:
  148. r"""einsum(equation, *operands) -> Tensor
  149. Sums the product of the elements of the input :attr:`operands` along dimensions specified using a notation
  150. based on the Einstein summation convention.
  151. Einsum allows computing many common multi-dimensional linear algebraic array operations by representing them
  152. in a short-hand format based on the Einstein summation convention, given by :attr:`equation`. The details of
  153. this format are described below, but the general idea is to label every dimension of the input :attr:`operands`
  154. with some subscript and define which subscripts are part of the output. The output is then computed by summing
  155. the product of the elements of the :attr:`operands` along the dimensions whose subscripts are not part of the
  156. output. For example, matrix multiplication can be computed using einsum as `torch.einsum("ij,jk->ik", A, B)`.
  157. Here, j is the summation subscript and i and k the output subscripts (see section below for more details on why).
  158. Equation:
  159. The :attr:`equation` string specifies the subscripts (letters in `[a-zA-Z]`) for each dimension of
  160. the input :attr:`operands` in the same order as the dimensions, separating subscripts for each operand by a
  161. comma (','), e.g. `'ij,jk'` specify subscripts for two 2D operands. The dimensions labeled with the same subscript
  162. must be broadcastable, that is, their size must either match or be `1`. The exception is if a subscript is
  163. repeated for the same input operand, in which case the dimensions labeled with this subscript for this operand
  164. must match in size and the operand will be replaced by its diagonal along these dimensions. The subscripts that
  165. appear exactly once in the :attr:`equation` will be part of the output, sorted in increasing alphabetical order.
  166. The output is computed by multiplying the input :attr:`operands` element-wise, with their dimensions aligned based
  167. on the subscripts, and then summing out the dimensions whose subscripts are not part of the output.
  168. Optionally, the output subscripts can be explicitly defined by adding an arrow ('->') at the end of the equation
  169. followed by the subscripts for the output. For instance, the following equation computes the transpose of a
  170. matrix multiplication: 'ij,jk->ki'. The output subscripts must appear at least once for some input operand and
  171. at most once for the output.
  172. Ellipsis ('...') can be used in place of subscripts to broadcast the dimensions covered by the ellipsis.
  173. Each input operand may contain at most one ellipsis which will cover the dimensions not covered by subscripts,
  174. e.g. for an input operand with 5 dimensions, the ellipsis in the equation `'ab...c'` cover the third and fourth
  175. dimensions. The ellipsis does not need to cover the same number of dimensions across the :attr:`operands` but the
  176. 'shape' of the ellipsis (the size of the dimensions covered by them) must broadcast together. If the output is not
  177. explicitly defined with the arrow ('->') notation, the ellipsis will come first in the output (left-most dimensions),
  178. before the subscript labels that appear exactly once for the input operands. e.g. the following equation implements
  179. batch matrix multiplication `'...ij,...jk'`.
  180. A few final notes: the equation may contain whitespaces between the different elements (subscripts, ellipsis,
  181. arrow and comma) but something like `'. . .'` is not valid. An empty string `''` is valid for scalar operands.
  182. .. note::
  183. ``torch.einsum`` handles ellipsis ('...') differently from NumPy in that it allows dimensions
  184. covered by the ellipsis to be summed over, that is, ellipsis are not required to be part of the output.
  185. .. note::
  186. Please install opt-einsum (https://optimized-einsum.readthedocs.io/en/stable/) in order to enroll into a more
  187. performant einsum. You can install when installing torch like so: `pip install torch[opt-einsum]` or by itself
  188. with `pip install opt-einsum`.
  189. If opt-einsum is available, this function will automatically speed up computation and/or consume less memory
  190. by optimizing contraction order through our opt_einsum backend :mod:`torch.backends.opt_einsum` (The _ vs - is
  191. confusing, I know). This optimization occurs when there are at least three inputs, since the order does not matter
  192. otherwise. Note that finding `the` optimal path is an NP-hard problem, thus, opt-einsum relies on different
  193. heuristics to achieve near-optimal results. If opt-einsum is not available, the default order is to contract
  194. from left to right.
  195. To bypass this default behavior, add the following to disable opt_einsum and skip path calculation:
  196. ``torch.backends.opt_einsum.enabled = False``
  197. To specify which strategy you'd like for opt_einsum to compute the contraction path, add the following line:
  198. ``torch.backends.opt_einsum.strategy = 'auto'``. The default strategy is 'auto', and we also support 'greedy' and
  199. 'optimal'. Disclaimer that the runtime of 'optimal' is factorial in the number of inputs! See more details in
  200. the opt_einsum documentation (https://optimized-einsum.readthedocs.io/en/stable/path_finding.html).
  201. .. note::
  202. As of PyTorch 1.10 :func:`torch.einsum` also supports the sublist format (see examples below). In this format,
  203. subscripts for each operand are specified by sublists, list of integers in the range [0, 52). These sublists
  204. follow their operands, and an extra sublist can appear at the end of the input to specify the output's
  205. subscripts., e.g. `torch.einsum(op1, sublist1, op2, sublist2, ..., [subslist_out])`. Python's `Ellipsis` object
  206. may be provided in a sublist to enable broadcasting as described in the Equation section above.
  207. Args:
  208. equation (str): The subscripts for the Einstein summation.
  209. operands (List[Tensor]): The tensors to compute the Einstein summation of.
  210. Examples::
  211. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  212. >>> # trace
  213. >>> torch.einsum('ii', torch.randn(4, 4))
  214. tensor(-1.2104)
  215. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  216. >>> # diagonal
  217. >>> torch.einsum('ii->i', torch.randn(4, 4))
  218. tensor([-0.1034, 0.7952, -0.2433, 0.4545])
  219. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  220. >>> # outer product
  221. >>> x = torch.randn(5)
  222. >>> y = torch.randn(4)
  223. >>> torch.einsum('i,j->ij', x, y)
  224. tensor([[ 0.1156, -0.2897, -0.3918, 0.4963],
  225. [-0.3744, 0.9381, 1.2685, -1.6070],
  226. [ 0.7208, -1.8058, -2.4419, 3.0936],
  227. [ 0.1713, -0.4291, -0.5802, 0.7350],
  228. [ 0.5704, -1.4290, -1.9323, 2.4480]])
  229. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  230. >>> # batch matrix multiplication
  231. >>> As = torch.randn(3, 2, 5)
  232. >>> Bs = torch.randn(3, 5, 4)
  233. >>> torch.einsum('bij,bjk->bik', As, Bs)
  234. tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
  235. [-1.6706, -0.8097, -0.8025, -2.1183]],
  236. [[ 4.2239, 0.3107, -0.5756, -0.2354],
  237. [-1.4558, -0.3460, 1.5087, -0.8530]],
  238. [[ 2.8153, 1.8787, -4.3839, -1.2112],
  239. [ 0.3728, -2.1131, 0.0921, 0.8305]]])
  240. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  241. >>> # with sublist format and ellipsis
  242. >>> torch.einsum(As, [..., 0, 1], Bs, [..., 1, 2], [..., 0, 2])
  243. tensor([[[-1.0564, -1.5904, 3.2023, 3.1271],
  244. [-1.6706, -0.8097, -0.8025, -2.1183]],
  245. [[ 4.2239, 0.3107, -0.5756, -0.2354],
  246. [-1.4558, -0.3460, 1.5087, -0.8530]],
  247. [[ 2.8153, 1.8787, -4.3839, -1.2112],
  248. [ 0.3728, -2.1131, 0.0921, 0.8305]]])
  249. >>> # batch permute
  250. >>> A = torch.randn(2, 3, 4, 5)
  251. >>> torch.einsum('...ij->...ji', A).shape
  252. torch.Size([2, 3, 5, 4])
  253. >>> # equivalent to torch.nn.functional.bilinear
  254. >>> A = torch.randn(3, 5, 4)
  255. >>> l = torch.randn(2, 5)
  256. >>> r = torch.randn(2, 4)
  257. >>> torch.einsum('bn,anm,bm->ba', l, A, r)
  258. tensor([[-0.3430, -5.2405, 0.4494],
  259. [ 0.3311, 5.5201, -3.0356]])
  260. """
  261. import torch.backends.opt_einsum as opt_einsum
  262. # This wrapper exists to support variadic args.
  263. if len(args) < 2:
  264. raise ValueError(
  265. "einsum(): must specify the equation string and at least one operand, "
  266. "or at least one operand and its subscripts list"
  267. )
  268. equation = None
  269. operands = None
  270. if isinstance(args[0], torch.Tensor):
  271. # Convert the subscript list format which is an interleaving of operand and its subscripts
  272. # list with an optional output subscripts list at the end (see documentation for more details on this)
  273. # to the equation string format by creating the equation string from the subscripts list and grouping the
  274. # input operands into a tensorlist (List[Tensor]).
  275. def parse_subscript(n: int) -> str:
  276. if n == Ellipsis:
  277. return "..."
  278. if n >= 0 and n < 26:
  279. return chr(ord("A") + n)
  280. if n >= 26 and n < 52:
  281. return chr(ord("a") + n - 26)
  282. raise ValueError(
  283. "einsum(): subscript in subscript list is not within the valid range [0, 52)"
  284. )
  285. # Parse subscripts for input operands
  286. equation = ",".join("".join(parse_subscript(s) for s in l) for l in args[1::2])
  287. # Parse optional output subscripts (provided when the number of arguments is odd)
  288. if len(args) % 2 == 1:
  289. equation += "->" + "".join(parse_subscript(s) for s in args[-1])
  290. operands = args[:-1:2]
  291. else:
  292. operands = args[::2]
  293. else:
  294. equation = args[0]
  295. operands = args[1:]
  296. if has_torch_function(operands):
  297. return handle_torch_function(einsum, operands, equation, *operands)
  298. if len(operands) == 1 and isinstance(operands[0], (list, tuple)):
  299. # the old interface of passing the operands as one list argument
  300. _operands = operands[0]
  301. # recurse in case operands contains value that has torch function
  302. # in the original implementation this line is omitted
  303. return einsum(equation, *_operands)
  304. if len(operands) <= 2 or not opt_einsum.enabled:
  305. # the path for contracting 0 or 1 time(s) is already optimized
  306. # or the user has disabled using opt_einsum
  307. return _VF.einsum(equation, operands) # type: ignore[attr-defined]
  308. path = None
  309. if opt_einsum.is_available():
  310. _opt_einsum = opt_einsum.get_opt_einsum()
  311. tupled_path = _opt_einsum.contract_path(
  312. equation, *operands, optimize=opt_einsum.strategy
  313. )[0]
  314. # flatten path for dispatching to C++
  315. path = [*itertools.chain.from_iterable(tupled_path)]
  316. return _VF.einsum(equation, operands, path=path) # type: ignore[attr-defined]
  317. # This wrapper exists to support variadic args.
  318. if TYPE_CHECKING:
  319. # The JIT doesn't understand Union, so only add type annotation for mypy
  320. def meshgrid(
  321. *tensors: Union[Tensor, list[Tensor]], indexing: Optional[str] = None
  322. ) -> tuple[Tensor, ...]:
  323. return _meshgrid(*tensors, indexing=indexing)
  324. else:
  325. def meshgrid(*tensors, indexing: Optional[str] = None) -> tuple[Tensor, ...]:
  326. r"""Creates grids of coordinates specified by the 1D inputs in `attr`:tensors.
  327. This is helpful when you want to visualize data over some
  328. range of inputs. See below for a plotting example.
  329. Given :math:`N` 1D tensors :math:`T_0 \ldots T_{N-1}` as
  330. inputs with corresponding sizes :math:`S_0 \ldots S_{N-1}`,
  331. this creates :math:`N` N-dimensional tensors :math:`G_0 \ldots
  332. G_{N-1}`, each with shape :math:`(S_0, ..., S_{N-1})` where
  333. the output :math:`G_i` is constructed by expanding :math:`T_i`
  334. to the result shape.
  335. .. note::
  336. 0D inputs are treated equivalently to 1D inputs of a
  337. single element.
  338. .. warning::
  339. `torch.meshgrid(*tensors)` currently has the same behavior
  340. as calling `numpy.meshgrid(*arrays, indexing='ij')`.
  341. In the future `torch.meshgrid` will transition to
  342. `indexing='xy'` as the default.
  343. https://github.com/pytorch/pytorch/issues/50276 tracks
  344. this issue with the goal of migrating to NumPy's behavior.
  345. .. seealso::
  346. :func:`torch.cartesian_prod` has the same effect but it
  347. collects the data in a tensor of vectors.
  348. Args:
  349. tensors (list of Tensor): list of scalars or 1 dimensional tensors. Scalars will be
  350. treated as tensors of size :math:`(1,)` automatically
  351. indexing: (str, optional): the indexing mode, either "xy"
  352. or "ij", defaults to "ij". See warning for future changes.
  353. If "xy" is selected, the first dimension corresponds
  354. to the cardinality of the second input and the second
  355. dimension corresponds to the cardinality of the first
  356. input.
  357. If "ij" is selected, the dimensions are in the same
  358. order as the cardinality of the inputs.
  359. Returns:
  360. seq (sequence of Tensors): If the input has :math:`N`
  361. tensors of size :math:`S_0 \ldots S_{N-1}``, then the
  362. output will also have :math:`N` tensors, where each tensor
  363. is of shape :math:`(S_0, ..., S_{N-1})`.
  364. Example::
  365. >>> x = torch.tensor([1, 2, 3])
  366. >>> y = torch.tensor([4, 5, 6])
  367. Observe the element-wise pairings across the grid, (1, 4),
  368. (1, 5), ..., (3, 6). This is the same thing as the
  369. cartesian product.
  370. >>> grid_x, grid_y = torch.meshgrid(x, y, indexing='ij')
  371. >>> grid_x
  372. tensor([[1, 1, 1],
  373. [2, 2, 2],
  374. [3, 3, 3]])
  375. >>> grid_y
  376. tensor([[4, 5, 6],
  377. [4, 5, 6],
  378. [4, 5, 6]])
  379. This correspondence can be seen when these grids are
  380. stacked properly.
  381. >>> torch.equal(torch.cat(tuple(torch.dstack([grid_x, grid_y]))),
  382. ... torch.cartesian_prod(x, y))
  383. True
  384. `torch.meshgrid` is commonly used to produce a grid for
  385. plotting.
  386. >>> # xdoctest: +REQUIRES(module:matplotlib)
  387. >>> # xdoctest: +REQUIRES(env:DOCTEST_SHOW)
  388. >>> import matplotlib.pyplot as plt
  389. >>> xs = torch.linspace(-5, 5, steps=100)
  390. >>> ys = torch.linspace(-5, 5, steps=100)
  391. >>> x, y = torch.meshgrid(xs, ys, indexing='xy')
  392. >>> z = torch.sin(torch.sqrt(x * x + y * y))
  393. >>> ax = plt.axes(projection='3d')
  394. >>> ax.plot_surface(x.numpy(), y.numpy(), z.numpy())
  395. >>> plt.show()
  396. .. image:: ../_static/img/meshgrid.png
  397. :width: 512
  398. """
  399. return _meshgrid(*tensors, indexing=indexing)
  400. def _meshgrid(*tensors, indexing: Optional[str]):
  401. if has_torch_function(tensors):
  402. return handle_torch_function(meshgrid, tensors, *tensors, indexing=indexing)
  403. if len(tensors) == 1 and isinstance(tensors[0], (list, tuple)):
  404. # the old interface of passing the operands as one list argument
  405. tensors = tensors[0] # type: ignore[assignment]
  406. # Continue allowing call of old method that takes no indexing
  407. # kwarg for forward compatibility reasons.
  408. #
  409. # Remove this two weeks after landing.
  410. kwargs = {} if indexing is None else {"indexing": indexing}
  411. return _VF.meshgrid(tensors, **kwargs) # type: ignore[attr-defined]
  412. def stft(
  413. input: Tensor,
  414. n_fft: int,
  415. hop_length: Optional[int] = None,
  416. win_length: Optional[int] = None,
  417. window: Optional[Tensor] = None,
  418. center: bool = True,
  419. pad_mode: str = "reflect",
  420. normalized: bool = False,
  421. onesided: Optional[bool] = None,
  422. return_complex: Optional[bool] = None,
  423. align_to_window: Optional[bool] = None,
  424. ) -> Tensor:
  425. r"""Short-time Fourier transform (STFT).
  426. .. warning::
  427. From version 1.8.0, :attr:`return_complex` must always be given
  428. explicitly for real inputs and `return_complex=False` has been
  429. deprecated. Strongly prefer `return_complex=True` as in a future
  430. pytorch release, this function will only return complex tensors.
  431. Note that :func:`torch.view_as_real` can be used to recover a real
  432. tensor with an extra last dimension for real and imaginary components.
  433. .. warning::
  434. From version 2.1, a warning will be provided if a :attr:`window` is
  435. not specified. In a future release, this attribute will be required.
  436. Not providing a window currently defaults to using a rectangular window,
  437. which may result in undesirable artifacts. Consider using tapered windows,
  438. such as :func:`torch.hann_window`.
  439. The STFT computes the Fourier transform of short overlapping windows of the
  440. input. This giving frequency components of the signal as they change over
  441. time. The interface of this function is modeled after (but *not* a drop-in
  442. replacement for) librosa_ stft function.
  443. .. _librosa: https://librosa.org/doc/latest/generated/librosa.stft.html
  444. Ignoring the optional batch dimension, this method computes the following
  445. expression:
  446. .. math::
  447. X[\omega, m] = \sum_{k = 0}^{\text{win\_length-1}}%
  448. \text{window}[k]\ \text{input}[m \times \text{hop\_length} + k]\ %
  449. \exp\left(- j \frac{2 \pi \cdot \omega k}{\text{n\_fft}}\right),
  450. where :math:`m` is the index of the sliding window, and :math:`\omega` is
  451. the frequency :math:`0 \leq \omega < \text{n\_fft}` for ``onesided=False``,
  452. or :math:`0 \leq \omega < \lfloor \text{n\_fft} / 2 \rfloor + 1` for ``onesided=True``.
  453. * :attr:`input` must be either a 1-D time sequence or a 2-D batch of time
  454. sequences.
  455. * If :attr:`hop_length` is ``None`` (default), it is treated as equal to
  456. ``floor(n_fft / 4)``.
  457. * If :attr:`win_length` is ``None`` (default), it is treated as equal to
  458. :attr:`n_fft`.
  459. * :attr:`window` can be a 1-D tensor of size :attr:`win_length`, e.g., from
  460. :meth:`torch.hann_window`. If :attr:`window` is ``None`` (default), it is
  461. treated as if having :math:`1` everywhere in the window. If
  462. :math:`\text{win\_length} < \text{n\_fft}`, :attr:`window` will be padded on
  463. both sides to length :attr:`n_fft` before being applied.
  464. * If :attr:`center` is ``True`` (default), :attr:`input` will be padded on
  465. both sides so that the :math:`t`-th frame is centered at time
  466. :math:`t \times \text{hop\_length}`. Otherwise, the :math:`t`-th frame
  467. begins at time :math:`t \times \text{hop\_length}`.
  468. * :attr:`pad_mode` determines the padding method used on :attr:`input` when
  469. :attr:`center` is ``True``. See :meth:`torch.nn.functional.pad` for
  470. all available options. Default is ``"reflect"``.
  471. * If :attr:`onesided` is ``True`` (default for real input), only values for
  472. :math:`\omega` in :math:`\left[0, 1, 2, \dots, \left\lfloor
  473. \frac{\text{n\_fft}}{2} \right\rfloor + 1\right]` are returned because
  474. the real-to-complex Fourier transform satisfies the conjugate symmetry,
  475. i.e., :math:`X[m, \omega] = X[m, \text{n\_fft} - \omega]^*`.
  476. Note if the input or window tensors are complex, then :attr:`onesided`
  477. output is not possible.
  478. * If :attr:`normalized` is ``True`` (default is ``False``), the function
  479. returns the normalized STFT results, i.e., multiplied by :math:`(\text{frame\_length})^{-0.5}`.
  480. * If :attr:`return_complex` is ``True`` (default if input is complex), the
  481. return is a ``input.dim() + 1`` dimensional complex tensor. If ``False``,
  482. the output is a ``input.dim() + 2`` dimensional real tensor where the last
  483. dimension represents the real and imaginary components.
  484. Returns either a complex tensor of size :math:`(* \times N \times T)` if
  485. :attr:`return_complex` is true, or a real tensor of size :math:`(* \times N
  486. \times T \times 2)`. Where :math:`*` is the optional batch size of
  487. :attr:`input`, :math:`N` is the number of frequencies where STFT is applied
  488. and :math:`T` is the total number of frames used.
  489. .. warning::
  490. This function changed signature at version 0.4.1. Calling with the
  491. previous signature may cause error or return incorrect result.
  492. Args:
  493. input (Tensor): the input tensor of shape `(B?, L)` where `B?` is an optional
  494. batch dimension
  495. n_fft (int): size of Fourier transform
  496. hop_length (int, optional): the distance between neighboring sliding window
  497. frames. Default: ``None`` (treated as equal to ``floor(n_fft / 4)``)
  498. win_length (int, optional): the size of window frame and STFT filter.
  499. Default: ``None`` (treated as equal to :attr:`n_fft`)
  500. window (Tensor, optional): the optional window function.
  501. Shape must be 1d and `<= n_fft`
  502. Default: ``None`` (treated as window of all :math:`1` s)
  503. center (bool, optional): whether to pad :attr:`input` on both sides so
  504. that the :math:`t`-th frame is centered at time :math:`t \times \text{hop\_length}`.
  505. Default: ``True``
  506. pad_mode (str, optional): controls the padding method used when
  507. :attr:`center` is ``True``. Default: ``"reflect"``
  508. normalized (bool, optional): controls whether to return the normalized STFT results
  509. Default: ``False``
  510. onesided (bool, optional): controls whether to return half of results to
  511. avoid redundancy for real inputs.
  512. Default: ``True`` for real :attr:`input` and :attr:`window`, ``False`` otherwise.
  513. return_complex (bool, optional): whether to return a complex tensor, or
  514. a real tensor with an extra last dimension for the real and
  515. imaginary components.
  516. .. versionchanged:: 2.0
  517. ``return_complex`` is now a required argument for real inputs,
  518. as the default is being transitioned to ``True``.
  519. .. deprecated:: 2.0
  520. ``return_complex=False`` is deprecated, instead use ``return_complex=True``
  521. Note that calling :func:`torch.view_as_real` on the output will
  522. recover the deprecated output format.
  523. Returns:
  524. Tensor: A tensor containing the STFT result with shape `(B?, N, T, C?)` where
  525. - `B?` is an optional batch dimension from the input.
  526. - `N` is the number of frequency samples, `(n_fft // 2) + 1` for
  527. `onesided=True`, or otherwise `n_fft`.
  528. - `T` is the number of frames, `1 + L // hop_length`
  529. for `center=True`, or `1 + (L - n_fft) // hop_length` otherwise.
  530. - `C?` is an optional length-2 dimension of real and imaginary
  531. components, present when `return_complex=False`.
  532. """
  533. if has_torch_function_unary(input):
  534. return handle_torch_function(
  535. stft,
  536. (input,),
  537. input,
  538. n_fft,
  539. hop_length=hop_length,
  540. win_length=win_length,
  541. window=window,
  542. center=center,
  543. pad_mode=pad_mode,
  544. normalized=normalized,
  545. onesided=onesided,
  546. return_complex=return_complex,
  547. align_to_window=align_to_window,
  548. )
  549. if center and align_to_window is not None:
  550. raise RuntimeError(
  551. "stft align_to_window should only be set when center = false"
  552. )
  553. # NOTE: Do not edit. This code will be removed once the forward-compatibility
  554. # period is over for PR #73432
  555. if center:
  556. signal_dim = input.dim()
  557. extended_shape = [1] * (3 - signal_dim) + list(input.size())
  558. pad = int(n_fft // 2)
  559. input = F.pad(input.view(extended_shape), [pad, pad], pad_mode)
  560. input = input.view(input.shape[-signal_dim:])
  561. return _VF.stft( # type: ignore[attr-defined]
  562. input,
  563. n_fft,
  564. hop_length,
  565. win_length,
  566. window,
  567. normalized,
  568. onesided,
  569. return_complex,
  570. align_to_window,
  571. )
  572. istft = _add_docstr(
  573. torch.istft,
  574. "istft(input, n_fft, hop_length=None, win_length=None, window=None, center=True, "
  575. "normalized=False, onesided=None, length=None, return_complex=False) -> Tensor:\n"
  576. r"""
  577. Inverse short time Fourier Transform. This is expected to be the inverse of :func:`~torch.stft`.
  578. .. warning::
  579. From version 2.1, a warning will be provided if a :attr:`window` is
  580. not specified. In a future release, this attribute will be required.
  581. Please provide the same window used in the stft call.
  582. It has the same parameters (+ additional optional parameter of :attr:`length`) and it should return the
  583. least squares estimation of the original signal. The algorithm will check using the NOLA condition (
  584. nonzero overlap).
  585. Important consideration in the parameters :attr:`window` and :attr:`center` so that the envelope
  586. created by the summation of all the windows is never zero at certain point in time. Specifically,
  587. :math:`\sum_{t=-\infty}^{\infty} |w|^2[n-t\times hop\_length] \cancel{=} 0`.
  588. Since :func:`~torch.stft` discards elements at the end of the signal if they do not fit in a frame,
  589. ``istft`` may return a shorter signal than the original signal (can occur if :attr:`center` is False
  590. since the signal isn't padded). If `length` is given in the arguments and is longer than expected,
  591. ``istft`` will pad zeros to the end of the returned signal.
  592. If :attr:`center` is ``True``, then there will be padding e.g. ``'constant'``, ``'reflect'``, etc.
  593. Left padding can be trimmed off exactly because they can be calculated but right padding cannot be
  594. calculated without additional information.
  595. Example: Suppose the last window is:
  596. ``[17, 18, 0, 0, 0]`` vs ``[18, 0, 0, 0, 0]``
  597. The :attr:`n_fft`, :attr:`hop_length`, :attr:`win_length` are all the same which prevents the calculation
  598. of right padding. These additional values could be zeros or a reflection of the signal so providing
  599. :attr:`length` could be useful. If :attr:`length` is ``None`` then padding will be aggressively removed
  600. (some loss of signal).
  601. [1] D. W. Griffin and J. S. Lim, "Signal estimation from modified short-time Fourier transform,"
  602. IEEE Trans. ASSP, vol.32, no.2, pp.236-243, Apr. 1984.
  603. Args:
  604. input (Tensor): The input tensor. Expected to be in the format of :func:`~torch.stft`,
  605. output. That is a complex tensor of shape `(B?, N, T)` where
  606. - `B?` is an optional batch dimension
  607. - `N` is the number of frequency samples, `(n_fft // 2) + 1`
  608. for onesided input, or otherwise `n_fft`.
  609. - `T` is the number of frames, `1 + length // hop_length` for centered stft,
  610. or `1 + (length - n_fft) // hop_length` otherwise.
  611. .. versionchanged:: 2.0
  612. Real datatype inputs are no longer supported. Input must now have a
  613. complex datatype, as returned by ``stft(..., return_complex=True)``.
  614. n_fft (int): Size of Fourier transform
  615. hop_length (Optional[int]): The distance between neighboring sliding window frames.
  616. (Default: ``n_fft // 4``)
  617. win_length (Optional[int]): The size of window frame and STFT filter. (Default: ``n_fft``)
  618. window (Optional[torch.Tensor]): The optional window function.
  619. Shape must be 1d and `<= n_fft`
  620. (Default: ``torch.ones(win_length)``)
  621. center (bool): Whether :attr:`input` was padded on both sides so that the :math:`t`-th frame is
  622. centered at time :math:`t \times \text{hop\_length}`.
  623. (Default: ``True``)
  624. normalized (bool): Whether the STFT was normalized. (Default: ``False``)
  625. onesided (Optional[bool]): Whether the STFT was onesided.
  626. (Default: ``True`` if `n_fft != fft_size` in the input size)
  627. length (Optional[int]): The amount to trim the signal by (i.e. the
  628. original signal length). Defaults to `(T - 1) * hop_length` for
  629. centered stft, or `n_fft + (T - 1) * hop_length` otherwise, where `T`
  630. is the number of input frames.
  631. return_complex (Optional[bool]):
  632. Whether the output should be complex, or if the input should be
  633. assumed to derive from a real signal and window.
  634. Note that this is incompatible with ``onesided=True``.
  635. (Default: ``False``)
  636. Returns:
  637. Tensor: Least squares estimation of the original signal of shape `(B?, length)` where
  638. `B?` is an optional batch dimension from the input tensor.
  639. """,
  640. )
  641. if TYPE_CHECKING:
  642. # These _impl functions return a variable number of tensors as output with
  643. # __torch_function__; tuple unpacking is done already rather than being
  644. # done by the caller of the _impl function
  645. _unique_impl_out = Any
  646. else:
  647. _unique_impl_out = tuple[Tensor, Tensor, Tensor]
  648. def _unique_impl(
  649. input: Tensor,
  650. sorted: bool = True,
  651. return_inverse: bool = False,
  652. return_counts: bool = False,
  653. dim: Optional[int] = None,
  654. ) -> _unique_impl_out:
  655. r"""unique(input, sorted=True, return_inverse=False, return_counts=False, dim=None) -> tuple[Tensor, Tensor, Tensor]
  656. Returns the unique elements of the input tensor.
  657. .. note:: This function is different from :func:`torch.unique_consecutive` in the sense that
  658. this function also eliminates non-consecutive duplicate values.
  659. .. note:: Currently in the CUDA implementation and the CPU implementation,
  660. `torch.unique` always sort the tensor at the beginning regardless of the `sort` argument.
  661. Sorting could be slow, so if your input tensor is already sorted, it is recommended to use
  662. :func:`torch.unique_consecutive` which avoids the sorting.
  663. Args:
  664. input (Tensor): the input tensor
  665. sorted (bool): Whether to sort the unique elements in ascending order
  666. before returning as output.
  667. return_inverse (bool): Whether to also return the indices for where
  668. elements in the original input ended up in the returned unique list.
  669. return_counts (bool): Whether to also return the counts for each unique
  670. element.
  671. dim (int, optional): the dimension to operate upon. If ``None``, the
  672. unique of the flattened input is returned. Otherwise, each of the
  673. tensors indexed by the given dimension is treated as one of the
  674. elements to apply the unique operation upon. See examples for more
  675. details. Default: ``None``
  676. Returns:
  677. (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
  678. - **output** (*Tensor*): the output list of unique scalar elements.
  679. - **inverse_indices** (*Tensor*): (optional) if
  680. :attr:`return_inverse` is True, there will be an additional
  681. returned tensor (same shape as input) representing the indices
  682. for where elements in the original input map to in the output;
  683. otherwise, this function will only return a single tensor.
  684. - **counts** (*Tensor*): (optional) if
  685. :attr:`return_counts` is True, there will be an additional
  686. returned tensor (same shape as output or output.size(dim),
  687. if dim was specified) representing the number of occurrences
  688. for each unique value or tensor.
  689. Example::
  690. >>> output = torch.unique(torch.tensor([1, 3, 2, 3], dtype=torch.long))
  691. >>> output
  692. tensor([1, 2, 3])
  693. >>> output, inverse_indices = torch.unique(
  694. ... torch.tensor([1, 3, 2, 3], dtype=torch.long), sorted=True, return_inverse=True)
  695. >>> output
  696. tensor([1, 2, 3])
  697. >>> inverse_indices
  698. tensor([0, 2, 1, 2])
  699. >>> output, inverse_indices = torch.unique(
  700. ... torch.tensor([[1, 3], [2, 3]], dtype=torch.long), sorted=True, return_inverse=True)
  701. >>> output
  702. tensor([1, 2, 3])
  703. >>> inverse_indices
  704. tensor([[0, 2],
  705. [1, 2]])
  706. >>> a = torch.tensor([
  707. ... [
  708. ... [1, 1, 0, 0],
  709. ... [1, 1, 0, 0],
  710. ... [0, 0, 1, 1],
  711. ... ],
  712. ... [
  713. ... [0, 0, 1, 1],
  714. ... [0, 0, 1, 1],
  715. ... [1, 1, 1, 1],
  716. ... ],
  717. ... [
  718. ... [1, 1, 0, 0],
  719. ... [1, 1, 0, 0],
  720. ... [0, 0, 1, 1],
  721. ... ],
  722. ... ])
  723. >>> # If we call `torch.unique(a, dim=0)`, each of the tensors `a[idx, :, :]`
  724. >>> # will be compared. We can see that `a[0, :, :]` and `a[2, :, :]` match
  725. >>> # each other, so one of them will be removed.
  726. >>> (a[0, :, :] == a[2, :, :]).all()
  727. tensor(True)
  728. >>> a_unique_dim0 = torch.unique(a, dim=0)
  729. >>> a_unique_dim0
  730. tensor([[[0, 0, 1, 1],
  731. [0, 0, 1, 1],
  732. [1, 1, 1, 1]],
  733. [[1, 1, 0, 0],
  734. [1, 1, 0, 0],
  735. [0, 0, 1, 1]]])
  736. >>> # Notice which sub-tensors from `a` match with the sub-tensors from
  737. >>> # `a_unique_dim0`:
  738. >>> (a_unique_dim0[0, :, :] == a[1, :, :]).all()
  739. tensor(True)
  740. >>> (a_unique_dim0[1, :, :] == a[0, :, :]).all()
  741. tensor(True)
  742. >>> # For `torch.unique(a, dim=1)`, each of the tensors `a[:, idx, :]` are
  743. >>> # compared. `a[:, 0, :]` and `a[:, 1, :]` match each other, so one of
  744. >>> # them will be removed.
  745. >>> (a[:, 0, :] == a[:, 1, :]).all()
  746. tensor(True)
  747. >>> torch.unique(a, dim=1)
  748. tensor([[[0, 0, 1, 1],
  749. [1, 1, 0, 0]],
  750. [[1, 1, 1, 1],
  751. [0, 0, 1, 1]],
  752. [[0, 0, 1, 1],
  753. [1, 1, 0, 0]]])
  754. >>> # For `torch.unique(a, dim=2)`, the tensors `a[:, :, idx]` are compared.
  755. >>> # `a[:, :, 0]` and `a[:, :, 1]` match each other. Also, `a[:, :, 2]` and
  756. >>> # `a[:, :, 3]` match each other as well. So in this case, two of the
  757. >>> # sub-tensors will be removed.
  758. >>> (a[:, :, 0] == a[:, :, 1]).all()
  759. tensor(True)
  760. >>> (a[:, :, 2] == a[:, :, 3]).all()
  761. tensor(True)
  762. >>> torch.unique(a, dim=2)
  763. tensor([[[0, 1],
  764. [0, 1],
  765. [1, 0]],
  766. [[1, 0],
  767. [1, 0],
  768. [1, 1]],
  769. [[0, 1],
  770. [0, 1],
  771. [1, 0]]])
  772. """
  773. if has_torch_function_unary(input):
  774. return handle_torch_function(
  775. unique,
  776. (input,),
  777. input,
  778. sorted=sorted,
  779. return_inverse=return_inverse,
  780. return_counts=return_counts,
  781. dim=dim,
  782. )
  783. if dim is not None:
  784. output, inverse_indices, counts = _VF.unique_dim(
  785. input,
  786. dim,
  787. sorted=sorted,
  788. return_inverse=return_inverse,
  789. return_counts=return_counts,
  790. )
  791. else:
  792. output, inverse_indices, counts = torch._unique2(
  793. input,
  794. sorted=sorted,
  795. return_inverse=return_inverse,
  796. return_counts=return_counts,
  797. )
  798. return output, inverse_indices, counts
  799. def _unique_consecutive_impl(
  800. input: Tensor,
  801. return_inverse: bool = False,
  802. return_counts: bool = False,
  803. dim: Optional[int] = None,
  804. ) -> _unique_impl_out:
  805. r"""Eliminates all but the first element from every consecutive group of equivalent elements.
  806. .. note:: This function is different from :func:`torch.unique` in the sense that this function
  807. only eliminates consecutive duplicate values. This semantics is similar to `std::unique`
  808. in C++.
  809. Args:
  810. input (Tensor): the input tensor
  811. return_inverse (bool): Whether to also return the indices for where
  812. elements in the original input ended up in the returned unique list.
  813. return_counts (bool): Whether to also return the counts for each unique
  814. element.
  815. dim (int): the dimension to apply unique. If ``None``, the unique of the
  816. flattened input is returned. default: ``None``
  817. Returns:
  818. (Tensor, Tensor (optional), Tensor (optional)): A tensor or a tuple of tensors containing
  819. - **output** (*Tensor*): the output list of unique scalar elements.
  820. - **inverse_indices** (*Tensor*): (optional) if
  821. :attr:`return_inverse` is True, there will be an additional
  822. returned tensor (same shape as input) representing the indices
  823. for where elements in the original input map to in the output;
  824. otherwise, this function will only return a single tensor.
  825. - **counts** (*Tensor*): (optional) if
  826. :attr:`return_counts` is True, there will be an additional
  827. returned tensor (same shape as output or output.size(dim),
  828. if dim was specified) representing the number of occurrences
  829. for each unique value or tensor.
  830. Example::
  831. >>> x = torch.tensor([1, 1, 2, 2, 3, 1, 1, 2])
  832. >>> output = torch.unique_consecutive(x)
  833. >>> output
  834. tensor([1, 2, 3, 1, 2])
  835. >>> output, inverse_indices = torch.unique_consecutive(x, return_inverse=True)
  836. >>> output
  837. tensor([1, 2, 3, 1, 2])
  838. >>> inverse_indices
  839. tensor([0, 0, 1, 1, 2, 3, 3, 4])
  840. >>> output, counts = torch.unique_consecutive(x, return_counts=True)
  841. >>> output
  842. tensor([1, 2, 3, 1, 2])
  843. >>> counts
  844. tensor([2, 2, 1, 2, 1])
  845. """
  846. if has_torch_function_unary(input):
  847. return handle_torch_function(
  848. unique_consecutive,
  849. (input,),
  850. input,
  851. return_inverse=return_inverse,
  852. return_counts=return_counts,
  853. dim=dim,
  854. )
  855. output, inverse_indices, counts = _VF.unique_consecutive( # type: ignore[attr-defined]
  856. input, return_inverse=return_inverse, return_counts=return_counts, dim=dim
  857. )
  858. return output, inverse_indices, counts
  859. def _return_counts(
  860. input,
  861. sorted=True,
  862. return_inverse=False,
  863. return_counts=False,
  864. dim=None,
  865. ):
  866. # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
  867. if has_torch_function_unary(input):
  868. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  869. output, _, counts = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  870. return output, counts
  871. def _return_output(
  872. input,
  873. sorted=True,
  874. return_inverse=False,
  875. return_counts=False,
  876. dim=None,
  877. ):
  878. # type: (Tensor, bool, bool, bool, Optional[int]) -> Tensor
  879. if has_torch_function_unary(input):
  880. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  881. output, _, _ = _unique_impl(input, sorted, return_inverse, return_counts, dim)
  882. return output
  883. def _return_inverse(
  884. input,
  885. sorted=True,
  886. return_inverse=False,
  887. return_counts=False,
  888. dim=None,
  889. ):
  890. # type: (Tensor, bool, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
  891. if has_torch_function_unary(input):
  892. return _unique_impl(input, sorted, return_inverse, return_counts, dim)
  893. output, inverse_indices, _ = _unique_impl(
  894. input, sorted, return_inverse, return_counts, dim
  895. )
  896. return output, inverse_indices
  897. _return_inverse_false = boolean_dispatch(
  898. arg_name="return_counts",
  899. arg_index=3,
  900. default=False,
  901. if_true=_return_counts,
  902. if_false=_return_output,
  903. module_name=__name__,
  904. func_name="unique",
  905. )
  906. _return_inverse_true = boolean_dispatch(
  907. arg_name="return_counts",
  908. arg_index=3,
  909. default=False,
  910. if_true=_unique_impl,
  911. if_false=_return_inverse,
  912. module_name=__name__,
  913. func_name="unique",
  914. )
  915. # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
  916. # resolve the output type in TorchScript we need to statically know the value of both parameters
  917. unique = boolean_dispatch(
  918. arg_name="return_inverse",
  919. arg_index=2,
  920. default=False,
  921. if_true=_return_inverse_true,
  922. if_false=_return_inverse_false,
  923. module_name=__name__,
  924. func_name="unique",
  925. )
  926. unique.__doc__ = _unique_impl.__doc__
  927. def _consecutive_return_counts(
  928. input,
  929. return_inverse=False,
  930. return_counts=False,
  931. dim=None,
  932. ):
  933. # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
  934. if has_torch_function_unary(input):
  935. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  936. output, _, counts = _unique_consecutive_impl(
  937. input, return_inverse, return_counts, dim
  938. )
  939. return output, counts
  940. def _consecutive_return_output(
  941. input,
  942. return_inverse=False,
  943. return_counts=False,
  944. dim=None,
  945. ):
  946. # type: (Tensor, bool, bool, Optional[int]) -> Tensor
  947. if has_torch_function_unary(input):
  948. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  949. output, _, _ = _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  950. return output
  951. def _consecutive_return_inverse(
  952. input,
  953. return_inverse=False,
  954. return_counts=False,
  955. dim=None,
  956. ):
  957. # type: (Tensor, bool, bool, Optional[int]) -> tuple[Tensor, Tensor]
  958. if has_torch_function_unary(input):
  959. return _unique_consecutive_impl(input, return_inverse, return_counts, dim)
  960. output, inverse_indices, _ = _unique_consecutive_impl(
  961. input, return_inverse, return_counts, dim
  962. )
  963. return output, inverse_indices
  964. _consecutive_return_inverse_false = boolean_dispatch(
  965. arg_name="return_counts",
  966. arg_index=1,
  967. default=False,
  968. if_true=_consecutive_return_counts,
  969. if_false=_consecutive_return_output,
  970. module_name=__name__,
  971. func_name="unique_consecutive",
  972. )
  973. _consecutive_return_inverse_true = boolean_dispatch(
  974. arg_name="return_counts",
  975. arg_index=1,
  976. default=False,
  977. if_true=_unique_consecutive_impl,
  978. if_false=_consecutive_return_inverse,
  979. module_name=__name__,
  980. func_name="unique_consecutive",
  981. )
  982. # The return type of unique depends on `return_inverse`, and `return_counts` so in order to
  983. # resolve the output type in TorchScript we need to statically know the value of both parameters
  984. unique_consecutive = boolean_dispatch(
  985. arg_name="return_inverse",
  986. arg_index=2,
  987. default=False,
  988. if_true=_consecutive_return_inverse_true,
  989. if_false=_consecutive_return_inverse_false,
  990. module_name=__name__,
  991. func_name="unique_consecutive",
  992. )
  993. unique_consecutive.__doc__ = _unique_consecutive_impl.__doc__
  994. if TYPE_CHECKING:
  995. pass
  996. # There's no good way to use this type annotation without breaking JIT
  997. # overloads. So leave untyped for mypy for now.
  998. else:
  999. @overload
  1000. def tensordot(
  1001. a,
  1002. b,
  1003. dims: int = 2,
  1004. out: Optional[torch.Tensor] = None,
  1005. ):
  1006. pass
  1007. @overload
  1008. def tensordot( # noqa: F811
  1009. a,
  1010. b,
  1011. dims: tuple[list[int], list[int]],
  1012. out: Optional[torch.Tensor] = None,
  1013. ):
  1014. pass
  1015. @overload
  1016. def tensordot( # noqa: F811
  1017. a,
  1018. b,
  1019. dims: list[list[int]],
  1020. out: Optional[torch.Tensor] = None,
  1021. ):
  1022. pass
  1023. @overload
  1024. def tensordot( # noqa: F811
  1025. a,
  1026. b,
  1027. dims: torch.Tensor,
  1028. out: Optional[torch.Tensor] = None,
  1029. ):
  1030. pass
  1031. def tensordot( # noqa: F811
  1032. a,
  1033. b,
  1034. dims=2,
  1035. out: Optional[torch.Tensor] = None,
  1036. ):
  1037. r"""Returns a contraction of a and b over multiple dimensions.
  1038. :attr:`tensordot` implements a generalized matrix product.
  1039. Args:
  1040. a (Tensor): Left tensor to contract
  1041. b (Tensor): Right tensor to contract
  1042. dims (int or Tuple[List[int], List[int]] or List[List[int]] containing two lists or Tensor): number of dimensions to
  1043. contract or explicit lists of dimensions for :attr:`a` and
  1044. :attr:`b` respectively
  1045. When called with a non-negative integer argument :attr:`dims` = :math:`d`, and
  1046. the number of dimensions of :attr:`a` and :attr:`b` is :math:`m` and :math:`n`,
  1047. respectively, :func:`~torch.tensordot` computes
  1048. .. math::
  1049. r_{i_0,...,i_{m-d}, i_d,...,i_n}
  1050. = \sum_{k_0,...,k_{d-1}} a_{i_0,...,i_{m-d},k_0,...,k_{d-1}} \times b_{k_0,...,k_{d-1}, i_d,...,i_n}.
  1051. When called with :attr:`dims` of the list form, the given dimensions will be contracted
  1052. in place of the last :math:`d` of :attr:`a` and the first :math:`d` of :math:`b`. The sizes
  1053. in these dimensions must match, but :func:`~torch.tensordot` will deal with broadcasted
  1054. dimensions.
  1055. Examples::
  1056. >>> a = torch.arange(60.).reshape(3, 4, 5)
  1057. >>> b = torch.arange(24.).reshape(4, 3, 2)
  1058. >>> torch.tensordot(a, b, dims=([1, 0], [0, 1]))
  1059. tensor([[4400., 4730.],
  1060. [4532., 4874.],
  1061. [4664., 5018.],
  1062. [4796., 5162.],
  1063. [4928., 5306.]])
  1064. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
  1065. >>> a = torch.randn(3, 4, 5, device='cuda')
  1066. >>> b = torch.randn(4, 5, 6, device='cuda')
  1067. >>> c = torch.tensordot(a, b, dims=2).cpu()
  1068. tensor([[ 8.3504, -2.5436, 6.2922, 2.7556, -1.0732, 3.2741],
  1069. [ 3.3161, 0.0704, 5.0187, -0.4079, -4.3126, 4.8744],
  1070. [ 0.8223, 3.9445, 3.2168, -0.2400, 3.4117, 1.7780]])
  1071. >>> a = torch.randn(3, 5, 4, 6)
  1072. >>> b = torch.randn(6, 4, 5, 3)
  1073. >>> torch.tensordot(a, b, dims=([2, 1, 3], [1, 2, 0]))
  1074. tensor([[ 7.7193, -2.4867, -10.3204],
  1075. [ 1.5513, -14.4737, -6.5113],
  1076. [ -0.2850, 4.2573, -3.5997]])
  1077. """
  1078. if has_torch_function_variadic(a, b):
  1079. return handle_torch_function(tensordot, (a, b), a, b, dims=dims, out=out)
  1080. if not isinstance(dims, (tuple, list, torch.Tensor, int, torch.SymInt)):
  1081. raise RuntimeError(
  1082. "tensordot expects dims to be int or "
  1083. + "tuple[list[int], list[int]] or "
  1084. + "list[list[int]] containing two lists, but got "
  1085. + f"dims={dims}"
  1086. )
  1087. dims_a: list[int] = []
  1088. dims_b: list[int] = []
  1089. if isinstance(dims, (tuple, list)):
  1090. dims_a, dims_b = dims
  1091. if isinstance(dims, torch.Tensor):
  1092. num_elements = dims.numel()
  1093. if num_elements > 1:
  1094. assert dims.size()[0] == 2
  1095. dims_a = torch.jit.annotate(list[int], dims[0].tolist())
  1096. dims_b = torch.jit.annotate(list[int], dims[1].tolist())
  1097. else:
  1098. dims_val = int(dims.item())
  1099. if dims_val < 0:
  1100. raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
  1101. dims_a = list(range(-dims_val, 0))
  1102. dims_b = list(range(dims_val))
  1103. if isinstance(dims, (int, torch.SymInt)):
  1104. if dims < 0:
  1105. raise RuntimeError(f"tensordot expects dims >= 0, but got dims={dims}")
  1106. if dims > min(a.dim(), b.dim()):
  1107. raise RuntimeError(
  1108. f"tensordot expects dims < ndim_a or ndim_b, but got dims={dims}"
  1109. )
  1110. dims_a = list(range(-dims, 0))
  1111. dims_b = list(range(dims))
  1112. if out is None:
  1113. return _VF.tensordot(a, b, dims_a, dims_b) # type: ignore[attr-defined]
  1114. else:
  1115. return _VF.tensordot(a, b, dims_a, dims_b, out=out) # type: ignore[attr-defined]
  1116. def cartesian_prod(*tensors: Tensor) -> Tensor:
  1117. """Do cartesian product of the given sequence of tensors. The behavior is similar to
  1118. python's `itertools.product`.
  1119. Args:
  1120. *tensors: any number of 1 dimensional tensors.
  1121. Returns:
  1122. Tensor: A tensor equivalent to converting all the input tensors into lists,
  1123. do `itertools.product` on these lists, and finally convert the resulting list
  1124. into tensor.
  1125. Example::
  1126. >>> import itertools
  1127. >>> a = [1, 2, 3]
  1128. >>> b = [4, 5]
  1129. >>> list(itertools.product(a, b))
  1130. [(1, 4), (1, 5), (2, 4), (2, 5), (3, 4), (3, 5)]
  1131. >>> tensor_a = torch.tensor(a)
  1132. >>> tensor_b = torch.tensor(b)
  1133. >>> torch.cartesian_prod(tensor_a, tensor_b)
  1134. tensor([[1, 4],
  1135. [1, 5],
  1136. [2, 4],
  1137. [2, 5],
  1138. [3, 4],
  1139. [3, 5]])
  1140. """
  1141. # This wrapper exists to support variadic args.
  1142. if has_torch_function(tensors):
  1143. return handle_torch_function(cartesian_prod, tensors, *tensors)
  1144. return _VF.cartesian_prod(tensors) # type: ignore[attr-defined]
  1145. def block_diag(*tensors):
  1146. """Create a block diagonal matrix from provided tensors.
  1147. Args:
  1148. *tensors: One or more tensors with 0, 1, or 2 dimensions.
  1149. Returns:
  1150. Tensor: A 2 dimensional tensor with all the input tensors arranged in
  1151. order such that their upper left and lower right corners are
  1152. diagonally adjacent. All other elements are set to 0.
  1153. Example::
  1154. >>> import torch
  1155. >>> A = torch.tensor([[0, 1], [1, 0]])
  1156. >>> B = torch.tensor([[3, 4, 5], [6, 7, 8]])
  1157. >>> C = torch.tensor(7)
  1158. >>> D = torch.tensor([1, 2, 3])
  1159. >>> E = torch.tensor([[4], [5], [6]])
  1160. >>> torch.block_diag(A, B, C, D, E)
  1161. tensor([[0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
  1162. [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
  1163. [0, 0, 3, 4, 5, 0, 0, 0, 0, 0],
  1164. [0, 0, 6, 7, 8, 0, 0, 0, 0, 0],
  1165. [0, 0, 0, 0, 0, 7, 0, 0, 0, 0],
  1166. [0, 0, 0, 0, 0, 0, 1, 2, 3, 0],
  1167. [0, 0, 0, 0, 0, 0, 0, 0, 0, 4],
  1168. [0, 0, 0, 0, 0, 0, 0, 0, 0, 5],
  1169. [0, 0, 0, 0, 0, 0, 0, 0, 0, 6]])
  1170. """
  1171. # This wrapper exists to support variadic args.
  1172. if has_torch_function(tensors):
  1173. return handle_torch_function(block_diag, tensors, *tensors)
  1174. return torch._C._VariableFunctions.block_diag(tensors) # type: ignore[attr-defined]
  1175. def cdist(x1, x2, p=2.0, compute_mode="use_mm_for_euclid_dist_if_necessary"):
  1176. # type: (Tensor, Tensor, float, str) -> (Tensor)
  1177. r"""Computes batched the p-norm distance between each pair of the two collections of row vectors.
  1178. Args:
  1179. x1 (Tensor): input tensor where the last two dimensions represent the points and the feature dimension respectively.
  1180. The shape can be :math:`D_1 \times D_2 \times \cdots \times D_n \times P \times M`,
  1181. where :math:`P` is the number of points and :math:`M` is the feature dimension.
  1182. x2 (Tensor): input tensor where the last two dimensions also represent the points and the feature dimension respectively.
  1183. The shape can be :math:`D_1' \times D_2' \times \cdots \times D_m' \times R \times M`,
  1184. where :math:`R` is the number of points and :math:`M` is the feature dimension,
  1185. which should match the feature dimension of `x1`.
  1186. p: p value for the p-norm distance to calculate between each vector pair
  1187. :math:`\in [0, \infty]`.
  1188. compute_mode:
  1189. 'use_mm_for_euclid_dist_if_necessary' - will use matrix multiplication approach to calculate
  1190. euclidean distance (p = 2) if P > 25 or R > 25
  1191. 'use_mm_for_euclid_dist' - will always use matrix multiplication approach to calculate
  1192. euclidean distance (p = 2)
  1193. 'donot_use_mm_for_euclid_dist' - will never use matrix multiplication approach to calculate
  1194. euclidean distance (p = 2)
  1195. Default: use_mm_for_euclid_dist_if_necessary.
  1196. If x1 has shape :math:`B \times P \times M` and x2 has shape :math:`B \times R \times M` then the
  1197. output will have shape :math:`B \times P \times R`.
  1198. This function is equivalent to `scipy.spatial.distance.cdist(input,'minkowski', p=p)`
  1199. if :math:`p \in (0, \infty)`. When :math:`p = 0` it is equivalent to
  1200. `scipy.spatial.distance.cdist(input, 'hamming') * M`. When :math:`p = \infty`, the closest
  1201. scipy function is `scipy.spatial.distance.cdist(xn, lambda x, y: np.abs(x - y).max())`.
  1202. Example:
  1203. >>> a = torch.tensor([[0.9041, 0.0196], [-0.3108, -2.4423], [-0.4821, 1.059]])
  1204. >>> a
  1205. tensor([[ 0.9041, 0.0196],
  1206. [-0.3108, -2.4423],
  1207. [-0.4821, 1.0590]])
  1208. >>> b = torch.tensor([[-2.1763, -0.4713], [-0.6986, 1.3702]])
  1209. >>> b
  1210. tensor([[-2.1763, -0.4713],
  1211. [-0.6986, 1.3702]])
  1212. >>> torch.cdist(a, b, p=2)
  1213. tensor([[3.1193, 2.0959],
  1214. [2.7138, 3.8322],
  1215. [2.2830, 0.3791]])
  1216. """
  1217. if has_torch_function_variadic(x1, x2):
  1218. return handle_torch_function(
  1219. cdist, (x1, x2), x1, x2, p=p, compute_mode=compute_mode
  1220. )
  1221. if compute_mode == "use_mm_for_euclid_dist_if_necessary":
  1222. return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
  1223. elif compute_mode == "use_mm_for_euclid_dist":
  1224. return _VF.cdist(x1, x2, p, 1) # type: ignore[attr-defined]
  1225. elif compute_mode == "donot_use_mm_for_euclid_dist":
  1226. return _VF.cdist(x1, x2, p, 2) # type: ignore[attr-defined]
  1227. else:
  1228. raise ValueError(f"{compute_mode} is not a valid value for compute_mode")
  1229. def atleast_1d(*tensors):
  1230. r"""
  1231. Returns a 1-dimensional view of each input tensor with zero dimensions.
  1232. Input tensors with one or more dimensions are returned as-is.
  1233. Args:
  1234. input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 1-dimensional.
  1235. Returns:
  1236. output (Tensor or tuple of Tensors)
  1237. Example::
  1238. >>> x = torch.arange(2)
  1239. >>> x
  1240. tensor([0, 1])
  1241. >>> torch.atleast_1d(x)
  1242. tensor([0, 1])
  1243. >>> x = torch.tensor(1.)
  1244. >>> x
  1245. tensor(1.)
  1246. >>> torch.atleast_1d(x)
  1247. tensor([1.])
  1248. >>> x = torch.tensor(0.5)
  1249. >>> y = torch.tensor(1.)
  1250. >>> torch.atleast_1d((x, y))
  1251. (tensor([0.5000]), tensor([1.]))
  1252. >>> torch.atleast_1d()
  1253. ()
  1254. """
  1255. # This wrapper exists to support variadic args.
  1256. if has_torch_function(tensors):
  1257. return handle_torch_function(atleast_1d, tensors, *tensors)
  1258. if len(tensors) == 1:
  1259. tensors = tensors[0]
  1260. return _VF.atleast_1d(tensors) # type: ignore[attr-defined]
  1261. def atleast_2d(*tensors):
  1262. r"""
  1263. Returns a 2-dimensional view of each input tensor with zero dimensions.
  1264. Input tensors with two or more dimensions are returned as-is.
  1265. Args:
  1266. input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 2-dimensional.
  1267. Returns:
  1268. output (Tensor or tuple of Tensors)
  1269. Example::
  1270. >>> x = torch.tensor(1.)
  1271. >>> x
  1272. tensor(1.)
  1273. >>> torch.atleast_2d(x)
  1274. tensor([[1.]])
  1275. >>> x = torch.arange(4).view(2, 2)
  1276. >>> x
  1277. tensor([[0, 1],
  1278. [2, 3]])
  1279. >>> torch.atleast_2d(x)
  1280. tensor([[0, 1],
  1281. [2, 3]])
  1282. >>> x = torch.tensor(0.5)
  1283. >>> y = torch.tensor(1.)
  1284. >>> torch.atleast_2d((x, y))
  1285. (tensor([[0.5000]]), tensor([[1.]]))
  1286. >>> torch.atleast_2d()
  1287. ()
  1288. """
  1289. # This wrapper exists to support variadic args.
  1290. if has_torch_function(tensors):
  1291. return handle_torch_function(atleast_2d, tensors, *tensors)
  1292. if len(tensors) == 1:
  1293. tensors = tensors[0]
  1294. return _VF.atleast_2d(tensors) # type: ignore[attr-defined]
  1295. def atleast_3d(*tensors):
  1296. r"""
  1297. Returns a 3-dimensional view of each input tensor with zero dimensions.
  1298. Input tensors with three or more dimensions are returned as-is.
  1299. Args:
  1300. input (Tensor or sequence of Tensors): tensor(s) to be converted to at least 3-dimensional.
  1301. Returns:
  1302. output (Tensor or tuple of Tensors)
  1303. Example:
  1304. >>> x = torch.tensor(0.5)
  1305. >>> x
  1306. tensor(0.5000)
  1307. >>> torch.atleast_3d(x)
  1308. tensor([[[0.5000]]])
  1309. >>> y = torch.arange(4).view(2, 2)
  1310. >>> y
  1311. tensor([[0, 1],
  1312. [2, 3]])
  1313. >>> torch.atleast_3d(y)
  1314. tensor([[[0],
  1315. [1]],
  1316. <BLANKLINE>
  1317. [[2],
  1318. [3]]])
  1319. >>> x = torch.tensor(1).view(1, 1, 1)
  1320. >>> x
  1321. tensor([[[1]]])
  1322. >>> torch.atleast_3d(x)
  1323. tensor([[[1]]])
  1324. >>> x = torch.tensor(0.5)
  1325. >>> y = torch.tensor(1.0)
  1326. >>> torch.atleast_3d((x, y))
  1327. (tensor([[[0.5000]]]), tensor([[[1.]]]))
  1328. >>> torch.atleast_3d()
  1329. ()
  1330. """
  1331. # This wrapper exists to support variadic args.
  1332. if has_torch_function(tensors):
  1333. return handle_torch_function(atleast_3d, tensors, *tensors)
  1334. if len(tensors) == 1:
  1335. tensors = tensors[0]
  1336. return _VF.atleast_3d(tensors) # type: ignore[attr-defined]
  1337. if TYPE_CHECKING:
  1338. pass
  1339. # There's no good way to use this type annotation; cannot rename norm() to
  1340. # _norm_impl() in a way that doesn't break JIT overloads. So leave untyped
  1341. # for mypy for now.
  1342. # def norm(input: Tensor,
  1343. # p: Optional[Union[str, Number]] = "fro",
  1344. # dim: Optional[Union[int, List[int]]] = None,
  1345. # keepdim: bool = False,
  1346. # out: Optional[Tensor] = None,
  1347. # dtype: _dtype = None) -> Tensor:
  1348. # return _norm_impl(input, p, dim, keepdim, out, dtype)
  1349. else:
  1350. # TODO: type dim as BroadcastingList when
  1351. # https://github.com/pytorch/pytorch/issues/33782 is fixed
  1352. @overload
  1353. def norm(
  1354. input,
  1355. p="fro",
  1356. dim=None,
  1357. keepdim=False,
  1358. out=None,
  1359. dtype=None,
  1360. ):
  1361. # type: (Tensor, str, Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
  1362. pass
  1363. @overload
  1364. def norm( # noqa: F811
  1365. input,
  1366. p="fro",
  1367. dim=None,
  1368. keepdim=False,
  1369. out=None,
  1370. dtype=None,
  1371. ):
  1372. # type: (Tensor, Optional[number], Optional[List[int]], bool, Optional[Tensor], Optional[int]) -> Tensor
  1373. pass
  1374. @overload
  1375. def norm( # noqa: F811
  1376. input,
  1377. p="fro",
  1378. dim=None,
  1379. keepdim=False,
  1380. out=None,
  1381. dtype=None,
  1382. ):
  1383. # type: (Tensor, Optional[number], Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
  1384. pass
  1385. @overload
  1386. def norm( # noqa: F811
  1387. input,
  1388. p="fro",
  1389. dim=None,
  1390. keepdim=False,
  1391. out=None,
  1392. dtype=None,
  1393. ):
  1394. # type: (Tensor, str, Optional[int], bool, Optional[Tensor], Optional[int]) -> Tensor
  1395. pass
  1396. def norm( # noqa: F811
  1397. input,
  1398. p: Optional[Union[float, str]] = "fro",
  1399. dim=None,
  1400. keepdim=False,
  1401. out=None,
  1402. dtype=None,
  1403. ):
  1404. r"""Returns the matrix norm or vector norm of a given tensor.
  1405. .. warning::
  1406. torch.norm is deprecated and may be removed in a future PyTorch release.
  1407. Its documentation and behavior may be incorrect, and it is no longer
  1408. actively maintained.
  1409. Use :func:`torch.linalg.vector_norm` when computing vector norms and
  1410. :func:`torch.linalg.matrix_norm` when computing matrix norms.
  1411. For a function with a similar behavior as this one see :func:`torch.linalg.norm`.
  1412. Note, however, the signature for these functions is slightly different than the
  1413. signature for ``torch.norm``.
  1414. Args:
  1415. input (Tensor): The input tensor. Its data type must be either a floating
  1416. point or complex type. For complex inputs, the norm is calculated using the
  1417. absolute value of each element. If the input is complex and neither
  1418. :attr:`dtype` nor :attr:`out` is specified, the result's data type will
  1419. be the corresponding floating point type (e.g. float if :attr:`input` is
  1420. complexfloat).
  1421. p (int, float, inf, -inf, 'fro', 'nuc', optional): the order of norm. Default: ``'fro'``
  1422. The following norms can be calculated:
  1423. ====== ============== ==========================
  1424. ord matrix norm vector norm
  1425. ====== ============== ==========================
  1426. 'fro' Frobenius norm --
  1427. 'nuc' nuclear norm --
  1428. Number -- sum(abs(x)**ord)**(1./ord)
  1429. ====== ============== ==========================
  1430. The vector norm can be calculated across any number of dimensions.
  1431. The corresponding dimensions of :attr:`input` are flattened into
  1432. one dimension, and the norm is calculated on the flattened
  1433. dimension.
  1434. Frobenius norm produces the same result as ``p=2`` in all cases
  1435. except when :attr:`dim` is a list of three or more dims, in which
  1436. case Frobenius norm throws an error.
  1437. Nuclear norm can only be calculated across exactly two dimensions.
  1438. dim (int, tuple of ints, list of ints, optional):
  1439. Specifies which dimension or dimensions of :attr:`input` to
  1440. calculate the norm across. If :attr:`dim` is ``None``, the norm will
  1441. be calculated across all dimensions of :attr:`input`. If the norm
  1442. type indicated by :attr:`p` does not support the specified number of
  1443. dimensions, an error will occur.
  1444. keepdim (bool, optional): whether the output tensors have :attr:`dim`
  1445. retained or not. Ignored if :attr:`dim` = ``None`` and
  1446. :attr:`out` = ``None``. Default: ``False``
  1447. out (Tensor, optional): the output tensor. Ignored if
  1448. :attr:`dim` = ``None`` and :attr:`out` = ``None``.
  1449. dtype (:class:`torch.dtype`, optional): the desired data type of
  1450. returned tensor. If specified, the input tensor is casted to
  1451. :attr:`dtype` while performing the operation. Default: None.
  1452. .. note::
  1453. Even though ``p='fro'`` supports any number of dimensions, the true
  1454. mathematical definition of Frobenius norm only applies to tensors with
  1455. exactly two dimensions. :func:`torch.linalg.matrix_norm` with ``ord='fro'``
  1456. aligns with the mathematical definition, since it can only be applied across
  1457. exactly two dimensions.
  1458. Example::
  1459. >>> import torch
  1460. >>> a = torch.arange(9, dtype= torch.float) - 4
  1461. >>> b = a.reshape((3, 3))
  1462. >>> torch.norm(a)
  1463. tensor(7.7460)
  1464. >>> torch.norm(b)
  1465. tensor(7.7460)
  1466. >>> torch.norm(a, float('inf'))
  1467. tensor(4.)
  1468. >>> torch.norm(b, float('inf'))
  1469. tensor(4.)
  1470. >>> c = torch.tensor([[ 1, 2, 3], [-1, 1, 4]] , dtype=torch.float)
  1471. >>> torch.norm(c, dim=0)
  1472. tensor([1.4142, 2.2361, 5.0000])
  1473. >>> torch.norm(c, dim=1)
  1474. tensor([3.7417, 4.2426])
  1475. >>> torch.norm(c, p=1, dim=1)
  1476. tensor([6., 6.])
  1477. >>> d = torch.arange(8, dtype=torch.float).reshape(2, 2, 2)
  1478. >>> torch.norm(d, dim=(1, 2))
  1479. tensor([ 3.7417, 11.2250])
  1480. >>> torch.norm(d[0, :, :]), torch.norm(d[1, :, :])
  1481. (tensor(3.7417), tensor(11.2250))
  1482. """
  1483. if has_torch_function_unary(input):
  1484. return handle_torch_function(
  1485. norm, (input,), input, p=p, dim=dim, keepdim=keepdim, out=out, dtype=dtype
  1486. )
  1487. # NB. All the repeated code and weird python is to please TorchScript.
  1488. # For a more compact implementation see the relevant function in `_refs/__init__.py`
  1489. # We don't do this for MPS or sparse tensors
  1490. if input.layout == torch.strided and input.device.type in (
  1491. "cpu",
  1492. "cuda",
  1493. "xpu",
  1494. "meta",
  1495. torch.utils.backend_registration._privateuse1_backend_name,
  1496. ):
  1497. if dim is not None:
  1498. if isinstance(dim, (int, torch.SymInt)):
  1499. _dim = [dim]
  1500. else:
  1501. _dim = dim
  1502. else:
  1503. _dim = None # type: ignore[assignment]
  1504. if isinstance(p, str):
  1505. if p == "fro" and (
  1506. dim is None or isinstance(dim, (int, torch.SymInt)) or len(dim) <= 2
  1507. ):
  1508. if out is None:
  1509. return torch.linalg.vector_norm(
  1510. input, 2, _dim, keepdim, dtype=dtype
  1511. )
  1512. else:
  1513. return torch.linalg.vector_norm(
  1514. input, 2, _dim, keepdim, dtype=dtype, out=out
  1515. )
  1516. # Here we either call the nuclear norm, or we call matrix_norm with some arguments
  1517. # that will throw an error
  1518. if _dim is None:
  1519. _dim = list(range(input.ndim))
  1520. if out is None:
  1521. return torch.linalg.matrix_norm(input, p, _dim, keepdim, dtype=dtype)
  1522. else:
  1523. return torch.linalg.matrix_norm(
  1524. input, p, _dim, keepdim, dtype=dtype, out=out
  1525. )
  1526. else:
  1527. # NB. p should be Union[str, number], not Optional!
  1528. _p = 2.0 if p is None else p
  1529. if out is None:
  1530. return torch.linalg.vector_norm(input, _p, _dim, keepdim, dtype=dtype)
  1531. else:
  1532. return torch.linalg.vector_norm(
  1533. input, _p, _dim, keepdim, dtype=dtype, out=out
  1534. )
  1535. ndim = input.dim()
  1536. # catch default case
  1537. if dim is None and out is None and dtype is None and p is not None:
  1538. if isinstance(p, str):
  1539. if p == "fro":
  1540. return _VF.frobenius_norm(input, dim=(), keepdim=keepdim)
  1541. if not isinstance(p, str):
  1542. _dim = list(range(ndim))
  1543. return _VF.norm(input, p, dim=_dim, keepdim=keepdim) # type: ignore[attr-defined]
  1544. # TODO: when https://github.com/pytorch/pytorch/issues/33782 is fixed
  1545. # remove the overloads where dim is an int and replace with BraodcastingList1
  1546. # and remove next four lines, replace _dim with dim
  1547. if dim is not None:
  1548. if isinstance(dim, (int, torch.SymInt)):
  1549. _dim = [dim]
  1550. else:
  1551. _dim = dim
  1552. else:
  1553. _dim = None # type: ignore[assignment]
  1554. if isinstance(p, str):
  1555. if p == "fro":
  1556. if dtype is not None:
  1557. raise ValueError("dtype argument is not supported in frobenius norm")
  1558. if _dim is None:
  1559. _dim = list(range(ndim))
  1560. if out is None:
  1561. return _VF.frobenius_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
  1562. else:
  1563. return _VF.frobenius_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
  1564. elif p == "nuc":
  1565. if dtype is not None:
  1566. raise ValueError("dtype argument is not supported in nuclear norm")
  1567. if _dim is None:
  1568. if out is None:
  1569. return _VF.nuclear_norm(input, keepdim=keepdim) # type: ignore[arg-type]
  1570. else:
  1571. return _VF.nuclear_norm(input, keepdim=keepdim, out=out) # type: ignore[arg-type]
  1572. else:
  1573. if out is None:
  1574. return _VF.nuclear_norm(input, _dim, keepdim=keepdim) # type: ignore[arg-type]
  1575. else:
  1576. return _VF.nuclear_norm(input, _dim, keepdim=keepdim, out=out) # type: ignore[arg-type]
  1577. raise RuntimeError(f"only valid string values are 'fro' and 'nuc', found {p}")
  1578. else:
  1579. if _dim is None:
  1580. _dim = list(range(ndim))
  1581. if out is None:
  1582. if dtype is None:
  1583. return _VF.norm(input, p, _dim, keepdim=keepdim) # type: ignore[attr-defined]
  1584. else:
  1585. return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype) # type: ignore[attr-defined]
  1586. else:
  1587. if dtype is None:
  1588. return _VF.norm(input, p, _dim, keepdim=keepdim, out=out) # type: ignore[attr-defined]
  1589. else:
  1590. return _VF.norm(input, p, _dim, keepdim=keepdim, dtype=dtype, out=out) # type: ignore[attr-defined]
  1591. def unravel_index(
  1592. indices: Tensor,
  1593. shape: Union[int, Sequence[int], torch.Size],
  1594. ) -> tuple[Tensor, ...]:
  1595. r"""Converts a tensor of flat indices into a tuple of coordinate tensors that
  1596. index into an arbitrary tensor of the specified shape.
  1597. Args:
  1598. indices (Tensor): An integer tensor containing indices into the
  1599. flattened version of an arbitrary tensor of shape :attr:`shape`.
  1600. All elements must be in the range ``[0, prod(shape) - 1]``.
  1601. shape (int, sequence of ints, or torch.Size): The shape of the arbitrary
  1602. tensor. All elements must be non-negative.
  1603. Returns:
  1604. tuple of Tensors: Each ``i``-th tensor in the output corresponds with
  1605. dimension ``i`` of :attr:`shape`. Each tensor has the same shape as
  1606. ``indices`` and contains one index into dimension ``i`` for each of the
  1607. flat indices given by ``indices``.
  1608. Example::
  1609. >>> import torch
  1610. >>> torch.unravel_index(torch.tensor(4), (3, 2))
  1611. (tensor(2),
  1612. tensor(0))
  1613. >>> torch.unravel_index(torch.tensor([4, 1]), (3, 2))
  1614. (tensor([2, 0]),
  1615. tensor([0, 1]))
  1616. >>> torch.unravel_index(torch.tensor([0, 1, 2, 3, 4, 5]), (3, 2))
  1617. (tensor([0, 0, 1, 1, 2, 2]),
  1618. tensor([0, 1, 0, 1, 0, 1]))
  1619. >>> torch.unravel_index(torch.tensor([1234, 5678]), (10, 10, 10, 10))
  1620. (tensor([1, 5]),
  1621. tensor([2, 6]),
  1622. tensor([3, 7]),
  1623. tensor([4, 8]))
  1624. >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (10, 10, 10, 10))
  1625. (tensor([[1], [5]]),
  1626. tensor([[2], [6]]),
  1627. tensor([[3], [7]]),
  1628. tensor([[4], [8]]))
  1629. >>> torch.unravel_index(torch.tensor([[1234], [5678]]), (100, 100))
  1630. (tensor([[12], [56]]),
  1631. tensor([[34], [78]]))
  1632. """
  1633. if has_torch_function_unary(indices):
  1634. return handle_torch_function(unravel_index, (indices,), indices, shape=shape)
  1635. res_tensor = _unravel_index(indices, shape)
  1636. return res_tensor.unbind(-1)
  1637. def _unravel_index(indices: Tensor, shape: Union[int, Sequence[int]]) -> Tensor:
  1638. torch._check_type(
  1639. not indices.is_complex()
  1640. and not indices.is_floating_point()
  1641. and not indices.dtype == torch.bool,
  1642. lambda: f"expected 'indices' to be integer dtype, but got {indices.dtype}",
  1643. )
  1644. torch._check_type(
  1645. isinstance(shape, (int, torch.SymInt, Sequence)),
  1646. lambda: f"expected 'shape' to be int or sequence of ints, but got {type(shape)}",
  1647. )
  1648. if isinstance(shape, (int, torch.SymInt)):
  1649. shape = torch.Size([shape])
  1650. else:
  1651. for dim in shape:
  1652. torch._check_type(
  1653. isinstance(dim, (int, torch.SymInt)),
  1654. lambda: f"expected 'shape' sequence to only contain ints, but got {type(dim)}",
  1655. )
  1656. shape = torch.Size(shape)
  1657. torch._check_value(
  1658. all(dim >= 0 for dim in shape),
  1659. lambda: f"'shape' cannot have negative values, but got {tuple(shape)}",
  1660. )
  1661. coefs = list(
  1662. reversed(
  1663. list(
  1664. itertools.accumulate(
  1665. reversed(shape[1:] + torch.Size([1])), func=operator.mul
  1666. )
  1667. )
  1668. )
  1669. )
  1670. return indices.unsqueeze(-1).floor_divide(
  1671. torch.tensor(coefs, device=indices.device, dtype=torch.int64)
  1672. ) % torch.tensor(shape, device=indices.device, dtype=torch.int64)
  1673. def chain_matmul(*matrices, out=None):
  1674. r"""Returns the matrix product of the :math:`N` 2-D tensors. This product is efficiently computed
  1675. using the matrix chain order algorithm which selects the order in which incurs the lowest cost in terms
  1676. of arithmetic operations (`[CLRS]`_). Note that since this is a function to compute the product, :math:`N`
  1677. needs to be greater than or equal to 2; if equal to 2 then a trivial matrix-matrix product is returned.
  1678. If :math:`N` is 1, then this is a no-op - the original matrix is returned as is.
  1679. .. warning::
  1680. :func:`torch.chain_matmul` is deprecated and will be removed in a future PyTorch release.
  1681. Use :func:`torch.linalg.multi_dot` instead, which accepts a list of two or more tensors
  1682. rather than multiple arguments.
  1683. Args:
  1684. matrices (Tensors...): a sequence of 2 or more 2-D tensors whose product is to be determined.
  1685. out (Tensor, optional): the output tensor. Ignored if :attr:`out` = ``None``.
  1686. Returns:
  1687. Tensor: if the :math:`i^{th}` tensor was of dimensions :math:`p_{i} \times p_{i + 1}`, then the product
  1688. would be of dimensions :math:`p_{1} \times p_{N + 1}`.
  1689. Example::
  1690. >>> # xdoctest: +SKIP
  1691. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  1692. >>> a = torch.randn(3, 4)
  1693. >>> b = torch.randn(4, 5)
  1694. >>> c = torch.randn(5, 6)
  1695. >>> d = torch.randn(6, 7)
  1696. >>> # will raise a deprecation warning
  1697. >>> torch.chain_matmul(a, b, c, d)
  1698. tensor([[ -2.3375, -3.9790, -4.1119, -6.6577, 9.5609, -11.5095, -3.2614],
  1699. [ 21.4038, 3.3378, -8.4982, -5.2457, -10.2561, -2.4684, 2.7163],
  1700. [ -0.9647, -5.8917, -2.3213, -5.2284, 12.8615, -12.2816, -2.5095]])
  1701. .. _`[CLRS]`: https://mitpress.mit.edu/books/introduction-algorithms-third-edition
  1702. """
  1703. # This wrapper exists to support variadic args.
  1704. if has_torch_function(matrices):
  1705. return handle_torch_function(chain_matmul, matrices, *matrices)
  1706. if out is None:
  1707. return _VF.chain_matmul(matrices) # type: ignore[attr-defined]
  1708. else:
  1709. return _VF.chain_matmul(matrices, out=out) # type: ignore[attr-defined]
  1710. def _lu_impl(A, pivot=True, get_infos=False, out=None):
  1711. # type: (Tensor, bool, bool, Any) -> tuple[Tensor, Tensor, Tensor]
  1712. r"""Computes the LU factorization of a matrix or batches of matrices
  1713. :attr:`A`. Returns a tuple containing the LU factorization and
  1714. pivots of :attr:`A`. Pivoting is done if :attr:`pivot` is set to
  1715. ``True``.
  1716. .. warning::
  1717. :func:`torch.lu` is deprecated in favor of :func:`torch.linalg.lu_factor`
  1718. and :func:`torch.linalg.lu_factor_ex`. :func:`torch.lu` will be removed in a
  1719. future PyTorch release.
  1720. ``LU, pivots, info = torch.lu(A, compute_pivots)`` should be replaced with
  1721. .. code:: python
  1722. LU, pivots = torch.linalg.lu_factor(A, compute_pivots)
  1723. ``LU, pivots, info = torch.lu(A, compute_pivots, get_infos=True)`` should be replaced with
  1724. .. code:: python
  1725. LU, pivots, info = torch.linalg.lu_factor_ex(A, compute_pivots)
  1726. .. note::
  1727. * The returned permutation matrix for every matrix in the batch is
  1728. represented by a 1-indexed vector of size ``min(A.shape[-2], A.shape[-1])``.
  1729. ``pivots[i] == j`` represents that in the ``i``-th step of the algorithm,
  1730. the ``i``-th row was permuted with the ``j-1``-th row.
  1731. * LU factorization with :attr:`pivot` = ``False`` is not available
  1732. for CPU, and attempting to do so will throw an error. However,
  1733. LU factorization with :attr:`pivot` = ``False`` is available for
  1734. CUDA.
  1735. * This function does not check if the factorization was successful
  1736. or not if :attr:`get_infos` is ``True`` since the status of the
  1737. factorization is present in the third element of the return tuple.
  1738. * In the case of batches of square matrices with size less or equal
  1739. to 32 on a CUDA device, the LU factorization is repeated for
  1740. singular matrices due to the bug in the MAGMA library
  1741. (see magma issue 13).
  1742. * ``L``, ``U``, and ``P`` can be derived using :func:`torch.lu_unpack`.
  1743. .. warning::
  1744. The gradients of this function will only be finite when :attr:`A` is full rank.
  1745. This is because the LU decomposition is just differentiable at full rank matrices.
  1746. Furthermore, if :attr:`A` is close to not being full rank,
  1747. the gradient will be numerically unstable as it depends on the computation of :math:`L^{-1}` and :math:`U^{-1}`.
  1748. Args:
  1749. A (Tensor): the tensor to factor of size :math:`(*, m, n)`
  1750. pivot (bool, optional): Whether to compute the LU decomposition with partial pivoting, or the regular LU
  1751. decomposition. :attr:`pivot`\ `= False` not supported on CPU. Default: `True`.
  1752. get_infos (bool, optional): if set to ``True``, returns an info IntTensor.
  1753. Default: ``False``
  1754. out (tuple, optional): optional output tuple. If :attr:`get_infos` is ``True``,
  1755. then the elements in the tuple are Tensor, IntTensor,
  1756. and IntTensor. If :attr:`get_infos` is ``False``, then the
  1757. elements in the tuple are Tensor, IntTensor. Default: ``None``
  1758. Returns:
  1759. (Tensor, IntTensor, IntTensor (optional)): A tuple of tensors containing
  1760. - **factorization** (*Tensor*): the factorization of size :math:`(*, m, n)`
  1761. - **pivots** (*IntTensor*): the pivots of size :math:`(*, \text{min}(m, n))`.
  1762. ``pivots`` stores all the intermediate transpositions of rows.
  1763. The final permutation ``perm`` could be reconstructed by
  1764. applying ``swap(perm[i], perm[pivots[i] - 1])`` for ``i = 0, ..., pivots.size(-1) - 1``,
  1765. where ``perm`` is initially the identity permutation of :math:`m` elements
  1766. (essentially this is what :func:`torch.lu_unpack` is doing).
  1767. - **infos** (*IntTensor*, *optional*): if :attr:`get_infos` is ``True``, this is a tensor of
  1768. size :math:`(*)` where non-zero values indicate whether factorization for the matrix or
  1769. each minibatch has succeeded or failed
  1770. Example::
  1771. >>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
  1772. >>> # xdoctest: +IGNORE_WANT("non-deterministic")
  1773. >>> A = torch.randn(2, 3, 3)
  1774. >>> A_LU, pivots = torch.lu(A)
  1775. >>> A_LU
  1776. tensor([[[ 1.3506, 2.5558, -0.0816],
  1777. [ 0.1684, 1.1551, 0.1940],
  1778. [ 0.1193, 0.6189, -0.5497]],
  1779. [[ 0.4526, 1.2526, -0.3285],
  1780. [-0.7988, 0.7175, -0.9701],
  1781. [ 0.2634, -0.9255, -0.3459]]])
  1782. >>> pivots
  1783. tensor([[ 3, 3, 3],
  1784. [ 3, 3, 3]], dtype=torch.int32)
  1785. >>> A_LU, pivots, info = torch.lu(A, get_infos=True)
  1786. >>> if info.nonzero().size(0) == 0:
  1787. ... print('LU factorization succeeded for all samples!')
  1788. LU factorization succeeded for all samples!
  1789. """
  1790. # If get_infos is True, then we don't need to check for errors and vice versa
  1791. return torch._lu_with_info(A, pivot=pivot, check_errors=(not get_infos))
  1792. if TYPE_CHECKING:
  1793. _ListOrSeq = Sequence[Tensor]
  1794. else:
  1795. _ListOrSeq = list[Tensor]
  1796. def _check_list_size(out_len: int, get_infos: bool, out: _ListOrSeq) -> None:
  1797. get_infos_int = 1 if get_infos else 0
  1798. if out_len - get_infos_int != 2:
  1799. raise TypeError(
  1800. f"expected tuple of {2 + int(get_infos)} elements but got {out_len}"
  1801. )
  1802. if not isinstance(out, (tuple, list)):
  1803. raise TypeError(
  1804. f"argument 'out' must be tuple of Tensors, not {type(out).__name__}"
  1805. )
  1806. def _lu_with_infos(A, pivot=True, get_infos=False, out=None):
  1807. # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor, Tensor]]) -> tuple[Tensor, Tensor, Tensor]
  1808. if has_torch_function_unary(A):
  1809. return handle_torch_function(
  1810. lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
  1811. )
  1812. result = _lu_impl(A, pivot, get_infos, out)
  1813. if out is not None:
  1814. _check_list_size(len(out), get_infos, out)
  1815. for i in range(len(out)):
  1816. out[i].resize_as_(result[i]).copy_(result[i])
  1817. return out
  1818. else:
  1819. return result # A_LU, pivots, infos
  1820. def _lu_no_infos(A, pivot=True, get_infos=False, out=None):
  1821. # type: (Tensor, bool, bool, Optional[tuple[Tensor, Tensor]]) -> tuple[Tensor, Tensor]
  1822. # need to check for torch_function here so that we exit if
  1823. if has_torch_function_unary(A):
  1824. return handle_torch_function(
  1825. lu, (A,), A, pivot=pivot, get_infos=get_infos, out=out
  1826. )
  1827. result = _lu_impl(A, pivot, get_infos, out)
  1828. if out is not None:
  1829. _check_list_size(len(out), get_infos, out)
  1830. for i in range(len(out)):
  1831. out[i].resize_as_(result[i]).copy_(result[i])
  1832. return out
  1833. else:
  1834. return result[0], result[1] # A_LU, pivots
  1835. # The return type of lu depends on `get_infos`, so in order to resolve the output type
  1836. # of lu in TorchScript we need to statically know the value of `get_infos`
  1837. lu = boolean_dispatch(
  1838. arg_name="get_infos",
  1839. arg_index=2,
  1840. default=False,
  1841. if_true=_lu_with_infos,
  1842. if_false=_lu_no_infos,
  1843. module_name=__name__,
  1844. func_name="lu",
  1845. )
  1846. lu.__doc__ = _lu_impl.__doc__
  1847. def align_tensors(*tensors):
  1848. raise RuntimeError("`align_tensors` not yet implemented.")