base.py 2.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. import os.path
  2. from abc import abstractmethod
  3. from typing import List, Union
  4. import torch.cuda
  5. from modelscope import read_config, snapshot_download
  6. from modelscope.utils.config import Config
  7. class InferFramework:
  8. def __init__(self, model_id_or_dir: str, **kwargs):
  9. """
  10. Args:
  11. model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files.
  12. """
  13. if os.path.exists(model_id_or_dir):
  14. self.model_dir = model_id_or_dir
  15. else:
  16. self.model_dir = snapshot_download(model_id_or_dir)
  17. model_supported = self.model_type_supported(model_id_or_dir)
  18. config: Config = read_config(self.model_dir)
  19. model_type = config.safe_get('model.type')
  20. if model_type is not None:
  21. model_supported = model_supported or self.model_type_supported(
  22. model_type)
  23. config_file = os.path.join(self.model_dir, 'config.json')
  24. if os.path.isfile(config_file):
  25. config = Config.from_file(config_file)
  26. model_type = config.safe_get('model_type')
  27. if model_type is not None:
  28. model_supported = model_supported or self.model_type_supported(
  29. model_type)
  30. if not model_supported:
  31. raise ValueError(
  32. f'Model accelerating not supported: {model_id_or_dir}')
  33. @abstractmethod
  34. def __call__(self, prompts: Union[List[str], List[List[int]]],
  35. **kwargs) -> List[str]:
  36. """
  37. Args:
  38. prompts(`Union[List[str], List[List[int]]]`):
  39. The string batch or the token list batch to input to the model.
  40. Returns:
  41. The answers in list according to the input prompt batch.
  42. """
  43. pass
  44. def model_type_supported(self, model_type: str):
  45. return False
  46. @staticmethod
  47. def check_gpu_compatibility(major_version: int):
  48. """Check the GPU compatibility.
  49. """
  50. major, _ = torch.cuda.get_device_capability()
  51. return major >= major_version
  52. @classmethod
  53. def from_pretrained(cls, model_id_or_dir, framework='vllm', **kwargs):
  54. """Instantiate the model wrapped by an accelerate framework.
  55. Args:
  56. model_id_or_dir(`str`): The model id of the modelhub or a local dir containing model files.
  57. framework(`str`): The framework to use.
  58. Returns:
  59. The wrapped model.
  60. """
  61. if framework == 'vllm':
  62. from .vllm import Vllm
  63. vllm = Vllm(model_id_or_dir, **kwargs)
  64. vllm.llm_framework = framework
  65. return vllm
  66. else:
  67. raise ValueError(f'Framework not supported: {framework}')