checkpoint.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import io
  3. import os
  4. import re
  5. import sys
  6. import time
  7. from collections import OrderedDict
  8. from shutil import copytree, ignore_patterns, rmtree
  9. from typing import Callable, Dict, Optional, Union
  10. import json
  11. import torch
  12. from torch import nn
  13. from torch.optim import Optimizer
  14. from torch.optim.lr_scheduler import _LRScheduler
  15. from modelscope.fileio import File, LocalStorage
  16. from modelscope.utils.config import Config, JSONIteratorEncoder
  17. from modelscope.utils.constant import ConfigFields, ModelFile
  18. from modelscope.utils.file_utils import copytree_py37
  19. from modelscope.utils.logger import get_logger
  20. from modelscope.utils.torch_utils import is_master
  21. logger = get_logger()
  22. storage = LocalStorage()
  23. def weights_to_cpu(state_dict):
  24. """Copy a model state_dict to cpu.
  25. Args:
  26. state_dict (OrderedDict): Model weights on GPU.
  27. Returns:
  28. OrderedDict: Model weights on GPU.
  29. """
  30. state_dict_cpu = OrderedDict()
  31. for key, val in state_dict.items():
  32. state_dict_cpu[key] = val.cpu()
  33. # Keep metadata in state_dict
  34. state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
  35. return state_dict_cpu
  36. def save_checkpoint(model: torch.nn.Module,
  37. filename: str,
  38. optimizer: Optional[Optimizer] = None,
  39. lr_scheduler: Optional[_LRScheduler] = None,
  40. meta: Optional[dict] = None,
  41. with_meta: bool = True,
  42. with_model: bool = True) -> None:
  43. """Save checkpoint to file.
  44. The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
  45. ``optimizer``. By default, ``meta`` will contain version and time info.
  46. Args:
  47. model (Module): Module whose params are to be saved.
  48. filename (str): Checkpoint filename.
  49. optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
  50. lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved.
  51. meta (dict, optional): Metadata to be saved in checkpoint.
  52. with_meta (bool, optional): Save meta info.
  53. with_model(bool, optional): Save model states.
  54. """
  55. checkpoint = {}
  56. if not with_meta and not with_model:
  57. raise ValueError(
  58. 'Save meta by "with_meta=True" or model by "with_model=True"')
  59. if with_meta:
  60. if meta is None:
  61. meta = {}
  62. elif not isinstance(meta, dict):
  63. raise TypeError(
  64. f'meta must be a dict or None, but got {type(meta)}')
  65. from modelscope import __version__
  66. meta.update(modelscope=__version__, time=time.asctime())
  67. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  68. model = model.module
  69. if hasattr(model, 'CLASSES') and model.CLASSES is not None:
  70. # save class name to the meta
  71. meta.update(CLASSES=model.CLASSES)
  72. checkpoint['meta'] = meta
  73. # save optimizer state dict in the checkpoint
  74. if isinstance(optimizer, Optimizer):
  75. checkpoint['optimizer'] = optimizer.state_dict()
  76. elif isinstance(optimizer, dict):
  77. checkpoint['optimizer'] = {}
  78. for name, optim in optimizer.items():
  79. checkpoint['optimizer'][name] = optim.state_dict()
  80. # save lr_scheduler state dict in the checkpoint
  81. if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
  82. checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
  83. if with_model:
  84. if isinstance(model, torch.nn.parallel.DistributedDataParallel):
  85. model = model.module
  86. _weights = weights_to_cpu(model.state_dict())
  87. if not with_meta:
  88. checkpoint = _weights
  89. else:
  90. checkpoint['state_dict'] = _weights
  91. with io.BytesIO() as f:
  92. torch.save(checkpoint, f)
  93. File.write(f.getvalue(), filename)
  94. def load_checkpoint(filename,
  95. model,
  96. optimizer: Optimizer = None,
  97. lr_scheduler: _LRScheduler = None):
  98. if not os.path.exists(filename):
  99. raise ValueError(f'Checkpoint file {filename} does not exist!')
  100. checkpoint = torch.load(filename, map_location='cpu', weights_only=True)
  101. if optimizer is not None:
  102. if 'optimizer' in checkpoint:
  103. if isinstance(optimizer, Optimizer):
  104. optimizer.load_state_dict(checkpoint['optimizer'])
  105. elif isinstance(optimizer, dict):
  106. optimizer_dict = checkpoint['optimizer']
  107. for key, optimizer_ins in optimizer.items():
  108. if key in optimizer_dict:
  109. optimizer_ins.load_state_dict(optimizer_dict[key])
  110. else:
  111. logger.warning(
  112. f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}'
  113. )
  114. else:
  115. logger.warning(
  116. f'The state dict of optimizer cannot be found in checkpoint file: {filename}'
  117. )
  118. if lr_scheduler is not None:
  119. if 'lr_scheduler' in checkpoint:
  120. lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
  121. else:
  122. logger.warning(
  123. f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}'
  124. )
  125. if model is not None:
  126. state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
  127. 'state_dict']
  128. model.load_state_dict(state_dict)
  129. return checkpoint.get('meta', {})
  130. def load_task_model_checkpoint(model_to_load,
  131. model_local_dir,
  132. default_dtype=None,
  133. load_state_fn=None,
  134. **kwargs):
  135. """
  136. Load model checkpoint file and feed the parameters into the model.
  137. Args:
  138. model_to_load: The model to be load
  139. model_local_dir: The actual checkpoint dir on local disk.
  140. default_dtype: Set the default float type by 'torch.set_default_dtype'
  141. load_state_fn: An optional load_state_fn used to load state_dict into the model.
  142. Returns:
  143. """
  144. def _add_head_prefix_to_state_dict(state_dicts, head_prefix,
  145. expected_keys_without_head_prefix,
  146. missing_keys):
  147. new_state_dict = OrderedDict()
  148. for name, module in state_dicts.items():
  149. if name in expected_keys_without_head_prefix:
  150. name_with_head = '.'.join([head_prefix, name])
  151. new_state_dict[name_with_head] = module
  152. expected_keys_without_head_prefix.remove(name)
  153. missing_keys = list(set(missing_keys) - set([name_with_head]))
  154. else:
  155. new_state_dict[name] = module
  156. missing_head_keys = []
  157. if len(expected_keys_without_head_prefix) > 0:
  158. missing_head_keys = expected_keys_without_head_prefix.copy()
  159. return new_state_dict, missing_head_keys, missing_keys
  160. def _find_mismatched_keys(
  161. state_dicts,
  162. model_state_dict,
  163. loaded_keys,
  164. prefix,
  165. add_prefix_to_model,
  166. remove_prefix_from_model,
  167. ignore_mismatched_sizes,
  168. ):
  169. mismatched_key = []
  170. if ignore_mismatched_sizes:
  171. for checkpoint_key in loaded_keys:
  172. model_key = checkpoint_key
  173. if remove_prefix_from_model:
  174. # The model key starts with `prefix` but `checkpoint_key` doesn't, so we add it.
  175. model_key = f'{prefix}.{checkpoint_key}'
  176. elif add_prefix_to_model:
  177. # The model key doesn't start with `prefix` but `checkpoint_key` does, so we remove it.
  178. model_key = '.'.join(checkpoint_key.split('.')[1:])
  179. if model_key in model_state_dict:
  180. model_shape = model_state_dict[model_key].shape
  181. checkpoint_shape = state_dicts[checkpoint_key].shape
  182. if checkpoint_shape != model_shape:
  183. mismatched_key.append(
  184. (checkpoint_key, state_dicts[checkpoint_key].shape,
  185. model_state_dict[model_key].shape))
  186. del state_dicts[checkpoint_key]
  187. return mismatched_key
  188. def _load_state_dict_into_model(
  189. model,
  190. state_dict,
  191. start_prefix,
  192. head_prefix_keys,
  193. load_state_fn=None,
  194. ):
  195. # Convert old format to new format if needed from a PyTorch state_dict
  196. old_keys = []
  197. new_keys = []
  198. for key in state_dict.keys():
  199. new_key = None
  200. if 'gamma' in key:
  201. new_key = key.replace('gamma', 'weight')
  202. if 'beta' in key:
  203. new_key = key.replace('beta', 'bias')
  204. if new_key:
  205. old_keys.append(key)
  206. new_keys.append(new_key)
  207. for old_key, new_key in zip(old_keys, new_keys):
  208. state_dict[new_key] = state_dict.pop(old_key)
  209. # copy state_dict so _load_from_state_dict can modify it
  210. metadata = getattr(state_dict, '_metadata', None)
  211. state_dict = state_dict.copy()
  212. if metadata is not None:
  213. state_dict._metadata = metadata
  214. error_msgs = []
  215. if load_state_fn is not None:
  216. load_state_fn(
  217. model,
  218. state_dict,
  219. prefix=start_prefix,
  220. head_prefix_keys=head_prefix_keys,
  221. local_metadata=None,
  222. error_msgs=error_msgs)
  223. else:
  224. def load(module: nn.Module, prefix=''):
  225. local_metadata = {} if metadata is None else metadata.get(
  226. prefix[:-1], {})
  227. args = (state_dict, prefix, local_metadata, True, [], [],
  228. error_msgs)
  229. module._load_from_state_dict(*args)
  230. for name, child in module._modules.items():
  231. if child is not None:
  232. load(child, prefix + name + '.')
  233. load(model, prefix=start_prefix)
  234. return error_msgs
  235. def _load_checkpoint(
  236. model,
  237. state_dict,
  238. load_state_fn,
  239. ignore_mismatched_sizes,
  240. _fast_init,
  241. ):
  242. # Retrieve missing & unexpected_keys
  243. model_state_dict = model.state_dict()
  244. expected_keys = list(model_state_dict.keys())
  245. keys_from_pretrained = list(state_dict.keys())
  246. prefix = model.base_model_prefix
  247. # during loading stage, base model prefix is complicated, should consider remove or add
  248. if len(prefix) > 0:
  249. # nlp: encoder, decoder
  250. pretrained_has_prefix_module = any(
  251. s.startswith(prefix) for s in keys_from_pretrained)
  252. model_expects_prefix_module = any(
  253. s.startswith(prefix) for s in expected_keys)
  254. else:
  255. # nlp:encoder-decoder, cv:backbone-head,
  256. pretrained_has_prefix_module = False
  257. model_expects_prefix_module = False
  258. remove_prefix_from_model = not pretrained_has_prefix_module and model_expects_prefix_module
  259. add_prefix_to_model = pretrained_has_prefix_module and not model_expects_prefix_module
  260. if remove_prefix_from_model:
  261. expected_keys_not_base_model_prefixed = [
  262. s for s in expected_keys if not s.startswith(prefix)
  263. ]
  264. expected_keys = [
  265. '.'.join(s.split('.')[1:]) if s.startswith(prefix) else s
  266. for s in expected_keys
  267. ]
  268. elif add_prefix_to_model:
  269. # backbone only
  270. expected_keys = ['.'.join([prefix, s]) for s in expected_keys]
  271. expected_keys_not_base_model_prefixed = []
  272. missing_keys = list(set(expected_keys) - set(keys_from_pretrained))
  273. unexpected_keys = list(set(keys_from_pretrained) - set(expected_keys))
  274. # during loading stage head prefix is simple, add or not add
  275. prefix_heads = model.head_prefix
  276. expected_head_keys_without_head_prefix = []
  277. missing_head_keys = []
  278. unexpected_head_keys = []
  279. pretrained_has_prefix_head = dict()
  280. head_prefix_keys = dict()
  281. # only for case of head mismatched with state-dict
  282. if len(prefix_heads) > 0 and len(unexpected_keys) > 0:
  283. if isinstance(prefix_heads, str):
  284. prefix_heads = [prefix_heads]
  285. # to double-check if head matched with state-dict
  286. for prefix_head in prefix_heads:
  287. pretrained_has_prefix_head[prefix_head] = any(
  288. s.startswith(prefix_head) for s in keys_from_pretrained)
  289. for prefix_head in prefix_heads:
  290. expected_keys_without_head_prefix = [
  291. '.'.join(s.split('.')[1:]) for s in expected_keys
  292. if s.startswith(prefix_head)
  293. ]
  294. expected_head_keys_without_head_prefix.extend(
  295. expected_keys_without_head_prefix)
  296. head_prefix_keys[
  297. prefix_head] = expected_keys_without_head_prefix
  298. unexpected_head_keys = list(
  299. set(unexpected_keys)
  300. - set(expected_head_keys_without_head_prefix))
  301. unexpected_keys = list(
  302. set(unexpected_keys)
  303. - set(expected_head_keys_without_head_prefix))
  304. _keys_to_ignore_on_load_missing = kwargs.pop(
  305. '_keys_to_ignore_on_load_missing', None)
  306. _keys_to_ignore_on_load_unexpected = kwargs.pop(
  307. '_keys_to_ignore_on_load_unexpected', None)
  308. # Some models may have keys that are not in the state by design, removing them before needlessly warning
  309. # the user.
  310. if _keys_to_ignore_on_load_missing is not None:
  311. for pat in _keys_to_ignore_on_load_missing:
  312. missing_keys = [
  313. k for k in missing_keys if re.search(pat, k) is None
  314. ]
  315. if _keys_to_ignore_on_load_unexpected is not None:
  316. for pat in _keys_to_ignore_on_load_unexpected:
  317. unexpected_keys = [
  318. k for k in unexpected_keys if re.search(pat, k) is None
  319. ]
  320. # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
  321. if _fast_init:
  322. uninitialized_modules = retrieve_modules_from_names(
  323. model,
  324. missing_keys,
  325. prefix=prefix,
  326. add_prefix=add_prefix_to_model,
  327. remove_prefix=remove_prefix_from_model)
  328. for module in uninitialized_modules:
  329. model._init_weights(module)
  330. # Make sure we are able to load head correctly by revise state-dict
  331. missing_head_keys_by_head = dict()
  332. if len(head_prefix_keys) > 0:
  333. for head_prefix in head_prefix_keys:
  334. if not pretrained_has_prefix_head[head_prefix]:
  335. state_dict, missing_head_keys, missing_keys = _add_head_prefix_to_state_dict(
  336. state_dict, head_prefix, head_prefix_keys[head_prefix],
  337. missing_keys)
  338. missing_head_keys_by_head[head_prefix] = missing_head_keys
  339. # Make sure we are able to load base models as well as derived models (with heads)
  340. start_prefix = ''
  341. model_to_load = model
  342. heads_to_load = dict()
  343. if len(model.base_model_prefix) > 0 and not hasattr(
  344. model,
  345. model.base_model_prefix) and pretrained_has_prefix_module:
  346. start_prefix = model.base_model_prefix + '.'
  347. if len(model.base_model_prefix) > 0 and hasattr(
  348. model,
  349. model.base_model_prefix) and not pretrained_has_prefix_module:
  350. model_to_load = getattr(model, model.base_model_prefix)
  351. for head_prefix in prefix_heads:
  352. heads_to_load[head_prefix] = getattr(model, head_prefix)
  353. if any(key in expected_keys_not_base_model_prefixed
  354. for key in keys_from_pretrained):
  355. raise ValueError(
  356. 'The state dictionary of the model you are trying to load is corrupted. Are you sure it was '
  357. 'properly saved?')
  358. # Whole checkpoint
  359. mismatched_keys = _find_mismatched_keys(
  360. state_dict,
  361. model_state_dict,
  362. keys_from_pretrained,
  363. prefix,
  364. add_prefix_to_model,
  365. remove_prefix_from_model,
  366. ignore_mismatched_sizes,
  367. )
  368. error_msgs = _load_state_dict_into_model(model_to_load, state_dict,
  369. start_prefix, load_state_fn)
  370. if len(heads_to_load) > 0:
  371. for head in heads_to_load:
  372. local_error_msgs = _load_state_dict_into_model(
  373. heads_to_load[head], state_dict, head + '.', load_state_fn)
  374. error_msgs.extend(local_error_msgs)
  375. if len(error_msgs) > 0:
  376. error_msg = '\n\t'.join(error_msgs)
  377. raise RuntimeError(
  378. f'Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}'
  379. )
  380. if len(unexpected_keys) > 0:
  381. logger.warning(
  382. f'Some weights of the model checkpoint were not used when'
  383. f' initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are'
  384. f' initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or'
  385. ' with another architecture (e.g. initializing a BertForTokenClassification model from a'
  386. ' BertForPreTraining model).\n- This IS NOT expected if you are initializing'
  387. f' {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical'
  388. ' (initializing a BertForTokenClassification model from a BertForTokenClassification model).'
  389. )
  390. elif len(unexpected_head_keys) > 0:
  391. logger.warning(
  392. f'Some weights of the model checkpoint were not used when'
  393. f' initializing {model.__class__.__name__}: {unexpected_head_keys}\n- This IS Not expected if you are'
  394. f' initializing {model.__class__.__name__} from the checkpoint of a model with a same task while the'
  395. ' structure is different (e.g. initializing a BertForTokenClassification model from a'
  396. ' BertForTokenClassification model).')
  397. else:
  398. logger.info(
  399. f'All model checkpoint weights were used when initializing {model.__class__.__name__}.\n'
  400. )
  401. if len(missing_keys) > 0:
  402. logger.warning(
  403. f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
  404. f' and are newly initialized: {missing_keys}\nYou should probably'
  405. ' TRAIN this model on a down-stream task to be able to use it for predictions and inference.'
  406. )
  407. elif len(mismatched_keys) == 0:
  408. logger.info(
  409. f'All the weights of {model.__class__.__name__} were initialized from the model checkpoint '
  410. f'If your task is similar to the task the model of the checkpoint'
  411. f' was trained on, you can already use {model.__class__.__name__} for predictions without further'
  412. ' training.')
  413. if len(mismatched_keys) > 0:
  414. mismatched_warning = '\n'.join([
  415. f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated'
  416. for key, shape1, shape2 in mismatched_keys
  417. ])
  418. logger.warning(
  419. f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
  420. f' and are newly initialized because the shapes did not'
  421. f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able'
  422. ' to use it for predictions and inference.')
  423. return missing_keys, unexpected_keys, mismatched_keys, error_msgs
  424. def retrieve_modules_from_names(model,
  425. names,
  426. prefix=None,
  427. add_prefix=False,
  428. remove_prefix=False):
  429. module_keys = set(['.'.join(key.split('.')[:-1]) for key in names])
  430. # torch.nn.ParameterList is a special case where two parameter keywords
  431. # are appended to the module name, *e.g.* bert.special_embeddings.0
  432. module_keys = module_keys.union(
  433. set([
  434. '.'.join(key.split('.')[:-2]) for key in names
  435. if key[-1].isdigit()
  436. ]))
  437. retrieved_modules = []
  438. # retrieve all modules that has at least one missing weight name
  439. for name, module in model.named_modules():
  440. if remove_prefix:
  441. name = '.'.join(
  442. name.split('.')[1:]) if name.startswith(prefix) else name
  443. elif add_prefix:
  444. name = '.'.join([prefix, name]) if len(name) > 0 else prefix
  445. if name in module_keys:
  446. retrieved_modules.append(module)
  447. return retrieved_modules
  448. def _tie_or_clone_weights(output_embeddings,
  449. input_embeddings,
  450. torchscript=False):
  451. if torchscript:
  452. output_embeddings.weight = nn.Parameter(
  453. input_embeddings.weight.clone())
  454. else:
  455. output_embeddings.weight = input_embeddings.weight
  456. if getattr(output_embeddings, 'bias', None) is not None:
  457. output_embeddings.bias.data = nn.functional.pad(
  458. output_embeddings.bias.data,
  459. (
  460. 0,
  461. output_embeddings.weight.shape[0]
  462. - output_embeddings.bias.shape[0],
  463. ),
  464. 'constant',
  465. 0,
  466. )
  467. if hasattr(output_embeddings, 'out_features') and hasattr(
  468. input_embeddings, 'num_embeddings'):
  469. output_embeddings.out_features = input_embeddings.num_embeddings
  470. def tie_weights(model, tie_word_embeddings=False):
  471. if tie_word_embeddings:
  472. output_embeddings = model.head.get_output_embeddings()
  473. if output_embeddings is not None:
  474. input_embeddings = model.encoder.get_input_embeddings()
  475. _tie_or_clone_weights(output_embeddings, input_embeddings)
  476. # TODO Sharded ckpt
  477. ckpt_file = os.path.join(model_local_dir, ModelFile.TORCH_MODEL_BIN_FILE)
  478. state_dict = torch.load(ckpt_file, map_location='cpu', weights_only=True)
  479. if default_dtype is not None:
  480. torch.set_default_dtype(default_dtype)
  481. missing_keys, unexpected_keys, mismatched_keys, error_msgs = _load_checkpoint(
  482. model_to_load,
  483. state_dict,
  484. load_state_fn=load_state_fn,
  485. ignore_mismatched_sizes=True,
  486. _fast_init=True,
  487. )
  488. if getattr(kwargs.get('head'), 'tie_word_embeddings', False):
  489. tie_weights(model_to_load, kwargs.get('head').tie_word_embeddings)
  490. return {
  491. 'model': model_to_load,
  492. 'missing_keys': missing_keys,
  493. 'unexpected_keys': unexpected_keys,
  494. 'mismatched_keys': mismatched_keys,
  495. 'error_msgs': error_msgs,
  496. }
  497. def save_configuration(target_folder, config: Dict):
  498. if isinstance(config, Config):
  499. config = config.to_dict()
  500. if ConfigFields.pipeline not in config:
  501. config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]}
  502. cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder)
  503. config_file = os.path.join(target_folder, ModelFile.CONFIGURATION)
  504. storage.write(cfg_str.encode(), config_file)
  505. def save_pretrained(model,
  506. target_folder: Union[str, os.PathLike],
  507. save_checkpoint_name: str = None,
  508. save_function: Callable = None,
  509. **kwargs):
  510. """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded
  511. Args:
  512. model (Model): Model whose params are to be saved.
  513. target_folder (Union[str, os.PathLike]):
  514. Directory to which to save. Will be created if it doesn't exist.
  515. save_checkpoint_name (str):
  516. The checkpoint name to be saved in the target_folder
  517. save_function (Callable):
  518. The function to use to save the state dictionary.
  519. """
  520. if save_function is None or not isinstance(save_function, Callable):
  521. raise Exception('A valid save function must be passed in')
  522. if target_folder is None or os.path.isfile(target_folder):
  523. raise ValueError(
  524. f'Provided path ({target_folder}) should be a directory, not a file'
  525. )
  526. if save_checkpoint_name is None:
  527. raise Exception(
  528. 'At least pass in one checkpoint name for saving method')
  529. # Single ckpt path, sharded ckpt logic will be added later
  530. output_ckpt_path = os.path.join(target_folder, save_checkpoint_name)
  531. # Save the files to be copied to the save directory, ignore the original ckpts and configuration
  532. origin_file_to_be_ignored = [save_checkpoint_name]
  533. ignore_file_set = set(origin_file_to_be_ignored)
  534. ignore_file_set.add(ModelFile.CONFIGURATION)
  535. ignore_file_set.add('*.safetensors')
  536. ignore_file_set.add('.*')
  537. if hasattr(model,
  538. 'model_dir') and model.model_dir is not None and is_master():
  539. if sys.version_info.minor >= 8:
  540. copytree_func = copytree
  541. else: # == 7
  542. copytree_func = copytree_py37
  543. copytree_func(
  544. model.model_dir,
  545. target_folder,
  546. ignore=ignore_patterns(*ignore_file_set),
  547. dirs_exist_ok=True)
  548. # Save the ckpt to the save directory
  549. try:
  550. save_function(model, output_ckpt_path, **kwargs)
  551. except Exception as e:
  552. raise Exception(
  553. f'During saving checkpoints, the error of "{type(e).__name__} '
  554. f'with msg {e} thrown')