logging.py 5.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import functools
  15. import logging
  16. import os
  17. from typing import Optional
  18. from .state import PartialState
  19. class MultiProcessAdapter(logging.LoggerAdapter):
  20. """
  21. An adapter to assist with logging in multiprocess.
  22. `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
  23. or only the main executed one. Default is `main_process_only=True`.
  24. Does not require an `Accelerator` object to be created first.
  25. """
  26. @staticmethod
  27. def _should_log(main_process_only):
  28. "Check if log should be performed"
  29. state = PartialState()
  30. return not main_process_only or (main_process_only and state.is_main_process)
  31. def log(self, level, msg, *args, **kwargs):
  32. """
  33. Delegates logger call after checking if we should log.
  34. Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
  35. or only the main executed one. Default is `True` if not passed
  36. Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
  37. read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
  38. break with the previous behavior.
  39. `in_order` is ignored if `main_process_only` is passed.
  40. """
  41. if PartialState._shared_state == {}:
  42. raise RuntimeError(
  43. "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
  44. )
  45. main_process_only = kwargs.pop("main_process_only", True)
  46. in_order = kwargs.pop("in_order", False)
  47. # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
  48. kwargs.setdefault("stacklevel", 2)
  49. if self.isEnabledFor(level):
  50. if self._should_log(main_process_only):
  51. msg, kwargs = self.process(msg, kwargs)
  52. self.logger.log(level, msg, *args, **kwargs)
  53. elif in_order:
  54. state = PartialState()
  55. for i in range(state.num_processes):
  56. if i == state.process_index:
  57. msg, kwargs = self.process(msg, kwargs)
  58. self.logger.log(level, msg, *args, **kwargs)
  59. state.wait_for_everyone()
  60. @functools.lru_cache(None)
  61. def warning_once(self, *args, **kwargs):
  62. """
  63. This method is identical to `logger.warning()`, but will emit the warning with the same message only once
  64. Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
  65. cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
  66. switch to another type of cache that includes the caller frame information in the hashing function.
  67. """
  68. self.warning(*args, **kwargs)
  69. def get_logger(name: str, log_level: Optional[str] = None):
  70. """
  71. Returns a `logging.Logger` for `name` that can handle multiprocessing.
  72. If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
  73. processes and in order, also pass `in_order=True`
  74. Args:
  75. name (`str`):
  76. The name for the logger, such as `__file__`
  77. log_level (`str`, *optional*):
  78. The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
  79. Example:
  80. ```python
  81. >>> from accelerate.logging import get_logger
  82. >>> from accelerate import Accelerator
  83. >>> logger = get_logger(__name__)
  84. >>> accelerator = Accelerator()
  85. >>> logger.info("My log", main_process_only=False)
  86. >>> logger.debug("My log", main_process_only=True)
  87. >>> logger = get_logger(__name__, log_level="DEBUG")
  88. >>> logger.info("My log")
  89. >>> logger.debug("My second log")
  90. >>> array = ["a", "b", "c", "d"]
  91. >>> letter_at_rank = array[accelerator.process_index]
  92. >>> logger.info(letter_at_rank, in_order=True)
  93. ```
  94. """
  95. if log_level is None:
  96. log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
  97. logger = logging.getLogger(name)
  98. if log_level is not None:
  99. logger.setLevel(log_level.upper())
  100. logger.root.setLevel(log_level.upper())
  101. return MultiProcessAdapter(logger, {})