modeling_squeezebert.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962
  1. # coding=utf-8
  2. # Copyright 2020 The SqueezeBert authors and The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch SqueezeBert model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. BaseModelOutputWithPooling,
  25. MaskedLMOutput,
  26. MultipleChoiceModelOutput,
  27. QuestionAnsweringModelOutput,
  28. SequenceClassifierOutput,
  29. TokenClassifierOutput,
  30. )
  31. from ...modeling_utils import PreTrainedModel
  32. from ...utils import (
  33. auto_docstring,
  34. logging,
  35. )
  36. from .configuration_squeezebert import SqueezeBertConfig
  37. logger = logging.get_logger(__name__)
  38. class SqueezeBertEmbeddings(nn.Module):
  39. """Construct the embeddings from word, position and token_type embeddings."""
  40. def __init__(self, config):
  41. super().__init__()
  42. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  43. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  44. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  45. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  46. # any TensorFlow checkpoint file
  47. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  48. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  49. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  50. self.register_buffer(
  51. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  52. )
  53. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None):
  54. if input_ids is not None:
  55. input_shape = input_ids.size()
  56. else:
  57. input_shape = inputs_embeds.size()[:-1]
  58. seq_length = input_shape[1]
  59. if position_ids is None:
  60. position_ids = self.position_ids[:, :seq_length]
  61. if token_type_ids is None:
  62. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  63. if inputs_embeds is None:
  64. inputs_embeds = self.word_embeddings(input_ids)
  65. position_embeddings = self.position_embeddings(position_ids)
  66. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  67. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  68. embeddings = self.LayerNorm(embeddings)
  69. embeddings = self.dropout(embeddings)
  70. return embeddings
  71. class MatMulWrapper(nn.Module):
  72. """
  73. Wrapper for torch.matmul(). This makes flop-counting easier to implement. Note that if you directly call
  74. torch.matmul() in your code, the flop counter will typically ignore the flops of the matmul.
  75. """
  76. def __init__(self):
  77. super().__init__()
  78. def forward(self, mat1, mat2):
  79. """
  80. :param inputs: two torch tensors :return: matmul of these tensors
  81. Here are the typical dimensions found in BERT (the B is optional) mat1.shape: [B, <optional extra dims>, M, K]
  82. mat2.shape: [B, <optional extra dims>, K, N] output shape: [B, <optional extra dims>, M, N]
  83. """
  84. return torch.matmul(mat1, mat2)
  85. class SqueezeBertLayerNorm(nn.LayerNorm):
  86. """
  87. This is a nn.LayerNorm subclass that accepts NCW data layout and performs normalization in the C dimension.
  88. N = batch C = channels W = sequence length
  89. """
  90. def __init__(self, hidden_size, eps=1e-12):
  91. nn.LayerNorm.__init__(self, normalized_shape=hidden_size, eps=eps) # instantiates self.{weight, bias, eps}
  92. def forward(self, x):
  93. x = x.permute(0, 2, 1)
  94. x = nn.LayerNorm.forward(self, x)
  95. return x.permute(0, 2, 1)
  96. class ConvDropoutLayerNorm(nn.Module):
  97. """
  98. ConvDropoutLayerNorm: Conv, Dropout, LayerNorm
  99. """
  100. def __init__(self, cin, cout, groups, dropout_prob):
  101. super().__init__()
  102. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  103. self.layernorm = SqueezeBertLayerNorm(cout)
  104. self.dropout = nn.Dropout(dropout_prob)
  105. def forward(self, hidden_states, input_tensor):
  106. x = self.conv1d(hidden_states)
  107. x = self.dropout(x)
  108. x = x + input_tensor
  109. x = self.layernorm(x)
  110. return x
  111. class ConvActivation(nn.Module):
  112. """
  113. ConvActivation: Conv, Activation
  114. """
  115. def __init__(self, cin, cout, groups, act):
  116. super().__init__()
  117. self.conv1d = nn.Conv1d(in_channels=cin, out_channels=cout, kernel_size=1, groups=groups)
  118. self.act = ACT2FN[act]
  119. def forward(self, x):
  120. output = self.conv1d(x)
  121. return self.act(output)
  122. class SqueezeBertSelfAttention(nn.Module):
  123. def __init__(self, config, cin, q_groups=1, k_groups=1, v_groups=1):
  124. """
  125. config = used for some things; ignored for others (work in progress...) cin = input channels = output channels
  126. groups = number of groups to use in conv1d layers
  127. """
  128. super().__init__()
  129. if cin % config.num_attention_heads != 0:
  130. raise ValueError(
  131. f"cin ({cin}) is not a multiple of the number of attention heads ({config.num_attention_heads})"
  132. )
  133. self.num_attention_heads = config.num_attention_heads
  134. self.attention_head_size = int(cin / config.num_attention_heads)
  135. self.all_head_size = self.num_attention_heads * self.attention_head_size
  136. self.query = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=q_groups)
  137. self.key = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=k_groups)
  138. self.value = nn.Conv1d(in_channels=cin, out_channels=cin, kernel_size=1, groups=v_groups)
  139. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  140. self.softmax = nn.Softmax(dim=-1)
  141. self.matmul_qk = MatMulWrapper()
  142. self.matmul_qkv = MatMulWrapper()
  143. def transpose_for_scores(self, x):
  144. """
  145. - input: [N, C, W]
  146. - output: [N, C1, W, C2] where C1 is the head index, and C2 is one head's contents
  147. """
  148. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  149. x = x.view(*new_x_shape)
  150. return x.permute(0, 1, 3, 2) # [N, C1, C2, W] --> [N, C1, W, C2]
  151. def transpose_key_for_scores(self, x):
  152. """
  153. - input: [N, C, W]
  154. - output: [N, C1, C2, W] where C1 is the head index, and C2 is one head's contents
  155. """
  156. new_x_shape = (x.size()[0], self.num_attention_heads, self.attention_head_size, x.size()[-1]) # [N, C1, C2, W]
  157. x = x.view(*new_x_shape)
  158. # no `permute` needed
  159. return x
  160. def transpose_output(self, x):
  161. """
  162. - input: [N, C1, W, C2]
  163. - output: [N, C, W]
  164. """
  165. x = x.permute(0, 1, 3, 2).contiguous() # [N, C1, C2, W]
  166. new_x_shape = (x.size()[0], self.all_head_size, x.size()[3]) # [N, C, W]
  167. x = x.view(*new_x_shape)
  168. return x
  169. def forward(self, hidden_states, attention_mask, output_attentions):
  170. """
  171. expects hidden_states in [N, C, W] data layout.
  172. The attention_mask data layout is [N, W], and it does not need to be transposed.
  173. """
  174. mixed_query_layer = self.query(hidden_states)
  175. mixed_key_layer = self.key(hidden_states)
  176. mixed_value_layer = self.value(hidden_states)
  177. query_layer = self.transpose_for_scores(mixed_query_layer)
  178. key_layer = self.transpose_key_for_scores(mixed_key_layer)
  179. value_layer = self.transpose_for_scores(mixed_value_layer)
  180. # Take the dot product between "query" and "key" to get the raw attention scores.
  181. attention_score = self.matmul_qk(query_layer, key_layer)
  182. attention_score = attention_score / math.sqrt(self.attention_head_size)
  183. # Apply the attention mask is (precomputed for all layers in BertModel forward() function)
  184. attention_score = attention_score + attention_mask
  185. # Normalize the attention scores to probabilities.
  186. attention_probs = self.softmax(attention_score)
  187. # This is actually dropping out entire tokens to attend to, which might
  188. # seem a bit unusual, but is taken from the original Transformer paper.
  189. attention_probs = self.dropout(attention_probs)
  190. context_layer = self.matmul_qkv(attention_probs, value_layer)
  191. context_layer = self.transpose_output(context_layer)
  192. result = {"context_layer": context_layer}
  193. if output_attentions:
  194. result["attention_score"] = attention_score
  195. return result
  196. class SqueezeBertModule(nn.Module):
  197. def __init__(self, config):
  198. """
  199. - hidden_size = input chans = output chans for Q, K, V (they are all the same ... for now) = output chans for
  200. the module
  201. - intermediate_size = output chans for intermediate layer
  202. - groups = number of groups for all layers in the BertModule. (eventually we could change the interface to
  203. allow different groups for different layers)
  204. """
  205. super().__init__()
  206. c0 = config.hidden_size
  207. c1 = config.hidden_size
  208. c2 = config.intermediate_size
  209. c3 = config.hidden_size
  210. self.attention = SqueezeBertSelfAttention(
  211. config=config, cin=c0, q_groups=config.q_groups, k_groups=config.k_groups, v_groups=config.v_groups
  212. )
  213. self.post_attention = ConvDropoutLayerNorm(
  214. cin=c0, cout=c1, groups=config.post_attention_groups, dropout_prob=config.hidden_dropout_prob
  215. )
  216. self.intermediate = ConvActivation(cin=c1, cout=c2, groups=config.intermediate_groups, act=config.hidden_act)
  217. self.output = ConvDropoutLayerNorm(
  218. cin=c2, cout=c3, groups=config.output_groups, dropout_prob=config.hidden_dropout_prob
  219. )
  220. def forward(self, hidden_states, attention_mask, output_attentions):
  221. att = self.attention(hidden_states, attention_mask, output_attentions)
  222. attention_output = att["context_layer"]
  223. post_attention_output = self.post_attention(attention_output, hidden_states)
  224. intermediate_output = self.intermediate(post_attention_output)
  225. layer_output = self.output(intermediate_output, post_attention_output)
  226. output_dict = {"feature_map": layer_output}
  227. if output_attentions:
  228. output_dict["attention_score"] = att["attention_score"]
  229. return output_dict
  230. class SqueezeBertEncoder(nn.Module):
  231. def __init__(self, config):
  232. super().__init__()
  233. assert config.embedding_size == config.hidden_size, (
  234. "If you want embedding_size != intermediate hidden_size, "
  235. "please insert a Conv1d layer to adjust the number of channels "
  236. "before the first SqueezeBertModule."
  237. )
  238. self.layers = nn.ModuleList(SqueezeBertModule(config) for _ in range(config.num_hidden_layers))
  239. def forward(
  240. self,
  241. hidden_states,
  242. attention_mask=None,
  243. head_mask=None,
  244. output_attentions=False,
  245. output_hidden_states=False,
  246. return_dict=True,
  247. ):
  248. if head_mask is None:
  249. head_mask_is_all_none = True
  250. elif head_mask.count(None) == len(head_mask):
  251. head_mask_is_all_none = True
  252. else:
  253. head_mask_is_all_none = False
  254. assert head_mask_is_all_none is True, "head_mask is not yet supported in the SqueezeBert implementation."
  255. # [batch_size, sequence_length, hidden_size] --> [batch_size, hidden_size, sequence_length]
  256. hidden_states = hidden_states.permute(0, 2, 1)
  257. all_hidden_states = () if output_hidden_states else None
  258. all_attentions = () if output_attentions else None
  259. for layer in self.layers:
  260. if output_hidden_states:
  261. hidden_states = hidden_states.permute(0, 2, 1)
  262. all_hidden_states += (hidden_states,)
  263. hidden_states = hidden_states.permute(0, 2, 1)
  264. layer_output = layer.forward(hidden_states, attention_mask, output_attentions)
  265. hidden_states = layer_output["feature_map"]
  266. if output_attentions:
  267. all_attentions += (layer_output["attention_score"],)
  268. # [batch_size, hidden_size, sequence_length] --> [batch_size, sequence_length, hidden_size]
  269. hidden_states = hidden_states.permute(0, 2, 1)
  270. if output_hidden_states:
  271. all_hidden_states += (hidden_states,)
  272. if not return_dict:
  273. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  274. return BaseModelOutput(
  275. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  276. )
  277. class SqueezeBertPooler(nn.Module):
  278. def __init__(self, config):
  279. super().__init__()
  280. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  281. self.activation = nn.Tanh()
  282. def forward(self, hidden_states):
  283. # We "pool" the model by simply taking the hidden state corresponding
  284. # to the first token.
  285. first_token_tensor = hidden_states[:, 0]
  286. pooled_output = self.dense(first_token_tensor)
  287. pooled_output = self.activation(pooled_output)
  288. return pooled_output
  289. class SqueezeBertPredictionHeadTransform(nn.Module):
  290. def __init__(self, config):
  291. super().__init__()
  292. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  293. if isinstance(config.hidden_act, str):
  294. self.transform_act_fn = ACT2FN[config.hidden_act]
  295. else:
  296. self.transform_act_fn = config.hidden_act
  297. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  298. def forward(self, hidden_states):
  299. hidden_states = self.dense(hidden_states)
  300. hidden_states = self.transform_act_fn(hidden_states)
  301. hidden_states = self.LayerNorm(hidden_states)
  302. return hidden_states
  303. class SqueezeBertLMPredictionHead(nn.Module):
  304. def __init__(self, config):
  305. super().__init__()
  306. self.transform = SqueezeBertPredictionHeadTransform(config)
  307. # The output weights are the same as the input embeddings, but there is
  308. # an output-only bias for each token.
  309. self.decoder = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  310. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  311. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  312. self.decoder.bias = self.bias
  313. def _tie_weights(self) -> None:
  314. self.decoder.bias = self.bias
  315. def forward(self, hidden_states):
  316. hidden_states = self.transform(hidden_states)
  317. hidden_states = self.decoder(hidden_states)
  318. return hidden_states
  319. class SqueezeBertOnlyMLMHead(nn.Module):
  320. def __init__(self, config):
  321. super().__init__()
  322. self.predictions = SqueezeBertLMPredictionHead(config)
  323. def forward(self, sequence_output):
  324. prediction_scores = self.predictions(sequence_output)
  325. return prediction_scores
  326. @auto_docstring
  327. class SqueezeBertPreTrainedModel(PreTrainedModel):
  328. config: SqueezeBertConfig
  329. base_model_prefix = "transformer"
  330. def _init_weights(self, module):
  331. """Initialize the weights"""
  332. if isinstance(module, (nn.Linear, nn.Conv1d)):
  333. # Slightly different from the TF version which uses truncated_normal for initialization
  334. # cf https://github.com/pytorch/pytorch/pull/5617
  335. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  336. if module.bias is not None:
  337. module.bias.data.zero_()
  338. elif isinstance(module, nn.Embedding):
  339. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  340. if module.padding_idx is not None:
  341. module.weight.data[module.padding_idx].zero_()
  342. elif isinstance(module, nn.LayerNorm):
  343. module.bias.data.zero_()
  344. module.weight.data.fill_(1.0)
  345. elif isinstance(module, SqueezeBertLMPredictionHead):
  346. module.bias.data.zero_()
  347. @auto_docstring
  348. class SqueezeBertModel(SqueezeBertPreTrainedModel):
  349. def __init__(self, config):
  350. super().__init__(config)
  351. self.embeddings = SqueezeBertEmbeddings(config)
  352. self.encoder = SqueezeBertEncoder(config)
  353. self.pooler = SqueezeBertPooler(config)
  354. # Initialize weights and apply final processing
  355. self.post_init()
  356. def get_input_embeddings(self):
  357. return self.embeddings.word_embeddings
  358. def set_input_embeddings(self, new_embeddings):
  359. self.embeddings.word_embeddings = new_embeddings
  360. def _prune_heads(self, heads_to_prune):
  361. """
  362. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  363. class PreTrainedModel
  364. """
  365. for layer, heads in heads_to_prune.items():
  366. self.encoder.layer[layer].attention.prune_heads(heads)
  367. @auto_docstring
  368. def forward(
  369. self,
  370. input_ids: Optional[torch.Tensor] = None,
  371. attention_mask: Optional[torch.Tensor] = None,
  372. token_type_ids: Optional[torch.Tensor] = None,
  373. position_ids: Optional[torch.Tensor] = None,
  374. head_mask: Optional[torch.Tensor] = None,
  375. inputs_embeds: Optional[torch.FloatTensor] = None,
  376. output_attentions: Optional[bool] = None,
  377. output_hidden_states: Optional[bool] = None,
  378. return_dict: Optional[bool] = None,
  379. ) -> Union[tuple, BaseModelOutputWithPooling]:
  380. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  381. output_hidden_states = (
  382. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  383. )
  384. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  385. if input_ids is not None and inputs_embeds is not None:
  386. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  387. elif input_ids is not None:
  388. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  389. input_shape = input_ids.size()
  390. elif inputs_embeds is not None:
  391. input_shape = inputs_embeds.size()[:-1]
  392. else:
  393. raise ValueError("You have to specify either input_ids or inputs_embeds")
  394. device = input_ids.device if input_ids is not None else inputs_embeds.device
  395. if attention_mask is None:
  396. attention_mask = torch.ones(input_shape, device=device)
  397. if token_type_ids is None:
  398. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  399. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  400. # Prepare head mask if needed
  401. # 1.0 in head_mask indicate we keep the head
  402. # attention_probs has shape bsz x n_heads x N x N
  403. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  404. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  405. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  406. embedding_output = self.embeddings(
  407. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  408. )
  409. encoder_outputs = self.encoder(
  410. hidden_states=embedding_output,
  411. attention_mask=extended_attention_mask,
  412. head_mask=head_mask,
  413. output_attentions=output_attentions,
  414. output_hidden_states=output_hidden_states,
  415. return_dict=return_dict,
  416. )
  417. sequence_output = encoder_outputs[0]
  418. pooled_output = self.pooler(sequence_output)
  419. if not return_dict:
  420. return (sequence_output, pooled_output) + encoder_outputs[1:]
  421. return BaseModelOutputWithPooling(
  422. last_hidden_state=sequence_output,
  423. pooler_output=pooled_output,
  424. hidden_states=encoder_outputs.hidden_states,
  425. attentions=encoder_outputs.attentions,
  426. )
  427. @auto_docstring
  428. class SqueezeBertForMaskedLM(SqueezeBertPreTrainedModel):
  429. _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
  430. def __init__(self, config):
  431. super().__init__(config)
  432. self.transformer = SqueezeBertModel(config)
  433. self.cls = SqueezeBertOnlyMLMHead(config)
  434. # Initialize weights and apply final processing
  435. self.post_init()
  436. def get_output_embeddings(self):
  437. return self.cls.predictions.decoder
  438. def set_output_embeddings(self, new_embeddings):
  439. self.cls.predictions.decoder = new_embeddings
  440. self.cls.predictions.bias = new_embeddings.bias
  441. @auto_docstring
  442. def forward(
  443. self,
  444. input_ids: Optional[torch.Tensor] = None,
  445. attention_mask: Optional[torch.Tensor] = None,
  446. token_type_ids: Optional[torch.Tensor] = None,
  447. position_ids: Optional[torch.Tensor] = None,
  448. head_mask: Optional[torch.Tensor] = None,
  449. inputs_embeds: Optional[torch.Tensor] = None,
  450. labels: Optional[torch.Tensor] = None,
  451. output_attentions: Optional[bool] = None,
  452. output_hidden_states: Optional[bool] = None,
  453. return_dict: Optional[bool] = None,
  454. ) -> Union[tuple, MaskedLMOutput]:
  455. r"""
  456. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  457. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  458. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  459. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  460. """
  461. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  462. outputs = self.transformer(
  463. input_ids,
  464. attention_mask=attention_mask,
  465. token_type_ids=token_type_ids,
  466. position_ids=position_ids,
  467. head_mask=head_mask,
  468. inputs_embeds=inputs_embeds,
  469. output_attentions=output_attentions,
  470. output_hidden_states=output_hidden_states,
  471. return_dict=return_dict,
  472. )
  473. sequence_output = outputs[0]
  474. prediction_scores = self.cls(sequence_output)
  475. masked_lm_loss = None
  476. if labels is not None:
  477. loss_fct = CrossEntropyLoss() # -100 index = padding token
  478. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  479. if not return_dict:
  480. output = (prediction_scores,) + outputs[2:]
  481. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  482. return MaskedLMOutput(
  483. loss=masked_lm_loss,
  484. logits=prediction_scores,
  485. hidden_states=outputs.hidden_states,
  486. attentions=outputs.attentions,
  487. )
  488. @auto_docstring(
  489. custom_intro="""
  490. SqueezeBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  491. pooled output) e.g. for GLUE tasks.
  492. """
  493. )
  494. class SqueezeBertForSequenceClassification(SqueezeBertPreTrainedModel):
  495. def __init__(self, config):
  496. super().__init__(config)
  497. self.num_labels = config.num_labels
  498. self.config = config
  499. self.transformer = SqueezeBertModel(config)
  500. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  501. self.classifier = nn.Linear(config.hidden_size, self.config.num_labels)
  502. # Initialize weights and apply final processing
  503. self.post_init()
  504. @auto_docstring
  505. def forward(
  506. self,
  507. input_ids: Optional[torch.Tensor] = None,
  508. attention_mask: Optional[torch.Tensor] = None,
  509. token_type_ids: Optional[torch.Tensor] = None,
  510. position_ids: Optional[torch.Tensor] = None,
  511. head_mask: Optional[torch.Tensor] = None,
  512. inputs_embeds: Optional[torch.Tensor] = None,
  513. labels: Optional[torch.Tensor] = None,
  514. output_attentions: Optional[bool] = None,
  515. output_hidden_states: Optional[bool] = None,
  516. return_dict: Optional[bool] = None,
  517. ) -> Union[tuple, SequenceClassifierOutput]:
  518. r"""
  519. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  520. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  521. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  522. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  523. """
  524. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  525. outputs = self.transformer(
  526. input_ids,
  527. attention_mask=attention_mask,
  528. token_type_ids=token_type_ids,
  529. position_ids=position_ids,
  530. head_mask=head_mask,
  531. inputs_embeds=inputs_embeds,
  532. output_attentions=output_attentions,
  533. output_hidden_states=output_hidden_states,
  534. return_dict=return_dict,
  535. )
  536. pooled_output = outputs[1]
  537. pooled_output = self.dropout(pooled_output)
  538. logits = self.classifier(pooled_output)
  539. loss = None
  540. if labels is not None:
  541. if self.config.problem_type is None:
  542. if self.num_labels == 1:
  543. self.config.problem_type = "regression"
  544. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  545. self.config.problem_type = "single_label_classification"
  546. else:
  547. self.config.problem_type = "multi_label_classification"
  548. if self.config.problem_type == "regression":
  549. loss_fct = MSELoss()
  550. if self.num_labels == 1:
  551. loss = loss_fct(logits.squeeze(), labels.squeeze())
  552. else:
  553. loss = loss_fct(logits, labels)
  554. elif self.config.problem_type == "single_label_classification":
  555. loss_fct = CrossEntropyLoss()
  556. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  557. elif self.config.problem_type == "multi_label_classification":
  558. loss_fct = BCEWithLogitsLoss()
  559. loss = loss_fct(logits, labels)
  560. if not return_dict:
  561. output = (logits,) + outputs[2:]
  562. return ((loss,) + output) if loss is not None else output
  563. return SequenceClassifierOutput(
  564. loss=loss,
  565. logits=logits,
  566. hidden_states=outputs.hidden_states,
  567. attentions=outputs.attentions,
  568. )
  569. @auto_docstring
  570. class SqueezeBertForMultipleChoice(SqueezeBertPreTrainedModel):
  571. def __init__(self, config):
  572. super().__init__(config)
  573. self.transformer = SqueezeBertModel(config)
  574. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  575. self.classifier = nn.Linear(config.hidden_size, 1)
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. @auto_docstring
  579. def forward(
  580. self,
  581. input_ids: Optional[torch.Tensor] = None,
  582. attention_mask: Optional[torch.Tensor] = None,
  583. token_type_ids: Optional[torch.Tensor] = None,
  584. position_ids: Optional[torch.Tensor] = None,
  585. head_mask: Optional[torch.Tensor] = None,
  586. inputs_embeds: Optional[torch.Tensor] = None,
  587. labels: Optional[torch.Tensor] = None,
  588. output_attentions: Optional[bool] = None,
  589. output_hidden_states: Optional[bool] = None,
  590. return_dict: Optional[bool] = None,
  591. ) -> Union[tuple, MultipleChoiceModelOutput]:
  592. r"""
  593. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  594. Indices of input sequence tokens in the vocabulary.
  595. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  596. [`PreTrainedTokenizer.__call__`] for details.
  597. [What are input IDs?](../glossary#input-ids)
  598. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  599. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  600. 1]`:
  601. - 0 corresponds to a *sentence A* token,
  602. - 1 corresponds to a *sentence B* token.
  603. [What are token type IDs?](../glossary#token-type-ids)
  604. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  605. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  606. config.max_position_embeddings - 1]`.
  607. [What are position IDs?](../glossary#position-ids)
  608. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  609. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  610. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  611. model's internal embedding lookup matrix.
  612. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  613. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  614. num_choices-1]` where *num_choices* is the size of the second dimension of the input tensors. (see
  615. *input_ids* above)
  616. """
  617. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  618. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  619. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  620. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  621. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  622. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  623. inputs_embeds = (
  624. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  625. if inputs_embeds is not None
  626. else None
  627. )
  628. outputs = self.transformer(
  629. input_ids,
  630. attention_mask=attention_mask,
  631. token_type_ids=token_type_ids,
  632. position_ids=position_ids,
  633. head_mask=head_mask,
  634. inputs_embeds=inputs_embeds,
  635. output_attentions=output_attentions,
  636. output_hidden_states=output_hidden_states,
  637. return_dict=return_dict,
  638. )
  639. pooled_output = outputs[1]
  640. pooled_output = self.dropout(pooled_output)
  641. logits = self.classifier(pooled_output)
  642. reshaped_logits = logits.view(-1, num_choices)
  643. loss = None
  644. if labels is not None:
  645. loss_fct = CrossEntropyLoss()
  646. loss = loss_fct(reshaped_logits, labels)
  647. if not return_dict:
  648. output = (reshaped_logits,) + outputs[2:]
  649. return ((loss,) + output) if loss is not None else output
  650. return MultipleChoiceModelOutput(
  651. loss=loss,
  652. logits=reshaped_logits,
  653. hidden_states=outputs.hidden_states,
  654. attentions=outputs.attentions,
  655. )
  656. @auto_docstring
  657. class SqueezeBertForTokenClassification(SqueezeBertPreTrainedModel):
  658. def __init__(self, config):
  659. super().__init__(config)
  660. self.num_labels = config.num_labels
  661. self.transformer = SqueezeBertModel(config)
  662. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  663. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  664. # Initialize weights and apply final processing
  665. self.post_init()
  666. @auto_docstring
  667. def forward(
  668. self,
  669. input_ids: Optional[torch.Tensor] = None,
  670. attention_mask: Optional[torch.Tensor] = None,
  671. token_type_ids: Optional[torch.Tensor] = None,
  672. position_ids: Optional[torch.Tensor] = None,
  673. head_mask: Optional[torch.Tensor] = None,
  674. inputs_embeds: Optional[torch.Tensor] = None,
  675. labels: Optional[torch.Tensor] = None,
  676. output_attentions: Optional[bool] = None,
  677. output_hidden_states: Optional[bool] = None,
  678. return_dict: Optional[bool] = None,
  679. ) -> Union[tuple, TokenClassifierOutput]:
  680. r"""
  681. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  682. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  683. """
  684. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  685. outputs = self.transformer(
  686. input_ids,
  687. attention_mask=attention_mask,
  688. token_type_ids=token_type_ids,
  689. position_ids=position_ids,
  690. head_mask=head_mask,
  691. inputs_embeds=inputs_embeds,
  692. output_attentions=output_attentions,
  693. output_hidden_states=output_hidden_states,
  694. return_dict=return_dict,
  695. )
  696. sequence_output = outputs[0]
  697. sequence_output = self.dropout(sequence_output)
  698. logits = self.classifier(sequence_output)
  699. loss = None
  700. if labels is not None:
  701. loss_fct = CrossEntropyLoss()
  702. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  703. if not return_dict:
  704. output = (logits,) + outputs[2:]
  705. return ((loss,) + output) if loss is not None else output
  706. return TokenClassifierOutput(
  707. loss=loss,
  708. logits=logits,
  709. hidden_states=outputs.hidden_states,
  710. attentions=outputs.attentions,
  711. )
  712. @auto_docstring
  713. class SqueezeBertForQuestionAnswering(SqueezeBertPreTrainedModel):
  714. def __init__(self, config):
  715. super().__init__(config)
  716. self.num_labels = config.num_labels
  717. self.transformer = SqueezeBertModel(config)
  718. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  719. # Initialize weights and apply final processing
  720. self.post_init()
  721. @auto_docstring
  722. def forward(
  723. self,
  724. input_ids: Optional[torch.Tensor] = None,
  725. attention_mask: Optional[torch.Tensor] = None,
  726. token_type_ids: Optional[torch.Tensor] = None,
  727. position_ids: Optional[torch.Tensor] = None,
  728. head_mask: Optional[torch.Tensor] = None,
  729. inputs_embeds: Optional[torch.Tensor] = None,
  730. start_positions: Optional[torch.Tensor] = None,
  731. end_positions: Optional[torch.Tensor] = None,
  732. output_attentions: Optional[bool] = None,
  733. output_hidden_states: Optional[bool] = None,
  734. return_dict: Optional[bool] = None,
  735. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  736. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  737. outputs = self.transformer(
  738. input_ids,
  739. attention_mask=attention_mask,
  740. token_type_ids=token_type_ids,
  741. position_ids=position_ids,
  742. head_mask=head_mask,
  743. inputs_embeds=inputs_embeds,
  744. output_attentions=output_attentions,
  745. output_hidden_states=output_hidden_states,
  746. return_dict=return_dict,
  747. )
  748. sequence_output = outputs[0]
  749. logits = self.qa_outputs(sequence_output)
  750. start_logits, end_logits = logits.split(1, dim=-1)
  751. start_logits = start_logits.squeeze(-1).contiguous()
  752. end_logits = end_logits.squeeze(-1).contiguous()
  753. total_loss = None
  754. if start_positions is not None and end_positions is not None:
  755. # If we are on multi-GPU, split add a dimension
  756. if len(start_positions.size()) > 1:
  757. start_positions = start_positions.squeeze(-1)
  758. if len(end_positions.size()) > 1:
  759. end_positions = end_positions.squeeze(-1)
  760. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  761. ignored_index = start_logits.size(1)
  762. start_positions = start_positions.clamp(0, ignored_index)
  763. end_positions = end_positions.clamp(0, ignored_index)
  764. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  765. start_loss = loss_fct(start_logits, start_positions)
  766. end_loss = loss_fct(end_logits, end_positions)
  767. total_loss = (start_loss + end_loss) / 2
  768. if not return_dict:
  769. output = (start_logits, end_logits) + outputs[2:]
  770. return ((total_loss,) + output) if total_loss is not None else output
  771. return QuestionAnsweringModelOutput(
  772. loss=total_loss,
  773. start_logits=start_logits,
  774. end_logits=end_logits,
  775. hidden_states=outputs.hidden_states,
  776. attentions=outputs.attentions,
  777. )
  778. __all__ = [
  779. "SqueezeBertForMaskedLM",
  780. "SqueezeBertForMultipleChoice",
  781. "SqueezeBertForQuestionAnswering",
  782. "SqueezeBertForSequenceClassification",
  783. "SqueezeBertForTokenClassification",
  784. "SqueezeBertModel",
  785. "SqueezeBertModule",
  786. "SqueezeBertPreTrainedModel",
  787. ]