sentence_embedding.py 6.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import torch
  3. from transformers import BloomConfig
  4. from transformers import BloomModel as BloomModelTransform
  5. from modelscope.metainfo import Models
  6. from modelscope.models import MODELS, TorchModel
  7. from modelscope.outputs import SentencEmbeddingModelOutput
  8. from modelscope.utils.constant import Tasks
  9. class DecoderPooler(torch.nn.Module):
  10. """
  11. Parameter-free poolers to get the sentence embedding
  12. 'last': the last token state.
  13. 'weighted_mean': position weighted average of all token states.
  14. """
  15. def __init__(self, pooler_type):
  16. super().__init__()
  17. self.pooler_type = pooler_type
  18. assert self.pooler_type in [
  19. 'last', 'weighted_mean'
  20. ], 'unrecognized pooling type %s' % self.pooler_type
  21. def forward(self, outputs, attention_mask):
  22. last_hidden = outputs.last_hidden_state
  23. if self.pooler_type in ['last']:
  24. n, l, h = last_hidden.shape
  25. # Get shape [n] indices of the last token (i.e. the last token for each batch item)
  26. # Any sequence where min == 1, we use the entire sequence length since argmin = 0
  27. values, indices = torch.min(attention_mask, 1, keepdim=False)
  28. gather_indices = torch.where(values == 0, indices,
  29. l) - 1 # Shape [n]
  30. # There are empty sequences, where the index would become -1 which will crash
  31. gather_indices = torch.clamp(gather_indices, min=0)
  32. # Turn indices from shape [n] --> [n, 1, h]
  33. gather_indices = gather_indices.unsqueeze(1).unsqueeze(1).expand(
  34. n, 1, h)
  35. # Gather along the 1st dim (l) (n, l, h -> n, h)
  36. pooled_output = torch.gather(last_hidden, 1,
  37. gather_indices).squeeze(dim=1)
  38. elif self.pooler_type == 'weighted_mean':
  39. input_mask_expanded = attention_mask.unsqueeze(-1).expand(
  40. last_hidden.size()).float()
  41. # last_hidden shape: bs, seq, hidden_dim
  42. weights = (
  43. torch.arange(start=1, end=last_hidden.shape[1]
  44. + 1).unsqueeze(0).unsqueeze(-1).expand(
  45. last_hidden.size()).float().to(
  46. last_hidden.device))
  47. assert weights.shape == last_hidden.shape == input_mask_expanded.shape
  48. input_mask_expanded = input_mask_expanded * weights
  49. sum_embeddings = torch.sum(last_hidden * input_mask_expanded, 1)
  50. sum_mask = input_mask_expanded.sum(1)
  51. sum_mask = torch.clamp(sum_mask, min=1e-9)
  52. pooled_output = sum_embeddings / sum_mask
  53. else:
  54. raise NotImplementedError
  55. return pooled_output
  56. @MODELS.register_module(
  57. group_key=Tasks.sentence_embedding, module_name=Models.bloom)
  58. class BloomForSentenceEmbedding(BloomModelTransform, TorchModel):
  59. r"""
  60. This model represent a text to a dense vector by the last token state or weighted mean of all token states.
  61. See `Language Models are Universal Embedders
  62. <https://arxiv.org/pdf/2310.08232.pdf>`_ for details.
  63. """
  64. def __init__(self, config, **kwargs):
  65. super().__init__(config)
  66. self.config = config
  67. self.pooler_type = kwargs.get('emb_pooler_type', 'weighted_mean')
  68. self.pooler = DecoderPooler(self.pooler_type)
  69. self.normalize = kwargs.get('normalize', False)
  70. setattr(self, self.base_model_prefix, BloomModelTransform(config))
  71. def forward(self, query=None, docs=None, labels=None):
  72. r"""
  73. Args:
  74. query (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
  75. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  76. for details.
  77. docs (:obj: `dict`): Dict of pretrained models's input for the query sequence. See
  78. :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__`
  79. for details.
  80. Returns:
  81. Returns `modelscope.outputs.SentencEmbeddingModelOutput
  82. Examples:
  83. >>> from modelscope.models import Model
  84. >>> from modelscope.preprocessors import Preprocessor
  85. >>> model = Model.from_pretrained('damo/nlp_udever_bloom_560m')
  86. >>> preprocessor = Preprocessor.from_pretrained('damo/nlp_udever_bloom_560m')
  87. >>> inputs = preprocessor({'source_sentence': ['This is a test']})
  88. >>> outputs = model(**inputs)
  89. >>> print(outputs)
  90. """
  91. query_embeddings, doc_embeddings = None, None
  92. if query is not None:
  93. query_embeddings = self.encode(**query)
  94. if docs is not None:
  95. doc_embeddings = self.encode(**docs)
  96. outputs = SentencEmbeddingModelOutput(
  97. query_embeddings=query_embeddings, doc_embeddings=doc_embeddings)
  98. if query_embeddings is None or doc_embeddings is None:
  99. return outputs
  100. if self.base_model.training:
  101. loss_fct = torch.nn.CrossEntropyLoss()
  102. scores = torch.matmul(query_embeddings, doc_embeddings.T)
  103. if labels is None:
  104. labels = torch.arange(
  105. scores.size(0), device=scores.device, dtype=torch.long)
  106. labels = labels * (
  107. doc_embeddings.size(0) // query_embeddings.size(0))
  108. loss = loss_fct(scores, labels)
  109. outputs.loss = loss
  110. return outputs
  111. def encode(
  112. self,
  113. input_ids=None,
  114. attention_mask=None,
  115. ):
  116. outputs = self.base_model.forward(
  117. input_ids, attention_mask=attention_mask)
  118. embeddings = self.pooler(outputs, attention_mask)
  119. if self.normalize:
  120. embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1)
  121. return embeddings
  122. @classmethod
  123. def _instantiate(cls, **kwargs):
  124. """Instantiate the model.
  125. Args:
  126. kwargs: Input args.
  127. model_dir: The model dir used to load the checkpoint and the label information.
  128. Returns:
  129. The loaded model, which is initialized by transformers.PreTrainedModel.from_pretrained
  130. """
  131. model_dir = kwargs.get('model_dir')
  132. model_kwargs = {
  133. 'emb_pooler_type': kwargs.get('emb_pooler_type', 'weighted_mean'),
  134. 'normalize': kwargs.get('normalize', False)
  135. }
  136. if model_dir is None:
  137. config = BloomConfig(**kwargs)
  138. model = cls(config)
  139. else:
  140. model = super(BloomModelTransform, cls).from_pretrained(
  141. pretrained_model_name_or_path=model_dir, **model_kwargs)
  142. model.model_dir = model_dir
  143. return model