# Copyright (c) Alibaba, Inc. and its affiliates. import torch from transformers import BloomConfig from transformers import BloomModel as BloomModelTransform from modelscope.metainfo import Models from modelscope.models import MODELS, TorchModel from modelscope.outputs import SentencEmbeddingModelOutput from modelscope.utils.constant import Tasks class DecoderPooler(torch.nn.Module): """ Parameter-free poolers to get the sentence embedding 'last': the last token state. 'weighted_mean': position weighted average of all token states. """ def __init__(self, pooler_type): super().__init__() self.pooler_type = pooler_type assert self.pooler_type in [ 'last', 'weighted_mean' ], 'unrecognized pooling type %s' % self.pooler_type def forward(self, outputs, attention_mask): last_hidden = outputs.last_hidden_state if self.pooler_type in ['last']: n, l, h = last_hidden.shape # Get shape [n] indices of the last token (i.e. the last token for each batch item) # Any sequence where min == 1, we use the entire sequence length since argmin = 0 values, indices = torch.min(attention_mask, 1, keepdim=False) gather_indices = torch.where(values == 0, indices, l) - 1 # Shape [n] # There are empty sequences, where the index would become -1 which will crash gather_indices = torch.clamp(gather_indices, min=0) # Turn indices from shape [n] --> [n, 1, h] gather_indices = gather_indices.unsqueeze(1).unsqueeze(1).expand( n, 1, h) # Gather along the 1st dim (l) (n, l, h -> n, h) pooled_output = torch.gather(last_hidden, 1, gather_indices).squeeze(dim=1) elif self.pooler_type == 'weighted_mean': input_mask_expanded = attention_mask.unsqueeze(-1).expand( last_hidden.size()).float() # last_hidden shape: bs, seq, hidden_dim weights = ( torch.arange(start=1, end=last_hidden.shape[1] + 1).unsqueeze(0).unsqueeze(-1).expand( last_hidden.size()).float().to( last_hidden.device)) assert weights.shape == last_hidden.shape == input_mask_expanded.shape input_mask_expanded = input_mask_expanded * weights sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1) sum_mask = input_mask_expanded.sum(1) sum_mask = torch.clamp(sum_mask, min=1e-9) pooled_output = sum_embeddings / sum_mask else: raise NotImplementedError return pooled_output @MODELS.register_module( group_key=Tasks.sentence_embedding, module_name=Models.bloom) class BloomForSentenceEmbedding(BloomModelTransform, TorchModel): r""" This model represent a text to a dense vector by the last token state or weighted mean of all token states. See `Language Models are Universal Embedders `_ for details. """ def __init__(self, config, **kwargs): super().__init__(config) self.config = config self.pooler_type = kwargs.get('emb_pooler_type', 'weighted_mean') self.pooler = DecoderPooler(self.pooler_type) self.normalize = kwargs.get('normalize', False) setattr(self, self.base_model_prefix, BloomModelTransform(config)) def forward(self, query=None, docs=None, labels=None): r""" Args: query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for details. docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for details. Returns: Returns `modelscope.outputs.SentencEmbeddingModelOutput Examples: >>> from modelscope.models import Model >>> from modelscope.preprocessors import Preprocessor >>> model = Model.from_pretrained('damo/nlp_udever_bloom_560m') >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_udever_bloom_560m') >>> inputs = preprocessor({'source_sentence': ['This is a test']}) >>> outputs = model(**inputs) >>> print(outputs) """ query_embeddings, doc_embeddings = None, None if query is not None: query_embeddings = self.encode(**query) if docs is not None: doc_embeddings = self.encode(**docs) outputs = SentencEmbeddingModelOutput( query_embeddings=query_embeddings, doc_embeddings=doc_embeddings) if query_embeddings is None or doc_embeddings is None: return outputs if self.base_model.training: loss_fct = torch.nn.CrossEntropyLoss() scores = torch.matmul(query_embeddings, doc_embeddings.T) if labels is None: labels = torch.arange( scores.size(0), device=scores.device, dtype=torch.long) labels = labels * ( doc_embeddings.size(0) // query_embeddings.size(0)) loss = loss_fct(scores, labels) outputs.loss = loss return outputs def encode( self, input_ids=None, attention_mask=None, ): outputs = self.base_model.forward( input_ids, attention_mask=attention_mask) embeddings = self.pooler(outputs, attention_mask) if self.normalize: embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) return embeddings @classmethod def _instantiate(cls, **kwargs): """Instantiate the model. Args: kwargs: Input args. model_dir: The model dir used to load the checkpoint and the label information. Returns: The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained """ model_dir = kwargs.get('model_dir') model_kwargs = { 'emb_pooler_type': kwargs.get('emb_pooler_type', 'weighted_mean'), 'normalize': kwargs.get('normalize', False) } if model_dir is None: config = BloomConfig(**kwargs) model = cls(config) else: model = super(BloomModelTransform, cls).from_pretrained( pretrained_model_name_or_path=model_dir, **model_kwargs) model.model_dir = model_dir return model