modeling_flaubert.py 78 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700
  1. # coding=utf-8
  2. # Copyright 2019-present CNRS, Facebook Inc. and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Flaubert model, based on XLM."""
  16. import math
  17. from dataclasses import dataclass
  18. from typing import Callable, Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import gelu, get_activation
  24. from ...cache_utils import DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_outputs import (
  27. BaseModelOutput,
  28. MaskedLMOutput,
  29. MultipleChoiceModelOutput,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutput,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  36. from ...utils import ModelOutput, auto_docstring, logging
  37. from .configuration_flaubert import FlaubertConfig
  38. logger = logging.get_logger(__name__)
  39. # Copied from transformers.models.xlm.modeling_xlm.create_sinusoidal_embeddings
  40. def create_sinusoidal_embeddings(n_pos, dim, out):
  41. position_enc = np.array([[pos / np.power(10000, 2 * (j // 2) / dim) for j in range(dim)] for pos in range(n_pos)])
  42. out.requires_grad = False
  43. out[:, 0::2] = torch.FloatTensor(np.sin(position_enc[:, 0::2]))
  44. out[:, 1::2] = torch.FloatTensor(np.cos(position_enc[:, 1::2]))
  45. out.detach_()
  46. # Copied from transformers.models.xlm.modeling_xlm.get_masks
  47. def get_masks(slen, lengths, causal, padding_mask=None):
  48. """
  49. Generate hidden states mask, and optionally an attention mask.
  50. """
  51. alen = torch.arange(slen, dtype=torch.long, device=lengths.device)
  52. if padding_mask is not None:
  53. mask = padding_mask
  54. else:
  55. assert lengths.max().item() <= slen
  56. mask = alen < lengths[:, None]
  57. # attention mask is the same as mask, or triangular inferior attention (causal)
  58. bs = lengths.size(0)
  59. if causal:
  60. attn_mask = alen[None, None, :].repeat(bs, slen, 1) <= alen[None, :, None]
  61. else:
  62. attn_mask = mask
  63. # sanity check
  64. assert mask.size() == (bs, slen)
  65. assert causal is False or attn_mask.size() == (bs, slen, slen)
  66. return mask, attn_mask
  67. # Copied from transformers.models.xlm.modeling_xlm.MultiHeadAttention
  68. class MultiHeadAttention(nn.Module):
  69. def __init__(self, n_heads, dim, config, layer_idx: int = 0):
  70. super().__init__()
  71. self.layer_id = layer_idx
  72. self.dim = dim
  73. self.n_heads = n_heads
  74. self.head_dim = dim // n_heads
  75. self.dropout = config.attention_dropout
  76. assert self.dim % self.n_heads == 0
  77. self.q_lin = nn.Linear(dim, dim)
  78. self.k_lin = nn.Linear(dim, dim)
  79. self.v_lin = nn.Linear(dim, dim)
  80. self.out_lin = nn.Linear(dim, dim)
  81. self.pruned_heads = set()
  82. def prune_heads(self, heads):
  83. attention_head_size = self.dim // self.n_heads
  84. if len(heads) == 0:
  85. return
  86. heads, index = find_pruneable_heads_and_indices(heads, self.n_heads, attention_head_size, self.pruned_heads)
  87. # Prune linear layers
  88. self.q_lin = prune_linear_layer(self.q_lin, index)
  89. self.k_lin = prune_linear_layer(self.k_lin, index)
  90. self.v_lin = prune_linear_layer(self.v_lin, index)
  91. self.out_lin = prune_linear_layer(self.out_lin, index, dim=1)
  92. # Update hyper params
  93. self.n_heads = self.n_heads - len(heads)
  94. self.dim = attention_head_size * self.n_heads
  95. self.pruned_heads = self.pruned_heads.union(heads)
  96. def forward(
  97. self,
  98. input,
  99. mask,
  100. kv=None,
  101. cache=None,
  102. head_mask=None,
  103. output_attentions=False,
  104. cache_position=None,
  105. ):
  106. """
  107. Self-attention (if kv is None) or attention over source sentence (provided by kv).
  108. """
  109. # Input is (bs, qlen, dim)
  110. # Mask is (bs, klen) (non-causal) or (bs, klen, klen)
  111. bs, qlen, dim = input.size()
  112. is_cross_attention = kv is not None
  113. mask_reshape = (bs, 1, qlen, -1) if mask.dim() == 3 else (bs, 1, 1, -1)
  114. q = self.q_lin(input).view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  115. if cache is not None:
  116. if isinstance(cache, EncoderDecoderCache):
  117. is_updated = cache.is_updated.get(self.layer_id)
  118. if is_cross_attention:
  119. # after the first generated id, we can subsequently re-use all key/value_states from cache
  120. curr_past_key_value = cache.cross_attention_cache
  121. else:
  122. curr_past_key_value = cache.self_attention_cache
  123. else:
  124. curr_past_key_value = cache
  125. current_states = kv if is_cross_attention else input
  126. if is_cross_attention and cache is not None and is_updated:
  127. # reuse k,v, cross_attentions
  128. k = curr_past_key_value.key_cache[self.layer_id]
  129. v = curr_past_key_value.value_cache[self.layer_id]
  130. else:
  131. k = self.k_lin(current_states)
  132. v = self.v_lin(current_states)
  133. k = k.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  134. v = v.view(bs, -1, self.n_heads, self.head_dim).transpose(1, 2)
  135. if cache is not None:
  136. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  137. cache_position = cache_position if not is_cross_attention else None
  138. k, v = curr_past_key_value.update(k, v, self.layer_id, {"cache_position": cache_position})
  139. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  140. if is_cross_attention:
  141. cache.is_updated[self.layer_id] = True
  142. q = q / math.sqrt(self.head_dim) # (bs, n_heads, qlen, head_dim)
  143. scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, qlen, klen)
  144. mask = (mask == 0).view(mask_reshape).expand_as(scores) # (bs, n_heads, qlen, klen)
  145. scores.masked_fill_(mask, torch.finfo(scores.dtype).min) # (bs, n_heads, qlen, klen)
  146. weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores) # (bs, n_heads, qlen, klen)
  147. weights = nn.functional.dropout(weights, p=self.dropout, training=self.training) # (bs, n_heads, qlen, klen)
  148. # Mask heads if we want to
  149. if head_mask is not None:
  150. weights = weights * head_mask
  151. context = torch.matmul(weights, v) # (bs, n_heads, qlen, head_dim)
  152. context = context.transpose(1, 2).contiguous().view(bs, -1, self.n_heads * self.head_dim)
  153. outputs = (self.out_lin(context),)
  154. if output_attentions:
  155. outputs = outputs + (weights,)
  156. return outputs
  157. # Copied from transformers.models.xlm.modeling_xlm.TransformerFFN
  158. class TransformerFFN(nn.Module):
  159. def __init__(self, in_dim, dim_hidden, out_dim, config):
  160. super().__init__()
  161. self.dropout = config.dropout
  162. self.lin1 = nn.Linear(in_dim, dim_hidden)
  163. self.lin2 = nn.Linear(dim_hidden, out_dim)
  164. self.act = gelu if config.gelu_activation else nn.functional.relu
  165. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  166. self.seq_len_dim = 1
  167. def forward(self, input):
  168. return apply_chunking_to_forward(self.ff_chunk, self.chunk_size_feed_forward, self.seq_len_dim, input)
  169. def ff_chunk(self, input):
  170. x = self.lin1(input)
  171. x = self.act(x)
  172. x = self.lin2(x)
  173. x = nn.functional.dropout(x, p=self.dropout, training=self.training)
  174. return x
  175. @auto_docstring(
  176. custom_intro="""
  177. The bare Flaubert Model transformer outputting raw hidden-states without any specific head on top.
  178. """
  179. )
  180. # Copied from transformers.models.xlm.modeling_xlm.XLMPredLayer with XLM->Flaubert
  181. class FlaubertPredLayer(nn.Module):
  182. """
  183. Prediction layer (cross_entropy or adaptive_softmax).
  184. """
  185. def __init__(self, config):
  186. super().__init__()
  187. self.asm = config.asm
  188. self.n_words = config.n_words
  189. self.pad_index = config.pad_index
  190. dim = config.emb_dim
  191. if config.asm is False:
  192. self.proj = nn.Linear(dim, config.n_words, bias=True)
  193. else:
  194. self.proj = nn.AdaptiveLogSoftmaxWithLoss(
  195. in_features=dim,
  196. n_classes=config.n_words,
  197. cutoffs=config.asm_cutoffs,
  198. div_value=config.asm_div_value,
  199. head_bias=True, # default is False
  200. )
  201. def forward(self, x, y=None):
  202. """Compute the loss, and optionally the scores."""
  203. outputs = ()
  204. if self.asm is False:
  205. scores = self.proj(x)
  206. outputs = (scores,) + outputs
  207. if y is not None:
  208. loss = nn.functional.cross_entropy(scores.view(-1, self.n_words), y.view(-1), reduction="mean")
  209. outputs = (loss,) + outputs
  210. else:
  211. scores = self.proj.log_prob(x)
  212. outputs = (scores,) + outputs
  213. if y is not None:
  214. _, loss = self.proj(x, y)
  215. outputs = (loss,) + outputs
  216. return outputs
  217. @dataclass
  218. @auto_docstring(
  219. custom_intro="""
  220. Base class for outputs of question answering models using a [`~modeling_utils.FlaubertSQuADHead`].
  221. """
  222. )
  223. # Copied from transformers.models.xlm.modeling_xlm.XLMSquadHeadOutput with XLM->Flaubert
  224. class FlaubertSquadHeadOutput(ModelOutput):
  225. r"""
  226. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  227. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  228. losses.
  229. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  230. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  231. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  232. Indices for the top config.start_n_top start token possibilities (beam-search).
  233. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  234. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  235. (beam-search).
  236. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  237. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  238. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  239. Log probabilities for the `is_impossible` label of the answers.
  240. """
  241. loss: Optional[torch.FloatTensor] = None
  242. start_top_log_probs: Optional[torch.FloatTensor] = None
  243. start_top_index: Optional[torch.LongTensor] = None
  244. end_top_log_probs: Optional[torch.FloatTensor] = None
  245. end_top_index: Optional[torch.LongTensor] = None
  246. cls_logits: Optional[torch.FloatTensor] = None
  247. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerStartLogits with XLM->Flaubert
  248. class FlaubertPoolerStartLogits(nn.Module):
  249. """
  250. Compute SQuAD start logits from sequence hidden states.
  251. Args:
  252. config ([`FlaubertConfig`]):
  253. The config used by the model, will be used to grab the `hidden_size` of the model.
  254. """
  255. def __init__(self, config: FlaubertConfig):
  256. super().__init__()
  257. self.dense = nn.Linear(config.hidden_size, 1)
  258. def forward(
  259. self, hidden_states: torch.FloatTensor, p_mask: Optional[torch.FloatTensor] = None
  260. ) -> torch.FloatTensor:
  261. """
  262. Args:
  263. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  264. The final hidden states of the model.
  265. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  266. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  267. should be masked.
  268. Returns:
  269. `torch.FloatTensor`: The start logits for SQuAD.
  270. """
  271. x = self.dense(hidden_states).squeeze(-1)
  272. if p_mask is not None:
  273. if p_mask.dtype == torch.float16:
  274. x = x * (1 - p_mask) - 65500 * p_mask
  275. else:
  276. x = x * (1 - p_mask) - 1e30 * p_mask
  277. return x
  278. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerEndLogits with XLM->Flaubert
  279. class FlaubertPoolerEndLogits(nn.Module):
  280. """
  281. Compute SQuAD end logits from sequence hidden states.
  282. Args:
  283. config ([`FlaubertConfig`]):
  284. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  285. to use.
  286. """
  287. def __init__(self, config: FlaubertConfig):
  288. super().__init__()
  289. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  290. self.activation = nn.Tanh()
  291. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  292. self.dense_1 = nn.Linear(config.hidden_size, 1)
  293. def forward(
  294. self,
  295. hidden_states: torch.FloatTensor,
  296. start_states: Optional[torch.FloatTensor] = None,
  297. start_positions: Optional[torch.LongTensor] = None,
  298. p_mask: Optional[torch.FloatTensor] = None,
  299. ) -> torch.FloatTensor:
  300. """
  301. Args:
  302. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  303. The final hidden states of the model.
  304. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  305. The hidden states of the first tokens for the labeled span.
  306. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  307. The position of the first token for the labeled span.
  308. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  309. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  310. should be masked.
  311. <Tip>
  312. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  313. `start_states`.
  314. </Tip>
  315. Returns:
  316. `torch.FloatTensor`: The end logits for SQuAD.
  317. """
  318. assert start_states is not None or start_positions is not None, (
  319. "One of start_states, start_positions should be not None"
  320. )
  321. if start_positions is not None:
  322. slen, hsz = hidden_states.shape[-2:]
  323. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  324. start_states = hidden_states.gather(-2, start_positions) # shape (bsz, 1, hsz)
  325. start_states = start_states.expand(-1, slen, -1) # shape (bsz, slen, hsz)
  326. x = self.dense_0(torch.cat([hidden_states, start_states], dim=-1))
  327. x = self.activation(x)
  328. x = self.LayerNorm(x)
  329. x = self.dense_1(x).squeeze(-1)
  330. if p_mask is not None:
  331. if p_mask.dtype == torch.float16:
  332. x = x * (1 - p_mask) - 65500 * p_mask
  333. else:
  334. x = x * (1 - p_mask) - 1e30 * p_mask
  335. return x
  336. # Copied from transformers.models.xlm.modeling_xlm.XLMPoolerAnswerClass with XLM->Flaubert
  337. class FlaubertPoolerAnswerClass(nn.Module):
  338. """
  339. Compute SQuAD 2.0 answer class from classification and start tokens hidden states.
  340. Args:
  341. config ([`FlaubertConfig`]):
  342. The config used by the model, will be used to grab the `hidden_size` of the model.
  343. """
  344. def __init__(self, config: FlaubertConfig):
  345. super().__init__()
  346. self.dense_0 = nn.Linear(config.hidden_size * 2, config.hidden_size)
  347. self.activation = nn.Tanh()
  348. self.dense_1 = nn.Linear(config.hidden_size, 1, bias=False)
  349. def forward(
  350. self,
  351. hidden_states: torch.FloatTensor,
  352. start_states: Optional[torch.FloatTensor] = None,
  353. start_positions: Optional[torch.LongTensor] = None,
  354. cls_index: Optional[torch.LongTensor] = None,
  355. ) -> torch.FloatTensor:
  356. """
  357. Args:
  358. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  359. The final hidden states of the model.
  360. start_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`, *optional*):
  361. The hidden states of the first tokens for the labeled span.
  362. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  363. The position of the first token for the labeled span.
  364. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  365. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  366. <Tip>
  367. One of `start_states` or `start_positions` should be not `None`. If both are set, `start_positions` overrides
  368. `start_states`.
  369. </Tip>
  370. Returns:
  371. `torch.FloatTensor`: The SQuAD 2.0 answer class.
  372. """
  373. # No dependency on end_feature so that we can obtain one single `cls_logits` for each sample.
  374. hsz = hidden_states.shape[-1]
  375. assert start_states is not None or start_positions is not None, (
  376. "One of start_states, start_positions should be not None"
  377. )
  378. if start_positions is not None:
  379. start_positions = start_positions[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  380. start_states = hidden_states.gather(-2, start_positions).squeeze(-2) # shape (bsz, hsz)
  381. if cls_index is not None:
  382. cls_index = cls_index[:, None, None].expand(-1, -1, hsz) # shape (bsz, 1, hsz)
  383. cls_token_state = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, hsz)
  384. else:
  385. cls_token_state = hidden_states[:, -1, :] # shape (bsz, hsz)
  386. x = self.dense_0(torch.cat([start_states, cls_token_state], dim=-1))
  387. x = self.activation(x)
  388. x = self.dense_1(x).squeeze(-1)
  389. return x
  390. # Copied from transformers.models.xlm.modeling_xlm.XLMSQuADHead with XLM->Flaubert
  391. class FlaubertSQuADHead(nn.Module):
  392. r"""
  393. A SQuAD head inspired by XLNet.
  394. Args:
  395. config ([`FlaubertConfig`]):
  396. The config used by the model, will be used to grab the `hidden_size` of the model and the `layer_norm_eps`
  397. to use.
  398. """
  399. def __init__(self, config: FlaubertConfig):
  400. super().__init__()
  401. self.start_n_top = config.start_n_top
  402. self.end_n_top = config.end_n_top
  403. self.start_logits = FlaubertPoolerStartLogits(config)
  404. self.end_logits = FlaubertPoolerEndLogits(config)
  405. self.answer_class = FlaubertPoolerAnswerClass(config)
  406. @auto_docstring
  407. def forward(
  408. self,
  409. hidden_states: torch.FloatTensor,
  410. start_positions: Optional[torch.LongTensor] = None,
  411. end_positions: Optional[torch.LongTensor] = None,
  412. cls_index: Optional[torch.LongTensor] = None,
  413. is_impossible: Optional[torch.LongTensor] = None,
  414. p_mask: Optional[torch.FloatTensor] = None,
  415. return_dict: bool = False,
  416. ) -> Union[FlaubertSquadHeadOutput, tuple[torch.FloatTensor]]:
  417. r"""
  418. hidden_states (`torch.FloatTensor` of shape `(batch_size, seq_len, hidden_size)`):
  419. Final hidden states of the model on the sequence tokens.
  420. start_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  421. Positions of the first token for the labeled span.
  422. end_positions (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  423. Positions of the last token for the labeled span.
  424. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  425. Position of the CLS token for each sentence in the batch. If `None`, takes the last token.
  426. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  427. Whether the question has a possible answer in the paragraph or not.
  428. p_mask (`torch.FloatTensor` of shape `(batch_size, seq_len)`, *optional*):
  429. Mask for tokens at invalid position, such as query and special symbols (PAD, SEP, CLS). 1.0 means token
  430. should be masked.
  431. """
  432. start_logits = self.start_logits(hidden_states, p_mask=p_mask)
  433. if start_positions is not None and end_positions is not None:
  434. # If we are on multi-GPU, let's remove the dimension added by batch splitting
  435. for x in (start_positions, end_positions, cls_index, is_impossible):
  436. if x is not None and x.dim() > 1:
  437. x.squeeze_(-1)
  438. # during training, compute the end logits based on the ground truth of the start position
  439. end_logits = self.end_logits(hidden_states, start_positions=start_positions, p_mask=p_mask)
  440. loss_fct = CrossEntropyLoss()
  441. start_loss = loss_fct(start_logits, start_positions)
  442. end_loss = loss_fct(end_logits, end_positions)
  443. total_loss = (start_loss + end_loss) / 2
  444. if cls_index is not None and is_impossible is not None:
  445. # Predict answerability from the representation of CLS and START
  446. cls_logits = self.answer_class(hidden_states, start_positions=start_positions, cls_index=cls_index)
  447. loss_fct_cls = nn.BCEWithLogitsLoss()
  448. cls_loss = loss_fct_cls(cls_logits, is_impossible)
  449. # note(zhiliny): by default multiply the loss by 0.5 so that the scale is comparable to start_loss and end_loss
  450. total_loss += cls_loss * 0.5
  451. return FlaubertSquadHeadOutput(loss=total_loss) if return_dict else (total_loss,)
  452. else:
  453. # during inference, compute the end logits based on beam search
  454. bsz, slen, hsz = hidden_states.size()
  455. start_log_probs = nn.functional.softmax(start_logits, dim=-1) # shape (bsz, slen)
  456. start_top_log_probs, start_top_index = torch.topk(
  457. start_log_probs, self.start_n_top, dim=-1
  458. ) # shape (bsz, start_n_top)
  459. start_top_index_exp = start_top_index.unsqueeze(-1).expand(-1, -1, hsz) # shape (bsz, start_n_top, hsz)
  460. start_states = torch.gather(hidden_states, -2, start_top_index_exp) # shape (bsz, start_n_top, hsz)
  461. start_states = start_states.unsqueeze(1).expand(-1, slen, -1, -1) # shape (bsz, slen, start_n_top, hsz)
  462. hidden_states_expanded = hidden_states.unsqueeze(2).expand_as(
  463. start_states
  464. ) # shape (bsz, slen, start_n_top, hsz)
  465. p_mask = p_mask.unsqueeze(-1) if p_mask is not None else None
  466. end_logits = self.end_logits(hidden_states_expanded, start_states=start_states, p_mask=p_mask)
  467. end_log_probs = nn.functional.softmax(end_logits, dim=1) # shape (bsz, slen, start_n_top)
  468. end_top_log_probs, end_top_index = torch.topk(
  469. end_log_probs, self.end_n_top, dim=1
  470. ) # shape (bsz, end_n_top, start_n_top)
  471. end_top_log_probs = end_top_log_probs.view(-1, self.start_n_top * self.end_n_top)
  472. end_top_index = end_top_index.view(-1, self.start_n_top * self.end_n_top)
  473. start_states = torch.einsum("blh,bl->bh", hidden_states, start_log_probs)
  474. cls_logits = self.answer_class(hidden_states, start_states=start_states, cls_index=cls_index)
  475. if not return_dict:
  476. return (start_top_log_probs, start_top_index, end_top_log_probs, end_top_index, cls_logits)
  477. else:
  478. return FlaubertSquadHeadOutput(
  479. start_top_log_probs=start_top_log_probs,
  480. start_top_index=start_top_index,
  481. end_top_log_probs=end_top_log_probs,
  482. end_top_index=end_top_index,
  483. cls_logits=cls_logits,
  484. )
  485. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->Flaubert
  486. class FlaubertSequenceSummary(nn.Module):
  487. r"""
  488. Compute a single vector summary of a sequence hidden states.
  489. Args:
  490. config ([`FlaubertConfig`]):
  491. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  492. config class of your model for the default values it uses):
  493. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  494. - `"last"` -- Take the last token hidden state (like XLNet)
  495. - `"first"` -- Take the first token hidden state (like Bert)
  496. - `"mean"` -- Take the mean of all tokens hidden states
  497. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  498. - `"attn"` -- Not implemented now, use multi-head attention
  499. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  500. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  501. (otherwise to `config.hidden_size`).
  502. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  503. another string or `None` will add no activation.
  504. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  505. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  506. """
  507. def __init__(self, config: FlaubertConfig):
  508. super().__init__()
  509. self.summary_type = getattr(config, "summary_type", "last")
  510. if self.summary_type == "attn":
  511. # We should use a standard multi-head attention module with absolute positional embedding for that.
  512. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  513. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  514. raise NotImplementedError
  515. self.summary = nn.Identity()
  516. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  517. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  518. num_classes = config.num_labels
  519. else:
  520. num_classes = config.hidden_size
  521. self.summary = nn.Linear(config.hidden_size, num_classes)
  522. activation_string = getattr(config, "summary_activation", None)
  523. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  524. self.first_dropout = nn.Identity()
  525. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  526. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  527. self.last_dropout = nn.Identity()
  528. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  529. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  530. def forward(
  531. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  532. ) -> torch.FloatTensor:
  533. """
  534. Compute a single vector summary of a sequence hidden states.
  535. Args:
  536. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  537. The hidden states of the last layer.
  538. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  539. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  540. Returns:
  541. `torch.FloatTensor`: The summary of the sequence hidden states.
  542. """
  543. if self.summary_type == "last":
  544. output = hidden_states[:, -1]
  545. elif self.summary_type == "first":
  546. output = hidden_states[:, 0]
  547. elif self.summary_type == "mean":
  548. output = hidden_states.mean(dim=1)
  549. elif self.summary_type == "cls_index":
  550. if cls_index is None:
  551. cls_index = torch.full_like(
  552. hidden_states[..., :1, :],
  553. hidden_states.shape[-2] - 1,
  554. dtype=torch.long,
  555. )
  556. else:
  557. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  558. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  559. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  560. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  561. elif self.summary_type == "attn":
  562. raise NotImplementedError
  563. output = self.first_dropout(output)
  564. output = self.summary(output)
  565. output = self.activation(output)
  566. output = self.last_dropout(output)
  567. return output
  568. @auto_docstring
  569. # Copied from transformers.models.xlm.modeling_xlm.XLMPreTrainedModel with XLM->Flaubert
  570. class FlaubertPreTrainedModel(PreTrainedModel):
  571. config: FlaubertConfig
  572. load_tf_weights = None
  573. base_model_prefix = "transformer"
  574. def __init__(self, *inputs, **kwargs):
  575. super().__init__(*inputs, **kwargs)
  576. @property
  577. def dummy_inputs(self):
  578. inputs_list = torch.tensor([[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]])
  579. attns_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  580. if self.config.use_lang_emb and self.config.n_langs > 1:
  581. langs_list = torch.tensor([[1, 1, 0, 0, 1], [1, 1, 1, 0, 0], [1, 0, 0, 1, 1]])
  582. else:
  583. langs_list = None
  584. return {"input_ids": inputs_list, "attention_mask": attns_list, "langs": langs_list}
  585. def _init_weights(self, module):
  586. """Initialize the weights."""
  587. if isinstance(module, nn.Embedding):
  588. if self.config is not None and self.config.embed_init_std is not None:
  589. nn.init.normal_(module.weight, mean=0, std=self.config.embed_init_std)
  590. if module.padding_idx is not None:
  591. module.weight.data[module.padding_idx].zero_()
  592. if isinstance(module, nn.Linear):
  593. if self.config is not None and self.config.init_std is not None:
  594. nn.init.normal_(module.weight, mean=0, std=self.config.init_std)
  595. if module.bias is not None:
  596. nn.init.constant_(module.bias, 0.0)
  597. if isinstance(module, nn.LayerNorm):
  598. module.bias.data.zero_()
  599. module.weight.data.fill_(1.0)
  600. if isinstance(module, FlaubertModel) and self.config.sinusoidal_embeddings:
  601. create_sinusoidal_embeddings(
  602. self.config.max_position_embeddings, self.config.emb_dim, out=module.position_embeddings.weight
  603. )
  604. @auto_docstring
  605. class FlaubertModel(FlaubertPreTrainedModel):
  606. def __init__(self, config): # , dico, is_encoder, with_output):
  607. super().__init__(config)
  608. # encoder / decoder, output layer
  609. self.is_encoder = config.is_encoder
  610. self.is_decoder = not config.is_encoder
  611. if self.is_decoder:
  612. raise NotImplementedError("Currently Flaubert can only be used as an encoder")
  613. # self.with_output = with_output
  614. self.causal = config.causal
  615. # dictionary / languages
  616. self.n_langs = config.n_langs
  617. self.use_lang_emb = config.use_lang_emb
  618. self.n_words = config.n_words
  619. self.eos_index = config.eos_index
  620. self.pad_index = config.pad_index
  621. # self.dico = dico
  622. # self.id2lang = config.id2lang
  623. # self.lang2id = config.lang2id
  624. # assert len(self.dico) == self.n_words
  625. # assert len(self.id2lang) == len(self.lang2id) == self.n_langs
  626. # model parameters
  627. self.dim = config.emb_dim # 512 by default
  628. self.hidden_dim = self.dim * 4 # 2048 by default
  629. self.n_heads = config.n_heads # 8 by default
  630. self.n_layers = config.n_layers
  631. self.dropout = config.dropout
  632. self.attention_dropout = config.attention_dropout
  633. assert self.dim % self.n_heads == 0, "transformer dim must be a multiple of n_heads"
  634. # embeddings
  635. self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.dim)
  636. if config.n_langs > 1 and config.use_lang_emb:
  637. self.lang_embeddings = nn.Embedding(self.n_langs, self.dim)
  638. self.embeddings = nn.Embedding(self.n_words, self.dim, padding_idx=self.pad_index)
  639. self.layer_norm_emb = nn.LayerNorm(self.dim, eps=config.layer_norm_eps)
  640. # transformer layers
  641. self.attentions = nn.ModuleList()
  642. self.layer_norm1 = nn.ModuleList()
  643. self.ffns = nn.ModuleList()
  644. self.layer_norm2 = nn.ModuleList()
  645. # if self.is_decoder:
  646. # self.layer_norm15 = nn.ModuleList()
  647. # self.encoder_attn = nn.ModuleList()
  648. for i in range(self.n_layers):
  649. self.attentions.append(MultiHeadAttention(self.n_heads, self.dim, config=config, layer_idx=i))
  650. self.layer_norm1.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  651. # if self.is_decoder:
  652. # self.layer_norm15.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  653. # self.encoder_attn.append(MultiHeadAttention(self.n_heads, self.dim, dropout=self.attention_dropout))
  654. self.ffns.append(TransformerFFN(self.dim, self.hidden_dim, self.dim, config=config))
  655. self.layer_norm2.append(nn.LayerNorm(self.dim, eps=config.layer_norm_eps))
  656. if hasattr(config, "pruned_heads"):
  657. pruned_heads = config.pruned_heads.copy().items()
  658. config.pruned_heads = {}
  659. for layer, heads in pruned_heads:
  660. if self.attentions[int(layer)].n_heads == config.n_heads:
  661. self.prune_heads({int(layer): list(map(int, heads))})
  662. # Initialize weights and apply final processing
  663. self.post_init()
  664. self.layerdrop = getattr(config, "layerdrop", 0.0)
  665. self.pre_norm = getattr(config, "pre_norm", False)
  666. self.register_buffer(
  667. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  668. )
  669. # Copied from transformers.models.xlm.modeling_xlm.XLMModel.get_input_embeddings
  670. def get_input_embeddings(self):
  671. return self.embeddings
  672. # Copied from transformers.models.xlm.modeling_xlm.XLMModel.set_input_embeddings
  673. def set_input_embeddings(self, new_embeddings):
  674. self.embeddings = new_embeddings
  675. # Copied from transformers.models.xlm.modeling_xlm.XLMModel._prune_heads
  676. def _prune_heads(self, heads_to_prune):
  677. """
  678. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  679. class PreTrainedModel
  680. """
  681. for layer, heads in heads_to_prune.items():
  682. self.attentions[layer].prune_heads(heads)
  683. @auto_docstring
  684. def forward(
  685. self,
  686. input_ids: Optional[torch.LongTensor] = None,
  687. attention_mask: Optional[torch.FloatTensor] = None,
  688. langs: Optional[torch.Tensor] = None,
  689. token_type_ids: Optional[torch.LongTensor] = None,
  690. position_ids: Optional[torch.LongTensor] = None,
  691. lengths: Optional[torch.LongTensor] = None,
  692. cache: Optional[dict[str, torch.FloatTensor]] = None,
  693. head_mask: Optional[torch.FloatTensor] = None,
  694. inputs_embeds: Optional[torch.FloatTensor] = None,
  695. output_attentions: Optional[bool] = None,
  696. output_hidden_states: Optional[bool] = None,
  697. return_dict: Optional[bool] = None,
  698. cache_position: Optional[torch.Tensor] = None,
  699. ) -> Union[tuple, BaseModelOutput]:
  700. r"""
  701. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  702. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  703. languages ids which can be obtained from the language names by using two conversion mappings provided in
  704. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  705. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  706. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  707. See usage examples detailed in the [multilingual documentation](../multilingual).
  708. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  709. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  710. also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
  711. `[0, ..., input_ids.size(-1)]`:
  712. cache (`dict[str, torch.FloatTensor]`, *optional*):
  713. Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
  714. attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
  715. decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
  716. hidden-states.
  717. """
  718. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  719. output_hidden_states = (
  720. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  721. )
  722. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  723. # removed: src_enc=None, src_len=None
  724. if input_ids is not None:
  725. bs, slen = input_ids.size()
  726. else:
  727. bs, slen = inputs_embeds.size()[:-1]
  728. device = input_ids.device if input_ids is not None else inputs_embeds.device
  729. if cache is None:
  730. cache = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  731. if isinstance(cache, tuple):
  732. cache = EncoderDecoderCache.from_legacy_cache(cache)
  733. if lengths is None:
  734. if input_ids is not None:
  735. lengths = (input_ids != self.pad_index).sum(dim=1).long()
  736. else:
  737. lengths = torch.tensor([slen] * bs, device=device)
  738. # mask = input_ids != self.pad_index
  739. # check inputs
  740. assert lengths.size(0) == bs
  741. assert lengths.max().item() <= slen
  742. # input_ids = input_ids.transpose(0, 1) # batch size as dimension 0
  743. # assert (src_enc is None) == (src_len is None)
  744. # if src_enc is not None:
  745. # assert self.is_decoder
  746. # assert src_enc.size(0) == bs
  747. # generate masks
  748. mask, attn_mask = get_masks(slen, lengths, self.causal, padding_mask=attention_mask)
  749. # if self.is_decoder and src_enc is not None:
  750. # src_mask = torch.arange(src_len.max(), dtype=torch.long, device=lengths.device) < src_len[:, None]
  751. # Setting the position-ids to the registered buffer in constructor, it helps
  752. # when tracing the model without passing position-ids, solves
  753. # issues similar to issue #5664
  754. if position_ids is None:
  755. if hasattr(self, "position_ids"):
  756. position_ids = self.position_ids[:, :slen]
  757. position_ids = position_ids.expand((bs, slen))
  758. else:
  759. position_ids = torch.arange(slen, dtype=torch.long, device=device)
  760. position_ids = position_ids.unsqueeze(0).expand((bs, slen))
  761. else:
  762. assert position_ids.size() == (bs, slen) # (slen, bs)
  763. # position_ids = position_ids.transpose(0, 1)
  764. # langs
  765. if langs is not None:
  766. assert langs.size() == (bs, slen) # (slen, bs)
  767. # langs = langs.transpose(0, 1)
  768. # Prepare head mask if needed
  769. head_mask = self.get_head_mask(head_mask, self.config.n_layers)
  770. # do not recompute cached elements
  771. if cache is not None and input_ids is not None:
  772. _slen = slen - cache.get_seq_length()
  773. input_ids = input_ids[:, -_slen:]
  774. position_ids = position_ids[:, -_slen:]
  775. if langs is not None:
  776. langs = langs[:, -_slen:]
  777. mask = mask[:, -_slen:]
  778. attn_mask = attn_mask[:, -_slen:]
  779. # embeddings
  780. if inputs_embeds is None:
  781. inputs_embeds = self.embeddings(input_ids)
  782. tensor = inputs_embeds + self.position_embeddings(position_ids).expand_as(inputs_embeds)
  783. if langs is not None and self.use_lang_emb and self.config.n_langs > 1:
  784. tensor = tensor + self.lang_embeddings(langs)
  785. if token_type_ids is not None:
  786. tensor = tensor + self.embeddings(token_type_ids)
  787. tensor = self.layer_norm_emb(tensor)
  788. tensor = nn.functional.dropout(tensor, p=self.dropout, training=self.training)
  789. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  790. # transformer layers
  791. hidden_states = () if output_hidden_states else None
  792. attentions = () if output_attentions else None
  793. for i in range(self.n_layers):
  794. # LayerDrop
  795. if self.training:
  796. dropout_probability = torch.rand([])
  797. if dropout_probability < self.layerdrop:
  798. continue
  799. if output_hidden_states:
  800. hidden_states = hidden_states + (tensor,)
  801. # self attention
  802. if not self.pre_norm:
  803. attn_outputs = self.attentions[i](
  804. tensor,
  805. attn_mask,
  806. cache=cache,
  807. head_mask=head_mask[i],
  808. output_attentions=output_attentions,
  809. cache_position=cache_position,
  810. )
  811. attn = attn_outputs[0]
  812. if output_attentions:
  813. attentions = attentions + (attn_outputs[1],)
  814. attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
  815. tensor = tensor + attn
  816. tensor = self.layer_norm1[i](tensor)
  817. else:
  818. tensor_normalized = self.layer_norm1[i](tensor)
  819. attn_outputs = self.attentions[i](tensor_normalized, attn_mask, cache=cache, head_mask=head_mask[i])
  820. attn = attn_outputs[0]
  821. if output_attentions:
  822. attentions = attentions + (attn_outputs[1],)
  823. attn = nn.functional.dropout(attn, p=self.dropout, training=self.training)
  824. tensor = tensor + attn
  825. # FFN
  826. if not self.pre_norm:
  827. tensor = tensor + self.ffns[i](tensor)
  828. tensor = self.layer_norm2[i](tensor)
  829. else:
  830. tensor_normalized = self.layer_norm2[i](tensor)
  831. tensor = tensor + self.ffns[i](tensor_normalized)
  832. tensor *= mask.unsqueeze(-1).to(tensor.dtype)
  833. # Add last hidden state
  834. if output_hidden_states:
  835. hidden_states = hidden_states + (tensor,)
  836. if not return_dict:
  837. return tuple(v for v in [tensor, hidden_states, attentions] if v is not None)
  838. return BaseModelOutput(last_hidden_state=tensor, hidden_states=hidden_states, attentions=attentions)
  839. @auto_docstring(
  840. custom_intro="""
  841. The Flaubert Model transformer with a language modeling head on top (linear layer with weights tied to the input
  842. embeddings).
  843. """
  844. )
  845. class FlaubertWithLMHeadModel(FlaubertPreTrainedModel, GenerationMixin):
  846. _tied_weights_keys = ["pred_layer.proj.weight"]
  847. def __init__(self, config):
  848. super().__init__(config)
  849. self.transformer = FlaubertModel(config)
  850. self.pred_layer = FlaubertPredLayer(config)
  851. # Initialize weights and apply final processing
  852. self.post_init()
  853. def get_output_embeddings(self):
  854. return self.pred_layer.proj
  855. def set_output_embeddings(self, new_embeddings):
  856. self.pred_layer.proj = new_embeddings
  857. def prepare_inputs_for_generation(self, input_ids, **kwargs):
  858. # Overwritten -- uses a language id
  859. mask_token_id = self.config.mask_token_id
  860. lang_id = self.config.lang_id
  861. effective_batch_size = input_ids.shape[0]
  862. mask_token = torch.full((effective_batch_size, 1), mask_token_id, dtype=torch.long, device=input_ids.device)
  863. input_ids = torch.cat([input_ids, mask_token], dim=1)
  864. if lang_id is not None:
  865. langs = torch.full_like(input_ids, lang_id)
  866. else:
  867. langs = None
  868. return {"input_ids": input_ids, "langs": langs}
  869. @auto_docstring
  870. def forward(
  871. self,
  872. input_ids: Optional[torch.Tensor] = None,
  873. attention_mask: Optional[torch.Tensor] = None,
  874. langs: Optional[torch.Tensor] = None,
  875. token_type_ids: Optional[torch.Tensor] = None,
  876. position_ids: Optional[torch.Tensor] = None,
  877. lengths: Optional[torch.Tensor] = None,
  878. cache: Optional[dict[str, torch.Tensor]] = None,
  879. head_mask: Optional[torch.Tensor] = None,
  880. inputs_embeds: Optional[torch.Tensor] = None,
  881. labels: Optional[torch.Tensor] = None,
  882. output_attentions: Optional[bool] = None,
  883. output_hidden_states: Optional[bool] = None,
  884. return_dict: Optional[bool] = None,
  885. ) -> Union[tuple, MaskedLMOutput]:
  886. r"""
  887. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  888. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  889. languages ids which can be obtained from the language names by using two conversion mappings provided in
  890. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  891. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  892. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  893. See usage examples detailed in the [multilingual documentation](../multilingual).
  894. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  895. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  896. also use `attention_mask` for the same result (see above), kept here for compatibility. Indices selected in
  897. `[0, ..., input_ids.size(-1)]`:
  898. cache (`dict[str, torch.FloatTensor]`, *optional*):
  899. Dictionary strings to `torch.FloatTensor` that contains precomputed hidden-states (key and values in the
  900. attention blocks) as computed by the model (see `cache` output below). Can be used to speed up sequential
  901. decoding. The dictionary object will be modified in-place during the forward pass to add newly computed
  902. hidden-states.
  903. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  904. Labels for language modeling. Note that the labels **are shifted** inside the model, i.e. you can set
  905. `labels = input_ids` Indices are selected in `[-100, 0, ..., config.vocab_size]` All labels set to `-100`
  906. are ignored (masked), the loss is only computed for labels in `[0, ..., config.vocab_size]`
  907. """
  908. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  909. transformer_outputs = self.transformer(
  910. input_ids,
  911. attention_mask=attention_mask,
  912. langs=langs,
  913. token_type_ids=token_type_ids,
  914. position_ids=position_ids,
  915. lengths=lengths,
  916. cache=cache,
  917. head_mask=head_mask,
  918. inputs_embeds=inputs_embeds,
  919. output_attentions=output_attentions,
  920. output_hidden_states=output_hidden_states,
  921. return_dict=return_dict,
  922. )
  923. output = transformer_outputs[0]
  924. outputs = self.pred_layer(output, labels) # (loss, logits) or (logits,) depending on if labels are provided.
  925. if not return_dict:
  926. return outputs + transformer_outputs[1:]
  927. return MaskedLMOutput(
  928. loss=outputs[0] if labels is not None else None,
  929. logits=outputs[0] if labels is None else outputs[1],
  930. hidden_states=transformer_outputs.hidden_states,
  931. attentions=transformer_outputs.attentions,
  932. )
  933. @auto_docstring(
  934. custom_intro="""
  935. Flaubert Model with a sequence classification/regression head on top (a linear layer on top of the pooled output)
  936. e.g. for GLUE tasks.
  937. """
  938. )
  939. # Copied from transformers.models.xlm.modeling_xlm.XLMForSequenceClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  940. class FlaubertForSequenceClassification(FlaubertPreTrainedModel):
  941. def __init__(self, config):
  942. super().__init__(config)
  943. self.num_labels = config.num_labels
  944. self.config = config
  945. self.transformer = FlaubertModel(config)
  946. self.sequence_summary = FlaubertSequenceSummary(config)
  947. # Initialize weights and apply final processing
  948. self.post_init()
  949. @auto_docstring
  950. def forward(
  951. self,
  952. input_ids: Optional[torch.Tensor] = None,
  953. attention_mask: Optional[torch.Tensor] = None,
  954. langs: Optional[torch.Tensor] = None,
  955. token_type_ids: Optional[torch.Tensor] = None,
  956. position_ids: Optional[torch.Tensor] = None,
  957. lengths: Optional[torch.Tensor] = None,
  958. cache: Optional[dict[str, torch.Tensor]] = None,
  959. head_mask: Optional[torch.Tensor] = None,
  960. inputs_embeds: Optional[torch.Tensor] = None,
  961. labels: Optional[torch.Tensor] = None,
  962. output_attentions: Optional[bool] = None,
  963. output_hidden_states: Optional[bool] = None,
  964. return_dict: Optional[bool] = None,
  965. ) -> Union[tuple, SequenceClassifierOutput]:
  966. r"""
  967. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  968. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  969. languages ids which can be obtained from the language names by using two conversion mappings provided in
  970. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  971. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  972. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  973. See usage examples detailed in the [multilingual documentation](../multilingual).
  974. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  975. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  976. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  977. `[0, ..., input_ids.size(-1)]`.
  978. cache (`dict[str, torch.FloatTensor]`, *optional*):
  979. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  980. decoding.
  981. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  982. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  983. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  984. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  985. """
  986. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  987. transformer_outputs = self.transformer(
  988. input_ids,
  989. attention_mask=attention_mask,
  990. langs=langs,
  991. token_type_ids=token_type_ids,
  992. position_ids=position_ids,
  993. lengths=lengths,
  994. cache=cache,
  995. head_mask=head_mask,
  996. inputs_embeds=inputs_embeds,
  997. output_attentions=output_attentions,
  998. output_hidden_states=output_hidden_states,
  999. return_dict=return_dict,
  1000. )
  1001. output = transformer_outputs[0]
  1002. logits = self.sequence_summary(output)
  1003. loss = None
  1004. if labels is not None:
  1005. if self.config.problem_type is None:
  1006. if self.num_labels == 1:
  1007. self.config.problem_type = "regression"
  1008. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1009. self.config.problem_type = "single_label_classification"
  1010. else:
  1011. self.config.problem_type = "multi_label_classification"
  1012. if self.config.problem_type == "regression":
  1013. loss_fct = MSELoss()
  1014. if self.num_labels == 1:
  1015. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1016. else:
  1017. loss = loss_fct(logits, labels)
  1018. elif self.config.problem_type == "single_label_classification":
  1019. loss_fct = CrossEntropyLoss()
  1020. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1021. elif self.config.problem_type == "multi_label_classification":
  1022. loss_fct = BCEWithLogitsLoss()
  1023. loss = loss_fct(logits, labels)
  1024. if not return_dict:
  1025. output = (logits,) + transformer_outputs[1:]
  1026. return ((loss,) + output) if loss is not None else output
  1027. return SequenceClassifierOutput(
  1028. loss=loss,
  1029. logits=logits,
  1030. hidden_states=transformer_outputs.hidden_states,
  1031. attentions=transformer_outputs.attentions,
  1032. )
  1033. @auto_docstring
  1034. # Copied from transformers.models.xlm.modeling_xlm.XLMForTokenClassification with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1035. class FlaubertForTokenClassification(FlaubertPreTrainedModel):
  1036. def __init__(self, config):
  1037. super().__init__(config)
  1038. self.num_labels = config.num_labels
  1039. self.transformer = FlaubertModel(config)
  1040. self.dropout = nn.Dropout(config.dropout)
  1041. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1042. # Initialize weights and apply final processing
  1043. self.post_init()
  1044. @auto_docstring
  1045. def forward(
  1046. self,
  1047. input_ids: Optional[torch.Tensor] = None,
  1048. attention_mask: Optional[torch.Tensor] = None,
  1049. langs: Optional[torch.Tensor] = None,
  1050. token_type_ids: Optional[torch.Tensor] = None,
  1051. position_ids: Optional[torch.Tensor] = None,
  1052. lengths: Optional[torch.Tensor] = None,
  1053. cache: Optional[dict[str, torch.Tensor]] = None,
  1054. head_mask: Optional[torch.Tensor] = None,
  1055. inputs_embeds: Optional[torch.Tensor] = None,
  1056. labels: Optional[torch.Tensor] = None,
  1057. output_attentions: Optional[bool] = None,
  1058. output_hidden_states: Optional[bool] = None,
  1059. return_dict: Optional[bool] = None,
  1060. ) -> Union[tuple, TokenClassifierOutput]:
  1061. r"""
  1062. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1063. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1064. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1065. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1066. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1067. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1068. See usage examples detailed in the [multilingual documentation](../multilingual).
  1069. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1070. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1071. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1072. `[0, ..., input_ids.size(-1)]`.
  1073. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1074. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1075. decoding.
  1076. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1077. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1078. """
  1079. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1080. outputs = self.transformer(
  1081. input_ids,
  1082. attention_mask=attention_mask,
  1083. langs=langs,
  1084. token_type_ids=token_type_ids,
  1085. position_ids=position_ids,
  1086. lengths=lengths,
  1087. cache=cache,
  1088. head_mask=head_mask,
  1089. inputs_embeds=inputs_embeds,
  1090. output_attentions=output_attentions,
  1091. output_hidden_states=output_hidden_states,
  1092. return_dict=return_dict,
  1093. )
  1094. sequence_output = outputs[0]
  1095. sequence_output = self.dropout(sequence_output)
  1096. logits = self.classifier(sequence_output)
  1097. loss = None
  1098. if labels is not None:
  1099. loss_fct = CrossEntropyLoss()
  1100. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1101. if not return_dict:
  1102. output = (logits,) + outputs[1:]
  1103. return ((loss,) + output) if loss is not None else output
  1104. return TokenClassifierOutput(
  1105. loss=loss,
  1106. logits=logits,
  1107. hidden_states=outputs.hidden_states,
  1108. attentions=outputs.attentions,
  1109. )
  1110. @auto_docstring(
  1111. custom_intro="""
  1112. Flaubert Model with a span classification head on top for extractive question-answering tasks like SQuAD (a linear
  1113. layers on top of the hidden-states output to compute `span start logits` and `span end logits`).
  1114. """
  1115. )
  1116. # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnsweringSimple with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1117. class FlaubertForQuestionAnsweringSimple(FlaubertPreTrainedModel):
  1118. def __init__(self, config):
  1119. super().__init__(config)
  1120. self.transformer = FlaubertModel(config)
  1121. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1122. # Initialize weights and apply final processing
  1123. self.post_init()
  1124. @auto_docstring
  1125. def forward(
  1126. self,
  1127. input_ids: Optional[torch.Tensor] = None,
  1128. attention_mask: Optional[torch.Tensor] = None,
  1129. langs: Optional[torch.Tensor] = None,
  1130. token_type_ids: Optional[torch.Tensor] = None,
  1131. position_ids: Optional[torch.Tensor] = None,
  1132. lengths: Optional[torch.Tensor] = None,
  1133. cache: Optional[dict[str, torch.Tensor]] = None,
  1134. head_mask: Optional[torch.Tensor] = None,
  1135. inputs_embeds: Optional[torch.Tensor] = None,
  1136. start_positions: Optional[torch.Tensor] = None,
  1137. end_positions: Optional[torch.Tensor] = None,
  1138. output_attentions: Optional[bool] = None,
  1139. output_hidden_states: Optional[bool] = None,
  1140. return_dict: Optional[bool] = None,
  1141. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1142. r"""
  1143. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1144. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1145. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1146. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1147. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1148. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1149. See usage examples detailed in the [multilingual documentation](../multilingual).
  1150. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1151. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1152. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1153. `[0, ..., input_ids.size(-1)]`.
  1154. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1155. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1156. decoding.
  1157. """
  1158. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1159. transformer_outputs = self.transformer(
  1160. input_ids,
  1161. attention_mask=attention_mask,
  1162. langs=langs,
  1163. token_type_ids=token_type_ids,
  1164. position_ids=position_ids,
  1165. lengths=lengths,
  1166. cache=cache,
  1167. head_mask=head_mask,
  1168. inputs_embeds=inputs_embeds,
  1169. output_attentions=output_attentions,
  1170. output_hidden_states=output_hidden_states,
  1171. return_dict=return_dict,
  1172. )
  1173. sequence_output = transformer_outputs[0]
  1174. logits = self.qa_outputs(sequence_output)
  1175. start_logits, end_logits = logits.split(1, dim=-1)
  1176. start_logits = start_logits.squeeze(-1).contiguous()
  1177. end_logits = end_logits.squeeze(-1).contiguous()
  1178. total_loss = None
  1179. if start_positions is not None and end_positions is not None:
  1180. # If we are on multi-GPU, split add a dimension
  1181. if len(start_positions.size()) > 1:
  1182. start_positions = start_positions.squeeze(-1)
  1183. if len(end_positions.size()) > 1:
  1184. end_positions = end_positions.squeeze(-1)
  1185. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1186. ignored_index = start_logits.size(1)
  1187. start_positions = start_positions.clamp(0, ignored_index)
  1188. end_positions = end_positions.clamp(0, ignored_index)
  1189. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1190. start_loss = loss_fct(start_logits, start_positions)
  1191. end_loss = loss_fct(end_logits, end_positions)
  1192. total_loss = (start_loss + end_loss) / 2
  1193. if not return_dict:
  1194. output = (start_logits, end_logits) + transformer_outputs[1:]
  1195. return ((total_loss,) + output) if total_loss is not None else output
  1196. return QuestionAnsweringModelOutput(
  1197. loss=total_loss,
  1198. start_logits=start_logits,
  1199. end_logits=end_logits,
  1200. hidden_states=transformer_outputs.hidden_states,
  1201. attentions=transformer_outputs.attentions,
  1202. )
  1203. @dataclass
  1204. @auto_docstring(
  1205. custom_intro="""
  1206. Base class for outputs of question answering models using a `SquadHead`.
  1207. """
  1208. )
  1209. # Copied from transformer.models.xlm.modeling_xlm.XLMForQuestionAnsweringOutput with XLM->Flaubert
  1210. class FlaubertForQuestionAnsweringOutput(ModelOutput):
  1211. r"""
  1212. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned if both `start_positions` and `end_positions` are provided):
  1213. Classification loss as the sum of start token, end token (and is_impossible if provided) classification
  1214. losses.
  1215. start_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1216. Log probabilities for the top config.start_n_top start token possibilities (beam-search).
  1217. start_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1218. Indices for the top config.start_n_top start token possibilities (beam-search).
  1219. end_top_log_probs (`torch.FloatTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1220. Log probabilities for the top `config.start_n_top * config.end_n_top` end token possibilities
  1221. (beam-search).
  1222. end_top_index (`torch.LongTensor` of shape `(batch_size, config.start_n_top * config.end_n_top)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1223. Indices for the top `config.start_n_top * config.end_n_top` end token possibilities (beam-search).
  1224. cls_logits (`torch.FloatTensor` of shape `(batch_size,)`, *optional*, returned if `start_positions` or `end_positions` is not provided):
  1225. Log probabilities for the `is_impossible` label of the answers.
  1226. """
  1227. loss: Optional[torch.FloatTensor] = None
  1228. start_top_log_probs: Optional[torch.FloatTensor] = None
  1229. start_top_index: Optional[torch.LongTensor] = None
  1230. end_top_log_probs: Optional[torch.FloatTensor] = None
  1231. end_top_index: Optional[torch.LongTensor] = None
  1232. cls_logits: Optional[torch.FloatTensor] = None
  1233. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1234. attentions: Optional[tuple[torch.FloatTensor]] = None
  1235. @auto_docstring
  1236. # Copied from transformers.models.xlm.modeling_xlm.XLMForQuestionAnswering with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1237. class FlaubertForQuestionAnswering(FlaubertPreTrainedModel):
  1238. def __init__(self, config):
  1239. super().__init__(config)
  1240. self.transformer = FlaubertModel(config)
  1241. self.qa_outputs = FlaubertSQuADHead(config)
  1242. # Initialize weights and apply final processing
  1243. self.post_init()
  1244. @auto_docstring
  1245. def forward(
  1246. self,
  1247. input_ids: Optional[torch.Tensor] = None,
  1248. attention_mask: Optional[torch.Tensor] = None,
  1249. langs: Optional[torch.Tensor] = None,
  1250. token_type_ids: Optional[torch.Tensor] = None,
  1251. position_ids: Optional[torch.Tensor] = None,
  1252. lengths: Optional[torch.Tensor] = None,
  1253. cache: Optional[dict[str, torch.Tensor]] = None,
  1254. head_mask: Optional[torch.Tensor] = None,
  1255. inputs_embeds: Optional[torch.Tensor] = None,
  1256. start_positions: Optional[torch.Tensor] = None,
  1257. end_positions: Optional[torch.Tensor] = None,
  1258. is_impossible: Optional[torch.Tensor] = None,
  1259. cls_index: Optional[torch.Tensor] = None,
  1260. p_mask: Optional[torch.Tensor] = None,
  1261. output_attentions: Optional[bool] = None,
  1262. output_hidden_states: Optional[bool] = None,
  1263. return_dict: Optional[bool] = None,
  1264. ) -> Union[tuple, FlaubertForQuestionAnsweringOutput]:
  1265. r"""
  1266. langs (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1267. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1268. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1269. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1270. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1271. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1272. See usage examples detailed in the [multilingual documentation](../multilingual).
  1273. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1274. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1275. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1276. `[0, ..., input_ids.size(-1)]`.
  1277. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1278. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1279. decoding.
  1280. is_impossible (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1281. Labels whether a question has an answer or no answer (SQuAD 2.0)
  1282. cls_index (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1283. Labels for position (index) of the classification token to use as input for computing plausibility of the
  1284. answer.
  1285. p_mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1286. Optional mask of tokens which can't be in answers (e.g. [CLS], [PAD], ...). 1.0 means token should be
  1287. masked. 0.0 mean token is not masked.
  1288. Example:
  1289. ```python
  1290. >>> from transformers import AutoTokenizer, FlaubertForQuestionAnswering
  1291. >>> import torch
  1292. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1293. >>> model = FlaubertForQuestionAnswering.from_pretrained("FacebookAI/xlm-mlm-en-2048")
  1294. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(
  1295. ... 0
  1296. ... ) # Batch size 1
  1297. >>> start_positions = torch.tensor([1])
  1298. >>> end_positions = torch.tensor([3])
  1299. >>> outputs = model(input_ids, start_positions=start_positions, end_positions=end_positions)
  1300. >>> loss = outputs.loss
  1301. ```"""
  1302. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1303. transformer_outputs = self.transformer(
  1304. input_ids,
  1305. attention_mask=attention_mask,
  1306. langs=langs,
  1307. token_type_ids=token_type_ids,
  1308. position_ids=position_ids,
  1309. lengths=lengths,
  1310. cache=cache,
  1311. head_mask=head_mask,
  1312. inputs_embeds=inputs_embeds,
  1313. output_attentions=output_attentions,
  1314. output_hidden_states=output_hidden_states,
  1315. return_dict=return_dict,
  1316. )
  1317. output = transformer_outputs[0]
  1318. outputs = self.qa_outputs(
  1319. output,
  1320. start_positions=start_positions,
  1321. end_positions=end_positions,
  1322. cls_index=cls_index,
  1323. is_impossible=is_impossible,
  1324. p_mask=p_mask,
  1325. return_dict=return_dict,
  1326. )
  1327. if not return_dict:
  1328. return outputs + transformer_outputs[1:]
  1329. return FlaubertForQuestionAnsweringOutput(
  1330. loss=outputs.loss,
  1331. start_top_log_probs=outputs.start_top_log_probs,
  1332. start_top_index=outputs.start_top_index,
  1333. end_top_log_probs=outputs.end_top_log_probs,
  1334. end_top_index=outputs.end_top_index,
  1335. cls_logits=outputs.cls_logits,
  1336. hidden_states=transformer_outputs.hidden_states,
  1337. attentions=transformer_outputs.attentions,
  1338. )
  1339. @auto_docstring
  1340. # Copied from transformers.models.xlm.modeling_xlm.XLMForMultipleChoice with XLM_INPUTS->FLAUBERT_INPUTS,XLM->Flaubert
  1341. class FlaubertForMultipleChoice(FlaubertPreTrainedModel):
  1342. def __init__(self, config, *inputs, **kwargs):
  1343. super().__init__(config, *inputs, **kwargs)
  1344. self.transformer = FlaubertModel(config)
  1345. self.sequence_summary = FlaubertSequenceSummary(config)
  1346. self.logits_proj = nn.Linear(config.num_labels, 1)
  1347. # Initialize weights and apply final processing
  1348. self.post_init()
  1349. @auto_docstring
  1350. def forward(
  1351. self,
  1352. input_ids: Optional[torch.Tensor] = None,
  1353. attention_mask: Optional[torch.Tensor] = None,
  1354. langs: Optional[torch.Tensor] = None,
  1355. token_type_ids: Optional[torch.Tensor] = None,
  1356. position_ids: Optional[torch.Tensor] = None,
  1357. lengths: Optional[torch.Tensor] = None,
  1358. cache: Optional[dict[str, torch.Tensor]] = None,
  1359. head_mask: Optional[torch.Tensor] = None,
  1360. inputs_embeds: Optional[torch.Tensor] = None,
  1361. labels: Optional[torch.Tensor] = None,
  1362. output_attentions: Optional[bool] = None,
  1363. output_hidden_states: Optional[bool] = None,
  1364. return_dict: Optional[bool] = None,
  1365. ) -> Union[tuple, MultipleChoiceModelOutput]:
  1366. r"""
  1367. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1368. Indices of input sequence tokens in the vocabulary.
  1369. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1370. [`PreTrainedTokenizer.__call__`] for details.
  1371. [What are input IDs?](../glossary#input-ids)
  1372. langs (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1373. A parallel sequence of tokens to be used to indicate the language of each token in the input. Indices are
  1374. languages ids which can be obtained from the language names by using two conversion mappings provided in
  1375. the configuration of the model (only provided for multilingual models). More precisely, the *language name
  1376. to language id* mapping is in `model.config.lang2id` (which is a dictionary string to int) and the
  1377. *language id to language name* mapping is in `model.config.id2lang` (dictionary int to string).
  1378. See usage examples detailed in the [multilingual documentation](../multilingual).
  1379. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1380. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1381. 1]`:
  1382. - 0 corresponds to a *sentence A* token,
  1383. - 1 corresponds to a *sentence B* token.
  1384. [What are token type IDs?](../glossary#token-type-ids)
  1385. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1386. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1387. config.max_position_embeddings - 1]`.
  1388. [What are position IDs?](../glossary#position-ids)
  1389. lengths (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1390. Length of each sentence that can be used to avoid performing attention on padding token indices. You can
  1391. also use *attention_mask* for the same result (see above), kept here for compatibility. Indices selected in
  1392. `[0, ..., input_ids.size(-1)]`.
  1393. cache (`dict[str, torch.FloatTensor]`, *optional*):
  1394. Instance of `EncoderDecoderCache` that contains precomputed KV states. Can be used to speed up sequential
  1395. decoding.
  1396. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1397. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1398. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1399. model's internal embedding lookup matrix.
  1400. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1401. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1402. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1403. `input_ids` above)
  1404. """
  1405. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1406. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1407. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1408. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1409. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1410. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1411. langs = langs.view(-1, langs.size(-1)) if langs is not None else None
  1412. inputs_embeds = (
  1413. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1414. if inputs_embeds is not None
  1415. else None
  1416. )
  1417. if lengths is not None:
  1418. logger.warning(
  1419. "The `lengths` parameter cannot be used with the Flaubert multiple choice models. Please use the "
  1420. "attention mask instead."
  1421. )
  1422. lengths = None
  1423. transformer_outputs = self.transformer(
  1424. input_ids=input_ids,
  1425. attention_mask=attention_mask,
  1426. langs=langs,
  1427. token_type_ids=token_type_ids,
  1428. position_ids=position_ids,
  1429. lengths=lengths,
  1430. cache=cache,
  1431. head_mask=head_mask,
  1432. inputs_embeds=inputs_embeds,
  1433. output_attentions=output_attentions,
  1434. output_hidden_states=output_hidden_states,
  1435. return_dict=return_dict,
  1436. )
  1437. output = transformer_outputs[0]
  1438. logits = self.sequence_summary(output)
  1439. logits = self.logits_proj(logits)
  1440. reshaped_logits = logits.view(-1, num_choices)
  1441. loss = None
  1442. if labels is not None:
  1443. loss_fct = CrossEntropyLoss()
  1444. loss = loss_fct(reshaped_logits, labels)
  1445. if not return_dict:
  1446. output = (reshaped_logits,) + transformer_outputs[1:]
  1447. return ((loss,) + output) if loss is not None else output
  1448. return MultipleChoiceModelOutput(
  1449. loss=loss,
  1450. logits=reshaped_logits,
  1451. hidden_states=transformer_outputs.hidden_states,
  1452. attentions=transformer_outputs.attentions,
  1453. )
  1454. __all__ = [
  1455. "FlaubertForMultipleChoice",
  1456. "FlaubertForQuestionAnswering",
  1457. "FlaubertForQuestionAnsweringSimple",
  1458. "FlaubertForSequenceClassification",
  1459. "FlaubertForTokenClassification",
  1460. "FlaubertModel",
  1461. "FlaubertWithLMHeadModel",
  1462. "FlaubertPreTrainedModel",
  1463. ]