builder.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, List, Optional, Union
  4. from modelscope.hub.snapshot_download import snapshot_download
  5. from modelscope.metainfo import DEFAULT_MODEL_FOR_PIPELINE
  6. from modelscope.models.base import Model
  7. from modelscope.utils.config import ConfigDict, check_config
  8. from modelscope.utils.constant import (DEFAULT_MODEL_REVISION, Invoke, Tasks,
  9. ThirdParty)
  10. from modelscope.utils.hub import read_config
  11. from modelscope.utils.import_utils import is_transformers_available
  12. from modelscope.utils.logger import get_logger
  13. from modelscope.utils.plugins import (register_modelhub_repo,
  14. register_plugins_repo)
  15. from modelscope.utils.registry import Registry, build_from_cfg
  16. from modelscope.utils.task_utils import is_embedding_task
  17. from .base import Pipeline
  18. from .util import is_official_hub_path
  19. PIPELINES = Registry('pipelines')
  20. logger = get_logger()
  21. def normalize_model_input(model,
  22. model_revision,
  23. third_party=None,
  24. ignore_file_pattern=None):
  25. """ normalize the input model, to ensure that a model str is a valid local path: in other words,
  26. for model represented by a model id, the model shall be downloaded locally
  27. """
  28. if isinstance(model, str) and is_official_hub_path(model, model_revision):
  29. # skip revision download if model is a local directory
  30. if not os.path.exists(model):
  31. # note that if there is already a local copy, snapshot_download will check and skip downloading
  32. user_agent = {Invoke.KEY: Invoke.PIPELINE}
  33. if third_party is not None:
  34. user_agent[ThirdParty.KEY] = third_party
  35. model = snapshot_download(
  36. model,
  37. revision=model_revision,
  38. user_agent=user_agent,
  39. ignore_file_pattern=ignore_file_pattern)
  40. elif isinstance(model, list) and isinstance(model[0], str):
  41. for idx in range(len(model)):
  42. if is_official_hub_path(
  43. model[idx],
  44. model_revision) and not os.path.exists(model[idx]):
  45. user_agent = {Invoke.KEY: Invoke.PIPELINE}
  46. if third_party is not None:
  47. user_agent[ThirdParty.KEY] = third_party
  48. model[idx] = snapshot_download(
  49. model[idx], revision=model_revision, user_agent=user_agent)
  50. return model
  51. def build_pipeline(cfg: ConfigDict,
  52. task_name: str = None,
  53. default_args: dict = None):
  54. """ build pipeline given model config dict.
  55. Args:
  56. cfg (:obj:`ConfigDict`): config dict for model object.
  57. task_name (str, optional): task name, refer to
  58. :obj:`Tasks` for more details.
  59. default_args (dict, optional): Default initialization arguments.
  60. """
  61. return build_from_cfg(
  62. cfg, PIPELINES, group_key=task_name, default_args=default_args)
  63. def pipeline(task: str = None,
  64. model: Union[str, List[str], Model, List[Model]] = None,
  65. preprocessor=None,
  66. config_file: str = None,
  67. pipeline_name: str = None,
  68. framework: str = None,
  69. device: str = None,
  70. model_revision: Optional[str] = DEFAULT_MODEL_REVISION,
  71. ignore_file_pattern: List[str] = None,
  72. **kwargs) -> Pipeline:
  73. """ Factory method to build an obj:`Pipeline`.
  74. Args:
  75. task (str): Task name defining which pipeline will be returned.
  76. model (str or List[str] or obj:`Model` or obj:list[`Model`]): (list of) model name or model object.
  77. preprocessor: preprocessor object.
  78. config_file (str, optional): path to config file.
  79. pipeline_name (str, optional): pipeline class name or alias name.
  80. framework (str, optional): framework type.
  81. model_revision: revision of model(s) if getting from model hub, for multiple models, expecting
  82. all models to have the same revision
  83. device (str, optional): whether to use gpu or cpu is used to do inference.
  84. ignore_file_pattern(`str` or `List`, *optional*, default to `None`):
  85. Any file pattern to be ignored in downloading, like exact file names or file extensions.
  86. Return:
  87. pipeline (obj:`Pipeline`): pipeline object for certain task.
  88. Examples:
  89. >>> # Using default model for a task
  90. >>> p = pipeline('image-classification')
  91. >>> # Using pipeline with a model name
  92. >>> p = pipeline('text-classification', model='damo/distilbert-base-uncased')
  93. >>> # Using pipeline with a model object
  94. >>> resnet = Model.from_pretrained('Resnet')
  95. >>> p = pipeline('image-classification', model=resnet)
  96. >>> # Using pipeline with a list of model names
  97. >>> p = pipeline('audio-kws', model=['damo/audio-tts', 'damo/auto-tts2'])
  98. """
  99. if task is None and pipeline_name is None:
  100. raise ValueError('task or pipeline_name is required')
  101. pipeline_props = None
  102. if pipeline_name is None:
  103. # get default pipeline for this task
  104. if isinstance(model, str) \
  105. or (isinstance(model, list) and isinstance(model[0], str)):
  106. if is_official_hub_path(model, revision=model_revision):
  107. # read config file from hub and parse
  108. cfg = read_config(
  109. model, revision=model_revision) if isinstance(
  110. model, str) else read_config(
  111. model[0], revision=model_revision)
  112. if cfg:
  113. pipeline_name = cfg.safe_get('pipeline',
  114. {}).get('type', None)
  115. if pipeline_name is None:
  116. prefer_llm_pipeline = kwargs.get('external_engine_for_llm')
  117. # if not specified in both args and configuration.json, prefer llm pipeline for aforementioned tasks
  118. if task is not None and task.lower() in [
  119. Tasks.text_generation, Tasks.chat
  120. ]:
  121. if prefer_llm_pipeline is None:
  122. prefer_llm_pipeline = True
  123. # for llm pipeline, if llm_framework is not specified, default to swift instead
  124. # TODO: port the swift infer based on transformer into ModelScope
  125. if prefer_llm_pipeline:
  126. if kwargs.get('llm_framework') is None:
  127. kwargs['llm_framework'] = 'swift'
  128. pipeline_name = external_engine_for_llm_checker(
  129. model, model_revision, kwargs)
  130. if pipeline_name is None or pipeline_name != 'llm':
  131. third_party = kwargs.get(ThirdParty.KEY)
  132. if third_party is not None:
  133. kwargs.pop(ThirdParty.KEY)
  134. model = normalize_model_input(
  135. model,
  136. model_revision,
  137. third_party=third_party,
  138. ignore_file_pattern=ignore_file_pattern)
  139. register_plugins_repo(cfg.safe_get('plugins'))
  140. register_modelhub_repo(model,
  141. cfg.get('allow_remote', False))
  142. if pipeline_name:
  143. pipeline_props = {'type': pipeline_name}
  144. else:
  145. try:
  146. check_config(cfg)
  147. pipeline_props = cfg.pipeline
  148. except AssertionError as e:
  149. logger.info(str(e))
  150. elif model is not None:
  151. # get pipeline info from Model object
  152. first_model = model[0] if isinstance(model, list) else model
  153. if not hasattr(first_model, 'pipeline'):
  154. # model is instantiated by user, we should parse config again
  155. cfg = read_config(first_model.model_dir)
  156. try:
  157. check_config(cfg)
  158. first_model.pipeline = cfg.pipeline
  159. except AssertionError as e:
  160. logger.info(str(e))
  161. if first_model.__dict__.get('pipeline'):
  162. pipeline_props = first_model.pipeline
  163. else:
  164. pipeline_name, default_model_repo = get_default_pipeline_info(task)
  165. model = normalize_model_input(default_model_repo, model_revision)
  166. pipeline_props = {'type': pipeline_name}
  167. else:
  168. pipeline_props = {'type': pipeline_name}
  169. if not pipeline_props and is_embedding_task(task):
  170. try:
  171. from modelscope.utils.hf_util import sentence_transformers_pipeline
  172. return sentence_transformers_pipeline(model=model, **kwargs)
  173. except Exception:
  174. logger.exception(
  175. 'We could not find a suitable pipeline from modelscope, so we tried to load it using the '
  176. 'sentence_transformers, but that also failed.')
  177. raise
  178. if not pipeline_props and is_transformers_available():
  179. try:
  180. from modelscope.utils.hf_util import hf_pipeline
  181. return hf_pipeline(
  182. task=task,
  183. model=model,
  184. framework=framework,
  185. device=device,
  186. **kwargs)
  187. except Exception as e:
  188. logger.error(
  189. 'We couldn\'t find a suitable pipeline from ms, so we tried to load it using the transformers pipeline,'
  190. ' but that also failed.')
  191. raise e
  192. if not device:
  193. device = 'gpu'
  194. pipeline_props['model'] = model
  195. pipeline_props['device'] = device
  196. cfg = ConfigDict(pipeline_props)
  197. clear_llm_info(kwargs, pipeline_name)
  198. if kwargs:
  199. cfg.update(kwargs)
  200. if preprocessor is not None:
  201. cfg.preprocessor = preprocessor
  202. return build_pipeline(cfg, task_name=task)
  203. def add_default_pipeline_info(task: str,
  204. model_name: str,
  205. modelhub_name: str = None,
  206. overwrite: bool = False):
  207. """ Add default model for a task.
  208. Args:
  209. task (str): task name.
  210. model_name (str): model_name.
  211. modelhub_name (str): name for default modelhub.
  212. overwrite (bool): overwrite default info.
  213. """
  214. if not overwrite:
  215. assert task not in DEFAULT_MODEL_FOR_PIPELINE, \
  216. f'task {task} already has default model.'
  217. DEFAULT_MODEL_FOR_PIPELINE[task] = (model_name, modelhub_name)
  218. def get_default_pipeline_info(task):
  219. """ Get default info for certain task.
  220. Args:
  221. task (str): task name.
  222. Return:
  223. A tuple: first element is pipeline name(model_name), second element
  224. is modelhub name.
  225. """
  226. if task not in DEFAULT_MODEL_FOR_PIPELINE:
  227. # support pipeline which does not register default model
  228. pipeline_name = list(PIPELINES.modules[task].keys())[0]
  229. default_model = None
  230. else:
  231. pipeline_name, default_model = DEFAULT_MODEL_FOR_PIPELINE[task]
  232. return pipeline_name, default_model
  233. def external_engine_for_llm_checker(model: Union[str, List[str], Model,
  234. List[Model]],
  235. revision: Optional[str],
  236. kwargs: Dict[str, Any]) -> Optional[str]:
  237. from .nlp.llm_pipeline import ModelTypeHelper, LLMAdapterRegistry
  238. from ..hub.check_model import get_model_id_from_cache
  239. if isinstance(model, list):
  240. model = model[0]
  241. if not isinstance(model, str):
  242. model = model.model_dir
  243. llm_framework = kwargs.get('llm_framework', '')
  244. if llm_framework == 'swift':
  245. from swift.llm import get_model_info_meta
  246. # check if swift supports
  247. if os.path.exists(model):
  248. model_id = get_model_id_from_cache(model)
  249. else:
  250. model_id = model
  251. try:
  252. info = get_model_info_meta(model_id)
  253. model_type = info[0].model_type
  254. except Exception as e:
  255. logger.warning(f'Cannot using llm_framework with {model_id}, '
  256. f'ignoring llm_framework={llm_framework} : {e}')
  257. model_type = None
  258. if model_type:
  259. return 'llm'
  260. model_type = ModelTypeHelper.get(
  261. model, revision, with_adapter=True, split='-', use_cache=True)
  262. if LLMAdapterRegistry.contains(model_type):
  263. return 'llm'
  264. def clear_llm_info(kwargs: Dict, pipeline_name: str):
  265. from modelscope.utils.model_type_helper import ModelTypeHelper
  266. kwargs.pop('external_engine_for_llm', None)
  267. if pipeline_name != 'llm':
  268. kwargs.pop('llm_framework', None)
  269. ModelTypeHelper.clear_cache()