vllm.py 3.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. from typing import List, Optional, Union
  2. from modelscope import get_logger
  3. from modelscope.pipelines.accelerate.base import InferFramework
  4. from modelscope.utils.import_utils import is_vllm_available
  5. logger = get_logger()
  6. class Vllm(InferFramework):
  7. def __init__(self,
  8. model_id_or_dir: str,
  9. dtype: str = 'auto',
  10. quantization: str = None,
  11. tensor_parallel_size: int = 1,
  12. trust_remote_code: Optional[bool] = None):
  13. """
  14. Args:
  15. dtype: The dtype to use, support `auto`, `float16`, `bfloat16`, `float32`
  16. quantization: The quantization bit, default None means do not do any quantization.
  17. tensor_parallel_size: The tensor parallel size.
  18. """
  19. super().__init__(model_id_or_dir)
  20. if not is_vllm_available():
  21. raise ImportError(
  22. 'Install vllm by `pip install vllm` before using vllm to accelerate inference'
  23. )
  24. from vllm import LLM
  25. if not Vllm.check_gpu_compatibility(8) and (dtype
  26. in ('bfloat16', 'auto')):
  27. dtype = 'float16'
  28. self.model = LLM(
  29. self.model_dir,
  30. dtype=dtype,
  31. quantization=quantization,
  32. trust_remote_code=trust_remote_code,
  33. tensor_parallel_size=tensor_parallel_size)
  34. def __call__(self, prompts: Union[List[str], List[List[int]]],
  35. **kwargs) -> List[str]:
  36. """Generate tokens.
  37. Args:
  38. prompts(`Union[List[str], List[List[int]]]`):
  39. The string batch or the token list batch to input to the model.
  40. kwargs: Sampling parameters.
  41. """
  42. # convert hf generate config to vllm
  43. do_sample = kwargs.pop('do_sample', None)
  44. num_beam = kwargs.pop('num_beam', 1)
  45. max_length = kwargs.pop('max_length', None)
  46. max_new_tokens = kwargs.pop('max_new_tokens', None)
  47. # for vllm, default to do_sample/greedy(depends on temperature).
  48. # for hf, do_sample=false, num_beam=1 -> greedy(default)
  49. # do_sample=true, num_beam=1 -> sample
  50. # do_sample=false, num_beam>1 -> beam_search
  51. if not do_sample and num_beam > 1:
  52. kwargs['use_beam_search'] = True
  53. if max_length:
  54. kwargs['max_tokens'] = max_length - len(prompts[0])
  55. if max_new_tokens:
  56. kwargs['max_tokens'] = max_new_tokens
  57. from vllm import SamplingParams
  58. sampling_params = SamplingParams(**kwargs)
  59. if isinstance(prompts[0], str):
  60. return [
  61. output.outputs[0].text for output in self.model.generate(
  62. prompts, sampling_params=sampling_params)
  63. ]
  64. else:
  65. return [
  66. output.outputs[0].text for output in self.model.generate(
  67. prompt_token_ids=prompts, sampling_params=sampling_params)
  68. ]
  69. def model_type_supported(self, model_type: str):
  70. return any([
  71. model in model_type.lower() for model in [
  72. 'llama',
  73. 'baichuan',
  74. 'internlm',
  75. 'mistral',
  76. 'aquila',
  77. 'bloom',
  78. 'falcon',
  79. 'gpt',
  80. 'mpt',
  81. 'opt',
  82. 'qwen',
  83. 'aquila',
  84. ]
  85. ])