sentence_embedding_pipeline.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import numpy as np
  4. import torch
  5. from modelscope.metainfo import Pipelines
  6. from modelscope.models import Model
  7. from modelscope.outputs import OutputKeys
  8. from modelscope.pipelines.base import Pipeline
  9. from modelscope.pipelines.builder import PIPELINES
  10. from modelscope.preprocessors import Preprocessor
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. __all__ = ['SentenceEmbeddingPipeline']
  13. @PIPELINES.register_module(
  14. Tasks.sentence_embedding, module_name=Pipelines.sentence_embedding)
  15. class SentenceEmbeddingPipeline(Pipeline):
  16. def __init__(self,
  17. model: Union[Model, str],
  18. preprocessor: Optional[Preprocessor] = None,
  19. config_file: str = None,
  20. device: str = 'gpu',
  21. auto_collate=True,
  22. sequence_length=128,
  23. **kwargs):
  24. """Use `model` and `preprocessor` to create a nlp text dual encoder then generates the text representation.
  25. Args:
  26. model (str or Model): Supply either a local model dir which supported the WS task,
  27. or a model id from the model hub, or a torch model instance.
  28. preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
  29. the model if supplied.
  30. kwargs (dict, `optional`):
  31. Extra kwargs passed into the preprocessor's constructor.
  32. """
  33. super().__init__(
  34. model=model,
  35. preprocessor=preprocessor,
  36. config_file=config_file,
  37. device=device,
  38. auto_collate=auto_collate,
  39. compile=kwargs.pop('compile', False),
  40. compile_options=kwargs.pop('compile_options', {}))
  41. assert isinstance(self.model, Model), \
  42. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  43. if preprocessor is None:
  44. self.preprocessor = Preprocessor.from_pretrained(
  45. self.model.model_dir,
  46. sequence_length=sequence_length,
  47. **kwargs)
  48. def forward(self, inputs: Dict[str, Any],
  49. **forward_params) -> Dict[str, Any]:
  50. return self.model(**inputs, **forward_params)
  51. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  52. """process the prediction results
  53. Args:
  54. inputs (Dict[str, Any]): _description_
  55. Returns:
  56. Dict[str, Any]: the predicted text representation
  57. """
  58. embeddings = inputs['query_embeddings']
  59. doc_embeddings = inputs['doc_embeddings']
  60. if doc_embeddings is not None:
  61. embeddings = torch.cat((embeddings, doc_embeddings), dim=0)
  62. embeddings = embeddings.detach().cpu().numpy()
  63. if doc_embeddings is not None:
  64. scores = np.dot(embeddings[0:1, ],
  65. np.transpose(embeddings[1:, ], (1, 0))).tolist()[0]
  66. else:
  67. scores = []
  68. return {
  69. OutputKeys.TEXT_EMBEDDING: embeddings,
  70. OutputKeys.SCORES: scores
  71. }