sentence_embedding.py 6.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. import torch.nn.functional as F
  4. from torch import nn
  5. from modelscope.metainfo import Models
  6. from modelscope.models import Model
  7. from modelscope.models.builder import MODELS
  8. from modelscope.outputs import SentencEmbeddingModelOutput
  9. from modelscope.utils.constant import Tasks
  10. from .backbone import BertModel, BertPreTrainedModel
  11. class Pooler(nn.Module):
  12. """
  13. Parameter-free poolers to get the sentence embedding
  14. 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
  15. 'cls_before_pooler': [CLS] representation without the original MLP pooler.
  16. 'avg': average of the last layers' hidden states at each token.
  17. 'avg_top2': average of the last two layers.
  18. 'avg_first_last': average of the first and the last layers.
  19. """
  20. def __init__(self, pooler_type):
  21. super().__init__()
  22. self.pooler_type = pooler_type
  23. assert self.pooler_type in [
  24. 'cls', 'avg', 'avg_top2', 'avg_first_last'
  25. ], 'unrecognized pooling type %s' % self.pooler_type
  26. def forward(self, outputs, attention_mask):
  27. last_hidden = outputs.last_hidden_state
  28. hidden_states = outputs.hidden_states
  29. if self.pooler_type in ['cls']:
  30. return last_hidden[:, 0]
  31. elif self.pooler_type == 'avg':
  32. return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1)
  33. / attention_mask.sum(-1).unsqueeze(-1))
  34. elif self.pooler_type == 'avg_first_last':
  35. first_hidden = hidden_states[1]
  36. last_hidden = hidden_states[-1]
  37. pooled_result = ((first_hidden + last_hidden) / 2.0
  38. * attention_mask.unsqueeze(-1)
  39. ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
  40. return pooled_result
  41. elif self.pooler_type == 'avg_top2':
  42. second_last_hidden = hidden_states[-2]
  43. last_hidden = hidden_states[-1]
  44. pooled_result = ((last_hidden + second_last_hidden) / 2.0
  45. * attention_mask.unsqueeze(-1)
  46. ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
  47. return pooled_result
  48. else:
  49. raise NotImplementedError
  50. @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert)
  51. class BertForSentenceEmbedding(BertPreTrainedModel):
  52. def __init__(self, config, **kwargs):
  53. super().__init__(config)
  54. self.config = config
  55. self.pooler_type = kwargs.get('emb_pooler_type', 'cls')
  56. self.pooler = Pooler(self.pooler_type)
  57. self.normalize = kwargs.get('normalize', False)
  58. setattr(self, self.base_model_prefix,
  59. BertModel(config, add_pooling_layer=False))
  60. def forward(self, query=None, docs=None, labels=None):
  61. r"""
  62. Args:
  63. query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
  64. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  65. for details.
  66. docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
  67. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  68. for details.
  69. Returns:
  70. Returns `modelscope.outputs.SentencEmbeddingModelOutput
  71. Examples:
  72. >>> from modelscope.models import Model
  73. >>> from modelscope.preprocessors import Preprocessor
  74. >>> model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base')
  75. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base')
  76. >>> print(model(**preprocessor('source_sentence':['This is a test'])))
  77. """
  78. query_embeddings, doc_embeddings = None, None
  79. if query is not None:
  80. query_embeddings = self.encode(**query)
  81. if docs is not None:
  82. doc_embeddings = self.encode(**docs)
  83. outputs = SentencEmbeddingModelOutput(
  84. query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
  85. if query_embeddings is None or doc_embeddings is None:
  86. return outputs
  87. if self.base_model.training:
  88. loss_fct = nn.CrossEntropyLoss()
  89. scores = torch.matmul(query_embeddings, doc_embeddings.T)
  90. if labels is None:
  91. labels = torch.arange(
  92. scores.size(0), device=scores.device, dtype=torch.long)
  93. labels = labels * (
  94. doc_embeddings.size(0) // query_embeddings.size(0))
  95. loss = loss_fct(scores, labels)
  96. outputs.loss = loss
  97. return outputs
  98. def encode(
  99. self,
  100. input_ids=None,
  101. attention_mask=None,
  102. token_type_ids=None,
  103. position_ids=None,
  104. head_mask=None,
  105. inputs_embeds=None,
  106. output_attentions=None,
  107. output_hidden_states=None,
  108. return_dict=None,
  109. ):
  110. outputs = self.base_model.forward(
  111. input_ids,
  112. attention_mask=attention_mask,
  113. token_type_ids=token_type_ids,
  114. position_ids=position_ids,
  115. head_mask=head_mask,
  116. inputs_embeds=inputs_embeds,
  117. output_attentions=output_attentions,
  118. output_hidden_states=output_hidden_states,
  119. return_dict=return_dict)
  120. outputs = self.pooler(outputs, attention_mask)
  121. if self.normalize:
  122. outputs = F.normalize(outputs, p=2, dim=-1)
  123. return outputs
  124. @classmethod
  125. def _instantiate(cls, **kwargs):
  126. """Instantiate the model.
  127. Args:
  128. kwargs: Input args.
  129. model_dir: The model dir used to load the checkpoint and the label information.
  130. Returns:
  131. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  132. """
  133. model_dir = kwargs.get('model_dir')
  134. model_kwargs = {
  135. 'emb_pooler_type': kwargs.get('emb_pooler_type', 'cls'),
  136. 'normalize': kwargs.get('normalize', False)
  137. }
  138. model = super(Model, cls).from_pretrained(
  139. pretrained_model_name_or_path=model_dir, **model_kwargs)
  140. model.model_dir = model_dir
  141. return model