| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import torch
- from transformers import BloomConfig
- from transformers import BloomModel as BloomModelTransform
- from modelscope.metainfo import Models
- from modelscope.models import MODELS, TorchModel
- from modelscope.outputs import SentencEmbeddingModelOutput
- from modelscope.utils.constant import Tasks
- class DecoderPooler(torch.nn.Module):
- """
- Parameter-free poolers to get the sentence embedding
- 'last': the last token state.
- 'weighted_mean': position weighted average of all token states.
- """
- def __init__(self, pooler_type):
- super().__init__()
- self.pooler_type = pooler_type
- assert self.pooler_type in [
- 'last', 'weighted_mean'
- ], 'unrecognized pooling type %s' % self.pooler_type
- def forward(self, outputs, attention_mask):
- last_hidden = outputs.last_hidden_state
- if self.pooler_type in ['last']:
- n, l, h = last_hidden.shape
- # Get shape [n] indices of the last token (i.e. the last token for each batch item)
- # Any sequence where min == 1, we use the entire sequence length since argmin = 0
- values, indices = torch.min(attention_mask, 1, keepdim=False)
- gather_indices = torch.where(values == 0, indices,
- l) - 1 # Shape [n]
- # There are empty sequences, where the index would become -1 which will crash
- gather_indices = torch.clamp(gather_indices, min=0)
- # Turn indices from shape [n] --> [n, 1, h]
- gather_indices = gather_indices.unsqueeze(1).unsqueeze(1).expand(
- n, 1, h)
- # Gather along the 1st dim (l) (n, l, h -> n, h)
- pooled_output = torch.gather(last_hidden, 1,
- gather_indices).squeeze(dim=1)
- elif self.pooler_type == 'weighted_mean':
- input_mask_expanded = attention_mask.unsqueeze(-1).expand(
- last_hidden.size()).float()
- # last_hidden shape: bs, seq, hidden_dim
- weights = (
- torch.arange(start=1, end=last_hidden.shape[1]
- + 1).unsqueeze(0).unsqueeze(-1).expand(
- last_hidden.size()).float().to(
- last_hidden.device))
- assert weights.shape == last_hidden.shape == input_mask_expanded.shape
- input_mask_expanded = input_mask_expanded * weights
- sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1)
- sum_mask = input_mask_expanded.sum(1)
- sum_mask = torch.clamp(sum_mask, min=1e-9)
- pooled_output = sum_embeddings / sum_mask
- else:
- raise NotImplementedError
- return pooled_output
- @MODELS.register_module(
- group_key=Tasks.sentence_embedding, module_name=Models.bloom)
- class BloomForSentenceEmbedding(BloomModelTransform, TorchModel):
- r"""
- This model represent a text to a dense vector by the last token state or weighted mean of all token states.
- See `Language Models are Universal Embedders
- <https://arxiv.org/pdf/2310.08232.pdf>`_ for details.
- """
- def __init__(self, config, **kwargs):
- super().__init__(config)
- self.config = config
- self.pooler_type = kwargs.get('emb_pooler_type', 'weighted_mean')
- self.pooler = DecoderPooler(self.pooler_type)
- self.normalize = kwargs.get('normalize', False)
- setattr(self, self.base_model_prefix, BloomModelTransform(config))
- def forward(self, query=None, docs=None, labels=None):
- r"""
- Args:
- query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
- for details.
- docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
- :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
- for details.
- Returns:
- Returns `modelscope.outputs.SentencEmbeddingModelOutput
- Examples:
- >>> from modelscope.models import Model
- >>> from modelscope.preprocessors import Preprocessor
- >>> model = Model.from_pretrained('damo/nlp_udever_bloom_560m')
- >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_udever_bloom_560m')
- >>> inputs = preprocessor({'source_sentence': ['This is a test']})
- >>> outputs = model(**inputs)
- >>> print(outputs)
- """
- query_embeddings, doc_embeddings = None, None
- if query is not None:
- query_embeddings = self.encode(**query)
- if docs is not None:
- doc_embeddings = self.encode(**docs)
- outputs = SentencEmbeddingModelOutput(
- query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
- if query_embeddings is None or doc_embeddings is None:
- return outputs
- if self.base_model.training:
- loss_fct = torch.nn.CrossEntropyLoss()
- scores = torch.matmul(query_embeddings, doc_embeddings.T)
- if labels is None:
- labels = torch.arange(
- scores.size(0), device=scores.device, dtype=torch.long)
- labels = labels * (
- doc_embeddings.size(0) // query_embeddings.size(0))
- loss = loss_fct(scores, labels)
- outputs.loss = loss
- return outputs
- def encode(
- self,
- input_ids=None,
- attention_mask=None,
- ):
- outputs = self.base_model.forward(
- input_ids, attention_mask=attention_mask)
- embeddings = self.pooler(outputs, attention_mask)
- if self.normalize:
- embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
- return embeddings
- @classmethod
- def _instantiate(cls, **kwargs):
- """Instantiate the model.
- Args:
- kwargs: Input args.
- model_dir: The model dir used to load the checkpoint and the label information.
- Returns:
- The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
- """
- model_dir = kwargs.get('model_dir')
- model_kwargs = {
- 'emb_pooler_type': kwargs.get('emb_pooler_type', 'weighted_mean'),
- 'normalize': kwargs.get('normalize', False)
- }
- if model_dir is None:
- config = BloomConfig(**kwargs)
- model = cls(config)
- else:
- model = super(BloomModelTransform, cls).from_pretrained(
- pretrained_model_name_or_path=model_dir, **model_kwargs)
- model.model_dir = model_dir
- return model
|