modeling_deberta.py 48 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204
  1. # coding=utf-8
  2. # Copyright 2020 Microsoft and the Hugging Face Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch DeBERTa model."""
  16. from typing import Optional, Union
  17. import torch
  18. from torch import nn
  19. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  20. from ...activations import ACT2FN
  21. from ...modeling_layers import GradientCheckpointingLayer
  22. from ...modeling_outputs import (
  23. BaseModelOutput,
  24. MaskedLMOutput,
  25. QuestionAnsweringModelOutput,
  26. SequenceClassifierOutput,
  27. TokenClassifierOutput,
  28. )
  29. from ...modeling_utils import PreTrainedModel
  30. from ...utils import auto_docstring, logging
  31. from .configuration_deberta import DebertaConfig
  32. logger = logging.get_logger(__name__)
  33. class DebertaLayerNorm(nn.Module):
  34. """LayerNorm module in the TF style (epsilon inside the square root)."""
  35. def __init__(self, size, eps=1e-12):
  36. super().__init__()
  37. self.weight = nn.Parameter(torch.ones(size))
  38. self.bias = nn.Parameter(torch.zeros(size))
  39. self.variance_epsilon = eps
  40. def forward(self, hidden_states):
  41. input_type = hidden_states.dtype
  42. hidden_states = hidden_states.float()
  43. mean = hidden_states.mean(-1, keepdim=True)
  44. variance = (hidden_states - mean).pow(2).mean(-1, keepdim=True)
  45. hidden_states = (hidden_states - mean) / torch.sqrt(variance + self.variance_epsilon)
  46. hidden_states = hidden_states.to(input_type)
  47. y = self.weight * hidden_states + self.bias
  48. return y
  49. class DebertaSelfOutput(nn.Module):
  50. def __init__(self, config):
  51. super().__init__()
  52. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  53. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  54. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  55. def forward(self, hidden_states, input_tensor):
  56. hidden_states = self.dense(hidden_states)
  57. hidden_states = self.dropout(hidden_states)
  58. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  59. return hidden_states
  60. @torch.jit.script
  61. def build_relative_position(query_layer, key_layer):
  62. """
  63. Build relative position according to the query and key
  64. We assume the absolute position of query \\(P_q\\) is range from (0, query_size) and the absolute position of key
  65. \\(P_k\\) is range from (0, key_size), The relative positions from query to key is \\(R_{q \\rightarrow k} = P_q -
  66. P_k\\)
  67. Args:
  68. query_size (int): the length of query
  69. key_size (int): the length of key
  70. Return:
  71. `torch.LongTensor`: A tensor with shape [1, query_size, key_size]
  72. """
  73. query_size = query_layer.size(-2)
  74. key_size = key_layer.size(-2)
  75. q_ids = torch.arange(query_size, dtype=torch.long, device=query_layer.device)
  76. k_ids = torch.arange(key_size, dtype=torch.long, device=key_layer.device)
  77. rel_pos_ids = q_ids[:, None] - k_ids.view(1, -1).repeat(query_size, 1)
  78. rel_pos_ids = rel_pos_ids[:query_size, :]
  79. rel_pos_ids = rel_pos_ids.unsqueeze(0)
  80. return rel_pos_ids
  81. @torch.jit.script
  82. def c2p_dynamic_expand(c2p_pos, query_layer, relative_pos):
  83. return c2p_pos.expand([query_layer.size(0), query_layer.size(1), query_layer.size(2), relative_pos.size(-1)])
  84. @torch.jit.script
  85. def p2c_dynamic_expand(c2p_pos, query_layer, key_layer):
  86. return c2p_pos.expand([query_layer.size(0), query_layer.size(1), key_layer.size(-2), key_layer.size(-2)])
  87. @torch.jit.script
  88. def pos_dynamic_expand(pos_index, p2c_att, key_layer):
  89. return pos_index.expand(p2c_att.size()[:2] + (pos_index.size(-2), key_layer.size(-2)))
  90. ###### To support a general trace, we have to define these operation as they use python objects (sizes) ##################
  91. # which are not supported by torch.jit.trace.
  92. # Full credits to @Szustarol
  93. @torch.jit.script
  94. def scaled_size_sqrt(query_layer: torch.Tensor, scale_factor: int):
  95. return torch.sqrt(torch.tensor(query_layer.size(-1), dtype=torch.float) * scale_factor)
  96. @torch.jit.script
  97. def build_rpos(query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
  98. if query_layer.size(-2) != key_layer.size(-2):
  99. return build_relative_position(query_layer, key_layer)
  100. else:
  101. return relative_pos
  102. @torch.jit.script
  103. def compute_attention_span(query_layer: torch.Tensor, key_layer: torch.Tensor, max_relative_positions: int):
  104. return torch.tensor(min(max(query_layer.size(-2), key_layer.size(-2)), max_relative_positions))
  105. @torch.jit.script
  106. def uneven_size_corrected(p2c_att, query_layer: torch.Tensor, key_layer: torch.Tensor, relative_pos):
  107. if query_layer.size(-2) != key_layer.size(-2):
  108. pos_index = relative_pos[:, :, :, 0].unsqueeze(-1)
  109. return torch.gather(p2c_att, dim=2, index=pos_dynamic_expand(pos_index, p2c_att, key_layer))
  110. else:
  111. return p2c_att
  112. ########################################################################################################################
  113. class DisentangledSelfAttention(nn.Module):
  114. """
  115. Disentangled self-attention module
  116. Parameters:
  117. config (`str`):
  118. A model config class instance with the configuration to build a new model. The schema is similar to
  119. *BertConfig*, for more details, please refer [`DebertaConfig`]
  120. """
  121. def __init__(self, config):
  122. super().__init__()
  123. if config.hidden_size % config.num_attention_heads != 0:
  124. raise ValueError(
  125. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  126. f"heads ({config.num_attention_heads})"
  127. )
  128. self.num_attention_heads = config.num_attention_heads
  129. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  130. self.all_head_size = self.num_attention_heads * self.attention_head_size
  131. self.in_proj = nn.Linear(config.hidden_size, self.all_head_size * 3, bias=False)
  132. self.q_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
  133. self.v_bias = nn.Parameter(torch.zeros((self.all_head_size), dtype=torch.float))
  134. self.pos_att_type = config.pos_att_type if config.pos_att_type is not None else []
  135. self.relative_attention = getattr(config, "relative_attention", False)
  136. self.talking_head = getattr(config, "talking_head", False)
  137. if self.talking_head:
  138. self.head_logits_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
  139. self.head_weights_proj = nn.Linear(config.num_attention_heads, config.num_attention_heads, bias=False)
  140. else:
  141. self.head_logits_proj = None
  142. self.head_weights_proj = None
  143. if self.relative_attention:
  144. self.max_relative_positions = getattr(config, "max_relative_positions", -1)
  145. if self.max_relative_positions < 1:
  146. self.max_relative_positions = config.max_position_embeddings
  147. self.pos_dropout = nn.Dropout(config.hidden_dropout_prob)
  148. if "c2p" in self.pos_att_type:
  149. self.pos_proj = nn.Linear(config.hidden_size, self.all_head_size, bias=False)
  150. if "p2c" in self.pos_att_type:
  151. self.pos_q_proj = nn.Linear(config.hidden_size, self.all_head_size)
  152. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  153. def transpose_for_scores(self, x):
  154. new_x_shape = x.size()[:-1] + (self.num_attention_heads, -1)
  155. x = x.view(new_x_shape)
  156. return x.permute(0, 2, 1, 3)
  157. def forward(
  158. self,
  159. hidden_states: torch.Tensor,
  160. attention_mask: torch.Tensor,
  161. output_attentions: bool = False,
  162. query_states: Optional[torch.Tensor] = None,
  163. relative_pos: Optional[torch.Tensor] = None,
  164. rel_embeddings: Optional[torch.Tensor] = None,
  165. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  166. """
  167. Call the module
  168. Args:
  169. hidden_states (`torch.FloatTensor`):
  170. Input states to the module usually the output from previous layer, it will be the Q,K and V in
  171. *Attention(Q,K,V)*
  172. attention_mask (`torch.BoolTensor`):
  173. An attention mask matrix of shape [*B*, *N*, *N*] where *B* is the batch size, *N* is the maximum
  174. sequence length in which element [i,j] = *1* means the *i* th token in the input can attend to the *j*
  175. th token.
  176. output_attentions (`bool`, *optional*):
  177. Whether return the attention matrix.
  178. query_states (`torch.FloatTensor`, *optional*):
  179. The *Q* state in *Attention(Q,K,V)*.
  180. relative_pos (`torch.LongTensor`):
  181. The relative position encoding between the tokens in the sequence. It's of shape [*B*, *N*, *N*] with
  182. values ranging in [*-max_relative_positions*, *max_relative_positions*].
  183. rel_embeddings (`torch.FloatTensor`):
  184. The embedding of relative distances. It's a tensor of shape [\\(2 \\times
  185. \\text{max_relative_positions}\\), *hidden_size*].
  186. """
  187. if query_states is None:
  188. qp = self.in_proj(hidden_states) # .split(self.all_head_size, dim=-1)
  189. query_layer, key_layer, value_layer = self.transpose_for_scores(qp).chunk(3, dim=-1)
  190. else:
  191. ws = self.in_proj.weight.chunk(self.num_attention_heads * 3, dim=0)
  192. qkvw = [torch.cat([ws[i * 3 + k] for i in range(self.num_attention_heads)], dim=0) for k in range(3)]
  193. q = torch.matmul(qkvw[0], query_states.t().to(dtype=qkvw[0].dtype))
  194. k = torch.matmul(qkvw[1], hidden_states.t().to(dtype=qkvw[1].dtype))
  195. v = torch.matmul(qkvw[2], hidden_states.t().to(dtype=qkvw[2].dtype))
  196. query_layer, key_layer, value_layer = [self.transpose_for_scores(x) for x in [q, k, v]]
  197. query_layer = query_layer + self.transpose_for_scores(self.q_bias[None, None, :])
  198. value_layer = value_layer + self.transpose_for_scores(self.v_bias[None, None, :])
  199. rel_att: int = 0
  200. # Take the dot product between "query" and "key" to get the raw attention scores.
  201. scale_factor = 1 + len(self.pos_att_type)
  202. scale = scaled_size_sqrt(query_layer, scale_factor)
  203. query_layer = query_layer / scale.to(dtype=query_layer.dtype)
  204. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  205. if self.relative_attention and rel_embeddings is not None and relative_pos is not None:
  206. rel_embeddings = self.pos_dropout(rel_embeddings)
  207. rel_att = self.disentangled_att_bias(query_layer, key_layer, relative_pos, rel_embeddings, scale_factor)
  208. if rel_att is not None:
  209. attention_scores = attention_scores + rel_att
  210. # bxhxlxd
  211. if self.head_logits_proj is not None:
  212. attention_scores = self.head_logits_proj(attention_scores.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  213. attention_mask = attention_mask.bool()
  214. attention_scores = attention_scores.masked_fill(~(attention_mask), torch.finfo(query_layer.dtype).min)
  215. # bsz x height x length x dimension
  216. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  217. attention_probs = self.dropout(attention_probs)
  218. if self.head_weights_proj is not None:
  219. attention_probs = self.head_weights_proj(attention_probs.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
  220. context_layer = torch.matmul(attention_probs, value_layer)
  221. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  222. new_context_layer_shape = context_layer.size()[:-2] + (-1,)
  223. context_layer = context_layer.view(new_context_layer_shape)
  224. if not output_attentions:
  225. return (context_layer, None)
  226. return (context_layer, attention_probs)
  227. def disentangled_att_bias(
  228. self,
  229. query_layer: torch.Tensor,
  230. key_layer: torch.Tensor,
  231. relative_pos: torch.Tensor,
  232. rel_embeddings: torch.Tensor,
  233. scale_factor: int,
  234. ):
  235. if relative_pos is None:
  236. relative_pos = build_relative_position(query_layer, key_layer, query_layer.device)
  237. if relative_pos.dim() == 2:
  238. relative_pos = relative_pos.unsqueeze(0).unsqueeze(0)
  239. elif relative_pos.dim() == 3:
  240. relative_pos = relative_pos.unsqueeze(1)
  241. # bxhxqxk
  242. elif relative_pos.dim() != 4:
  243. raise ValueError(f"Relative position ids must be of dim 2 or 3 or 4. {relative_pos.dim()}")
  244. att_span = compute_attention_span(query_layer, key_layer, self.max_relative_positions)
  245. relative_pos = relative_pos.long()
  246. rel_embeddings = rel_embeddings[
  247. self.max_relative_positions - att_span : self.max_relative_positions + att_span, :
  248. ].unsqueeze(0)
  249. score = 0
  250. # content->position
  251. if "c2p" in self.pos_att_type:
  252. pos_key_layer = self.pos_proj(rel_embeddings)
  253. pos_key_layer = self.transpose_for_scores(pos_key_layer)
  254. c2p_att = torch.matmul(query_layer, pos_key_layer.transpose(-1, -2))
  255. c2p_pos = torch.clamp(relative_pos + att_span, 0, att_span * 2 - 1)
  256. c2p_att = torch.gather(c2p_att, dim=-1, index=c2p_dynamic_expand(c2p_pos, query_layer, relative_pos))
  257. score += c2p_att
  258. # position->content
  259. if "p2c" in self.pos_att_type:
  260. pos_query_layer = self.pos_q_proj(rel_embeddings)
  261. pos_query_layer = self.transpose_for_scores(pos_query_layer)
  262. pos_query_layer /= scaled_size_sqrt(pos_query_layer, scale_factor)
  263. r_pos = build_rpos(
  264. query_layer,
  265. key_layer,
  266. relative_pos,
  267. )
  268. p2c_pos = torch.clamp(-r_pos + att_span, 0, att_span * 2 - 1)
  269. p2c_att = torch.matmul(key_layer, pos_query_layer.transpose(-1, -2).to(dtype=key_layer.dtype))
  270. p2c_att = torch.gather(
  271. p2c_att, dim=-1, index=p2c_dynamic_expand(p2c_pos, query_layer, key_layer)
  272. ).transpose(-1, -2)
  273. p2c_att = uneven_size_corrected(p2c_att, query_layer, key_layer, relative_pos)
  274. score += p2c_att
  275. return score
  276. class DebertaEmbeddings(nn.Module):
  277. """Construct the embeddings from word, position and token_type embeddings."""
  278. def __init__(self, config):
  279. super().__init__()
  280. pad_token_id = getattr(config, "pad_token_id", 0)
  281. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  282. self.word_embeddings = nn.Embedding(config.vocab_size, self.embedding_size, padding_idx=pad_token_id)
  283. self.position_biased_input = getattr(config, "position_biased_input", True)
  284. if not self.position_biased_input:
  285. self.position_embeddings = None
  286. else:
  287. self.position_embeddings = nn.Embedding(config.max_position_embeddings, self.embedding_size)
  288. if config.type_vocab_size > 0:
  289. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, self.embedding_size)
  290. else:
  291. self.token_type_embeddings = None
  292. if self.embedding_size != config.hidden_size:
  293. self.embed_proj = nn.Linear(self.embedding_size, config.hidden_size, bias=False)
  294. else:
  295. self.embed_proj = None
  296. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  297. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  298. self.config = config
  299. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  300. self.register_buffer(
  301. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  302. )
  303. def forward(self, input_ids=None, token_type_ids=None, position_ids=None, mask=None, inputs_embeds=None):
  304. if input_ids is not None:
  305. input_shape = input_ids.size()
  306. else:
  307. input_shape = inputs_embeds.size()[:-1]
  308. seq_length = input_shape[1]
  309. if position_ids is None:
  310. position_ids = self.position_ids[:, :seq_length]
  311. if token_type_ids is None:
  312. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  313. if inputs_embeds is None:
  314. inputs_embeds = self.word_embeddings(input_ids)
  315. if self.position_embeddings is not None:
  316. position_embeddings = self.position_embeddings(position_ids.long())
  317. else:
  318. position_embeddings = torch.zeros_like(inputs_embeds)
  319. embeddings = inputs_embeds
  320. if self.position_biased_input:
  321. embeddings = embeddings + position_embeddings
  322. if self.token_type_embeddings is not None:
  323. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  324. embeddings = embeddings + token_type_embeddings
  325. if self.embed_proj is not None:
  326. embeddings = self.embed_proj(embeddings)
  327. embeddings = self.LayerNorm(embeddings)
  328. if mask is not None:
  329. if mask.dim() != embeddings.dim():
  330. if mask.dim() == 4:
  331. mask = mask.squeeze(1).squeeze(1)
  332. mask = mask.unsqueeze(2)
  333. mask = mask.to(embeddings.dtype)
  334. embeddings = embeddings * mask
  335. embeddings = self.dropout(embeddings)
  336. return embeddings
  337. class DebertaAttention(nn.Module):
  338. def __init__(self, config):
  339. super().__init__()
  340. self.self = DisentangledSelfAttention(config)
  341. self.output = DebertaSelfOutput(config)
  342. self.config = config
  343. def forward(
  344. self,
  345. hidden_states,
  346. attention_mask,
  347. output_attentions: bool = False,
  348. query_states=None,
  349. relative_pos=None,
  350. rel_embeddings=None,
  351. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  352. self_output, att_matrix = self.self(
  353. hidden_states,
  354. attention_mask,
  355. output_attentions,
  356. query_states=query_states,
  357. relative_pos=relative_pos,
  358. rel_embeddings=rel_embeddings,
  359. )
  360. if query_states is None:
  361. query_states = hidden_states
  362. attention_output = self.output(self_output, query_states)
  363. if output_attentions:
  364. return (attention_output, att_matrix)
  365. else:
  366. return (attention_output, None)
  367. # Copied from transformers.models.bert.modeling_bert.BertIntermediate with Bert->Deberta
  368. class DebertaIntermediate(nn.Module):
  369. def __init__(self, config):
  370. super().__init__()
  371. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  372. if isinstance(config.hidden_act, str):
  373. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  374. else:
  375. self.intermediate_act_fn = config.hidden_act
  376. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  377. hidden_states = self.dense(hidden_states)
  378. hidden_states = self.intermediate_act_fn(hidden_states)
  379. return hidden_states
  380. class DebertaOutput(nn.Module):
  381. def __init__(self, config):
  382. super().__init__()
  383. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  384. self.LayerNorm = DebertaLayerNorm(config.hidden_size, config.layer_norm_eps)
  385. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  386. self.config = config
  387. def forward(self, hidden_states, input_tensor):
  388. hidden_states = self.dense(hidden_states)
  389. hidden_states = self.dropout(hidden_states)
  390. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  391. return hidden_states
  392. class DebertaLayer(GradientCheckpointingLayer):
  393. def __init__(self, config):
  394. super().__init__()
  395. self.attention = DebertaAttention(config)
  396. self.intermediate = DebertaIntermediate(config)
  397. self.output = DebertaOutput(config)
  398. def forward(
  399. self,
  400. hidden_states,
  401. attention_mask,
  402. query_states=None,
  403. relative_pos=None,
  404. rel_embeddings=None,
  405. output_attentions: bool = False,
  406. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  407. attention_output, att_matrix = self.attention(
  408. hidden_states,
  409. attention_mask,
  410. output_attentions=output_attentions,
  411. query_states=query_states,
  412. relative_pos=relative_pos,
  413. rel_embeddings=rel_embeddings,
  414. )
  415. intermediate_output = self.intermediate(attention_output)
  416. layer_output = self.output(intermediate_output, attention_output)
  417. if output_attentions:
  418. return (layer_output, att_matrix)
  419. else:
  420. return (layer_output, None)
  421. class DebertaEncoder(nn.Module):
  422. """Modified BertEncoder with relative position bias support"""
  423. def __init__(self, config):
  424. super().__init__()
  425. self.layer = nn.ModuleList([DebertaLayer(config) for _ in range(config.num_hidden_layers)])
  426. self.relative_attention = getattr(config, "relative_attention", False)
  427. if self.relative_attention:
  428. self.max_relative_positions = getattr(config, "max_relative_positions", -1)
  429. if self.max_relative_positions < 1:
  430. self.max_relative_positions = config.max_position_embeddings
  431. self.rel_embeddings = nn.Embedding(self.max_relative_positions * 2, config.hidden_size)
  432. self.gradient_checkpointing = False
  433. def get_rel_embedding(self):
  434. rel_embeddings = self.rel_embeddings.weight if self.relative_attention else None
  435. return rel_embeddings
  436. def get_attention_mask(self, attention_mask):
  437. if attention_mask.dim() <= 2:
  438. extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)
  439. attention_mask = extended_attention_mask * extended_attention_mask.squeeze(-2).unsqueeze(-1)
  440. elif attention_mask.dim() == 3:
  441. attention_mask = attention_mask.unsqueeze(1)
  442. return attention_mask
  443. def get_rel_pos(self, hidden_states, query_states=None, relative_pos=None):
  444. if self.relative_attention and relative_pos is None:
  445. if query_states is not None:
  446. relative_pos = build_relative_position(query_states, hidden_states)
  447. else:
  448. relative_pos = build_relative_position(hidden_states, hidden_states)
  449. return relative_pos
  450. def forward(
  451. self,
  452. hidden_states: torch.Tensor,
  453. attention_mask: torch.Tensor,
  454. output_hidden_states: bool = True,
  455. output_attentions: bool = False,
  456. query_states=None,
  457. relative_pos=None,
  458. return_dict: bool = True,
  459. ):
  460. attention_mask = self.get_attention_mask(attention_mask)
  461. relative_pos = self.get_rel_pos(hidden_states, query_states, relative_pos)
  462. all_hidden_states: Optional[tuple[torch.Tensor]] = (hidden_states,) if output_hidden_states else None
  463. all_attentions = () if output_attentions else None
  464. next_kv = hidden_states
  465. rel_embeddings = self.get_rel_embedding()
  466. for i, layer_module in enumerate(self.layer):
  467. hidden_states, att_m = layer_module(
  468. next_kv,
  469. attention_mask,
  470. query_states=query_states,
  471. relative_pos=relative_pos,
  472. rel_embeddings=rel_embeddings,
  473. output_attentions=output_attentions,
  474. )
  475. if output_hidden_states:
  476. all_hidden_states = all_hidden_states + (hidden_states,)
  477. if query_states is not None:
  478. query_states = hidden_states
  479. else:
  480. next_kv = hidden_states
  481. if output_attentions:
  482. all_attentions = all_attentions + (att_m,)
  483. if not return_dict:
  484. return tuple(v for v in [hidden_states, all_hidden_states, all_attentions] if v is not None)
  485. return BaseModelOutput(
  486. last_hidden_state=hidden_states, hidden_states=all_hidden_states, attentions=all_attentions
  487. )
  488. @auto_docstring
  489. class DebertaPreTrainedModel(PreTrainedModel):
  490. config: DebertaConfig
  491. base_model_prefix = "deberta"
  492. _keys_to_ignore_on_load_unexpected = ["position_embeddings"]
  493. supports_gradient_checkpointing = True
  494. def _init_weights(self, module):
  495. """Initialize the weights."""
  496. if isinstance(module, nn.Linear):
  497. # Slightly different from the TF version which uses truncated_normal for initialization
  498. # cf https://github.com/pytorch/pytorch/pull/5617
  499. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  500. if module.bias is not None:
  501. module.bias.data.zero_()
  502. elif isinstance(module, nn.Embedding):
  503. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  504. if module.padding_idx is not None:
  505. module.weight.data[module.padding_idx].zero_()
  506. elif isinstance(module, (nn.LayerNorm, DebertaLayerNorm)):
  507. module.weight.data.fill_(1.0)
  508. module.bias.data.zero_()
  509. elif isinstance(module, DisentangledSelfAttention):
  510. module.q_bias.data.zero_()
  511. module.v_bias.data.zero_()
  512. elif isinstance(module, (LegacyDebertaLMPredictionHead, DebertaLMPredictionHead)):
  513. module.bias.data.zero_()
  514. @auto_docstring
  515. class DebertaModel(DebertaPreTrainedModel):
  516. def __init__(self, config):
  517. super().__init__(config)
  518. self.embeddings = DebertaEmbeddings(config)
  519. self.encoder = DebertaEncoder(config)
  520. self.z_steps = 0
  521. self.config = config
  522. # Initialize weights and apply final processing
  523. self.post_init()
  524. def get_input_embeddings(self):
  525. return self.embeddings.word_embeddings
  526. def set_input_embeddings(self, new_embeddings):
  527. self.embeddings.word_embeddings = new_embeddings
  528. def _prune_heads(self, heads_to_prune):
  529. """
  530. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  531. class PreTrainedModel
  532. """
  533. raise NotImplementedError("The prune function is not implemented in DeBERTa model.")
  534. @auto_docstring
  535. def forward(
  536. self,
  537. input_ids: Optional[torch.Tensor] = None,
  538. attention_mask: Optional[torch.Tensor] = None,
  539. token_type_ids: Optional[torch.Tensor] = None,
  540. position_ids: Optional[torch.Tensor] = None,
  541. inputs_embeds: Optional[torch.Tensor] = None,
  542. output_attentions: Optional[bool] = None,
  543. output_hidden_states: Optional[bool] = None,
  544. return_dict: Optional[bool] = None,
  545. ) -> Union[tuple, BaseModelOutput]:
  546. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  547. output_hidden_states = (
  548. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  549. )
  550. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  551. if input_ids is not None and inputs_embeds is not None:
  552. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  553. elif input_ids is not None:
  554. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  555. input_shape = input_ids.size()
  556. elif inputs_embeds is not None:
  557. input_shape = inputs_embeds.size()[:-1]
  558. else:
  559. raise ValueError("You have to specify either input_ids or inputs_embeds")
  560. device = input_ids.device if input_ids is not None else inputs_embeds.device
  561. if attention_mask is None:
  562. attention_mask = torch.ones(input_shape, device=device)
  563. if token_type_ids is None:
  564. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  565. embedding_output = self.embeddings(
  566. input_ids=input_ids,
  567. token_type_ids=token_type_ids,
  568. position_ids=position_ids,
  569. mask=attention_mask,
  570. inputs_embeds=inputs_embeds,
  571. )
  572. encoder_outputs = self.encoder(
  573. embedding_output,
  574. attention_mask,
  575. output_hidden_states=True,
  576. output_attentions=output_attentions,
  577. return_dict=return_dict,
  578. )
  579. encoded_layers = encoder_outputs[1]
  580. if self.z_steps > 1:
  581. hidden_states = encoded_layers[-2]
  582. layers = [self.encoder.layer[-1] for _ in range(self.z_steps)]
  583. query_states = encoded_layers[-1]
  584. rel_embeddings = self.encoder.get_rel_embedding()
  585. attention_mask = self.encoder.get_attention_mask(attention_mask)
  586. rel_pos = self.encoder.get_rel_pos(embedding_output)
  587. for layer in layers[1:]:
  588. query_states = layer(
  589. hidden_states,
  590. attention_mask,
  591. output_attentions=False,
  592. query_states=query_states,
  593. relative_pos=rel_pos,
  594. rel_embeddings=rel_embeddings,
  595. )
  596. encoded_layers.append(query_states)
  597. sequence_output = encoded_layers[-1]
  598. if not return_dict:
  599. return (sequence_output,) + encoder_outputs[(1 if output_hidden_states else 2) :]
  600. return BaseModelOutput(
  601. last_hidden_state=sequence_output,
  602. hidden_states=encoder_outputs.hidden_states if output_hidden_states else None,
  603. attentions=encoder_outputs.attentions,
  604. )
  605. class LegacyDebertaPredictionHeadTransform(nn.Module):
  606. def __init__(self, config):
  607. super().__init__()
  608. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  609. self.dense = nn.Linear(config.hidden_size, self.embedding_size)
  610. if isinstance(config.hidden_act, str):
  611. self.transform_act_fn = ACT2FN[config.hidden_act]
  612. else:
  613. self.transform_act_fn = config.hidden_act
  614. self.LayerNorm = nn.LayerNorm(self.embedding_size, eps=config.layer_norm_eps)
  615. def forward(self, hidden_states):
  616. hidden_states = self.dense(hidden_states)
  617. hidden_states = self.transform_act_fn(hidden_states)
  618. hidden_states = self.LayerNorm(hidden_states)
  619. return hidden_states
  620. class LegacyDebertaLMPredictionHead(nn.Module):
  621. def __init__(self, config):
  622. super().__init__()
  623. self.transform = LegacyDebertaPredictionHeadTransform(config)
  624. self.embedding_size = getattr(config, "embedding_size", config.hidden_size)
  625. # The output weights are the same as the input embeddings, but there is
  626. # an output-only bias for each token.
  627. self.decoder = nn.Linear(self.embedding_size, config.vocab_size, bias=False)
  628. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  629. # Need a link between the two variables so that the bias is correctly resized with `resize_token_embeddings`
  630. self.decoder.bias = self.bias
  631. def _tie_weights(self):
  632. self.decoder.bias = self.bias
  633. def forward(self, hidden_states):
  634. hidden_states = self.transform(hidden_states)
  635. hidden_states = self.decoder(hidden_states)
  636. return hidden_states
  637. # Copied from transformers.models.bert.modeling_bert.BertOnlyMLMHead with Bert->LegacyDeberta
  638. class LegacyDebertaOnlyMLMHead(nn.Module):
  639. def __init__(self, config):
  640. super().__init__()
  641. self.predictions = LegacyDebertaLMPredictionHead(config)
  642. def forward(self, sequence_output: torch.Tensor) -> torch.Tensor:
  643. prediction_scores = self.predictions(sequence_output)
  644. return prediction_scores
  645. class DebertaLMPredictionHead(nn.Module):
  646. """https://github.com/microsoft/DeBERTa/blob/master/DeBERTa/deberta/bert.py#L270"""
  647. def __init__(self, config):
  648. super().__init__()
  649. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  650. if isinstance(config.hidden_act, str):
  651. self.transform_act_fn = ACT2FN[config.hidden_act]
  652. else:
  653. self.transform_act_fn = config.hidden_act
  654. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps, elementwise_affine=True)
  655. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  656. # note that the input embeddings must be passed as an argument
  657. def forward(self, hidden_states, word_embeddings):
  658. hidden_states = self.dense(hidden_states)
  659. hidden_states = self.transform_act_fn(hidden_states)
  660. hidden_states = self.LayerNorm(
  661. hidden_states
  662. ) # original used MaskedLayerNorm, but passed no mask. This is equivalent.
  663. hidden_states = torch.matmul(hidden_states, word_embeddings.weight.t()) + self.bias
  664. return hidden_states
  665. class DebertaOnlyMLMHead(nn.Module):
  666. def __init__(self, config):
  667. super().__init__()
  668. self.lm_head = DebertaLMPredictionHead(config)
  669. # note that the input embeddings must be passed as an argument
  670. def forward(self, sequence_output, word_embeddings):
  671. prediction_scores = self.lm_head(sequence_output, word_embeddings)
  672. return prediction_scores
  673. @auto_docstring
  674. class DebertaForMaskedLM(DebertaPreTrainedModel):
  675. _tied_weights_keys = ["cls.predictions.decoder.weight", "cls.predictions.decoder.bias"]
  676. def __init__(self, config):
  677. super().__init__(config)
  678. self.legacy = config.legacy
  679. self.deberta = DebertaModel(config)
  680. if self.legacy:
  681. self.cls = LegacyDebertaOnlyMLMHead(config)
  682. else:
  683. self._tied_weights_keys = ["lm_predictions.lm_head.weight", "deberta.embeddings.word_embeddings.weight"]
  684. self.lm_predictions = DebertaOnlyMLMHead(config)
  685. # Initialize weights and apply final processing
  686. self.post_init()
  687. def get_output_embeddings(self):
  688. if self.legacy:
  689. return self.cls.predictions.decoder
  690. else:
  691. return self.lm_predictions.lm_head.dense
  692. def set_output_embeddings(self, new_embeddings):
  693. if self.legacy:
  694. self.cls.predictions.decoder = new_embeddings
  695. self.cls.predictions.bias = new_embeddings.bias
  696. else:
  697. self.lm_predictions.lm_head.dense = new_embeddings
  698. self.lm_predictions.lm_head.bias = new_embeddings.bias
  699. @auto_docstring
  700. def forward(
  701. self,
  702. input_ids: Optional[torch.Tensor] = None,
  703. attention_mask: Optional[torch.Tensor] = None,
  704. token_type_ids: Optional[torch.Tensor] = None,
  705. position_ids: Optional[torch.Tensor] = None,
  706. inputs_embeds: Optional[torch.Tensor] = None,
  707. labels: Optional[torch.Tensor] = None,
  708. output_attentions: Optional[bool] = None,
  709. output_hidden_states: Optional[bool] = None,
  710. return_dict: Optional[bool] = None,
  711. ) -> Union[tuple, MaskedLMOutput]:
  712. r"""
  713. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  714. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  715. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  716. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  717. """
  718. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  719. outputs = self.deberta(
  720. input_ids,
  721. attention_mask=attention_mask,
  722. token_type_ids=token_type_ids,
  723. position_ids=position_ids,
  724. inputs_embeds=inputs_embeds,
  725. output_attentions=output_attentions,
  726. output_hidden_states=output_hidden_states,
  727. return_dict=return_dict,
  728. )
  729. sequence_output = outputs[0]
  730. if self.legacy:
  731. prediction_scores = self.cls(sequence_output)
  732. else:
  733. prediction_scores = self.lm_predictions(sequence_output, self.deberta.embeddings.word_embeddings)
  734. masked_lm_loss = None
  735. if labels is not None:
  736. loss_fct = CrossEntropyLoss() # -100 index = padding token
  737. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  738. if not return_dict:
  739. output = (prediction_scores,) + outputs[1:]
  740. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  741. return MaskedLMOutput(
  742. loss=masked_lm_loss,
  743. logits=prediction_scores,
  744. hidden_states=outputs.hidden_states,
  745. attentions=outputs.attentions,
  746. )
  747. class ContextPooler(nn.Module):
  748. def __init__(self, config):
  749. super().__init__()
  750. self.dense = nn.Linear(config.pooler_hidden_size, config.pooler_hidden_size)
  751. self.dropout = nn.Dropout(config.pooler_dropout)
  752. self.config = config
  753. def forward(self, hidden_states):
  754. # We "pool" the model by simply taking the hidden state corresponding
  755. # to the first token.
  756. context_token = hidden_states[:, 0]
  757. context_token = self.dropout(context_token)
  758. pooled_output = self.dense(context_token)
  759. pooled_output = ACT2FN[self.config.pooler_hidden_act](pooled_output)
  760. return pooled_output
  761. @property
  762. def output_dim(self):
  763. return self.config.hidden_size
  764. @auto_docstring(
  765. custom_intro="""
  766. DeBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  767. pooled output) e.g. for GLUE tasks.
  768. """
  769. )
  770. class DebertaForSequenceClassification(DebertaPreTrainedModel):
  771. def __init__(self, config):
  772. super().__init__(config)
  773. num_labels = getattr(config, "num_labels", 2)
  774. self.num_labels = num_labels
  775. self.deberta = DebertaModel(config)
  776. self.pooler = ContextPooler(config)
  777. output_dim = self.pooler.output_dim
  778. self.classifier = nn.Linear(output_dim, num_labels)
  779. drop_out = getattr(config, "cls_dropout", None)
  780. drop_out = self.config.hidden_dropout_prob if drop_out is None else drop_out
  781. self.dropout = nn.Dropout(drop_out)
  782. # Initialize weights and apply final processing
  783. self.post_init()
  784. def get_input_embeddings(self):
  785. return self.deberta.get_input_embeddings()
  786. def set_input_embeddings(self, new_embeddings):
  787. self.deberta.set_input_embeddings(new_embeddings)
  788. @auto_docstring
  789. def forward(
  790. self,
  791. input_ids: Optional[torch.Tensor] = None,
  792. attention_mask: Optional[torch.Tensor] = None,
  793. token_type_ids: Optional[torch.Tensor] = None,
  794. position_ids: Optional[torch.Tensor] = None,
  795. inputs_embeds: Optional[torch.Tensor] = None,
  796. labels: Optional[torch.Tensor] = None,
  797. output_attentions: Optional[bool] = None,
  798. output_hidden_states: Optional[bool] = None,
  799. return_dict: Optional[bool] = None,
  800. ) -> Union[tuple, SequenceClassifierOutput]:
  801. r"""
  802. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  803. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  804. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  805. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  806. """
  807. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  808. outputs = self.deberta(
  809. input_ids,
  810. token_type_ids=token_type_ids,
  811. attention_mask=attention_mask,
  812. position_ids=position_ids,
  813. inputs_embeds=inputs_embeds,
  814. output_attentions=output_attentions,
  815. output_hidden_states=output_hidden_states,
  816. return_dict=return_dict,
  817. )
  818. encoder_layer = outputs[0]
  819. pooled_output = self.pooler(encoder_layer)
  820. pooled_output = self.dropout(pooled_output)
  821. logits = self.classifier(pooled_output)
  822. loss = None
  823. if labels is not None:
  824. if self.config.problem_type is None:
  825. if self.num_labels == 1:
  826. # regression task
  827. loss_fn = nn.MSELoss()
  828. logits = logits.view(-1).to(labels.dtype)
  829. loss = loss_fn(logits, labels.view(-1))
  830. elif labels.dim() == 1 or labels.size(-1) == 1:
  831. label_index = (labels >= 0).nonzero()
  832. labels = labels.long()
  833. if label_index.size(0) > 0:
  834. labeled_logits = torch.gather(
  835. logits, 0, label_index.expand(label_index.size(0), logits.size(1))
  836. )
  837. labels = torch.gather(labels, 0, label_index.view(-1))
  838. loss_fct = CrossEntropyLoss()
  839. loss = loss_fct(labeled_logits.view(-1, self.num_labels).float(), labels.view(-1))
  840. else:
  841. loss = torch.tensor(0).to(logits)
  842. else:
  843. log_softmax = nn.LogSoftmax(-1)
  844. loss = -((log_softmax(logits) * labels).sum(-1)).mean()
  845. elif self.config.problem_type == "regression":
  846. loss_fct = MSELoss()
  847. if self.num_labels == 1:
  848. loss = loss_fct(logits.squeeze(), labels.squeeze())
  849. else:
  850. loss = loss_fct(logits, labels)
  851. elif self.config.problem_type == "single_label_classification":
  852. loss_fct = CrossEntropyLoss()
  853. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  854. elif self.config.problem_type == "multi_label_classification":
  855. loss_fct = BCEWithLogitsLoss()
  856. loss = loss_fct(logits, labels)
  857. if not return_dict:
  858. output = (logits,) + outputs[1:]
  859. return ((loss,) + output) if loss is not None else output
  860. return SequenceClassifierOutput(
  861. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  862. )
  863. @auto_docstring
  864. class DebertaForTokenClassification(DebertaPreTrainedModel):
  865. def __init__(self, config):
  866. super().__init__(config)
  867. self.num_labels = config.num_labels
  868. self.deberta = DebertaModel(config)
  869. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  870. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  871. # Initialize weights and apply final processing
  872. self.post_init()
  873. @auto_docstring
  874. def forward(
  875. self,
  876. input_ids: Optional[torch.Tensor] = None,
  877. attention_mask: Optional[torch.Tensor] = None,
  878. token_type_ids: Optional[torch.Tensor] = None,
  879. position_ids: Optional[torch.Tensor] = None,
  880. inputs_embeds: Optional[torch.Tensor] = None,
  881. labels: Optional[torch.Tensor] = None,
  882. output_attentions: Optional[bool] = None,
  883. output_hidden_states: Optional[bool] = None,
  884. return_dict: Optional[bool] = None,
  885. ) -> Union[tuple, TokenClassifierOutput]:
  886. r"""
  887. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  888. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  889. """
  890. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  891. outputs = self.deberta(
  892. input_ids,
  893. attention_mask=attention_mask,
  894. token_type_ids=token_type_ids,
  895. position_ids=position_ids,
  896. inputs_embeds=inputs_embeds,
  897. output_attentions=output_attentions,
  898. output_hidden_states=output_hidden_states,
  899. return_dict=return_dict,
  900. )
  901. sequence_output = outputs[0]
  902. sequence_output = self.dropout(sequence_output)
  903. logits = self.classifier(sequence_output)
  904. loss = None
  905. if labels is not None:
  906. loss_fct = CrossEntropyLoss()
  907. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  908. if not return_dict:
  909. output = (logits,) + outputs[1:]
  910. return ((loss,) + output) if loss is not None else output
  911. return TokenClassifierOutput(
  912. loss=loss, logits=logits, hidden_states=outputs.hidden_states, attentions=outputs.attentions
  913. )
  914. @auto_docstring
  915. class DebertaForQuestionAnswering(DebertaPreTrainedModel):
  916. def __init__(self, config):
  917. super().__init__(config)
  918. self.num_labels = config.num_labels
  919. self.deberta = DebertaModel(config)
  920. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  921. # Initialize weights and apply final processing
  922. self.post_init()
  923. @auto_docstring
  924. def forward(
  925. self,
  926. input_ids: Optional[torch.Tensor] = None,
  927. attention_mask: Optional[torch.Tensor] = None,
  928. token_type_ids: Optional[torch.Tensor] = None,
  929. position_ids: Optional[torch.Tensor] = None,
  930. inputs_embeds: Optional[torch.Tensor] = None,
  931. start_positions: Optional[torch.Tensor] = None,
  932. end_positions: Optional[torch.Tensor] = None,
  933. output_attentions: Optional[bool] = None,
  934. output_hidden_states: Optional[bool] = None,
  935. return_dict: Optional[bool] = None,
  936. ) -> Union[tuple, QuestionAnsweringModelOutput]:
  937. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  938. outputs = self.deberta(
  939. input_ids,
  940. attention_mask=attention_mask,
  941. token_type_ids=token_type_ids,
  942. position_ids=position_ids,
  943. inputs_embeds=inputs_embeds,
  944. output_attentions=output_attentions,
  945. output_hidden_states=output_hidden_states,
  946. return_dict=return_dict,
  947. )
  948. sequence_output = outputs[0]
  949. logits = self.qa_outputs(sequence_output)
  950. start_logits, end_logits = logits.split(1, dim=-1)
  951. start_logits = start_logits.squeeze(-1).contiguous()
  952. end_logits = end_logits.squeeze(-1).contiguous()
  953. total_loss = None
  954. if start_positions is not None and end_positions is not None:
  955. # If we are on multi-GPU, split add a dimension
  956. if len(start_positions.size()) > 1:
  957. start_positions = start_positions.squeeze(-1)
  958. if len(end_positions.size()) > 1:
  959. end_positions = end_positions.squeeze(-1)
  960. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  961. ignored_index = start_logits.size(1)
  962. start_positions = start_positions.clamp(0, ignored_index)
  963. end_positions = end_positions.clamp(0, ignored_index)
  964. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  965. start_loss = loss_fct(start_logits, start_positions)
  966. end_loss = loss_fct(end_logits, end_positions)
  967. total_loss = (start_loss + end_loss) / 2
  968. if not return_dict:
  969. output = (start_logits, end_logits) + outputs[1:]
  970. return ((total_loss,) + output) if total_loss is not None else output
  971. return QuestionAnsweringModelOutput(
  972. loss=total_loss,
  973. start_logits=start_logits,
  974. end_logits=end_logits,
  975. hidden_states=outputs.hidden_states,
  976. attentions=outputs.attentions,
  977. )
  978. __all__ = [
  979. "DebertaForMaskedLM",
  980. "DebertaForQuestionAnswering",
  981. "DebertaForSequenceClassification",
  982. "DebertaForTokenClassification",
  983. "DebertaModel",
  984. "DebertaPreTrainedModel",
  985. ]