document_segmentation.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Any, Dict
  3. import torch
  4. from torch import nn
  5. from torch.nn import CrossEntropyLoss
  6. from modelscope.metainfo import Models
  7. from modelscope.models.base import Model
  8. from modelscope.models.builder import MODELS
  9. from modelscope.models.nlp.bert import BertConfig
  10. from modelscope.outputs import AttentionTokenClassificationModelOutput
  11. from modelscope.utils.constant import Tasks
  12. from .backbone import PoNetModel, PoNetPreTrainedModel
  13. from .configuration import PoNetConfig
  14. __all__ = ['PoNetForDocumentSegmentation']
  15. @MODELS.register_module(
  16. Tasks.document_segmentation, module_name=Models.ponet_for_ds)
  17. @MODELS.register_module(
  18. Tasks.extractive_summarization, module_name=Models.ponet_for_ds)
  19. class PoNetForDocumentSegmentation(PoNetPreTrainedModel):
  20. _keys_to_ignore_on_load_unexpected = [r'pooler']
  21. def __init__(self, config, **kwargs):
  22. super().__init__(config)
  23. self.num_labels = config.num_labels
  24. self.ponet = PoNetModel(config, add_pooling_layer=False)
  25. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  26. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  27. self.init_weights()
  28. def forward(
  29. self,
  30. input_ids=None,
  31. attention_mask=None,
  32. token_type_ids=None,
  33. segment_ids=None,
  34. position_ids=None,
  35. head_mask=None,
  36. inputs_embeds=None,
  37. labels=None,
  38. output_attentions=None,
  39. output_hidden_states=None,
  40. return_dict=None,
  41. ):
  42. r"""
  43. labels (:obj:`torch.LongTensor` of shape :obj:`(batch_size, sequence_length)`, `optional`):
  44. Labels for computing the token classification loss. Indices should be in ``[0, ..., config.num_labels -
  45. 1]``.
  46. """
  47. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  48. outputs = self.ponet(
  49. input_ids,
  50. attention_mask=attention_mask,
  51. token_type_ids=token_type_ids,
  52. segment_ids=segment_ids,
  53. position_ids=position_ids,
  54. head_mask=head_mask,
  55. inputs_embeds=inputs_embeds,
  56. output_attentions=output_attentions,
  57. output_hidden_states=output_hidden_states,
  58. return_dict=return_dict,
  59. )
  60. sequence_output = outputs[0]
  61. sequence_output = self.dropout(sequence_output)
  62. logits = self.classifier(sequence_output)
  63. loss = None
  64. if labels is not None:
  65. loss_fct = CrossEntropyLoss()
  66. # Only keep active parts of the loss
  67. if attention_mask is not None:
  68. active_loss = attention_mask.view(-1) == 1
  69. active_logits = logits.view(-1, self.num_labels)
  70. active_labels = torch.where(
  71. active_loss, labels.view(-1),
  72. torch.tensor(loss_fct.ignore_index).type_as(labels))
  73. loss = loss_fct(active_logits, active_labels)
  74. else:
  75. loss = loss_fct(
  76. logits.view(-1, self.num_labels), labels.view(-1))
  77. if not return_dict:
  78. output = (logits, ) + outputs[2:]
  79. return ((loss, ) + output) if loss is not None else output
  80. return AttentionTokenClassificationModelOutput(
  81. loss=loss,
  82. logits=logits,
  83. hidden_states=outputs.hidden_states,
  84. attentions=outputs.attentions,
  85. )
  86. @classmethod
  87. def _instantiate(cls, model_dir, model_config: Dict[str, Any], **kwargs):
  88. if model_config['type'] == 'bert':
  89. config = BertConfig.from_pretrained(model_dir, num_labels=2)
  90. elif model_config['type'] == 'ponet':
  91. config = PoNetConfig.from_pretrained(model_dir, num_labels=2)
  92. else:
  93. raise ValueError(
  94. f'Expected config type bert and ponet, which is : {model_config["type"]}'
  95. )
  96. model = super(Model, cls).from_pretrained(model_dir, config=config)
  97. model.model_dir = model_dir
  98. model.model_cfg = model_config
  99. return model