modeling_ibert.py 50 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253
  1. # coding=utf-8
  2. # Copyright 2021 The I-BERT Authors (Sehoon Kim, Amir Gholami, Zhewei Yao,
  3. # Michael Mahoney, Kurt Keutzer - UC Berkeley) and The HuggingFace Inc. team.
  4. # Copyright (c) 20121, NVIDIA CORPORATION. All rights reserved.
  5. #
  6. # Licensed under the Apache License, Version 2.0 (the "License");
  7. # you may not use this file except in compliance with the License.
  8. # You may obtain a copy of the License at
  9. #
  10. # http://www.apache.org/licenses/LICENSE-2.0
  11. #
  12. # Unless required by applicable law or agreed to in writing, software
  13. # distributed under the License is distributed on an "AS IS" BASIS,
  14. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  15. # See the License for the specific language governing permissions and
  16. # limitations under the License.
  17. """PyTorch I-BERT model."""
  18. import math
  19. from typing import Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  23. from ...activations import gelu
  24. from ...modeling_outputs import (
  25. BaseModelOutputWithPastAndCrossAttentions,
  26. BaseModelOutputWithPoolingAndCrossAttentions,
  27. MaskedLMOutput,
  28. MultipleChoiceModelOutput,
  29. QuestionAnsweringModelOutput,
  30. SequenceClassifierOutput,
  31. TokenClassifierOutput,
  32. )
  33. from ...modeling_utils import PreTrainedModel
  34. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  35. from ...utils import auto_docstring, logging
  36. from .configuration_ibert import IBertConfig
  37. from .quant_modules import IntGELU, IntLayerNorm, IntSoftmax, QuantAct, QuantEmbedding, QuantLinear
  38. logger = logging.get_logger(__name__)
  39. class IBertEmbeddings(nn.Module):
  40. """
  41. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  42. """
  43. def __init__(self, config):
  44. super().__init__()
  45. self.quant_mode = config.quant_mode
  46. self.embedding_bit = 8
  47. self.embedding_act_bit = 16
  48. self.act_bit = 8
  49. self.ln_input_bit = 22
  50. self.ln_output_bit = 32
  51. self.word_embeddings = QuantEmbedding(
  52. config.vocab_size,
  53. config.hidden_size,
  54. padding_idx=config.pad_token_id,
  55. weight_bit=self.embedding_bit,
  56. quant_mode=self.quant_mode,
  57. )
  58. self.token_type_embeddings = QuantEmbedding(
  59. config.type_vocab_size, config.hidden_size, weight_bit=self.embedding_bit, quant_mode=self.quant_mode
  60. )
  61. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  62. self.register_buffer(
  63. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  64. )
  65. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  66. # End copy
  67. self.padding_idx = config.pad_token_id
  68. self.position_embeddings = QuantEmbedding(
  69. config.max_position_embeddings,
  70. config.hidden_size,
  71. padding_idx=self.padding_idx,
  72. weight_bit=self.embedding_bit,
  73. quant_mode=self.quant_mode,
  74. )
  75. # Integer-only addition between embeddings
  76. self.embeddings_act1 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
  77. self.embeddings_act2 = QuantAct(self.embedding_act_bit, quant_mode=self.quant_mode)
  78. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  79. # any TensorFlow checkpoint file
  80. self.LayerNorm = IntLayerNorm(
  81. config.hidden_size,
  82. eps=config.layer_norm_eps,
  83. output_bit=self.ln_output_bit,
  84. quant_mode=self.quant_mode,
  85. force_dequant=config.force_dequant,
  86. )
  87. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  88. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  89. def forward(
  90. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  91. ):
  92. if position_ids is None:
  93. if input_ids is not None:
  94. # Create the position ids from the input token ids. Any padded tokens remain padded.
  95. position_ids = create_position_ids_from_input_ids(
  96. input_ids, self.padding_idx, past_key_values_length
  97. ).to(input_ids.device)
  98. else:
  99. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  100. if input_ids is not None:
  101. input_shape = input_ids.size()
  102. else:
  103. input_shape = inputs_embeds.size()[:-1]
  104. if token_type_ids is None:
  105. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  106. if inputs_embeds is None:
  107. inputs_embeds, inputs_embeds_scaling_factor = self.word_embeddings(input_ids)
  108. else:
  109. inputs_embeds_scaling_factor = None
  110. token_type_embeddings, token_type_embeddings_scaling_factor = self.token_type_embeddings(token_type_ids)
  111. embeddings, embeddings_scaling_factor = self.embeddings_act1(
  112. inputs_embeds,
  113. inputs_embeds_scaling_factor,
  114. identity=token_type_embeddings,
  115. identity_scaling_factor=token_type_embeddings_scaling_factor,
  116. )
  117. if self.position_embedding_type == "absolute":
  118. position_embeddings, position_embeddings_scaling_factor = self.position_embeddings(position_ids)
  119. embeddings, embeddings_scaling_factor = self.embeddings_act1(
  120. embeddings,
  121. embeddings_scaling_factor,
  122. identity=position_embeddings,
  123. identity_scaling_factor=position_embeddings_scaling_factor,
  124. )
  125. embeddings, embeddings_scaling_factor = self.LayerNorm(embeddings, embeddings_scaling_factor)
  126. embeddings = self.dropout(embeddings)
  127. embeddings, embeddings_scaling_factor = self.output_activation(embeddings, embeddings_scaling_factor)
  128. return embeddings, embeddings_scaling_factor
  129. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  130. """
  131. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  132. Args:
  133. inputs_embeds: torch.Tensor
  134. Returns: torch.Tensor
  135. """
  136. input_shape = inputs_embeds.size()[:-1]
  137. sequence_length = input_shape[1]
  138. position_ids = torch.arange(
  139. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  140. )
  141. return position_ids.unsqueeze(0).expand(input_shape)
  142. class IBertSelfAttention(nn.Module):
  143. def __init__(self, config):
  144. super().__init__()
  145. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  146. raise ValueError(
  147. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  148. f"heads ({config.num_attention_heads})"
  149. )
  150. self.quant_mode = config.quant_mode
  151. self.weight_bit = 8
  152. self.bias_bit = 32
  153. self.act_bit = 8
  154. self.num_attention_heads = config.num_attention_heads
  155. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  156. self.all_head_size = self.num_attention_heads * self.attention_head_size
  157. # Q, K, V Linear layers
  158. self.query = QuantLinear(
  159. config.hidden_size,
  160. self.all_head_size,
  161. bias=True,
  162. weight_bit=self.weight_bit,
  163. bias_bit=self.bias_bit,
  164. quant_mode=self.quant_mode,
  165. per_channel=True,
  166. )
  167. self.key = QuantLinear(
  168. config.hidden_size,
  169. self.all_head_size,
  170. bias=True,
  171. weight_bit=self.weight_bit,
  172. bias_bit=self.bias_bit,
  173. quant_mode=self.quant_mode,
  174. per_channel=True,
  175. )
  176. self.value = QuantLinear(
  177. config.hidden_size,
  178. self.all_head_size,
  179. bias=True,
  180. weight_bit=self.weight_bit,
  181. bias_bit=self.bias_bit,
  182. quant_mode=self.quant_mode,
  183. per_channel=True,
  184. )
  185. # Requantization (32bit -> 8bit) for Q, K, V activations
  186. self.query_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  187. self.key_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  188. self.value_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  189. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  190. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  191. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  192. if self.position_embedding_type != "absolute":
  193. raise ValueError("I-BERT only supports 'absolute' for `config.position_embedding_type`")
  194. self.softmax = IntSoftmax(self.act_bit, quant_mode=self.quant_mode, force_dequant=config.force_dequant)
  195. def forward(
  196. self,
  197. hidden_states,
  198. hidden_states_scaling_factor,
  199. attention_mask=None,
  200. head_mask=None,
  201. output_attentions=False,
  202. ):
  203. # Projection
  204. mixed_query_layer, mixed_query_layer_scaling_factor = self.query(hidden_states, hidden_states_scaling_factor)
  205. mixed_key_layer, mixed_key_layer_scaling_factor = self.key(hidden_states, hidden_states_scaling_factor)
  206. mixed_value_layer, mixed_value_layer_scaling_factor = self.value(hidden_states, hidden_states_scaling_factor)
  207. # Requantization
  208. query_layer, query_layer_scaling_factor = self.query_activation(
  209. mixed_query_layer, mixed_query_layer_scaling_factor
  210. )
  211. key_layer, key_layer_scaling_factor = self.key_activation(mixed_key_layer, mixed_key_layer_scaling_factor)
  212. value_layer, value_layer_scaling_factor = self.value_activation(
  213. mixed_value_layer, mixed_value_layer_scaling_factor
  214. )
  215. # Transpose
  216. batch_size, seq_length, _ = hidden_states.shape
  217. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  218. 1, 2
  219. )
  220. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  221. value_layer = value_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  222. 1, 2
  223. )
  224. # Take the dot product between "query" and "key" to get the raw attention scores.
  225. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  226. scale = math.sqrt(self.attention_head_size)
  227. attention_scores = attention_scores / scale
  228. if self.quant_mode:
  229. attention_scores_scaling_factor = query_layer_scaling_factor * key_layer_scaling_factor / scale
  230. else:
  231. attention_scores_scaling_factor = None
  232. if attention_mask is not None:
  233. # Apply the attention mask is (precomputed for all layers in IBertModel forward() function)
  234. attention_scores = attention_scores + attention_mask
  235. # Normalize the attention scores to probabilities.
  236. attention_probs, attention_probs_scaling_factor = self.softmax(
  237. attention_scores, attention_scores_scaling_factor
  238. )
  239. # This is actually dropping out entire tokens to attend to, which might
  240. # seem a bit unusual, but is taken from the original Transformer paper.
  241. attention_probs = self.dropout(attention_probs)
  242. # Mask heads if we want to
  243. if head_mask is not None:
  244. attention_probs = attention_probs * head_mask
  245. context_layer = torch.matmul(attention_probs, value_layer)
  246. if attention_probs_scaling_factor is not None:
  247. context_layer_scaling_factor = attention_probs_scaling_factor * value_layer_scaling_factor
  248. else:
  249. context_layer_scaling_factor = None
  250. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  251. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  252. context_layer = context_layer.view(*new_context_layer_shape)
  253. # requantization: 32-bit -> 8-bit
  254. context_layer, context_layer_scaling_factor = self.output_activation(
  255. context_layer, context_layer_scaling_factor
  256. )
  257. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  258. output_scaling_factor = (
  259. (context_layer_scaling_factor, attention_probs_scaling_factor)
  260. if output_attentions
  261. else (context_layer_scaling_factor,)
  262. )
  263. return outputs, output_scaling_factor
  264. class IBertSelfOutput(nn.Module):
  265. def __init__(self, config):
  266. super().__init__()
  267. self.quant_mode = config.quant_mode
  268. self.act_bit = 8
  269. self.weight_bit = 8
  270. self.bias_bit = 32
  271. self.ln_input_bit = 22
  272. self.ln_output_bit = 32
  273. self.dense = QuantLinear(
  274. config.hidden_size,
  275. config.hidden_size,
  276. bias=True,
  277. weight_bit=self.weight_bit,
  278. bias_bit=self.bias_bit,
  279. quant_mode=self.quant_mode,
  280. per_channel=True,
  281. )
  282. self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
  283. self.LayerNorm = IntLayerNorm(
  284. config.hidden_size,
  285. eps=config.layer_norm_eps,
  286. output_bit=self.ln_output_bit,
  287. quant_mode=self.quant_mode,
  288. force_dequant=config.force_dequant,
  289. )
  290. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  291. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  292. def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
  293. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  294. hidden_states = self.dropout(hidden_states)
  295. hidden_states, hidden_states_scaling_factor = self.ln_input_act(
  296. hidden_states,
  297. hidden_states_scaling_factor,
  298. identity=input_tensor,
  299. identity_scaling_factor=input_tensor_scaling_factor,
  300. )
  301. hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
  302. hidden_states, hidden_states_scaling_factor = self.output_activation(
  303. hidden_states, hidden_states_scaling_factor
  304. )
  305. return hidden_states, hidden_states_scaling_factor
  306. class IBertAttention(nn.Module):
  307. def __init__(self, config):
  308. super().__init__()
  309. self.quant_mode = config.quant_mode
  310. self.self = IBertSelfAttention(config)
  311. self.output = IBertSelfOutput(config)
  312. self.pruned_heads = set()
  313. def prune_heads(self, heads):
  314. if len(heads) == 0:
  315. return
  316. heads, index = find_pruneable_heads_and_indices(
  317. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  318. )
  319. # Prune linear layers
  320. self.self.query = prune_linear_layer(self.self.query, index)
  321. self.self.key = prune_linear_layer(self.self.key, index)
  322. self.self.value = prune_linear_layer(self.self.value, index)
  323. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  324. # Update hyper params and store pruned heads
  325. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  326. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  327. self.pruned_heads = self.pruned_heads.union(heads)
  328. def forward(
  329. self,
  330. hidden_states,
  331. hidden_states_scaling_factor,
  332. attention_mask=None,
  333. head_mask=None,
  334. output_attentions=False,
  335. ):
  336. self_outputs, self_outputs_scaling_factor = self.self(
  337. hidden_states,
  338. hidden_states_scaling_factor,
  339. attention_mask,
  340. head_mask,
  341. output_attentions,
  342. )
  343. attention_output, attention_output_scaling_factor = self.output(
  344. self_outputs[0], self_outputs_scaling_factor[0], hidden_states, hidden_states_scaling_factor
  345. )
  346. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  347. outputs_scaling_factor = (attention_output_scaling_factor,) + self_outputs_scaling_factor[1:]
  348. return outputs, outputs_scaling_factor
  349. class IBertIntermediate(nn.Module):
  350. def __init__(self, config):
  351. super().__init__()
  352. self.quant_mode = config.quant_mode
  353. self.act_bit = 8
  354. self.weight_bit = 8
  355. self.bias_bit = 32
  356. self.dense = QuantLinear(
  357. config.hidden_size,
  358. config.intermediate_size,
  359. bias=True,
  360. weight_bit=self.weight_bit,
  361. bias_bit=self.bias_bit,
  362. quant_mode=self.quant_mode,
  363. per_channel=True,
  364. )
  365. if config.hidden_act != "gelu":
  366. raise ValueError("I-BERT only supports 'gelu' for `config.hidden_act`")
  367. self.intermediate_act_fn = IntGELU(quant_mode=self.quant_mode, force_dequant=config.force_dequant)
  368. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  369. def forward(self, hidden_states, hidden_states_scaling_factor):
  370. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  371. hidden_states, hidden_states_scaling_factor = self.intermediate_act_fn(
  372. hidden_states, hidden_states_scaling_factor
  373. )
  374. # Requantization: 32bit -> 8-bit
  375. hidden_states, hidden_states_scaling_factor = self.output_activation(
  376. hidden_states, hidden_states_scaling_factor
  377. )
  378. return hidden_states, hidden_states_scaling_factor
  379. class IBertOutput(nn.Module):
  380. def __init__(self, config):
  381. super().__init__()
  382. self.quant_mode = config.quant_mode
  383. self.act_bit = 8
  384. self.weight_bit = 8
  385. self.bias_bit = 32
  386. self.ln_input_bit = 22
  387. self.ln_output_bit = 32
  388. self.dense = QuantLinear(
  389. config.intermediate_size,
  390. config.hidden_size,
  391. bias=True,
  392. weight_bit=self.weight_bit,
  393. bias_bit=self.bias_bit,
  394. quant_mode=self.quant_mode,
  395. per_channel=True,
  396. )
  397. self.ln_input_act = QuantAct(self.ln_input_bit, quant_mode=self.quant_mode)
  398. self.LayerNorm = IntLayerNorm(
  399. config.hidden_size,
  400. eps=config.layer_norm_eps,
  401. output_bit=self.ln_output_bit,
  402. quant_mode=self.quant_mode,
  403. force_dequant=config.force_dequant,
  404. )
  405. self.output_activation = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  406. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  407. def forward(self, hidden_states, hidden_states_scaling_factor, input_tensor, input_tensor_scaling_factor):
  408. hidden_states, hidden_states_scaling_factor = self.dense(hidden_states, hidden_states_scaling_factor)
  409. hidden_states = self.dropout(hidden_states)
  410. hidden_states, hidden_states_scaling_factor = self.ln_input_act(
  411. hidden_states,
  412. hidden_states_scaling_factor,
  413. identity=input_tensor,
  414. identity_scaling_factor=input_tensor_scaling_factor,
  415. )
  416. hidden_states, hidden_states_scaling_factor = self.LayerNorm(hidden_states, hidden_states_scaling_factor)
  417. hidden_states, hidden_states_scaling_factor = self.output_activation(
  418. hidden_states, hidden_states_scaling_factor
  419. )
  420. return hidden_states, hidden_states_scaling_factor
  421. class IBertLayer(nn.Module):
  422. def __init__(self, config):
  423. super().__init__()
  424. self.quant_mode = config.quant_mode
  425. self.act_bit = 8
  426. self.seq_len_dim = 1
  427. self.attention = IBertAttention(config)
  428. self.intermediate = IBertIntermediate(config)
  429. self.output = IBertOutput(config)
  430. self.pre_intermediate_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  431. self.pre_output_act = QuantAct(self.act_bit, quant_mode=self.quant_mode)
  432. def forward(
  433. self,
  434. hidden_states,
  435. hidden_states_scaling_factor,
  436. attention_mask=None,
  437. head_mask=None,
  438. output_attentions=False,
  439. ):
  440. self_attention_outputs, self_attention_outputs_scaling_factor = self.attention(
  441. hidden_states,
  442. hidden_states_scaling_factor,
  443. attention_mask,
  444. head_mask,
  445. output_attentions=output_attentions,
  446. )
  447. attention_output = self_attention_outputs[0]
  448. attention_output_scaling_factor = self_attention_outputs_scaling_factor[0]
  449. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  450. layer_output, layer_output_scaling_factor = self.feed_forward_chunk(
  451. attention_output, attention_output_scaling_factor
  452. )
  453. outputs = (layer_output,) + outputs
  454. return outputs
  455. def feed_forward_chunk(self, attention_output, attention_output_scaling_factor):
  456. attention_output, attention_output_scaling_factor = self.pre_intermediate_act(
  457. attention_output, attention_output_scaling_factor
  458. )
  459. intermediate_output, intermediate_output_scaling_factor = self.intermediate(
  460. attention_output, attention_output_scaling_factor
  461. )
  462. intermediate_output, intermediate_output_scaling_factor = self.pre_output_act(
  463. intermediate_output, intermediate_output_scaling_factor
  464. )
  465. layer_output, layer_output_scaling_factor = self.output(
  466. intermediate_output, intermediate_output_scaling_factor, attention_output, attention_output_scaling_factor
  467. )
  468. return layer_output, layer_output_scaling_factor
  469. class IBertEncoder(nn.Module):
  470. def __init__(self, config):
  471. super().__init__()
  472. self.config = config
  473. self.quant_mode = config.quant_mode
  474. self.layer = nn.ModuleList([IBertLayer(config) for _ in range(config.num_hidden_layers)])
  475. def forward(
  476. self,
  477. hidden_states,
  478. hidden_states_scaling_factor,
  479. attention_mask=None,
  480. head_mask=None,
  481. output_attentions=False,
  482. output_hidden_states=False,
  483. return_dict=True,
  484. ):
  485. all_hidden_states = () if output_hidden_states else None
  486. all_self_attentions = () if output_attentions else None
  487. all_cross_attentions = None # `config.add_cross_attention` is not supported
  488. for i, layer_module in enumerate(self.layer):
  489. if output_hidden_states:
  490. all_hidden_states = all_hidden_states + (hidden_states,)
  491. layer_head_mask = head_mask[i] if head_mask is not None else None
  492. layer_outputs = layer_module(
  493. hidden_states,
  494. hidden_states_scaling_factor,
  495. attention_mask,
  496. layer_head_mask,
  497. output_attentions,
  498. )
  499. hidden_states = layer_outputs[0]
  500. if output_attentions:
  501. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  502. if output_hidden_states:
  503. all_hidden_states = all_hidden_states + (hidden_states,)
  504. if not return_dict:
  505. return tuple(
  506. v
  507. for v in [
  508. hidden_states,
  509. all_hidden_states,
  510. all_self_attentions,
  511. all_cross_attentions,
  512. ]
  513. if v is not None
  514. )
  515. return BaseModelOutputWithPastAndCrossAttentions(
  516. last_hidden_state=hidden_states,
  517. hidden_states=all_hidden_states,
  518. attentions=all_self_attentions,
  519. cross_attentions=all_cross_attentions,
  520. )
  521. class IBertPooler(nn.Module):
  522. def __init__(self, config):
  523. super().__init__()
  524. self.quant_mode = config.quant_mode
  525. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  526. self.activation = nn.Tanh()
  527. def forward(self, hidden_states):
  528. # We "pool" the model by simply taking the hidden state corresponding
  529. # to the first token.
  530. first_token_tensor = hidden_states[:, 0]
  531. pooled_output = self.dense(first_token_tensor)
  532. pooled_output = self.activation(pooled_output)
  533. return pooled_output
  534. @auto_docstring
  535. class IBertPreTrainedModel(PreTrainedModel):
  536. config: IBertConfig
  537. base_model_prefix = "ibert"
  538. def _init_weights(self, module):
  539. """Initialize the weights"""
  540. if isinstance(module, (QuantLinear, nn.Linear)):
  541. # Slightly different from the TF version which uses truncated_normal for initialization
  542. # cf https://github.com/pytorch/pytorch/pull/5617
  543. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  544. if module.bias is not None:
  545. module.bias.data.zero_()
  546. elif isinstance(module, (QuantEmbedding, nn.Embedding)):
  547. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  548. if module.padding_idx is not None:
  549. module.weight.data[module.padding_idx].zero_()
  550. elif isinstance(module, (IntLayerNorm, nn.LayerNorm)):
  551. module.bias.data.zero_()
  552. module.weight.data.fill_(1.0)
  553. elif isinstance(module, IBertLMHead):
  554. module.bias.data.zero_()
  555. def resize_token_embeddings(self, new_num_tokens=None):
  556. raise NotImplementedError("`resize_token_embeddings` is not supported for I-BERT.")
  557. @auto_docstring
  558. class IBertModel(IBertPreTrainedModel):
  559. """
  560. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  561. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  562. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  563. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  564. """
  565. def __init__(self, config, add_pooling_layer=True):
  566. r"""
  567. add_pooling_layer (bool, *optional*, defaults to `True`):
  568. Whether to add a pooling layer
  569. """
  570. super().__init__(config)
  571. self.config = config
  572. self.quant_mode = config.quant_mode
  573. self.embeddings = IBertEmbeddings(config)
  574. self.encoder = IBertEncoder(config)
  575. self.pooler = IBertPooler(config) if add_pooling_layer else None
  576. # Initialize weights and apply final processing
  577. self.post_init()
  578. def get_input_embeddings(self):
  579. return self.embeddings.word_embeddings
  580. def set_input_embeddings(self, value):
  581. self.embeddings.word_embeddings = value
  582. def _prune_heads(self, heads_to_prune):
  583. """
  584. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  585. class PreTrainedModel
  586. """
  587. for layer, heads in heads_to_prune.items():
  588. self.encoder.layer[layer].attention.prune_heads(heads)
  589. @auto_docstring
  590. def forward(
  591. self,
  592. input_ids: Optional[torch.LongTensor] = None,
  593. attention_mask: Optional[torch.FloatTensor] = None,
  594. token_type_ids: Optional[torch.LongTensor] = None,
  595. position_ids: Optional[torch.LongTensor] = None,
  596. head_mask: Optional[torch.FloatTensor] = None,
  597. inputs_embeds: Optional[torch.FloatTensor] = None,
  598. output_attentions: Optional[bool] = None,
  599. output_hidden_states: Optional[bool] = None,
  600. return_dict: Optional[bool] = None,
  601. ) -> Union[BaseModelOutputWithPoolingAndCrossAttentions, tuple[torch.FloatTensor]]:
  602. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  603. output_hidden_states = (
  604. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  605. )
  606. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  607. if input_ids is not None and inputs_embeds is not None:
  608. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  609. elif input_ids is not None:
  610. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  611. input_shape = input_ids.size()
  612. elif inputs_embeds is not None:
  613. input_shape = inputs_embeds.size()[:-1]
  614. else:
  615. raise ValueError("You have to specify either input_ids or inputs_embeds")
  616. batch_size, seq_length = input_shape
  617. device = input_ids.device if input_ids is not None else inputs_embeds.device
  618. if attention_mask is None:
  619. attention_mask = torch.ones(((batch_size, seq_length)), device=device)
  620. if token_type_ids is None:
  621. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  622. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  623. # ourselves in which case we just need to make it broadcastable to all heads.
  624. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  625. # Prepare head mask if needed
  626. # 1.0 in head_mask indicate we keep the head
  627. # attention_probs has shape bsz x n_heads x N x N
  628. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  629. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  630. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  631. embedding_output, embedding_output_scaling_factor = self.embeddings(
  632. input_ids=input_ids,
  633. position_ids=position_ids,
  634. token_type_ids=token_type_ids,
  635. inputs_embeds=inputs_embeds,
  636. )
  637. encoder_outputs = self.encoder(
  638. embedding_output,
  639. embedding_output_scaling_factor,
  640. attention_mask=extended_attention_mask,
  641. head_mask=head_mask,
  642. output_attentions=output_attentions,
  643. output_hidden_states=output_hidden_states,
  644. return_dict=return_dict,
  645. )
  646. sequence_output = encoder_outputs[0]
  647. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  648. if not return_dict:
  649. return (sequence_output, pooled_output) + encoder_outputs[1:]
  650. return BaseModelOutputWithPoolingAndCrossAttentions(
  651. last_hidden_state=sequence_output,
  652. pooler_output=pooled_output,
  653. hidden_states=encoder_outputs.hidden_states,
  654. attentions=encoder_outputs.attentions,
  655. cross_attentions=encoder_outputs.cross_attentions,
  656. )
  657. @auto_docstring
  658. class IBertForMaskedLM(IBertPreTrainedModel):
  659. _tied_weights_keys = ["lm_head.decoder.bias", "lm_head.decoder.weight"]
  660. def __init__(self, config):
  661. super().__init__(config)
  662. self.ibert = IBertModel(config, add_pooling_layer=False)
  663. self.lm_head = IBertLMHead(config)
  664. # Initialize weights and apply final processing
  665. self.post_init()
  666. def get_output_embeddings(self):
  667. return self.lm_head.decoder
  668. def set_output_embeddings(self, new_embeddings):
  669. self.lm_head.decoder = new_embeddings
  670. self.lm_head.bias = new_embeddings.bias
  671. @auto_docstring
  672. def forward(
  673. self,
  674. input_ids: Optional[torch.LongTensor] = None,
  675. attention_mask: Optional[torch.FloatTensor] = None,
  676. token_type_ids: Optional[torch.LongTensor] = None,
  677. position_ids: Optional[torch.LongTensor] = None,
  678. head_mask: Optional[torch.FloatTensor] = None,
  679. inputs_embeds: Optional[torch.FloatTensor] = None,
  680. labels: Optional[torch.LongTensor] = None,
  681. output_attentions: Optional[bool] = None,
  682. output_hidden_states: Optional[bool] = None,
  683. return_dict: Optional[bool] = None,
  684. ) -> Union[MaskedLMOutput, tuple[torch.FloatTensor]]:
  685. r"""
  686. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  687. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  688. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  689. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  690. """
  691. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  692. outputs = self.ibert(
  693. input_ids,
  694. attention_mask=attention_mask,
  695. token_type_ids=token_type_ids,
  696. position_ids=position_ids,
  697. head_mask=head_mask,
  698. inputs_embeds=inputs_embeds,
  699. output_attentions=output_attentions,
  700. output_hidden_states=output_hidden_states,
  701. return_dict=return_dict,
  702. )
  703. sequence_output = outputs[0]
  704. prediction_scores = self.lm_head(sequence_output)
  705. masked_lm_loss = None
  706. if labels is not None:
  707. loss_fct = CrossEntropyLoss()
  708. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  709. if not return_dict:
  710. output = (prediction_scores,) + outputs[2:]
  711. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  712. return MaskedLMOutput(
  713. loss=masked_lm_loss,
  714. logits=prediction_scores,
  715. hidden_states=outputs.hidden_states,
  716. attentions=outputs.attentions,
  717. )
  718. class IBertLMHead(nn.Module):
  719. """I-BERT Head for masked language modeling."""
  720. def __init__(self, config):
  721. super().__init__()
  722. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  723. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  724. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  725. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  726. self.decoder.bias = self.bias
  727. def forward(self, features, **kwargs):
  728. x = self.dense(features)
  729. x = gelu(x)
  730. x = self.layer_norm(x)
  731. # project back to size of vocabulary with bias
  732. x = self.decoder(x)
  733. return x
  734. def _tie_weights(self) -> None:
  735. # For accelerate compatibility and to not break backward compatibility
  736. if self.decoder.bias.device.type == "meta":
  737. self.decoder.bias = self.bias
  738. else:
  739. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  740. self.bias = self.decoder.bias
  741. @auto_docstring(
  742. custom_intro="""
  743. I-BERT Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  744. output) e.g. for GLUE tasks.
  745. """
  746. )
  747. class IBertForSequenceClassification(IBertPreTrainedModel):
  748. def __init__(self, config):
  749. super().__init__(config)
  750. self.num_labels = config.num_labels
  751. self.ibert = IBertModel(config, add_pooling_layer=False)
  752. self.classifier = IBertClassificationHead(config)
  753. # Initialize weights and apply final processing
  754. self.post_init()
  755. @auto_docstring
  756. def forward(
  757. self,
  758. input_ids: Optional[torch.LongTensor] = None,
  759. attention_mask: Optional[torch.FloatTensor] = None,
  760. token_type_ids: Optional[torch.LongTensor] = None,
  761. position_ids: Optional[torch.LongTensor] = None,
  762. head_mask: Optional[torch.FloatTensor] = None,
  763. inputs_embeds: Optional[torch.FloatTensor] = None,
  764. labels: Optional[torch.LongTensor] = None,
  765. output_attentions: Optional[bool] = None,
  766. output_hidden_states: Optional[bool] = None,
  767. return_dict: Optional[bool] = None,
  768. ) -> Union[SequenceClassifierOutput, tuple[torch.FloatTensor]]:
  769. r"""
  770. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  771. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  772. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  773. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  774. """
  775. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  776. outputs = self.ibert(
  777. input_ids,
  778. attention_mask=attention_mask,
  779. token_type_ids=token_type_ids,
  780. position_ids=position_ids,
  781. head_mask=head_mask,
  782. inputs_embeds=inputs_embeds,
  783. output_attentions=output_attentions,
  784. output_hidden_states=output_hidden_states,
  785. return_dict=return_dict,
  786. )
  787. sequence_output = outputs[0]
  788. logits = self.classifier(sequence_output)
  789. loss = None
  790. if labels is not None:
  791. if self.config.problem_type is None:
  792. if self.num_labels == 1:
  793. self.config.problem_type = "regression"
  794. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  795. self.config.problem_type = "single_label_classification"
  796. else:
  797. self.config.problem_type = "multi_label_classification"
  798. if self.config.problem_type == "regression":
  799. loss_fct = MSELoss()
  800. if self.num_labels == 1:
  801. loss = loss_fct(logits.squeeze(), labels.squeeze())
  802. else:
  803. loss = loss_fct(logits, labels)
  804. elif self.config.problem_type == "single_label_classification":
  805. loss_fct = CrossEntropyLoss()
  806. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  807. elif self.config.problem_type == "multi_label_classification":
  808. loss_fct = BCEWithLogitsLoss()
  809. loss = loss_fct(logits, labels)
  810. if not return_dict:
  811. output = (logits,) + outputs[2:]
  812. return ((loss,) + output) if loss is not None else output
  813. return SequenceClassifierOutput(
  814. loss=loss,
  815. logits=logits,
  816. hidden_states=outputs.hidden_states,
  817. attentions=outputs.attentions,
  818. )
  819. @auto_docstring
  820. class IBertForMultipleChoice(IBertPreTrainedModel):
  821. def __init__(self, config):
  822. super().__init__(config)
  823. self.ibert = IBertModel(config)
  824. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  825. self.classifier = nn.Linear(config.hidden_size, 1)
  826. # Initialize weights and apply final processing
  827. self.post_init()
  828. @auto_docstring
  829. def forward(
  830. self,
  831. input_ids: Optional[torch.LongTensor] = None,
  832. token_type_ids: Optional[torch.LongTensor] = None,
  833. attention_mask: Optional[torch.FloatTensor] = None,
  834. labels: Optional[torch.LongTensor] = None,
  835. position_ids: Optional[torch.LongTensor] = None,
  836. head_mask: Optional[torch.FloatTensor] = None,
  837. inputs_embeds: Optional[torch.FloatTensor] = None,
  838. output_attentions: Optional[bool] = None,
  839. output_hidden_states: Optional[bool] = None,
  840. return_dict: Optional[bool] = None,
  841. ) -> Union[MultipleChoiceModelOutput, tuple[torch.FloatTensor]]:
  842. r"""
  843. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  844. Indices of input sequence tokens in the vocabulary.
  845. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  846. [`PreTrainedTokenizer.__call__`] for details.
  847. [What are input IDs?](../glossary#input-ids)
  848. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  849. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  850. 1]`:
  851. - 0 corresponds to a *sentence A* token,
  852. - 1 corresponds to a *sentence B* token.
  853. [What are token type IDs?](../glossary#token-type-ids)
  854. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  855. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  856. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  857. `input_ids` above)
  858. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  859. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  860. config.max_position_embeddings - 1]`.
  861. [What are position IDs?](../glossary#position-ids)
  862. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  863. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  864. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  865. model's internal embedding lookup matrix.
  866. """
  867. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  868. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  869. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  870. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  871. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  872. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  873. flat_inputs_embeds = (
  874. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  875. if inputs_embeds is not None
  876. else None
  877. )
  878. outputs = self.ibert(
  879. flat_input_ids,
  880. position_ids=flat_position_ids,
  881. token_type_ids=flat_token_type_ids,
  882. attention_mask=flat_attention_mask,
  883. head_mask=head_mask,
  884. inputs_embeds=flat_inputs_embeds,
  885. output_attentions=output_attentions,
  886. output_hidden_states=output_hidden_states,
  887. return_dict=return_dict,
  888. )
  889. pooled_output = outputs[1]
  890. pooled_output = self.dropout(pooled_output)
  891. logits = self.classifier(pooled_output)
  892. reshaped_logits = logits.view(-1, num_choices)
  893. loss = None
  894. if labels is not None:
  895. loss_fct = CrossEntropyLoss()
  896. loss = loss_fct(reshaped_logits, labels)
  897. if not return_dict:
  898. output = (reshaped_logits,) + outputs[2:]
  899. return ((loss,) + output) if loss is not None else output
  900. return MultipleChoiceModelOutput(
  901. loss=loss,
  902. logits=reshaped_logits,
  903. hidden_states=outputs.hidden_states,
  904. attentions=outputs.attentions,
  905. )
  906. @auto_docstring
  907. class IBertForTokenClassification(IBertPreTrainedModel):
  908. def __init__(self, config):
  909. super().__init__(config)
  910. self.num_labels = config.num_labels
  911. self.ibert = IBertModel(config, add_pooling_layer=False)
  912. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  913. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  914. # Initialize weights and apply final processing
  915. self.post_init()
  916. @auto_docstring
  917. def forward(
  918. self,
  919. input_ids: Optional[torch.LongTensor] = None,
  920. attention_mask: Optional[torch.FloatTensor] = None,
  921. token_type_ids: Optional[torch.LongTensor] = None,
  922. position_ids: Optional[torch.LongTensor] = None,
  923. head_mask: Optional[torch.FloatTensor] = None,
  924. inputs_embeds: Optional[torch.FloatTensor] = None,
  925. labels: Optional[torch.LongTensor] = None,
  926. output_attentions: Optional[bool] = None,
  927. output_hidden_states: Optional[bool] = None,
  928. return_dict: Optional[bool] = None,
  929. ) -> Union[TokenClassifierOutput, tuple[torch.FloatTensor]]:
  930. r"""
  931. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  932. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  933. """
  934. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  935. outputs = self.ibert(
  936. input_ids,
  937. attention_mask=attention_mask,
  938. token_type_ids=token_type_ids,
  939. position_ids=position_ids,
  940. head_mask=head_mask,
  941. inputs_embeds=inputs_embeds,
  942. output_attentions=output_attentions,
  943. output_hidden_states=output_hidden_states,
  944. return_dict=return_dict,
  945. )
  946. sequence_output = outputs[0]
  947. sequence_output = self.dropout(sequence_output)
  948. logits = self.classifier(sequence_output)
  949. loss = None
  950. if labels is not None:
  951. loss_fct = CrossEntropyLoss()
  952. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  953. if not return_dict:
  954. output = (logits,) + outputs[2:]
  955. return ((loss,) + output) if loss is not None else output
  956. return TokenClassifierOutput(
  957. loss=loss,
  958. logits=logits,
  959. hidden_states=outputs.hidden_states,
  960. attentions=outputs.attentions,
  961. )
  962. class IBertClassificationHead(nn.Module):
  963. """Head for sentence-level classification tasks."""
  964. def __init__(self, config):
  965. super().__init__()
  966. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  967. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  968. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  969. def forward(self, features, **kwargs):
  970. hidden_states = features[:, 0, :] # take <s> token (equiv. to [CLS])
  971. hidden_states = self.dropout(hidden_states)
  972. hidden_states = self.dense(hidden_states)
  973. hidden_states = torch.tanh(hidden_states)
  974. hidden_states = self.dropout(hidden_states)
  975. hidden_states = self.out_proj(hidden_states)
  976. return hidden_states
  977. @auto_docstring
  978. class IBertForQuestionAnswering(IBertPreTrainedModel):
  979. def __init__(self, config):
  980. super().__init__(config)
  981. self.num_labels = config.num_labels
  982. self.ibert = IBertModel(config, add_pooling_layer=False)
  983. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  984. # Initialize weights and apply final processing
  985. self.post_init()
  986. @auto_docstring
  987. def forward(
  988. self,
  989. input_ids: Optional[torch.LongTensor] = None,
  990. attention_mask: Optional[torch.FloatTensor] = None,
  991. token_type_ids: Optional[torch.LongTensor] = None,
  992. position_ids: Optional[torch.LongTensor] = None,
  993. head_mask: Optional[torch.FloatTensor] = None,
  994. inputs_embeds: Optional[torch.FloatTensor] = None,
  995. start_positions: Optional[torch.LongTensor] = None,
  996. end_positions: Optional[torch.LongTensor] = None,
  997. output_attentions: Optional[bool] = None,
  998. output_hidden_states: Optional[bool] = None,
  999. return_dict: Optional[bool] = None,
  1000. ) -> Union[QuestionAnsweringModelOutput, tuple[torch.FloatTensor]]:
  1001. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1002. outputs = self.ibert(
  1003. input_ids,
  1004. attention_mask=attention_mask,
  1005. token_type_ids=token_type_ids,
  1006. position_ids=position_ids,
  1007. head_mask=head_mask,
  1008. inputs_embeds=inputs_embeds,
  1009. output_attentions=output_attentions,
  1010. output_hidden_states=output_hidden_states,
  1011. return_dict=return_dict,
  1012. )
  1013. sequence_output = outputs[0]
  1014. logits = self.qa_outputs(sequence_output)
  1015. start_logits, end_logits = logits.split(1, dim=-1)
  1016. start_logits = start_logits.squeeze(-1).contiguous()
  1017. end_logits = end_logits.squeeze(-1).contiguous()
  1018. total_loss = None
  1019. if start_positions is not None and end_positions is not None:
  1020. # If we are on multi-GPU, split add a dimension
  1021. if len(start_positions.size()) > 1:
  1022. start_positions = start_positions.squeeze(-1)
  1023. if len(end_positions.size()) > 1:
  1024. end_positions = end_positions.squeeze(-1)
  1025. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1026. ignored_index = start_logits.size(1)
  1027. start_positions = start_positions.clamp(0, ignored_index)
  1028. end_positions = end_positions.clamp(0, ignored_index)
  1029. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1030. start_loss = loss_fct(start_logits, start_positions)
  1031. end_loss = loss_fct(end_logits, end_positions)
  1032. total_loss = (start_loss + end_loss) / 2
  1033. if not return_dict:
  1034. output = (start_logits, end_logits) + outputs[2:]
  1035. return ((total_loss,) + output) if total_loss is not None else output
  1036. return QuestionAnsweringModelOutput(
  1037. loss=total_loss,
  1038. start_logits=start_logits,
  1039. end_logits=end_logits,
  1040. hidden_states=outputs.hidden_states,
  1041. attentions=outputs.attentions,
  1042. )
  1043. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  1044. """
  1045. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  1046. are ignored. This is modified from fairseq's *utils.make_positions*.
  1047. Args:
  1048. input_ids (`torch.LongTensor`):
  1049. Indices of input sequence tokens in the vocabulary.
  1050. Returns: torch.Tensor
  1051. """
  1052. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1053. mask = input_ids.ne(padding_idx).int()
  1054. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1055. return incremental_indices.long() + padding_idx
  1056. __all__ = [
  1057. "IBertForMaskedLM",
  1058. "IBertForMultipleChoice",
  1059. "IBertForQuestionAnswering",
  1060. "IBertForSequenceClassification",
  1061. "IBertForTokenClassification",
  1062. "IBertModel",
  1063. "IBertPreTrainedModel",
  1064. ]