builder.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import inspect
  3. import torch
  4. from packaging import version
  5. from modelscope.utils.config import ConfigDict
  6. from modelscope.utils.registry import Registry, build_from_cfg, default_group
  7. LR_SCHEDULER = Registry('lr_scheduler')
  8. def build_lr_scheduler(cfg: ConfigDict, default_args: dict = None):
  9. """ build lr scheduler from given lr scheduler config dict
  10. Args:
  11. cfg (:obj:`ConfigDict`): config dict for lr scheduler object.
  12. default_args (dict, optional): Default initialization arguments.
  13. """
  14. if cfg['type'].lower().endswith('warmup'):
  15. # build warmup lr scheduler
  16. if not hasattr(cfg, 'base_scheduler'):
  17. if default_args is None or ('base_scheduler' not in default_args):
  18. raise ValueError(
  19. 'Must provide ``base_scheduler`` which is an instance of ``torch.optim.lr_scheduler._LRScheduler`` '
  20. 'for build warmup lr scheduler.')
  21. else:
  22. # build lr scheduler without warmup
  23. if not hasattr(cfg, 'optimizer'):
  24. if default_args is None or ('optimizer' not in default_args):
  25. raise ValueError(
  26. 'Must provide ``optimizer`` which is an instance of ``torch.optim.Optimizer`` '
  27. 'for build lr scheduler')
  28. return build_from_cfg(
  29. cfg, LR_SCHEDULER, group_key=default_group, default_args=default_args)
  30. def register_torch_lr_scheduler():
  31. from torch.optim import lr_scheduler
  32. if version.parse(torch.__version__) < version.parse('2.0.0.dev'):
  33. from torch.optim.lr_scheduler import _LRScheduler
  34. else:
  35. from torch.optim.lr_scheduler import LRScheduler as _LRScheduler
  36. members = inspect.getmembers(lr_scheduler)
  37. for name, obj in members:
  38. if (inspect.isclass(obj) and issubclass(
  39. obj, _LRScheduler)) or name in ['ReduceLROnPlateau']:
  40. LR_SCHEDULER.register_module(module_name=name, module_cls=obj)
  41. register_torch_lr_scheduler()