text_classification_pipeline.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict, Union
  3. import numpy as np
  4. import torch
  5. from modelscope.metainfo import Pipelines, Preprocessors
  6. from modelscope.models.base import Model
  7. from modelscope.outputs import OutputKeys, TextClassificationModelOutput
  8. from modelscope.pipelines.base import Pipeline
  9. from modelscope.pipelines.builder import PIPELINES
  10. from modelscope.pipelines.util import batch_process
  11. from modelscope.preprocessors import Preprocessor
  12. from modelscope.utils.constant import Fields, ModelFile, Tasks
  13. from modelscope.utils.logger import get_logger
  14. logger = get_logger()
  15. @PIPELINES.register_module(
  16. Tasks.text_classification, module_name=Pipelines.sentiment_analysis)
  17. @PIPELINES.register_module(Tasks.nli, module_name=Pipelines.nli)
  18. @PIPELINES.register_module(
  19. Tasks.sentence_similarity, module_name=Pipelines.sentence_similarity)
  20. @PIPELINES.register_module(
  21. Tasks.text_classification, module_name=Pipelines.text_classification)
  22. @PIPELINES.register_module(
  23. Tasks.text_classification, module_name=Pipelines.sentiment_classification)
  24. @PIPELINES.register_module(
  25. Tasks.text_classification, module_name=Pipelines.sentence_similarity)
  26. @PIPELINES.register_module(
  27. Tasks.sentiment_classification,
  28. module_name=Pipelines.sentiment_classification)
  29. class TextClassificationPipeline(Pipeline):
  30. def __init__(self,
  31. model: Union[Model, str],
  32. preprocessor: Preprocessor = None,
  33. config_file: str = None,
  34. device: str = 'gpu',
  35. auto_collate=True,
  36. **kwargs):
  37. """The inference pipeline for all the text classification sub-tasks.
  38. Args:
  39. model (`str` or `Model` or module instance): A model instance or a model local dir
  40. or a model id in the model hub.
  41. preprocessor (`Preprocessor`, `optional`): A Preprocessor instance.
  42. kwargs (dict, `optional`):
  43. Extra kwargs passed into the preprocessor's constructor.
  44. Examples:
  45. >>> from modelscope.pipelines import pipeline
  46. >>> pipeline_ins = pipeline('text-classification',
  47. model='damo/nlp_structbert_sentence-similarity_chinese-base')
  48. >>> input = ('这是个测试', '这也是个测试')
  49. >>> print(pipeline_ins(input))
  50. """
  51. super().__init__(
  52. model=model,
  53. preprocessor=preprocessor,
  54. config_file=config_file,
  55. device=device,
  56. auto_collate=auto_collate,
  57. compile=kwargs.pop('compile', False),
  58. compile_options=kwargs.pop('compile_options', {}))
  59. assert isinstance(self.model, Model), \
  60. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  61. if preprocessor is None:
  62. if self.model.__class__.__name__ == 'OfaForAllTasks':
  63. self.preprocessor = Preprocessor.from_pretrained(
  64. model_name_or_path=self.model.model_dir,
  65. type=Preprocessors.ofa_tasks_preprocessor,
  66. field=Fields.multi_modal,
  67. **kwargs)
  68. else:
  69. first_sequence = kwargs.pop('first_sequence', 'text')
  70. second_sequence = kwargs.pop('second_sequence', None)
  71. sequence_length = kwargs.pop('sequence_length', 512)
  72. self.preprocessor = Preprocessor.from_pretrained(
  73. self.model.model_dir, **{
  74. 'first_sequence': first_sequence,
  75. 'second_sequence': second_sequence,
  76. 'sequence_length': sequence_length,
  77. **kwargs
  78. })
  79. if hasattr(self.preprocessor, 'id2label'):
  80. self.id2label = self.preprocessor.id2label
  81. def _batch(self, data):
  82. if self.model.__class__.__name__ == 'OfaForAllTasks':
  83. return batch_process(self.model, data)
  84. else:
  85. return super(TextClassificationPipeline, self)._batch(data)
  86. def forward(self, inputs: Dict[str, Any],
  87. **forward_params) -> Dict[str, Any]:
  88. if self.model.__class__.__name__ == 'OfaForAllTasks':
  89. with torch.no_grad():
  90. return super().forward(inputs, **forward_params)
  91. return self.model(**inputs, **forward_params)
  92. def postprocess(self,
  93. inputs: Union[Dict[str, Any],
  94. TextClassificationModelOutput],
  95. topk: int = None) -> Dict[str, Any]:
  96. """Process the prediction results
  97. Args:
  98. inputs (`Dict[str, Any]` or `TextClassificationModelOutput`): The model output, please check
  99. the `TextClassificationModelOutput` class for details.
  100. topk (int): The topk probs to take
  101. Returns:
  102. Dict[str, Any]: the prediction results.
  103. scores: The probabilities of each label.
  104. labels: The real labels.
  105. Label at index 0 is the smallest probability.
  106. """
  107. if self.model.__class__.__name__ == 'OfaForAllTasks':
  108. return inputs
  109. else:
  110. if getattr(self, 'id2label', None) is None:
  111. logger.warning(
  112. 'The id2label mapping is None, will return original ids.')
  113. logits = inputs[OutputKeys.LOGITS].cpu().numpy()
  114. if logits.shape[0] == 1:
  115. logits = logits[0]
  116. def softmax(logits):
  117. exp = np.exp(logits - np.max(logits, axis=-1, keepdims=True))
  118. return exp / exp.sum(axis=-1, keepdims=True)
  119. probs = softmax(logits)
  120. num_classes = probs.shape[-1]
  121. topk = min(topk, num_classes) if topk is not None else num_classes
  122. top_indices = np.argpartition(probs, -topk)[-topk:]
  123. probs = np.take_along_axis(probs, top_indices, axis=-1).tolist()
  124. def map_to_label(id):
  125. if getattr(self, 'id2label', None) is not None:
  126. if id in self.id2label:
  127. return self.id2label[id]
  128. elif str(id) in self.id2label:
  129. return self.id2label[str(id)]
  130. else:
  131. raise Exception(
  132. f'id {id} not found in id2label: {self.id2label}')
  133. else:
  134. return id
  135. v_func = np.vectorize(map_to_label)
  136. top_indices = v_func(top_indices).tolist()
  137. probs = list(reversed(probs))
  138. top_indices = list(reversed(top_indices))
  139. return {OutputKeys.SCORES: probs, OutputKeys.LABELS: top_indices}