base_model.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. from abc import ABC, abstractmethod
  5. from typing import Any, Dict, List, Optional, Union
  6. from modelscope.hub.snapshot_download import snapshot_download
  7. from modelscope.metainfo import Tasks
  8. from modelscope.models.builder import build_backbone, build_model
  9. from modelscope.utils.automodel_utils import (can_load_by_ms,
  10. check_model_from_owner_group,
  11. try_to_load_hf_model)
  12. from modelscope.utils.config import Config, ConfigDict
  13. from modelscope.utils.constant import DEFAULT_MODEL_REVISION, Invoke, ModelFile
  14. from modelscope.utils.device import verify_device
  15. from modelscope.utils.logger import get_logger
  16. from modelscope.utils.plugins import (register_modelhub_repo,
  17. register_plugins_repo)
  18. logger = get_logger()
  19. Tensor = Union['torch.Tensor', 'tf.Tensor']
  20. class Model(ABC):
  21. """Base model interface.
  22. """
  23. def __init__(self, model_dir, *args, **kwargs):
  24. self.model_dir = model_dir
  25. device_name = kwargs.get('device', 'gpu')
  26. verify_device(device_name)
  27. self._device_name = device_name
  28. self.trust_remote_code = kwargs.get('trust_remote_code', False)
  29. def __call__(self, *args, **kwargs) -> Dict[str, Any]:
  30. return self.postprocess(self.forward(*args, **kwargs))
  31. def check_trust_remote_code(self,
  32. info_str: Optional[str] = None,
  33. model_dir: Optional[str] = None):
  34. """Check trust_remote_code if the model needs to import extra libs
  35. Args:
  36. info_str(str): The info showed to user if trust_remote_code is `False`.
  37. model_dir(`Optional[str]`): The local model directory. If is a trusted model, check remote code will pass.
  38. """
  39. info_str = info_str or (
  40. 'This model requires `trust_remote_code` to be `True` because it needs to '
  41. 'import extra libs or execute the code in the model repo, setting this to true '
  42. 'means you trust the files in it.')
  43. if not check_model_from_owner_group(model_dir=model_dir):
  44. assert self.trust_remote_code, info_str
  45. @abstractmethod
  46. def forward(self, *args, **kwargs) -> Dict[str, Any]:
  47. """
  48. Run the forward pass for a model.
  49. Returns:
  50. Dict[str, Any]: output from the model forward pass
  51. """
  52. pass
  53. def postprocess(self, inputs: Dict[str, Any], **kwargs) -> Dict[str, Any]:
  54. """ Model specific postprocess and convert model output to
  55. standard model outputs.
  56. Args:
  57. inputs: input data
  58. Return:
  59. dict of results: a dict containing outputs of model, each
  60. output should have the standard output name.
  61. """
  62. return inputs
  63. @classmethod
  64. def _instantiate(cls, **kwargs):
  65. """ Define the instantiation method of a model,default method is by
  66. calling the constructor. Note that in the case of no loading model
  67. process in constructor of a task model, a load_model method is
  68. added, and thus this method is overloaded
  69. """
  70. return cls(**kwargs)
  71. @classmethod
  72. def from_pretrained(cls,
  73. model_name_or_path: str,
  74. revision: Optional[str] = DEFAULT_MODEL_REVISION,
  75. cfg_dict: Config = None,
  76. device: str = None,
  77. trust_remote_code: Optional[bool] = False,
  78. **kwargs):
  79. """Instantiate a model from local directory or remote model repo. Note
  80. that when loading from remote, the model revision can be specified.
  81. Args:
  82. model_name_or_path(str): A model dir or a model id to be loaded
  83. revision(str, `optional`): The revision used when the model_name_or_path is
  84. a model id of the remote hub. default `master`.
  85. cfg_dict(Config, `optional`): An optional model config. If provided, it will replace
  86. the config read out of the `model_name_or_path`
  87. device(str, `optional`): The device to load the model.
  88. trust_remote_code(bool, `optional`): Whether to trust and allow execution of remote code. Default is False.
  89. **kwargs:
  90. task(str, `optional`): The `Tasks` enumeration value to replace the task value
  91. read out of config in the `model_name_or_path`. This is useful when the model to be loaded is not
  92. equal to the model saved.
  93. For example, load a `backbone` into a `text-classification` model.
  94. Other kwargs will be directly fed into the `model` key, to replace the default configs.
  95. use_hf(bool, `optional`):
  96. If set to True, it will initialize the model using AutoModel or AutoModelFor* from hf.
  97. If set to False, the model is loaded using the modelscope mode.
  98. If set to None, the loading mode will be automatically selected.
  99. ignore_file_pattern(List[str], `optional`):
  100. This parameter is passed to snapshot_download
  101. device_map(str | Dict[str, str], `optional`):
  102. This parameter is passed to AutoModel or AutoModelFor*
  103. torch_dtype(torch.dtype, `optional`):
  104. This parameter is passed to AutoModel or AutoModelFor*
  105. config(PretrainedConfig, `optional`):
  106. This parameter is passed to AutoModel or AutoModelFor*
  107. Returns:
  108. A model instance.
  109. Examples:
  110. >>> from modelscope.models import Model
  111. >>> Model.from_pretrained('damo/nlp_structbert_backbone_base_std', task='text-classification')
  112. """
  113. prefetched = kwargs.get('model_prefetched')
  114. if prefetched is not None:
  115. kwargs.pop('model_prefetched')
  116. invoked_by = kwargs.get(Invoke.KEY)
  117. if invoked_by is not None:
  118. kwargs.pop(Invoke.KEY)
  119. else:
  120. invoked_by = Invoke.PRETRAINED
  121. ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
  122. if osp.exists(model_name_or_path):
  123. local_model_dir = model_name_or_path
  124. else:
  125. if prefetched is True:
  126. raise RuntimeError(
  127. 'Expecting model is pre-fetched locally, but is not found.'
  128. )
  129. invoked_by = '%s/%s' % (Invoke.KEY, invoked_by)
  130. local_model_dir = snapshot_download(
  131. model_name_or_path,
  132. revision,
  133. user_agent=invoked_by,
  134. ignore_file_pattern=ignore_file_pattern)
  135. logger.info(f'initialize model from {local_model_dir}')
  136. configuration_path = osp.join(local_model_dir, ModelFile.CONFIGURATION)
  137. cfg = None
  138. if cfg_dict is not None:
  139. cfg = cfg_dict
  140. elif os.path.exists(configuration_path):
  141. cfg = Config.from_file(configuration_path)
  142. task_name = getattr(cfg, 'task', None)
  143. if 'task' in kwargs:
  144. task_name = kwargs.pop('task')
  145. model_cfg = getattr(cfg, 'model', ConfigDict())
  146. if hasattr(model_cfg, 'model_type') and not hasattr(model_cfg, 'type'):
  147. model_cfg.type = model_cfg.model_type
  148. model_type = getattr(model_cfg, 'type', None)
  149. if isinstance(device, str) and device.startswith('gpu'):
  150. device = 'cuda' + device[3:]
  151. use_hf = kwargs.pop('use_hf', None)
  152. if use_hf is None and can_load_by_ms(local_model_dir, task_name,
  153. model_type):
  154. use_hf = False
  155. model = None
  156. if use_hf in {True, None}:
  157. model = try_to_load_hf_model(local_model_dir, task_name, use_hf,
  158. **kwargs)
  159. if model is not None:
  160. device_map = kwargs.pop('device_map', None)
  161. if device_map is None and device is not None:
  162. model = model.to(device)
  163. return model
  164. # use ms
  165. if cfg is None:
  166. raise FileNotFoundError(
  167. f'`{ModelFile.CONFIGURATION}` file not found.')
  168. model_cfg.model_dir = local_model_dir
  169. # Security check: Only allow execution of remote code or plugins if trust_remote_code is True
  170. plugins = cfg.safe_get('plugins')
  171. if plugins and not trust_remote_code:
  172. raise RuntimeError(
  173. 'Detected plugins field in the model configuration file, but '
  174. 'trust_remote_code=True was not explicitly set.\n'
  175. 'To prevent potential execution of malicious code, loading has been refused.\n'
  176. 'If you trust this model repository, please pass trust_remote_code=True to from_pretrained.'
  177. )
  178. if plugins and trust_remote_code:
  179. logger.warning(
  180. 'Use trust_remote_code=True. Will invoke codes or install plugins from remote model repo. '
  181. 'Please make sure that you can trust the external codes.')
  182. register_modelhub_repo(local_model_dir, allow_remote=trust_remote_code)
  183. default_args = {}
  184. if trust_remote_code:
  185. default_args = {'trust_remote_code': trust_remote_code}
  186. register_plugins_repo(plugins)
  187. for k, v in kwargs.items():
  188. model_cfg[k] = v
  189. if device is not None:
  190. model_cfg.device = device
  191. if task_name is Tasks.backbone:
  192. model_cfg.init_backbone = True
  193. model = build_backbone(model_cfg)
  194. else:
  195. model = build_model(
  196. model_cfg, task_name=task_name, default_args=default_args)
  197. # dynamically add pipeline info to model for pipeline inference
  198. if hasattr(cfg, 'pipeline'):
  199. model.pipeline = cfg.pipeline
  200. if not hasattr(model, 'cfg'):
  201. model.cfg = cfg
  202. model_cfg.pop('model_dir', None)
  203. model.name = model_name_or_path
  204. model.model_dir = local_model_dir
  205. return model
  206. def save_pretrained(self,
  207. target_folder: Union[str, os.PathLike],
  208. save_checkpoint_names: Union[str, List[str]] = None,
  209. config: Optional[dict] = None,
  210. **kwargs):
  211. """save the pretrained model, its configuration and other related files to a directory,
  212. so that it can be re-loaded
  213. Args:
  214. target_folder (Union[str, os.PathLike]):
  215. Directory to which to save. Will be created if it doesn't exist.
  216. save_checkpoint_names (Union[str, List[str]]):
  217. The checkpoint names to be saved in the target_folder
  218. config (Optional[dict], optional):
  219. The config for the configuration.json, might not be identical with model.config
  220. """
  221. raise NotImplementedError(
  222. 'save_pretrained method need to be implemented by the subclass.')