logger.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. # mypy: allow-untyped-defs
  2. import functools
  3. import logging
  4. import time
  5. from typing import Any, Callable, TypeVar
  6. from typing_extensions import ParamSpec
  7. from uuid import uuid4
  8. import torch.distributed.c10d_logger as c10d_logger
  9. from torch.distributed.checkpoint.logging_handlers import DCP_LOGGER_NAME
  10. logger = logging.getLogger()
  11. __all__: list[str] = []
  12. global _dcp_logger
  13. _dcp_logger = c10d_logger._get_or_create_logger(DCP_LOGGER_NAME)
  14. _T = TypeVar("_T")
  15. _P = ParamSpec("_P")
  16. def _msg_dict_from_dcp_method_args(*args, **kwargs) -> dict[str, Any]:
  17. """
  18. Extracts log data from dcp method args
  19. """
  20. msg_dict = {}
  21. # checkpoint ID can be passed in through the serializer or through the checkpoint id directly
  22. storage_writer = kwargs.get("storage_writer", None)
  23. storage_reader = kwargs.get("storage_reader", None)
  24. planner = kwargs.get("planner", None)
  25. checkpoint_id = kwargs.get("checkpoint_id", None)
  26. if not checkpoint_id and (serializer := storage_writer or storage_reader):
  27. checkpoint_id = getattr(serializer, "checkpoint_id", None)
  28. msg_dict["checkpoint_id"] = (
  29. str(checkpoint_id) if checkpoint_id is not None else checkpoint_id
  30. )
  31. # Uniquely identify a _dcp_method_logger wrapped function call.
  32. msg_dict["uuid"] = str(uuid4().int)
  33. if storage_writer:
  34. msg_dict["storage_writer"] = storage_writer.__class__.__name__
  35. if storage_reader:
  36. msg_dict["storage_reader"] = storage_reader.__class__.__name__
  37. if planner:
  38. msg_dict["planner"] = planner.__class__.__name__
  39. return msg_dict
  40. def _get_msg_dict(func_name, *args, **kwargs) -> dict[str, Any]:
  41. msg_dict = _msg_dict_from_dcp_method_args(*args, **kwargs)
  42. msg_dict.update(c10d_logger._get_msg_dict(func_name, *args, **kwargs))
  43. return msg_dict
  44. def _dcp_method_logger(
  45. log_exceptions: bool = False, **wrapper_kwargs: Any
  46. ) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]: # pyre-ignore
  47. """This method decorator logs the start, end, and exception of wrapped events."""
  48. def decorator(func: Callable[_P, _T]):
  49. @functools.wraps(func)
  50. def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _T:
  51. msg_dict = _get_msg_dict(
  52. func.__name__, *args, **{**wrapper_kwargs, **kwargs}
  53. )
  54. # log start event
  55. msg_dict["event"] = "start"
  56. t0 = time.time_ns()
  57. msg_dict["time"] = t0
  58. msg_dict["log_exceptions"] = log_exceptions
  59. _dcp_logger.debug(msg_dict)
  60. # exceptions
  61. try:
  62. result = func(*args, **kwargs)
  63. except BaseException as error:
  64. if log_exceptions:
  65. msg_dict["event"] = "exception"
  66. msg_dict["error"] = f"{error}"
  67. msg_dict["time"] = time.time_ns()
  68. _dcp_logger.error(msg_dict)
  69. raise
  70. # end event
  71. msg_dict["event"] = "end"
  72. t1 = time.time_ns()
  73. msg_dict["time"] = time.time_ns()
  74. msg_dict["times_spent"] = t1 - t0
  75. _dcp_logger.debug(msg_dict)
  76. return result
  77. return wrapper
  78. return decorator
  79. def _init_logger(rank: int):
  80. logger.setLevel(logging.INFO)
  81. ch = logging.StreamHandler()
  82. ch.setLevel(logging.INFO)
  83. formatter = logging.Formatter(
  84. f"[{rank}] %(asctime)s - %(name)s - %(levelname)s - %(message)s"
  85. )
  86. ch.setFormatter(formatter)
  87. logger.addHandler(ch)