modeling_albert.py 56 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349
  1. # coding=utf-8
  2. # Copyright 2018 Google AI, Google Brain and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ALBERT model."""
  16. import math
  17. import os
  18. from dataclasses import dataclass
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN
  24. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa
  25. from ...modeling_outputs import (
  26. BaseModelOutput,
  27. BaseModelOutputWithPooling,
  28. MaskedLMOutput,
  29. MultipleChoiceModelOutput,
  30. QuestionAnsweringModelOutput,
  31. SequenceClassifierOutput,
  32. TokenClassifierOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import (
  36. apply_chunking_to_forward,
  37. find_pruneable_heads_and_indices,
  38. prune_linear_layer,
  39. )
  40. from ...utils import ModelOutput, auto_docstring, logging
  41. from .configuration_albert import AlbertConfig
  42. logger = logging.get_logger(__name__)
  43. def load_tf_weights_in_albert(model, config, tf_checkpoint_path):
  44. """Load tf checkpoints in a pytorch model."""
  45. try:
  46. import re
  47. import numpy as np
  48. import tensorflow as tf
  49. except ImportError:
  50. logger.error(
  51. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  52. "https://www.tensorflow.org/install/ for installation instructions."
  53. )
  54. raise
  55. tf_path = os.path.abspath(tf_checkpoint_path)
  56. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  57. # Load weights from TF model
  58. init_vars = tf.train.list_variables(tf_path)
  59. names = []
  60. arrays = []
  61. for name, shape in init_vars:
  62. logger.info(f"Loading TF weight {name} with shape {shape}")
  63. array = tf.train.load_variable(tf_path, name)
  64. names.append(name)
  65. arrays.append(array)
  66. for name, array in zip(names, arrays):
  67. print(name)
  68. for name, array in zip(names, arrays):
  69. original_name = name
  70. # If saved from the TF HUB module
  71. name = name.replace("module/", "")
  72. # Renaming and simplifying
  73. name = name.replace("ffn_1", "ffn")
  74. name = name.replace("bert/", "albert/")
  75. name = name.replace("attention_1", "attention")
  76. name = name.replace("transform/", "")
  77. name = name.replace("LayerNorm_1", "full_layer_layer_norm")
  78. name = name.replace("LayerNorm", "attention/LayerNorm")
  79. name = name.replace("transformer/", "")
  80. # The feed forward layer had an 'intermediate' step which has been abstracted away
  81. name = name.replace("intermediate/dense/", "")
  82. name = name.replace("ffn/intermediate/output/dense/", "ffn_output/")
  83. # ALBERT attention was split between self and output which have been abstracted away
  84. name = name.replace("/output/", "/")
  85. name = name.replace("/self/", "/")
  86. # The pooler is a linear layer
  87. name = name.replace("pooler/dense", "pooler")
  88. # The classifier was simplified to predictions from cls/predictions
  89. name = name.replace("cls/predictions", "predictions")
  90. name = name.replace("predictions/attention", "predictions")
  91. # Naming was changed to be more explicit
  92. name = name.replace("embeddings/attention", "embeddings")
  93. name = name.replace("inner_group_", "albert_layers/")
  94. name = name.replace("group_", "albert_layer_groups/")
  95. # Classifier
  96. if len(name.split("/")) == 1 and ("output_bias" in name or "output_weights" in name):
  97. name = "classifier/" + name
  98. # No ALBERT model currently handles the next sentence prediction task
  99. if "seq_relationship" in name:
  100. name = name.replace("seq_relationship/output_", "sop_classifier/classifier/")
  101. name = name.replace("weights", "weight")
  102. name = name.split("/")
  103. # Ignore the gradients applied by the LAMB/ADAM optimizers.
  104. if (
  105. "adam_m" in name
  106. or "adam_v" in name
  107. or "AdamWeightDecayOptimizer" in name
  108. or "AdamWeightDecayOptimizer_1" in name
  109. or "global_step" in name
  110. ):
  111. logger.info(f"Skipping {'/'.join(name)}")
  112. continue
  113. pointer = model
  114. for m_name in name:
  115. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  116. scope_names = re.split(r"_(\d+)", m_name)
  117. else:
  118. scope_names = [m_name]
  119. if scope_names[0] == "kernel" or scope_names[0] == "gamma":
  120. pointer = getattr(pointer, "weight")
  121. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  122. pointer = getattr(pointer, "bias")
  123. elif scope_names[0] == "output_weights":
  124. pointer = getattr(pointer, "weight")
  125. elif scope_names[0] == "squad":
  126. pointer = getattr(pointer, "classifier")
  127. else:
  128. try:
  129. pointer = getattr(pointer, scope_names[0])
  130. except AttributeError:
  131. logger.info(f"Skipping {'/'.join(name)}")
  132. continue
  133. if len(scope_names) >= 2:
  134. num = int(scope_names[1])
  135. pointer = pointer[num]
  136. if m_name[-11:] == "_embeddings":
  137. pointer = getattr(pointer, "weight")
  138. elif m_name == "kernel":
  139. array = np.transpose(array)
  140. try:
  141. if pointer.shape != array.shape:
  142. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  143. except ValueError as e:
  144. e.args += (pointer.shape, array.shape)
  145. raise
  146. print(f"Initialize PyTorch weight {name} from {original_name}")
  147. pointer.data = torch.from_numpy(array)
  148. return model
  149. class AlbertEmbeddings(nn.Module):
  150. """
  151. Construct the embeddings from word, position and token_type embeddings.
  152. """
  153. def __init__(self, config: AlbertConfig):
  154. super().__init__()
  155. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  156. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  157. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  158. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  159. # any TensorFlow checkpoint file
  160. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  161. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  162. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  163. self.register_buffer(
  164. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  165. )
  166. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  167. self.register_buffer(
  168. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  169. )
  170. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.forward
  171. def forward(
  172. self,
  173. input_ids: Optional[torch.LongTensor] = None,
  174. token_type_ids: Optional[torch.LongTensor] = None,
  175. position_ids: Optional[torch.LongTensor] = None,
  176. inputs_embeds: Optional[torch.FloatTensor] = None,
  177. past_key_values_length: int = 0,
  178. ) -> torch.Tensor:
  179. if input_ids is not None:
  180. input_shape = input_ids.size()
  181. else:
  182. input_shape = inputs_embeds.size()[:-1]
  183. seq_length = input_shape[1]
  184. if position_ids is None:
  185. position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length]
  186. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  187. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  188. # issue #5664
  189. if token_type_ids is None:
  190. if hasattr(self, "token_type_ids"):
  191. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  192. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  193. token_type_ids = buffered_token_type_ids_expanded
  194. else:
  195. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  196. if inputs_embeds is None:
  197. inputs_embeds = self.word_embeddings(input_ids)
  198. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  199. embeddings = inputs_embeds + token_type_embeddings
  200. if self.position_embedding_type == "absolute":
  201. position_embeddings = self.position_embeddings(position_ids)
  202. embeddings += position_embeddings
  203. embeddings = self.LayerNorm(embeddings)
  204. embeddings = self.dropout(embeddings)
  205. return embeddings
  206. class AlbertAttention(nn.Module):
  207. def __init__(self, config: AlbertConfig):
  208. super().__init__()
  209. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  210. raise ValueError(
  211. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  212. f"heads ({config.num_attention_heads}"
  213. )
  214. self.num_attention_heads = config.num_attention_heads
  215. self.hidden_size = config.hidden_size
  216. self.attention_head_size = config.hidden_size // config.num_attention_heads
  217. self.all_head_size = self.num_attention_heads * self.attention_head_size
  218. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  219. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  220. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  221. self.attention_dropout = nn.Dropout(config.attention_probs_dropout_prob)
  222. self.output_dropout = nn.Dropout(config.hidden_dropout_prob)
  223. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  224. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  225. self.pruned_heads = set()
  226. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  227. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  228. self.max_position_embeddings = config.max_position_embeddings
  229. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  230. def prune_heads(self, heads: list[int]) -> None:
  231. if len(heads) == 0:
  232. return
  233. heads, index = find_pruneable_heads_and_indices(
  234. heads, self.num_attention_heads, self.attention_head_size, self.pruned_heads
  235. )
  236. # Prune linear layers
  237. self.query = prune_linear_layer(self.query, index)
  238. self.key = prune_linear_layer(self.key, index)
  239. self.value = prune_linear_layer(self.value, index)
  240. self.dense = prune_linear_layer(self.dense, index, dim=1)
  241. # Update hyper params and store pruned heads
  242. self.num_attention_heads = self.num_attention_heads - len(heads)
  243. self.all_head_size = self.attention_head_size * self.num_attention_heads
  244. self.pruned_heads = self.pruned_heads.union(heads)
  245. def forward(
  246. self,
  247. hidden_states: torch.Tensor,
  248. attention_mask: Optional[torch.FloatTensor] = None,
  249. head_mask: Optional[torch.FloatTensor] = None,
  250. output_attentions: bool = False,
  251. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  252. batch_size, seq_length, _ = hidden_states.shape
  253. query_layer = self.query(hidden_states)
  254. key_layer = self.key(hidden_states)
  255. value_layer = self.value(hidden_states)
  256. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  257. 1, 2
  258. )
  259. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  260. value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  261. 1, 2
  262. )
  263. # Take the dot product between "query" and "key" to get the raw attention scores.
  264. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  265. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  266. if attention_mask is not None:
  267. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  268. attention_scores = attention_scores + attention_mask
  269. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  270. seq_length = hidden_states.size()[1]
  271. position_ids_l = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  272. position_ids_r = torch.arange(seq_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  273. distance = position_ids_l - position_ids_r
  274. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  275. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  276. if self.position_embedding_type == "relative_key":
  277. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  278. attention_scores = attention_scores + relative_position_scores
  279. elif self.position_embedding_type == "relative_key_query":
  280. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  281. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  282. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  283. # Normalize the attention scores to probabilities.
  284. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  285. # This is actually dropping out entire tokens to attend to, which might
  286. # seem a bit unusual, but is taken from the original Transformer paper.
  287. attention_probs = self.attention_dropout(attention_probs)
  288. # Mask heads if we want to
  289. if head_mask is not None:
  290. attention_probs = attention_probs * head_mask
  291. context_layer = torch.matmul(attention_probs, value_layer)
  292. context_layer = context_layer.transpose(2, 1).flatten(2)
  293. projected_context_layer = self.dense(context_layer)
  294. projected_context_layer_dropout = self.output_dropout(projected_context_layer)
  295. layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
  296. return (layernormed_context_layer, attention_probs) if output_attentions else (layernormed_context_layer,)
  297. class AlbertSdpaAttention(AlbertAttention):
  298. def __init__(self, config):
  299. super().__init__(config)
  300. self.dropout_prob = config.attention_probs_dropout_prob
  301. def forward(
  302. self,
  303. hidden_states: torch.Tensor,
  304. attention_mask: Optional[torch.FloatTensor] = None,
  305. head_mask: Optional[torch.FloatTensor] = None,
  306. output_attentions: bool = False,
  307. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  308. if self.position_embedding_type != "absolute" or output_attentions:
  309. logger.warning(
  310. "AlbertSdpaAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  311. "non-absolute `position_embedding_type` or `output_attentions=True` . Falling back to "
  312. "the eager attention implementation, but specifying the eager implementation will be required from "
  313. "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
  314. '`attn_implementation="eager"` when loading the model.'
  315. )
  316. return super().forward(hidden_states, attention_mask, output_attentions=output_attentions)
  317. batch_size, seq_len, _ = hidden_states.size()
  318. query_layer = (
  319. self.query(hidden_states)
  320. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  321. .transpose(1, 2)
  322. )
  323. key_layer = (
  324. self.key(hidden_states)
  325. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  326. .transpose(1, 2)
  327. )
  328. value_layer = (
  329. self.value(hidden_states)
  330. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  331. .transpose(1, 2)
  332. )
  333. attention_output = torch.nn.functional.scaled_dot_product_attention(
  334. query=query_layer,
  335. key=key_layer,
  336. value=value_layer,
  337. attn_mask=attention_mask,
  338. dropout_p=self.dropout_prob if self.training else 0.0,
  339. is_causal=False,
  340. )
  341. attention_output = attention_output.transpose(1, 2)
  342. attention_output = attention_output.reshape(batch_size, seq_len, self.all_head_size)
  343. projected_context_layer = self.dense(attention_output)
  344. projected_context_layer_dropout = self.output_dropout(projected_context_layer)
  345. layernormed_context_layer = self.LayerNorm(hidden_states + projected_context_layer_dropout)
  346. return (layernormed_context_layer,)
  347. ALBERT_ATTENTION_CLASSES = {
  348. "eager": AlbertAttention,
  349. "sdpa": AlbertSdpaAttention,
  350. }
  351. class AlbertLayer(nn.Module):
  352. def __init__(self, config: AlbertConfig):
  353. super().__init__()
  354. self.config = config
  355. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  356. self.seq_len_dim = 1
  357. self.full_layer_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  358. self.attention = ALBERT_ATTENTION_CLASSES[config._attn_implementation](config)
  359. self.ffn = nn.Linear(config.hidden_size, config.intermediate_size)
  360. self.ffn_output = nn.Linear(config.intermediate_size, config.hidden_size)
  361. self.activation = ACT2FN[config.hidden_act]
  362. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  363. def forward(
  364. self,
  365. hidden_states: torch.Tensor,
  366. attention_mask: Optional[torch.FloatTensor] = None,
  367. head_mask: Optional[torch.FloatTensor] = None,
  368. output_attentions: bool = False,
  369. output_hidden_states: bool = False,
  370. ) -> tuple[torch.Tensor, torch.Tensor]:
  371. attention_output = self.attention(hidden_states, attention_mask, head_mask, output_attentions)
  372. ffn_output = apply_chunking_to_forward(
  373. self.ff_chunk,
  374. self.chunk_size_feed_forward,
  375. self.seq_len_dim,
  376. attention_output[0],
  377. )
  378. hidden_states = self.full_layer_layer_norm(ffn_output + attention_output[0])
  379. return (hidden_states,) + attention_output[1:] # add attentions if we output them
  380. def ff_chunk(self, attention_output: torch.Tensor) -> torch.Tensor:
  381. ffn_output = self.ffn(attention_output)
  382. ffn_output = self.activation(ffn_output)
  383. ffn_output = self.ffn_output(ffn_output)
  384. return ffn_output
  385. class AlbertLayerGroup(nn.Module):
  386. def __init__(self, config: AlbertConfig):
  387. super().__init__()
  388. self.albert_layers = nn.ModuleList([AlbertLayer(config) for _ in range(config.inner_group_num)])
  389. def forward(
  390. self,
  391. hidden_states: torch.Tensor,
  392. attention_mask: Optional[torch.FloatTensor] = None,
  393. head_mask: Optional[torch.FloatTensor] = None,
  394. output_attentions: bool = False,
  395. output_hidden_states: bool = False,
  396. ) -> tuple[Union[torch.Tensor, tuple[torch.Tensor]], ...]:
  397. layer_hidden_states = ()
  398. layer_attentions = ()
  399. for layer_index, albert_layer in enumerate(self.albert_layers):
  400. layer_output = albert_layer(hidden_states, attention_mask, head_mask[layer_index], output_attentions)
  401. hidden_states = layer_output[0]
  402. if output_attentions:
  403. layer_attentions = layer_attentions + (layer_output[1],)
  404. if output_hidden_states:
  405. layer_hidden_states = layer_hidden_states + (hidden_states,)
  406. outputs = (hidden_states,)
  407. if output_hidden_states:
  408. outputs = outputs + (layer_hidden_states,)
  409. if output_attentions:
  410. outputs = outputs + (layer_attentions,)
  411. return outputs # last-layer hidden state, (layer hidden states), (layer attentions)
  412. class AlbertTransformer(nn.Module):
  413. def __init__(self, config: AlbertConfig):
  414. super().__init__()
  415. self.config = config
  416. self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
  417. self.albert_layer_groups = nn.ModuleList([AlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
  418. def forward(
  419. self,
  420. hidden_states: torch.Tensor,
  421. attention_mask: Optional[torch.FloatTensor] = None,
  422. head_mask: Optional[torch.FloatTensor] = None,
  423. output_attentions: bool = False,
  424. output_hidden_states: bool = False,
  425. return_dict: bool = True,
  426. ) -> Union[BaseModelOutput, tuple]:
  427. hidden_states = self.embedding_hidden_mapping_in(hidden_states)
  428. all_hidden_states = (hidden_states,) if output_hidden_states else None
  429. all_attentions = () if output_attentions else None
  430. head_mask = [None] * self.config.num_hidden_layers if head_mask is None else head_mask
  431. for i in range(self.config.num_hidden_layers):
  432. # Number of layers in a hidden group
  433. layers_per_group = int(self.config.num_hidden_layers / self.config.num_hidden_groups)
  434. # Index of the hidden group
  435. group_idx = int(i / (self.config.num_hidden_layers / self.config.num_hidden_groups))
  436. layer_group_output = self.albert_layer_groups[group_idx](
  437. hidden_states,
  438. attention_mask,
  439. head_mask[group_idx * layers_per_group : (group_idx + 1) * layers_per_group],
  440. output_attentions,
  441. output_hidden_states,
  442. )
  443. hidden_states = layer_group_output[0]
  444. if output_attentions:
  445. all_attentions = all_attentions + layer_group_output[-1]
  446. if output_hidden_states:
  447. all_hidden_states = all_hidden_states + (hidden_states,)
  448. if not return_dict:
  449. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  450. return BaseModelOutput(
  451. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  452. )
  453. @auto_docstring
  454. class AlbertPreTrainedModel(PreTrainedModel):
  455. config: AlbertConfig
  456. load_tf_weights = load_tf_weights_in_albert
  457. base_model_prefix = "albert"
  458. _supports_sdpa = True
  459. def _init_weights(self, module):
  460. """Initialize the weights."""
  461. if isinstance(module, nn.Linear):
  462. # Slightly different from the TF version which uses truncated_normal for initialization
  463. # cf https://github.com/pytorch/pytorch/pull/5617
  464. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  465. if module.bias is not None:
  466. module.bias.data.zero_()
  467. elif isinstance(module, nn.Embedding):
  468. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  469. if module.padding_idx is not None:
  470. module.weight.data[module.padding_idx].zero_()
  471. elif isinstance(module, nn.LayerNorm):
  472. module.bias.data.zero_()
  473. module.weight.data.fill_(1.0)
  474. elif isinstance(module, AlbertMLMHead):
  475. module.bias.data.zero_()
  476. @dataclass
  477. @auto_docstring(
  478. custom_intro="""
  479. Output type of [`AlbertForPreTraining`].
  480. """
  481. )
  482. class AlbertForPreTrainingOutput(ModelOutput):
  483. r"""
  484. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  485. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  486. (classification) loss.
  487. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  488. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  489. sop_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  490. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  491. before SoftMax).
  492. """
  493. loss: Optional[torch.FloatTensor] = None
  494. prediction_logits: Optional[torch.FloatTensor] = None
  495. sop_logits: Optional[torch.FloatTensor] = None
  496. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  497. attentions: Optional[tuple[torch.FloatTensor]] = None
  498. @auto_docstring
  499. class AlbertModel(AlbertPreTrainedModel):
  500. config: AlbertConfig
  501. base_model_prefix = "albert"
  502. def __init__(self, config: AlbertConfig, add_pooling_layer: bool = True):
  503. r"""
  504. add_pooling_layer (bool, *optional*, defaults to `True`):
  505. Whether to add a pooling layer
  506. """
  507. super().__init__(config)
  508. self.config = config
  509. self.embeddings = AlbertEmbeddings(config)
  510. self.encoder = AlbertTransformer(config)
  511. if add_pooling_layer:
  512. self.pooler = nn.Linear(config.hidden_size, config.hidden_size)
  513. self.pooler_activation = nn.Tanh()
  514. else:
  515. self.pooler = None
  516. self.pooler_activation = None
  517. self.attn_implementation = config._attn_implementation
  518. self.position_embedding_type = config.position_embedding_type
  519. # Initialize weights and apply final processing
  520. self.post_init()
  521. def get_input_embeddings(self) -> nn.Embedding:
  522. return self.embeddings.word_embeddings
  523. def set_input_embeddings(self, value: nn.Embedding) -> None:
  524. self.embeddings.word_embeddings = value
  525. def _prune_heads(self, heads_to_prune: dict[int, list[int]]) -> None:
  526. """
  527. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} ALBERT has
  528. a different architecture in that its layers are shared across groups, which then has inner groups. If an ALBERT
  529. model has 12 hidden layers and 2 hidden groups, with two inner groups, there is a total of 4 different layers.
  530. These layers are flattened: the indices [0,1] correspond to the two inner groups of the first hidden layer,
  531. while [2,3] correspond to the two inner groups of the second hidden layer.
  532. Any layer with in index other than [0,1,2,3] will result in an error. See base class PreTrainedModel for more
  533. information about head pruning
  534. """
  535. for layer, heads in heads_to_prune.items():
  536. group_idx = int(layer / self.config.inner_group_num)
  537. inner_group_idx = int(layer - group_idx * self.config.inner_group_num)
  538. self.encoder.albert_layer_groups[group_idx].albert_layers[inner_group_idx].attention.prune_heads(heads)
  539. @auto_docstring
  540. def forward(
  541. self,
  542. input_ids: Optional[torch.LongTensor] = None,
  543. attention_mask: Optional[torch.FloatTensor] = None,
  544. token_type_ids: Optional[torch.LongTensor] = None,
  545. position_ids: Optional[torch.LongTensor] = None,
  546. head_mask: Optional[torch.FloatTensor] = None,
  547. inputs_embeds: Optional[torch.FloatTensor] = None,
  548. output_attentions: Optional[bool] = None,
  549. output_hidden_states: Optional[bool] = None,
  550. return_dict: Optional[bool] = None,
  551. ) -> Union[BaseModelOutputWithPooling, tuple]:
  552. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  553. output_hidden_states = (
  554. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  555. )
  556. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  557. if input_ids is not None and inputs_embeds is not None:
  558. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  559. elif input_ids is not None:
  560. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  561. input_shape = input_ids.size()
  562. elif inputs_embeds is not None:
  563. input_shape = inputs_embeds.size()[:-1]
  564. else:
  565. raise ValueError("You have to specify either input_ids or inputs_embeds")
  566. batch_size, seq_length = input_shape
  567. device = input_ids.device if input_ids is not None else inputs_embeds.device
  568. if attention_mask is None:
  569. attention_mask = torch.ones(input_shape, device=device)
  570. if token_type_ids is None:
  571. if hasattr(self.embeddings, "token_type_ids"):
  572. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  573. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  574. token_type_ids = buffered_token_type_ids_expanded
  575. else:
  576. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  577. embedding_output = self.embeddings(
  578. input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  579. )
  580. use_sdpa_attention_mask = (
  581. self.attn_implementation == "sdpa"
  582. and self.position_embedding_type == "absolute"
  583. and head_mask is None
  584. and not output_attentions
  585. )
  586. if use_sdpa_attention_mask:
  587. extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  588. attention_mask, embedding_output.dtype, tgt_len=seq_length
  589. )
  590. else:
  591. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  592. extended_attention_mask = extended_attention_mask.to(dtype=self.dtype) # fp16 compatibility
  593. extended_attention_mask = (1.0 - extended_attention_mask) * torch.finfo(self.dtype).min
  594. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  595. encoder_outputs = self.encoder(
  596. embedding_output,
  597. extended_attention_mask,
  598. head_mask=head_mask,
  599. output_attentions=output_attentions,
  600. output_hidden_states=output_hidden_states,
  601. return_dict=return_dict,
  602. )
  603. sequence_output = encoder_outputs[0]
  604. pooled_output = self.pooler_activation(self.pooler(sequence_output[:, 0])) if self.pooler is not None else None
  605. if not return_dict:
  606. return (sequence_output, pooled_output) + encoder_outputs[1:]
  607. return BaseModelOutputWithPooling(
  608. last_hidden_state=sequence_output,
  609. pooler_output=pooled_output,
  610. hidden_states=encoder_outputs.hidden_states,
  611. attentions=encoder_outputs.attentions,
  612. )
  613. @auto_docstring(
  614. custom_intro="""
  615. Albert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
  616. `sentence order prediction (classification)` head.
  617. """
  618. )
  619. class AlbertForPreTraining(AlbertPreTrainedModel):
  620. _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
  621. def __init__(self, config: AlbertConfig):
  622. super().__init__(config)
  623. self.albert = AlbertModel(config)
  624. self.predictions = AlbertMLMHead(config)
  625. self.sop_classifier = AlbertSOPHead(config)
  626. # Initialize weights and apply final processing
  627. self.post_init()
  628. def get_output_embeddings(self) -> nn.Linear:
  629. return self.predictions.decoder
  630. def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
  631. self.predictions.decoder = new_embeddings
  632. def get_input_embeddings(self) -> nn.Embedding:
  633. return self.albert.embeddings.word_embeddings
  634. @auto_docstring
  635. def forward(
  636. self,
  637. input_ids: Optional[torch.LongTensor] = None,
  638. attention_mask: Optional[torch.FloatTensor] = None,
  639. token_type_ids: Optional[torch.LongTensor] = None,
  640. position_ids: Optional[torch.LongTensor] = None,
  641. head_mask: Optional[torch.FloatTensor] = None,
  642. inputs_embeds: Optional[torch.FloatTensor] = None,
  643. labels: Optional[torch.LongTensor] = None,
  644. sentence_order_label: Optional[torch.LongTensor] = None,
  645. output_attentions: Optional[bool] = None,
  646. output_hidden_states: Optional[bool] = None,
  647. return_dict: Optional[bool] = None,
  648. ) -> Union[AlbertForPreTrainingOutput, tuple]:
  649. r"""
  650. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  651. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  652. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  653. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  654. sentence_order_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  655. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  656. (see `input_ids` docstring) Indices should be in `[0, 1]`. `0` indicates original order (sequence A, then
  657. sequence B), `1` indicates switched order (sequence B, then sequence A).
  658. Example:
  659. ```python
  660. >>> from transformers import AutoTokenizer, AlbertForPreTraining
  661. >>> import torch
  662. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  663. >>> model = AlbertForPreTraining.from_pretrained("albert/albert-base-v2")
  664. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
  665. >>> # Batch size 1
  666. >>> outputs = model(input_ids)
  667. >>> prediction_logits = outputs.prediction_logits
  668. >>> sop_logits = outputs.sop_logits
  669. ```"""
  670. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  671. outputs = self.albert(
  672. input_ids,
  673. attention_mask=attention_mask,
  674. token_type_ids=token_type_ids,
  675. position_ids=position_ids,
  676. head_mask=head_mask,
  677. inputs_embeds=inputs_embeds,
  678. output_attentions=output_attentions,
  679. output_hidden_states=output_hidden_states,
  680. return_dict=return_dict,
  681. )
  682. sequence_output, pooled_output = outputs[:2]
  683. prediction_scores = self.predictions(sequence_output)
  684. sop_scores = self.sop_classifier(pooled_output)
  685. total_loss = None
  686. if labels is not None and sentence_order_label is not None:
  687. loss_fct = CrossEntropyLoss()
  688. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  689. sentence_order_loss = loss_fct(sop_scores.view(-1, 2), sentence_order_label.view(-1))
  690. total_loss = masked_lm_loss + sentence_order_loss
  691. if not return_dict:
  692. output = (prediction_scores, sop_scores) + outputs[2:]
  693. return ((total_loss,) + output) if total_loss is not None else output
  694. return AlbertForPreTrainingOutput(
  695. loss=total_loss,
  696. prediction_logits=prediction_scores,
  697. sop_logits=sop_scores,
  698. hidden_states=outputs.hidden_states,
  699. attentions=outputs.attentions,
  700. )
  701. class AlbertMLMHead(nn.Module):
  702. def __init__(self, config: AlbertConfig):
  703. super().__init__()
  704. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  705. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  706. self.dense = nn.Linear(config.hidden_size, config.embedding_size)
  707. self.decoder = nn.Linear(config.embedding_size, config.vocab_size)
  708. self.activation = ACT2FN[config.hidden_act]
  709. self.decoder.bias = self.bias
  710. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  711. hidden_states = self.dense(hidden_states)
  712. hidden_states = self.activation(hidden_states)
  713. hidden_states = self.LayerNorm(hidden_states)
  714. hidden_states = self.decoder(hidden_states)
  715. prediction_scores = hidden_states
  716. return prediction_scores
  717. def _tie_weights(self) -> None:
  718. # For accelerate compatibility and to not break backward compatibility
  719. if self.decoder.bias.device.type == "meta":
  720. self.decoder.bias = self.bias
  721. else:
  722. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  723. self.bias = self.decoder.bias
  724. class AlbertSOPHead(nn.Module):
  725. def __init__(self, config: AlbertConfig):
  726. super().__init__()
  727. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  728. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  729. def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
  730. dropout_pooled_output = self.dropout(pooled_output)
  731. logits = self.classifier(dropout_pooled_output)
  732. return logits
  733. @auto_docstring
  734. class AlbertForMaskedLM(AlbertPreTrainedModel):
  735. _tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"]
  736. def __init__(self, config):
  737. super().__init__(config)
  738. self.albert = AlbertModel(config, add_pooling_layer=False)
  739. self.predictions = AlbertMLMHead(config)
  740. # Initialize weights and apply final processing
  741. self.post_init()
  742. def get_output_embeddings(self) -> nn.Linear:
  743. return self.predictions.decoder
  744. def set_output_embeddings(self, new_embeddings: nn.Linear) -> None:
  745. self.predictions.decoder = new_embeddings
  746. self.predictions.bias = new_embeddings.bias
  747. def get_input_embeddings(self) -> nn.Embedding:
  748. return self.albert.embeddings.word_embeddings
  749. @auto_docstring
  750. def forward(
  751. self,
  752. input_ids: Optional[torch.LongTensor] = None,
  753. attention_mask: Optional[torch.FloatTensor] = None,
  754. token_type_ids: Optional[torch.LongTensor] = None,
  755. position_ids: Optional[torch.LongTensor] = None,
  756. head_mask: Optional[torch.FloatTensor] = None,
  757. inputs_embeds: Optional[torch.FloatTensor] = None,
  758. labels: Optional[torch.LongTensor] = None,
  759. output_attentions: Optional[bool] = None,
  760. output_hidden_states: Optional[bool] = None,
  761. return_dict: Optional[bool] = None,
  762. ) -> Union[MaskedLMOutput, tuple]:
  763. r"""
  764. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  765. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  766. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  767. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  768. Example:
  769. ```python
  770. >>> import torch
  771. >>> from transformers import AutoTokenizer, AlbertForMaskedLM
  772. >>> tokenizer = AutoTokenizer.from_pretrained("albert/albert-base-v2")
  773. >>> model = AlbertForMaskedLM.from_pretrained("albert/albert-base-v2")
  774. >>> # add mask_token
  775. >>> inputs = tokenizer("The capital of [MASK] is Paris.", return_tensors="pt")
  776. >>> with torch.no_grad():
  777. ... logits = model(**inputs).logits
  778. >>> # retrieve index of [MASK]
  779. >>> mask_token_index = (inputs.input_ids == tokenizer.mask_token_id)[0].nonzero(as_tuple=True)[0]
  780. >>> predicted_token_id = logits[0, mask_token_index].argmax(axis=-1)
  781. >>> tokenizer.decode(predicted_token_id)
  782. 'france'
  783. ```
  784. ```python
  785. >>> labels = tokenizer("The capital of France is Paris.", return_tensors="pt")["input_ids"]
  786. >>> labels = torch.where(inputs.input_ids == tokenizer.mask_token_id, labels, -100)
  787. >>> outputs = model(**inputs, labels=labels)
  788. >>> round(outputs.loss.item(), 2)
  789. 0.81
  790. ```
  791. """
  792. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  793. outputs = self.albert(
  794. input_ids=input_ids,
  795. attention_mask=attention_mask,
  796. token_type_ids=token_type_ids,
  797. position_ids=position_ids,
  798. head_mask=head_mask,
  799. inputs_embeds=inputs_embeds,
  800. output_attentions=output_attentions,
  801. output_hidden_states=output_hidden_states,
  802. return_dict=return_dict,
  803. )
  804. sequence_outputs = outputs[0]
  805. prediction_scores = self.predictions(sequence_outputs)
  806. masked_lm_loss = None
  807. if labels is not None:
  808. loss_fct = CrossEntropyLoss()
  809. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  810. if not return_dict:
  811. output = (prediction_scores,) + outputs[2:]
  812. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  813. return MaskedLMOutput(
  814. loss=masked_lm_loss,
  815. logits=prediction_scores,
  816. hidden_states=outputs.hidden_states,
  817. attentions=outputs.attentions,
  818. )
  819. @auto_docstring(
  820. custom_intro="""
  821. Albert Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  822. output) e.g. for GLUE tasks.
  823. """
  824. )
  825. class AlbertForSequenceClassification(AlbertPreTrainedModel):
  826. def __init__(self, config: AlbertConfig):
  827. super().__init__(config)
  828. self.num_labels = config.num_labels
  829. self.config = config
  830. self.albert = AlbertModel(config)
  831. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  832. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  833. # Initialize weights and apply final processing
  834. self.post_init()
  835. @auto_docstring
  836. def forward(
  837. self,
  838. input_ids: Optional[torch.LongTensor] = None,
  839. attention_mask: Optional[torch.FloatTensor] = None,
  840. token_type_ids: Optional[torch.LongTensor] = None,
  841. position_ids: Optional[torch.LongTensor] = None,
  842. head_mask: Optional[torch.FloatTensor] = None,
  843. inputs_embeds: Optional[torch.FloatTensor] = None,
  844. labels: Optional[torch.LongTensor] = None,
  845. output_attentions: Optional[bool] = None,
  846. output_hidden_states: Optional[bool] = None,
  847. return_dict: Optional[bool] = None,
  848. ) -> Union[SequenceClassifierOutput, tuple]:
  849. r"""
  850. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  851. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  852. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  853. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  854. """
  855. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  856. outputs = self.albert(
  857. input_ids=input_ids,
  858. attention_mask=attention_mask,
  859. token_type_ids=token_type_ids,
  860. position_ids=position_ids,
  861. head_mask=head_mask,
  862. inputs_embeds=inputs_embeds,
  863. output_attentions=output_attentions,
  864. output_hidden_states=output_hidden_states,
  865. return_dict=return_dict,
  866. )
  867. pooled_output = outputs[1]
  868. pooled_output = self.dropout(pooled_output)
  869. logits = self.classifier(pooled_output)
  870. loss = None
  871. if labels is not None:
  872. if self.config.problem_type is None:
  873. if self.num_labels == 1:
  874. self.config.problem_type = "regression"
  875. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  876. self.config.problem_type = "single_label_classification"
  877. else:
  878. self.config.problem_type = "multi_label_classification"
  879. if self.config.problem_type == "regression":
  880. loss_fct = MSELoss()
  881. if self.num_labels == 1:
  882. loss = loss_fct(logits.squeeze(), labels.squeeze())
  883. else:
  884. loss = loss_fct(logits, labels)
  885. elif self.config.problem_type == "single_label_classification":
  886. loss_fct = CrossEntropyLoss()
  887. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  888. elif self.config.problem_type == "multi_label_classification":
  889. loss_fct = BCEWithLogitsLoss()
  890. loss = loss_fct(logits, labels)
  891. if not return_dict:
  892. output = (logits,) + outputs[2:]
  893. return ((loss,) + output) if loss is not None else output
  894. return SequenceClassifierOutput(
  895. loss=loss,
  896. logits=logits,
  897. hidden_states=outputs.hidden_states,
  898. attentions=outputs.attentions,
  899. )
  900. @auto_docstring
  901. class AlbertForTokenClassification(AlbertPreTrainedModel):
  902. def __init__(self, config: AlbertConfig):
  903. super().__init__(config)
  904. self.num_labels = config.num_labels
  905. self.albert = AlbertModel(config, add_pooling_layer=False)
  906. classifier_dropout_prob = (
  907. config.classifier_dropout_prob
  908. if config.classifier_dropout_prob is not None
  909. else config.hidden_dropout_prob
  910. )
  911. self.dropout = nn.Dropout(classifier_dropout_prob)
  912. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  913. # Initialize weights and apply final processing
  914. self.post_init()
  915. @auto_docstring
  916. def forward(
  917. self,
  918. input_ids: Optional[torch.LongTensor] = None,
  919. attention_mask: Optional[torch.FloatTensor] = None,
  920. token_type_ids: Optional[torch.LongTensor] = None,
  921. position_ids: Optional[torch.LongTensor] = None,
  922. head_mask: Optional[torch.FloatTensor] = None,
  923. inputs_embeds: Optional[torch.FloatTensor] = None,
  924. labels: Optional[torch.LongTensor] = None,
  925. output_attentions: Optional[bool] = None,
  926. output_hidden_states: Optional[bool] = None,
  927. return_dict: Optional[bool] = None,
  928. ) -> Union[TokenClassifierOutput, tuple]:
  929. r"""
  930. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  931. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  932. """
  933. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  934. outputs = self.albert(
  935. input_ids,
  936. attention_mask=attention_mask,
  937. token_type_ids=token_type_ids,
  938. position_ids=position_ids,
  939. head_mask=head_mask,
  940. inputs_embeds=inputs_embeds,
  941. output_attentions=output_attentions,
  942. output_hidden_states=output_hidden_states,
  943. return_dict=return_dict,
  944. )
  945. sequence_output = outputs[0]
  946. sequence_output = self.dropout(sequence_output)
  947. logits = self.classifier(sequence_output)
  948. loss = None
  949. if labels is not None:
  950. loss_fct = CrossEntropyLoss()
  951. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  952. if not return_dict:
  953. output = (logits,) + outputs[2:]
  954. return ((loss,) + output) if loss is not None else output
  955. return TokenClassifierOutput(
  956. loss=loss,
  957. logits=logits,
  958. hidden_states=outputs.hidden_states,
  959. attentions=outputs.attentions,
  960. )
  961. @auto_docstring
  962. class AlbertForQuestionAnswering(AlbertPreTrainedModel):
  963. def __init__(self, config: AlbertConfig):
  964. super().__init__(config)
  965. self.num_labels = config.num_labels
  966. self.albert = AlbertModel(config, add_pooling_layer=False)
  967. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  968. # Initialize weights and apply final processing
  969. self.post_init()
  970. @auto_docstring
  971. def forward(
  972. self,
  973. input_ids: Optional[torch.LongTensor] = None,
  974. attention_mask: Optional[torch.FloatTensor] = None,
  975. token_type_ids: Optional[torch.LongTensor] = None,
  976. position_ids: Optional[torch.LongTensor] = None,
  977. head_mask: Optional[torch.FloatTensor] = None,
  978. inputs_embeds: Optional[torch.FloatTensor] = None,
  979. start_positions: Optional[torch.LongTensor] = None,
  980. end_positions: Optional[torch.LongTensor] = None,
  981. output_attentions: Optional[bool] = None,
  982. output_hidden_states: Optional[bool] = None,
  983. return_dict: Optional[bool] = None,
  984. ) -> Union[AlbertForPreTrainingOutput, tuple]:
  985. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  986. outputs = self.albert(
  987. input_ids=input_ids,
  988. attention_mask=attention_mask,
  989. token_type_ids=token_type_ids,
  990. position_ids=position_ids,
  991. head_mask=head_mask,
  992. inputs_embeds=inputs_embeds,
  993. output_attentions=output_attentions,
  994. output_hidden_states=output_hidden_states,
  995. return_dict=return_dict,
  996. )
  997. sequence_output = outputs[0]
  998. logits: torch.Tensor = self.qa_outputs(sequence_output)
  999. start_logits, end_logits = logits.split(1, dim=-1)
  1000. start_logits = start_logits.squeeze(-1).contiguous()
  1001. end_logits = end_logits.squeeze(-1).contiguous()
  1002. total_loss = None
  1003. if start_positions is not None and end_positions is not None:
  1004. # If we are on multi-GPU, split add a dimension
  1005. if len(start_positions.size()) > 1:
  1006. start_positions = start_positions.squeeze(-1)
  1007. if len(end_positions.size()) > 1:
  1008. end_positions = end_positions.squeeze(-1)
  1009. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1010. ignored_index = start_logits.size(1)
  1011. start_positions = start_positions.clamp(0, ignored_index)
  1012. end_positions = end_positions.clamp(0, ignored_index)
  1013. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1014. start_loss = loss_fct(start_logits, start_positions)
  1015. end_loss = loss_fct(end_logits, end_positions)
  1016. total_loss = (start_loss + end_loss) / 2
  1017. if not return_dict:
  1018. output = (start_logits, end_logits) + outputs[2:]
  1019. return ((total_loss,) + output) if total_loss is not None else output
  1020. return QuestionAnsweringModelOutput(
  1021. loss=total_loss,
  1022. start_logits=start_logits,
  1023. end_logits=end_logits,
  1024. hidden_states=outputs.hidden_states,
  1025. attentions=outputs.attentions,
  1026. )
  1027. @auto_docstring
  1028. class AlbertForMultipleChoice(AlbertPreTrainedModel):
  1029. def __init__(self, config: AlbertConfig):
  1030. super().__init__(config)
  1031. self.albert = AlbertModel(config)
  1032. self.dropout = nn.Dropout(config.classifier_dropout_prob)
  1033. self.classifier = nn.Linear(config.hidden_size, 1)
  1034. # Initialize weights and apply final processing
  1035. self.post_init()
  1036. @auto_docstring
  1037. def forward(
  1038. self,
  1039. input_ids: Optional[torch.LongTensor] = None,
  1040. attention_mask: Optional[torch.FloatTensor] = None,
  1041. token_type_ids: Optional[torch.LongTensor] = None,
  1042. position_ids: Optional[torch.LongTensor] = None,
  1043. head_mask: Optional[torch.FloatTensor] = None,
  1044. inputs_embeds: Optional[torch.FloatTensor] = None,
  1045. labels: Optional[torch.LongTensor] = None,
  1046. output_attentions: Optional[bool] = None,
  1047. output_hidden_states: Optional[bool] = None,
  1048. return_dict: Optional[bool] = None,
  1049. ) -> Union[AlbertForPreTrainingOutput, tuple]:
  1050. r"""
  1051. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1052. Indices of input sequence tokens in the vocabulary.
  1053. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.__call__`] and
  1054. [`PreTrainedTokenizer.encode`] for details.
  1055. [What are input IDs?](../glossary#input-ids)
  1056. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1057. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1058. 1]`:
  1059. - 0 corresponds to a *sentence A* token,
  1060. - 1 corresponds to a *sentence B* token.
  1061. [What are token type IDs?](../glossary#token-type-ids)
  1062. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1063. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1064. config.max_position_embeddings - 1]`.
  1065. [What are position IDs?](../glossary#position-ids)
  1066. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1067. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1068. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1069. model's internal embedding lookup matrix.
  1070. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1071. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1072. num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
  1073. *input_ids* above)
  1074. """
  1075. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1076. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1077. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1078. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1079. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1080. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1081. inputs_embeds = (
  1082. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1083. if inputs_embeds is not None
  1084. else None
  1085. )
  1086. outputs = self.albert(
  1087. input_ids,
  1088. attention_mask=attention_mask,
  1089. token_type_ids=token_type_ids,
  1090. position_ids=position_ids,
  1091. head_mask=head_mask,
  1092. inputs_embeds=inputs_embeds,
  1093. output_attentions=output_attentions,
  1094. output_hidden_states=output_hidden_states,
  1095. return_dict=return_dict,
  1096. )
  1097. pooled_output = outputs[1]
  1098. pooled_output = self.dropout(pooled_output)
  1099. logits: torch.Tensor = self.classifier(pooled_output)
  1100. reshaped_logits = logits.view(-1, num_choices)
  1101. loss = None
  1102. if labels is not None:
  1103. loss_fct = CrossEntropyLoss()
  1104. loss = loss_fct(reshaped_logits, labels)
  1105. if not return_dict:
  1106. output = (reshaped_logits,) + outputs[2:]
  1107. return ((loss,) + output) if loss is not None else output
  1108. return MultipleChoiceModelOutput(
  1109. loss=loss,
  1110. logits=reshaped_logits,
  1111. hidden_states=outputs.hidden_states,
  1112. attentions=outputs.attentions,
  1113. )
  1114. __all__ = [
  1115. "load_tf_weights_in_albert",
  1116. "AlbertPreTrainedModel",
  1117. "AlbertModel",
  1118. "AlbertForPreTraining",
  1119. "AlbertForMaskedLM",
  1120. "AlbertForSequenceClassification",
  1121. "AlbertForTokenClassification",
  1122. "AlbertForQuestionAnswering",
  1123. "AlbertForMultipleChoice",
  1124. ]