| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121 |
- # Copyright 2021-2022 The Alibaba DAMO NLP Team Authors.
- # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
- # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
- # All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import copy
- from torch.nn import CrossEntropyLoss, MSELoss
- from modelscope.metainfo import Models
- from modelscope.models.builder import MODELS
- from modelscope.outputs import AttentionTextClassificationModelOutput
- from modelscope.utils import logger as logging
- from modelscope.utils.constant import Tasks
- from .backbone import (PeerClassificationHead, PeerModel, PeerPreTrainedModel,
- PeerTopModel)
- logger = logging.get_logger()
- @MODELS.register_module(Tasks.text_classification, module_name=Models.peer)
- @MODELS.register_module(Tasks.nli, module_name=Models.peer)
- @MODELS.register_module(
- Tasks.sentiment_classification, module_name=Models.peer)
- @MODELS.register_module(Tasks.sentence_similarity, module_name=Models.peer)
- @MODELS.register_module(
- Tasks.zero_shot_classification, module_name=Models.peer)
- class PeerForSequenceClassification(PeerPreTrainedModel):
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.config = config
- config_discr_top = copy.deepcopy(config)
- config_shared_bottom = copy.deepcopy(config)
- assert config.num_hidden_layers_shared > 0, 'config.num_hidden_layers_shared should be greater than 0!'
- config_shared_bottom.num_hidden_layers = config.num_hidden_layers_shared
- config_discr_top.num_hidden_layers = config_discr_top.num_hidden_layers \
- - config_discr_top.num_hidden_layers_shared
- self.teams1_shared_bottom = PeerModel(config_shared_bottom)
- self.teams1_discr_top = PeerTopModel(config_discr_top)
- self.classifier = PeerClassificationHead(config)
- self.init_weights()
- def forward(
- self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- side_info_sets=dict(),
- return_dict=None,
- ):
- r"""
- labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size,)`, `optional`):
- Labels for computing the sequence classification/regression loss. Indices should be in :obj:`[0, ...,
- config.num_labels - 1]`. If :obj:`config.num_labels == 1` a regression loss is computed (Mean-Square loss),
- If :obj:`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
- """
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- hidden_states_discr_bottom = self.teams1_shared_bottom(
- input_ids, attention_mask, token_type_ids, position_ids, head_mask,
- inputs_embeds, output_attentions, output_hidden_states,
- side_info_sets, return_dict)
- hidden_states_discr_top = self.teams1_discr_top(
- hidden_states_discr_bottom[0], input_ids, attention_mask,
- token_type_ids, position_ids, head_mask, inputs_embeds,
- output_attentions, output_hidden_states, side_info_sets,
- return_dict)
- discriminator_hidden_states = hidden_states_discr_top
- sequence_output = discriminator_hidden_states[0]
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- if self.num_labels == 1:
- # We are doing regression
- loss_fct = MSELoss()
- loss = loss_fct(logits.view(-1), labels.view(-1))
- else:
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(
- logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits, ) + discriminator_hidden_states[1:]
- return ((loss, ) + output) if loss is not None else output
- return AttentionTextClassificationModelOutput(
- loss=loss,
- logits=logits,
- hidden_states=discriminator_hidden_states.hidden_states,
- attentions=discriminator_hidden_states.attentions,
- )
|