wrappers.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. # mypy: allow-untyped-defs
  2. import inspect
  3. from contextlib import contextmanager
  4. from functools import wraps
  5. import torch
  6. import torch._custom_ops
  7. from torch._C import DispatchKey
  8. from torch._export.utils import _maybe_find_pre_dispatch_tf_mode_for_export
  9. from torch._higher_order_ops.flat_apply import (
  10. _ConstantFunction,
  11. flat_apply,
  12. to_graphable,
  13. )
  14. from torch._higher_order_ops.strict_mode import strict_mode
  15. from torch._higher_order_ops.utils import autograd_not_implemented
  16. from torch._ops import HigherOrderOperator
  17. from torch._subclasses.fake_tensor import FakeTensorMode
  18. from torch.fx.experimental.proxy_tensor import (
  19. PreDispatchTorchFunctionMode,
  20. ProxyTorchDispatchMode,
  21. track_tensor_tree,
  22. )
  23. from torch.utils import _pytree as pytree
  24. from torch.utils._python_dispatch import is_traceable_wrapper_subclass_type
  25. class ExportTracepoint(HigherOrderOperator):
  26. def __init__(self):
  27. super().__init__("_export_tracepoint")
  28. def __call__(self, *args, **kwargs):
  29. return super().__call__(*args, **kwargs)
  30. _export_tracepoint = ExportTracepoint()
  31. @_export_tracepoint.py_impl(ProxyTorchDispatchMode)
  32. def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
  33. p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
  34. proxy = mode.tracer.create_proxy(
  35. "call_function", _export_tracepoint, p_args, p_kwargs
  36. )
  37. return track_tensor_tree(args, proxy, constant=None, tracer=mode.tracer)
  38. @_export_tracepoint.py_impl(FakeTensorMode)
  39. def export_tracepoint_fake_tensor_mode(mode, *args, **kwargs):
  40. with mode:
  41. return args
  42. @_export_tracepoint.py_functionalize_impl
  43. def export_tracepoint_functional(ctx, *args, **kwargs):
  44. unwrapped_args = ctx.unwrap_tensors(args)
  45. unwrapped_kwargs = ctx.unwrap_tensors(kwargs)
  46. with ctx.redispatch_to_next():
  47. _export_tracepoint(*unwrapped_args, **unwrapped_kwargs)
  48. return args
  49. _export_tracepoint.py_impl(DispatchKey.Autograd)(
  50. autograd_not_implemented(_export_tracepoint, deferred_error=True)
  51. )
  52. @_export_tracepoint.py_impl(DispatchKey.CPU)
  53. def export_tracepoint_cpu(*args, **kwargs):
  54. return args
  55. def _wrap_submodule(mod, path, module_call_specs):
  56. assert isinstance(mod, torch.nn.Module)
  57. assert path != ""
  58. submodule = torch.fx.graph_module._get_attr(mod, path)
  59. def update_module_call_signatures(path, in_spec, out_spec):
  60. if path in module_call_specs:
  61. assert module_call_specs[path]["in_spec"] == in_spec
  62. assert module_call_specs[path]["out_spec"] == out_spec
  63. module_call_specs[path] = {"in_spec": in_spec, "out_spec": out_spec}
  64. def check_flattened(flat_args):
  65. for a in flat_args:
  66. if not (isinstance(a, (torch.Tensor, str, int, float, bool)) or a is None):
  67. raise AssertionError(
  68. f"Only Tensors or scalars are supported as pytree flattened inputs, got: {a}"
  69. )
  70. def pre_hook(module, args, kwargs):
  71. flat_args, in_spec = pytree.tree_flatten((args, kwargs))
  72. check_flattened(flat_args)
  73. flat_args = _export_tracepoint(*flat_args, kind="module_call_inputs", path=path)
  74. args, kwargs = pytree.tree_unflatten(flat_args, in_spec)
  75. return args, kwargs
  76. def post_hook(module, args, kwargs, res):
  77. _, in_spec = pytree.tree_flatten((args, kwargs))
  78. flat_res, out_spec = pytree.tree_flatten(res)
  79. check_flattened(flat_res)
  80. flat_res = _export_tracepoint(*flat_res, kind="module_call_outputs", path=path)
  81. update_module_call_signatures(path, in_spec, out_spec)
  82. return pytree.tree_unflatten(flat_res, out_spec)
  83. pre_handle = submodule.register_forward_pre_hook(pre_hook, with_kwargs=True)
  84. post_handle = submodule.register_forward_hook(post_hook, with_kwargs=True)
  85. return pre_handle, post_handle
  86. @contextmanager
  87. def _wrap_submodules(f, preserve_signature, module_call_signatures):
  88. handles = []
  89. try:
  90. for path in preserve_signature:
  91. handles.extend(_wrap_submodule(f, path, module_call_signatures))
  92. yield
  93. finally:
  94. for handle in handles:
  95. handle.remove()
  96. def _mark_strict_experimental(cls):
  97. def call(self, *args):
  98. return strict_mode(self, args)
  99. cls.__call__ = call
  100. return cls
  101. def _register_func_spec_proxy_in_tracer(tracer, name, spec):
  102. """
  103. This is a wrapper utility method on top of tracer to cache the
  104. already registered subclass spec attribute. This is useful because
  105. Subclass.__init__ will be same for each subclass. By default, fx will
  106. create multiple attributes/proxies for given attribute.
  107. """
  108. fx_name = name + "0"
  109. if hasattr(tracer.root, fx_name):
  110. assert getattr(tracer.root, fx_name) == spec
  111. return tracer.create_proxy("get_attr", fx_name, (), {})
  112. qualname = tracer.get_fresh_qualname(name)
  113. setattr(tracer.root, qualname, spec)
  114. return tracer.create_proxy("get_attr", qualname, (), {})
  115. def _emit_flat_apply_call(
  116. *,
  117. tracer,
  118. spec_name: str,
  119. const_target_for_apply,
  120. graphable_args,
  121. track_value,
  122. call_spec_cache_key: str,
  123. ):
  124. # Flatten to graphable form and record the spec on the FX root
  125. flat_args, in_spec = to_graphable(graphable_args)
  126. qualname = tracer.get_fresh_qualname(spec_name) # type: ignore[union-attr]
  127. setattr(tracer.root, qualname, in_spec) # type: ignore[union-attr]
  128. spec_proxy = tracer.create_proxy("get_attr", qualname, (), {})
  129. # Reuse/cached ConstantFunction spec on the root
  130. _, func_spec = pytree.tree_flatten(_ConstantFunction(const_target_for_apply))
  131. func_spec_proxy = _register_func_spec_proxy_in_tracer(
  132. tracer, f"{call_spec_cache_key}_const_func_spec", func_spec
  133. )
  134. # Map runtime args -> proxies (always via tracer.unwrap_proxy now)
  135. flat_proxy_args = pytree.tree_map(tracer.unwrap_proxy, flat_args)
  136. # Emit flat_apply and track result structure
  137. out_proxy = tracer.create_proxy(
  138. "call_function", flat_apply, (func_spec_proxy, spec_proxy, *flat_proxy_args), {}
  139. )
  140. track_tensor_tree(track_value, out_proxy, constant=None, tracer=tracer)
  141. def _is_init(fn):
  142. return callable(fn) and fn.__name__ == "__init__"
  143. def mark_subclass_constructor_exportable_experimental(constructor_subclass):
  144. """
  145. Experimental decorator that makes subclass to be traceable in export
  146. with pre-dispatch IR. To make your subclass traceble in export, you need to:
  147. 1. Implement __init__ method for your subclass (Look at DTensor implementation)
  148. 2. Decorate your __init__ method with _mark_constructor_exportable_experimental
  149. 3. Put torch._dynamo_disable decorator to prevent dynamo from peeking into its' impl
  150. Example:
  151. class FooTensor(torch.Tensor):
  152. @staticmethod
  153. def __new__(cls, elem, *, requires_grad=False):
  154. # ...
  155. return torch.Tensor._make_subclass(cls, elem, requires_grad=requires_grad)
  156. @torch._dynamo_disable
  157. @mark_subclass_constructor_exportable_experimental
  158. def __init__(self, elem, ...):
  159. # ...
  160. """
  161. if not _is_init(constructor_subclass):
  162. raise RuntimeError(
  163. f"torch._export.wrappers.mark_constructor_exportable_experimental can only be applied on subclass tensor.__init__"
  164. f"But, you are adding it on {constructor_subclass.__name__} which is not supported. "
  165. f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example"
  166. )
  167. def wrapper(*args, **kwargs):
  168. constructor_subclass(*args, **kwargs)
  169. if not torch.compiler.is_exporting():
  170. return
  171. if not is_traceable_wrapper_subclass_type(type(args[0])):
  172. assert constructor_subclass.__qualname__.endswith("__init__")
  173. obj_name = constructor_subclass.__qualname__[: -len("__init__")]
  174. raise RuntimeError(
  175. f"Can't intercept {obj_name} in export because this object is not a traceable "
  176. f"tensor subclass. Please look at DTensor.__init__ implementation as an example of proper usage of this API."
  177. )
  178. mode = _maybe_find_pre_dispatch_tf_mode_for_export()
  179. if mode is None:
  180. return
  181. assert isinstance(mode, PreDispatchTorchFunctionMode)
  182. tracer = mode.tracer
  183. subclass = args[0]
  184. graphable = (tuple(args[1:]), kwargs)
  185. spec_name = "_".join(constructor_subclass.__qualname__.lower().split("."))
  186. call_spec_cache_key = type(subclass).__name__.lower()
  187. _emit_flat_apply_call(
  188. tracer=tracer,
  189. spec_name=spec_name,
  190. const_target_for_apply=type(subclass),
  191. graphable_args=graphable,
  192. track_value=subclass, # track the constructed subclass instance
  193. call_spec_cache_key=call_spec_cache_key,
  194. )
  195. return
  196. return wrapper
  197. def allow_in_pre_dispatch_graph(func):
  198. """
  199. Experimental decorator that adds user function to export pre-dispatch graph. Note that
  200. we only support custom autograd function/subclass constructors today. To use this function:
  201. 1. For subclasses:
  202. 1. refer to instructions in mark_subclass_constructor_exportable_experimental
  203. 2. Define apply method on your custom autograd function and apply this decorator.
  204. Example:
  205. class MyCoolCustomAutogradFunc(autograd.Function):
  206. @classmethod
  207. @torch._export.wrappers.allow_in_pre_dispatch_graph
  208. def apply(cls, *args, **kwargs):
  209. return super(MyCoolCustomAutogradFunc, cls).apply(*args, **kwargs)
  210. """
  211. if _is_init(func):
  212. return mark_subclass_constructor_exportable_experimental(func)
  213. if not (_is_init(func) or func.__name__ == "apply"):
  214. raise RuntimeError(
  215. f"torch._export.wrappers.allow_in_pre_dispatch_graph can only be applied on subclass tensor.__init_ "
  216. f"or custom_autograd_function.apply. "
  217. f"But, you are adding it on {func.__name__} which is not supported. "
  218. f"If __init__ doesn't exist on your subclass, please add it. Look at DTensor.__init__ implementation for example. "
  219. f"If you are adding it on custom autograd function, please add it on apply method. "
  220. f"If anything else, file an issue on github and we may consider extending our support. "
  221. )
  222. @wraps(func)
  223. def wrapper(*args, **kwargs):
  224. if not torch.compiler.is_exporting():
  225. return func(*args, **kwargs)
  226. if not inspect.isclass(args[0]):
  227. return func(*args, **kwargs)
  228. if not issubclass(args[0], torch.autograd.Function):
  229. return func(*args, **kwargs)
  230. from torch._ops import _get_dispatch_mode_pre_dispatch
  231. mode = _get_dispatch_mode_pre_dispatch(torch._C._TorchDispatchModeKey.PROXY)
  232. if mode is None:
  233. return func(*args, **kwargs)
  234. # Sometimes custom autograd functions can call into HOPs that don't have proxy impl
  235. # at PreDispatch level, so we just dispatch it below to get the concrete result.
  236. include_to_set = torch._C._dispatch_tls_local_include_set().remove(
  237. torch._C.DispatchKey.PreDispatch
  238. )
  239. exclude_to_set = (
  240. torch._C._dispatch_tls_local_exclude_set()
  241. | torch._C.DispatchKeySet(torch._C.DispatchKey.PreDispatch)
  242. )
  243. with torch._C._ForceDispatchKeyGuard(include_to_set, exclude_to_set):
  244. out = func(*args, **kwargs)
  245. assert mode.pre_dispatch, "Should only do this in predispatch"
  246. tracer = mode.tracer
  247. function_cls_name = f"{args[0].__module__}.{args[0].__qualname__}"
  248. graphable = ((function_cls_name, *args[1:]), kwargs)
  249. from torch.export.custom_ops import (
  250. _call_custom_autograd_function_in_pre_dispatch,
  251. )
  252. spec_name = "_".join(function_cls_name.split("."))
  253. call_spec_cache_key = type(
  254. _call_custom_autograd_function_in_pre_dispatch
  255. ).__name__.lower()
  256. _emit_flat_apply_call(
  257. tracer=tracer,
  258. spec_name=spec_name,
  259. const_target_for_apply=_call_custom_autograd_function_in_pre_dispatch,
  260. graphable_args=graphable,
  261. track_value=out,
  262. call_spec_cache_key=call_spec_cache_key,
  263. )
  264. return out
  265. return wrapper