# Copyright (c) Alibaba, Inc. and its affiliates. import re from typing import Any, Dict, List, Union import numpy as np import torch from datasets import Dataset from modelscope.metainfo import Pipelines from modelscope.models import Model from modelscope.outputs import OutputKeys from modelscope.pipelines.base import Pipeline, Tensor from modelscope.pipelines.builder import PIPELINES from modelscope.preprocessors import \ DocumentSegmentationTransformersPreprocessor from modelscope.utils.constant import Tasks from modelscope.utils.logger import get_logger logger = get_logger() __all__ = ['ExtractiveSummarizationPipeline'] @PIPELINES.register_module( Tasks.extractive_summarization, module_name=Pipelines.extractive_summarization) class ExtractiveSummarizationPipeline(Pipeline): def __init__( self, model: Union[Model, str], preprocessor: DocumentSegmentationTransformersPreprocessor = None, config_file: str = None, device: str = 'gpu', auto_collate=True, **kwargs): super().__init__( model=model, preprocessor=preprocessor, config_file=config_file, device=device, auto_collate=auto_collate, **kwargs) kwargs.pop('compile', None) kwargs.pop('compile_options', None) self.model_dir = self.model.model_dir self.model_cfg = self.model.model_cfg if preprocessor is None: self.preprocessor = DocumentSegmentationTransformersPreprocessor( self.model_dir, self.model.config.max_position_embeddings, **kwargs) def __call__(self, documents: Union[List[str], str]) -> Dict[str, Any]: output = self.predict(documents) output = self.postprocess(output) return output def predict(self, documents: Union[List[str], str]) -> Dict[str, Any]: pred_samples = self.cut_documents(documents) predict_examples = Dataset.from_dict(pred_samples) # Predict Feature Creation predict_dataset = self.preprocessor(predict_examples, self.model_cfg) num_examples = len( predict_examples[self.preprocessor.context_column_name]) num_samples = len( predict_dataset[self.preprocessor.context_column_name]) labels = predict_dataset.pop('labels') sentences = predict_dataset.pop('sentences') example_ids = predict_dataset.pop( self.preprocessor.example_id_column_name) with torch.no_grad(): input = { key: torch.tensor(val) for key, val in predict_dataset.items() } logits = self.model.forward(**input).logits predictions = np.argmax(logits, axis=2) assert len(sentences) == len( predictions), 'sample {} infer_sample {} prediction {}'.format( num_samples, len(sentences), len(predictions)) # Remove ignored index (special tokens) true_predictions = [ [ self.preprocessor.label_list[p] for (p, l) in zip(prediction, label) if l != -100 # noqa * ] for prediction, label in zip(predictions, labels) ] true_labels = [ [ self.preprocessor.label_list[l] for (p, l) in zip(prediction, label) if l != -100 # noqa * ] for prediction, label in zip(predictions, labels) ] # Save predictions out = [] for i in range(num_examples): out.append({'sentences': [], 'labels': [], 'predictions': []}) for prediction, sentence_list, label, example_id in zip( true_predictions, sentences, true_labels, example_ids): if len(label) < len(sentence_list): label.append('O') prediction.append('O') assert len(sentence_list) == len(prediction), '{} {}'.format( len(sentence_list), len(prediction)) assert len(sentence_list) == len(label), '{} {}'.format( len(sentence_list), len(label)) out[example_id]['sentences'].extend(sentence_list) out[example_id]['labels'].extend(label) out[example_id]['predictions'].extend(prediction) return out def postprocess(self, inputs: Dict[str, Tensor]) -> Dict[str, Tensor]: """process the prediction results Args: inputs (Dict[str, Any]): _description_ Returns: Dict[str, str]: the prediction results """ result = [] list_count = len(inputs) for num in range(list_count): res = [] for s, p in zip(inputs[num]['sentences'], inputs[num]['predictions']): s = s.strip() if p == 'B-EOP': res.append(s) result.append('\n'.join(res)) if list_count == 1: return {OutputKeys.TEXT: result[0]} else: return {OutputKeys.TEXT: result} def cut_documents(self, para: Union[List[str], str]): if isinstance(para, str): document_list = [para] else: document_list = para sentences = [] labels = [] example_id = [] id = 0 for document in document_list: sentence = self.cut_sentence(document) label = ['O'] * (len(sentence) - 1) + ['B-EOP'] sentences.append(sentence) labels.append(label) example_id.append(id) id += 1 return { 'example_id': example_id, 'sentences': sentences, 'labels': labels } def cut_sentence(self, para): para = re.sub(r'([。!.!?\?])([^”’])', r'\1\n\2', para) # noqa * para = re.sub(r'(\.{6})([^”’])', r'\1\n\2', para) # noqa * para = re.sub(r'(\…{2})([^”’])', r'\1\n\2', para) # noqa * para = re.sub(r'([。!?\?][”’])([^,。!?\?])', r'\1\n\2', para) # noqa * para = para.rstrip() return [_ for _ in para.split('\n') if _]