hook.py 6.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. from functools import wraps
  4. from modelscope.utils.constant import TrainerStages
  5. from modelscope.utils.import_utils import is_method_overridden
  6. from .priority import Priority
  7. class Hook:
  8. """
  9. The Hook base class of any modelscope trainer. You can build your own hook inherited from this class.
  10. """
  11. stages = (TrainerStages.after_init, TrainerStages.before_run,
  12. TrainerStages.before_val, TrainerStages.before_train_epoch,
  13. TrainerStages.before_train_iter, TrainerStages.after_train_iter,
  14. TrainerStages.after_train_epoch, TrainerStages.before_val_epoch,
  15. TrainerStages.before_val_iter, TrainerStages.after_val_iter,
  16. TrainerStages.after_val_epoch, TrainerStages.after_run,
  17. TrainerStages.after_val)
  18. PRIORITY = Priority.NORMAL
  19. def after_init(self, trainer):
  20. """
  21. Will be called at the end of the trainer's `__init__` method
  22. """
  23. pass
  24. def before_run(self, trainer):
  25. """
  26. Will be called before trainer loop begins.
  27. Args:
  28. trainer: The trainer instance.
  29. Returns: None
  30. """
  31. pass
  32. def after_run(self, trainer):
  33. """
  34. Will be called after trainer loop end.
  35. Args:
  36. trainer: The trainer instance.
  37. Returns: None
  38. """
  39. pass
  40. def before_val(self, trainer):
  41. """
  42. Will be called before eval loop begins.
  43. Args:
  44. trainer: The trainer instance.
  45. Returns: None
  46. """
  47. pass
  48. def after_val(self, trainer):
  49. """
  50. Will be called after eval loop end.
  51. Args:
  52. trainer: The trainer instance.
  53. Returns: None
  54. """
  55. pass
  56. def before_epoch(self, trainer):
  57. """
  58. Will be called before every epoch begins.
  59. Args:
  60. trainer: The trainer instance.
  61. Returns: None
  62. """
  63. pass
  64. def after_epoch(self, trainer):
  65. """
  66. Will be called after every epoch ends.
  67. Args:
  68. trainer: The trainer instance.
  69. Returns: None
  70. """
  71. pass
  72. def before_iter(self, trainer):
  73. """
  74. Will be called before every loop begins.
  75. Args:
  76. trainer: The trainer instance.
  77. Returns: None
  78. """
  79. pass
  80. def after_iter(self, trainer):
  81. """
  82. Will be called after every loop ends.
  83. Args:
  84. trainer: The trainer instance.
  85. Returns: None
  86. """
  87. pass
  88. def before_train_epoch(self, trainer):
  89. """
  90. Will be called before every train epoch begins. Default call ``self.before_epoch``
  91. Args:
  92. trainer: The trainer instance.
  93. Returns: None
  94. """
  95. self.before_epoch(trainer)
  96. def before_val_epoch(self, trainer):
  97. """
  98. Will be called before every validation epoch begins. Default call ``self.before_epoch``
  99. Args:
  100. trainer: The trainer instance.
  101. Returns: None
  102. """
  103. self.before_epoch(trainer)
  104. def after_train_epoch(self, trainer):
  105. """
  106. Will be called after every train epoch ends. Default call ``self.after_epoch``
  107. Args:
  108. trainer: The trainer instance.
  109. Returns: None
  110. """
  111. self.after_epoch(trainer)
  112. def after_val_epoch(self, trainer):
  113. """
  114. Will be called after every validation epoch ends. Default call ``self.after_epoch``
  115. Args:
  116. trainer: The trainer instance.
  117. Returns: None
  118. """
  119. self.after_epoch(trainer)
  120. def before_train_iter(self, trainer):
  121. """
  122. Will be called before every train loop begins. Default call ``self.before_iter``
  123. Args:
  124. trainer: The trainer instance.
  125. Returns: None
  126. """
  127. self.before_iter(trainer)
  128. def before_val_iter(self, trainer):
  129. """
  130. Will be called before every validation loop begins. Default call ``self.before_iter``
  131. Args:
  132. trainer: The trainer instance.
  133. Returns: None
  134. """
  135. self.before_iter(trainer)
  136. def after_train_iter(self, trainer):
  137. """
  138. Will be called after every train loop ends. Default call ``self.after_iter``
  139. Args:
  140. trainer: The trainer instance.
  141. Returns: None
  142. """
  143. self.after_iter(trainer)
  144. def after_val_iter(self, trainer):
  145. """
  146. Will be called after every validation loop ends. Default call ``self.after_iter``
  147. Args:
  148. trainer: The trainer instance.
  149. Returns: None
  150. """
  151. self.after_iter(trainer)
  152. @staticmethod
  153. def every_n_epochs(trainer, n):
  154. """
  155. Whether to reach every ``n`` epochs
  156. Returns: bool
  157. """
  158. return (trainer.epoch + 1) % n == 0 if n > 0 else False
  159. @staticmethod
  160. def every_n_inner_iters(runner, n):
  161. """
  162. Whether to reach every ``n`` iterations at every epoch
  163. Returns: bool
  164. """
  165. return (runner.inner_iter + 1) % n == 0 if n > 0 else False
  166. @staticmethod
  167. def every_n_iters(trainer, n):
  168. """
  169. Whether to reach every ``n`` iterations
  170. Returns: bool
  171. """
  172. return (trainer.iter + 1) % n == 0 if n > 0 else False
  173. @staticmethod
  174. def end_of_epoch(trainer):
  175. """
  176. Whether to reach the end of every epoch
  177. Returns: bool
  178. """
  179. return trainer.inner_iter + 1 == trainer.iters_per_epoch
  180. @staticmethod
  181. def is_last_epoch(trainer):
  182. """
  183. Whether to reach the last epoch
  184. Returns: bool
  185. """
  186. return trainer.epoch + 1 == trainer.max_epochs
  187. @staticmethod
  188. def is_last_iter(trainer):
  189. """
  190. Whether to reach the last iteration in the entire training process
  191. Returns: bool
  192. """
  193. return trainer.iter + 1 == trainer.max_iters
  194. def get_triggered_stages(self):
  195. trigger_stages = set()
  196. for stage in Hook.stages:
  197. if is_method_overridden(stage, Hook, self):
  198. trigger_stages.add(stage)
  199. return [stage for stage in Hook.stages if stage in trigger_stages]
  200. def state_dict(self):
  201. return {}
  202. def load_state_dict(self, state_dict):
  203. pass