| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Dict, List, Optional, Tuple
- from modelscope.utils.config import Config
- DEFAULT_HOOKS_CONFIG = {
- 'train.hooks': [{
- 'type': 'CheckpointHook',
- 'interval': 1
- }, {
- 'type': 'TextLoggerHook',
- 'interval': 10
- }, {
- 'type': 'IterTimerHook'
- }]
- }
- _HOOK_KEY_CHAIN_MAP = {
- 'TextLoggerHook': 'train.logging',
- 'CheckpointHook': 'train.checkpoint.period',
- 'BestCkptSaverHook': 'train.checkpoint.best',
- 'EvaluationHook': 'evaluation.period',
- }
- def merge_cfg(cfg: Config):
- """Merge the default config into the input cfg.
- This function will pop the default CheckpointHook when the BestCkptSaverHook exists in the input cfg.
- Aegs:
- cfg: The input cfg to be merged into.
- """
- cfg.merge_from_dict(DEFAULT_HOOKS_CONFIG, force=False)
- def merge_hooks(cfg: Config) -> List[Dict]:
- hooks = getattr(cfg.train, 'hooks', []).copy()
- for hook_type, key_chain in _HOOK_KEY_CHAIN_MAP.items():
- hook = _key_chain_to_hook(cfg, key_chain, hook_type)
- if hook is not None:
- hooks.append(hook)
- return hooks
- def update_cfg(cfg: Config) -> Config:
- if 'hooks' not in cfg.train:
- return cfg
- key_chain_map = {}
- for hook in cfg.train.hooks:
- if not hook:
- continue
- key, value = _hook_split(hook)
- if key not in _HOOK_KEY_CHAIN_MAP:
- continue
- key_chain_map[_HOOK_KEY_CHAIN_MAP[key]] = value
- hook.clear()
- cfg.train.hooks = list(filter(bool, cfg.train.hooks))
- cfg.merge_from_dict(key_chain_map, force=False)
- return cfg
- def _key_chain_to_hook(cfg: Config, key_chain: str,
- hook_type: str) -> Optional[Dict]:
- if not _check_basic_hook(cfg, key_chain, hook_type):
- return None
- hook_params: Dict = cfg.safe_get(key_chain)
- hook = {'type': hook_type}
- hook.update(hook_params)
- return hook
- def _check_basic_hook(cfg: Config, key_chain: str, hook_type: str) -> bool:
- if cfg.safe_get(key_chain) is None:
- return False
- hooks = list(
- filter(lambda hook: hook['type'] == hook_type,
- getattr(cfg.train, 'hooks', [])))
- assert len(hooks) == 0, f'The key_chain {key_chain} and the traditional hook ' \
- f'cannot exist at the same time, ' \
- f'please delete {hook_type} in the configuration file.'
- return True
- def _hook_split(hook: Dict) -> Tuple[str, Dict]:
- hook = hook.copy()
- return hook.pop('type'), hook
|