__init__.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. import functorch._C
  2. import torch
  3. from functorch._C import dim as _C
  4. from .tree_map import tree_flatten, tree_map
  5. from .wrap_type import wrap_type
  6. _C._patch_tensor_class()
  7. dims, DimList, dimlists = _C.dims, _C.DimList, _C.dimlists
  8. class DimensionMismatchError(Exception):
  9. pass
  10. class DimensionBindError(Exception):
  11. pass
  12. from . import op_properties
  13. # use dict to avoid writing C++ bindings for set
  14. pointwise = dict.fromkeys(op_properties.pointwise, True)
  15. class _Tensor:
  16. # fast path around slow wrapping/unwrapping logic for simply queries used
  17. # by the implementation...
  18. @property
  19. def dims(self):
  20. return tuple(d for d in self._levels if isinstance(d, Dim))
  21. def dim(self):
  22. return self.ndim
  23. __torch_function__ = classmethod(_C.__torch_function__)
  24. expand = _C._instancemethod(_C.expand)
  25. index = _C._instancemethod(_C.index)
  26. def __repr__(self):
  27. tensor, levels, ndim = self._tensor, self._levels, self.ndim
  28. return f"{tensor}\nwith dims={tuple(l + ndim if isinstance(l, int) else l for l in levels)} sizes={tuple(tensor.size())}"
  29. TensorLike = (_Tensor, torch.Tensor)
  30. class Dim(_C.Dim, _Tensor):
  31. # note that _C.Dim comes before tensor because we want the Dim API for things like size to take precedence.
  32. # Tensor defines format, but we want to print Dims with special formatting
  33. __format__ = object.__format__
  34. class Tensor(_Tensor, _C.Tensor):
  35. from_positional = staticmethod(_C.Tensor_from_positional)
  36. sum = _C._instancemethod(_C.Tensor_sum)
  37. def cat(tensors, dim, new_dim):
  38. n = dims()
  39. return stack(tensors, n, dim).index([n, dim], new_dim)
  40. _wrap = _C._wrap
  41. def _def(name, *args, **kwargs):
  42. orig = getattr(torch.Tensor, name)
  43. setattr(_Tensor, name, _C._instancemethod(_wrap(orig, *args, **kwargs)))
  44. t__getitem__ = _C._instancemethod(_C.__getitem__)
  45. stack = _C.stack
  46. split = _C._instancemethod(_C.split)
  47. # note: there is no python reference
  48. t__setitem__ = _C._instancemethod(_C.__setitem__)
  49. # this is patched in the C API because otherwise torch.Tensor will
  50. # no longer be considered a sequence and things will break
  51. # torch.Tensor.__getitem__ = t__getitem__
  52. _Tensor.__getitem__ = t__getitem__
  53. # torch.Tensor.__setitem__ = t__setitem__
  54. _Tensor.__setitem__ = t__setitem__
  55. torch.Tensor.split = split
  56. _Tensor.split = split
  57. torch.Tensor.expand = _C._instancemethod(_C.expand)
  58. torch.Tensor.index = _C._instancemethod(_C.index)
  59. wrap_type(_Tensor, torch.Tensor, _Tensor.__torch_function__)
  60. del _Tensor.ndim
  61. _Tensor.order = _C._instancemethod(_C.order)
  62. _def("mean")
  63. _def("sum")
  64. _def("all")
  65. _def("amax")
  66. _def("amin")
  67. _def("aminmax")
  68. _def("any")
  69. _def("count_nonzero")
  70. _def("logsumexp")
  71. _def("nanmean")
  72. _def("nansum")
  73. _def("prod")
  74. _def("std", keepdim_offset=2)
  75. _def("var", keepdim_offset=2)
  76. _def("max", single_dim=True)
  77. _def("min", single_dim=True)
  78. _def("argmax", single_dim=True)
  79. _def("argmin", single_dim=True)
  80. _def("kthvalue", single_dim=True)
  81. _def("median", single_dim=True)
  82. _def("nanmedian", single_dim=True)
  83. _def("mode", single_dim=True)
  84. _def("sort", reduce=False)
  85. _def("argsort", reduce=False)
  86. _def("unbind", single_dim=True)
  87. _def("chunk", dim_offset=1, reduce=False)
  88. _def("cummax", single_dim=True, reduce=False)
  89. _def("cummin", single_dim=True, reduce=False)
  90. _def("cumprod", single_dim=True, reduce=False)
  91. _def("cumprod_", single_dim=True, reduce=False)
  92. _def("cumsum", single_dim=True, reduce=False)
  93. _def("cumsum_", single_dim=True, reduce=False)
  94. _def("logcumsumexp", single_dim=True, reduce=False)
  95. _def("renorm", dim_offset=1, single_dim=True, reduce=False)
  96. _def("softmax", single_dim=True, reduce=False)
  97. softmax = _wrap(torch.nn.functional.softmax, single_dim=True, reduce=False)
  98. # stuff to handle in the future, because they require special
  99. # binding logic for dims
  100. # cross
  101. # diag_embed
  102. # diagonal
  103. # diagonal_scatter
  104. # diff
  105. # nanquantile
  106. # quantile
  107. # roll
  108. # rot90
  109. # topk (new dimes on output)
  110. # should these all be subsumed by inplace indexing?
  111. # index_add_
  112. # index_add
  113. # index_copy
  114. # index_copy_
  115. # index_fill
  116. # index_fill_
  117. # index_select
  118. # scatter
  119. # scatter_
  120. # scatter_add
  121. # scatter_add_
  122. # scatter_reduce