ufunc.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209
  1. from __future__ import annotations
  2. from dataclasses import dataclass
  3. import torchgen.api.types as api_types
  4. from torchgen.api import cpp, structured
  5. from torchgen.api.types import (
  6. ArgName,
  7. BaseCppType,
  8. BaseCType,
  9. Binding,
  10. ConstRefCType,
  11. CType,
  12. NamedCType,
  13. scalarT,
  14. )
  15. from torchgen.model import (
  16. Argument,
  17. BaseTy,
  18. BaseType,
  19. DispatchKey,
  20. FunctionSchema,
  21. NativeFunctionsGroup,
  22. Type,
  23. )
  24. def schema_kernel_name(func: FunctionSchema, dispatch_key: DispatchKey) -> str:
  25. assert func.is_out_fn(), "ufunc.kernel_name should only be invoked on out schemas"
  26. return f"ufunc_{func.name.name}_{dispatch_key}"
  27. def kernel_name(g: NativeFunctionsGroup, dispatch_key: DispatchKey) -> str:
  28. return schema_kernel_name(g.out.func, dispatch_key)
  29. # Tensors are omitted (as they are stored in TensorIterator), everything else is
  30. # passed along (technically, we can pass tensors along too, it just wastes
  31. # argument registers)
  32. #
  33. # NB: used for CPU only
  34. def dispatchstub_type(t: Type, *, binds: ArgName) -> NamedCType | None:
  35. # Dispatch stubs are always plain ints
  36. r = cpp.valuetype_type(t, binds=binds, symint=False)
  37. if r is not None:
  38. return r
  39. if t == BaseType(BaseTy.Scalar):
  40. return NamedCType(binds, ConstRefCType(BaseCType(scalarT)))
  41. elif t == BaseType(BaseTy.Tensor):
  42. return None
  43. else:
  44. raise AssertionError(f"unrecognized type {repr(t)}")
  45. def opmath_type(scalar_t: BaseCppType) -> BaseCppType:
  46. if scalar_t == api_types.scalar_t:
  47. return api_types.opmath_t
  48. raise NotImplementedError
  49. # NB: Tensors in constructor are stored in opmath_t, not scalar_t
  50. # because Tensor in constructor = its a scalar tensor partially applied =
  51. # it can be higher precision and we want to compute in that higher precision
  52. #
  53. # NB: CUDA only
  54. def ufunctor_ctor_type(t: Type, *, binds: ArgName, scalar_t: BaseCppType) -> NamedCType:
  55. r = cpp.valuetype_type(t, binds=binds, symint=False)
  56. if r is not None:
  57. return r
  58. if t == BaseType(BaseTy.Scalar):
  59. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  60. elif t == BaseType(BaseTy.Tensor):
  61. return NamedCType(binds, BaseCType(opmath_type(scalar_t)))
  62. else:
  63. raise AssertionError(f"unrecognized type {repr(t)}")
  64. # Only Tensors ever get passed directly to operator()
  65. #
  66. # NB: CUDA only
  67. # (Actually, this works for CPU too)
  68. def ufunctor_apply_type(
  69. t: Type, *, binds: ArgName, scalar_t: BaseCppType
  70. ) -> NamedCType:
  71. if t == BaseType(BaseTy.Tensor):
  72. return NamedCType(binds, BaseCType(scalar_t))
  73. else:
  74. raise AssertionError(f"unrecognized type {repr(t)}")
  75. # The actual ufunc template function the user writes. Everything here
  76. # is done in the computation type. compute_t is opmath_t in CUDA and scalar_t
  77. # in CPU
  78. def ufunc_type(t: Type, *, binds: ArgName, compute_t: CType) -> NamedCType:
  79. r = cpp.valuetype_type(t, binds=binds, symint=False)
  80. if r is not None:
  81. return r
  82. if t == BaseType(BaseTy.Scalar):
  83. return NamedCType(binds, compute_t)
  84. elif t == BaseType(BaseTy.Tensor):
  85. return NamedCType(binds, compute_t)
  86. else:
  87. raise AssertionError(f"unrecognized type {repr(t)}")
  88. def ufunctor_ctor_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  89. return Binding(
  90. nctype=ufunctor_ctor_type(a.type, binds=a.name, scalar_t=scalar_t),
  91. name=a.name,
  92. default=None,
  93. argument=a,
  94. )
  95. def ufunctor_apply_argument(a: Argument, scalar_t: BaseCppType) -> Binding:
  96. return Binding(
  97. nctype=ufunctor_apply_type(a.type, binds=a.name, scalar_t=scalar_t),
  98. name=a.name,
  99. default=None,
  100. argument=a,
  101. )
  102. def ufunc_argument(a: Argument, compute_t: CType) -> Binding:
  103. return Binding(
  104. nctype=ufunc_type(a.type, binds=a.name, compute_t=compute_t),
  105. name=a.name,
  106. default=None,
  107. argument=a,
  108. )
  109. @dataclass(frozen=True)
  110. class UfunctorBindings:
  111. ctor: list[Binding]
  112. apply: list[Binding]
  113. # ufunctors are a CUDA-only concept representing functors that take some of
  114. # their arguments on a host-side constructor, and the rest in the device-side
  115. # apply. E.g.,
  116. #
  117. # template <typename scalar_t>
  118. # struct CUDAFunctorOnSelf_add {
  119. # using opmath_t = at::opmath_type<scalar_t>;
  120. # opmath_t other_;
  121. # opmath_t alpha_;
  122. # CUDAFunctorOnSelf_add(opmath_t other, opmath_t alpha) : other_(other), alpha_(alpha) {}
  123. # __device__ scalar_t operator()(scalar_t self) {
  124. # return ufunc::add(static_cast<opmath_t>(self), other_, alpha_);
  125. # }
  126. # };
  127. #
  128. # The ctor refers to the constructor CUDAFunctorOnSelf_add, while apply refers
  129. # to the operator() definition
  130. def ufunctor_arguments(
  131. g: NativeFunctionsGroup, *, scalar_tensor_idx: int | None, scalar_t: BaseCppType
  132. ) -> UfunctorBindings:
  133. ctor = []
  134. apply = []
  135. for a in g.functional.func.arguments.flat_non_out:
  136. if a.type.is_tensor_like():
  137. if scalar_tensor_idx == 0:
  138. # put it in the ctor anyway
  139. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  140. scalar_tensor_idx = None
  141. else:
  142. if scalar_tensor_idx is not None:
  143. scalar_tensor_idx -= 1
  144. apply.append(ufunctor_apply_argument(a, scalar_t=scalar_t))
  145. else:
  146. ctor.append(ufunctor_ctor_argument(a, scalar_t=scalar_t))
  147. assert scalar_tensor_idx is None
  148. return UfunctorBindings(ctor=ctor, apply=apply)
  149. # ufuncs are the inner loop template functions that you wrote in ufunc/add.h
  150. # which do the actual computation in question. E.g.,
  151. #
  152. # template <typename T>
  153. # C10_HOST_DEVICE T add(T self, T other, T alpha) __ubsan_ignore_undefined__ {
  154. # return self + alpha * other;
  155. # }
  156. #
  157. # In this file, we refer to T as compute_t which is bound by caller
  158. def ufunc_arguments(g: NativeFunctionsGroup, *, compute_t: CType) -> list[Binding]:
  159. return [
  160. ufunc_argument(a, compute_t=compute_t)
  161. for a in g.functional.func.arguments.flat_non_out
  162. ]
  163. # Stubs are the DispatchStub trampolines that CPU kernels use to get to their
  164. # vectorized versions. E.g.,
  165. #
  166. # using structured_binary_fn_alpha = void(*)(TensorIteratorBase&, const Scalar& alpha);
  167. # DECLARE_DISPATCH(structured_binary_fn_alpha, add_stub);
  168. def stub_arguments(g: NativeFunctionsGroup) -> list[Binding]:
  169. # stubs drop all tensor arguments (they are implicit in the TensorIterator
  170. # argument and keep everything else)
  171. return [
  172. r
  173. for a in g.out.func.arguments.flat_non_out
  174. if not a.type.is_tensor_like()
  175. for r in structured.argument(a)
  176. ]