feature_extraction_pipeline.py 3.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from typing import Any, Dict, Optional, Union
  4. import torch
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models import Model
  7. from modelscope.outputs import OutputKeys
  8. from modelscope.pipelines.base import Pipeline, Tensor
  9. from modelscope.pipelines.builder import PIPELINES
  10. from modelscope.preprocessors import (FillMaskTransformersPreprocessor,
  11. Preprocessor)
  12. from modelscope.utils.config import Config
  13. from modelscope.utils.constant import ModelFile, Tasks
  14. __all__ = ['FeatureExtractionPipeline']
  15. @PIPELINES.register_module(
  16. Tasks.feature_extraction, module_name=Pipelines.feature_extraction)
  17. class FeatureExtractionPipeline(Pipeline):
  18. def __init__(self,
  19. model: Union[Model, str],
  20. preprocessor: Optional[Preprocessor] = None,
  21. config_file: str = None,
  22. device: str = 'gpu',
  23. auto_collate=True,
  24. padding=False,
  25. sequence_length=128,
  26. **kwargs):
  27. """Use `model` and `preprocessor` to create a nlp feature extraction pipeline for prediction
  28. Args:
  29. model (str or Model): Supply either a local model dir which supported feature extraction task, or a
  30. no-head model id from the model hub, or a torch model instance.
  31. preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
  32. the model if supplied.
  33. kwargs (dict, `optional`):
  34. Extra kwargs passed into the preprocessor's constructor.
  35. Examples:
  36. >>> from modelscope.pipelines import pipeline
  37. >>> pipe_ins = pipeline('feature_extraction', model='damo/nlp_structbert_feature-extraction_english-large')
  38. >>> input = 'Everything you love is treasure'
  39. >>> print(pipe_ins(input))
  40. """
  41. super().__init__(
  42. model=model,
  43. preprocessor=preprocessor,
  44. config_file=config_file,
  45. device=device,
  46. auto_collate=auto_collate,
  47. compile=kwargs.pop('compile', False),
  48. compile_options=kwargs.pop('compile_options', {}))
  49. assert isinstance(self.model, Model), \
  50. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  51. if preprocessor is None:
  52. self.preprocessor = Preprocessor.from_pretrained(
  53. self.model.model_dir,
  54. padding=padding,
  55. sequence_length=sequence_length,
  56. **kwargs)
  57. self.model.eval()
  58. def forward(self, inputs: Dict[str, Any],
  59. **forward_params) -> Dict[str, Any]:
  60. with torch.no_grad():
  61. return self.model(**inputs, **forward_params)
  62. def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  63. """process the prediction results
  64. Args:
  65. inputs (Dict[str, Any]): _description_
  66. Returns:
  67. Dict[str, str]: the prediction results
  68. """
  69. return {
  70. OutputKeys.TEXT_EMBEDDING:
  71. inputs[OutputKeys.TEXT_EMBEDDING].tolist()
  72. }