default_config.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict, List, Optional, Tuple
  3. from modelscope.utils.config import Config
  4. DEFAULT_HOOKS_CONFIG = {
  5. 'train.hooks': [{
  6. 'type': 'CheckpointHook',
  7. 'interval': 1
  8. }, {
  9. 'type': 'TextLoggerHook',
  10. 'interval': 10
  11. }, {
  12. 'type': 'IterTimerHook'
  13. }]
  14. }
  15. _HOOK_KEY_CHAIN_MAP = {
  16. 'TextLoggerHook': 'train.logging',
  17. 'CheckpointHook': 'train.checkpoint.period',
  18. 'BestCkptSaverHook': 'train.checkpoint.best',
  19. 'EvaluationHook': 'evaluation.period',
  20. }
  21. def merge_cfg(cfg: Config):
  22. """Merge the default config into the input cfg.
  23. This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg.
  24. Aegs:
  25. cfg: The input cfg to be merged into.
  26. """
  27. cfg.merge_from_dict(DEFAULT_HOOKS_CONFIG, force=False)
  28. def merge_hooks(cfg: Config) -> List[Dict]:
  29. hooks = getattr(cfg.train, 'hooks', []).copy()
  30. for hook_type, key_chain in _HOOK_KEY_CHAIN_MAP.items():
  31. hook = _key_chain_to_hook(cfg, key_chain, hook_type)
  32. if hook is not None:
  33. hooks.append(hook)
  34. return hooks
  35. def update_cfg(cfg: Config) -> Config:
  36. if 'hooks' not in cfg.train:
  37. return cfg
  38. key_chain_map = {}
  39. for hook in cfg.train.hooks:
  40. if not hook:
  41. continue
  42. key, value = _hook_split(hook)
  43. if key not in _HOOK_KEY_CHAIN_MAP:
  44. continue
  45. key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value
  46. hook.clear()
  47. cfg.train.hooks = list(filter(bool, cfg.train.hooks))
  48. cfg.merge_from_dict(key_chain_map, force=False)
  49. return cfg
  50. def _key_chain_to_hook(cfg: Config, key_chain: str,
  51. hook_type: str) -> Optional[Dict]:
  52. if not _check_basic_hook(cfg, key_chain, hook_type):
  53. return None
  54. hook_params: Dict = cfg.safe_get(key_chain)
  55. hook = {'type': hook_type}
  56. hook.update(hook_params)
  57. return hook
  58. def _check_basic_hook(cfg: Config, key_chain: str, hook_type: str) -> bool:
  59. if cfg.safe_get(key_chain) is None:
  60. return False
  61. hooks = list(
  62. filter(lambda hook: hook['type'] == hook_type,
  63. getattr(cfg.train, 'hooks', [])))
  64. assert len(hooks) == 0, f'The key_chain {key_chain} and the traditional hook ' \
  65. f'cannot exist at the same time, ' \
  66. f'please delete {hook_type} in the configuration file.'
  67. return True
  68. def _hook_split(hook: Dict) -> Tuple[str, Dict]:
  69. hook = hook.copy()
  70. return hook.pop('type'), hook