base.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import os.path as osp
  4. import random
  5. from abc import ABC, abstractmethod
  6. from functools import partial
  7. from multiprocessing import Pool
  8. from threading import Lock
  9. from typing import Any, Dict, Generator, List, Mapping, Optional, Union
  10. import numpy as np
  11. from packaging import version
  12. from modelscope.models.base import Model
  13. from modelscope.msdatasets import MsDataset
  14. from modelscope.outputs import TASK_OUTPUTS, ModelOutputBase
  15. from modelscope.pipeline_inputs import TASK_INPUTS, check_input_type
  16. from modelscope.preprocessors import Preprocessor
  17. from modelscope.utils.config import Config
  18. from modelscope.utils.constant import Frameworks, Invoke, ModelFile
  19. from modelscope.utils.device import (create_device, device_placement,
  20. verify_device)
  21. from modelscope.utils.hub import read_config, snapshot_download
  22. from modelscope.utils.import_utils import is_tf_available, is_torch_available
  23. from modelscope.utils.logger import get_logger
  24. from modelscope.utils.torch_utils import compile_model
  25. from ..utils.automodel_utils import check_model_from_owner_group
  26. from .util import is_model, is_official_hub_path
  27. if is_torch_available():
  28. import torch
  29. if is_tf_available():
  30. pass
  31. Tensor = Union['torch.Tensor', 'tf.Tensor']
  32. Input = Union[str, tuple, MsDataset, 'Image.Image', 'numpy.ndarray']
  33. InputModel = Union[str, Model, 'torch.nn.Module']
  34. logger = get_logger()
  35. class Pipeline(ABC):
  36. """Pipeline base.
  37. """
  38. def initiate_single_model(self, model, **kwargs):
  39. if self.trust_remote_code:
  40. kwargs['trust_remote_code'] = True
  41. if isinstance(model, str):
  42. logger.info(f'initiate model from {model}')
  43. if isinstance(model, str) and is_official_hub_path(model):
  44. logger.info(f'initiate model from location {model}.')
  45. # expecting model has been prefetched to local cache beforehand
  46. return Model.from_pretrained(
  47. model,
  48. device=self.device_name,
  49. model_prefetched=True,
  50. invoked_by=Invoke.PIPELINE,
  51. device_map=self.device_map,
  52. **kwargs) if is_model(model) else model
  53. else:
  54. return model
  55. def initiate_multiple_models(self, input_models: List[InputModel]):
  56. models = []
  57. for model in input_models:
  58. models.append(self.initiate_single_model(model))
  59. return models
  60. def __init__(self,
  61. config_file: str = None,
  62. model: Union[InputModel, List[InputModel]] = None,
  63. preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
  64. device: str = 'gpu',
  65. auto_collate=True,
  66. device_map=None,
  67. **kwargs):
  68. """ Base class for pipeline.
  69. If config_file is provided, model and preprocessor will be
  70. instantiated from corresponding config. Otherwise, model
  71. and preprocessor will be constructed separately.
  72. Args:
  73. config_file(str, optional): Filepath to configuration file.
  74. model: (list of) Model name or model object
  75. preprocessor: (list of) Preprocessor object
  76. device (str): device str, should be either cpu, cuda, gpu, gpu:X or cuda:X
  77. auto_collate (bool): automatically to convert data to tensor or not.
  78. compile (bool, optional): Compile the model with torch 2.0, default False
  79. compile_options (dict, optional): The compile options if compile=True,
  80. default None to use the default params of 'TorchModel.compile'.
  81. """
  82. if device_map is not None:
  83. assert device == 'gpu', '`device` and `device_map` cannot be input at the same time!'
  84. self.device_map = device_map
  85. verify_device(device)
  86. self.device_name = device
  87. self.trust_remote_code = kwargs.get('trust_remote_code', False)
  88. if not isinstance(model, List):
  89. self.model = self.initiate_single_model(model, **kwargs)
  90. self.models = [self.model]
  91. else:
  92. self.model = None
  93. self.models = self.initiate_multiple_models(model)
  94. self.has_multiple_models = len(self.models) > 1
  95. if config_file is not None:
  96. self.cfg = Config.from_file(config_file)
  97. model_dir = os.path.dirname(config_file)
  98. elif not self.has_multiple_models:
  99. if isinstance(self.model, str):
  100. model_dir = self.model
  101. else:
  102. model_dir = self.model.model_dir
  103. self.cfg = read_config(model_dir)
  104. if preprocessor is None and not self.has_multiple_models:
  105. self.preprocessor = Preprocessor.from_pretrained(model_dir)
  106. else:
  107. self.preprocessor = preprocessor
  108. if self.model or (self.has_multiple_models and self.models[0]):
  109. self.framework = self._get_framework()
  110. else:
  111. self.framework = None
  112. if self.framework == Frameworks.torch:
  113. self.device = create_device(self.device_name)
  114. self._model_prepare = False
  115. self._model_prepare_lock = Lock()
  116. self._auto_collate = auto_collate
  117. self._compile = kwargs.get('compile', False)
  118. self._compile_options = kwargs.get('compile_options', {})
  119. def check_trust_remote_code(self,
  120. info_str: Optional[str] = None,
  121. model_dir: Optional[str] = None):
  122. """Check trust_remote_code if the pipeline needs to import extra libs
  123. Args:
  124. info_str(str): The info showed to user if trust_remote_code is `False`.
  125. model_dir(`Optional[str]`): The local model directory. If is a trusted model, check remote code will pass.
  126. """
  127. info_str = info_str or (
  128. 'This pipeline requires `trust_remote_code` to be `True` because it needs to '
  129. 'import extra libs or execute the code in the model repo, setting this to true '
  130. 'means you trust the files in it.')
  131. if not check_model_from_owner_group(model_dir=model_dir):
  132. assert self.trust_remote_code, info_str
  133. def prepare_model(self):
  134. """ Place model on certain device for pytorch models before first inference
  135. """
  136. self._model_prepare_lock.acquire(timeout=600)
  137. def _prepare_single(model):
  138. if not isinstance(model, torch.nn.Module) and hasattr(
  139. model, 'model'):
  140. model = model.model
  141. if not isinstance(model, torch.nn.Module):
  142. return
  143. model.eval()
  144. from modelscope.utils.torch_utils import is_on_same_device
  145. if is_on_same_device(model):
  146. model.to(self.device)
  147. if not self._model_prepare:
  148. # prepare model for pytorch
  149. if self.framework == Frameworks.torch:
  150. if self.has_multiple_models:
  151. for m in self.models:
  152. _prepare_single(m)
  153. if self._compile:
  154. self.models = [
  155. compile_model(m, **self._compile_options)
  156. for m in self.models
  157. ]
  158. else:
  159. _prepare_single(self.model)
  160. if self._compile:
  161. self.model = compile_model(self.model,
  162. **self._compile_options)
  163. self._model_prepare = True
  164. self._model_prepare_lock.release()
  165. def _get_framework(self) -> str:
  166. frameworks = []
  167. for m in self.models:
  168. if isinstance(m, str):
  169. model_dir = m
  170. else:
  171. model_dir = m.model_dir
  172. cfg_file = osp.join(model_dir, ModelFile.CONFIGURATION)
  173. cfg = Config.from_file(cfg_file)
  174. frameworks.append(cfg.framework)
  175. if not all(x == frameworks[0] for x in frameworks):
  176. logger.warning(
  177. f'got multiple models, but they are in different frameworks {frameworks}'
  178. )
  179. return None
  180. return frameworks[0]
  181. def __call__(self, input: Union[Input, List[Input]], *args,
  182. **kwargs) -> Union[Dict[str, Any], Generator]:
  183. # model provider should leave it as it is
  184. # modelscope library developer will handle this function
  185. # place model to cpu or gpu
  186. if (self.model or (self.has_multiple_models and self.models[0])):
  187. if not self._model_prepare:
  188. self.prepare_model()
  189. # simple showcase, need to support iterator type for both tensorflow and pytorch
  190. # input_dict = self._handle_input(input)
  191. # sanitize the parameters
  192. batch_size = kwargs.pop('batch_size', None)
  193. preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
  194. **kwargs)
  195. kwargs['preprocess_params'] = preprocess_params
  196. kwargs['forward_params'] = forward_params
  197. kwargs['postprocess_params'] = postprocess_params
  198. # for LLMPipeline, we shall support treating list of roles as a
  199. # one single 'messages' input
  200. if 'LLMPipeline' in type(self).__name__ and isinstance(input, list):
  201. input = {'messages': input}
  202. kwargs['is_message'] = True
  203. if isinstance(input, list):
  204. if batch_size is None:
  205. output = []
  206. for ele in input:
  207. output.append(self._process_single(ele, *args, **kwargs))
  208. else:
  209. output = self._process_batch(input, batch_size, **kwargs)
  210. elif isinstance(input, MsDataset):
  211. return self._process_iterator(input, *args, **kwargs)
  212. else:
  213. output = self._process_single(input, *args, **kwargs)
  214. return output
  215. def _sanitize_parameters(self, **pipeline_parameters):
  216. """
  217. this method should sanitize the keyword args to preprocessor params,
  218. forward params and postprocess params on '__call__' or '_process_single' method
  219. considered to be a normal classmethod with default implementation / output
  220. Default Returns:
  221. Dict[str, str]: preprocess_params = {}
  222. Dict[str, str]: forward_params = {}
  223. Dict[str, str]: postprocess_params = pipeline_parameters
  224. """
  225. return {}, {}, pipeline_parameters
  226. def _process_iterator(self, input: Input, *args, **kwargs):
  227. for ele in input:
  228. yield self._process_single(ele, *args, **kwargs)
  229. def _collate_fn(self, data):
  230. return collate_fn(data, self.device)
  231. def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
  232. preprocess_params = kwargs.get('preprocess_params', {})
  233. forward_params = kwargs.get('forward_params', {})
  234. postprocess_params = kwargs.get('postprocess_params', {})
  235. self._check_input(input)
  236. out = self.preprocess(input, **preprocess_params)
  237. with device_placement(self.framework, self.device_name):
  238. if self.framework == Frameworks.torch:
  239. with torch.no_grad():
  240. if self._auto_collate:
  241. out = self._collate_fn(out)
  242. out = self.forward(out, **forward_params)
  243. else:
  244. out = self.forward(out, **forward_params)
  245. out = self.postprocess(out, **postprocess_params)
  246. self._check_output(out)
  247. return out
  248. def _batch(self, data_list):
  249. batch_data = {}
  250. for sample_preprocessed in data_list:
  251. for k, v in sample_preprocessed.items():
  252. value_list = batch_data.get(k, [])
  253. value_list.append(v)
  254. batch_data[k] = value_list
  255. for k in batch_data.keys():
  256. if isinstance(batch_data[k][0], torch.Tensor):
  257. batch_data[k] = torch.cat(batch_data[k])
  258. return batch_data
  259. def _process_batch(self, input: List[Input], batch_size,
  260. **kwargs) -> Dict[str, Any]:
  261. preprocess_params = kwargs.get('preprocess_params')
  262. forward_params = kwargs.get('forward_params')
  263. postprocess_params = kwargs.get('postprocess_params')
  264. # batch data
  265. output_list = []
  266. for i in range(0, len(input), batch_size):
  267. end = min(i + batch_size, len(input))
  268. real_batch_size = end - i
  269. preprocessed_list = [
  270. self.preprocess(i, **preprocess_params) for i in input[i:end]
  271. ]
  272. with device_placement(self.framework, self.device_name):
  273. if self.framework == Frameworks.torch:
  274. with torch.no_grad():
  275. batched_out = self._batch(preprocessed_list)
  276. if self._auto_collate:
  277. batched_out = self._collate_fn(batched_out)
  278. batched_out = self.forward(batched_out,
  279. **forward_params)
  280. else:
  281. batched_out = self._batch(preprocessed_list)
  282. batched_out = self.forward(batched_out, **forward_params)
  283. for batch_idx in range(real_batch_size):
  284. out = {}
  285. for k, element in batched_out.items():
  286. if element is not None:
  287. if isinstance(element, (tuple, list)):
  288. if isinstance(element[0], torch.Tensor):
  289. out[k] = type(element)(
  290. e[batch_idx:batch_idx + 1]
  291. for e in element)
  292. else:
  293. # Compatible with traditional pipelines
  294. out[k] = element[batch_idx]
  295. else:
  296. out[k] = element[batch_idx:batch_idx + 1]
  297. out = self.postprocess(out, **postprocess_params)
  298. self._check_output(out)
  299. output_list.append(out)
  300. return output_list
  301. def _check_input(self, input):
  302. task_name = self.group_key
  303. if task_name in TASK_INPUTS:
  304. input_type = TASK_INPUTS[task_name]
  305. # if multiple input formats are defined, we first
  306. # found the one that match input data and check
  307. if isinstance(input_type, list):
  308. matched_type = None
  309. for t in input_type:
  310. if isinstance(input, (dict, tuple)):
  311. if type(t) == type(input):
  312. matched_type = t
  313. break
  314. elif isinstance(t, str):
  315. matched_type = t
  316. break
  317. if matched_type is None:
  318. err_msg = 'input data format for current pipeline should be one of following: \n'
  319. for t in input_type:
  320. err_msg += f'{t}\n'
  321. raise ValueError(err_msg)
  322. else:
  323. input_type = matched_type
  324. if isinstance(input_type, str):
  325. check_input_type(input_type, input)
  326. elif isinstance(input_type, tuple):
  327. assert isinstance(input, tuple), 'input should be a tuple'
  328. for t, input_ele in zip(input_type, input):
  329. check_input_type(t, input_ele)
  330. elif isinstance(input_type, dict):
  331. for k in input_type.keys():
  332. # allow single input for multi-modal models
  333. if isinstance(input, dict) and k in input:
  334. check_input_type(input_type[k], input[k])
  335. else:
  336. raise ValueError(f'invalid input_type definition {input_type}')
  337. elif not getattr(self, '_input_has_warned', False):
  338. logger.warning(f'task {task_name} input definition is missing')
  339. self._input_has_warned = True
  340. def _check_output(self, input):
  341. # this attribute is dynamically attached by registry
  342. # when cls is registered in registry using task name
  343. task_name = self.group_key
  344. if task_name not in TASK_OUTPUTS:
  345. if not getattr(self, '_output_has_warned', False):
  346. logger.warning(f'task {task_name} output keys are missing')
  347. self._output_has_warned = True
  348. return
  349. output_keys = TASK_OUTPUTS[task_name]
  350. missing_keys = []
  351. input = input.keys() if isinstance(input,
  352. (dict, ModelOutputBase)) else input
  353. for k in output_keys:
  354. if isinstance(k, (dict, ModelOutputBase)) and k not in input:
  355. missing_keys.append(k)
  356. if len(missing_keys) > 0:
  357. raise ValueError(f'expected output keys are {output_keys}, '
  358. f'those {missing_keys} are missing')
  359. def preprocess(self, inputs: Input, **preprocess_params) -> Dict[str, Any]:
  360. """ Provide default implementation based on preprocess_cfg and user can reimplement it
  361. """
  362. assert self.preprocessor is not None, 'preprocess method should be implemented'
  363. assert not isinstance(self.preprocessor, List),\
  364. 'default implementation does not support using multiple preprocessors.'
  365. return self.preprocessor(inputs, **preprocess_params)
  366. def forward(self, inputs: Dict[str, Any],
  367. **forward_params) -> Dict[str, Any]:
  368. """ Provide default implementation using self.model and user can reimplement it
  369. """
  370. assert self.model is not None, 'forward method should be implemented'
  371. assert not self.has_multiple_models, 'default implementation does not support multiple models in a pipeline.'
  372. return self.model(inputs, **forward_params)
  373. def postprocess(self, inputs: Dict[str, Any],
  374. **post_params) -> Dict[str, Any]:
  375. """ If current pipeline support model reuse, common postprocess
  376. code should be write here.
  377. Args:
  378. inputs: input data
  379. post_params: post process parameters
  380. Return:
  381. dict of results: a dict containing outputs of model, each
  382. output should have the standard output name.
  383. """
  384. raise NotImplementedError('postprocess')
  385. class DistributedPipeline(Pipeline):
  386. """This pipeline is used to load multi gpu models.
  387. What will this class do:
  388. 1. Read the global config from the configuration.json
  389. 2. Set the multiprocessing method to spawn
  390. 3. Open a multiprocessing pool of the world_size to instantiate model pieces.
  391. 4. Set the master port and ip
  392. 5. Call _instantiate_one to instantiate one model piece,
  393. This method should be implemented by the derived class.
  394. 6. After the forward method is called, do preprocess in main process and
  395. call _forward_one to collect results, and do post process in main process.
  396. NOTE: _instantiate_one and _forward_one are class methods, any derived class should implement them and
  397. store the model handler in the class field.
  398. """
  399. def __init__(self,
  400. model: str = None,
  401. preprocessor: Union[Preprocessor, List[Preprocessor]] = None,
  402. auto_collate=True,
  403. **kwargs):
  404. # DistributedPipeline uses classmethod to initialize model
  405. # without calling super().__init__ method
  406. self.preprocessor = preprocessor
  407. self._model_prepare = False
  408. self._model_prepare_lock = Lock()
  409. self._auto_collate = auto_collate
  410. if os.path.exists(model):
  411. self.model_dir = model
  412. else:
  413. self.model_dir = snapshot_download(model)
  414. self.cfg = read_config(self.model_dir)
  415. self.world_size = self._get_world_size(self.cfg)
  416. self.model_pool = None
  417. self.device_name = 'cpu'
  418. self.device = create_device(self.device_name)
  419. self.has_multiple_models = False
  420. self.framework = self.cfg.framework
  421. torch.multiprocessing.set_start_method('spawn', force=True)
  422. ranks = list(range(self.world_size))
  423. self.model_pool = Pool(self.world_size)
  424. if 'master_ip' not in kwargs:
  425. kwargs['master_ip'] = '127.0.0.1'
  426. master_port = int(kwargs['master_port']
  427. ) if 'master_port' in kwargs else random.randint(
  428. 29500, 39500)
  429. from modelscope.utils.torch_utils import _find_free_port, _is_free_port
  430. if not _is_free_port(master_port):
  431. master_port = _find_free_port()
  432. kwargs['master_port'] = str(master_port)
  433. # TODO: Pass ip and port to megatron_util for initialization
  434. os.environ['MASTER_ADDR'] = kwargs['master_ip']
  435. os.environ['MASTER_PORT'] = kwargs['master_port']
  436. self.model_pool.map(
  437. partial(
  438. self.__class__._instantiate_one,
  439. model_dir=self.model_dir,
  440. **self.cfg.model,
  441. **kwargs), ranks)
  442. self.models = []
  443. def __del__(self):
  444. if hasattr(self, 'model_pool') and self.model_pool is not None:
  445. try:
  446. self.model_pool.terminate()
  447. except AttributeError:
  448. pass
  449. def __getstate__(self):
  450. self_dict = self.__dict__.copy()
  451. del self_dict['model_pool']
  452. del self_dict['preprocessor']
  453. del self_dict['_model_prepare_lock']
  454. return self_dict
  455. @classmethod
  456. def _instantiate_one(cls, rank, model_dir, **kwargs):
  457. """Instantiate one model piece.
  458. Args:
  459. rank: The model rank.
  460. model_dir: The model_dir in the node.
  461. kwargs: Any extra args.
  462. Returns:
  463. None. The model handler should be kept in the class field.
  464. """
  465. pass
  466. def forward(self, inputs: Dict[str, Any],
  467. **forward_params) -> Dict[str, Any]:
  468. inputs = {
  469. 'inputs': inputs,
  470. 'forward_params': forward_params,
  471. }
  472. res = self.model_pool.map(self.__class__._forward_one,
  473. [inputs] * self.world_size)
  474. return res[0]
  475. @classmethod
  476. def _forward_one(cls, inputs):
  477. """Forward the inputs to one model piece.
  478. Use the model handler kept in the class field to forward.
  479. Args:
  480. inputs: The inputs after the preprocessing.
  481. Returns:
  482. The forward results.
  483. """
  484. pass
  485. def _get_world_size(self, cfg: Config) -> int:
  486. m_world_size = cfg.safe_get('megatron.world_size')
  487. if m_world_size is None:
  488. return cfg.safe_get('model.world_size')
  489. return m_world_size
  490. def collate_fn(data, device):
  491. """Prepare the input just before the forward function.
  492. This method will move the tensors to the right device.
  493. Usually this method does not need to be overridden.
  494. Args:
  495. data: The data out of the dataloader.
  496. device: The device to move data to.
  497. Returns: The processed data.
  498. """
  499. from torch.utils.data.dataloader import default_collate
  500. def get_class_name(obj):
  501. return obj.__class__.__name__
  502. if isinstance(data, dict) or isinstance(data, Mapping):
  503. # add compatibility for img_metas for mmlab models
  504. return type(data)({
  505. k: collate_fn(v, device) if k != 'img_metas' else v
  506. for k, v in data.items()
  507. })
  508. elif isinstance(data, (tuple, list)):
  509. if 0 == len(data):
  510. return torch.Tensor([])
  511. if isinstance(data[0], (int, float)):
  512. return default_collate(data).to(device)
  513. else:
  514. return type(data)(collate_fn(v, device) for v in data)
  515. elif isinstance(data, np.ndarray):
  516. if data.dtype.type is np.str_:
  517. return data
  518. else:
  519. return collate_fn(torch.from_numpy(data), device)
  520. elif isinstance(data, torch.Tensor):
  521. return data.to(device)
  522. elif isinstance(data, (bytes, str, int, float, bool, type(None))):
  523. return data
  524. elif get_class_name(data) == 'InputFeatures':
  525. # modelscope.preprocessors.nlp.InputFeatures
  526. return data
  527. elif get_class_name(data) == 'DataContainer':
  528. # mmcv.parallel.DataContainer
  529. return data
  530. else:
  531. raise ValueError(f'Unsupported data type {type(data)}')