_utils_internal.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import os
  5. import sys
  6. import tempfile
  7. import typing_extensions
  8. from collections.abc import Callable
  9. from typing import Any, TypeVar
  10. from typing_extensions import ParamSpec
  11. import torch
  12. from torch._strobelight.compile_time_profiler import StrobelightCompileTimeProfiler
  13. _T = TypeVar("_T")
  14. _P = ParamSpec("_P")
  15. log = logging.getLogger(__name__)
  16. if os.environ.get("TORCH_COMPILE_STROBELIGHT", False):
  17. import shutil
  18. if not shutil.which("strobeclient"):
  19. log.info(
  20. "TORCH_COMPILE_STROBELIGHT is true, but seems like you are not on a FB machine."
  21. )
  22. else:
  23. log.info("Strobelight profiler is enabled via environment variable")
  24. StrobelightCompileTimeProfiler.enable()
  25. # this arbitrary-looking assortment of functionality is provided here
  26. # to have a central place for overridable behavior. The motivating
  27. # use is the FB build environment, where this source file is replaced
  28. # by an equivalent.
  29. if os.path.basename(os.path.dirname(__file__)) == "shared":
  30. torch_parent = os.path.dirname(os.path.dirname(os.path.dirname(__file__)))
  31. else:
  32. torch_parent = os.path.dirname(os.path.dirname(__file__))
  33. def get_file_path(*path_components: str) -> str:
  34. return os.path.join(torch_parent, *path_components)
  35. def get_file_path_2(*path_components: str) -> str:
  36. return os.path.join(*path_components)
  37. def get_writable_path(path: str) -> str:
  38. if os.access(path, os.W_OK):
  39. return path
  40. return tempfile.mkdtemp(suffix=os.path.basename(path))
  41. def prepare_multiprocessing_environment(path: str) -> None:
  42. pass
  43. def resolve_library_path(path: str) -> str:
  44. return os.path.realpath(path)
  45. def throw_abstract_impl_not_imported_error(opname, module, context):
  46. if module in sys.modules:
  47. raise NotImplementedError(
  48. f"{opname}: We could not find the fake impl for this operator. "
  49. )
  50. else:
  51. raise NotImplementedError(
  52. f"{opname}: We could not find the fake impl for this operator. "
  53. f"The operator specified that you may need to import the '{module}' "
  54. f"Python module to load the fake impl. {context}"
  55. )
  56. # NB! This treats "skip" kwarg specially!!
  57. def compile_time_strobelight_meta(
  58. phase_name: str,
  59. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
  60. def compile_time_strobelight_meta_inner(
  61. function: Callable[_P, _T],
  62. ) -> Callable[_P, _T]:
  63. @functools.wraps(function)
  64. def wrapper_function(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  65. if "skip" in kwargs and isinstance(
  66. # pyrefly: ignore [unsupported-operation]
  67. skip := kwargs["skip"],
  68. int,
  69. ):
  70. kwargs["skip"] = skip + 1
  71. # This is not needed but we have it here to avoid having profile_compile_time
  72. # in stack traces when profiling is not enabled.
  73. if not StrobelightCompileTimeProfiler.enabled:
  74. return function(*args, **kwargs)
  75. return StrobelightCompileTimeProfiler.profile_compile_time(
  76. function, phase_name, *args, **kwargs
  77. )
  78. return wrapper_function
  79. return compile_time_strobelight_meta_inner
  80. # Meta only, see
  81. # https://www.internalfb.com/intern/wiki/ML_Workflow_Observability/User_Guides/Adding_instrumentation_to_your_code/
  82. #
  83. # This will cause an event to get logged to Scuba via the signposts API. You
  84. # can view samples on the API at https://fburl.com/scuba/workflow_signpost/zh9wmpqs
  85. # we log to subsystem "torch", and the category and name you provide here.
  86. # Each of the arguments translate into a Scuba column. We're still figuring
  87. # out local conventions in PyTorch, but category should be something like
  88. # "dynamo" or "inductor", and name should be a specific string describing what
  89. # kind of event happened.
  90. #
  91. # Killswitch is at
  92. # https://www.internalfb.com/intern/justknobs/?name=pytorch%2Fsignpost#event
  93. def signpost_event(category: str, name: str, parameters: dict[str, Any]):
  94. log.info("%s %s: %r", category, name, parameters)
  95. def add_mlhub_insight(category: str, insight: str, insight_description: str):
  96. pass
  97. def log_compilation_event(metrics):
  98. log.info("%s", metrics)
  99. def upload_graph(graph):
  100. pass
  101. def set_pytorch_distributed_envs_from_justknobs():
  102. pass
  103. def log_export_usage(**kwargs):
  104. pass
  105. def log_draft_export_usage(**kwargs):
  106. pass
  107. def log_trace_structured_event(*args, **kwargs) -> None:
  108. pass
  109. def log_cache_bypass(*args, **kwargs) -> None:
  110. pass
  111. def log_torchscript_usage(api: str, **kwargs):
  112. _ = api
  113. return
  114. def check_if_torch_exportable():
  115. return False
  116. def export_training_ir_rollout_check() -> bool:
  117. return True
  118. def full_aoti_runtime_assert() -> bool:
  119. return True
  120. def log_torch_jit_trace_exportability(
  121. api: str,
  122. type_of_export: str,
  123. export_outcome: str,
  124. result: str,
  125. ):
  126. _, _, _, _ = api, type_of_export, export_outcome, result
  127. return
  128. DISABLE_JUSTKNOBS = True
  129. def justknobs_check(name: str, default: bool = True) -> bool:
  130. """
  131. This function can be used to killswitch functionality in FB prod,
  132. where you can toggle this value to False in JK without having to
  133. do a code push. In OSS, we always have everything turned on all
  134. the time, because downstream users can simply choose to not update
  135. PyTorch. (If more fine-grained enable/disable is needed, we could
  136. potentially have a map we lookup name in to toggle behavior. But
  137. the point is that it's all tied to source code in OSS, since there's
  138. no live server to query.)
  139. This is the bare minimum functionality I needed to do some killswitches.
  140. We have a more detailed plan at
  141. https://docs.google.com/document/d/1Ukerh9_42SeGh89J-tGtecpHBPwGlkQ043pddkKb3PU/edit
  142. In particular, in some circumstances it may be necessary to read in
  143. a knob once at process start, and then use it consistently for the
  144. rest of the process. Future functionality will codify these patterns
  145. into a better high level API.
  146. WARNING: Do NOT call this function at module import time, JK is not
  147. fork safe and you will break anyone who forks the process and then
  148. hits JK again.
  149. """
  150. return default
  151. def justknobs_getval_int(name: str) -> int:
  152. """
  153. Read warning on justknobs_check
  154. """
  155. return 0
  156. def is_fb_unit_test() -> bool:
  157. return False
  158. @functools.cache
  159. def max_clock_rate():
  160. """
  161. unit: MHz
  162. """
  163. if not torch.version.hip:
  164. from triton.testing import nvsmi
  165. return nvsmi(["clocks.max.sm"])[0]
  166. else:
  167. # Manually set max-clock speeds on ROCm until equivalent nvmsi
  168. # functionality in triton.testing or via pyamdsmi enablement. Required
  169. # for test_snode_runtime unit tests.
  170. gcn_arch = str(torch.cuda.get_device_properties(0).gcnArchName.split(":", 1)[0])
  171. if "gfx94" in gcn_arch:
  172. return 1700
  173. elif "gfx90a" in gcn_arch:
  174. return 1700
  175. elif "gfx908" in gcn_arch:
  176. return 1502
  177. elif "gfx12" in gcn_arch:
  178. return 1700
  179. elif "gfx11" in gcn_arch:
  180. return 1700
  181. elif "gfx103" in gcn_arch:
  182. return 1967
  183. elif "gfx101" in gcn_arch:
  184. return 1144
  185. elif "gfx95" in gcn_arch:
  186. return 1700 # TODO: placeholder, get actual value
  187. else:
  188. return 1100
  189. def get_mast_job_name_version() -> tuple[str, int] | None:
  190. return None
  191. TEST_MASTER_ADDR = "127.0.0.1"
  192. TEST_MASTER_PORT = 29500
  193. # USE_GLOBAL_DEPS controls whether __init__.py tries to load
  194. # libtorch_global_deps, see Note [Global dependencies]
  195. USE_GLOBAL_DEPS = True
  196. # USE_RTLD_GLOBAL_WITH_LIBTORCH controls whether __init__.py tries to load
  197. # _C.so with RTLD_GLOBAL during the call to dlopen.
  198. USE_RTLD_GLOBAL_WITH_LIBTORCH = False
  199. # If an op was defined in C++ and extended from Python using the
  200. # torch.library.register_fake, returns if we require that there be a
  201. # m.set_python_module("mylib.ops") call from C++ that associates
  202. # the C++ op with a python module.
  203. REQUIRES_SET_PYTHON_MODULE = False
  204. def maybe_upload_prof_stats_to_manifold(profile_path: str) -> str | None:
  205. print("Uploading profile stats (fb-only otherwise no-op)")
  206. return None
  207. def log_chromium_event_internal(
  208. event: dict[str, Any],
  209. stack: list[str],
  210. logger_uuid: str,
  211. start_time_ns: int,
  212. ):
  213. return None
  214. def record_chromium_event_internal(
  215. event: dict[str, Any],
  216. ):
  217. return None
  218. def profiler_allow_cudagraph_cupti_lazy_reinit_cuda12():
  219. return True
  220. def deprecated():
  221. """
  222. When we deprecate a function that might still be in use, we make it internal
  223. by adding a leading underscore. This decorator is used with a private function,
  224. and creates a public alias without the leading underscore, but has a deprecation
  225. warning. This tells users "THIS FUNCTION IS DEPRECATED, please use something else"
  226. without breaking them, however, if they still really really want to use the
  227. deprecated function without the warning, they can do so by using the internal
  228. function name.
  229. """
  230. def decorator(func: Callable[_P, _T]) -> Callable[_P, _T]:
  231. # Validate naming convention - single leading underscore, not dunder
  232. if not (func.__name__.startswith("_")):
  233. raise ValueError(
  234. "@deprecate must decorate a function whose name "
  235. "starts with a single leading underscore (e.g. '_foo') as the api should be considered internal for deprecation."
  236. )
  237. public_name = func.__name__[1:] # drop exactly one leading underscore
  238. module = sys.modules[func.__module__]
  239. # Don't clobber an existing symbol accidentally.
  240. if hasattr(module, public_name):
  241. raise RuntimeError(
  242. f"Cannot create alias '{public_name}' -> symbol already exists in {module.__name__}. \
  243. Please rename it or consult a pytorch developer on what to do"
  244. )
  245. warning_msg = f"{func.__name__[1:]} is DEPRECATED, please consider using an alternative API(s). "
  246. # public deprecated alias
  247. alias = typing_extensions.deprecated(
  248. # pyrefly: ignore [bad-argument-type]
  249. warning_msg,
  250. category=UserWarning,
  251. stacklevel=1,
  252. )(func)
  253. alias.__name__ = public_name
  254. # Adjust qualname if nested inside a class or another function
  255. if "." in func.__qualname__:
  256. alias.__qualname__ = func.__qualname__.rsplit(".", 1)[0] + "." + public_name
  257. else:
  258. alias.__qualname__ = public_name
  259. setattr(module, public_name, alias)
  260. return func
  261. return decorator
  262. def get_default_numa_options():
  263. """
  264. When using elastic agent, if no numa options are provided, we will use these
  265. as the default.
  266. For external use cases, we return None, i.e. no numa binding. If you would like
  267. to use torch's automatic numa binding capabilities, you should provide
  268. NumaOptions to your launch config directly or use the numa binding option
  269. available in torchrun.
  270. Must return None or NumaOptions, but not specifying to avoid circular import.
  271. """
  272. return None
  273. def log_triton_builds(fail: str | None):
  274. pass
  275. def find_compile_subproc_binary() -> str | None:
  276. """
  277. Allows overriding the binary used for subprocesses
  278. """
  279. return None