text_classification.py 4.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. # All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. import copy
  18. from torch.nn import CrossEntropyLoss, MSELoss
  19. from modelscope.metainfo import Models
  20. from modelscope.models.builder import MODELS
  21. from modelscope.outputs import AttentionTextClassificationModelOutput
  22. from modelscope.utils import logger as logging
  23. from modelscope.utils.constant import Tasks
  24. from .backbone import (PeerClassificationHead, PeerModel, PeerPreTrainedModel,
  25. PeerTopModel)
  26. logger = logging.get_logger()
  27. @MODELS.register_module(Tasks.text_classification, module_name=Models.peer)
  28. @MODELS.register_module(Tasks.nli, module_name=Models.peer)
  29. @MODELS.register_module(
  30. Tasks.sentiment_classification, module_name=Models.peer)
  31. @MODELS.register_module(Tasks.sentence_similarity, module_name=Models.peer)
  32. @MODELS.register_module(
  33. Tasks.zero_shot_classification, module_name=Models.peer)
  34. class PeerForSequenceClassification(PeerPreTrainedModel):
  35. def __init__(self, config, **kwargs):
  36. super().__init__(config)
  37. self.num_labels = config.num_labels
  38. self.config = config
  39. config_discr_top = copy.deepcopy(config)
  40. config_shared_bottom = copy.deepcopy(config)
  41. assert config.num_hidden_layers_shared > 0, 'config.num_hidden_layers_shared should be greater than 0!'
  42. config_shared_bottom.num_hidden_layers = config.num_hidden_layers_shared
  43. config_discr_top.num_hidden_layers = config_discr_top.num_hidden_layers \
  44. - config_discr_top.num_hidden_layers_shared
  45. self.teams1_shared_bottom = PeerModel(config_shared_bottom)
  46. self.teams1_discr_top = PeerTopModel(config_discr_top)
  47. self.classifier = PeerClassificationHead(config)
  48. self.init_weights()
  49. def forward(
  50. self,
  51. input_ids=None,
  52. attention_mask=None,
  53. token_type_ids=None,
  54. position_ids=None,
  55. head_mask=None,
  56. inputs_embeds=None,
  57. labels=None,
  58. output_attentions=None,
  59. output_hidden_states=None,
  60. side_info_sets=dict(),
  61. return_dict=None,
  62. ):
  63. r"""
  64. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
  65. Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
  66. config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
  67. If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  68. """
  69. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  70. hidden_states_discr_bottom = self.teams1_shared_bottom(
  71. input_ids, attention_mask, token_type_ids, position_ids, head_mask,
  72. inputs_embeds, output_attentions, output_hidden_states,
  73. side_info_sets, return_dict)
  74. hidden_states_discr_top = self.teams1_discr_top(
  75. hidden_states_discr_bottom[0], input_ids, attention_mask,
  76. token_type_ids, position_ids, head_mask, inputs_embeds,
  77. output_attentions, output_hidden_states, side_info_sets,
  78. return_dict)
  79. discriminator_hidden_states = hidden_states_discr_top
  80. sequence_output = discriminator_hidden_states[0]
  81. logits = self.classifier(sequence_output)
  82. loss = None
  83. if labels is not None:
  84. if self.num_labels == 1:
  85. # We are doing regression
  86. loss_fct = MSELoss()
  87. loss = loss_fct(logits.view(-1), labels.view(-1))
  88. else:
  89. loss_fct = CrossEntropyLoss()
  90. loss = loss_fct(
  91. logits.view(-1, self.num_labels), labels.view(-1))
  92. if not return_dict:
  93. output = (logits, ) + discriminator_hidden_states[1:]
  94. return ((loss, ) + output) if loss is not None else output
  95. return AttentionTextClassificationModelOutput(
  96. loss=loss,
  97. logits=logits,
  98. hidden_states=discriminator_hidden_states.hidden_states,
  99. attentions=discriminator_hidden_states.attentions,
  100. )