modeling_dpr.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592
  1. # coding=utf-8
  2. # Copyright 2018 DPR Authors, The Hugging Face Team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DPR model for Open Domain Question Answering."""
  16. from dataclasses import dataclass
  17. from typing import Optional, Union
  18. import torch
  19. from torch import Tensor, nn
  20. from ...modeling_outputs import BaseModelOutputWithPooling
  21. from ...modeling_utils import PreTrainedModel
  22. from ...utils import (
  23. ModelOutput,
  24. auto_docstring,
  25. logging,
  26. )
  27. from ..bert.modeling_bert import BertModel
  28. from .configuration_dpr import DPRConfig
  29. logger = logging.get_logger(__name__)
  30. ##########
  31. # Outputs
  32. ##########
  33. @dataclass
  34. @auto_docstring(
  35. custom_intro="""
  36. Class for outputs of [`DPRQuestionEncoder`].
  37. """
  38. )
  39. class DPRContextEncoderOutput(ModelOutput):
  40. r"""
  41. pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
  42. The DPR encoder outputs the *pooler_output* that corresponds to the context representation. Last layer
  43. hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
  44. This output is to be used to embed contexts for nearest neighbors queries with questions embeddings.
  45. """
  46. pooler_output: torch.FloatTensor
  47. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  48. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  49. @dataclass
  50. @auto_docstring(
  51. custom_intro="""
  52. Class for outputs of [`DPRQuestionEncoder`].
  53. """
  54. )
  55. class DPRQuestionEncoderOutput(ModelOutput):
  56. r"""
  57. pooler_output (`torch.FloatTensor` of shape `(batch_size, embeddings_size)`):
  58. The DPR encoder outputs the *pooler_output* that corresponds to the question representation. Last layer
  59. hidden-state of the first token of the sequence (classification token) further processed by a Linear layer.
  60. This output is to be used to embed questions for nearest neighbors queries with context embeddings.
  61. """
  62. pooler_output: torch.FloatTensor
  63. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  64. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  65. @dataclass
  66. @auto_docstring(
  67. custom_intro="""
  68. Class for outputs of [`DPRQuestionEncoder`].
  69. """
  70. )
  71. class DPRReaderOutput(ModelOutput):
  72. r"""
  73. start_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
  74. Logits of the start index of the span for each passage.
  75. end_logits (`torch.FloatTensor` of shape `(n_passages, sequence_length)`):
  76. Logits of the end index of the span for each passage.
  77. relevance_logits (`torch.FloatTensor` of shape `(n_passages, )`):
  78. Outputs of the QA classifier of the DPRReader that corresponds to the scores of each passage to answer the
  79. question, compared to all the other passages.
  80. """
  81. start_logits: torch.FloatTensor
  82. end_logits: Optional[torch.FloatTensor] = None
  83. relevance_logits: Optional[torch.FloatTensor] = None
  84. hidden_states: Optional[tuple[torch.FloatTensor, ...]] = None
  85. attentions: Optional[tuple[torch.FloatTensor, ...]] = None
  86. @auto_docstring
  87. class DPRPreTrainedModel(PreTrainedModel):
  88. _supports_sdpa = True
  89. def _init_weights(self, module):
  90. """Initialize the weights"""
  91. if isinstance(module, nn.Linear):
  92. # Slightly different from the TF version which uses truncated_normal for initialization
  93. # cf https://github.com/pytorch/pytorch/pull/5617
  94. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  95. if module.bias is not None:
  96. module.bias.data.zero_()
  97. elif isinstance(module, nn.Embedding):
  98. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  99. if module.padding_idx is not None:
  100. module.weight.data[module.padding_idx].zero_()
  101. elif isinstance(module, nn.LayerNorm):
  102. module.bias.data.zero_()
  103. module.weight.data.fill_(1.0)
  104. class DPREncoder(DPRPreTrainedModel):
  105. base_model_prefix = "bert_model"
  106. def __init__(self, config: DPRConfig):
  107. super().__init__(config)
  108. self.bert_model = BertModel(config, add_pooling_layer=False)
  109. if self.bert_model.config.hidden_size <= 0:
  110. raise ValueError("Encoder hidden_size can't be zero")
  111. self.projection_dim = config.projection_dim
  112. if self.projection_dim > 0:
  113. self.encode_proj = nn.Linear(self.bert_model.config.hidden_size, config.projection_dim)
  114. # Initialize weights and apply final processing
  115. self.post_init()
  116. def forward(
  117. self,
  118. input_ids: Tensor,
  119. attention_mask: Optional[Tensor] = None,
  120. token_type_ids: Optional[Tensor] = None,
  121. inputs_embeds: Optional[Tensor] = None,
  122. output_attentions: bool = False,
  123. output_hidden_states: bool = False,
  124. return_dict: bool = False,
  125. ) -> Union[BaseModelOutputWithPooling, tuple[Tensor, ...]]:
  126. outputs = self.bert_model(
  127. input_ids=input_ids,
  128. attention_mask=attention_mask,
  129. token_type_ids=token_type_ids,
  130. inputs_embeds=inputs_embeds,
  131. output_attentions=output_attentions,
  132. output_hidden_states=output_hidden_states,
  133. return_dict=return_dict,
  134. )
  135. sequence_output = outputs[0]
  136. pooled_output = sequence_output[:, 0, :]
  137. if self.projection_dim > 0:
  138. pooled_output = self.encode_proj(pooled_output)
  139. if not return_dict:
  140. return (sequence_output, pooled_output) + outputs[2:]
  141. return BaseModelOutputWithPooling(
  142. last_hidden_state=sequence_output,
  143. pooler_output=pooled_output,
  144. hidden_states=outputs.hidden_states,
  145. attentions=outputs.attentions,
  146. )
  147. @property
  148. def embeddings_size(self) -> int:
  149. if self.projection_dim > 0:
  150. return self.encode_proj.out_features
  151. return self.bert_model.config.hidden_size
  152. class DPRSpanPredictor(DPRPreTrainedModel):
  153. base_model_prefix = "encoder"
  154. def __init__(self, config: DPRConfig):
  155. super().__init__(config)
  156. self.encoder = DPREncoder(config)
  157. self.qa_outputs = nn.Linear(self.encoder.embeddings_size, 2)
  158. self.qa_classifier = nn.Linear(self.encoder.embeddings_size, 1)
  159. # Initialize weights and apply final processing
  160. self.post_init()
  161. def forward(
  162. self,
  163. input_ids: Tensor,
  164. attention_mask: Tensor,
  165. inputs_embeds: Optional[Tensor] = None,
  166. output_attentions: bool = False,
  167. output_hidden_states: bool = False,
  168. return_dict: bool = False,
  169. ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
  170. # notations: N - number of questions in a batch, M - number of passages per questions, L - sequence length
  171. n_passages, sequence_length = input_ids.size() if input_ids is not None else inputs_embeds.size()[:2]
  172. # feed encoder
  173. outputs = self.encoder(
  174. input_ids,
  175. attention_mask=attention_mask,
  176. inputs_embeds=inputs_embeds,
  177. output_attentions=output_attentions,
  178. output_hidden_states=output_hidden_states,
  179. return_dict=return_dict,
  180. )
  181. sequence_output = outputs[0]
  182. # compute logits
  183. logits = self.qa_outputs(sequence_output)
  184. start_logits, end_logits = logits.split(1, dim=-1)
  185. start_logits = start_logits.squeeze(-1).contiguous()
  186. end_logits = end_logits.squeeze(-1).contiguous()
  187. relevance_logits = self.qa_classifier(sequence_output[:, 0, :])
  188. # resize
  189. start_logits = start_logits.view(n_passages, sequence_length)
  190. end_logits = end_logits.view(n_passages, sequence_length)
  191. relevance_logits = relevance_logits.view(n_passages)
  192. if not return_dict:
  193. return (start_logits, end_logits, relevance_logits) + outputs[2:]
  194. return DPRReaderOutput(
  195. start_logits=start_logits,
  196. end_logits=end_logits,
  197. relevance_logits=relevance_logits,
  198. hidden_states=outputs.hidden_states,
  199. attentions=outputs.attentions,
  200. )
  201. ##################
  202. # PreTrainedModel
  203. ##################
  204. class DPRPretrainedContextEncoder(DPRPreTrainedModel):
  205. """
  206. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  207. models.
  208. """
  209. config: DPRConfig
  210. load_tf_weights = None
  211. base_model_prefix = "ctx_encoder"
  212. class DPRPretrainedQuestionEncoder(DPRPreTrainedModel):
  213. """
  214. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  215. models.
  216. """
  217. config: DPRConfig
  218. load_tf_weights = None
  219. base_model_prefix = "question_encoder"
  220. class DPRPretrainedReader(DPRPreTrainedModel):
  221. """
  222. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  223. models.
  224. """
  225. config: DPRConfig
  226. load_tf_weights = None
  227. base_model_prefix = "span_predictor"
  228. ###############
  229. # Actual Models
  230. ###############
  231. @auto_docstring(
  232. custom_intro="""
  233. The bare DPRContextEncoder transformer outputting pooler outputs as context representations.
  234. """
  235. )
  236. class DPRContextEncoder(DPRPretrainedContextEncoder):
  237. def __init__(self, config: DPRConfig):
  238. super().__init__(config)
  239. self.config = config
  240. self.ctx_encoder = DPREncoder(config)
  241. # Initialize weights and apply final processing
  242. self.post_init()
  243. @auto_docstring
  244. def forward(
  245. self,
  246. input_ids: Optional[Tensor] = None,
  247. attention_mask: Optional[Tensor] = None,
  248. token_type_ids: Optional[Tensor] = None,
  249. inputs_embeds: Optional[Tensor] = None,
  250. output_attentions: Optional[bool] = None,
  251. output_hidden_states: Optional[bool] = None,
  252. return_dict: Optional[bool] = None,
  253. ) -> Union[DPRContextEncoderOutput, tuple[Tensor, ...]]:
  254. r"""
  255. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  256. Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
  257. formatted with [CLS] and [SEP] tokens as follows:
  258. (a) For sequence pairs (for a pair title+text for example):
  259. ```
  260. tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  261. token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  262. ```
  263. (b) For single sequences (for a question for example):
  264. ```
  265. tokens: [CLS] the dog is hairy . [SEP]
  266. token_type_ids: 0 0 0 0 0 0 0
  267. ```
  268. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  269. rather than the left.
  270. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  271. [`PreTrainedTokenizer.__call__`] for details.
  272. [What are input IDs?](../glossary#input-ids)
  273. Examples:
  274. ```python
  275. >>> from transformers import DPRContextEncoder, DPRContextEncoderTokenizer
  276. >>> tokenizer = DPRContextEncoderTokenizer.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
  277. >>> model = DPRContextEncoder.from_pretrained("facebook/dpr-ctx_encoder-single-nq-base")
  278. >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
  279. >>> embeddings = model(input_ids).pooler_output
  280. ```"""
  281. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  282. output_hidden_states = (
  283. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  284. )
  285. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  286. if input_ids is not None and inputs_embeds is not None:
  287. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  288. elif input_ids is not None:
  289. input_shape = input_ids.size()
  290. elif inputs_embeds is not None:
  291. input_shape = inputs_embeds.size()[:-1]
  292. else:
  293. raise ValueError("You have to specify either input_ids or inputs_embeds")
  294. device = input_ids.device if input_ids is not None else inputs_embeds.device
  295. if attention_mask is None:
  296. attention_mask = (
  297. torch.ones(input_shape, device=device)
  298. if input_ids is None
  299. else (input_ids != self.config.pad_token_id)
  300. )
  301. if token_type_ids is None:
  302. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  303. outputs = self.ctx_encoder(
  304. input_ids=input_ids,
  305. attention_mask=attention_mask,
  306. token_type_ids=token_type_ids,
  307. inputs_embeds=inputs_embeds,
  308. output_attentions=output_attentions,
  309. output_hidden_states=output_hidden_states,
  310. return_dict=return_dict,
  311. )
  312. if not return_dict:
  313. return outputs[1:]
  314. return DPRContextEncoderOutput(
  315. pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  316. )
  317. @auto_docstring(
  318. custom_intro="""
  319. The bare DPRQuestionEncoder transformer outputting pooler outputs as question representations.
  320. """
  321. )
  322. class DPRQuestionEncoder(DPRPretrainedQuestionEncoder):
  323. def __init__(self, config: DPRConfig):
  324. super().__init__(config)
  325. self.config = config
  326. self.question_encoder = DPREncoder(config)
  327. # Initialize weights and apply final processing
  328. self.post_init()
  329. @auto_docstring
  330. def forward(
  331. self,
  332. input_ids: Optional[Tensor] = None,
  333. attention_mask: Optional[Tensor] = None,
  334. token_type_ids: Optional[Tensor] = None,
  335. inputs_embeds: Optional[Tensor] = None,
  336. output_attentions: Optional[bool] = None,
  337. output_hidden_states: Optional[bool] = None,
  338. return_dict: Optional[bool] = None,
  339. ) -> Union[DPRQuestionEncoderOutput, tuple[Tensor, ...]]:
  340. r"""
  341. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  342. Indices of input sequence tokens in the vocabulary. To match pretraining, DPR input sequence should be
  343. formatted with [CLS] and [SEP] tokens as follows:
  344. (a) For sequence pairs (for a pair title+text for example):
  345. ```
  346. tokens: [CLS] is this jack ##son ##ville ? [SEP] no it is not . [SEP]
  347. token_type_ids: 0 0 0 0 0 0 0 0 1 1 1 1 1 1
  348. ```
  349. (b) For single sequences (for a question for example):
  350. ```
  351. tokens: [CLS] the dog is hairy . [SEP]
  352. token_type_ids: 0 0 0 0 0 0 0
  353. ```
  354. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  355. rather than the left.
  356. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  357. [`PreTrainedTokenizer.__call__`] for details.
  358. [What are input IDs?](../glossary#input-ids)
  359. Examples:
  360. ```python
  361. >>> from transformers import DPRQuestionEncoder, DPRQuestionEncoderTokenizer
  362. >>> tokenizer = DPRQuestionEncoderTokenizer.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
  363. >>> model = DPRQuestionEncoder.from_pretrained("facebook/dpr-question_encoder-single-nq-base")
  364. >>> input_ids = tokenizer("Hello, is my dog cute ?", return_tensors="pt")["input_ids"]
  365. >>> embeddings = model(input_ids).pooler_output
  366. ```
  367. """
  368. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  369. output_hidden_states = (
  370. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  371. )
  372. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  373. if input_ids is not None and inputs_embeds is not None:
  374. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  375. elif input_ids is not None:
  376. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  377. input_shape = input_ids.size()
  378. elif inputs_embeds is not None:
  379. input_shape = inputs_embeds.size()[:-1]
  380. else:
  381. raise ValueError("You have to specify either input_ids or inputs_embeds")
  382. device = input_ids.device if input_ids is not None else inputs_embeds.device
  383. if attention_mask is None:
  384. attention_mask = (
  385. torch.ones(input_shape, device=device)
  386. if input_ids is None
  387. else (input_ids != self.config.pad_token_id)
  388. )
  389. if token_type_ids is None:
  390. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  391. outputs = self.question_encoder(
  392. input_ids=input_ids,
  393. attention_mask=attention_mask,
  394. token_type_ids=token_type_ids,
  395. inputs_embeds=inputs_embeds,
  396. output_attentions=output_attentions,
  397. output_hidden_states=output_hidden_states,
  398. return_dict=return_dict,
  399. )
  400. if not return_dict:
  401. return outputs[1:]
  402. return DPRQuestionEncoderOutput(
  403. pooler_output=outputs.pooler_output, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  404. )
  405. @auto_docstring(
  406. custom_intro="""
  407. The bare DPRReader transformer outputting span predictions.
  408. """
  409. )
  410. class DPRReader(DPRPretrainedReader):
  411. def __init__(self, config: DPRConfig):
  412. super().__init__(config)
  413. self.config = config
  414. self.span_predictor = DPRSpanPredictor(config)
  415. # Initialize weights and apply final processing
  416. self.post_init()
  417. @auto_docstring
  418. def forward(
  419. self,
  420. input_ids: Optional[Tensor] = None,
  421. attention_mask: Optional[Tensor] = None,
  422. inputs_embeds: Optional[Tensor] = None,
  423. output_attentions: Optional[bool] = None,
  424. output_hidden_states: Optional[bool] = None,
  425. return_dict: Optional[bool] = None,
  426. ) -> Union[DPRReaderOutput, tuple[Tensor, ...]]:
  427. r"""
  428. input_ids (`tuple[torch.LongTensor]` of shapes `(n_passages, sequence_length)`):
  429. Indices of input sequence tokens in the vocabulary. It has to be a sequence triplet with 1) the question
  430. and 2) the passages titles and 3) the passages texts To match pretraining, DPR `input_ids` sequence should
  431. be formatted with [CLS] and [SEP] with the format:
  432. `[CLS] <question token ids> [SEP] <titles ids> [SEP] <texts ids>`
  433. DPR is a model with absolute position embeddings so it's usually advised to pad the inputs on the right
  434. rather than the left.
  435. Indices can be obtained using [`DPRReaderTokenizer`]. See this class documentation for more details.
  436. [What are input IDs?](../glossary#input-ids)
  437. inputs_embeds (`torch.FloatTensor` of shape `(n_passages, sequence_length, hidden_size)`, *optional*):
  438. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  439. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  440. model's internal embedding lookup matrix.
  441. Examples:
  442. ```python
  443. >>> from transformers import DPRReader, DPRReaderTokenizer
  444. >>> tokenizer = DPRReaderTokenizer.from_pretrained("facebook/dpr-reader-single-nq-base")
  445. >>> model = DPRReader.from_pretrained("facebook/dpr-reader-single-nq-base")
  446. >>> encoded_inputs = tokenizer(
  447. ... questions=["What is love ?"],
  448. ... titles=["Haddaway"],
  449. ... texts=["'What Is Love' is a song recorded by the artist Haddaway"],
  450. ... return_tensors="pt",
  451. ... )
  452. >>> outputs = model(**encoded_inputs)
  453. >>> start_logits = outputs.start_logits
  454. >>> end_logits = outputs.end_logits
  455. >>> relevance_logits = outputs.relevance_logits
  456. ```
  457. """
  458. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  459. output_hidden_states = (
  460. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  461. )
  462. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  463. if input_ids is not None and inputs_embeds is not None:
  464. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  465. elif input_ids is not None:
  466. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  467. input_shape = input_ids.size()
  468. elif inputs_embeds is not None:
  469. input_shape = inputs_embeds.size()[:-1]
  470. else:
  471. raise ValueError("You have to specify either input_ids or inputs_embeds")
  472. device = input_ids.device if input_ids is not None else inputs_embeds.device
  473. if attention_mask is None:
  474. attention_mask = torch.ones(input_shape, device=device)
  475. return self.span_predictor(
  476. input_ids,
  477. attention_mask,
  478. inputs_embeds=inputs_embeds,
  479. output_attentions=output_attentions,
  480. output_hidden_states=output_hidden_states,
  481. return_dict=return_dict,
  482. )
  483. __all__ = [
  484. "DPRContextEncoder",
  485. "DPRPretrainedContextEncoder",
  486. "DPRPreTrainedModel",
  487. "DPRPretrainedQuestionEncoder",
  488. "DPRPretrainedReader",
  489. "DPRQuestionEncoder",
  490. "DPRReader",
  491. ]