modeling_splinter.py 37 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862
  1. # coding=utf-8
  2. # Copyright 2021 Tel AViv University, AllenAI and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Splinter model."""
  16. from dataclasses import dataclass
  17. from typing import Callable, Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ...activations import ACT2FN
  22. from ...modeling_layers import GradientCheckpointingLayer
  23. from ...modeling_outputs import BaseModelOutput, ModelOutput, QuestionAnsweringModelOutput
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  26. from ...utils import (
  27. auto_docstring,
  28. can_return_tuple,
  29. logging,
  30. )
  31. from .configuration_splinter import SplinterConfig
  32. logger = logging.get_logger(__name__)
  33. class SplinterEmbeddings(nn.Module):
  34. """Construct the embeddings from word, position and token_type embeddings."""
  35. def __init__(self, config):
  36. super().__init__()
  37. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  38. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  39. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  40. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  41. # any TensorFlow checkpoint file
  42. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  43. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  44. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  45. self.register_buffer(
  46. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  47. )
  48. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  49. def forward(
  50. self,
  51. input_ids: Optional[torch.LongTensor] = None,
  52. token_type_ids: Optional[torch.LongTensor] = None,
  53. position_ids: Optional[torch.LongTensor] = None,
  54. inputs_embeds: Optional[torch.FloatTensor] = None,
  55. ) -> tuple:
  56. if input_ids is not None:
  57. input_shape = input_ids.size()
  58. else:
  59. input_shape = inputs_embeds.size()[:-1]
  60. seq_length = input_shape[1]
  61. if position_ids is None:
  62. position_ids = self.position_ids[:, :seq_length]
  63. if token_type_ids is None:
  64. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  65. if inputs_embeds is None:
  66. inputs_embeds = self.word_embeddings(input_ids)
  67. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  68. embeddings = inputs_embeds + token_type_embeddings
  69. if self.position_embedding_type == "absolute":
  70. position_embeddings = self.position_embeddings(position_ids)
  71. embeddings += position_embeddings
  72. embeddings = self.LayerNorm(embeddings)
  73. embeddings = self.dropout(embeddings)
  74. return embeddings
  75. # Copied from transformers.models.align.modeling_align.eager_attention_forward
  76. def eager_attention_forward(
  77. module: nn.Module,
  78. query: torch.Tensor,
  79. key: torch.Tensor,
  80. value: torch.Tensor,
  81. attention_mask: Optional[torch.Tensor],
  82. scaling: float,
  83. dropout: float = 0.0,
  84. head_mask: Optional[torch.Tensor] = None,
  85. **kwargs,
  86. ):
  87. attn_weights = torch.matmul(query, key.transpose(2, 3)) * scaling
  88. if attention_mask is not None:
  89. causal_mask = attention_mask[:, :, :, : key.shape[-2]]
  90. attn_weights = attn_weights + causal_mask
  91. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  92. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  93. if head_mask is not None:
  94. attn_weights = attn_weights * head_mask.view(1, -1, 1, 1)
  95. attn_output = torch.matmul(attn_weights, value)
  96. attn_output = attn_output.transpose(1, 2).contiguous()
  97. return attn_output, attn_weights
  98. # Copied from transformers.models.align.modeling_align.AlignTextSelfAttention with AlignText->Splinter
  99. class SplinterSelfAttention(nn.Module):
  100. def __init__(self, config):
  101. super().__init__()
  102. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  103. raise ValueError(
  104. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  105. f"heads ({config.num_attention_heads})"
  106. )
  107. self.config = config
  108. self.num_attention_heads = config.num_attention_heads
  109. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  110. self.all_head_size = self.num_attention_heads * self.attention_head_size
  111. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  112. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  113. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  114. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  115. self.attention_dropout = config.attention_probs_dropout_prob
  116. self.scaling = self.attention_head_size**-0.5
  117. def forward(
  118. self,
  119. hidden_states: torch.Tensor,
  120. attention_mask: Optional[torch.FloatTensor] = None,
  121. head_mask: Optional[torch.FloatTensor] = None,
  122. output_attentions: Optional[bool] = False,
  123. **kwargs,
  124. ) -> tuple[torch.Tensor]:
  125. input_shape = hidden_states.shape[:-1]
  126. hidden_shape = (*input_shape, -1, self.attention_head_size)
  127. query_states = self.query(hidden_states).view(hidden_shape).transpose(1, 2)
  128. key_states = self.key(hidden_states).view(hidden_shape).transpose(1, 2)
  129. value_states = self.value(hidden_states).view(hidden_shape).transpose(1, 2)
  130. attention_interface: Callable = eager_attention_forward
  131. if self.config._attn_implementation != "eager":
  132. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  133. attn_output, attn_weights = attention_interface(
  134. self,
  135. query_states,
  136. key_states,
  137. value_states,
  138. attention_mask,
  139. dropout=0.0 if not self.training else self.attention_dropout,
  140. scaling=self.scaling,
  141. head_mask=head_mask,
  142. **kwargs,
  143. )
  144. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  145. outputs = (attn_output, attn_weights) if output_attentions else (attn_output,)
  146. return outputs
  147. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput with Bert->Splinter
  148. class SplinterSelfOutput(nn.Module):
  149. def __init__(self, config):
  150. super().__init__()
  151. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  152. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  153. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  154. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  155. hidden_states = self.dense(hidden_states)
  156. hidden_states = self.dropout(hidden_states)
  157. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  158. return hidden_states
  159. # Copied from transformers.models.align.modeling_align.AlignTextAttention with AlignText->Splinter
  160. class SplinterAttention(nn.Module):
  161. def __init__(self, config):
  162. super().__init__()
  163. self.self = SplinterSelfAttention(config)
  164. self.output = SplinterSelfOutput(config)
  165. self.pruned_heads = set()
  166. def prune_heads(self, heads):
  167. if len(heads) == 0:
  168. return
  169. heads, index = find_pruneable_heads_and_indices(
  170. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  171. )
  172. # Prune linear layers
  173. self.self.query = prune_linear_layer(self.self.query, index)
  174. self.self.key = prune_linear_layer(self.self.key, index)
  175. self.self.value = prune_linear_layer(self.self.value, index)
  176. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  177. # Update hyper params and store pruned heads
  178. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  179. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  180. self.pruned_heads = self.pruned_heads.union(heads)
  181. def forward(
  182. self,
  183. hidden_states: torch.Tensor,
  184. attention_mask: Optional[torch.FloatTensor] = None,
  185. head_mask: Optional[torch.FloatTensor] = None,
  186. output_attentions: Optional[bool] = False,
  187. **kwargs,
  188. ) -> tuple[torch.Tensor]:
  189. self_outputs = self.self(
  190. hidden_states,
  191. attention_mask=attention_mask,
  192. head_mask=head_mask,
  193. output_attentions=output_attentions,
  194. **kwargs,
  195. )
  196. attention_output = self.output(self_outputs[0], hidden_states)
  197. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  198. return outputs
  199. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Splinter
  200. class SplinterIntermediate(nn.Module):
  201. def __init__(self, config):
  202. super().__init__()
  203. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  204. if isinstance(config.hidden_act, str):
  205. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  206. else:
  207. self.intermediate_act_fn = config.hidden_act
  208. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  209. hidden_states = self.dense(hidden_states)
  210. hidden_states = self.intermediate_act_fn(hidden_states)
  211. return hidden_states
  212. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->Splinter
  213. class SplinterOutput(nn.Module):
  214. def __init__(self, config):
  215. super().__init__()
  216. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  217. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  218. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  219. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  220. hidden_states = self.dense(hidden_states)
  221. hidden_states = self.dropout(hidden_states)
  222. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  223. return hidden_states
  224. # Copied from transformers.models.align.modeling_align.AlignTextLayer with AlignText->Splinter
  225. class SplinterLayer(GradientCheckpointingLayer):
  226. def __init__(self, config):
  227. super().__init__()
  228. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  229. self.seq_len_dim = 1
  230. self.attention = SplinterAttention(config)
  231. self.intermediate = SplinterIntermediate(config)
  232. self.output = SplinterOutput(config)
  233. def forward(
  234. self,
  235. hidden_states: torch.Tensor,
  236. attention_mask: Optional[torch.FloatTensor] = None,
  237. head_mask: Optional[torch.FloatTensor] = None,
  238. output_attentions: Optional[bool] = False,
  239. **kwargs,
  240. ) -> tuple[torch.Tensor]:
  241. self_attention_outputs = self.attention(
  242. hidden_states,
  243. attention_mask=attention_mask,
  244. head_mask=head_mask,
  245. output_attentions=output_attentions,
  246. **kwargs,
  247. )
  248. attention_output = self_attention_outputs[0]
  249. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  250. layer_output = apply_chunking_to_forward(
  251. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  252. )
  253. outputs = (layer_output,) + outputs
  254. return outputs
  255. def feed_forward_chunk(self, attention_output):
  256. intermediate_output = self.intermediate(attention_output)
  257. layer_output = self.output(intermediate_output, attention_output)
  258. return layer_output
  259. # Copied from transformers.models.align.modeling_align.AlignTextEncoder with AlignText->Splinter
  260. class SplinterEncoder(nn.Module):
  261. def __init__(self, config):
  262. super().__init__()
  263. self.config = config
  264. self.layer = nn.ModuleList([SplinterLayer(config) for i in range(config.num_hidden_layers)])
  265. self.gradient_checkpointing = False
  266. @can_return_tuple
  267. def forward(
  268. self,
  269. hidden_states: torch.Tensor,
  270. attention_mask: Optional[torch.FloatTensor] = None,
  271. head_mask: Optional[torch.FloatTensor] = None,
  272. output_attentions: Optional[bool] = False,
  273. output_hidden_states: Optional[bool] = False,
  274. return_dict: Optional[bool] = True,
  275. **kwargs,
  276. ) -> Union[tuple[torch.Tensor], BaseModelOutput]:
  277. all_hidden_states = () if output_hidden_states else None
  278. all_self_attentions = () if output_attentions else None
  279. for i, layer_module in enumerate(self.layer):
  280. if output_hidden_states:
  281. all_hidden_states = all_hidden_states + (hidden_states,)
  282. layer_head_mask = head_mask[i] if head_mask is not None else None
  283. layer_outputs = layer_module(
  284. hidden_states=hidden_states,
  285. attention_mask=attention_mask,
  286. head_mask=layer_head_mask,
  287. output_attentions=output_attentions,
  288. **kwargs,
  289. )
  290. hidden_states = layer_outputs[0]
  291. if output_attentions:
  292. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  293. if output_hidden_states:
  294. all_hidden_states = all_hidden_states + (hidden_states,)
  295. return BaseModelOutput(
  296. last_hidden_state=hidden_states,
  297. hidden_states=all_hidden_states,
  298. attentions=all_self_attentions,
  299. )
  300. @auto_docstring
  301. class SplinterPreTrainedModel(PreTrainedModel):
  302. config: SplinterConfig
  303. base_model_prefix = "splinter"
  304. supports_gradient_checkpointing = True
  305. def _init_weights(self, module):
  306. """Initialize the weights"""
  307. if isinstance(module, nn.Linear):
  308. # Slightly different from the TF version which uses truncated_normal for initialization
  309. # cf https://github.com/pytorch/pytorch/pull/5617
  310. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  311. if module.bias is not None:
  312. module.bias.data.zero_()
  313. elif isinstance(module, nn.Embedding):
  314. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  315. if module.padding_idx is not None:
  316. module.weight.data[module.padding_idx].zero_()
  317. elif isinstance(module, nn.LayerNorm):
  318. module.bias.data.zero_()
  319. module.weight.data.fill_(1.0)
  320. @auto_docstring
  321. class SplinterModel(SplinterPreTrainedModel):
  322. """
  323. The model is an encoder (with only self-attention) following the architecture described in [Attention is all you
  324. need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones,
  325. Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  326. """
  327. def __init__(self, config):
  328. super().__init__(config)
  329. self.config = config
  330. self.embeddings = SplinterEmbeddings(config)
  331. self.encoder = SplinterEncoder(config)
  332. # Initialize weights and apply final processing
  333. self.post_init()
  334. def get_input_embeddings(self):
  335. return self.embeddings.word_embeddings
  336. def set_input_embeddings(self, value):
  337. self.embeddings.word_embeddings = value
  338. def _prune_heads(self, heads_to_prune):
  339. """
  340. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  341. class PreTrainedModel
  342. """
  343. for layer, heads in heads_to_prune.items():
  344. self.encoder.layer[layer].attention.prune_heads(heads)
  345. @can_return_tuple
  346. @auto_docstring
  347. def forward(
  348. self,
  349. input_ids: Optional[torch.Tensor] = None,
  350. attention_mask: Optional[torch.Tensor] = None,
  351. token_type_ids: Optional[torch.Tensor] = None,
  352. position_ids: Optional[torch.Tensor] = None,
  353. head_mask: Optional[torch.Tensor] = None,
  354. inputs_embeds: Optional[torch.Tensor] = None,
  355. output_attentions: Optional[bool] = None,
  356. output_hidden_states: Optional[bool] = None,
  357. return_dict: Optional[bool] = None,
  358. ) -> Union[tuple, BaseModelOutput]:
  359. r"""
  360. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  361. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  362. 1]`:
  363. - 0 corresponds to a *sentence A* token,
  364. - 1 corresponds to a *sentence B* token.
  365. [What are token type IDs?](../glossary#token-type-ids)
  366. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  367. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  368. config.max_position_embeddings - 1]`.
  369. [What are position IDs?](../glossary#position-ids)
  370. """
  371. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  372. output_hidden_states = (
  373. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  374. )
  375. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  376. if input_ids is not None and inputs_embeds is not None:
  377. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  378. elif input_ids is not None:
  379. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  380. input_shape = input_ids.size()
  381. elif inputs_embeds is not None:
  382. input_shape = inputs_embeds.size()[:-1]
  383. else:
  384. raise ValueError("You have to specify either input_ids or inputs_embeds")
  385. batch_size, seq_length = input_shape
  386. device = input_ids.device if input_ids is not None else inputs_embeds.device
  387. if attention_mask is None:
  388. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  389. if token_type_ids is None:
  390. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  391. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  392. # ourselves in which case we just need to make it broadcastable to all heads.
  393. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  394. # Prepare head mask if needed
  395. # 1.0 in head_mask indicate we keep the head
  396. # attention_probs has shape bsz x n_heads x N x N
  397. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  398. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  399. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  400. embedding_output = self.embeddings(
  401. input_ids=input_ids,
  402. position_ids=position_ids,
  403. token_type_ids=token_type_ids,
  404. inputs_embeds=inputs_embeds,
  405. )
  406. encoder_outputs = self.encoder(
  407. embedding_output,
  408. attention_mask=extended_attention_mask,
  409. head_mask=head_mask,
  410. output_attentions=output_attentions,
  411. output_hidden_states=output_hidden_states,
  412. return_dict=True,
  413. )
  414. sequence_output = encoder_outputs[0]
  415. return BaseModelOutput(
  416. last_hidden_state=sequence_output,
  417. hidden_states=encoder_outputs.hidden_states,
  418. attentions=encoder_outputs.attentions,
  419. )
  420. class SplinterFullyConnectedLayer(nn.Module):
  421. def __init__(self, input_dim, output_dim, hidden_act="gelu"):
  422. super().__init__()
  423. self.input_dim = input_dim
  424. self.output_dim = output_dim
  425. self.dense = nn.Linear(self.input_dim, self.output_dim)
  426. self.act_fn = ACT2FN[hidden_act]
  427. self.LayerNorm = nn.LayerNorm(self.output_dim)
  428. def forward(self, inputs: torch.Tensor) -> torch.Tensor:
  429. hidden_states = self.dense(inputs)
  430. hidden_states = self.act_fn(hidden_states)
  431. hidden_states = self.LayerNorm(hidden_states)
  432. return hidden_states
  433. class QuestionAwareSpanSelectionHead(nn.Module):
  434. """
  435. Implementation of Question-Aware Span Selection (QASS) head, described in Splinter's paper:
  436. """
  437. def __init__(self, config):
  438. super().__init__()
  439. self.query_start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  440. self.query_end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  441. self.start_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  442. self.end_transform = SplinterFullyConnectedLayer(config.hidden_size, config.hidden_size)
  443. self.start_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  444. self.end_classifier = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
  445. def forward(self, inputs, positions):
  446. _, _, dim = inputs.size()
  447. index = positions.unsqueeze(-1).repeat(1, 1, dim) # [batch_size, num_positions, dim]
  448. gathered_reps = torch.gather(inputs, dim=1, index=index) # [batch_size, num_positions, dim]
  449. query_start_reps = self.query_start_transform(gathered_reps) # [batch_size, num_positions, dim]
  450. query_end_reps = self.query_end_transform(gathered_reps) # [batch_size, num_positions, dim]
  451. start_reps = self.start_transform(inputs) # [batch_size, seq_length, dim]
  452. end_reps = self.end_transform(inputs) # [batch_size, seq_length, dim]
  453. hidden_states = self.start_classifier(query_start_reps) # [batch_size, num_positions, dim]
  454. start_reps = start_reps.permute(0, 2, 1) # [batch_size, dim, seq_length]
  455. start_logits = torch.matmul(hidden_states, start_reps)
  456. hidden_states = self.end_classifier(query_end_reps)
  457. end_reps = end_reps.permute(0, 2, 1)
  458. end_logits = torch.matmul(hidden_states, end_reps)
  459. return start_logits, end_logits
  460. @auto_docstring
  461. class SplinterForQuestionAnswering(SplinterPreTrainedModel):
  462. def __init__(self, config):
  463. super().__init__(config)
  464. self.splinter = SplinterModel(config)
  465. self.splinter_qass = QuestionAwareSpanSelectionHead(config)
  466. self.question_token_id = config.question_token_id
  467. # Initialize weights and apply final processing
  468. self.post_init()
  469. @auto_docstring
  470. def forward(
  471. self,
  472. input_ids: Optional[torch.Tensor] = None,
  473. attention_mask: Optional[torch.Tensor] = None,
  474. token_type_ids: Optional[torch.Tensor] = None,
  475. position_ids: Optional[torch.Tensor] = None,
  476. head_mask: Optional[torch.Tensor] = None,
  477. inputs_embeds: Optional[torch.Tensor] = None,
  478. start_positions: Optional[torch.LongTensor] = None,
  479. end_positions: Optional[torch.LongTensor] = None,
  480. output_attentions: Optional[bool] = None,
  481. output_hidden_states: Optional[bool] = None,
  482. return_dict: Optional[bool] = None,
  483. question_positions: Optional[torch.LongTensor] = None,
  484. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  485. r"""
  486. token_type_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  487. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  488. 1]`:
  489. - 0 corresponds to a *sentence A* token,
  490. - 1 corresponds to a *sentence B* token.
  491. [What are token type IDs?](../glossary#token-type-ids)
  492. position_ids (`torch.LongTensor` of shape `batch_size, sequence_length`, *optional*):
  493. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  494. config.max_position_embeddings - 1]`.
  495. [What are position IDs?](../glossary#position-ids)
  496. question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  497. The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
  498. num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
  499. the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
  500. sequence_length)`.
  501. """
  502. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  503. question_positions_were_none = False
  504. if question_positions is None:
  505. if input_ids is not None:
  506. question_position_for_each_example = torch.argmax(
  507. (torch.eq(input_ids, self.question_token_id)).int(), dim=-1
  508. )
  509. else:
  510. question_position_for_each_example = torch.zeros(
  511. inputs_embeds.size(0), dtype=torch.long, layout=inputs_embeds.layout, device=inputs_embeds.device
  512. )
  513. question_positions = question_position_for_each_example.unsqueeze(-1)
  514. question_positions_were_none = True
  515. outputs = self.splinter(
  516. input_ids,
  517. attention_mask=attention_mask,
  518. token_type_ids=token_type_ids,
  519. position_ids=position_ids,
  520. head_mask=head_mask,
  521. inputs_embeds=inputs_embeds,
  522. output_attentions=output_attentions,
  523. output_hidden_states=output_hidden_states,
  524. return_dict=return_dict,
  525. )
  526. sequence_output = outputs[0]
  527. start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
  528. if question_positions_were_none:
  529. start_logits, end_logits = start_logits.squeeze(1), end_logits.squeeze(1)
  530. if attention_mask is not None:
  531. start_logits = start_logits + (1 - attention_mask) * torch.finfo(start_logits.dtype).min
  532. end_logits = end_logits + (1 - attention_mask) * torch.finfo(end_logits.dtype).min
  533. total_loss = None
  534. if start_positions is not None and end_positions is not None:
  535. # If we are on multi-GPU, split add a dimension
  536. if len(start_positions.size()) > 1:
  537. start_positions = start_positions.squeeze(-1)
  538. if len(end_positions.size()) > 1:
  539. end_positions = end_positions.squeeze(-1)
  540. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  541. ignored_index = start_logits.size(1)
  542. start_positions.clamp_(0, ignored_index)
  543. end_positions.clamp_(0, ignored_index)
  544. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  545. start_loss = loss_fct(start_logits, start_positions)
  546. end_loss = loss_fct(end_logits, end_positions)
  547. total_loss = (start_loss + end_loss) / 2
  548. if not return_dict:
  549. output = (start_logits, end_logits) + outputs[1:]
  550. return ((total_loss,) + output) if total_loss is not None else output
  551. return QuestionAnsweringModelOutput(
  552. loss=total_loss,
  553. start_logits=start_logits,
  554. end_logits=end_logits,
  555. hidden_states=outputs.hidden_states,
  556. attentions=outputs.attentions,
  557. )
  558. @dataclass
  559. @auto_docstring(
  560. custom_intro="""
  561. Class for outputs of Splinter as a span selection model.
  562. """
  563. )
  564. class SplinterForPreTrainingOutput(ModelOutput):
  565. r"""
  566. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when start and end positions are provided):
  567. Total span extraction loss is the sum of a Cross-Entropy for the start and end positions.
  568. start_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
  569. Span-start scores (before SoftMax).
  570. end_logits (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length)`):
  571. Span-end scores (before SoftMax).
  572. """
  573. loss: Optional[torch.FloatTensor] = None
  574. start_logits: Optional[torch.FloatTensor] = None
  575. end_logits: Optional[torch.FloatTensor] = None
  576. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  577. attentions: Optional[tuple[torch.FloatTensor]] = None
  578. @auto_docstring(
  579. custom_intro="""
  580. Splinter Model for the recurring span selection task as done during the pretraining. The difference to the QA task
  581. is that we do not have a question, but multiple question tokens that replace the occurrences of recurring spans
  582. instead.
  583. """
  584. )
  585. class SplinterForPreTraining(SplinterPreTrainedModel):
  586. def __init__(self, config):
  587. super().__init__(config)
  588. self.splinter = SplinterModel(config)
  589. self.splinter_qass = QuestionAwareSpanSelectionHead(config)
  590. self.question_token_id = config.question_token_id
  591. # Initialize weights and apply final processing
  592. self.post_init()
  593. @auto_docstring
  594. def forward(
  595. self,
  596. input_ids: Optional[torch.Tensor] = None,
  597. attention_mask: Optional[torch.Tensor] = None,
  598. token_type_ids: Optional[torch.Tensor] = None,
  599. position_ids: Optional[torch.Tensor] = None,
  600. head_mask: Optional[torch.Tensor] = None,
  601. inputs_embeds: Optional[torch.Tensor] = None,
  602. start_positions: Optional[torch.LongTensor] = None,
  603. end_positions: Optional[torch.LongTensor] = None,
  604. output_attentions: Optional[bool] = None,
  605. output_hidden_states: Optional[bool] = None,
  606. return_dict: Optional[bool] = None,
  607. question_positions: Optional[torch.LongTensor] = None,
  608. ) -> Union[tuple, SplinterForPreTrainingOutput]:
  609. r"""
  610. input_ids (`torch.LongTensor` of shape `(batch_size, num_questions, sequence_length)`):
  611. Indices of input sequence tokens in the vocabulary.
  612. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  613. [`PreTrainedTokenizer.__call__`] for details.
  614. [What are input IDs?](../glossary#input-ids)
  615. token_type_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
  616. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  617. 1]`:
  618. - 0 corresponds to a *sentence A* token,
  619. - 1 corresponds to a *sentence B* token.
  620. [What are token type IDs?](../glossary#token-type-ids)
  621. position_ids (`torch.LongTensor` of shape `batch_size, num_questions, sequence_length`, *optional*):
  622. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  623. config.max_position_embeddings - 1]`.
  624. [What are position IDs?](../glossary#position-ids)
  625. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_questions, sequence_length, hidden_size)`, *optional*):
  626. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  627. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  628. model's internal embedding lookup matrix.
  629. start_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  630. Labels for position (index) of the start of the labelled span for computing the token classification loss.
  631. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  632. are not taken into account for computing the loss.
  633. end_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  634. Labels for position (index) of the end of the labelled span for computing the token classification loss.
  635. Positions are clamped to the length of the sequence (`sequence_length`). Position outside of the sequence
  636. are not taken into account for computing the loss.
  637. question_positions (`torch.LongTensor` of shape `(batch_size, num_questions)`, *optional*):
  638. The positions of all question tokens. If given, start_logits and end_logits will be of shape `(batch_size,
  639. num_questions, sequence_length)`. If None, the first question token in each sequence in the batch will be
  640. the only one for which start_logits and end_logits are calculated and they will be of shape `(batch_size,
  641. sequence_length)`.
  642. """
  643. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  644. if question_positions is None and start_positions is not None and end_positions is not None:
  645. raise TypeError("question_positions must be specified in order to calculate the loss")
  646. elif question_positions is None and input_ids is None:
  647. raise TypeError("question_positions must be specified when input_embeds is used")
  648. elif question_positions is None:
  649. question_positions = self._prepare_question_positions(input_ids)
  650. outputs = self.splinter(
  651. input_ids,
  652. attention_mask=attention_mask,
  653. token_type_ids=token_type_ids,
  654. position_ids=position_ids,
  655. head_mask=head_mask,
  656. inputs_embeds=inputs_embeds,
  657. output_attentions=output_attentions,
  658. output_hidden_states=output_hidden_states,
  659. return_dict=return_dict,
  660. )
  661. sequence_output = outputs[0]
  662. batch_size, sequence_length, dim = sequence_output.size()
  663. # [batch_size, num_questions, sequence_length]
  664. start_logits, end_logits = self.splinter_qass(sequence_output, question_positions)
  665. num_questions = question_positions.size(1)
  666. if attention_mask is not None:
  667. attention_mask_for_each_question = attention_mask.unsqueeze(1).expand(
  668. batch_size, num_questions, sequence_length
  669. )
  670. start_logits = start_logits + (1 - attention_mask_for_each_question) * torch.finfo(start_logits.dtype).min
  671. end_logits = end_logits + (1 - attention_mask_for_each_question) * torch.finfo(end_logits.dtype).min
  672. total_loss = None
  673. # [batch_size, num_questions, sequence_length]
  674. if start_positions is not None and end_positions is not None:
  675. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  676. start_positions.clamp_(0, max(0, sequence_length - 1))
  677. end_positions.clamp_(0, max(0, sequence_length - 1))
  678. # Ignore zero positions in the loss. Splinter never predicts zero
  679. # during pretraining and zero is used for padding question
  680. # tokens as well as for start and end positions of padded
  681. # question tokens.
  682. loss_fct = CrossEntropyLoss(ignore_index=self.config.pad_token_id)
  683. start_loss = loss_fct(
  684. start_logits.view(batch_size * num_questions, sequence_length),
  685. start_positions.view(batch_size * num_questions),
  686. )
  687. end_loss = loss_fct(
  688. end_logits.view(batch_size * num_questions, sequence_length),
  689. end_positions.view(batch_size * num_questions),
  690. )
  691. total_loss = (start_loss + end_loss) / 2
  692. if not return_dict:
  693. output = (start_logits, end_logits) + outputs[1:]
  694. return ((total_loss,) + output) if total_loss is not None else output
  695. return SplinterForPreTrainingOutput(
  696. loss=total_loss,
  697. start_logits=start_logits,
  698. end_logits=end_logits,
  699. hidden_states=outputs.hidden_states,
  700. attentions=outputs.attentions,
  701. )
  702. def _prepare_question_positions(self, input_ids: torch.Tensor) -> torch.Tensor:
  703. rows, flat_positions = torch.where(input_ids == self.config.question_token_id)
  704. num_questions = torch.bincount(rows)
  705. positions = torch.full(
  706. (input_ids.size(0), num_questions.max()),
  707. self.config.pad_token_id,
  708. dtype=torch.long,
  709. device=input_ids.device,
  710. )
  711. cols = torch.cat([torch.arange(n) for n in num_questions])
  712. positions[rows, cols] = flat_positions
  713. return positions
  714. __all__ = [
  715. "SplinterForQuestionAnswering",
  716. "SplinterForPreTraining",
  717. "SplinterLayer",
  718. "SplinterModel",
  719. "SplinterPreTrainedModel",
  720. ]