siamese_uie.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from copy import deepcopy
  3. import torch
  4. from torch import nn
  5. from modelscope.metainfo import Models
  6. from modelscope.models.base import TorchModel
  7. from modelscope.models.builder import MODELS
  8. from modelscope.utils.constant import ModelFile, Tasks
  9. from .backbone import BertEncoder, BertModel, BertPreTrainedModel
  10. __all__ = ['SiameseUieModel']
  11. @MODELS.register_module(Tasks.siamese_uie, module_name=Models.bert)
  12. class SiameseUieModel(BertPreTrainedModel):
  13. r"""SiameseUIE general information extraction model,
  14. based on the construction idea of prompt (Prompt) + text (Text),
  15. uses pointer network (Pointer Network) to
  16. realize segment extraction (Span Extraction), so as to
  17. realize named entity recognition (NER), relation extraction (RE),
  18. Extraction of various tasks such as event extraction (EE),
  19. attribute sentiment extraction (ABSA), etc. Different from
  20. the existing general information extraction tasks on the market:
  21. """
  22. def __init__(self, config, **kwargs):
  23. super().__init__(config)
  24. self.config = config
  25. self.plm = BertModel(self.config, add_pooling_layer=True)
  26. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  27. self.head_clsf = nn.Linear(config.hidden_size, 1)
  28. self.tail_clsf = nn.Linear(config.hidden_size, 1)
  29. self.set_crossattention_layer()
  30. def set_crossattention_layer(self, num_hidden_layers=6):
  31. crossattention_config = deepcopy(self.config)
  32. crossattention_config.num_hidden_layers = num_hidden_layers
  33. self.config.num_hidden_layers -= num_hidden_layers
  34. self.crossattention = BertEncoder(crossattention_config)
  35. self.crossattention.layer = self.plm.encoder.layer[self.config.
  36. num_hidden_layers:]
  37. self.plm.encoder.layer = self.plm.encoder.layer[:self.config.
  38. num_hidden_layers]
  39. def circle_loss(self, y_pred, y_true):
  40. batch_size = y_true.size(0)
  41. y_true = y_true.view(batch_size, -1)
  42. y_pred = y_pred.view(batch_size, -1)
  43. y_pred = (1 - 2 * y_true) * y_pred
  44. y_pred_neg = y_pred - y_true * 1e12
  45. y_pred_pos = y_pred - (1 - y_true) * 1e12
  46. zeros = torch.zeros_like(y_pred[:, :1])
  47. y_pred_neg = torch.cat([y_pred_neg, zeros], dim=-1)
  48. y_pred_pos = torch.cat([y_pred_pos, zeros], dim=-1)
  49. neg_loss = torch.logsumexp(y_pred_neg, dim=-1)
  50. pos_loss = torch.logsumexp(y_pred_pos, dim=-1)
  51. return (neg_loss + pos_loss).mean()
  52. def get_cross_attention_output(self, hidden_states, attention_mask,
  53. encoder_hidden_states,
  54. encoder_attention_mask):
  55. cat_hidden_states = torch.cat([hidden_states, encoder_hidden_states],
  56. dim=1)
  57. cat_attention_mask = torch.cat(
  58. [attention_mask, encoder_attention_mask], dim=1)
  59. cat_attention_mask = self.plm.get_extended_attention_mask(
  60. cat_attention_mask,
  61. cat_hidden_states.size()[:2])
  62. hidden_states = self.crossattention(
  63. hidden_states=cat_hidden_states, attention_mask=cat_attention_mask
  64. )[0][:, :hidden_states.size()[1], :]
  65. return hidden_states
  66. def get_plm_sequence_output(self,
  67. input_ids,
  68. attention_mask,
  69. position_ids=None,
  70. is_hint=False):
  71. token_type_ids = torch.ones_like(attention_mask) if is_hint else None
  72. sequence_output = self.plm(
  73. input_ids,
  74. attention_mask=attention_mask,
  75. token_type_ids=token_type_ids,
  76. position_ids=position_ids)[0]
  77. return sequence_output
  78. def forward(self, input_ids, attention_masks, hint_ids,
  79. cross_attention_masks, head_labels, tail_labels):
  80. """train forward
  81. Args:
  82. input_ids (Tensor): input token ids of text.
  83. attention_masks (Tensor): attention_masks of text.
  84. hint_ids (Tensor): input token ids of prompt.
  85. cross_attention_masks (Tensor): attention_masks of prompt.
  86. head_labels (Tensor): labels of start position.
  87. tail_labels (Tensor): labels of end position.
  88. Returns:
  89. Dict[str, float]: the loss
  90. Example:
  91. {"loss": 0.5091743}
  92. """
  93. sequence_output = self.get_plm_sequence_output(input_ids,
  94. attention_masks)
  95. assert hint_ids.size(1) + input_ids.size(1) <= 512
  96. position_ids = torch.arange(hint_ids.size(1)).expand(
  97. (1, -1)) + input_ids.size(1)
  98. position_ids = position_ids.to(sequence_output.device)
  99. hint_sequence_output = self.get_plm_sequence_output(
  100. hint_ids, cross_attention_masks, position_ids, is_hint=True)
  101. sequence_output = self.get_cross_attention_output(
  102. sequence_output, attention_masks, hint_sequence_output,
  103. cross_attention_masks)
  104. # (b, l, n)
  105. head_logits = self.head_clsf(sequence_output).squeeze(-1)
  106. tail_logits = self.tail_clsf(sequence_output).squeeze(-1)
  107. loss_func = self.circle_loss
  108. head_loss = loss_func(head_logits, head_labels)
  109. tail_loss = loss_func(tail_logits, tail_labels)
  110. return {'loss': head_loss + tail_loss}
  111. def fast_inference(self, sequence_output, attention_masks, hint_ids,
  112. cross_attention_masks):
  113. """
  114. Args:
  115. sequence_output(tensor): 3-dimension tensor (batch size, sequence length, hidden size)
  116. attention_masks(tensor): attention mask, 2-dimension tensor (batch size, sequence length)
  117. hint_ids(tensor): token ids of prompt 2-dimension tensor (batch size, sequence length)
  118. cross_attention_masks(tensor): cross attention mask, 2-dimension tensor (batch size, sequence length)
  119. Default Returns:
  120. head_probs(tensor): 2-dimension tensor(batch size, sequence length)
  121. tail_probs(tensor): 2-dimension tensor(batch size, sequence length)
  122. """
  123. position_ids = torch.arange(hint_ids.size(1)).expand(
  124. (1, -1)) + sequence_output.size(1)
  125. position_ids = position_ids.to(sequence_output.device)
  126. hint_sequence_output = self.get_plm_sequence_output(
  127. hint_ids, cross_attention_masks, position_ids, is_hint=True)
  128. sequence_output = self.get_cross_attention_output(
  129. sequence_output, attention_masks, hint_sequence_output,
  130. cross_attention_masks)
  131. # (b, l, n)
  132. head_logits = self.head_clsf(sequence_output).squeeze(-1)
  133. tail_logits = self.tail_clsf(sequence_output).squeeze(-1)
  134. head_probs = head_logits + (1 - attention_masks) * -10000
  135. tail_probs = tail_logits + (1 - attention_masks) * -10000
  136. return head_probs, tail_probs