| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import torch
- import torch.nn.functional as F
- from torch import nn
- from modelscope.metainfo import Models
- from modelscope.models import Model
- from modelscope.models.builder import MODELS
- from modelscope.outputs import SentencEmbeddingModelOutput
- from modelscope.utils.constant import Tasks
- from .backbone import BertModel, BertPreTrainedModel
- class Pooler(nn.Module):
- """
- Parameter-free poolers to get the sentence embedding
- 'cls': [CLS] representation with BERT/RoBERTa's MLP pooler.
- 'cls_before_pooler': [CLS] representation without the original MLP pooler.
- 'avg': average of the last layers' hidden states at each token.
- 'avg_top2': average of the last two layers.
- 'avg_first_last': average of the first and the last layers.
- """
- def __init__(self, pooler_type):
- super().__init__()
- self.pooler_type = pooler_type
- assert self.pooler_type in [
- 'cls', 'avg', 'avg_top2', 'avg_first_last'
- ], 'unrecognized pooling type %s' % self.pooler_type
- def forward(self, outputs, attention_mask):
- last_hidden = outputs.last_hidden_state
- hidden_states = outputs.hidden_states
- if self.pooler_type in ['cls']:
- return last_hidden[:, 0]
- elif self.pooler_type == 'avg':
- return ((last_hidden * attention_mask.unsqueeze(-1)).sum(1)
- / attention_mask.sum(-1).unsqueeze(-1))
- elif self.pooler_type == 'avg_first_last':
- first_hidden = hidden_states[1]
- last_hidden = hidden_states[-1]
- pooled_result = ((first_hidden + last_hidden) / 2.0
- * attention_mask.unsqueeze(-1)
- ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
- return pooled_result
- elif self.pooler_type == 'avg_top2':
- second_last_hidden = hidden_states[-2]
- last_hidden = hidden_states[-1]
- pooled_result = ((last_hidden + second_last_hidden) / 2.0
- * attention_mask.unsqueeze(-1)
- ).sum(1) / attention_mask.sum(-1).unsqueeze(-1)
- return pooled_result
- else:
- raise NotImplementedError
- @MODELS.register_module(Tasks.sentence_embedding, module_name=Models.bert)
- class BertForSentenceEmbedding(BertPreTrainedModel):
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.config = config
- self.pooler_type = kwargs.get('emb_pooler_type', 'cls')
- self.pooler = Pooler(self.pooler_type)
- self.normalize = kwargs.get('normalize', False)
- setattr(self, self.base_model_prefix,
- BertModel(config, add_pooling_layer=False))
- def forward(self, query=None, docs=None, labels=None):
- r"""
- Args:
- query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
- for details.
- docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
- for details.
- Returns:
- Returns `modelscope.outputs.SentencEmbeddingModelOutput
- Examples:
- >>> from modelscope.models import Model
- >>> from modelscope.preprocessors import Preprocessor
- >>> model = Model.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base')
- >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_corom_sentence-embedding_chinese-base')
- >>> print(model(**preprocessor('source_sentence':['This is a test'])))
- """
- query_embeddings, doc_embeddings = None, None
- if query is not None:
- query_embeddings = self.encode(**query)
- if docs is not None:
- doc_embeddings = self.encode(**docs)
- outputs = SentencEmbeddingModelOutput(
- query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
- if query_embeddings is None or doc_embeddings is None:
- return outputs
- if self.base_model.training:
- loss_fct = nn.CrossEntropyLoss()
- scores = torch.matmul(query_embeddings, doc_embeddings.T)
- if labels is None:
- labels = torch.arange(
- scores.size(0), device=scores.device, dtype=torch.long)
- labels = labels * (
- doc_embeddings.size(0) // query_embeddings.size(0))
- loss = loss_fct(scores, labels)
- outputs.loss = loss
- return outputs
- def encode(
- self,
- input_ids=None,
- attention_mask=None,
- token_type_ids=None,
- position_ids=None,
- head_mask=None,
- inputs_embeds=None,
- output_attentions=None,
- output_hidden_states=None,
- return_dict=None,
- ):
- outputs = self.base_model.forward(
- input_ids,
- attention_mask=attention_mask,
- token_type_ids=token_type_ids,
- position_ids=position_ids,
- head_mask=head_mask,
- inputs_embeds=inputs_embeds,
- output_attentions=output_attentions,
- output_hidden_states=output_hidden_states,
- return_dict=return_dict)
- outputs = self.pooler(outputs, attention_mask)
- if self.normalize:
- outputs = F.normalize(outputs, p=2, dim=-1)
- return outputs
- @classmethod
- def _instantiate(cls, **kwargs):
- """Instantiate the model.
- Args:
- kwargs: Input args.
- model_dir: The model dir used to load the checkpoint and the label information.
- Returns:
- The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
- """
- model_dir = kwargs.get('model_dir')
- model_kwargs = {
- 'emb_pooler_type': kwargs.get('emb_pooler_type', 'cls'),
- 'normalize': kwargs.get('normalize', False)
- }
- model = super(Model, cls).from_pretrained(
- pretrained_model_name_or_path=model_dir, **model_kwargs)
- model.model_dir = model_dir
- return model
|