modeling_mt5.py 111 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453
  1. # coding=utf-8
  2. # Copyright 2020 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch mT5 model."""
  16. import copy
  17. import math
  18. import os
  19. import warnings
  20. from typing import Optional, Union
  21. import torch
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import AttentionMaskConverter
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. Seq2SeqQuestionAnsweringModelOutput,
  35. Seq2SeqSequenceClassifierOutput,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  40. from ...utils import (
  41. DUMMY_INPUTS,
  42. DUMMY_MASK,
  43. add_start_docstrings,
  44. auto_docstring,
  45. is_torch_flex_attn_available,
  46. is_torch_fx_proxy,
  47. is_torchdynamo_compiling,
  48. logging,
  49. )
  50. from ...utils.deprecation import deprecate_kwarg
  51. from ...utils.model_parallel_utils import assert_device_map, get_device_map
  52. from .configuration_mt5 import MT5Config
  53. if is_torch_flex_attn_available():
  54. from torch.nn.attention.flex_attention import BlockMask
  55. from ...integrations.flex_attention import make_flex_block_causal_mask
  56. logger = logging.get_logger(__name__)
  57. ####################################################
  58. # This dict contains ids and associated url
  59. # for the pretrained weights provided with the models
  60. ####################################################
  61. PARALLELIZE_DOCSTRING = r"""
  62. This is an experimental feature and is a subject to change at a moment's notice.
  63. Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
  64. it will evenly distribute blocks across all devices.
  65. Args:
  66. device_map (`dict[int, list]`, *optional*):
  67. A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
  68. automatically mapped to the first device (for esoteric reasons). That means that the first device should
  69. have fewer attention modules mapped to it than other devices. For reference, the mt5 models have the
  70. following number of attention modules:
  71. - mt5-small: 6
  72. - mt5-base: 12
  73. - mt5-large: 24
  74. - mt5-xl: 24
  75. - mt5-xxl: 24
  76. Example:
  77. ```python
  78. # Here is an example of a device map on a machine with 4 GPUs using mt5-xl, which has a total of 24 attention modules:
  79. model = MT5ForConditionalGeneration.from_pretrained("mt5-xl")
  80. device_map = {
  81. 0: [0, 1, 2],
  82. 1: [3, 4, 5, 6, 7, 8, 9],
  83. 2: [10, 11, 12, 13, 14, 15, 16],
  84. 3: [17, 18, 19, 20, 21, 22, 23],
  85. }
  86. model.parallelize(device_map)
  87. ```
  88. """
  89. DEPARALLELIZE_DOCSTRING = r"""
  90. Moves the model to cpu from a model parallel state.
  91. Example:
  92. ```python
  93. # On a 4 GPU machine with mt5-xl:
  94. model = MT5ForConditionalGeneration.from_pretrained("Mt5-xl")
  95. device_map = {
  96. 0: [0, 1, 2],
  97. 1: [3, 4, 5, 6, 7, 8, 9],
  98. 2: [10, 11, 12, 13, 14, 15, 16],
  99. 3: [17, 18, 19, 20, 21, 22, 23],
  100. }
  101. model.parallelize(device_map) # Splits the model across several devices
  102. model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
  103. ```
  104. """
  105. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->MT5
  106. class MT5LayerNorm(nn.Module):
  107. def __init__(self, hidden_size, eps=1e-6):
  108. """
  109. Construct a layernorm module in the MT5 style. No bias and no subtraction of mean.
  110. """
  111. super().__init__()
  112. self.weight = nn.Parameter(torch.ones(hidden_size))
  113. self.variance_epsilon = eps
  114. def forward(self, hidden_states):
  115. # MT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  116. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  117. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  118. # half-precision inputs is done in fp32
  119. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  120. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  121. # convert into half-precision if necessary
  122. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  123. hidden_states = hidden_states.to(self.weight.dtype)
  124. return self.weight * hidden_states
  125. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->MT5
  126. class MT5DenseActDense(nn.Module):
  127. def __init__(self, config: MT5Config):
  128. super().__init__()
  129. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  130. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  131. self.dropout = nn.Dropout(config.dropout_rate)
  132. self.act = ACT2FN[config.dense_act_fn]
  133. def forward(self, hidden_states):
  134. hidden_states = self.wi(hidden_states)
  135. hidden_states = self.act(hidden_states)
  136. hidden_states = self.dropout(hidden_states)
  137. if (
  138. isinstance(self.wo.weight, torch.Tensor)
  139. and hidden_states.dtype != self.wo.weight.dtype
  140. and self.wo.weight.dtype != torch.int8
  141. ):
  142. hidden_states = hidden_states.to(self.wo.weight.dtype)
  143. hidden_states = self.wo(hidden_states)
  144. return hidden_states
  145. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->MT5
  146. class MT5DenseGatedActDense(nn.Module):
  147. def __init__(self, config: MT5Config):
  148. super().__init__()
  149. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  150. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  151. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  152. self.dropout = nn.Dropout(config.dropout_rate)
  153. self.act = ACT2FN[config.dense_act_fn]
  154. def forward(self, hidden_states):
  155. hidden_gelu = self.act(self.wi_0(hidden_states))
  156. hidden_linear = self.wi_1(hidden_states)
  157. hidden_states = hidden_gelu * hidden_linear
  158. hidden_states = self.dropout(hidden_states)
  159. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  160. # See https://github.com/huggingface/transformers/issues/20287
  161. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  162. if (
  163. isinstance(self.wo.weight, torch.Tensor)
  164. and hidden_states.dtype != self.wo.weight.dtype
  165. and self.wo.weight.dtype != torch.int8
  166. ):
  167. hidden_states = hidden_states.to(self.wo.weight.dtype)
  168. hidden_states = self.wo(hidden_states)
  169. return hidden_states
  170. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->MT5
  171. class MT5LayerFF(nn.Module):
  172. def __init__(self, config: MT5Config):
  173. super().__init__()
  174. if config.is_gated_act:
  175. self.DenseReluDense = MT5DenseGatedActDense(config)
  176. else:
  177. self.DenseReluDense = MT5DenseActDense(config)
  178. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  179. self.dropout = nn.Dropout(config.dropout_rate)
  180. def forward(self, hidden_states):
  181. forwarded_states = self.layer_norm(hidden_states)
  182. forwarded_states = self.DenseReluDense(forwarded_states)
  183. hidden_states = hidden_states + self.dropout(forwarded_states)
  184. return hidden_states
  185. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->MT5
  186. class MT5Attention(nn.Module):
  187. def __init__(
  188. self,
  189. config: MT5Config,
  190. has_relative_attention_bias=False,
  191. layer_idx: Optional[int] = None,
  192. ):
  193. super().__init__()
  194. self.is_decoder = config.is_decoder
  195. self.has_relative_attention_bias = has_relative_attention_bias
  196. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  197. self.relative_attention_max_distance = config.relative_attention_max_distance
  198. self.d_model = config.d_model
  199. self.key_value_proj_dim = config.d_kv
  200. self.n_heads = config.num_heads
  201. self.dropout = config.dropout_rate
  202. self.inner_dim = self.n_heads * self.key_value_proj_dim
  203. self.layer_idx = layer_idx
  204. if layer_idx is None and self.is_decoder:
  205. logger.warning_once(
  206. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  207. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  208. "when creating this class."
  209. )
  210. # Mesh TensorFlow initialization to avoid scaling before softmax
  211. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  212. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  213. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  214. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  215. if self.has_relative_attention_bias:
  216. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  217. self.pruned_heads = set()
  218. self.gradient_checkpointing = False
  219. def prune_heads(self, heads):
  220. if len(heads) == 0:
  221. return
  222. heads, index = find_pruneable_heads_and_indices(
  223. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  224. )
  225. # Prune linear layers
  226. self.q = prune_linear_layer(self.q, index)
  227. self.k = prune_linear_layer(self.k, index)
  228. self.v = prune_linear_layer(self.v, index)
  229. self.o = prune_linear_layer(self.o, index, dim=1)
  230. # Update hyper params
  231. self.n_heads = self.n_heads - len(heads)
  232. self.inner_dim = self.key_value_proj_dim * self.n_heads
  233. self.pruned_heads = self.pruned_heads.union(heads)
  234. @staticmethod
  235. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  236. """
  237. Adapted from Mesh Tensorflow:
  238. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  239. Translate relative position to a bucket number for relative attention. The relative position is defined as
  240. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  241. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  242. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  243. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  244. This should allow for more graceful generalization to longer sequences than the model has been trained on
  245. Args:
  246. relative_position: an int32 Tensor
  247. bidirectional: a boolean - whether the attention is bidirectional
  248. num_buckets: an integer
  249. max_distance: an integer
  250. Returns:
  251. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  252. """
  253. relative_buckets = 0
  254. if bidirectional:
  255. num_buckets //= 2
  256. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  257. relative_position = torch.abs(relative_position)
  258. else:
  259. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  260. # now relative_position is in the range [0, inf)
  261. # half of the buckets are for exact increments in positions
  262. max_exact = num_buckets // 2
  263. is_small = relative_position < max_exact
  264. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  265. relative_position_if_large = max_exact + (
  266. torch.log(relative_position.float() / max_exact)
  267. / math.log(max_distance / max_exact)
  268. * (num_buckets - max_exact)
  269. ).to(torch.long)
  270. relative_position_if_large = torch.min(
  271. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  272. )
  273. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  274. return relative_buckets
  275. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  276. """Compute binned relative position bias"""
  277. if device is None:
  278. device = self.relative_attention_bias.weight.device
  279. if cache_position is None:
  280. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  281. else:
  282. context_position = cache_position[:, None].to(device)
  283. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  284. relative_position = memory_position - context_position # shape (query_length, key_length)
  285. relative_position_bucket = self._relative_position_bucket(
  286. relative_position, # shape (query_length, key_length)
  287. bidirectional=(not self.is_decoder),
  288. num_buckets=self.relative_attention_num_buckets,
  289. max_distance=self.relative_attention_max_distance,
  290. )
  291. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  292. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  293. return values
  294. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  295. def forward(
  296. self,
  297. hidden_states,
  298. mask=None,
  299. key_value_states=None,
  300. position_bias=None,
  301. past_key_values=None,
  302. layer_head_mask=None,
  303. query_length=None,
  304. use_cache=False,
  305. output_attentions=False,
  306. cache_position=None,
  307. ):
  308. """
  309. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  310. """
  311. # Input is (batch_size, seq_length, dim)
  312. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  313. batch_size, seq_length = hidden_states.shape[:2]
  314. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  315. is_cross_attention = key_value_states is not None
  316. query_states = self.q(hidden_states)
  317. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  318. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  319. is_updated = False
  320. if isinstance(past_key_values, EncoderDecoderCache):
  321. is_updated = past_key_values.is_updated.get(self.layer_idx)
  322. if is_cross_attention:
  323. # after the first generated id, we can subsequently re-use all key/value_states from cache
  324. curr_past_key_value = past_key_values.cross_attention_cache
  325. else:
  326. curr_past_key_value = past_key_values.self_attention_cache
  327. else:
  328. curr_past_key_value = past_key_values
  329. current_states = key_value_states if is_cross_attention else hidden_states
  330. if is_cross_attention and past_key_values is not None and is_updated:
  331. # reuse k,v, cross_attentions
  332. key_states = curr_past_key_value.layers[self.layer_idx].keys
  333. value_states = curr_past_key_value.layers[self.layer_idx].values
  334. else:
  335. key_states = self.k(current_states)
  336. value_states = self.v(current_states)
  337. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  338. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  339. if past_key_values is not None:
  340. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  341. cache_position = cache_position if not is_cross_attention else None
  342. key_states, value_states = curr_past_key_value.update(
  343. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  344. )
  345. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  346. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  347. past_key_values.is_updated[self.layer_idx] = True
  348. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  349. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  350. if position_bias is None:
  351. key_length = key_states.shape[-2]
  352. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  353. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  354. if not self.has_relative_attention_bias:
  355. position_bias = torch.zeros(
  356. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  357. )
  358. if self.gradient_checkpointing and self.training:
  359. position_bias.requires_grad = True
  360. else:
  361. position_bias = self.compute_bias(
  362. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  363. )
  364. position_bias = position_bias[:, :, -seq_length:, :]
  365. if mask is not None:
  366. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  367. position_bias = position_bias + causal_mask
  368. if self.pruned_heads:
  369. mask = torch.ones(position_bias.shape[1])
  370. mask[list(self.pruned_heads)] = 0
  371. position_bias_masked = position_bias[:, mask.bool()]
  372. else:
  373. position_bias_masked = position_bias
  374. scores += position_bias_masked
  375. # (batch_size, n_heads, seq_length, key_length)
  376. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  377. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  378. # Mask heads if we want to
  379. if layer_head_mask is not None:
  380. attn_weights = attn_weights * layer_head_mask
  381. attn_output = torch.matmul(attn_weights, value_states)
  382. attn_output = attn_output.transpose(1, 2).contiguous()
  383. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  384. attn_output = self.o(attn_output)
  385. outputs = (attn_output, position_bias)
  386. if output_attentions:
  387. outputs = outputs + (attn_weights,)
  388. return outputs
  389. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->MT5
  390. class MT5LayerSelfAttention(nn.Module):
  391. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  392. super().__init__()
  393. self.SelfAttention = MT5Attention(
  394. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  395. )
  396. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  397. self.dropout = nn.Dropout(config.dropout_rate)
  398. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  399. def forward(
  400. self,
  401. hidden_states,
  402. attention_mask=None,
  403. position_bias=None,
  404. layer_head_mask=None,
  405. past_key_values=None,
  406. use_cache=False,
  407. output_attentions=False,
  408. cache_position=None,
  409. ):
  410. normed_hidden_states = self.layer_norm(hidden_states)
  411. attention_output = self.SelfAttention(
  412. normed_hidden_states,
  413. mask=attention_mask,
  414. position_bias=position_bias,
  415. layer_head_mask=layer_head_mask,
  416. past_key_values=past_key_values,
  417. use_cache=use_cache,
  418. output_attentions=output_attentions,
  419. cache_position=cache_position,
  420. )
  421. hidden_states = hidden_states + self.dropout(attention_output[0])
  422. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  423. return outputs
  424. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->MT5
  425. class MT5LayerCrossAttention(nn.Module):
  426. def __init__(self, config, layer_idx: Optional[int] = None):
  427. super().__init__()
  428. self.EncDecAttention = MT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  429. self.layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  430. self.dropout = nn.Dropout(config.dropout_rate)
  431. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  432. def forward(
  433. self,
  434. hidden_states,
  435. key_value_states,
  436. attention_mask=None,
  437. position_bias=None,
  438. layer_head_mask=None,
  439. past_key_values=None,
  440. use_cache=False,
  441. query_length=None,
  442. output_attentions=False,
  443. cache_position=None,
  444. ):
  445. normed_hidden_states = self.layer_norm(hidden_states)
  446. attention_output = self.EncDecAttention(
  447. normed_hidden_states,
  448. mask=attention_mask,
  449. key_value_states=key_value_states,
  450. position_bias=position_bias,
  451. layer_head_mask=layer_head_mask,
  452. past_key_values=past_key_values,
  453. use_cache=use_cache,
  454. query_length=query_length,
  455. output_attentions=output_attentions,
  456. cache_position=cache_position,
  457. )
  458. layer_output = hidden_states + self.dropout(attention_output[0])
  459. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  460. return outputs
  461. # Copied from transformers.models.t5.modeling_t5.T5Block with T5->MT5
  462. class MT5Block(GradientCheckpointingLayer):
  463. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  464. super().__init__()
  465. self.is_decoder = config.is_decoder
  466. self.layer = nn.ModuleList()
  467. self.layer.append(
  468. MT5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  469. )
  470. if self.is_decoder:
  471. self.layer.append(MT5LayerCrossAttention(config, layer_idx=layer_idx))
  472. self.layer.append(MT5LayerFF(config))
  473. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  474. def forward(
  475. self,
  476. hidden_states,
  477. attention_mask=None,
  478. position_bias=None,
  479. encoder_hidden_states=None,
  480. encoder_attention_mask=None,
  481. encoder_decoder_position_bias=None,
  482. layer_head_mask=None,
  483. cross_attn_layer_head_mask=None,
  484. past_key_values=None,
  485. use_cache=False,
  486. output_attentions=False,
  487. return_dict=True,
  488. cache_position=None,
  489. ):
  490. self_attention_outputs = self.layer[0](
  491. hidden_states,
  492. attention_mask=attention_mask,
  493. position_bias=position_bias,
  494. layer_head_mask=layer_head_mask,
  495. past_key_values=past_key_values,
  496. use_cache=use_cache,
  497. output_attentions=output_attentions,
  498. cache_position=cache_position,
  499. )
  500. hidden_states = self_attention_outputs[0]
  501. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  502. # clamp inf values to enable fp16 training
  503. if hidden_states.dtype == torch.float16:
  504. clamp_value = torch.where(
  505. torch.isinf(hidden_states).any(),
  506. torch.finfo(hidden_states.dtype).max - 1000,
  507. torch.finfo(hidden_states.dtype).max,
  508. )
  509. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  510. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  511. if do_cross_attention:
  512. cross_attention_outputs = self.layer[1](
  513. hidden_states,
  514. key_value_states=encoder_hidden_states,
  515. attention_mask=encoder_attention_mask,
  516. position_bias=encoder_decoder_position_bias,
  517. layer_head_mask=cross_attn_layer_head_mask,
  518. past_key_values=past_key_values,
  519. query_length=cache_position[-1] + 1,
  520. use_cache=use_cache,
  521. output_attentions=output_attentions,
  522. )
  523. hidden_states = cross_attention_outputs[0]
  524. # clamp inf values to enable fp16 training
  525. if hidden_states.dtype == torch.float16:
  526. clamp_value = torch.where(
  527. torch.isinf(hidden_states).any(),
  528. torch.finfo(hidden_states.dtype).max - 1000,
  529. torch.finfo(hidden_states.dtype).max,
  530. )
  531. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  532. # Keep cross-attention outputs and relative position weights
  533. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  534. # Apply Feed Forward layer
  535. hidden_states = self.layer[-1](hidden_states)
  536. # clamp inf values to enable fp16 training
  537. if hidden_states.dtype == torch.float16:
  538. clamp_value = torch.where(
  539. torch.isinf(hidden_states).any(),
  540. torch.finfo(hidden_states.dtype).max - 1000,
  541. torch.finfo(hidden_states.dtype).max,
  542. )
  543. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  544. outputs = (hidden_states,)
  545. return (
  546. outputs + attention_outputs
  547. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  548. def load_tf_weights_in_mt5(model, config, tf_checkpoint_path):
  549. """Load tf checkpoints in a pytorch model."""
  550. try:
  551. import re
  552. import numpy as np
  553. import tensorflow as tf
  554. except ImportError:
  555. logger.error(
  556. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  557. "https://www.tensorflow.org/install/ for installation instructions."
  558. )
  559. raise
  560. tf_path = os.path.abspath(tf_checkpoint_path)
  561. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  562. # Load weights from TF model
  563. init_vars = tf.train.list_variables(tf_path)
  564. names = []
  565. tf_weights = {}
  566. for name, shape in init_vars:
  567. logger.info(f"Loading TF weight {name} with shape {shape}")
  568. array = tf.train.load_variable(tf_path, name)
  569. names.append(name)
  570. tf_weights[name] = array
  571. for txt_name in names:
  572. name = txt_name.split("/")
  573. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  574. # which are not required for using pretrained model
  575. if any(
  576. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  577. for n in name
  578. ):
  579. logger.info(f"Skipping {'/'.join(name)}")
  580. tf_weights.pop(txt_name, None)
  581. continue
  582. if "_slot_" in name[-1]:
  583. logger.info(f"Skipping {'/'.join(name)}")
  584. tf_weights.pop(txt_name, None)
  585. continue
  586. pointer = model
  587. array = tf_weights[txt_name]
  588. for m_name in name:
  589. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  590. scope_names = re.split(r"_(\d+)", m_name)
  591. else:
  592. scope_names = [m_name]
  593. if scope_names[0] in ["kernel", "scale", "embedding"]:
  594. pointer = getattr(pointer, "weight")
  595. elif scope_names[0] == "self_attention":
  596. pointer = getattr(pointer, "layer")
  597. pointer = pointer[0]
  598. elif scope_names[0] == "enc_dec_attention":
  599. pointer = getattr(pointer, "layer")
  600. pointer = pointer[1]
  601. elif scope_names[0] == "dense_relu_dense":
  602. pointer = getattr(pointer, "layer")
  603. pointer = pointer[2]
  604. elif scope_names[0] == "rms_norm":
  605. if hasattr(pointer, "layer_norm"):
  606. pointer = getattr(pointer, "layer_norm")
  607. elif hasattr(pointer, "final_layer_norm"):
  608. pointer = getattr(pointer, "final_layer_norm")
  609. elif scope_names[0] == "scale":
  610. pointer = getattr(pointer, "weight")
  611. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  612. pointer = getattr(pointer, "bias")
  613. elif scope_names[0] == "squad":
  614. pointer = getattr(pointer, "classifier")
  615. elif scope_names[0] == "decoder" and name[1] == "logits":
  616. continue
  617. elif scope_names[0] == "logits":
  618. pointer = getattr(pointer, "lm_head")
  619. elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
  620. pointer = getattr(pointer, f"wi_{scope_names[1]}")
  621. continue
  622. else:
  623. try:
  624. pointer = getattr(pointer, scope_names[0])
  625. except AttributeError:
  626. logger.info(f"Skipping {'/'.join(name)}")
  627. continue
  628. if len(scope_names) >= 2:
  629. num = int(scope_names[1])
  630. pointer = pointer[num]
  631. if scope_names[0] not in ["kernel", "scale", "embedding"]:
  632. pointer = getattr(pointer, "weight")
  633. if scope_names[0] != "embedding":
  634. logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
  635. array = np.transpose(array)
  636. try:
  637. assert pointer.shape == array.shape, (
  638. f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched"
  639. )
  640. except AssertionError as e:
  641. e.args += (pointer.shape, array.shape)
  642. raise
  643. logger.info(f"Initialize PyTorch weight {name}")
  644. pointer.data = torch.from_numpy(array.astype(np.float32))
  645. tf_weights.pop(txt_name, None)
  646. logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
  647. return model
  648. # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->MT5
  649. class MT5ClassificationHead(nn.Module):
  650. """Head for sentence-level classification tasks."""
  651. def __init__(self, config: MT5Config):
  652. super().__init__()
  653. self.dense = nn.Linear(config.d_model, config.d_model)
  654. self.dropout = nn.Dropout(p=config.classifier_dropout)
  655. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  656. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  657. hidden_states = self.dropout(hidden_states)
  658. hidden_states = self.dense(hidden_states)
  659. hidden_states = torch.tanh(hidden_states)
  660. hidden_states = self.dropout(hidden_states)
  661. hidden_states = self.out_proj(hidden_states)
  662. return hidden_states
  663. @auto_docstring
  664. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel with T5->MT5, t5->mt5
  665. class MT5PreTrainedModel(PreTrainedModel):
  666. config: MT5Config
  667. load_tf_weights = load_tf_weights_in_mt5
  668. base_model_prefix = "transformer"
  669. is_parallelizable = True
  670. supports_gradient_checkpointing = True
  671. _can_compile_fullgraph = True
  672. _no_split_modules = ["MT5Block"]
  673. _keep_in_fp32_modules = ["wo"]
  674. @property
  675. def dummy_inputs(self):
  676. input_ids = torch.tensor(DUMMY_INPUTS)
  677. input_mask = torch.tensor(DUMMY_MASK)
  678. dummy_inputs = {
  679. "decoder_input_ids": input_ids,
  680. "input_ids": input_ids,
  681. "decoder_attention_mask": input_mask,
  682. }
  683. return dummy_inputs
  684. def _init_weights(self, module):
  685. """Initialize the weights"""
  686. factor = self.config.initializer_factor # Used for testing weights initialization
  687. if isinstance(module, MT5LayerNorm):
  688. module.weight.data.fill_(factor * 1.0)
  689. elif isinstance(
  690. module,
  691. (MT5Model, MT5ForConditionalGeneration, MT5EncoderModel, MT5ForQuestionAnswering),
  692. ):
  693. # Mesh TensorFlow embeddings initialization
  694. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  695. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  696. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  697. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  698. if hasattr(module, "qa_outputs"):
  699. module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  700. module.qa_outputs.bias.data.zero_()
  701. elif isinstance(module, MT5ForTokenClassification):
  702. if hasattr(module, "classifier"):
  703. module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
  704. module.classifier.bias.data.zero_()
  705. elif isinstance(module, MT5ClassificationHead):
  706. module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  707. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  708. module.dense.bias.data.zero_()
  709. module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  710. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  711. module.out_proj.bias.data.zero_()
  712. elif isinstance(module, MT5DenseActDense):
  713. # Mesh TensorFlow FF initialization
  714. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  715. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  716. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  717. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  718. module.wi.bias.data.zero_()
  719. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  720. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  721. module.wo.bias.data.zero_()
  722. elif isinstance(module, MT5DenseGatedActDense):
  723. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  724. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  725. module.wi_0.bias.data.zero_()
  726. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  727. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  728. module.wi_1.bias.data.zero_()
  729. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  730. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  731. module.wo.bias.data.zero_()
  732. elif isinstance(module, MT5Attention):
  733. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  734. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  735. d_model = self.config.d_model
  736. key_value_proj_dim = self.config.d_kv
  737. n_heads = self.config.num_heads
  738. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  739. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  740. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  741. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  742. if module.has_relative_attention_bias:
  743. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  744. def _shift_right(self, input_ids):
  745. decoder_start_token_id = self.config.decoder_start_token_id
  746. pad_token_id = self.config.pad_token_id
  747. if decoder_start_token_id is None:
  748. raise ValueError(
  749. "self.model.config.decoder_start_token_id has to be defined. In MT5 it is usually set to the pad_token_id. "
  750. "See MT5 docs for more information."
  751. )
  752. # shift inputs to the right
  753. if is_torch_fx_proxy(input_ids):
  754. # Item assignment is not supported natively for proxies.
  755. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  756. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  757. else:
  758. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  759. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  760. shifted_input_ids[..., 0] = decoder_start_token_id
  761. if pad_token_id is None:
  762. raise ValueError("self.model.config.pad_token_id has to be defined.")
  763. # replace possible -100 values in labels by `pad_token_id`
  764. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  765. return shifted_input_ids
  766. # Copied from transformers.models.t5.modeling_t5.T5Stack with T5->MT5
  767. class MT5Stack(MT5PreTrainedModel):
  768. def __init__(self, config, embed_tokens=None):
  769. super().__init__(config)
  770. self.embed_tokens = embed_tokens
  771. self.is_decoder = config.is_decoder
  772. self.block = nn.ModuleList(
  773. [MT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
  774. )
  775. self.final_layer_norm = MT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  776. self.dropout = nn.Dropout(config.dropout_rate)
  777. # Initialize weights and apply final processing
  778. self.post_init()
  779. # Model parallel
  780. self.model_parallel = False
  781. self.device_map = None
  782. self.gradient_checkpointing = False
  783. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  784. def parallelize(self, device_map=None):
  785. warnings.warn(
  786. "`MT5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  787. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  788. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  789. " 'block.1': 1, ...}",
  790. FutureWarning,
  791. )
  792. # Check validity of device_map
  793. self.device_map = (
  794. get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
  795. )
  796. assert_device_map(self.device_map, len(self.block))
  797. self.model_parallel = True
  798. self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
  799. self.last_device = "cuda:" + str(max(self.device_map.keys()))
  800. # Load onto devices
  801. for k, v in self.device_map.items():
  802. for layer in v:
  803. cuda_device = "cuda:" + str(k)
  804. self.block[layer] = self.block[layer].to(cuda_device)
  805. # Set embed_tokens to first layer
  806. self.embed_tokens = self.embed_tokens.to(self.first_device)
  807. # Set final layer norm to last device
  808. self.final_layer_norm = self.final_layer_norm.to(self.last_device)
  809. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  810. def deparallelize(self):
  811. warnings.warn(
  812. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  813. FutureWarning,
  814. )
  815. self.model_parallel = False
  816. self.device_map = None
  817. self.first_device = "cpu"
  818. self.last_device = "cpu"
  819. for i in range(len(self.block)):
  820. self.block[i] = self.block[i].to("cpu")
  821. self.embed_tokens = self.embed_tokens.to("cpu")
  822. self.final_layer_norm = self.final_layer_norm.to("cpu")
  823. torch.cuda.empty_cache()
  824. def set_input_embeddings(self, new_embeddings):
  825. self.embed_tokens = new_embeddings
  826. def forward(
  827. self,
  828. input_ids=None,
  829. attention_mask=None,
  830. encoder_hidden_states=None,
  831. encoder_attention_mask=None,
  832. inputs_embeds=None,
  833. head_mask=None,
  834. cross_attn_head_mask=None,
  835. past_key_values=None,
  836. use_cache=None,
  837. output_attentions=None,
  838. output_hidden_states=None,
  839. return_dict=None,
  840. cache_position=None,
  841. ):
  842. # Model parallel
  843. if self.model_parallel:
  844. torch.cuda.set_device(self.first_device)
  845. self.embed_tokens = self.embed_tokens.to(self.first_device)
  846. use_cache = use_cache if use_cache is not None else self.config.use_cache
  847. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  848. output_hidden_states = (
  849. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  850. )
  851. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  852. if input_ids is not None and inputs_embeds is not None:
  853. err_msg_prefix = "decoder_" if self.is_decoder else ""
  854. raise ValueError(
  855. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  856. )
  857. elif input_ids is not None:
  858. input_shape = input_ids.size()
  859. input_ids = input_ids.view(-1, input_shape[-1])
  860. elif inputs_embeds is not None:
  861. input_shape = inputs_embeds.size()[:-1]
  862. else:
  863. err_msg_prefix = "decoder_" if self.is_decoder else ""
  864. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  865. if self.gradient_checkpointing and self.training:
  866. if use_cache:
  867. logger.warning_once(
  868. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  869. )
  870. use_cache = False
  871. if inputs_embeds is None:
  872. if self.embed_tokens is None:
  873. raise ValueError("You have to initialize the model with valid token embeddings")
  874. inputs_embeds = self.embed_tokens(input_ids)
  875. batch_size, seq_length = input_shape
  876. if use_cache is True:
  877. if not self.is_decoder:
  878. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  879. if self.is_decoder:
  880. if use_cache and past_key_values is None:
  881. if self.config.is_encoder_decoder:
  882. past_key_values = EncoderDecoderCache(
  883. DynamicCache(config=self.config), DynamicCache(config=self.config)
  884. )
  885. else:
  886. past_key_values = DynamicCache(config=self.config)
  887. elif not self.is_decoder:
  888. # do not pass cache object down the line for encoder stack
  889. # it messes indexing later in decoder-stack because cache object is modified in-place
  890. past_key_values = None
  891. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  892. if cache_position is None:
  893. cache_position = torch.arange(
  894. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  895. )
  896. if attention_mask is None and not is_torchdynamo_compiling():
  897. # required mask seq length can be calculated via length of past cache
  898. mask_seq_length = past_key_values_length + seq_length
  899. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  900. if self.config.is_decoder:
  901. causal_mask = self._update_causal_mask(
  902. attention_mask,
  903. inputs_embeds,
  904. cache_position,
  905. past_key_values.self_attention_cache
  906. if isinstance(past_key_values, EncoderDecoderCache)
  907. else past_key_values,
  908. output_attentions,
  909. )
  910. elif attention_mask is not None:
  911. causal_mask = attention_mask[:, None, None, :]
  912. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  913. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  914. else:
  915. causal_mask = None
  916. # If a 2D or 3D attention mask is provided for the cross-attention
  917. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  918. if self.is_decoder and encoder_hidden_states is not None:
  919. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  920. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  921. if encoder_attention_mask is None:
  922. encoder_attention_mask = torch.ones(
  923. encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
  924. )
  925. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  926. else:
  927. encoder_extended_attention_mask = None
  928. # Prepare head mask if needed
  929. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  930. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  931. all_hidden_states = () if output_hidden_states else None
  932. all_attentions = () if output_attentions else None
  933. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  934. position_bias = None
  935. encoder_decoder_position_bias = None
  936. hidden_states = self.dropout(inputs_embeds)
  937. for i, layer_module in enumerate(self.block):
  938. layer_head_mask = head_mask[i]
  939. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  940. # Model parallel
  941. if self.model_parallel:
  942. torch.cuda.set_device(hidden_states.device)
  943. # Ensure that attention_mask is always on the same device as hidden_states
  944. if causal_mask is not None:
  945. causal_mask = causal_mask.to(hidden_states.device)
  946. if position_bias is not None:
  947. position_bias = position_bias.to(hidden_states.device)
  948. if encoder_hidden_states is not None:
  949. encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
  950. if encoder_extended_attention_mask is not None:
  951. encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
  952. if encoder_decoder_position_bias is not None:
  953. encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
  954. if layer_head_mask is not None:
  955. layer_head_mask = layer_head_mask.to(hidden_states.device)
  956. if cross_attn_layer_head_mask is not None:
  957. cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
  958. if output_hidden_states:
  959. all_hidden_states = all_hidden_states + (hidden_states,)
  960. layer_outputs = layer_module(
  961. hidden_states,
  962. causal_mask,
  963. position_bias,
  964. encoder_hidden_states,
  965. encoder_extended_attention_mask,
  966. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  967. layer_head_mask=layer_head_mask,
  968. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  969. past_key_values=past_key_values,
  970. use_cache=use_cache,
  971. output_attentions=output_attentions,
  972. return_dict=return_dict,
  973. cache_position=cache_position,
  974. )
  975. hidden_states = layer_outputs[0]
  976. # We share the position biases between the layers - the first layer store them
  977. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  978. # (cross-attention position bias), (cross-attention weights)
  979. position_bias = layer_outputs[1]
  980. if self.is_decoder and encoder_hidden_states is not None:
  981. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  982. if output_attentions:
  983. all_attentions = all_attentions + (layer_outputs[2],)
  984. if self.is_decoder:
  985. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  986. # Model Parallel: If it's the last layer for that device, put things on the next device
  987. if self.model_parallel:
  988. for k, v in self.device_map.items():
  989. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  990. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  991. hidden_states = self.final_layer_norm(hidden_states)
  992. hidden_states = self.dropout(hidden_states)
  993. # Add last layer
  994. if output_hidden_states:
  995. all_hidden_states = all_hidden_states + (hidden_states,)
  996. if not return_dict:
  997. return tuple(
  998. v
  999. for v in [
  1000. hidden_states,
  1001. past_key_values,
  1002. all_hidden_states,
  1003. all_attentions,
  1004. all_cross_attentions,
  1005. ]
  1006. if v is not None
  1007. )
  1008. return BaseModelOutputWithPastAndCrossAttentions(
  1009. last_hidden_state=hidden_states,
  1010. past_key_values=past_key_values,
  1011. hidden_states=all_hidden_states,
  1012. attentions=all_attentions,
  1013. cross_attentions=all_cross_attentions,
  1014. )
  1015. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  1016. def _update_causal_mask(
  1017. self,
  1018. attention_mask: Union[torch.Tensor, "BlockMask"],
  1019. input_tensor: torch.Tensor,
  1020. cache_position: torch.Tensor,
  1021. past_key_values: Cache,
  1022. output_attentions: bool = False,
  1023. ):
  1024. if self.config._attn_implementation == "flash_attention_2":
  1025. if attention_mask is not None and (attention_mask == 0.0).any():
  1026. return attention_mask
  1027. return None
  1028. if self.config._attn_implementation == "flex_attention":
  1029. if isinstance(attention_mask, torch.Tensor):
  1030. attention_mask = make_flex_block_causal_mask(attention_mask)
  1031. return attention_mask
  1032. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1033. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1034. # to infer the attention mask.
  1035. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1036. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  1037. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1038. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  1039. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1040. attention_mask,
  1041. inputs_embeds=input_tensor,
  1042. past_key_values_length=past_seen_tokens,
  1043. is_training=self.training,
  1044. ):
  1045. return None
  1046. dtype = input_tensor.dtype
  1047. sequence_length = input_tensor.shape[1]
  1048. if using_compilable_cache:
  1049. target_length = past_key_values.get_max_cache_shape()
  1050. else:
  1051. target_length = (
  1052. attention_mask.shape[-1]
  1053. if isinstance(attention_mask, torch.Tensor)
  1054. else past_seen_tokens + sequence_length + 1
  1055. )
  1056. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1057. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1058. attention_mask,
  1059. sequence_length=sequence_length,
  1060. target_length=target_length,
  1061. dtype=dtype,
  1062. cache_position=cache_position,
  1063. batch_size=input_tensor.shape[0],
  1064. )
  1065. if (
  1066. self.config._attn_implementation == "sdpa"
  1067. and attention_mask is not None
  1068. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1069. and not output_attentions
  1070. ):
  1071. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1072. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1073. # Details: https://github.com/pytorch/pytorch/issues/110213
  1074. min_dtype = torch.finfo(dtype).min
  1075. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1076. return causal_mask
  1077. @staticmethod
  1078. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  1079. def _prepare_4d_causal_attention_mask_with_cache_position(
  1080. attention_mask: torch.Tensor,
  1081. sequence_length: int,
  1082. target_length: int,
  1083. dtype: torch.dtype,
  1084. cache_position: torch.Tensor,
  1085. batch_size: int,
  1086. **kwargs,
  1087. ):
  1088. """
  1089. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1090. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1091. Args:
  1092. attention_mask (`torch.Tensor`):
  1093. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1094. `(batch_size, 1, query_length, key_value_length)`.
  1095. sequence_length (`int`):
  1096. The sequence length being processed.
  1097. target_length (`int`):
  1098. The target length: when generating with static cache, the mask should be as long as the static cache,
  1099. to account for the 0 padding, the part of the cache that is not filled yet.
  1100. dtype (`torch.dtype`):
  1101. The dtype to use for the 4D attention mask.
  1102. cache_position (`torch.Tensor`):
  1103. Indices depicting the position of the input sequence tokens in the sequence.
  1104. batch_size (`torch.Tensor`):
  1105. Batch size.
  1106. """
  1107. if attention_mask is not None and attention_mask.dim() == 4:
  1108. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1109. causal_mask = attention_mask
  1110. else:
  1111. min_dtype = torch.finfo(dtype).min
  1112. causal_mask = torch.full(
  1113. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1114. )
  1115. if sequence_length != 1:
  1116. causal_mask = torch.triu(causal_mask, diagonal=1)
  1117. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1118. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1119. if attention_mask is not None:
  1120. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1121. mask_length = attention_mask.shape[-1]
  1122. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1123. causal_mask.device
  1124. )
  1125. padding_mask = padding_mask == 0
  1126. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1127. padding_mask, min_dtype
  1128. )
  1129. return causal_mask
  1130. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1131. __HEAD_MASK_WARNING_MSG = """
  1132. The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
  1133. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
  1134. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
  1135. num_heads)`.
  1136. """
  1137. @auto_docstring
  1138. class MT5Model(MT5PreTrainedModel):
  1139. r"""
  1140. Examples:
  1141. ```python
  1142. >>> from transformers import MT5Model, AutoTokenizer
  1143. >>> model = MT5Model.from_pretrained("google/mt5-small")
  1144. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1145. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1146. >>> summary = "Weiter Verhandlung in Syrien."
  1147. >>> inputs = tokenizer(article, return_tensors="pt")
  1148. >>> labels = tokenizer(text_target=summary, return_tensors="pt")
  1149. >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
  1150. >>> hidden_states = outputs.last_hidden_state
  1151. ```"""
  1152. model_type = "mt5"
  1153. config: MT5Config
  1154. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1155. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1156. # Copied from transformers.models.t5.modeling_t5.T5Model.__init__ with T5->MT5
  1157. def __init__(self, config: MT5Config):
  1158. super().__init__(config)
  1159. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1160. encoder_config = copy.deepcopy(config)
  1161. encoder_config.is_decoder = False
  1162. encoder_config.use_cache = False
  1163. encoder_config.tie_encoder_decoder = False
  1164. self.encoder = MT5Stack(encoder_config, self.shared)
  1165. decoder_config = copy.deepcopy(config)
  1166. decoder_config.is_decoder = True
  1167. decoder_config.tie_encoder_decoder = False
  1168. decoder_config.num_layers = config.num_decoder_layers
  1169. self.decoder = MT5Stack(decoder_config, self.shared)
  1170. # Initialize weights and apply final processing
  1171. self.post_init()
  1172. # Model parallel
  1173. self.model_parallel = False
  1174. self.device_map = None
  1175. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1176. # Copied from transformers.models.t5.modeling_t5.T5Model.parallelize
  1177. def parallelize(self, device_map=None):
  1178. warnings.warn(
  1179. "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  1180. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1181. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
  1182. " 0, 'encoder.block.1': 1, ...}",
  1183. FutureWarning,
  1184. )
  1185. self.device_map = (
  1186. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1187. if device_map is None
  1188. else device_map
  1189. )
  1190. assert_device_map(self.device_map, len(self.encoder.block))
  1191. self.encoder.parallelize(self.device_map)
  1192. self.decoder.parallelize(self.device_map)
  1193. self.model_parallel = True
  1194. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1195. # Copied from transformers.models.t5.modeling_t5.T5Model.deparallelize
  1196. def deparallelize(self):
  1197. warnings.warn(
  1198. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1199. FutureWarning,
  1200. )
  1201. self.encoder.deparallelize()
  1202. self.decoder.deparallelize()
  1203. self.encoder = self.encoder.to("cpu")
  1204. self.decoder = self.decoder.to("cpu")
  1205. self.model_parallel = False
  1206. self.device_map = None
  1207. torch.cuda.empty_cache()
  1208. # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
  1209. def get_input_embeddings(self):
  1210. return self.shared
  1211. # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
  1212. def set_input_embeddings(self, new_embeddings):
  1213. self.shared = new_embeddings
  1214. self.encoder.set_input_embeddings(new_embeddings)
  1215. self.decoder.set_input_embeddings(new_embeddings)
  1216. # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
  1217. def get_encoder(self):
  1218. return self.encoder
  1219. # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
  1220. def _prune_heads(self, heads_to_prune):
  1221. """
  1222. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1223. class PreTrainedModel
  1224. """
  1225. for layer, heads in heads_to_prune.items():
  1226. self.encoder.layer[layer].attention.prune_heads(heads)
  1227. @auto_docstring
  1228. # Copied from transformers.models.t5.modeling_t5.T5Model.forward with google-t5/->google/, T5->MT5, t5->mt5
  1229. def forward(
  1230. self,
  1231. input_ids: Optional[torch.LongTensor] = None,
  1232. attention_mask: Optional[torch.FloatTensor] = None,
  1233. decoder_input_ids: Optional[torch.LongTensor] = None,
  1234. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1235. head_mask: Optional[torch.FloatTensor] = None,
  1236. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1237. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1238. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1239. past_key_values: Optional[Cache] = None,
  1240. inputs_embeds: Optional[torch.Tensor] = None,
  1241. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1242. use_cache: Optional[bool] = None,
  1243. output_attentions: Optional[bool] = None,
  1244. output_hidden_states: Optional[bool] = None,
  1245. return_dict: Optional[bool] = None,
  1246. cache_position: Optional[torch.LongTensor] = None,
  1247. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1248. r"""
  1249. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1250. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1251. should be able to pad the inputs on both the right and the left.
  1252. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1253. [`PreTrainedTokenizer.__call__`] for detail.
  1254. [What are input IDs?](../glossary#input-ids)
  1255. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1256. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1257. Indices of decoder input sequence tokens in the vocabulary.
  1258. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1259. [`PreTrainedTokenizer.__call__`] for details.
  1260. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1261. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1262. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1263. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  1264. Training](./mt5#training).
  1265. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1266. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1267. be used by default.
  1268. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1269. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1270. 1]`:
  1271. - 1 indicates the head is **not masked**,
  1272. - 0 indicates the head is **masked**.
  1273. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1274. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1275. `[0, 1]`:
  1276. - 1 indicates the head is **not masked**,
  1277. - 0 indicates the head is **masked**.
  1278. Example:
  1279. ```python
  1280. >>> from transformers import AutoTokenizer, MT5Model
  1281. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1282. >>> model = MT5Model.from_pretrained("google/mt5-small")
  1283. >>> input_ids = tokenizer(
  1284. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1285. ... ).input_ids # Batch size 1
  1286. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1287. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for MT5Model.
  1288. >>> # This is not needed for torch's MT5ForConditionalGeneration as it does this internally using labels arg.
  1289. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  1290. >>> # forward pass
  1291. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1292. >>> last_hidden_states = outputs.last_hidden_state
  1293. ```"""
  1294. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1295. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1296. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1297. if head_mask is not None and decoder_head_mask is None:
  1298. if self.config.num_layers == self.config.num_decoder_layers:
  1299. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1300. decoder_head_mask = head_mask
  1301. # Encode if needed (training, first prediction pass)
  1302. if encoder_outputs is None:
  1303. encoder_outputs = self.encoder(
  1304. input_ids=input_ids,
  1305. attention_mask=attention_mask,
  1306. inputs_embeds=inputs_embeds,
  1307. head_mask=head_mask,
  1308. output_attentions=output_attentions,
  1309. output_hidden_states=output_hidden_states,
  1310. return_dict=return_dict,
  1311. )
  1312. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1313. encoder_outputs = BaseModelOutput(
  1314. last_hidden_state=encoder_outputs[0],
  1315. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1316. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1317. )
  1318. hidden_states = encoder_outputs[0]
  1319. # Set device for model parallelism
  1320. if self.model_parallel:
  1321. torch.cuda.set_device(self.decoder.first_device)
  1322. hidden_states = hidden_states.to(self.decoder.first_device)
  1323. if decoder_input_ids is not None:
  1324. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1325. if attention_mask is not None:
  1326. attention_mask = attention_mask.to(self.decoder.first_device)
  1327. if decoder_attention_mask is not None:
  1328. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1329. # Decode
  1330. decoder_outputs = self.decoder(
  1331. input_ids=decoder_input_ids,
  1332. attention_mask=decoder_attention_mask,
  1333. inputs_embeds=decoder_inputs_embeds,
  1334. past_key_values=past_key_values,
  1335. encoder_hidden_states=hidden_states,
  1336. encoder_attention_mask=attention_mask,
  1337. head_mask=decoder_head_mask,
  1338. cross_attn_head_mask=cross_attn_head_mask,
  1339. use_cache=use_cache,
  1340. output_attentions=output_attentions,
  1341. output_hidden_states=output_hidden_states,
  1342. return_dict=return_dict,
  1343. cache_position=cache_position,
  1344. )
  1345. if not return_dict:
  1346. return decoder_outputs + encoder_outputs
  1347. return Seq2SeqModelOutput(
  1348. last_hidden_state=decoder_outputs.last_hidden_state,
  1349. past_key_values=decoder_outputs.past_key_values,
  1350. decoder_hidden_states=decoder_outputs.hidden_states,
  1351. decoder_attentions=decoder_outputs.attentions,
  1352. cross_attentions=decoder_outputs.cross_attentions,
  1353. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1354. encoder_hidden_states=encoder_outputs.hidden_states,
  1355. encoder_attentions=encoder_outputs.attentions,
  1356. )
  1357. @auto_docstring(
  1358. custom_intro="""
  1359. MT5 Model with a `language modeling` head on top.
  1360. """
  1361. )
  1362. class MT5ForConditionalGeneration(MT5PreTrainedModel, GenerationMixin):
  1363. r"""
  1364. Examples:
  1365. ```python
  1366. >>> from transformers import MT5ForConditionalGeneration, AutoTokenizer
  1367. >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
  1368. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1369. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1370. >>> summary = "Weiter Verhandlung in Syrien."
  1371. >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
  1372. >>> outputs = model(**inputs)
  1373. >>> loss = outputs.loss
  1374. ```"""
  1375. model_type = "mt5"
  1376. config: MT5Config
  1377. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1378. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1379. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.__init__ with T5->MT5
  1380. def __init__(self, config: MT5Config):
  1381. super().__init__(config)
  1382. self.model_dim = config.d_model
  1383. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1384. encoder_config = copy.deepcopy(config)
  1385. encoder_config.is_decoder = False
  1386. encoder_config.use_cache = False
  1387. encoder_config.tie_encoder_decoder = False
  1388. self.encoder = MT5Stack(encoder_config, self.shared)
  1389. decoder_config = copy.deepcopy(config)
  1390. decoder_config.is_decoder = True
  1391. decoder_config.tie_encoder_decoder = False
  1392. decoder_config.num_layers = config.num_decoder_layers
  1393. self.decoder = MT5Stack(decoder_config, self.shared)
  1394. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1395. # Initialize weights and apply final processing
  1396. self.post_init()
  1397. # Model parallel
  1398. self.model_parallel = False
  1399. self.device_map = None
  1400. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1401. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.parallelize
  1402. def parallelize(self, device_map=None):
  1403. warnings.warn(
  1404. "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
  1405. " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
  1406. " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
  1407. " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
  1408. FutureWarning,
  1409. )
  1410. self.device_map = (
  1411. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1412. if device_map is None
  1413. else device_map
  1414. )
  1415. assert_device_map(self.device_map, len(self.encoder.block))
  1416. self.encoder.parallelize(self.device_map)
  1417. self.decoder.parallelize(self.device_map)
  1418. self.lm_head = self.lm_head.to(self.decoder.first_device)
  1419. self.model_parallel = True
  1420. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1421. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.deparallelize
  1422. def deparallelize(self):
  1423. warnings.warn(
  1424. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1425. FutureWarning,
  1426. )
  1427. self.encoder.deparallelize()
  1428. self.decoder.deparallelize()
  1429. self.encoder = self.encoder.to("cpu")
  1430. self.decoder = self.decoder.to("cpu")
  1431. self.lm_head = self.lm_head.to("cpu")
  1432. self.model_parallel = False
  1433. self.device_map = None
  1434. torch.cuda.empty_cache()
  1435. def get_input_embeddings(self):
  1436. return self.shared
  1437. def set_input_embeddings(self, new_embeddings):
  1438. self.shared = new_embeddings
  1439. self.encoder.set_input_embeddings(new_embeddings)
  1440. self.decoder.set_input_embeddings(new_embeddings)
  1441. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder
  1442. def get_encoder(self):
  1443. return self.encoder
  1444. @auto_docstring
  1445. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.forward with google-t5/->google/, T5->MT5, t5->mt5
  1446. def forward(
  1447. self,
  1448. input_ids: Optional[torch.LongTensor] = None,
  1449. attention_mask: Optional[torch.FloatTensor] = None,
  1450. decoder_input_ids: Optional[torch.LongTensor] = None,
  1451. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1452. head_mask: Optional[torch.FloatTensor] = None,
  1453. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1454. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1455. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1456. past_key_values: Optional[Cache] = None,
  1457. inputs_embeds: Optional[torch.FloatTensor] = None,
  1458. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1459. labels: Optional[torch.LongTensor] = None,
  1460. use_cache: Optional[bool] = None,
  1461. output_attentions: Optional[bool] = None,
  1462. output_hidden_states: Optional[bool] = None,
  1463. return_dict: Optional[bool] = None,
  1464. cache_position: Optional[torch.LongTensor] = None,
  1465. ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1466. r"""
  1467. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1468. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1469. should be able to pad the inputs on both the right and the left.
  1470. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1471. [`PreTrainedTokenizer.__call__`] for detail.
  1472. [What are input IDs?](../glossary#input-ids)
  1473. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1474. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1475. Indices of decoder input sequence tokens in the vocabulary.
  1476. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1477. [`PreTrainedTokenizer.__call__`] for details.
  1478. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1479. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1480. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1481. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  1482. Training](./mt5#training).
  1483. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1484. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1485. be used by default.
  1486. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1487. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1488. 1]`:
  1489. - 1 indicates the head is **not masked**,
  1490. - 0 indicates the head is **masked**.
  1491. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1492. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1493. `[0, 1]`:
  1494. - 1 indicates the head is **not masked**,
  1495. - 0 indicates the head is **masked**.
  1496. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1497. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1498. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1499. labels in `[0, ..., config.vocab_size]`
  1500. Examples:
  1501. ```python
  1502. >>> from transformers import AutoTokenizer, MT5ForConditionalGeneration
  1503. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1504. >>> model = MT5ForConditionalGeneration.from_pretrained("google/mt5-small")
  1505. >>> # training
  1506. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  1507. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  1508. >>> outputs = model(input_ids=input_ids, labels=labels)
  1509. >>> loss = outputs.loss
  1510. >>> logits = outputs.logits
  1511. >>> # inference
  1512. >>> input_ids = tokenizer(
  1513. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  1514. ... ).input_ids # Batch size 1
  1515. >>> outputs = model.generate(input_ids)
  1516. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1517. >>> # studies have shown that owning a dog is good for you.
  1518. ```"""
  1519. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1520. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1521. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1522. if head_mask is not None and decoder_head_mask is None:
  1523. if self.config.num_layers == self.config.num_decoder_layers:
  1524. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1525. decoder_head_mask = head_mask
  1526. # Encode if needed (training, first prediction pass)
  1527. if encoder_outputs is None:
  1528. # Convert encoder inputs in embeddings if needed
  1529. encoder_outputs = self.encoder(
  1530. input_ids=input_ids,
  1531. attention_mask=attention_mask,
  1532. inputs_embeds=inputs_embeds,
  1533. head_mask=head_mask,
  1534. output_attentions=output_attentions,
  1535. output_hidden_states=output_hidden_states,
  1536. return_dict=return_dict,
  1537. )
  1538. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1539. encoder_outputs = BaseModelOutput(
  1540. last_hidden_state=encoder_outputs[0],
  1541. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1542. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1543. )
  1544. hidden_states = encoder_outputs[0]
  1545. if self.model_parallel:
  1546. torch.cuda.set_device(self.decoder.first_device)
  1547. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1548. # get decoder inputs from shifting lm labels to the right
  1549. decoder_input_ids = self._shift_right(labels)
  1550. # Set device for model parallelism
  1551. if self.model_parallel:
  1552. torch.cuda.set_device(self.decoder.first_device)
  1553. hidden_states = hidden_states.to(self.decoder.first_device)
  1554. if decoder_input_ids is not None:
  1555. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1556. if attention_mask is not None:
  1557. attention_mask = attention_mask.to(self.decoder.first_device)
  1558. if decoder_attention_mask is not None:
  1559. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1560. # Decode
  1561. decoder_outputs = self.decoder(
  1562. input_ids=decoder_input_ids,
  1563. attention_mask=decoder_attention_mask,
  1564. inputs_embeds=decoder_inputs_embeds,
  1565. past_key_values=past_key_values,
  1566. encoder_hidden_states=hidden_states,
  1567. encoder_attention_mask=attention_mask,
  1568. head_mask=decoder_head_mask,
  1569. cross_attn_head_mask=cross_attn_head_mask,
  1570. use_cache=use_cache,
  1571. output_attentions=output_attentions,
  1572. output_hidden_states=output_hidden_states,
  1573. return_dict=return_dict,
  1574. cache_position=cache_position,
  1575. )
  1576. sequence_output = decoder_outputs[0]
  1577. # Set device for model parallelism
  1578. if self.model_parallel:
  1579. torch.cuda.set_device(self.encoder.first_device)
  1580. self.lm_head = self.lm_head.to(self.encoder.first_device)
  1581. sequence_output = sequence_output.to(self.lm_head.weight.device)
  1582. if self.config.tie_word_embeddings:
  1583. # Rescale output before projecting on vocab
  1584. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1585. sequence_output = sequence_output * (self.model_dim**-0.5)
  1586. lm_logits = self.lm_head(sequence_output)
  1587. loss = None
  1588. if labels is not None:
  1589. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1590. # move labels to correct device to enable PP
  1591. labels = labels.to(lm_logits.device)
  1592. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1593. # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  1594. if not return_dict:
  1595. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1596. return ((loss,) + output) if loss is not None else output
  1597. return Seq2SeqLMOutput(
  1598. loss=loss,
  1599. logits=lm_logits,
  1600. past_key_values=decoder_outputs.past_key_values,
  1601. decoder_hidden_states=decoder_outputs.hidden_states,
  1602. decoder_attentions=decoder_outputs.attentions,
  1603. cross_attentions=decoder_outputs.cross_attentions,
  1604. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1605. encoder_hidden_states=encoder_outputs.hidden_states,
  1606. encoder_attentions=encoder_outputs.attentions,
  1607. )
  1608. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
  1609. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1610. return self._shift_right(labels)
  1611. @auto_docstring
  1612. class MT5EncoderModel(MT5PreTrainedModel):
  1613. r"""
  1614. Examples:
  1615. ```python
  1616. >>> from transformers import MT5EncoderModel, AutoTokenizer
  1617. >>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
  1618. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1619. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1620. >>> input_ids = tokenizer(article, return_tensors="pt").input_ids
  1621. >>> outputs = model(input_ids)
  1622. >>> hidden_state = outputs.last_hidden_state
  1623. ```"""
  1624. model_type = "mt5"
  1625. config: MT5Config
  1626. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1627. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.__init__ with T5->MT5
  1628. def __init__(self, config: MT5Config):
  1629. super().__init__(config)
  1630. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1631. encoder_config = config
  1632. encoder_config.use_cache = False
  1633. encoder_config.is_encoder_decoder = False
  1634. self.encoder = MT5Stack(encoder_config, self.shared)
  1635. # Initialize weights and apply final processing
  1636. self.post_init()
  1637. # Model parallel
  1638. self.model_parallel = False
  1639. self.device_map = None
  1640. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1641. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.parallelize
  1642. def parallelize(self, device_map=None):
  1643. warnings.warn(
  1644. "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
  1645. " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1646. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  1647. " 'block.1': 1, ...}",
  1648. FutureWarning,
  1649. )
  1650. self.device_map = (
  1651. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1652. if device_map is None
  1653. else device_map
  1654. )
  1655. assert_device_map(self.device_map, len(self.encoder.block))
  1656. self.encoder.parallelize(self.device_map)
  1657. self.model_parallel = True
  1658. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1659. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.deparallelize
  1660. def deparallelize(self):
  1661. warnings.warn(
  1662. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1663. FutureWarning,
  1664. )
  1665. self.encoder.deparallelize()
  1666. self.encoder = self.encoder.to("cpu")
  1667. self.model_parallel = False
  1668. self.device_map = None
  1669. torch.cuda.empty_cache()
  1670. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings
  1671. def get_input_embeddings(self):
  1672. return self.shared
  1673. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings
  1674. def set_input_embeddings(self, new_embeddings):
  1675. self.shared = new_embeddings
  1676. self.encoder.set_input_embeddings(new_embeddings)
  1677. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
  1678. def get_encoder(self):
  1679. return self.encoder
  1680. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._prune_heads
  1681. def _prune_heads(self, heads_to_prune):
  1682. """
  1683. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1684. class PreTrainedModel
  1685. """
  1686. for layer, heads in heads_to_prune.items():
  1687. self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
  1688. @auto_docstring
  1689. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with google-t5/->google/, T5->MT5, t5->mt5
  1690. def forward(
  1691. self,
  1692. input_ids: Optional[torch.LongTensor] = None,
  1693. attention_mask: Optional[torch.FloatTensor] = None,
  1694. head_mask: Optional[torch.FloatTensor] = None,
  1695. inputs_embeds: Optional[torch.FloatTensor] = None,
  1696. output_attentions: Optional[bool] = None,
  1697. output_hidden_states: Optional[bool] = None,
  1698. return_dict: Optional[bool] = None,
  1699. ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
  1700. r"""
  1701. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1702. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1703. should be able to pad the inputs on both the right and the left.
  1704. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1705. [`PreTrainedTokenizer.__call__`] for detail.
  1706. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1707. Example:
  1708. ```python
  1709. >>> from transformers import AutoTokenizer, MT5EncoderModel
  1710. >>> tokenizer = AutoTokenizer.from_pretrained("google/mt5-small")
  1711. >>> model = MT5EncoderModel.from_pretrained("google/mt5-small")
  1712. >>> input_ids = tokenizer(
  1713. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1714. ... ).input_ids # Batch size 1
  1715. >>> outputs = model(input_ids=input_ids)
  1716. >>> last_hidden_states = outputs.last_hidden_state
  1717. ```"""
  1718. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1719. encoder_outputs = self.encoder(
  1720. input_ids=input_ids,
  1721. attention_mask=attention_mask,
  1722. inputs_embeds=inputs_embeds,
  1723. head_mask=head_mask,
  1724. output_attentions=output_attentions,
  1725. output_hidden_states=output_hidden_states,
  1726. return_dict=return_dict,
  1727. )
  1728. return encoder_outputs
  1729. @auto_docstring(
  1730. custom_intro="""
  1731. MT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1732. tasks.
  1733. """
  1734. )
  1735. class MT5ForSequenceClassification(MT5PreTrainedModel):
  1736. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1737. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1738. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->MT5
  1739. def __init__(self, config: MT5Config):
  1740. super().__init__(config)
  1741. self.transformer = MT5Model(config)
  1742. self.classification_head = MT5ClassificationHead(config)
  1743. # Initialize weights and apply final processing
  1744. self.post_init()
  1745. self.model_parallel = False
  1746. @auto_docstring
  1747. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.forward with T5->MT5, t5->mt5
  1748. def forward(
  1749. self,
  1750. input_ids: Optional[torch.LongTensor] = None,
  1751. attention_mask: Optional[torch.Tensor] = None,
  1752. decoder_input_ids: Optional[torch.LongTensor] = None,
  1753. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1754. head_mask: Optional[torch.Tensor] = None,
  1755. decoder_head_mask: Optional[torch.Tensor] = None,
  1756. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1757. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1758. inputs_embeds: Optional[torch.FloatTensor] = None,
  1759. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1760. labels: Optional[torch.LongTensor] = None,
  1761. use_cache: Optional[bool] = None,
  1762. output_attentions: Optional[bool] = None,
  1763. output_hidden_states: Optional[bool] = None,
  1764. return_dict: Optional[bool] = None,
  1765. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1766. r"""
  1767. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1768. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1769. should be able to pad the inputs on both the right and the left.
  1770. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1771. [`PreTrainedTokenizer.__call__`] for detail.
  1772. [What are input IDs?](../glossary#input-ids)
  1773. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./mt5#training).
  1774. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1775. Indices of decoder input sequence tokens in the vocabulary.
  1776. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1777. [`PreTrainedTokenizer.__call__`] for details.
  1778. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1779. MT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1780. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1781. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [MT5
  1782. Training](./mt5#training).
  1783. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1784. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1785. be used by default.
  1786. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1787. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1788. 1]`:
  1789. - 1 indicates the head is **not masked**,
  1790. - 0 indicates the head is **masked**.
  1791. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1792. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1793. `[0, 1]`:
  1794. - 1 indicates the head is **not masked**,
  1795. - 0 indicates the head is **masked**.
  1796. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1797. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1798. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1799. """
  1800. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1801. if labels is not None:
  1802. use_cache = False
  1803. if input_ids is None and inputs_embeds is not None:
  1804. raise NotImplementedError(
  1805. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1806. )
  1807. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, MT5 automatically creates
  1808. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1809. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1810. if input_ids is None:
  1811. raise ValueError(
  1812. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1813. "passed, `input_ids` cannot be `None`. Please pass either "
  1814. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1815. )
  1816. decoder_input_ids = self._shift_right(input_ids)
  1817. outputs = self.transformer(
  1818. input_ids,
  1819. attention_mask=attention_mask,
  1820. decoder_input_ids=decoder_input_ids,
  1821. decoder_attention_mask=decoder_attention_mask,
  1822. head_mask=head_mask,
  1823. decoder_head_mask=decoder_head_mask,
  1824. cross_attn_head_mask=cross_attn_head_mask,
  1825. encoder_outputs=encoder_outputs,
  1826. inputs_embeds=inputs_embeds,
  1827. decoder_inputs_embeds=decoder_inputs_embeds,
  1828. use_cache=use_cache,
  1829. output_attentions=output_attentions,
  1830. output_hidden_states=output_hidden_states,
  1831. return_dict=return_dict,
  1832. )
  1833. sequence_output = outputs[0]
  1834. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1835. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1836. raise ValueError("All examples must have the same number of <eos> tokens.")
  1837. batch_size, _, hidden_size = sequence_output.shape
  1838. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1839. logits = self.classification_head(sentence_representation)
  1840. loss = None
  1841. if labels is not None:
  1842. labels = labels.to(logits.device)
  1843. if self.config.problem_type is None:
  1844. if self.config.num_labels == 1:
  1845. self.config.problem_type = "regression"
  1846. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1847. self.config.problem_type = "single_label_classification"
  1848. else:
  1849. self.config.problem_type = "multi_label_classification"
  1850. if self.config.problem_type == "regression":
  1851. loss_fct = MSELoss()
  1852. if self.config.num_labels == 1:
  1853. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1854. else:
  1855. loss = loss_fct(logits, labels)
  1856. elif self.config.problem_type == "single_label_classification":
  1857. loss_fct = CrossEntropyLoss()
  1858. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1859. elif self.config.problem_type == "multi_label_classification":
  1860. loss_fct = BCEWithLogitsLoss()
  1861. loss = loss_fct(logits, labels)
  1862. if not return_dict:
  1863. output = (logits,) + outputs[1:]
  1864. return ((loss,) + output) if loss is not None else output
  1865. return Seq2SeqSequenceClassifierOutput(
  1866. loss=loss,
  1867. logits=logits,
  1868. past_key_values=outputs.past_key_values,
  1869. decoder_hidden_states=outputs.decoder_hidden_states,
  1870. decoder_attentions=outputs.decoder_attentions,
  1871. cross_attentions=outputs.cross_attentions,
  1872. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1873. encoder_hidden_states=outputs.encoder_hidden_states,
  1874. encoder_attentions=outputs.encoder_attentions,
  1875. )
  1876. @auto_docstring
  1877. class MT5ForTokenClassification(MT5PreTrainedModel):
  1878. _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
  1879. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->MT5
  1880. def __init__(self, config: MT5Config):
  1881. super().__init__(config)
  1882. self.num_labels = config.num_labels
  1883. self.transformer = MT5EncoderModel(config)
  1884. self.dropout = nn.Dropout(config.classifier_dropout)
  1885. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1886. # Initialize weights and apply final processing
  1887. self.post_init()
  1888. @auto_docstring
  1889. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->MT5
  1890. def forward(
  1891. self,
  1892. input_ids: Optional[torch.Tensor] = None,
  1893. attention_mask: Optional[torch.Tensor] = None,
  1894. head_mask: Optional[torch.Tensor] = None,
  1895. inputs_embeds: Optional[torch.Tensor] = None,
  1896. labels: Optional[torch.Tensor] = None,
  1897. output_attentions: Optional[bool] = None,
  1898. output_hidden_states: Optional[bool] = None,
  1899. return_dict: Optional[bool] = None,
  1900. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1901. r"""
  1902. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1903. Indices of input sequence tokens in the vocabulary. MT5 is a model with relative position embeddings so you
  1904. should be able to pad the inputs on both the right and the left.
  1905. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1906. [`PreTrainedTokenizer.__call__`] for detail.
  1907. [What are input IDs?](../glossary#input-ids)
  1908. To know more on how to prepare `input_ids` for pretraining take a look a [MT5 Training](./t5#training).
  1909. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1910. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1911. """
  1912. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1913. outputs = self.transformer(
  1914. input_ids,
  1915. attention_mask=attention_mask,
  1916. head_mask=head_mask,
  1917. inputs_embeds=inputs_embeds,
  1918. output_attentions=output_attentions,
  1919. output_hidden_states=output_hidden_states,
  1920. return_dict=return_dict,
  1921. )
  1922. hidden_states = outputs[0]
  1923. hidden_states = self.dropout(hidden_states)
  1924. logits = self.classifier(hidden_states)
  1925. loss = None
  1926. if labels is not None:
  1927. loss_fct = CrossEntropyLoss()
  1928. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1929. if not return_dict:
  1930. output = (logits, outputs[2:-1])
  1931. return ((loss,) + output) if loss is not None else output
  1932. return TokenClassifierOutput(
  1933. loss=loss,
  1934. logits=logits,
  1935. hidden_states=outputs.hidden_states,
  1936. attentions=outputs.attentions,
  1937. )
  1938. @auto_docstring
  1939. class MT5ForQuestionAnswering(MT5PreTrainedModel):
  1940. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1941. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1942. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.__init__ with T5->MT5
  1943. def __init__(self, config: MT5Config):
  1944. super().__init__(config)
  1945. self.model_dim = config.d_model
  1946. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1947. encoder_config = copy.deepcopy(config)
  1948. encoder_config.is_decoder = False
  1949. encoder_config.use_cache = False
  1950. encoder_config.tie_encoder_decoder = False
  1951. self.encoder = MT5Stack(encoder_config, self.shared)
  1952. decoder_config = copy.deepcopy(config)
  1953. decoder_config.is_decoder = True
  1954. decoder_config.tie_encoder_decoder = False
  1955. decoder_config.num_layers = config.num_decoder_layers
  1956. self.decoder = MT5Stack(decoder_config, self.shared)
  1957. self.num_labels = config.num_labels
  1958. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1959. # Initialize weights and apply final processing
  1960. self.post_init()
  1961. self.model_parallel = False
  1962. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
  1963. def get_input_embeddings(self):
  1964. return self.shared
  1965. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings
  1966. def set_input_embeddings(self, new_embeddings):
  1967. self.shared = new_embeddings
  1968. self.encoder.set_input_embeddings(new_embeddings)
  1969. self.decoder.set_input_embeddings(new_embeddings)
  1970. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
  1971. def get_encoder(self):
  1972. return self.encoder
  1973. @auto_docstring
  1974. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.forward
  1975. def forward(
  1976. self,
  1977. input_ids: Optional[torch.LongTensor] = None,
  1978. attention_mask: Optional[torch.FloatTensor] = None,
  1979. decoder_input_ids: Optional[torch.LongTensor] = None,
  1980. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1981. head_mask: Optional[torch.FloatTensor] = None,
  1982. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1983. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1984. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1985. start_positions: Optional[torch.LongTensor] = None,
  1986. end_positions: Optional[torch.LongTensor] = None,
  1987. inputs_embeds: Optional[torch.FloatTensor] = None,
  1988. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1989. use_cache: Optional[bool] = None,
  1990. output_attentions: Optional[bool] = None,
  1991. output_hidden_states: Optional[bool] = None,
  1992. return_dict: Optional[bool] = None,
  1993. ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
  1994. r"""
  1995. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1996. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1997. should be able to pad the inputs on both the right and the left.
  1998. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1999. [`PreTrainedTokenizer.__call__`] for detail.
  2000. [What are input IDs?](../glossary#input-ids)
  2001. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  2002. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  2003. Indices of decoder input sequence tokens in the vocabulary.
  2004. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  2005. [`PreTrainedTokenizer.__call__`] for details.
  2006. [What are decoder input IDs?](../glossary#decoder-input-ids)
  2007. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  2008. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  2009. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  2010. Training](./t5#training).
  2011. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  2012. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  2013. be used by default.
  2014. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  2015. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  2016. 1]`:
  2017. - 1 indicates the head is **not masked**,
  2018. - 0 indicates the head is **masked**.
  2019. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  2020. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  2021. `[0, 1]`:
  2022. - 1 indicates the head is **not masked**,
  2023. - 0 indicates the head is **masked**.
  2024. """
  2025. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2026. use_cache = use_cache if use_cache is not None else self.config.use_cache
  2027. if start_positions is not None and end_positions is not None:
  2028. use_cache = False
  2029. # Copied from models.bart.modeling_bart.BartModel.forward
  2030. # different to other models, T5 automatically creates decoder_input_ids from
  2031. # input_ids if no decoder_input_ids are provided
  2032. if decoder_input_ids is None and decoder_inputs_embeds is None:
  2033. if input_ids is None:
  2034. raise ValueError(
  2035. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  2036. "passed, `input_ids` cannot be `None`. Please pass either "
  2037. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  2038. )
  2039. decoder_input_ids = self._shift_right(input_ids)
  2040. use_cache = use_cache if use_cache is not None else self.config.use_cache
  2041. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  2042. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  2043. if head_mask is not None and decoder_head_mask is None:
  2044. if self.config.num_layers == self.config.num_decoder_layers:
  2045. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  2046. decoder_head_mask = head_mask
  2047. # Encode if needed (training, first prediction pass)
  2048. if encoder_outputs is None:
  2049. encoder_outputs = self.encoder(
  2050. input_ids=input_ids,
  2051. attention_mask=attention_mask,
  2052. inputs_embeds=inputs_embeds,
  2053. head_mask=head_mask,
  2054. output_attentions=output_attentions,
  2055. output_hidden_states=output_hidden_states,
  2056. return_dict=return_dict,
  2057. )
  2058. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  2059. encoder_outputs = BaseModelOutput(
  2060. last_hidden_state=encoder_outputs[0],
  2061. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  2062. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  2063. )
  2064. hidden_states = encoder_outputs[0]
  2065. # Decode
  2066. decoder_outputs = self.decoder(
  2067. input_ids=decoder_input_ids,
  2068. attention_mask=decoder_attention_mask,
  2069. inputs_embeds=decoder_inputs_embeds,
  2070. past_key_values=None,
  2071. encoder_hidden_states=hidden_states,
  2072. encoder_attention_mask=attention_mask,
  2073. head_mask=decoder_head_mask,
  2074. cross_attn_head_mask=cross_attn_head_mask,
  2075. use_cache=use_cache,
  2076. output_attentions=output_attentions,
  2077. output_hidden_states=output_hidden_states,
  2078. return_dict=return_dict,
  2079. )
  2080. sequence_output = decoder_outputs[0]
  2081. logits = self.qa_outputs(sequence_output)
  2082. start_logits, end_logits = logits.split(1, dim=-1)
  2083. start_logits = start_logits.squeeze(-1).contiguous()
  2084. end_logits = end_logits.squeeze(-1).contiguous()
  2085. total_loss = None
  2086. if start_positions is not None and end_positions is not None:
  2087. # If we are on multi-GPU, split add a dimension
  2088. if len(start_positions.size()) > 1:
  2089. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  2090. if len(end_positions.size()) > 1:
  2091. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  2092. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  2093. ignored_index = start_logits.size(1)
  2094. start_positions = start_positions.clamp(0, ignored_index)
  2095. end_positions = end_positions.clamp(0, ignored_index)
  2096. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  2097. start_loss = loss_fct(start_logits, start_positions)
  2098. end_loss = loss_fct(end_logits, end_positions)
  2099. total_loss = (start_loss + end_loss) / 2
  2100. if not return_dict:
  2101. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  2102. return ((total_loss,) + output) if total_loss is not None else output
  2103. return Seq2SeqQuestionAnsweringModelOutput(
  2104. loss=total_loss,
  2105. start_logits=start_logits,
  2106. end_logits=end_logits,
  2107. past_key_values=decoder_outputs.past_key_values,
  2108. decoder_hidden_states=decoder_outputs.hidden_states,
  2109. decoder_attentions=decoder_outputs.attentions,
  2110. cross_attentions=decoder_outputs.cross_attentions,
  2111. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2112. encoder_hidden_states=encoder_outputs.hidden_states,
  2113. encoder_attentions=encoder_outputs.attentions,
  2114. )
  2115. __all__ = [
  2116. "MT5EncoderModel",
  2117. "MT5ForConditionalGeneration",
  2118. "MT5ForQuestionAnswering",
  2119. "MT5ForSequenceClassification",
  2120. "MT5ForTokenClassification",
  2121. "MT5Model",
  2122. "MT5PreTrainedModel",
  2123. ]