document_segmentation_pipeline.py 9.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import re
  3. from typing import Any, Dict, List, Union
  4. import numpy as np
  5. import torch
  6. from datasets import Dataset
  7. from modelscope.metainfo import Pipelines
  8. from modelscope.models import Model
  9. from modelscope.outputs import OutputKeys
  10. from modelscope.pipelines.base import Pipeline, Tensor
  11. from modelscope.pipelines.builder import PIPELINES
  12. from modelscope.preprocessors import \
  13. DocumentSegmentationTransformersPreprocessor
  14. from modelscope.utils.constant import Tasks
  15. from modelscope.utils.logger import get_logger
  16. logger = get_logger()
  17. __all__ = ['DocumentSegmentationPipeline']
  18. @PIPELINES.register_module(
  19. Tasks.document_segmentation, module_name=Pipelines.document_segmentation)
  20. class DocumentSegmentationPipeline(Pipeline):
  21. def __init__(
  22. self,
  23. model: Union[Model, str],
  24. preprocessor: DocumentSegmentationTransformersPreprocessor = None,
  25. config_file: str = None,
  26. device: str = 'gpu',
  27. auto_collate=True,
  28. **kwargs):
  29. """The document segmentation pipeline.
  30. Args:
  31. model (str or Model): Supply either a local model dir or a model id from the model hub
  32. preprocessor (Preprocessor): An optional preprocessor instance, please make sure the preprocessor fits for
  33. the model if supplied.
  34. """
  35. super().__init__(
  36. model=model,
  37. preprocessor=preprocessor,
  38. config_file=config_file,
  39. device=device,
  40. auto_collate=auto_collate,
  41. **kwargs)
  42. kwargs.pop('compile', None)
  43. kwargs.pop('compile_options', None)
  44. self.model_dir = self.model.model_dir
  45. self.model_cfg = self.model.model_cfg
  46. if preprocessor is None:
  47. self.preprocessor = DocumentSegmentationTransformersPreprocessor(
  48. self.model_dir, self.model.config.max_position_embeddings,
  49. **kwargs)
  50. def __call__(
  51. self, documents: Union[List[List[str]], List[str],
  52. str]) -> Dict[str, Any]:
  53. output = self.predict(documents)
  54. output = self.postprocess(output)
  55. return output
  56. def predict(
  57. self, documents: Union[List[List[str]], List[str],
  58. str]) -> Dict[str, Any]:
  59. pred_samples = self.cut_documents(documents)
  60. if self.model_cfg['level'] == 'topic':
  61. paragraphs = pred_samples.pop('paragraphs')
  62. predict_examples = Dataset.from_dict(pred_samples)
  63. # Predict Feature Creation
  64. predict_dataset = self.preprocessor(predict_examples, self.model_cfg)
  65. num_examples = len(
  66. predict_examples[self.preprocessor.context_column_name])
  67. num_samples = len(
  68. predict_dataset[self.preprocessor.context_column_name])
  69. if self.model_cfg['type'] == 'bert':
  70. predict_dataset.pop('segment_ids')
  71. labels = predict_dataset.pop('labels')
  72. sentences = predict_dataset.pop('sentences')
  73. example_ids = predict_dataset.pop(
  74. self.preprocessor.example_id_column_name)
  75. if (self.model or (self.has_multiple_models and self.models[0])):
  76. if not self._model_prepare:
  77. self.prepare_model()
  78. with torch.no_grad():
  79. input = {
  80. key: torch.tensor(val).to(self.device)
  81. for key, val in predict_dataset.items()
  82. }
  83. predictions = self.model.forward(**input).logits.cpu()
  84. predictions = np.argmax(predictions, axis=2)
  85. assert len(sentences) == len(
  86. predictions), 'sample {} infer_sample {} prediction {}'.format(
  87. num_samples, len(sentences), len(predictions))
  88. # Remove ignored index (special tokens)
  89. true_predictions = [
  90. [
  91. self.preprocessor.label_list[p]
  92. for (p, l) in zip(prediction, label) if l != -100 # noqa *
  93. ] for prediction, label in zip(predictions, labels)
  94. ]
  95. true_labels = [
  96. [
  97. self.preprocessor.label_list[l]
  98. for (p, l) in zip(prediction, label) if l != -100 # noqa *
  99. ] for prediction, label in zip(predictions, labels)
  100. ]
  101. # Save predictions
  102. out = []
  103. for i in range(num_examples):
  104. if self.model_cfg['level'] == 'topic':
  105. out.append({
  106. 'sentences': [],
  107. 'labels': [],
  108. 'predictions': [],
  109. 'paragraphs': paragraphs[i]
  110. })
  111. else:
  112. out.append({'sentences': [], 'labels': [], 'predictions': []})
  113. for prediction, sentence_list, label, example_id in zip(
  114. true_predictions, sentences, true_labels, example_ids):
  115. if self.model_cfg['level'] == 'doc':
  116. if len(label) < len(sentence_list):
  117. label.append('B-EOP')
  118. prediction.append('B-EOP')
  119. assert len(sentence_list) == len(prediction), '{} {}'.format(
  120. len(sentence_list), len(prediction))
  121. assert len(sentence_list) == len(label), '{} {}'.format(
  122. len(sentence_list), len(label))
  123. out[example_id]['sentences'].extend(sentence_list)
  124. out[example_id]['labels'].extend(label)
  125. out[example_id]['predictions'].extend(prediction)
  126. if self.model_cfg['level'] == 'topic':
  127. for i in range(num_examples):
  128. assert len(out[i]['predictions']) + 1 == len(
  129. out[i]['paragraphs'])
  130. out[i]['predictions'].append('B-EOP')
  131. out[i]['labels'].append('B-EOP')
  132. return out
  133. def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
  134. """process the prediction results
  135. Args:
  136. inputs (Dict[str, Any]): _description_
  137. Returns:
  138. Dict[str, str]: the prediction results
  139. """
  140. result = []
  141. res_preds = []
  142. list_count = len(inputs)
  143. if self.model_cfg['level'] == 'topic':
  144. for num in range(list_count):
  145. res = []
  146. pred = []
  147. for s, p, l in zip(inputs[num]['paragraphs'],
  148. inputs[num]['predictions'],
  149. inputs[num]['labels']):
  150. s = s.strip()
  151. if p == 'B-EOP':
  152. s = ''.join([s, '\n\n\t'])
  153. pred.append(1)
  154. else:
  155. s = ''.join([s, '\n\t'])
  156. pred.append(0)
  157. res.append(s)
  158. res_preds.append(pred)
  159. document = ('\t' + ''.join(res).strip())
  160. result.append(document)
  161. else:
  162. for num in range(list_count):
  163. res = []
  164. for s, p in zip(inputs[num]['sentences'],
  165. inputs[num]['predictions']):
  166. s = s.strip()
  167. if p == 'B-EOP':
  168. s = ''.join([s, '\n\t'])
  169. res.append(s)
  170. document = ('\t' + ''.join(res))
  171. result.append(document)
  172. if list_count == 1:
  173. return {OutputKeys.TEXT: result[0]}
  174. else:
  175. return {OutputKeys.TEXT: result}
  176. def cut_documents(self, para: Union[List[List[str]], List[str], str]):
  177. document_list = para
  178. paragraphs = []
  179. sentences = []
  180. labels = []
  181. example_id = []
  182. id = 0
  183. if self.model_cfg['level'] == 'topic':
  184. if isinstance(para, str):
  185. document_list = [[para]]
  186. elif isinstance(para[0], str):
  187. document_list = [para]
  188. for document in document_list:
  189. sentence = []
  190. label = []
  191. for item in document:
  192. sentence_of_current_paragraph = self.cut_sentence(item)
  193. sentence.extend(sentence_of_current_paragraph)
  194. label.extend(['-100']
  195. * (len(sentence_of_current_paragraph) - 1)
  196. + ['B-EOP'])
  197. paragraphs.append(document)
  198. sentences.append(sentence)
  199. labels.append(label)
  200. example_id.append(id)
  201. id += 1
  202. return {
  203. 'example_id': example_id,
  204. 'sentences': sentences,
  205. 'paragraphs': paragraphs,
  206. 'labels': labels
  207. }
  208. else:
  209. if isinstance(para, str):
  210. document_list = [para]
  211. for document in document_list:
  212. sentence = self.cut_sentence(document)
  213. label = ['O'] * (len(sentence) - 1) + ['B-EOP']
  214. sentences.append(sentence)
  215. labels.append(label)
  216. example_id.append(id)
  217. id += 1
  218. return {
  219. 'example_id': example_id,
  220. 'sentences': sentences,
  221. 'labels': labels
  222. }
  223. def cut_sentence(self, para):
  224. para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa *
  225. para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa *
  226. para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa *
  227. para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa *
  228. para = para.rstrip()
  229. return [_ for _ in para.split('\n') if _]