modeling_bert.py 76 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch BERT model."""
  17. import math
  18. import os
  19. import warnings
  20. from dataclasses import dataclass
  21. from typing import Optional, Union
  22. import torch
  23. from torch import nn
  24. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  25. from ...activations import ACT2FN
  26. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  27. from ...generation import GenerationMixin
  28. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
  29. from ...modeling_layers import GradientCheckpointingLayer
  30. from ...modeling_outputs import (
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. BaseModelOutputWithPoolingAndCrossAttentions,
  33. CausalLMOutputWithCrossAttentions,
  34. MaskedLMOutput,
  35. MultipleChoiceModelOutput,
  36. NextSentencePredictorOutput,
  37. QuestionAnsweringModelOutput,
  38. SequenceClassifierOutput,
  39. TokenClassifierOutput,
  40. )
  41. from ...modeling_utils import PreTrainedModel
  42. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  43. from ...utils import ModelOutput, auto_docstring, logging
  44. from ...utils.deprecation import deprecate_kwarg
  45. from .configuration_bert import BertConfig
  46. logger = logging.get_logger(__name__)
  47. def load_tf_weights_in_bert(model, config, tf_checkpoint_path):
  48. """Load tf checkpoints in a pytorch model."""
  49. try:
  50. import re
  51. import numpy as np
  52. import tensorflow as tf
  53. except ImportError:
  54. logger.error(
  55. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  56. "https://www.tensorflow.org/install/ for installation instructions."
  57. )
  58. raise
  59. tf_path = os.path.abspath(tf_checkpoint_path)
  60. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  61. # Load weights from TF model
  62. init_vars = tf.train.list_variables(tf_path)
  63. names = []
  64. arrays = []
  65. for name, shape in init_vars:
  66. logger.info(f"Loading TF weight {name} with shape {shape}")
  67. array = tf.train.load_variable(tf_path, name)
  68. names.append(name)
  69. arrays.append(array)
  70. for name, array in zip(names, arrays):
  71. name = name.split("/")
  72. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  73. # which are not required for using pretrained model
  74. if any(
  75. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  76. for n in name
  77. ):
  78. logger.info(f"Skipping {'/'.join(name)}")
  79. continue
  80. pointer = model
  81. for m_name in name:
  82. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  83. scope_names = re.split(r"_(\d+)", m_name)
  84. else:
  85. scope_names = [m_name]
  86. if scope_names[0] == "kernel" or scope_names[0] == "gamma":
  87. pointer = getattr(pointer, "weight")
  88. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  89. pointer = getattr(pointer, "bias")
  90. elif scope_names[0] == "output_weights":
  91. pointer = getattr(pointer, "weight")
  92. elif scope_names[0] == "squad":
  93. pointer = getattr(pointer, "classifier")
  94. else:
  95. try:
  96. pointer = getattr(pointer, scope_names[0])
  97. except AttributeError:
  98. logger.info(f"Skipping {'/'.join(name)}")
  99. continue
  100. if len(scope_names) >= 2:
  101. num = int(scope_names[1])
  102. pointer = pointer[num]
  103. if m_name[-11:] == "_embeddings":
  104. pointer = getattr(pointer, "weight")
  105. elif m_name == "kernel":
  106. array = np.transpose(array)
  107. try:
  108. if pointer.shape != array.shape:
  109. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  110. except ValueError as e:
  111. e.args += (pointer.shape, array.shape)
  112. raise
  113. logger.info(f"Initialize PyTorch weight {name}")
  114. pointer.data = torch.from_numpy(array)
  115. return model
  116. class BertEmbeddings(nn.Module):
  117. """Construct the embeddings from word, position and token_type embeddings."""
  118. def __init__(self, config):
  119. super().__init__()
  120. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  121. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  122. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  123. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  124. # any TensorFlow checkpoint file
  125. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  126. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  127. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  128. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  129. self.register_buffer(
  130. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  131. )
  132. self.register_buffer(
  133. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  134. )
  135. def forward(
  136. self,
  137. input_ids: Optional[torch.LongTensor] = None,
  138. token_type_ids: Optional[torch.LongTensor] = None,
  139. position_ids: Optional[torch.LongTensor] = None,
  140. inputs_embeds: Optional[torch.FloatTensor] = None,
  141. past_key_values_length: int = 0,
  142. ) -> torch.Tensor:
  143. if input_ids is not None:
  144. input_shape = input_ids.size()
  145. else:
  146. input_shape = inputs_embeds.size()[:-1]
  147. seq_length = input_shape[1]
  148. if position_ids is None:
  149. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  150. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  151. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  152. # issue #5664
  153. if token_type_ids is None:
  154. if hasattr(self, "token_type_ids"):
  155. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  156. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  157. token_type_ids = buffered_token_type_ids_expanded
  158. else:
  159. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  160. if inputs_embeds is None:
  161. inputs_embeds = self.word_embeddings(input_ids)
  162. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  163. embeddings = inputs_embeds + token_type_embeddings
  164. if self.position_embedding_type == "absolute":
  165. position_embeddings = self.position_embeddings(position_ids)
  166. embeddings += position_embeddings
  167. embeddings = self.LayerNorm(embeddings)
  168. embeddings = self.dropout(embeddings)
  169. return embeddings
  170. class BertSelfAttention(nn.Module):
  171. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  172. super().__init__()
  173. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  174. raise ValueError(
  175. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  176. f"heads ({config.num_attention_heads})"
  177. )
  178. self.num_attention_heads = config.num_attention_heads
  179. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  180. self.all_head_size = self.num_attention_heads * self.attention_head_size
  181. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  182. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  183. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  184. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  185. self.position_embedding_type = position_embedding_type or getattr(
  186. config, "position_embedding_type", "absolute"
  187. )
  188. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  189. self.max_position_embeddings = config.max_position_embeddings
  190. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  191. self.is_decoder = config.is_decoder
  192. self.layer_idx = layer_idx
  193. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  194. def forward(
  195. self,
  196. hidden_states: torch.Tensor,
  197. attention_mask: Optional[torch.FloatTensor] = None,
  198. head_mask: Optional[torch.FloatTensor] = None,
  199. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  200. past_key_values: Optional[Cache] = None,
  201. output_attentions: Optional[bool] = False,
  202. cache_position: Optional[torch.Tensor] = None,
  203. ) -> tuple[torch.Tensor]:
  204. batch_size, seq_length, _ = hidden_states.shape
  205. query_layer = self.query(hidden_states)
  206. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  207. 1, 2
  208. )
  209. is_updated = False
  210. is_cross_attention = encoder_hidden_states is not None
  211. if past_key_values is not None:
  212. if isinstance(past_key_values, EncoderDecoderCache):
  213. is_updated = past_key_values.is_updated.get(self.layer_idx)
  214. if is_cross_attention:
  215. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  216. curr_past_key_value = past_key_values.cross_attention_cache
  217. else:
  218. curr_past_key_value = past_key_values.self_attention_cache
  219. else:
  220. curr_past_key_value = past_key_values
  221. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  222. if is_cross_attention and past_key_values is not None and is_updated:
  223. # reuse k,v, cross_attentions
  224. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  225. value_layer = curr_past_key_value.layers[self.layer_idx].values
  226. else:
  227. key_layer = self.key(current_states)
  228. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  229. 1, 2
  230. )
  231. value_layer = self.value(current_states)
  232. value_layer = value_layer.view(
  233. batch_size, -1, self.num_attention_heads, self.attention_head_size
  234. ).transpose(1, 2)
  235. if past_key_values is not None:
  236. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  237. cache_position = cache_position if not is_cross_attention else None
  238. key_layer, value_layer = curr_past_key_value.update(
  239. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  240. )
  241. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  242. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  243. past_key_values.is_updated[self.layer_idx] = True
  244. # Take the dot product between "query" and "key" to get the raw attention scores.
  245. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  246. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  247. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  248. if past_key_values is not None:
  249. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  250. -1, 1
  251. )
  252. else:
  253. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  254. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  255. distance = position_ids_l - position_ids_r
  256. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  257. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  258. if self.position_embedding_type == "relative_key":
  259. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  260. attention_scores = attention_scores + relative_position_scores
  261. elif self.position_embedding_type == "relative_key_query":
  262. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  263. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  264. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  265. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  266. if attention_mask is not None:
  267. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  268. attention_scores = attention_scores + attention_mask
  269. # Normalize the attention scores to probabilities.
  270. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  271. # This is actually dropping out entire tokens to attend to, which might
  272. # seem a bit unusual, but is taken from the original Transformer paper.
  273. attention_probs = self.dropout(attention_probs)
  274. # Mask heads if we want to
  275. if head_mask is not None:
  276. attention_probs = attention_probs * head_mask
  277. context_layer = torch.matmul(attention_probs, value_layer)
  278. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  279. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  280. context_layer = context_layer.view(new_context_layer_shape)
  281. return context_layer, attention_probs
  282. class BertSdpaSelfAttention(BertSelfAttention):
  283. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  284. super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
  285. self.dropout_prob = config.attention_probs_dropout_prob
  286. # Adapted from BertSelfAttention
  287. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  288. def forward(
  289. self,
  290. hidden_states: torch.Tensor,
  291. attention_mask: Optional[torch.Tensor] = None,
  292. head_mask: Optional[torch.FloatTensor] = None,
  293. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  294. past_key_values: Optional[Cache] = None,
  295. output_attentions: Optional[bool] = False,
  296. cache_position: Optional[torch.Tensor] = None,
  297. ) -> tuple[torch.Tensor]:
  298. if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
  299. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
  300. logger.warning_once(
  301. "BertSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  302. "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
  303. "the manual attention implementation, but specifying the manual implementation will be required from "
  304. "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
  305. '`attn_implementation="eager"` when loading the model.'
  306. )
  307. return super().forward(
  308. hidden_states,
  309. attention_mask,
  310. head_mask,
  311. encoder_hidden_states,
  312. past_key_values,
  313. output_attentions,
  314. cache_position,
  315. )
  316. bsz, tgt_len, _ = hidden_states.size()
  317. query_layer = (
  318. self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  319. )
  320. is_updated = False
  321. is_cross_attention = encoder_hidden_states is not None
  322. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  323. if past_key_values is not None:
  324. if isinstance(past_key_values, EncoderDecoderCache):
  325. is_updated = past_key_values.is_updated.get(self.layer_idx)
  326. if is_cross_attention:
  327. # after the first generated id, we can subsequently re-use all key/value_states from cache
  328. curr_past_key_value = past_key_values.cross_attention_cache
  329. else:
  330. curr_past_key_value = past_key_values.self_attention_cache
  331. else:
  332. curr_past_key_value = past_key_values
  333. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  334. if is_cross_attention and past_key_values is not None and is_updated:
  335. # reuse k,v, cross_attentions
  336. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  337. value_layer = curr_past_key_value.layers[self.layer_idx].values
  338. else:
  339. key_layer = (
  340. self.key(current_states)
  341. .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
  342. .transpose(1, 2)
  343. )
  344. value_layer = (
  345. self.value(current_states)
  346. .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
  347. .transpose(1, 2)
  348. )
  349. if past_key_values is not None:
  350. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  351. cache_position = cache_position if not is_cross_attention else None
  352. key_layer, value_layer = curr_past_key_value.update(
  353. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  354. )
  355. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  356. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  357. past_key_values.is_updated[self.layer_idx] = True
  358. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  359. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  360. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
  361. # a causal mask in case tgt_len == 1.
  362. is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
  363. attn_output = torch.nn.functional.scaled_dot_product_attention(
  364. query_layer,
  365. key_layer,
  366. value_layer,
  367. attn_mask=attention_mask,
  368. dropout_p=self.dropout_prob if self.training else 0.0,
  369. is_causal=is_causal,
  370. )
  371. attn_output = attn_output.transpose(1, 2)
  372. attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
  373. return attn_output, None
  374. class BertSelfOutput(nn.Module):
  375. def __init__(self, config):
  376. super().__init__()
  377. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  378. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  379. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  380. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  381. hidden_states = self.dense(hidden_states)
  382. hidden_states = self.dropout(hidden_states)
  383. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  384. return hidden_states
  385. BERT_SELF_ATTENTION_CLASSES = {
  386. "eager": BertSelfAttention,
  387. "sdpa": BertSdpaSelfAttention,
  388. }
  389. class BertAttention(nn.Module):
  390. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  391. super().__init__()
  392. self.self = BERT_SELF_ATTENTION_CLASSES[config._attn_implementation](
  393. config,
  394. position_embedding_type=position_embedding_type,
  395. layer_idx=layer_idx,
  396. )
  397. self.output = BertSelfOutput(config)
  398. self.pruned_heads = set()
  399. def prune_heads(self, heads):
  400. if len(heads) == 0:
  401. return
  402. heads, index = find_pruneable_heads_and_indices(
  403. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  404. )
  405. # Prune linear layers
  406. self.self.query = prune_linear_layer(self.self.query, index)
  407. self.self.key = prune_linear_layer(self.self.key, index)
  408. self.self.value = prune_linear_layer(self.self.value, index)
  409. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  410. # Update hyper params and store pruned heads
  411. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  412. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  413. self.pruned_heads = self.pruned_heads.union(heads)
  414. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  415. def forward(
  416. self,
  417. hidden_states: torch.Tensor,
  418. attention_mask: Optional[torch.FloatTensor] = None,
  419. head_mask: Optional[torch.FloatTensor] = None,
  420. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  421. past_key_values: Optional[Cache] = None,
  422. output_attentions: Optional[bool] = False,
  423. cache_position: Optional[torch.Tensor] = None,
  424. ) -> tuple[torch.Tensor]:
  425. self_outputs = self.self(
  426. hidden_states,
  427. attention_mask=attention_mask,
  428. head_mask=head_mask,
  429. encoder_hidden_states=encoder_hidden_states,
  430. past_key_values=past_key_values,
  431. output_attentions=output_attentions,
  432. cache_position=cache_position,
  433. )
  434. attention_output = self.output(self_outputs[0], hidden_states)
  435. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  436. return outputs
  437. class BertIntermediate(nn.Module):
  438. def __init__(self, config):
  439. super().__init__()
  440. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  441. if isinstance(config.hidden_act, str):
  442. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  443. else:
  444. self.intermediate_act_fn = config.hidden_act
  445. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  446. hidden_states = self.dense(hidden_states)
  447. hidden_states = self.intermediate_act_fn(hidden_states)
  448. return hidden_states
  449. class BertOutput(nn.Module):
  450. def __init__(self, config):
  451. super().__init__()
  452. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  453. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  454. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  455. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  456. hidden_states = self.dense(hidden_states)
  457. hidden_states = self.dropout(hidden_states)
  458. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  459. return hidden_states
  460. class BertLayer(GradientCheckpointingLayer):
  461. def __init__(self, config, layer_idx=None):
  462. super().__init__()
  463. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  464. self.seq_len_dim = 1
  465. self.attention = BertAttention(config, layer_idx=layer_idx)
  466. self.is_decoder = config.is_decoder
  467. self.add_cross_attention = config.add_cross_attention
  468. if self.add_cross_attention:
  469. if not self.is_decoder:
  470. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  471. self.crossattention = BertAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
  472. self.intermediate = BertIntermediate(config)
  473. self.output = BertOutput(config)
  474. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  475. def forward(
  476. self,
  477. hidden_states: torch.Tensor,
  478. attention_mask: Optional[torch.FloatTensor] = None,
  479. head_mask: Optional[torch.FloatTensor] = None,
  480. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  481. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  482. past_key_values: Optional[Cache] = None,
  483. output_attentions: Optional[bool] = False,
  484. cache_position: Optional[torch.Tensor] = None,
  485. ) -> tuple[torch.Tensor]:
  486. self_attention_outputs = self.attention(
  487. hidden_states,
  488. attention_mask=attention_mask,
  489. head_mask=head_mask,
  490. output_attentions=output_attentions,
  491. past_key_values=past_key_values,
  492. cache_position=cache_position,
  493. )
  494. attention_output = self_attention_outputs[0]
  495. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  496. if self.is_decoder and encoder_hidden_states is not None:
  497. if not hasattr(self, "crossattention"):
  498. raise ValueError(
  499. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  500. " by setting `config.add_cross_attention=True`"
  501. )
  502. cross_attention_outputs = self.crossattention(
  503. attention_output,
  504. attention_mask=encoder_attention_mask,
  505. head_mask=head_mask,
  506. encoder_hidden_states=encoder_hidden_states,
  507. past_key_values=past_key_values,
  508. output_attentions=output_attentions,
  509. cache_position=cache_position,
  510. )
  511. attention_output = cross_attention_outputs[0]
  512. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  513. layer_output = apply_chunking_to_forward(
  514. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  515. )
  516. outputs = (layer_output,) + outputs
  517. return outputs
  518. def feed_forward_chunk(self, attention_output):
  519. intermediate_output = self.intermediate(attention_output)
  520. layer_output = self.output(intermediate_output, attention_output)
  521. return layer_output
  522. class BertEncoder(nn.Module):
  523. def __init__(self, config, layer_idx=None):
  524. super().__init__()
  525. self.config = config
  526. self.layer = nn.ModuleList([BertLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  527. self.gradient_checkpointing = False
  528. def forward(
  529. self,
  530. hidden_states: torch.Tensor,
  531. attention_mask: Optional[torch.FloatTensor] = None,
  532. head_mask: Optional[torch.FloatTensor] = None,
  533. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  534. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  535. past_key_values: Optional[Cache] = None,
  536. use_cache: Optional[bool] = None,
  537. output_attentions: Optional[bool] = False,
  538. output_hidden_states: Optional[bool] = False,
  539. return_dict: Optional[bool] = True,
  540. cache_position: Optional[torch.Tensor] = None,
  541. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  542. all_hidden_states = () if output_hidden_states else None
  543. all_self_attentions = () if output_attentions else None
  544. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  545. if self.gradient_checkpointing and self.training:
  546. if use_cache:
  547. logger.warning_once(
  548. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  549. )
  550. use_cache = False
  551. if use_cache and self.config.is_decoder and past_key_values is None:
  552. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  553. if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
  554. logger.warning_once(
  555. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  556. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  557. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  558. )
  559. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  560. for i, layer_module in enumerate(self.layer):
  561. if output_hidden_states:
  562. all_hidden_states = all_hidden_states + (hidden_states,)
  563. layer_head_mask = head_mask[i] if head_mask is not None else None
  564. layer_outputs = layer_module(
  565. hidden_states,
  566. attention_mask,
  567. layer_head_mask,
  568. encoder_hidden_states, # as a positional argument for gradient checkpointing
  569. encoder_attention_mask=encoder_attention_mask,
  570. past_key_values=past_key_values,
  571. output_attentions=output_attentions,
  572. cache_position=cache_position,
  573. )
  574. hidden_states = layer_outputs[0]
  575. if output_attentions:
  576. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  577. if self.config.add_cross_attention:
  578. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  579. if output_hidden_states:
  580. all_hidden_states = all_hidden_states + (hidden_states,)
  581. if not return_dict:
  582. return tuple(
  583. v
  584. for v in [
  585. hidden_states,
  586. past_key_values,
  587. all_hidden_states,
  588. all_self_attentions,
  589. all_cross_attentions,
  590. ]
  591. if v is not None
  592. )
  593. return BaseModelOutputWithPastAndCrossAttentions(
  594. last_hidden_state=hidden_states,
  595. past_key_values=past_key_values,
  596. hidden_states=all_hidden_states,
  597. attentions=all_self_attentions,
  598. cross_attentions=all_cross_attentions,
  599. )
  600. class BertPooler(nn.Module):
  601. def __init__(self, config):
  602. super().__init__()
  603. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  604. self.activation = nn.Tanh()
  605. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  606. # We "pool" the model by simply taking the hidden state corresponding
  607. # to the first token.
  608. first_token_tensor = hidden_states[:, 0]
  609. pooled_output = self.dense(first_token_tensor)
  610. pooled_output = self.activation(pooled_output)
  611. return pooled_output
  612. class BertPredictionHeadTransform(nn.Module):
  613. def __init__(self, config):
  614. super().__init__()
  615. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  616. if isinstance(config.hidden_act, str):
  617. self.transform_act_fn = ACT2FN[config.hidden_act]
  618. else:
  619. self.transform_act_fn = config.hidden_act
  620. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  621. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  622. hidden_states = self.dense(hidden_states)
  623. hidden_states = self.transform_act_fn(hidden_states)
  624. hidden_states = self.LayerNorm(hidden_states)
  625. return hidden_states
  626. class BertLMPredictionHead(nn.Module):
  627. def __init__(self, config):
  628. super().__init__()
  629. self.transform = BertPredictionHeadTransform(config)
  630. # The output weights are the same as the input embeddings, but there is
  631. # an output-only bias for each token.
  632. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  633. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  634. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  635. self.decoder.bias = self.bias
  636. def _tie_weights(self):
  637. self.decoder.bias = self.bias
  638. def forward(self, hidden_states):
  639. hidden_states = self.transform(hidden_states)
  640. hidden_states = self.decoder(hidden_states)
  641. return hidden_states
  642. class BertOnlyMLMHead(nn.Module):
  643. def __init__(self, config):
  644. super().__init__()
  645. self.predictions = BertLMPredictionHead(config)
  646. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  647. prediction_scores = self.predictions(sequence_output)
  648. return prediction_scores
  649. class BertOnlyNSPHead(nn.Module):
  650. def __init__(self, config):
  651. super().__init__()
  652. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  653. def forward(self, pooled_output):
  654. seq_relationship_score = self.seq_relationship(pooled_output)
  655. return seq_relationship_score
  656. class BertPreTrainingHeads(nn.Module):
  657. def __init__(self, config):
  658. super().__init__()
  659. self.predictions = BertLMPredictionHead(config)
  660. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  661. def forward(self, sequence_output, pooled_output):
  662. prediction_scores = self.predictions(sequence_output)
  663. seq_relationship_score = self.seq_relationship(pooled_output)
  664. return prediction_scores, seq_relationship_score
  665. @auto_docstring
  666. class BertPreTrainedModel(PreTrainedModel):
  667. config: BertConfig
  668. load_tf_weights = load_tf_weights_in_bert
  669. base_model_prefix = "bert"
  670. supports_gradient_checkpointing = True
  671. _supports_sdpa = True
  672. def _init_weights(self, module):
  673. """Initialize the weights"""
  674. if isinstance(module, nn.Linear):
  675. # Slightly different from the TF version which uses truncated_normal for initialization
  676. # cf https://github.com/pytorch/pytorch/pull/5617
  677. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  678. if module.bias is not None:
  679. module.bias.data.zero_()
  680. elif isinstance(module, nn.Embedding):
  681. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  682. if module.padding_idx is not None:
  683. module.weight.data[module.padding_idx].zero_()
  684. elif isinstance(module, nn.LayerNorm):
  685. module.bias.data.zero_()
  686. module.weight.data.fill_(1.0)
  687. elif isinstance(module, BertLMPredictionHead):
  688. module.bias.data.zero_()
  689. @dataclass
  690. @auto_docstring(
  691. custom_intro="""
  692. Output type of [`BertForPreTraining`].
  693. """
  694. )
  695. class BertForPreTrainingOutput(ModelOutput):
  696. r"""
  697. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  698. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  699. (classification) loss.
  700. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  701. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  702. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  703. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  704. before SoftMax).
  705. """
  706. loss: Optional[torch.FloatTensor] = None
  707. prediction_logits: Optional[torch.FloatTensor] = None
  708. seq_relationship_logits: Optional[torch.FloatTensor] = None
  709. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  710. attentions: Optional[tuple[torch.FloatTensor]] = None
  711. @auto_docstring(
  712. custom_intro="""
  713. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  714. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  715. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  716. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  717. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  718. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  719. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  720. """
  721. )
  722. class BertModel(BertPreTrainedModel):
  723. _no_split_modules = ["BertEmbeddings", "BertLayer"]
  724. def __init__(self, config, add_pooling_layer=True):
  725. r"""
  726. add_pooling_layer (bool, *optional*, defaults to `True`):
  727. Whether to add a pooling layer
  728. """
  729. super().__init__(config)
  730. self.config = config
  731. self.embeddings = BertEmbeddings(config)
  732. self.encoder = BertEncoder(config)
  733. self.pooler = BertPooler(config) if add_pooling_layer else None
  734. self.attn_implementation = config._attn_implementation
  735. self.position_embedding_type = config.position_embedding_type
  736. # Initialize weights and apply final processing
  737. self.post_init()
  738. def get_input_embeddings(self):
  739. return self.embeddings.word_embeddings
  740. def set_input_embeddings(self, value):
  741. self.embeddings.word_embeddings = value
  742. def _prune_heads(self, heads_to_prune):
  743. """
  744. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  745. class PreTrainedModel
  746. """
  747. for layer, heads in heads_to_prune.items():
  748. self.encoder.layer[layer].attention.prune_heads(heads)
  749. @auto_docstring
  750. def forward(
  751. self,
  752. input_ids: Optional[torch.Tensor] = None,
  753. attention_mask: Optional[torch.Tensor] = None,
  754. token_type_ids: Optional[torch.Tensor] = None,
  755. position_ids: Optional[torch.Tensor] = None,
  756. head_mask: Optional[torch.Tensor] = None,
  757. inputs_embeds: Optional[torch.Tensor] = None,
  758. encoder_hidden_states: Optional[torch.Tensor] = None,
  759. encoder_attention_mask: Optional[torch.Tensor] = None,
  760. past_key_values: Optional[Cache] = None,
  761. use_cache: Optional[bool] = None,
  762. output_attentions: Optional[bool] = None,
  763. output_hidden_states: Optional[bool] = None,
  764. return_dict: Optional[bool] = None,
  765. cache_position: Optional[torch.Tensor] = None,
  766. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  767. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  768. output_hidden_states = (
  769. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  770. )
  771. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  772. if self.config.is_decoder:
  773. use_cache = use_cache if use_cache is not None else self.config.use_cache
  774. else:
  775. use_cache = False
  776. if input_ids is not None and inputs_embeds is not None:
  777. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  778. elif input_ids is not None:
  779. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  780. input_shape = input_ids.size()
  781. elif inputs_embeds is not None:
  782. input_shape = inputs_embeds.size()[:-1]
  783. else:
  784. raise ValueError("You have to specify either input_ids or inputs_embeds")
  785. batch_size, seq_length = input_shape
  786. device = input_ids.device if input_ids is not None else inputs_embeds.device
  787. past_key_values_length = 0
  788. if past_key_values is not None:
  789. past_key_values_length = (
  790. past_key_values[0][0].shape[-2]
  791. if not isinstance(past_key_values, Cache)
  792. else past_key_values.get_seq_length()
  793. )
  794. if token_type_ids is None:
  795. if hasattr(self.embeddings, "token_type_ids"):
  796. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  797. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  798. token_type_ids = buffered_token_type_ids_expanded
  799. else:
  800. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  801. embedding_output = self.embeddings(
  802. input_ids=input_ids,
  803. position_ids=position_ids,
  804. token_type_ids=token_type_ids,
  805. inputs_embeds=inputs_embeds,
  806. past_key_values_length=past_key_values_length,
  807. )
  808. if attention_mask is None:
  809. attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
  810. use_sdpa_attention_masks = (
  811. self.attn_implementation == "sdpa"
  812. and self.position_embedding_type == "absolute"
  813. and head_mask is None
  814. and not output_attentions
  815. )
  816. # Expand the attention mask
  817. if use_sdpa_attention_masks and attention_mask.dim() == 2:
  818. # Expand the attention mask for SDPA.
  819. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  820. if self.config.is_decoder:
  821. extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  822. attention_mask,
  823. input_shape,
  824. embedding_output,
  825. past_key_values_length,
  826. )
  827. else:
  828. extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  829. attention_mask, embedding_output.dtype, tgt_len=seq_length
  830. )
  831. else:
  832. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  833. # ourselves in which case we just need to make it broadcastable to all heads.
  834. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  835. # If a 2D or 3D attention mask is provided for the cross-attention
  836. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  837. if self.config.is_decoder and encoder_hidden_states is not None:
  838. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  839. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  840. if encoder_attention_mask is None:
  841. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  842. if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
  843. # Expand the attention mask for SDPA.
  844. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  845. encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  846. encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
  847. )
  848. else:
  849. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  850. else:
  851. encoder_extended_attention_mask = None
  852. # Prepare head mask if needed
  853. # 1.0 in head_mask indicate we keep the head
  854. # attention_probs has shape bsz x n_heads x N x N
  855. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  856. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  857. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  858. encoder_outputs = self.encoder(
  859. embedding_output,
  860. attention_mask=extended_attention_mask,
  861. head_mask=head_mask,
  862. encoder_hidden_states=encoder_hidden_states,
  863. encoder_attention_mask=encoder_extended_attention_mask,
  864. past_key_values=past_key_values,
  865. use_cache=use_cache,
  866. output_attentions=output_attentions,
  867. output_hidden_states=output_hidden_states,
  868. return_dict=return_dict,
  869. cache_position=cache_position,
  870. )
  871. sequence_output = encoder_outputs[0]
  872. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  873. if not return_dict:
  874. return (sequence_output, pooled_output) + encoder_outputs[1:]
  875. return BaseModelOutputWithPoolingAndCrossAttentions(
  876. last_hidden_state=sequence_output,
  877. pooler_output=pooled_output,
  878. past_key_values=encoder_outputs.past_key_values,
  879. hidden_states=encoder_outputs.hidden_states,
  880. attentions=encoder_outputs.attentions,
  881. cross_attentions=encoder_outputs.cross_attentions,
  882. )
  883. @auto_docstring(
  884. custom_intro="""
  885. Bert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a `next
  886. sentence prediction (classification)` head.
  887. """
  888. )
  889. class BertForPreTraining(BertPreTrainedModel):
  890. _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
  891. def __init__(self, config):
  892. super().__init__(config)
  893. self.bert = BertModel(config)
  894. self.cls = BertPreTrainingHeads(config)
  895. # Initialize weights and apply final processing
  896. self.post_init()
  897. def get_output_embeddings(self):
  898. return self.cls.predictions.decoder
  899. def set_output_embeddings(self, new_embeddings):
  900. self.cls.predictions.decoder = new_embeddings
  901. self.cls.predictions.bias = new_embeddings.bias
  902. @auto_docstring
  903. def forward(
  904. self,
  905. input_ids: Optional[torch.Tensor] = None,
  906. attention_mask: Optional[torch.Tensor] = None,
  907. token_type_ids: Optional[torch.Tensor] = None,
  908. position_ids: Optional[torch.Tensor] = None,
  909. head_mask: Optional[torch.Tensor] = None,
  910. inputs_embeds: Optional[torch.Tensor] = None,
  911. labels: Optional[torch.Tensor] = None,
  912. next_sentence_label: Optional[torch.Tensor] = None,
  913. output_attentions: Optional[bool] = None,
  914. output_hidden_states: Optional[bool] = None,
  915. return_dict: Optional[bool] = None,
  916. ) -> Union[tuple[torch.Tensor], BertForPreTrainingOutput]:
  917. r"""
  918. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  919. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  920. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
  921. the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  922. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  923. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
  924. pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
  925. - 0 indicates sequence B is a continuation of sequence A,
  926. - 1 indicates sequence B is a random sequence.
  927. Example:
  928. ```python
  929. >>> from transformers import AutoTokenizer, BertForPreTraining
  930. >>> import torch
  931. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  932. >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
  933. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  934. >>> outputs = model(**inputs)
  935. >>> prediction_logits = outputs.prediction_logits
  936. >>> seq_relationship_logits = outputs.seq_relationship_logits
  937. ```
  938. """
  939. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  940. outputs = self.bert(
  941. input_ids,
  942. attention_mask=attention_mask,
  943. token_type_ids=token_type_ids,
  944. position_ids=position_ids,
  945. head_mask=head_mask,
  946. inputs_embeds=inputs_embeds,
  947. output_attentions=output_attentions,
  948. output_hidden_states=output_hidden_states,
  949. return_dict=return_dict,
  950. )
  951. sequence_output, pooled_output = outputs[:2]
  952. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  953. total_loss = None
  954. if labels is not None and next_sentence_label is not None:
  955. loss_fct = CrossEntropyLoss()
  956. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  957. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  958. total_loss = masked_lm_loss + next_sentence_loss
  959. if not return_dict:
  960. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  961. return ((total_loss,) + output) if total_loss is not None else output
  962. return BertForPreTrainingOutput(
  963. loss=total_loss,
  964. prediction_logits=prediction_scores,
  965. seq_relationship_logits=seq_relationship_score,
  966. hidden_states=outputs.hidden_states,
  967. attentions=outputs.attentions,
  968. )
  969. @auto_docstring(
  970. custom_intro="""
  971. Bert Model with a `language modeling` head on top for CLM fine-tuning.
  972. """
  973. )
  974. class BertLMHeadModel(BertPreTrainedModel, GenerationMixin):
  975. _tied_weights_keys = ["cls.predictions.decoder.bias", "cls.predictions.decoder.weight"]
  976. def __init__(self, config):
  977. super().__init__(config)
  978. if not config.is_decoder:
  979. logger.warning("If you want to use `BertLMHeadModel` as a standalone, add `is_decoder=True.`")
  980. self.bert = BertModel(config, add_pooling_layer=False)
  981. self.cls = BertOnlyMLMHead(config)
  982. # Initialize weights and apply final processing
  983. self.post_init()
  984. def get_output_embeddings(self):
  985. return self.cls.predictions.decoder
  986. def set_output_embeddings(self, new_embeddings):
  987. self.cls.predictions.decoder = new_embeddings
  988. self.cls.predictions.bias = new_embeddings.bias
  989. @auto_docstring
  990. def forward(
  991. self,
  992. input_ids: Optional[torch.Tensor] = None,
  993. attention_mask: Optional[torch.Tensor] = None,
  994. token_type_ids: Optional[torch.Tensor] = None,
  995. position_ids: Optional[torch.Tensor] = None,
  996. head_mask: Optional[torch.Tensor] = None,
  997. inputs_embeds: Optional[torch.Tensor] = None,
  998. encoder_hidden_states: Optional[torch.Tensor] = None,
  999. encoder_attention_mask: Optional[torch.Tensor] = None,
  1000. labels: Optional[torch.Tensor] = None,
  1001. past_key_values: Optional[Cache] = None,
  1002. use_cache: Optional[bool] = None,
  1003. output_attentions: Optional[bool] = None,
  1004. output_hidden_states: Optional[bool] = None,
  1005. return_dict: Optional[bool] = None,
  1006. cache_position: Optional[torch.Tensor] = None,
  1007. **loss_kwargs,
  1008. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  1009. r"""
  1010. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1011. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  1012. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  1013. ignored (masked), the loss is only computed for the tokens with labels n `[0, ..., config.vocab_size]`
  1014. """
  1015. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1016. if labels is not None:
  1017. use_cache = False
  1018. outputs = self.bert(
  1019. input_ids,
  1020. attention_mask=attention_mask,
  1021. token_type_ids=token_type_ids,
  1022. position_ids=position_ids,
  1023. head_mask=head_mask,
  1024. inputs_embeds=inputs_embeds,
  1025. encoder_hidden_states=encoder_hidden_states,
  1026. encoder_attention_mask=encoder_attention_mask,
  1027. past_key_values=past_key_values,
  1028. use_cache=use_cache,
  1029. output_attentions=output_attentions,
  1030. output_hidden_states=output_hidden_states,
  1031. return_dict=return_dict,
  1032. cache_position=cache_position,
  1033. )
  1034. sequence_output = outputs[0]
  1035. prediction_scores = self.cls(sequence_output)
  1036. lm_loss = None
  1037. if labels is not None:
  1038. lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
  1039. if not return_dict:
  1040. output = (prediction_scores,) + outputs[2:]
  1041. return ((lm_loss,) + output) if lm_loss is not None else output
  1042. return CausalLMOutputWithCrossAttentions(
  1043. loss=lm_loss,
  1044. logits=prediction_scores,
  1045. past_key_values=outputs.past_key_values,
  1046. hidden_states=outputs.hidden_states,
  1047. attentions=outputs.attentions,
  1048. cross_attentions=outputs.cross_attentions,
  1049. )
  1050. @auto_docstring
  1051. class BertForMaskedLM(BertPreTrainedModel):
  1052. _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
  1053. def __init__(self, config):
  1054. super().__init__(config)
  1055. if config.is_decoder:
  1056. logger.warning(
  1057. "If you want to use `BertForMaskedLM` make sure `config.is_decoder=False` for "
  1058. "bi-directional self-attention."
  1059. )
  1060. self.bert = BertModel(config, add_pooling_layer=False)
  1061. self.cls = BertOnlyMLMHead(config)
  1062. # Initialize weights and apply final processing
  1063. self.post_init()
  1064. def get_output_embeddings(self):
  1065. return self.cls.predictions.decoder
  1066. def set_output_embeddings(self, new_embeddings):
  1067. self.cls.predictions.decoder = new_embeddings
  1068. self.cls.predictions.bias = new_embeddings.bias
  1069. @auto_docstring
  1070. def forward(
  1071. self,
  1072. input_ids: Optional[torch.Tensor] = None,
  1073. attention_mask: Optional[torch.Tensor] = None,
  1074. token_type_ids: Optional[torch.Tensor] = None,
  1075. position_ids: Optional[torch.Tensor] = None,
  1076. head_mask: Optional[torch.Tensor] = None,
  1077. inputs_embeds: Optional[torch.Tensor] = None,
  1078. encoder_hidden_states: Optional[torch.Tensor] = None,
  1079. encoder_attention_mask: Optional[torch.Tensor] = None,
  1080. labels: Optional[torch.Tensor] = None,
  1081. output_attentions: Optional[bool] = None,
  1082. output_hidden_states: Optional[bool] = None,
  1083. return_dict: Optional[bool] = None,
  1084. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  1085. r"""
  1086. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1087. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  1088. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  1089. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  1090. """
  1091. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1092. outputs = self.bert(
  1093. input_ids,
  1094. attention_mask=attention_mask,
  1095. token_type_ids=token_type_ids,
  1096. position_ids=position_ids,
  1097. head_mask=head_mask,
  1098. inputs_embeds=inputs_embeds,
  1099. encoder_hidden_states=encoder_hidden_states,
  1100. encoder_attention_mask=encoder_attention_mask,
  1101. output_attentions=output_attentions,
  1102. output_hidden_states=output_hidden_states,
  1103. return_dict=return_dict,
  1104. )
  1105. sequence_output = outputs[0]
  1106. prediction_scores = self.cls(sequence_output)
  1107. masked_lm_loss = None
  1108. if labels is not None:
  1109. loss_fct = CrossEntropyLoss() # -100 index = padding token
  1110. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  1111. if not return_dict:
  1112. output = (prediction_scores,) + outputs[2:]
  1113. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  1114. return MaskedLMOutput(
  1115. loss=masked_lm_loss,
  1116. logits=prediction_scores,
  1117. hidden_states=outputs.hidden_states,
  1118. attentions=outputs.attentions,
  1119. )
  1120. def prepare_inputs_for_generation(self, input_ids, attention_mask=None, **model_kwargs):
  1121. input_shape = input_ids.shape
  1122. effective_batch_size = input_shape[0]
  1123. # add a dummy token
  1124. if self.config.pad_token_id is None:
  1125. raise ValueError("The PAD token should be defined for generation")
  1126. attention_mask = torch.cat([attention_mask, attention_mask.new_zeros((attention_mask.shape[0], 1))], dim=-1)
  1127. dummy_token = torch.full(
  1128. (effective_batch_size, 1), self.config.pad_token_id, dtype=torch.long, device=input_ids.device
  1129. )
  1130. input_ids = torch.cat([input_ids, dummy_token], dim=1)
  1131. return {"input_ids": input_ids, "attention_mask": attention_mask}
  1132. @classmethod
  1133. def can_generate(cls) -> bool:
  1134. """
  1135. Legacy correction: BertForMaskedLM can't call `generate()` from `GenerationMixin`, even though it has a
  1136. `prepare_inputs_for_generation` method.
  1137. """
  1138. return False
  1139. @auto_docstring(
  1140. custom_intro="""
  1141. Bert Model with a `next sentence prediction (classification)` head on top.
  1142. """
  1143. )
  1144. class BertForNextSentencePrediction(BertPreTrainedModel):
  1145. def __init__(self, config):
  1146. super().__init__(config)
  1147. self.bert = BertModel(config)
  1148. self.cls = BertOnlyNSPHead(config)
  1149. # Initialize weights and apply final processing
  1150. self.post_init()
  1151. @auto_docstring
  1152. def forward(
  1153. self,
  1154. input_ids: Optional[torch.Tensor] = None,
  1155. attention_mask: Optional[torch.Tensor] = None,
  1156. token_type_ids: Optional[torch.Tensor] = None,
  1157. position_ids: Optional[torch.Tensor] = None,
  1158. head_mask: Optional[torch.Tensor] = None,
  1159. inputs_embeds: Optional[torch.Tensor] = None,
  1160. labels: Optional[torch.Tensor] = None,
  1161. output_attentions: Optional[bool] = None,
  1162. output_hidden_states: Optional[bool] = None,
  1163. return_dict: Optional[bool] = None,
  1164. **kwargs,
  1165. ) -> Union[tuple[torch.Tensor], NextSentencePredictorOutput]:
  1166. r"""
  1167. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1168. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  1169. (see `input_ids` docstring). Indices should be in `[0, 1]`:
  1170. - 0 indicates sequence B is a continuation of sequence A,
  1171. - 1 indicates sequence B is a random sequence.
  1172. Example:
  1173. ```python
  1174. >>> from transformers import AutoTokenizer, BertForNextSentencePrediction
  1175. >>> import torch
  1176. >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
  1177. >>> model = BertForNextSentencePrediction.from_pretrained("google-bert/bert-base-uncased")
  1178. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  1179. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  1180. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  1181. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  1182. >>> logits = outputs.logits
  1183. >>> assert logits[0, 0] < logits[0, 1] # next sentence was random
  1184. ```
  1185. """
  1186. if "next_sentence_label" in kwargs:
  1187. warnings.warn(
  1188. "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
  1189. " `labels` instead.",
  1190. FutureWarning,
  1191. )
  1192. labels = kwargs.pop("next_sentence_label")
  1193. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1194. outputs = self.bert(
  1195. input_ids,
  1196. attention_mask=attention_mask,
  1197. token_type_ids=token_type_ids,
  1198. position_ids=position_ids,
  1199. head_mask=head_mask,
  1200. inputs_embeds=inputs_embeds,
  1201. output_attentions=output_attentions,
  1202. output_hidden_states=output_hidden_states,
  1203. return_dict=return_dict,
  1204. )
  1205. pooled_output = outputs[1]
  1206. seq_relationship_scores = self.cls(pooled_output)
  1207. next_sentence_loss = None
  1208. if labels is not None:
  1209. loss_fct = CrossEntropyLoss()
  1210. next_sentence_loss = loss_fct(seq_relationship_scores.view(-1, 2), labels.view(-1))
  1211. if not return_dict:
  1212. output = (seq_relationship_scores,) + outputs[2:]
  1213. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  1214. return NextSentencePredictorOutput(
  1215. loss=next_sentence_loss,
  1216. logits=seq_relationship_scores,
  1217. hidden_states=outputs.hidden_states,
  1218. attentions=outputs.attentions,
  1219. )
  1220. @auto_docstring(
  1221. custom_intro="""
  1222. Bert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  1223. output) e.g. for GLUE tasks.
  1224. """
  1225. )
  1226. class BertForSequenceClassification(BertPreTrainedModel):
  1227. def __init__(self, config):
  1228. super().__init__(config)
  1229. self.num_labels = config.num_labels
  1230. self.config = config
  1231. self.bert = BertModel(config)
  1232. classifier_dropout = (
  1233. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1234. )
  1235. self.dropout = nn.Dropout(classifier_dropout)
  1236. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1237. # Initialize weights and apply final processing
  1238. self.post_init()
  1239. @auto_docstring
  1240. def forward(
  1241. self,
  1242. input_ids: Optional[torch.Tensor] = None,
  1243. attention_mask: Optional[torch.Tensor] = None,
  1244. token_type_ids: Optional[torch.Tensor] = None,
  1245. position_ids: Optional[torch.Tensor] = None,
  1246. head_mask: Optional[torch.Tensor] = None,
  1247. inputs_embeds: Optional[torch.Tensor] = None,
  1248. labels: Optional[torch.Tensor] = None,
  1249. output_attentions: Optional[bool] = None,
  1250. output_hidden_states: Optional[bool] = None,
  1251. return_dict: Optional[bool] = None,
  1252. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  1253. r"""
  1254. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1255. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1256. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1257. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1258. """
  1259. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1260. outputs = self.bert(
  1261. input_ids,
  1262. attention_mask=attention_mask,
  1263. token_type_ids=token_type_ids,
  1264. position_ids=position_ids,
  1265. head_mask=head_mask,
  1266. inputs_embeds=inputs_embeds,
  1267. output_attentions=output_attentions,
  1268. output_hidden_states=output_hidden_states,
  1269. return_dict=return_dict,
  1270. )
  1271. pooled_output = outputs[1]
  1272. pooled_output = self.dropout(pooled_output)
  1273. logits = self.classifier(pooled_output)
  1274. loss = None
  1275. if labels is not None:
  1276. if self.config.problem_type is None:
  1277. if self.num_labels == 1:
  1278. self.config.problem_type = "regression"
  1279. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1280. self.config.problem_type = "single_label_classification"
  1281. else:
  1282. self.config.problem_type = "multi_label_classification"
  1283. if self.config.problem_type == "regression":
  1284. loss_fct = MSELoss()
  1285. if self.num_labels == 1:
  1286. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1287. else:
  1288. loss = loss_fct(logits, labels)
  1289. elif self.config.problem_type == "single_label_classification":
  1290. loss_fct = CrossEntropyLoss()
  1291. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1292. elif self.config.problem_type == "multi_label_classification":
  1293. loss_fct = BCEWithLogitsLoss()
  1294. loss = loss_fct(logits, labels)
  1295. if not return_dict:
  1296. output = (logits,) + outputs[2:]
  1297. return ((loss,) + output) if loss is not None else output
  1298. return SequenceClassifierOutput(
  1299. loss=loss,
  1300. logits=logits,
  1301. hidden_states=outputs.hidden_states,
  1302. attentions=outputs.attentions,
  1303. )
  1304. @auto_docstring
  1305. class BertForMultipleChoice(BertPreTrainedModel):
  1306. def __init__(self, config):
  1307. super().__init__(config)
  1308. self.bert = BertModel(config)
  1309. classifier_dropout = (
  1310. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1311. )
  1312. self.dropout = nn.Dropout(classifier_dropout)
  1313. self.classifier = nn.Linear(config.hidden_size, 1)
  1314. # Initialize weights and apply final processing
  1315. self.post_init()
  1316. @auto_docstring
  1317. def forward(
  1318. self,
  1319. input_ids: Optional[torch.Tensor] = None,
  1320. attention_mask: Optional[torch.Tensor] = None,
  1321. token_type_ids: Optional[torch.Tensor] = None,
  1322. position_ids: Optional[torch.Tensor] = None,
  1323. head_mask: Optional[torch.Tensor] = None,
  1324. inputs_embeds: Optional[torch.Tensor] = None,
  1325. labels: Optional[torch.Tensor] = None,
  1326. output_attentions: Optional[bool] = None,
  1327. output_hidden_states: Optional[bool] = None,
  1328. return_dict: Optional[bool] = None,
  1329. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1330. r"""
  1331. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1332. Indices of input sequence tokens in the vocabulary.
  1333. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1334. [`PreTrainedTokenizer.__call__`] for details.
  1335. [What are input IDs?](../glossary#input-ids)
  1336. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1337. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1338. 1]`:
  1339. - 0 corresponds to a *sentence A* token,
  1340. - 1 corresponds to a *sentence B* token.
  1341. [What are token type IDs?](../glossary#token-type-ids)
  1342. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1343. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1344. config.max_position_embeddings - 1]`.
  1345. [What are position IDs?](../glossary#position-ids)
  1346. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1347. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1348. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1349. model's internal embedding lookup matrix.
  1350. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1351. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1352. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1353. `input_ids` above)
  1354. """
  1355. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1356. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1357. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1358. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1359. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1360. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1361. inputs_embeds = (
  1362. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1363. if inputs_embeds is not None
  1364. else None
  1365. )
  1366. outputs = self.bert(
  1367. input_ids,
  1368. attention_mask=attention_mask,
  1369. token_type_ids=token_type_ids,
  1370. position_ids=position_ids,
  1371. head_mask=head_mask,
  1372. inputs_embeds=inputs_embeds,
  1373. output_attentions=output_attentions,
  1374. output_hidden_states=output_hidden_states,
  1375. return_dict=return_dict,
  1376. )
  1377. pooled_output = outputs[1]
  1378. pooled_output = self.dropout(pooled_output)
  1379. logits = self.classifier(pooled_output)
  1380. reshaped_logits = logits.view(-1, num_choices)
  1381. loss = None
  1382. if labels is not None:
  1383. loss_fct = CrossEntropyLoss()
  1384. loss = loss_fct(reshaped_logits, labels)
  1385. if not return_dict:
  1386. output = (reshaped_logits,) + outputs[2:]
  1387. return ((loss,) + output) if loss is not None else output
  1388. return MultipleChoiceModelOutput(
  1389. loss=loss,
  1390. logits=reshaped_logits,
  1391. hidden_states=outputs.hidden_states,
  1392. attentions=outputs.attentions,
  1393. )
  1394. @auto_docstring
  1395. class BertForTokenClassification(BertPreTrainedModel):
  1396. def __init__(self, config):
  1397. super().__init__(config)
  1398. self.num_labels = config.num_labels
  1399. self.bert = BertModel(config, add_pooling_layer=False)
  1400. classifier_dropout = (
  1401. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1402. )
  1403. self.dropout = nn.Dropout(classifier_dropout)
  1404. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1405. # Initialize weights and apply final processing
  1406. self.post_init()
  1407. @auto_docstring
  1408. def forward(
  1409. self,
  1410. input_ids: Optional[torch.Tensor] = None,
  1411. attention_mask: Optional[torch.Tensor] = None,
  1412. token_type_ids: Optional[torch.Tensor] = None,
  1413. position_ids: Optional[torch.Tensor] = None,
  1414. head_mask: Optional[torch.Tensor] = None,
  1415. inputs_embeds: Optional[torch.Tensor] = None,
  1416. labels: Optional[torch.Tensor] = None,
  1417. output_attentions: Optional[bool] = None,
  1418. output_hidden_states: Optional[bool] = None,
  1419. return_dict: Optional[bool] = None,
  1420. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1421. r"""
  1422. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1423. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1424. """
  1425. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1426. outputs = self.bert(
  1427. input_ids,
  1428. attention_mask=attention_mask,
  1429. token_type_ids=token_type_ids,
  1430. position_ids=position_ids,
  1431. head_mask=head_mask,
  1432. inputs_embeds=inputs_embeds,
  1433. output_attentions=output_attentions,
  1434. output_hidden_states=output_hidden_states,
  1435. return_dict=return_dict,
  1436. )
  1437. sequence_output = outputs[0]
  1438. sequence_output = self.dropout(sequence_output)
  1439. logits = self.classifier(sequence_output)
  1440. loss = None
  1441. if labels is not None:
  1442. loss_fct = CrossEntropyLoss()
  1443. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1444. if not return_dict:
  1445. output = (logits,) + outputs[2:]
  1446. return ((loss,) + output) if loss is not None else output
  1447. return TokenClassifierOutput(
  1448. loss=loss,
  1449. logits=logits,
  1450. hidden_states=outputs.hidden_states,
  1451. attentions=outputs.attentions,
  1452. )
  1453. @auto_docstring
  1454. class BertForQuestionAnswering(BertPreTrainedModel):
  1455. def __init__(self, config):
  1456. super().__init__(config)
  1457. self.num_labels = config.num_labels
  1458. self.bert = BertModel(config, add_pooling_layer=False)
  1459. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1460. # Initialize weights and apply final processing
  1461. self.post_init()
  1462. @auto_docstring
  1463. def forward(
  1464. self,
  1465. input_ids: Optional[torch.Tensor] = None,
  1466. attention_mask: Optional[torch.Tensor] = None,
  1467. token_type_ids: Optional[torch.Tensor] = None,
  1468. position_ids: Optional[torch.Tensor] = None,
  1469. head_mask: Optional[torch.Tensor] = None,
  1470. inputs_embeds: Optional[torch.Tensor] = None,
  1471. start_positions: Optional[torch.Tensor] = None,
  1472. end_positions: Optional[torch.Tensor] = None,
  1473. output_attentions: Optional[bool] = None,
  1474. output_hidden_states: Optional[bool] = None,
  1475. return_dict: Optional[bool] = None,
  1476. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1477. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1478. outputs = self.bert(
  1479. input_ids,
  1480. attention_mask=attention_mask,
  1481. token_type_ids=token_type_ids,
  1482. position_ids=position_ids,
  1483. head_mask=head_mask,
  1484. inputs_embeds=inputs_embeds,
  1485. output_attentions=output_attentions,
  1486. output_hidden_states=output_hidden_states,
  1487. return_dict=return_dict,
  1488. )
  1489. sequence_output = outputs[0]
  1490. logits = self.qa_outputs(sequence_output)
  1491. start_logits, end_logits = logits.split(1, dim=-1)
  1492. start_logits = start_logits.squeeze(-1).contiguous()
  1493. end_logits = end_logits.squeeze(-1).contiguous()
  1494. total_loss = None
  1495. if start_positions is not None and end_positions is not None:
  1496. # If we are on multi-GPU, split add a dimension
  1497. if len(start_positions.size()) > 1:
  1498. start_positions = start_positions.squeeze(-1)
  1499. if len(end_positions.size()) > 1:
  1500. end_positions = end_positions.squeeze(-1)
  1501. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1502. ignored_index = start_logits.size(1)
  1503. start_positions = start_positions.clamp(0, ignored_index)
  1504. end_positions = end_positions.clamp(0, ignored_index)
  1505. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1506. start_loss = loss_fct(start_logits, start_positions)
  1507. end_loss = loss_fct(end_logits, end_positions)
  1508. total_loss = (start_loss + end_loss) / 2
  1509. if not return_dict:
  1510. output = (start_logits, end_logits) + outputs[2:]
  1511. return ((total_loss,) + output) if total_loss is not None else output
  1512. return QuestionAnsweringModelOutput(
  1513. loss=total_loss,
  1514. start_logits=start_logits,
  1515. end_logits=end_logits,
  1516. hidden_states=outputs.hidden_states,
  1517. attentions=outputs.attentions,
  1518. )
  1519. __all__ = [
  1520. "BertForMaskedLM",
  1521. "BertForMultipleChoice",
  1522. "BertForNextSentencePrediction",
  1523. "BertForPreTraining",
  1524. "BertForQuestionAnswering",
  1525. "BertForSequenceClassification",
  1526. "BertForTokenClassification",
  1527. "BertLayer",
  1528. "BertLMHeadModel",
  1529. "BertModel",
  1530. "BertPreTrainedModel",
  1531. "load_tf_weights_in_bert",
  1532. ]