document_segmentation.py 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from torch import nn
  5. from torch.nn import CrossEntropyLoss
  6. from modelscope.metainfo import Models
  7. from modelscope.models import Model
  8. from modelscope.models.builder import MODELS
  9. from modelscope.models.nlp.ponet import PoNetConfig
  10. from modelscope.outputs import AttentionTokenClassificationModelOutput
  11. from modelscope.utils.constant import Tasks
  12. from .backbone import BertModel, BertPreTrainedModel
  13. from .configuration import BertConfig
  14. __all__ = ['BertForDocumentSegmentation']
  15. @MODELS.register_module(
  16. Tasks.document_segmentation, module_name=Models.bert_for_ds)
  17. class BertForDocumentSegmentation(BertPreTrainedModel):
  18. _keys_to_ignore_on_load_unexpected = [r'pooler']
  19. def __init__(self, config, **kwargs):
  20. super().__init__(config)
  21. self.num_labels = config.num_labels
  22. self.sentence_pooler_type = None
  23. self.bert = BertModel(config, add_pooling_layer=False)
  24. classifier_dropout = config.hidden_dropout_prob
  25. self.dropout = nn.Dropout(classifier_dropout)
  26. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  27. self.class_weights = None
  28. self.init_weights()
  29. def forward(self,
  30. input_ids=None,
  31. attention_mask=None,
  32. token_type_ids=None,
  33. position_ids=None,
  34. head_mask=None,
  35. sentence_attention_mask=None,
  36. inputs_embeds=None,
  37. labels=None,
  38. output_attentions=None,
  39. output_hidden_states=None,
  40. return_dict=None):
  41. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  42. outputs = self.bert(
  43. input_ids,
  44. attention_mask=attention_mask,
  45. token_type_ids=token_type_ids,
  46. position_ids=position_ids,
  47. head_mask=head_mask,
  48. inputs_embeds=inputs_embeds,
  49. output_attentions=output_attentions,
  50. output_hidden_states=output_hidden_states,
  51. return_dict=return_dict,
  52. )
  53. sequence_output = outputs[0]
  54. if self.sentence_pooler_type is not None:
  55. raise NotImplementedError
  56. else:
  57. sequence_output = self.dropout(sequence_output)
  58. logits = self.classifier(sequence_output)
  59. loss = None
  60. if labels is not None:
  61. loss_fct = CrossEntropyLoss(weight=self.class_weights)
  62. if sentence_attention_mask is not None:
  63. active_loss = sentence_attention_mask.view(-1) == 1
  64. active_logits = logits.view(-1, self.num_labels)
  65. active_labels = torch.where(
  66. active_loss, labels.view(-1),
  67. torch.tensor(loss_fct.ignore_index).type_as(labels))
  68. loss = loss_fct(active_logits, active_labels)
  69. else:
  70. loss = loss_fct(
  71. logits.view(-1, self.num_labels), labels.view(-1))
  72. if not return_dict:
  73. output = (logits, ) + outputs[2:]
  74. return ((loss, ) + output) if loss is not None else output
  75. return AttentionTokenClassificationModelOutput(
  76. loss=loss,
  77. logits=logits,
  78. hidden_states=outputs.hidden_states,
  79. attentions=outputs.attentions,
  80. )
  81. @classmethod
  82. def _instantiate(cls, model_dir, model_config: Dict[str, Any], **kwargs):
  83. if model_config['type'] == 'bert':
  84. config = BertConfig.from_pretrained(model_dir, num_labels=2)
  85. elif model_config['type'] == 'ponet':
  86. config = PoNetConfig.from_pretrained(model_dir, num_labels=2)
  87. else:
  88. raise ValueError(
  89. f'Expected config type bert and ponet, which is : {model_config["type"]}'
  90. )
  91. model = super(Model, cls).from_pretrained(
  92. model_dir, from_tf=False, config=config)
  93. model.model_dir = model_dir
  94. model.model_cfg = model_config
  95. return model