| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import re
- from typing import Any, Dict, List, Union
- import numpy as np
- import torch
- from datasets import Dataset
- from modelscope.metainfo import Pipelines
- from modelscope.models import Model
- from modelscope.outputs import OutputKeys
- from modelscope.pipelines.base import Pipeline, Tensor
- from modelscope.pipelines.builder import PIPELINES
- from modelscope.preprocessors import \
- DocumentSegmentationTransformersPreprocessor
- from modelscope.utils.constant import Tasks
- from modelscope.utils.logger import get_logger
- logger = get_logger()
- __all__ = ['ExtractiveSummarizationPipeline']
- @PIPELINES.register_module(
- Tasks.extractive_summarization,
- module_name=Pipelines.extractive_summarization)
- class ExtractiveSummarizationPipeline(Pipeline):
- def __init__(
- self,
- model: Union[Model, str],
- preprocessor: DocumentSegmentationTransformersPreprocessor = None,
- config_file: str = None,
- device: str = 'gpu',
- auto_collate=True,
- **kwargs):
- super().__init__(
- model=model,
- preprocessor=preprocessor,
- config_file=config_file,
- device=device,
- auto_collate=auto_collate,
- **kwargs)
- kwargs.pop('compile', None)
- kwargs.pop('compile_options', None)
- self.model_dir = self.model.model_dir
- self.model_cfg = self.model.model_cfg
- if preprocessor is None:
- self.preprocessor = DocumentSegmentationTransformersPreprocessor(
- self.model_dir, self.model.config.max_position_embeddings,
- **kwargs)
- def __call__(self, documents: Union[List[str], str]) -> Dict[str, Any]:
- output = self.predict(documents)
- output = self.postprocess(output)
- return output
- def predict(self, documents: Union[List[str], str]) -> Dict[str, Any]:
- pred_samples = self.cut_documents(documents)
- predict_examples = Dataset.from_dict(pred_samples)
- # Predict Feature Creation
- predict_dataset = self.preprocessor(predict_examples, self.model_cfg)
- num_examples = len(
- predict_examples[self.preprocessor.context_column_name])
- num_samples = len(
- predict_dataset[self.preprocessor.context_column_name])
- labels = predict_dataset.pop('labels')
- sentences = predict_dataset.pop('sentences')
- example_ids = predict_dataset.pop(
- self.preprocessor.example_id_column_name)
- with torch.no_grad():
- input = {
- key: torch.tensor(val)
- for key, val in predict_dataset.items()
- }
- logits = self.model.forward(**input).logits
- predictions = np.argmax(logits, axis=2)
- assert len(sentences) == len(
- predictions), 'sample {} infer_sample {} prediction {}'.format(
- num_samples, len(sentences), len(predictions))
- # Remove ignored index (special tokens)
- true_predictions = [
- [
- self.preprocessor.label_list[p]
- for (p, l) in zip(prediction, label) if l != -100 # noqa *
- ] for prediction, label in zip(predictions, labels)
- ]
- true_labels = [
- [
- self.preprocessor.label_list[l]
- for (p, l) in zip(prediction, label) if l != -100 # noqa *
- ] for prediction, label in zip(predictions, labels)
- ]
- # Save predictions
- out = []
- for i in range(num_examples):
- out.append({'sentences': [], 'labels': [], 'predictions': []})
- for prediction, sentence_list, label, example_id in zip(
- true_predictions, sentences, true_labels, example_ids):
- if len(label) < len(sentence_list):
- label.append('O')
- prediction.append('O')
- assert len(sentence_list) == len(prediction), '{} {}'.format(
- len(sentence_list), len(prediction))
- assert len(sentence_list) == len(label), '{} {}'.format(
- len(sentence_list), len(label))
- out[example_id]['sentences'].extend(sentence_list)
- out[example_id]['labels'].extend(label)
- out[example_id]['predictions'].extend(prediction)
- return out
- def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]:
- """process the prediction results
- Args:
- inputs (Dict[str, Any]): _description_
- Returns:
- Dict[str, str]: the prediction results
- """
- result = []
- list_count = len(inputs)
- for num in range(list_count):
- res = []
- for s, p in zip(inputs[num]['sentences'],
- inputs[num]['predictions']):
- s = s.strip()
- if p == 'B-EOP':
- res.append(s)
- result.append('\n'.join(res))
- if list_count == 1:
- return {OutputKeys.TEXT: result[0]}
- else:
- return {OutputKeys.TEXT: result}
- def cut_documents(self, para: Union[List[str], str]):
- if isinstance(para, str):
- document_list = [para]
- else:
- document_list = para
- sentences = []
- labels = []
- example_id = []
- id = 0
- for document in document_list:
- sentence = self.cut_sentence(document)
- label = ['O'] * (len(sentence) - 1) + ['B-EOP']
- sentences.append(sentence)
- labels.append(label)
- example_id.append(id)
- id += 1
- return {
- 'example_id': example_id,
- 'sentences': sentences,
- 'labels': labels
- }
- def cut_sentence(self, para):
- para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa *
- para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa *
- para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa *
- para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa *
- para = para.rstrip()
- return [_ for _ in para.split('\n') if _]
|