_helpers.py 7.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. """ Model creation / weight loading / state_dict helpers
  2. Hacked together by / Copyright 2020 Ross Wightman
  3. """
  4. import logging
  5. import os
  6. from typing import Any, Callable, Dict, Optional, Union
  7. import torch
  8. try:
  9. import safetensors.torch
  10. _has_safetensors = True
  11. except ImportError:
  12. _has_safetensors = False
  13. _logger = logging.getLogger(__name__)
  14. __all__ = ['clean_state_dict', 'load_state_dict', 'load_checkpoint', 'remap_state_dict', 'resume_checkpoint']
  15. def _remove_prefix(text: str, prefix: str) -> str:
  16. # FIXME replace with 3.9 stdlib fn when min at 3.9
  17. if text.startswith(prefix):
  18. return text[len(prefix):]
  19. return text
  20. def clean_state_dict(state_dict: Dict[str, Any]) -> Dict[str, Any]:
  21. # 'clean' checkpoint by removing .module prefix from state dict if it exists from parallel training
  22. cleaned_state_dict = {}
  23. to_remove = (
  24. 'module.', # DDP wrapper
  25. '_orig_mod.', # torchcompile dynamo wrapper
  26. )
  27. for k, v in state_dict.items():
  28. for r in to_remove:
  29. k = _remove_prefix(k, r)
  30. cleaned_state_dict[k] = v
  31. return cleaned_state_dict
  32. def load_state_dict(
  33. checkpoint_path: str,
  34. use_ema: bool = True,
  35. device: Union[str, torch.device] = 'cpu',
  36. weights_only: bool = False,
  37. ) -> Dict[str, Any]:
  38. """Load state dictionary from checkpoint file.
  39. Args:
  40. checkpoint_path: Path to checkpoint file.
  41. use_ema: Whether to use EMA weights if available.
  42. device: Device to load checkpoint to.
  43. weights_only: Whether to load only weights (torch.load parameter).
  44. Returns:
  45. State dictionary loaded from checkpoint.
  46. """
  47. if checkpoint_path and os.path.isfile(checkpoint_path):
  48. # Check if safetensors or not and load weights accordingly
  49. if str(checkpoint_path).endswith(".safetensors"):
  50. assert _has_safetensors, "`pip install safetensors` to use .safetensors"
  51. checkpoint = safetensors.torch.load_file(checkpoint_path, device=device)
  52. else:
  53. try:
  54. checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=weights_only)
  55. except TypeError:
  56. checkpoint = torch.load(checkpoint_path, map_location=device)
  57. state_dict_key = ''
  58. if isinstance(checkpoint, dict):
  59. if use_ema and checkpoint.get('state_dict_ema', None) is not None:
  60. state_dict_key = 'state_dict_ema'
  61. elif use_ema and checkpoint.get('model_ema', None) is not None:
  62. state_dict_key = 'model_ema'
  63. elif 'state_dict' in checkpoint:
  64. state_dict_key = 'state_dict'
  65. elif 'model' in checkpoint:
  66. state_dict_key = 'model'
  67. state_dict = clean_state_dict(checkpoint[state_dict_key] if state_dict_key else checkpoint)
  68. _logger.info("Loaded {} from checkpoint '{}'".format(state_dict_key, checkpoint_path))
  69. return state_dict
  70. else:
  71. _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
  72. raise FileNotFoundError()
  73. def load_checkpoint(
  74. model: torch.nn.Module,
  75. checkpoint_path: str,
  76. use_ema: bool = True,
  77. device: Union[str, torch.device] = 'cpu',
  78. strict: bool = True,
  79. remap: bool = False,
  80. filter_fn: Optional[Callable] = None,
  81. weights_only: bool = False,
  82. ) -> Any:
  83. """Load checkpoint into model.
  84. Args:
  85. model: Model to load checkpoint into.
  86. checkpoint_path: Path to checkpoint file.
  87. use_ema: Whether to use EMA weights if available.
  88. device: Device to load checkpoint to.
  89. strict: Whether to strictly enforce state_dict keys match.
  90. remap: Whether to remap state dict keys by order.
  91. filter_fn: Optional function to filter state dict.
  92. weights_only: Whether to load only weights (torch.load parameter).
  93. Returns:
  94. Incompatible keys from model.load_state_dict().
  95. """
  96. if os.path.splitext(checkpoint_path)[-1].lower() in ('.npz', '.npy'):
  97. # numpy checkpoint, try to load via model specific load_pretrained fn
  98. if hasattr(model, 'load_pretrained'):
  99. model.load_pretrained(checkpoint_path)
  100. else:
  101. raise NotImplementedError('Model cannot load numpy checkpoint')
  102. return
  103. state_dict = load_state_dict(checkpoint_path, use_ema, device=device, weights_only=weights_only)
  104. if remap:
  105. state_dict = remap_state_dict(state_dict, model)
  106. elif filter_fn:
  107. state_dict = filter_fn(state_dict, model)
  108. incompatible_keys = model.load_state_dict(state_dict, strict=strict)
  109. return incompatible_keys
  110. def remap_state_dict(
  111. state_dict: Dict[str, Any],
  112. model: torch.nn.Module,
  113. allow_reshape: bool = True
  114. ) -> Dict[str, Any]:
  115. """Remap checkpoint by iterating over state dicts in order (ignoring original keys).
  116. This assumes models (and originating state dict) were created with params registered in same order.
  117. Args:
  118. state_dict: State dict to remap.
  119. model: Model whose state dict keys to use.
  120. allow_reshape: Whether to allow reshaping tensors to match.
  121. Returns:
  122. Remapped state dictionary.
  123. """
  124. out_dict = {}
  125. for (ka, va), (kb, vb) in zip(model.state_dict().items(), state_dict.items()):
  126. assert va.numel() == vb.numel(), f'Tensor size mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
  127. if va.shape != vb.shape:
  128. if allow_reshape:
  129. vb = vb.reshape(va.shape)
  130. else:
  131. assert False, f'Tensor shape mismatch {ka}: {va.shape} vs {kb}: {vb.shape}. Remap failed.'
  132. out_dict[ka] = vb
  133. return out_dict
  134. def resume_checkpoint(
  135. model: torch.nn.Module,
  136. checkpoint_path: str,
  137. optimizer: Optional[torch.optim.Optimizer] = None,
  138. loss_scaler: Optional[Any] = None,
  139. log_info: bool = True,
  140. ) -> Optional[int]:
  141. """Resume training from checkpoint.
  142. Args:
  143. model: Model to load checkpoint into.
  144. checkpoint_path: Path to checkpoint file.
  145. optimizer: Optional optimizer to restore state.
  146. loss_scaler: Optional AMP loss scaler to restore state.
  147. log_info: Whether to log loading info.
  148. Returns:
  149. Resume epoch number if available, else None.
  150. """
  151. resume_epoch = None
  152. if os.path.isfile(checkpoint_path):
  153. checkpoint = torch.load(checkpoint_path, map_location='cpu', weights_only=False)
  154. if isinstance(checkpoint, dict) and 'state_dict' in checkpoint:
  155. if log_info:
  156. _logger.info('Restoring model state from checkpoint...')
  157. state_dict = clean_state_dict(checkpoint['state_dict'])
  158. model.load_state_dict(state_dict)
  159. if optimizer is not None and 'optimizer' in checkpoint:
  160. if log_info:
  161. _logger.info('Restoring optimizer state from checkpoint...')
  162. optimizer.load_state_dict(checkpoint['optimizer'])
  163. if loss_scaler is not None and loss_scaler.state_dict_key in checkpoint:
  164. if log_info:
  165. _logger.info('Restoring AMP loss scaler state from checkpoint...')
  166. loss_scaler.load_state_dict(checkpoint[loss_scaler.state_dict_key])
  167. if 'epoch' in checkpoint:
  168. resume_epoch = checkpoint['epoch']
  169. if 'version' in checkpoint and checkpoint['version'] > 1:
  170. resume_epoch += 1 # start at the next epoch, old checkpoints incremented before save
  171. if log_info:
  172. _logger.info("Loaded checkpoint '{}' (epoch {})".format(checkpoint_path, checkpoint['epoch']))
  173. else:
  174. model.load_state_dict(checkpoint)
  175. if log_info:
  176. _logger.info("Loaded checkpoint '{}'".format(checkpoint_path))
  177. return resume_epoch
  178. else:
  179. _logger.error("No checkpoint found at '{}'".format(checkpoint_path))
  180. raise FileNotFoundError()