fake_impl.py 8.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. from typing import Callable
  5. from typing_extensions import deprecated
  6. import torch
  7. from torch._library.utils import Kernel, RegistrationHandle
  8. class FakeImplHolder:
  9. """A holder where one can register an fake impl to."""
  10. def __init__(self, qualname: str):
  11. self.qualname: str = qualname
  12. # kernels stores all registered fake kernels, ordered by registration
  13. # time ascendingly (newer registration after older registration). If an
  14. # operator library gets loaded that overrides an existing fake kernel,
  15. # both kernels will be in the list, but the newest one will be the one
  16. # that is run. If the library is unloaded, we will remove the kernel
  17. # from this list.
  18. self.kernels: list[Kernel] = []
  19. @property
  20. def kernel(self):
  21. if len(self.kernels) == 0:
  22. return None
  23. return self.kernels[-1]
  24. @kernel.setter
  25. def kernel(self, value):
  26. raise RuntimeError("Unable to directly set kernel.")
  27. def register(
  28. self, func: Callable, source: str, lib, *, allow_override=False
  29. ) -> RegistrationHandle:
  30. """Register an fake impl.
  31. Returns a RegistrationHandle that one can use to de-register this
  32. fake impl.
  33. """
  34. if not allow_override:
  35. if self.kernel is not None:
  36. raise RuntimeError(
  37. f"register_fake(...): the operator {self.qualname} "
  38. f"already has an fake impl registered at "
  39. f"{self.kernel.source}."
  40. )
  41. if torch._C._dispatch_has_kernel_for_dispatch_key(self.qualname, "Meta"):
  42. raise RuntimeError(
  43. f"register_fake(...): the operator {self.qualname} "
  44. f"already has an DispatchKey::Meta implementation via a "
  45. f"pre-existing torch.library or TORCH_LIBRARY registration. "
  46. f"Please either remove that registration or don't call "
  47. f"register_fake."
  48. )
  49. if torch._C._dispatch_has_kernel_for_dispatch_key(
  50. self.qualname, "CompositeImplicitAutograd"
  51. ):
  52. raise RuntimeError(
  53. f"register_fake(...): the operator {self.qualname} "
  54. f"already has an implementation for this device type via a "
  55. f"pre-existing registration to "
  56. f"DispatchKey::CompositeImplicitAutograd."
  57. f"CompositeImplicitAutograd operators do not need an fake "
  58. f"impl; "
  59. f"instead, the operator will decompose into its constituents "
  60. f"and those "
  61. f"can have fake impls defined on them."
  62. )
  63. # Store the kernel in this holder
  64. kernel = Kernel(func, source)
  65. self.kernels.append(kernel)
  66. def deregister_fake_kernel():
  67. self.kernels.remove(kernel)
  68. meta_kernel = construct_meta_kernel(self.qualname, self)
  69. lib.impl(self.qualname, meta_kernel, "Meta", allow_override=allow_override)
  70. handle = RegistrationHandle(deregister_fake_kernel)
  71. return handle
  72. def construct_meta_kernel(qualname: str, fake_impl_holder: FakeImplHolder) -> Callable:
  73. assert fake_impl_holder.kernel is not None
  74. @functools.wraps(fake_impl_holder.kernel.func)
  75. def meta_kernel(*args, **kwargs):
  76. assert fake_impl_holder.kernel is not None
  77. source = fake_impl_holder.kernel.source
  78. def error_on_ctx():
  79. raise RuntimeError(
  80. f"{qualname} ({source}): You're trying to run this operator "
  81. f"with meta Tensors (as opposed to FakeTensors), but this "
  82. f"operator may return an output Tensor with data-dependent shape. Meta "
  83. f"Tensors don't support operators with outputs that have data-dependent shapes "
  84. f"but FakeTensors do. "
  85. f"If your operator does not return an output with data-dependent shape, "
  86. f"make sure the FakeTensor and/or meta kernel does not call "
  87. f"torch.library.get_ctx(). Otherwise, please use FakeTensors."
  88. )
  89. with set_ctx_getter(error_on_ctx):
  90. return fake_impl_holder.kernel(*args, **kwargs)
  91. return meta_kernel
  92. def get_none():
  93. return None
  94. global_ctx_getter: Callable = get_none
  95. @contextlib.contextmanager
  96. def set_ctx_getter(ctx_getter):
  97. global global_ctx_getter
  98. prev = global_ctx_getter
  99. try:
  100. global_ctx_getter = ctx_getter
  101. yield
  102. finally:
  103. global_ctx_getter = prev
  104. class FakeImplCtx:
  105. """
  106. Context object for writing fake implementations for custom operators.
  107. """
  108. def __init__(self, _fake_mode, _op):
  109. self._fake_mode = _fake_mode
  110. self._shape_env = _fake_mode.shape_env
  111. self._op = _op
  112. @deprecated(
  113. "`create_unbacked_symint` is deprecated, please use `new_dynamic_size` instead",
  114. category=FutureWarning,
  115. )
  116. def create_unbacked_symint(self, *, min=2, max=None) -> torch.SymInt:
  117. return self.new_dynamic_size(min=min, max=max)
  118. def new_dynamic_size(self, *, min=0, max=None) -> torch.SymInt:
  119. """Constructs a new symint (symbolic int) representing a data-dependent value.
  120. This is useful for writing the fake implementation (which is necessary
  121. for torch.compile) for a CustomOp where an output Tensor has a size
  122. that depends on the data of the input Tensors.
  123. Args:
  124. min (int): A statically known inclusive lower bound for this symint. Default: 0
  125. max (Optional[int]): A statically known inclusive upper bound for this
  126. symint. Default: None
  127. .. warning:
  128. It is important that the ``min`` and ``max`` (if not None) values are set
  129. correctly, otherwise, there will be undefined behavior under
  130. torch.compile. The default value of ``min`` is 2 due to torch.compile
  131. specializing on 0/1 sizes.
  132. You must also verify that your implementation on concrete Tensors
  133. (e.g. CPU/CUDA) only returns Tensors where the size that corresponds
  134. to the symint also has respects these constraint.
  135. The easiest way to do this is to add an assertion in the CPU/CUDA/etc
  136. implementation that the size follows these bounds.
  137. Example::
  138. >>> # An operator with data-dependent output shape
  139. >>> lib = torch.library.Library("mymodule", "FRAGMENT")
  140. >>> lib.define("mymodule::custom_nonzero(Tensor x) -> Tensor")
  141. >>>
  142. >>> @torch.library.register_fake("mymodule::custom_nonzero")
  143. >>> def _(x):
  144. >>> # Number of nonzero-elements is data-dependent.
  145. >>> # Since we cannot peek at the data in an fake impl,
  146. >>> # we use the ctx object to construct a new symint that
  147. >>> # represents the data-dependent size.
  148. >>> ctx = torch.library.get_ctx()
  149. >>> nnz = ctx.new_dynamic_size()
  150. >>> shape = [nnz, x.dim()]
  151. >>> result = x.new_empty(shape, dtype=torch.int64)
  152. >>> return result
  153. >>>
  154. >>> @torch.library.impl(lib, "custom_nonzero", "CPU")
  155. >>> def _(x):
  156. >>> x_np = x.numpy()
  157. >>> res = np.stack(np.nonzero(x_np), axis=1)
  158. >>> return torch.tensor(res, device=x.device)
  159. """
  160. if (
  161. self._shape_env is None
  162. or not self._shape_env.allow_dynamic_output_shape_ops
  163. ):
  164. raise torch._subclasses.fake_tensor.DynamicOutputShapeException(self._op)
  165. if isinstance(min, torch.SymInt) or isinstance(max, torch.SymInt):
  166. raise ValueError(
  167. f"ctx.new_dynamic_size(min={min}, max={max}): expected "
  168. f"min and max to be statically known ints but got SymInt. "
  169. f"This is not supported."
  170. )
  171. if min < 0:
  172. raise ValueError(
  173. f"ctx.new_dynamic_size(min={min}, ...): expected min to be "
  174. f"greater than or equal to 0: this API can only create "
  175. f"non-negative sizes."
  176. )
  177. return allocate_size(self._shape_env, min, max)
  178. def allocate_size(shape_env, min_val=0, max_val=None):
  179. result = shape_env.create_unbacked_symint()
  180. torch.fx.experimental.symbolic_shapes._constrain_range_for_size(
  181. result, min=min_val, max=max_val
  182. )
  183. return result