_ops.py 65 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811
  1. # mypy: allow-untyped-defs
  2. import warnings
  3. from typing import Any, Callable, Optional, TYPE_CHECKING, TypeVar, Union
  4. from typing_extensions import ParamSpec, TypeAlias
  5. import torch
  6. from torch import sym_float, Tensor
  7. from torch._prims_common import corresponding_real_dtype
  8. from torch.masked import _docs
  9. from torch.masked.maskedtensor.core import is_masked_tensor, MaskedTensor
  10. from torch.masked.maskedtensor.creation import as_masked_tensor
  11. if TYPE_CHECKING:
  12. from torch._prims_common import DimsType
  13. from torch.types import _dtype as DType
  14. DimOrDims: TypeAlias = Optional[DimsType]
  15. else:
  16. # The JIT doesn't understand Union, nor torch.dtype here
  17. DType = int
  18. DimOrDims = Optional[tuple[int, ...]]
  19. __all__: list[str] = []
  20. _T = TypeVar("_T")
  21. _P = ParamSpec("_P")
  22. # All masked reduction/normalization operations have the same
  23. # signatures. Here we introduce docstring templates that are applied
  24. # to docstrings of reduction/normalization functions via
  25. # _apply_docstring_templates decorator.
  26. def _apply_docstring_templates(func: Callable[_P, _T]) -> Callable[_P, _T]:
  27. """Decorator that applies docstring templates to function docstring
  28. and returns the function instance.
  29. """
  30. doc_string = getattr(_docs, f"{func.__name__}_docstring", None)
  31. if doc_string is None:
  32. warnings.warn(
  33. f"No documentation string available for {func.__name__}."
  34. " PyTorch team should run `python tools/update_masked_docs.py`"
  35. " to generate the missing docstrings."
  36. )
  37. else:
  38. func.__doc__ = doc_string
  39. # Expose function as public symbol
  40. __all__.append(func.__name__)
  41. return func
  42. def _generate_docstring(func):
  43. """A utility function called from tools/update_masked_docs.py
  44. script to update the module torch.masked._docs.py
  45. """
  46. docstring_templates = dict(
  47. reduction_signature="""\
  48. {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
  49. reduction_descr="""\
  50. Returns {operation name} of all the elements in the :attr:`input`
  51. tensor along the given dimension(s) :attr:`dim` while the :attr:`input`
  52. elements are masked out according to the boolean tensor
  53. :attr:`mask`.""",
  54. reduction_args="""\
  55. If :attr:`keepdim` is ``True``, the output tensor is of the same size
  56. as :attr:`input` except in the dimension(s) :attr:`dim` where it is of
  57. size 1. Otherwise, :attr:`dim` is squeezed (see
  58. :func:`torch.squeeze`), resulting in the output tensor having 1 (or
  59. ``len(dim)``) fewer dimension(s).
  60. The boolean tensor :attr:`mask` defines the "validity" of
  61. :attr:`input` tensor elements: if :attr:`mask` element is True
  62. then the corresponding element in :attr:`input` tensor will be
  63. included in {operation name} computation, otherwise the element is
  64. ignored.
  65. When all elements of :attr:`input` along the given dimension
  66. :attr:`dim` are ignored (fully masked-out), the corresponding element
  67. of the output tensor will have undefined value: it may or may not
  68. correspond to the identity value of {operation name} operation; the
  69. choice may correspond to the value that leads to the most efficient
  70. storage of :attr:`output` tensor.
  71. The mask of the output tensor can be computed as
  72. ``torch.any(torch.broadcast_to(mask, input.shape), dim, keepdim=keepdim,
  73. dtype=torch.bool)``.
  74. The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
  75. don't need to match, but they must be :ref:`broadcastable
  76. <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
  77. tensor must not be greater than of the :attr:`input` tensor.
  78. Args:
  79. input (Tensor): the input tensor
  80. {args_declarations}
  81. Keyword args:
  82. {kwargs_declarations}""",
  83. reduction_example="""\
  84. Example::
  85. >>> input = {example_input}
  86. >>> input
  87. {indent_example_input}
  88. >>> mask = {example_mask}
  89. >>> mask
  90. {indent_example_mask}
  91. >>> {full_function_name}(input, {example_args}, mask=mask)
  92. {indent_example_output}
  93. """,
  94. reduction_identity="""\
  95. The identity value of {operation name} operation, which is used to start the reduction, is ``{identity_int32}``.""",
  96. reduction_identity_dtype="""\
  97. The identity value of {operation name} operation, which is used to start the
  98. reduction, depends on input dtype. For instance, for float32, uint8,
  99. and int32 dtypes, the identity values are ``{identity_float32}``, ``{identity_uint8}``, and ``{identity_int32}``, respectively.""",
  100. normalization_signature="""\
  101. {function_name}(input, {operation_args}, *, {operation_kwargs}) -> Tensor""",
  102. normalization_descr="""\
  103. Returns {operation name} of all the slices in the :attr:`input` tensor
  104. along :attr:`dim` while the :attr:`input` elements are masked out
  105. according to the boolean tensor :attr:`mask`.
  106. {definition}""",
  107. normalization_args="""\
  108. The boolean tensor :attr:`mask` defines the "validity" of
  109. :attr:`input` tensor elements: if :attr:`mask` element is True then
  110. the corresponding element in :attr:`input` tensor will be included in
  111. {operation name} computation, otherwise the element is ignored.
  112. The values of masked-out elements of the output tensor have undefined
  113. value: it may or may not be set to zero or nan; the choice may correspond to
  114. the value that leads to the most efficient storage of :attr:`output`
  115. tensor.
  116. The mask of the {operation name} output tensor can be computed as
  117. ``torch.broadcast_to(mask, input.shape)``.
  118. The shapes of the :attr:`mask` tensor and the :attr:`input` tensor
  119. don't need to match, but they must be :ref:`broadcastable
  120. <broadcasting-semantics>` and the dimensionality of the :attr:`mask`
  121. tensor must not be greater than of the :attr:`input` tensor.
  122. Args:
  123. input (Tensor): the input tensor
  124. {args_declarations}
  125. Keyword args:
  126. {kwargs_declarations}""",
  127. normalization_example="""\
  128. Example::
  129. >>> input = {example_input}
  130. >>> input
  131. {indent_example_input}
  132. >>> mask = {example_mask}
  133. >>> mask
  134. {indent_example_mask}
  135. >>> {full_function_name}(input, {example_args}, mask=mask)
  136. {indent_example_output}
  137. """,
  138. )
  139. args_and_kwargs = {
  140. # argument name sufficies separated by double underscore will
  141. # be removed in the final documentation string.
  142. "sum": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  143. "prod": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  144. "cumsum": (("dim__as_int",), ("dtype=None", "mask=None")),
  145. "cumprod": (("dim__as_int",), ("dtype=None", "mask=None")),
  146. "amin": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  147. "amax": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  148. "argmin": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  149. "argmax": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  150. "mean": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  151. "median": (("dim__as_int",), ("keepdim=False", "dtype=None", "mask=None")),
  152. "norm": (
  153. (
  154. "ord",
  155. "dim",
  156. ),
  157. ("keepdim=False", "dtype=None", "mask=None"),
  158. ),
  159. "var": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
  160. "std": (("dim", "unbiased"), ("keepdim=False", "dtype=None", "mask=None")),
  161. "logsumexp": (("dim",), ("keepdim=False", "dtype=None", "mask=None")),
  162. "softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
  163. "log_softmax": (("dim__as_int",), ("dtype=None", "mask=None")),
  164. "softmin": (("dim__as_int",), ("dtype=None", "mask=None")),
  165. "normalize": (
  166. (
  167. "ord__required",
  168. "dim__as_int",
  169. ),
  170. ("eps=1e-12", "dtype=None", "mask=None"),
  171. ),
  172. }
  173. argument_declarations = {
  174. "dim": """\
  175. dim (int or tuple of ints, optional): the dimension or dimensions to reduce.
  176. Default: None that is equivalent to ``tuple(range(input.ndim))``.""",
  177. "dim__as_int": """\
  178. dim (int): the dimension along which {operation name} is computed.""",
  179. "ord": """\
  180. ord (int, float, optional): the order of vector norm. Default: 2.
  181. See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
  182. "ord__required": """\
  183. ord (int, float): the order of vector norm. Default: 2.
  184. See :func:`torch.linalg.vector_norm` for a list of supported norms.""",
  185. "unbiased": """\
  186. unbiased (bool): when True, use Bessel's correction, otherwise, compute
  187. the uncorrected sample variance.""",
  188. "eps": """\
  189. eps (float, optional): small value to avoid division by zero. Default: {default}.""",
  190. "keepdim": """\
  191. keepdim (bool, optional): whether the output tensor has
  192. :attr:`dim` retained or not. Default: {default}.""",
  193. "dtype": """\
  194. dtype (:class:`torch.dtype`, optional): the desired data type
  195. of returned tensor. If specified, the input tensor is
  196. casted to :attr:`dtype` before the operation is
  197. performed. Default: {default}.""",
  198. "mask": """\
  199. mask (:class:`torch.Tensor`, optional): the boolean tensor
  200. containing the binary mask of validity of input tensor
  201. elements.
  202. Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.""",
  203. }
  204. definitions = {
  205. "softmax": """\
  206. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  207. of the :attr:`input` tensor. Softmax of i-th element in ``x`` is
  208. defined as ``exp(x[i])/sum(exp(x))``.""",
  209. "log_softmax": """\
  210. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  211. of the :attr:`input` tensor. LogSoftmax of i-th element in ``x`` is
  212. defined as ``log(exp(x[i])/sum(exp(x)))``.""",
  213. "softmin": """\
  214. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  215. of the :attr:`input` tensor. Softmin of i-th element in ``x`` is
  216. defined as ``exp(-x[i])/sum(exp(-x))``.""",
  217. "normalize": """\
  218. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  219. of the :attr:`input` tensor. Normalize of i-th element in ``x`` is
  220. defined as ``x[i]/max(norm(x, p), eps)``.""",
  221. "cumsum": """\
  222. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  223. of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
  224. defined as ``sum(x[:i])``.""",
  225. "cumprod": """\
  226. Let ``x`` be a sequence of unmasked elements of one-dimensional slice
  227. of the :attr:`input` tensor. Cumsum of i-th element in ``x`` is
  228. defined as ``prod(x[:i])``.""",
  229. }
  230. reduction_names = {
  231. "sum": "sum",
  232. "prod": "product",
  233. "amax": "maximum",
  234. "amin": "minimum",
  235. "argmax": "argmax",
  236. "argmin": "argmin",
  237. "mean": "mean",
  238. "median": "median",
  239. "norm": "norm",
  240. "var": "variance",
  241. "std": "standard_deviation",
  242. "logsumexp": "logsumexp",
  243. }
  244. normalization_names = {
  245. "softmax": "softmax",
  246. "log_softmax": "log_softmax",
  247. "softmin": "softmin",
  248. "normalize": "normalize",
  249. "cumsum": "cumulative_sum",
  250. "cumprod": "cumulative_prod",
  251. }
  252. operation_names = {}
  253. operation_names.update(reduction_names)
  254. operation_names.update(normalization_names)
  255. # Default example data:
  256. example_dim = 1
  257. example_input = torch.tensor([[-3, -2, -1], [0, 1, 2]])
  258. example_mask = torch.tensor([[True, False, True], [False, False, False]])
  259. example_args: tuple[Any, ...]
  260. if func.__name__ in {"norm", "normalize"}:
  261. example_args = (2.0, example_dim)
  262. example_input = example_input.to(dtype=torch.float32)
  263. elif func.__name__ in {"var", "std"}:
  264. example_args = (example_dim, False)
  265. elif func.__name__ == "median":
  266. example_args = (example_dim,)
  267. example_input = example_input.to(dtype=torch.float32)
  268. else:
  269. example_args = (example_dim,)
  270. operation_args: tuple[str, ...]
  271. operation_kwargs: tuple[str, ...]
  272. operation_args, operation_kwargs = args_and_kwargs[func.__name__]
  273. arg_declarations = [
  274. "\n ".join(
  275. argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines()
  276. )
  277. for a in operation_args
  278. ]
  279. kwarg_declarations = [
  280. "\n ".join(
  281. argument_declarations.get(
  282. a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD."
  283. )
  284. .format(default=a.split("=", 1)[1])
  285. .splitlines()
  286. )
  287. for a in operation_kwargs
  288. ]
  289. if func.__name__ in reduction_names:
  290. op_kind = "reduction"
  291. doc_sections = ["signature", "descr", "identity", "args", "example"]
  292. elif func.__name__ in normalization_names:
  293. op_kind = "normalization"
  294. doc_sections = ["signature", "descr", "args", "example"]
  295. example_input = example_input.to(dtype=torch.float32)
  296. else:
  297. assert 0 # add function name to operation names dictionaries
  298. example_output = func(example_input, *example_args, mask=example_mask)
  299. template_data = {
  300. "function_name": func.__name__,
  301. "full_function_name": func.__module__ + "." + func.__name__,
  302. "operation name": operation_names[func.__name__],
  303. "operation_args": ", ".join(a.split("__", 1)[0] for a in operation_args),
  304. "operation_kwargs": ", ".join(a.split("__", 1)[0] for a in operation_kwargs),
  305. # one-line representation of a tensor:
  306. "example_input": " ".join(str(example_input).split()),
  307. "example_args": ", ".join(map(str, example_args)),
  308. "example_mask": " ".join(str(example_mask).split()),
  309. # multi-line representation of a tensor with indent
  310. "indent_example_input": ("\n ").join(str(example_input).splitlines()),
  311. "indent_example_mask": ("\n ").join(str(example_mask).splitlines()),
  312. "indent_example_output": ("\n ").join(str(example_output).splitlines()),
  313. }
  314. if func.__name__ in reduction_names:
  315. template_data.update(
  316. identity_uint8=_reduction_identity(
  317. func.__name__, torch.tensor(0, dtype=torch.uint8)
  318. ),
  319. identity_int32=_reduction_identity(
  320. func.__name__, torch.tensor(0, dtype=torch.int32)
  321. ),
  322. identity_float32=_reduction_identity(
  323. func.__name__, torch.tensor(0, dtype=torch.float32)
  324. ),
  325. )
  326. if func.__name__ == "norm":
  327. template_data.update(
  328. identity_ord_ninf=_reduction_identity(
  329. func.__name__, torch.tensor(0, dtype=torch.float32), float("-inf")
  330. )
  331. )
  332. elif func.__name__ in normalization_names:
  333. template_data.update(definition=definitions[func.__name__])
  334. else:
  335. assert 0 # add function name to operation names dictionaries
  336. template_data.update(
  337. args_declarations=("\n ".join(arg_declarations)).format_map(template_data)
  338. )
  339. template_data.update(
  340. kwargs_declarations=("\n ".join(kwarg_declarations)).format_map(
  341. template_data
  342. )
  343. )
  344. # Apply function name info to docstring templates:
  345. templates = {
  346. k: v.format_map(template_data)
  347. for k, v in docstring_templates.items()
  348. if k.startswith(op_kind)
  349. }
  350. templates.update(
  351. (k, v.format_map(template_data) if isinstance(v, str) else v)
  352. for k, v in template_data.items()
  353. )
  354. # Apply docstring templates to function doctring:
  355. if func.__doc__ is None:
  356. doc_template = "\n\n".join([f"{{{op_kind}_{sec}}}" for sec in doc_sections])
  357. else:
  358. doc_template = func.__doc__
  359. return doc_template.format_map(templates)
  360. def _reduction_identity(op_name: str, input: Tensor, *args):
  361. """Return identity value as scalar tensor of a reduction operation on
  362. given input, or None, if the identity value cannot be uniquely
  363. defined for the given input.
  364. The identity value of the operation is defined as the initial
  365. value to reduction operation that has a property ``op(op_identity,
  366. value) == value`` for any value in the domain of the operation.
  367. Or put it another way, including or excluding the identity value in
  368. a list of operands will not change the reduction result.
  369. See https://github.com/pytorch/rfcs/pull/27 for more information.
  370. """
  371. dtype: DType = input.dtype
  372. device = input.device
  373. op_name = op_name.rsplit(".", 1)[-1] # lstrip module name when present
  374. if op_name in {"sum", "cumsum"}:
  375. return torch.tensor(0, dtype=dtype, device=device)
  376. elif op_name in {"prod", "cumprod"}:
  377. return torch.tensor(1, dtype=dtype, device=device)
  378. elif op_name in {"amax", "argmax", "logaddexp"}:
  379. if torch.is_floating_point(input):
  380. return torch.tensor(-torch.inf, dtype=dtype, device=device)
  381. elif torch.is_signed(input) or dtype == torch.uint8:
  382. return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
  383. elif op_name in {"logsumexp"}:
  384. if torch.is_floating_point(input):
  385. return torch.tensor(-torch.inf, dtype=dtype, device=device)
  386. elif torch.is_complex(input):
  387. return torch.tensor(-torch.inf + 0j, dtype=dtype, device=device)
  388. elif torch.is_signed(input) or dtype == torch.uint8:
  389. return torch.tensor(torch.iinfo(dtype).min, dtype=dtype, device=device)
  390. elif op_name in {"amin", "argmin"}:
  391. if torch.is_floating_point(input):
  392. return torch.tensor(torch.inf, dtype=dtype, device=device)
  393. elif torch.is_signed(input) or dtype == torch.uint8:
  394. return torch.tensor(torch.iinfo(dtype).max, dtype=dtype, device=device)
  395. elif op_name == "mean":
  396. # Strictly speaking, the identity value of the mean operation
  397. # is the mean of the input. Since the mean value depends on
  398. # the dim argument and it may be a non-scalar tensor, we
  399. # consider the identity value of the mean operation ambiguous.
  400. # Moreover, the mean value of empty input is undefined.
  401. return None
  402. elif op_name == "norm":
  403. ord = args[0] if args else 2
  404. if ord == float("-inf"):
  405. assert torch.is_floating_point(input), input.dtype
  406. return torch.tensor(torch.inf, dtype=dtype, device=device)
  407. return torch.tensor(0, dtype=dtype, device=device)
  408. elif op_name == "median":
  409. # We use NaN for now because the implementation is currently using torch.nanmedian
  410. # and NaN is the identity for that function since it gets ignored
  411. dtype = input.dtype if torch.is_floating_point(input) else torch.float
  412. return torch.tensor(torch.nan, dtype=dtype, device=device)
  413. elif op_name in {"var", "std"}:
  414. return None
  415. raise NotImplementedError(f"identity of {op_name} on {dtype} input")
  416. def _canonical_dim(dim: DimOrDims, ndim: int) -> tuple[int, ...]:
  417. """Return dim argument as a tuple of sorted dim values."""
  418. dims: list[int] = []
  419. if dim == ():
  420. # Currently, `dim=()` in reductions operations means "reduce
  421. # over all dimensions" while in future, it will read "no
  422. # reduce". See https://github.com/pytorch/pytorch/issues/29137
  423. # When gh-29137 is resolved, this if-block must be deleted.
  424. dim = None
  425. if dim is None:
  426. return tuple(range(ndim))
  427. ndim = max(ndim, 1)
  428. dim_ = (dim,) if isinstance(dim, (int, torch.SymInt)) else dim
  429. for d in dim_:
  430. if d in dims:
  431. raise RuntimeError(f"dim={d} appears multiple times in the list of dims")
  432. if d >= ndim or d < -ndim:
  433. raise IndexError(
  434. f"Dimension out of range (expected to be in range of [{-ndim}, {ndim - 1}], but got {d})"
  435. )
  436. dims.append(d % ndim)
  437. return tuple(sorted(dims))
  438. def _sparse_coo_flatten_indices(indices: Tensor, shape: tuple):
  439. # Flatted N-D indices to 1-D indices
  440. flat_indices = indices.new_zeros(indices.size(1))
  441. for d, sz in enumerate(shape):
  442. flat_indices.mul_(sz)
  443. flat_indices.add_(indices[d])
  444. return flat_indices
  445. def _any(input: Tensor, dim: tuple, keepdim: bool):
  446. # Support torch.any with tuple dim argument.
  447. # Workaround of https://github.com/pytorch/pytorch/issues/56586
  448. r = input
  449. for d in reversed(dim):
  450. r = r.any(dim=d, keepdim=keepdim)
  451. return r
  452. def _sparse_coo_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  453. """Sparse variant of torch.where. Supports sparse COO and hybrid sparse COO tensors.
  454. _sparse_coo_where implements the following invariant:
  455. _sparse_coo_where(mask, input, fill_value).to_dense(fill_value) ==
  456. torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
  457. where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
  458. tensor, and `to_dense(fill_value)` is like `to_dense()` except
  459. that the unspecified elements are mapped to `fill_value` rather
  460. than to `0`.
  461. Returns a sparse COO tensor with the following features:
  462. - all specified elements correspond to masked-in elements that
  463. have the values of the input tensor. If there exists a masked-in
  464. element (as specified by mask) that is not specified in the
  465. input, in the result tensor, the corresponding element has value
  466. 0. In the dense part of the sparse tensor, the masked-out
  467. elements are replaced with fill_value.
  468. - all unspecified elements correspond to masked-out elements.
  469. """
  470. assert input.layout == torch.sparse_coo
  471. assert mask.layout == input.layout
  472. assert mask.shape == input.shape
  473. assert mask.dense_dim() == input.dense_dim() # TODO: eliminate this restriction
  474. input = input.coalesce()
  475. # For set operations on sparse tensor indices, we'll convert
  476. # multi-dimensional indices to 1-D indices for efficiency.
  477. input_flat_indices = _sparse_coo_flatten_indices(
  478. input.indices(), input.shape[: input.sparse_dim()]
  479. )
  480. mask_flat_indices = _sparse_coo_flatten_indices(
  481. mask.indices(), mask.shape[: mask.sparse_dim()]
  482. )
  483. # the set of mask flat indices that define masked-in elements:
  484. if mask.dense_dim() > 0:
  485. mask_values = _any(
  486. mask.values(), tuple(range(1, input.sparse_dim() + 1)), False
  487. )
  488. else:
  489. mask_values = mask.values()
  490. maskin_flat_indices = mask_flat_indices[mask_values.nonzero()[:, 0]]
  491. def intersection(i1, i2):
  492. union, counts = torch.cat([i1, i2]).unique(return_counts=True)
  493. return union, torch.where(counts.gt(1))
  494. def minus(i1, i2):
  495. union, counts = torch.cat([i1, i2]).unique(return_counts=True)
  496. return intersection(union[torch.where(counts.eq(1))], i1)
  497. def _apply(a):
  498. obj, w = a
  499. return obj[w]
  500. # the set of input flat indices of specified and masked-in elements:
  501. maskin_input_flat_indices = _apply(
  502. intersection(maskin_flat_indices, input_flat_indices)
  503. )
  504. _, w = intersection(input_flat_indices, maskin_input_flat_indices)
  505. # the indices and values of masked-in elements
  506. where_input_indices = input.indices()[(slice(None),) + w]
  507. where_input_values = input.values()[w]
  508. if mask.dense_dim() > 0:
  509. # apply mask to the dense part of the input values:
  510. _, w1 = intersection(mask_flat_indices, maskin_input_flat_indices)
  511. where_mask_values = mask.values()[w1]
  512. where_input_values = torch.where(
  513. where_mask_values, where_input_values, fill_value
  514. )
  515. # the set of flat indices of unspecified input and masked-in elements:
  516. maskin_zero_flat_indices = _apply(
  517. minus(maskin_flat_indices, maskin_input_flat_indices)
  518. )
  519. # the indices of masked-in zero elements
  520. _, w = intersection(mask_flat_indices, maskin_zero_flat_indices)
  521. where_zero_indices = mask.indices()[(slice(None),) + w]
  522. # construct result
  523. n = where_zero_indices.size(1)
  524. if n == 0:
  525. # the input is coalesced, hence input_flat_indices are ordered
  526. # and the result is guaranteed to be coalesced:
  527. result = torch.sparse_coo_tensor(
  528. where_input_indices, where_input_values, input.shape
  529. )
  530. return result._coalesced_(True)
  531. where_indices = torch.cat([where_input_indices, where_zero_indices], dim=1)
  532. where_values = torch.cat(
  533. [
  534. where_input_values,
  535. where_input_values.new_zeros((n,) + where_input_values.shape[1:]),
  536. ]
  537. )
  538. result = torch.sparse_coo_tensor(where_indices, where_values, input.shape)
  539. # appending zero elements leads to uncoalesced sparse tensor
  540. return result.coalesce()
  541. def _sparse_coo_scatter_reduction_helper(
  542. op,
  543. mask_input: Tensor,
  544. dims: tuple[int, ...],
  545. keepdim: bool,
  546. dtype: Optional[DType] = None,
  547. ) -> Tensor:
  548. reduce = op.__name__
  549. valid_reductions = ["sum", "prod", "amax", "amin"]
  550. if reduce not in valid_reductions:
  551. raise ValueError(
  552. f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
  553. )
  554. output_dtype = dtype
  555. values, indices = mask_input._values(), mask_input._indices()
  556. input_dims = mask_input.dim()
  557. num_sparse_dims = mask_input.sparse_dim()
  558. reduced_sparse_dims = []
  559. retained_sparse_dims = []
  560. reduced_dense_dims = []
  561. # promote dtype if specified
  562. if values.dtype != output_dtype:
  563. values = values.to(output_dtype)
  564. if keepdim:
  565. output_shape = tuple(
  566. 1 if i in dims else si for (i, si) in enumerate(mask_input.shape)
  567. )
  568. else:
  569. output_shape = tuple(
  570. si for (i, si) in enumerate(mask_input.shape) if i not in dims
  571. )
  572. for d in dims:
  573. if d >= input_dims:
  574. continue
  575. if d < num_sparse_dims:
  576. reduced_sparse_dims.append(d)
  577. else:
  578. reduced_dense_dims.append(d + 1 - num_sparse_dims)
  579. # Reduce dense dimensions
  580. if len(reduced_dense_dims) > 0:
  581. if reduce == "sum":
  582. new_values = values
  583. new_values = op(new_values, dim=reduced_dense_dims, keepdim=bool(keepdim))
  584. else:
  585. # FIXME: Implement reductions for dense dimensions for ops with non-zero reduction identities
  586. return NotImplemented
  587. else:
  588. new_values = values.clone()
  589. # Reduce sparse dimensions
  590. if len(reduced_sparse_dims) == num_sparse_dims:
  591. if reduce in {"amax", "amin"} and new_values.size(0) == 0:
  592. # IndexError: amax(): Expected reduction dim 0 to have non-zero size.
  593. # sum()/prod() return the reduction identity when dim has size 0 but amax()/amin() do not
  594. # See https://github.com/pytorch/pytorch/issues/61901
  595. new_values = _reduction_identity(reduce, new_values)
  596. else:
  597. new_values = op(new_values, dim=0)
  598. if keepdim:
  599. for _ in range(num_sparse_dims):
  600. new_values = new_values.unsqueeze(0)
  601. return new_values.to(dtype=output_dtype).to_sparse()
  602. else:
  603. new_indices = indices.clone()
  604. if keepdim:
  605. # zero out reduced sparse dimensions if keepdim = True
  606. # ensures that the call to torch.unique folds duplicated indices together while preserving the dimension
  607. new_indices[reduced_sparse_dims, :] = 0
  608. else:
  609. # remove reduced sparse dimensions if keepdim = False
  610. if len(reduced_sparse_dims) > 0:
  611. retained_sparse_dims = [
  612. i
  613. for i in range(num_sparse_dims)
  614. if i not in set(reduced_sparse_dims)
  615. ]
  616. new_indices = new_indices.index_select(
  617. 0, torch.tensor(retained_sparse_dims).to(mask_input.device)
  618. )
  619. # Use scatter_reduce to reduce items in the new_values tensor that correspond to the same indices in new_indices
  620. if new_indices.numel() > 0:
  621. # lexsort indices and get index tensor for scatter reduction
  622. new_indices, inverse_indices = torch.unique(
  623. new_indices, return_inverse=True, dim=1
  624. )
  625. out_shape = list(new_values.shape)
  626. out_shape[0] = new_indices.shape[1]
  627. for _ in range(new_values.ndim - 1):
  628. inverse_indices = inverse_indices.unsqueeze(-1)
  629. scatter_indices = inverse_indices.expand(new_values.shape)
  630. # FIXME: temporary workaround for issue with bfloat16/float16 remove when acctype is implemented for scatter_reduce
  631. if output_dtype in {torch.bfloat16, torch.float16}:
  632. new_values = new_values.to(torch.float)
  633. out = new_values.new_empty(out_shape)
  634. new_values = out.scatter_reduce_(
  635. 0, scatter_indices, new_values, reduce=reduce, include_self=False
  636. )
  637. new_values = new_values.to(dtype=output_dtype)
  638. else:
  639. out = new_values.new_empty(out_shape)
  640. new_values = out.scatter_reduce_(
  641. 0, scatter_indices, new_values, reduce=reduce, include_self=False
  642. )
  643. return torch.sparse_coo_tensor(
  644. new_indices,
  645. new_values,
  646. output_shape,
  647. dtype=output_dtype,
  648. device=mask_input.device,
  649. )
  650. def _sparse_csr_segment_reduction_helper(
  651. op,
  652. mask_input: Tensor,
  653. dims: tuple[int, ...],
  654. keepdim: bool,
  655. dtype: Optional[DType] = None,
  656. ) -> Tensor:
  657. # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True
  658. # FIXME: when dense dimensions are implemented for CSR tensors
  659. assert keepdim, (
  660. "reduction operations on CSR tensors with keepdim=False is unsupported"
  661. )
  662. reduce = op.__name__
  663. valid_reductions = ["sum", "prod", "mean", "amax", "amin"]
  664. if reduce not in valid_reductions:
  665. raise ValueError(
  666. f"op must be one of {' '.join(valid_reductions)}, but got {reduce} instead"
  667. )
  668. device = mask_input.device
  669. output_dtype = dtype
  670. values, crow_indices, col_indices = (
  671. mask_input.values(),
  672. mask_input.crow_indices(),
  673. mask_input.col_indices(),
  674. )
  675. # promote dtype if specified
  676. if values.dtype != output_dtype:
  677. values = values.to(output_dtype)
  678. if len(dims) == 0:
  679. return mask_input
  680. if len(dims) == 1:
  681. if dims[0] == 0:
  682. new_col_indices, scatter_indices = torch.unique(
  683. col_indices, return_inverse=True
  684. )
  685. new_nnz = new_col_indices.shape[0]
  686. new_crow_indices = torch.tensor([0, new_nnz])
  687. new_values = values.new_empty(new_col_indices.shape)
  688. new_values.scatter_reduce_(
  689. 0, scatter_indices, values, reduce, include_self=False
  690. )
  691. new_shape = [1, mask_input.size(1)]
  692. else:
  693. assert dims[0] == 1, (
  694. "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1."
  695. )
  696. # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1
  697. # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0
  698. new_crow_indices = torch.cat(
  699. (
  700. crow_indices.new_zeros(1),
  701. torch.cumsum(torch.diff(crow_indices) != 0, 0),
  702. ),
  703. 0,
  704. )
  705. new_nnz = new_crow_indices[-1]
  706. new_col_indices = col_indices.new_zeros(new_nnz) # type: ignore[call-overload]
  707. new_values = torch._segment_reduce(values, reduce, offsets=crow_indices) # type: ignore[attr-defined]
  708. new_shape = [mask_input.size(0), 1]
  709. else:
  710. assert len(dims) == 2
  711. nnz = min(1, values.numel())
  712. if nnz == 1:
  713. op_kwargs = {"keepdim": True, "dtype": output_dtype}
  714. # amax and amin do not support dtype kwarg
  715. if reduce in ["amax", "amin"]:
  716. del op_kwargs["dtype"]
  717. new_values = op(values, 0, **op_kwargs)
  718. else:
  719. new_values = torch.empty(0, dtype=output_dtype)
  720. new_col_indices = col_indices.new_zeros(nnz)
  721. new_crow_indices = torch.tensor([0, nnz])
  722. new_shape = [1, nnz]
  723. return torch.sparse_csr_tensor(
  724. new_crow_indices,
  725. new_col_indices,
  726. new_values,
  727. new_shape,
  728. dtype=output_dtype,
  729. device=device,
  730. )
  731. def _sparse_csr_where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  732. """Sparse variant of torch.where. Supports sparse CSR tensors."""
  733. # TODO: implement sparse CSR specific where operator for efficiency
  734. return _sparse_coo_where(
  735. mask.to_sparse_coo(), input.to_sparse_coo(), fill_value
  736. ).to_sparse_csr()
  737. def _where(mask: Tensor, input: Tensor, fill_value: Tensor) -> Tensor:
  738. """torch.where with sparse inputs support.
  739. _where implements the following invariant:
  740. _where(mask, input, fill_value).to_dense(fill_value) ==
  741. torch.where(mask.to_dense(), input.to_dense(), torch.full(input.shape, fill_value))
  742. where `a == b` means `assertEqual(a, b)`, mask is boolean sparse
  743. tensor, and `to_dense(fill_value)` is like `to_dense()` except
  744. that the unspecified elements are mapped to `fill_value` rather
  745. than to `0`.
  746. Returns a sparse tensor with the following features:
  747. - all specified elements correspond to masked-in elements that
  748. have the values of the input tensor. If there exists a masked-in
  749. element (as specified by mask) that is not specified in the
  750. input, in the result tensor, the corresponding element has value
  751. 0. In the dense part of the sparse tensor, the masked-out
  752. elements are replaced with fill_value.
  753. - all unspecified elements correspond to masked-out elements.
  754. """
  755. if mask.layout == torch.strided:
  756. return torch.where(mask, input, fill_value)
  757. elif mask.layout == torch.sparse_coo:
  758. return _sparse_coo_where(mask, input, fill_value)
  759. elif mask.layout == torch.sparse_csr:
  760. return _sparse_csr_where(mask, input, fill_value)
  761. else:
  762. raise ValueError(
  763. f"_where expects strided or sparse COO or sparse CSR tensor but got {mask.layout}"
  764. )
  765. def _input_mask(input: Union[Tensor, MaskedTensor], *args, **kwargs) -> Tensor:
  766. """Return canonical input mask.
  767. A canonical input mask is defined as a boolean mask tensor that
  768. shape and layout matches with the shape and the layout of the
  769. input.
  770. The canonical input mask is computed from the :attr:`mask` tensor
  771. content to meet the following criteria:
  772. 1. The shape of the canonical input mask is the same as the shape
  773. of :attr:`input` tensor. If the mask tensor has a smaller shape
  774. than the shape of the :attr:`input`, broadcasting rules will be
  775. applied. Downcasting of mask is not supported.
  776. 2. The layout of the canonical input mask is the same as the
  777. layout of the :attr:`input` tensor. If the mask has different
  778. layout, it will be converted to the expected layout. In the
  779. case of sparse COO layout, the canonical input mask will be
  780. coalesced.
  781. 3. The dtype of the canonical input mask is torch.bool. If the
  782. mask dtype is not bool then it will be converted to bool dtype
  783. using `.to(dtype=bool)` method call.
  784. 4. The elements of the canonical input mask have boolean values
  785. copied from the content of the :attr:`mask` tensor (after
  786. possible broadcasting and dtype conversion transforms). In
  787. general, the sparsity pattern of the sparse canonical input
  788. mask need not to be the same as the sparsity pattern of the
  789. sparse :attr:`input` tensor.
  790. """
  791. if input.layout not in {torch.strided, torch.sparse_coo, torch.sparse_csr}:
  792. raise ValueError(
  793. f"_input_mask expects strided or sparse COO or sparse CSR tensor but got {input.layout}"
  794. )
  795. mask = kwargs.get("mask")
  796. # default mask
  797. if mask is None:
  798. raise ValueError("_input_mask requires explicit mask")
  799. # mask shape must match with input shape
  800. if mask.shape != input.shape:
  801. if mask.ndim > input.ndim:
  802. raise IndexError(
  803. "_input_mask expected broadcastable mask (got mask dimensionality higher than of the input)"
  804. )
  805. if mask.layout == torch.strided:
  806. mask = torch.broadcast_to(mask.clone(), input.shape).to(dtype=torch.bool)
  807. elif mask.layout == torch.sparse_coo:
  808. mask = torch._sparse_broadcast_to(mask, input.shape)
  809. else:
  810. assert mask.layout == torch.sparse_csr
  811. # Broadcasting of CSR tensors is not implemented. Working
  812. # around by using COO layout.
  813. mask = torch._sparse_broadcast_to(
  814. mask.to_sparse(), input.shape
  815. ).to_sparse_csr()
  816. # mask layout must match with input layout
  817. if mask.layout != input.layout:
  818. if input.layout == torch.strided:
  819. mask = mask.to_dense()
  820. elif input.layout == torch.sparse_coo:
  821. if mask.layout == torch.strided:
  822. mask = mask.to_sparse(input.sparse_dim())
  823. else:
  824. mask = mask.to_sparse()
  825. else:
  826. assert input.layout == torch.sparse_csr
  827. mask = mask.to_sparse_csr()
  828. # sparse mask must be coalesced
  829. if mask.layout == torch.sparse_coo:
  830. mask = mask.coalesce()
  831. # mask is a boolean tensor
  832. mask = mask.to(dtype=torch.bool)
  833. return mask
  834. def _output_mask(op, input: Tensor, *args, **kwargs) -> Tensor:
  835. """Return output mask of masked operation applied to given arguments."""
  836. if callable(op):
  837. is_reduction = op.__name__ in {
  838. "sum",
  839. "prod",
  840. "amax",
  841. "amin",
  842. "argmax",
  843. "argmin",
  844. "mean",
  845. "median",
  846. "norm",
  847. "var",
  848. "std",
  849. "logsumexp",
  850. }
  851. is_normalization = op.__name__ in {
  852. "softmax",
  853. "log_softmax",
  854. "softmin",
  855. "normalize",
  856. "cumsum",
  857. "cumprod",
  858. }
  859. if is_reduction:
  860. if op.__name__ == "norm":
  861. if args:
  862. args = args[1:] # lstrip ord argument
  863. dim = args[0] if args else kwargs.get("dim")
  864. outmask = _input_mask(input, *args, **kwargs)
  865. keepdim = kwargs.get("keepdim", False)
  866. dim_ = _canonical_dim(dim, input.ndim)
  867. return _any(outmask, dim_, bool(keepdim))
  868. elif is_normalization:
  869. return _input_mask(input, *args, **kwargs)
  870. else:
  871. raise ValueError(
  872. f"_output_mask expected masked operation (got callable {op.__module__}.{op.__name__})"
  873. )
  874. else:
  875. raise ValueError(
  876. f"_output_mask expected masked operation (got {type(op).__name__} object)"
  877. )
  878. def _combine_input_and_mask(
  879. op, input: Union[MaskedTensor, Tensor], mask, *args
  880. ) -> Tensor:
  881. def helper(input, mask):
  882. if mask is None:
  883. return input
  884. canonical_mask = _input_mask(input, mask=mask)
  885. if callable(op):
  886. fill_value = _reduction_identity(op.__name__, input, *args)
  887. return _where(canonical_mask, input, fill_value)
  888. else:
  889. raise ValueError(
  890. f"_combine_input_and_mask expected masked operation (got {type(op).__name__} object)"
  891. )
  892. class Combine(torch.autograd.Function):
  893. @staticmethod
  894. def forward(ctx, input, mask):
  895. """Return input with masked-out elements eliminated for the given operations."""
  896. ctx.save_for_backward(mask)
  897. if mask is not None:
  898. ctx.mark_non_differentiable(mask)
  899. return helper(input, mask)
  900. @staticmethod
  901. def backward(ctx, grad_output):
  902. (mask,) = ctx.saved_tensors
  903. grad_data = (
  904. grad_output.get_data() if is_masked_tensor(grad_output) else grad_output
  905. )
  906. result = as_masked_tensor(grad_data, mask)
  907. return result, None
  908. return (
  909. Combine.apply(input.get_data(), input.get_mask()) # type: ignore[union-attr]
  910. if is_masked_tensor(input)
  911. else helper(input, mask)
  912. )
  913. @_apply_docstring_templates
  914. def sum(
  915. input: Union[Tensor, MaskedTensor],
  916. dim: DimOrDims = None,
  917. *,
  918. keepdim: Optional[bool] = False,
  919. dtype: Optional[DType] = None,
  920. mask: Optional[Tensor] = None,
  921. ) -> Tensor:
  922. # __doc__ is generated by _apply_docstring_templates decorator
  923. if dtype is None:
  924. # promote integer types to int64 when output dtype is not specified
  925. if input.layout == torch.sparse_csr:
  926. if input.dtype in {
  927. torch.uint8,
  928. torch.bool,
  929. torch.int8,
  930. torch.int16,
  931. torch.int32,
  932. }:
  933. # csr.to(dtype=torch.int64) is not implemented, so
  934. # using coo.to on input to ensure the promoted dtype
  935. input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
  936. else:
  937. dtype = input.dtype
  938. else:
  939. dtype = input.dtype
  940. if input.dtype in {
  941. torch.uint8,
  942. torch.bool,
  943. torch.int8,
  944. torch.int16,
  945. torch.int32,
  946. }:
  947. dtype = torch.int64
  948. dim_ = _canonical_dim(dim, input.ndim)
  949. mask_input = _combine_input_and_mask(sum, input, mask)
  950. if mask_input.layout == torch.strided:
  951. return torch.sum(mask_input, dim_, bool(keepdim), dtype=dtype)
  952. elif mask_input.layout == torch.sparse_coo:
  953. return _sparse_coo_scatter_reduction_helper(
  954. torch.sum, mask_input, dim_, bool(keepdim), dtype
  955. )
  956. elif mask_input.layout == torch.sparse_csr:
  957. return torch._sparse_csr_sum(
  958. mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
  959. )
  960. else:
  961. raise ValueError(
  962. f"masked sum expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  963. )
  964. @_apply_docstring_templates
  965. def prod(
  966. input: Union[Tensor, MaskedTensor],
  967. dim: DimOrDims = None,
  968. *,
  969. keepdim: Optional[bool] = False,
  970. dtype: Optional[DType] = None,
  971. mask: Optional[Tensor] = None,
  972. ) -> Tensor:
  973. # __doc__ is generated by _apply_docstring_templates decorator
  974. if dtype is None:
  975. # promote integer types to int64 when output dtype is not specified
  976. if input.layout == torch.sparse_csr:
  977. if input.dtype in {
  978. torch.uint8,
  979. torch.bool,
  980. torch.int8,
  981. torch.int16,
  982. torch.int32,
  983. }:
  984. # csr.to(dtype=torch.int64) is not implemented, so
  985. # using coo.to on input to ensure the promoted dtype
  986. input = input.to_sparse_coo().to(dtype=torch.int64).to_sparse_csr()
  987. else:
  988. dtype = input.dtype
  989. else:
  990. dtype = input.dtype
  991. if input.dtype in {
  992. torch.uint8,
  993. torch.bool,
  994. torch.int8,
  995. torch.int16,
  996. torch.int32,
  997. }:
  998. dtype = torch.int64
  999. dim_ = _canonical_dim(dim, input.ndim)
  1000. mask_input = _combine_input_and_mask(prod, input, mask)
  1001. if mask_input.layout == torch.strided:
  1002. # Workaround https://github.com/pytorch/pytorch/issues/56586
  1003. result = mask_input
  1004. result = result.to(dtype=dtype)
  1005. for d in reversed(dim_):
  1006. result = result.prod(dim=d, keepdim=bool(keepdim))
  1007. return result
  1008. elif mask_input.layout == torch.sparse_coo:
  1009. if mask is None:
  1010. # See comment in the sparse_csr branch, the same issue arises for sparse_coo tensors
  1011. raise ValueError(
  1012. "masked prod expects explicit mask for sparse_coo tensor input"
  1013. )
  1014. return _sparse_coo_scatter_reduction_helper(
  1015. torch.prod, mask_input, dim_, bool(keepdim), dtype
  1016. )
  1017. elif mask_input.layout == torch.sparse_csr:
  1018. if mask is None:
  1019. # mask is None corresponds to all-True mask. The
  1020. # unspecified elements in the CSR tensor correspond to
  1021. # zero values. Hence, the prod reduction result is
  1022. # automatically zero unless all elements are specified.
  1023. # A semi-optimal way to take this into account is to use:
  1024. #
  1025. # masked_prod(csr, ..., mask=None) == torch._sparse_csr_prod(csr, ...) * all(csr.nonzero(), ...)
  1026. #
  1027. # but that requires implementing `all` and `nonzero`
  1028. # support for sparse csr tensors.
  1029. raise ValueError(
  1030. "masked prod expects explicit mask for sparse_csr tensor input"
  1031. )
  1032. return torch._sparse_csr_prod(
  1033. mask_input, dim=list(dim_), keepdim=bool(keepdim), dtype=dtype
  1034. )
  1035. else:
  1036. raise ValueError(
  1037. f"masked prod expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1038. )
  1039. @_apply_docstring_templates
  1040. def cumsum(
  1041. input: Tensor,
  1042. dim: int,
  1043. *,
  1044. dtype: Optional[DType] = None,
  1045. mask: Optional[Tensor] = None,
  1046. ) -> Tensor:
  1047. if dtype is None:
  1048. dtype = input.dtype
  1049. dim_ = _canonical_dim(dim, input.ndim)[0]
  1050. mask_input = _combine_input_and_mask(sum, input, mask)
  1051. if mask_input.layout == torch.strided:
  1052. return torch.cumsum(mask_input, dim_, dtype=dtype).to(dtype=dtype)
  1053. else:
  1054. raise ValueError(
  1055. f"masked cumsum expects strided tensor (got {mask_input.layout} tensor)"
  1056. )
  1057. @_apply_docstring_templates
  1058. def cumprod(
  1059. input: Tensor,
  1060. dim: int,
  1061. *,
  1062. dtype: Optional[DType] = None,
  1063. mask: Optional[Tensor] = None,
  1064. ) -> Tensor:
  1065. if dtype is None:
  1066. dtype = input.dtype
  1067. dim_ = _canonical_dim(dim, input.ndim)[0]
  1068. mask_input = _combine_input_and_mask(prod, input, mask)
  1069. if mask_input.layout == torch.strided:
  1070. return torch.cumprod(mask_input, dim_, dtype=dtype).to(dtype=dtype)
  1071. else:
  1072. raise ValueError(
  1073. f"masked cumprod expects strided tensor (got {mask_input.layout} tensor)"
  1074. )
  1075. @_apply_docstring_templates
  1076. def amax(
  1077. input: Union[Tensor, MaskedTensor],
  1078. dim: DimOrDims = None,
  1079. *,
  1080. keepdim: Optional[bool] = False,
  1081. dtype: Optional[DType] = None,
  1082. mask: Optional[Tensor] = None,
  1083. ) -> Tensor:
  1084. """\
  1085. {reduction_signature}
  1086. {reduction_descr}
  1087. {reduction_identity_dtype}
  1088. {reduction_args}
  1089. {reduction_example}"""
  1090. if dtype is None:
  1091. dtype = input.dtype
  1092. mask_input = _combine_input_and_mask(amax, input, mask)
  1093. dim_ = _canonical_dim(dim, mask_input.ndim)
  1094. if mask_input.layout == torch.strided:
  1095. return torch.amax(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
  1096. elif mask_input.layout == torch.sparse_coo:
  1097. if mask is None:
  1098. # See comment in the sparse_csr branch of prod, a similar issue arises here
  1099. # where unspecified elements along a dimension may need to be reduced with the result
  1100. raise ValueError(
  1101. "masked amax expects explicit mask for sparse_coo tensor input"
  1102. )
  1103. return _sparse_coo_scatter_reduction_helper(
  1104. torch.amax, mask_input, dim_, bool(keepdim), dtype
  1105. )
  1106. elif mask_input.layout == torch.sparse_csr:
  1107. if mask is None:
  1108. raise ValueError(
  1109. "masked amax expects explicit mask for sparse_csr tensor input"
  1110. )
  1111. return _sparse_csr_segment_reduction_helper(
  1112. torch.amax, mask_input, dim_, bool(keepdim), dtype
  1113. )
  1114. else:
  1115. raise ValueError(
  1116. f"masked amax expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1117. )
  1118. @_apply_docstring_templates
  1119. def amin(
  1120. input: Union[Tensor, MaskedTensor],
  1121. dim: DimOrDims = None,
  1122. *,
  1123. keepdim: Optional[bool] = False,
  1124. dtype: Optional[DType] = None,
  1125. mask: Optional[Tensor] = None,
  1126. ) -> Tensor:
  1127. """\
  1128. {reduction_signature}
  1129. {reduction_descr}
  1130. {reduction_identity_dtype}
  1131. {reduction_args}
  1132. {reduction_example}"""
  1133. if dtype is None:
  1134. dtype = input.dtype
  1135. mask_input = _combine_input_and_mask(amin, input, mask)
  1136. dim_ = _canonical_dim(dim, mask_input.ndim)
  1137. if mask_input.layout == torch.strided:
  1138. return torch.amin(mask_input, dim_, bool(keepdim)).to(dtype=dtype)
  1139. elif mask_input.layout == torch.sparse_coo:
  1140. if mask is None:
  1141. # See comment in the sparse_csr branch of prod, a similar issue arises here
  1142. # where unspecified elements along a dimension may need to be reduced with the result
  1143. raise ValueError(
  1144. "masked amax expects explicit mask for sparse_coo tensor input"
  1145. )
  1146. return _sparse_coo_scatter_reduction_helper(
  1147. torch.amin, mask_input, dim_, bool(keepdim), dtype
  1148. )
  1149. elif mask_input.layout == torch.sparse_csr:
  1150. if mask is None:
  1151. raise ValueError(
  1152. "masked amin expects explicit mask for sparse_csr tensor input"
  1153. )
  1154. return _sparse_csr_segment_reduction_helper(
  1155. torch.amin, mask_input, dim_, bool(keepdim), dtype
  1156. )
  1157. else:
  1158. raise ValueError(
  1159. f"masked amin expects strided, sparse_coo or sparse_csr tensor (got {mask_input.layout} tensor)"
  1160. )
  1161. @_apply_docstring_templates
  1162. def argmax(
  1163. input: Union[Tensor, MaskedTensor],
  1164. dim: Optional[int] = None,
  1165. *,
  1166. keepdim: Optional[bool] = False,
  1167. dtype: Optional[DType] = None,
  1168. mask: Optional[Tensor] = None,
  1169. ) -> Tensor:
  1170. """\
  1171. {reduction_signature}
  1172. {reduction_descr}
  1173. {reduction_identity_dtype}
  1174. {reduction_args}
  1175. {reduction_example}"""
  1176. if dtype is None:
  1177. dtype = input.dtype
  1178. mask_input = _combine_input_and_mask(argmax, input, mask)
  1179. if mask_input.layout == torch.strided:
  1180. return torch.argmax(mask_input, dim, bool(keepdim)).to(dtype=dtype)
  1181. else:
  1182. raise ValueError(
  1183. f"masked argmax expects strided tensor (got {mask_input.layout} tensor)"
  1184. )
  1185. @_apply_docstring_templates
  1186. def argmin(
  1187. input: Union[Tensor, MaskedTensor],
  1188. dim: Optional[int] = None,
  1189. *,
  1190. keepdim: Optional[bool] = False,
  1191. dtype: Optional[DType] = None,
  1192. mask: Optional[Tensor] = None,
  1193. ) -> Tensor:
  1194. """\
  1195. {reduction_signature}
  1196. {reduction_descr}
  1197. {reduction_identity_dtype}
  1198. {reduction_args}
  1199. {reduction_example}"""
  1200. if dtype is None:
  1201. dtype = input.dtype
  1202. mask_input = _combine_input_and_mask(argmin, input, mask)
  1203. if mask_input.layout == torch.strided:
  1204. return torch.argmin(mask_input, dim, bool(keepdim)).to(dtype=dtype)
  1205. else:
  1206. raise ValueError(
  1207. f"masked argmin expects strided tensor (got {mask_input.layout} tensor)"
  1208. )
  1209. @_apply_docstring_templates
  1210. def mean(
  1211. input: Union[Tensor, MaskedTensor],
  1212. dim: DimOrDims = None,
  1213. *,
  1214. keepdim: Optional[bool] = False,
  1215. dtype: Optional[DType] = None,
  1216. mask: Optional[Tensor] = None,
  1217. ) -> Tensor:
  1218. """\
  1219. {reduction_signature}
  1220. {reduction_descr}
  1221. By definition, the identity value of a mean operation is the mean
  1222. value of the tensor. If all elements of the input tensor along given
  1223. dimension(s) :attr:`dim` are masked-out, the identity value of the
  1224. mean is undefined. Due to this ambiguity, the elements of output
  1225. tensor with strided layout, that correspond to fully masked-out
  1226. elements, have ``nan`` values.
  1227. {reduction_args}
  1228. {reduction_example}"""
  1229. dtype_source = "Optional"
  1230. if dtype is None:
  1231. dtype = input.dtype
  1232. dtype_source = "Input"
  1233. if not (dtype.is_floating_point or dtype.is_complex):
  1234. raise ValueError(
  1235. f"mean(): Could not infer output dtype. {dtype_source} dtype must be either "
  1236. f"a floating point or complex dtype. Got: {dtype}"
  1237. )
  1238. if input.layout == torch.strided:
  1239. if mask is None:
  1240. # TODO: compute count analytically
  1241. count = sum(
  1242. torch.ones(input.shape, dtype=torch.int64, device=input.device),
  1243. dim,
  1244. keepdim=keepdim,
  1245. )
  1246. total = sum(input, dim, keepdim=keepdim, dtype=dtype)
  1247. else:
  1248. inmask = _input_mask(input, mask=mask)
  1249. count = inmask.sum(dim=dim, keepdim=bool(keepdim))
  1250. total = sum(input, dim, keepdim=keepdim, dtype=dtype, mask=inmask)
  1251. return total / count
  1252. elif input.layout == torch.sparse_csr:
  1253. mask_input = _combine_input_and_mask(mean, input, mask)
  1254. dim_ = _canonical_dim(dim, mask_input.ndim)
  1255. if mask is None:
  1256. raise ValueError(
  1257. "masked mean expects explicit mask for sparse_csr tensor input"
  1258. )
  1259. return _sparse_csr_segment_reduction_helper(
  1260. torch.mean, mask_input, dim_, bool(keepdim), dtype
  1261. )
  1262. else:
  1263. raise ValueError(
  1264. f"masked mean expects strided or sparse_csr tensor (got {input.layout} tensor)"
  1265. )
  1266. @_apply_docstring_templates
  1267. def median(
  1268. input: Union[Tensor, MaskedTensor],
  1269. dim: int = -1,
  1270. *,
  1271. keepdim: bool = False,
  1272. dtype: Optional[DType] = None,
  1273. mask: Optional[Tensor] = None,
  1274. ) -> Tensor:
  1275. """\
  1276. {reduction_signature}
  1277. {reduction_descr}
  1278. By definition, the identity value of a median operation is the median
  1279. value of the tensor. If all elements of the input tensor along given
  1280. dimension(s) :attr:`dim` are masked-out, the identity value of the
  1281. median is undefined. Due to this ambiguity, the elements of output
  1282. tensor with strided layout, that correspond to fully masked-out
  1283. elements, have ``nan`` values.
  1284. {reduction_args}
  1285. {reduction_example}"""
  1286. if dtype is None:
  1287. dtype = input.dtype
  1288. dim_ = _canonical_dim(dim, input.ndim)[0]
  1289. is_float = torch.is_floating_point(input)
  1290. if not is_float:
  1291. input = input.to(dtype=torch.float)
  1292. mask_input = _combine_input_and_mask(median, input, mask)
  1293. if mask_input.layout == torch.strided:
  1294. output = torch.nanmedian(mask_input, dim_, keepdim).values
  1295. if is_float:
  1296. return output
  1297. elif not is_float and not torch.isnan(output).any():
  1298. return output.to(dtype=dtype)
  1299. else:
  1300. raise ValueError(
  1301. "masked median expects no fully masked out rows if dtype is not floating point"
  1302. )
  1303. else:
  1304. raise ValueError(
  1305. f"masked median expects strided tensor (got {mask_input.layout} tensor)"
  1306. )
  1307. @_apply_docstring_templates
  1308. def logsumexp(
  1309. input: Tensor,
  1310. dim: DimOrDims = None,
  1311. *,
  1312. keepdim: bool = False,
  1313. dtype: Optional[DType] = None,
  1314. mask: Optional[Tensor] = None,
  1315. ) -> Tensor:
  1316. if dtype is None:
  1317. dtype = input.dtype
  1318. dim_ = _canonical_dim(dim, input.ndim)
  1319. mask_input = _combine_input_and_mask(logsumexp, input, mask)
  1320. if mask_input.layout == torch.strided:
  1321. return torch.logsumexp(mask_input, dim_, keepdim=keepdim).to(dtype=dtype)
  1322. else:
  1323. raise ValueError(
  1324. f"masked logsumexp expects strided tensor (got {mask_input.layout} tensor)"
  1325. )
  1326. # Cannot use _apply_docstring_templates as it is only set up for reductions and normalizations
  1327. def logaddexp(
  1328. input: Union[Tensor, MaskedTensor],
  1329. other: Union[Tensor, MaskedTensor],
  1330. *,
  1331. dtype: Optional[DType] = None,
  1332. input_mask: Optional[Tensor] = None,
  1333. other_mask: Optional[Tensor] = None,
  1334. ) -> Tensor:
  1335. """logaddexp(input, other, *, dtype=None, input_mask=None, other_mask=None) -> Tensor
  1336. Returns logaddexp of all the elements in the :attr:`input` and the :attr:`other`
  1337. tensor. The :attr:`input` elements are masked out according to the boolean tensor
  1338. :attr:`input_mask` and the attr:`other` elements are masked out according to the boolean tensor
  1339. :attr:`other_mask`.
  1340. The shapes of a mask tensor and the tensor to be masked
  1341. don't need to match, but they must be :ref:`broadcastable
  1342. <broadcasting-semantics>` and the dimensionality of the mask
  1343. tensor must not be greater than of the tensor to be masked.
  1344. Args:
  1345. input (Tensor): the input tensor
  1346. other (Tensor): the second input tensor
  1347. Keyword args:
  1348. dtype (:class:`torch.dtype`, optional): the desired data type
  1349. of returned tensor. If specified, the output tensor is
  1350. casted to :attr:`dtype` after the operation is
  1351. performed. Default: None.
  1352. input_mask (:class:`torch.Tensor`, optional): the boolean tensor
  1353. containing the binary mask of validity of :attr:`input` tensor elements.
  1354. Default: None that is equivalent to ``torch.ones(input.shape, dtype=torch.bool)``.
  1355. other_mask (:class:`torch.Tensor`, optional): the boolean tensor
  1356. containing the binary mask of validity of :attr:`other` tensor elements.
  1357. Default: None that is equivalent to ``torch.ones(other.shape, dtype=torch.bool)``.
  1358. Example::
  1359. >>> input = torch.tensor([-100.0, -200, -300])
  1360. >>> input
  1361. tensor([-100., -200., -300.])
  1362. >>> other = torch.tensor([-1.0, -2, -3])
  1363. >>> other
  1364. tensor([-1., -2., -3.])
  1365. >>> mask = torch.tensor([True, False, True])
  1366. >>> mask
  1367. tensor([ True, False, True])
  1368. >>> torch.masked._ops.logaddexp(input, other, input_mask=mask, other_mask=mask)
  1369. tensor([-1., -inf, -3.])"""
  1370. if dtype is None:
  1371. dtype = input.dtype
  1372. if input.layout == torch.strided and other.layout == torch.strided:
  1373. mask_input = _combine_input_and_mask(logaddexp, input, input_mask)
  1374. mask_other = _combine_input_and_mask(logaddexp, other, other_mask)
  1375. return torch.logaddexp(mask_input, mask_other).to(dtype=dtype)
  1376. else:
  1377. raise ValueError(
  1378. f"masked logaddexp expects strided tensors (got {input.layout} tensor for input, {other.layout} for other)"
  1379. )
  1380. @_apply_docstring_templates
  1381. def norm(
  1382. input: Union[Tensor, MaskedTensor],
  1383. ord: Optional[float] = 2.0,
  1384. dim: DimOrDims = None,
  1385. *,
  1386. keepdim: Optional[bool] = False,
  1387. dtype: Optional[DType] = None,
  1388. mask: Optional[Tensor] = None,
  1389. ) -> Tensor:
  1390. """\
  1391. {reduction_signature}
  1392. {reduction_descr}
  1393. The identity value of norm operation, which is used to start the
  1394. reduction, is ``{identity_float32}``, except for ``ord=-inf`` it is
  1395. ``{identity_ord_ninf}``.
  1396. {reduction_args}
  1397. {reduction_example}"""
  1398. if dtype is None:
  1399. dtype = input.dtype
  1400. mask_input = _combine_input_and_mask(norm, input, mask, ord)
  1401. if mask_input.layout == torch.strided:
  1402. dim_ = _canonical_dim(dim, input.ndim)
  1403. return torch.linalg.vector_norm(
  1404. mask_input, ord, dim_, bool(keepdim), dtype=dtype
  1405. )
  1406. else:
  1407. raise ValueError(
  1408. f"masked norm expects strided tensor (got {mask_input.layout} tensor)"
  1409. )
  1410. def _std_var(
  1411. input: Union[Tensor, MaskedTensor],
  1412. dim: DimOrDims,
  1413. unbiased: Optional[bool],
  1414. *,
  1415. correction_opt: Optional[Union[int, float]],
  1416. keepdim: Optional[bool],
  1417. dtype: Optional[DType],
  1418. mask: Optional[Tensor],
  1419. take_sqrt: Optional[bool],
  1420. ) -> Tensor:
  1421. assert unbiased is None or correction_opt is None, (
  1422. "Only one of unbiased and correction may be given"
  1423. )
  1424. correction = 1.0
  1425. if unbiased is not None:
  1426. correction = 1.0 if unbiased else 0.0
  1427. if correction_opt is not None:
  1428. correction = sym_float(correction_opt)
  1429. if dtype is None:
  1430. dtype = input.dtype
  1431. if not (dtype.is_floating_point or dtype.is_complex):
  1432. dtype = torch.float32
  1433. compute_dtype = dtype
  1434. if not (compute_dtype.is_floating_point or compute_dtype.is_complex):
  1435. compute_dtype = torch.float32
  1436. if input.layout == torch.strided:
  1437. if mask is None:
  1438. # TODO: compute count analytically
  1439. count = sum(
  1440. torch.ones(input.shape, dtype=torch.int64, device=input.device),
  1441. dim,
  1442. keepdim=True,
  1443. )
  1444. sample_total = sum(input, dim, keepdim=True, dtype=dtype)
  1445. else:
  1446. inmask = _input_mask(input, mask=mask)
  1447. count = inmask.sum(dim=dim, keepdim=True)
  1448. sample_total = sum(input, dim, keepdim=True, dtype=dtype, mask=inmask)
  1449. # TODO: replace torch.subtract/divide/square/maximum with
  1450. # masked subtract/divide/square/maximum when these will be
  1451. # available.
  1452. sample_mean = torch.divide(sample_total, count)
  1453. x = torch.subtract(input, sample_mean)
  1454. if mask is None:
  1455. total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype)
  1456. else:
  1457. total = sum(
  1458. x * x.conj(),
  1459. dim,
  1460. keepdim=keepdim,
  1461. dtype=compute_dtype,
  1462. mask=inmask, # type: ignore[possibly-undefined]
  1463. )
  1464. if not keepdim:
  1465. count = count.reshape(total.shape)
  1466. if correction != 0:
  1467. real_dtype = (
  1468. corresponding_real_dtype(compute_dtype)
  1469. if compute_dtype.is_complex
  1470. else compute_dtype
  1471. )
  1472. count = count.to(real_dtype)
  1473. count = torch.subtract(count, correction)
  1474. count = torch.maximum(count, count.new_zeros([]))
  1475. output = torch.divide(total, count).to(dtype=dtype)
  1476. if take_sqrt:
  1477. output = torch.sqrt(output)
  1478. return output
  1479. else:
  1480. raise ValueError(
  1481. f"masked std/var expects strided tensor (got {input.layout} tensor)"
  1482. )
  1483. @_apply_docstring_templates
  1484. def var(
  1485. input: Union[Tensor, MaskedTensor],
  1486. dim: DimOrDims = None,
  1487. unbiased: Optional[bool] = None,
  1488. *,
  1489. correction: Optional[Union[int, float]] = None,
  1490. keepdim: Optional[bool] = False,
  1491. dtype: Optional[DType] = None,
  1492. mask: Optional[Tensor] = None,
  1493. ) -> Tensor:
  1494. """\
  1495. {reduction_signature}
  1496. {reduction_descr}
  1497. The identity value of sample variance operation is undefined. The
  1498. elements of output tensor with strided layout, that correspond to
  1499. fully masked-out elements, have ``nan`` values.
  1500. {reduction_args}
  1501. {reduction_example}"""
  1502. return _std_var(
  1503. input=input,
  1504. dim=dim,
  1505. unbiased=unbiased,
  1506. correction_opt=correction,
  1507. keepdim=keepdim,
  1508. dtype=dtype,
  1509. mask=mask,
  1510. take_sqrt=False,
  1511. )
  1512. @_apply_docstring_templates
  1513. def std(
  1514. input: Union[Tensor, MaskedTensor],
  1515. dim: DimOrDims = None,
  1516. unbiased: Optional[bool] = None,
  1517. *,
  1518. correction: Optional[int] = None,
  1519. keepdim: Optional[bool] = False,
  1520. dtype: Optional[DType] = None,
  1521. mask: Optional[Tensor] = None,
  1522. ) -> Tensor:
  1523. """\
  1524. {reduction_signature}
  1525. {reduction_descr}
  1526. The identity value of sample standard deviation operation is undefined. The
  1527. elements of output tensor with strided layout, that correspond to
  1528. fully masked-out elements, have ``nan`` values.
  1529. {reduction_args}
  1530. {reduction_example}"""
  1531. return _std_var(
  1532. input=input,
  1533. dim=dim,
  1534. unbiased=unbiased,
  1535. correction_opt=correction,
  1536. keepdim=keepdim,
  1537. dtype=dtype,
  1538. mask=mask,
  1539. take_sqrt=True,
  1540. )
  1541. @_apply_docstring_templates
  1542. def softmax(
  1543. input: Union[Tensor, MaskedTensor],
  1544. dim: int,
  1545. *,
  1546. dtype: Optional[DType] = None,
  1547. mask: Optional[Tensor] = None,
  1548. ) -> Tensor:
  1549. if dtype is None:
  1550. dtype = input.dtype
  1551. dim_ = _canonical_dim(dim, input.ndim)[0]
  1552. mask_input = _combine_input_and_mask(amax, input, mask)
  1553. if mask_input.layout == torch.strided:
  1554. return torch.nn.functional.softmax(mask_input, dim_, dtype=dtype)
  1555. else:
  1556. raise ValueError(
  1557. f"masked softmax expects strided tensor (got {mask_input.layout} tensor)"
  1558. )
  1559. @_apply_docstring_templates
  1560. def log_softmax(
  1561. input: Union[Tensor, MaskedTensor],
  1562. dim: int,
  1563. *,
  1564. dtype: Optional[DType] = None,
  1565. mask: Optional[Tensor] = None,
  1566. ) -> Tensor:
  1567. if dtype is None:
  1568. dtype = input.dtype
  1569. dim_ = _canonical_dim(dim, input.ndim)[0]
  1570. mask_input = _combine_input_and_mask(amax, input, mask)
  1571. if mask_input.layout == torch.strided:
  1572. return torch.nn.functional.log_softmax(mask_input, dim_, dtype=dtype)
  1573. else:
  1574. raise ValueError(
  1575. f"masked log_softmax expects strided tensor (got {mask_input.layout} tensor)"
  1576. )
  1577. @_apply_docstring_templates
  1578. def softmin(
  1579. input: Union[Tensor, MaskedTensor],
  1580. dim: int,
  1581. *,
  1582. dtype: Optional[DType] = None,
  1583. mask: Optional[Tensor] = None,
  1584. ) -> Tensor:
  1585. if dtype is None:
  1586. dtype = input.dtype
  1587. dim_ = _canonical_dim(dim, input.ndim)[0]
  1588. mask_input = _combine_input_and_mask(amin, input, mask)
  1589. if mask_input.layout == torch.strided:
  1590. return torch.nn.functional.softmin(mask_input, dim_, dtype=dtype)
  1591. else:
  1592. raise ValueError(
  1593. f"masked softmin expects strided tensor (got {mask_input.layout} tensor)"
  1594. )
  1595. @_apply_docstring_templates
  1596. def normalize(
  1597. input: Union[Tensor, MaskedTensor],
  1598. ord: float,
  1599. dim: int,
  1600. *,
  1601. eps: float = 1e-12,
  1602. dtype: Optional[DType] = None,
  1603. mask: Optional[Tensor] = None,
  1604. ) -> Tensor:
  1605. if dtype is None:
  1606. dtype = input.dtype
  1607. # TODO: eliminate mask_input as unnecessary when using masked divide.
  1608. mask_input = _combine_input_and_mask(sum, input, mask)
  1609. if mask_input.layout == torch.strided:
  1610. nrm_ = norm(input, ord, dim, keepdim=True, dtype=dtype, mask=mask)
  1611. # TODO: replace torch.maximum with masked maximum when available.
  1612. denom = torch.maximum(nrm_, nrm_.new_full([], eps))
  1613. # TODO: replace torch.divide with masked divide when available.
  1614. return torch.divide(mask_input, denom)
  1615. else:
  1616. raise ValueError(
  1617. f"masked normalize expects strided tensor (got {mask_input.layout} tensor)"
  1618. )