modeling_xmod.py 67 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517
  1. # coding=utf-8
  2. # Copyright 2023 Meta AI Team and the HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch X-MOD model."""
  16. import math
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  21. from ...activations import ACT2FN, gelu
  22. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  23. from ...generation import GenerationMixin
  24. from ...modeling_layers import GradientCheckpointingLayer
  25. from ...modeling_outputs import (
  26. BaseModelOutputWithPastAndCrossAttentions,
  27. BaseModelOutputWithPoolingAndCrossAttentions,
  28. CausalLMOutputWithCrossAttentions,
  29. MaskedLMOutput,
  30. MultipleChoiceModelOutput,
  31. QuestionAnsweringModelOutput,
  32. SequenceClassifierOutput,
  33. TokenClassifierOutput,
  34. )
  35. from ...modeling_utils import PreTrainedModel
  36. from ...pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
  37. from ...utils import auto_docstring, logging
  38. from ...utils.deprecation import deprecate_kwarg
  39. from .configuration_xmod import XmodConfig
  40. logger = logging.get_logger(__name__)
  41. # Copied from transformers.models.roberta.modeling_roberta.RobertaEmbeddings with Roberta->Xmod
  42. class XmodEmbeddings(nn.Module):
  43. """
  44. Same as BertEmbeddings with a tiny tweak for positional embeddings indexing.
  45. """
  46. # Copied from transformers.models.bert.modeling_bert.BertEmbeddings.__init__
  47. def __init__(self, config):
  48. super().__init__()
  49. self.word_embeddings = nn.Embedding(config.vocab_size, config.hidden_size, padding_idx=config.pad_token_id)
  50. self.position_embeddings = nn.Embedding(config.max_position_embeddings, config.hidden_size)
  51. self.token_type_embeddings = nn.Embedding(config.type_vocab_size, config.hidden_size)
  52. # self.LayerNorm is not snake-cased to stick with TensorFlow model variable name and be able to load
  53. # any TensorFlow checkpoint file
  54. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  55. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  56. # position_ids (1, len position emb) is contiguous in memory and exported when serialized
  57. self.position_embedding_type = getattr(config, "position_embedding_type", "absolute")
  58. self.register_buffer(
  59. "position_ids", torch.arange(config.max_position_embeddings).expand((1, -1)), persistent=False
  60. )
  61. self.register_buffer(
  62. "token_type_ids", torch.zeros(self.position_ids.size(), dtype=torch.long), persistent=False
  63. )
  64. # End copy
  65. self.padding_idx = config.pad_token_id
  66. self.position_embeddings = nn.Embedding(
  67. config.max_position_embeddings, config.hidden_size, padding_idx=self.padding_idx
  68. )
  69. def forward(
  70. self, input_ids=None, token_type_ids=None, position_ids=None, inputs_embeds=None, past_key_values_length=0
  71. ):
  72. if position_ids is None:
  73. if input_ids is not None:
  74. # Create the position ids from the input token ids. Any padded tokens remain padded.
  75. position_ids = create_position_ids_from_input_ids(input_ids, self.padding_idx, past_key_values_length)
  76. else:
  77. position_ids = self.create_position_ids_from_inputs_embeds(inputs_embeds)
  78. if input_ids is not None:
  79. input_shape = input_ids.size()
  80. else:
  81. input_shape = inputs_embeds.size()[:-1]
  82. seq_length = input_shape[1]
  83. # Setting the token_type_ids to the registered buffer in constructor where it is all zeros, which usually occurs
  84. # when its auto-generated, registered buffer helps users when tracing the model without passing token_type_ids, solves
  85. # issue #5664
  86. if token_type_ids is None:
  87. if hasattr(self, "token_type_ids"):
  88. buffered_token_type_ids = self.token_type_ids[:, :seq_length]
  89. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length)
  90. token_type_ids = buffered_token_type_ids_expanded
  91. else:
  92. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device)
  93. if inputs_embeds is None:
  94. inputs_embeds = self.word_embeddings(input_ids)
  95. token_type_embeddings = self.token_type_embeddings(token_type_ids)
  96. embeddings = inputs_embeds + token_type_embeddings
  97. if self.position_embedding_type == "absolute":
  98. position_embeddings = self.position_embeddings(position_ids)
  99. embeddings += position_embeddings
  100. embeddings = self.LayerNorm(embeddings)
  101. embeddings = self.dropout(embeddings)
  102. return embeddings
  103. def create_position_ids_from_inputs_embeds(self, inputs_embeds):
  104. """
  105. We are provided embeddings directly. We cannot infer which are padded so just generate sequential position ids.
  106. Args:
  107. inputs_embeds: torch.Tensor
  108. Returns: torch.Tensor
  109. """
  110. input_shape = inputs_embeds.size()[:-1]
  111. sequence_length = input_shape[1]
  112. position_ids = torch.arange(
  113. self.padding_idx + 1, sequence_length + self.padding_idx + 1, dtype=torch.long, device=inputs_embeds.device
  114. )
  115. return position_ids.unsqueeze(0).expand(input_shape)
  116. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfAttention with Roberta->Xmod
  117. class XmodSelfAttention(nn.Module):
  118. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  119. super().__init__()
  120. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  121. raise ValueError(
  122. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  123. f"heads ({config.num_attention_heads})"
  124. )
  125. self.num_attention_heads = config.num_attention_heads
  126. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  127. self.all_head_size = self.num_attention_heads * self.attention_head_size
  128. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  129. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  130. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  131. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  132. self.position_embedding_type = position_embedding_type or getattr(
  133. config, "position_embedding_type", "absolute"
  134. )
  135. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  136. self.max_position_embeddings = config.max_position_embeddings
  137. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  138. self.is_decoder = config.is_decoder
  139. self.layer_idx = layer_idx
  140. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  141. def forward(
  142. self,
  143. hidden_states: torch.Tensor,
  144. attention_mask: Optional[torch.FloatTensor] = None,
  145. head_mask: Optional[torch.FloatTensor] = None,
  146. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  147. past_key_values: Optional[Cache] = None,
  148. output_attentions: Optional[bool] = False,
  149. cache_position: Optional[torch.Tensor] = None,
  150. ) -> tuple[torch.Tensor]:
  151. batch_size, seq_length, _ = hidden_states.shape
  152. query_layer = self.query(hidden_states)
  153. query_layer = query_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  154. 1, 2
  155. )
  156. is_updated = False
  157. is_cross_attention = encoder_hidden_states is not None
  158. if past_key_values is not None:
  159. if isinstance(past_key_values, EncoderDecoderCache):
  160. is_updated = past_key_values.is_updated.get(self.layer_idx)
  161. if is_cross_attention:
  162. # after the first generated id, we can subsequently re-use all key/value_layer from cache
  163. curr_past_key_value = past_key_values.cross_attention_cache
  164. else:
  165. curr_past_key_value = past_key_values.self_attention_cache
  166. else:
  167. curr_past_key_value = past_key_values
  168. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  169. if is_cross_attention and past_key_values is not None and is_updated:
  170. # reuse k,v, cross_attentions
  171. key_layer = curr_past_key_value.layers[self.layer_idx].keys
  172. value_layer = curr_past_key_value.layers[self.layer_idx].values
  173. else:
  174. key_layer = self.key(current_states)
  175. key_layer = key_layer.view(batch_size, -1, self.num_attention_heads, self.attention_head_size).transpose(
  176. 1, 2
  177. )
  178. value_layer = self.value(current_states)
  179. value_layer = value_layer.view(
  180. batch_size, -1, self.num_attention_heads, self.attention_head_size
  181. ).transpose(1, 2)
  182. if past_key_values is not None:
  183. # save all key/value_layer to cache to be re-used for fast auto-regressive generation
  184. cache_position = cache_position if not is_cross_attention else None
  185. key_layer, value_layer = curr_past_key_value.update(
  186. key_layer, value_layer, self.layer_idx, {"cache_position": cache_position}
  187. )
  188. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  189. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  190. past_key_values.is_updated[self.layer_idx] = True
  191. # Take the dot product between "query" and "key" to get the raw attention scores.
  192. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  193. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  194. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  195. if past_key_values is not None:
  196. position_ids_l = torch.tensor(key_length - 1, dtype=torch.long, device=hidden_states.device).view(
  197. -1, 1
  198. )
  199. else:
  200. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  201. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  202. distance = position_ids_l - position_ids_r
  203. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  204. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  205. if self.position_embedding_type == "relative_key":
  206. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  207. attention_scores = attention_scores + relative_position_scores
  208. elif self.position_embedding_type == "relative_key_query":
  209. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  210. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  211. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  212. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  213. if attention_mask is not None:
  214. # Apply the attention mask is (precomputed for all layers in XmodModel forward() function)
  215. attention_scores = attention_scores + attention_mask
  216. # Normalize the attention scores to probabilities.
  217. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  218. # This is actually dropping out entire tokens to attend to, which might
  219. # seem a bit unusual, but is taken from the original Transformer paper.
  220. attention_probs = self.dropout(attention_probs)
  221. # Mask heads if we want to
  222. if head_mask is not None:
  223. attention_probs = attention_probs * head_mask
  224. context_layer = torch.matmul(attention_probs, value_layer)
  225. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  226. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  227. context_layer = context_layer.view(new_context_layer_shape)
  228. return context_layer, attention_probs
  229. class XmodSelfOutput(nn.Module):
  230. # Copied from transformers.models.roberta.modeling_roberta.RobertaSelfOutput.__init__
  231. def __init__(self, config):
  232. super().__init__()
  233. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  234. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  235. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  236. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor) -> torch.Tensor:
  237. hidden_states = self.dense(hidden_states)
  238. hidden_states = self.dropout(hidden_states)
  239. hidden_states = hidden_states + input_tensor
  240. return hidden_states
  241. class XmodAttention(nn.Module):
  242. def __init__(self, config, position_embedding_type=None, layer_idx=None):
  243. super().__init__()
  244. self.self = XmodSelfAttention(config, position_embedding_type=position_embedding_type, layer_idx=layer_idx)
  245. self.output = XmodSelfOutput(config)
  246. self.pruned_heads = set()
  247. self.pre_norm = config.pre_norm
  248. # Copied from transformers.models.roberta.modeling_roberta.RobertaAttention.prune_heads
  249. def prune_heads(self, heads):
  250. if len(heads) == 0:
  251. return
  252. heads, index = find_pruneable_heads_and_indices(
  253. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  254. )
  255. # Prune linear layers
  256. self.self.query = prune_linear_layer(self.self.query, index)
  257. self.self.key = prune_linear_layer(self.self.key, index)
  258. self.self.value = prune_linear_layer(self.self.value, index)
  259. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  260. # Update hyper params and store pruned heads
  261. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  262. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  263. self.pruned_heads = self.pruned_heads.union(heads)
  264. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  265. def forward(
  266. self,
  267. hidden_states: torch.Tensor,
  268. attention_mask: Optional[torch.FloatTensor] = None,
  269. head_mask: Optional[torch.FloatTensor] = None,
  270. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  271. past_key_values: Optional[Cache] = None,
  272. output_attentions: Optional[bool] = False,
  273. cache_position: Optional[torch.Tensor] = None,
  274. ) -> tuple[torch.Tensor]:
  275. residual = hidden_states
  276. if self.pre_norm:
  277. hidden_states = self.output.LayerNorm(hidden_states)
  278. self_outputs = self.self(
  279. hidden_states,
  280. attention_mask,
  281. head_mask,
  282. encoder_hidden_states,
  283. past_key_values,
  284. output_attentions,
  285. cache_position,
  286. )
  287. attention_output = self.output(self_outputs[0], residual)
  288. if not self.pre_norm:
  289. attention_output = self.output.LayerNorm(attention_output)
  290. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  291. return outputs
  292. # Copied from transformers.models.roberta.modeling_roberta.RobertaIntermediate
  293. class XmodIntermediate(nn.Module):
  294. def __init__(self, config):
  295. super().__init__()
  296. self.dense = nn.Linear(config.hidden_size, config.intermediate_size)
  297. if isinstance(config.hidden_act, str):
  298. self.intermediate_act_fn = ACT2FN[config.hidden_act]
  299. else:
  300. self.intermediate_act_fn = config.hidden_act
  301. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  302. hidden_states = self.dense(hidden_states)
  303. hidden_states = self.intermediate_act_fn(hidden_states)
  304. return hidden_states
  305. class XmodAdapter(nn.Module):
  306. def __init__(self, config):
  307. super().__init__()
  308. self.bottleneck_size = config.hidden_size // config.adapter_reduction_factor
  309. self.dense1 = nn.Linear(config.hidden_size, self.bottleneck_size)
  310. self.dense2 = nn.Linear(self.bottleneck_size, config.hidden_size)
  311. if isinstance(config.hidden_act, str):
  312. self.adapter_act_fn = ACT2FN[config.hidden_act]
  313. else:
  314. self.adapter_act_fn = config.hidden_act
  315. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  316. hidden_states = self.dense1(hidden_states)
  317. hidden_states = self.adapter_act_fn(hidden_states)
  318. hidden_states = self.dense2(hidden_states)
  319. return hidden_states
  320. class XmodOutput(nn.Module):
  321. def __init__(self, config):
  322. super().__init__()
  323. self.dense = nn.Linear(config.intermediate_size, config.hidden_size)
  324. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  325. self.ln_before_adapter = config.ln_before_adapter
  326. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  327. if config.adapter_layer_norm:
  328. self.adapter_layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  329. else:
  330. self.adapter_layer_norm = None
  331. self.adapter_reuse_layer_norm = config.adapter_reuse_layer_norm
  332. self.adapter_modules = nn.ModuleDict({})
  333. for language in config.languages:
  334. self.adapter_modules[str(language)] = XmodAdapter(config)
  335. def forward(self, hidden_states: torch.Tensor, input_tensor: torch.Tensor, lang_ids: torch.Tensor) -> torch.Tensor:
  336. hidden_states = self.dense(hidden_states)
  337. hidden_states = self.dropout(hidden_states)
  338. hidden_states = hidden_states + input_tensor
  339. hidden_states = self.lang_adapter(lang_ids, hidden_states)
  340. return hidden_states
  341. def lang_adapter(self, lang_ids: torch.Tensor, hidden_states: torch.Tensor):
  342. # Process subsequent samples with the same lang_id in parallel
  343. lang_ids, lang_lengths = torch.unique_consecutive(lang_ids, return_counts=True)
  344. if not self.ln_before_adapter:
  345. residual = hidden_states
  346. if self.adapter_layer_norm is not None:
  347. hidden_states = self.adapter_layer_norm(hidden_states)
  348. elif self.adapter_reuse_layer_norm:
  349. hidden_states = self.LayerNorm(hidden_states)
  350. if self.ln_before_adapter:
  351. residual = hidden_states
  352. split_hidden_states = torch.split(hidden_states, lang_lengths.tolist(), 0)
  353. lang_wise_outputs = []
  354. for i, (lang_id, split_hidden_state) in enumerate(zip(lang_ids, split_hidden_states)):
  355. lang = list(self.adapter_modules.keys())[int(lang_id.item())]
  356. lang_wise_outputs.append(self.adapter_modules[lang](split_hidden_state))
  357. hidden_states = torch.cat(lang_wise_outputs, 0)
  358. hidden_states = self.dropout(hidden_states)
  359. hidden_states += residual
  360. return hidden_states
  361. class XmodLayer(GradientCheckpointingLayer):
  362. def __init__(self, config, layer_idx=None):
  363. super().__init__()
  364. self.chunk_size_feed_forward = config.chunk_size_feed_forward
  365. self.seq_len_dim = 1
  366. self.attention = XmodAttention(config, layer_idx=layer_idx)
  367. self.is_decoder = config.is_decoder
  368. self.add_cross_attention = config.add_cross_attention
  369. if self.add_cross_attention:
  370. if not self.is_decoder:
  371. raise ValueError(f"{self} should be used as a decoder model if cross attention is added")
  372. self.crossattention = XmodAttention(config, position_embedding_type="absolute", layer_idx=layer_idx)
  373. self.intermediate = XmodIntermediate(config)
  374. self.output = XmodOutput(config)
  375. self.pre_norm = config.pre_norm
  376. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  377. def forward(
  378. self,
  379. hidden_states: torch.Tensor,
  380. lang_ids: torch.Tensor,
  381. attention_mask: Optional[torch.FloatTensor] = None,
  382. head_mask: Optional[torch.FloatTensor] = None,
  383. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  384. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  385. past_key_values: Optional[Cache] = None,
  386. output_attentions: Optional[bool] = False,
  387. cache_position: Optional[torch.Tensor] = None,
  388. ) -> tuple[torch.Tensor]:
  389. self_attention_outputs = self.attention(
  390. hidden_states,
  391. attention_mask=attention_mask,
  392. head_mask=head_mask,
  393. output_attentions=output_attentions,
  394. past_key_values=past_key_values,
  395. cache_position=cache_position,
  396. )
  397. attention_output = self_attention_outputs[0]
  398. outputs = self_attention_outputs[1:] # add self attentions if we output attention weights
  399. if self.is_decoder and encoder_hidden_states is not None:
  400. if not hasattr(self, "crossattention"):
  401. raise ValueError(
  402. f"If `encoder_hidden_states` are passed, {self} has to be instantiated with cross-attention layers"
  403. " by setting `config.add_cross_attention=True`"
  404. )
  405. cross_attention_outputs = self.crossattention(
  406. attention_output,
  407. attention_mask=encoder_attention_mask,
  408. head_mask=head_mask,
  409. encoder_hidden_states=encoder_hidden_states,
  410. past_key_values=past_key_values,
  411. output_attentions=output_attentions,
  412. cache_position=cache_position,
  413. )
  414. attention_output = cross_attention_outputs[0]
  415. outputs = outputs + cross_attention_outputs[1:] # add cross attentions if we output attention weights
  416. residual = attention_output
  417. if self.pre_norm:
  418. attention_output = self.output.LayerNorm(attention_output)
  419. intermediate_output = apply_chunking_to_forward(
  420. self.feed_forward_chunk,
  421. self.chunk_size_feed_forward,
  422. self.seq_len_dim,
  423. attention_output,
  424. )
  425. layer_output = self.output(intermediate_output, residual, lang_ids)
  426. if not self.pre_norm:
  427. layer_output = self.output.LayerNorm(layer_output)
  428. return (layer_output,) + outputs
  429. def feed_forward_chunk(self, attention_output):
  430. return self.intermediate(attention_output)
  431. class XmodEncoder(nn.Module):
  432. def __init__(self, config):
  433. super().__init__()
  434. self.config = config
  435. self.layer = nn.ModuleList([XmodLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)])
  436. self.is_pre_norm = config.pre_norm
  437. if self.is_pre_norm:
  438. self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  439. self.gradient_checkpointing = False
  440. def forward(
  441. self,
  442. hidden_states: torch.Tensor,
  443. lang_ids: torch.Tensor,
  444. attention_mask: Optional[torch.FloatTensor] = None,
  445. head_mask: Optional[torch.FloatTensor] = None,
  446. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  447. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  448. past_key_values: Optional[Cache] = None,
  449. use_cache: Optional[bool] = None,
  450. output_attentions: Optional[bool] = False,
  451. output_hidden_states: Optional[bool] = False,
  452. return_dict: Optional[bool] = True,
  453. cache_position: Optional[torch.Tensor] = None,
  454. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPastAndCrossAttentions]:
  455. if self.gradient_checkpointing and self.training:
  456. if use_cache:
  457. logger.warning_once(
  458. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  459. )
  460. use_cache = False
  461. if use_cache and past_key_values is None:
  462. past_key_values = EncoderDecoderCache(DynamicCache(config=self.config), DynamicCache(config=self.config))
  463. if use_cache and isinstance(past_key_values, tuple):
  464. logger.warning_once(
  465. "Passing a tuple of `past_key_values` is deprecated and will be removed in Transformers v4.58.0. "
  466. "You should pass an instance of `EncoderDecoderCache` instead, e.g. "
  467. "`past_key_values=EncoderDecoderCache.from_legacy_cache(past_key_values)`."
  468. )
  469. past_key_values = EncoderDecoderCache.from_legacy_cache(past_key_values)
  470. all_hidden_states = () if output_hidden_states else None
  471. all_self_attentions = () if output_attentions else None
  472. all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
  473. for i, layer_module in enumerate(self.layer):
  474. if output_hidden_states:
  475. all_hidden_states = all_hidden_states + (hidden_states,)
  476. layer_head_mask = head_mask[i] if head_mask is not None else None
  477. layer_outputs = layer_module(
  478. hidden_states,
  479. lang_ids,
  480. attention_mask,
  481. layer_head_mask,
  482. encoder_hidden_states,
  483. encoder_attention_mask,
  484. past_key_values,
  485. output_attentions,
  486. cache_position,
  487. )
  488. hidden_states = layer_outputs[0]
  489. if output_attentions:
  490. all_self_attentions = all_self_attentions + (layer_outputs[1],)
  491. if self.config.add_cross_attention:
  492. all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
  493. if self.is_pre_norm:
  494. hidden_states = self.LayerNorm(hidden_states)
  495. if output_hidden_states:
  496. all_hidden_states = all_hidden_states + (hidden_states,)
  497. if not return_dict:
  498. return tuple(
  499. v
  500. for v in [
  501. hidden_states,
  502. past_key_values,
  503. all_hidden_states,
  504. all_self_attentions,
  505. all_cross_attentions,
  506. ]
  507. if v is not None
  508. )
  509. return BaseModelOutputWithPastAndCrossAttentions(
  510. last_hidden_state=hidden_states,
  511. past_key_values=past_key_values,
  512. hidden_states=all_hidden_states,
  513. attentions=all_self_attentions,
  514. cross_attentions=all_cross_attentions,
  515. )
  516. # Copied from transformers.models.roberta.modeling_roberta.RobertaPooler
  517. class XmodPooler(nn.Module):
  518. def __init__(self, config):
  519. super().__init__()
  520. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  521. self.activation = nn.Tanh()
  522. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  523. # We "pool" the model by simply taking the hidden state corresponding
  524. # to the first token.
  525. first_token_tensor = hidden_states[:, 0]
  526. pooled_output = self.dense(first_token_tensor)
  527. pooled_output = self.activation(pooled_output)
  528. return pooled_output
  529. @auto_docstring
  530. class XmodPreTrainedModel(PreTrainedModel):
  531. config: XmodConfig
  532. base_model_prefix = "roberta"
  533. supports_gradient_checkpointing = True
  534. # Copied from transformers.models.bert.modeling_bert.BertPreTrainedModel._init_weights with BertLMPredictionHead->XmodLMHead
  535. def _init_weights(self, module):
  536. """Initialize the weights"""
  537. if isinstance(module, nn.Linear):
  538. # Slightly different from the TF version which uses truncated_normal for initialization
  539. # cf https://github.com/pytorch/pytorch/pull/5617
  540. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  541. if module.bias is not None:
  542. module.bias.data.zero_()
  543. elif isinstance(module, nn.Embedding):
  544. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  545. if module.padding_idx is not None:
  546. module.weight.data[module.padding_idx].zero_()
  547. elif isinstance(module, nn.LayerNorm):
  548. module.bias.data.zero_()
  549. module.weight.data.fill_(1.0)
  550. elif isinstance(module, XmodLMHead):
  551. module.bias.data.zero_()
  552. def set_default_language(self, language: str):
  553. """
  554. Set the default language code for the model. This is used when the language is not specified in the input.
  555. Args:
  556. language (`str`): The language code, such as `"en_XX"` or `"de_DE"`.
  557. """
  558. if language not in self.config.languages:
  559. raise ValueError(
  560. f"{self} does not have an adapter for {language}. Supported languages: {list(self.config.languages)}"
  561. )
  562. self.config.default_language = language
  563. def freeze_embeddings_and_language_adapters(self):
  564. """
  565. Freeze the embeddings and language adapters of the model. Usually, this is applied before the model is
  566. fine-tuned on a downstream task.
  567. """
  568. logger.info("Freezing embeddings")
  569. for parameter in self.roberta.embeddings.parameters():
  570. parameter.requires_grad = False
  571. logger.info("Freezing adapters")
  572. for layer in self.roberta.encoder.layer:
  573. if layer.output.adapter_layer_norm is not None:
  574. for parameter in layer.output.adapter_layer_norm.parameters():
  575. parameter.requires_grad = False
  576. for parameter in layer.output.adapter_modules.parameters():
  577. parameter.requires_grad = False
  578. @auto_docstring(
  579. custom_intro="""
  580. The model can behave as an encoder (with only self-attention) as well as a decoder, in which case a layer of
  581. cross-attention is added between the self-attention layers, following the architecture described in *Attention is
  582. all you need*_ by Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz
  583. Kaiser and Illia Polosukhin.
  584. To behave as an decoder the model needs to be initialized with the `is_decoder` argument of the configuration set
  585. to `True`. To be used in a Seq2Seq model, the model needs to initialized with both `is_decoder` argument and
  586. `add_cross_attention` set to `True`; an `encoder_hidden_states` is then expected as an input to the forward pass.
  587. .. _*Attention is all you need*: https://huggingface.co/papers/1706.03762
  588. """
  589. )
  590. class XmodModel(XmodPreTrainedModel):
  591. # Copied from transformers.models.clap.modeling_clap.ClapTextModel.__init__ with ClapText->Xmod
  592. def __init__(self, config, add_pooling_layer=True):
  593. r"""
  594. add_pooling_layer (bool, *optional*, defaults to `True`):
  595. Whether to add a pooling layer
  596. """
  597. super().__init__(config)
  598. self.config = config
  599. self.embeddings = XmodEmbeddings(config)
  600. self.encoder = XmodEncoder(config)
  601. self.pooler = XmodPooler(config) if add_pooling_layer else None
  602. # Initialize weights and apply final processing
  603. self.post_init()
  604. # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.get_input_embeddings
  605. def get_input_embeddings(self):
  606. return self.embeddings.word_embeddings
  607. # Copied from transformers.models.roberta.modeling_roberta.RobertaModel.set_input_embeddings
  608. def set_input_embeddings(self, value):
  609. self.embeddings.word_embeddings = value
  610. # Copied from transformers.models.roberta.modeling_roberta.RobertaModel._prune_heads
  611. def _prune_heads(self, heads_to_prune):
  612. """
  613. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  614. class PreTrainedModel
  615. """
  616. for layer, heads in heads_to_prune.items():
  617. self.encoder.layer[layer].attention.prune_heads(heads)
  618. @auto_docstring
  619. def forward(
  620. self,
  621. input_ids: Optional[torch.Tensor] = None,
  622. lang_ids: Optional[torch.LongTensor] = None,
  623. attention_mask: Optional[torch.Tensor] = None,
  624. token_type_ids: Optional[torch.Tensor] = None,
  625. position_ids: Optional[torch.Tensor] = None,
  626. head_mask: Optional[torch.Tensor] = None,
  627. inputs_embeds: Optional[torch.Tensor] = None,
  628. encoder_hidden_states: Optional[torch.Tensor] = None,
  629. encoder_attention_mask: Optional[torch.Tensor] = None,
  630. past_key_values: Optional[Cache] = None,
  631. use_cache: Optional[bool] = None,
  632. output_attentions: Optional[bool] = None,
  633. output_hidden_states: Optional[bool] = None,
  634. return_dict: Optional[bool] = None,
  635. cache_position: Optional[torch.Tensor] = None,
  636. ) -> Union[tuple[torch.Tensor], BaseModelOutputWithPoolingAndCrossAttentions]:
  637. r"""
  638. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  639. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  640. that corresponds to `self.config.default_language`.
  641. """
  642. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  643. output_hidden_states = (
  644. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  645. )
  646. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  647. if self.config.is_decoder:
  648. use_cache = use_cache if use_cache is not None else self.config.use_cache
  649. else:
  650. use_cache = False
  651. if input_ids is not None and inputs_embeds is not None:
  652. raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
  653. elif input_ids is not None:
  654. self.warn_if_padding_and_no_attention_mask(input_ids, attention_mask)
  655. input_shape = input_ids.size()
  656. elif inputs_embeds is not None:
  657. input_shape = inputs_embeds.size()[:-1]
  658. else:
  659. raise ValueError("You have to specify either input_ids or inputs_embeds")
  660. batch_size, seq_length = input_shape
  661. device = input_ids.device if input_ids is not None else inputs_embeds.device
  662. past_key_values_length = 0
  663. if past_key_values is not None:
  664. past_key_values_length = (
  665. past_key_values[0][0].shape[-2]
  666. if not isinstance(past_key_values, Cache)
  667. else past_key_values.get_seq_length()
  668. )
  669. if lang_ids is None:
  670. if self.config.default_language is None:
  671. raise ValueError("Input language unknown. Please call `XmodPreTrainedModel.set_default_language()`")
  672. adapter_languages = list(self.encoder.layer[0].output.adapter_modules.keys())
  673. default_lang_id = adapter_languages.index(self.config.default_language)
  674. lang_ids = default_lang_id * torch.ones(batch_size, device=device)
  675. if attention_mask is None:
  676. attention_mask = torch.ones(((batch_size, seq_length + past_key_values_length)), device=device)
  677. if token_type_ids is None:
  678. if hasattr(self.embeddings, "token_type_ids"):
  679. buffered_token_type_ids = self.embeddings.token_type_ids[:, :seq_length]
  680. buffered_token_type_ids_expanded = buffered_token_type_ids.expand(batch_size, seq_length)
  681. token_type_ids = buffered_token_type_ids_expanded
  682. else:
  683. token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=device)
  684. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  685. # ourselves in which case we just need to make it broadcastable to all heads.
  686. extended_attention_mask: torch.Tensor = self.get_extended_attention_mask(attention_mask, input_shape)
  687. # If a 2D or 3D attention mask is provided for the cross-attention
  688. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  689. if self.config.is_decoder and encoder_hidden_states is not None:
  690. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  691. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  692. if encoder_attention_mask is None:
  693. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=device)
  694. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  695. else:
  696. encoder_extended_attention_mask = None
  697. # Prepare head mask if needed
  698. # 1.0 in head_mask indicate we keep the head
  699. # attention_probs has shape bsz x n_heads x N x N
  700. # input head_mask has shape [num_heads] or [num_hidden_layers x num_heads]
  701. # and head_mask is converted to shape [num_hidden_layers x batch x num_heads x seq_length x seq_length]
  702. head_mask = self.get_head_mask(head_mask, self.config.num_hidden_layers)
  703. embedding_output = self.embeddings(
  704. input_ids=input_ids,
  705. position_ids=position_ids,
  706. token_type_ids=token_type_ids,
  707. inputs_embeds=inputs_embeds,
  708. past_key_values_length=past_key_values_length,
  709. )
  710. encoder_outputs = self.encoder(
  711. embedding_output,
  712. lang_ids=lang_ids,
  713. attention_mask=extended_attention_mask,
  714. head_mask=head_mask,
  715. encoder_hidden_states=encoder_hidden_states,
  716. encoder_attention_mask=encoder_extended_attention_mask,
  717. past_key_values=past_key_values,
  718. use_cache=use_cache,
  719. output_attentions=output_attentions,
  720. output_hidden_states=output_hidden_states,
  721. return_dict=return_dict,
  722. cache_position=cache_position,
  723. )
  724. sequence_output = encoder_outputs[0]
  725. pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
  726. if not return_dict:
  727. return (sequence_output, pooled_output) + encoder_outputs[1:]
  728. return BaseModelOutputWithPoolingAndCrossAttentions(
  729. last_hidden_state=sequence_output,
  730. pooler_output=pooled_output,
  731. past_key_values=encoder_outputs.past_key_values,
  732. hidden_states=encoder_outputs.hidden_states,
  733. attentions=encoder_outputs.attentions,
  734. cross_attentions=encoder_outputs.cross_attentions,
  735. )
  736. @auto_docstring(
  737. custom_intro="""
  738. X-MOD Model with a `language modeling` head on top for CLM fine-tuning.
  739. """
  740. )
  741. class XmodForCausalLM(XmodPreTrainedModel, GenerationMixin):
  742. _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
  743. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.__init__ with Roberta->Xmod
  744. def __init__(self, config):
  745. super().__init__(config)
  746. if not config.is_decoder:
  747. logger.warning("If you want to use `XmodLMHeadModel` as a standalone, add `is_decoder=True.`")
  748. self.roberta = XmodModel(config, add_pooling_layer=False)
  749. self.lm_head = XmodLMHead(config)
  750. # Initialize weights and apply final processing
  751. self.post_init()
  752. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.get_output_embeddings
  753. def get_output_embeddings(self):
  754. return self.lm_head.decoder
  755. # Copied from transformers.models.roberta.modeling_roberta.RobertaForCausalLM.set_output_embeddings
  756. def set_output_embeddings(self, new_embeddings):
  757. self.lm_head.decoder = new_embeddings
  758. @auto_docstring
  759. def forward(
  760. self,
  761. input_ids: Optional[torch.LongTensor] = None,
  762. lang_ids: Optional[torch.LongTensor] = None,
  763. attention_mask: Optional[torch.FloatTensor] = None,
  764. token_type_ids: Optional[torch.LongTensor] = None,
  765. position_ids: Optional[torch.LongTensor] = None,
  766. head_mask: Optional[torch.FloatTensor] = None,
  767. inputs_embeds: Optional[torch.FloatTensor] = None,
  768. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  769. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  770. labels: Optional[torch.LongTensor] = None,
  771. past_key_values: Optional[Cache] = None,
  772. use_cache: Optional[bool] = None,
  773. output_attentions: Optional[bool] = None,
  774. output_hidden_states: Optional[bool] = None,
  775. return_dict: Optional[bool] = None,
  776. cache_position: Optional[torch.Tensor] = None,
  777. **kwargs,
  778. ) -> Union[tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
  779. r"""
  780. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  781. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  782. that corresponds to `self.config.default_language`.
  783. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  784. Labels for computing the left-to-right language modeling loss (next word prediction). Indices should be in
  785. `[-100, 0, ..., config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are
  786. ignored (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  787. Example:
  788. ```python
  789. >>> from transformers import AutoTokenizer, XmodForCausalLM, AutoConfig
  790. >>> import torch
  791. >>> tokenizer = AutoTokenizer.from_pretrained("FacebookAI/xlm-roberta-base")
  792. >>> config = AutoConfig.from_pretrained("facebook/xmod-base")
  793. >>> config.is_decoder = True
  794. >>> model = XmodForCausalLM.from_pretrained("facebook/xmod-base", config=config)
  795. >>> model.set_default_language("en_XX")
  796. >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
  797. >>> outputs = model(**inputs)
  798. >>> prediction_logits = outputs.logits
  799. ```"""
  800. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  801. if labels is not None:
  802. use_cache = False
  803. outputs = self.roberta(
  804. input_ids,
  805. lang_ids=lang_ids,
  806. attention_mask=attention_mask,
  807. token_type_ids=token_type_ids,
  808. position_ids=position_ids,
  809. head_mask=head_mask,
  810. inputs_embeds=inputs_embeds,
  811. encoder_hidden_states=encoder_hidden_states,
  812. encoder_attention_mask=encoder_attention_mask,
  813. past_key_values=past_key_values,
  814. use_cache=use_cache,
  815. output_attentions=output_attentions,
  816. output_hidden_states=output_hidden_states,
  817. return_dict=return_dict,
  818. cache_position=cache_position,
  819. )
  820. sequence_output = outputs[0]
  821. prediction_scores = self.lm_head(sequence_output)
  822. lm_loss = None
  823. if labels is not None:
  824. lm_loss = self.loss_function(
  825. prediction_scores,
  826. labels,
  827. vocab_size=self.config.vocab_size,
  828. **kwargs,
  829. )
  830. if not return_dict:
  831. output = (prediction_scores,) + outputs[2:]
  832. return ((lm_loss,) + output) if lm_loss is not None else output
  833. return CausalLMOutputWithCrossAttentions(
  834. loss=lm_loss,
  835. logits=prediction_scores,
  836. past_key_values=outputs.past_key_values,
  837. hidden_states=outputs.hidden_states,
  838. attentions=outputs.attentions,
  839. cross_attentions=outputs.cross_attentions,
  840. )
  841. @auto_docstring
  842. class XmodForMaskedLM(XmodPreTrainedModel):
  843. _tied_weights_keys = ["lm_head.decoder.weight", "lm_head.decoder.bias"]
  844. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.__init__ with Roberta->Xmod
  845. def __init__(self, config):
  846. super().__init__(config)
  847. if config.is_decoder:
  848. logger.warning(
  849. "If you want to use `XmodForMaskedLM` make sure `config.is_decoder=False` for "
  850. "bi-directional self-attention."
  851. )
  852. self.roberta = XmodModel(config, add_pooling_layer=False)
  853. self.lm_head = XmodLMHead(config)
  854. # Initialize weights and apply final processing
  855. self.post_init()
  856. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.get_output_embeddings
  857. def get_output_embeddings(self):
  858. return self.lm_head.decoder
  859. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMaskedLM.set_output_embeddings
  860. def set_output_embeddings(self, new_embeddings):
  861. self.lm_head.decoder = new_embeddings
  862. @auto_docstring
  863. def forward(
  864. self,
  865. input_ids: Optional[torch.LongTensor] = None,
  866. lang_ids: Optional[torch.LongTensor] = None,
  867. attention_mask: Optional[torch.FloatTensor] = None,
  868. token_type_ids: Optional[torch.LongTensor] = None,
  869. position_ids: Optional[torch.LongTensor] = None,
  870. head_mask: Optional[torch.FloatTensor] = None,
  871. inputs_embeds: Optional[torch.FloatTensor] = None,
  872. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  873. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  874. labels: Optional[torch.LongTensor] = None,
  875. output_attentions: Optional[bool] = None,
  876. output_hidden_states: Optional[bool] = None,
  877. return_dict: Optional[bool] = None,
  878. ) -> Union[tuple[torch.Tensor], MaskedLMOutput]:
  879. r"""
  880. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  881. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  882. that corresponds to `self.config.default_language`.
  883. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  884. Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
  885. config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked), the
  886. loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
  887. """
  888. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  889. outputs = self.roberta(
  890. input_ids,
  891. lang_ids=lang_ids,
  892. attention_mask=attention_mask,
  893. token_type_ids=token_type_ids,
  894. position_ids=position_ids,
  895. head_mask=head_mask,
  896. inputs_embeds=inputs_embeds,
  897. encoder_hidden_states=encoder_hidden_states,
  898. encoder_attention_mask=encoder_attention_mask,
  899. output_attentions=output_attentions,
  900. output_hidden_states=output_hidden_states,
  901. return_dict=return_dict,
  902. )
  903. sequence_output = outputs[0]
  904. prediction_scores = self.lm_head(sequence_output)
  905. masked_lm_loss = None
  906. if labels is not None:
  907. loss_fct = CrossEntropyLoss()
  908. masked_lm_loss = loss_fct(prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
  909. if not return_dict:
  910. output = (prediction_scores,) + outputs[2:]
  911. return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
  912. return MaskedLMOutput(
  913. loss=masked_lm_loss,
  914. logits=prediction_scores,
  915. hidden_states=outputs.hidden_states,
  916. attentions=outputs.attentions,
  917. )
  918. # Copied from transformers.models.roberta.modeling_roberta.RobertaLMHead
  919. class XmodLMHead(nn.Module):
  920. """Roberta Head for masked language modeling."""
  921. def __init__(self, config):
  922. super().__init__()
  923. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  924. self.layer_norm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps)
  925. self.decoder = nn.Linear(config.hidden_size, config.vocab_size)
  926. self.bias = nn.Parameter(torch.zeros(config.vocab_size))
  927. self.decoder.bias = self.bias
  928. def forward(self, features, **kwargs):
  929. x = self.dense(features)
  930. x = gelu(x)
  931. x = self.layer_norm(x)
  932. # project back to size of vocabulary with bias
  933. x = self.decoder(x)
  934. return x
  935. def _tie_weights(self):
  936. # To tie those two weights if they get disconnected (on TPU or when the bias is resized)
  937. # For accelerate compatibility and to not break backward compatibility
  938. if self.decoder.bias.device.type == "meta":
  939. self.decoder.bias = self.bias
  940. else:
  941. self.bias = self.decoder.bias
  942. @auto_docstring(
  943. custom_intro="""
  944. X-MOD Model transformer with a sequence classification/regression head on top (a linear layer on top of the pooled
  945. output) e.g. for GLUE tasks.
  946. """
  947. )
  948. class XmodForSequenceClassification(XmodPreTrainedModel):
  949. # Copied from transformers.models.roberta.modeling_roberta.RobertaForSequenceClassification.__init__ with Roberta->Xmod
  950. def __init__(self, config):
  951. super().__init__(config)
  952. self.num_labels = config.num_labels
  953. self.config = config
  954. self.roberta = XmodModel(config, add_pooling_layer=False)
  955. self.classifier = XmodClassificationHead(config)
  956. # Initialize weights and apply final processing
  957. self.post_init()
  958. @auto_docstring
  959. def forward(
  960. self,
  961. input_ids: Optional[torch.LongTensor] = None,
  962. lang_ids: Optional[torch.LongTensor] = None,
  963. attention_mask: Optional[torch.FloatTensor] = None,
  964. token_type_ids: Optional[torch.LongTensor] = None,
  965. position_ids: Optional[torch.LongTensor] = None,
  966. head_mask: Optional[torch.FloatTensor] = None,
  967. inputs_embeds: Optional[torch.FloatTensor] = None,
  968. labels: Optional[torch.LongTensor] = None,
  969. output_attentions: Optional[bool] = None,
  970. output_hidden_states: Optional[bool] = None,
  971. return_dict: Optional[bool] = None,
  972. ) -> Union[tuple[torch.Tensor], SequenceClassifierOutput]:
  973. r"""
  974. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  975. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  976. that corresponds to `self.config.default_language`.
  977. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  978. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  979. config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
  980. `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  981. """
  982. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  983. outputs = self.roberta(
  984. input_ids,
  985. lang_ids=lang_ids,
  986. attention_mask=attention_mask,
  987. token_type_ids=token_type_ids,
  988. position_ids=position_ids,
  989. head_mask=head_mask,
  990. inputs_embeds=inputs_embeds,
  991. output_attentions=output_attentions,
  992. output_hidden_states=output_hidden_states,
  993. return_dict=return_dict,
  994. )
  995. sequence_output = outputs[0]
  996. logits = self.classifier(sequence_output)
  997. loss = None
  998. if labels is not None:
  999. if self.config.problem_type is None:
  1000. if self.num_labels == 1:
  1001. self.config.problem_type = "regression"
  1002. elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1003. self.config.problem_type = "single_label_classification"
  1004. else:
  1005. self.config.problem_type = "multi_label_classification"
  1006. if self.config.problem_type == "regression":
  1007. loss_fct = MSELoss()
  1008. if self.num_labels == 1:
  1009. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1010. else:
  1011. loss = loss_fct(logits, labels)
  1012. elif self.config.problem_type == "single_label_classification":
  1013. loss_fct = CrossEntropyLoss()
  1014. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1015. elif self.config.problem_type == "multi_label_classification":
  1016. loss_fct = BCEWithLogitsLoss()
  1017. loss = loss_fct(logits, labels)
  1018. if not return_dict:
  1019. output = (logits,) + outputs[2:]
  1020. return ((loss,) + output) if loss is not None else output
  1021. return SequenceClassifierOutput(
  1022. loss=loss,
  1023. logits=logits,
  1024. hidden_states=outputs.hidden_states,
  1025. attentions=outputs.attentions,
  1026. )
  1027. @auto_docstring
  1028. class XmodForMultipleChoice(XmodPreTrainedModel):
  1029. # Copied from transformers.models.roberta.modeling_roberta.RobertaForMultipleChoice.__init__ with Roberta->Xmod
  1030. def __init__(self, config):
  1031. super().__init__(config)
  1032. self.roberta = XmodModel(config)
  1033. self.dropout = nn.Dropout(config.hidden_dropout_prob)
  1034. self.classifier = nn.Linear(config.hidden_size, 1)
  1035. # Initialize weights and apply final processing
  1036. self.post_init()
  1037. @auto_docstring
  1038. def forward(
  1039. self,
  1040. input_ids: Optional[torch.LongTensor] = None,
  1041. lang_ids: Optional[torch.LongTensor] = None,
  1042. token_type_ids: Optional[torch.LongTensor] = None,
  1043. attention_mask: Optional[torch.FloatTensor] = None,
  1044. labels: Optional[torch.LongTensor] = None,
  1045. position_ids: Optional[torch.LongTensor] = None,
  1046. head_mask: Optional[torch.FloatTensor] = None,
  1047. inputs_embeds: Optional[torch.FloatTensor] = None,
  1048. output_attentions: Optional[bool] = None,
  1049. output_hidden_states: Optional[bool] = None,
  1050. return_dict: Optional[bool] = None,
  1051. ) -> Union[tuple[torch.Tensor], MultipleChoiceModelOutput]:
  1052. r"""
  1053. input_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`):
  1054. Indices of input sequence tokens in the vocabulary.
  1055. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1056. [`PreTrainedTokenizer.__call__`] for details.
  1057. [What are input IDs?](../glossary#input-ids)
  1058. lang_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1059. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  1060. that corresponds to `self.config.default_language`.
  1061. token_type_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1062. Segment token indices to indicate first and second portions of the inputs. Indices are selected in `[0,
  1063. 1]`:
  1064. - 0 corresponds to a *sentence A* token,
  1065. - 1 corresponds to a *sentence B* token.
  1066. [What are token type IDs?](../glossary#token-type-ids)
  1067. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1068. Labels for computing the multiple choice classification loss. Indices should be in `[0, ...,
  1069. num_choices-1]` where `num_choices` is the size of the second dimension of the input tensors. (See
  1070. `input_ids` above)
  1071. position_ids (`torch.LongTensor` of shape `(batch_size, num_choices, sequence_length)`, *optional*):
  1072. Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
  1073. config.max_position_embeddings - 1]`.
  1074. [What are position IDs?](../glossary#position-ids)
  1075. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, num_choices, sequence_length, hidden_size)`, *optional*):
  1076. Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
  1077. is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
  1078. model's internal embedding lookup matrix.
  1079. """
  1080. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1081. num_choices = input_ids.shape[1] if input_ids is not None else inputs_embeds.shape[1]
  1082. flat_input_ids = input_ids.view(-1, input_ids.size(-1)) if input_ids is not None else None
  1083. flat_lang_ids = lang_ids.repeat(input_ids.size(0) * input_ids.size(1)) if lang_ids is not None else None
  1084. flat_position_ids = position_ids.view(-1, position_ids.size(-1)) if position_ids is not None else None
  1085. flat_token_type_ids = token_type_ids.view(-1, token_type_ids.size(-1)) if token_type_ids is not None else None
  1086. flat_attention_mask = attention_mask.view(-1, attention_mask.size(-1)) if attention_mask is not None else None
  1087. flat_inputs_embeds = (
  1088. inputs_embeds.view(-1, inputs_embeds.size(-2), inputs_embeds.size(-1))
  1089. if inputs_embeds is not None
  1090. else None
  1091. )
  1092. outputs = self.roberta(
  1093. flat_input_ids,
  1094. lang_ids=flat_lang_ids,
  1095. position_ids=flat_position_ids,
  1096. token_type_ids=flat_token_type_ids,
  1097. attention_mask=flat_attention_mask,
  1098. head_mask=head_mask,
  1099. inputs_embeds=flat_inputs_embeds,
  1100. output_attentions=output_attentions,
  1101. output_hidden_states=output_hidden_states,
  1102. return_dict=return_dict,
  1103. )
  1104. pooled_output = outputs[1]
  1105. pooled_output = self.dropout(pooled_output)
  1106. logits = self.classifier(pooled_output)
  1107. reshaped_logits = logits.view(-1, num_choices)
  1108. loss = None
  1109. if labels is not None:
  1110. loss_fct = CrossEntropyLoss()
  1111. loss = loss_fct(reshaped_logits, labels)
  1112. if not return_dict:
  1113. output = (reshaped_logits,) + outputs[2:]
  1114. return ((loss,) + output) if loss is not None else output
  1115. return MultipleChoiceModelOutput(
  1116. loss=loss,
  1117. logits=reshaped_logits,
  1118. hidden_states=outputs.hidden_states,
  1119. attentions=outputs.attentions,
  1120. )
  1121. @auto_docstring
  1122. class XmodForTokenClassification(XmodPreTrainedModel):
  1123. # Copied from transformers.models.roberta.modeling_roberta.RobertaForTokenClassification.__init__ with Roberta->Xmod
  1124. def __init__(self, config):
  1125. super().__init__(config)
  1126. self.num_labels = config.num_labels
  1127. self.roberta = XmodModel(config, add_pooling_layer=False)
  1128. classifier_dropout = (
  1129. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1130. )
  1131. self.dropout = nn.Dropout(classifier_dropout)
  1132. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1133. # Initialize weights and apply final processing
  1134. self.post_init()
  1135. @auto_docstring
  1136. def forward(
  1137. self,
  1138. input_ids: Optional[torch.LongTensor] = None,
  1139. lang_ids: Optional[torch.LongTensor] = None,
  1140. attention_mask: Optional[torch.FloatTensor] = None,
  1141. token_type_ids: Optional[torch.LongTensor] = None,
  1142. position_ids: Optional[torch.LongTensor] = None,
  1143. head_mask: Optional[torch.FloatTensor] = None,
  1144. inputs_embeds: Optional[torch.FloatTensor] = None,
  1145. labels: Optional[torch.LongTensor] = None,
  1146. output_attentions: Optional[bool] = None,
  1147. output_hidden_states: Optional[bool] = None,
  1148. return_dict: Optional[bool] = None,
  1149. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1150. r"""
  1151. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1152. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  1153. that corresponds to `self.config.default_language`.
  1154. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1155. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1156. """
  1157. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1158. outputs = self.roberta(
  1159. input_ids,
  1160. lang_ids=lang_ids,
  1161. attention_mask=attention_mask,
  1162. token_type_ids=token_type_ids,
  1163. position_ids=position_ids,
  1164. head_mask=head_mask,
  1165. inputs_embeds=inputs_embeds,
  1166. output_attentions=output_attentions,
  1167. output_hidden_states=output_hidden_states,
  1168. return_dict=return_dict,
  1169. )
  1170. sequence_output = outputs[0]
  1171. sequence_output = self.dropout(sequence_output)
  1172. logits = self.classifier(sequence_output)
  1173. loss = None
  1174. if labels is not None:
  1175. loss_fct = CrossEntropyLoss()
  1176. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1177. if not return_dict:
  1178. output = (logits,) + outputs[2:]
  1179. return ((loss,) + output) if loss is not None else output
  1180. return TokenClassifierOutput(
  1181. loss=loss,
  1182. logits=logits,
  1183. hidden_states=outputs.hidden_states,
  1184. attentions=outputs.attentions,
  1185. )
  1186. # Copied from transformers.models.roberta.modeling_roberta.RobertaClassificationHead
  1187. class XmodClassificationHead(nn.Module):
  1188. """Head for sentence-level classification tasks."""
  1189. def __init__(self, config):
  1190. super().__init__()
  1191. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  1192. classifier_dropout = (
  1193. config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
  1194. )
  1195. self.dropout = nn.Dropout(classifier_dropout)
  1196. self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
  1197. def forward(self, features, **kwargs):
  1198. x = features[:, 0, :] # take <s> token (equiv. to [CLS])
  1199. x = self.dropout(x)
  1200. x = self.dense(x)
  1201. x = torch.tanh(x)
  1202. x = self.dropout(x)
  1203. x = self.out_proj(x)
  1204. return x
  1205. @auto_docstring
  1206. class XmodForQuestionAnswering(XmodPreTrainedModel):
  1207. # Copied from transformers.models.roberta.modeling_roberta.RobertaForQuestionAnswering.__init__ with Roberta->Xmod
  1208. def __init__(self, config):
  1209. super().__init__(config)
  1210. self.num_labels = config.num_labels
  1211. self.roberta = XmodModel(config, add_pooling_layer=False)
  1212. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1213. # Initialize weights and apply final processing
  1214. self.post_init()
  1215. @auto_docstring
  1216. def forward(
  1217. self,
  1218. input_ids: Optional[torch.LongTensor] = None,
  1219. lang_ids: Optional[torch.LongTensor] = None,
  1220. attention_mask: Optional[torch.FloatTensor] = None,
  1221. token_type_ids: Optional[torch.LongTensor] = None,
  1222. position_ids: Optional[torch.LongTensor] = None,
  1223. head_mask: Optional[torch.FloatTensor] = None,
  1224. inputs_embeds: Optional[torch.FloatTensor] = None,
  1225. start_positions: Optional[torch.LongTensor] = None,
  1226. end_positions: Optional[torch.LongTensor] = None,
  1227. output_attentions: Optional[bool] = None,
  1228. output_hidden_states: Optional[bool] = None,
  1229. return_dict: Optional[bool] = None,
  1230. ) -> Union[tuple[torch.Tensor], QuestionAnsweringModelOutput]:
  1231. r"""
  1232. lang_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1233. Indices of the language adapters that should be activated for each sample, respectively. Default: the index
  1234. that corresponds to `self.config.default_language`.
  1235. """
  1236. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1237. outputs = self.roberta(
  1238. input_ids,
  1239. lang_ids=lang_ids,
  1240. attention_mask=attention_mask,
  1241. token_type_ids=token_type_ids,
  1242. position_ids=position_ids,
  1243. head_mask=head_mask,
  1244. inputs_embeds=inputs_embeds,
  1245. output_attentions=output_attentions,
  1246. output_hidden_states=output_hidden_states,
  1247. return_dict=return_dict,
  1248. )
  1249. sequence_output = outputs[0]
  1250. logits = self.qa_outputs(sequence_output)
  1251. start_logits, end_logits = logits.split(1, dim=-1)
  1252. start_logits = start_logits.squeeze(-1).contiguous()
  1253. end_logits = end_logits.squeeze(-1).contiguous()
  1254. total_loss = None
  1255. if start_positions is not None and end_positions is not None:
  1256. # If we are on multi-GPU, split add a dimension
  1257. if len(start_positions.size()) > 1:
  1258. start_positions = start_positions.squeeze(-1)
  1259. if len(end_positions.size()) > 1:
  1260. end_positions = end_positions.squeeze(-1)
  1261. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1262. ignored_index = start_logits.size(1)
  1263. start_positions = start_positions.clamp(0, ignored_index)
  1264. end_positions = end_positions.clamp(0, ignored_index)
  1265. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1266. start_loss = loss_fct(start_logits, start_positions)
  1267. end_loss = loss_fct(end_logits, end_positions)
  1268. total_loss = (start_loss + end_loss) / 2
  1269. if not return_dict:
  1270. output = (start_logits, end_logits) + outputs[2:]
  1271. return ((total_loss,) + output) if total_loss is not None else output
  1272. return QuestionAnsweringModelOutput(
  1273. loss=total_loss,
  1274. start_logits=start_logits,
  1275. end_logits=end_logits,
  1276. hidden_states=outputs.hidden_states,
  1277. attentions=outputs.attentions,
  1278. )
  1279. # Copied from transformers.models.roberta.modeling_roberta.create_position_ids_from_input_ids
  1280. def create_position_ids_from_input_ids(input_ids, padding_idx, past_key_values_length=0):
  1281. """
  1282. Replace non-padding symbols with their position numbers. Position numbers begin at padding_idx+1. Padding symbols
  1283. are ignored. This is modified from fairseq's `utils.make_positions`.
  1284. Args:
  1285. x: torch.Tensor x:
  1286. Returns: torch.Tensor
  1287. """
  1288. # The series of casts and type-conversions here are carefully balanced to both work with ONNX export and XLA.
  1289. mask = input_ids.ne(padding_idx).int()
  1290. incremental_indices = (torch.cumsum(mask, dim=1).type_as(mask) + past_key_values_length) * mask
  1291. return incremental_indices.long() + padding_idx
  1292. __all__ = [
  1293. "XmodForCausalLM",
  1294. "XmodForMaskedLM",
  1295. "XmodForMultipleChoice",
  1296. "XmodForQuestionAnswering",
  1297. "XmodForSequenceClassification",
  1298. "XmodForTokenClassification",
  1299. "XmodModel",
  1300. "XmodPreTrainedModel",
  1301. ]