| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from typing import Any, Dict
- import torch
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from modelscope.metainfo import Models
- from modelscope.models import Model
- from modelscope.models.builder import MODELS
- from modelscope.models.nlp.ponet import PoNetConfig
- from modelscope.outputs import AttentionTokenClassificationModelOutput
- from modelscope.utils.constant import Tasks
- from .backbone import BertModel, BertPreTrainedModel
- from .configuration import BertConfig
- __all__ = ['BertForDocumentSegmentation']
- @MODELS.register_module(
- Tasks.document_segmentation, module_name=Models.bert_for_ds)
- class BertForDocumentSegmentation(BertPreTrainedModel):
- _keys_to_ignore_on_load_unexpected = [r'pooler']
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.num_labels = config.num_labels
- self.sentence_pooler_type = None
- self.bert = BertModel(config, add_pooling_layer=False)
- classifier_dropout = config.hidden_dropout_prob
- self.dropout = nn.Dropout(classifier_dropout)
- self.classifier = nn.Linear(config.hidden_size, config.num_labels)
- self.class_weights = None
- self.init_weights()
- def forward(self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- sentence_attention_mask=None,
- inputs_embeds=None,
- labels=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None):
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
- outputs = self.bert(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict,
- )
- sequence_output = outputs[0]
- if self.sentence_pooler_type is not None:
- raise NotImplementedError
- else:
- sequence_output = self.dropout(sequence_output)
- logits = self.classifier(sequence_output)
- loss = None
- if labels is not None:
- loss_fct = CrossEntropyLoss(weight=self.class_weights)
- if sentence_attention_mask is not None:
- active_loss = sentence_attention_mask.view(-1) == 1
- active_logits = logits.view(-1, self.num_labels)
- active_labels = torch.where(
- active_loss, labels.view(-1),
- torch.tensor(loss_fct.ignore_index).type_as(labels))
- loss = loss_fct(active_logits, active_labels)
- else:
- loss = loss_fct(
- logits.view(-1, self.num_labels), labels.view(-1))
- if not return_dict:
- output = (logits, ) + outputs[2:]
- return ((loss, ) + output) if loss is not None else output
- return AttentionTokenClassificationModelOutput(
- loss=loss,
- logits=logits,
- hidden_states=outputs.hidden_states,
- attentions=outputs.attentions,
- )
- @classmethod
- def _instantiate(cls, model_dir, model_config: Dict[str, Any], **kwargs):
- if model_config['type'] == 'bert':
- config = BertConfig.from_pretrained(model_dir, num_labels=2)
- elif model_config['type'] == 'ponet':
- config = PoNetConfig.from_pretrained(model_dir, num_labels=2)
- else:
- raise ValueError(
- f'Expected config type bert and ponet, which is : {model_config["type"]}'
- )
- model = super(Model, cls).from_pretrained(
- model_dir, from_tf=False, config=config)
- model.model_dir = model_dir
- model.model_cfg = model_config
- return model
|