modeling_roberta.py 69 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562
  1. # coding=utf-8
  2. # Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
  3. # Copyright (c) 2018, NVIDIA CORPORATION. All rights reserved.
  4. #
  5. # Licensed under the Apache License, Version 2.0 (the "License");
  6. # you may not use this file except in compliance with the License.
  7. # You may obtain a copy of the License at
  8. #
  9. # http://www.apache.org/licenses/LICENSE-2.0
  10. #
  11. # Unless required by applicable law or agreed to in writing, software
  12. # distributed under the License is distributed on an "AS IS" BASIS,
  13. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  14. # See the License for the specific language governing permissions and
  15. # limitations under the License.
  16. """PyTorch RoBERTa model."""
  17. import math
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN, gelu
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import _prepare_4d_attention_mask_for_sdpa, _prepare_4d_causal_attention_mask_for_sdpa
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutputWithPastAndCrossAttentions,
  29. BaseModelOutputWithPoolingAndCrossAttentions,
  30. CausalLMOutputWithCrossAttentions,
  31. MaskedLMOutput,
  32. MultipleChoiceModelOutput,
  33. QuestionAnsweringModelOutput,
  34. SequenceClassifierOutput,
  35. TokenClassifierOutput,
  36. )
  37. from ...modeling_utils import PreTrainedModel
  38. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  39. from ...utils import auto_docstring, logging
  40. from ...utils.deprecation import deprecate_kwarg
  41. from .configuration_roberta import RobertaConfig
  42. logger = logging.get_logger(__name__)
  43. class RobertaEmbeddings(nn.Module):
  44. """
  45. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  46. """
  47. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
  48. def __init__(self, config):
  49. super().__init__()
  50. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  51. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  52. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  53. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  54. # any TensorFlow checkpoint file
  55. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  56. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  57. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  58. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  59. self.register_buffer(
  60. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  61. )
  62. self.register_buffer(
  63. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  64. )
  65. # End copy
  66. self.padding_idx = config.pad_token_id
  67. self.position_embeddings = nn.Embedding(
  68. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  69. )
  70. def forward(
  71. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  72. ):
  73. if position_ids is None:
  74. if input_ids is not None:
  75. # Create the position ids from the input token ids. Any padded tokens remain padded.
  76. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  77. else:
  78. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  79. if input_ids is not None:
  80. input_shape = input_ids.size()
  81. else:
  82. input_shape = inputs_embeds.size()[:-1]
  83. seq_length = input_shape[1]
  84. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  85. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  86. # issue #5664
  87. if token_type_ids is None:
  88. if hasattr(self, "token_type_ids"):
  89. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  90. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  91. token_type_ids = buffered_token_type_ids_expanded
  92. else:
  93. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  94. if inputs_embeds is None:
  95. inputs_embeds = self.word_embeddings(input_ids)
  96. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  97. embeddings = inputs_embeds + token_type_embeddings
  98. if self.position_embedding_type == "absolute":
  99. position_embeddings = self.position_embeddings(position_ids)
  100. embeddings += position_embeddings
  101. embeddings = self.LayerNorm(embeddings)
  102. embeddings = self.dropout(embeddings)
  103. return embeddings
  104. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  105. """
  106. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  107. Args:
  108. inputs_embeds: torch.Tensor
  109. Returns: torch.Tensor
  110. """
  111. input_shape = inputs_embeds.size()[:-1]
  112. sequence_length = input_shape[1]
  113. position_ids = torch.arange(
  114. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  115. )
  116. return position_ids.unsqueeze(0).expand(input_shape)
  117. # Copied from transformers.models.bert.modeling_bert.BertSelfAttention with Bert->Roberta
  118. class RobertaSelfAttention(nn.Module):
  119. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  120. super().__init__()
  121. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  122. raise ValueError(
  123. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  124. f"heads ({config.num_attention_heads})"
  125. )
  126. self.num_attention_heads = config.num_attention_heads
  127. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  128. self.all_head_size = self.num_attention_heads * self.attention_head_size
  129. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  130. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  131. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  132. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  133. self.position_embedding_type = position_embedding_type or getattr(
  134. config, "position_embedding_type", "absolute"
  135. )
  136. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  137. self.max_position_embeddings = config.max_position_embeddings
  138. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  139. self.is_decoder = config.is_decoder
  140. self.layer_idx = layer_idx
  141. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  142. def forward(
  143. self,
  144. hidden_states: torch.Tensor,
  145. attention_mask: Optional[torch.FloatTensor] = None,
  146. head_mask: Optional[torch.FloatTensor] = None,
  147. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  148. past_key_values: Optional[Cache] = None,
  149. output_attentions: Optional[bool] = False,
  150. cache_position: Optional[torch.Tensor] = None,
  151. ) -> tuple[torch.Tensor]:
  152. batch_size, seq_length, _ = hidden_states.shape
  153. query_layer = self.query(hidden_states)
  154. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  155. 1, 2
  156. )
  157. is_updated = False
  158. is_cross_attention = encoder_hidden_states is not None
  159. if past_key_values is not None:
  160. if isinstance(past_key_values, EncoderDecoderCache):
  161. is_updated = past_key_values.is_updated.get(self.layer_idx)
  162. if is_cross_attention:
  163. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  164. curr_past_key_value = past_key_values.cross_attention_cache
  165. else:
  166. curr_past_key_value = past_key_values.self_attention_cache
  167. else:
  168. curr_past_key_value = past_key_values
  169. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  170. if is_cross_attention and past_key_values is not None and is_updated:
  171. # reuse k,v, cross_attentions
  172. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  173. value_layer = curr_past_key_value.layers[self.layer_idx].values
  174. else:
  175. key_layer = self.key(current_states)
  176. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  177. 1, 2
  178. )
  179. value_layer = self.value(current_states)
  180. value_layer = value_layer.view(
  181. batch_size, -1, self.num_attention_heads, self.attention_head_size
  182. ).transpose(1, 2)
  183. if past_key_values is not None:
  184. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  185. cache_position = cache_position if not is_cross_attention else None
  186. key_layer, value_layer = curr_past_key_value.update(
  187. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  188. )
  189. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  190. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  191. past_key_values.is_updated[self.layer_idx] = True
  192. # Take the dot product between "query" and "key" to get the raw attention scores.
  193. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  194. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  195. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  196. if past_key_values is not None:
  197. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  198. -1, 1
  199. )
  200. else:
  201. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  202. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  203. distance = position_ids_l - position_ids_r
  204. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  205. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  206. if self.position_embedding_type == "relative_key":
  207. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  208. attention_scores = attention_scores + relative_position_scores
  209. elif self.position_embedding_type == "relative_key_query":
  210. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  211. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  212. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  213. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  214. if attention_mask is not None:
  215. # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
  216. attention_scores = attention_scores + attention_mask
  217. # Normalize the attention scores to probabilities.
  218. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  219. # This is actually dropping out entire tokens to attend to, which might
  220. # seem a bit unusual, but is taken from the original Transformer paper.
  221. attention_probs = self.dropout(attention_probs)
  222. # Mask heads if we want to
  223. if head_mask is not None:
  224. attention_probs = attention_probs * head_mask
  225. context_layer = torch.matmul(attention_probs, value_layer)
  226. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  227. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  228. context_layer = context_layer.view(new_context_layer_shape)
  229. return context_layer, attention_probs
  230. # Copied from transformers.models.bert.modeling_bert.BertSdpaSelfAttention with Bert->Roberta
  231. class RobertaSdpaSelfAttention(RobertaSelfAttention):
  232. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  233. super().__init__(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
  234. self.dropout_prob = config.attention_probs_dropout_prob
  235. # Adapted from RobertaSelfAttention
  236. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  237. def forward(
  238. self,
  239. hidden_states: torch.Tensor,
  240. attention_mask: Optional[torch.Tensor] = None,
  241. head_mask: Optional[torch.FloatTensor] = None,
  242. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  243. past_key_values: Optional[Cache] = None,
  244. output_attentions: Optional[bool] = False,
  245. cache_position: Optional[torch.Tensor] = None,
  246. ) -> tuple[torch.Tensor]:
  247. if self.position_embedding_type != "absolute" or output_attentions or head_mask is not None:
  248. # TODO: Improve this warning with e.g. `model.config._attn_implementation = "manual"` once implemented.
  249. logger.warning_once(
  250. "RobertaSdpaSelfAttention is used but `torch.nn.functional.scaled_dot_product_attention` does not support "
  251. "non-absolute `position_embedding_type` or `output_attentions=True` or `head_mask`. Falling back to "
  252. "the manual attention implementation, but specifying the manual implementation will be required from "
  253. "Transformers version v5.0.0 onwards. This warning can be removed using the argument "
  254. '`attn_implementation="eager"` when loading the model.'
  255. )
  256. return super().forward(
  257. hidden_states,
  258. attention_mask,
  259. head_mask,
  260. encoder_hidden_states,
  261. past_key_values,
  262. output_attentions,
  263. cache_position,
  264. )
  265. bsz, tgt_len, _ = hidden_states.size()
  266. query_layer = (
  267. self.query(hidden_states).view(bsz, -1, self.num_attention_heads, self.attention_head_size).transpose(1, 2)
  268. )
  269. is_updated = False
  270. is_cross_attention = encoder_hidden_states is not None
  271. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  272. if past_key_values is not None:
  273. if isinstance(past_key_values, EncoderDecoderCache):
  274. is_updated = past_key_values.is_updated.get(self.layer_idx)
  275. if is_cross_attention:
  276. # after the first generated id, we can subsequently re-use all key/value_states from cache
  277. curr_past_key_value = past_key_values.cross_attention_cache
  278. else:
  279. curr_past_key_value = past_key_values.self_attention_cache
  280. else:
  281. curr_past_key_value = past_key_values
  282. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  283. if is_cross_attention and past_key_values is not None and is_updated:
  284. # reuse k,v, cross_attentions
  285. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  286. value_layer = curr_past_key_value.layers[self.layer_idx].values
  287. else:
  288. key_layer = (
  289. self.key(current_states)
  290. .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
  291. .transpose(1, 2)
  292. )
  293. value_layer = (
  294. self.value(current_states)
  295. .view(bsz, -1, self.num_attention_heads, self.attention_head_size)
  296. .transpose(1, 2)
  297. )
  298. if past_key_values is not None:
  299. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  300. cache_position = cache_position if not is_cross_attention else None
  301. key_layer, value_layer = curr_past_key_value.update(
  302. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  303. )
  304. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  305. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  306. past_key_values.is_updated[self.layer_idx] = True
  307. # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment
  308. # in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
  309. # The tgt_len > 1 is necessary to match with AttentionMaskConverter.to_causal_4d that does not create
  310. # a causal mask in case tgt_len == 1.
  311. is_causal = self.is_decoder and not is_cross_attention and attention_mask is None and tgt_len > 1
  312. attn_output = torch.nn.functional.scaled_dot_product_attention(
  313. query_layer,
  314. key_layer,
  315. value_layer,
  316. attn_mask=attention_mask,
  317. dropout_p=self.dropout_prob if self.training else 0.0,
  318. is_causal=is_causal,
  319. )
  320. attn_output = attn_output.transpose(1, 2)
  321. attn_output = attn_output.reshape(bsz, tgt_len, self.all_head_size)
  322. return attn_output, None
  323. # Copied from transformers.models.bert.modeling_bert.BertSelfOutput
  324. class RobertaSelfOutput(nn.Module):
  325. def __init__(self, config):
  326. super().__init__()
  327. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  328. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  329. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  330. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  331. hidden_states = self.dense(hidden_states)
  332. hidden_states = self.dropout(hidden_states)
  333. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  334. return hidden_states
  335. ROBERTA_SELF_ATTENTION_CLASSES = {
  336. "eager": RobertaSelfAttention,
  337. "sdpa": RobertaSdpaSelfAttention,
  338. }
  339. # Copied from transformers.models.bert.modeling_bert.BertAttention with Bert->Roberta,BERT->ROBERTA
  340. class RobertaAttention(nn.Module):
  341. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  342. super().__init__()
  343. self.self = ROBERTA_SELF_ATTENTION_CLASSES[config._attn_implementation](
  344. config,
  345. position_embedding_type=position_embedding_type,
  346. layer_idx=layer_idx,
  347. )
  348. self.output = RobertaSelfOutput(config)
  349. self.pruned_heads = set()
  350. def prune_heads(self, heads):
  351. if len(heads) == 0:
  352. return
  353. heads, index = find_pruneable_heads_and_indices(
  354. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  355. )
  356. # Prune linear layers
  357. self.self.query = prune_linear_layer(self.self.query, index)
  358. self.self.key = prune_linear_layer(self.self.key, index)
  359. self.self.value = prune_linear_layer(self.self.value, index)
  360. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  361. # Update hyper params and store pruned heads
  362. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  363. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  364. self.pruned_heads = self.pruned_heads.union(heads)
  365. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  366. def forward(
  367. self,
  368. hidden_states: torch.Tensor,
  369. attention_mask: Optional[torch.FloatTensor] = None,
  370. head_mask: Optional[torch.FloatTensor] = None,
  371. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  372. past_key_values: Optional[Cache] = None,
  373. output_attentions: Optional[bool] = False,
  374. cache_position: Optional[torch.Tensor] = None,
  375. ) -> tuple[torch.Tensor]:
  376. self_outputs = self.self(
  377. hidden_states,
  378. attention_mask=attention_mask,
  379. head_mask=head_mask,
  380. encoder_hidden_states=encoder_hidden_states,
  381. past_key_values=past_key_values,
  382. output_attentions=output_attentions,
  383. cache_position=cache_position,
  384. )
  385. attention_output = self.output(self_outputs[0], hidden_states)
  386. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  387. return outputs
  388. # Copied from transformers.models.bert.modeling_bert.BertIntermediate
  389. class RobertaIntermediate(nn.Module):
  390. def __init__(self, config):
  391. super().__init__()
  392. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  393. if isinstance(config.hidden_act, str):
  394. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  395. else:
  396. self.intermediate_act_fn = config.hidden_act
  397. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  398. hidden_states = self.dense(hidden_states)
  399. hidden_states = self.intermediate_act_fn(hidden_states)
  400. return hidden_states
  401. # Copied from transformers.models.bert.modeling_bert.BertOutput
  402. class RobertaOutput(nn.Module):
  403. def __init__(self, config):
  404. super().__init__()
  405. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  406. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  407. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  408. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  409. hidden_states = self.dense(hidden_states)
  410. hidden_states = self.dropout(hidden_states)
  411. hidden_states = self.LayerNorm(hidden_states + input_tensor)
  412. return hidden_states
  413. # Copied from transformers.models.bert.modeling_bert.BertLayer with Bert->Roberta
  414. class RobertaLayer(GradientCheckpointingLayer):
  415. def __init__(self, config, layer_idx=None):
  416. super().__init__()
  417. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  418. self.seq_len_dim = 1
  419. self.attention = RobertaAttention(config, layer_idx=layer_idx)
  420. self.is_decoder = config.is_decoder
  421. self.add_cross_attention = config.add_cross_attention
  422. if self.add_cross_attention:
  423. if not self.is_decoder:
  424. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  425. self.crossattention = RobertaAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
  426. self.intermediate = RobertaIntermediate(config)
  427. self.output = RobertaOutput(config)
  428. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  429. def forward(
  430. self,
  431. hidden_states: torch.Tensor,
  432. attention_mask: Optional[torch.FloatTensor] = None,
  433. head_mask: Optional[torch.FloatTensor] = None,
  434. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  435. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  436. past_key_values: Optional[Cache] = None,
  437. output_attentions: Optional[bool] = False,
  438. cache_position: Optional[torch.Tensor] = None,
  439. ) -> tuple[torch.Tensor]:
  440. self_attention_outputs = self.attention(
  441. hidden_states,
  442. attention_mask=attention_mask,
  443. head_mask=head_mask,
  444. output_attentions=output_attentions,
  445. past_key_values=past_key_values,
  446. cache_position=cache_position,
  447. )
  448. attention_output = self_attention_outputs[0]
  449. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  450. if self.is_decoder and encoder_hidden_states is not None:
  451. if not hasattr(self, "crossattention"):
  452. raise ValueError(
  453. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  454. " by setting `config.add_cross_attention=True`"
  455. )
  456. cross_attention_outputs = self.crossattention(
  457. attention_output,
  458. attention_mask=encoder_attention_mask,
  459. head_mask=head_mask,
  460. encoder_hidden_states=encoder_hidden_states,
  461. past_key_values=past_key_values,
  462. output_attentions=output_attentions,
  463. cache_position=cache_position,
  464. )
  465. attention_output = cross_attention_outputs[0]
  466. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  467. layer_output = apply_chunking_to_forward(
  468. self.feed_forward_chunk, self.chunk_size_feed_forward, self.seq_len_dim, attention_output
  469. )
  470. outputs = (layer_output,) + outputs
  471. return outputs
  472. def feed_forward_chunk(self, attention_output):
  473. intermediate_output = self.intermediate(attention_output)
  474. layer_output = self.output(intermediate_output, attention_output)
  475. return layer_output
  476. # Copied from transformers.models.bert.modeling_bert.BertEncoder with Bert->Roberta
  477. class RobertaEncoder(nn.Module):
  478. def __init__(self, config, layer_idx=None):
  479. super().__init__()
  480. self.config = config
  481. self.layer = nn.ModuleList([RobertaLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  482. self.gradient_checkpointing = False
  483. def forward(
  484. self,
  485. hidden_states: torch.Tensor,
  486. attention_mask: Optional[torch.FloatTensor] = None,
  487. head_mask: Optional[torch.FloatTensor] = None,
  488. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  489. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  490. past_key_values: Optional[Cache] = None,
  491. use_cache: Optional[bool] = None,
  492. output_attentions: Optional[bool] = False,
  493. output_hidden_states: Optional[bool] = False,
  494. return_dict: Optional[bool] = True,
  495. cache_position: Optional[torch.Tensor] = None,
  496. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  497. all_hidden_states = () if output_hidden_states else None
  498. all_self_attentions = () if output_attentions else None
  499. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  500. if self.gradient_checkpointing and self.training:
  501. if use_cache:
  502. logger.warning_once(
  503. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  504. )
  505. use_cache = False
  506. if use_cache and self.config.is_decoder and past_key_values is None:
  507. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  508. if use_cache and self.config.is_decoder and isinstance(past_key_values, tuple):
  509. logger.warning_once(
  510. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  511. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  512. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  513. )
  514. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  515. for i, layer_module in enumerate(self.layer):
  516. if output_hidden_states:
  517. all_hidden_states = all_hidden_states + (hidden_states,)
  518. layer_head_mask = head_mask[i] if head_mask is not None else None
  519. layer_outputs = layer_module(
  520. hidden_states,
  521. attention_mask,
  522. layer_head_mask,
  523. encoder_hidden_states, # as a positional argument for gradient checkpointing
  524. encoder_attention_mask=encoder_attention_mask,
  525. past_key_values=past_key_values,
  526. output_attentions=output_attentions,
  527. cache_position=cache_position,
  528. )
  529. hidden_states = layer_outputs[0]
  530. if output_attentions:
  531. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  532. if self.config.add_cross_attention:
  533. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  534. if output_hidden_states:
  535. all_hidden_states = all_hidden_states + (hidden_states,)
  536. if not return_dict:
  537. return tuple(
  538. v
  539. for v in [
  540. hidden_states,
  541. past_key_values,
  542. all_hidden_states,
  543. all_self_attentions,
  544. all_cross_attentions,
  545. ]
  546. if v is not None
  547. )
  548. return BaseModelOutputWithPastAndCrossAttentions(
  549. last_hidden_state=hidden_states,
  550. past_key_values=past_key_values,
  551. hidden_states=all_hidden_states,
  552. attentions=all_self_attentions,
  553. cross_attentions=all_cross_attentions,
  554. )
  555. # Copied from transformers.models.bert.modeling_bert.BertPooler
  556. class RobertaPooler(nn.Module):
  557. def __init__(self, config):
  558. super().__init__()
  559. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  560. self.activation = nn.Tanh()
  561. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  562. # We "pool" the model by simply taking the hidden state corresponding
  563. # to the first token.
  564. first_token_tensor = hidden_states[:, 0]
  565. pooled_output = self.dense(first_token_tensor)
  566. pooled_output = self.activation(pooled_output)
  567. return pooled_output
  568. @auto_docstring
  569. class RobertaPreTrainedModel(PreTrainedModel):
  570. config: RobertaConfig
  571. base_model_prefix = "roberta"
  572. supports_gradient_checkpointing = True
  573. _no_split_modules = ["RobertaEmbeddings", "RobertaSelfAttention", "RobertaSdpaSelfAttention"]
  574. _supports_sdpa = True
  575. # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->RobertaLMHead
  576. def _init_weights(self, module):
  577. """Initialize the weights"""
  578. if isinstance(module, nn.Linear):
  579. # Slightly different from the TF version which uses truncated_normal for initialization
  580. # cf https://github.com/pytorch/pytorch/pull/5617
  581. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  582. if module.bias is not None:
  583. module.bias.data.zero_()
  584. elif isinstance(module, nn.Embedding):
  585. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  586. if module.padding_idx is not None:
  587. module.weight.data[module.padding_idx].zero_()
  588. elif isinstance(module, nn.LayerNorm):
  589. module.bias.data.zero_()
  590. module.weight.data.fill_(1.0)
  591. elif isinstance(module, RobertaLMHead):
  592. module.bias.data.zero_()
  593. @auto_docstring(
  594. custom_intro="""
  595. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  596. cross-attention is added between the self-attention layers, following the architecture described in [Attention is
  597. all you need](https://huggingface.co/papers/1706.03762) by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit,
  598. Llion Jones, Aidan N. Gomez, Lukasz Kaiser and Illia Polosukhin.
  599. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  600. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  601. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  602. """
  603. )
  604. # Copied from transformers.models.bert.modeling_bert.BertModel with Bert->Roberta, BERT->ROBERTA
  605. class RobertaModel(RobertaPreTrainedModel):
  606. _no_split_modules = ["RobertaEmbeddings", "RobertaLayer"]
  607. def __init__(self, config, add_pooling_layer=True):
  608. r"""
  609. add_pooling_layer (bool, *optional*, defaults to `True`):
  610. Whether to add a pooling layer
  611. """
  612. super().__init__(config)
  613. self.config = config
  614. self.embeddings = RobertaEmbeddings(config)
  615. self.encoder = RobertaEncoder(config)
  616. self.pooler = RobertaPooler(config) if add_pooling_layer else None
  617. self.attn_implementation = config._attn_implementation
  618. self.position_embedding_type = config.position_embedding_type
  619. # Initialize weights and apply final processing
  620. self.post_init()
  621. def get_input_embeddings(self):
  622. return self.embeddings.word_embeddings
  623. def set_input_embeddings(self, value):
  624. self.embeddings.word_embeddings = value
  625. def _prune_heads(self, heads_to_prune):
  626. """
  627. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  628. class PreTrainedModel
  629. """
  630. for layer, heads in heads_to_prune.items():
  631. self.encoder.layer[layer].attention.prune_heads(heads)
  632. @auto_docstring
  633. def forward(
  634. self,
  635. input_ids: Optional[torch.Tensor] = None,
  636. attention_mask: Optional[torch.Tensor] = None,
  637. token_type_ids: Optional[torch.Tensor] = None,
  638. position_ids: Optional[torch.Tensor] = None,
  639. head_mask: Optional[torch.Tensor] = None,
  640. inputs_embeds: Optional[torch.Tensor] = None,
  641. encoder_hidden_states: Optional[torch.Tensor] = None,
  642. encoder_attention_mask: Optional[torch.Tensor] = None,
  643. past_key_values: Optional[Cache] = None,
  644. use_cache: Optional[bool] = None,
  645. output_attentions: Optional[bool] = None,
  646. output_hidden_states: Optional[bool] = None,
  647. return_dict: Optional[bool] = None,
  648. cache_position: Optional[torch.Tensor] = None,
  649. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  650. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  651. output_hidden_states = (
  652. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  653. )
  654. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  655. if self.config.is_decoder:
  656. use_cache = use_cache if use_cache is not None else self.config.use_cache
  657. else:
  658. use_cache = False
  659. if input_ids is not None and inputs_embeds is not None:
  660. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  661. elif input_ids is not None:
  662. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  663. input_shape = input_ids.size()
  664. elif inputs_embeds is not None:
  665. input_shape = inputs_embeds.size()[:-1]
  666. else:
  667. raise ValueError("You have to specify either input_ids or inputs_embeds")
  668. batch_size, seq_length = input_shape
  669. device = input_ids.device if input_ids is not None else inputs_embeds.device
  670. past_key_values_length = 0
  671. if past_key_values is not None:
  672. past_key_values_length = (
  673. past_key_values[0][0].shape[-2]
  674. if not isinstance(past_key_values, Cache)
  675. else past_key_values.get_seq_length()
  676. )
  677. if token_type_ids is None:
  678. if hasattr(self.embeddings, "token_type_ids"):
  679. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  680. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  681. token_type_ids = buffered_token_type_ids_expanded
  682. else:
  683. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  684. embedding_output = self.embeddings(
  685. input_ids=input_ids,
  686. position_ids=position_ids,
  687. token_type_ids=token_type_ids,
  688. inputs_embeds=inputs_embeds,
  689. past_key_values_length=past_key_values_length,
  690. )
  691. if attention_mask is None:
  692. attention_mask = torch.ones((batch_size, seq_length + past_key_values_length), device=device)
  693. use_sdpa_attention_masks = (
  694. self.attn_implementation == "sdpa"
  695. and self.position_embedding_type == "absolute"
  696. and head_mask is None
  697. and not output_attentions
  698. )
  699. # Expand the attention mask
  700. if use_sdpa_attention_masks and attention_mask.dim() == 2:
  701. # Expand the attention mask for SDPA.
  702. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  703. if self.config.is_decoder:
  704. extended_attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
  705. attention_mask,
  706. input_shape,
  707. embedding_output,
  708. past_key_values_length,
  709. )
  710. else:
  711. extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  712. attention_mask, embedding_output.dtype, tgt_len=seq_length
  713. )
  714. else:
  715. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  716. # ourselves in which case we just need to make it broadcastable to all heads.
  717. extended_attention_mask = self.get_extended_attention_mask(attention_mask, input_shape)
  718. # If a 2D or 3D attention mask is provided for the cross-attention
  719. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  720. if self.config.is_decoder and encoder_hidden_states is not None:
  721. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  722. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  723. if encoder_attention_mask is None:
  724. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  725. if use_sdpa_attention_masks and encoder_attention_mask.dim() == 2:
  726. # Expand the attention mask for SDPA.
  727. # [bsz, seq_len] -> [bsz, 1, seq_len, seq_len]
  728. encoder_extended_attention_mask = _prepare_4d_attention_mask_for_sdpa(
  729. encoder_attention_mask, embedding_output.dtype, tgt_len=seq_length
  730. )
  731. else:
  732. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  733. else:
  734. encoder_extended_attention_mask = None
  735. # Prepare head mask if needed
  736. # 1.0 in head_mask indicate we keep the head
  737. # attention_probs has shape bsz x n_heads x N x N
  738. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  739. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  740. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  741. encoder_outputs = self.encoder(
  742. embedding_output,
  743. attention_mask=extended_attention_mask,
  744. head_mask=head_mask,
  745. encoder_hidden_states=encoder_hidden_states,
  746. encoder_attention_mask=encoder_extended_attention_mask,
  747. past_key_values=past_key_values,
  748. use_cache=use_cache,
  749. output_attentions=output_attentions,
  750. output_hidden_states=output_hidden_states,
  751. return_dict=return_dict,
  752. cache_position=cache_position,
  753. )
  754. sequence_output = encoder_outputs[0]
  755. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  756. if not return_dict:
  757. return (sequence_output, pooled_output) + encoder_outputs[1:]
  758. return BaseModelOutputWithPoolingAndCrossAttentions(
  759. last_hidden_state=sequence_output,
  760. pooler_output=pooled_output,
  761. past_key_values=encoder_outputs.past_key_values,
  762. hidden_states=encoder_outputs.hidden_states,
  763. attentions=encoder_outputs.attentions,
  764. cross_attentions=encoder_outputs.cross_attentions,
  765. )
  766. @auto_docstring(
  767. custom_intro="""
  768. RoBERTa Model with a `language modeling` head on top for CLM fine-tuning.
  769. """
  770. )
  771. class RobertaForCausalLM(RobertaPreTrainedModel, GenerationMixin):
  772. _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
  773. def __init__(self, config):
  774. super().__init__(config)
  775. if not config.is_decoder:
  776. logger.warning("If you want to use `RobertaLMHeadModel` as a standalone, add `is_decoder=True.`")
  777. self.roberta = RobertaModel(config, add_pooling_layer=False)
  778. self.lm_head = RobertaLMHead(config)
  779. # Initialize weights and apply final processing
  780. self.post_init()
  781. def get_output_embeddings(self):
  782. return self.lm_head.decoder
  783. def set_output_embeddings(self, new_embeddings):
  784. self.lm_head.decoder = new_embeddings
  785. @auto_docstring
  786. def forward(
  787. self,
  788. input_ids: Optional[torch.LongTensor] = None,
  789. attention_mask: Optional[torch.FloatTensor] = None,
  790. token_type_ids: Optional[torch.LongTensor] = None,
  791. position_ids: Optional[torch.LongTensor] = None,
  792. head_mask: Optional[torch.FloatTensor] = None,
  793. inputs_embeds: Optional[torch.FloatTensor] = None,
  794. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  795. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  796. labels: Optional[torch.LongTensor] = None,
  797. past_key_values: Optional[Cache] = None,
  798. use_cache: Optional[bool] = None,
  799. output_attentions: Optional[bool] = None,
  800. output_hidden_states: Optional[bool] = None,
  801. return_dict: Optional[bool] = None,
  802. **kwargs,
  803. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  804. r"""
  805. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  806. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  807. - 0 corresponds to a *sentence A* token,
  808. - 1 corresponds to a *sentence B* token.
  809. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  810. >= 2. All the value in this tensor should be always < type_vocab_size.
  811. [What are token type IDs?](../glossary#token-type-ids)
  812. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  813. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  814. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  815. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  816. Example:
  817. ```python
  818. >>> from transformers import AutoTokenizer, RobertaForCausalLM, AutoConfig
  819. >>> import torch
  820. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/roberta-base")
  821. >>> config = AutoConfig.from_pretrained("FacebookAI/roberta-base")
  822. >>> config.is_decoder = True
  823. >>> model = RobertaForCausalLM.from_pretrained("FacebookAI/roberta-base", config=config)
  824. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  825. >>> outputs = model(**inputs)
  826. >>> prediction_logits = outputs.logits
  827. ```"""
  828. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  829. if labels is not None:
  830. use_cache = False
  831. outputs = self.roberta(
  832. input_ids,
  833. attention_mask=attention_mask,
  834. token_type_ids=token_type_ids,
  835. position_ids=position_ids,
  836. head_mask=head_mask,
  837. inputs_embeds=inputs_embeds,
  838. encoder_hidden_states=encoder_hidden_states,
  839. encoder_attention_mask=encoder_attention_mask,
  840. past_key_values=past_key_values,
  841. use_cache=use_cache,
  842. output_attentions=output_attentions,
  843. output_hidden_states=output_hidden_states,
  844. return_dict=return_dict,
  845. )
  846. sequence_output = outputs[0]
  847. prediction_scores = self.lm_head(sequence_output)
  848. lm_loss = None
  849. if labels is not None:
  850. # move labels to correct device to enable model parallelism
  851. labels = labels.to(prediction_scores.device)
  852. lm_loss = self.loss_function(
  853. prediction_scores,
  854. labels,
  855. vocab_size=self.config.vocab_size,
  856. **kwargs,
  857. )
  858. if not return_dict:
  859. output = (prediction_scores,) + outputs[2:]
  860. return ((lm_loss,) + output) if lm_loss is not None else output
  861. return CausalLMOutputWithCrossAttentions(
  862. loss=lm_loss,
  863. logits=prediction_scores,
  864. past_key_values=outputs.past_key_values,
  865. hidden_states=outputs.hidden_states,
  866. attentions=outputs.attentions,
  867. cross_attentions=outputs.cross_attentions,
  868. )
  869. @auto_docstring
  870. class RobertaForMaskedLM(RobertaPreTrainedModel):
  871. _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
  872. def __init__(self, config):
  873. super().__init__(config)
  874. if config.is_decoder:
  875. logger.warning(
  876. "If you want to use `RobertaForMaskedLM` make sure `config.is_decoder=False` for "
  877. "bi-directional self-attention."
  878. )
  879. self.roberta = RobertaModel(config, add_pooling_layer=False)
  880. self.lm_head = RobertaLMHead(config)
  881. # Initialize weights and apply final processing
  882. self.post_init()
  883. def get_output_embeddings(self):
  884. return self.lm_head.decoder
  885. def set_output_embeddings(self, new_embeddings):
  886. self.lm_head.decoder = new_embeddings
  887. @auto_docstring
  888. def forward(
  889. self,
  890. input_ids: Optional[torch.LongTensor] = None,
  891. attention_mask: Optional[torch.FloatTensor] = None,
  892. token_type_ids: Optional[torch.LongTensor] = None,
  893. position_ids: Optional[torch.LongTensor] = None,
  894. head_mask: Optional[torch.FloatTensor] = None,
  895. inputs_embeds: Optional[torch.FloatTensor] = None,
  896. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  897. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  898. labels: Optional[torch.LongTensor] = None,
  899. output_attentions: Optional[bool] = None,
  900. output_hidden_states: Optional[bool] = None,
  901. return_dict: Optional[bool] = None,
  902. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  903. r"""
  904. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  905. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  906. - 0 corresponds to a *sentence A* token,
  907. - 1 corresponds to a *sentence B* token.
  908. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  909. >= 2. All the value in this tensor should be always < type_vocab_size.
  910. [What are token type IDs?](../glossary#token-type-ids)
  911. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  912. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  913. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  914. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  915. """
  916. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  917. outputs = self.roberta(
  918. input_ids,
  919. attention_mask=attention_mask,
  920. token_type_ids=token_type_ids,
  921. position_ids=position_ids,
  922. head_mask=head_mask,
  923. inputs_embeds=inputs_embeds,
  924. encoder_hidden_states=encoder_hidden_states,
  925. encoder_attention_mask=encoder_attention_mask,
  926. output_attentions=output_attentions,
  927. output_hidden_states=output_hidden_states,
  928. return_dict=return_dict,
  929. )
  930. sequence_output = outputs[0]
  931. prediction_scores = self.lm_head(sequence_output)
  932. masked_lm_loss = None
  933. if labels is not None:
  934. # move labels to correct device to enable model parallelism
  935. labels = labels.to(prediction_scores.device)
  936. loss_fct = CrossEntropyLoss()
  937. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  938. if not return_dict:
  939. output = (prediction_scores,) + outputs[2:]
  940. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  941. return MaskedLMOutput(
  942. loss=masked_lm_loss,
  943. logits=prediction_scores,
  944. hidden_states=outputs.hidden_states,
  945. attentions=outputs.attentions,
  946. )
  947. class RobertaLMHead(nn.Module):
  948. """Roberta Head for masked language modeling."""
  949. def __init__(self, config):
  950. super().__init__()
  951. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  952. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  953. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  954. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  955. self.decoder.bias = self.bias
  956. def forward(self, features, **kwargs):
  957. x = self.dense(features)
  958. x = gelu(x)
  959. x = self.layer_norm(x)
  960. # project back to size of vocabulary with bias
  961. x = self.decoder(x)
  962. return x
  963. def _tie_weights(self):
  964. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  965. # For accelerate compatibility and to not break backward compatibility
  966. if self.decoder.bias.device.type == "meta":
  967. self.decoder.bias = self.bias
  968. else:
  969. self.bias = self.decoder.bias
  970. @auto_docstring(
  971. custom_intro="""
  972. RoBERTa Model transformer with a sequence classification/regression head on top (a linear layer on top of the
  973. pooled output) e.g. for GLUE tasks.
  974. """
  975. )
  976. class RobertaForSequenceClassification(RobertaPreTrainedModel):
  977. def __init__(self, config):
  978. super().__init__(config)
  979. self.num_labels = config.num_labels
  980. self.config = config
  981. self.roberta = RobertaModel(config, add_pooling_layer=False)
  982. self.classifier = RobertaClassificationHead(config)
  983. # Initialize weights and apply final processing
  984. self.post_init()
  985. @auto_docstring
  986. def forward(
  987. self,
  988. input_ids: Optional[torch.LongTensor] = None,
  989. attention_mask: Optional[torch.FloatTensor] = None,
  990. token_type_ids: Optional[torch.LongTensor] = None,
  991. position_ids: Optional[torch.LongTensor] = None,
  992. head_mask: Optional[torch.FloatTensor] = None,
  993. inputs_embeds: Optional[torch.FloatTensor] = None,
  994. labels: Optional[torch.LongTensor] = None,
  995. output_attentions: Optional[bool] = None,
  996. output_hidden_states: Optional[bool] = None,
  997. return_dict: Optional[bool] = None,
  998. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  999. r"""
  1000. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1001. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  1002. - 0 corresponds to a *sentence A* token,
  1003. - 1 corresponds to a *sentence B* token.
  1004. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  1005. >= 2. All the value in this tensor should be always < type_vocab_size.
  1006. [What are token type IDs?](../glossary#token-type-ids)
  1007. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1008. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1009. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  1010. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1011. """
  1012. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1013. outputs = self.roberta(
  1014. input_ids,
  1015. attention_mask=attention_mask,
  1016. token_type_ids=token_type_ids,
  1017. position_ids=position_ids,
  1018. head_mask=head_mask,
  1019. inputs_embeds=inputs_embeds,
  1020. output_attentions=output_attentions,
  1021. output_hidden_states=output_hidden_states,
  1022. return_dict=return_dict,
  1023. )
  1024. sequence_output = outputs[0]
  1025. logits = self.classifier(sequence_output)
  1026. loss = None
  1027. if labels is not None:
  1028. # move labels to correct device to enable model parallelism
  1029. labels = labels.to(logits.device)
  1030. if self.config.problem_type is None:
  1031. if self.num_labels == 1:
  1032. self.config.problem_type = "regression"
  1033. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1034. self.config.problem_type = "single_label_classification"
  1035. else:
  1036. self.config.problem_type = "multi_label_classification"
  1037. if self.config.problem_type == "regression":
  1038. loss_fct = MSELoss()
  1039. if self.num_labels == 1:
  1040. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1041. else:
  1042. loss = loss_fct(logits, labels)
  1043. elif self.config.problem_type == "single_label_classification":
  1044. loss_fct = CrossEntropyLoss()
  1045. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1046. elif self.config.problem_type == "multi_label_classification":
  1047. loss_fct = BCEWithLogitsLoss()
  1048. loss = loss_fct(logits, labels)
  1049. if not return_dict:
  1050. output = (logits,) + outputs[2:]
  1051. return ((loss,) + output) if loss is not None else output
  1052. return SequenceClassifierOutput(
  1053. loss=loss,
  1054. logits=logits,
  1055. hidden_states=outputs.hidden_states,
  1056. attentions=outputs.attentions,
  1057. )
  1058. @auto_docstring
  1059. class RobertaForMultipleChoice(RobertaPreTrainedModel):
  1060. def __init__(self, config):
  1061. super().__init__(config)
  1062. self.roberta = RobertaModel(config)
  1063. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1064. self.classifier = nn.Linear(config.hidden_size, 1)
  1065. # Initialize weights and apply final processing
  1066. self.post_init()
  1067. @auto_docstring
  1068. def forward(
  1069. self,
  1070. input_ids: Optional[torch.LongTensor] = None,
  1071. token_type_ids: Optional[torch.LongTensor] = None,
  1072. attention_mask: Optional[torch.FloatTensor] = None,
  1073. labels: Optional[torch.LongTensor] = None,
  1074. position_ids: Optional[torch.LongTensor] = None,
  1075. head_mask: Optional[torch.FloatTensor] = None,
  1076. inputs_embeds: Optional[torch.FloatTensor] = None,
  1077. output_attentions: Optional[bool] = None,
  1078. output_hidden_states: Optional[bool] = None,
  1079. return_dict: Optional[bool] = None,
  1080. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1081. r"""
  1082. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1083. Indices of input sequence tokens in the vocabulary.
  1084. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1085. [`PreTrainedTokenizer.__call__`] for details.
  1086. [What are input IDs?](../glossary#input-ids)
  1087. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1088. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  1089. - 0 corresponds to a *sentence A* token,
  1090. - 1 corresponds to a *sentence B* token.
  1091. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  1092. >= 2. All the value in this tensor should be always < type_vocab_size.
  1093. [What are token type IDs?](../glossary#token-type-ids)
  1094. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1095. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1096. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1097. `input_ids` above)
  1098. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1099. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1100. config.max_position_embeddings - 1]`.
  1101. [What are position IDs?](../glossary#position-ids)
  1102. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1103. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1104. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1105. model's internal embedding lookup matrix.
  1106. """
  1107. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1108. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1109. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1110. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1111. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1112. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1113. flat_inputs_embeds = (
  1114. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1115. if inputs_embeds is not None
  1116. else None
  1117. )
  1118. outputs = self.roberta(
  1119. flat_input_ids,
  1120. position_ids=flat_position_ids,
  1121. token_type_ids=flat_token_type_ids,
  1122. attention_mask=flat_attention_mask,
  1123. head_mask=head_mask,
  1124. inputs_embeds=flat_inputs_embeds,
  1125. output_attentions=output_attentions,
  1126. output_hidden_states=output_hidden_states,
  1127. return_dict=return_dict,
  1128. )
  1129. pooled_output = outputs[1]
  1130. pooled_output = self.dropout(pooled_output)
  1131. logits = self.classifier(pooled_output)
  1132. reshaped_logits = logits.view(-1, num_choices)
  1133. loss = None
  1134. if labels is not None:
  1135. # move labels to correct device to enable model parallelism
  1136. labels = labels.to(reshaped_logits.device)
  1137. loss_fct = CrossEntropyLoss()
  1138. loss = loss_fct(reshaped_logits, labels)
  1139. if not return_dict:
  1140. output = (reshaped_logits,) + outputs[2:]
  1141. return ((loss,) + output) if loss is not None else output
  1142. return MultipleChoiceModelOutput(
  1143. loss=loss,
  1144. logits=reshaped_logits,
  1145. hidden_states=outputs.hidden_states,
  1146. attentions=outputs.attentions,
  1147. )
  1148. @auto_docstring
  1149. class RobertaForTokenClassification(RobertaPreTrainedModel):
  1150. def __init__(self, config):
  1151. super().__init__(config)
  1152. self.num_labels = config.num_labels
  1153. self.roberta = RobertaModel(config, add_pooling_layer=False)
  1154. classifier_dropout = (
  1155. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1156. )
  1157. self.dropout = nn.Dropout(classifier_dropout)
  1158. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1159. # Initialize weights and apply final processing
  1160. self.post_init()
  1161. @auto_docstring
  1162. def forward(
  1163. self,
  1164. input_ids: Optional[torch.LongTensor] = None,
  1165. attention_mask: Optional[torch.FloatTensor] = None,
  1166. token_type_ids: Optional[torch.LongTensor] = None,
  1167. position_ids: Optional[torch.LongTensor] = None,
  1168. head_mask: Optional[torch.FloatTensor] = None,
  1169. inputs_embeds: Optional[torch.FloatTensor] = None,
  1170. labels: Optional[torch.LongTensor] = None,
  1171. output_attentions: Optional[bool] = None,
  1172. output_hidden_states: Optional[bool] = None,
  1173. return_dict: Optional[bool] = None,
  1174. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1175. r"""
  1176. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1177. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  1178. - 0 corresponds to a *sentence A* token,
  1179. - 1 corresponds to a *sentence B* token.
  1180. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  1181. >= 2. All the value in this tensor should be always < type_vocab_size.
  1182. [What are token type IDs?](../glossary#token-type-ids)
  1183. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1184. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1185. """
  1186. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1187. outputs = self.roberta(
  1188. input_ids,
  1189. attention_mask=attention_mask,
  1190. token_type_ids=token_type_ids,
  1191. position_ids=position_ids,
  1192. head_mask=head_mask,
  1193. inputs_embeds=inputs_embeds,
  1194. output_attentions=output_attentions,
  1195. output_hidden_states=output_hidden_states,
  1196. return_dict=return_dict,
  1197. )
  1198. sequence_output = outputs[0]
  1199. sequence_output = self.dropout(sequence_output)
  1200. logits = self.classifier(sequence_output)
  1201. loss = None
  1202. if labels is not None:
  1203. # move labels to correct device to enable model parallelism
  1204. labels = labels.to(logits.device)
  1205. loss_fct = CrossEntropyLoss()
  1206. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1207. if not return_dict:
  1208. output = (logits,) + outputs[2:]
  1209. return ((loss,) + output) if loss is not None else output
  1210. return TokenClassifierOutput(
  1211. loss=loss,
  1212. logits=logits,
  1213. hidden_states=outputs.hidden_states,
  1214. attentions=outputs.attentions,
  1215. )
  1216. class RobertaClassificationHead(nn.Module):
  1217. """Head for sentence-level classification tasks."""
  1218. def __init__(self, config):
  1219. super().__init__()
  1220. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1221. classifier_dropout = (
  1222. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1223. )
  1224. self.dropout = nn.Dropout(classifier_dropout)
  1225. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  1226. def forward(self, features, **kwargs):
  1227. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  1228. x = self.dropout(x)
  1229. x = self.dense(x)
  1230. x = torch.tanh(x)
  1231. x = self.dropout(x)
  1232. x = self.out_proj(x)
  1233. return x
  1234. @auto_docstring
  1235. class RobertaForQuestionAnswering(RobertaPreTrainedModel):
  1236. def __init__(self, config):
  1237. super().__init__(config)
  1238. self.num_labels = config.num_labels
  1239. self.roberta = RobertaModel(config, add_pooling_layer=False)
  1240. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1241. # Initialize weights and apply final processing
  1242. self.post_init()
  1243. @auto_docstring
  1244. def forward(
  1245. self,
  1246. input_ids: Optional[torch.LongTensor] = None,
  1247. attention_mask: Optional[torch.FloatTensor] = None,
  1248. token_type_ids: Optional[torch.LongTensor] = None,
  1249. position_ids: Optional[torch.LongTensor] = None,
  1250. head_mask: Optional[torch.FloatTensor] = None,
  1251. inputs_embeds: Optional[torch.FloatTensor] = None,
  1252. start_positions: Optional[torch.LongTensor] = None,
  1253. end_positions: Optional[torch.LongTensor] = None,
  1254. output_attentions: Optional[bool] = None,
  1255. output_hidden_states: Optional[bool] = None,
  1256. return_dict: Optional[bool] = None,
  1257. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1258. r"""
  1259. token_type_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1260. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,1]`:
  1261. - 0 corresponds to a *sentence A* token,
  1262. - 1 corresponds to a *sentence B* token.
  1263. This parameter can only be used when the model is initialized with `type_vocab_size` parameter with value
  1264. >= 2. All the value in this tensor should be always < type_vocab_size.
  1265. [What are token type IDs?](../glossary#token-type-ids)
  1266. """
  1267. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1268. outputs = self.roberta(
  1269. input_ids,
  1270. attention_mask=attention_mask,
  1271. token_type_ids=token_type_ids,
  1272. position_ids=position_ids,
  1273. head_mask=head_mask,
  1274. inputs_embeds=inputs_embeds,
  1275. output_attentions=output_attentions,
  1276. output_hidden_states=output_hidden_states,
  1277. return_dict=return_dict,
  1278. )
  1279. sequence_output = outputs[0]
  1280. logits = self.qa_outputs(sequence_output)
  1281. start_logits, end_logits = logits.split(1, dim=-1)
  1282. start_logits = start_logits.squeeze(-1).contiguous()
  1283. end_logits = end_logits.squeeze(-1).contiguous()
  1284. total_loss = None
  1285. if start_positions is not None and end_positions is not None:
  1286. # If we are on multi-GPU, split add a dimension
  1287. if len(start_positions.size()) > 1:
  1288. start_positions = start_positions.squeeze(-1)
  1289. if len(end_positions.size()) > 1:
  1290. end_positions = end_positions.squeeze(-1)
  1291. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1292. ignored_index = start_logits.size(1)
  1293. start_positions = start_positions.clamp(0, ignored_index)
  1294. end_positions = end_positions.clamp(0, ignored_index)
  1295. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1296. start_loss = loss_fct(start_logits, start_positions)
  1297. end_loss = loss_fct(end_logits, end_positions)
  1298. total_loss = (start_loss + end_loss) / 2
  1299. if not return_dict:
  1300. output = (start_logits, end_logits) + outputs[2:]
  1301. return ((total_loss,) + output) if total_loss is not None else output
  1302. return QuestionAnsweringModelOutput(
  1303. loss=total_loss,
  1304. start_logits=start_logits,
  1305. end_logits=end_logits,
  1306. hidden_states=outputs.hidden_states,
  1307. attentions=outputs.attentions,
  1308. )
  1309. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  1310. """
  1311. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  1312. are ignored. This is modified from fairseq's `utils.make_positions`.
  1313. Args:
  1314. x: torch.Tensor x:
  1315. Returns: torch.Tensor
  1316. """
  1317. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1318. mask = input_ids.ne(padding_idx).int()
  1319. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1320. return incremental_indices.long() + padding_idx
  1321. __all__ = [
  1322. "RobertaForCausalLM",
  1323. "RobertaForMaskedLM",
  1324. "RobertaForMultipleChoice",
  1325. "RobertaForQuestionAnswering",
  1326. "RobertaForSequenceClassification",
  1327. "RobertaForTokenClassification",
  1328. "RobertaModel",
  1329. "RobertaPreTrainedModel",
  1330. ]