modeling_funnel.py 60 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452
  1. # coding=utf-8
  2. # Copyright 2020-present Google Brain and Carnegie Mellon University Authors and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch Funnel Transformer model."""
  16. import os
  17. from dataclasses import dataclass
  18. from typing import Optional, Union
  19. import numpy as np
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_outputs import (
  25. BaseModelOutput,
  26. MaskedLMOutput,
  27. MultipleChoiceModelOutput,
  28. QuestionAnsweringModelOutput,
  29. SequenceClassifierOutput,
  30. TokenClassifierOutput,
  31. )
  32. from ...modeling_utils import PreTrainedModel
  33. from ...utils import ModelOutput, auto_docstring, logging
  34. from .configuration_funnel import FunnelConfig
  35. logger = logging.get_logger(__name__)
  36. INF = 1e6
  37. def load_tf_weights_in_funnel(model, config, tf_checkpoint_path):
  38. """Load tf checkpoints in a pytorch model."""
  39. try:
  40. import re
  41. import numpy as np
  42. import tensorflow as tf
  43. except ImportError:
  44. logger.error(
  45. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  46. "https://www.tensorflow.org/install/ for installation instructions."
  47. )
  48. raise
  49. tf_path = os.path.abspath(tf_checkpoint_path)
  50. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  51. # Load weights from TF model
  52. init_vars = tf.train.list_variables(tf_path)
  53. names = []
  54. arrays = []
  55. for name, shape in init_vars:
  56. logger.info(f"Loading TF weight {name} with shape {shape}")
  57. array = tf.train.load_variable(tf_path, name)
  58. names.append(name)
  59. arrays.append(array)
  60. _layer_map = {
  61. "k": "k_head",
  62. "q": "q_head",
  63. "v": "v_head",
  64. "o": "post_proj",
  65. "layer_1": "linear_1",
  66. "layer_2": "linear_2",
  67. "rel_attn": "attention",
  68. "ff": "ffn",
  69. "kernel": "weight",
  70. "gamma": "weight",
  71. "beta": "bias",
  72. "lookup_table": "weight",
  73. "word_embedding": "word_embeddings",
  74. "input": "embeddings",
  75. }
  76. for name, array in zip(names, arrays):
  77. name = name.split("/")
  78. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  79. # which are not required for using pretrained model
  80. if any(
  81. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  82. for n in name
  83. ):
  84. logger.info(f"Skipping {'/'.join(name)}")
  85. continue
  86. if name[0] == "generator":
  87. continue
  88. pointer = model
  89. skipped = False
  90. for m_name in name[1:]:
  91. if not isinstance(pointer, FunnelPositionwiseFFN) and re.fullmatch(r"layer_\d+", m_name):
  92. layer_index = int(re.search(r"layer_(\d+)", m_name).groups()[0])
  93. if layer_index < config.num_hidden_layers:
  94. block_idx = 0
  95. while layer_index >= config.block_sizes[block_idx]:
  96. layer_index -= config.block_sizes[block_idx]
  97. block_idx += 1
  98. pointer = pointer.blocks[block_idx][layer_index]
  99. else:
  100. layer_index -= config.num_hidden_layers
  101. pointer = pointer.layers[layer_index]
  102. elif m_name == "r" and isinstance(pointer, FunnelRelMultiheadAttention):
  103. pointer = pointer.r_kernel
  104. break
  105. elif m_name in _layer_map:
  106. pointer = getattr(pointer, _layer_map[m_name])
  107. else:
  108. try:
  109. pointer = getattr(pointer, m_name)
  110. except AttributeError:
  111. print(f"Skipping {'/'.join(name)}", array.shape)
  112. skipped = True
  113. break
  114. if not skipped:
  115. if len(pointer.shape) != len(array.shape):
  116. array = array.reshape(pointer.shape)
  117. if m_name == "kernel":
  118. array = np.transpose(array)
  119. pointer.data = torch.from_numpy(array)
  120. return model
  121. class FunnelEmbeddings(nn.Module):
  122. def __init__(self, config: FunnelConfig) -> None:
  123. super().__init__()
  124. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  125. self.layer_norm = nn.LayerNorm(config.d_model, eps=config.layer_norm_eps)
  126. self.dropout = nn.Dropout(config.hidden_dropout)
  127. def forward(
  128. self, input_ids: Optional[torch.Tensor] = None, inputs_embeds: Optional[torch.Tensor] = None
  129. ) -> torch.Tensor:
  130. if inputs_embeds is None:
  131. inputs_embeds = self.word_embeddings(input_ids)
  132. embeddings = self.layer_norm(inputs_embeds)
  133. embeddings = self.dropout(embeddings)
  134. return embeddings
  135. class FunnelAttentionStructure(nn.Module):
  136. """
  137. Contains helpers for `FunnelRelMultiheadAttention `.
  138. """
  139. cls_token_type_id: int = 2
  140. def __init__(self, config: FunnelConfig) -> None:
  141. super().__init__()
  142. self.config = config
  143. self.sin_dropout = nn.Dropout(config.hidden_dropout)
  144. self.cos_dropout = nn.Dropout(config.hidden_dropout)
  145. # Track where we are at in terms of pooling from the original input, e.g., by how much the sequence length was
  146. # divided.
  147. self.pooling_mult = None
  148. def init_attention_inputs(
  149. self,
  150. inputs_embeds: torch.Tensor,
  151. attention_mask: Optional[torch.Tensor] = None,
  152. token_type_ids: Optional[torch.Tensor] = None,
  153. ) -> tuple[torch.Tensor]:
  154. """Returns the attention inputs associated to the inputs of the model."""
  155. # inputs_embeds has shape batch_size x seq_len x d_model
  156. # attention_mask and token_type_ids have shape batch_size x seq_len
  157. self.pooling_mult = 1
  158. self.seq_len = seq_len = inputs_embeds.size(1)
  159. position_embeds = self.get_position_embeds(seq_len, inputs_embeds.dtype, inputs_embeds.device)
  160. token_type_mat = self.token_type_ids_to_mat(token_type_ids) if token_type_ids is not None else None
  161. cls_mask = (
  162. nn.functional.pad(inputs_embeds.new_ones([seq_len - 1, seq_len - 1]), (1, 0, 1, 0))
  163. if self.config.separate_cls
  164. else None
  165. )
  166. return (position_embeds, token_type_mat, attention_mask, cls_mask)
  167. def token_type_ids_to_mat(self, token_type_ids: torch.Tensor) -> torch.Tensor:
  168. """Convert `token_type_ids` to `token_type_mat`."""
  169. token_type_mat = token_type_ids[:, :, None] == token_type_ids[:, None]
  170. # Treat <cls> as in the same segment as both A & B
  171. cls_ids = token_type_ids == self.cls_token_type_id
  172. cls_mat = cls_ids[:, :, None] | cls_ids[:, None]
  173. return cls_mat | token_type_mat
  174. def get_position_embeds(
  175. self, seq_len: int, dtype: torch.dtype, device: torch.device
  176. ) -> Union[tuple[torch.Tensor], list[list[torch.Tensor]]]:
  177. """
  178. Create and cache inputs related to relative position encoding. Those are very different depending on whether we
  179. are using the factorized or the relative shift attention:
  180. For the factorized attention, it returns the matrices (phi, pi, psi, omega) used in the paper, appendix A.2.2,
  181. final formula.
  182. For the relative shift attention, it returns all possible vectors R used in the paper, appendix A.2.1, final
  183. formula.
  184. Paper link: https://huggingface.co/papers/2006.03236
  185. """
  186. d_model = self.config.d_model
  187. if self.config.attention_type == "factorized":
  188. # Notations from the paper, appending A.2.2, final formula.
  189. # We need to create and return the matrices phi, psi, pi and omega.
  190. pos_seq = torch.arange(0, seq_len, 1.0, dtype=torch.int64, device=device).to(dtype)
  191. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  192. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  193. sinusoid = pos_seq[:, None] * inv_freq[None]
  194. sin_embed = torch.sin(sinusoid)
  195. sin_embed_d = self.sin_dropout(sin_embed)
  196. cos_embed = torch.cos(sinusoid)
  197. cos_embed_d = self.cos_dropout(cos_embed)
  198. # This is different from the formula on the paper...
  199. phi = torch.cat([sin_embed_d, sin_embed_d], dim=-1)
  200. psi = torch.cat([cos_embed, sin_embed], dim=-1)
  201. pi = torch.cat([cos_embed_d, cos_embed_d], dim=-1)
  202. omega = torch.cat([-sin_embed, cos_embed], dim=-1)
  203. return (phi, pi, psi, omega)
  204. else:
  205. # Notations from the paper, appending A.2.1, final formula.
  206. # We need to create and return all the possible vectors R for all blocks and shifts.
  207. freq_seq = torch.arange(0, d_model // 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  208. inv_freq = 1 / (10000 ** (freq_seq / (d_model // 2)))
  209. # Maximum relative positions for the first input
  210. rel_pos_id = torch.arange(-seq_len * 2, seq_len * 2, 1.0, dtype=torch.int64, device=device).to(dtype)
  211. zero_offset = seq_len * 2
  212. sinusoid = rel_pos_id[:, None] * inv_freq[None]
  213. sin_embed = self.sin_dropout(torch.sin(sinusoid))
  214. cos_embed = self.cos_dropout(torch.cos(sinusoid))
  215. pos_embed = torch.cat([sin_embed, cos_embed], dim=-1)
  216. pos = torch.arange(0, seq_len, dtype=torch.int64, device=device).to(dtype)
  217. pooled_pos = pos
  218. position_embeds_list = []
  219. for block_index in range(0, self.config.num_blocks):
  220. # For each block with block_index > 0, we need two types position embeddings:
  221. # - Attention(pooled-q, unpooled-kv)
  222. # - Attention(pooled-q, pooled-kv)
  223. # For block_index = 0 we only need the second one and leave the first one as None.
  224. # First type
  225. if block_index == 0:
  226. position_embeds_pooling = None
  227. else:
  228. pooled_pos = self.stride_pool_pos(pos, block_index)
  229. # construct rel_pos_id
  230. stride = 2 ** (block_index - 1)
  231. rel_pos = self.relative_pos(pos, stride, pooled_pos, shift=2)
  232. rel_pos = rel_pos[:, None] + zero_offset
  233. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  234. position_embeds_pooling = torch.gather(pos_embed, 0, rel_pos)
  235. # Second type
  236. pos = pooled_pos
  237. stride = 2**block_index
  238. rel_pos = self.relative_pos(pos, stride)
  239. rel_pos = rel_pos[:, None] + zero_offset
  240. rel_pos = rel_pos.expand(rel_pos.size(0), d_model)
  241. position_embeds_no_pooling = torch.gather(pos_embed, 0, rel_pos)
  242. position_embeds_list.append([position_embeds_no_pooling, position_embeds_pooling])
  243. return position_embeds_list
  244. def stride_pool_pos(self, pos_id: torch.Tensor, block_index: int):
  245. """
  246. Pool `pos_id` while keeping the cls token separate (if `config.separate_cls=True`).
  247. """
  248. if self.config.separate_cls:
  249. # Under separate <cls>, we treat the <cls> as the first token in
  250. # the previous block of the 1st real block. Since the 1st real
  251. # block always has position 1, the position of the previous block
  252. # will be at `1 - 2 ** block_index`.
  253. cls_pos = pos_id.new_tensor([-(2**block_index) + 1])
  254. pooled_pos_id = pos_id[1:-1] if self.config.truncate_seq else pos_id[1:]
  255. return torch.cat([cls_pos, pooled_pos_id[::2]], 0)
  256. else:
  257. return pos_id[::2]
  258. def relative_pos(self, pos: torch.Tensor, stride: int, pooled_pos=None, shift: int = 1) -> torch.Tensor:
  259. """
  260. Build the relative positional vector between `pos` and `pooled_pos`.
  261. """
  262. if pooled_pos is None:
  263. pooled_pos = pos
  264. ref_point = pooled_pos[0] - pos[0]
  265. num_remove = shift * len(pooled_pos)
  266. max_dist = ref_point + num_remove * stride
  267. min_dist = pooled_pos[0] - pos[-1]
  268. return torch.arange(max_dist, min_dist - 1, -stride, dtype=torch.long, device=pos.device)
  269. def stride_pool(
  270. self,
  271. tensor: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]],
  272. axis: Union[int, tuple[int], list[int]],
  273. ) -> torch.Tensor:
  274. """
  275. Perform pooling by stride slicing the tensor along the given axis.
  276. """
  277. if tensor is None:
  278. return None
  279. # Do the stride pool recursively if axis is a list or a tuple of ints.
  280. if isinstance(axis, (list, tuple)):
  281. for ax in axis:
  282. tensor = self.stride_pool(tensor, ax)
  283. return tensor
  284. # Do the stride pool recursively if tensor is a list or tuple of tensors.
  285. if isinstance(tensor, (tuple, list)):
  286. return type(tensor)(self.stride_pool(x, axis) for x in tensor)
  287. # Deal with negative axis
  288. axis %= tensor.ndim
  289. axis_slice = (
  290. slice(None, -1, 2) if self.config.separate_cls and self.config.truncate_seq else slice(None, None, 2)
  291. )
  292. enc_slice = [slice(None)] * axis + [axis_slice]
  293. if self.config.separate_cls:
  294. cls_slice = [slice(None)] * axis + [slice(None, 1)]
  295. tensor = torch.cat([tensor[cls_slice], tensor], axis=axis)
  296. return tensor[enc_slice]
  297. def pool_tensor(
  298. self, tensor: Union[torch.Tensor, tuple[torch.Tensor], list[torch.Tensor]], mode: str = "mean", stride: int = 2
  299. ) -> torch.Tensor:
  300. """Apply 1D pooling to a tensor of size [B x T (x H)]."""
  301. if tensor is None:
  302. return None
  303. # Do the pool recursively if tensor is a list or tuple of tensors.
  304. if isinstance(tensor, (tuple, list)):
  305. return type(tensor)(self.pool_tensor(tensor, mode=mode, stride=stride) for x in tensor)
  306. if self.config.separate_cls:
  307. suffix = tensor[:, :-1] if self.config.truncate_seq else tensor
  308. tensor = torch.cat([tensor[:, :1], suffix], dim=1)
  309. ndim = tensor.ndim
  310. if ndim == 2:
  311. tensor = tensor[:, None, :, None]
  312. elif ndim == 3:
  313. tensor = tensor[:, None, :, :]
  314. # Stride is applied on the second-to-last dimension.
  315. stride = (stride, 1)
  316. if mode == "mean":
  317. tensor = nn.functional.avg_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  318. elif mode == "max":
  319. tensor = nn.functional.max_pool2d(tensor, stride, stride=stride, ceil_mode=True)
  320. elif mode == "min":
  321. tensor = -nn.functional.max_pool2d(-tensor, stride, stride=stride, ceil_mode=True)
  322. else:
  323. raise NotImplementedError("The supported modes are 'mean', 'max' and 'min'.")
  324. if ndim == 2:
  325. return tensor[:, 0, :, 0]
  326. elif ndim == 3:
  327. return tensor[:, 0]
  328. return tensor
  329. def pre_attention_pooling(
  330. self, output, attention_inputs: tuple[torch.Tensor]
  331. ) -> tuple[torch.Tensor, tuple[torch.Tensor]]:
  332. """Pool `output` and the proper parts of `attention_inputs` before the attention layer."""
  333. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  334. if self.config.pool_q_only:
  335. if self.config.attention_type == "factorized":
  336. position_embeds = self.stride_pool(position_embeds[:2], 0) + position_embeds[2:]
  337. token_type_mat = self.stride_pool(token_type_mat, 1)
  338. cls_mask = self.stride_pool(cls_mask, 0)
  339. output = self.pool_tensor(output, mode=self.config.pooling_type)
  340. else:
  341. self.pooling_mult *= 2
  342. if self.config.attention_type == "factorized":
  343. position_embeds = self.stride_pool(position_embeds, 0)
  344. token_type_mat = self.stride_pool(token_type_mat, [1, 2])
  345. cls_mask = self.stride_pool(cls_mask, [1, 2])
  346. attention_mask = self.pool_tensor(attention_mask, mode="min")
  347. output = self.pool_tensor(output, mode=self.config.pooling_type)
  348. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  349. return output, attention_inputs
  350. def post_attention_pooling(self, attention_inputs: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
  351. """Pool the proper parts of `attention_inputs` after the attention layer."""
  352. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  353. if self.config.pool_q_only:
  354. self.pooling_mult *= 2
  355. if self.config.attention_type == "factorized":
  356. position_embeds = position_embeds[:2] + self.stride_pool(position_embeds[2:], 0)
  357. token_type_mat = self.stride_pool(token_type_mat, 2)
  358. cls_mask = self.stride_pool(cls_mask, 1)
  359. attention_mask = self.pool_tensor(attention_mask, mode="min")
  360. attention_inputs = (position_embeds, token_type_mat, attention_mask, cls_mask)
  361. return attention_inputs
  362. def _relative_shift_gather(positional_attn: torch.Tensor, context_len: int, shift: int) -> torch.Tensor:
  363. batch_size, n_head, seq_len, max_rel_len = positional_attn.shape
  364. # max_rel_len = 2 * context_len + shift -1 is the numbers of possible relative positions i-j
  365. # What's next is the same as doing the following gather, which might be clearer code but less efficient.
  366. # idxs = context_len + torch.arange(0, context_len).unsqueeze(0) - torch.arange(0, seq_len).unsqueeze(1)
  367. # # matrix of context_len + i-j
  368. # return positional_attn.gather(3, idxs.expand([batch_size, n_head, context_len, context_len]))
  369. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, max_rel_len, seq_len])
  370. positional_attn = positional_attn[:, :, shift:, :]
  371. positional_attn = torch.reshape(positional_attn, [batch_size, n_head, seq_len, max_rel_len - shift])
  372. positional_attn = positional_attn[..., :context_len]
  373. return positional_attn
  374. class FunnelRelMultiheadAttention(nn.Module):
  375. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  376. super().__init__()
  377. self.config = config
  378. self.block_index = block_index
  379. d_model, n_head, d_head = config.d_model, config.n_head, config.d_head
  380. self.hidden_dropout = nn.Dropout(config.hidden_dropout)
  381. self.attention_dropout = nn.Dropout(config.attention_dropout)
  382. self.q_head = nn.Linear(d_model, n_head * d_head, bias=False)
  383. self.k_head = nn.Linear(d_model, n_head * d_head)
  384. self.v_head = nn.Linear(d_model, n_head * d_head)
  385. self.r_w_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  386. self.r_r_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  387. self.r_kernel = nn.Parameter(torch.zeros([d_model, n_head, d_head]))
  388. self.r_s_bias = nn.Parameter(torch.zeros([n_head, d_head]))
  389. self.seg_embed = nn.Parameter(torch.zeros([2, n_head, d_head]))
  390. self.post_proj = nn.Linear(n_head * d_head, d_model)
  391. self.layer_norm = nn.LayerNorm(d_model, eps=config.layer_norm_eps)
  392. self.scale = 1.0 / (d_head**0.5)
  393. def relative_positional_attention(self, position_embeds, q_head, context_len, cls_mask=None):
  394. """Relative attention score for the positional encodings"""
  395. # q_head has shape batch_size x sea_len x n_head x d_head
  396. if self.config.attention_type == "factorized":
  397. # Notations from the paper, appending A.2.2, final formula (https://huggingface.co/papers/2006.03236)
  398. # phi and pi have shape seq_len x d_model, psi and omega have shape context_len x d_model
  399. phi, pi, psi, omega = position_embeds
  400. # Shape n_head x d_head
  401. u = self.r_r_bias * self.scale
  402. # Shape d_model x n_head x d_head
  403. w_r = self.r_kernel
  404. # Shape batch_size x sea_len x n_head x d_model
  405. q_r_attention = torch.einsum("binh,dnh->bind", q_head + u, w_r)
  406. q_r_attention_1 = q_r_attention * phi[:, None]
  407. q_r_attention_2 = q_r_attention * pi[:, None]
  408. # Shape batch_size x n_head x seq_len x context_len
  409. positional_attn = torch.einsum("bind,jd->bnij", q_r_attention_1, psi) + torch.einsum(
  410. "bind,jd->bnij", q_r_attention_2, omega
  411. )
  412. else:
  413. shift = 2 if q_head.shape[1] != context_len else 1
  414. # Notations from the paper, appending A.2.1, final formula (https://huggingface.co/papers/2006.03236)
  415. # Grab the proper positional encoding, shape max_rel_len x d_model
  416. r = position_embeds[self.block_index][shift - 1]
  417. # Shape n_head x d_head
  418. v = self.r_r_bias * self.scale
  419. # Shape d_model x n_head x d_head
  420. w_r = self.r_kernel
  421. # Shape max_rel_len x n_head x d_model
  422. r_head = torch.einsum("td,dnh->tnh", r, w_r)
  423. # Shape batch_size x n_head x seq_len x max_rel_len
  424. positional_attn = torch.einsum("binh,tnh->bnit", q_head + v, r_head)
  425. # Shape batch_size x n_head x seq_len x context_len
  426. positional_attn = _relative_shift_gather(positional_attn, context_len, shift)
  427. if cls_mask is not None:
  428. positional_attn *= cls_mask
  429. return positional_attn
  430. def relative_token_type_attention(self, token_type_mat, q_head, cls_mask=None):
  431. """Relative attention score for the token_type_ids"""
  432. if token_type_mat is None:
  433. return 0
  434. batch_size, seq_len, context_len = token_type_mat.shape
  435. # q_head has shape batch_size x seq_len x n_head x d_head
  436. # Shape n_head x d_head
  437. r_s_bias = self.r_s_bias * self.scale
  438. # Shape batch_size x n_head x seq_len x 2
  439. token_type_bias = torch.einsum("bind,snd->bnis", q_head + r_s_bias, self.seg_embed)
  440. # Shape batch_size x n_head x seq_len x context_len
  441. token_type_mat = token_type_mat[:, None].expand([batch_size, q_head.shape[2], seq_len, context_len])
  442. # Shapes batch_size x n_head x seq_len
  443. diff_token_type, same_token_type = torch.split(token_type_bias, 1, dim=-1)
  444. # Shape batch_size x n_head x seq_len x context_len
  445. token_type_attn = torch.where(
  446. token_type_mat, same_token_type.expand(token_type_mat.shape), diff_token_type.expand(token_type_mat.shape)
  447. )
  448. if cls_mask is not None:
  449. token_type_attn *= cls_mask
  450. return token_type_attn
  451. def forward(
  452. self,
  453. query: torch.Tensor,
  454. key: torch.Tensor,
  455. value: torch.Tensor,
  456. attention_inputs: tuple[torch.Tensor],
  457. output_attentions: bool = False,
  458. ) -> tuple[torch.Tensor, ...]:
  459. # query has shape batch_size x seq_len x d_model
  460. # key and value have shapes batch_size x context_len x d_model
  461. position_embeds, token_type_mat, attention_mask, cls_mask = attention_inputs
  462. batch_size, seq_len, _ = query.shape
  463. context_len = key.shape[1]
  464. n_head, d_head = self.config.n_head, self.config.d_head
  465. # Shape batch_size x seq_len x n_head x d_head
  466. q_head = self.q_head(query).view(batch_size, seq_len, n_head, d_head)
  467. # Shapes batch_size x context_len x n_head x d_head
  468. k_head = self.k_head(key).view(batch_size, context_len, n_head, d_head)
  469. v_head = self.v_head(value).view(batch_size, context_len, n_head, d_head)
  470. q_head = q_head * self.scale
  471. # Shape n_head x d_head
  472. r_w_bias = self.r_w_bias * self.scale
  473. # Shapes batch_size x n_head x seq_len x context_len
  474. content_score = torch.einsum("bind,bjnd->bnij", q_head + r_w_bias, k_head)
  475. positional_attn = self.relative_positional_attention(position_embeds, q_head, context_len, cls_mask)
  476. token_type_attn = self.relative_token_type_attention(token_type_mat, q_head, cls_mask)
  477. # merge attention scores
  478. attn_score = content_score + positional_attn + token_type_attn
  479. # precision safe in case of mixed precision training
  480. dtype = attn_score.dtype
  481. attn_score = attn_score.float()
  482. # perform masking
  483. if attention_mask is not None:
  484. attn_score = attn_score - INF * (1 - attention_mask[:, None, None].float())
  485. # attention probability
  486. attn_prob = torch.softmax(attn_score, dim=-1, dtype=dtype)
  487. attn_prob = self.attention_dropout(attn_prob)
  488. # attention output, shape batch_size x seq_len x n_head x d_head
  489. attn_vec = torch.einsum("bnij,bjnd->bind", attn_prob, v_head)
  490. # Shape shape batch_size x seq_len x d_model
  491. attn_out = self.post_proj(attn_vec.reshape(batch_size, seq_len, n_head * d_head))
  492. attn_out = self.hidden_dropout(attn_out)
  493. output = self.layer_norm(query + attn_out)
  494. return (output, attn_prob) if output_attentions else (output,)
  495. class FunnelPositionwiseFFN(nn.Module):
  496. def __init__(self, config: FunnelConfig) -> None:
  497. super().__init__()
  498. self.linear_1 = nn.Linear(config.d_model, config.d_inner)
  499. self.activation_function = ACT2FN[config.hidden_act]
  500. self.activation_dropout = nn.Dropout(config.activation_dropout)
  501. self.linear_2 = nn.Linear(config.d_inner, config.d_model)
  502. self.dropout = nn.Dropout(config.hidden_dropout)
  503. self.layer_norm = nn.LayerNorm(config.d_model, config.layer_norm_eps)
  504. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  505. h = self.linear_1(hidden)
  506. h = self.activation_function(h)
  507. h = self.activation_dropout(h)
  508. h = self.linear_2(h)
  509. h = self.dropout(h)
  510. return self.layer_norm(hidden + h)
  511. class FunnelLayer(nn.Module):
  512. def __init__(self, config: FunnelConfig, block_index: int) -> None:
  513. super().__init__()
  514. self.attention = FunnelRelMultiheadAttention(config, block_index)
  515. self.ffn = FunnelPositionwiseFFN(config)
  516. def forward(
  517. self,
  518. query: torch.Tensor,
  519. key: torch.Tensor,
  520. value: torch.Tensor,
  521. attention_inputs,
  522. output_attentions: bool = False,
  523. ) -> tuple:
  524. attn = self.attention(query, key, value, attention_inputs, output_attentions=output_attentions)
  525. output = self.ffn(attn[0])
  526. return (output, attn[1]) if output_attentions else (output,)
  527. class FunnelEncoder(nn.Module):
  528. def __init__(self, config: FunnelConfig) -> None:
  529. super().__init__()
  530. self.config = config
  531. self.attention_structure = FunnelAttentionStructure(config)
  532. self.blocks = nn.ModuleList(
  533. [
  534. nn.ModuleList([FunnelLayer(config, block_index) for _ in range(block_size)])
  535. for block_index, block_size in enumerate(config.block_sizes)
  536. ]
  537. )
  538. def forward(
  539. self,
  540. inputs_embeds: torch.Tensor,
  541. attention_mask: Optional[torch.Tensor] = None,
  542. token_type_ids: Optional[torch.Tensor] = None,
  543. output_attentions: bool = False,
  544. output_hidden_states: bool = False,
  545. return_dict: bool = True,
  546. ) -> Union[tuple, BaseModelOutput]:
  547. # The pooling is not implemented on long tensors, so we convert this mask.
  548. attention_mask = attention_mask.type_as(inputs_embeds)
  549. attention_inputs = self.attention_structure.init_attention_inputs(
  550. inputs_embeds,
  551. attention_mask=attention_mask,
  552. token_type_ids=token_type_ids,
  553. )
  554. hidden = inputs_embeds
  555. all_hidden_states = (inputs_embeds,) if output_hidden_states else None
  556. all_attentions = () if output_attentions else None
  557. for block_index, block in enumerate(self.blocks):
  558. pooling_flag = hidden.size(1) > (2 if self.config.separate_cls else 1)
  559. pooling_flag = pooling_flag and block_index > 0
  560. if pooling_flag:
  561. pooled_hidden, attention_inputs = self.attention_structure.pre_attention_pooling(
  562. hidden, attention_inputs
  563. )
  564. for layer_index, layer in enumerate(block):
  565. for repeat_index in range(self.config.block_repeats[block_index]):
  566. do_pooling = (repeat_index == 0) and (layer_index == 0) and pooling_flag
  567. if do_pooling:
  568. query = pooled_hidden
  569. key = value = hidden if self.config.pool_q_only else pooled_hidden
  570. else:
  571. query = key = value = hidden
  572. layer_output = layer(query, key, value, attention_inputs, output_attentions=output_attentions)
  573. hidden = layer_output[0]
  574. if do_pooling:
  575. attention_inputs = self.attention_structure.post_attention_pooling(attention_inputs)
  576. if output_attentions:
  577. all_attentions = all_attentions + layer_output[1:]
  578. if output_hidden_states:
  579. all_hidden_states = all_hidden_states + (hidden,)
  580. if not return_dict:
  581. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  582. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  583. def upsample(
  584. x: torch.Tensor, stride: int, target_len: int, separate_cls: bool = True, truncate_seq: bool = False
  585. ) -> torch.Tensor:
  586. """
  587. Upsample tensor `x` to match `target_len` by repeating the tokens `stride` time on the sequence length dimension.
  588. """
  589. if stride == 1:
  590. return x
  591. if separate_cls:
  592. cls = x[:, :1]
  593. x = x[:, 1:]
  594. output = torch.repeat_interleave(x, repeats=stride, dim=1)
  595. if separate_cls:
  596. if truncate_seq:
  597. output = nn.functional.pad(output, (0, 0, 0, stride - 1, 0, 0))
  598. output = output[:, : target_len - 1]
  599. output = torch.cat([cls, output], dim=1)
  600. else:
  601. output = output[:, :target_len]
  602. return output
  603. class FunnelDecoder(nn.Module):
  604. def __init__(self, config: FunnelConfig) -> None:
  605. super().__init__()
  606. self.config = config
  607. self.attention_structure = FunnelAttentionStructure(config)
  608. self.layers = nn.ModuleList([FunnelLayer(config, 0) for _ in range(config.num_decoder_layers)])
  609. def forward(
  610. self,
  611. final_hidden: torch.Tensor,
  612. first_block_hidden: torch.Tensor,
  613. attention_mask: Optional[torch.Tensor] = None,
  614. token_type_ids: Optional[torch.Tensor] = None,
  615. output_attentions: bool = False,
  616. output_hidden_states: bool = False,
  617. return_dict: bool = True,
  618. ) -> Union[tuple, BaseModelOutput]:
  619. upsampled_hidden = upsample(
  620. final_hidden,
  621. stride=2 ** (len(self.config.block_sizes) - 1),
  622. target_len=first_block_hidden.shape[1],
  623. separate_cls=self.config.separate_cls,
  624. truncate_seq=self.config.truncate_seq,
  625. )
  626. hidden = upsampled_hidden + first_block_hidden
  627. all_hidden_states = (hidden,) if output_hidden_states else None
  628. all_attentions = () if output_attentions else None
  629. attention_inputs = self.attention_structure.init_attention_inputs(
  630. hidden,
  631. attention_mask=attention_mask,
  632. token_type_ids=token_type_ids,
  633. )
  634. for layer in self.layers:
  635. layer_output = layer(hidden, hidden, hidden, attention_inputs, output_attentions=output_attentions)
  636. hidden = layer_output[0]
  637. if output_attentions:
  638. all_attentions = all_attentions + layer_output[1:]
  639. if output_hidden_states:
  640. all_hidden_states = all_hidden_states + (hidden,)
  641. if not return_dict:
  642. return tuple(v for v in [hidden, all_hidden_states, all_attentions] if v is not None)
  643. return BaseModelOutput(last_hidden_state=hidden, hidden_states=all_hidden_states, attentions=all_attentions)
  644. class FunnelDiscriminatorPredictions(nn.Module):
  645. """Prediction module for the discriminator, made up of two dense layers."""
  646. def __init__(self, config: FunnelConfig) -> None:
  647. super().__init__()
  648. self.config = config
  649. self.dense = nn.Linear(config.d_model, config.d_model)
  650. self.dense_prediction = nn.Linear(config.d_model, 1)
  651. def forward(self, discriminator_hidden_states: torch.Tensor) -> torch.Tensor:
  652. hidden_states = self.dense(discriminator_hidden_states)
  653. hidden_states = ACT2FN[self.config.hidden_act](hidden_states)
  654. logits = self.dense_prediction(hidden_states).squeeze(-1)
  655. return logits
  656. @auto_docstring
  657. class FunnelPreTrainedModel(PreTrainedModel):
  658. config: FunnelConfig
  659. load_tf_weights = load_tf_weights_in_funnel
  660. base_model_prefix = "funnel"
  661. def _init_weights(self, module):
  662. classname = module.__class__.__name__
  663. if classname.find("Linear") != -1:
  664. if getattr(module, "weight", None) is not None:
  665. if self.config.initializer_std is None:
  666. fan_out, fan_in = module.weight.shape
  667. std = np.sqrt(1.0 / float(fan_in + fan_out))
  668. else:
  669. std = self.config.initializer_std
  670. nn.init.normal_(module.weight, std=std)
  671. if getattr(module, "bias", None) is not None:
  672. nn.init.constant_(module.bias, 0.0)
  673. elif classname == "FunnelRelMultiheadAttention":
  674. nn.init.uniform_(module.r_w_bias, b=self.config.initializer_range)
  675. nn.init.uniform_(module.r_r_bias, b=self.config.initializer_range)
  676. nn.init.uniform_(module.r_kernel, b=self.config.initializer_range)
  677. nn.init.uniform_(module.r_s_bias, b=self.config.initializer_range)
  678. nn.init.uniform_(module.seg_embed, b=self.config.initializer_range)
  679. elif classname == "FunnelEmbeddings":
  680. std = 1.0 if self.config.initializer_std is None else self.config.initializer_std
  681. nn.init.normal_(module.word_embeddings.weight, std=std)
  682. if module.word_embeddings.padding_idx is not None:
  683. module.word_embeddings.weight.data[module.word_embeddings.padding_idx].zero_()
  684. class FunnelClassificationHead(nn.Module):
  685. def __init__(self, config: FunnelConfig, n_labels: int) -> None:
  686. super().__init__()
  687. self.linear_hidden = nn.Linear(config.d_model, config.d_model)
  688. self.dropout = nn.Dropout(config.hidden_dropout)
  689. self.linear_out = nn.Linear(config.d_model, n_labels)
  690. def forward(self, hidden: torch.Tensor) -> torch.Tensor:
  691. hidden = self.linear_hidden(hidden)
  692. hidden = torch.tanh(hidden)
  693. hidden = self.dropout(hidden)
  694. return self.linear_out(hidden)
  695. @dataclass
  696. @auto_docstring(
  697. custom_intro="""
  698. Output type of [`FunnelForPreTraining`].
  699. """
  700. )
  701. class FunnelForPreTrainingOutput(ModelOutput):
  702. r"""
  703. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  704. Total loss of the ELECTRA-style objective.
  705. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
  706. Prediction scores of the head (scores for each token before SoftMax).
  707. """
  708. loss: Optional[torch.FloatTensor] = None
  709. logits: Optional[torch.FloatTensor] = None
  710. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  711. attentions: Optional[tuple[torch.FloatTensor]] = None
  712. @auto_docstring(
  713. custom_intro="""
  714. The base Funnel Transformer Model transformer outputting raw hidden-states without upsampling head (also called
  715. decoder) or any task-specific head on top.
  716. """
  717. )
  718. class FunnelBaseModel(FunnelPreTrainedModel):
  719. def __init__(self, config: FunnelConfig) -> None:
  720. super().__init__(config)
  721. self.embeddings = FunnelEmbeddings(config)
  722. self.encoder = FunnelEncoder(config)
  723. # Initialize weights and apply final processing
  724. self.post_init()
  725. def get_input_embeddings(self) -> nn.Embedding:
  726. return self.embeddings.word_embeddings
  727. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  728. self.embeddings.word_embeddings = new_embeddings
  729. @auto_docstring
  730. def forward(
  731. self,
  732. input_ids: Optional[torch.Tensor] = None,
  733. attention_mask: Optional[torch.Tensor] = None,
  734. token_type_ids: Optional[torch.Tensor] = None,
  735. position_ids: Optional[torch.Tensor] = None,
  736. head_mask: Optional[torch.Tensor] = None,
  737. inputs_embeds: Optional[torch.Tensor] = None,
  738. output_attentions: Optional[bool] = None,
  739. output_hidden_states: Optional[bool] = None,
  740. return_dict: Optional[bool] = None,
  741. ) -> Union[tuple, BaseModelOutput]:
  742. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  743. output_hidden_states = (
  744. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  745. )
  746. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  747. if input_ids is not None and inputs_embeds is not None:
  748. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  749. elif input_ids is not None:
  750. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  751. input_shape = input_ids.size()
  752. elif inputs_embeds is not None:
  753. input_shape = inputs_embeds.size()[:-1]
  754. else:
  755. raise ValueError("You have to specify either input_ids or inputs_embeds")
  756. device = input_ids.device if input_ids is not None else inputs_embeds.device
  757. if attention_mask is None:
  758. attention_mask = torch.ones(input_shape, device=device)
  759. if token_type_ids is None:
  760. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  761. # TODO: deal with head_mask
  762. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  763. encoder_outputs = self.encoder(
  764. inputs_embeds,
  765. attention_mask=attention_mask,
  766. token_type_ids=token_type_ids,
  767. output_attentions=output_attentions,
  768. output_hidden_states=output_hidden_states,
  769. return_dict=return_dict,
  770. )
  771. return encoder_outputs
  772. @auto_docstring
  773. class FunnelModel(FunnelPreTrainedModel):
  774. def __init__(self, config: FunnelConfig) -> None:
  775. super().__init__(config)
  776. self.config = config
  777. self.embeddings = FunnelEmbeddings(config)
  778. self.encoder = FunnelEncoder(config)
  779. self.decoder = FunnelDecoder(config)
  780. # Initialize weights and apply final processing
  781. self.post_init()
  782. def get_input_embeddings(self) -> nn.Embedding:
  783. return self.embeddings.word_embeddings
  784. def set_input_embeddings(self, new_embeddings: nn.Embedding) -> None:
  785. self.embeddings.word_embeddings = new_embeddings
  786. @auto_docstring
  787. def forward(
  788. self,
  789. input_ids: Optional[torch.Tensor] = None,
  790. attention_mask: Optional[torch.Tensor] = None,
  791. token_type_ids: Optional[torch.Tensor] = None,
  792. inputs_embeds: Optional[torch.Tensor] = None,
  793. output_attentions: Optional[bool] = None,
  794. output_hidden_states: Optional[bool] = None,
  795. return_dict: Optional[bool] = None,
  796. ) -> Union[tuple, BaseModelOutput]:
  797. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  798. output_hidden_states = (
  799. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  800. )
  801. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  802. if input_ids is not None and inputs_embeds is not None:
  803. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  804. elif input_ids is not None:
  805. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  806. input_shape = input_ids.size()
  807. elif inputs_embeds is not None:
  808. input_shape = inputs_embeds.size()[:-1]
  809. else:
  810. raise ValueError("You have to specify either input_ids or inputs_embeds")
  811. device = input_ids.device if input_ids is not None else inputs_embeds.device
  812. if attention_mask is None:
  813. attention_mask = torch.ones(input_shape, device=device)
  814. if token_type_ids is None:
  815. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  816. # TODO: deal with head_mask
  817. inputs_embeds = self.embeddings(input_ids, inputs_embeds=inputs_embeds)
  818. encoder_outputs = self.encoder(
  819. inputs_embeds,
  820. attention_mask=attention_mask,
  821. token_type_ids=token_type_ids,
  822. output_attentions=output_attentions,
  823. output_hidden_states=True,
  824. return_dict=return_dict,
  825. )
  826. decoder_outputs = self.decoder(
  827. final_hidden=encoder_outputs[0],
  828. first_block_hidden=encoder_outputs[1][self.config.block_sizes[0]],
  829. attention_mask=attention_mask,
  830. token_type_ids=token_type_ids,
  831. output_attentions=output_attentions,
  832. output_hidden_states=output_hidden_states,
  833. return_dict=return_dict,
  834. )
  835. if not return_dict:
  836. idx = 0
  837. outputs = (decoder_outputs[0],)
  838. if output_hidden_states:
  839. idx += 1
  840. outputs = outputs + (encoder_outputs[1] + decoder_outputs[idx],)
  841. if output_attentions:
  842. idx += 1
  843. outputs = outputs + (encoder_outputs[2] + decoder_outputs[idx],)
  844. return outputs
  845. return BaseModelOutput(
  846. last_hidden_state=decoder_outputs[0],
  847. hidden_states=(encoder_outputs.hidden_states + decoder_outputs.hidden_states)
  848. if output_hidden_states
  849. else None,
  850. attentions=(encoder_outputs.attentions + decoder_outputs.attentions) if output_attentions else None,
  851. )
  852. @auto_docstring(
  853. custom_intro="""
  854. Funnel Transformer model with a binary classification head on top as used during pretraining for identifying
  855. generated tokens.
  856. """
  857. )
  858. class FunnelForPreTraining(FunnelPreTrainedModel):
  859. def __init__(self, config: FunnelConfig) -> None:
  860. super().__init__(config)
  861. self.funnel = FunnelModel(config)
  862. self.discriminator_predictions = FunnelDiscriminatorPredictions(config)
  863. # Initialize weights and apply final processing
  864. self.post_init()
  865. @auto_docstring
  866. def forward(
  867. self,
  868. input_ids: Optional[torch.Tensor] = None,
  869. attention_mask: Optional[torch.Tensor] = None,
  870. token_type_ids: Optional[torch.Tensor] = None,
  871. inputs_embeds: Optional[torch.Tensor] = None,
  872. labels: Optional[torch.Tensor] = None,
  873. output_attentions: Optional[bool] = None,
  874. output_hidden_states: Optional[bool] = None,
  875. return_dict: Optional[bool] = None,
  876. ) -> Union[tuple, FunnelForPreTrainingOutput]:
  877. r"""
  878. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  879. Labels for computing the ELECTRA-style loss. Input should be a sequence of tokens (see `input_ids`
  880. docstring) Indices should be in `[0, 1]`:
  881. - 0 indicates the token is an original token,
  882. - 1 indicates the token was replaced.
  883. Examples:
  884. ```python
  885. >>> from transformers import AutoTokenizer, FunnelForPreTraining
  886. >>> import torch
  887. >>> tokenizer = AutoTokenizer.from_pretrained("funnel-transformer/small")
  888. >>> model = FunnelForPreTraining.from_pretrained("funnel-transformer/small")
  889. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  890. >>> logits = model(**inputs).logits
  891. ```"""
  892. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  893. discriminator_hidden_states = self.funnel(
  894. input_ids,
  895. attention_mask=attention_mask,
  896. token_type_ids=token_type_ids,
  897. inputs_embeds=inputs_embeds,
  898. output_attentions=output_attentions,
  899. output_hidden_states=output_hidden_states,
  900. return_dict=return_dict,
  901. )
  902. discriminator_sequence_output = discriminator_hidden_states[0]
  903. logits = self.discriminator_predictions(discriminator_sequence_output)
  904. loss = None
  905. if labels is not None:
  906. loss_fct = nn.BCEWithLogitsLoss()
  907. if attention_mask is not None:
  908. active_loss = attention_mask.view(-1, discriminator_sequence_output.shape[1]) == 1
  909. active_logits = logits.view(-1, discriminator_sequence_output.shape[1])[active_loss]
  910. active_labels = labels[active_loss]
  911. loss = loss_fct(active_logits, active_labels.float())
  912. else:
  913. loss = loss_fct(logits.view(-1, discriminator_sequence_output.shape[1]), labels.float())
  914. if not return_dict:
  915. output = (logits,) + discriminator_hidden_states[1:]
  916. return ((loss,) + output) if loss is not None else output
  917. return FunnelForPreTrainingOutput(
  918. loss=loss,
  919. logits=logits,
  920. hidden_states=discriminator_hidden_states.hidden_states,
  921. attentions=discriminator_hidden_states.attentions,
  922. )
  923. @auto_docstring
  924. class FunnelForMaskedLM(FunnelPreTrainedModel):
  925. _tied_weights_keys = ["lm_head.weight"]
  926. def __init__(self, config: FunnelConfig) -> None:
  927. super().__init__(config)
  928. self.funnel = FunnelModel(config)
  929. self.lm_head = nn.Linear(config.d_model, config.vocab_size)
  930. # Initialize weights and apply final processing
  931. self.post_init()
  932. def get_output_embeddings(self) -> nn.Linear:
  933. return self.lm_head
  934. def set_output_embeddings(self, new_embeddings: nn.Embedding) -> None:
  935. self.lm_head = new_embeddings
  936. @auto_docstring
  937. def forward(
  938. self,
  939. input_ids: Optional[torch.Tensor] = None,
  940. attention_mask: Optional[torch.Tensor] = None,
  941. token_type_ids: Optional[torch.Tensor] = None,
  942. inputs_embeds: Optional[torch.Tensor] = None,
  943. labels: Optional[torch.Tensor] = None,
  944. output_attentions: Optional[bool] = None,
  945. output_hidden_states: Optional[bool] = None,
  946. return_dict: Optional[bool] = None,
  947. ) -> Union[tuple, MaskedLMOutput]:
  948. r"""
  949. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  950. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  951. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  952. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  953. """
  954. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  955. outputs = self.funnel(
  956. input_ids,
  957. attention_mask=attention_mask,
  958. token_type_ids=token_type_ids,
  959. inputs_embeds=inputs_embeds,
  960. output_attentions=output_attentions,
  961. output_hidden_states=output_hidden_states,
  962. return_dict=return_dict,
  963. )
  964. last_hidden_state = outputs[0]
  965. prediction_logits = self.lm_head(last_hidden_state)
  966. masked_lm_loss = None
  967. if labels is not None:
  968. loss_fct = CrossEntropyLoss() # -100 index = padding token
  969. masked_lm_loss = loss_fct(prediction_logits.view(-1, self.config.vocab_size), labels.view(-1))
  970. if not return_dict:
  971. output = (prediction_logits,) + outputs[1:]
  972. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  973. return MaskedLMOutput(
  974. loss=masked_lm_loss,
  975. logits=prediction_logits,
  976. hidden_states=outputs.hidden_states,
  977. attentions=outputs.attentions,
  978. )
  979. @auto_docstring(
  980. custom_intro="""
  981. Funnel Transformer Model with a sequence classification/regression head on top (two linear layer on top of the
  982. first timestep of the last hidden state) e.g. for GLUE tasks.
  983. """
  984. )
  985. class FunnelForSequenceClassification(FunnelPreTrainedModel):
  986. def __init__(self, config: FunnelConfig) -> None:
  987. super().__init__(config)
  988. self.num_labels = config.num_labels
  989. self.config = config
  990. self.funnel = FunnelBaseModel(config)
  991. self.classifier = FunnelClassificationHead(config, config.num_labels)
  992. # Initialize weights and apply final processing
  993. self.post_init()
  994. @auto_docstring
  995. def forward(
  996. self,
  997. input_ids: Optional[torch.Tensor] = None,
  998. attention_mask: Optional[torch.Tensor] = None,
  999. token_type_ids: Optional[torch.Tensor] = None,
  1000. inputs_embeds: Optional[torch.Tensor] = None,
  1001. labels: Optional[torch.Tensor] = None,
  1002. output_attentions: Optional[bool] = None,
  1003. output_hidden_states: Optional[bool] = None,
  1004. return_dict: Optional[bool] = None,
  1005. ) -> Union[tuple, SequenceClassifierOutput]:
  1006. r"""
  1007. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1008. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1009. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1010. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1011. """
  1012. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1013. outputs = self.funnel(
  1014. input_ids,
  1015. attention_mask=attention_mask,
  1016. token_type_ids=token_type_ids,
  1017. inputs_embeds=inputs_embeds,
  1018. output_attentions=output_attentions,
  1019. output_hidden_states=output_hidden_states,
  1020. return_dict=return_dict,
  1021. )
  1022. last_hidden_state = outputs[0]
  1023. pooled_output = last_hidden_state[:, 0]
  1024. logits = self.classifier(pooled_output)
  1025. loss = None
  1026. if labels is not None:
  1027. if self.config.problem_type is None:
  1028. if self.num_labels == 1:
  1029. self.config.problem_type = "regression"
  1030. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1031. self.config.problem_type = "single_label_classification"
  1032. else:
  1033. self.config.problem_type = "multi_label_classification"
  1034. if self.config.problem_type == "regression":
  1035. loss_fct = MSELoss()
  1036. if self.num_labels == 1:
  1037. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1038. else:
  1039. loss = loss_fct(logits, labels)
  1040. elif self.config.problem_type == "single_label_classification":
  1041. loss_fct = CrossEntropyLoss()
  1042. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1043. elif self.config.problem_type == "multi_label_classification":
  1044. loss_fct = BCEWithLogitsLoss()
  1045. loss = loss_fct(logits, labels)
  1046. if not return_dict:
  1047. output = (logits,) + outputs[1:]
  1048. return ((loss,) + output) if loss is not None else output
  1049. return SequenceClassifierOutput(
  1050. loss=loss,
  1051. logits=logits,
  1052. hidden_states=outputs.hidden_states,
  1053. attentions=outputs.attentions,
  1054. )
  1055. @auto_docstring
  1056. class FunnelForMultipleChoice(FunnelPreTrainedModel):
  1057. def __init__(self, config: FunnelConfig) -> None:
  1058. super().__init__(config)
  1059. self.funnel = FunnelBaseModel(config)
  1060. self.classifier = FunnelClassificationHead(config, 1)
  1061. # Initialize weights and apply final processing
  1062. self.post_init()
  1063. @auto_docstring
  1064. def forward(
  1065. self,
  1066. input_ids: Optional[torch.Tensor] = None,
  1067. attention_mask: Optional[torch.Tensor] = None,
  1068. token_type_ids: Optional[torch.Tensor] = None,
  1069. inputs_embeds: Optional[torch.Tensor] = None,
  1070. labels: Optional[torch.Tensor] = None,
  1071. output_attentions: Optional[bool] = None,
  1072. output_hidden_states: Optional[bool] = None,
  1073. return_dict: Optional[bool] = None,
  1074. ) -> Union[tuple, MultipleChoiceModelOutput]:
  1075. r"""
  1076. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1077. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1078. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1079. `input_ids` above)
  1080. """
  1081. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1082. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1083. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1084. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1085. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1086. inputs_embeds = (
  1087. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1088. if inputs_embeds is not None
  1089. else None
  1090. )
  1091. outputs = self.funnel(
  1092. input_ids,
  1093. attention_mask=attention_mask,
  1094. token_type_ids=token_type_ids,
  1095. inputs_embeds=inputs_embeds,
  1096. output_attentions=output_attentions,
  1097. output_hidden_states=output_hidden_states,
  1098. return_dict=return_dict,
  1099. )
  1100. last_hidden_state = outputs[0]
  1101. pooled_output = last_hidden_state[:, 0]
  1102. logits = self.classifier(pooled_output)
  1103. reshaped_logits = logits.view(-1, num_choices)
  1104. loss = None
  1105. if labels is not None:
  1106. loss_fct = CrossEntropyLoss()
  1107. loss = loss_fct(reshaped_logits, labels)
  1108. if not return_dict:
  1109. output = (reshaped_logits,) + outputs[1:]
  1110. return ((loss,) + output) if loss is not None else output
  1111. return MultipleChoiceModelOutput(
  1112. loss=loss,
  1113. logits=reshaped_logits,
  1114. hidden_states=outputs.hidden_states,
  1115. attentions=outputs.attentions,
  1116. )
  1117. @auto_docstring
  1118. class FunnelForTokenClassification(FunnelPreTrainedModel):
  1119. def __init__(self, config: FunnelConfig) -> None:
  1120. super().__init__(config)
  1121. self.num_labels = config.num_labels
  1122. self.funnel = FunnelModel(config)
  1123. self.dropout = nn.Dropout(config.hidden_dropout)
  1124. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1125. # Initialize weights and apply final processing
  1126. self.post_init()
  1127. @auto_docstring
  1128. def forward(
  1129. self,
  1130. input_ids: Optional[torch.Tensor] = None,
  1131. attention_mask: Optional[torch.Tensor] = None,
  1132. token_type_ids: Optional[torch.Tensor] = None,
  1133. inputs_embeds: Optional[torch.Tensor] = None,
  1134. labels: Optional[torch.Tensor] = None,
  1135. output_attentions: Optional[bool] = None,
  1136. output_hidden_states: Optional[bool] = None,
  1137. return_dict: Optional[bool] = None,
  1138. ) -> Union[tuple, TokenClassifierOutput]:
  1139. r"""
  1140. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1141. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1142. """
  1143. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1144. outputs = self.funnel(
  1145. input_ids,
  1146. attention_mask=attention_mask,
  1147. token_type_ids=token_type_ids,
  1148. inputs_embeds=inputs_embeds,
  1149. output_attentions=output_attentions,
  1150. output_hidden_states=output_hidden_states,
  1151. return_dict=return_dict,
  1152. )
  1153. last_hidden_state = outputs[0]
  1154. last_hidden_state = self.dropout(last_hidden_state)
  1155. logits = self.classifier(last_hidden_state)
  1156. loss = None
  1157. if labels is not None:
  1158. loss_fct = CrossEntropyLoss()
  1159. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1160. if not return_dict:
  1161. output = (logits,) + outputs[1:]
  1162. return ((loss,) + output) if loss is not None else output
  1163. return TokenClassifierOutput(
  1164. loss=loss,
  1165. logits=logits,
  1166. hidden_states=outputs.hidden_states,
  1167. attentions=outputs.attentions,
  1168. )
  1169. @auto_docstring
  1170. class FunnelForQuestionAnswering(FunnelPreTrainedModel):
  1171. def __init__(self, config: FunnelConfig) -> None:
  1172. super().__init__(config)
  1173. self.num_labels = config.num_labels
  1174. self.funnel = FunnelModel(config)
  1175. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1176. # Initialize weights and apply final processing
  1177. self.post_init()
  1178. @auto_docstring
  1179. def forward(
  1180. self,
  1181. input_ids: Optional[torch.Tensor] = None,
  1182. attention_mask: Optional[torch.Tensor] = None,
  1183. token_type_ids: Optional[torch.Tensor] = None,
  1184. inputs_embeds: Optional[torch.Tensor] = None,
  1185. start_positions: Optional[torch.Tensor] = None,
  1186. end_positions: Optional[torch.Tensor] = None,
  1187. output_attentions: Optional[bool] = None,
  1188. output_hidden_states: Optional[bool] = None,
  1189. return_dict: Optional[bool] = None,
  1190. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1191. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1192. outputs = self.funnel(
  1193. input_ids,
  1194. attention_mask=attention_mask,
  1195. token_type_ids=token_type_ids,
  1196. inputs_embeds=inputs_embeds,
  1197. output_attentions=output_attentions,
  1198. output_hidden_states=output_hidden_states,
  1199. return_dict=return_dict,
  1200. )
  1201. last_hidden_state = outputs[0]
  1202. logits = self.qa_outputs(last_hidden_state)
  1203. start_logits, end_logits = logits.split(1, dim=-1)
  1204. start_logits = start_logits.squeeze(-1).contiguous()
  1205. end_logits = end_logits.squeeze(-1).contiguous()
  1206. total_loss = None
  1207. if start_positions is not None and end_positions is not None:
  1208. # If we are on multi-GPU, split add a dimension
  1209. if len(start_positions.size()) > 1:
  1210. start_positions = start_positions.squeze(-1)
  1211. if len(end_positions.size()) > 1:
  1212. end_positions = end_positions.squeeze(-1)
  1213. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1214. ignored_index = start_logits.size(1)
  1215. start_positions = start_positions.clamp(0, ignored_index)
  1216. end_positions = end_positions.clamp(0, ignored_index)
  1217. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1218. start_loss = loss_fct(start_logits, start_positions)
  1219. end_loss = loss_fct(end_logits, end_positions)
  1220. total_loss = (start_loss + end_loss) / 2
  1221. if not return_dict:
  1222. output = (start_logits, end_logits) + outputs[1:]
  1223. return ((total_loss,) + output) if total_loss is not None else output
  1224. return QuestionAnsweringModelOutput(
  1225. loss=total_loss,
  1226. start_logits=start_logits,
  1227. end_logits=end_logits,
  1228. hidden_states=outputs.hidden_states,
  1229. attentions=outputs.attentions,
  1230. )
  1231. __all__ = [
  1232. "FunnelBaseModel",
  1233. "FunnelForMaskedLM",
  1234. "FunnelForMultipleChoice",
  1235. "FunnelForPreTraining",
  1236. "FunnelForQuestionAnswering",
  1237. "FunnelForSequenceClassification",
  1238. "FunnelForTokenClassification",
  1239. "FunnelModel",
  1240. "FunnelPreTrainedModel",
  1241. "load_tf_weights_in_funnel",
  1242. ]