text_ranking_pipeline.py 2.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Optional, Union
  3. import numpy as np
  4. from modelscope.metainfo import Pipelines
  5. from modelscope.models import Model
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.pipelines.base import Pipeline
  8. from modelscope.pipelines.builder import PIPELINES
  9. from modelscope.preprocessors import (Preprocessor,
  10. TextRankingTransformersPreprocessor)
  11. from modelscope.utils.constant import ModelFile, Tasks
  12. __all__ = ['TextRankingPipeline']
  13. @PIPELINES.register_module(
  14. Tasks.text_ranking, module_name=Pipelines.text_ranking)
  15. class TextRankingPipeline(Pipeline):
  16. def __init__(self,
  17. model: Union[Model, str],
  18. preprocessor: Optional[Preprocessor] = None,
  19. config_file: str = None,
  20. device: str = 'gpu',
  21. auto_collate=True,
  22. sequence_length=128,
  23. **kwargs):
  24. """Use `model` and `preprocessor` to create a nlp word segment pipeline for prediction.
  25. Args:
  26. model (str or Model): Supply either a local model dir which supported the WS task,
  27. or a model id from the model hub, or a torch model instance.
  28. preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
  29. the model if supplied.
  30. kwargs (dict, `optional`):
  31. Extra kwargs passed into the preprocessor's constructor.
  32. """
  33. super().__init__(
  34. model=model,
  35. preprocessor=preprocessor,
  36. config_file=config_file,
  37. device=device,
  38. auto_collate=auto_collate,
  39. compile=kwargs.pop('compile', False),
  40. compile_options=kwargs.pop('compile_options', {}))
  41. assert isinstance(self.model, Model), \
  42. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  43. if preprocessor is None:
  44. self.preprocessor = Preprocessor.from_pretrained(
  45. self.model.model_dir,
  46. sequence_length=sequence_length,
  47. **kwargs)
  48. def forward(self, inputs: Dict[str, Any],
  49. **forward_params) -> Dict[str, Any]:
  50. return self.model(**inputs, **forward_params)
  51. def postprocess(self, inputs: Dict[str, Any]) -> Dict[str, Any]:
  52. """process the prediction results
  53. Args:
  54. inputs (Dict[str, Any]): _description_
  55. Returns:
  56. Dict[str, Any]: the predicted text representation
  57. """
  58. def sigmoid(logits):
  59. return np.exp(logits) / (1 + np.exp(logits))
  60. logits = inputs[OutputKeys.LOGITS].squeeze(-1).detach().cpu().numpy()
  61. pred_list = sigmoid(logits).tolist()
  62. return {OutputKeys.SCORES: pred_list}