modeling_t5.py 107 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407
  1. # coding=utf-8
  2. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch T5 model."""
  16. import copy
  17. import math
  18. import os
  19. import warnings
  20. from typing import Optional, Union
  21. import torch
  22. from torch import nn
  23. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  24. from ...activations import ACT2FN
  25. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  26. from ...generation import GenerationMixin
  27. from ...modeling_attn_mask_utils import AttentionMaskConverter
  28. from ...modeling_layers import GradientCheckpointingLayer
  29. from ...modeling_outputs import (
  30. BaseModelOutput,
  31. BaseModelOutputWithPastAndCrossAttentions,
  32. Seq2SeqLMOutput,
  33. Seq2SeqModelOutput,
  34. Seq2SeqQuestionAnsweringModelOutput,
  35. Seq2SeqSequenceClassifierOutput,
  36. TokenClassifierOutput,
  37. )
  38. from ...modeling_utils import PreTrainedModel
  39. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  40. from ...utils import (
  41. DUMMY_INPUTS,
  42. DUMMY_MASK,
  43. add_start_docstrings,
  44. auto_docstring,
  45. is_torch_flex_attn_available,
  46. is_torch_fx_proxy,
  47. is_torchdynamo_compiling,
  48. logging,
  49. )
  50. from ...utils.deprecation import deprecate_kwarg
  51. from ...utils.model_parallel_utils import assert_device_map, get_device_map
  52. from .configuration_t5 import T5Config
  53. if is_torch_flex_attn_available():
  54. from torch.nn.attention.flex_attention import BlockMask
  55. from ...integrations.flex_attention import make_flex_block_causal_mask
  56. logger = logging.get_logger(__name__)
  57. ####################################################
  58. # This dict contains ids and associated url
  59. # for the pretrained weights provided with the models
  60. ####################################################
  61. ####################################################
  62. # This is a conversion method from TF 1.0 to PyTorch
  63. # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
  64. ####################################################
  65. def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
  66. """Load tf checkpoints in a pytorch model."""
  67. try:
  68. import re
  69. import numpy as np
  70. import tensorflow as tf
  71. except ImportError:
  72. logger.error(
  73. "Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see "
  74. "https://www.tensorflow.org/install/ for installation instructions."
  75. )
  76. raise
  77. tf_path = os.path.abspath(tf_checkpoint_path)
  78. logger.info(f"Converting TensorFlow checkpoint from {tf_path}")
  79. # Load weights from TF model
  80. init_vars = tf.train.list_variables(tf_path)
  81. names = []
  82. tf_weights = {}
  83. for name, shape in init_vars:
  84. logger.info(f"Loading TF weight {name} with shape {shape}")
  85. array = tf.train.load_variable(tf_path, name)
  86. names.append(name)
  87. tf_weights[name] = array
  88. for txt_name in names:
  89. name = txt_name.split("/")
  90. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  91. # which are not required for using pretrained model
  92. if any(
  93. n in ["adam_v", "adam_m", "AdamWeightDecayOptimizer", "AdamWeightDecayOptimizer_1", "global_step"]
  94. for n in name
  95. ):
  96. logger.info(f"Skipping {'/'.join(name)}")
  97. tf_weights.pop(txt_name, None)
  98. continue
  99. if "_slot_" in name[-1]:
  100. logger.info(f"Skipping {'/'.join(name)}")
  101. tf_weights.pop(txt_name, None)
  102. continue
  103. pointer = model
  104. array = tf_weights[txt_name]
  105. for m_name in name:
  106. if re.fullmatch(r"[A-Za-z]+_\d+", m_name):
  107. scope_names = re.split(r"_(\d+)", m_name)
  108. else:
  109. scope_names = [m_name]
  110. if scope_names[0] in ["kernel", "scale", "embedding"]:
  111. pointer = getattr(pointer, "weight")
  112. elif scope_names[0] == "self_attention":
  113. pointer = getattr(pointer, "layer")
  114. pointer = pointer[0]
  115. elif scope_names[0] == "enc_dec_attention":
  116. pointer = getattr(pointer, "layer")
  117. pointer = pointer[1]
  118. elif scope_names[0] == "dense_relu_dense":
  119. pointer = getattr(pointer, "layer")
  120. pointer = pointer[2]
  121. elif scope_names[0] == "rms_norm":
  122. if hasattr(pointer, "layer_norm"):
  123. pointer = getattr(pointer, "layer_norm")
  124. elif hasattr(pointer, "final_layer_norm"):
  125. pointer = getattr(pointer, "final_layer_norm")
  126. elif scope_names[0] == "scale":
  127. pointer = getattr(pointer, "weight")
  128. elif scope_names[0] == "output_bias" or scope_names[0] == "beta":
  129. pointer = getattr(pointer, "bias")
  130. elif scope_names[0] == "squad":
  131. pointer = getattr(pointer, "classifier")
  132. elif scope_names[0] == "decoder" and name[1] == "logits":
  133. continue
  134. elif scope_names[0] == "logits":
  135. pointer = getattr(pointer, "lm_head")
  136. elif scope_names[0] == "wi" and len(scope_names) > 1 and scope_names[1].isdigit():
  137. pointer = getattr(pointer, f"wi_{scope_names[1]}")
  138. continue
  139. else:
  140. try:
  141. pointer = getattr(pointer, scope_names[0])
  142. except AttributeError:
  143. logger.info(f"Skipping {'/'.join(name)}")
  144. continue
  145. if len(scope_names) >= 2:
  146. num = int(scope_names[1])
  147. pointer = pointer[num]
  148. if scope_names[0] not in ["kernel", "scale", "embedding"]:
  149. pointer = getattr(pointer, "weight")
  150. if scope_names[0] != "embedding":
  151. logger.info(f"Transposing numpy weight of shape {array.shape} for {name}")
  152. array = np.transpose(array)
  153. try:
  154. if pointer.shape != array.shape:
  155. raise ValueError(f"Pointer shape {pointer.shape} and array shape {array.shape} mismatched")
  156. except AssertionError as e:
  157. e.args += (pointer.shape, array.shape)
  158. raise
  159. logger.info(f"Initialize PyTorch weight {name}")
  160. pointer.data = torch.from_numpy(array.astype(np.float32))
  161. tf_weights.pop(txt_name, None)
  162. logger.info(f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}.")
  163. return model
  164. ####################################################
  165. # PyTorch Models are constructed by sub-classing
  166. # - torch.nn.Module for the layers and
  167. # - PreTrainedModel for the models (it-self a sub-class of nn.Module)
  168. ####################################################
  169. PARALLELIZE_DOCSTRING = r"""
  170. This is an experimental feature and is a subject to change at a moment's notice.
  171. Uses a device map to distribute attention modules of the model across several devices. If no device map is given,
  172. it will evenly distribute blocks across all devices.
  173. Args:
  174. device_map (`dict[int, list]`, *optional*):
  175. A dictionary that maps attention modules to devices. Note that the embedding module and LMHead are always
  176. automatically mapped to the first device (for esoteric reasons). That means that the first device should
  177. have fewer attention modules mapped to it than other devices. For reference, the t5 models have the
  178. following number of attention modules:
  179. - google-t5/t5-small: 6
  180. - google-t5/t5-base: 12
  181. - google-t5/t5-large: 24
  182. - google-t5/t5-3b: 24
  183. - google-t5/t5-11b: 24
  184. Example:
  185. ```python
  186. # Here is an example of a device map on a machine with 4 GPUs using google-t5/t5-3b, which has a total of 24 attention modules:
  187. model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
  188. device_map = {
  189. 0: [0, 1, 2],
  190. 1: [3, 4, 5, 6, 7, 8, 9],
  191. 2: [10, 11, 12, 13, 14, 15, 16],
  192. 3: [17, 18, 19, 20, 21, 22, 23],
  193. }
  194. model.parallelize(device_map)
  195. ```
  196. """
  197. DEPARALLELIZE_DOCSTRING = r"""
  198. Moves the model to cpu from a model parallel state.
  199. Example:
  200. ```python
  201. # On a 4 GPU machine with google-t5/t5-3b:
  202. model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-3b")
  203. device_map = {
  204. 0: [0, 1, 2],
  205. 1: [3, 4, 5, 6, 7, 8, 9],
  206. 2: [10, 11, 12, 13, 14, 15, 16],
  207. 3: [17, 18, 19, 20, 21, 22, 23],
  208. }
  209. model.parallelize(device_map) # Splits the model across several devices
  210. model.deparallelize() # Put the model back on cpu and cleans memory by calling torch.cuda.empty_cache()
  211. ```
  212. """
  213. class T5LayerNorm(nn.Module):
  214. def __init__(self, hidden_size, eps=1e-6):
  215. """
  216. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  217. """
  218. super().__init__()
  219. self.weight = nn.Parameter(torch.ones(hidden_size))
  220. self.variance_epsilon = eps
  221. def forward(self, hidden_states):
  222. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  223. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  224. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  225. # half-precision inputs is done in fp32
  226. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  227. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  228. # convert into half-precision if necessary
  229. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  230. hidden_states = hidden_states.to(self.weight.dtype)
  231. return self.weight * hidden_states
  232. try:
  233. from apex.normalization import FusedRMSNorm
  234. T5LayerNorm = FusedRMSNorm
  235. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of T5LayerNorm")
  236. except ImportError:
  237. # using the normal T5LayerNorm
  238. pass
  239. except Exception:
  240. logger.warning("discovered apex but it failed to load, falling back to T5LayerNorm")
  241. pass
  242. class T5DenseActDense(nn.Module):
  243. def __init__(self, config: T5Config):
  244. super().__init__()
  245. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  246. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  247. self.dropout = nn.Dropout(config.dropout_rate)
  248. self.act = ACT2FN[config.dense_act_fn]
  249. def forward(self, hidden_states):
  250. hidden_states = self.wi(hidden_states)
  251. hidden_states = self.act(hidden_states)
  252. hidden_states = self.dropout(hidden_states)
  253. if (
  254. isinstance(self.wo.weight, torch.Tensor)
  255. and hidden_states.dtype != self.wo.weight.dtype
  256. and self.wo.weight.dtype != torch.int8
  257. ):
  258. hidden_states = hidden_states.to(self.wo.weight.dtype)
  259. hidden_states = self.wo(hidden_states)
  260. return hidden_states
  261. class T5DenseGatedActDense(nn.Module):
  262. def __init__(self, config: T5Config):
  263. super().__init__()
  264. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  265. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  266. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  267. self.dropout = nn.Dropout(config.dropout_rate)
  268. self.act = ACT2FN[config.dense_act_fn]
  269. def forward(self, hidden_states):
  270. hidden_gelu = self.act(self.wi_0(hidden_states))
  271. hidden_linear = self.wi_1(hidden_states)
  272. hidden_states = hidden_gelu * hidden_linear
  273. hidden_states = self.dropout(hidden_states)
  274. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  275. # See https://github.com/huggingface/transformers/issues/20287
  276. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  277. if (
  278. isinstance(self.wo.weight, torch.Tensor)
  279. and hidden_states.dtype != self.wo.weight.dtype
  280. and self.wo.weight.dtype != torch.int8
  281. ):
  282. hidden_states = hidden_states.to(self.wo.weight.dtype)
  283. hidden_states = self.wo(hidden_states)
  284. return hidden_states
  285. class T5LayerFF(nn.Module):
  286. def __init__(self, config: T5Config):
  287. super().__init__()
  288. if config.is_gated_act:
  289. self.DenseReluDense = T5DenseGatedActDense(config)
  290. else:
  291. self.DenseReluDense = T5DenseActDense(config)
  292. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  293. self.dropout = nn.Dropout(config.dropout_rate)
  294. def forward(self, hidden_states):
  295. forwarded_states = self.layer_norm(hidden_states)
  296. forwarded_states = self.DenseReluDense(forwarded_states)
  297. hidden_states = hidden_states + self.dropout(forwarded_states)
  298. return hidden_states
  299. class T5Attention(nn.Module):
  300. def __init__(
  301. self,
  302. config: T5Config,
  303. has_relative_attention_bias=False,
  304. layer_idx: Optional[int] = None,
  305. ):
  306. super().__init__()
  307. self.is_decoder = config.is_decoder
  308. self.has_relative_attention_bias = has_relative_attention_bias
  309. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  310. self.relative_attention_max_distance = config.relative_attention_max_distance
  311. self.d_model = config.d_model
  312. self.key_value_proj_dim = config.d_kv
  313. self.n_heads = config.num_heads
  314. self.dropout = config.dropout_rate
  315. self.inner_dim = self.n_heads * self.key_value_proj_dim
  316. self.layer_idx = layer_idx
  317. if layer_idx is None and self.is_decoder:
  318. logger.warning_once(
  319. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  320. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  321. "when creating this class."
  322. )
  323. # Mesh TensorFlow initialization to avoid scaling before softmax
  324. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  325. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  326. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  327. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  328. if self.has_relative_attention_bias:
  329. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  330. self.pruned_heads = set()
  331. self.gradient_checkpointing = False
  332. def prune_heads(self, heads):
  333. if len(heads) == 0:
  334. return
  335. heads, index = find_pruneable_heads_and_indices(
  336. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  337. )
  338. # Prune linear layers
  339. self.q = prune_linear_layer(self.q, index)
  340. self.k = prune_linear_layer(self.k, index)
  341. self.v = prune_linear_layer(self.v, index)
  342. self.o = prune_linear_layer(self.o, index, dim=1)
  343. # Update hyper params
  344. self.n_heads = self.n_heads - len(heads)
  345. self.inner_dim = self.key_value_proj_dim * self.n_heads
  346. self.pruned_heads = self.pruned_heads.union(heads)
  347. @staticmethod
  348. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  349. """
  350. Adapted from Mesh Tensorflow:
  351. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  352. Translate relative position to a bucket number for relative attention. The relative position is defined as
  353. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  354. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  355. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  356. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  357. This should allow for more graceful generalization to longer sequences than the model has been trained on
  358. Args:
  359. relative_position: an int32 Tensor
  360. bidirectional: a boolean - whether the attention is bidirectional
  361. num_buckets: an integer
  362. max_distance: an integer
  363. Returns:
  364. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  365. """
  366. relative_buckets = 0
  367. if bidirectional:
  368. num_buckets //= 2
  369. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  370. relative_position = torch.abs(relative_position)
  371. else:
  372. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  373. # now relative_position is in the range [0, inf)
  374. # half of the buckets are for exact increments in positions
  375. max_exact = num_buckets // 2
  376. is_small = relative_position < max_exact
  377. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  378. relative_position_if_large = max_exact + (
  379. torch.log(relative_position.float() / max_exact)
  380. / math.log(max_distance / max_exact)
  381. * (num_buckets - max_exact)
  382. ).to(torch.long)
  383. relative_position_if_large = torch.min(
  384. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  385. )
  386. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  387. return relative_buckets
  388. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  389. """Compute binned relative position bias"""
  390. if device is None:
  391. device = self.relative_attention_bias.weight.device
  392. if cache_position is None:
  393. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  394. else:
  395. context_position = cache_position[:, None].to(device)
  396. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  397. relative_position = memory_position - context_position # shape (query_length, key_length)
  398. relative_position_bucket = self._relative_position_bucket(
  399. relative_position, # shape (query_length, key_length)
  400. bidirectional=(not self.is_decoder),
  401. num_buckets=self.relative_attention_num_buckets,
  402. max_distance=self.relative_attention_max_distance,
  403. )
  404. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  405. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  406. return values
  407. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  408. def forward(
  409. self,
  410. hidden_states,
  411. mask=None,
  412. key_value_states=None,
  413. position_bias=None,
  414. past_key_values=None,
  415. layer_head_mask=None,
  416. query_length=None,
  417. use_cache=False,
  418. output_attentions=False,
  419. cache_position=None,
  420. ):
  421. """
  422. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  423. """
  424. # Input is (batch_size, seq_length, dim)
  425. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  426. batch_size, seq_length = hidden_states.shape[:2]
  427. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  428. is_cross_attention = key_value_states is not None
  429. query_states = self.q(hidden_states)
  430. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  431. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  432. is_updated = False
  433. if isinstance(past_key_values, EncoderDecoderCache):
  434. is_updated = past_key_values.is_updated.get(self.layer_idx)
  435. if is_cross_attention:
  436. # after the first generated id, we can subsequently re-use all key/value_states from cache
  437. curr_past_key_value = past_key_values.cross_attention_cache
  438. else:
  439. curr_past_key_value = past_key_values.self_attention_cache
  440. else:
  441. curr_past_key_value = past_key_values
  442. current_states = key_value_states if is_cross_attention else hidden_states
  443. if is_cross_attention and past_key_values is not None and is_updated:
  444. # reuse k,v, cross_attentions
  445. key_states = curr_past_key_value.layers[self.layer_idx].keys
  446. value_states = curr_past_key_value.layers[self.layer_idx].values
  447. else:
  448. key_states = self.k(current_states)
  449. value_states = self.v(current_states)
  450. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  451. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  452. if past_key_values is not None:
  453. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  454. cache_position = cache_position if not is_cross_attention else None
  455. key_states, value_states = curr_past_key_value.update(
  456. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  457. )
  458. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  459. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  460. past_key_values.is_updated[self.layer_idx] = True
  461. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  462. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  463. if position_bias is None:
  464. key_length = key_states.shape[-2]
  465. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  466. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  467. if not self.has_relative_attention_bias:
  468. position_bias = torch.zeros(
  469. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  470. )
  471. if self.gradient_checkpointing and self.training:
  472. position_bias.requires_grad = True
  473. else:
  474. position_bias = self.compute_bias(
  475. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  476. )
  477. position_bias = position_bias[:, :, -seq_length:, :]
  478. if mask is not None:
  479. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  480. position_bias = position_bias + causal_mask
  481. if self.pruned_heads:
  482. mask = torch.ones(position_bias.shape[1])
  483. mask[list(self.pruned_heads)] = 0
  484. position_bias_masked = position_bias[:, mask.bool()]
  485. else:
  486. position_bias_masked = position_bias
  487. scores += position_bias_masked
  488. # (batch_size, n_heads, seq_length, key_length)
  489. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  490. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  491. # Mask heads if we want to
  492. if layer_head_mask is not None:
  493. attn_weights = attn_weights * layer_head_mask
  494. attn_output = torch.matmul(attn_weights, value_states)
  495. attn_output = attn_output.transpose(1, 2).contiguous()
  496. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  497. attn_output = self.o(attn_output)
  498. outputs = (attn_output, position_bias)
  499. if output_attentions:
  500. outputs = outputs + (attn_weights,)
  501. return outputs
  502. class T5LayerSelfAttention(nn.Module):
  503. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  504. super().__init__()
  505. self.SelfAttention = T5Attention(
  506. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  507. )
  508. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  509. self.dropout = nn.Dropout(config.dropout_rate)
  510. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  511. def forward(
  512. self,
  513. hidden_states,
  514. attention_mask=None,
  515. position_bias=None,
  516. layer_head_mask=None,
  517. past_key_values=None,
  518. use_cache=False,
  519. output_attentions=False,
  520. cache_position=None,
  521. ):
  522. normed_hidden_states = self.layer_norm(hidden_states)
  523. attention_output = self.SelfAttention(
  524. normed_hidden_states,
  525. mask=attention_mask,
  526. position_bias=position_bias,
  527. layer_head_mask=layer_head_mask,
  528. past_key_values=past_key_values,
  529. use_cache=use_cache,
  530. output_attentions=output_attentions,
  531. cache_position=cache_position,
  532. )
  533. hidden_states = hidden_states + self.dropout(attention_output[0])
  534. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  535. return outputs
  536. class T5LayerCrossAttention(nn.Module):
  537. def __init__(self, config, layer_idx: Optional[int] = None):
  538. super().__init__()
  539. self.EncDecAttention = T5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  540. self.layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  541. self.dropout = nn.Dropout(config.dropout_rate)
  542. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  543. def forward(
  544. self,
  545. hidden_states,
  546. key_value_states,
  547. attention_mask=None,
  548. position_bias=None,
  549. layer_head_mask=None,
  550. past_key_values=None,
  551. use_cache=False,
  552. query_length=None,
  553. output_attentions=False,
  554. cache_position=None,
  555. ):
  556. normed_hidden_states = self.layer_norm(hidden_states)
  557. attention_output = self.EncDecAttention(
  558. normed_hidden_states,
  559. mask=attention_mask,
  560. key_value_states=key_value_states,
  561. position_bias=position_bias,
  562. layer_head_mask=layer_head_mask,
  563. past_key_values=past_key_values,
  564. use_cache=use_cache,
  565. query_length=query_length,
  566. output_attentions=output_attentions,
  567. cache_position=cache_position,
  568. )
  569. layer_output = hidden_states + self.dropout(attention_output[0])
  570. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  571. return outputs
  572. class T5Block(GradientCheckpointingLayer):
  573. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  574. super().__init__()
  575. self.is_decoder = config.is_decoder
  576. self.layer = nn.ModuleList()
  577. self.layer.append(
  578. T5LayerSelfAttention(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  579. )
  580. if self.is_decoder:
  581. self.layer.append(T5LayerCrossAttention(config, layer_idx=layer_idx))
  582. self.layer.append(T5LayerFF(config))
  583. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  584. def forward(
  585. self,
  586. hidden_states,
  587. attention_mask=None,
  588. position_bias=None,
  589. encoder_hidden_states=None,
  590. encoder_attention_mask=None,
  591. encoder_decoder_position_bias=None,
  592. layer_head_mask=None,
  593. cross_attn_layer_head_mask=None,
  594. past_key_values=None,
  595. use_cache=False,
  596. output_attentions=False,
  597. return_dict=True,
  598. cache_position=None,
  599. ):
  600. self_attention_outputs = self.layer[0](
  601. hidden_states,
  602. attention_mask=attention_mask,
  603. position_bias=position_bias,
  604. layer_head_mask=layer_head_mask,
  605. past_key_values=past_key_values,
  606. use_cache=use_cache,
  607. output_attentions=output_attentions,
  608. cache_position=cache_position,
  609. )
  610. hidden_states = self_attention_outputs[0]
  611. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  612. # clamp inf values to enable fp16 training
  613. if hidden_states.dtype == torch.float16:
  614. clamp_value = torch.where(
  615. torch.isinf(hidden_states).any(),
  616. torch.finfo(hidden_states.dtype).max - 1000,
  617. torch.finfo(hidden_states.dtype).max,
  618. )
  619. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  620. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  621. if do_cross_attention:
  622. cross_attention_outputs = self.layer[1](
  623. hidden_states,
  624. key_value_states=encoder_hidden_states,
  625. attention_mask=encoder_attention_mask,
  626. position_bias=encoder_decoder_position_bias,
  627. layer_head_mask=cross_attn_layer_head_mask,
  628. past_key_values=past_key_values,
  629. query_length=cache_position[-1] + 1,
  630. use_cache=use_cache,
  631. output_attentions=output_attentions,
  632. )
  633. hidden_states = cross_attention_outputs[0]
  634. # clamp inf values to enable fp16 training
  635. if hidden_states.dtype == torch.float16:
  636. clamp_value = torch.where(
  637. torch.isinf(hidden_states).any(),
  638. torch.finfo(hidden_states.dtype).max - 1000,
  639. torch.finfo(hidden_states.dtype).max,
  640. )
  641. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  642. # Keep cross-attention outputs and relative position weights
  643. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  644. # Apply Feed Forward layer
  645. hidden_states = self.layer[-1](hidden_states)
  646. # clamp inf values to enable fp16 training
  647. if hidden_states.dtype == torch.float16:
  648. clamp_value = torch.where(
  649. torch.isinf(hidden_states).any(),
  650. torch.finfo(hidden_states.dtype).max - 1000,
  651. torch.finfo(hidden_states.dtype).max,
  652. )
  653. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  654. outputs = (hidden_states,)
  655. return (
  656. outputs + attention_outputs
  657. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  658. class T5ClassificationHead(nn.Module):
  659. """Head for sentence-level classification tasks."""
  660. def __init__(self, config: T5Config):
  661. super().__init__()
  662. self.dense = nn.Linear(config.d_model, config.d_model)
  663. self.dropout = nn.Dropout(p=config.classifier_dropout)
  664. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  665. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  666. hidden_states = self.dropout(hidden_states)
  667. hidden_states = self.dense(hidden_states)
  668. hidden_states = torch.tanh(hidden_states)
  669. hidden_states = self.dropout(hidden_states)
  670. hidden_states = self.out_proj(hidden_states)
  671. return hidden_states
  672. @auto_docstring
  673. class T5PreTrainedModel(PreTrainedModel):
  674. config: T5Config
  675. load_tf_weights = load_tf_weights_in_t5
  676. base_model_prefix = "transformer"
  677. is_parallelizable = True
  678. supports_gradient_checkpointing = True
  679. _can_compile_fullgraph = True
  680. _no_split_modules = ["T5Block"]
  681. _keep_in_fp32_modules = ["wo"]
  682. @property
  683. def dummy_inputs(self):
  684. input_ids = torch.tensor(DUMMY_INPUTS)
  685. input_mask = torch.tensor(DUMMY_MASK)
  686. dummy_inputs = {
  687. "decoder_input_ids": input_ids,
  688. "input_ids": input_ids,
  689. "decoder_attention_mask": input_mask,
  690. }
  691. return dummy_inputs
  692. def _init_weights(self, module):
  693. """Initialize the weights"""
  694. factor = self.config.initializer_factor # Used for testing weights initialization
  695. if isinstance(module, T5LayerNorm):
  696. module.weight.data.fill_(factor * 1.0)
  697. elif isinstance(
  698. module,
  699. (T5Model, T5ForConditionalGeneration, T5EncoderModel, T5ForQuestionAnswering),
  700. ):
  701. # Mesh TensorFlow embeddings initialization
  702. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  703. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  704. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  705. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  706. if hasattr(module, "qa_outputs"):
  707. module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  708. module.qa_outputs.bias.data.zero_()
  709. elif isinstance(module, T5ForTokenClassification):
  710. if hasattr(module, "classifier"):
  711. module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
  712. module.classifier.bias.data.zero_()
  713. elif isinstance(module, T5ClassificationHead):
  714. module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  715. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  716. module.dense.bias.data.zero_()
  717. module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  718. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  719. module.out_proj.bias.data.zero_()
  720. elif isinstance(module, T5DenseActDense):
  721. # Mesh TensorFlow FF initialization
  722. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  723. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  724. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  725. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  726. module.wi.bias.data.zero_()
  727. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  728. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  729. module.wo.bias.data.zero_()
  730. elif isinstance(module, T5DenseGatedActDense):
  731. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  732. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  733. module.wi_0.bias.data.zero_()
  734. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  735. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  736. module.wi_1.bias.data.zero_()
  737. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  738. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  739. module.wo.bias.data.zero_()
  740. elif isinstance(module, T5Attention):
  741. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  742. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  743. d_model = self.config.d_model
  744. key_value_proj_dim = self.config.d_kv
  745. n_heads = self.config.num_heads
  746. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  747. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  748. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  749. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  750. if module.has_relative_attention_bias:
  751. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  752. def _shift_right(self, input_ids):
  753. decoder_start_token_id = self.config.decoder_start_token_id
  754. pad_token_id = self.config.pad_token_id
  755. if decoder_start_token_id is None:
  756. raise ValueError(
  757. "self.model.config.decoder_start_token_id has to be defined. In T5 it is usually set to the pad_token_id. "
  758. "See T5 docs for more information."
  759. )
  760. # shift inputs to the right
  761. if is_torch_fx_proxy(input_ids):
  762. # Item assignment is not supported natively for proxies.
  763. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  764. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  765. else:
  766. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  767. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  768. shifted_input_ids[..., 0] = decoder_start_token_id
  769. if pad_token_id is None:
  770. raise ValueError("self.model.config.pad_token_id has to be defined.")
  771. # replace possible -100 values in labels by `pad_token_id`
  772. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  773. return shifted_input_ids
  774. class T5Stack(T5PreTrainedModel):
  775. def __init__(self, config, embed_tokens=None):
  776. super().__init__(config)
  777. self.embed_tokens = embed_tokens
  778. self.is_decoder = config.is_decoder
  779. self.block = nn.ModuleList(
  780. [T5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i) for i in range(config.num_layers)]
  781. )
  782. self.final_layer_norm = T5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  783. self.dropout = nn.Dropout(config.dropout_rate)
  784. # Initialize weights and apply final processing
  785. self.post_init()
  786. # Model parallel
  787. self.model_parallel = False
  788. self.device_map = None
  789. self.gradient_checkpointing = False
  790. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  791. def parallelize(self, device_map=None):
  792. warnings.warn(
  793. "`T5Stack.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  794. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  795. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  796. " 'block.1': 1, ...}",
  797. FutureWarning,
  798. )
  799. # Check validity of device_map
  800. self.device_map = (
  801. get_device_map(len(self.block), range(torch.cuda.device_count())) if device_map is None else device_map
  802. )
  803. assert_device_map(self.device_map, len(self.block))
  804. self.model_parallel = True
  805. self.first_device = "cpu" if "cpu" in self.device_map else "cuda:" + str(min(self.device_map.keys()))
  806. self.last_device = "cuda:" + str(max(self.device_map.keys()))
  807. # Load onto devices
  808. for k, v in self.device_map.items():
  809. for layer in v:
  810. cuda_device = "cuda:" + str(k)
  811. self.block[layer] = self.block[layer].to(cuda_device)
  812. # Set embed_tokens to first layer
  813. self.embed_tokens = self.embed_tokens.to(self.first_device)
  814. # Set final layer norm to last device
  815. self.final_layer_norm = self.final_layer_norm.to(self.last_device)
  816. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  817. def deparallelize(self):
  818. warnings.warn(
  819. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  820. FutureWarning,
  821. )
  822. self.model_parallel = False
  823. self.device_map = None
  824. self.first_device = "cpu"
  825. self.last_device = "cpu"
  826. for i in range(len(self.block)):
  827. self.block[i] = self.block[i].to("cpu")
  828. self.embed_tokens = self.embed_tokens.to("cpu")
  829. self.final_layer_norm = self.final_layer_norm.to("cpu")
  830. torch.cuda.empty_cache()
  831. def set_input_embeddings(self, new_embeddings):
  832. self.embed_tokens = new_embeddings
  833. def forward(
  834. self,
  835. input_ids=None,
  836. attention_mask=None,
  837. encoder_hidden_states=None,
  838. encoder_attention_mask=None,
  839. inputs_embeds=None,
  840. head_mask=None,
  841. cross_attn_head_mask=None,
  842. past_key_values=None,
  843. use_cache=None,
  844. output_attentions=None,
  845. output_hidden_states=None,
  846. return_dict=None,
  847. cache_position=None,
  848. ):
  849. # Model parallel
  850. if self.model_parallel:
  851. torch.cuda.set_device(self.first_device)
  852. self.embed_tokens = self.embed_tokens.to(self.first_device)
  853. use_cache = use_cache if use_cache is not None else self.config.use_cache
  854. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  855. output_hidden_states = (
  856. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  857. )
  858. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  859. if input_ids is not None and inputs_embeds is not None:
  860. err_msg_prefix = "decoder_" if self.is_decoder else ""
  861. raise ValueError(
  862. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  863. )
  864. elif input_ids is not None:
  865. input_shape = input_ids.size()
  866. input_ids = input_ids.view(-1, input_shape[-1])
  867. elif inputs_embeds is not None:
  868. input_shape = inputs_embeds.size()[:-1]
  869. else:
  870. err_msg_prefix = "decoder_" if self.is_decoder else ""
  871. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  872. if self.gradient_checkpointing and self.training:
  873. if use_cache:
  874. logger.warning_once(
  875. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  876. )
  877. use_cache = False
  878. if inputs_embeds is None:
  879. if self.embed_tokens is None:
  880. raise ValueError("You have to initialize the model with valid token embeddings")
  881. inputs_embeds = self.embed_tokens(input_ids)
  882. batch_size, seq_length = input_shape
  883. if use_cache is True:
  884. if not self.is_decoder:
  885. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  886. if self.is_decoder:
  887. if use_cache and past_key_values is None:
  888. if self.config.is_encoder_decoder:
  889. past_key_values = EncoderDecoderCache(
  890. DynamicCache(config=self.config), DynamicCache(config=self.config)
  891. )
  892. else:
  893. past_key_values = DynamicCache(config=self.config)
  894. elif not self.is_decoder:
  895. # do not pass cache object down the line for encoder stack
  896. # it messes indexing later in decoder-stack because cache object is modified in-place
  897. past_key_values = None
  898. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  899. if cache_position is None:
  900. cache_position = torch.arange(
  901. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  902. )
  903. if attention_mask is None and not is_torchdynamo_compiling():
  904. # required mask seq length can be calculated via length of past cache
  905. mask_seq_length = past_key_values_length + seq_length
  906. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  907. if self.config.is_decoder:
  908. causal_mask = self._update_causal_mask(
  909. attention_mask,
  910. inputs_embeds,
  911. cache_position,
  912. past_key_values.self_attention_cache
  913. if isinstance(past_key_values, EncoderDecoderCache)
  914. else past_key_values,
  915. output_attentions,
  916. )
  917. elif attention_mask is not None:
  918. causal_mask = attention_mask[:, None, None, :]
  919. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  920. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  921. else:
  922. causal_mask = None
  923. # If a 2D or 3D attention mask is provided for the cross-attention
  924. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  925. if self.is_decoder and encoder_hidden_states is not None:
  926. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  927. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  928. if encoder_attention_mask is None:
  929. encoder_attention_mask = torch.ones(
  930. encoder_hidden_shape, device=inputs_embeds.device, dtype=torch.long
  931. )
  932. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  933. else:
  934. encoder_extended_attention_mask = None
  935. # Prepare head mask if needed
  936. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  937. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  938. all_hidden_states = () if output_hidden_states else None
  939. all_attentions = () if output_attentions else None
  940. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  941. position_bias = None
  942. encoder_decoder_position_bias = None
  943. hidden_states = self.dropout(inputs_embeds)
  944. for i, layer_module in enumerate(self.block):
  945. layer_head_mask = head_mask[i]
  946. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  947. # Model parallel
  948. if self.model_parallel:
  949. torch.cuda.set_device(hidden_states.device)
  950. # Ensure that attention_mask is always on the same device as hidden_states
  951. if causal_mask is not None:
  952. causal_mask = causal_mask.to(hidden_states.device)
  953. if position_bias is not None:
  954. position_bias = position_bias.to(hidden_states.device)
  955. if encoder_hidden_states is not None:
  956. encoder_hidden_states = encoder_hidden_states.to(hidden_states.device)
  957. if encoder_extended_attention_mask is not None:
  958. encoder_extended_attention_mask = encoder_extended_attention_mask.to(hidden_states.device)
  959. if encoder_decoder_position_bias is not None:
  960. encoder_decoder_position_bias = encoder_decoder_position_bias.to(hidden_states.device)
  961. if layer_head_mask is not None:
  962. layer_head_mask = layer_head_mask.to(hidden_states.device)
  963. if cross_attn_layer_head_mask is not None:
  964. cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(hidden_states.device)
  965. if output_hidden_states:
  966. all_hidden_states = all_hidden_states + (hidden_states,)
  967. layer_outputs = layer_module(
  968. hidden_states,
  969. causal_mask,
  970. position_bias,
  971. encoder_hidden_states,
  972. encoder_extended_attention_mask,
  973. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  974. layer_head_mask=layer_head_mask,
  975. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  976. past_key_values=past_key_values,
  977. use_cache=use_cache,
  978. output_attentions=output_attentions,
  979. return_dict=return_dict,
  980. cache_position=cache_position,
  981. )
  982. hidden_states = layer_outputs[0]
  983. # We share the position biases between the layers - the first layer store them
  984. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  985. # (cross-attention position bias), (cross-attention weights)
  986. position_bias = layer_outputs[1]
  987. if self.is_decoder and encoder_hidden_states is not None:
  988. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  989. if output_attentions:
  990. all_attentions = all_attentions + (layer_outputs[2],)
  991. if self.is_decoder:
  992. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  993. # Model Parallel: If it's the last layer for that device, put things on the next device
  994. if self.model_parallel:
  995. for k, v in self.device_map.items():
  996. if i == v[-1] and "cuda:" + str(k) != self.last_device:
  997. hidden_states = hidden_states.to("cuda:" + str(k + 1))
  998. hidden_states = self.final_layer_norm(hidden_states)
  999. hidden_states = self.dropout(hidden_states)
  1000. # Add last layer
  1001. if output_hidden_states:
  1002. all_hidden_states = all_hidden_states + (hidden_states,)
  1003. if not return_dict:
  1004. return tuple(
  1005. v
  1006. for v in [
  1007. hidden_states,
  1008. past_key_values,
  1009. all_hidden_states,
  1010. all_attentions,
  1011. all_cross_attentions,
  1012. ]
  1013. if v is not None
  1014. )
  1015. return BaseModelOutputWithPastAndCrossAttentions(
  1016. last_hidden_state=hidden_states,
  1017. past_key_values=past_key_values,
  1018. hidden_states=all_hidden_states,
  1019. attentions=all_attentions,
  1020. cross_attentions=all_cross_attentions,
  1021. )
  1022. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  1023. def _update_causal_mask(
  1024. self,
  1025. attention_mask: Union[torch.Tensor, "BlockMask"],
  1026. input_tensor: torch.Tensor,
  1027. cache_position: torch.Tensor,
  1028. past_key_values: Cache,
  1029. output_attentions: bool = False,
  1030. ):
  1031. if self.config._attn_implementation == "flash_attention_2":
  1032. if attention_mask is not None and (attention_mask == 0.0).any():
  1033. return attention_mask
  1034. return None
  1035. if self.config._attn_implementation == "flex_attention":
  1036. if isinstance(attention_mask, torch.Tensor):
  1037. attention_mask = make_flex_block_causal_mask(attention_mask)
  1038. return attention_mask
  1039. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1040. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1041. # to infer the attention mask.
  1042. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1043. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  1044. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1045. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  1046. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1047. attention_mask,
  1048. inputs_embeds=input_tensor,
  1049. past_key_values_length=past_seen_tokens,
  1050. is_training=self.training,
  1051. ):
  1052. return None
  1053. dtype = input_tensor.dtype
  1054. sequence_length = input_tensor.shape[1]
  1055. if using_compilable_cache:
  1056. target_length = past_key_values.get_max_cache_shape()
  1057. else:
  1058. target_length = (
  1059. attention_mask.shape[-1]
  1060. if isinstance(attention_mask, torch.Tensor)
  1061. else past_seen_tokens + sequence_length + 1
  1062. )
  1063. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1064. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1065. attention_mask,
  1066. sequence_length=sequence_length,
  1067. target_length=target_length,
  1068. dtype=dtype,
  1069. cache_position=cache_position,
  1070. batch_size=input_tensor.shape[0],
  1071. )
  1072. if (
  1073. self.config._attn_implementation == "sdpa"
  1074. and attention_mask is not None
  1075. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1076. and not output_attentions
  1077. ):
  1078. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1079. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1080. # Details: https://github.com/pytorch/pytorch/issues/110213
  1081. min_dtype = torch.finfo(dtype).min
  1082. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1083. return causal_mask
  1084. @staticmethod
  1085. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  1086. def _prepare_4d_causal_attention_mask_with_cache_position(
  1087. attention_mask: torch.Tensor,
  1088. sequence_length: int,
  1089. target_length: int,
  1090. dtype: torch.dtype,
  1091. cache_position: torch.Tensor,
  1092. batch_size: int,
  1093. **kwargs,
  1094. ):
  1095. """
  1096. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1097. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1098. Args:
  1099. attention_mask (`torch.Tensor`):
  1100. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1101. `(batch_size, 1, query_length, key_value_length)`.
  1102. sequence_length (`int`):
  1103. The sequence length being processed.
  1104. target_length (`int`):
  1105. The target length: when generating with static cache, the mask should be as long as the static cache,
  1106. to account for the 0 padding, the part of the cache that is not filled yet.
  1107. dtype (`torch.dtype`):
  1108. The dtype to use for the 4D attention mask.
  1109. cache_position (`torch.Tensor`):
  1110. Indices depicting the position of the input sequence tokens in the sequence.
  1111. batch_size (`torch.Tensor`):
  1112. Batch size.
  1113. """
  1114. if attention_mask is not None and attention_mask.dim() == 4:
  1115. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1116. causal_mask = attention_mask
  1117. else:
  1118. min_dtype = torch.finfo(dtype).min
  1119. causal_mask = torch.full(
  1120. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1121. )
  1122. if sequence_length != 1:
  1123. causal_mask = torch.triu(causal_mask, diagonal=1)
  1124. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1125. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1126. if attention_mask is not None:
  1127. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1128. mask_length = attention_mask.shape[-1]
  1129. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1130. causal_mask.device
  1131. )
  1132. padding_mask = padding_mask == 0
  1133. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1134. padding_mask, min_dtype
  1135. )
  1136. return causal_mask
  1137. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1138. __HEAD_MASK_WARNING_MSG = """
  1139. The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
  1140. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
  1141. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
  1142. num_heads)`.
  1143. """
  1144. @auto_docstring
  1145. class T5Model(T5PreTrainedModel):
  1146. _keys_to_ignore_on_load_unexpected = [
  1147. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1148. ]
  1149. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1150. def __init__(self, config: T5Config):
  1151. super().__init__(config)
  1152. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1153. encoder_config = copy.deepcopy(config)
  1154. encoder_config.is_decoder = False
  1155. encoder_config.use_cache = False
  1156. encoder_config.tie_encoder_decoder = False
  1157. self.encoder = T5Stack(encoder_config, self.shared)
  1158. decoder_config = copy.deepcopy(config)
  1159. decoder_config.is_decoder = True
  1160. decoder_config.tie_encoder_decoder = False
  1161. decoder_config.num_layers = config.num_decoder_layers
  1162. self.decoder = T5Stack(decoder_config, self.shared)
  1163. # Initialize weights and apply final processing
  1164. self.post_init()
  1165. # Model parallel
  1166. self.model_parallel = False
  1167. self.device_map = None
  1168. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1169. def parallelize(self, device_map=None):
  1170. warnings.warn(
  1171. "`T5Model.parallelize` is deprecated and will be removed in v5 of Transformers, you should load your model"
  1172. " with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1173. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'encoder.block.0':"
  1174. " 0, 'encoder.block.1': 1, ...}",
  1175. FutureWarning,
  1176. )
  1177. self.device_map = (
  1178. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1179. if device_map is None
  1180. else device_map
  1181. )
  1182. assert_device_map(self.device_map, len(self.encoder.block))
  1183. self.encoder.parallelize(self.device_map)
  1184. self.decoder.parallelize(self.device_map)
  1185. self.model_parallel = True
  1186. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1187. def deparallelize(self):
  1188. warnings.warn(
  1189. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1190. FutureWarning,
  1191. )
  1192. self.encoder.deparallelize()
  1193. self.decoder.deparallelize()
  1194. self.encoder = self.encoder.to("cpu")
  1195. self.decoder = self.decoder.to("cpu")
  1196. self.model_parallel = False
  1197. self.device_map = None
  1198. torch.cuda.empty_cache()
  1199. def get_input_embeddings(self):
  1200. return self.shared
  1201. def set_input_embeddings(self, new_embeddings):
  1202. self.shared = new_embeddings
  1203. self.encoder.set_input_embeddings(new_embeddings)
  1204. self.decoder.set_input_embeddings(new_embeddings)
  1205. def _tie_weights(self):
  1206. if self.config.tie_word_embeddings:
  1207. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1208. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1209. def get_encoder(self):
  1210. return self.encoder
  1211. def _prune_heads(self, heads_to_prune):
  1212. """
  1213. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1214. class PreTrainedModel
  1215. """
  1216. for layer, heads in heads_to_prune.items():
  1217. self.encoder.layer[layer].attention.prune_heads(heads)
  1218. @auto_docstring
  1219. def forward(
  1220. self,
  1221. input_ids: Optional[torch.LongTensor] = None,
  1222. attention_mask: Optional[torch.FloatTensor] = None,
  1223. decoder_input_ids: Optional[torch.LongTensor] = None,
  1224. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1225. head_mask: Optional[torch.FloatTensor] = None,
  1226. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1227. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1228. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1229. past_key_values: Optional[Cache] = None,
  1230. inputs_embeds: Optional[torch.Tensor] = None,
  1231. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1232. use_cache: Optional[bool] = None,
  1233. output_attentions: Optional[bool] = None,
  1234. output_hidden_states: Optional[bool] = None,
  1235. return_dict: Optional[bool] = None,
  1236. cache_position: Optional[torch.LongTensor] = None,
  1237. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1238. r"""
  1239. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1240. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1241. should be able to pad the inputs on both the right and the left.
  1242. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1243. [`PreTrainedTokenizer.__call__`] for detail.
  1244. [What are input IDs?](../glossary#input-ids)
  1245. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1246. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1247. Indices of decoder input sequence tokens in the vocabulary.
  1248. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1249. [`PreTrainedTokenizer.__call__`] for details.
  1250. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1251. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1252. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1253. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1254. Training](./t5#training).
  1255. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1256. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1257. be used by default.
  1258. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1259. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1260. 1]`:
  1261. - 1 indicates the head is **not masked**,
  1262. - 0 indicates the head is **masked**.
  1263. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1264. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1265. `[0, 1]`:
  1266. - 1 indicates the head is **not masked**,
  1267. - 0 indicates the head is **masked**.
  1268. Example:
  1269. ```python
  1270. >>> from transformers import AutoTokenizer, T5Model
  1271. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1272. >>> model = T5Model.from_pretrained("google-t5/t5-small")
  1273. >>> input_ids = tokenizer(
  1274. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1275. ... ).input_ids # Batch size 1
  1276. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1277. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for T5Model.
  1278. >>> # This is not needed for torch's T5ForConditionalGeneration as it does this internally using labels arg.
  1279. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  1280. >>> # forward pass
  1281. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1282. >>> last_hidden_states = outputs.last_hidden_state
  1283. ```"""
  1284. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1285. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1286. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1287. if head_mask is not None and decoder_head_mask is None:
  1288. if self.config.num_layers == self.config.num_decoder_layers:
  1289. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1290. decoder_head_mask = head_mask
  1291. # Encode if needed (training, first prediction pass)
  1292. if encoder_outputs is None:
  1293. encoder_outputs = self.encoder(
  1294. input_ids=input_ids,
  1295. attention_mask=attention_mask,
  1296. inputs_embeds=inputs_embeds,
  1297. head_mask=head_mask,
  1298. output_attentions=output_attentions,
  1299. output_hidden_states=output_hidden_states,
  1300. return_dict=return_dict,
  1301. )
  1302. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1303. encoder_outputs = BaseModelOutput(
  1304. last_hidden_state=encoder_outputs[0],
  1305. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1306. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1307. )
  1308. hidden_states = encoder_outputs[0]
  1309. # Set device for model parallelism
  1310. if self.model_parallel:
  1311. torch.cuda.set_device(self.decoder.first_device)
  1312. hidden_states = hidden_states.to(self.decoder.first_device)
  1313. if decoder_input_ids is not None:
  1314. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1315. if attention_mask is not None:
  1316. attention_mask = attention_mask.to(self.decoder.first_device)
  1317. if decoder_attention_mask is not None:
  1318. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1319. # Decode
  1320. decoder_outputs = self.decoder(
  1321. input_ids=decoder_input_ids,
  1322. attention_mask=decoder_attention_mask,
  1323. inputs_embeds=decoder_inputs_embeds,
  1324. past_key_values=past_key_values,
  1325. encoder_hidden_states=hidden_states,
  1326. encoder_attention_mask=attention_mask,
  1327. head_mask=decoder_head_mask,
  1328. cross_attn_head_mask=cross_attn_head_mask,
  1329. use_cache=use_cache,
  1330. output_attentions=output_attentions,
  1331. output_hidden_states=output_hidden_states,
  1332. return_dict=return_dict,
  1333. cache_position=cache_position,
  1334. )
  1335. if not return_dict:
  1336. return decoder_outputs + encoder_outputs
  1337. return Seq2SeqModelOutput(
  1338. last_hidden_state=decoder_outputs.last_hidden_state,
  1339. past_key_values=decoder_outputs.past_key_values,
  1340. decoder_hidden_states=decoder_outputs.hidden_states,
  1341. decoder_attentions=decoder_outputs.attentions,
  1342. cross_attentions=decoder_outputs.cross_attentions,
  1343. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1344. encoder_hidden_states=encoder_outputs.hidden_states,
  1345. encoder_attentions=encoder_outputs.attentions,
  1346. )
  1347. @auto_docstring(
  1348. custom_intro="""
  1349. T5 Model with a `language modeling` head on top.
  1350. """
  1351. )
  1352. class T5ForConditionalGeneration(T5PreTrainedModel, GenerationMixin):
  1353. _keys_to_ignore_on_load_unexpected = [
  1354. "decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1355. ]
  1356. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1357. def __init__(self, config: T5Config):
  1358. super().__init__(config)
  1359. self.model_dim = config.d_model
  1360. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1361. encoder_config = copy.deepcopy(config)
  1362. encoder_config.is_decoder = False
  1363. encoder_config.use_cache = False
  1364. encoder_config.tie_encoder_decoder = False
  1365. self.encoder = T5Stack(encoder_config, self.shared)
  1366. decoder_config = copy.deepcopy(config)
  1367. decoder_config.is_decoder = True
  1368. decoder_config.tie_encoder_decoder = False
  1369. decoder_config.num_layers = config.num_decoder_layers
  1370. self.decoder = T5Stack(decoder_config, self.shared)
  1371. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1372. # Initialize weights and apply final processing
  1373. self.post_init()
  1374. # Model parallel
  1375. self.model_parallel = False
  1376. self.device_map = None
  1377. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1378. def parallelize(self, device_map=None):
  1379. warnings.warn(
  1380. "`T5ForConditionalGeneration.parallelize` is deprecated and will be removed in v5 of Transformers, you"
  1381. " should load your model with `device_map='balanced'` in the call to `from_pretrained`. You can also"
  1382. " provide your own `device_map` but it needs to be a dictionary module_name to device, so for instance"
  1383. " {'encoder.block.0': 0, 'encoder.block.1': 1, ...}",
  1384. FutureWarning,
  1385. )
  1386. self.device_map = (
  1387. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1388. if device_map is None
  1389. else device_map
  1390. )
  1391. assert_device_map(self.device_map, len(self.encoder.block))
  1392. self.encoder.parallelize(self.device_map)
  1393. self.decoder.parallelize(self.device_map)
  1394. self.lm_head = self.lm_head.to(self.decoder.first_device)
  1395. self.model_parallel = True
  1396. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1397. def deparallelize(self):
  1398. warnings.warn(
  1399. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1400. FutureWarning,
  1401. )
  1402. self.encoder.deparallelize()
  1403. self.decoder.deparallelize()
  1404. self.encoder = self.encoder.to("cpu")
  1405. self.decoder = self.decoder.to("cpu")
  1406. self.lm_head = self.lm_head.to("cpu")
  1407. self.model_parallel = False
  1408. self.device_map = None
  1409. torch.cuda.empty_cache()
  1410. def get_input_embeddings(self):
  1411. return self.shared
  1412. def set_input_embeddings(self, new_embeddings):
  1413. self.shared = new_embeddings
  1414. self.encoder.set_input_embeddings(new_embeddings)
  1415. self.decoder.set_input_embeddings(new_embeddings)
  1416. def _tie_weights(self):
  1417. if self.config.tie_word_embeddings:
  1418. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1419. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1420. def get_encoder(self):
  1421. return self.encoder
  1422. @auto_docstring
  1423. def forward(
  1424. self,
  1425. input_ids: Optional[torch.LongTensor] = None,
  1426. attention_mask: Optional[torch.FloatTensor] = None,
  1427. decoder_input_ids: Optional[torch.LongTensor] = None,
  1428. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1429. head_mask: Optional[torch.FloatTensor] = None,
  1430. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1431. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1432. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1433. past_key_values: Optional[Cache] = None,
  1434. inputs_embeds: Optional[torch.FloatTensor] = None,
  1435. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1436. labels: Optional[torch.LongTensor] = None,
  1437. use_cache: Optional[bool] = None,
  1438. output_attentions: Optional[bool] = None,
  1439. output_hidden_states: Optional[bool] = None,
  1440. return_dict: Optional[bool] = None,
  1441. cache_position: Optional[torch.LongTensor] = None,
  1442. ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1443. r"""
  1444. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1445. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1446. should be able to pad the inputs on both the right and the left.
  1447. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1448. [`PreTrainedTokenizer.__call__`] for detail.
  1449. [What are input IDs?](../glossary#input-ids)
  1450. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1451. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1452. Indices of decoder input sequence tokens in the vocabulary.
  1453. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1454. [`PreTrainedTokenizer.__call__`] for details.
  1455. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1456. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1457. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1458. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1459. Training](./t5#training).
  1460. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1461. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1462. be used by default.
  1463. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1464. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1465. 1]`:
  1466. - 1 indicates the head is **not masked**,
  1467. - 0 indicates the head is **masked**.
  1468. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1469. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1470. `[0, 1]`:
  1471. - 1 indicates the head is **not masked**,
  1472. - 0 indicates the head is **masked**.
  1473. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1474. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1475. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1476. labels in `[0, ..., config.vocab_size]`
  1477. Examples:
  1478. ```python
  1479. >>> from transformers import AutoTokenizer, T5ForConditionalGeneration
  1480. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1481. >>> model = T5ForConditionalGeneration.from_pretrained("google-t5/t5-small")
  1482. >>> # training
  1483. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  1484. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  1485. >>> outputs = model(input_ids=input_ids, labels=labels)
  1486. >>> loss = outputs.loss
  1487. >>> logits = outputs.logits
  1488. >>> # inference
  1489. >>> input_ids = tokenizer(
  1490. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  1491. ... ).input_ids # Batch size 1
  1492. >>> outputs = model.generate(input_ids)
  1493. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1494. >>> # studies have shown that owning a dog is good for you.
  1495. ```"""
  1496. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1497. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1498. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1499. if head_mask is not None and decoder_head_mask is None:
  1500. if self.config.num_layers == self.config.num_decoder_layers:
  1501. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1502. decoder_head_mask = head_mask
  1503. # Encode if needed (training, first prediction pass)
  1504. if encoder_outputs is None:
  1505. # Convert encoder inputs in embeddings if needed
  1506. encoder_outputs = self.encoder(
  1507. input_ids=input_ids,
  1508. attention_mask=attention_mask,
  1509. inputs_embeds=inputs_embeds,
  1510. head_mask=head_mask,
  1511. output_attentions=output_attentions,
  1512. output_hidden_states=output_hidden_states,
  1513. return_dict=return_dict,
  1514. )
  1515. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1516. encoder_outputs = BaseModelOutput(
  1517. last_hidden_state=encoder_outputs[0],
  1518. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1519. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1520. )
  1521. hidden_states = encoder_outputs[0]
  1522. if self.model_parallel:
  1523. torch.cuda.set_device(self.decoder.first_device)
  1524. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1525. # get decoder inputs from shifting lm labels to the right
  1526. decoder_input_ids = self._shift_right(labels)
  1527. # Set device for model parallelism
  1528. if self.model_parallel:
  1529. torch.cuda.set_device(self.decoder.first_device)
  1530. hidden_states = hidden_states.to(self.decoder.first_device)
  1531. if decoder_input_ids is not None:
  1532. decoder_input_ids = decoder_input_ids.to(self.decoder.first_device)
  1533. if attention_mask is not None:
  1534. attention_mask = attention_mask.to(self.decoder.first_device)
  1535. if decoder_attention_mask is not None:
  1536. decoder_attention_mask = decoder_attention_mask.to(self.decoder.first_device)
  1537. # Decode
  1538. decoder_outputs = self.decoder(
  1539. input_ids=decoder_input_ids,
  1540. attention_mask=decoder_attention_mask,
  1541. inputs_embeds=decoder_inputs_embeds,
  1542. past_key_values=past_key_values,
  1543. encoder_hidden_states=hidden_states,
  1544. encoder_attention_mask=attention_mask,
  1545. head_mask=decoder_head_mask,
  1546. cross_attn_head_mask=cross_attn_head_mask,
  1547. use_cache=use_cache,
  1548. output_attentions=output_attentions,
  1549. output_hidden_states=output_hidden_states,
  1550. return_dict=return_dict,
  1551. cache_position=cache_position,
  1552. )
  1553. sequence_output = decoder_outputs[0]
  1554. # Set device for model parallelism
  1555. if self.model_parallel:
  1556. torch.cuda.set_device(self.encoder.first_device)
  1557. self.lm_head = self.lm_head.to(self.encoder.first_device)
  1558. sequence_output = sequence_output.to(self.lm_head.weight.device)
  1559. if self.config.tie_word_embeddings:
  1560. # Rescale output before projecting on vocab
  1561. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1562. sequence_output = sequence_output * (self.model_dim**-0.5)
  1563. lm_logits = self.lm_head(sequence_output)
  1564. loss = None
  1565. if labels is not None:
  1566. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1567. # move labels to correct device to enable PP
  1568. labels = labels.to(lm_logits.device)
  1569. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1570. # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  1571. if not return_dict:
  1572. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1573. return ((loss,) + output) if loss is not None else output
  1574. return Seq2SeqLMOutput(
  1575. loss=loss,
  1576. logits=lm_logits,
  1577. past_key_values=decoder_outputs.past_key_values,
  1578. decoder_hidden_states=decoder_outputs.hidden_states,
  1579. decoder_attentions=decoder_outputs.attentions,
  1580. cross_attentions=decoder_outputs.cross_attentions,
  1581. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1582. encoder_hidden_states=encoder_outputs.hidden_states,
  1583. encoder_attentions=encoder_outputs.attentions,
  1584. )
  1585. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1586. return self._shift_right(labels)
  1587. @auto_docstring
  1588. class T5EncoderModel(T5PreTrainedModel):
  1589. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1590. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  1591. def __init__(self, config: T5Config):
  1592. super().__init__(config)
  1593. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1594. encoder_config = config
  1595. encoder_config.use_cache = False
  1596. encoder_config.is_encoder_decoder = False
  1597. self.encoder = T5Stack(encoder_config, self.shared)
  1598. # Initialize weights and apply final processing
  1599. self.post_init()
  1600. # Model parallel
  1601. self.model_parallel = False
  1602. self.device_map = None
  1603. @add_start_docstrings(PARALLELIZE_DOCSTRING)
  1604. def parallelize(self, device_map=None):
  1605. warnings.warn(
  1606. "`T5EncoderModel.parallelize` is deprecated and will be removed in v5 of Transformers, you should load"
  1607. " your model with `device_map='balanced'` in the call to `from_pretrained`. You can also provide your own"
  1608. " `device_map` but it needs to be a dictionary module_name to device, so for instance {'block.0': 0,"
  1609. " 'block.1': 1, ...}",
  1610. FutureWarning,
  1611. )
  1612. self.device_map = (
  1613. get_device_map(len(self.encoder.block), range(torch.cuda.device_count()))
  1614. if device_map is None
  1615. else device_map
  1616. )
  1617. assert_device_map(self.device_map, len(self.encoder.block))
  1618. self.encoder.parallelize(self.device_map)
  1619. self.model_parallel = True
  1620. @add_start_docstrings(DEPARALLELIZE_DOCSTRING)
  1621. def deparallelize(self):
  1622. warnings.warn(
  1623. "Like `parallelize`, `deparallelize` is deprecated and will be removed in v5 of Transformers.",
  1624. FutureWarning,
  1625. )
  1626. self.encoder.deparallelize()
  1627. self.encoder = self.encoder.to("cpu")
  1628. self.model_parallel = False
  1629. self.device_map = None
  1630. torch.cuda.empty_cache()
  1631. def get_input_embeddings(self):
  1632. return self.shared
  1633. def set_input_embeddings(self, new_embeddings):
  1634. self.shared = new_embeddings
  1635. self.encoder.set_input_embeddings(new_embeddings)
  1636. def _tie_weights(self):
  1637. if self.config.tie_word_embeddings:
  1638. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1639. def get_encoder(self):
  1640. return self.encoder
  1641. def _prune_heads(self, heads_to_prune):
  1642. """
  1643. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1644. class PreTrainedModel
  1645. """
  1646. for layer, heads in heads_to_prune.items():
  1647. self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
  1648. @auto_docstring
  1649. def forward(
  1650. self,
  1651. input_ids: Optional[torch.LongTensor] = None,
  1652. attention_mask: Optional[torch.FloatTensor] = None,
  1653. head_mask: Optional[torch.FloatTensor] = None,
  1654. inputs_embeds: Optional[torch.FloatTensor] = None,
  1655. output_attentions: Optional[bool] = None,
  1656. output_hidden_states: Optional[bool] = None,
  1657. return_dict: Optional[bool] = None,
  1658. ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
  1659. r"""
  1660. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1661. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1662. should be able to pad the inputs on both the right and the left.
  1663. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1664. [`PreTrainedTokenizer.__call__`] for detail.
  1665. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1666. Example:
  1667. ```python
  1668. >>> from transformers import AutoTokenizer, T5EncoderModel
  1669. >>> tokenizer = AutoTokenizer.from_pretrained("google-t5/t5-small")
  1670. >>> model = T5EncoderModel.from_pretrained("google-t5/t5-small")
  1671. >>> input_ids = tokenizer(
  1672. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1673. ... ).input_ids # Batch size 1
  1674. >>> outputs = model(input_ids=input_ids)
  1675. >>> last_hidden_states = outputs.last_hidden_state
  1676. ```"""
  1677. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1678. encoder_outputs = self.encoder(
  1679. input_ids=input_ids,
  1680. attention_mask=attention_mask,
  1681. inputs_embeds=inputs_embeds,
  1682. head_mask=head_mask,
  1683. output_attentions=output_attentions,
  1684. output_hidden_states=output_hidden_states,
  1685. return_dict=return_dict,
  1686. )
  1687. return encoder_outputs
  1688. @auto_docstring(
  1689. custom_intro="""
  1690. T5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1691. tasks.
  1692. """
  1693. )
  1694. class T5ForSequenceClassification(T5PreTrainedModel):
  1695. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1696. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1697. def __init__(self, config: T5Config):
  1698. super().__init__(config)
  1699. self.transformer = T5Model(config)
  1700. self.classification_head = T5ClassificationHead(config)
  1701. # Initialize weights and apply final processing
  1702. self.post_init()
  1703. self.model_parallel = False
  1704. @auto_docstring
  1705. def forward(
  1706. self,
  1707. input_ids: Optional[torch.LongTensor] = None,
  1708. attention_mask: Optional[torch.Tensor] = None,
  1709. decoder_input_ids: Optional[torch.LongTensor] = None,
  1710. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1711. head_mask: Optional[torch.Tensor] = None,
  1712. decoder_head_mask: Optional[torch.Tensor] = None,
  1713. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1714. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1715. inputs_embeds: Optional[torch.FloatTensor] = None,
  1716. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1717. labels: Optional[torch.LongTensor] = None,
  1718. use_cache: Optional[bool] = None,
  1719. output_attentions: Optional[bool] = None,
  1720. output_hidden_states: Optional[bool] = None,
  1721. return_dict: Optional[bool] = None,
  1722. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1723. r"""
  1724. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1725. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1726. should be able to pad the inputs on both the right and the left.
  1727. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1728. [`PreTrainedTokenizer.__call__`] for detail.
  1729. [What are input IDs?](../glossary#input-ids)
  1730. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1731. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1732. Indices of decoder input sequence tokens in the vocabulary.
  1733. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1734. [`PreTrainedTokenizer.__call__`] for details.
  1735. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1736. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1737. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1738. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1739. Training](./t5#training).
  1740. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1741. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1742. be used by default.
  1743. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1744. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1745. 1]`:
  1746. - 1 indicates the head is **not masked**,
  1747. - 0 indicates the head is **masked**.
  1748. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1749. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1750. `[0, 1]`:
  1751. - 1 indicates the head is **not masked**,
  1752. - 0 indicates the head is **masked**.
  1753. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1754. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1755. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1756. """
  1757. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1758. if labels is not None:
  1759. use_cache = False
  1760. if input_ids is None and inputs_embeds is not None:
  1761. raise NotImplementedError(
  1762. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1763. )
  1764. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
  1765. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1766. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1767. if input_ids is None:
  1768. raise ValueError(
  1769. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1770. "passed, `input_ids` cannot be `None`. Please pass either "
  1771. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1772. )
  1773. decoder_input_ids = self._shift_right(input_ids)
  1774. outputs = self.transformer(
  1775. input_ids,
  1776. attention_mask=attention_mask,
  1777. decoder_input_ids=decoder_input_ids,
  1778. decoder_attention_mask=decoder_attention_mask,
  1779. head_mask=head_mask,
  1780. decoder_head_mask=decoder_head_mask,
  1781. cross_attn_head_mask=cross_attn_head_mask,
  1782. encoder_outputs=encoder_outputs,
  1783. inputs_embeds=inputs_embeds,
  1784. decoder_inputs_embeds=decoder_inputs_embeds,
  1785. use_cache=use_cache,
  1786. output_attentions=output_attentions,
  1787. output_hidden_states=output_hidden_states,
  1788. return_dict=return_dict,
  1789. )
  1790. sequence_output = outputs[0]
  1791. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1792. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1793. raise ValueError("All examples must have the same number of <eos> tokens.")
  1794. batch_size, _, hidden_size = sequence_output.shape
  1795. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1796. logits = self.classification_head(sentence_representation)
  1797. loss = None
  1798. if labels is not None:
  1799. labels = labels.to(logits.device)
  1800. if self.config.problem_type is None:
  1801. if self.config.num_labels == 1:
  1802. self.config.problem_type = "regression"
  1803. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1804. self.config.problem_type = "single_label_classification"
  1805. else:
  1806. self.config.problem_type = "multi_label_classification"
  1807. if self.config.problem_type == "regression":
  1808. loss_fct = MSELoss()
  1809. if self.config.num_labels == 1:
  1810. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1811. else:
  1812. loss = loss_fct(logits, labels)
  1813. elif self.config.problem_type == "single_label_classification":
  1814. loss_fct = CrossEntropyLoss()
  1815. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1816. elif self.config.problem_type == "multi_label_classification":
  1817. loss_fct = BCEWithLogitsLoss()
  1818. loss = loss_fct(logits, labels)
  1819. if not return_dict:
  1820. output = (logits,) + outputs[1:]
  1821. return ((loss,) + output) if loss is not None else output
  1822. return Seq2SeqSequenceClassifierOutput(
  1823. loss=loss,
  1824. logits=logits,
  1825. past_key_values=outputs.past_key_values,
  1826. decoder_hidden_states=outputs.decoder_hidden_states,
  1827. decoder_attentions=outputs.decoder_attentions,
  1828. cross_attentions=outputs.cross_attentions,
  1829. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1830. encoder_hidden_states=outputs.encoder_hidden_states,
  1831. encoder_attentions=outputs.encoder_attentions,
  1832. )
  1833. @auto_docstring
  1834. class T5ForTokenClassification(T5PreTrainedModel):
  1835. _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
  1836. def __init__(self, config: T5Config):
  1837. super().__init__(config)
  1838. self.num_labels = config.num_labels
  1839. self.transformer = T5EncoderModel(config)
  1840. self.dropout = nn.Dropout(config.classifier_dropout)
  1841. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1842. # Initialize weights and apply final processing
  1843. self.post_init()
  1844. @auto_docstring
  1845. def forward(
  1846. self,
  1847. input_ids: Optional[torch.Tensor] = None,
  1848. attention_mask: Optional[torch.Tensor] = None,
  1849. head_mask: Optional[torch.Tensor] = None,
  1850. inputs_embeds: Optional[torch.Tensor] = None,
  1851. labels: Optional[torch.Tensor] = None,
  1852. output_attentions: Optional[bool] = None,
  1853. output_hidden_states: Optional[bool] = None,
  1854. return_dict: Optional[bool] = None,
  1855. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1856. r"""
  1857. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1858. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1859. should be able to pad the inputs on both the right and the left.
  1860. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1861. [`PreTrainedTokenizer.__call__`] for detail.
  1862. [What are input IDs?](../glossary#input-ids)
  1863. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1864. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1865. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1866. """
  1867. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1868. outputs = self.transformer(
  1869. input_ids,
  1870. attention_mask=attention_mask,
  1871. head_mask=head_mask,
  1872. inputs_embeds=inputs_embeds,
  1873. output_attentions=output_attentions,
  1874. output_hidden_states=output_hidden_states,
  1875. return_dict=return_dict,
  1876. )
  1877. hidden_states = outputs[0]
  1878. hidden_states = self.dropout(hidden_states)
  1879. logits = self.classifier(hidden_states)
  1880. loss = None
  1881. if labels is not None:
  1882. loss_fct = CrossEntropyLoss()
  1883. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1884. if not return_dict:
  1885. output = (logits, outputs[2:-1])
  1886. return ((loss,) + output) if loss is not None else output
  1887. return TokenClassifierOutput(
  1888. loss=loss,
  1889. logits=logits,
  1890. hidden_states=outputs.hidden_states,
  1891. attentions=outputs.attentions,
  1892. )
  1893. @auto_docstring
  1894. class T5ForQuestionAnswering(T5PreTrainedModel):
  1895. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1896. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1897. def __init__(self, config: T5Config):
  1898. super().__init__(config)
  1899. self.model_dim = config.d_model
  1900. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1901. encoder_config = copy.deepcopy(config)
  1902. encoder_config.is_decoder = False
  1903. encoder_config.use_cache = False
  1904. encoder_config.tie_encoder_decoder = False
  1905. self.encoder = T5Stack(encoder_config, self.shared)
  1906. decoder_config = copy.deepcopy(config)
  1907. decoder_config.is_decoder = True
  1908. decoder_config.tie_encoder_decoder = False
  1909. decoder_config.num_layers = config.num_decoder_layers
  1910. self.decoder = T5Stack(decoder_config, self.shared)
  1911. self.num_labels = config.num_labels
  1912. self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels)
  1913. # Initialize weights and apply final processing
  1914. self.post_init()
  1915. self.model_parallel = False
  1916. def get_input_embeddings(self):
  1917. return self.shared
  1918. def set_input_embeddings(self, new_embeddings):
  1919. self.shared = new_embeddings
  1920. self.encoder.set_input_embeddings(new_embeddings)
  1921. self.decoder.set_input_embeddings(new_embeddings)
  1922. def _tie_weights(self):
  1923. if self.config.tie_word_embeddings:
  1924. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1925. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1926. def get_encoder(self):
  1927. return self.encoder
  1928. @auto_docstring
  1929. def forward(
  1930. self,
  1931. input_ids: Optional[torch.LongTensor] = None,
  1932. attention_mask: Optional[torch.FloatTensor] = None,
  1933. decoder_input_ids: Optional[torch.LongTensor] = None,
  1934. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1935. head_mask: Optional[torch.FloatTensor] = None,
  1936. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1937. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1938. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1939. start_positions: Optional[torch.LongTensor] = None,
  1940. end_positions: Optional[torch.LongTensor] = None,
  1941. inputs_embeds: Optional[torch.FloatTensor] = None,
  1942. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1943. use_cache: Optional[bool] = None,
  1944. output_attentions: Optional[bool] = None,
  1945. output_hidden_states: Optional[bool] = None,
  1946. return_dict: Optional[bool] = None,
  1947. ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
  1948. r"""
  1949. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1950. Indices of input sequence tokens in the vocabulary. T5 is a model with relative position embeddings so you
  1951. should be able to pad the inputs on both the right and the left.
  1952. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1953. [`PreTrainedTokenizer.__call__`] for detail.
  1954. [What are input IDs?](../glossary#input-ids)
  1955. To know more on how to prepare `input_ids` for pretraining take a look a [T5 Training](./t5#training).
  1956. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1957. Indices of decoder input sequence tokens in the vocabulary.
  1958. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1959. [`PreTrainedTokenizer.__call__`] for details.
  1960. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1961. T5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1962. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1963. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [T5
  1964. Training](./t5#training).
  1965. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1966. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1967. be used by default.
  1968. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1969. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1970. 1]`:
  1971. - 1 indicates the head is **not masked**,
  1972. - 0 indicates the head is **masked**.
  1973. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1974. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1975. `[0, 1]`:
  1976. - 1 indicates the head is **not masked**,
  1977. - 0 indicates the head is **masked**.
  1978. """
  1979. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1980. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1981. if start_positions is not None and end_positions is not None:
  1982. use_cache = False
  1983. # Copied from models.bart.modeling_bart.BartModel.forward
  1984. # different to other models, T5 automatically creates decoder_input_ids from
  1985. # input_ids if no decoder_input_ids are provided
  1986. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1987. if input_ids is None:
  1988. raise ValueError(
  1989. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1990. "passed, `input_ids` cannot be `None`. Please pass either "
  1991. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1992. )
  1993. decoder_input_ids = self._shift_right(input_ids)
  1994. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1995. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1996. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1997. if head_mask is not None and decoder_head_mask is None:
  1998. if self.config.num_layers == self.config.num_decoder_layers:
  1999. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  2000. decoder_head_mask = head_mask
  2001. # Encode if needed (training, first prediction pass)
  2002. if encoder_outputs is None:
  2003. encoder_outputs = self.encoder(
  2004. input_ids=input_ids,
  2005. attention_mask=attention_mask,
  2006. inputs_embeds=inputs_embeds,
  2007. head_mask=head_mask,
  2008. output_attentions=output_attentions,
  2009. output_hidden_states=output_hidden_states,
  2010. return_dict=return_dict,
  2011. )
  2012. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  2013. encoder_outputs = BaseModelOutput(
  2014. last_hidden_state=encoder_outputs[0],
  2015. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  2016. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  2017. )
  2018. hidden_states = encoder_outputs[0]
  2019. # Decode
  2020. decoder_outputs = self.decoder(
  2021. input_ids=decoder_input_ids,
  2022. attention_mask=decoder_attention_mask,
  2023. inputs_embeds=decoder_inputs_embeds,
  2024. past_key_values=None,
  2025. encoder_hidden_states=hidden_states,
  2026. encoder_attention_mask=attention_mask,
  2027. head_mask=decoder_head_mask,
  2028. cross_attn_head_mask=cross_attn_head_mask,
  2029. use_cache=use_cache,
  2030. output_attentions=output_attentions,
  2031. output_hidden_states=output_hidden_states,
  2032. return_dict=return_dict,
  2033. )
  2034. sequence_output = decoder_outputs[0]
  2035. logits = self.qa_outputs(sequence_output)
  2036. start_logits, end_logits = logits.split(1, dim=-1)
  2037. start_logits = start_logits.squeeze(-1).contiguous()
  2038. end_logits = end_logits.squeeze(-1).contiguous()
  2039. total_loss = None
  2040. if start_positions is not None and end_positions is not None:
  2041. # If we are on multi-GPU, split add a dimension
  2042. if len(start_positions.size()) > 1:
  2043. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  2044. if len(end_positions.size()) > 1:
  2045. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  2046. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  2047. ignored_index = start_logits.size(1)
  2048. start_positions = start_positions.clamp(0, ignored_index)
  2049. end_positions = end_positions.clamp(0, ignored_index)
  2050. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  2051. start_loss = loss_fct(start_logits, start_positions)
  2052. end_loss = loss_fct(end_logits, end_positions)
  2053. total_loss = (start_loss + end_loss) / 2
  2054. if not return_dict:
  2055. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  2056. return ((total_loss,) + output) if total_loss is not None else output
  2057. return Seq2SeqQuestionAnsweringModelOutput(
  2058. loss=total_loss,
  2059. start_logits=start_logits,
  2060. end_logits=end_logits,
  2061. past_key_values=decoder_outputs.past_key_values,
  2062. decoder_hidden_states=decoder_outputs.hidden_states,
  2063. decoder_attentions=decoder_outputs.attentions,
  2064. cross_attentions=decoder_outputs.cross_attentions,
  2065. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  2066. encoder_hidden_states=encoder_outputs.hidden_states,
  2067. encoder_attentions=encoder_outputs.attentions,
  2068. )
  2069. __all__ = [
  2070. "T5EncoderModel",
  2071. "T5ForConditionalGeneration",
  2072. "T5Model",
  2073. "T5PreTrainedModel",
  2074. "load_tf_weights_in_t5",
  2075. "T5ForQuestionAnswering",
  2076. "T5ForSequenceClassification",
  2077. "T5ForTokenClassification",
  2078. ]