modeling_convbert.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334
  1. # coding=utf-8
  2. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch ConvBERT model."""
  16. import math
  17. import os
  18. from operator import attrgetter
  19. from typing import Callable, Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import ACT2FN, get_activation
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithCrossAttentions,
  27. MaskedLMOutput,
  28. MultipleChoiceModelOutput,
  29. QuestionAnsweringModelOutput,
  30. SequenceClassifierOutput,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  35. from ...utils import (
  36. auto_docstring,
  37. logging,
  38. )
  39. from .configuration_convbert import ConvBertConfig
  40. logger = logging.get_logger(__name__)
  41. def load_tf_weights_in_convbert(model, config, tf_checkpoint_path):
  42. """Load tf checkpoints in a pytorch model."""
  43. try:
  44. import tensorflow as tf
  45. except ImportError:
  46. logger.error(
  47. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  48. "https://www.tensorflow.org/install/ for installation instructions."
  49. )
  50. raise
  51. tf_path = os.path.abspath(tf_checkpoint_path)
  52. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  53. # Load weights from TF model
  54. init_vars = tf.train.list_variables(tf_path)
  55. tf_data = {}
  56. for name, shape in init_vars:
  57. logger.info(f"Loading TF weight {name} with shape {shape}")
  58. array = tf.train.load_variable(tf_path, name)
  59. tf_data[name] = array
  60. param_mapping = {
  61. "embeddings.word_embeddings.weight": "electra/embeddings/word_embeddings",
  62. "embeddings.position_embeddings.weight": "electra/embeddings/position_embeddings",
  63. "embeddings.token_type_embeddings.weight": "electra/embeddings/token_type_embeddings",
  64. "embeddings.LayerNorm.weight": "electra/embeddings/LayerNorm/gamma",
  65. "embeddings.LayerNorm.bias": "electra/embeddings/LayerNorm/beta",
  66. "embeddings_project.weight": "electra/embeddings_project/kernel",
  67. "embeddings_project.bias": "electra/embeddings_project/bias",
  68. }
  69. if config.num_groups > 1:
  70. group_dense_name = "g_dense"
  71. else:
  72. group_dense_name = "dense"
  73. for j in range(config.num_hidden_layers):
  74. param_mapping[f"encoder.layer.{j}.attention.self.query.weight"] = (
  75. f"electra/encoder/layer_{j}/attention/self/query/kernel"
  76. )
  77. param_mapping[f"encoder.layer.{j}.attention.self.query.bias"] = (
  78. f"electra/encoder/layer_{j}/attention/self/query/bias"
  79. )
  80. param_mapping[f"encoder.layer.{j}.attention.self.key.weight"] = (
  81. f"electra/encoder/layer_{j}/attention/self/key/kernel"
  82. )
  83. param_mapping[f"encoder.layer.{j}.attention.self.key.bias"] = (
  84. f"electra/encoder/layer_{j}/attention/self/key/bias"
  85. )
  86. param_mapping[f"encoder.layer.{j}.attention.self.value.weight"] = (
  87. f"electra/encoder/layer_{j}/attention/self/value/kernel"
  88. )
  89. param_mapping[f"encoder.layer.{j}.attention.self.value.bias"] = (
  90. f"electra/encoder/layer_{j}/attention/self/value/bias"
  91. )
  92. param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.depthwise.weight"] = (
  93. f"electra/encoder/layer_{j}/attention/self/conv_attn_key/depthwise_kernel"
  94. )
  95. param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.pointwise.weight"] = (
  96. f"electra/encoder/layer_{j}/attention/self/conv_attn_key/pointwise_kernel"
  97. )
  98. param_mapping[f"encoder.layer.{j}.attention.self.key_conv_attn_layer.bias"] = (
  99. f"electra/encoder/layer_{j}/attention/self/conv_attn_key/bias"
  100. )
  101. param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.weight"] = (
  102. f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/kernel"
  103. )
  104. param_mapping[f"encoder.layer.{j}.attention.self.conv_kernel_layer.bias"] = (
  105. f"electra/encoder/layer_{j}/attention/self/conv_attn_kernel/bias"
  106. )
  107. param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.weight"] = (
  108. f"electra/encoder/layer_{j}/attention/self/conv_attn_point/kernel"
  109. )
  110. param_mapping[f"encoder.layer.{j}.attention.self.conv_out_layer.bias"] = (
  111. f"electra/encoder/layer_{j}/attention/self/conv_attn_point/bias"
  112. )
  113. param_mapping[f"encoder.layer.{j}.attention.output.dense.weight"] = (
  114. f"electra/encoder/layer_{j}/attention/output/dense/kernel"
  115. )
  116. param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.weight"] = (
  117. f"electra/encoder/layer_{j}/attention/output/LayerNorm/gamma"
  118. )
  119. param_mapping[f"encoder.layer.{j}.attention.output.dense.bias"] = (
  120. f"electra/encoder/layer_{j}/attention/output/dense/bias"
  121. )
  122. param_mapping[f"encoder.layer.{j}.attention.output.LayerNorm.bias"] = (
  123. f"electra/encoder/layer_{j}/attention/output/LayerNorm/beta"
  124. )
  125. param_mapping[f"encoder.layer.{j}.intermediate.dense.weight"] = (
  126. f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/kernel"
  127. )
  128. param_mapping[f"encoder.layer.{j}.intermediate.dense.bias"] = (
  129. f"electra/encoder/layer_{j}/intermediate/{group_dense_name}/bias"
  130. )
  131. param_mapping[f"encoder.layer.{j}.output.dense.weight"] = (
  132. f"electra/encoder/layer_{j}/output/{group_dense_name}/kernel"
  133. )
  134. param_mapping[f"encoder.layer.{j}.output.dense.bias"] = (
  135. f"electra/encoder/layer_{j}/output/{group_dense_name}/bias"
  136. )
  137. param_mapping[f"encoder.layer.{j}.output.LayerNorm.weight"] = (
  138. f"electra/encoder/layer_{j}/output/LayerNorm/gamma"
  139. )
  140. param_mapping[f"encoder.layer.{j}.output.LayerNorm.bias"] = f"electra/encoder/layer_{j}/output/LayerNorm/beta"
  141. for param in model.named_parameters():
  142. param_name = param[0]
  143. retriever = attrgetter(param_name)
  144. result = retriever(model)
  145. tf_name = param_mapping[param_name]
  146. value = torch.from_numpy(tf_data[tf_name])
  147. logger.info(f"TF: {tf_name}, PT: {param_name} ")
  148. if tf_name.endswith("/kernel"):
  149. if not tf_name.endswith("/intermediate/g_dense/kernel"):
  150. if not tf_name.endswith("/output/g_dense/kernel"):
  151. value = value.T
  152. if tf_name.endswith("/depthwise_kernel"):
  153. value = value.permute(1, 2, 0) # 2, 0, 1
  154. if tf_name.endswith("/pointwise_kernel"):
  155. value = value.permute(2, 1, 0) # 2, 1, 0
  156. if tf_name.endswith("/conv_attn_key/bias"):
  157. value = value.unsqueeze(-1)
  158. result.data = value
  159. return model
  160. class ConvBertEmbeddings(nn.Module):
  161. """Construct the embeddings from word, position and token_type embeddings."""
  162. def __init__(self, config):
  163. super().__init__()
  164. self.word_embeddings = nn.Embedding(config.vocab_size, config.embedding_size, padding_idx=config.pad_token_id)
  165. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.embedding_size)
  166. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.embedding_size)
  167. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  168. # any TensorFlow checkpoint file
  169. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  170. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  171. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  172. self.register_buffer(
  173. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  174. )
  175. self.register_buffer(
  176. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  177. )
  178. def forward(
  179. self,
  180. input_ids: Optional[torch.LongTensor] = None,
  181. token_type_ids: Optional[torch.LongTensor] = None,
  182. position_ids: Optional[torch.LongTensor] = None,
  183. inputs_embeds: Optional[torch.FloatTensor] = None,
  184. ) -> torch.LongTensor:
  185. if input_ids is not None:
  186. input_shape = input_ids.size()
  187. else:
  188. input_shape = inputs_embeds.size()[:-1]
  189. seq_length = input_shape[1]
  190. if position_ids is None:
  191. position_ids = self.position_ids[:, :seq_length]
  192. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  193. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  194. # issue #5664
  195. if token_type_ids is None:
  196. if hasattr(self, "token_type_ids"):
  197. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  198. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  199. token_type_ids = buffered_token_type_ids_expanded
  200. else:
  201. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  202. if inputs_embeds is None:
  203. inputs_embeds = self.word_embeddings(input_ids)
  204. position_embeddings = self.position_embeddings(position_ids)
  205. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  206. embeddings = inputs_embeds + position_embeddings + token_type_embeddings
  207. embeddings = self.LayerNorm(embeddings)
  208. embeddings = self.dropout(embeddings)
  209. return embeddings
  210. @auto_docstring
  211. class ConvBertPreTrainedModel(PreTrainedModel):
  212. config: ConvBertConfig
  213. load_tf_weights = load_tf_weights_in_convbert
  214. base_model_prefix = "convbert"
  215. supports_gradient_checkpointing = True
  216. def _init_weights(self, module):
  217. """Initialize the weights"""
  218. if isinstance(module, (nn.Linear, nn.Conv1d)):
  219. # Slightly different from the TF version which uses truncated_normal for initialization
  220. # cf https://github.com/pytorch/pytorch/pull/5617
  221. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  222. if module.bias is not None:
  223. module.bias.data.zero_()
  224. elif isinstance(module, nn.Embedding):
  225. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  226. if module.padding_idx is not None:
  227. module.weight.data[module.padding_idx].zero_()
  228. elif isinstance(module, nn.LayerNorm):
  229. module.bias.data.zero_()
  230. module.weight.data.fill_(1.0)
  231. elif isinstance(module, SeparableConv1D):
  232. module.bias.data.zero_()
  233. elif isinstance(module, GroupedLinearLayer):
  234. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  235. module.bias.data.zero_()
  236. class SeparableConv1D(nn.Module):
  237. """This class implements separable convolution, i.e. a depthwise and a pointwise layer"""
  238. def __init__(self, config, input_filters, output_filters, kernel_size, **kwargs):
  239. super().__init__()
  240. self.depthwise = nn.Conv1d(
  241. input_filters,
  242. input_filters,
  243. kernel_size=kernel_size,
  244. groups=input_filters,
  245. padding=kernel_size // 2,
  246. bias=False,
  247. )
  248. self.pointwise = nn.Conv1d(input_filters, output_filters, kernel_size=1, bias=False)
  249. self.bias = nn.Parameter(torch.zeros(output_filters, 1))
  250. self.depthwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
  251. self.pointwise.weight.data.normal_(mean=0.0, std=config.initializer_range)
  252. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  253. x = self.depthwise(hidden_states)
  254. x = self.pointwise(x)
  255. x += self.bias
  256. return x
  257. class ConvBertSelfAttention(nn.Module):
  258. def __init__(self, config):
  259. super().__init__()
  260. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  261. raise ValueError(
  262. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  263. f"heads ({config.num_attention_heads})"
  264. )
  265. new_num_attention_heads = config.num_attention_heads // config.head_ratio
  266. if new_num_attention_heads < 1:
  267. self.head_ratio = config.num_attention_heads
  268. self.num_attention_heads = 1
  269. else:
  270. self.num_attention_heads = new_num_attention_heads
  271. self.head_ratio = config.head_ratio
  272. self.conv_kernel_size = config.conv_kernel_size
  273. if config.hidden_size % self.num_attention_heads != 0:
  274. raise ValueError("hidden_size should be divisible by num_attention_heads")
  275. self.attention_head_size = (config.hidden_size // self.num_attention_heads) // 2
  276. self.all_head_size = self.num_attention_heads * self.attention_head_size
  277. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  278. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  279. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  280. self.key_conv_attn_layer = SeparableConv1D(
  281. config, config.hidden_size, self.all_head_size, self.conv_kernel_size
  282. )
  283. self.conv_kernel_layer = nn.Linear(self.all_head_size, self.num_attention_heads * self.conv_kernel_size)
  284. self.conv_out_layer = nn.Linear(config.hidden_size, self.all_head_size)
  285. self.unfold = nn.Unfold(
  286. kernel_size=[self.conv_kernel_size, 1], padding=[int((self.conv_kernel_size - 1) / 2), 0]
  287. )
  288. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  289. def forward(
  290. self,
  291. hidden_states: torch.Tensor,
  292. attention_mask: Optional[torch.FloatTensor] = None,
  293. head_mask: Optional[torch.FloatTensor] = None,
  294. encoder_hidden_states: Optional[torch.Tensor] = None,
  295. output_attentions: Optional[bool] = False,
  296. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  297. batch_size, seq_length, _ = hidden_states.shape
  298. # If this is instantiated as a cross-attention module, the keys
  299. # and values come from an encoder; the attention mask needs to be
  300. # such that the encoder's padding tokens are not attended to.
  301. if encoder_hidden_states is not None:
  302. mixed_key_layer = self.key(encoder_hidden_states)
  303. mixed_value_layer = self.value(encoder_hidden_states)
  304. else:
  305. mixed_key_layer = self.key(hidden_states)
  306. mixed_value_layer = self.value(hidden_states)
  307. mixed_key_conv_attn_layer = self.key_conv_attn_layer(hidden_states.transpose(1, 2))
  308. mixed_key_conv_attn_layer = mixed_key_conv_attn_layer.transpose(1, 2)
  309. mixed_query_layer = self.query(hidden_states)
  310. query_layer = mixed_query_layer.view(
  311. batch_size, -1, self.num_attention_heads, self.attention_head_size
  312. ).transpose(1, 2)
  313. key_layer = mixed_key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  314. 1, 2
  315. )
  316. value_layer = mixed_value_layer.view(
  317. batch_size, -1, self.num_attention_heads, self.attention_head_size
  318. ).transpose(1, 2)
  319. conv_attn_layer = torch.multiply(mixed_key_conv_attn_layer, mixed_query_layer)
  320. conv_kernel_layer = self.conv_kernel_layer(conv_attn_layer)
  321. conv_kernel_layer = torch.reshape(conv_kernel_layer, [-1, self.conv_kernel_size, 1])
  322. conv_kernel_layer = torch.softmax(conv_kernel_layer, dim=1)
  323. conv_out_layer = self.conv_out_layer(hidden_states)
  324. conv_out_layer = torch.reshape(conv_out_layer, [batch_size, -1, self.all_head_size])
  325. conv_out_layer = conv_out_layer.transpose(1, 2).contiguous().unsqueeze(-1)
  326. conv_out_layer = nn.functional.unfold(
  327. conv_out_layer,
  328. kernel_size=[self.conv_kernel_size, 1],
  329. dilation=1,
  330. padding=[(self.conv_kernel_size - 1) // 2, 0],
  331. stride=1,
  332. )
  333. conv_out_layer = conv_out_layer.transpose(1, 2).reshape(
  334. batch_size, -1, self.all_head_size, self.conv_kernel_size
  335. )
  336. conv_out_layer = torch.reshape(conv_out_layer, [-1, self.attention_head_size, self.conv_kernel_size])
  337. conv_out_layer = torch.matmul(conv_out_layer, conv_kernel_layer)
  338. conv_out_layer = torch.reshape(conv_out_layer, [-1, self.all_head_size])
  339. # Take the dot product between "query" and "key" to get the raw attention scores.
  340. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  341. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  342. if attention_mask is not None:
  343. # Apply the attention mask is (precomputed for all layers in ConvBertModel forward() function)
  344. attention_scores = attention_scores + attention_mask
  345. # Normalize the attention scores to probabilities.
  346. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  347. # This is actually dropping out entire tokens to attend to, which might
  348. # seem a bit unusual, but is taken from the original Transformer paper.
  349. attention_probs = self.dropout(attention_probs)
  350. # Mask heads if we want to
  351. if head_mask is not None:
  352. attention_probs = attention_probs * head_mask
  353. context_layer = torch.matmul(attention_probs, value_layer)
  354. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  355. conv_out = torch.reshape(conv_out_layer, [batch_size, -1, self.num_attention_heads, self.attention_head_size])
  356. context_layer = torch.cat([context_layer, conv_out], 2)
  357. # conv and context
  358. new_context_layer_shape = context_layer.size()[:-2] + (
  359. self.num_attention_heads * self.attention_head_size * 2,
  360. )
  361. context_layer = context_layer.view(*new_context_layer_shape)
  362. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  363. return outputs
  364. class ConvBertSelfOutput(nn.Module):
  365. def __init__(self, config):
  366. super().__init__()
  367. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  368. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  369. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  370. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  371. hidden_states = self.dense(hidden_states)
  372. hidden_states = self.dropout(hidden_states)
  373. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  374. return hidden_states
  375. class ConvBertAttention(nn.Module):
  376. def __init__(self, config):
  377. super().__init__()
  378. self.self = ConvBertSelfAttention(config)
  379. self.output = ConvBertSelfOutput(config)
  380. self.pruned_heads = set()
  381. def prune_heads(self, heads):
  382. if len(heads) == 0:
  383. return
  384. heads, index = find_pruneable_heads_and_indices(
  385. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  386. )
  387. # Prune linear layers
  388. self.self.query = prune_linear_layer(self.self.query, index)
  389. self.self.key = prune_linear_layer(self.self.key, index)
  390. self.self.value = prune_linear_layer(self.self.value, index)
  391. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  392. # Update hyper params and store pruned heads
  393. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  394. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  395. self.pruned_heads = self.pruned_heads.union(heads)
  396. def forward(
  397. self,
  398. hidden_states: torch.Tensor,
  399. attention_mask: Optional[torch.FloatTensor] = None,
  400. head_mask: Optional[torch.FloatTensor] = None,
  401. encoder_hidden_states: Optional[torch.Tensor] = None,
  402. output_attentions: Optional[bool] = False,
  403. ) -> tuple[torch.Tensor, Optional[torch.FloatTensor]]:
  404. self_outputs = self.self(
  405. hidden_states,
  406. attention_mask,
  407. head_mask,
  408. encoder_hidden_states,
  409. output_attentions,
  410. )
  411. attention_output = self.output(self_outputs[0], hidden_states)
  412. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  413. return outputs
  414. class GroupedLinearLayer(nn.Module):
  415. def __init__(self, input_size, output_size, num_groups):
  416. super().__init__()
  417. self.input_size = input_size
  418. self.output_size = output_size
  419. self.num_groups = num_groups
  420. self.group_in_dim = self.input_size // self.num_groups
  421. self.group_out_dim = self.output_size // self.num_groups
  422. self.weight = nn.Parameter(torch.empty(self.num_groups, self.group_in_dim, self.group_out_dim))
  423. self.bias = nn.Parameter(torch.empty(output_size))
  424. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  425. batch_size = list(hidden_states.size())[0]
  426. x = torch.reshape(hidden_states, [-1, self.num_groups, self.group_in_dim])
  427. x = x.permute(1, 0, 2)
  428. x = torch.matmul(x, self.weight)
  429. x = x.permute(1, 0, 2)
  430. x = torch.reshape(x, [batch_size, -1, self.output_size])
  431. x = x + self.bias
  432. return x
  433. class ConvBertIntermediate(nn.Module):
  434. def __init__(self, config):
  435. super().__init__()
  436. if config.num_groups == 1:
  437. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  438. else:
  439. self.dense = GroupedLinearLayer(
  440. input_size=config.hidden_size, output_size=config.intermediate_size, num_groups=config.num_groups
  441. )
  442. if isinstance(config.hidden_act, str):
  443. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  444. else:
  445. self.intermediate_act_fn = config.hidden_act
  446. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  447. hidden_states = self.dense(hidden_states)
  448. hidden_states = self.intermediate_act_fn(hidden_states)
  449. return hidden_states
  450. class ConvBertOutput(nn.Module):
  451. def __init__(self, config):
  452. super().__init__()
  453. if config.num_groups == 1:
  454. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  455. else:
  456. self.dense = GroupedLinearLayer(
  457. input_size=config.intermediate_size, output_size=config.hidden_size, num_groups=config.num_groups
  458. )
  459. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  460. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  461. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  462. hidden_states = self.dense(hidden_states)
  463. hidden_states = self.dropout(hidden_states)
  464. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  465. return hidden_states
  466. class ConvBertLayer(GradientCheckpointingLayer):
  467. def __init__(self, config):
  468. super().__init__()
  469. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  470. self.seq_len_dim = 1
  471. self.attention = ConvBertAttention(config)
  472. self.is_decoder = config.is_decoder
  473. self.add_cross_attention = config.add_cross_attention
  474. if self.add_cross_attention:
  475. if not self.is_decoder:
  476. raise TypeError(f"{self} should be used as a decoder model if cross attention is added")
  477. self.crossattention = ConvBertAttention(config)
  478. self.intermediate = ConvBertIntermediate(config)
  479. self.output = ConvBertOutput(config)
  480. def forward(
  481. self,
  482. hidden_states: torch.Tensor,
  483. attention_mask: Optional[torch.FloatTensor] = None,
  484. head_mask: Optional[torch.FloatTensor] = None,
  485. encoder_hidden_states: Optional[torch.Tensor] = None,
  486. encoder_attention_mask: Optional[torch.Tensor] = None,
  487. output_attentions: Optional[bool] = False,
  488. ) -> tuple[torch.Tensor, Optional[torch.FloatTensor]]:
  489. self_attention_outputs = self.attention(
  490. hidden_states,
  491. attention_mask,
  492. head_mask,
  493. output_attentions=output_attentions,
  494. )
  495. attention_output = self_attention_outputs[0]
  496. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  497. if self.is_decoder and encoder_hidden_states is not None:
  498. if not hasattr(self, "crossattention"):
  499. raise AttributeError(
  500. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  501. " by setting `config.add_cross_attention=True`"
  502. )
  503. cross_attention_outputs = self.crossattention(
  504. attention_output,
  505. encoder_attention_mask,
  506. head_mask,
  507. encoder_hidden_states,
  508. output_attentions,
  509. )
  510. attention_output = cross_attention_outputs[0]
  511. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  512. layer_output = apply_chunking_to_forward(
  513. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  514. )
  515. outputs = (layer_output,) + outputs
  516. return outputs
  517. def feed_forward_chunk(self, attention_output):
  518. intermediate_output = self.intermediate(attention_output)
  519. layer_output = self.output(intermediate_output, attention_output)
  520. return layer_output
  521. class ConvBertEncoder(nn.Module):
  522. def __init__(self, config):
  523. super().__init__()
  524. self.config = config
  525. self.layer = nn.ModuleList([ConvBertLayer(config) for _ in range(config.num_hidden_layers)])
  526. self.gradient_checkpointing = False
  527. def forward(
  528. self,
  529. hidden_states: torch.Tensor,
  530. attention_mask: Optional[torch.FloatTensor] = None,
  531. head_mask: Optional[torch.FloatTensor] = None,
  532. encoder_hidden_states: Optional[torch.Tensor] = None,
  533. encoder_attention_mask: Optional[torch.Tensor] = None,
  534. output_attentions: Optional[bool] = False,
  535. output_hidden_states: Optional[bool] = False,
  536. return_dict: Optional[bool] = True,
  537. ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
  538. all_hidden_states = () if output_hidden_states else None
  539. all_self_attentions = () if output_attentions else None
  540. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  541. for i, layer_module in enumerate(self.layer):
  542. if output_hidden_states:
  543. all_hidden_states = all_hidden_states + (hidden_states,)
  544. layer_head_mask = head_mask[i] if head_mask is not None else None
  545. layer_outputs = layer_module(
  546. hidden_states,
  547. attention_mask,
  548. layer_head_mask,
  549. encoder_hidden_states,
  550. encoder_attention_mask,
  551. output_attentions,
  552. )
  553. hidden_states = layer_outputs[0]
  554. if output_attentions:
  555. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  556. if self.config.add_cross_attention:
  557. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  558. if output_hidden_states:
  559. all_hidden_states = all_hidden_states + (hidden_states,)
  560. if not return_dict:
  561. return tuple(
  562. v
  563. for v in [hidden_states, all_hidden_states, all_self_attentions, all_cross_attentions]
  564. if v is not None
  565. )
  566. return BaseModelOutputWithCrossAttentions(
  567. last_hidden_state=hidden_states,
  568. hidden_states=all_hidden_states,
  569. attentions=all_self_attentions,
  570. cross_attentions=all_cross_attentions,
  571. )
  572. class ConvBertPredictionHeadTransform(nn.Module):
  573. def __init__(self, config):
  574. super().__init__()
  575. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  576. if isinstance(config.hidden_act, str):
  577. self.transform_act_fn = ACT2FN[config.hidden_act]
  578. else:
  579. self.transform_act_fn = config.hidden_act
  580. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  581. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  582. hidden_states = self.dense(hidden_states)
  583. hidden_states = self.transform_act_fn(hidden_states)
  584. hidden_states = self.LayerNorm(hidden_states)
  585. return hidden_states
  586. # Copied from transformers.models.xlm.modeling_xlm.XLMSequenceSummary with XLM->ConvBert
  587. class ConvBertSequenceSummary(nn.Module):
  588. r"""
  589. Compute a single vector summary of a sequence hidden states.
  590. Args:
  591. config ([`ConvBertConfig`]):
  592. The config used by the model. Relevant arguments in the config class of the model are (refer to the actual
  593. config class of your model for the default values it uses):
  594. - **summary_type** (`str`) -- The method to use to make this summary. Accepted values are:
  595. - `"last"` -- Take the last token hidden state (like XLNet)
  596. - `"first"` -- Take the first token hidden state (like Bert)
  597. - `"mean"` -- Take the mean of all tokens hidden states
  598. - `"cls_index"` -- Supply a Tensor of classification token position (GPT/GPT-2)
  599. - `"attn"` -- Not implemented now, use multi-head attention
  600. - **summary_use_proj** (`bool`) -- Add a projection after the vector extraction.
  601. - **summary_proj_to_labels** (`bool`) -- If `True`, the projection outputs to `config.num_labels` classes
  602. (otherwise to `config.hidden_size`).
  603. - **summary_activation** (`Optional[str]`) -- Set to `"tanh"` to add a tanh activation to the output,
  604. another string or `None` will add no activation.
  605. - **summary_first_dropout** (`float`) -- Optional dropout probability before the projection and activation.
  606. - **summary_last_dropout** (`float`)-- Optional dropout probability after the projection and activation.
  607. """
  608. def __init__(self, config: ConvBertConfig):
  609. super().__init__()
  610. self.summary_type = getattr(config, "summary_type", "last")
  611. if self.summary_type == "attn":
  612. # We should use a standard multi-head attention module with absolute positional embedding for that.
  613. # Cf. https://github.com/zihangdai/xlnet/blob/master/modeling.py#L253-L276
  614. # We can probably just use the multi-head attention module of PyTorch >=1.1.0
  615. raise NotImplementedError
  616. self.summary = nn.Identity()
  617. if hasattr(config, "summary_use_proj") and config.summary_use_proj:
  618. if hasattr(config, "summary_proj_to_labels") and config.summary_proj_to_labels and config.num_labels > 0:
  619. num_classes = config.num_labels
  620. else:
  621. num_classes = config.hidden_size
  622. self.summary = nn.Linear(config.hidden_size, num_classes)
  623. activation_string = getattr(config, "summary_activation", None)
  624. self.activation: Callable = get_activation(activation_string) if activation_string else nn.Identity()
  625. self.first_dropout = nn.Identity()
  626. if hasattr(config, "summary_first_dropout") and config.summary_first_dropout > 0:
  627. self.first_dropout = nn.Dropout(config.summary_first_dropout)
  628. self.last_dropout = nn.Identity()
  629. if hasattr(config, "summary_last_dropout") and config.summary_last_dropout > 0:
  630. self.last_dropout = nn.Dropout(config.summary_last_dropout)
  631. def forward(
  632. self, hidden_states: torch.FloatTensor, cls_index: Optional[torch.LongTensor] = None
  633. ) -> torch.FloatTensor:
  634. """
  635. Compute a single vector summary of a sequence hidden states.
  636. Args:
  637. hidden_states (`torch.FloatTensor` of shape `[batch_size, seq_len, hidden_size]`):
  638. The hidden states of the last layer.
  639. cls_index (`torch.LongTensor` of shape `[batch_size]` or `[batch_size, ...]` where ... are optional leading dimensions of `hidden_states`, *optional*):
  640. Used if `summary_type == "cls_index"` and takes the last token of the sequence as classification token.
  641. Returns:
  642. `torch.FloatTensor`: The summary of the sequence hidden states.
  643. """
  644. if self.summary_type == "last":
  645. output = hidden_states[:, -1]
  646. elif self.summary_type == "first":
  647. output = hidden_states[:, 0]
  648. elif self.summary_type == "mean":
  649. output = hidden_states.mean(dim=1)
  650. elif self.summary_type == "cls_index":
  651. if cls_index is None:
  652. cls_index = torch.full_like(
  653. hidden_states[..., :1, :],
  654. hidden_states.shape[-2] - 1,
  655. dtype=torch.long,
  656. )
  657. else:
  658. cls_index = cls_index.unsqueeze(-1).unsqueeze(-1)
  659. cls_index = cls_index.expand((-1,) * (cls_index.dim() - 1) + (hidden_states.size(-1),))
  660. # shape of cls_index: (bsz, XX, 1, hidden_size) where XX are optional leading dim of hidden_states
  661. output = hidden_states.gather(-2, cls_index).squeeze(-2) # shape (bsz, XX, hidden_size)
  662. elif self.summary_type == "attn":
  663. raise NotImplementedError
  664. output = self.first_dropout(output)
  665. output = self.summary(output)
  666. output = self.activation(output)
  667. output = self.last_dropout(output)
  668. return output
  669. @auto_docstring
  670. class ConvBertModel(ConvBertPreTrainedModel):
  671. def __init__(self, config):
  672. super().__init__(config)
  673. self.embeddings = ConvBertEmbeddings(config)
  674. if config.embedding_size != config.hidden_size:
  675. self.embeddings_project = nn.Linear(config.embedding_size, config.hidden_size)
  676. self.encoder = ConvBertEncoder(config)
  677. self.config = config
  678. # Initialize weights and apply final processing
  679. self.post_init()
  680. def get_input_embeddings(self):
  681. return self.embeddings.word_embeddings
  682. def set_input_embeddings(self, value):
  683. self.embeddings.word_embeddings = value
  684. def _prune_heads(self, heads_to_prune):
  685. """
  686. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  687. class PreTrainedModel
  688. """
  689. for layer, heads in heads_to_prune.items():
  690. self.encoder.layer[layer].attention.prune_heads(heads)
  691. @auto_docstring
  692. def forward(
  693. self,
  694. input_ids: Optional[torch.LongTensor] = None,
  695. attention_mask: Optional[torch.FloatTensor] = None,
  696. token_type_ids: Optional[torch.LongTensor] = None,
  697. position_ids: Optional[torch.LongTensor] = None,
  698. head_mask: Optional[torch.FloatTensor] = None,
  699. inputs_embeds: Optional[torch.FloatTensor] = None,
  700. output_attentions: Optional[bool] = None,
  701. output_hidden_states: Optional[bool] = None,
  702. return_dict: Optional[bool] = None,
  703. ) -> Union[tuple, BaseModelOutputWithCrossAttentions]:
  704. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  705. output_hidden_states = (
  706. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  707. )
  708. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  709. if input_ids is not None and inputs_embeds is not None:
  710. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  711. elif input_ids is not None:
  712. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  713. input_shape = input_ids.size()
  714. elif inputs_embeds is not None:
  715. input_shape = inputs_embeds.size()[:-1]
  716. else:
  717. raise ValueError("You have to specify either input_ids or inputs_embeds")
  718. batch_size, seq_length = input_shape
  719. device = input_ids.device if input_ids is not None else inputs_embeds.device
  720. if attention_mask is None:
  721. attention_mask = torch.ones(input_shape, device=device)
  722. if token_type_ids is None:
  723. if hasattr(self.embeddings, "token_type_ids"):
  724. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  725. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  726. token_type_ids = buffered_token_type_ids_expanded
  727. else:
  728. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  729. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  730. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  731. hidden_states = self.embeddings(
  732. input_ids=input_ids, position_ids=position_ids, token_type_ids=token_type_ids, inputs_embeds=inputs_embeds
  733. )
  734. if hasattr(self, "embeddings_project"):
  735. hidden_states = self.embeddings_project(hidden_states)
  736. hidden_states = self.encoder(
  737. hidden_states,
  738. attention_mask=extended_attention_mask,
  739. head_mask=head_mask,
  740. output_attentions=output_attentions,
  741. output_hidden_states=output_hidden_states,
  742. return_dict=return_dict,
  743. )
  744. return hidden_states
  745. class ConvBertGeneratorPredictions(nn.Module):
  746. """Prediction module for the generator, made up of two dense layers."""
  747. def __init__(self, config):
  748. super().__init__()
  749. self.activation = get_activation("gelu")
  750. self.LayerNorm = nn.LayerNorm(config.embedding_size, eps=config.layer_norm_eps)
  751. self.dense = nn.Linear(config.hidden_size, config.embedding_size)
  752. def forward(self, generator_hidden_states: torch.FloatTensor) -> torch.FloatTensor:
  753. hidden_states = self.dense(generator_hidden_states)
  754. hidden_states = self.activation(hidden_states)
  755. hidden_states = self.LayerNorm(hidden_states)
  756. return hidden_states
  757. @auto_docstring
  758. class ConvBertForMaskedLM(ConvBertPreTrainedModel):
  759. _tied_weights_keys = ["generator.lm_head.weight"]
  760. def __init__(self, config):
  761. super().__init__(config)
  762. self.convbert = ConvBertModel(config)
  763. self.generator_predictions = ConvBertGeneratorPredictions(config)
  764. self.generator_lm_head = nn.Linear(config.embedding_size, config.vocab_size)
  765. # Initialize weights and apply final processing
  766. self.post_init()
  767. def get_output_embeddings(self):
  768. return self.generator_lm_head
  769. def set_output_embeddings(self, word_embeddings):
  770. self.generator_lm_head = word_embeddings
  771. @auto_docstring
  772. def forward(
  773. self,
  774. input_ids: Optional[torch.LongTensor] = None,
  775. attention_mask: Optional[torch.FloatTensor] = None,
  776. token_type_ids: Optional[torch.LongTensor] = None,
  777. position_ids: Optional[torch.LongTensor] = None,
  778. head_mask: Optional[torch.FloatTensor] = None,
  779. inputs_embeds: Optional[torch.FloatTensor] = None,
  780. labels: Optional[torch.LongTensor] = None,
  781. output_attentions: Optional[bool] = None,
  782. output_hidden_states: Optional[bool] = None,
  783. return_dict: Optional[bool] = None,
  784. ) -> Union[tuple, MaskedLMOutput]:
  785. r"""
  786. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  787. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  788. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  789. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  790. """
  791. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  792. generator_hidden_states = self.convbert(
  793. input_ids,
  794. attention_mask,
  795. token_type_ids,
  796. position_ids,
  797. head_mask,
  798. inputs_embeds,
  799. output_attentions,
  800. output_hidden_states,
  801. return_dict,
  802. )
  803. generator_sequence_output = generator_hidden_states[0]
  804. prediction_scores = self.generator_predictions(generator_sequence_output)
  805. prediction_scores = self.generator_lm_head(prediction_scores)
  806. loss = None
  807. # Masked language modeling softmax layer
  808. if labels is not None:
  809. loss_fct = nn.CrossEntropyLoss() # -100 index = padding token
  810. loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  811. if not return_dict:
  812. output = (prediction_scores,) + generator_hidden_states[1:]
  813. return ((loss,) + output) if loss is not None else output
  814. return MaskedLMOutput(
  815. loss=loss,
  816. logits=prediction_scores,
  817. hidden_states=generator_hidden_states.hidden_states,
  818. attentions=generator_hidden_states.attentions,
  819. )
  820. class ConvBertClassificationHead(nn.Module):
  821. """Head for sentence-level classification tasks."""
  822. def __init__(self, config):
  823. super().__init__()
  824. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  825. classifier_dropout = (
  826. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  827. )
  828. self.dropout = nn.Dropout(classifier_dropout)
  829. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  830. self.config = config
  831. def forward(self, hidden_states: torch.Tensor, **kwargs) -> torch.Tensor:
  832. x = hidden_states[:, 0, :] # take <s> token (equiv. to [CLS])
  833. x = self.dropout(x)
  834. x = self.dense(x)
  835. x = ACT2FN[self.config.hidden_act](x)
  836. x = self.dropout(x)
  837. x = self.out_proj(x)
  838. return x
  839. @auto_docstring(
  840. custom_intro="""
  841. ConvBERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  842. pooled output) e.g. for GLUE tasks.
  843. """
  844. )
  845. class ConvBertForSequenceClassification(ConvBertPreTrainedModel):
  846. def __init__(self, config):
  847. super().__init__(config)
  848. self.num_labels = config.num_labels
  849. self.config = config
  850. self.convbert = ConvBertModel(config)
  851. self.classifier = ConvBertClassificationHead(config)
  852. # Initialize weights and apply final processing
  853. self.post_init()
  854. @auto_docstring
  855. def forward(
  856. self,
  857. input_ids: Optional[torch.LongTensor] = None,
  858. attention_mask: Optional[torch.FloatTensor] = None,
  859. token_type_ids: Optional[torch.LongTensor] = None,
  860. position_ids: Optional[torch.LongTensor] = None,
  861. head_mask: Optional[torch.FloatTensor] = None,
  862. inputs_embeds: Optional[torch.FloatTensor] = None,
  863. labels: Optional[torch.LongTensor] = None,
  864. output_attentions: Optional[bool] = None,
  865. output_hidden_states: Optional[bool] = None,
  866. return_dict: Optional[bool] = None,
  867. ) -> Union[tuple, SequenceClassifierOutput]:
  868. r"""
  869. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  870. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  871. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  872. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  873. """
  874. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  875. outputs = self.convbert(
  876. input_ids,
  877. attention_mask=attention_mask,
  878. token_type_ids=token_type_ids,
  879. position_ids=position_ids,
  880. head_mask=head_mask,
  881. inputs_embeds=inputs_embeds,
  882. output_attentions=output_attentions,
  883. output_hidden_states=output_hidden_states,
  884. return_dict=return_dict,
  885. )
  886. sequence_output = outputs[0]
  887. logits = self.classifier(sequence_output)
  888. loss = None
  889. if labels is not None:
  890. if self.config.problem_type is None:
  891. if self.num_labels == 1:
  892. self.config.problem_type = "regression"
  893. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  894. self.config.problem_type = "single_label_classification"
  895. else:
  896. self.config.problem_type = "multi_label_classification"
  897. if self.config.problem_type == "regression":
  898. loss_fct = MSELoss()
  899. if self.num_labels == 1:
  900. loss = loss_fct(logits.squeeze(), labels.squeeze())
  901. else:
  902. loss = loss_fct(logits, labels)
  903. elif self.config.problem_type == "single_label_classification":
  904. loss_fct = CrossEntropyLoss()
  905. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  906. elif self.config.problem_type == "multi_label_classification":
  907. loss_fct = BCEWithLogitsLoss()
  908. loss = loss_fct(logits, labels)
  909. if not return_dict:
  910. output = (logits,) + outputs[1:]
  911. return ((loss,) + output) if loss is not None else output
  912. return SequenceClassifierOutput(
  913. loss=loss,
  914. logits=logits,
  915. hidden_states=outputs.hidden_states,
  916. attentions=outputs.attentions,
  917. )
  918. @auto_docstring
  919. class ConvBertForMultipleChoice(ConvBertPreTrainedModel):
  920. def __init__(self, config):
  921. super().__init__(config)
  922. self.convbert = ConvBertModel(config)
  923. self.sequence_summary = ConvBertSequenceSummary(config)
  924. self.classifier = nn.Linear(config.hidden_size, 1)
  925. # Initialize weights and apply final processing
  926. self.post_init()
  927. @auto_docstring
  928. def forward(
  929. self,
  930. input_ids: Optional[torch.LongTensor] = None,
  931. attention_mask: Optional[torch.FloatTensor] = None,
  932. token_type_ids: Optional[torch.LongTensor] = None,
  933. position_ids: Optional[torch.LongTensor] = None,
  934. head_mask: Optional[torch.FloatTensor] = None,
  935. inputs_embeds: Optional[torch.FloatTensor] = None,
  936. labels: Optional[torch.LongTensor] = None,
  937. output_attentions: Optional[bool] = None,
  938. output_hidden_states: Optional[bool] = None,
  939. return_dict: Optional[bool] = None,
  940. ) -> Union[tuple, MultipleChoiceModelOutput]:
  941. r"""
  942. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  943. Indices of input sequence tokens in the vocabulary.
  944. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  945. [`PreTrainedTokenizer.__call__`] for details.
  946. [What are input IDs?](../glossary#input-ids)
  947. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  948. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  949. 1]`:
  950. - 0 corresponds to a *sentence A* token,
  951. - 1 corresponds to a *sentence B* token.
  952. [What are token type IDs?](../glossary#token-type-ids)
  953. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  954. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  955. config.max_position_embeddings - 1]`.
  956. [What are position IDs?](../glossary#position-ids)
  957. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  958. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  959. is useful if you want more control over how to convert *input_ids* indices into associated vectors than the
  960. model's internal embedding lookup matrix.
  961. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  962. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  963. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  964. `input_ids` above)
  965. """
  966. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  967. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  968. input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  969. attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  970. token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  971. position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  972. inputs_embeds = (
  973. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  974. if inputs_embeds is not None
  975. else None
  976. )
  977. outputs = self.convbert(
  978. input_ids,
  979. attention_mask=attention_mask,
  980. token_type_ids=token_type_ids,
  981. position_ids=position_ids,
  982. head_mask=head_mask,
  983. inputs_embeds=inputs_embeds,
  984. output_attentions=output_attentions,
  985. output_hidden_states=output_hidden_states,
  986. return_dict=return_dict,
  987. )
  988. sequence_output = outputs[0]
  989. pooled_output = self.sequence_summary(sequence_output)
  990. logits = self.classifier(pooled_output)
  991. reshaped_logits = logits.view(-1, num_choices)
  992. loss = None
  993. if labels is not None:
  994. loss_fct = CrossEntropyLoss()
  995. loss = loss_fct(reshaped_logits, labels)
  996. if not return_dict:
  997. output = (reshaped_logits,) + outputs[1:]
  998. return ((loss,) + output) if loss is not None else output
  999. return MultipleChoiceModelOutput(
  1000. loss=loss,
  1001. logits=reshaped_logits,
  1002. hidden_states=outputs.hidden_states,
  1003. attentions=outputs.attentions,
  1004. )
  1005. @auto_docstring
  1006. class ConvBertForTokenClassification(ConvBertPreTrainedModel):
  1007. def __init__(self, config):
  1008. super().__init__(config)
  1009. self.num_labels = config.num_labels
  1010. self.convbert = ConvBertModel(config)
  1011. classifier_dropout = (
  1012. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1013. )
  1014. self.dropout = nn.Dropout(classifier_dropout)
  1015. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1016. # Initialize weights and apply final processing
  1017. self.post_init()
  1018. @auto_docstring
  1019. def forward(
  1020. self,
  1021. input_ids: Optional[torch.LongTensor] = None,
  1022. attention_mask: Optional[torch.FloatTensor] = None,
  1023. token_type_ids: Optional[torch.LongTensor] = None,
  1024. position_ids: Optional[torch.LongTensor] = None,
  1025. head_mask: Optional[torch.FloatTensor] = None,
  1026. inputs_embeds: Optional[torch.FloatTensor] = None,
  1027. labels: Optional[torch.LongTensor] = None,
  1028. output_attentions: Optional[bool] = None,
  1029. output_hidden_states: Optional[bool] = None,
  1030. return_dict: Optional[bool] = None,
  1031. ) -> Union[tuple, TokenClassifierOutput]:
  1032. r"""
  1033. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1034. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1035. """
  1036. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1037. outputs = self.convbert(
  1038. input_ids,
  1039. attention_mask=attention_mask,
  1040. token_type_ids=token_type_ids,
  1041. position_ids=position_ids,
  1042. head_mask=head_mask,
  1043. inputs_embeds=inputs_embeds,
  1044. output_attentions=output_attentions,
  1045. output_hidden_states=output_hidden_states,
  1046. return_dict=return_dict,
  1047. )
  1048. sequence_output = outputs[0]
  1049. sequence_output = self.dropout(sequence_output)
  1050. logits = self.classifier(sequence_output)
  1051. loss = None
  1052. if labels is not None:
  1053. loss_fct = CrossEntropyLoss()
  1054. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1055. if not return_dict:
  1056. output = (logits,) + outputs[1:]
  1057. return ((loss,) + output) if loss is not None else output
  1058. return TokenClassifierOutput(
  1059. loss=loss,
  1060. logits=logits,
  1061. hidden_states=outputs.hidden_states,
  1062. attentions=outputs.attentions,
  1063. )
  1064. @auto_docstring
  1065. class ConvBertForQuestionAnswering(ConvBertPreTrainedModel):
  1066. def __init__(self, config):
  1067. super().__init__(config)
  1068. self.num_labels = config.num_labels
  1069. self.convbert = ConvBertModel(config)
  1070. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1071. # Initialize weights and apply final processing
  1072. self.post_init()
  1073. @auto_docstring
  1074. def forward(
  1075. self,
  1076. input_ids: Optional[torch.LongTensor] = None,
  1077. attention_mask: Optional[torch.FloatTensor] = None,
  1078. token_type_ids: Optional[torch.LongTensor] = None,
  1079. position_ids: Optional[torch.LongTensor] = None,
  1080. head_mask: Optional[torch.FloatTensor] = None,
  1081. inputs_embeds: Optional[torch.FloatTensor] = None,
  1082. start_positions: Optional[torch.LongTensor] = None,
  1083. end_positions: Optional[torch.LongTensor] = None,
  1084. output_attentions: Optional[bool] = None,
  1085. output_hidden_states: Optional[bool] = None,
  1086. return_dict: Optional[bool] = None,
  1087. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  1088. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1089. outputs = self.convbert(
  1090. input_ids,
  1091. attention_mask=attention_mask,
  1092. token_type_ids=token_type_ids,
  1093. position_ids=position_ids,
  1094. head_mask=head_mask,
  1095. inputs_embeds=inputs_embeds,
  1096. output_attentions=output_attentions,
  1097. output_hidden_states=output_hidden_states,
  1098. return_dict=return_dict,
  1099. )
  1100. sequence_output = outputs[0]
  1101. logits = self.qa_outputs(sequence_output)
  1102. start_logits, end_logits = logits.split(1, dim=-1)
  1103. start_logits = start_logits.squeeze(-1).contiguous()
  1104. end_logits = end_logits.squeeze(-1).contiguous()
  1105. total_loss = None
  1106. if start_positions is not None and end_positions is not None:
  1107. # If we are on multi-GPU, split add a dimension
  1108. if len(start_positions.size()) > 1:
  1109. start_positions = start_positions.squeeze(-1)
  1110. if len(end_positions.size()) > 1:
  1111. end_positions = end_positions.squeeze(-1)
  1112. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1113. ignored_index = start_logits.size(1)
  1114. start_positions = start_positions.clamp(0, ignored_index)
  1115. end_positions = end_positions.clamp(0, ignored_index)
  1116. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1117. start_loss = loss_fct(start_logits, start_positions)
  1118. end_loss = loss_fct(end_logits, end_positions)
  1119. total_loss = (start_loss + end_loss) / 2
  1120. if not return_dict:
  1121. output = (start_logits, end_logits) + outputs[1:]
  1122. return ((total_loss,) + output) if total_loss is not None else output
  1123. return QuestionAnsweringModelOutput(
  1124. loss=total_loss,
  1125. start_logits=start_logits,
  1126. end_logits=end_logits,
  1127. hidden_states=outputs.hidden_states,
  1128. attentions=outputs.attentions,
  1129. )
  1130. __all__ = [
  1131. "ConvBertForMaskedLM",
  1132. "ConvBertForMultipleChoice",
  1133. "ConvBertForQuestionAnswering",
  1134. "ConvBertForSequenceClassification",
  1135. "ConvBertForTokenClassification",
  1136. "ConvBertLayer",
  1137. "ConvBertModel",
  1138. "ConvBertPreTrainedModel",
  1139. "load_tf_weights_in_convbert",
  1140. ]