modeling_rag.py 87 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665
  1. # coding=utf-8
  2. # Copyright 2020, The RAG Authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """RAG model implementation."""
  16. import copy
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import torch
  20. from torch import nn
  21. from ...cache_utils import Cache, EncoderDecoderCache
  22. from ...configuration_utils import PretrainedConfig
  23. from ...generation import GenerationConfig, GenerationMixin, LogitsProcessorList, StoppingCriteriaList
  24. from ...modeling_outputs import ModelOutput
  25. from ...modeling_utils import PreTrainedModel
  26. from ...utils import auto_docstring, logging
  27. from .configuration_rag import RagConfig
  28. from .retrieval_rag import RagRetriever
  29. logger = logging.get_logger(__name__)
  30. @dataclass
  31. @auto_docstring(
  32. custom_intro="""
  33. Base class for retriever augmented marginalized models outputs.
  34. """
  35. )
  36. class RetrievAugLMMarginOutput(ModelOutput):
  37. r"""
  38. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  39. Language modeling loss.
  40. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  41. Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
  42. each vocabulary token.
  43. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  44. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  45. `question_encoder_last_hidden_state`.
  46. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  47. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  48. Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
  49. (see `past_key_values` input) to speed up sequential decoding.
  50. retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
  51. Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
  52. the `doc_scores`.
  53. retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
  54. The indexes of the embedded documents retrieved by the retriever.
  55. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  56. Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
  57. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  58. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  59. retriever.
  60. question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  61. Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
  62. model.
  63. question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  64. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  65. shape `(batch_size, sequence_length, hidden_size)`.
  66. Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
  67. question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  68. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  69. sequence_length)`.
  70. Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
  71. average in the self-attention heads.
  72. generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  73. Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
  74. generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  75. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  76. shape `(batch_size, sequence_length, hidden_size)`.
  77. Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
  78. generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  79. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  80. sequence_length)`.
  81. Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
  82. average in the self-attention heads.
  83. generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  84. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  85. shape `(batch_size, sequence_length, hidden_size)`.
  86. Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
  87. generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  88. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  89. sequence_length)`.
  90. Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
  91. average in the self-attention heads.
  92. generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  93. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  94. sequence_length)`.
  95. Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
  96. weighted average in the cross-attention heads.
  97. """
  98. loss: Optional[torch.FloatTensor] = None
  99. logits: Optional[torch.FloatTensor] = None
  100. doc_scores: Optional[torch.FloatTensor] = None
  101. past_key_values: Optional[Cache] = None
  102. retrieved_doc_embeds: Optional[torch.FloatTensor] = None
  103. retrieved_doc_ids: Optional[torch.LongTensor] = None
  104. context_input_ids: Optional[torch.LongTensor] = None
  105. context_attention_mask: Optional[torch.LongTensor] = None
  106. question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  107. question_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  108. question_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  109. generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None
  110. generator_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  111. generator_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  112. generator_dec_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  113. generator_dec_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  114. generator_cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  115. @dataclass
  116. @auto_docstring
  117. class RetrievAugLMOutput(ModelOutput):
  118. r"""
  119. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  120. Prediction scores of the language modeling head. The score is possibly marginalized over all documents for
  121. each vocabulary token.
  122. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  123. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  124. `question_encoder_last_hidden_state`.
  125. past_key_values (`Cache`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
  126. It is a [`~cache_utils.Cache`] instance. For more details, see our [kv cache guide](https://huggingface.co/docs/transformers/en/kv_cache).
  127. Contains precomputed hidden-states (key and values in the attention blocks) of the decoder that can be used
  128. (see `past_key_values` input) to speed up sequential decoding.
  129. retrieved_doc_embeds (`torch.FloatTensor` of shape `(batch_size, config.n_docs, hidden_size)`, *optional*, returned when *output_retrieved=True*):
  130. Embedded documents retrieved by the retriever. Is used with `question_encoder_last_hidden_state` to compute
  131. the `doc_scores`.
  132. retrieved_doc_ids (`torch.LongTensor` of shape `(batch_size, config.n_docs)`, *optional*, returned when *output_retrieved=True*):
  133. The indexes of the embedded documents retrieved by the retriever.
  134. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  135. Input ids post-processed from the retrieved documents and the question encoder input_ids by the retriever.
  136. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  137. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  138. retriever.
  139. question_encoder_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  140. Sequence of hidden states at the output of the last layer of the question encoder pooled output of the
  141. model.
  142. question_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  143. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  144. shape `(batch_size, sequence_length, hidden_size)`.
  145. Hidden states of the question encoder at the output of each layer plus the initial embedding outputs.
  146. question_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  147. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  148. sequence_length)`.
  149. Attentions weights of the question encoder, after the attention softmax, used to compute the weighted
  150. average in the self-attention heads.
  151. generator_enc_last_hidden_state (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  152. Sequence of hidden-states at the output of the last layer of the generator encoder of the model.
  153. generator_enc_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  154. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  155. shape `(batch_size, sequence_length, hidden_size)`.
  156. Hidden states of the generator encoder at the output of each layer plus the initial embedding outputs.
  157. generator_enc_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  158. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  159. sequence_length)`.
  160. Attentions weights of the generator encoder, after the attention softmax, used to compute the weighted
  161. average in the self-attention heads.
  162. generator_dec_hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
  163. Tuple of `torch.FloatTensor` (one for the output of the embeddings and one for the output of each layer) of
  164. shape `(batch_size, sequence_length, hidden_size)`.
  165. Hidden states of the generator decoder at the output of each layer plus the initial embedding outputs.
  166. generator_dec_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  167. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  168. sequence_length)`.
  169. Attentions weights of the generator decoder, after the attention softmax, used to compute the weighted
  170. average in the self-attention heads.
  171. generator_cross_attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
  172. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
  173. sequence_length)`.
  174. Cross-attentions weights of the generator decoder, after the attention softmax, used to compute the
  175. weighted average in the cross-attention heads.
  176. """
  177. logits: Optional[torch.FloatTensor] = None
  178. doc_scores: Optional[torch.FloatTensor] = None
  179. past_key_values: Optional[Cache] = None
  180. retrieved_doc_embeds: Optional[torch.FloatTensor] = None
  181. retrieved_doc_ids: Optional[torch.LongTensor] = None
  182. context_input_ids: Optional[torch.LongTensor] = None
  183. context_attention_mask: Optional[torch.LongTensor] = None
  184. question_encoder_last_hidden_state: Optional[torch.FloatTensor] = None
  185. question_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  186. question_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  187. generator_enc_last_hidden_state: Optional[torch.FloatTensor] = None
  188. generator_enc_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  189. generator_enc_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  190. generator_dec_hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  191. generator_dec_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  192. generator_cross_attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  193. @auto_docstring(
  194. custom_intro="""
  195. RAG models were released with the paper [Retrieval-Augmented Generation for Knowledge-Intensive NLP
  196. Tasks](https://huggingface.co/papers/2005.11401) by Patrick Lewis, Ethan Perez, Aleksandra Piktus et al.
  197. RAG is a retriever augmented model and encapsulate three components: a question encoder, a dataset retriever and a
  198. generator, the encoder and generator are trainable while the retriever is just an indexed dataset.
  199. """
  200. )
  201. @auto_docstring
  202. class RagPreTrainedModel(PreTrainedModel):
  203. config: RagConfig
  204. base_model_prefix = "rag"
  205. _supports_flash_attn = True
  206. _supports_sdpa = True
  207. @classmethod
  208. def from_pretrained_question_encoder_generator(
  209. cls,
  210. question_encoder_pretrained_model_name_or_path: Optional[str] = None,
  211. generator_pretrained_model_name_or_path: Optional[str] = None,
  212. retriever: RagRetriever = None,
  213. **kwargs,
  214. ) -> PreTrainedModel:
  215. r"""
  216. Instantiates an question encoder and a generator from one or two base classes of the library from pretrained
  217. model checkpoints.
  218. The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
  219. the model, you need to first set it back in training mode with `model.train()`.
  220. Params:
  221. question_encoder_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
  222. Information necessary to initiate the question encoder. Can be either:
  223. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  224. - A path to a *directory* containing model weights saved using
  225. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  226. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
  227. this case, `from_tf` should be set to `True` and a configuration object should be provided as
  228. `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
  229. PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
  230. generator_pretrained_model_name_or_path (`str`, *optional*, defaults to `None`):
  231. Information necessary to initiate the generator. Can be either:
  232. - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
  233. - A path to a *directory* containing model weights saved using
  234. [`~PreTrainedModel.save_pretrained`], e.g., `./my_model_directory/`.
  235. - A path or url to a *tensorflow index checkpoint file* (e.g, `./tf_model/model.ckpt.index`). In
  236. this case, `from_tf` should be set to `True` and a configuration object should be provided as
  237. `config` argument. This loading path is slower than converting the TensorFlow checkpoint in a
  238. PyTorch model using the provided conversion scripts and loading the PyTorch model afterwards.
  239. model_args (remaining positional arguments, *optional*):
  240. All remaining positional arguments will be passed to the underlying model's `__init__` method.
  241. retriever ([`RagRetriever`], *optional*):
  242. The retriever to use.
  243. kwwargs (remaining dictionary of keyword arguments, *optional*):
  244. Can be used to update the configuration object (after it being loaded) and initiate the model (e.g.,
  245. `output_attentions=True`).
  246. - To update the question_encoder configuration, use the prefix *question_encoder_* for each
  247. configuration parameter.
  248. - To update the generator configuration, use the prefix *generator_* for each configuration parameter.
  249. - To update the parent model configuration, do not use a prefix for each configuration parameter.
  250. Behaves differently depending on whether a `config` is provided or automatically loaded.
  251. Example:
  252. ```python
  253. >>> from transformers import RagModel
  254. >>> # initialize a RAG from two pretrained models.
  255. >>> model = RagModel.from_pretrained_question_encoder_generator(
  256. ... "facebook/dpr-question_encoder-single-nq-base", "google-t5/t5-small"
  257. ... )
  258. >>> # saving model after fine-tuning
  259. >>> model.save_pretrained("./rag")
  260. >>> # load fine-tuned model
  261. >>> model = RagModel.from_pretrained("./rag")
  262. ```"""
  263. kwargs_question_encoder = {
  264. argument[len("question_encoder_") :]: value
  265. for argument, value in kwargs.items()
  266. if argument.startswith("question_encoder_")
  267. }
  268. kwargs_generator = {
  269. argument[len("generator_") :]: value
  270. for argument, value in kwargs.items()
  271. if argument.startswith("generator_")
  272. }
  273. # remove question_encoder, generator kwargs from kwargs
  274. for key in kwargs_question_encoder:
  275. del kwargs["question_encoder_" + key]
  276. for key in kwargs_generator:
  277. del kwargs["generator_" + key]
  278. # Load and initialize the question_encoder and generator
  279. # The distinction between question_encoder and generator at the model level is made
  280. # by the value of the flag `is_generator` that we need to set correctly.
  281. question_encoder = kwargs_question_encoder.pop("model", None)
  282. if question_encoder is None:
  283. assert question_encoder_pretrained_model_name_or_path is not None, (
  284. "If `model` is not defined as an argument, a `question_encoder_pretrained_model_name_or_path` has to"
  285. " be defined"
  286. )
  287. from ..auto.modeling_auto import AutoModel
  288. if "config" not in kwargs_question_encoder:
  289. from ..auto.configuration_auto import AutoConfig
  290. question_encoder_config, kwargs_question_encoder = AutoConfig.from_pretrained(
  291. question_encoder_pretrained_model_name_or_path,
  292. **kwargs_question_encoder,
  293. return_unused_kwargs=True,
  294. )
  295. kwargs_question_encoder["config"] = question_encoder_config
  296. question_encoder = AutoModel.from_pretrained(
  297. question_encoder_pretrained_model_name_or_path, **kwargs_question_encoder
  298. )
  299. generator = kwargs_generator.pop("model", None)
  300. if generator is None:
  301. assert generator_pretrained_model_name_or_path is not None, (
  302. "If `generator_model` is not defined as an argument, a `generator_pretrained_model_name_or_path` has"
  303. " to be defined"
  304. )
  305. from ..auto.modeling_auto import AutoModelForSeq2SeqLM
  306. if "config" not in kwargs_generator:
  307. from ..auto.configuration_auto import AutoConfig
  308. generator_config, kwargs_generator = AutoConfig.from_pretrained(
  309. generator_pretrained_model_name_or_path, **kwargs_generator, return_unused_kwargs=True
  310. )
  311. kwargs_generator["config"] = generator_config
  312. generator = AutoModelForSeq2SeqLM.from_pretrained(
  313. generator_pretrained_model_name_or_path, **kwargs_generator
  314. )
  315. # instantiate config with corresponding kwargs
  316. config = kwargs.get("config")
  317. if config is None:
  318. config = RagConfig.from_question_encoder_generator_configs(
  319. question_encoder.config, generator.config, **kwargs
  320. )
  321. return cls(question_encoder=question_encoder, generator=generator, config=config, retriever=retriever)
  322. @auto_docstring
  323. class RagModel(RagPreTrainedModel):
  324. def __init__(
  325. self,
  326. config: Optional[PretrainedConfig] = None,
  327. question_encoder: Optional[PreTrainedModel] = None,
  328. generator: Optional[PreTrainedModel] = None,
  329. retriever: Optional[RagRetriever] = None, # or maybe just use a `set_retriever(...)` method
  330. **kwargs,
  331. ):
  332. r"""
  333. question_encoder (`PreTrainedModel`, *optional*):
  334. The model responsible for encoding the question into hidden states for retrieval.
  335. generator (`PreTrainedModel`, *optional*):
  336. The model responsible for generating text based on retrieved documents.
  337. retriever (`RagRetriever`, *optional*):
  338. The component responsible for retrieving documents from a knowledge base given the encoded question.
  339. """
  340. assert config is not None or (question_encoder is not None and generator is not None), (
  341. "Either a configuration or an question_encoder and a generator has to be provided."
  342. )
  343. if config is None:
  344. config = RagConfig.from_question_encoder_generator_configs(
  345. question_encoder.config, generator.config, **kwargs
  346. )
  347. else:
  348. assert isinstance(config, self.config_class), f"config: {config} has to be of type {self.config_class}"
  349. super().__init__(config)
  350. if question_encoder is None:
  351. from ..auto.modeling_auto import AutoModel
  352. question_encoder = AutoModel.from_config(config.question_encoder)
  353. if generator is None:
  354. from ..auto.modeling_auto import AutoModelForSeq2SeqLM
  355. generator = AutoModelForSeq2SeqLM.from_config(config.generator)
  356. self.retriever = retriever
  357. if self.retriever is not None:
  358. assert isinstance(retriever, RagRetriever), (
  359. f"`self.retriever` is of type {type(self.retriever)}, but should be of type `RagRetriever`"
  360. )
  361. self.retriever = retriever
  362. self.question_encoder = question_encoder
  363. self.generator = generator
  364. self.ctx_encoder = None
  365. self.context_encoder_training = False
  366. @auto_docstring
  367. def forward(
  368. self,
  369. input_ids: Optional[torch.LongTensor] = None,
  370. attention_mask: Optional[torch.Tensor] = None,
  371. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  372. decoder_input_ids: Optional[torch.LongTensor] = None,
  373. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  374. past_key_values: Optional[Cache] = None,
  375. doc_scores: Optional[torch.FloatTensor] = None,
  376. context_input_ids: Optional[torch.LongTensor] = None,
  377. context_attention_mask: Optional[torch.LongTensor] = None,
  378. use_cache: Optional[bool] = None,
  379. output_attentions: Optional[bool] = None,
  380. output_hidden_states: Optional[bool] = None,
  381. output_retrieved: Optional[bool] = None,
  382. n_docs: Optional[int] = None,
  383. ) -> Union[tuple[torch.Tensor], RetrievAugLMOutput]:
  384. r"""
  385. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  386. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  387. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  388. obtain the indices.
  389. [What are input IDs?](../glossary#input-ids)
  390. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  391. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  392. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  393. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  394. generator's encoder.
  395. Used by the ([`RagModel`]) model during decoding.
  396. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  397. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  398. you're using with your RAG instance.
  399. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  400. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  401. be used by default.
  402. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  403. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  404. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  405. has to be provided to the forward pass. `doc_scores` can be computed via
  406. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  407. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  408. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  409. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  410. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  411. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  412. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  413. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  414. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  415. output_retrieved (`bool`, *optional*):
  416. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  417. `context_attention_mask`. See returned tensors for more detail.
  418. n_docs (`int`, *optional*):
  419. The number of documents to retrieve.
  420. Example:
  421. ```python
  422. >>> from transformers import AutoTokenizer, RagRetriever, RagModel
  423. >>> import torch
  424. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-base")
  425. >>> retriever = RagRetriever.from_pretrained(
  426. ... "facebook/rag-token-base", index_name="exact", use_dummy_dataset=True
  427. ... )
  428. >>> # initialize with RagRetriever to do everything in one forward call
  429. >>> model = RagModel.from_pretrained("facebook/rag-token-base", retriever=retriever)
  430. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  431. >>> outputs = model(input_ids=inputs["input_ids"])
  432. ```"""
  433. n_docs = n_docs if n_docs is not None else self.config.n_docs
  434. use_cache = use_cache if use_cache is not None else self.config.use_cache
  435. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  436. output_hidden_states = (
  437. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  438. )
  439. output_retrieved = output_retrieved if output_retrieved is not None else self.config.output_retrieved
  440. # whether retriever has to be used
  441. has_to_retrieve = (
  442. self.retriever is not None
  443. and (context_input_ids is None or context_attention_mask is None or doc_scores is None)
  444. and encoder_outputs is None
  445. )
  446. # encoder_outputs are pre-computed during RAG-token generation
  447. if encoder_outputs is None:
  448. if has_to_retrieve:
  449. question_enc_outputs = self.question_encoder(
  450. input_ids, attention_mask=attention_mask, return_dict=True
  451. )
  452. question_encoder_last_hidden_state = question_enc_outputs[0] # hidden states of question encoder
  453. retriever_outputs = self.retriever(
  454. input_ids,
  455. question_encoder_last_hidden_state.detach().to(device="cpu", dtype=torch.float32).numpy(),
  456. prefix=self.generator.config.prefix,
  457. n_docs=n_docs,
  458. return_tensors="pt",
  459. )
  460. if self.context_encoder_training:
  461. (
  462. context_input_ids,
  463. context_attention_mask,
  464. retrieved_doc_embeds,
  465. retrieved_doc_input_ids,
  466. retrieved_doc_attention_mask,
  467. retrieved_doc_ids,
  468. ) = (
  469. retriever_outputs["context_input_ids"],
  470. retriever_outputs["context_attention_mask"],
  471. retriever_outputs["retrieved_doc_embeds"],
  472. retriever_outputs["tokenized_doc_ids"],
  473. retriever_outputs["tokenized_doc_attention_mask"],
  474. retriever_outputs["doc_ids"],
  475. )
  476. context_input_ids = context_input_ids.to(input_ids)
  477. context_attention_mask = context_attention_mask.to(input_ids)
  478. retrieved_doc_input_ids = retrieved_doc_input_ids.to(input_ids)
  479. retrieved_doc_attention_mask = retrieved_doc_attention_mask.to(input_ids)
  480. retrieved_doc_embeds = self.ctx_encoder(
  481. retrieved_doc_input_ids, attention_mask=retrieved_doc_attention_mask, return_dict=True
  482. ).pooler_output
  483. retrieved_doc_embeds = retrieved_doc_embeds.view(
  484. -1, n_docs, question_encoder_last_hidden_state.shape[1]
  485. ) # reshaping
  486. # compute doc_scores involving ctx_encoder
  487. doc_scores = torch.bmm(
  488. question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
  489. ).squeeze(1)
  490. else:
  491. context_input_ids, context_attention_mask, retrieved_doc_embeds, retrieved_doc_ids = (
  492. retriever_outputs["context_input_ids"],
  493. retriever_outputs["context_attention_mask"],
  494. retriever_outputs["retrieved_doc_embeds"],
  495. retriever_outputs["doc_ids"],
  496. )
  497. # set to correct device
  498. retrieved_doc_embeds = retrieved_doc_embeds.to(question_encoder_last_hidden_state)
  499. context_input_ids = context_input_ids.to(input_ids)
  500. context_attention_mask = context_attention_mask.to(input_ids)
  501. # compute doc_scores
  502. doc_scores = torch.bmm(
  503. question_encoder_last_hidden_state.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)
  504. ).squeeze(1)
  505. else:
  506. assert context_input_ids is not None, (
  507. "Make sure that `context_input_ids` are passed, if no `retriever` is set. Alternatively, you can"
  508. " set a retriever using the `set_retriever(...)` function."
  509. )
  510. assert context_attention_mask is not None, (
  511. "Make sure that `context_attention_mask` are passed, if no `retriever` is set. Alternatively, you"
  512. " can set a retriever using the `set_retriever(...)` function."
  513. )
  514. assert doc_scores is not None, (
  515. "Make sure that `doc_scores` are passed, if no `retriever` is set. Alternatively, you can set a"
  516. " retriever using the `set_retriever(...)` function."
  517. )
  518. assert doc_scores is not None, (
  519. "Make sure that `doc_scores` are passed when passing `encoder_outputs` to the forward function."
  520. )
  521. assert (doc_scores.shape[1] % n_docs) == 0, (
  522. f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
  523. f" {context_input_ids.shape[0]}."
  524. )
  525. # Decoder input without context documents
  526. if decoder_input_ids is not None:
  527. decoder_input_ids = decoder_input_ids.repeat_interleave(n_docs, dim=0)
  528. if decoder_attention_mask is not None:
  529. decoder_attention_mask = decoder_attention_mask.repeat_interleave(n_docs, dim=0)
  530. gen_outputs = self.generator(
  531. input_ids=context_input_ids,
  532. attention_mask=context_attention_mask,
  533. encoder_outputs=encoder_outputs,
  534. decoder_input_ids=decoder_input_ids,
  535. decoder_attention_mask=decoder_attention_mask,
  536. past_key_values=past_key_values,
  537. use_cache=use_cache,
  538. output_attentions=output_attentions,
  539. return_dict=True,
  540. )
  541. if not has_to_retrieve:
  542. question_encoder_last_hidden_state = None
  543. question_enc_hidden_states = None
  544. question_enc_attentions = None
  545. retrieved_doc_embeds = None
  546. retrieved_doc_ids = None
  547. else:
  548. question_enc_hidden_states = question_enc_outputs.hidden_states
  549. question_enc_attentions = question_enc_outputs.attentions
  550. if not has_to_retrieve or not output_retrieved:
  551. # don't output retrieved docs
  552. context_input_ids = (None,)
  553. context_attention_mask = None
  554. retrieved_doc_embeds = None
  555. retrieved_doc_ids = None
  556. return RetrievAugLMOutput(
  557. logits=gen_outputs.logits,
  558. doc_scores=doc_scores,
  559. past_key_values=gen_outputs.past_key_values,
  560. context_input_ids=context_input_ids,
  561. context_attention_mask=context_attention_mask,
  562. retrieved_doc_embeds=retrieved_doc_embeds,
  563. retrieved_doc_ids=retrieved_doc_ids,
  564. question_encoder_last_hidden_state=question_encoder_last_hidden_state,
  565. question_enc_hidden_states=question_enc_hidden_states,
  566. question_enc_attentions=question_enc_attentions,
  567. generator_enc_last_hidden_state=gen_outputs.encoder_last_hidden_state,
  568. generator_enc_hidden_states=gen_outputs.encoder_hidden_states,
  569. generator_enc_attentions=gen_outputs.encoder_attentions,
  570. generator_dec_hidden_states=gen_outputs.decoder_hidden_states,
  571. generator_dec_attentions=gen_outputs.decoder_attentions,
  572. generator_cross_attentions=gen_outputs.cross_attentions,
  573. )
  574. @auto_docstring(
  575. custom_intro="""
  576. A RAG-sequence model implementation. It performs RAG-sequence specific marginalization in the forward pass.
  577. """
  578. )
  579. class RagSequenceForGeneration(RagPreTrainedModel):
  580. def __init__(
  581. self,
  582. config: Optional[PretrainedConfig] = None,
  583. question_encoder: Optional[PreTrainedModel] = None,
  584. generator: Optional[PreTrainedModel] = None,
  585. retriever: Optional[RagRetriever] = None,
  586. **kwargs,
  587. ):
  588. r"""
  589. question_encoder (`PreTrainedModel`, *optional*):
  590. The model responsible for encoding the question into hidden states for retrieval.
  591. generator (`PreTrainedModel`, *optional*):
  592. The model responsible for generating text based on retrieved documents.
  593. retriever (`RagRetriever`, *optional*):
  594. The component responsible for retrieving documents from a knowledge base given the encoded question.
  595. """
  596. assert config is not None or (question_encoder is not None and generator is not None), (
  597. "Either a configuration or an encoder and a generator has to be provided."
  598. )
  599. if config is None:
  600. config = RagConfig.from_question_encoder_generator_configs(
  601. question_encoder.config, generator.config, **kwargs
  602. )
  603. super().__init__(config)
  604. # instantiate model
  605. self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
  606. def set_retriever(self, retriever: RagRetriever):
  607. self.rag.retriever = retriever
  608. def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
  609. self.rag.context_encoder_training = True
  610. self.rag.ctx_encoder = ctx_encoder
  611. @auto_docstring
  612. def forward(
  613. self,
  614. input_ids: Optional[torch.LongTensor] = None,
  615. attention_mask: Optional[torch.Tensor] = None,
  616. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  617. decoder_input_ids: Optional[torch.LongTensor] = None,
  618. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  619. past_key_values: Optional[Cache] = None,
  620. context_input_ids: Optional[torch.LongTensor] = None,
  621. context_attention_mask: Optional[torch.LongTensor] = None,
  622. doc_scores: Optional[torch.FloatTensor] = None,
  623. use_cache: Optional[bool] = None,
  624. output_attentions: Optional[bool] = None,
  625. output_hidden_states: Optional[bool] = None,
  626. output_retrieved: Optional[bool] = None,
  627. exclude_bos_score: Optional[bool] = None,
  628. reduce_loss: Optional[bool] = None,
  629. labels: Optional[torch.LongTensor] = None,
  630. n_docs: Optional[int] = None,
  631. **kwargs, # needs kwargs for generation
  632. ) -> RetrievAugLMMarginOutput:
  633. r"""
  634. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  635. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  636. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  637. obtain the indices.
  638. [What are input IDs?](../glossary#input-ids)
  639. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  640. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  641. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  642. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  643. generator's encoder.
  644. Used by the ([`RagModel`]) model during decoding.
  645. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  646. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  647. you're using with your RAG instance.
  648. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  649. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  650. be used by default.
  651. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  652. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  653. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  654. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  655. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  656. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  657. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  658. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  659. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  660. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  661. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  662. has to be provided to the forward pass. `doc_scores` can be computed via
  663. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  664. output_retrieved (`bool`, *optional*):
  665. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  666. `context_attention_mask`. See returned tensors for more detail.
  667. exclude_bos_score (`bool`, *optional*):
  668. Only relevant if `labels` is passed. If `True`, the score of the BOS token is disregarded when computing
  669. the loss.
  670. reduce_loss (`bool`, *optional*):
  671. Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
  672. operation.
  673. n_docs (`int`, *optional*):
  674. The number of documents to retrieve.
  675. Example:
  676. ```python
  677. >>> from transformers import AutoTokenizer, RagRetriever, RagSequenceForGeneration
  678. >>> import torch
  679. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-sequence-nq")
  680. >>> retriever = RagRetriever.from_pretrained(
  681. ... "facebook/rag-sequence-nq", index_name="exact", use_dummy_dataset=True
  682. ... )
  683. >>> # initialize with RagRetriever to do everything in one forward call
  684. >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
  685. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  686. >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
  687. >>> input_ids = inputs["input_ids"]
  688. >>> labels = targets["input_ids"]
  689. >>> outputs = model(input_ids=input_ids, labels=labels)
  690. >>> # or use retriever separately
  691. >>> model = RagSequenceForGeneration.from_pretrained("facebook/rag-sequence-nq", use_dummy_dataset=True)
  692. >>> # 1. Encode
  693. >>> question_hidden_states = model.question_encoder(input_ids)[0]
  694. >>> # 2. Retrieve
  695. >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
  696. >>> doc_scores = torch.bmm(
  697. ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
  698. ... ).squeeze(1)
  699. >>> # 3. Forward to generator
  700. >>> outputs = model(
  701. ... context_input_ids=docs_dict["context_input_ids"],
  702. ... context_attention_mask=docs_dict["context_attention_mask"],
  703. ... doc_scores=doc_scores,
  704. ... decoder_input_ids=labels,
  705. ... )
  706. ```"""
  707. n_docs = n_docs if n_docs is not None else self.config.n_docs
  708. exclude_bos_score = exclude_bos_score if exclude_bos_score is not None else self.config.exclude_bos_score
  709. reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
  710. if labels is not None:
  711. if decoder_input_ids is None:
  712. decoder_input_ids = labels
  713. use_cache = False
  714. outputs = self.rag(
  715. input_ids=input_ids,
  716. attention_mask=attention_mask,
  717. encoder_outputs=encoder_outputs,
  718. decoder_input_ids=decoder_input_ids,
  719. decoder_attention_mask=decoder_attention_mask,
  720. context_input_ids=context_input_ids,
  721. context_attention_mask=context_attention_mask,
  722. doc_scores=doc_scores,
  723. past_key_values=past_key_values,
  724. use_cache=use_cache,
  725. output_attentions=output_attentions,
  726. output_hidden_states=output_hidden_states,
  727. output_retrieved=output_retrieved,
  728. n_docs=n_docs,
  729. )
  730. loss = None
  731. if labels is not None:
  732. loss = self.get_nll(
  733. outputs.logits,
  734. outputs.doc_scores,
  735. decoder_input_ids,
  736. reduce_loss=reduce_loss,
  737. epsilon=self.config.label_smoothing,
  738. exclude_bos_score=exclude_bos_score,
  739. n_docs=n_docs,
  740. )
  741. return RetrievAugLMMarginOutput(
  742. loss=loss,
  743. logits=outputs.logits,
  744. doc_scores=outputs.doc_scores,
  745. past_key_values=outputs.past_key_values,
  746. context_input_ids=outputs.context_input_ids,
  747. context_attention_mask=outputs.context_attention_mask,
  748. retrieved_doc_embeds=outputs.retrieved_doc_embeds,
  749. retrieved_doc_ids=outputs.retrieved_doc_ids,
  750. question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
  751. question_enc_hidden_states=outputs.question_enc_hidden_states,
  752. question_enc_attentions=outputs.question_enc_attentions,
  753. generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
  754. generator_enc_hidden_states=outputs.generator_enc_hidden_states,
  755. generator_enc_attentions=outputs.generator_enc_attentions,
  756. generator_dec_hidden_states=outputs.generator_dec_hidden_states,
  757. generator_dec_attentions=outputs.generator_dec_attentions,
  758. generator_cross_attentions=outputs.generator_cross_attentions,
  759. )
  760. @property
  761. def retriever(self):
  762. return self.rag.retriever
  763. @property
  764. def generator(self):
  765. return self.rag.generator
  766. @property
  767. def question_encoder(self):
  768. return self.rag.question_encoder
  769. @torch.no_grad()
  770. def generate(
  771. self,
  772. input_ids: Optional[torch.LongTensor] = None,
  773. attention_mask: Optional[torch.LongTensor] = None,
  774. context_input_ids: Optional[torch.LongTensor] = None,
  775. context_attention_mask: Optional[torch.LongTensor] = None,
  776. doc_scores: Optional[torch.FloatTensor] = None,
  777. do_deduplication: Optional[bool] = None, # defaults to True
  778. num_return_sequences: Optional[int] = None, # defaults to 1
  779. num_beams: Optional[int] = None, # defaults to 1
  780. n_docs: Optional[int] = None,
  781. **model_kwargs,
  782. ) -> torch.LongTensor:
  783. """
  784. Implements RAG sequence "thorough" decoding. Read the [`~generation.GenerationMixin.generate`]` documentation
  785. for more information on how to set other generate input parameters.
  786. Args:
  787. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  788. The sequence used as a prompt for the generation. If `input_ids` is not passed, then
  789. `context_input_ids` has to be provided.
  790. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  791. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  792. - 1 for tokens that are **not masked**,
  793. - 0 for tokens that are **masked**.
  794. [What are attention masks?](../glossary#attention-mask)
  795. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  796. Input IDs post-processed from the retrieved documents and the question encoder input_ids by the
  797. retriever.
  798. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  799. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  800. retriever.
  801. If the model is not initialized with a `retriever` or `input_ids` is not given, `context_input_ids` and
  802. `context_attention_mask` have to be provided to the forward pass. They are returned by
  803. [`~RagRetriever.__call__`].
  804. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  805. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  806. `question_encoder_last_hidden_state`.
  807. If the model is not initialized with a `retriever` or `input_ids` is not given, `doc_scores` has to be
  808. provided to the forward pass. `doc_scores` are returned by [`~RagRetriever.__call__`].
  809. do_deduplication (`bool`, *optional*):
  810. Whether or not to deduplicate the generations from different context documents for a given input. Has
  811. to be set to `False` if used while training with distributed backend.
  812. num_return_sequences(`int`, *optional*, defaults to 1):
  813. The number of independently computed returned sequences for each element in the batch. Note that this
  814. is not the value we pass to the `generator`'s `[`~generation.GenerationMixin.generate`]` function,
  815. where we set `num_return_sequences` to `num_beams`.
  816. num_beams (`int`, *optional*, defaults to 1):
  817. Number of beams for beam search. 1 means no beam search.
  818. n_docs (`int`, *optional*, defaults to `config.n_docs`)
  819. Number of documents to retrieve and/or number of documents for which to generate an answer.
  820. kwargs (`dict[str, Any]`, *optional*):
  821. Additional kwargs will be passed to [`~generation.GenerationMixin.generate`].
  822. Return:
  823. `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
  824. sequences. The second dimension (sequence length) is either equal to `max_length` or shorter if all batches
  825. finished early due to the `eos_token_id`.
  826. """
  827. n_docs = n_docs if n_docs is not None else self.config.n_docs
  828. do_deduplication = do_deduplication if do_deduplication is not None else self.config.do_deduplication
  829. num_doc_return_sequences = (
  830. num_return_sequences if num_return_sequences is not None else self.config.num_return_sequences
  831. )
  832. num_beams = num_beams if num_beams is not None else self.config.num_beams
  833. assert input_ids is not None or context_input_ids is not None, (
  834. " At least one of input_ids or context_input_ids must be given"
  835. )
  836. if self.retriever is not None and context_input_ids is None:
  837. question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
  838. context_input_ids = self.retriever(
  839. input_ids,
  840. question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
  841. prefix=self.generator.config.prefix,
  842. n_docs=n_docs,
  843. return_tensors="pt",
  844. )["context_input_ids"]
  845. # set to correct device
  846. context_input_ids = context_input_ids.to(input_ids)
  847. hypos = []
  848. model_kwargs["num_beams"] = num_beams
  849. model_kwargs["num_return_sequences"] = num_beams
  850. model_kwargs["attention_mask"] = None
  851. batch_size = input_ids.shape[0] if input_ids is not None else context_input_ids.shape[0] // n_docs
  852. for index in range(batch_size):
  853. # first, generate beams from documents:
  854. generator_input_ids = context_input_ids[index * n_docs : (index + 1) * n_docs] # (n_docs, max_len)
  855. output_sequences = self.generator.generate(
  856. generator_input_ids,
  857. **model_kwargs,
  858. ) # n_docs * n_beam, tgt_len
  859. if do_deduplication:
  860. # do_deduplication, max_output_len
  861. output_sequences = torch.stack(list({str(k.tolist()): k for k in output_sequences}.values()))
  862. num_candidates = output_sequences.shape[
  863. 0
  864. ] # after deduplication, this number can be less than n_docs*n_beam
  865. # then, run model forwards to get nll scores:
  866. if input_ids is not None:
  867. new_input_ids = input_ids[index : index + 1].repeat(num_candidates, 1)
  868. outputs = self(new_input_ids, labels=output_sequences, exclude_bos_score=True)
  869. else: # input_ids is None, need context_input_ids/mask and doc_scores
  870. assert context_attention_mask is not None, (
  871. "Make sure that `context_attention_mask` are passed, if no `input_ids` is set. Alternatively, you"
  872. " can set a retriever using the `set_retriever(...)` function."
  873. )
  874. assert doc_scores is not None, (
  875. "Make sure that `doc_scores` are passed, if no `input_ids` is set. Alternatively, you can set a"
  876. " retriever using the `set_retriever(...)` function."
  877. )
  878. individual_input_ids = generator_input_ids.repeat(
  879. num_candidates, 1
  880. ) # (num_candidates*n_docs, max_len)
  881. individual_attention_mask = context_attention_mask[index * n_docs : (index + 1) * n_docs]
  882. individual_attention_mask = individual_attention_mask.repeat(num_candidates, 1)
  883. individual_doc_scores = doc_scores[index : (index + 1), :] # doc_scores.shape = [batch, n_docs]
  884. individual_doc_scores = individual_doc_scores.repeat(num_candidates, 1) # [num_candidates, n_docs]
  885. outputs = self(
  886. context_input_ids=individual_input_ids,
  887. context_attention_mask=individual_attention_mask,
  888. doc_scores=individual_doc_scores,
  889. labels=output_sequences,
  890. exclude_bos_score=True,
  891. )
  892. top_cand_inds = (-outputs["loss"]).topk(num_doc_return_sequences)[1]
  893. # add hypothesis
  894. hypos.append(output_sequences[top_cand_inds])
  895. return self._cat_and_pad(hypos, pad_token_id=self.config.generator.pad_token_id)
  896. def get_nll(
  897. self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, exclude_bos_score=False, n_docs=None
  898. ):
  899. # shift tokens left
  900. target = torch.cat(
  901. [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
  902. )
  903. n_docs = n_docs if n_docs is not None else self.config.n_docs
  904. # bos_token_id is None for T5
  905. bos_token_id = self.config.bos_token_id or self.config.generator.bos_token_id
  906. use_bos = bos_token_id is not None and target[:, 0].eq(bos_token_id).all()
  907. def _mask_pads(ll, smooth_obj):
  908. pad_mask = target.eq(self.config.generator.pad_token_id)
  909. if pad_mask.any():
  910. ll.masked_fill_(pad_mask, 0.0)
  911. smooth_obj.masked_fill_(pad_mask, 0.0)
  912. return ll.squeeze(-1), smooth_obj.squeeze(-1)
  913. # seq_logits dim = (batch*n_docs, tgt_len , #vocabs)
  914. seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
  915. seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
  916. ) # batch_size x n_docs x tgt_len x #vocab_size
  917. doc_logprobs = nn.functional.log_softmax(doc_scores, dim=1).unsqueeze(-1).unsqueeze(-1)
  918. # RAG-sequence marginalization
  919. first_token_scores = seq_logprobs[:, :, :1, :]
  920. second_token_scores = seq_logprobs[:, :, 1:2, :]
  921. remainder = seq_logprobs[:, :, 2:, :]
  922. rag_logprobs = torch.cat([first_token_scores, second_token_scores + doc_logprobs, remainder], dim=2)
  923. # calculate loss
  924. target = target.unsqueeze(1).unsqueeze(-1).repeat(1, n_docs, 1, 1)
  925. assert target.dim() == rag_logprobs.dim()
  926. ll = rag_logprobs.gather(dim=-1, index=target)
  927. smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
  928. ll, smooth_obj = _mask_pads(ll, smooth_obj)
  929. # sum over tokens, exclude bos while scoring
  930. ll = ll[:, :, 1:].sum(2) if exclude_bos_score and use_bos else ll.sum(2)
  931. smooth_obj = smooth_obj.sum(2)
  932. ll = ll.logsumexp(1) # logsumexp over docs
  933. smooth_obj = smooth_obj.logsumexp(1)
  934. nll_loss = -ll
  935. smooth_loss = -smooth_obj
  936. if reduce_loss:
  937. nll_loss = nll_loss.sum()
  938. smooth_loss = smooth_loss.sum()
  939. eps_i = epsilon / rag_logprobs.size(-1)
  940. loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
  941. return loss
  942. @staticmethod
  943. def _cat_and_pad(tensors, pad_token_id):
  944. output = tensors[0].new(sum(t.shape[0] for t in tensors), max(t.shape[1] for t in tensors)).fill_(pad_token_id)
  945. ind = 0
  946. for t in tensors:
  947. output[ind : ind + t.shape[0], : t.shape[1]] = t
  948. ind += t.shape[0]
  949. return output
  950. @auto_docstring(
  951. custom_intro="""
  952. A RAG-token model implementation. It performs RAG-token specific marginalization in the forward pass.
  953. """
  954. )
  955. class RagTokenForGeneration(RagPreTrainedModel, GenerationMixin):
  956. def __init__(
  957. self,
  958. config: Optional[PretrainedConfig] = None,
  959. question_encoder: Optional[PreTrainedModel] = None,
  960. generator: Optional[PreTrainedModel] = None,
  961. retriever: Optional[RagRetriever] = None,
  962. **kwargs,
  963. ):
  964. r"""
  965. question_encoder (`PreTrainedModel`, *optional*):
  966. The model responsible for encoding the question into hidden states for retrieval.
  967. generator (`PreTrainedModel`, *optional*):
  968. The model responsible for generating text based on retrieved documents.
  969. retriever (`RagRetriever`, *optional*):
  970. The component responsible for retrieving documents from a knowledge base given the encoded question.
  971. """
  972. assert config is not None or (question_encoder is not None and generator is not None), (
  973. "Either a configuration or an encoder and a generator has to be provided."
  974. )
  975. if config is None:
  976. config = RagConfig.from_question_encoder_generator_configs(
  977. question_encoder.config, generator.config, **kwargs
  978. )
  979. super().__init__(config)
  980. # instantiate model
  981. self.rag = RagModel(config=config, question_encoder=question_encoder, generator=generator, retriever=retriever)
  982. def set_retriever(self, retriever: RagRetriever):
  983. self.rag.retriever = retriever
  984. def set_context_encoder_for_training(self, ctx_encoder: PreTrainedModel):
  985. self.rag.context_encoder_training = True
  986. self.rag.ctx_encoder = ctx_encoder
  987. def prepare_inputs_for_generation(
  988. self,
  989. decoder_input_ids,
  990. past_key_values=None,
  991. attention_mask=None,
  992. use_cache=None,
  993. encoder_outputs=None,
  994. doc_scores=None,
  995. n_docs=None,
  996. **kwargs,
  997. ):
  998. # Overwritten -- `do_marginalize` is explicitly set in the output
  999. if past_key_values is not None:
  1000. # if past is defined use only last decoder_input_ids
  1001. decoder_input_ids = decoder_input_ids[:, -1:]
  1002. return {
  1003. "input_ids": None,
  1004. "encoder_outputs": encoder_outputs,
  1005. "doc_scores": doc_scores,
  1006. "context_attention_mask": attention_mask,
  1007. "decoder_input_ids": decoder_input_ids,
  1008. "past_key_values": past_key_values,
  1009. "use_cache": use_cache,
  1010. "do_marginalize": True,
  1011. "n_docs": n_docs,
  1012. }
  1013. @property
  1014. def retriever(self):
  1015. return self.rag.retriever
  1016. @property
  1017. def generator(self):
  1018. return self.rag.generator
  1019. @property
  1020. def question_encoder(self):
  1021. return self.rag.question_encoder
  1022. @staticmethod
  1023. def _reorder_cache(past_key_values, beam_idx):
  1024. """Reorders cache for generation. BART-inspired but we need to take care of the extra dimension for docs"""
  1025. def _reorder_stacked(hidden_states, new_order):
  1026. n_docs = hidden_states.shape[0] // new_order.shape[0]
  1027. hidden_states = hidden_states.view(-1, n_docs, *hidden_states.shape[1:])
  1028. hidden_states = hidden_states.index_select(0, new_order)
  1029. result = hidden_states.view(-1, *hidden_states.shape[2:])
  1030. return result
  1031. reordered_past = ()
  1032. for layer_past in past_key_values:
  1033. # get the correct batch idx from decoder layer's batch dim for cross and self-attn
  1034. reordered_past += (
  1035. tuple(_reorder_stacked(past_state, beam_idx.to(past_state.device)) for past_state in layer_past),
  1036. )
  1037. if isinstance(past_key_values, EncoderDecoderCache):
  1038. reordered_past = EncoderDecoderCache.from_legacy_cache(reordered_past)
  1039. return reordered_past
  1040. def marginalize(self, seq_logits, doc_scores, n_docs=None):
  1041. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1042. # RAG-token marginalization
  1043. seq_logprobs = nn.functional.log_softmax(seq_logits, dim=-1).view(
  1044. seq_logits.shape[0] // n_docs, n_docs, -1, seq_logits.size(-1)
  1045. )
  1046. doc_logprobs = torch.log_softmax(doc_scores, dim=1)
  1047. log_prob_sum = seq_logprobs + doc_logprobs.unsqueeze(-1).unsqueeze(-1)
  1048. return torch.logsumexp(log_prob_sum, dim=1)
  1049. @auto_docstring
  1050. def forward(
  1051. self,
  1052. input_ids: Optional[torch.LongTensor] = None,
  1053. attention_mask: Optional[torch.FloatTensor] = None,
  1054. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1055. decoder_input_ids: Optional[torch.LongTensor] = None,
  1056. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1057. past_key_values: Optional[Cache] = None,
  1058. context_input_ids: Optional[torch.LongTensor] = None,
  1059. context_attention_mask: Optional[torch.LongTensor] = None,
  1060. doc_scores: Optional[torch.FloatTensor] = None,
  1061. use_cache: Optional[bool] = None,
  1062. output_attentions: Optional[bool] = None,
  1063. output_hidden_states: Optional[bool] = None,
  1064. output_retrieved: Optional[bool] = None,
  1065. do_marginalize: Optional[bool] = None,
  1066. reduce_loss: Optional[bool] = None,
  1067. labels: Optional[torch.LongTensor] = None,
  1068. n_docs: Optional[int] = None,
  1069. **kwargs, # needs kwargs for generation
  1070. ) -> RetrievAugLMMarginOutput:
  1071. r"""
  1072. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1073. Indices of input sequence tokens in the vocabulary. [`RagConfig`], used to initialize the model, specifies
  1074. which generator to use, it also specifies a compatible generator tokenizer. Use that tokenizer class to
  1075. obtain the indices.
  1076. [What are input IDs?](../glossary#input-ids)
  1077. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*)
  1078. Tuple consists of (`generator_enc_last_hidden_state`, *optional*: `generator_enc_hidden_states`,
  1079. *optional*: `generator_enc_attentions`). `generator_enc_last_hidden_state` of shape `(batch_size, n_docs *
  1080. sequence_length, hidden_size)` is a sequence of hidden-states at the output of the last layer of the
  1081. generator's encoder.
  1082. Used by the ([`RagModel`]) model during decoding.
  1083. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1084. Provide for generation tasks. `None` by default, construct as per instructions for the generator model
  1085. you're using with your RAG instance.
  1086. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1087. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1088. be used by default.
  1089. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1090. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  1091. retriever. If the model was not initialized with a `retriever` ``context_input_ids` has to be provided to
  1092. the forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1093. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`,*optional*, returned when *output_retrieved=True*):
  1094. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  1095. retriever. If the model has is not initialized with a `retriever` `context_attention_mask` has to be
  1096. provided to the forward pass. `context_attention_mask` are returned by [`~RagRetriever.__call__`].
  1097. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  1098. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  1099. `question_encoder_last_hidden_state`. If the model has is not initialized with a `retriever` `doc_scores`
  1100. has to be provided to the forward pass. `doc_scores` can be computed via
  1101. `question_encoder_last_hidden_state` and `retrieved_doc_embeds`, see examples for more information.
  1102. output_retrieved (`bool`, *optional*):
  1103. Whether or not to return the `retrieved_doc_embeds`, `retrieved_doc_ids`, `context_input_ids` and
  1104. `context_attention_mask`. See returned tensors for more detail.
  1105. do_marginalize (`bool`, *optional*):
  1106. If `True`, the logits are marginalized over all documents by making use of
  1107. `torch.nn.functional.log_softmax`.
  1108. reduce_loss (`bool`, *optional*):
  1109. Only relevant if `labels` is passed. If `True`, the NLL loss is reduced using the `torch.Tensor.sum`
  1110. operation.
  1111. n_docs (`int`, *optional*):
  1112. The number of documents to retrieve.
  1113. Example:
  1114. ```python
  1115. >>> from transformers import AutoTokenizer, RagRetriever, RagTokenForGeneration
  1116. >>> import torch
  1117. >>> tokenizer = AutoTokenizer.from_pretrained("facebook/rag-token-nq")
  1118. >>> retriever = RagRetriever.from_pretrained(
  1119. ... "facebook/rag-token-nq", index_name="exact", use_dummy_dataset=True
  1120. ... )
  1121. >>> # initialize with RagRetriever to do everything in one forward call
  1122. >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", retriever=retriever)
  1123. >>> inputs = tokenizer("How many people live in Paris?", return_tensors="pt")
  1124. >>> targets = tokenizer(text_target="In Paris, there are 10 million people.", return_tensors="pt")
  1125. >>> input_ids = inputs["input_ids"]
  1126. >>> labels = targets["input_ids"]
  1127. >>> outputs = model(input_ids=input_ids, labels=labels)
  1128. >>> # or use retriever separately
  1129. >>> model = RagTokenForGeneration.from_pretrained("facebook/rag-token-nq", use_dummy_dataset=True)
  1130. >>> # 1. Encode
  1131. >>> question_hidden_states = model.question_encoder(input_ids)[0]
  1132. >>> # 2. Retrieve
  1133. >>> docs_dict = retriever(input_ids.numpy(), question_hidden_states.detach().numpy(), return_tensors="pt")
  1134. >>> doc_scores = torch.bmm(
  1135. ... question_hidden_states.unsqueeze(1), docs_dict["retrieved_doc_embeds"].float().transpose(1, 2)
  1136. ... ).squeeze(1)
  1137. >>> # 3. Forward to generator
  1138. >>> outputs = model(
  1139. ... context_input_ids=docs_dict["context_input_ids"],
  1140. ... context_attention_mask=docs_dict["context_attention_mask"],
  1141. ... doc_scores=doc_scores,
  1142. ... decoder_input_ids=labels,
  1143. ... )
  1144. >>> # or directly generate
  1145. >>> generated = model.generate(
  1146. ... context_input_ids=docs_dict["context_input_ids"],
  1147. ... context_attention_mask=docs_dict["context_attention_mask"],
  1148. ... doc_scores=doc_scores,
  1149. ... )
  1150. >>> generated_string = tokenizer.batch_decode(generated, skip_special_tokens=True)
  1151. ```"""
  1152. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1153. do_marginalize = do_marginalize if do_marginalize is not None else self.config.do_marginalize
  1154. reduce_loss = reduce_loss if reduce_loss is not None else self.config.reduce_loss
  1155. if labels is not None:
  1156. if decoder_input_ids is None:
  1157. decoder_input_ids = labels
  1158. use_cache = False
  1159. outputs = self.rag(
  1160. input_ids=input_ids,
  1161. attention_mask=attention_mask,
  1162. encoder_outputs=encoder_outputs,
  1163. decoder_input_ids=decoder_input_ids,
  1164. decoder_attention_mask=decoder_attention_mask,
  1165. context_input_ids=context_input_ids,
  1166. context_attention_mask=context_attention_mask,
  1167. doc_scores=doc_scores,
  1168. past_key_values=past_key_values,
  1169. use_cache=use_cache,
  1170. output_attentions=output_attentions,
  1171. output_hidden_states=output_hidden_states,
  1172. output_retrieved=output_retrieved,
  1173. n_docs=n_docs,
  1174. )
  1175. loss = None
  1176. logits = outputs.logits
  1177. if labels is not None:
  1178. assert decoder_input_ids is not None
  1179. loss = self.get_nll(
  1180. outputs.logits,
  1181. outputs.doc_scores,
  1182. labels,
  1183. reduce_loss=reduce_loss,
  1184. epsilon=self.config.label_smoothing,
  1185. n_docs=n_docs,
  1186. )
  1187. if do_marginalize:
  1188. logits = self.marginalize(logits, outputs.doc_scores, n_docs)
  1189. return RetrievAugLMMarginOutput(
  1190. loss=loss,
  1191. logits=logits,
  1192. doc_scores=outputs.doc_scores,
  1193. past_key_values=outputs.past_key_values,
  1194. context_input_ids=outputs.context_input_ids,
  1195. context_attention_mask=outputs.context_attention_mask,
  1196. retrieved_doc_embeds=outputs.retrieved_doc_embeds,
  1197. retrieved_doc_ids=outputs.retrieved_doc_ids,
  1198. question_encoder_last_hidden_state=outputs.question_encoder_last_hidden_state,
  1199. question_enc_hidden_states=outputs.question_enc_hidden_states,
  1200. question_enc_attentions=outputs.question_enc_attentions,
  1201. generator_enc_last_hidden_state=outputs.generator_enc_last_hidden_state,
  1202. generator_enc_hidden_states=outputs.generator_enc_hidden_states,
  1203. generator_enc_attentions=outputs.generator_enc_attentions,
  1204. generator_dec_hidden_states=outputs.generator_dec_hidden_states,
  1205. generator_dec_attentions=outputs.generator_dec_attentions,
  1206. generator_cross_attentions=outputs.generator_cross_attentions,
  1207. )
  1208. @torch.no_grad()
  1209. def generate(
  1210. self,
  1211. input_ids: Optional[torch.LongTensor] = None,
  1212. attention_mask: Optional[torch.LongTensor] = None,
  1213. context_input_ids: Optional[torch.LongTensor] = None,
  1214. context_attention_mask: Optional[torch.LongTensor] = None,
  1215. doc_scores: Optional[torch.FloatTensor] = None,
  1216. n_docs: Optional[int] = None,
  1217. generation_config: Optional[GenerationConfig] = None,
  1218. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor], list[int]]] = None,
  1219. logits_processor: Optional[LogitsProcessorList] = LogitsProcessorList(),
  1220. stopping_criteria: Optional[StoppingCriteriaList] = StoppingCriteriaList(),
  1221. **kwargs,
  1222. ) -> torch.LongTensor:
  1223. """
  1224. Implements RAG token decoding.
  1225. Args:
  1226. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1227. The sequence used as a prompt for the generation. If `input_ids` is not passed, then
  1228. `context_input_ids` has to be provided.
  1229. attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
  1230. Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
  1231. - 1 for tokens that are **not masked**,
  1232. - 0 for tokens that are **masked**.
  1233. [What are attention masks?](../glossary#attention-mask)
  1234. context_input_ids (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1235. Input IDs post-processed from the retrieved documents and the question encoder `input_ids` by the
  1236. retriever.
  1237. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1238. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1239. context_attention_mask (`torch.LongTensor` of shape `(batch_size * config.n_docs, config.max_combined_length)`, *optional*, returned when *output_retrieved=True*):
  1240. Attention mask post-processed from the retrieved documents and the question encoder `input_ids` by the
  1241. retriever.
  1242. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1243. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1244. doc_scores (`torch.FloatTensor` of shape `(batch_size, config.n_docs)`):
  1245. Score between each retrieved document embeddings (see `retrieved_doc_embeds`) and
  1246. `question_encoder_last_hidden_state`.
  1247. If the model has is not initialized with a `retriever`, `context_input_ids` has to be provided to the
  1248. forward pass. `context_input_ids` are returned by [`~RagRetriever.__call__`].
  1249. n_docs (`int`, *optional*, defaults to `config.n_docs`)
  1250. Number of documents to retrieve and/or number of documents for which to generate an answer.
  1251. generation_config (`~generation.GenerationConfig`, *optional*):
  1252. The generation configuration to be used as base parametrization for the generation call. `**kwargs`
  1253. passed to generate matching the attributes of `generation_config` will override them. If
  1254. `generation_config` is not provided, the default will be used, which has the following loading
  1255. priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
  1256. configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
  1257. default values, whose documentation should be checked to parameterize generation.
  1258. prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], list[int]]`, *optional*):
  1259. If provided, this function constraints the beam search to allowed tokens only at each step. If not
  1260. provided no constraint is applied. This function takes 2 arguments `inputs_ids` and the batch ID
  1261. `batch_id`. It has to return a list with the allowed tokens for the next generation step conditioned on
  1262. the previously generated tokens `inputs_ids` and the batch ID `batch_id`. This argument is useful for
  1263. constrained generation conditioned on the prefix, as described in [Autoregressive Entity
  1264. Retrieval](https://huggingface.co/papers/2010.00904).
  1265. logits_processor (`LogitsProcessorList`, *optional*):
  1266. Custom logits processors that complement the default logits processors built from arguments and a
  1267. model's config. If a logit processor is passed that is already created with the arguments or a model's
  1268. config an error is thrown.
  1269. stopping_criteria (`StoppingCriteriaList`, *optional*):
  1270. Custom stopping criteria that complement the default stopping criteria built from arguments and a
  1271. model's config. If a stopping criteria is passed that is already created with the arguments or a
  1272. model's config an error is thrown.
  1273. kwargs (`dict[str, Any]`, *optional*):
  1274. Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
  1275. forwarded to the `forward` function of the model.
  1276. Return:
  1277. `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated
  1278. sequences. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches
  1279. finished early due to the `eos_token_id`.
  1280. """
  1281. # Handle `generation_config` and kwargs that might update it
  1282. if generation_config is None:
  1283. generation_config = self.generation_config
  1284. generation_config = copy.deepcopy(generation_config)
  1285. model_kwargs = generation_config.update(**kwargs) # All unused kwargs must be model kwargs
  1286. kwargs_has_attention_mask = model_kwargs.get("attention_mask", None) is not None
  1287. self._prepare_special_tokens(generation_config, kwargs_has_attention_mask)
  1288. # set default parameters
  1289. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1290. # retrieve docs
  1291. if self.retriever is not None and context_input_ids is None:
  1292. question_hidden_states = self.question_encoder(input_ids, attention_mask=attention_mask)[0]
  1293. out = self.retriever(
  1294. input_ids,
  1295. question_hidden_states.detach().to(device="cpu", dtype=torch.float32).numpy(),
  1296. prefix=self.generator.config.prefix,
  1297. n_docs=n_docs,
  1298. return_tensors="pt",
  1299. )
  1300. context_input_ids, context_attention_mask, retrieved_doc_embeds = (
  1301. out["context_input_ids"],
  1302. out["context_attention_mask"],
  1303. out["retrieved_doc_embeds"],
  1304. )
  1305. # set to correct device
  1306. retrieved_doc_embeds = retrieved_doc_embeds.to(question_hidden_states)
  1307. context_input_ids = context_input_ids.to(input_ids)
  1308. context_attention_mask = context_attention_mask.to(input_ids)
  1309. # compute doc_scores
  1310. doc_scores = torch.bmm(question_hidden_states.unsqueeze(1), retrieved_doc_embeds.transpose(1, 2)).squeeze(
  1311. 1
  1312. )
  1313. assert (context_input_ids.shape[0] % n_docs) == 0, (
  1314. f" The first dimension of `context_input_ids` should be a multiple of `n_docs`={n_docs}, but is"
  1315. f" {context_input_ids.shape[0]}."
  1316. )
  1317. # batch_size
  1318. batch_size = context_input_ids.shape[0] // n_docs
  1319. encoder = self.rag.generator.get_encoder()
  1320. encoder_outputs = encoder(input_ids=context_input_ids, attention_mask=context_attention_mask, return_dict=True)
  1321. input_ids = torch.full(
  1322. (batch_size * generation_config.num_beams, 1),
  1323. generation_config.decoder_start_token_id,
  1324. dtype=torch.long,
  1325. device=next(self.parameters()).device,
  1326. )
  1327. input_ids_seq_length = input_ids.shape[-1]
  1328. last_hidden_state = encoder_outputs["last_hidden_state"]
  1329. def extend_enc_output(tensor, num_beams=None):
  1330. # split into `batch_size`, `num_beams`, `num_docs`
  1331. tensor = tensor[None, None, :].reshape((batch_size, 1, n_docs) + tensor.shape[1:])
  1332. # repeat same last hidden states over `num_beams` dimension
  1333. tensor = tensor.expand((batch_size, num_beams, n_docs) + tensor.shape[3:])
  1334. # merge `batch_size`, `num_beams`, `num_docs` dims again
  1335. return tensor.reshape((batch_size * num_beams * n_docs,) + tensor.shape[3:])
  1336. # correctly extend last_hidden_state and attention mask
  1337. context_attention_mask = extend_enc_output(context_attention_mask, num_beams=generation_config.num_beams)
  1338. encoder_outputs["last_hidden_state"] = extend_enc_output(
  1339. last_hidden_state, num_beams=generation_config.num_beams
  1340. )
  1341. doc_scores = doc_scores.repeat_interleave(generation_config.num_beams, dim=0)
  1342. # define start_len & additional parameters
  1343. model_kwargs["doc_scores"] = doc_scores
  1344. model_kwargs["encoder_outputs"] = encoder_outputs
  1345. model_kwargs["attention_mask"] = context_attention_mask
  1346. model_kwargs["n_docs"] = n_docs
  1347. pre_processor = self._get_logits_processor(
  1348. generation_config=generation_config,
  1349. input_ids_seq_length=input_ids_seq_length,
  1350. encoder_input_ids=context_input_ids,
  1351. prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
  1352. logits_processor=logits_processor,
  1353. device=input_ids.device,
  1354. )
  1355. prepared_stopping_criteria = self._get_stopping_criteria(
  1356. generation_config=generation_config, stopping_criteria=stopping_criteria
  1357. )
  1358. self._prepare_cache_for_generation(
  1359. generation_config,
  1360. model_kwargs,
  1361. generation_mode=None,
  1362. batch_size=input_ids.shape[0],
  1363. max_cache_length=generation_config.max_length - 1,
  1364. )
  1365. if generation_config.num_beams == 1:
  1366. if generation_config.num_return_sequences > 1:
  1367. raise ValueError(
  1368. f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
  1369. " greedy search."
  1370. )
  1371. return self._sample(
  1372. input_ids,
  1373. logits_processor=pre_processor,
  1374. stopping_criteria=prepared_stopping_criteria,
  1375. generation_config=generation_config,
  1376. synced_gpus=False,
  1377. streamer=None,
  1378. **model_kwargs,
  1379. )
  1380. elif generation_config.num_beams > 1:
  1381. if generation_config.num_return_sequences > generation_config.num_beams:
  1382. raise ValueError("`num_return_sequences` has to be smaller or equal to `num_beams`.")
  1383. return self._beam_search(
  1384. input_ids,
  1385. logits_processor=pre_processor,
  1386. stopping_criteria=prepared_stopping_criteria,
  1387. generation_config=generation_config,
  1388. synced_gpus=False,
  1389. **model_kwargs,
  1390. )
  1391. else:
  1392. raise ValueError(
  1393. f"`num_beams` has to be an integer strictly superior to 0 (≥ 1), but is {generation_config.num_beams}"
  1394. )
  1395. # Auxiliary functions for beam search
  1396. def _temporary_reorder_cache(self, past_key_values, beam_idx):
  1397. # RAG should always use the legacy path even though the LM backbone (T5) uses new cache format
  1398. # because RAG expands input for doc-size internally. TODO: raushan, remove me when all models support
  1399. # new cache format
  1400. past_key_values = self._reorder_cache(past_key_values, beam_idx)
  1401. return past_key_values
  1402. def get_input_embeddings(self):
  1403. return self.rag.generator.get_input_embeddings()
  1404. def get_output_embeddings(self):
  1405. return self.rag.generator.get_output_embeddings()
  1406. def set_output_embeddings(self, new_embeddings):
  1407. return self.rag.generator.set_output_embeddings(new_embeddings)
  1408. def shift_tokens_right(self, input_ids, start_token_id=None):
  1409. """Shift input ids one token to the right, and pad with start_token_id"""
  1410. if start_token_id is None:
  1411. start_token_id = self.config.decoder_start_token_id
  1412. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  1413. shifted_input_ids[:, 1:] = input_ids[:, :-1].clone()
  1414. shifted_input_ids[:, 0] = start_token_id
  1415. return shifted_input_ids
  1416. def get_nll(self, seq_logits, doc_scores, target, reduce_loss=False, epsilon=0.0, n_docs=None):
  1417. n_docs = n_docs if n_docs is not None else self.config.n_docs
  1418. # shift tokens left
  1419. target = torch.cat(
  1420. [target[:, 1:], target.new(target.shape[0], 1).fill_(self.config.generator.pad_token_id)], 1
  1421. )
  1422. def _mask_pads(ll, smooth_obj):
  1423. pad_mask = target.eq(self.config.generator.pad_token_id)
  1424. if pad_mask.any():
  1425. ll.masked_fill_(pad_mask, 0.0)
  1426. smooth_obj.masked_fill_(pad_mask, 0.0)
  1427. return ll.squeeze(-1), smooth_obj.squeeze(-1)
  1428. rag_logprobs = self.marginalize(seq_logits, doc_scores, n_docs)
  1429. target = target.unsqueeze(-1)
  1430. assert target.dim() == rag_logprobs.dim()
  1431. ll = rag_logprobs.gather(dim=-1, index=target)
  1432. smooth_obj = rag_logprobs.sum(dim=-1, keepdim=True) # total sum of all (normalised) logits
  1433. ll, smooth_obj = _mask_pads(ll, smooth_obj)
  1434. ll = ll.sum(1) # sum over tokens
  1435. smooth_obj = smooth_obj.sum(1)
  1436. nll_loss = -ll
  1437. smooth_loss = -smooth_obj
  1438. if reduce_loss:
  1439. nll_loss = nll_loss.sum()
  1440. smooth_loss = smooth_loss.sum()
  1441. eps_i = epsilon / rag_logprobs.size(-1)
  1442. loss = (1.0 - epsilon) * nll_loss + eps_i * smooth_loss
  1443. return loss
  1444. __all__ = ["RagModel", "RagPreTrainedModel", "RagSequenceForGeneration", "RagTokenForGeneration"]