checkpoint_hook.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import random
  4. import shutil
  5. from typing import Optional
  6. import json
  7. import numpy as np
  8. import torch
  9. from modelscope.hub.check_model import check_model_is_id
  10. from modelscope.hub.push_to_hub import (UploadStrategy, push_to_hub_in_queue,
  11. wait_for_done)
  12. from modelscope.metainfo import Hooks
  13. from modelscope.trainers.hooks.builder import HOOKS
  14. from modelscope.trainers.hooks.checkpoint.checkpoint_processor import \
  15. CheckpointProcessor
  16. from modelscope.trainers.hooks.hook import Hook
  17. from modelscope.trainers.hooks.priority import Priority
  18. from modelscope.utils.constant import (DEFAULT_REPOSITORY_REVISION, LogKeys,
  19. ModelFile)
  20. from modelscope.utils.logger import get_logger
  21. from modelscope.utils.torch_utils import is_master
  22. class CheckpointStrategy:
  23. by_epoch = 'by_epoch'
  24. by_step = 'by_step'
  25. no = 'no'
  26. @HOOKS.register_module(module_name=Hooks.CheckpointHook)
  27. class CheckpointHook(Hook):
  28. """Save checkpoints periodically.
  29. Args:
  30. save_strategy(str): The strategy to save checkpoint, can be `by_epoch`, `by_step` or `no`
  31. interval (int): The frequency to save model. If `by_epoch=True`,
  32. it means the number of epochs, else means the number of iterations
  33. save_dir (str): The directory to save checkpoints. If is None, use `trainer.work_dir`
  34. output_dir (str): The absolute path to save the output files for inference. If it's not specified,
  35. the default dir is `{sub_dir}/output`.
  36. save_last (bool): Whether to save the last checkpoint. Default: True.
  37. max_checkpoint_num (int): The max number of checkpoint files, default None which means never delete anything.
  38. If the number exceeding the limit, earlier checkpoints will be deleted first.
  39. push_to_hub (bool): Whether push the checkpoint to modelhub.
  40. hub_repo_id (str): The hub repo id.
  41. hub_token (str): The token of the modelhub. You can also set the environment variable `MODELSCOPE_API_TOKEN`.
  42. private_hub (bool): Whether push to a private hub, default True.
  43. hub_revision (str): Which branch to push the model to, default is `master`.
  44. upload_strategy (str): The action adopted when the previous uploading is not done
  45. and the next one is coming, can be `cancel` or `wait`.
  46. save_trainer_state (bool): Save the trainer state for continue training, default True.
  47. kwargs:
  48. by_epoch (bool): Same with `save_strategy`, but has a higher priority, legacy argument.
  49. output_sub_dir (str): The folder under the `save_dir` to save the output checkpoint for inference.
  50. This argument is kept to fit the existing configs.
  51. """
  52. PRIORITY = Priority.LOW
  53. EVAL_RESULT_FILE = 'eval_result.txt'
  54. PUSH_TO_HUB_QUEUE_NAME = 'train.checkpoint'
  55. def __init__(self,
  56. save_strategy: Optional[str] = CheckpointStrategy.by_epoch,
  57. interval: Optional[int] = 0,
  58. save_dir: Optional[str] = None,
  59. output_dir: Optional[str] = None,
  60. save_last: Optional[bool] = True,
  61. max_checkpoint_num: Optional[int] = None,
  62. push_to_hub: Optional[bool] = False,
  63. hub_repo_id: Optional[str] = None,
  64. hub_token: Optional[str] = None,
  65. private_hub: Optional[bool] = True,
  66. hub_revision: Optional[str] = DEFAULT_REPOSITORY_REVISION,
  67. upload_strategy: Optional[str] = UploadStrategy.cancel,
  68. save_trainer_state: bool = True,
  69. **kwargs):
  70. self.interval = interval
  71. self.save_dir = save_dir
  72. if 'by_epoch' in kwargs:
  73. self.save_strategy = CheckpointStrategy.by_epoch if kwargs[
  74. 'by_epoch'] else CheckpointStrategy.by_step
  75. else:
  76. self.save_strategy = save_strategy
  77. if 'output_sub_dir' in kwargs:
  78. self.output_sub_dir = kwargs['output_sub_dir']
  79. self.output_dir = None
  80. else:
  81. self.output_sub_dir = None
  82. self.output_dir = output_dir
  83. self.save_last = save_last
  84. self.rng_state = None
  85. self.push_to_hub = push_to_hub
  86. self.hub_repo_id = hub_repo_id
  87. self.hub_token = hub_token
  88. self.private_hub = private_hub
  89. self.hub_revision = hub_revision
  90. self.upload_strategy = upload_strategy
  91. self.save_trainer_state = save_trainer_state
  92. self.tag = -1
  93. self.is_model_id = None
  94. self.max_checkpoint_num = None
  95. if max_checkpoint_num is not None:
  96. self.max_checkpoint_num = max(int(max_checkpoint_num), 1)
  97. self.history_checkpoints = []
  98. self.processor = CheckpointProcessor()
  99. def set_processor(self, processor):
  100. """
  101. The checkpoint hook accepts a processor to finish the actual saving/deleting action.
  102. """
  103. self.processor = processor
  104. def before_run(self, trainer):
  105. self.tag = -1
  106. if not self.save_dir:
  107. self.save_dir = trainer.work_dir
  108. if not self.output_dir:
  109. if self.output_sub_dir:
  110. self.output_dir = os.path.join(self.save_dir,
  111. self.output_sub_dir)
  112. else:
  113. self.output_dir = os.path.join(self.save_dir,
  114. ModelFile.TRAIN_OUTPUT_DIR)
  115. if not os.path.exists(self.save_dir):
  116. os.makedirs(self.save_dir, exist_ok=True)
  117. if not hasattr(trainer, 'logger'):
  118. self.logger = get_logger()
  119. else:
  120. self.logger = trainer.logger
  121. if is_master():
  122. output_dir = self.output_dir
  123. # only global master prepares the output folder
  124. self.processor.prepare_output(trainer, output_dir)
  125. self.logger.info(f'Checkpoints will be saved to {self.save_dir}')
  126. def generate_prefix(self, trainer, save_strategy):
  127. if save_strategy == CheckpointStrategy.by_epoch:
  128. return f'{LogKeys.EPOCH}_{trainer.epoch + 1}'
  129. else:
  130. return f'{LogKeys.ITER}_{trainer.iter + 1}'
  131. def _do_save(self, trainer, save_strategy):
  132. # prefix like 'epoch-1' or 'iter-1'
  133. prefix = self.generate_prefix(trainer, save_strategy)
  134. if self.processor.should_save_on_rank(trainer):
  135. if is_master():
  136. if save_strategy == CheckpointStrategy.by_epoch:
  137. self.logger.info(
  138. f'Saving checkpoint at {trainer.epoch + 1} epoch')
  139. else:
  140. self.logger.info(
  141. f'Saving checkpoint at {trainer.iter + 1} iter')
  142. self._save_checkpoint(trainer, prefix)
  143. if is_master() and self.push_to_hub:
  144. if self.upload_strategy == UploadStrategy.cancel:
  145. output_dir = self.output_dir
  146. delete_dir = False
  147. else:
  148. output_dir = self.output_dir + '_upload_' + prefix
  149. shutil.copytree(
  150. self.output_dir, output_dir, dirs_exist_ok=True)
  151. delete_dir = True
  152. self._push_to_hub(trainer, prefix, output_dir, delete_dir)
  153. def after_train_epoch(self, trainer):
  154. if self.save_strategy != CheckpointStrategy.by_epoch:
  155. return
  156. if self._should_save(trainer):
  157. self._do_save(trainer, CheckpointStrategy.by_epoch)
  158. def after_train_iter(self, trainer):
  159. if self.save_strategy != CheckpointStrategy.by_step:
  160. return
  161. if self._should_save(trainer):
  162. self._do_save(trainer, CheckpointStrategy.by_step)
  163. def after_run(self, trainer):
  164. self.logger.info('Train finished. Uploading models, waiting...')
  165. push_to_hub_in_queue(
  166. self.PUSH_TO_HUB_QUEUE_NAME,
  167. strategy=self.upload_strategy,
  168. done=True)
  169. wait_for_done(self.PUSH_TO_HUB_QUEUE_NAME)
  170. if self.push_to_hub:
  171. self.logger.info('Uploading models done.')
  172. def _push_to_hub(self, trainer, prefix, output_dir, delete_dir=False):
  173. if self.is_model_id is None:
  174. self.is_model_id = check_model_is_id(trainer.input_model_id,
  175. self.hub_token)
  176. self.tag += 1
  177. return push_to_hub_in_queue(
  178. self.PUSH_TO_HUB_QUEUE_NAME,
  179. strategy=self.upload_strategy,
  180. repo_name=self.hub_repo_id,
  181. output_dir=output_dir,
  182. token=self.hub_token,
  183. private=self.private_hub,
  184. commit_message=prefix,
  185. tag=f'v1.{self.tag}',
  186. revision=self.hub_revision,
  187. source_repo=trainer.input_model_id if self.is_model_id else '',
  188. delete_dir=delete_dir)
  189. def save_evaluate_results(self, trainer):
  190. with open(os.path.join(self.output_dir, self.EVAL_RESULT_FILE),
  191. 'w') as f:
  192. f.write(json.dumps(trainer.metric_values))
  193. def _save_checkpoint(self, trainer, prefix):
  194. """Save checkpoint files and remove obsolete ones
  195. """
  196. checkpoint_path_prefix = os.path.join(self.save_dir, prefix)
  197. meta = self._create_training_state(trainer)
  198. self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
  199. self.output_dir, meta,
  200. self.save_trainer_state)
  201. self.save_evaluate_results(trainer)
  202. self.history_checkpoints.append(checkpoint_path_prefix)
  203. self._remove_obsolete_checkpoints(trainer)
  204. return prefix
  205. def _remove_obsolete_checkpoints(self, trainer):
  206. if self.max_checkpoint_num is not None and \
  207. len(self.history_checkpoints) > self.max_checkpoint_num:
  208. history_checkpoints = [ckpt for ckpt in self.history_checkpoints]
  209. self.history_checkpoints.clear()
  210. for i, checkpoint_path_prefix in enumerate(history_checkpoints):
  211. if i < len(history_checkpoints) - self.max_checkpoint_num:
  212. self.logger.info(
  213. f'deleting checkpoint: {checkpoint_path_prefix}')
  214. self.processor.remove_checkpoints(
  215. trainer, checkpoint_path_prefix=checkpoint_path_prefix)
  216. else:
  217. self.history_checkpoints.append(checkpoint_path_prefix)
  218. def _should_save(self, trainer):
  219. if self.save_strategy == CheckpointStrategy.by_epoch:
  220. check_last = self.is_last_epoch
  221. check_frequency = self.every_n_epochs
  222. elif self.save_strategy == CheckpointStrategy.by_step:
  223. check_last = self.is_last_iter
  224. check_frequency = self.every_n_iters
  225. else:
  226. return False
  227. if check_frequency(trainer,
  228. self.interval) or (self.save_last
  229. and check_last(trainer)):
  230. return True
  231. return False
  232. def _create_training_state(self, trainer):
  233. self.rng_state = {
  234. 'random': random.getstate(),
  235. 'numpy': np.random.get_state(),
  236. 'cpu': torch.random.get_rng_state(),
  237. 'cuda': torch.cuda.get_rng_state_all(),
  238. }
  239. # keep epoch/iter/inner_iter/random_state
  240. meta = {
  241. 'epoch': trainer.epoch,
  242. 'iter': trainer.iter + 1,
  243. 'inner_iter': trainer.inner_iter + 1,
  244. 'rng_state': self.rng_state,
  245. }
  246. # keep hooks state
  247. i = 0
  248. for hook in trainer.hooks:
  249. if hasattr(hook, 'state_dict') and getattr(hook, '_should_save',
  250. True):
  251. meta[f'{hook.__class__}-{i}'] = hook.state_dict()
  252. i += 1
  253. return meta
  254. @HOOKS.register_module(module_name=Hooks.BestCkptSaverHook)
  255. class BestCkptSaverHook(CheckpointHook):
  256. """
  257. Save best checkpoints hook.
  258. Args:
  259. metric_key (str): Metric key to compare rule for best score.
  260. save_best(bool): Save the best checkpoint, if set to False, this hook will have no effect.
  261. rule (str): Comparison rule for best score. Support "max" and "min". If rule is "max", the checkpoint
  262. at the maximum `metric_key` will be saved, If rule is "min", the checkpoint at the minimum `metric_key`
  263. will be saved.
  264. save_file_name: The manual specified saving file name.
  265. restore_best (bool): Whether to restore the best checkpoint after training.
  266. max_checkpoint_num (int): The max number of checkpoint files, default None which means never delete anything.
  267. If the number exceeding the limit, checkpoints with worse metric will be deleted, which is judged by the
  268. `rule` and `metric_key` arguments.
  269. save_trainer_state (bool): Save the trainer state for continue training, default True.
  270. The `BestCkptSaverHook` class accepts `output_sub_dir` and `output_dir` argument as its super class do.
  271. If neither of them are passed, the default value is `{save_dir}/output_best`.
  272. This class will not accept the `interval` or `save_strategy` or `by_epoch` argument, because the saving interval
  273. will follow the `EvaluationHook`.
  274. """
  275. PRIORITY = Priority.LOW
  276. rule_map = {'max': lambda x, y: x > y, 'min': lambda x, y: x < y}
  277. def __init__(self,
  278. metric_key: str,
  279. save_best: Optional[bool] = True,
  280. rule: Optional[str] = 'max',
  281. save_file_name: Optional[str] = None,
  282. restore_best: Optional[bool] = False,
  283. max_checkpoint_num: Optional[int] = 1,
  284. save_trainer_state: bool = True,
  285. **kwargs):
  286. assert rule in ['max', 'min'], 'Only support "max" or "min" rule now.'
  287. output_kwargs = {}
  288. if 'output_sub_dir' not in kwargs and 'output_dir' not in kwargs:
  289. output_kwargs['output_sub_dir'] = ModelFile.TRAIN_BEST_OUTPUT_DIR
  290. kwargs.pop('interval', None)
  291. kwargs.pop('save_strategy', None)
  292. super().__init__(
  293. max_checkpoint_num=max_checkpoint_num,
  294. save_trainer_state=save_trainer_state,
  295. **kwargs,
  296. **output_kwargs,
  297. )
  298. self.save_best = save_best
  299. self.metric_key = metric_key
  300. self.rule = rule
  301. self._best_metric = None
  302. self._best_ckpt_file = None
  303. self.save_file_name = save_file_name
  304. self.restore_best = restore_best
  305. self.history_checkpoints = set()
  306. def after_train_epoch(self, trainer):
  307. from modelscope.trainers.hooks import EvaluationHook
  308. eval_hook = trainer.get_hook(EvaluationHook)
  309. if len(eval_hook) == 0:
  310. self.logger.error(
  311. 'Trying to save the best checkpoint, but there is no evaluation, skipping.'
  312. )
  313. if eval_hook[0].last_eval_tag == (
  314. 'epoch', trainer.epoch) and self._should_save(trainer):
  315. self._do_save(trainer, 'by_epoch')
  316. def after_train_iter(self, trainer):
  317. from modelscope.trainers.hooks import EvaluationHook
  318. eval_hook = trainer.get_hook(EvaluationHook)
  319. if len(eval_hook) == 0:
  320. self.logger.error(
  321. 'Trying to save the best checkpoint, but there is no evaluation, skipping.'
  322. )
  323. if eval_hook[0].last_eval_tag == (
  324. 'iter', trainer.iter) and self._should_save(trainer):
  325. self._do_save(trainer, 'by_step')
  326. def _should_save(self, trainer):
  327. return self.save_best and self._is_best_metric(trainer.metric_values)
  328. def _is_best_metric(self, metric_values):
  329. if metric_values is None:
  330. return False
  331. if self.metric_key not in metric_values:
  332. raise ValueError(
  333. f'Not find metric_key: {self.metric_key} in {metric_values}')
  334. if self._best_metric is None:
  335. self._best_metric = metric_values[self.metric_key]
  336. return True
  337. else:
  338. compare_fn = self.rule_map[self.rule]
  339. if compare_fn(metric_values[self.metric_key], self._best_metric):
  340. self._best_metric = metric_values[self.metric_key]
  341. return True
  342. return False
  343. def generate_prefix(self, trainer, save_strategy):
  344. if save_strategy == CheckpointStrategy.by_epoch:
  345. return f'best_{LogKeys.EPOCH}{trainer.epoch + 1}_{self.metric_key}{self._best_metric}'
  346. else:
  347. return f'best_{LogKeys.ITER}{trainer.iter + 1}_{self.metric_key}{self._best_metric}'
  348. def _save_checkpoint(self, trainer, prefix):
  349. checkpoint_path_prefix = self.save_file_name
  350. if checkpoint_path_prefix is None:
  351. checkpoint_path_prefix = os.path.join(self.save_dir, prefix)
  352. else:
  353. checkpoint_path_prefix = os.path.join(self.save_dir,
  354. checkpoint_path_prefix)
  355. self._best_ckpt_file = checkpoint_path_prefix
  356. meta = self._create_training_state(trainer)
  357. self.processor.save_checkpoints(trainer, checkpoint_path_prefix,
  358. self.output_dir, meta,
  359. self.save_trainer_state)
  360. self.save_evaluate_results(trainer)
  361. self.history_checkpoints.add(checkpoint_path_prefix)
  362. self._remove_obsolete_checkpoints(trainer)
  363. return prefix
  364. def _remove_obsolete_checkpoints(self, trainer):
  365. def extract_metric_from_filename(name1):
  366. metric1 = float(name1.split(self.metric_key)[1])
  367. if self.rule == 'max':
  368. return -metric1
  369. else:
  370. return metric1
  371. if self.max_checkpoint_num is not None and \
  372. len(self.history_checkpoints) > self.max_checkpoint_num:
  373. history_checkpoints = sorted(
  374. self.history_checkpoints, key=extract_metric_from_filename)
  375. self.history_checkpoints.clear()
  376. for i, checkpoint_path_prefix in enumerate(history_checkpoints):
  377. if i < self.max_checkpoint_num:
  378. self.history_checkpoints.add(checkpoint_path_prefix)
  379. else:
  380. self.logger.info(
  381. f'deleting checkpoint: {checkpoint_path_prefix}')
  382. self.processor.remove_checkpoints(
  383. trainer, checkpoint_path_prefix=checkpoint_path_prefix)
  384. def state_dict(self):
  385. return {
  386. 'best_metric': self._best_metric,
  387. }
  388. def load_state_dict(self, state_dict):
  389. if state_dict is not None and len(state_dict) > 0:
  390. self._best_metric = state_dict.get('best_metric')
  391. else:
  392. self.logger.warning(
  393. 'The state_dict is not available, the best metric value will be affected.'
  394. )
  395. def after_run(self, trainer):
  396. if self.restore_best:
  397. # If restore_best is True, will call the LoadCheckpointHook to load the best checkpoint
  398. # for later evaluation or prediction.
  399. from modelscope.trainers.hooks.checkpoint.load_checkpoint_hook import LoadCheckpointHook
  400. LoadCheckpointHook.load_checkpoint(self._best_ckpt_file, trainer)