base.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import logging
  3. from torch.nn.utils import clip_grad
  4. from modelscope.metainfo import Hooks
  5. from modelscope.outputs import OutputKeys
  6. from modelscope.trainers.hooks.builder import HOOKS
  7. from modelscope.trainers.hooks.hook import Hook
  8. from modelscope.trainers.hooks.priority import Priority
  9. class OptimizerProcessor:
  10. def initialize_optimizer(self, trainer):
  11. """Initialize the optimizer.
  12. This is a strategic function which can be registered by other hook's function.
  13. """
  14. trainer.optimizer.zero_grad()
  15. def before_forward(self, trainer):
  16. pass
  17. def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
  18. """Do module backward, optimizer's step and zero_grad and clip the grads.
  19. This is a strategic function which can be registered by other hook's function.
  20. Args:
  21. trainer(`EpochBasedTrainer`): The trainer instance.
  22. loss_keys(`list`): The list of loss keys.
  23. cumulative_iters(`int`): The cumulative iters for gradients.
  24. grad_clip(`dict`): The grad clipping options.
  25. """
  26. for k in loss_keys:
  27. trainer.train_outputs[k] /= cumulative_iters
  28. trainer.train_outputs[k].backward()
  29. if Hook.every_n_iters(trainer, cumulative_iters):
  30. if grad_clip is not None:
  31. self.clip_grads(trainer.model.parameters(), **grad_clip)
  32. trainer.optimizer.step()
  33. trainer.optimizer.zero_grad()
  34. @staticmethod
  35. def clip_grads(params, **clip_args):
  36. params = list(
  37. filter(lambda p: p.requires_grad and p.grad is not None, params))
  38. if len(params) > 0:
  39. return clip_grad.clip_grad_norm_(params, **clip_args)
  40. @HOOKS.register_module(module_name=Hooks.OptimizerHook)
  41. class OptimizerHook(Hook):
  42. """Optimizer hook
  43. Args:
  44. cumulative_iters (int): interval of gradients accumulation. Default: 1
  45. grad_clip (dict): Default None. Containing keys:
  46. max_norm (float or int): max norm of the gradients
  47. norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
  48. More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_`
  49. loss_keys (str | list): keys list of loss
  50. """
  51. PRIORITY = Priority.ABOVE_NORMAL
  52. def __init__(self,
  53. cumulative_iters=1,
  54. grad_clip=None,
  55. loss_keys=OutputKeys.LOSS,
  56. **kwargs) -> None:
  57. if isinstance(loss_keys, str):
  58. loss_keys = [loss_keys]
  59. assert isinstance(loss_keys, (tuple, list))
  60. self.loss_keys = loss_keys
  61. self.cumulative_iters = cumulative_iters
  62. self.grad_clip = grad_clip
  63. self.processor = OptimizerProcessor()
  64. def set_processor(self, processor):
  65. self.processor = processor
  66. def before_run(self, trainer):
  67. trainer.cumulative_iters = self.cumulative_iters
  68. self.processor.initialize_optimizer(trainer)
  69. def before_train_iter(self, trainer):
  70. self.processor.before_forward(trainer)
  71. def after_train_iter(self, trainer):
  72. self.processor.backward(trainer, self.loss_keys, self.cumulative_iters,
  73. self.grad_clip)
  74. @HOOKS.register_module(module_name=Hooks.NoneOptimizerHook)
  75. class NoneOptimizerHook(OptimizerHook):
  76. def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'):
  77. super(NoneOptimizerHook, self).__init__(
  78. grad_clip=grad_clip, loss_keys=loss_keys)
  79. self.cumulative_iters = cumulative_iters
  80. def before_run(self, trainer):
  81. return
  82. def after_train_iter(self, trainer):
  83. return