backbone.py 6.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright 2021-2022 The Alibaba DAMO Team Authors. All rights reserved.
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. from __future__ import absolute_import, division, print_function
  17. import os.path
  18. import torch
  19. import torch.nn.functional as F
  20. from torch import nn
  21. from torch.utils.checkpoint import checkpoint
  22. from transformers import (AutoConfig, DPRConfig, DPRQuestionEncoder,
  23. MT5ForConditionalGeneration, RagTokenForGeneration,
  24. XLMRobertaForSequenceClassification, XLMRobertaModel,
  25. XLMRobertaTokenizer)
  26. from modelscope.utils.logger import get_logger
  27. logger = get_logger()
  28. class Wrapper(nn.Module):
  29. def __init__(self, encoder):
  30. super(Wrapper, self).__init__()
  31. self.encoder = encoder
  32. def forward(self, input_ids, attention_mask, dummy_tensor):
  33. return self.encoder(input_ids, attention_mask).pooler_output
  34. class DPRModel(nn.Module):
  35. def __init__(self, model_dir, config):
  36. super().__init__()
  37. self.config = config
  38. qry_encoder = XLMRobertaModel(
  39. config=AutoConfig.from_pretrained(
  40. os.path.join(model_dir, 'qry_encoder')))
  41. ctx_encoder = XLMRobertaModel(
  42. config=AutoConfig.from_pretrained(
  43. os.path.join(model_dir, 'ctx_encoder')))
  44. self.qry_encoder = Wrapper(qry_encoder)
  45. self.ctx_encoder = Wrapper(ctx_encoder)
  46. self.loss_fct = nn.CrossEntropyLoss()
  47. @staticmethod
  48. def encode(model, input_ids, attention_mask, gck_segment=32):
  49. dummy_tensor = torch.ones(1, dtype=torch.float32, requires_grad=True)
  50. pooled_output = []
  51. for mini_batch in range(0, input_ids.shape[0], gck_segment):
  52. mini_batch_input_ids = input_ids[mini_batch:mini_batch
  53. + gck_segment]
  54. mini_batch_attention_mask = attention_mask[mini_batch:mini_batch
  55. + gck_segment]
  56. mini_batch_pooled_output = checkpoint(model, mini_batch_input_ids,
  57. mini_batch_attention_mask,
  58. dummy_tensor)
  59. pooled_output.append(mini_batch_pooled_output)
  60. return torch.cat(pooled_output, dim=0)
  61. def forward(self,
  62. query_input_ids,
  63. query_attention_mask,
  64. context_input_ids,
  65. context_attention_mask,
  66. labels,
  67. gck_segment=32):
  68. query_vector = self.encode(self.qry_encoder, query_input_ids,
  69. query_attention_mask, gck_segment)
  70. context_vector = self.encode(self.ctx_encoder, context_input_ids,
  71. context_attention_mask, gck_segment)
  72. logits = torch.matmul(query_vector, context_vector.T)
  73. loss = self.loss_fct(logits, labels)
  74. return loss, logits
  75. class ClassifyRerank(nn.Module):
  76. def __init__(self, model_dir):
  77. super().__init__()
  78. self.base_model = XLMRobertaForSequenceClassification.from_pretrained(
  79. model_dir)
  80. def forward(self,
  81. input_ids=None,
  82. attention_mask=None,
  83. token_type_ids=None,
  84. position_ids=None,
  85. head_mask=None,
  86. inputs_embeds=None,
  87. labels=None,
  88. output_attentions=None,
  89. output_hidden_states=None,
  90. return_dict=None,
  91. *args,
  92. **kwargs):
  93. outputs = self.base_model.forward(
  94. input_ids=input_ids,
  95. attention_mask=attention_mask,
  96. token_type_ids=token_type_ids,
  97. position_ids=position_ids,
  98. head_mask=head_mask,
  99. inputs_embeds=inputs_embeds,
  100. output_attentions=output_attentions,
  101. output_hidden_states=output_hidden_states,
  102. return_dict=return_dict)
  103. return outputs
  104. class Rerank(nn.Module):
  105. def __init__(self, encoder, top_k):
  106. super().__init__()
  107. self.encoder = encoder
  108. self.top_k = top_k
  109. def forward(self, inputs):
  110. model = self.encoder
  111. logits = F.log_softmax(model(**inputs)[0], dim=-1)[:, 1]
  112. logits = logits.view(-1, self.top_k)
  113. logprobs = F.log_softmax(logits, dim=-1)
  114. return logprobs
  115. class Re2GModel(nn.Module):
  116. def __init__(self, model_dir, config):
  117. super(Re2GModel, self).__init__()
  118. self.config = config
  119. self.top_k = self.config['top_k']
  120. encoder = XLMRobertaForSequenceClassification(
  121. config=AutoConfig.from_pretrained(
  122. os.path.join(model_dir, 'rerank')))
  123. generator = MT5ForConditionalGeneration(
  124. config=AutoConfig.from_pretrained(
  125. os.path.join(model_dir, 'generation')))
  126. self.rerank = Rerank(encoder, self.top_k)
  127. dpr_config = DPRConfig()
  128. dpr_config.vocab_size = encoder.config.vocab_size
  129. rag_model = RagTokenForGeneration(
  130. question_encoder=DPRQuestionEncoder(dpr_config),
  131. generator=generator)
  132. rag_model.rag.question_encoder = None
  133. self.generator = rag_model
  134. def forward(self, rerank_input_ids, input_ids, attention_mask, label_ids):
  135. doc_scores = self.rerank(rerank_input_ids)
  136. outputs = self.generator(
  137. labels=label_ids,
  138. context_input_ids=input_ids,
  139. context_attention_mask=attention_mask,
  140. doc_scores=doc_scores,
  141. n_docs=self.top_k)
  142. return outputs
  143. def generate(self, rerank_input_ids, input_ids, attention_mask):
  144. doc_scores = self.rerank(rerank_input_ids)
  145. beam_search_output = self.generator.generate(
  146. n_docs=self.top_k,
  147. encoder_input_ids=input_ids,
  148. context_input_ids=input_ids,
  149. context_attention_mask=attention_mask,
  150. doc_scores=doc_scores,
  151. num_beams=self.config['num_beams'],
  152. max_length=self.config['target_sequence_length'],
  153. early_stopping=True,
  154. no_repeat_ngram_size=self.config['no_repeat_ngram_size'],
  155. return_dict_in_generate=True,
  156. output_scores=True)
  157. generated_ids = beam_search_output.detach().cpu().numpy()
  158. return generated_ids