_compile.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859
  1. """
  2. APIs related to torch.compile which lazily import torch._dynamo to avoid
  3. circular dependencies.
  4. """
  5. import functools
  6. from typing import Callable, Literal, Optional, overload, TypeVar, Union
  7. from typing_extensions import ParamSpec
  8. _T = TypeVar("_T")
  9. _P = ParamSpec("_P")
  10. @overload
  11. def _disable_dynamo(
  12. fn: Callable[_P, _T], recursive: bool = True
  13. ) -> Callable[_P, _T]: ...
  14. @overload
  15. def _disable_dynamo(
  16. fn: Literal[None] = None, recursive: bool = True
  17. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: ...
  18. def _disable_dynamo(
  19. fn: Optional[Callable[_P, _T]] = None, recursive: bool = True
  20. ) -> Union[Callable[_P, _T], Callable[[Callable[_P, _T]], Callable[_P, _T]]]:
  21. """
  22. This API should be only used inside torch, external users should still use
  23. torch._dynamo.disable. The main goal of this API is to avoid circular
  24. imports issues that is common while using _dynamo.disable inside torch
  25. itself.
  26. This API avoids it by lazily importing torch._dynamo from the import time to
  27. the invocation of the decorated function.
  28. """
  29. if fn is not None:
  30. @functools.wraps(fn)
  31. def inner(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  32. # cache this on the first invocation to avoid adding too much overhead.
  33. disable_fn = getattr(fn, "__dynamo_disable", None)
  34. if disable_fn is None:
  35. import torch._dynamo
  36. # We can safely turn off functools.wraps here because the inner
  37. # already wraps fn in the outer scope.
  38. disable_fn = torch._dynamo.disable(fn, recursive, wrapping=False)
  39. fn.__dynamo_disable = disable_fn # type: ignore[attr-defined]
  40. return disable_fn(*args, **kwargs)
  41. return inner
  42. else:
  43. # decorator usage like @_disable_dynamo(recursive=False). The resulting
  44. # object expects the original decorated function as the arg.
  45. return functools.partial(_disable_dynamo, recursive=recursive)