base.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. import numbers
  4. from abc import ABCMeta, abstractmethod
  5. import numpy as np
  6. import torch
  7. from modelscope.trainers.hooks.hook import Hook
  8. from modelscope.trainers.hooks.priority import Priority
  9. from modelscope.utils.constant import ModeKeys
  10. class LoggerHook(Hook):
  11. """Base class for logger hooks.
  12. Args:
  13. interval (int): Logging interval (every k iterations). It is interval of iterations even by_epoch is true.
  14. ignore_last (bool): Ignore the log of last iterations in each epoch
  15. if less than `interval`.
  16. reset_flag (bool): Whether to clear the output buffer after logging.
  17. by_epoch (bool): Whether EpochBasedtrainer is used.
  18. """
  19. __metaclass__ = ABCMeta
  20. PRIORITY = Priority.VERY_LOW
  21. def __init__(self,
  22. interval=10,
  23. ignore_last=True,
  24. reset_flag=False,
  25. by_epoch=True):
  26. self.interval = interval
  27. self.ignore_last = ignore_last
  28. self.reset_flag = reset_flag
  29. self.by_epoch = by_epoch
  30. @abstractmethod
  31. def log(self, trainer):
  32. pass
  33. @staticmethod
  34. def is_scalar(val, include_np=True, include_torch=True):
  35. """Tell the input variable is a scalar or not.
  36. Args:
  37. val: Input variable.
  38. include_np (bool): Whether to treat 0-d np.ndarray as a scalar.
  39. include_torch (bool): Whether to treat 0-d torch.Tensor as a scalar.
  40. Returns:
  41. bool: True or False.
  42. """
  43. if isinstance(val, numbers.Number):
  44. return True
  45. elif include_np and isinstance(val, np.ndarray) and val.ndim == 0:
  46. return True
  47. elif include_torch and isinstance(val, torch.Tensor) and len(val) == 1:
  48. return True
  49. else:
  50. return False
  51. def fetch_tensor(self, trainer, n=0):
  52. """Fetch latest n values or all values, process tensor type, convert to numpy for dump logs."""
  53. assert n >= 0
  54. for key in trainer.log_buffer.val_history:
  55. values = trainer.log_buffer.val_history[key][-n:]
  56. for i, v in enumerate(values):
  57. if isinstance(v, torch.Tensor):
  58. values[i] = v.clone().detach().cpu().numpy()
  59. trainer.log_buffer.val_history[key][-n:] = values
  60. def get_epoch(self, trainer):
  61. if trainer.mode in [ModeKeys.TRAIN, ModeKeys.EVAL]:
  62. epoch = trainer.epoch + 1
  63. else:
  64. raise ValueError(
  65. f'trainer mode should be {ModeKeys.TRAIN} or {ModeKeys.EVAL}, '
  66. f'but got {trainer.mode}')
  67. return epoch
  68. def get_iter(self, trainer, inner_iter=False):
  69. """Get the current training iteration step."""
  70. if self.by_epoch and inner_iter:
  71. current_iter = trainer.inner_iter + 1
  72. else:
  73. current_iter = trainer.iter + 1
  74. return current_iter
  75. def before_run(self, trainer):
  76. for hook in trainer.hooks[::-1]:
  77. if isinstance(hook, LoggerHook):
  78. hook.reset_flag = True
  79. break
  80. def before_epoch(self, trainer):
  81. trainer.log_buffer.clear() # clear logs of last epoch
  82. def after_train_iter(self, trainer):
  83. if self.by_epoch and self.every_n_inner_iters(trainer, self.interval):
  84. self.fetch_tensor(trainer, self.interval)
  85. trainer.log_buffer.average(self.interval)
  86. elif not self.by_epoch and self.every_n_iters(trainer, self.interval):
  87. self.fetch_tensor(trainer, self.interval)
  88. trainer.log_buffer.average(self.interval)
  89. elif self.end_of_epoch(trainer) and not self.ignore_last:
  90. # not precise but more stable
  91. self.fetch_tensor(trainer, self.interval)
  92. trainer.log_buffer.average(self.interval)
  93. if trainer.log_buffer.ready:
  94. self.log(trainer)
  95. if self.reset_flag:
  96. trainer.log_buffer.clear_output()
  97. def after_train_epoch(self, trainer):
  98. if trainer.log_buffer.ready:
  99. self.log(trainer)
  100. if self.reset_flag:
  101. trainer.log_buffer.clear_output()
  102. def after_val_epoch(self, trainer):
  103. self.fetch_tensor(trainer)
  104. trainer.log_buffer.average()
  105. self.log(trainer)
  106. if self.reset_flag:
  107. trainer.log_buffer.clear_output()