| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from copy import deepcopy
- import torch
- from torch import nn
- from modelscope.metainfo import Models
- from modelscope.models.base import TorchModel
- from modelscope.models.builder import MODELS
- from modelscope.utils.constant import ModelFile, Tasks
- from .backbone import BertEncoder, BertModel, BertPreTrainedModel
- __all__ = ['SiameseUieModel']
- @MODELS.register_module(Tasks.siamese_uie, module_name=Models.bert)
- class SiameseUieModel(BertPreTrainedModel):
- r"""SiameseUIE general information extraction model,
- based on the construction idea of prompt (Prompt) + text (Text),
- uses pointer network (Pointer Network) to
- realize segment extraction (Span Extraction), so as to
- realize named entity recognition (NER), relation extraction (RE),
- Extraction of various tasks such as event extraction (EE),
- attribute sentiment extraction (ABSA), etc. Different from
- the existing general information extraction tasks on the market:
- """
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.config = config
- self.plm = BertModel(self.config, add_pooling_layer=True)
- self.dropout = nn.Dropout(config.hidden_dropout_prob)
- self.head_clsf = nn.Linear(config.hidden_size, 1)
- self.tail_clsf = nn.Linear(config.hidden_size, 1)
- self.set_crossattention_layer()
- def set_crossattention_layer(self, num_hidden_layers=6):
- crossattention_config = deepcopy(self.config)
- crossattention_config.num_hidden_layers = num_hidden_layers
- self.config.num_hidden_layers -= num_hidden_layers
- self.crossattention = BertEncoder(crossattention_config)
- self.crossattention.layer = self.plm.encoder.layer[self.config.
- num_hidden_layers:]
- self.plm.encoder.layer = self.plm.encoder.layer[:self.config.
- num_hidden_layers]
- def circle_loss(self, y_pred, y_true):
- batch_size = y_true.size(0)
- y_true = y_true.view(batch_size, -1)
- y_pred = y_pred.view(batch_size, -1)
- y_pred = (1 - 2 * y_true) * y_pred
- y_pred_neg = y_pred - y_true * 1e12
- y_pred_pos = y_pred - (1 - y_true) * 1e12
- zeros = torch.zeros_like(y_pred[:, :1])
- y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
- y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
- neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
- pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
- return (neg_loss + pos_loss).mean()
- def get_cross_attention_output(self, hidden_states, attention_mask,
- encoder_hidden_states,
- encoder_attention_mask):
- cat_hidden_states = torch.cat([hidden_states, encoder_hidden_states],
- dim=1)
- cat_attention_mask = torch.cat(
- [attention_mask, encoder_attention_mask], dim=1)
- cat_attention_mask = self.plm.get_extended_attention_mask(
- cat_attention_mask,
- cat_hidden_states.size()[:2])
- hidden_states = self.crossattention(
- hidden_states=cat_hidden_states, attention_mask=cat_attention_mask
- )[0][:, :hidden_states.size()[1], :]
- return hidden_states
- def get_plm_sequence_output(self,
- input_ids,
- attention_mask,
- position_ids=None,
- is_hint=False):
- token_type_ids = torch.ones_like(attention_mask) if is_hint else None
- sequence_output = self.plm(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids)[0]
- return sequence_output
- def forward(self, input_ids, attention_masks, hint_ids,
- cross_attention_masks, head_labels, tail_labels):
- """train forward
- Args:
- input_ids (Tensor): input token ids of text.
- attention_masks (Tensor): attention_masks of text.
- hint_ids (Tensor): input token ids of prompt.
- cross_attention_masks (Tensor): attention_masks of prompt.
- head_labels (Tensor): labels of start position.
- tail_labels (Tensor): labels of end position.
- Returns:
- Dict[str, float]: the loss
- Example:
- {"loss": 0.5091743}
- """
- sequence_output = self.get_plm_sequence_output(input_ids,
- attention_masks)
- assert hint_ids.size(1) + input_ids.size(1) <= 512
- position_ids = torch.arange(hint_ids.size(1)).expand(
- (1, -1)) + input_ids.size(1)
- position_ids = position_ids.to(sequence_output.device)
- hint_sequence_output = self.get_plm_sequence_output(
- hint_ids, cross_attention_masks, position_ids, is_hint=True)
- sequence_output = self.get_cross_attention_output(
- sequence_output, attention_masks, hint_sequence_output,
- cross_attention_masks)
- # (b, l, n)
- head_logits = self.head_clsf(sequence_output).squeeze(-1)
- tail_logits = self.tail_clsf(sequence_output).squeeze(-1)
- loss_func = self.circle_loss
- head_loss = loss_func(head_logits, head_labels)
- tail_loss = loss_func(tail_logits, tail_labels)
- return {'loss': head_loss + tail_loss}
- def fast_inference(self, sequence_output, attention_masks, hint_ids,
- cross_attention_masks):
- """
- Args:
- sequence_output(tensor): 3-dimension tensor (batch size, sequence length, hidden size)
- attention_masks(tensor): attention mask, 2-dimension tensor (batch size, sequence length)
- hint_ids(tensor): token ids of prompt 2-dimension tensor (batch size, sequence length)
- cross_attention_masks(tensor): cross attention mask, 2-dimension tensor (batch size, sequence length)
- Default Returns:
- head_probs(tensor): 2-dimension tensor(batch size, sequence length)
- tail_probs(tensor): 2-dimension tensor(batch size, sequence length)
- """
- position_ids = torch.arange(hint_ids.size(1)).expand(
- (1, -1)) + sequence_output.size(1)
- position_ids = position_ids.to(sequence_output.device)
- hint_sequence_output = self.get_plm_sequence_output(
- hint_ids, cross_attention_masks, position_ids, is_hint=True)
- sequence_output = self.get_cross_attention_output(
- sequence_output, attention_masks, hint_sequence_output,
- cross_attention_masks)
- # (b, l, n)
- head_logits = self.head_clsf(sequence_output).squeeze(-1)
- tail_logits = self.tail_clsf(sequence_output).squeeze(-1)
- head_probs = head_logits + (1 - attention_masks) * -10000
- tail_probs = tail_logits + (1 - attention_masks) * -10000
- return head_probs, tail_probs
|