extractive_summarization_pipeline.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import re
  3. from typing import Any, Dict, List, Union
  4. import numpy as np
  5. import torch
  6. from datasets import Dataset
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models import Model
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Pipeline, Tensor
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.preprocessors import \
  13. DocumentSegmentationTransformersPreprocessor
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. __all__ = ['ExtractiveSummarizationPipeline']
  18. @PIPELINES.register_module(
  19. Tasks.extractive_summarization,
  20. module_name=Pipelines.extractive_summarization)
  21. class ExtractiveSummarizationPipeline(Pipeline):
  22. def __init__(
  23. self,
  24. model: Union[Model, str],
  25. preprocessor: DocumentSegmentationTransformersPreprocessor = None,
  26. config_file: str = None,
  27. device: str = 'gpu',
  28. auto_collate=True,
  29. **kwargs):
  30. super().__init__(
  31. model=model,
  32. preprocessor=preprocessor,
  33. config_file=config_file,
  34. device=device,
  35. auto_collate=auto_collate,
  36. **kwargs)
  37. kwargs.pop('compile', None)
  38. kwargs.pop('compile_options', None)
  39. self.model_dir = self.model.model_dir
  40. self.model_cfg = self.model.model_cfg
  41. if preprocessor is None:
  42. self.preprocessor = DocumentSegmentationTransformersPreprocessor(
  43. self.model_dir, self.model.config.max_position_embeddings,
  44. **kwargs)
  45. def __call__(self, documents: Union[List[str], str]) -> Dict[str, Any]:
  46. output = self.predict(documents)
  47. output = self.postprocess(output)
  48. return output
  49. def predict(self, documents: Union[List[str], str]) -> Dict[str, Any]:
  50. pred_samples = self.cut_documents(documents)
  51. predict_examples = Dataset.from_dict(pred_samples)
  52. # Predict Feature Creation
  53. predict_dataset = self.preprocessor(predict_examples, self.model_cfg)
  54. num_examples = len(
  55. predict_examples[self.preprocessor.context_column_name])
  56. num_samples = len(
  57. predict_dataset[self.preprocessor.context_column_name])
  58. labels = predict_dataset.pop('labels')
  59. sentences = predict_dataset.pop('sentences')
  60. example_ids = predict_dataset.pop(
  61. self.preprocessor.example_id_column_name)
  62. with torch.no_grad():
  63. input = {
  64. key: torch.tensor(val)
  65. for key, val in predict_dataset.items()
  66. }
  67. logits = self.model.forward(**input).logits
  68. predictions = np.argmax(logits, axis=2)
  69. assert len(sentences) == len(
  70. predictions), 'sample {} infer_sample {} prediction {}'.format(
  71. num_samples, len(sentences), len(predictions))
  72. # Remove ignored index (special tokens)
  73. true_predictions = [
  74. [
  75. self.preprocessor.label_list[p]
  76. for (p, l) in zip(prediction, label) if l != -100 # noqa *
  77. ] for prediction, label in zip(predictions, labels)
  78. ]
  79. true_labels = [
  80. [
  81. self.preprocessor.label_list[l]
  82. for (p, l) in zip(prediction, label) if l != -100 # noqa *
  83. ] for prediction, label in zip(predictions, labels)
  84. ]
  85. # Save predictions
  86. out = []
  87. for i in range(num_examples):
  88. out.append({'sentences': [], 'labels': [], 'predictions': []})
  89. for prediction, sentence_list, label, example_id in zip(
  90. true_predictions, sentences, true_labels, example_ids):
  91. if len(label) < len(sentence_list):
  92. label.append('O')
  93. prediction.append('O')
  94. assert len(sentence_list) == len(prediction), '{} {}'.format(
  95. len(sentence_list), len(prediction))
  96. assert len(sentence_list) == len(label), '{} {}'.format(
  97. len(sentence_list), len(label))
  98. out[example_id]['sentences'].extend(sentence_list)
  99. out[example_id]['labels'].extend(label)
  100. out[example_id]['predictions'].extend(prediction)
  101. return out
  102. def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  103. """process the prediction results
  104. Args:
  105. inputs (Dict[str, Any]): _description_
  106. Returns:
  107. Dict[str, str]: the prediction results
  108. """
  109. result = []
  110. list_count = len(inputs)
  111. for num in range(list_count):
  112. res = []
  113. for s, p in zip(inputs[num]['sentences'],
  114. inputs[num]['predictions']):
  115. s = s.strip()
  116. if p == 'B-EOP':
  117. res.append(s)
  118. result.append('\n'.join(res))
  119. if list_count == 1:
  120. return {OutputKeys.TEXT: result[0]}
  121. else:
  122. return {OutputKeys.TEXT: result}
  123. def cut_documents(self, para: Union[List[str], str]):
  124. if isinstance(para, str):
  125. document_list = [para]
  126. else:
  127. document_list = para
  128. sentences = []
  129. labels = []
  130. example_id = []
  131. id = 0
  132. for document in document_list:
  133. sentence = self.cut_sentence(document)
  134. label = ['O'] * (len(sentence) - 1) + ['B-EOP']
  135. sentences.append(sentence)
  136. labels.append(label)
  137. example_id.append(id)
  138. id += 1
  139. return {
  140. 'example_id': example_id,
  141. 'sentences': sentences,
  142. 'labels': labels
  143. }
  144. def cut_sentence(self, para):
  145. para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa *
  146. para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa *
  147. para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa *
  148. para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa *
  149. para = para.rstrip()
  150. return [_ for _ in para.split('\n') if _]