modeling_mobilebert.py 62 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482
  1. # MIT License
  2. #
  3. # Copyright (c) 2020 The Google AI Language Team Authors, The HuggingFace Inc. team and github/lonePatient
  4. #
  5. # Permission is hereby granted, free of charge, to any person obtaining a copy
  6. # of this software and associated documentation files (the "Software"), to deal
  7. # in the Software without restriction, including without limitation the rights
  8. # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
  9. # copies of the Software, and to permit persons to whom the Software is
  10. # furnished to do so, subject to the following conditions:
  11. #
  12. # The above copyright notice and this permission notice shall be included in all
  13. # copies or substantial portions of the Software.
  14. #
  15. # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
  16. # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
  17. # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
  18. # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
  19. # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
  20. # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
  21. # SOFTWARE.
  22. import math
  23. import os
  24. import warnings
  25. from dataclasses import dataclass
  26. from typing import Optional, Union
  27. import torch
  28. from torch import nn
  29. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  30. from ...activations import ACT2FN
  31. from ...modeling_outputs import (
  32. BaseModelOutput,
  33. BaseModelOutputWithPooling,
  34. MaskedLMOutput,
  35. MultipleChoiceModelOutput,
  36. NextSentencePredictorOutput,
  37. QuestionAnsweringModelOutput,
  38. SequenceClassifierOutput,
  39. TokenClassifierOutput,
  40. )
  41. from ...modeling_utils import PreTrainedModel
  42. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  43. from ...utils import ModelOutput, auto_docstring, logging
  44. from .configuration_mobilebert import MobileBertConfig
  45. logger = logging.get_logger(__name__)
  46. def load_tf_weights_in_mobilebert(model, config, tf_checkpoint_path):
  47. """Load tf checkpoints in a pytorch model."""
  48. try:
  49. import re
  50. import numpy as np
  51. import tensorflow as tf
  52. except ImportError:
  53. logger.error(
  54. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  55. "https://www.tensorflow.org/install/ for installation instructions."
  56. )
  57. raise
  58. tf_path = os.path.abspath(tf_checkpoint_path)
  59. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  60. # Load weights from TF model
  61. init_vars = tf.train.list_variables(tf_path)
  62. names = []
  63. arrays = []
  64. for name, shape in init_vars:
  65. logger.info(f"Loading TF weight {name} with shape {shape}")
  66. array = tf.train.load_variable(tf_path, name)
  67. names.append(name)
  68. arrays.append(array)
  69. for name, array in zip(names, arrays):
  70. name = name.replace("ffn_layer", "ffn")
  71. name = name.replace("FakeLayerNorm", "LayerNorm")
  72. name = name.replace("extra_output_weights", "dense/kernel")
  73. name = name.replace("bert", "mobilebert")
  74. name = name.split("/")
  75. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  76. # which are not required for using pretrained model
  77. if any(
  78. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  79. for n in name
  80. ):
  81. logger.info(f"Skipping {'/'.join(name)}")
  82. continue
  83. pointer = model
  84. for m_name in name:
  85. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  86. scope_names = re.split(r"_(\d+)", m_name)
  87. else:
  88. scope_names = [m_name]
  89. if scope_names[0] == "kernel" or scope_names[0] == "gamma":
  90. pointer = getattr(pointer, "weight")
  91. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  92. pointer = getattr(pointer, "bias")
  93. elif scope_names[0] == "output_weights":
  94. pointer = getattr(pointer, "weight")
  95. elif scope_names[0] == "squad":
  96. pointer = getattr(pointer, "classifier")
  97. else:
  98. try:
  99. pointer = getattr(pointer, scope_names[0])
  100. except AttributeError:
  101. logger.info(f"Skipping {'/'.join(name)}")
  102. continue
  103. if len(scope_names) >= 2:
  104. num = int(scope_names[1])
  105. pointer = pointer[num]
  106. if m_name[-11:] == "_embeddings":
  107. pointer = getattr(pointer, "weight")
  108. elif m_name == "kernel":
  109. array = np.transpose(array)
  110. try:
  111. assert pointer.shape == array.shape, (
  112. f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
  113. )
  114. except AssertionError as e:
  115. e.args += (pointer.shape, array.shape)
  116. raise
  117. logger.info(f"Initialize PyTorch weight {name}")
  118. pointer.data = torch.from_numpy(array)
  119. return model
  120. class NoNorm(nn.Module):
  121. def __init__(self, feat_size, eps=None):
  122. super().__init__()
  123. self.bias = nn.Parameter(torch.zeros(feat_size))
  124. self.weight = nn.Parameter(torch.ones(feat_size))
  125. def forward(self, input_tensor: torch.Tensor) -> torch.Tensor:
  126. return input_tensor * self.weight + self.bias
  127. NORM2FN = {"layer_norm": nn.LayerNorm, "no_norm": NoNorm}
  128. class MobileBertEmbeddings(nn.Module):
  129. """Construct the embeddings from word, position and token_type embeddings."""
  130. def __init__(self, config):
  131. super().__init__()
  132. self.trigram_input = config.trigram_input
  133. self.embedding_size = config.embedding_size
  134. self.hidden_size = config.hidden_size
  135. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  136. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  137. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  138. embed_dim_multiplier = 3 if self.trigram_input else 1
  139. embedded_input_size = self.embedding_size * embed_dim_multiplier
  140. self.embedding_transformation = nn.Linear(embedded_input_size, config.hidden_size)
  141. self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size)
  142. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  143. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  144. self.register_buffer(
  145. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  146. )
  147. def forward(
  148. self,
  149. input_ids: Optional[torch.LongTensor] = None,
  150. token_type_ids: Optional[torch.LongTensor] = None,
  151. position_ids: Optional[torch.LongTensor] = None,
  152. inputs_embeds: Optional[torch.FloatTensor] = None,
  153. ) -> torch.Tensor:
  154. if input_ids is not None:
  155. input_shape = input_ids.size()
  156. else:
  157. input_shape = inputs_embeds.size()[:-1]
  158. seq_length = input_shape[1]
  159. if position_ids is None:
  160. position_ids = self.position_ids[:, :seq_length]
  161. if token_type_ids is None:
  162. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  163. if inputs_embeds is None:
  164. inputs_embeds = self.word_embeddings(input_ids)
  165. if self.trigram_input:
  166. # From the paper MobileBERT: a Compact Task-Agnostic BERT for Resource-Limited
  167. # Devices (https://huggingface.co/papers/2004.02984)
  168. #
  169. # The embedding table in BERT models accounts for a substantial proportion of model size. To compress
  170. # the embedding layer, we reduce the embedding dimension to 128 in MobileBERT.
  171. # Then, we apply a 1D convolution with kernel size 3 on the raw token embedding to produce a 512
  172. # dimensional output.
  173. inputs_embeds = torch.cat(
  174. [
  175. nn.functional.pad(inputs_embeds[:, 1:], [0, 0, 0, 1, 0, 0], value=0.0),
  176. inputs_embeds,
  177. nn.functional.pad(inputs_embeds[:, :-1], [0, 0, 1, 0, 0, 0], value=0.0),
  178. ],
  179. dim=2,
  180. )
  181. if self.trigram_input or self.embedding_size != self.hidden_size:
  182. inputs_embeds = self.embedding_transformation(inputs_embeds)
  183. # Add positional embeddings and token type embeddings, then layer
  184. # normalize and perform dropout.
  185. position_embeddings = self.position_embeddings(position_ids)
  186. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  187. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  188. embeddings = self.LayerNorm(embeddings)
  189. embeddings = self.dropout(embeddings)
  190. return embeddings
  191. class MobileBertSelfAttention(nn.Module):
  192. def __init__(self, config):
  193. super().__init__()
  194. self.num_attention_heads = config.num_attention_heads
  195. self.attention_head_size = int(config.true_hidden_size / config.num_attention_heads)
  196. self.all_head_size = self.num_attention_heads * self.attention_head_size
  197. self.query = nn.Linear(config.true_hidden_size, self.all_head_size)
  198. self.key = nn.Linear(config.true_hidden_size, self.all_head_size)
  199. self.value = nn.Linear(
  200. config.true_hidden_size if config.use_bottleneck_attention else config.hidden_size, self.all_head_size
  201. )
  202. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  203. def forward(
  204. self,
  205. query_tensor: torch.Tensor,
  206. key_tensor: torch.Tensor,
  207. value_tensor: torch.Tensor,
  208. attention_mask: Optional[torch.FloatTensor] = None,
  209. head_mask: Optional[torch.FloatTensor] = None,
  210. output_attentions: Optional[bool] = None,
  211. ) -> tuple[torch.Tensor]:
  212. batch_size, seq_length, _ = query_tensor.shape
  213. query_layer = (
  214. self.query(query_tensor)
  215. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  216. .transpose(1, 2)
  217. )
  218. key_layer = (
  219. self.key(key_tensor)
  220. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  221. .transpose(1, 2)
  222. )
  223. value_layer = (
  224. self.value(value_tensor)
  225. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  226. .transpose(1, 2)
  227. )
  228. # Take the dot product between "query" and "key" to get the raw attention scores.
  229. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  230. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  231. if attention_mask is not None:
  232. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  233. attention_scores = attention_scores + attention_mask
  234. # Normalize the attention scores to probabilities.
  235. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  236. # This is actually dropping out entire tokens to attend to, which might
  237. # seem a bit unusual, but is taken from the original Transformer paper.
  238. attention_probs = self.dropout(attention_probs)
  239. # Mask heads if we want to
  240. if head_mask is not None:
  241. attention_probs = attention_probs * head_mask
  242. context_layer = torch.matmul(attention_probs, value_layer)
  243. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  244. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  245. context_layer = context_layer.view(new_context_layer_shape)
  246. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  247. return outputs
  248. class MobileBertSelfOutput(nn.Module):
  249. def __init__(self, config):
  250. super().__init__()
  251. self.use_bottleneck = config.use_bottleneck
  252. self.dense = nn.Linear(config.true_hidden_size, config.true_hidden_size)
  253. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
  254. if not self.use_bottleneck:
  255. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  256. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  257. layer_outputs = self.dense(hidden_states)
  258. if not self.use_bottleneck:
  259. layer_outputs = self.dropout(layer_outputs)
  260. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  261. return layer_outputs
  262. class MobileBertAttention(nn.Module):
  263. def __init__(self, config):
  264. super().__init__()
  265. self.self = MobileBertSelfAttention(config)
  266. self.output = MobileBertSelfOutput(config)
  267. self.pruned_heads = set()
  268. def prune_heads(self, heads):
  269. if len(heads) == 0:
  270. return
  271. heads, index = find_pruneable_heads_and_indices(
  272. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  273. )
  274. # Prune linear layers
  275. self.self.query = prune_linear_layer(self.self.query, index)
  276. self.self.key = prune_linear_layer(self.self.key, index)
  277. self.self.value = prune_linear_layer(self.self.value, index)
  278. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  279. # Update hyper params and store pruned heads
  280. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  281. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  282. self.pruned_heads = self.pruned_heads.union(heads)
  283. def forward(
  284. self,
  285. query_tensor: torch.Tensor,
  286. key_tensor: torch.Tensor,
  287. value_tensor: torch.Tensor,
  288. layer_input: torch.Tensor,
  289. attention_mask: Optional[torch.FloatTensor] = None,
  290. head_mask: Optional[torch.FloatTensor] = None,
  291. output_attentions: Optional[bool] = None,
  292. ) -> tuple[torch.Tensor]:
  293. self_outputs = self.self(
  294. query_tensor,
  295. key_tensor,
  296. value_tensor,
  297. attention_mask,
  298. head_mask,
  299. output_attentions,
  300. )
  301. # Run a linear projection of `hidden_size` then add a residual
  302. # with `layer_input`.
  303. attention_output = self.output(self_outputs[0], layer_input)
  304. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  305. return outputs
  306. class MobileBertIntermediate(nn.Module):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.dense = nn.Linear(config.true_hidden_size, config.intermediate_size)
  310. if isinstance(config.hidden_act, str):
  311. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  312. else:
  313. self.intermediate_act_fn = config.hidden_act
  314. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  315. hidden_states = self.dense(hidden_states)
  316. hidden_states = self.intermediate_act_fn(hidden_states)
  317. return hidden_states
  318. class OutputBottleneck(nn.Module):
  319. def __init__(self, config):
  320. super().__init__()
  321. self.dense = nn.Linear(config.true_hidden_size, config.hidden_size)
  322. self.LayerNorm = NORM2FN[config.normalization_type](config.hidden_size, eps=config.layer_norm_eps)
  323. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  324. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  325. layer_outputs = self.dense(hidden_states)
  326. layer_outputs = self.dropout(layer_outputs)
  327. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  328. return layer_outputs
  329. class MobileBertOutput(nn.Module):
  330. def __init__(self, config):
  331. super().__init__()
  332. self.use_bottleneck = config.use_bottleneck
  333. self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
  334. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size)
  335. if not self.use_bottleneck:
  336. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  337. else:
  338. self.bottleneck = OutputBottleneck(config)
  339. def forward(
  340. self, intermediate_states: torch.Tensor, residual_tensor_1: torch.Tensor, residual_tensor_2: torch.Tensor
  341. ) -> torch.Tensor:
  342. layer_output = self.dense(intermediate_states)
  343. if not self.use_bottleneck:
  344. layer_output = self.dropout(layer_output)
  345. layer_output = self.LayerNorm(layer_output + residual_tensor_1)
  346. else:
  347. layer_output = self.LayerNorm(layer_output + residual_tensor_1)
  348. layer_output = self.bottleneck(layer_output, residual_tensor_2)
  349. return layer_output
  350. class BottleneckLayer(nn.Module):
  351. def __init__(self, config):
  352. super().__init__()
  353. self.dense = nn.Linear(config.hidden_size, config.intra_bottleneck_size)
  354. self.LayerNorm = NORM2FN[config.normalization_type](config.intra_bottleneck_size, eps=config.layer_norm_eps)
  355. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  356. layer_input = self.dense(hidden_states)
  357. layer_input = self.LayerNorm(layer_input)
  358. return layer_input
  359. class Bottleneck(nn.Module):
  360. def __init__(self, config):
  361. super().__init__()
  362. self.key_query_shared_bottleneck = config.key_query_shared_bottleneck
  363. self.use_bottleneck_attention = config.use_bottleneck_attention
  364. self.input = BottleneckLayer(config)
  365. if self.key_query_shared_bottleneck:
  366. self.attention = BottleneckLayer(config)
  367. def forward(self, hidden_states: torch.Tensor) -> tuple[torch.Tensor]:
  368. # This method can return three different tuples of values. These different values make use of bottlenecks,
  369. # which are linear layers used to project the hidden states to a lower-dimensional vector, reducing memory
  370. # usage. These linear layer have weights that are learned during training.
  371. #
  372. # If `config.use_bottleneck_attention`, it will return the result of the bottleneck layer four times for the
  373. # key, query, value, and "layer input" to be used by the attention layer.
  374. # This bottleneck is used to project the hidden. This last layer input will be used as a residual tensor
  375. # in the attention self output, after the attention scores have been computed.
  376. #
  377. # If not `config.use_bottleneck_attention` and `config.key_query_shared_bottleneck`, this will return
  378. # four values, three of which have been passed through a bottleneck: the query and key, passed through the same
  379. # bottleneck, and the residual layer to be applied in the attention self output, through another bottleneck.
  380. #
  381. # Finally, in the last case, the values for the query, key and values are the hidden states without bottleneck,
  382. # and the residual layer will be this value passed through a bottleneck.
  383. bottlenecked_hidden_states = self.input(hidden_states)
  384. if self.use_bottleneck_attention:
  385. return (bottlenecked_hidden_states,) * 4
  386. elif self.key_query_shared_bottleneck:
  387. shared_attention_input = self.attention(hidden_states)
  388. return (shared_attention_input, shared_attention_input, hidden_states, bottlenecked_hidden_states)
  389. else:
  390. return (hidden_states, hidden_states, hidden_states, bottlenecked_hidden_states)
  391. class FFNOutput(nn.Module):
  392. def __init__(self, config):
  393. super().__init__()
  394. self.dense = nn.Linear(config.intermediate_size, config.true_hidden_size)
  395. self.LayerNorm = NORM2FN[config.normalization_type](config.true_hidden_size, eps=config.layer_norm_eps)
  396. def forward(self, hidden_states: torch.Tensor, residual_tensor: torch.Tensor) -> torch.Tensor:
  397. layer_outputs = self.dense(hidden_states)
  398. layer_outputs = self.LayerNorm(layer_outputs + residual_tensor)
  399. return layer_outputs
  400. class FFNLayer(nn.Module):
  401. def __init__(self, config):
  402. super().__init__()
  403. self.intermediate = MobileBertIntermediate(config)
  404. self.output = FFNOutput(config)
  405. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  406. intermediate_output = self.intermediate(hidden_states)
  407. layer_outputs = self.output(intermediate_output, hidden_states)
  408. return layer_outputs
  409. class MobileBertLayer(nn.Module):
  410. def __init__(self, config):
  411. super().__init__()
  412. self.use_bottleneck = config.use_bottleneck
  413. self.num_feedforward_networks = config.num_feedforward_networks
  414. self.attention = MobileBertAttention(config)
  415. self.intermediate = MobileBertIntermediate(config)
  416. self.output = MobileBertOutput(config)
  417. if self.use_bottleneck:
  418. self.bottleneck = Bottleneck(config)
  419. if config.num_feedforward_networks > 1:
  420. self.ffn = nn.ModuleList([FFNLayer(config) for _ in range(config.num_feedforward_networks - 1)])
  421. def forward(
  422. self,
  423. hidden_states: torch.Tensor,
  424. attention_mask: Optional[torch.FloatTensor] = None,
  425. head_mask: Optional[torch.FloatTensor] = None,
  426. output_attentions: Optional[bool] = None,
  427. ) -> tuple[torch.Tensor]:
  428. if self.use_bottleneck:
  429. query_tensor, key_tensor, value_tensor, layer_input = self.bottleneck(hidden_states)
  430. else:
  431. query_tensor, key_tensor, value_tensor, layer_input = [hidden_states] * 4
  432. self_attention_outputs = self.attention(
  433. query_tensor,
  434. key_tensor,
  435. value_tensor,
  436. layer_input,
  437. attention_mask,
  438. head_mask,
  439. output_attentions=output_attentions,
  440. )
  441. attention_output = self_attention_outputs[0]
  442. s = (attention_output,)
  443. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  444. if self.num_feedforward_networks != 1:
  445. for i, ffn_module in enumerate(self.ffn):
  446. attention_output = ffn_module(attention_output)
  447. s += (attention_output,)
  448. intermediate_output = self.intermediate(attention_output)
  449. layer_output = self.output(intermediate_output, attention_output, hidden_states)
  450. outputs = (
  451. (layer_output,)
  452. + outputs
  453. + (
  454. torch.tensor(1000),
  455. query_tensor,
  456. key_tensor,
  457. value_tensor,
  458. layer_input,
  459. attention_output,
  460. intermediate_output,
  461. )
  462. + s
  463. )
  464. return outputs
  465. class MobileBertEncoder(nn.Module):
  466. def __init__(self, config):
  467. super().__init__()
  468. self.layer = nn.ModuleList([MobileBertLayer(config) for _ in range(config.num_hidden_layers)])
  469. def forward(
  470. self,
  471. hidden_states: torch.Tensor,
  472. attention_mask: Optional[torch.FloatTensor] = None,
  473. head_mask: Optional[torch.FloatTensor] = None,
  474. output_attentions: Optional[bool] = False,
  475. output_hidden_states: Optional[bool] = False,
  476. return_dict: Optional[bool] = True,
  477. ) -> Union[tuple, BaseModelOutput]:
  478. all_hidden_states = () if output_hidden_states else None
  479. all_attentions = () if output_attentions else None
  480. for i, layer_module in enumerate(self.layer):
  481. if output_hidden_states:
  482. all_hidden_states = all_hidden_states + (hidden_states,)
  483. layer_outputs = layer_module(
  484. hidden_states,
  485. attention_mask,
  486. head_mask[i],
  487. output_attentions,
  488. )
  489. hidden_states = layer_outputs[0]
  490. if output_attentions:
  491. all_attentions = all_attentions + (layer_outputs[1],)
  492. # Add last layer
  493. if output_hidden_states:
  494. all_hidden_states = all_hidden_states + (hidden_states,)
  495. if not return_dict:
  496. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  497. return BaseModelOutput(
  498. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  499. )
  500. class MobileBertPooler(nn.Module):
  501. def __init__(self, config):
  502. super().__init__()
  503. self.do_activate = config.classifier_activation
  504. if self.do_activate:
  505. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  506. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  507. # We "pool" the model by simply taking the hidden state corresponding
  508. # to the first token.
  509. first_token_tensor = hidden_states[:, 0]
  510. if not self.do_activate:
  511. return first_token_tensor
  512. else:
  513. pooled_output = self.dense(first_token_tensor)
  514. pooled_output = torch.tanh(pooled_output)
  515. return pooled_output
  516. class MobileBertPredictionHeadTransform(nn.Module):
  517. def __init__(self, config):
  518. super().__init__()
  519. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  520. if isinstance(config.hidden_act, str):
  521. self.transform_act_fn = ACT2FN[config.hidden_act]
  522. else:
  523. self.transform_act_fn = config.hidden_act
  524. self.LayerNorm = NORM2FN["layer_norm"](config.hidden_size, eps=config.layer_norm_eps)
  525. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  526. hidden_states = self.dense(hidden_states)
  527. hidden_states = self.transform_act_fn(hidden_states)
  528. hidden_states = self.LayerNorm(hidden_states)
  529. return hidden_states
  530. class MobileBertLMPredictionHead(nn.Module):
  531. def __init__(self, config):
  532. super().__init__()
  533. self.transform = MobileBertPredictionHeadTransform(config)
  534. # The output weights are the same as the input embeddings, but there is
  535. # an output-only bias for each token.
  536. self.dense = nn.Linear(config.vocab_size, config.hidden_size - config.embedding_size, bias=False)
  537. self.decoder = nn.Linear(config.embedding_size, config.vocab_size, bias=False)
  538. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  539. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  540. self.decoder.bias = self.bias
  541. def _tie_weights(self) -> None:
  542. self.decoder.bias = self.bias
  543. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  544. hidden_states = self.transform(hidden_states)
  545. hidden_states = hidden_states.matmul(torch.cat([self.decoder.weight.t(), self.dense.weight], dim=0))
  546. hidden_states += self.decoder.bias
  547. return hidden_states
  548. class MobileBertOnlyMLMHead(nn.Module):
  549. def __init__(self, config):
  550. super().__init__()
  551. self.predictions = MobileBertLMPredictionHead(config)
  552. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  553. prediction_scores = self.predictions(sequence_output)
  554. return prediction_scores
  555. class MobileBertPreTrainingHeads(nn.Module):
  556. def __init__(self, config):
  557. super().__init__()
  558. self.predictions = MobileBertLMPredictionHead(config)
  559. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  560. def forward(self, sequence_output: torch.Tensor, pooled_output: torch.Tensor) -> tuple[torch.Tensor]:
  561. prediction_scores = self.predictions(sequence_output)
  562. seq_relationship_score = self.seq_relationship(pooled_output)
  563. return prediction_scores, seq_relationship_score
  564. @auto_docstring
  565. class MobileBertPreTrainedModel(PreTrainedModel):
  566. config: MobileBertConfig
  567. load_tf_weights = load_tf_weights_in_mobilebert
  568. base_model_prefix = "mobilebert"
  569. def _init_weights(self, module):
  570. """Initialize the weights"""
  571. if isinstance(module, nn.Linear):
  572. # Slightly different from the TF version which uses truncated_normal for initialization
  573. # cf https://github.com/pytorch/pytorch/pull/5617
  574. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  575. if module.bias is not None:
  576. module.bias.data.zero_()
  577. elif isinstance(module, nn.Embedding):
  578. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  579. if module.padding_idx is not None:
  580. module.weight.data[module.padding_idx].zero_()
  581. elif isinstance(module, (nn.LayerNorm, NoNorm)):
  582. module.bias.data.zero_()
  583. module.weight.data.fill_(1.0)
  584. elif isinstance(module, MobileBertLMPredictionHead):
  585. module.bias.data.zero_()
  586. @dataclass
  587. @auto_docstring(
  588. custom_intro="""
  589. Output type of [`MobileBertForPreTraining`].
  590. """
  591. )
  592. class MobileBertForPreTrainingOutput(ModelOutput):
  593. r"""
  594. loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
  595. Total loss as the sum of the masked language modeling loss and the next sequence prediction
  596. (classification) loss.
  597. prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  598. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  599. seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
  600. Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
  601. before SoftMax).
  602. """
  603. loss: Optional[torch.FloatTensor] = None
  604. prediction_logits: Optional[torch.FloatTensor] = None
  605. seq_relationship_logits: Optional[torch.FloatTensor] = None
  606. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  607. attentions: Optional[tuple[torch.FloatTensor]] = None
  608. @auto_docstring
  609. class MobileBertModel(MobileBertPreTrainedModel):
  610. """
  611. https://huggingface.co/papers/2004.02984
  612. """
  613. def __init__(self, config, add_pooling_layer=True):
  614. r"""
  615. add_pooling_layer (bool, *optional*, defaults to `True`):
  616. Whether to add a pooling layer
  617. """
  618. super().__init__(config)
  619. self.config = config
  620. self.embeddings = MobileBertEmbeddings(config)
  621. self.encoder = MobileBertEncoder(config)
  622. self.pooler = MobileBertPooler(config) if add_pooling_layer else None
  623. # Initialize weights and apply final processing
  624. self.post_init()
  625. def get_input_embeddings(self):
  626. return self.embeddings.word_embeddings
  627. def set_input_embeddings(self, value):
  628. self.embeddings.word_embeddings = value
  629. def _prune_heads(self, heads_to_prune):
  630. """
  631. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  632. class PreTrainedModel
  633. """
  634. for layer, heads in heads_to_prune.items():
  635. self.encoder.layer[layer].attention.prune_heads(heads)
  636. @auto_docstring
  637. def forward(
  638. self,
  639. input_ids: Optional[torch.LongTensor] = None,
  640. attention_mask: Optional[torch.FloatTensor] = None,
  641. token_type_ids: Optional[torch.LongTensor] = None,
  642. position_ids: Optional[torch.LongTensor] = None,
  643. head_mask: Optional[torch.FloatTensor] = None,
  644. inputs_embeds: Optional[torch.FloatTensor] = None,
  645. output_hidden_states: Optional[bool] = None,
  646. output_attentions: Optional[bool] = None,
  647. return_dict: Optional[bool] = None,
  648. ) -> Union[tuple, BaseModelOutputWithPooling]:
  649. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  650. output_hidden_states = (
  651. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  652. )
  653. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  654. if input_ids is not None and inputs_embeds is not None:
  655. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  656. elif input_ids is not None:
  657. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  658. input_shape = input_ids.size()
  659. elif inputs_embeds is not None:
  660. input_shape = inputs_embeds.size()[:-1]
  661. else:
  662. raise ValueError("You have to specify either input_ids or inputs_embeds")
  663. device = input_ids.device if input_ids is not None else inputs_embeds.device
  664. if attention_mask is None:
  665. attention_mask = torch.ones(input_shape, device=device)
  666. if token_type_ids is None:
  667. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  668. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  669. # ourselves in which case we just need to make it broadcastable to all heads.
  670. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  671. # Prepare head mask if needed
  672. # 1.0 in head_mask indicate we keep the head
  673. # attention_probs has shape bsz x n_heads x N x N
  674. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  675. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  676. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  677. embedding_output = self.embeddings(
  678. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  679. )
  680. encoder_outputs = self.encoder(
  681. embedding_output,
  682. attention_mask=extended_attention_mask,
  683. head_mask=head_mask,
  684. output_attentions=output_attentions,
  685. output_hidden_states=output_hidden_states,
  686. return_dict=return_dict,
  687. )
  688. sequence_output = encoder_outputs[0]
  689. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  690. if not return_dict:
  691. return (sequence_output, pooled_output) + encoder_outputs[1:]
  692. return BaseModelOutputWithPooling(
  693. last_hidden_state=sequence_output,
  694. pooler_output=pooled_output,
  695. hidden_states=encoder_outputs.hidden_states,
  696. attentions=encoder_outputs.attentions,
  697. )
  698. @auto_docstring(
  699. custom_intro="""
  700. MobileBert Model with two heads on top as done during the pretraining: a `masked language modeling` head and a
  701. `next sentence prediction (classification)` head.
  702. """
  703. )
  704. class MobileBertForPreTraining(MobileBertPreTrainedModel):
  705. _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
  706. def __init__(self, config):
  707. super().__init__(config)
  708. self.mobilebert = MobileBertModel(config)
  709. self.cls = MobileBertPreTrainingHeads(config)
  710. # Initialize weights and apply final processing
  711. self.post_init()
  712. def get_output_embeddings(self):
  713. return self.cls.predictions.decoder
  714. def set_output_embeddings(self, new_embeddings):
  715. self.cls.predictions.decoder = new_embeddings
  716. self.cls.predictions.bias = new_embeddings.bias
  717. def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
  718. # resize dense output embedings at first
  719. self.cls.predictions.dense = self._get_resized_lm_head(
  720. self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
  721. )
  722. return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
  723. @auto_docstring
  724. def forward(
  725. self,
  726. input_ids: Optional[torch.LongTensor] = None,
  727. attention_mask: Optional[torch.FloatTensor] = None,
  728. token_type_ids: Optional[torch.LongTensor] = None,
  729. position_ids: Optional[torch.LongTensor] = None,
  730. head_mask: Optional[torch.FloatTensor] = None,
  731. inputs_embeds: Optional[torch.FloatTensor] = None,
  732. labels: Optional[torch.LongTensor] = None,
  733. next_sentence_label: Optional[torch.LongTensor] = None,
  734. output_attentions: Optional[torch.FloatTensor] = None,
  735. output_hidden_states: Optional[torch.FloatTensor] = None,
  736. return_dict: Optional[torch.FloatTensor] = None,
  737. ) -> Union[tuple, MobileBertForPreTrainingOutput]:
  738. r"""
  739. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  740. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  741. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  742. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  743. next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  744. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  745. (see `input_ids` docstring) Indices should be in `[0, 1]`:
  746. - 0 indicates sequence B is a continuation of sequence A,
  747. - 1 indicates sequence B is a random sequence.
  748. Examples:
  749. ```python
  750. >>> from transformers import AutoTokenizer, MobileBertForPreTraining
  751. >>> import torch
  752. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  753. >>> model = MobileBertForPreTraining.from_pretrained("google/mobilebert-uncased")
  754. >>> input_ids = torch.tensor(tokenizer.encode("Hello, my dog is cute", add_special_tokens=True)).unsqueeze(0)
  755. >>> # Batch size 1
  756. >>> outputs = model(input_ids)
  757. >>> prediction_logits = outputs.prediction_logits
  758. >>> seq_relationship_logits = outputs.seq_relationship_logits
  759. ```"""
  760. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  761. outputs = self.mobilebert(
  762. input_ids,
  763. attention_mask=attention_mask,
  764. token_type_ids=token_type_ids,
  765. position_ids=position_ids,
  766. head_mask=head_mask,
  767. inputs_embeds=inputs_embeds,
  768. output_attentions=output_attentions,
  769. output_hidden_states=output_hidden_states,
  770. return_dict=return_dict,
  771. )
  772. sequence_output, pooled_output = outputs[:2]
  773. prediction_scores, seq_relationship_score = self.cls(sequence_output, pooled_output)
  774. total_loss = None
  775. if labels is not None and next_sentence_label is not None:
  776. loss_fct = CrossEntropyLoss()
  777. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  778. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), next_sentence_label.view(-1))
  779. total_loss = masked_lm_loss + next_sentence_loss
  780. if not return_dict:
  781. output = (prediction_scores, seq_relationship_score) + outputs[2:]
  782. return ((total_loss,) + output) if total_loss is not None else output
  783. return MobileBertForPreTrainingOutput(
  784. loss=total_loss,
  785. prediction_logits=prediction_scores,
  786. seq_relationship_logits=seq_relationship_score,
  787. hidden_states=outputs.hidden_states,
  788. attentions=outputs.attentions,
  789. )
  790. @auto_docstring
  791. class MobileBertForMaskedLM(MobileBertPreTrainedModel):
  792. _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
  793. def __init__(self, config):
  794. super().__init__(config)
  795. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  796. self.cls = MobileBertOnlyMLMHead(config)
  797. self.config = config
  798. # Initialize weights and apply final processing
  799. self.post_init()
  800. def get_output_embeddings(self):
  801. return self.cls.predictions.decoder
  802. def set_output_embeddings(self, new_embeddings):
  803. self.cls.predictions.decoder = new_embeddings
  804. self.cls.predictions.bias = new_embeddings.bias
  805. def resize_token_embeddings(self, new_num_tokens: Optional[int] = None) -> nn.Embedding:
  806. # resize dense output embedings at first
  807. self.cls.predictions.dense = self._get_resized_lm_head(
  808. self.cls.predictions.dense, new_num_tokens=new_num_tokens, transposed=True
  809. )
  810. return super().resize_token_embeddings(new_num_tokens=new_num_tokens)
  811. @auto_docstring
  812. def forward(
  813. self,
  814. input_ids: Optional[torch.LongTensor] = None,
  815. attention_mask: Optional[torch.FloatTensor] = None,
  816. token_type_ids: Optional[torch.LongTensor] = None,
  817. position_ids: Optional[torch.LongTensor] = None,
  818. head_mask: Optional[torch.FloatTensor] = None,
  819. inputs_embeds: Optional[torch.FloatTensor] = None,
  820. labels: Optional[torch.LongTensor] = None,
  821. output_attentions: Optional[bool] = None,
  822. output_hidden_states: Optional[bool] = None,
  823. return_dict: Optional[bool] = None,
  824. ) -> Union[tuple, MaskedLMOutput]:
  825. r"""
  826. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  827. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  828. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  829. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  830. """
  831. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  832. outputs = self.mobilebert(
  833. input_ids,
  834. attention_mask=attention_mask,
  835. token_type_ids=token_type_ids,
  836. position_ids=position_ids,
  837. head_mask=head_mask,
  838. inputs_embeds=inputs_embeds,
  839. output_attentions=output_attentions,
  840. output_hidden_states=output_hidden_states,
  841. return_dict=return_dict,
  842. )
  843. sequence_output = outputs[0]
  844. prediction_scores = self.cls(sequence_output)
  845. masked_lm_loss = None
  846. if labels is not None:
  847. loss_fct = CrossEntropyLoss() # -100 index = padding token
  848. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  849. if not return_dict:
  850. output = (prediction_scores,) + outputs[2:]
  851. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  852. return MaskedLMOutput(
  853. loss=masked_lm_loss,
  854. logits=prediction_scores,
  855. hidden_states=outputs.hidden_states,
  856. attentions=outputs.attentions,
  857. )
  858. class MobileBertOnlyNSPHead(nn.Module):
  859. def __init__(self, config):
  860. super().__init__()
  861. self.seq_relationship = nn.Linear(config.hidden_size, 2)
  862. def forward(self, pooled_output: torch.Tensor) -> torch.Tensor:
  863. seq_relationship_score = self.seq_relationship(pooled_output)
  864. return seq_relationship_score
  865. @auto_docstring(
  866. custom_intro="""
  867. MobileBert Model with a `next sentence prediction (classification)` head on top.
  868. """
  869. )
  870. class MobileBertForNextSentencePrediction(MobileBertPreTrainedModel):
  871. def __init__(self, config):
  872. super().__init__(config)
  873. self.mobilebert = MobileBertModel(config)
  874. self.cls = MobileBertOnlyNSPHead(config)
  875. # Initialize weights and apply final processing
  876. self.post_init()
  877. @auto_docstring
  878. def forward(
  879. self,
  880. input_ids: Optional[torch.LongTensor] = None,
  881. attention_mask: Optional[torch.FloatTensor] = None,
  882. token_type_ids: Optional[torch.LongTensor] = None,
  883. position_ids: Optional[torch.LongTensor] = None,
  884. head_mask: Optional[torch.FloatTensor] = None,
  885. inputs_embeds: Optional[torch.FloatTensor] = None,
  886. labels: Optional[torch.LongTensor] = None,
  887. output_attentions: Optional[bool] = None,
  888. output_hidden_states: Optional[bool] = None,
  889. return_dict: Optional[bool] = None,
  890. **kwargs,
  891. ) -> Union[tuple, NextSentencePredictorOutput]:
  892. r"""
  893. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  894. Labels for computing the next sequence prediction (classification) loss. Input should be a sequence pair
  895. (see `input_ids` docstring) Indices should be in `[0, 1]`.
  896. - 0 indicates sequence B is a continuation of sequence A,
  897. - 1 indicates sequence B is a random sequence.
  898. Examples:
  899. ```python
  900. >>> from transformers import AutoTokenizer, MobileBertForNextSentencePrediction
  901. >>> import torch
  902. >>> tokenizer = AutoTokenizer.from_pretrained("google/mobilebert-uncased")
  903. >>> model = MobileBertForNextSentencePrediction.from_pretrained("google/mobilebert-uncased")
  904. >>> prompt = "In Italy, pizza served in formal settings, such as at a restaurant, is presented unsliced."
  905. >>> next_sentence = "The sky is blue due to the shorter wavelength of blue light."
  906. >>> encoding = tokenizer(prompt, next_sentence, return_tensors="pt")
  907. >>> outputs = model(**encoding, labels=torch.LongTensor([1]))
  908. >>> loss = outputs.loss
  909. >>> logits = outputs.logits
  910. ```"""
  911. if "next_sentence_label" in kwargs:
  912. warnings.warn(
  913. "The `next_sentence_label` argument is deprecated and will be removed in a future version, use"
  914. " `labels` instead.",
  915. FutureWarning,
  916. )
  917. labels = kwargs.pop("next_sentence_label")
  918. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  919. outputs = self.mobilebert(
  920. input_ids,
  921. attention_mask=attention_mask,
  922. token_type_ids=token_type_ids,
  923. position_ids=position_ids,
  924. head_mask=head_mask,
  925. inputs_embeds=inputs_embeds,
  926. output_attentions=output_attentions,
  927. output_hidden_states=output_hidden_states,
  928. return_dict=return_dict,
  929. )
  930. pooled_output = outputs[1]
  931. seq_relationship_score = self.cls(pooled_output)
  932. next_sentence_loss = None
  933. if labels is not None:
  934. loss_fct = CrossEntropyLoss()
  935. next_sentence_loss = loss_fct(seq_relationship_score.view(-1, 2), labels.view(-1))
  936. if not return_dict:
  937. output = (seq_relationship_score,) + outputs[2:]
  938. return ((next_sentence_loss,) + output) if next_sentence_loss is not None else output
  939. return NextSentencePredictorOutput(
  940. loss=next_sentence_loss,
  941. logits=seq_relationship_score,
  942. hidden_states=outputs.hidden_states,
  943. attentions=outputs.attentions,
  944. )
  945. @auto_docstring(
  946. custom_intro="""
  947. MobileBert Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  948. pooled output) e.g. for GLUE tasks.
  949. """
  950. )
  951. # Copied from transformers.models.bert.modeling_bert.BertForSequenceClassification with Bert->MobileBert all-casing
  952. class MobileBertForSequenceClassification(MobileBertPreTrainedModel):
  953. def __init__(self, config):
  954. super().__init__(config)
  955. self.num_labels = config.num_labels
  956. self.config = config
  957. self.mobilebert = MobileBertModel(config)
  958. classifier_dropout = (
  959. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  960. )
  961. self.dropout = nn.Dropout(classifier_dropout)
  962. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  963. # Initialize weights and apply final processing
  964. self.post_init()
  965. @auto_docstring
  966. def forward(
  967. self,
  968. input_ids: Optional[torch.Tensor] = None,
  969. attention_mask: Optional[torch.Tensor] = None,
  970. token_type_ids: Optional[torch.Tensor] = None,
  971. position_ids: Optional[torch.Tensor] = None,
  972. head_mask: Optional[torch.Tensor] = None,
  973. inputs_embeds: Optional[torch.Tensor] = None,
  974. labels: Optional[torch.Tensor] = None,
  975. output_attentions: Optional[bool] = None,
  976. output_hidden_states: Optional[bool] = None,
  977. return_dict: Optional[bool] = None,
  978. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  979. r"""
  980. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  981. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  982. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  983. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  984. """
  985. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  986. outputs = self.mobilebert(
  987. input_ids,
  988. attention_mask=attention_mask,
  989. token_type_ids=token_type_ids,
  990. position_ids=position_ids,
  991. head_mask=head_mask,
  992. inputs_embeds=inputs_embeds,
  993. output_attentions=output_attentions,
  994. output_hidden_states=output_hidden_states,
  995. return_dict=return_dict,
  996. )
  997. pooled_output = outputs[1]
  998. pooled_output = self.dropout(pooled_output)
  999. logits = self.classifier(pooled_output)
  1000. loss = None
  1001. if labels is not None:
  1002. if self.config.problem_type is None:
  1003. if self.num_labels == 1:
  1004. self.config.problem_type = "regression"
  1005. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1006. self.config.problem_type = "single_label_classification"
  1007. else:
  1008. self.config.problem_type = "multi_label_classification"
  1009. if self.config.problem_type == "regression":
  1010. loss_fct = MSELoss()
  1011. if self.num_labels == 1:
  1012. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1013. else:
  1014. loss = loss_fct(logits, labels)
  1015. elif self.config.problem_type == "single_label_classification":
  1016. loss_fct = CrossEntropyLoss()
  1017. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1018. elif self.config.problem_type == "multi_label_classification":
  1019. loss_fct = BCEWithLogitsLoss()
  1020. loss = loss_fct(logits, labels)
  1021. if not return_dict:
  1022. output = (logits,) + outputs[2:]
  1023. return ((loss,) + output) if loss is not None else output
  1024. return SequenceClassifierOutput(
  1025. loss=loss,
  1026. logits=logits,
  1027. hidden_states=outputs.hidden_states,
  1028. attentions=outputs.attentions,
  1029. )
  1030. @auto_docstring
  1031. # Copied from transformers.models.bert.modeling_bert.BertForQuestionAnswering with Bert->MobileBert all-casing
  1032. class MobileBertForQuestionAnswering(MobileBertPreTrainedModel):
  1033. def __init__(self, config):
  1034. super().__init__(config)
  1035. self.num_labels = config.num_labels
  1036. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  1037. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1038. # Initialize weights and apply final processing
  1039. self.post_init()
  1040. @auto_docstring
  1041. def forward(
  1042. self,
  1043. input_ids: Optional[torch.Tensor] = None,
  1044. attention_mask: Optional[torch.Tensor] = None,
  1045. token_type_ids: Optional[torch.Tensor] = None,
  1046. position_ids: Optional[torch.Tensor] = None,
  1047. head_mask: Optional[torch.Tensor] = None,
  1048. inputs_embeds: Optional[torch.Tensor] = None,
  1049. start_positions: Optional[torch.Tensor] = None,
  1050. end_positions: Optional[torch.Tensor] = None,
  1051. output_attentions: Optional[bool] = None,
  1052. output_hidden_states: Optional[bool] = None,
  1053. return_dict: Optional[bool] = None,
  1054. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1055. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1056. outputs = self.mobilebert(
  1057. input_ids,
  1058. attention_mask=attention_mask,
  1059. token_type_ids=token_type_ids,
  1060. position_ids=position_ids,
  1061. head_mask=head_mask,
  1062. inputs_embeds=inputs_embeds,
  1063. output_attentions=output_attentions,
  1064. output_hidden_states=output_hidden_states,
  1065. return_dict=return_dict,
  1066. )
  1067. sequence_output = outputs[0]
  1068. logits = self.qa_outputs(sequence_output)
  1069. start_logits, end_logits = logits.split(1, dim=-1)
  1070. start_logits = start_logits.squeeze(-1).contiguous()
  1071. end_logits = end_logits.squeeze(-1).contiguous()
  1072. total_loss = None
  1073. if start_positions is not None and end_positions is not None:
  1074. # If we are on multi-GPU, split add a dimension
  1075. if len(start_positions.size()) > 1:
  1076. start_positions = start_positions.squeeze(-1)
  1077. if len(end_positions.size()) > 1:
  1078. end_positions = end_positions.squeeze(-1)
  1079. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1080. ignored_index = start_logits.size(1)
  1081. start_positions = start_positions.clamp(0, ignored_index)
  1082. end_positions = end_positions.clamp(0, ignored_index)
  1083. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1084. start_loss = loss_fct(start_logits, start_positions)
  1085. end_loss = loss_fct(end_logits, end_positions)
  1086. total_loss = (start_loss + end_loss) / 2
  1087. if not return_dict:
  1088. output = (start_logits, end_logits) + outputs[2:]
  1089. return ((total_loss,) + output) if total_loss is not None else output
  1090. return QuestionAnsweringModelOutput(
  1091. loss=total_loss,
  1092. start_logits=start_logits,
  1093. end_logits=end_logits,
  1094. hidden_states=outputs.hidden_states,
  1095. attentions=outputs.attentions,
  1096. )
  1097. @auto_docstring
  1098. # Copied from transformers.models.bert.modeling_bert.BertForMultipleChoice with Bert->MobileBert all-casing
  1099. class MobileBertForMultipleChoice(MobileBertPreTrainedModel):
  1100. def __init__(self, config):
  1101. super().__init__(config)
  1102. self.mobilebert = MobileBertModel(config)
  1103. classifier_dropout = (
  1104. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1105. )
  1106. self.dropout = nn.Dropout(classifier_dropout)
  1107. self.classifier = nn.Linear(config.hidden_size, 1)
  1108. # Initialize weights and apply final processing
  1109. self.post_init()
  1110. @auto_docstring
  1111. def forward(
  1112. self,
  1113. input_ids: Optional[torch.Tensor] = None,
  1114. attention_mask: Optional[torch.Tensor] = None,
  1115. token_type_ids: Optional[torch.Tensor] = None,
  1116. position_ids: Optional[torch.Tensor] = None,
  1117. head_mask: Optional[torch.Tensor] = None,
  1118. inputs_embeds: Optional[torch.Tensor] = None,
  1119. labels: Optional[torch.Tensor] = None,
  1120. output_attentions: Optional[bool] = None,
  1121. output_hidden_states: Optional[bool] = None,
  1122. return_dict: Optional[bool] = None,
  1123. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1124. r"""
  1125. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1126. Indices of input sequence tokens in the vocabulary.
  1127. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1128. [`PreTrainedTokenizer.__call__`] for details.
  1129. [What are input IDs?](../glossary#input-ids)
  1130. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1131. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1132. 1]`:
  1133. - 0 corresponds to a *sentence A* token,
  1134. - 1 corresponds to a *sentence B* token.
  1135. [What are token type IDs?](../glossary#token-type-ids)
  1136. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1137. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1138. config.max_position_embeddings - 1]`.
  1139. [What are position IDs?](../glossary#position-ids)
  1140. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1141. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1142. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1143. model's internal embedding lookup matrix.
  1144. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1145. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1146. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1147. `input_ids` above)
  1148. """
  1149. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1150. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1151. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1152. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1153. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1154. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1155. inputs_embeds = (
  1156. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1157. if inputs_embeds is not None
  1158. else None
  1159. )
  1160. outputs = self.mobilebert(
  1161. input_ids,
  1162. attention_mask=attention_mask,
  1163. token_type_ids=token_type_ids,
  1164. position_ids=position_ids,
  1165. head_mask=head_mask,
  1166. inputs_embeds=inputs_embeds,
  1167. output_attentions=output_attentions,
  1168. output_hidden_states=output_hidden_states,
  1169. return_dict=return_dict,
  1170. )
  1171. pooled_output = outputs[1]
  1172. pooled_output = self.dropout(pooled_output)
  1173. logits = self.classifier(pooled_output)
  1174. reshaped_logits = logits.view(-1, num_choices)
  1175. loss = None
  1176. if labels is not None:
  1177. loss_fct = CrossEntropyLoss()
  1178. loss = loss_fct(reshaped_logits, labels)
  1179. if not return_dict:
  1180. output = (reshaped_logits,) + outputs[2:]
  1181. return ((loss,) + output) if loss is not None else output
  1182. return MultipleChoiceModelOutput(
  1183. loss=loss,
  1184. logits=reshaped_logits,
  1185. hidden_states=outputs.hidden_states,
  1186. attentions=outputs.attentions,
  1187. )
  1188. @auto_docstring
  1189. # Copied from transformers.models.bert.modeling_bert.BertForTokenClassification with Bert->MobileBert all-casing
  1190. class MobileBertForTokenClassification(MobileBertPreTrainedModel):
  1191. def __init__(self, config):
  1192. super().__init__(config)
  1193. self.num_labels = config.num_labels
  1194. self.mobilebert = MobileBertModel(config, add_pooling_layer=False)
  1195. classifier_dropout = (
  1196. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1197. )
  1198. self.dropout = nn.Dropout(classifier_dropout)
  1199. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1200. # Initialize weights and apply final processing
  1201. self.post_init()
  1202. @auto_docstring
  1203. def forward(
  1204. self,
  1205. input_ids: Optional[torch.Tensor] = None,
  1206. attention_mask: Optional[torch.Tensor] = None,
  1207. token_type_ids: Optional[torch.Tensor] = None,
  1208. position_ids: Optional[torch.Tensor] = None,
  1209. head_mask: Optional[torch.Tensor] = None,
  1210. inputs_embeds: Optional[torch.Tensor] = None,
  1211. labels: Optional[torch.Tensor] = None,
  1212. output_attentions: Optional[bool] = None,
  1213. output_hidden_states: Optional[bool] = None,
  1214. return_dict: Optional[bool] = None,
  1215. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1216. r"""
  1217. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1218. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1219. """
  1220. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1221. outputs = self.mobilebert(
  1222. input_ids,
  1223. attention_mask=attention_mask,
  1224. token_type_ids=token_type_ids,
  1225. position_ids=position_ids,
  1226. head_mask=head_mask,
  1227. inputs_embeds=inputs_embeds,
  1228. output_attentions=output_attentions,
  1229. output_hidden_states=output_hidden_states,
  1230. return_dict=return_dict,
  1231. )
  1232. sequence_output = outputs[0]
  1233. sequence_output = self.dropout(sequence_output)
  1234. logits = self.classifier(sequence_output)
  1235. loss = None
  1236. if labels is not None:
  1237. loss_fct = CrossEntropyLoss()
  1238. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1239. if not return_dict:
  1240. output = (logits,) + outputs[2:]
  1241. return ((loss,) + output) if loss is not None else output
  1242. return TokenClassifierOutput(
  1243. loss=loss,
  1244. logits=logits,
  1245. hidden_states=outputs.hidden_states,
  1246. attentions=outputs.attentions,
  1247. )
  1248. __all__ = [
  1249. "MobileBertForMaskedLM",
  1250. "MobileBertForMultipleChoice",
  1251. "MobileBertForNextSentencePrediction",
  1252. "MobileBertForPreTraining",
  1253. "MobileBertForQuestionAnswering",
  1254. "MobileBertForSequenceClassification",
  1255. "MobileBertForTokenClassification",
  1256. "MobileBertLayer",
  1257. "MobileBertModel",
  1258. "MobileBertPreTrainedModel",
  1259. "load_tf_weights_in_mobilebert",
  1260. ]