| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596 |
- from typing import List, Optional, Union
- from modelscope import get_logger
- from modelscope.pipelines.accelerate.base import InferFramework
- from modelscope.utils.import_utils import is_vllm_available
- logger = get_logger()
- class Vllm(InferFramework):
- def __init__(self,
- model_id_or_dir: str,
- dtype: str = 'auto',
- quantization: str = None,
- tensor_parallel_size: int = 1,
- trust_remote_code: Optional[bool] = None):
- """
- Args:
- dtype: The dtype to use, support `auto`, `float16`, `bfloat16`, `float32`
- quantization: The quantization bit, default None means do not do any quantization.
- tensor_parallel_size: The tensor parallel size.
- """
- super().__init__(model_id_or_dir)
- if not is_vllm_available():
- raise ImportError(
- 'Install vllm by `pip install vllm` before using vllm to accelerate inference'
- )
- from vllm import LLM
- if not Vllm.check_gpu_compatibility(8) and (dtype
- in ('bfloat16', 'auto')):
- dtype = 'float16'
- self.model = LLM(
- self.model_dir,
- dtype=dtype,
- quantization=quantization,
- trust_remote_code=trust_remote_code,
- tensor_parallel_size=tensor_parallel_size)
- def __call__(self, prompts: Union[List[str], List[List[int]]],
- **kwargs) -> List[str]:
- """Generate tokens.
- Args:
- prompts(`Union[List[str], List[List[int]]]`):
- The string batch or the token list batch to input to the model.
- kwargs: Sampling parameters.
- """
- # convert hf generate config to vllm
- do_sample = kwargs.pop('do_sample', None)
- num_beam = kwargs.pop('num_beam', 1)
- max_length = kwargs.pop('max_length', None)
- max_new_tokens = kwargs.pop('max_new_tokens', None)
- # for vllm, default to do_sample/greedy(depends on temperature).
- # for hf, do_sample=false, num_beam=1 -> greedy(default)
- # do_sample=true, num_beam=1 -> sample
- # do_sample=false, num_beam>1 -> beam_search
- if not do_sample and num_beam > 1:
- kwargs['use_beam_search'] = True
- if max_length:
- kwargs['max_tokens'] = max_length - len(prompts[0])
- if max_new_tokens:
- kwargs['max_tokens'] = max_new_tokens
- from vllm import SamplingParams
- sampling_params = SamplingParams(**kwargs)
- if isinstance(prompts[0], str):
- return [
- output.outputs[0].text for output in self.model.generate(
- prompts, sampling_params=sampling_params)
- ]
- else:
- return [
- output.outputs[0].text for output in self.model.generate(
- prompt_token_ids=prompts, sampling_params=sampling_params)
- ]
- def model_type_supported(self, model_type: str):
- return any([
- model in model_type.lower() for model in [
- 'llama',
- 'baichuan',
- 'internlm',
- 'mistral',
- 'aquila',
- 'bloom',
- 'falcon',
- 'gpt',
- 'mpt',
- 'opt',
- 'qwen',
- 'aquila',
- ]
- ])
|