_internal.py 51 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import functools
  4. import hashlib
  5. import importlib.util
  6. import itertools
  7. import json
  8. import logging
  9. import os
  10. import os.path
  11. import pathlib
  12. import re
  13. import sys
  14. import tempfile
  15. import time
  16. import warnings
  17. from collections import defaultdict
  18. from dataclasses import dataclass, field
  19. from typing import Any, Callable, Generic, Optional, Union
  20. from typing_extensions import ParamSpec
  21. from weakref import WeakSet
  22. import torch._logging.structured
  23. from torch._guards import CompileId
  24. from torch._utils_internal import log_trace_structured_event
  25. from torch.utils._traceback import CapturedTraceback
  26. _P = ParamSpec("_P")
  27. log = logging.getLogger(__name__)
  28. # This is a synthetic logger which doesn't correspond to an actual logger,
  29. # but handles all of our "tracing" logging, which is structured and doesn't go
  30. # to stderr but always goes to a dedicated log file. We don't put these
  31. # loggers in the classic module hierarchy, because we don't want a suppression
  32. # of logs to also cause a trace to get suppressed (traces typically are not
  33. # collected, unless we are in prod, in which case they always are collected.)
  34. #
  35. # TODO: Maybe we should allow for some sub-hierarchy so you can control which
  36. # traces you want to collect, for performance reasons.
  37. #
  38. # See https://docs.google.com/document/d/1CX_hJ0PNy9f3R1y8TJrfkSeLkvGjjjLU84BSXgS2AZ8/edit
  39. trace_log = logging.getLogger("torch.__trace")
  40. DEFAULT_LOG_LEVEL = logging.WARNING
  41. LOG_ENV_VAR = "TORCH_LOGS"
  42. LOG_OUT_ENV_VAR = "TORCH_LOGS_OUT"
  43. LOG_FORMAT_ENV_VAR = "TORCH_LOGS_FORMAT"
  44. LOG_TRACE_ID_FILTER = "TORCH_LOGS_TRACE_ID_FILTER"
  45. TRACE_ENV_VAR = "TORCH_TRACE"
  46. DTRACE_ENV_VAR = "TORCH_DTRACE"
  47. LOG_TRACE_HANDLER: Optional["LazyTraceHandler"] = None
  48. GET_DTRACE_STRUCTURED = False
  49. @dataclass
  50. class LogRegistry:
  51. # shorthand name to log qualified name
  52. # Note: this only contains loggers registered
  53. # from register_log
  54. # e.g. "dynamo" -> "torch._dynamo"
  55. log_alias_to_log_qnames: dict[str, list[str]] = field(default_factory=dict)
  56. # artifact logger qualified names,
  57. # this is populated lazily, as calls to getArtifactLogger
  58. # currently formatted as <module>.__<artifact_name>
  59. # e.g. "torch._dynamo.convert_frame.__guards"
  60. artifact_log_qnames: set[str] = field(default_factory=set)
  61. # child logs of registered logs if specified via open
  62. # registration by the user (ie placing "torch._dynamo.output_graph" in the env var)
  63. # these need to be tracked so their levels can be reset properly
  64. # e.g. "torch._dynamo.output_graph"
  65. child_log_qnames: set[str] = field(default_factory=set)
  66. # artifact names, populated by register_artifact
  67. # e.g. "guards"
  68. artifact_names: set[str] = field(default_factory=set)
  69. # Artifacts that should be visible by default in the error message
  70. visible_artifacts: set[str] = field(default_factory=set)
  71. # A short description of each artifact
  72. artifact_descriptions: dict[str, str] = field(default_factory=dict)
  73. # artifacts which are not displayed unless explicitly named in the
  74. # settings. Ex. output_code is NOT displayed even if the inductor
  75. # log level is set to DEBUG. It must be explicitly named in the settings
  76. off_by_default_artifact_names: set[str] = field(default_factory=set)
  77. # logging format string for artifacts
  78. artifact_log_formatters: dict[str, logging.Formatter] = field(default_factory=dict)
  79. def is_artifact(self, name):
  80. return name in self.artifact_names
  81. def is_log(self, alias):
  82. return alias in self.log_alias_to_log_qnames
  83. # register a log with an alias
  84. def register_log(self, alias, log_qnames: Union[str, list[str]]) -> None:
  85. if isinstance(log_qnames, str):
  86. log_qnames = [log_qnames]
  87. self.log_alias_to_log_qnames[alias] = log_qnames
  88. # register an artifact name
  89. def register_artifact_name(
  90. self, name, description, visible, off_by_default, log_format
  91. ) -> None:
  92. self.artifact_names.add(name)
  93. if visible:
  94. self.visible_artifacts.add(name)
  95. self.artifact_descriptions[name] = description
  96. # if off by default, don't enable it
  97. # when log_name's log_level is set to DEBUG
  98. if off_by_default:
  99. self.off_by_default_artifact_names.add(name)
  100. if log_format is not None:
  101. self.artifact_log_formatters[name] = logging.Formatter(log_format)
  102. # register the qualified name of an artifact log
  103. # this is needed to know which logs need to be reset
  104. # whenever the log_state is changed
  105. def register_artifact_log(self, artifact_log_qname) -> None:
  106. self.artifact_log_qnames.add(artifact_log_qname)
  107. def register_child_log(self, log_qname) -> None:
  108. self.child_log_qnames.add(log_qname)
  109. # flattens all the qnames together (TODO: consider memoizing?)
  110. def get_log_qnames(self) -> set[str]:
  111. return set(itertools.chain.from_iterable(self.log_alias_to_log_qnames.values()))
  112. def get_artifact_log_qnames(self):
  113. return set(self.artifact_log_qnames)
  114. def get_child_log_qnames(self):
  115. return set(self.child_log_qnames)
  116. def is_off_by_default(self, artifact_qname):
  117. return artifact_qname in self.off_by_default_artifact_names
  118. @dataclass
  119. class LogState:
  120. # qualified log names -> currently set log level
  121. log_qname_to_level: dict[str, str] = field(default_factory=dict)
  122. # the set of currently enabled artifacts
  123. artifact_names: set[str] = field(default_factory=set)
  124. def enable_artifact(self, artifact_name) -> None:
  125. self.artifact_names.add(artifact_name)
  126. def is_artifact_enabled(self, name):
  127. return name in self.artifact_names
  128. def enable_log(self, log_qnames, log_level) -> None:
  129. if isinstance(log_qnames, str):
  130. log_qnames = [log_qnames]
  131. for log_qname in log_qnames:
  132. self.log_qname_to_level[log_qname] = log_level
  133. def get_log_level_pairs(self):
  134. """Returns all qualified module names for which the user requested
  135. explicit logging settings.
  136. .. warning:
  137. This function used to return all loggers, regardless of whether
  138. or not the user specified them or not; it now only returns logs
  139. which were explicitly mentioned by the user (and torch, which
  140. always is implicitly requested when we initialize our logging
  141. subsystem.)
  142. """
  143. return self.log_qname_to_level.items()
  144. def clear(self) -> None:
  145. self.log_qname_to_level.clear()
  146. self.artifact_names.clear()
  147. log_registry = LogRegistry()
  148. log_state = LogState()
  149. # sample usage: torch._logging.set_logs(**torch._logging.DEFAULT_LOGGING)
  150. DEFAULT_LOGGING = {
  151. "dynamo": logging.INFO,
  152. "aot": logging.INFO,
  153. "inductor": logging.INFO,
  154. "fsdp": logging.INFO,
  155. "ddp_graphs": True,
  156. "graph_breaks": True,
  157. "guards": True,
  158. "recompiles": True,
  159. "dynamic": logging.INFO,
  160. }
  161. def set_logs(
  162. *,
  163. all: Optional[int] = None,
  164. dynamo: Optional[int] = None,
  165. aot: Optional[int] = None,
  166. autograd: Optional[int] = None,
  167. dynamic: Optional[int] = None,
  168. inductor: Optional[int] = None,
  169. distributed: Optional[int] = None,
  170. c10d: Optional[int] = None,
  171. ddp: Optional[int] = None,
  172. fsdp: Optional[int] = None,
  173. dtensor: Optional[int] = None,
  174. onnx: Optional[int] = None,
  175. bytecode: bool = False,
  176. aot_graphs: bool = False,
  177. aot_joint_graph: bool = False,
  178. ddp_graphs: bool = False,
  179. graph: bool = False,
  180. graph_code: bool = False,
  181. graph_code_verbose: bool = False,
  182. graph_breaks: bool = False,
  183. graph_sizes: bool = False,
  184. guards: bool = False,
  185. recompiles: bool = False,
  186. recompiles_verbose: bool = False,
  187. trace_source: bool = False,
  188. trace_call: bool = False,
  189. trace_bytecode: bool = False,
  190. output_code: bool = False,
  191. kernel_code: bool = False,
  192. schedule: bool = False,
  193. perf_hints: bool = False,
  194. pre_grad_graphs: bool = False,
  195. post_grad_graphs: bool = False,
  196. ir_pre_fusion: bool = False,
  197. ir_post_fusion: bool = False,
  198. onnx_diagnostics: bool = False,
  199. fusion: bool = False,
  200. overlap: bool = False,
  201. export: Optional[int] = None,
  202. modules: Optional[dict[str, Union[int, bool]]] = None,
  203. cudagraphs: bool = False,
  204. sym_node: bool = False,
  205. compiled_autograd: bool = False,
  206. compiled_autograd_verbose: bool = False,
  207. cudagraph_static_inputs: bool = False,
  208. benchmarking: bool = False,
  209. autotuning: bool = False,
  210. graph_region_expansion: bool = False,
  211. inductor_metrics: bool = False,
  212. hierarchical_compile: bool = False,
  213. compute_dependencies: bool = False,
  214. ) -> None:
  215. """
  216. Sets the log level for individual components and toggles individual log
  217. artifact types.
  218. .. warning:: This feature is a prototype and may have compatibility
  219. breaking changes in the future.
  220. .. note:: The ``TORCH_LOGS`` environment variable has complete precedence
  221. over this function, so if it was set, this function does nothing.
  222. A component is a set of related features in PyTorch. All of the log
  223. messages emitted from a given component have their own log levels. If the
  224. log level of a particular message has priority greater than or equal to its
  225. component's log level setting, it is emitted. Otherwise, it is suppressed.
  226. This allows you to, for instance, silence large groups of log messages that
  227. are not relevant to you and increase verbosity of logs for components that
  228. are relevant. The expected log level values, ordered from highest to lowest
  229. priority, are:
  230. * ``logging.CRITICAL``
  231. * ``logging.ERROR``
  232. * ``logging.WARNING``
  233. * ``logging.INFO``
  234. * ``logging.DEBUG``
  235. * ``logging.NOTSET``
  236. See documentation for the Python ``logging`` module for more information on
  237. log levels: `<https://docs.python.org/3/library/logging.html#logging-levels>`_
  238. An artifact is a particular type of log message. Each artifact is assigned
  239. to a parent component. A component can emit many different kinds of
  240. artifacts. In general, an artifact is emitted if either its corresponding
  241. setting in the argument list below is turned on or if its parent component
  242. is set to a log level less than or equal to the log level of the artifact.
  243. Keyword args:
  244. all (:class:`Optional[int]`):
  245. The default log level for all components. Default: ``logging.WARN``
  246. dynamo (:class:`Optional[int]`):
  247. The log level for the TorchDynamo component. Default: ``logging.WARN``
  248. aot (:class:`Optional[int]`):
  249. The log level for the AOTAutograd component. Default: ``logging.WARN``
  250. autograd (:class:`Optional[int]`):
  251. The log level for autograd. Default: ``logging.WARN``
  252. inductor (:class:`Optional[int]`):
  253. The log level for the TorchInductor component. Default: ``logging.WARN``
  254. dynamic (:class:`Optional[int]`):
  255. The log level for dynamic shapes. Default: ``logging.WARN``
  256. distributed (:class:`Optional[int]`):
  257. Whether to log c10d communication operations and other debug info from PyTorch Distributed components.
  258. Default: ``logging.WARN``
  259. c10d (:class:`Optional[int]`):
  260. Whether to log c10d communication operations related debug info in PyTorch Distributed components.
  261. Default: ``logging.WARN``
  262. ddp (:class:`Optional[int]`):
  263. Whether to log debug info related to ``DistributedDataParallel``(DDP) from PyTorch Distributed components.
  264. Default: ``logging.WARN``
  265. fsdp (:class:`Optional[int]`):
  266. Whether to log debug info related to ``FullyShardedDataParallel``(FSDP) in PyTorch Distributed components.
  267. Default: ``logging.WARN``
  268. dtensor (:class:`Optional[int]`):
  269. Whether to log debug info related to ``DTensor``(DTensor) in PyTorch Distributed components.
  270. Default: ``logging.WARN``
  271. onnx (:class:`Optional[int]`):
  272. The log level for the ONNX exporter component. Default: ``logging.WARN``
  273. bytecode (:class:`bool`):
  274. Whether to emit the original and generated bytecode from TorchDynamo.
  275. Default: ``False``
  276. aot_graphs (:class:`bool`):
  277. Whether to emit the graphs generated by AOTAutograd. Default: ``False``
  278. aot_joint_graph (:class:`bool`):
  279. Whether to emit the joint forward-backward graph generated by AOTAutograd. Default: ``False``
  280. ddp_graphs (:class:`bool`):
  281. Whether to emit graphs generated by DDPOptimizer. Default: ``False``
  282. graph (:class:`bool`):
  283. Whether to emit the graph captured by TorchDynamo in tabular format.
  284. Default: ``False``
  285. graph_code (:class:`bool`):
  286. Whether to emit the python source of the graph captured by TorchDynamo.
  287. Default: ``False``
  288. graph_code_verbose (:class:`bool`):
  289. Whether to emit verbose/intermediate FX pass logs for graph code. Default: ``False``
  290. graph_breaks (:class:`bool`):
  291. Whether to emit the graph breaks encountered by TorchDynamo.
  292. Default: ``False``
  293. graph_sizes (:class:`bool`):
  294. Whether to emit tensor sizes of the graph captured by TorchDynamo.
  295. Default: ``False``
  296. guards (:class:`bool`):
  297. Whether to emit the guards generated by TorchDynamo for each compiled
  298. function. Default: ``False``
  299. recompiles (:class:`bool`):
  300. Whether to emit a guard failure reason and message every time
  301. TorchDynamo recompiles a function. Default: ``False``
  302. recompiles_verbose (:class:`bool`):
  303. Whether to emit all guard failure reasons when TorchDynamo recompiles
  304. a function, even those that are not actually run. Default: ``False``
  305. trace_source (:class:`bool`):
  306. Whether to emit when TorchDynamo begins tracing a new line. Default: ``False``
  307. trace_call (:class:`bool`):
  308. Whether to emit detailed line location when TorchDynamo creates an FX node
  309. corresponding to function call. Python 3.11+ only. Default: ``False``
  310. trace_bytecode (:class:`bool`):
  311. Whether to emit bytecode instructions and traced stack state as TorchDynamo
  312. traces bytecode. Default: ``False``
  313. output_code (:class:`bool`):
  314. Whether to emit the TorchInductor output code on a per-graph basis. Default: ``False``
  315. kernel_code (:class:`bool`):
  316. Whether to emit the TorchInductor output code on a per-kernel bases. Default: ``False``
  317. schedule (:class:`bool`):
  318. Whether to emit the TorchInductor schedule. Default: ``False``
  319. perf_hints (:class:`bool`):
  320. Whether to emit the TorchInductor perf hints. Default: ``False``
  321. pre_grad_graphs (:class:`bool`):
  322. Whether to emit the graphs before inductor grad passes. Default: ``False``
  323. post_grad_graphs (:class:`bool`):
  324. Whether to emit the graphs generated by after post grad passes. Default: ``False``
  325. ir_pre_fusion (:class:`bool`):
  326. Whether to emit the graphs before inductor fusion passes. Default: ``False``
  327. ir_post_fusion (:class:`bool`):
  328. Whether to emit the graphs after inductor fusion passes. Default: ``False``
  329. onnx_diagnostics (:class:`bool`):
  330. Whether to emit the ONNX exporter diagnostics in logging. Default: ``False``
  331. fusion (:class:`bool`):
  332. Whether to emit detailed Inductor fusion decisions. Default: ``False``
  333. overlap (:class:`bool`):
  334. Whether to emit detailed Inductor compute/comm overlap decisions. Default: ``False``
  335. sym_node (:class:`bool`):
  336. Whether to emit debug info for various SymNode opterations. Default: ``False``
  337. export (:class:`Optional[int]`):
  338. The log level for export. Default: ``logging.WARN``
  339. benchmarking (:class:`bool`):
  340. Whether to emit detailed Inductor benchmarking information. Default: ``False``
  341. modules (dict):
  342. This argument provides an alternate way to specify the above log
  343. component and artifact settings, in the format of a keyword args
  344. dictionary given as a single argument. There are two cases
  345. where this is useful (1) if a new log component or artifact has
  346. been registered but a keyword argument for it has not been added
  347. to this function and (2) if the log level for an unregistered module
  348. needs to be set. This can be done by providing the fully-qualified module
  349. name as the key, with the log level as the value. Default: ``None``
  350. cudagraph_static_inputs (:class:`bool`):
  351. Whether to emit debug info for cudagraph static input detection. Default: ``False``
  352. autotuning (:class:`bool`):
  353. Autotuning choice logs, such as kernel source, perf, and tuning parameters. Default: ``False``
  354. graph_region_expansion (:class:`bool`):
  355. Whether to emit the detailed steps of the duplicate graph region tracker expansion algorithm. Default: ``False``
  356. inductor_metrics (:class:`bool`):
  357. Whether to estimate the runtimes of the nodes in a graph and log them to the metrics table. Default: ``False``
  358. hierarchical_compile (:class:`bool`):
  359. Whether to emit debug info for hierarchical compilation. Default: ``False``
  360. Example::
  361. >>> # xdoctest: +SKIP
  362. >>> import logging
  363. # The following changes the "dynamo" component to emit DEBUG-level
  364. # logs, and to emit "graph_code" artifacts.
  365. >>> torch._logging.set_logs(dynamo=logging.DEBUG, graph_code=True)
  366. # The following enables the logs for a different module
  367. >>> torch._logging.set_logs(modules={"unregistered.module.name": logging.DEBUG})
  368. """
  369. # ignore if env var is set
  370. if LOG_ENV_VAR in os.environ:
  371. log.warning(
  372. "Using TORCH_LOGS environment variable for log settings, ignoring call to set_logs"
  373. )
  374. return
  375. log_state.clear()
  376. modules = modules or {}
  377. def _set_logs(**kwargs) -> None:
  378. for alias, val in itertools.chain(kwargs.items(), modules.items()): # type: ignore[union-attr]
  379. if val is None:
  380. continue
  381. if log_registry.is_artifact(alias):
  382. if not isinstance(val, bool):
  383. raise ValueError(
  384. f"Expected bool to enable artifact {alias}, received {val}"
  385. )
  386. if val:
  387. log_state.enable_artifact(alias)
  388. elif log_registry.is_log(alias) or alias in log_registry.child_log_qnames:
  389. if val not in logging._levelToName:
  390. raise ValueError(
  391. f"Unrecognized log level for log {alias}: {val}, valid level values "
  392. f"are: {','.join([str(k) for k in logging._levelToName.keys()])}"
  393. )
  394. log_state.enable_log(
  395. log_registry.log_alias_to_log_qnames.get(alias, alias), val
  396. )
  397. elif _is_valid_module(alias):
  398. if not _has_registered_parent(alias):
  399. log_registry.register_log(alias, alias)
  400. else:
  401. log_registry.register_child_log(alias)
  402. log_state.enable_log(
  403. log_registry.log_alias_to_log_qnames.get(alias, alias), val
  404. )
  405. else:
  406. raise ValueError(
  407. f"Unrecognized log or artifact name passed to set_logs: {alias}"
  408. )
  409. _init_logs()
  410. _set_logs(
  411. torch=all,
  412. dynamo=dynamo,
  413. aot=aot,
  414. autograd=autograd,
  415. inductor=inductor,
  416. dynamic=dynamic,
  417. bytecode=bytecode,
  418. aot_graphs=aot_graphs,
  419. aot_joint_graph=aot_joint_graph,
  420. ddp_graphs=ddp_graphs,
  421. distributed=distributed,
  422. c10d=c10d,
  423. ddp=ddp,
  424. fsdp=fsdp,
  425. dtensor=dtensor,
  426. graph=graph,
  427. graph_code=graph_code,
  428. graph_code_verbose=graph_code_verbose,
  429. graph_breaks=graph_breaks,
  430. graph_sizes=graph_sizes,
  431. guards=guards,
  432. recompiles=recompiles,
  433. recompiles_verbose=recompiles_verbose,
  434. trace_source=trace_source,
  435. trace_call=trace_call,
  436. trace_bytecode=trace_bytecode,
  437. output_code=output_code,
  438. kernel_code=kernel_code,
  439. schedule=schedule,
  440. perf_hints=perf_hints,
  441. pre_grad_graphs=pre_grad_graphs,
  442. post_grad_graphs=post_grad_graphs,
  443. ir_pre_fusion=ir_pre_fusion,
  444. ir_post_fusion=ir_post_fusion,
  445. onnx=onnx,
  446. onnx_diagnostics=onnx_diagnostics,
  447. fusion=fusion,
  448. overlap=overlap,
  449. sym_node=sym_node,
  450. export=export,
  451. cudagraphs=cudagraphs,
  452. compiled_autograd=compiled_autograd,
  453. compiled_autograd_verbose=compiled_autograd_verbose,
  454. cudagraph_static_inputs=cudagraph_static_inputs,
  455. benchmarking=benchmarking,
  456. autotuning=autotuning,
  457. graph_region_expansion=graph_region_expansion,
  458. inductor_metrics=inductor_metrics,
  459. hierarchical_compile=hierarchical_compile,
  460. compute_dependencies=compute_dependencies,
  461. )
  462. def get_loggers() -> list[logging.Logger]:
  463. """
  464. Returns: a list of all registered loggers
  465. """
  466. return [logging.getLogger(qname) for qname in log_registry.get_log_qnames()]
  467. def register_log(setting_name, log_name) -> None:
  468. """
  469. Enables a log to be controlled by the env var and user API with the setting_name
  470. Args:
  471. setting_name: the shorthand name used in the env var and user API
  472. log_name: the log name that the setting_name is associated with
  473. """
  474. log_registry.register_log(setting_name, log_name)
  475. def register_artifact(
  476. setting_name, description, visible=False, off_by_default=False, log_format=None
  477. ) -> None:
  478. """
  479. Enables an artifact to be controlled by the env var and user API with name
  480. Args:
  481. setting_name: the shorthand name used in the env var and user API
  482. description: A description of what this outputs
  483. visible: Whether it gets suggested to users by default
  484. off_by_default: whether this artifact should be logged when the ancestor loggers
  485. are enabled at level DEBUG
  486. """
  487. log_registry.register_artifact_name(
  488. setting_name, description, visible, off_by_default, log_format
  489. )
  490. def getArtifactLogger(module_qname, artifact_name) -> logging.Logger:
  491. if artifact_name not in log_registry.artifact_names:
  492. raise ValueError(
  493. f"Artifact name: {repr(artifact_name)} not registered,"
  494. f"please call register_artifact({repr(artifact_name)}) in torch._logging.registrations."
  495. )
  496. qname = module_qname + f".__{artifact_name}"
  497. log = logging.getLogger(qname)
  498. log.artifact_name = artifact_name # type: ignore[attr-defined]
  499. log_registry.register_artifact_log(qname)
  500. configure_artifact_log(log)
  501. return log
  502. INCR_VERBOSITY_CHAR = "+"
  503. DECR_VERBOSITY_CHAR = "-"
  504. VERBOSITY_REGEX = (
  505. "("
  506. + "|".join([re.escape(INCR_VERBOSITY_CHAR), re.escape(DECR_VERBOSITY_CHAR)])
  507. + "?)"
  508. )
  509. def configure_artifact_log(log) -> None:
  510. # If the artifact is off by default, then it should only be logged when explicitly
  511. # enabled; set propagate to False so that this artifact is not propagated
  512. # to its ancestor logger
  513. if log_registry.is_off_by_default(log.artifact_name):
  514. log.propagate = False
  515. # enable artifact logging when explicitly enabled
  516. if log_state.is_artifact_enabled(log.artifact_name):
  517. log.setLevel(logging.DEBUG)
  518. log.propagate = True
  519. # match a comma separated list of loggable names (whitespace allowed after commas)
  520. def _gen_settings_regex():
  521. return re.compile(r"((\+|-)?[\w\.]+,\s*)*(\+|-)?[\w\.]+?")
  522. def _validate_settings(settings):
  523. return re.fullmatch(_gen_settings_regex(), settings) is not None
  524. def help_message(verbose=False):
  525. def pad_to(s, length=30):
  526. assert len(s) <= length
  527. return s + " " * (length - len(s))
  528. if verbose:
  529. printed_artifacts = log_registry.artifact_names
  530. else:
  531. printed_artifacts = log_registry.visible_artifacts
  532. if verbose:
  533. heading = "All registered names"
  534. else:
  535. heading = "Visible registered names (use TORCH_LOGS='+help' for full list)"
  536. lines = (
  537. ["all"]
  538. + sorted(log_registry.log_alias_to_log_qnames.keys())
  539. + sorted(
  540. [
  541. f"{pad_to(name)}\t{log_registry.artifact_descriptions[name]}"
  542. for name in printed_artifacts
  543. ]
  544. )
  545. )
  546. setting_info = " " + "\n ".join(lines)
  547. examples = """
  548. Examples:
  549. TORCH_LOGS="+dynamo,aot" will set the log level of TorchDynamo to
  550. logging.DEBUG and AOT to logging.INFO
  551. TORCH_LOGS="-dynamo,+inductor" will set the log level of TorchDynamo to
  552. logging.ERROR and TorchInductor to logging.DEBUG
  553. TORCH_LOGS="aot_graphs" will enable the aot_graphs artifact
  554. TORCH_LOGS="+dynamo,schedule" will enable set the log level of TorchDynamo
  555. to logging.DEBUG and enable the schedule artifact
  556. TORCH_LOGS="+some.random.module,schedule" will set the log level of
  557. some.random.module to logging.DEBUG and enable the schedule artifact
  558. TORCH_LOGS_FORMAT="%(levelname)s: %(message)s" or any provided format
  559. string will set the output format
  560. Valid keys are "levelname", "message", "pathname", "levelno", "lineno",
  561. "filename" and "name".
  562. TORCH_LOGS_OUT=/tmp/output.txt will output the logs to /tmp/output.txt as
  563. well. This is useful when the output is long.
  564. """ # flake8: noqa: B950
  565. msg = f"""
  566. TORCH_LOGS Info
  567. {examples}
  568. {heading}
  569. {setting_info}
  570. """
  571. return msg
  572. def _invalid_settings_err_msg(settings, verbose=False):
  573. valid_settings = (
  574. ["all"]
  575. + list(log_registry.log_alias_to_log_qnames.keys())
  576. + list(log_registry.artifact_names)
  577. )
  578. valid_settings = ", ".join(sorted(valid_settings))
  579. msg = f"""
  580. Invalid log settings: {settings}, must be a comma separated list of fully
  581. qualified module names, registered log names or registered artifact names.
  582. For more info on various settings, try TORCH_LOGS="help"
  583. Valid settings:
  584. {valid_settings}
  585. """
  586. return msg
  587. def process_env_var_string_for_windows(env_var_str: str) -> str:
  588. """
  589. When we setup logging config as guide: https://docs.pytorch.org/docs/stable/logging.html
  590. Such as:
  591. TORCH_LOGS="+schedule,+inductor,+output_code"
  592. On Linux, it shows as:
  593. declare -x SSH_TTY="/dev/pts/0"
  594. declare -x TERM="xterm"
  595. declare -x TORCH_LOGS="+schedule,+inductor,+output_code"
  596. declare -x USER="xu"
  597. On Windows, it shows as:
  598. TORCHINDUCTOR_WINDOWS_TESTS=1
  599. TORCH_LOGS="+schedule,+inductor,+output_code"
  600. UCRTVersion=10.0.22000.0
  601. For Linux, it shows quotes by default, And Windows is not shows quotes.
  602. Besides that, Windows would auto assemble quotes when env var processing.
  603. On Linux, we will get variable: "+schedule,+inductor,+output_code"
  604. On Windows, we will get variable: '"+schedule,+inductor,+output_code"'
  605. So, we need remove the outer quotes for Windows.
  606. """
  607. _IS_WINDOWS = sys.platform == "win32"
  608. def remove_outer_quotes(s: str) -> str:
  609. if len(s) >= 2 and (
  610. (s[0] == '"' and s[-1] == '"') or (s[0] == "'" and s[-1] == "'")
  611. ):
  612. return s[1:-1]
  613. return s
  614. if _IS_WINDOWS:
  615. env_var_str = remove_outer_quotes(env_var_str)
  616. return env_var_str
  617. @functools.lru_cache
  618. def _parse_log_settings(settings):
  619. settings = process_env_var_string_for_windows(settings)
  620. if settings == "":
  621. return {}
  622. if settings == "help":
  623. raise ValueError(help_message(verbose=False))
  624. elif settings == "+help":
  625. raise ValueError(help_message(verbose=True))
  626. if not _validate_settings(settings):
  627. raise ValueError(_invalid_settings_err_msg(settings))
  628. settings = re.sub(r"\s+", "", settings)
  629. log_names = settings.split(",")
  630. def get_name_level_pair(name):
  631. clean_name = name.replace(INCR_VERBOSITY_CHAR, "")
  632. clean_name = clean_name.replace(DECR_VERBOSITY_CHAR, "")
  633. if name[0] == INCR_VERBOSITY_CHAR:
  634. level = logging.DEBUG
  635. elif name[0] == DECR_VERBOSITY_CHAR:
  636. level = logging.ERROR
  637. else:
  638. level = logging.INFO
  639. return clean_name, level
  640. log_state = LogState()
  641. for name in log_names:
  642. name, level = get_name_level_pair(name)
  643. if name == "all":
  644. name = "torch"
  645. if log_registry.is_log(name):
  646. assert level is not None
  647. log_qnames = log_registry.log_alias_to_log_qnames[name]
  648. log_state.enable_log(log_qnames, level)
  649. elif log_registry.is_artifact(name):
  650. log_state.enable_artifact(name)
  651. elif _is_valid_module(name):
  652. if not _has_registered_parent(name):
  653. log_registry.register_log(name, name)
  654. else:
  655. log_registry.register_child_log(name)
  656. log_state.enable_log(name, level)
  657. else:
  658. raise ValueError(_invalid_settings_err_msg(settings))
  659. return log_state
  660. def _is_valid_module(qname):
  661. spec = importlib.util.find_spec(qname)
  662. return spec is not None
  663. def _update_log_state_from_env() -> None:
  664. global log_state
  665. log_setting = os.environ.get(LOG_ENV_VAR, None)
  666. if log_setting is not None:
  667. log_state = _parse_log_settings(log_setting)
  668. def _has_registered_parent(log_qname) -> bool:
  669. cur_log = logging.getLogger(log_qname)
  670. registered_log_qnames = log_registry.get_log_qnames()
  671. while cur_log.parent:
  672. if cur_log.name in registered_log_qnames:
  673. return True
  674. cur_log = cur_log.parent
  675. return False
  676. def make_module_path_relative(abs_path):
  677. """
  678. Given an absolute filepath corresponding to a Python module which was
  679. loaded via normal import mechanisms using sys.path, convert it into
  680. a relative path relative to one of the Python search paths.
  681. """
  682. abs_path = pathlib.Path(abs_path).resolve()
  683. for path in sys.path:
  684. try:
  685. rel_path = abs_path.relative_to(path)
  686. except ValueError:
  687. continue
  688. else:
  689. return str(rel_path)
  690. return str(abs_path)
  691. # apply custom formats to artifacts when necessary
  692. class TorchLogsFormatter(logging.Formatter):
  693. def __init__(
  694. self, *, trace: bool = False, trace_id_filter: Optional[set[str]] = None
  695. ) -> None:
  696. super().__init__()
  697. self._is_trace = trace
  698. self._trace_id_filter = trace_id_filter
  699. def format(self, record):
  700. artifact_name = getattr(logging.getLogger(record.name), "artifact_name", None)
  701. if artifact_name is not None:
  702. artifact_formatter = log_registry.artifact_log_formatters.get(
  703. artifact_name, None
  704. )
  705. if artifact_formatter is not None:
  706. return artifact_formatter.format(record)
  707. record.message = record.getMessage()
  708. record.asctime = self.formatTime(record, "%m%d %H:%M:%S")
  709. # exception handling - copied from logging.Formatter.format
  710. s = record.message
  711. if record.exc_info:
  712. # Cache the traceback text to avoid converting it multiple times
  713. # (it's constant anyway)
  714. if not record.exc_text:
  715. record.exc_text = self.formatException(record.exc_info)
  716. if record.exc_text:
  717. if s[-1:] != "\n":
  718. s = s + "\n"
  719. s = s + record.exc_text
  720. if record.stack_info:
  721. if s[-1:] != "\n":
  722. s = s + "\n"
  723. s = s + self.formatStack(record.stack_info)
  724. record.rankprefix = ""
  725. if not self._is_trace and dist.is_available() and dist.is_initialized():
  726. record.rankprefix = f"[rank{dist.get_rank()}]:"
  727. record.traceid = ""
  728. if (
  729. not self._is_trace
  730. and (trace_id := torch._guards.CompileContext.current_trace_id())
  731. is not None
  732. ):
  733. record.traceid = f" [{trace_id}]"
  734. glog_level_to_abbr = {
  735. "DEBUG": "V", # V is for VERBOSE in glog
  736. "INFO": "I",
  737. "WARNING": "W",
  738. "ERROR": "E",
  739. "CRITICAL": "C",
  740. }
  741. shortlevel = glog_level_to_abbr.get(record.levelname, record.levelname)
  742. record.artifactprefix = ""
  743. if artifact_name is not None:
  744. record.artifactprefix = f" [__{artifact_name}]"
  745. filepath = make_module_path_relative(record.pathname)
  746. if (
  747. self._trace_id_filter
  748. and record.traceid.strip() not in self._trace_id_filter
  749. ):
  750. return ""
  751. prefix = (
  752. f"{record.rankprefix}{shortlevel}{record.asctime}.{int(record.msecs * 1000):06d} {record.process} "
  753. f"{filepath}:"
  754. f"{record.lineno}]{record.traceid}{record.artifactprefix}"
  755. )
  756. if self._is_trace:
  757. assert s == ""
  758. try:
  759. r = f"{prefix} {json.dumps(record.metadata)}"
  760. except TypeError:
  761. log.warning("failing metadata: %r", record.metadata)
  762. raise
  763. if record.payload is not None:
  764. r += "".join(f"\n\t{l}" for l in record.payload.split("\n"))
  765. return r
  766. else:
  767. lines = s.split("\n")
  768. return "\n".join(f"{prefix} {l}" for l in lines)
  769. def _default_formatter():
  770. fmt = os.environ.get(LOG_FORMAT_ENV_VAR, None)
  771. trace_id_filter = {
  772. item.strip()
  773. for item in os.environ.get(LOG_TRACE_ID_FILTER, "").split(",")
  774. if item.strip()
  775. }
  776. if fmt is None:
  777. return TorchLogsFormatter(trace_id_filter=trace_id_filter)
  778. else:
  779. if fmt in ("short", "basic"):
  780. fmt = logging.BASIC_FORMAT
  781. return logging.Formatter(fmt)
  782. DEFAULT_FORMATTER = _default_formatter()
  783. def _setup_handlers(create_handler_fn, log) -> None:
  784. debug_handler = _track_handler(create_handler_fn())
  785. debug_handler.setFormatter(DEFAULT_FORMATTER)
  786. debug_handler.setLevel(logging.DEBUG)
  787. log.addHandler(debug_handler)
  788. handlers = WeakSet() # type: ignore[var-annotated]
  789. # mark handlers that we've created
  790. # so we don't modify user handlers
  791. def _track_handler(handler):
  792. handlers.add(handler)
  793. return handler
  794. def _is_torch_handler(handler):
  795. return handler in handlers
  796. # clears all torch handlers on specified loggers
  797. def _clear_handlers(log) -> None:
  798. to_remove = [handler for handler in log.handlers if _is_torch_handler(handler)]
  799. for handler in to_remove:
  800. log.removeHandler(handler)
  801. def _reset_logs() -> None:
  802. # reset all registered logs
  803. for log_qname in log_registry.get_log_qnames():
  804. log = logging.getLogger(log_qname)
  805. log.setLevel(logging.WARNING)
  806. log.propagate = False
  807. _clear_handlers(log)
  808. # reset all artifact and child logs
  809. for artifact_log_qname in itertools.chain(
  810. log_registry.get_artifact_log_qnames(), log_registry.get_child_log_qnames()
  811. ):
  812. log = logging.getLogger(artifact_log_qname)
  813. log.setLevel(logging.NOTSET)
  814. log.propagate = True
  815. trace_log.propagate = False
  816. _clear_handlers(trace_log)
  817. def _get_log_state():
  818. return log_state
  819. def _set_log_state(state) -> None:
  820. global log_state
  821. log_state = state
  822. def _init_logs(log_file_name=None) -> None:
  823. global GET_DTRACE_STRUCTURED
  824. _reset_logs()
  825. _update_log_state_from_env()
  826. out = os.environ.get(LOG_OUT_ENV_VAR, None)
  827. if out is not None:
  828. log_file_name = out
  829. # First, reset all known (registered) loggers to NOTSET, so that they
  830. # respect their parent log level
  831. for log_qname in log_registry.get_log_qnames():
  832. # But not the top level torch level: this defaults to WARNING so
  833. # that our log messages don't leak to the lower levels
  834. if log_qname == "torch":
  835. continue
  836. log = logging.getLogger(log_qname)
  837. log.setLevel(logging.NOTSET)
  838. # Now, for all loggers which the user requested to have non-standard
  839. # logging behavior, modify their log levels
  840. for log_qname, level in log_state.get_log_level_pairs():
  841. log = logging.getLogger(log_qname)
  842. log.setLevel(level)
  843. # Finally, setup handlers for all registered loggers
  844. for log_qname in log_registry.get_log_qnames():
  845. log = logging.getLogger(log_qname)
  846. _setup_handlers(
  847. logging.StreamHandler,
  848. log,
  849. )
  850. if log_file_name is not None:
  851. _setup_handlers(
  852. lambda: logging.FileHandler(log_file_name),
  853. log,
  854. )
  855. # configure artifact loggers, note: this must happen last
  856. # since the levels of ancestor loggers are taken into account
  857. for artifact_log_qname in log_registry.get_artifact_log_qnames():
  858. log = logging.getLogger(artifact_log_qname)
  859. configure_artifact_log(log)
  860. # Setup handler for the special trace_log, with different default
  861. # configuration
  862. trace_dir_name = os.environ.get(TRACE_ENV_VAR, None)
  863. if dtrace_dir_name := os.environ.get(DTRACE_ENV_VAR, None):
  864. GET_DTRACE_STRUCTURED = True
  865. trace_dir_name = dtrace_dir_name
  866. # This handler may remove itself if trace_dir_name is None and we are not
  867. # actually in an FB environment. This allows us to defer actually
  868. # initializing it until we actually need to log anything. This is
  869. # important because JK initializes a C++ singleton, which will pork our
  870. # process if we subsequently fork.
  871. global LOG_TRACE_HANDLER
  872. if LOG_TRACE_HANDLER is None:
  873. LOG_TRACE_HANDLER = LazyTraceHandler(trace_dir_name)
  874. # This log is ALWAYS at debug level. We will additionally test if there
  875. # are any handlers before deciding to actually call logging on this. Do
  876. # not manually call
  877. trace_log.setLevel(logging.DEBUG)
  878. trace_log_handler = _track_handler(LOG_TRACE_HANDLER)
  879. trace_log_handler.setFormatter(TorchLogsFormatter(trace=True))
  880. trace_log.addHandler(trace_log_handler)
  881. class LazyTraceHandler(logging.StreamHandler):
  882. """Like FileHandler, but the file is allocated lazily only upon the first log message"""
  883. def __init__(self, root_dir: Optional[str]) -> None:
  884. # This is implemented in the same way that delay is implemented on
  885. # FileHandler
  886. self.root_dir = root_dir
  887. logging.Handler.__init__(self)
  888. self.stream = None
  889. self._builtin_open = open
  890. # cloned from FileHandler in cpython
  891. def close(self) -> None:
  892. self.acquire()
  893. try:
  894. try:
  895. if self.stream:
  896. try:
  897. self.flush()
  898. finally:
  899. stream = self.stream
  900. self.stream = None
  901. if hasattr(stream, "close"):
  902. stream.close()
  903. finally:
  904. # Issue #19523: call unconditionally to
  905. # prevent a handler leak when delay is set
  906. # Also see Issue #42378: we also rely on
  907. # self._closed being set to True there
  908. logging.StreamHandler.close(self)
  909. finally:
  910. self.release()
  911. def emit(self, record) -> None:
  912. if self.stream is None:
  913. if self.root_dir is None:
  914. TRACE_LOG_DIR = "/logs"
  915. import torch.version as torch_version
  916. if (
  917. hasattr(torch_version, "git_version")
  918. and os.getenv("MAST_HPC_JOB_NAME") is None
  919. ):
  920. log.info(
  921. "LazyTraceHandler: disabled because not fbcode or conda on mast"
  922. )
  923. elif not torch._utils_internal.justknobs_check("pytorch/trace:enable"):
  924. log.info(
  925. "LazyTraceHandler: disabled because justknobs_check('pytorch/trace:enable') returned False"
  926. )
  927. elif not os.path.exists(TRACE_LOG_DIR):
  928. log.info(
  929. "LazyTraceHandler: disabled because %s does not exist",
  930. TRACE_LOG_DIR,
  931. )
  932. elif not os.access(TRACE_LOG_DIR, os.W_OK):
  933. log.info(
  934. "LazyTraceHandler: disabled because %s is not writeable",
  935. TRACE_LOG_DIR,
  936. )
  937. else:
  938. self.root_dir = TRACE_LOG_DIR
  939. if self.root_dir is not None:
  940. os.makedirs(self.root_dir, exist_ok=True)
  941. ranksuffix = ""
  942. if dist.is_available() and dist.is_initialized():
  943. ranksuffix = f"rank_{dist.get_rank()}_"
  944. self.stream = tempfile.NamedTemporaryFile(
  945. mode="w+",
  946. suffix=".log",
  947. prefix=f"dedicated_log_torch_trace_{ranksuffix}",
  948. dir=self.root_dir,
  949. delete=False,
  950. )
  951. log.info("LazyTraceHandler: logging to %s", self.stream.name)
  952. else:
  953. # We go poof, remove and no-op
  954. trace_log.removeHandler(self)
  955. return
  956. if self.stream:
  957. super().emit(record)
  958. @functools.cache
  959. def warning_once(logger_obj, *args, **kwargs) -> None:
  960. """
  961. This function is similar to `logger.warning()`, but will emit the warning with the same message only once
  962. Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the cache.
  963. The assumption here is that all warning messages are unique across the code. If they aren't then need to switch to
  964. another type of cache that includes the caller frame information in the hashing function.
  965. """
  966. logger_obj.warning(*args, **kwargs)
  967. def safe_grad_filter(message, category, filename, lineno, file=None, line=None) -> bool:
  968. return "The .grad attribute of a Tensor" not in str(message)
  969. def user_warning_filter(
  970. message, category, filename, lineno, file=None, line=None
  971. ) -> bool:
  972. return not category == UserWarning
  973. @contextlib.contextmanager
  974. def hide_warnings(filter_fn=lambda *args, **kwargs: True):
  975. """
  976. A context manager that temporarily suppresses warnings,
  977. using public API: https://docs.python.org/3/library/warnings.html#warnings.showwarning.
  978. Useful to hide warnings without mutating warnings module state, see:
  979. https://github.com/pytorch/pytorch/issues/128427#issuecomment-2161496162.
  980. NOTE: Warnings issued under this context will still be cached in the __warningregistry__
  981. and count towards the once/default rule. So you should NEVER use this on a user-land function.
  982. Filter must implement the showwarning API:
  983. def filter_fn(message, category, filename, lineno, file=None, line=None) -> bool:
  984. return True # show this warning entry
  985. """
  986. prior = warnings.showwarning
  987. def _showwarning(*args, **kwargs):
  988. if filter_fn(*args, **kwargs):
  989. prior(*args, **kwargs)
  990. try:
  991. warnings.showwarning = _showwarning
  992. yield
  993. finally:
  994. warnings.showwarning = prior
  995. class LazyString(Generic[_P]):
  996. def __init__(
  997. self, func: Callable[_P, str], *args: _P.args, **kwargs: _P.kwargs
  998. ) -> None:
  999. self.func = func
  1000. self.args = args
  1001. self.kwargs = kwargs
  1002. def __str__(self) -> str:
  1003. return self.func(*self.args, **self.kwargs)
  1004. # Logs the time it takes to do structured logging by frame/compile id
  1005. # key is always {frame_id}_{frame_compile_id}
  1006. structured_logging_overhead: dict[str, float] = defaultdict(float)
  1007. def add_structured_logging_overhead(time_spent: float) -> None:
  1008. global structured_logging_overhead
  1009. key = None
  1010. if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
  1011. frame_id = trace_id.compile_id.frame_id
  1012. frame_compile_id = trace_id.compile_id.frame_compile_id
  1013. # Why not trace_id.attempt, like structured logging?
  1014. # We aggregate across all attempts because
  1015. # a compilation metric is logged per successful attempt
  1016. key = f"{frame_id}_{frame_compile_id}"
  1017. # TODO: deal with structured logging that occurs outside of specific compile ids
  1018. # It's hard to figure out where we would log that if we want it in compilation metrics
  1019. # itself.
  1020. if key is not None:
  1021. key = str(key)
  1022. structured_logging_overhead[key] += time_spent
  1023. def get_structured_logging_overhead() -> Optional[float]:
  1024. key = None
  1025. if (trace_id := torch._guards.CompileContext.current_trace_id()) is not None:
  1026. frame_id = trace_id.compile_id.frame_id
  1027. frame_compile_id = trace_id.compile_id.frame_compile_id
  1028. key = f"{frame_id}_{frame_compile_id}"
  1029. if key is not None:
  1030. return structured_logging_overhead.get(key)
  1031. else:
  1032. return None
  1033. def trace_structured_artifact(
  1034. name: str, # this will go in metadata
  1035. encoding: str,
  1036. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1037. compile_id: Optional[CompileId] = None,
  1038. ) -> None:
  1039. trace_structured(
  1040. "artifact",
  1041. metadata_fn=lambda: {
  1042. "name": name,
  1043. "encoding": encoding,
  1044. },
  1045. payload_fn=payload_fn,
  1046. compile_id=compile_id,
  1047. )
  1048. def trace_structured(
  1049. name: str,
  1050. # NB: metadata expected to be dict so adding more info is forward compatible
  1051. # Tuple[str, int] is a special case for string interning
  1052. metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
  1053. *,
  1054. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1055. suppress_context: bool = False,
  1056. expect_trace_id: bool = True, # Whether or not we expect to have a current trace id
  1057. record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
  1058. compile_id: Optional[CompileId] = None, # Optional if unavailable in the trace
  1059. ) -> None:
  1060. """
  1061. metadata is an arbitrary JSON compatible struct, but it's expected to not be
  1062. too long (e.g., less than 1MB)
  1063. payload is an arbitrary string, which can be arbitrarily long (but expected to have
  1064. newlines so no lines are too long)
  1065. """
  1066. assert name not in [
  1067. "rank",
  1068. "compiled_autograd_id",
  1069. "frame_id",
  1070. "frame_compile_id",
  1071. "attempt",
  1072. "severity",
  1073. "timestamp",
  1074. "pathname",
  1075. "thread",
  1076. ]
  1077. assert callable(metadata_fn), (
  1078. f"metadata_fn should be callable, but got {type(metadata_fn)}"
  1079. )
  1080. assert callable(payload_fn), (
  1081. f"payload_fn should be callable, but got {type(payload_fn)}"
  1082. )
  1083. # trace_log never propagates and is ALWAYS DEBUG, so also check that there
  1084. # are handlers instead of checking the log level
  1085. if trace_log.handlers:
  1086. start_time = time.time_ns()
  1087. record: dict[str, object] = {}
  1088. record[name] = metadata_fn()
  1089. if not suppress_context:
  1090. # TODO: Actually, the rank probably should just be emitted once at
  1091. # the top, and not repeatedly spammed in all the logs, since it
  1092. # never changes and we assume no interleaving
  1093. if dist.is_available() and dist.is_initialized():
  1094. record["rank"] = dist.get_rank()
  1095. trace_id = torch._guards.CompileContext.current_trace_id()
  1096. if expect_trace_id and trace_id is None and compile_id is None:
  1097. # Record the stack of the log call to better diagnose why we
  1098. # don't have a frame id for it
  1099. record["stack"] = torch._logging.structured.from_traceback(
  1100. CapturedTraceback.extract(skip=1).summary()
  1101. )
  1102. else:
  1103. cid = trace_id.compile_id if trace_id else compile_id
  1104. if cid is not None:
  1105. if cid.compiled_autograd_id is not None:
  1106. record["compiled_autograd_id"] = cid.compiled_autograd_id
  1107. if cid.frame_id is not None:
  1108. record["frame_id"] = cid.frame_id
  1109. if cid.frame_compile_id is not None:
  1110. record["frame_compile_id"] = cid.frame_compile_id
  1111. if trace_id:
  1112. record["attempt"] = trace_id.attempt
  1113. payload = payload_fn()
  1114. if payload is not None:
  1115. if not isinstance(payload, str):
  1116. if isinstance(payload, list):
  1117. # special case to look better
  1118. payload = "[\n" + ",\n".join(json.dumps(i) for i in payload) + "\n]"
  1119. else:
  1120. def json_default(obj):
  1121. # Sets aren't json serializable
  1122. if isinstance(obj, set):
  1123. return list(obj)
  1124. raise TypeError(
  1125. f"Object of type {type(obj)} is not JSON serializable"
  1126. )
  1127. # force newlines so we are unlikely to overflow line limit
  1128. payload = json.dumps(payload, default=json_default, indent=0)
  1129. h = hashlib.md5(usedforsecurity=False)
  1130. h.update(payload.encode("utf-8"))
  1131. record["has_payload"] = h.hexdigest()
  1132. trace_log.debug(
  1133. "", extra={"metadata": record, "payload": payload}, stacklevel=2
  1134. )
  1135. log_trace_structured_event(name, record)
  1136. if record_logging_overhead:
  1137. # Convert to seconds from nanoseconds, add it to the frame compile total
  1138. structured_logging_overhead_s = (time.time_ns() - start_time) / 1e9
  1139. add_structured_logging_overhead(structured_logging_overhead_s)
  1140. def dtrace_structured(
  1141. name: str,
  1142. # NB: metadata expected to be dict so adding more info is forward compatible
  1143. # Tuple[str, int] is a special case for string interning
  1144. metadata_fn: Callable[[], Union[dict[str, Any], tuple[str, int]]] = dict,
  1145. *,
  1146. payload_fn: Callable[[], Optional[Union[str, object]]] = lambda: None,
  1147. suppress_context: bool = False,
  1148. expect_trace_id: bool = False, # Whether or not we expect to have a current trace id
  1149. record_logging_overhead: bool = True, # Whether or not to record the time spent on structured logging
  1150. ) -> None:
  1151. """
  1152. For logging more detailed information used for debugging. This may result in
  1153. the program becoming slow.
  1154. """
  1155. if GET_DTRACE_STRUCTURED:
  1156. trace_structured(
  1157. name,
  1158. metadata_fn,
  1159. payload_fn=payload_fn,
  1160. suppress_context=suppress_context,
  1161. expect_trace_id=expect_trace_id,
  1162. record_logging_overhead=record_logging_overhead,
  1163. )
  1164. import torch._guards
  1165. import torch._utils_internal
  1166. import torch.distributed as dist