pipeline_builder.py 3.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192
  1. import os
  2. from typing import Optional, Union
  3. from modelscope.hub import snapshot_download
  4. from modelscope.utils.hf_util.patcher import _patch_pretrained_class
  5. from modelscope.utils.logger import get_logger
  6. logger = get_logger()
  7. def _get_hf_device(device):
  8. if isinstance(device, str):
  9. device_name = device.lower()
  10. eles = device_name.split(':')
  11. if eles[0] == 'gpu':
  12. eles = ['cuda'] + eles[1:]
  13. device = ''.join(eles)
  14. return device
  15. def _get_hf_pipeline_class(task, model):
  16. from transformers.pipelines import check_task, get_task
  17. if not task:
  18. task = get_task(model)
  19. normalized_task, targeted_task, task_options = check_task(task)
  20. pipeline_class = targeted_task['impl']
  21. pipeline_class = _patch_pretrained_class([pipeline_class])[0]
  22. return pipeline_class
  23. def hf_pipeline(
  24. task: str = None,
  25. model: Optional[Union[str, 'PreTrainedModel', 'TFPreTrainedModel']] = None,
  26. framework: Optional[str] = None,
  27. device: Optional[Union[int, str, 'torch.device']] = None,
  28. **kwargs,
  29. ) -> 'transformers.Pipeline':
  30. from transformers import pipeline
  31. if isinstance(model, str):
  32. if not os.path.exists(model):
  33. model = snapshot_download(model)
  34. framework = 'pt' if framework == 'pytorch' else framework
  35. device = _get_hf_device(device)
  36. pipeline_class = _get_hf_pipeline_class(task, model)
  37. kwargs.pop('external_engine_for_llm', None)
  38. kwargs.pop('llm_framework', None)
  39. return pipeline(
  40. task=task,
  41. model=model,
  42. framework=framework,
  43. device=device,
  44. pipeline_class=pipeline_class,
  45. **kwargs)
  46. def sentence_transformers_pipeline(model: str, **kwargs):
  47. try:
  48. from sentence_transformers import SentenceTransformer
  49. except ImportError:
  50. raise ImportError(
  51. 'Could not import sentence_transformers, please upgrade to the latest version of sentence_transformers '
  52. "with: 'pip install -U sentence_transformers'") from None
  53. if isinstance(model, str):
  54. if not os.path.exists(model):
  55. model = snapshot_download(model)
  56. from modelscope.pipelines import Pipeline
  57. class SentenceTransformerPipeline(Pipeline):
  58. """A wrapper for sentence_transformers.SentenceTransformer to make it compatible
  59. with the modelscope pipeline conventions."""
  60. def __init__(self, model_path: str, **kwargs):
  61. self.model = SentenceTransformer(model_path, **kwargs)
  62. def __call__(self,
  63. sentences: str | list[str] | None = None,
  64. prompt_name: str | None = None,
  65. **kwargs):
  66. input_data = kwargs.pop('input', None)
  67. if input_data is not None:
  68. sentences = input_data['source_sentence']
  69. res = self.model.encode(sentences, **kwargs)
  70. return {'text_embedding': res}
  71. return self.model.encode(
  72. sentences, prompt_name=prompt_name, **kwargs)
  73. return SentenceTransformerPipeline(model, **kwargs)