| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126 |
- # Copyright 2022 The HuggingFace Team. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import functools
- import logging
- import os
- from typing import Optional
- from .state import PartialState
- class MultiProcessAdapter(logging.LoggerAdapter):
- """
- An adapter to assist with logging in multiprocess.
- `log` takes in an additional `main_process_only` kwarg, which dictates whether it should be called on all processes
- or only the main executed one. Default is `main_process_only=True`.
- Does not require an `Accelerator` object to be created first.
- """
- @staticmethod
- def _should_log(main_process_only):
- "Check if log should be performed"
- state = PartialState()
- return not main_process_only or (main_process_only and state.is_main_process)
- def log(self, level, msg, *args, **kwargs):
- """
- Delegates logger call after checking if we should log.
- Accepts a new kwarg of `main_process_only`, which will dictate whether it will be logged across all processes
- or only the main executed one. Default is `True` if not passed
- Also accepts "in_order", which if `True` makes the processes log one by one, in order. This is much easier to
- read, but comes at the cost of sometimes needing to wait for the other processes. Default is `False` to not
- break with the previous behavior.
- `in_order` is ignored if `main_process_only` is passed.
- """
- if PartialState._shared_state == {}:
- raise RuntimeError(
- "You must initialize the accelerate state by calling either `PartialState()` or `Accelerator()` before using the logging utility."
- )
- main_process_only = kwargs.pop("main_process_only", True)
- in_order = kwargs.pop("in_order", False)
- # set `stacklevel` to exclude ourself in `Logger.findCaller()` while respecting user's choice
- kwargs.setdefault("stacklevel", 2)
- if self.isEnabledFor(level):
- if self._should_log(main_process_only):
- msg, kwargs = self.process(msg, kwargs)
- self.logger.log(level, msg, *args, **kwargs)
- elif in_order:
- state = PartialState()
- for i in range(state.num_processes):
- if i == state.process_index:
- msg, kwargs = self.process(msg, kwargs)
- self.logger.log(level, msg, *args, **kwargs)
- state.wait_for_everyone()
- @functools.lru_cache(None)
- def warning_once(self, *args, **kwargs):
- """
- This method is identical to `logger.warning()`, but will emit the warning with the same message only once
- Note: The cache is for the function arguments, so 2 different callers using the same arguments will hit the
- cache. The assumption here is that all warning messages are unique across the code. If they aren't then need to
- switch to another type of cache that includes the caller frame information in the hashing function.
- """
- self.warning(*args, **kwargs)
- def get_logger(name: str, log_level: Optional[str] = None):
- """
- Returns a `logging.Logger` for `name` that can handle multiprocessing.
- If a log should be called on all processes, pass `main_process_only=False` If a log should be called on all
- processes and in order, also pass `in_order=True`
- Args:
- name (`str`):
- The name for the logger, such as `__file__`
- log_level (`str`, *optional*):
- The log level to use. If not passed, will default to the `LOG_LEVEL` environment variable, or `INFO` if not
- Example:
- ```python
- >>> from accelerate.logging import get_logger
- >>> from accelerate import Accelerator
- >>> logger = get_logger(__name__)
- >>> accelerator = Accelerator()
- >>> logger.info("My log", main_process_only=False)
- >>> logger.debug("My log", main_process_only=True)
- >>> logger = get_logger(__name__, log_level="DEBUG")
- >>> logger.info("My log")
- >>> logger.debug("My second log")
- >>> array = ["a", "b", "c", "d"]
- >>> letter_at_rank = array[accelerator.process_index]
- >>> logger.info(letter_at_rank, in_order=True)
- ```
- """
- if log_level is None:
- log_level = os.environ.get("ACCELERATE_LOG_LEVEL", None)
- logger = logging.getLogger(name)
- if log_level is not None:
- logger.setLevel(log_level.upper())
- logger.root.setLevel(log_level.upper())
- return MultiProcessAdapter(logger, {})
|