token_classification.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import torch
  18. import torch.nn as nn
  19. import torch.utils.checkpoint
  20. from torch.nn import CrossEntropyLoss
  21. from modelscope.metainfo import Models
  22. from modelscope.models.builder import MODELS
  23. from modelscope.outputs import AttentionTokenClassificationModelOutput
  24. from modelscope.utils import logger as logging
  25. from modelscope.utils.constant import Tasks
  26. from .adv_utils import compute_adv_loss
  27. from .backbone import SbertModel, SbertPreTrainedModel
  28. from .configuration import SbertConfig
  29. logger = logging.get_logger()
  30. @MODELS.register_module(
  31. Tasks.token_classification, module_name=Models.structbert)
  32. @MODELS.register_module(Tasks.word_segmentation, module_name=Models.structbert)
  33. @MODELS.register_module(Tasks.part_of_speech, module_name=Models.structbert)
  34. class SbertForTokenClassification(SbertPreTrainedModel):
  35. r"""StructBERT Model with a token classification head on top (a linear layer on top of the hidden-states output)
  36. e.g. for Named-Entity-Recognition (NER) tasks.
  37. This model inherits from :class:`~transformers.PreTrainedModel`. Check the superclass documentation for the generic
  38. methods the library implements for all its model (such as downloading or saving, resizing the input embeddings,
  39. pruning heads etc.)
  40. This model is also a PyTorch `torch.nn.Module <https://pytorch.org/docs/stable/nn.html#torch.nn.Module>`__
  41. subclass. Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to
  42. general usage and behavior.
  43. Preprocessor:
  44. This is the token-classification model of StructBERT, the preprocessor of this model
  45. is `modelscope.preprocessors.TokenClassificationTransformersPreprocessor`.
  46. Trainer:
  47. This model is a normal PyTorch model, and can be trained by variable trainers, like EpochBasedTrainer,
  48. NlpEpochBasedTrainer, or trainers from other frameworks.
  49. The preferred trainer in modelscope is NlpEpochBasedTrainer.
  50. Parameters:
  51. config (:class:`~modelscope.models.nlp.structbert.SbertConfig`): Model configuration class with
  52. all the parameters of the model.
  53. Initializing with a config file does not load the weights associated with the model, only the
  54. configuration. Check out the :meth:`~transformers.PreTrainedModel.from_pretrained` method to load the model
  55. weights.
  56. """
  57. _keys_to_ignore_on_load_unexpected = [r'pooler']
  58. def __init__(self, config: SbertConfig, **kwargs):
  59. super().__init__(config)
  60. self.num_labels = config.num_labels
  61. self.config = config
  62. if self.config.adv_grad_factor is None:
  63. logger.warning(
  64. 'Adv parameters not set, skipping compute_adv_loss.')
  65. setattr(self, self.base_model_prefix,
  66. SbertModel(config, add_pooling_layer=False))
  67. classifier_dropout = (
  68. config.classifier_dropout if config.classifier_dropout is not None
  69. else config.hidden_dropout_prob)
  70. self.dropout = nn.Dropout(classifier_dropout)
  71. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  72. self.init_weights()
  73. def _forward_call(self, **kwargs):
  74. outputs = self.bert(**kwargs)
  75. sequence_output = outputs[0]
  76. sequence_output = self.dropout(sequence_output)
  77. logits = self.classifier(sequence_output)
  78. outputs['logits'] = logits
  79. outputs.kwargs = kwargs
  80. return outputs
  81. def forward(
  82. self,
  83. input_ids=None,
  84. attention_mask=None,
  85. token_type_ids=None,
  86. position_ids=None,
  87. head_mask=None,
  88. inputs_embeds=None,
  89. labels=None,
  90. output_attentions=None,
  91. output_hidden_states=None,
  92. return_dict=None,
  93. offset_mapping=None,
  94. label_mask=None,
  95. ):
  96. r"""
  97. Args:
  98. input_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`):
  99. Indices of input sequence tokens in the vocabulary.
  100. Indices can be obtained using :class:`~modelscope.models.nlp.structbert.SbertTokenizer`. See
  101. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  102. for details.
  103. attention_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  104. Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:
  105. - 1 for tokens that are **not masked**,
  106. - 0 for tokens that are **masked**.
  107. token_type_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  108. Segment token indices to indicate first and second portions of the inputs. Indices are selected in ``[0,
  109. 1]``:
  110. - 0 corresponds to a `sentence A` token,
  111. - 1 corresponds to a `sentence B` token.
  112. position_ids (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  113. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range
  114. ``[0, config.max_position_embeddings - 1]``.
  115. head_mask (:obj:`torch.FloatTensor` of shape :obj:`(num_heads,)` or :obj:`(num_layers, num_heads)`,
  116. `optional`):
  117. Mask to nullify selected heads of the self-attention modules. Mask values selected in ``[0, 1]``:
  118. - 1 indicates the head is **not masked**,
  119. - 0 indicates the head is **masked**.
  120. inputs_embeds (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length, hidden_size)`,
  121. `optional`):
  122. Optionally, instead of passing :obj:`input_ids` you can choose to directly pass an embedded
  123. representation. This is useful if you want more control over how to convert :obj:`input_ids` indices
  124. into associated vectors than the model's internal embedding lookup matrix.
  125. output_attentions (:obj:`bool`, `optional`):
  126. Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under
  127. returned tensors for more detail.
  128. output_hidden_states (:obj:`bool`, `optional`):
  129. Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors
  130. for more detail.
  131. return_dict (:obj:`bool`, `optional`):
  132. Whether or not to return a :class:`~transformers.ModelOutput` instead of a plain tuple.
  133. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  134. Labels for computing the token classification loss. Indices should be in
  135. ``[0, ..., config.num_labels - 1]``.
  136. offset_mapping (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  137. Indices of positions of each input sequence tokens in the sentence.
  138. Selected in the range ``[0, sequence_length - 1]``.
  139. label_mask (:obj:`torch.FloatTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  140. Mask to avoid performing attention on padding token indices. Mask
  141. values selected in ``[0, 1]``:
  142. - 1 for tokens that are **not masked**,
  143. - 0 for tokens that are **masked**.
  144. Returns:
  145. Returns `modelscope.outputs.AttentionTokenClassificationModelOutput`
  146. Examples:
  147. >>> from modelscope.models import Model
  148. >>> from modelscope.preprocessors import Preprocessor
  149. >>> model = Model.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base')
  150. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_structbert_word-segmentation_chinese-base')
  151. >>> print(model(**preprocessor(('This is a test', 'This is also a test'))))
  152. """
  153. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  154. if not return_dict:
  155. logger.error('Return tuple in sbert is not supported now.')
  156. outputs = self._forward_call(
  157. input_ids=input_ids,
  158. attention_mask=attention_mask,
  159. token_type_ids=token_type_ids,
  160. position_ids=position_ids,
  161. head_mask=head_mask,
  162. inputs_embeds=inputs_embeds,
  163. output_attentions=output_attentions,
  164. output_hidden_states=output_hidden_states,
  165. return_dict=return_dict)
  166. logits = outputs.logits
  167. embedding_output = outputs.embedding_output
  168. loss = None
  169. if labels is not None:
  170. loss_fct = CrossEntropyLoss()
  171. # Only keep active parts of the loss
  172. if attention_mask is not None:
  173. active_loss = attention_mask.view(-1) == 1
  174. active_logits = logits.view(-1, self.num_labels)
  175. active_labels = torch.where(
  176. active_loss, labels.view(-1),
  177. torch.tensor(loss_fct.ignore_index).type_as(labels))
  178. loss = loss_fct(active_logits, active_labels)
  179. else:
  180. loss = loss_fct(
  181. logits.view(-1, self.num_labels), labels.view(-1))
  182. if self.config.adv_grad_factor is not None and self.training:
  183. loss = compute_adv_loss(
  184. embedding=embedding_output,
  185. model=self._forward_call,
  186. ori_logits=logits,
  187. ori_loss=loss,
  188. adv_bound=self.config.adv_bound,
  189. adv_grad_factor=self.config.adv_grad_factor,
  190. sigma=self.config.sigma,
  191. with_attention_mask=attention_mask is not None,
  192. **outputs.kwargs)
  193. return AttentionTokenClassificationModelOutput(
  194. loss=loss,
  195. logits=logits,
  196. hidden_states=outputs.hidden_states,
  197. attentions=outputs.attentions,
  198. offset_mapping=offset_mapping,
  199. label_mask=label_mask,
  200. )