| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129 |
- # Copyright (c) OpenMMLab. All rights reserved.
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import numbers
- from abc import ABCMeta, abstractmethod
- import numpy as np
- import torch
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.trainers.hooks.priority import Priority
- from modelscope.utils.constant import ModeKeys
- class LoggerHook(Hook):
- """Base class for logger hooks.
- Args:
- interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true.
- ignore_last (bool): Ignore the log of last iterations in each epoch
- if less than `interval`.
- reset_flag (bool): Whether to clear the output buffer after logging.
- by_epoch (bool): Whether EpochBasedtrainer is used.
- """
- __metaclass__ = ABCMeta
- PRIORITY = Priority.VERY_LOW
- def __init__(self,
- interval=10,
- ignore_last=True,
- reset_flag=False,
- by_epoch=True):
- self.interval = interval
- self.ignore_last = ignore_last
- self.reset_flag = reset_flag
- self.by_epoch = by_epoch
- @abstractmethod
- def log(self, trainer):
- pass
- @staticmethod
- def is_scalar(val, include_np=True, include_torch=True):
- """Tell the input variable is a scalar or not.
- Args:
- val: Input variable.
- include_np (bool): Whether to treat 0-d np.ndarray as a scalar.
- include_torch (bool): Whether to treat 0-d torch.Tensor as a scalar.
- Returns:
- bool: True or False.
- """
- if isinstance(val, numbers.Number):
- return True
- elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
- return True
- elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
- return True
- else:
- return False
- def fetch_tensor(self, trainer, n=0):
- """Fetch latest n values or all values, process tensor type, convert to numpy for dump logs."""
- assert n >= 0
- for key in trainer.log_buffer.val_history:
- values = trainer.log_buffer.val_history[key][-n:]
- for i, v in enumerate(values):
- if isinstance(v, torch.Tensor):
- values[i] = v.clone().detach().cpu().numpy()
- trainer.log_buffer.val_history[key][-n:] = values
- def get_epoch(self, trainer):
- if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]:
- epoch = trainer.epoch + 1
- else:
- raise ValueError(
- f'trainer mode should be {ModeKeys.TRAIN} or {ModeKeys.EVAL}, '
- f'but got {trainer.mode}')
- return epoch
- def get_iter(self, trainer, inner_iter=False):
- """Get the current training iteration step."""
- if self.by_epoch and inner_iter:
- current_iter = trainer.inner_iter + 1
- else:
- current_iter = trainer.iter + 1
- return current_iter
- def before_run(self, trainer):
- for hook in trainer.hooks[::-1]:
- if isinstance(hook, LoggerHook):
- hook.reset_flag = True
- break
- def before_epoch(self, trainer):
- trainer.log_buffer.clear() # clear logs of last epoch
- def after_train_iter(self, trainer):
- if self.by_epoch and self.every_n_inner_iters(trainer, self.interval):
- self.fetch_tensor(trainer, self.interval)
- trainer.log_buffer.average(self.interval)
- elif not self.by_epoch and self.every_n_iters(trainer, self.interval):
- self.fetch_tensor(trainer, self.interval)
- trainer.log_buffer.average(self.interval)
- elif self.end_of_epoch(trainer) and not self.ignore_last:
- # not precise but more stable
- self.fetch_tensor(trainer, self.interval)
- trainer.log_buffer.average(self.interval)
- if trainer.log_buffer.ready:
- self.log(trainer)
- if self.reset_flag:
- trainer.log_buffer.clear_output()
- def after_train_epoch(self, trainer):
- if trainer.log_buffer.ready:
- self.log(trainer)
- if self.reset_flag:
- trainer.log_buffer.clear_output()
- def after_val_epoch(self, trainer):
- self.fetch_tensor(trainer)
- trainer.log_buffer.average()
- self.log(trainer)
- if self.reset_flag:
- trainer.log_buffer.clear_output()
|