builder.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import inspect
  3. from typing import Iterable, Union
  4. import torch
  5. from modelscope.utils.config import ConfigDict
  6. from modelscope.utils.registry import Registry, build_from_cfg, default_group
  7. OPTIMIZERS = Registry('optimizer')
  8. def build_optimizer(model: Union[torch.nn.Module,
  9. Iterable[torch.nn.parameter.Parameter]],
  10. cfg: ConfigDict,
  11. default_args: dict = None):
  12. """ build optimizer from optimizer config dict
  13. Args:
  14. model: A torch.nn.Module or an iterable of parameters.
  15. cfg (:obj:`ConfigDict`): config dict for optimizer object.
  16. default_args (dict, optional): Default initialization arguments.
  17. """
  18. if default_args is None:
  19. default_args = {}
  20. if isinstance(model, torch.nn.Module) or (hasattr(
  21. model, 'module') and isinstance(model.module, torch.nn.Module)):
  22. if hasattr(model, 'module'):
  23. model = model.module
  24. default_args['params'] = model.parameters()
  25. else:
  26. # Input is a iterable of parameters, this case fits for the scenario of user-defined parameter groups.
  27. default_args['params'] = model
  28. return build_from_cfg(
  29. cfg, OPTIMIZERS, group_key=default_group, default_args=default_args)
  30. def register_torch_optimizers():
  31. for name, module in inspect.getmembers(torch.optim):
  32. if name.startswith('__'):
  33. continue
  34. if inspect.isclass(module) and issubclass(module,
  35. torch.optim.Optimizer):
  36. OPTIMIZERS.register_module(
  37. default_group, module_name=name, module_cls=module)
  38. register_torch_optimizers()