logging.py 2.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """Logging utilities for Dynamo and Inductor.
  2. This module provides specialized logging functionality including:
  3. - Step-based logging that prepends step numbers to log messages
  4. - Progress bar management for compilation phases
  5. - Centralized logger management for Dynamo and Inductor components
  6. The logging system helps track the progress of compilation phases and provides structured
  7. logging output for debugging and monitoring.
  8. """
  9. import itertools
  10. import logging
  11. from typing import Any, Callable
  12. from torch.hub import _Faketqdm, tqdm
  13. # Disable progress bar by default, not in dynamo config because otherwise get a circular import
  14. disable_progress = True
  15. # Return all loggers that torchdynamo/torchinductor is responsible for
  16. def get_loggers() -> list[logging.Logger]:
  17. return [
  18. logging.getLogger("torch.fx.experimental.symbolic_shapes"),
  19. logging.getLogger("torch._dynamo"),
  20. logging.getLogger("torch._inductor"),
  21. ]
  22. # Creates a logging function that logs a message with a step # prepended.
  23. # get_step_logger should be lazily called (i.e. at runtime, not at module-load time)
  24. # so that step numbers are initialized properly. e.g.:
  25. # @functools.cache
  26. # def _step_logger():
  27. # return get_step_logger(logging.getLogger(...))
  28. # def fn():
  29. # _step_logger()(logging.INFO, "msg")
  30. _step_counter = itertools.count(1)
  31. # Update num_steps if more phases are added: Dynamo, AOT, Backend
  32. # This is very inductor centric
  33. # _inductor.utils.has_triton() gives a circular import error here
  34. if not disable_progress:
  35. try:
  36. import triton # noqa: F401
  37. num_steps = 3
  38. except ImportError:
  39. num_steps = 2
  40. pbar = tqdm(total=num_steps, desc="torch.compile()", delay=0)
  41. def get_step_logger(logger: logging.Logger) -> Callable[..., None]:
  42. if not disable_progress:
  43. pbar.update(1)
  44. if not isinstance(pbar, _Faketqdm):
  45. pbar.set_postfix_str(f"{logger.name}")
  46. step = next(_step_counter)
  47. def log(level: int, msg: str, **kwargs: Any) -> None:
  48. if "stacklevel" not in kwargs:
  49. kwargs["stacklevel"] = 2
  50. logger.log(level, "Step %s: %s", step, msg, **kwargs)
  51. return log