modeling_fnet.py 43 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093
  1. # coding=utf-8
  2. # Copyright 2021 Google Research and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch FNet model."""
  16. import warnings
  17. from dataclasses import dataclass
  18. from functools import partial
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...utils import auto_docstring, is_scipy_available
  24. if is_scipy_available():
  25. from scipy import linalg
  26. from ...activations import ACT2FN
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPooling,
  31. MaskedLMOutput,
  32. ModelOutput,
  33. MultipleChoiceModelOutput,
  34. NextSentencePredictorOutput,
  35. QuestionAnsweringModelOutput,
  36. SequenceClassifierOutput,
  37. TokenClassifierOutput,
  38. )
  39. from ...modeling_utils import PreTrainedModel
  40. from ...pytorch_utils import apply_chunking_to_forward
  41. from ...utils import logging
  42. from .configuration_fnet import FNetConfig
  43. logger = logging.get_logger(__name__)
  44. # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  45. def _two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
  46. """Applies 2D matrix multiplication to 3D input arrays."""
  47. seq_length = x.shape[1]
  48. matrix_dim_one = matrix_dim_one[:seq_length, :seq_length]
  49. x = x.type(torch.complex64)
  50. return torch.einsum("bij,jk,ni->bnk", x, matrix_dim_two, matrix_dim_one)
  51. # # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  52. def two_dim_matmul(x, matrix_dim_one, matrix_dim_two):
  53. return _two_dim_matmul(x, matrix_dim_one, matrix_dim_two)
  54. # Adapted from https://github.com/google-research/google-research/blob/master/f_net/fourier.py
  55. def fftn(x):
  56. """
  57. Applies n-dimensional Fast Fourier Transform (FFT) to input array.
  58. Args:
  59. x: Input n-dimensional array.
  60. Returns:
  61. n-dimensional Fourier transform of input n-dimensional array.
  62. """
  63. out = x
  64. for axis in reversed(range(x.ndim)[1:]): # We don't need to apply FFT to last axis
  65. out = torch.fft.fft(out, axis=axis)
  66. return out
  67. class FNetEmbeddings(nn.Module):
  68. """Construct the embeddings from word, position and token_type embeddings."""
  69. def __init__(self, config):
  70. super().__init__()
  71. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  72. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  73. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  74. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  75. # any TensorFlow checkpoint file
  76. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  77. # NOTE: This is the project layer and will be needed. The original code allows for different embedding and different model dimensions.
  78. self.projection = nn.Linear(config.hidden_size, config.hidden_size)
  79. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  80. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  81. self.register_buffer(
  82. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  83. )
  84. self.register_buffer(
  85. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  86. )
  87. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  88. if input_ids is not None:
  89. input_shape = input_ids.size()
  90. else:
  91. input_shape = inputs_embeds.size()[:-1]
  92. seq_length = input_shape[1]
  93. if position_ids is None:
  94. position_ids = self.position_ids[:, :seq_length]
  95. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  96. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  97. # issue #5664
  98. if token_type_ids is None:
  99. if hasattr(self, "token_type_ids"):
  100. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  101. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  102. token_type_ids = buffered_token_type_ids_expanded
  103. else:
  104. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  105. if inputs_embeds is None:
  106. inputs_embeds = self.word_embeddings(input_ids)
  107. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  108. embeddings = inputs_embeds + token_type_embeddings
  109. position_embeddings = self.position_embeddings(position_ids)
  110. embeddings += position_embeddings
  111. embeddings = self.LayerNorm(embeddings)
  112. embeddings = self.projection(embeddings)
  113. embeddings = self.dropout(embeddings)
  114. return embeddings
  115. class FNetBasicFourierTransform(nn.Module):
  116. def __init__(self, config):
  117. super().__init__()
  118. self._init_fourier_transform(config)
  119. def _init_fourier_transform(self, config):
  120. if not config.use_tpu_fourier_optimizations:
  121. self.fourier_transform = partial(torch.fft.fftn, dim=(1, 2))
  122. elif config.max_position_embeddings <= 4096:
  123. if is_scipy_available():
  124. self.register_buffer(
  125. "dft_mat_hidden", torch.tensor(linalg.dft(config.hidden_size), dtype=torch.complex64)
  126. )
  127. self.register_buffer(
  128. "dft_mat_seq", torch.tensor(linalg.dft(config.tpu_short_seq_length), dtype=torch.complex64)
  129. )
  130. self.fourier_transform = partial(
  131. two_dim_matmul, matrix_dim_one=self.dft_mat_seq, matrix_dim_two=self.dft_mat_hidden
  132. )
  133. else:
  134. logging.warning(
  135. "SciPy is needed for DFT matrix calculation and is not found. Using TPU optimized fast fourier"
  136. " transform instead."
  137. )
  138. self.fourier_transform = fftn
  139. else:
  140. self.fourier_transform = fftn
  141. def forward(self, hidden_states):
  142. # NOTE: We do not use torch.vmap as it is not integrated into PyTorch stable versions.
  143. # Interested users can modify the code to use vmap from the nightly versions, getting the vmap from here:
  144. # https://pytorch.org/docs/master/generated/torch.vmap.html. Note that fourier transform methods will need
  145. # change accordingly.
  146. outputs = self.fourier_transform(hidden_states).real
  147. return (outputs,)
  148. class FNetBasicOutput(nn.Module):
  149. def __init__(self, config):
  150. super().__init__()
  151. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  152. def forward(self, hidden_states, input_tensor):
  153. hidden_states = self.LayerNorm(input_tensor + hidden_states)
  154. return hidden_states
  155. class FNetFourierTransform(nn.Module):
  156. def __init__(self, config):
  157. super().__init__()
  158. self.self = FNetBasicFourierTransform(config)
  159. self.output = FNetBasicOutput(config)
  160. def forward(self, hidden_states):
  161. self_outputs = self.self(hidden_states)
  162. fourier_output = self.output(self_outputs[0], hidden_states)
  163. outputs = (fourier_output,)
  164. return outputs
  165. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->FNet
  166. class FNetIntermediate(nn.Module):
  167. def __init__(self, config):
  168. super().__init__()
  169. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  170. if isinstance(config.hidden_act, str):
  171. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  172. else:
  173. self.intermediate_act_fn = config.hidden_act
  174. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  175. hidden_states = self.dense(hidden_states)
  176. hidden_states = self.intermediate_act_fn(hidden_states)
  177. return hidden_states
  178. # Copied from transformers.models.bert.modeling_bert.BertOutput with Bert->FNet
  179. class FNetOutput(nn.Module):
  180. def __init__(self, config):
  181. super().__init__()
  182. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  183. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  184. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  185. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  186. hidden_states = self.dense(hidden_states)
  187. hidden_states = self.dropout(hidden_states)
  188. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  189. return hidden_states
  190. class FNetLayer(GradientCheckpointingLayer):
  191. def __init__(self, config):
  192. super().__init__()
  193. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  194. self.seq_len_dim = 1 # The dimension which has the sequence length
  195. self.fourier = FNetFourierTransform(config)
  196. self.intermediate = FNetIntermediate(config)
  197. self.output = FNetOutput(config)
  198. def forward(self, hidden_states):
  199. self_fourier_outputs = self.fourier(hidden_states)
  200. fourier_output = self_fourier_outputs[0]
  201. layer_output = apply_chunking_to_forward(
  202. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, fourier_output
  203. )
  204. outputs = (layer_output,)
  205. return outputs
  206. def feed_forward_chunk(self, fourier_output):
  207. intermediate_output = self.intermediate(fourier_output)
  208. layer_output = self.output(intermediate_output, fourier_output)
  209. return layer_output
  210. class FNetEncoder(nn.Module):
  211. def __init__(self, config):
  212. super().__init__()
  213. self.config = config
  214. self.layer = nn.ModuleList([FNetLayer(config) for _ in range(config.num_hidden_layers)])
  215. self.gradient_checkpointing = False
  216. def forward(self, hidden_states, output_hidden_states=False, return_dict=True):
  217. all_hidden_states = () if output_hidden_states else None
  218. for i, layer_module in enumerate(self.layer):
  219. if output_hidden_states:
  220. all_hidden_states = all_hidden_states + (hidden_states,)
  221. layer_outputs = layer_module(hidden_states)
  222. hidden_states = layer_outputs[0]
  223. if output_hidden_states:
  224. all_hidden_states = all_hidden_states + (hidden_states,)
  225. if not return_dict:
  226. return tuple(v for v in [hidden_states, all_hidden_states] if v is not None)
  227. return BaseModelOutput(last_hidden_state=hidden_states, hidden_states=all_hidden_states)
  228. # Copied from transformers.models.bert.modeling_bert.BertPooler with Bert->FNet
  229. class FNetPooler(nn.Module):
  230. def __init__(self, config):
  231. super().__init__()
  232. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  233. self.activation = nn.Tanh()
  234. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  235. # We "pool" the model by simply taking the hidden state corresponding
  236. # to the first token.
  237. first_token_tensor = hidden_states[:, 0]
  238. pooled_output = self.dense(first_token_tensor)
  239. pooled_output = self.activation(pooled_output)
  240. return pooled_output
  241. # Copied from transformers.models.bert.modeling_bert.BertPredictionHeadTransform with Bert->FNet
  242. class FNetPredictionHeadTransform(nn.Module):
  243. def __init__(self, config):
  244. super().__init__()
  245. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  246. if isinstance(config.hidden_act, str):
  247. self.transform_act_fn = ACT2FN[config.hidden_act]
  248. else:
  249. self.transform_act_fn = config.hidden_act
  250. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  251. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  252. hidden_states = self.dense(hidden_states)
  253. hidden_states = self.transform_act_fn(hidden_states)
  254. hidden_states = self.LayerNorm(hidden_states)
  255. return hidden_states
  256. class FNetLMPredictionHead(nn.Module):
  257. def __init__(self, config):
  258. super().__init__()
  259. self.transform = FNetPredictionHeadTransform(config)
  260. # The output weights are the same as the input embeddings, but there is
  261. # an output-only bias for each token.
  262. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  263. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  264. self.decoder.bias = self.bias
  265. def forward(self, hidden_states):
  266. hidden_states = self.transform(hidden_states)
  267. hidden_states = self.decoder(hidden_states)
  268. return hidden_states
  269. def _tie_weights(self) -> None:
  270. # For accelerate compatibility and to not break backward compatibility
  271. if self.decoder.bias.device.type == "meta":
  272. self.decoder.bias = self.bias
  273. else:
  274. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  275. self.bias = self.decoder.bias
  276. class FNetOnlyMLMHead(nn.Module):
  277. def __init__(self, config):
  278. super().__init__()
  279. self.predictions = FNetLMPredictionHead(config)
  280. def forward(self, sequence_output):
  281. prediction_scores = self.predictions(sequence_output)
  282. return prediction_scores
  283. # Copied from transformers.models.bert.modeling_bert.BertOnlyNSPHead with Bert->FNet
  284. class FNetOnlyNSPHead(nn.Module):
  285. def __init__(self, config):
  286. super().__init__()
  287. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  288. def forward(self, pooled_output):
  289. seq_relationship_score = self.seq_relationship(pooled_output)
  290. return seq_relationship_score
  291. # Copied from transformers.models.bert.modeling_bert.BertPreTrainingHeads with Bert->FNet
  292. class FNetPreTrainingHeads(nn.Module):
  293. def __init__(self, config):
  294. super().__init__()
  295. self.predictions = FNetLMPredictionHead(config)
  296. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  297. def forward(self, sequence_output, pooled_output):
  298. prediction_scores = self.predictions(sequence_output)
  299. seq_relationship_score = self.seq_relationship(pooled_output)
  300. return prediction_scores, seq_relationship_score
  301. @auto_docstring
  302. class FNetPreTrainedModel(PreTrainedModel):
  303. config: FNetConfig
  304. base_model_prefix = "fnet"
  305. supports_gradient_checkpointing = True
  306. def _init_weights(self, module):
  307. """Initialize the weights"""
  308. if isinstance(module, nn.Linear):
  309. # Slightly different from the TF version which uses truncated_normal for initialization
  310. # cf https://github.com/pytorch/pytorch/pull/5617
  311. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  312. # NOTE: Original code uses same initialization as weights for biases as well.
  313. if module.bias is not None:
  314. module.bias.data.zero_()
  315. elif isinstance(module, nn.Embedding):
  316. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  317. if module.padding_idx is not None:
  318. module.weight.data[module.padding_idx].zero_()
  319. elif isinstance(module, nn.LayerNorm):
  320. module.bias.data.zero_()
  321. module.weight.data.fill_(1.0)
  322. @dataclass
  323. @auto_docstring(
  324. custom_intro="""
  325. Output type of [`FNetForPreTraining`].
  326. """
  327. )
  328. class FNetForPreTrainingOutput(ModelOutput):
  329. r"""
  330. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  331. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  332. (classification) loss.
  333. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  334. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  335. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  336. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  337. before SoftMax).
  338. """
  339. loss: Optional[torch.FloatTensor] = None
  340. prediction_logits: Optional[torch.FloatTensor] = None
  341. seq_relationship_logits: Optional[torch.FloatTensor] = None
  342. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  343. @auto_docstring
  344. class FNetModel(FNetPreTrainedModel):
  345. """
  346. The model can behave as an encoder, following the architecture described in [FNet: Mixing Tokens with Fourier
  347. Transforms](https://huggingface.co/papers/2105.03824) by James Lee-Thorp, Joshua Ainslie, Ilya Eckstein, Santiago Ontanon.
  348. """
  349. def __init__(self, config, add_pooling_layer=True):
  350. r"""
  351. add_pooling_layer (bool, *optional*, defaults to `True`):
  352. Whether to add a pooling layer
  353. """
  354. super().__init__(config)
  355. self.config = config
  356. self.embeddings = FNetEmbeddings(config)
  357. self.encoder = FNetEncoder(config)
  358. self.pooler = FNetPooler(config) if add_pooling_layer else None
  359. # Initialize weights and apply final processing
  360. self.post_init()
  361. def get_input_embeddings(self):
  362. return self.embeddings.word_embeddings
  363. def set_input_embeddings(self, value):
  364. self.embeddings.word_embeddings = value
  365. @auto_docstring
  366. def forward(
  367. self,
  368. input_ids: Optional[torch.LongTensor] = None,
  369. token_type_ids: Optional[torch.LongTensor] = None,
  370. position_ids: Optional[torch.LongTensor] = None,
  371. inputs_embeds: Optional[torch.FloatTensor] = None,
  372. output_hidden_states: Optional[bool] = None,
  373. return_dict: Optional[bool] = None,
  374. ) -> Union[tuple, BaseModelOutput]:
  375. output_hidden_states = (
  376. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  377. )
  378. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  379. if input_ids is not None and inputs_embeds is not None:
  380. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  381. elif input_ids is not None:
  382. input_shape = input_ids.size()
  383. batch_size, seq_length = input_shape
  384. elif inputs_embeds is not None:
  385. input_shape = inputs_embeds.size()[:-1]
  386. batch_size, seq_length = input_shape
  387. else:
  388. raise ValueError("You have to specify either input_ids or inputs_embeds")
  389. if (
  390. self.config.use_tpu_fourier_optimizations
  391. and seq_length <= 4096
  392. and self.config.tpu_short_seq_length != seq_length
  393. ):
  394. raise ValueError(
  395. "The `tpu_short_seq_length` in FNetConfig should be set equal to the sequence length being passed to"
  396. " the model when using TPU optimizations."
  397. )
  398. device = input_ids.device if input_ids is not None else inputs_embeds.device
  399. if token_type_ids is None:
  400. if hasattr(self.embeddings, "token_type_ids"):
  401. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  402. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  403. token_type_ids = buffered_token_type_ids_expanded
  404. else:
  405. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  406. embedding_output = self.embeddings(
  407. input_ids=input_ids,
  408. position_ids=position_ids,
  409. token_type_ids=token_type_ids,
  410. inputs_embeds=inputs_embeds,
  411. )
  412. encoder_outputs = self.encoder(
  413. embedding_output,
  414. output_hidden_states=output_hidden_states,
  415. return_dict=return_dict,
  416. )
  417. sequence_output = encoder_outputs[0]
  418. pooler_output = self.pooler(sequence_output) if self.pooler is not None else None
  419. if not return_dict:
  420. return (sequence_output, pooler_output) + encoder_outputs[1:]
  421. return BaseModelOutputWithPooling(
  422. last_hidden_state=sequence_output,
  423. pooler_output=pooler_output,
  424. hidden_states=encoder_outputs.hidden_states,
  425. )
  426. @auto_docstring(
  427. custom_intro="""
  428. FNet Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  429. sentence prediction (classification)` head.
  430. """
  431. )
  432. class FNetForPreTraining(FNetPreTrainedModel):
  433. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  434. def __init__(self, config):
  435. super().__init__(config)
  436. self.fnet = FNetModel(config)
  437. self.cls = FNetPreTrainingHeads(config)
  438. # Initialize weights and apply final processing
  439. self.post_init()
  440. def get_output_embeddings(self):
  441. return self.cls.predictions.decoder
  442. def set_output_embeddings(self, new_embeddings):
  443. self.cls.predictions.decoder = new_embeddings
  444. self.cls.predictions.bias = new_embeddings.bias
  445. @auto_docstring
  446. def forward(
  447. self,
  448. input_ids: Optional[torch.Tensor] = None,
  449. token_type_ids: Optional[torch.Tensor] = None,
  450. position_ids: Optional[torch.Tensor] = None,
  451. inputs_embeds: Optional[torch.Tensor] = None,
  452. labels: Optional[torch.Tensor] = None,
  453. next_sentence_label: Optional[torch.Tensor] = None,
  454. output_hidden_states: Optional[bool] = None,
  455. return_dict: Optional[bool] = None,
  456. ) -> Union[tuple, FNetForPreTrainingOutput]:
  457. r"""
  458. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  459. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  460. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  461. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  462. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  463. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  464. (see `input_ids` docstring) Indices should be in `[0, 1]`:
  465. - 0 indicates sequence B is a continuation of sequence A,
  466. - 1 indicates sequence B is a random sequence.
  467. Example:
  468. ```python
  469. >>> from transformers import AutoTokenizer, FNetForPreTraining
  470. >>> import torch
  471. >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
  472. >>> model = FNetForPreTraining.from_pretrained("google/fnet-base")
  473. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  474. >>> outputs = model(**inputs)
  475. >>> prediction_logits = outputs.prediction_logits
  476. >>> seq_relationship_logits = outputs.seq_relationship_logits
  477. ```"""
  478. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  479. outputs = self.fnet(
  480. input_ids,
  481. token_type_ids=token_type_ids,
  482. position_ids=position_ids,
  483. inputs_embeds=inputs_embeds,
  484. output_hidden_states=output_hidden_states,
  485. return_dict=return_dict,
  486. )
  487. sequence_output, pooled_output = outputs[:2]
  488. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  489. total_loss = None
  490. if labels is not None and next_sentence_label is not None:
  491. loss_fct = CrossEntropyLoss()
  492. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  493. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  494. total_loss = masked_lm_loss + next_sentence_loss
  495. if not return_dict:
  496. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  497. return ((total_loss,) + output) if total_loss is not None else output
  498. return FNetForPreTrainingOutput(
  499. loss=total_loss,
  500. prediction_logits=prediction_scores,
  501. seq_relationship_logits=seq_relationship_score,
  502. hidden_states=outputs.hidden_states,
  503. )
  504. @auto_docstring
  505. class FNetForMaskedLM(FNetPreTrainedModel):
  506. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  507. def __init__(self, config):
  508. super().__init__(config)
  509. self.fnet = FNetModel(config)
  510. self.cls = FNetOnlyMLMHead(config)
  511. # Initialize weights and apply final processing
  512. self.post_init()
  513. def get_output_embeddings(self):
  514. return self.cls.predictions.decoder
  515. def set_output_embeddings(self, new_embeddings):
  516. self.cls.predictions.decoder = new_embeddings
  517. self.cls.predictions.bias = new_embeddings.bias
  518. @auto_docstring
  519. def forward(
  520. self,
  521. input_ids: Optional[torch.Tensor] = None,
  522. token_type_ids: Optional[torch.Tensor] = None,
  523. position_ids: Optional[torch.Tensor] = None,
  524. inputs_embeds: Optional[torch.Tensor] = None,
  525. labels: Optional[torch.Tensor] = None,
  526. output_hidden_states: Optional[bool] = None,
  527. return_dict: Optional[bool] = None,
  528. ) -> Union[tuple, MaskedLMOutput]:
  529. r"""
  530. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  531. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  532. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  533. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
  534. """
  535. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  536. outputs = self.fnet(
  537. input_ids,
  538. token_type_ids=token_type_ids,
  539. position_ids=position_ids,
  540. inputs_embeds=inputs_embeds,
  541. output_hidden_states=output_hidden_states,
  542. return_dict=return_dict,
  543. )
  544. sequence_output = outputs[0]
  545. prediction_scores = self.cls(sequence_output)
  546. masked_lm_loss = None
  547. if labels is not None:
  548. loss_fct = CrossEntropyLoss() # -100 index = padding token
  549. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  550. if not return_dict:
  551. output = (prediction_scores,) + outputs[2:]
  552. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  553. return MaskedLMOutput(loss=masked_lm_loss, logits=prediction_scores, hidden_states=outputs.hidden_states)
  554. @auto_docstring(
  555. custom_intro="""
  556. FNet Model with a `next sentence prediction (classification)` head on top.
  557. """
  558. )
  559. class FNetForNextSentencePrediction(FNetPreTrainedModel):
  560. def __init__(self, config):
  561. super().__init__(config)
  562. self.fnet = FNetModel(config)
  563. self.cls = FNetOnlyNSPHead(config)
  564. # Initialize weights and apply final processing
  565. self.post_init()
  566. @auto_docstring
  567. def forward(
  568. self,
  569. input_ids: Optional[torch.Tensor] = None,
  570. token_type_ids: Optional[torch.Tensor] = None,
  571. position_ids: Optional[torch.Tensor] = None,
  572. inputs_embeds: Optional[torch.Tensor] = None,
  573. labels: Optional[torch.Tensor] = None,
  574. output_hidden_states: Optional[bool] = None,
  575. return_dict: Optional[bool] = None,
  576. **kwargs,
  577. ) -> Union[tuple, NextSentencePredictorOutput]:
  578. r"""
  579. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  580. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  581. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  582. - 0 indicates sequence B is a continuation of sequence A,
  583. - 1 indicates sequence B is a random sequence.
  584. Example:
  585. ```python
  586. >>> from transformers import AutoTokenizer, FNetForNextSentencePrediction
  587. >>> import torch
  588. >>> tokenizer = AutoTokenizer.from_pretrained("google/fnet-base")
  589. >>> model = FNetForNextSentencePrediction.from_pretrained("google/fnet-base")
  590. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  591. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  592. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  593. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  594. >>> logits = outputs.logits
  595. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  596. ```"""
  597. if "next_sentence_label" in kwargs:
  598. warnings.warn(
  599. "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
  600. " `labels` instead.",
  601. FutureWarning,
  602. )
  603. labels = kwargs.pop("next_sentence_label")
  604. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  605. outputs = self.fnet(
  606. input_ids,
  607. token_type_ids=token_type_ids,
  608. position_ids=position_ids,
  609. inputs_embeds=inputs_embeds,
  610. output_hidden_states=output_hidden_states,
  611. return_dict=return_dict,
  612. )
  613. pooled_output = outputs[1]
  614. seq_relationship_scores = self.cls(pooled_output)
  615. next_sentence_loss = None
  616. if labels is not None:
  617. loss_fct = CrossEntropyLoss()
  618. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  619. if not return_dict:
  620. output = (seq_relationship_scores,) + outputs[2:]
  621. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  622. return NextSentencePredictorOutput(
  623. loss=next_sentence_loss,
  624. logits=seq_relationship_scores,
  625. hidden_states=outputs.hidden_states,
  626. )
  627. @auto_docstring(
  628. custom_intro="""
  629. FNet Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  630. output) e.g. for GLUE tasks.
  631. """
  632. )
  633. class FNetForSequenceClassification(FNetPreTrainedModel):
  634. def __init__(self, config):
  635. super().__init__(config)
  636. self.num_labels = config.num_labels
  637. self.fnet = FNetModel(config)
  638. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  639. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  640. # Initialize weights and apply final processing
  641. self.post_init()
  642. @auto_docstring
  643. def forward(
  644. self,
  645. input_ids: Optional[torch.Tensor] = None,
  646. token_type_ids: Optional[torch.Tensor] = None,
  647. position_ids: Optional[torch.Tensor] = None,
  648. inputs_embeds: Optional[torch.Tensor] = None,
  649. labels: Optional[torch.Tensor] = None,
  650. output_hidden_states: Optional[bool] = None,
  651. return_dict: Optional[bool] = None,
  652. ) -> Union[tuple, SequenceClassifierOutput]:
  653. r"""
  654. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  655. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  656. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  657. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  658. """
  659. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  660. outputs = self.fnet(
  661. input_ids,
  662. token_type_ids=token_type_ids,
  663. position_ids=position_ids,
  664. inputs_embeds=inputs_embeds,
  665. output_hidden_states=output_hidden_states,
  666. return_dict=return_dict,
  667. )
  668. pooled_output = outputs[1]
  669. pooled_output = self.dropout(pooled_output)
  670. logits = self.classifier(pooled_output)
  671. loss = None
  672. if labels is not None:
  673. if self.config.problem_type is None:
  674. if self.num_labels == 1:
  675. self.config.problem_type = "regression"
  676. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  677. self.config.problem_type = "single_label_classification"
  678. else:
  679. self.config.problem_type = "multi_label_classification"
  680. if self.config.problem_type == "regression":
  681. loss_fct = MSELoss()
  682. if self.num_labels == 1:
  683. loss = loss_fct(logits.squeeze(), labels.squeeze())
  684. else:
  685. loss = loss_fct(logits, labels)
  686. elif self.config.problem_type == "single_label_classification":
  687. loss_fct = CrossEntropyLoss()
  688. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  689. elif self.config.problem_type == "multi_label_classification":
  690. loss_fct = BCEWithLogitsLoss()
  691. loss = loss_fct(logits, labels)
  692. if not return_dict:
  693. output = (logits,) + outputs[2:]
  694. return ((loss,) + output) if loss is not None else output
  695. return SequenceClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  696. @auto_docstring
  697. class FNetForMultipleChoice(FNetPreTrainedModel):
  698. def __init__(self, config):
  699. super().__init__(config)
  700. self.fnet = FNetModel(config)
  701. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  702. self.classifier = nn.Linear(config.hidden_size, 1)
  703. # Initialize weights and apply final processing
  704. self.post_init()
  705. @auto_docstring
  706. def forward(
  707. self,
  708. input_ids: Optional[torch.Tensor] = None,
  709. token_type_ids: Optional[torch.Tensor] = None,
  710. position_ids: Optional[torch.Tensor] = None,
  711. inputs_embeds: Optional[torch.Tensor] = None,
  712. labels: Optional[torch.Tensor] = None,
  713. output_hidden_states: Optional[bool] = None,
  714. return_dict: Optional[bool] = None,
  715. ) -> Union[tuple, MultipleChoiceModelOutput]:
  716. r"""
  717. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  718. Indices of input sequence tokens in the vocabulary.
  719. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  720. [`PreTrainedTokenizer.__call__`] for details.
  721. [What are input IDs?](../glossary#input-ids)
  722. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  723. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  724. 1]`:
  725. - 0 corresponds to a *sentence A* token,
  726. - 1 corresponds to a *sentence B* token.
  727. [What are token type IDs?](../glossary#token-type-ids)
  728. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  729. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  730. config.max_position_embeddings - 1]`.
  731. [What are position IDs?](../glossary#position-ids)
  732. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  733. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  734. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  735. model's internal embedding lookup matrix.
  736. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  737. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  738. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  739. `input_ids` above)
  740. """
  741. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  742. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  743. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  744. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  745. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  746. inputs_embeds = (
  747. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  748. if inputs_embeds is not None
  749. else None
  750. )
  751. outputs = self.fnet(
  752. input_ids,
  753. token_type_ids=token_type_ids,
  754. position_ids=position_ids,
  755. inputs_embeds=inputs_embeds,
  756. output_hidden_states=output_hidden_states,
  757. return_dict=return_dict,
  758. )
  759. pooled_output = outputs[1]
  760. pooled_output = self.dropout(pooled_output)
  761. logits = self.classifier(pooled_output)
  762. reshaped_logits = logits.view(-1, num_choices)
  763. loss = None
  764. if labels is not None:
  765. loss_fct = CrossEntropyLoss()
  766. loss = loss_fct(reshaped_logits, labels)
  767. if not return_dict:
  768. output = (reshaped_logits,) + outputs[2:]
  769. return ((loss,) + output) if loss is not None else output
  770. return MultipleChoiceModelOutput(loss=loss, logits=reshaped_logits, hidden_states=outputs.hidden_states)
  771. @auto_docstring
  772. class FNetForTokenClassification(FNetPreTrainedModel):
  773. def __init__(self, config):
  774. super().__init__(config)
  775. self.num_labels = config.num_labels
  776. self.fnet = FNetModel(config)
  777. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  778. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  779. # Initialize weights and apply final processing
  780. self.post_init()
  781. @auto_docstring
  782. def forward(
  783. self,
  784. input_ids: Optional[torch.Tensor] = None,
  785. token_type_ids: Optional[torch.Tensor] = None,
  786. position_ids: Optional[torch.Tensor] = None,
  787. inputs_embeds: Optional[torch.Tensor] = None,
  788. labels: Optional[torch.Tensor] = None,
  789. output_hidden_states: Optional[bool] = None,
  790. return_dict: Optional[bool] = None,
  791. ) -> Union[tuple, TokenClassifierOutput]:
  792. r"""
  793. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  794. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  795. """
  796. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  797. outputs = self.fnet(
  798. input_ids,
  799. token_type_ids=token_type_ids,
  800. position_ids=position_ids,
  801. inputs_embeds=inputs_embeds,
  802. output_hidden_states=output_hidden_states,
  803. return_dict=return_dict,
  804. )
  805. sequence_output = outputs[0]
  806. sequence_output = self.dropout(sequence_output)
  807. logits = self.classifier(sequence_output)
  808. loss = None
  809. if labels is not None:
  810. loss_fct = CrossEntropyLoss()
  811. # Only keep active parts of the loss
  812. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  813. if not return_dict:
  814. output = (logits,) + outputs[2:]
  815. return ((loss,) + output) if loss is not None else output
  816. return TokenClassifierOutput(loss=loss, logits=logits, hidden_states=outputs.hidden_states)
  817. @auto_docstring
  818. class FNetForQuestionAnswering(FNetPreTrainedModel):
  819. def __init__(self, config):
  820. super().__init__(config)
  821. self.num_labels = config.num_labels
  822. self.fnet = FNetModel(config)
  823. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  824. # Initialize weights and apply final processing
  825. self.post_init()
  826. @auto_docstring
  827. def forward(
  828. self,
  829. input_ids: Optional[torch.Tensor] = None,
  830. token_type_ids: Optional[torch.Tensor] = None,
  831. position_ids: Optional[torch.Tensor] = None,
  832. inputs_embeds: Optional[torch.Tensor] = None,
  833. start_positions: Optional[torch.Tensor] = None,
  834. end_positions: Optional[torch.Tensor] = None,
  835. output_hidden_states: Optional[bool] = None,
  836. return_dict: Optional[bool] = None,
  837. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  838. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  839. outputs = self.fnet(
  840. input_ids,
  841. token_type_ids=token_type_ids,
  842. position_ids=position_ids,
  843. inputs_embeds=inputs_embeds,
  844. output_hidden_states=output_hidden_states,
  845. return_dict=return_dict,
  846. )
  847. sequence_output = outputs[0]
  848. logits = self.qa_outputs(sequence_output)
  849. start_logits, end_logits = logits.split(1, dim=-1)
  850. start_logits = start_logits.squeeze(-1).contiguous()
  851. end_logits = end_logits.squeeze(-1).contiguous()
  852. total_loss = None
  853. if start_positions is not None and end_positions is not None:
  854. # If we are on multi-GPU, split add a dimension
  855. if len(start_positions.size()) > 1:
  856. start_positions = start_positions.squeeze(-1)
  857. if len(end_positions.size()) > 1:
  858. end_positions = end_positions.squeeze(-1)
  859. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  860. ignored_index = start_logits.size(1)
  861. start_positions = start_positions.clamp(0, ignored_index)
  862. end_positions = end_positions.clamp(0, ignored_index)
  863. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  864. start_loss = loss_fct(start_logits, start_positions)
  865. end_loss = loss_fct(end_logits, end_positions)
  866. total_loss = (start_loss + end_loss) / 2
  867. if not return_dict:
  868. output = (start_logits, end_logits) + outputs[2:]
  869. return ((total_loss,) + output) if total_loss is not None else output
  870. return QuestionAnsweringModelOutput(
  871. loss=total_loss, start_logits=start_logits, end_logits=end_logits, hidden_states=outputs.hidden_states
  872. )
  873. __all__ = [
  874. "FNetForMaskedLM",
  875. "FNetForMultipleChoice",
  876. "FNetForNextSentencePrediction",
  877. "FNetForPreTraining",
  878. "FNetForQuestionAnswering",
  879. "FNetForSequenceClassification",
  880. "FNetForTokenClassification",
  881. "FNetLayer",
  882. "FNetModel",
  883. "FNetPreTrainedModel",
  884. ]