token_classification_pipeline.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. from typing import Any, Dict, List, Optional, Tuple, Union
  4. import numpy as np
  5. import torch
  6. from modelscope.metainfo import Pipelines
  7. from modelscope.models import Model
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.pipelines.base import Input, Pipeline
  10. from modelscope.pipelines.builder import PIPELINES
  11. from modelscope.preprocessors import Preprocessor
  12. from modelscope.utils.constant import ModelFile, Tasks
  13. from modelscope.utils.tensor_utils import (torch_nested_detach,
  14. torch_nested_numpify)
  15. __all__ = ['TokenClassificationPipeline']
  16. @PIPELINES.register_module(
  17. Tasks.token_classification, module_name=Pipelines.token_classification)
  18. @PIPELINES.register_module(
  19. Tasks.token_classification, module_name=Pipelines.part_of_speech)
  20. @PIPELINES.register_module(
  21. Tasks.token_classification, module_name=Pipelines.word_segmentation)
  22. @PIPELINES.register_module(
  23. Tasks.token_classification, module_name=Pipelines.named_entity_recognition)
  24. @PIPELINES.register_module(
  25. Tasks.part_of_speech, module_name=Pipelines.part_of_speech)
  26. class TokenClassificationPipeline(Pipeline):
  27. def __init__(self,
  28. model: Union[Model, str],
  29. preprocessor: Optional[Preprocessor] = None,
  30. config_file: str = None,
  31. device: str = 'gpu',
  32. auto_collate=True,
  33. sequence_length=512,
  34. **kwargs):
  35. """use `model` and `preprocessor` to create a token classification pipeline for prediction
  36. Args:
  37. model (str or Model): A model instance or a model local dir or a model id in the model hub.
  38. preprocessor (Preprocessor): a preprocessor instance, must not be None.
  39. kwargs (dict, `optional`):
  40. Extra kwargs passed into the preprocessor's constructor.
  41. """
  42. super().__init__(
  43. model=model,
  44. preprocessor=preprocessor,
  45. config_file=config_file,
  46. device=device,
  47. auto_collate=auto_collate,
  48. compile=kwargs.pop('compile', False),
  49. compile_options=kwargs.pop('compile_options', {}))
  50. assert isinstance(self.model, Model), \
  51. f'please check whether model config exists in {ModelFile.CONFIGURATION}'
  52. if preprocessor is None:
  53. self.preprocessor = Preprocessor.from_pretrained(
  54. self.model.model_dir,
  55. sequence_length=sequence_length,
  56. **kwargs)
  57. self.model.eval()
  58. self.sequence_length = sequence_length
  59. assert hasattr(self.preprocessor, 'id2label')
  60. self.id2label = self.preprocessor.id2label
  61. def forward(self, inputs: Dict[str, Any],
  62. **forward_params) -> Dict[str, Any]:
  63. text = inputs.pop(OutputKeys.TEXT)
  64. with torch.no_grad():
  65. return {
  66. **self.model(**inputs, **forward_params), OutputKeys.TEXT: text
  67. }
  68. def postprocess(self, inputs: Dict[str, Any],
  69. **postprocess_params) -> Dict[str, Any]:
  70. """Process the prediction results
  71. Args:
  72. inputs (Dict[str, Any]): should be tensors from model
  73. Returns:
  74. Dict[str, Any]: the prediction results
  75. """
  76. chunks = self._chunk_process(inputs, **postprocess_params)
  77. return {OutputKeys.OUTPUT: chunks}
  78. def _chunk_process(self, inputs: Dict[str, Any],
  79. **postprocess_params) -> List:
  80. """process the prediction results and output as chunks
  81. Args:
  82. inputs (Dict[str, Any]): should be tensors from model
  83. Returns:
  84. List: The output chunks
  85. """
  86. text = inputs['text']
  87. # TODO post_process does not support batch for now.
  88. if OutputKeys.PREDICTIONS not in inputs:
  89. logits = inputs[OutputKeys.LOGITS]
  90. if len(logits.shape) == 3:
  91. logits = logits[0]
  92. predictions = torch.argmax(logits, dim=-1)
  93. else:
  94. predictions = inputs[OutputKeys.PREDICTIONS]
  95. if len(predictions.shape) == 2:
  96. predictions = predictions[0]
  97. offset_mapping = inputs['offset_mapping']
  98. if len(offset_mapping.shape) == 3:
  99. offset_mapping = offset_mapping[0]
  100. label_mask = inputs.get('label_mask')
  101. if label_mask is not None:
  102. masked_lengths = label_mask.sum(-1).long().cpu().item()
  103. offset_mapping = torch.narrow(
  104. offset_mapping, 0, 0,
  105. masked_lengths) # index_select only move loc, not resize
  106. if len(label_mask.shape) == 2:
  107. label_mask = label_mask[0]
  108. predictions = predictions.masked_select(label_mask)
  109. offset_mapping = torch_nested_numpify(
  110. torch_nested_detach(offset_mapping))
  111. predictions = torch_nested_numpify(torch_nested_detach(predictions))
  112. labels = [self.id2label[x] for x in predictions]
  113. return_prob = postprocess_params.pop('return_prob', True)
  114. if return_prob:
  115. if OutputKeys.LOGITS in inputs:
  116. logits = inputs[OutputKeys.LOGITS]
  117. if len(logits.shape) == 3:
  118. logits = logits[0]
  119. probs = torch_nested_numpify(
  120. torch_nested_detach(logits.softmax(-1)))
  121. else:
  122. return_prob = False
  123. chunks = []
  124. chunk = {}
  125. for i, (label, offsets) in enumerate(zip(labels, offset_mapping)):
  126. if label[0] in 'BS':
  127. if chunk:
  128. chunk['span'] = text[chunk['start']:chunk['end']]
  129. chunks.append(chunk)
  130. chunk = {
  131. 'type': label[2:],
  132. 'start': offsets[0],
  133. 'end': offsets[1]
  134. }
  135. if return_prob:
  136. chunk['prob'] = probs[i][predictions[i]]
  137. if label[0] in 'I':
  138. if not chunk:
  139. chunk = {
  140. 'type': label[2:],
  141. 'start': offsets[0],
  142. 'end': offsets[1]
  143. }
  144. if return_prob:
  145. chunk['prob'] = probs[i][predictions[i]]
  146. if label[0] in 'E':
  147. if not chunk:
  148. chunk = {
  149. 'type': label[2:],
  150. 'start': offsets[0],
  151. 'end': offsets[1]
  152. }
  153. if return_prob:
  154. chunk['prob'] = probs[i][predictions[i]]
  155. if label[0] in 'IES':
  156. if chunk:
  157. chunk['end'] = offsets[1]
  158. if label[0] in 'ES':
  159. if chunk:
  160. chunk['span'] = text[chunk['start']:chunk['end']]
  161. chunks.append(chunk)
  162. chunk = {}
  163. if chunk:
  164. chunk['span'] = text[chunk['start']:chunk['end']]
  165. chunks.append(chunk)
  166. return chunks
  167. def _process_single(self, input: Input, *args, **kwargs) -> Dict[str, Any]:
  168. split_max_length = kwargs.pop('split_max_length',
  169. 0) # default: no split
  170. if split_max_length <= 0:
  171. return super()._process_single(input, *args, **kwargs)
  172. else:
  173. split_texts, index_mapping = self._auto_split([input],
  174. split_max_length)
  175. outputs = []
  176. for text in split_texts:
  177. outputs.append(super()._process_single(text, *args, **kwargs))
  178. return self._auto_join(outputs, index_mapping)[0]
  179. def _process_batch(self, input: List[Input], batch_size: int, *args,
  180. **kwargs) -> List[Dict[str, Any]]:
  181. split_max_length = kwargs.pop('split_max_length',
  182. 0) # default: no split
  183. if split_max_length <= 0:
  184. return super()._process_batch(
  185. input, batch_size=batch_size, *args, **kwargs)
  186. else:
  187. split_texts, index_mapping = self._auto_split(
  188. input, split_max_length)
  189. outputs = super()._process_batch(
  190. split_texts, batch_size=batch_size, *args, **kwargs)
  191. return self._auto_join(outputs, index_mapping)
  192. def _auto_split(self, input_texts: List[str], split_max_length: int):
  193. split_texts = []
  194. index_mapping = {}
  195. new_idx = 0
  196. for raw_idx, text in enumerate(input_texts):
  197. if len(text) < split_max_length:
  198. split_texts.append(text)
  199. index_mapping[new_idx] = (raw_idx, 0)
  200. new_idx += 1
  201. else:
  202. n_split = math.ceil(len(text) / split_max_length)
  203. for i in range(n_split):
  204. offset = i * split_max_length
  205. split_texts.append(text[offset:offset + split_max_length])
  206. index_mapping[new_idx] = (raw_idx, offset)
  207. new_idx += 1
  208. return split_texts, index_mapping
  209. def _auto_join(
  210. self, outputs: List[Dict[str, Any]],
  211. index_mapping: Dict[int, Tuple[int, int]]) -> List[Dict[str, Any]]:
  212. joined_outputs = []
  213. for idx, output in enumerate(outputs):
  214. raw_idx, offset = index_mapping[idx]
  215. if raw_idx >= len(joined_outputs):
  216. joined_outputs.append(output)
  217. else:
  218. for chunk in output[OutputKeys.OUTPUT]:
  219. chunk['start'] += offset
  220. chunk['end'] += offset
  221. joined_outputs[raw_idx][OutputKeys.OUTPUT].append(chunk)
  222. return joined_outputs