| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import logging
- from torch.nn.utils import clip_grad
- from modelscope.metainfo import Hooks
- from modelscope.outputs import OutputKeys
- from modelscope.trainers.hooks.builder import HOOKS
- from modelscope.trainers.hooks.hook import Hook
- from modelscope.trainers.hooks.priority import Priority
- class OptimizerProcessor:
- def initialize_optimizer(self, trainer):
- """Initialize the optimizer.
- This is a strategic function which can be registered by other hook's function.
- """
- trainer.optimizer.zero_grad()
- def before_forward(self, trainer):
- pass
- def backward(self, trainer, loss_keys, cumulative_iters, grad_clip):
- """Do module backward, optimizer's step and zero_grad and clip the grads.
- This is a strategic function which can be registered by other hook's function.
- Args:
- trainer(`EpochBasedTrainer`): The trainer instance.
- loss_keys(`list`): The list of loss keys.
- cumulative_iters(`int`): The cumulative iters for gradients.
- grad_clip(`dict`): The grad clipping options.
- """
- for k in loss_keys:
- trainer.train_outputs[k] /= cumulative_iters
- trainer.train_outputs[k].backward()
- if Hook.every_n_iters(trainer, cumulative_iters):
- if grad_clip is not None:
- self.clip_grads(trainer.model.parameters(), **grad_clip)
- trainer.optimizer.step()
- trainer.optimizer.zero_grad()
- @staticmethod
- def clip_grads(params, **clip_args):
- params = list(
- filter(lambda p: p.requires_grad and p.grad is not None, params))
- if len(params) > 0:
- return clip_grad.clip_grad_norm_(params, **clip_args)
- @HOOKS.register_module(module_name=Hooks.OptimizerHook)
- class OptimizerHook(Hook):
- """Optimizer hook
- Args:
- cumulative_iters (int): interval of gradients accumulation. Default: 1
- grad_clip (dict): Default None. Containing keys:
- max_norm (float or int): max norm of the gradients
- norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for infinity norm.
- More details please refer to `torch.nn.utils.clip_grad.clip_grad_norm_`
- loss_keys (str | list): keys list of loss
- """
- PRIORITY = Priority.ABOVE_NORMAL
- def __init__(self,
- cumulative_iters=1,
- grad_clip=None,
- loss_keys=OutputKeys.LOSS,
- **kwargs) -> None:
- if isinstance(loss_keys, str):
- loss_keys = [loss_keys]
- assert isinstance(loss_keys, (tuple, list))
- self.loss_keys = loss_keys
- self.cumulative_iters = cumulative_iters
- self.grad_clip = grad_clip
- self.processor = OptimizerProcessor()
- def set_processor(self, processor):
- self.processor = processor
- def before_run(self, trainer):
- trainer.cumulative_iters = self.cumulative_iters
- self.processor.initialize_optimizer(trainer)
- def before_train_iter(self, trainer):
- self.processor.before_forward(trainer)
- def after_train_iter(self, trainer):
- self.processor.backward(trainer, self.loss_keys, self.cumulative_iters,
- self.grad_clip)
- @HOOKS.register_module(module_name=Hooks.NoneOptimizerHook)
- class NoneOptimizerHook(OptimizerHook):
- def __init__(self, cumulative_iters=1, grad_clip=None, loss_keys='loss'):
- super(NoneOptimizerHook, self).__init__(
- grad_clip=grad_clip, loss_keys=loss_keys)
- self.cumulative_iters = cumulative_iters
- def before_run(self, trainer):
- return
- def after_train_iter(self, trainer):
- return
|