| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import io
- import os
- import re
- import sys
- import time
- from collections import OrderedDict
- from shutil import copytree, ignore_patterns, rmtree
- from typing import Callable, Dict, Optional, Union
- import json
- import torch
- from torch import nn
- from torch.optim import Optimizer
- from torch.optim.lr_scheduler import _LRScheduler
- from modelscope.fileio import File, LocalStorage
- from modelscope.utils.config import Config, JSONIteratorEncoder
- from modelscope.utils.constant import ConfigFields, ModelFile
- from modelscope.utils.file_utils import copytree_py37
- from modelscope.utils.logger import get_logger
- from modelscope.utils.torch_utils import is_master
- logger = get_logger()
- storage = LocalStorage()
- def weights_to_cpu(state_dict):
- """Copy a model state_dict to cpu.
- Args:
- state_dict (OrderedDict): Model weights on GPU.
- Returns:
- OrderedDict: Model weights on GPU.
- """
- state_dict_cpu = OrderedDict()
- for key, val in state_dict.items():
- state_dict_cpu[key] = val.cpu()
- # Keep metadata in state_dict
- state_dict_cpu._metadata = getattr(state_dict, '_metadata', OrderedDict())
- return state_dict_cpu
- def save_checkpoint(model: torch.nn.Module,
- filename: str,
- optimizer: Optional[Optimizer] = None,
- lr_scheduler: Optional[_LRScheduler] = None,
- meta: Optional[dict] = None,
- with_meta: bool = True,
- with_model: bool = True) -> None:
- """Save checkpoint to file.
- The checkpoint will have 3 fields: ``meta``, ``state_dict`` and
- ``optimizer``. By default, ``meta`` will contain version and time info.
- Args:
- model (Module): Module whose params are to be saved.
- filename (str): Checkpoint filename.
- optimizer (:obj:`Optimizer`, optional): Optimizer to be saved.
- lr_scheduler(:obj:`_LRScheduler`, optional): LRScheduler to be saved.
- meta (dict, optional): Metadata to be saved in checkpoint.
- with_meta (bool, optional): Save meta info.
- with_model(bool, optional): Save model states.
- """
- checkpoint = {}
- if not with_meta and not with_model:
- raise ValueError(
- 'Save meta by "with_meta=True" or model by "with_model=True"')
- if with_meta:
- if meta is None:
- meta = {}
- elif not isinstance(meta, dict):
- raise TypeError(
- f'meta must be a dict or None, but got {type(meta)}')
- from modelscope import __version__
- meta.update(modelscope=__version__, time=time.asctime())
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- model = model.module
- if hasattr(model, 'CLASSES') and model.CLASSES is not None:
- # save class name to the meta
- meta.update(CLASSES=model.CLASSES)
- checkpoint['meta'] = meta
- # save optimizer state dict in the checkpoint
- if isinstance(optimizer, Optimizer):
- checkpoint['optimizer'] = optimizer.state_dict()
- elif isinstance(optimizer, dict):
- checkpoint['optimizer'] = {}
- for name, optim in optimizer.items():
- checkpoint['optimizer'][name] = optim.state_dict()
- # save lr_scheduler state dict in the checkpoint
- if lr_scheduler is not None and hasattr(lr_scheduler, 'state_dict'):
- checkpoint['lr_scheduler'] = lr_scheduler.state_dict()
- if with_model:
- if isinstance(model, torch.nn.parallel.DistributedDataParallel):
- model = model.module
- _weights = weights_to_cpu(model.state_dict())
- if not with_meta:
- checkpoint = _weights
- else:
- checkpoint['state_dict'] = _weights
- with io.BytesIO() as f:
- torch.save(checkpoint, f)
- File.write(f.getvalue(), filename)
- def load_checkpoint(filename,
- model,
- optimizer: Optimizer = None,
- lr_scheduler: _LRScheduler = None):
- if not os.path.exists(filename):
- raise ValueError(f'Checkpoint file {filename} does not exist!')
- checkpoint = torch.load(filename, map_location='cpu', weights_only=True)
- if optimizer is not None:
- if 'optimizer' in checkpoint:
- if isinstance(optimizer, Optimizer):
- optimizer.load_state_dict(checkpoint['optimizer'])
- elif isinstance(optimizer, dict):
- optimizer_dict = checkpoint['optimizer']
- for key, optimizer_ins in optimizer.items():
- if key in optimizer_dict:
- optimizer_ins.load_state_dict(optimizer_dict[key])
- else:
- logger.warning(
- f'The state dict of optimizer {key} cannot be found in checkpoint file: {filename}'
- )
- else:
- logger.warning(
- f'The state dict of optimizer cannot be found in checkpoint file: {filename}'
- )
- if lr_scheduler is not None:
- if 'lr_scheduler' in checkpoint:
- lr_scheduler.load_state_dict(checkpoint['lr_scheduler'])
- else:
- logger.warning(
- f'The state dict of lr_scheduler cannot be found in checkpoint file: {filename}'
- )
- if model is not None:
- state_dict = checkpoint if 'state_dict' not in checkpoint else checkpoint[
- 'state_dict']
- model.load_state_dict(state_dict)
- return checkpoint.get('meta', {})
- def load_task_model_checkpoint(model_to_load,
- model_local_dir,
- default_dtype=None,
- load_state_fn=None,
- **kwargs):
- """
- Load model checkpoint file and feed the parameters into the model.
- Args:
- model_to_load: The model to be load
- model_local_dir: The actual checkpoint dir on local disk.
- default_dtype: Set the default float type by 'torch.set_default_dtype'
- load_state_fn: An optional load_state_fn used to load state_dict into the model.
- Returns:
- """
- def _add_head_prefix_to_state_dict(state_dicts, head_prefix,
- expected_keys_without_head_prefix,
- missing_keys):
- new_state_dict = OrderedDict()
- for name, module in state_dicts.items():
- if name in expected_keys_without_head_prefix:
- name_with_head = '.'.join([head_prefix, name])
- new_state_dict[name_with_head] = module
- expected_keys_without_head_prefix.remove(name)
- missing_keys = list(set(missing_keys) - set([name_with_head]))
- else:
- new_state_dict[name] = module
- missing_head_keys = []
- if len(expected_keys_without_head_prefix) > 0:
- missing_head_keys = expected_keys_without_head_prefix.copy()
- return new_state_dict, missing_head_keys, missing_keys
- def _find_mismatched_keys(
- state_dicts,
- model_state_dict,
- loaded_keys,
- prefix,
- add_prefix_to_model,
- remove_prefix_from_model,
- ignore_mismatched_sizes,
- ):
- mismatched_key = []
- if ignore_mismatched_sizes:
- for checkpoint_key in loaded_keys:
- model_key = checkpoint_key
- if remove_prefix_from_model:
- # The model key starts with `prefix` but `checkpoint_key` doesn't, so we add it.
- model_key = f'{prefix}.{checkpoint_key}'
- elif add_prefix_to_model:
- # The model key doesn't start with `prefix` but `checkpoint_key` does, so we remove it.
- model_key = '.'.join(checkpoint_key.split('.')[1:])
- if model_key in model_state_dict:
- model_shape = model_state_dict[model_key].shape
- checkpoint_shape = state_dicts[checkpoint_key].shape
- if checkpoint_shape != model_shape:
- mismatched_key.append(
- (checkpoint_key, state_dicts[checkpoint_key].shape,
- model_state_dict[model_key].shape))
- del state_dicts[checkpoint_key]
- return mismatched_key
- def _load_state_dict_into_model(
- model,
- state_dict,
- start_prefix,
- head_prefix_keys,
- load_state_fn=None,
- ):
- # Convert old format to new format if needed from a PyTorch state_dict
- old_keys = []
- new_keys = []
- for key in state_dict.keys():
- new_key = None
- if 'gamma' in key:
- new_key = key.replace('gamma', 'weight')
- if 'beta' in key:
- new_key = key.replace('beta', 'bias')
- if new_key:
- old_keys.append(key)
- new_keys.append(new_key)
- for old_key, new_key in zip(old_keys, new_keys):
- state_dict[new_key] = state_dict.pop(old_key)
- # copy state_dict so _load_from_state_dict can modify it
- metadata = getattr(state_dict, '_metadata', None)
- state_dict = state_dict.copy()
- if metadata is not None:
- state_dict._metadata = metadata
- error_msgs = []
- if load_state_fn is not None:
- load_state_fn(
- model,
- state_dict,
- prefix=start_prefix,
- head_prefix_keys=head_prefix_keys,
- local_metadata=None,
- error_msgs=error_msgs)
- else:
- def load(module: nn.Module, prefix=''):
- local_metadata = {} if metadata is None else metadata.get(
- prefix[:-1], {})
- args = (state_dict, prefix, local_metadata, True, [], [],
- error_msgs)
- module._load_from_state_dict(*args)
- for name, child in module._modules.items():
- if child is not None:
- load(child, prefix + name + '.')
- load(model, prefix=start_prefix)
- return error_msgs
- def _load_checkpoint(
- model,
- state_dict,
- load_state_fn,
- ignore_mismatched_sizes,
- _fast_init,
- ):
- # Retrieve missing & unexpected_keys
- model_state_dict = model.state_dict()
- expected_keys = list(model_state_dict.keys())
- keys_from_pretrained = list(state_dict.keys())
- prefix = model.base_model_prefix
- # during loading stage, base model prefix is complicated, should consider remove or add
- if len(prefix) > 0:
- # nlp: encoder, decoder
- pretrained_has_prefix_module = any(
- s.startswith(prefix) for s in keys_from_pretrained)
- model_expects_prefix_module = any(
- s.startswith(prefix) for s in expected_keys)
- else:
- # nlp:encoder-decoder, cv:backbone-head,
- pretrained_has_prefix_module = False
- model_expects_prefix_module = False
- remove_prefix_from_model = not pretrained_has_prefix_module and model_expects_prefix_module
- add_prefix_to_model = pretrained_has_prefix_module and not model_expects_prefix_module
- if remove_prefix_from_model:
- expected_keys_not_base_model_prefixed = [
- s for s in expected_keys if not s.startswith(prefix)
- ]
- expected_keys = [
- '.'.join(s.split('.')[1:]) if s.startswith(prefix) else s
- for s in expected_keys
- ]
- elif add_prefix_to_model:
- # backbone only
- expected_keys = ['.'.join([prefix, s]) for s in expected_keys]
- expected_keys_not_base_model_prefixed = []
- missing_keys = list(set(expected_keys) - set(keys_from_pretrained))
- unexpected_keys = list(set(keys_from_pretrained) - set(expected_keys))
- # during loading stage head prefix is simple, add or not add
- prefix_heads = model.head_prefix
- expected_head_keys_without_head_prefix = []
- missing_head_keys = []
- unexpected_head_keys = []
- pretrained_has_prefix_head = dict()
- head_prefix_keys = dict()
- # only for case of head mismatched with state-dict
- if len(prefix_heads) > 0 and len(unexpected_keys) > 0:
- if isinstance(prefix_heads, str):
- prefix_heads = [prefix_heads]
- # to double-check if head matched with state-dict
- for prefix_head in prefix_heads:
- pretrained_has_prefix_head[prefix_head] = any(
- s.startswith(prefix_head) for s in keys_from_pretrained)
- for prefix_head in prefix_heads:
- expected_keys_without_head_prefix = [
- '.'.join(s.split('.')[1:]) for s in expected_keys
- if s.startswith(prefix_head)
- ]
- expected_head_keys_without_head_prefix.extend(
- expected_keys_without_head_prefix)
- head_prefix_keys[
- prefix_head] = expected_keys_without_head_prefix
- unexpected_head_keys = list(
- set(unexpected_keys)
- - set(expected_head_keys_without_head_prefix))
- unexpected_keys = list(
- set(unexpected_keys)
- - set(expected_head_keys_without_head_prefix))
- _keys_to_ignore_on_load_missing = kwargs.pop(
- '_keys_to_ignore_on_load_missing', None)
- _keys_to_ignore_on_load_unexpected = kwargs.pop(
- '_keys_to_ignore_on_load_unexpected', None)
- # Some models may have keys that are not in the state by design, removing them before needlessly warning
- # the user.
- if _keys_to_ignore_on_load_missing is not None:
- for pat in _keys_to_ignore_on_load_missing:
- missing_keys = [
- k for k in missing_keys if re.search(pat, k) is None
- ]
- if _keys_to_ignore_on_load_unexpected is not None:
- for pat in _keys_to_ignore_on_load_unexpected:
- unexpected_keys = [
- k for k in unexpected_keys if re.search(pat, k) is None
- ]
- # retrieve uninitialized modules and initialize before maybe overriding that with the pretrained weights.
- if _fast_init:
- uninitialized_modules = retrieve_modules_from_names(
- model,
- missing_keys,
- prefix=prefix,
- add_prefix=add_prefix_to_model,
- remove_prefix=remove_prefix_from_model)
- for module in uninitialized_modules:
- model._init_weights(module)
- # Make sure we are able to load head correctly by revise state-dict
- missing_head_keys_by_head = dict()
- if len(head_prefix_keys) > 0:
- for head_prefix in head_prefix_keys:
- if not pretrained_has_prefix_head[head_prefix]:
- state_dict, missing_head_keys, missing_keys = _add_head_prefix_to_state_dict(
- state_dict, head_prefix, head_prefix_keys[head_prefix],
- missing_keys)
- missing_head_keys_by_head[head_prefix] = missing_head_keys
- # Make sure we are able to load base models as well as derived models (with heads)
- start_prefix = ''
- model_to_load = model
- heads_to_load = dict()
- if len(model.base_model_prefix) > 0 and not hasattr(
- model,
- model.base_model_prefix) and pretrained_has_prefix_module:
- start_prefix = model.base_model_prefix + '.'
- if len(model.base_model_prefix) > 0 and hasattr(
- model,
- model.base_model_prefix) and not pretrained_has_prefix_module:
- model_to_load = getattr(model, model.base_model_prefix)
- for head_prefix in prefix_heads:
- heads_to_load[head_prefix] = getattr(model, head_prefix)
- if any(key in expected_keys_not_base_model_prefixed
- for key in keys_from_pretrained):
- raise ValueError(
- 'The state dictionary of the model you are trying to load is corrupted. Are you sure it was '
- 'properly saved?')
- # Whole checkpoint
- mismatched_keys = _find_mismatched_keys(
- state_dict,
- model_state_dict,
- keys_from_pretrained,
- prefix,
- add_prefix_to_model,
- remove_prefix_from_model,
- ignore_mismatched_sizes,
- )
- error_msgs = _load_state_dict_into_model(model_to_load, state_dict,
- start_prefix, load_state_fn)
- if len(heads_to_load) > 0:
- for head in heads_to_load:
- local_error_msgs = _load_state_dict_into_model(
- heads_to_load[head], state_dict, head + '.', load_state_fn)
- error_msgs.extend(local_error_msgs)
- if len(error_msgs) > 0:
- error_msg = '\n\t'.join(error_msgs)
- raise RuntimeError(
- f'Error(s) in loading state_dict for {model.__class__.__name__}:\n\t{error_msg}'
- )
- if len(unexpected_keys) > 0:
- logger.warning(
- f'Some weights of the model checkpoint were not used when'
- f' initializing {model.__class__.__name__}: {unexpected_keys}\n- This IS expected if you are'
- f' initializing {model.__class__.__name__} from the checkpoint of a model trained on another task or'
- ' with another architecture (e.g. initializing a BertForTokenClassification model from a'
- ' BertForPreTraining model).\n- This IS NOT expected if you are initializing'
- f' {model.__class__.__name__} from the checkpoint of a model that you expect to be exactly identical'
- ' (initializing a BertForTokenClassification model from a BertForTokenClassification model).'
- )
- elif len(unexpected_head_keys) > 0:
- logger.warning(
- f'Some weights of the model checkpoint were not used when'
- f' initializing {model.__class__.__name__}: {unexpected_head_keys}\n- This IS Not expected if you are'
- f' initializing {model.__class__.__name__} from the checkpoint of a model with a same task while the'
- ' structure is different (e.g. initializing a BertForTokenClassification model from a'
- ' BertForTokenClassification model).')
- else:
- logger.info(
- f'All model checkpoint weights were used when initializing {model.__class__.__name__}.\n'
- )
- if len(missing_keys) > 0:
- logger.warning(
- f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
- f' and are newly initialized: {missing_keys}\nYou should probably'
- ' TRAIN this model on a down-stream task to be able to use it for predictions and inference.'
- )
- elif len(mismatched_keys) == 0:
- logger.info(
- f'All the weights of {model.__class__.__name__} were initialized from the model checkpoint '
- f'If your task is similar to the task the model of the checkpoint'
- f' was trained on, you can already use {model.__class__.__name__} for predictions without further'
- ' training.')
- if len(mismatched_keys) > 0:
- mismatched_warning = '\n'.join([
- f'- {key}: found shape {shape1} in the checkpoint and {shape2} in the model instantiated'
- for key, shape1, shape2 in mismatched_keys
- ])
- logger.warning(
- f'Some weights of {model.__class__.__name__} were not initialized from the model checkpoint'
- f' and are newly initialized because the shapes did not'
- f' match:\n{mismatched_warning}\nYou should probably TRAIN this model on a down-stream task to be able'
- ' to use it for predictions and inference.')
- return missing_keys, unexpected_keys, mismatched_keys, error_msgs
- def retrieve_modules_from_names(model,
- names,
- prefix=None,
- add_prefix=False,
- remove_prefix=False):
- module_keys = set(['.'.join(key.split('.')[:-1]) for key in names])
- # torch.nn.ParameterList is a special case where two parameter keywords
- # are appended to the module name, *e.g.* bert.special_embeddings.0
- module_keys = module_keys.union(
- set([
- '.'.join(key.split('.')[:-2]) for key in names
- if key[-1].isdigit()
- ]))
- retrieved_modules = []
- # retrieve all modules that has at least one missing weight name
- for name, module in model.named_modules():
- if remove_prefix:
- name = '.'.join(
- name.split('.')[1:]) if name.startswith(prefix) else name
- elif add_prefix:
- name = '.'.join([prefix, name]) if len(name) > 0 else prefix
- if name in module_keys:
- retrieved_modules.append(module)
- return retrieved_modules
- def _tie_or_clone_weights(output_embeddings,
- input_embeddings,
- torchscript=False):
- if torchscript:
- output_embeddings.weight = nn.Parameter(
- input_embeddings.weight.clone())
- else:
- output_embeddings.weight = input_embeddings.weight
- if getattr(output_embeddings, 'bias', None) is not None:
- output_embeddings.bias.data = nn.functional.pad(
- output_embeddings.bias.data,
- (
- 0,
- output_embeddings.weight.shape[0]
- - output_embeddings.bias.shape[0],
- ),
- 'constant',
- 0,
- )
- if hasattr(output_embeddings, 'out_features') and hasattr(
- input_embeddings, 'num_embeddings'):
- output_embeddings.out_features = input_embeddings.num_embeddings
- def tie_weights(model, tie_word_embeddings=False):
- if tie_word_embeddings:
- output_embeddings = model.head.get_output_embeddings()
- if output_embeddings is not None:
- input_embeddings = model.encoder.get_input_embeddings()
- _tie_or_clone_weights(output_embeddings, input_embeddings)
- # TODO Sharded ckpt
- ckpt_file = os.path.join(model_local_dir, ModelFile.TORCH_MODEL_BIN_FILE)
- state_dict = torch.load(ckpt_file, map_location='cpu', weights_only=True)
- if default_dtype is not None:
- torch.set_default_dtype(default_dtype)
- missing_keys, unexpected_keys, mismatched_keys, error_msgs = _load_checkpoint(
- model_to_load,
- state_dict,
- load_state_fn=load_state_fn,
- ignore_mismatched_sizes=True,
- _fast_init=True,
- )
- if getattr(kwargs.get('head'), 'tie_word_embeddings', False):
- tie_weights(model_to_load, kwargs.get('head').tie_word_embeddings)
- return {
- 'model': model_to_load,
- 'missing_keys': missing_keys,
- 'unexpected_keys': unexpected_keys,
- 'mismatched_keys': mismatched_keys,
- 'error_msgs': error_msgs,
- }
- def save_configuration(target_folder, config: Dict):
- if isinstance(config, Config):
- config = config.to_dict()
- if ConfigFields.pipeline not in config:
- config[ConfigFields.pipeline] = {'type': config[ConfigFields.task]}
- cfg_str = json.dumps(config, indent=4, cls=JSONIteratorEncoder)
- config_file = os.path.join(target_folder, ModelFile.CONFIGURATION)
- storage.write(cfg_str.encode(), config_file)
- def save_pretrained(model,
- target_folder: Union[str, os.PathLike],
- save_checkpoint_name: str = None,
- save_function: Callable = None,
- **kwargs):
- """save the pretrained model, its configuration and other related files to a directory, so that it can be re-loaded
- Args:
- model (Model): Model whose params are to be saved.
- target_folder (Union[str, os.PathLike]):
- Directory to which to save. Will be created if it doesn't exist.
- save_checkpoint_name (str):
- The checkpoint name to be saved in the target_folder
- save_function (Callable):
- The function to use to save the state dictionary.
- """
- if save_function is None or not isinstance(save_function, Callable):
- raise Exception('A valid save function must be passed in')
- if target_folder is None or os.path.isfile(target_folder):
- raise ValueError(
- f'Provided path ({target_folder}) should be a directory, not a file'
- )
- if save_checkpoint_name is None:
- raise Exception(
- 'At least pass in one checkpoint name for saving method')
- # Single ckpt path, sharded ckpt logic will be added later
- output_ckpt_path = os.path.join(target_folder, save_checkpoint_name)
- # Save the files to be copied to the save directory, ignore the original ckpts and configuration
- origin_file_to_be_ignored = [save_checkpoint_name]
- ignore_file_set = set(origin_file_to_be_ignored)
- ignore_file_set.add(ModelFile.CONFIGURATION)
- ignore_file_set.add('*.safetensors')
- ignore_file_set.add('.*')
- if hasattr(model,
- 'model_dir') and model.model_dir is not None and is_master():
- if sys.version_info.minor >= 8:
- copytree_func = copytree
- else: # == 7
- copytree_func = copytree_py37
- copytree_func(
- model.model_dir,
- target_folder,
- ignore=ignore_patterns(*ignore_file_set),
- dirs_exist_ok=True)
- # Save the ckpt to the save directory
- try:
- save_function(model, output_ckpt_path, **kwargs)
- except Exception as e:
- raise Exception(
- f'During saving checkpoints, the error of "{type(e).__name__} '
- f'with msg {e} thrown')
|