modeling_longt5.py 102 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227
  1. # coding=utf-8
  2. # Copyright 2022 Google LLC., LongT5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch LongT5 model."""
  16. import copy
  17. import math
  18. import warnings
  19. from typing import Any, Optional, Union
  20. import torch
  21. from torch import nn
  22. from torch.nn import CrossEntropyLoss
  23. from ...activations import ACT2FN
  24. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  25. from ...generation import GenerationMixin
  26. from ...modeling_attn_mask_utils import AttentionMaskConverter
  27. from ...modeling_layers import GradientCheckpointingLayer
  28. from ...modeling_outputs import (
  29. BaseModelOutput,
  30. BaseModelOutputWithPastAndCrossAttentions,
  31. Seq2SeqLMOutput,
  32. Seq2SeqModelOutput,
  33. )
  34. from ...modeling_utils import PreTrainedModel
  35. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  36. from ...utils import (
  37. DUMMY_INPUTS,
  38. DUMMY_MASK,
  39. auto_docstring,
  40. is_torch_flex_attn_available,
  41. is_torch_fx_proxy,
  42. is_torchdynamo_compiling,
  43. logging,
  44. )
  45. from ...utils.deprecation import deprecate_kwarg
  46. from .configuration_longt5 import LongT5Config
  47. if is_torch_flex_attn_available():
  48. from torch.nn.attention.flex_attention import BlockMask
  49. from ...integrations.flex_attention import make_flex_block_causal_mask
  50. logger = logging.get_logger(__name__)
  51. # TODO: Update before the merge
  52. def _pad_to_multiple(x: torch.Tensor, block_len: int, dim: int, pad_value: int = 0) -> torch.Tensor:
  53. """Pad a tensor so that a sequence length will be a multiple of `block_len`"""
  54. pad_len = -x.shape[dim] % block_len
  55. # Handle cases when an empty input sequence is given
  56. if not all(x.shape):
  57. new_shape = list(x.shape)
  58. new_shape[dim] += pad_len
  59. return torch.zeros(new_shape, dtype=x.dtype)
  60. pad = [(0, 0)] * x.ndim
  61. pad[dim] = (0, pad_len)
  62. pad = sum(pad[::-1], ())
  63. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  64. return x
  65. def _split_into_blocks(x: torch.Tensor, block_len: int, dim: int) -> torch.Tensor:
  66. """Split an input tensor into blocks of a given `block_len` along the given `dim`. If the dimension length
  67. is not a multiple of `block_len`, it will be padded first with selected `pad_value`.
  68. """
  69. # pad tensor to multiple of block_len
  70. if x.shape[dim] % block_len != 0:
  71. x = _pad_to_multiple(x, block_len, dim, pad_value=0)
  72. num_blocks = x.shape[dim] // block_len
  73. output_shape = x.shape[:dim] + (num_blocks, block_len) + x.shape[(dim + 1) :]
  74. # If 0 is in output_shape, we cannot apply reshape because of incompatibility with ONNX conversion
  75. if 0 in output_shape:
  76. return torch.empty(output_shape, dtype=x.dtype, device=x.device)
  77. return x.reshape(output_shape)
  78. def _concatenate_3_blocks(x: torch.Tensor, block_dim: int, sequence_dim: int, pad_value: int = 0) -> torch.Tensor:
  79. """Concatenate three consecutive blocks for each input block for local attentiont.
  80. For more information, see: https://huggingface.co/papers/2112.07916.
  81. """
  82. num_blocks = x.shape[block_dim]
  83. pad = [(0, 0)] * x.ndim
  84. pad[block_dim] = (1, 1)
  85. pad = sum(pad[::-1], ())
  86. # [batch_size, num_blocks, block_len] -> [batch_size, num_blocks + 2, block_len]
  87. x = nn.functional.pad(x, pad=pad, mode="constant", value=pad_value)
  88. blocks_list: list[torch.Tensor] = []
  89. for i in range(3):
  90. # We use indexing approach here:
  91. # https://numpy.org/doc/stable/user/basics.indexing.html#dealing-with-variable-numbers-of-indices-within-programs
  92. indices = [slice(0, None)] * x.ndim
  93. indices[block_dim] = slice(i, i + num_blocks)
  94. indices = tuple(indices)
  95. blocks_list.append(x[indices])
  96. # [batch_size, num_blocks, 3 * block_len, ...]
  97. return torch.cat(blocks_list, dim=sequence_dim)
  98. def _make_3block_relative_position_ids(block_len: int) -> torch.Tensor:
  99. """Makes 3-blocked relative position ids for local attention."""
  100. position_ids = torch.arange(3 * block_len, dtype=torch.int32)
  101. center_position_ids = position_ids[block_len:-block_len]
  102. # [block_len, 3 * block_len]
  103. relative_position_ids = position_ids.unsqueeze(0) - center_position_ids.unsqueeze(1)
  104. return relative_position_ids
  105. def _mask_local_attention_mask(local_attention_mask: torch.Tensor, block_len: int) -> torch.Tensor:
  106. """Mask local attention mask to enforce that tokens are not allowed to attend tokens farther than ``local_radius."""
  107. relative_position_ids = _make_3block_relative_position_ids(block_len)
  108. locality_mask = torch.abs(relative_position_ids) < block_len
  109. locality_mask = locality_mask[None, None, :, :]
  110. locality_mask = locality_mask.to(local_attention_mask.device)
  111. return torch.logical_and(local_attention_mask, locality_mask)
  112. def _get_local_attention_mask(attention_mask: torch.Tensor, block_len: int, device: torch.device) -> torch.Tensor:
  113. """Prepare attention mask to be applied for a local attention."""
  114. # [batch_size, num_blocks, block_len]
  115. _blocked_attention_mask = _split_into_blocks(attention_mask, block_len, dim=1)
  116. # [batch_size, num_block, 3 * block_len]
  117. _3blocked_attention_mask = _concatenate_3_blocks(_blocked_attention_mask, block_dim=1, sequence_dim=2)
  118. _blocked_attention_mask = _blocked_attention_mask.unsqueeze(-1)
  119. _3blocked_attention_mask = _3blocked_attention_mask.unsqueeze(-2)
  120. # [batch_size, num_block, block_len, 3 * block_len]
  121. local_attention_mask = torch.logical_and(_blocked_attention_mask, _3blocked_attention_mask)
  122. local_attention_mask = _mask_local_attention_mask(local_attention_mask, block_len)
  123. # [batch_size, 1, num_block, block_len, 3 * block_len]
  124. return local_attention_mask.unsqueeze(1).to(device)
  125. def _make_global_fixed_block_ids(
  126. attention_mask: torch.Tensor, global_block_size: int
  127. ) -> tuple[torch.Tensor, torch.Tensor]:
  128. """Obtain the "fixed block" global id corresponding to each input token.
  129. This implementation is a simplified version of the original Flaxformr implementation adopted from:
  130. https://github.com/google/flaxformer/blob/main/flaxformer/architectures/longt5/long_attention.py.
  131. In our scenario, as we use this strategy only for a decoder, orphan tokens, i.e. those tokens which do not make for
  132. the whole fixed block, are assigned to the preceding block.
  133. Padding tokens from the original sequence are represented by -1.
  134. """
  135. batch_size, seq_len = attention_mask.shape[:2]
  136. def handle_orphan_tokens(block_ids: torch.Tensor) -> torch.Tensor:
  137. block_ends = (torch.arange(seq_len) % global_block_size) == global_block_size - 1
  138. block_ends = block_ends.to(block_ids.device)
  139. true_block_ends = torch.logical_and(block_ends, block_ids >= 0)
  140. full_blocks = true_block_ends.sum(-1).unsqueeze(-1).type(block_ids.dtype) - 1
  141. block_ids = torch.where(block_ids < full_blocks, block_ids, full_blocks)
  142. return block_ids
  143. fixed_block_mask = torch.ones_like(attention_mask, device=attention_mask.device) / global_block_size
  144. fixed_block_mask = torch.cumsum(fixed_block_mask, axis=1) - fixed_block_mask
  145. mask = torch.where(attention_mask != 0.0, 1.0, -1000.0).type(attention_mask.dtype)
  146. global_block_ids = torch.floor(mask + fixed_block_mask - 1.0).type(attention_mask.dtype)
  147. _global_block_ids_lower_bound = torch.tensor(-1, dtype=global_block_ids.dtype, device=global_block_ids.device)
  148. global_block_ids = torch.where(
  149. global_block_ids > _global_block_ids_lower_bound, global_block_ids, _global_block_ids_lower_bound
  150. )
  151. # set padding tokens to -1
  152. global_block_ids = (global_block_ids * attention_mask) + (attention_mask - 1)
  153. # [batch_size, seq_len]
  154. global_block_ids = handle_orphan_tokens(global_block_ids)
  155. num_globals = seq_len // global_block_size
  156. # [batch_size, seq_len // global_block_size]
  157. if num_globals > 0:
  158. _sequence_block_ids_max = torch.max(global_block_ids, dim=-1).values.repeat(num_globals, 1).transpose(0, 1)
  159. else:
  160. _sequence_block_ids_max = torch.zeros(
  161. batch_size, 0, dtype=global_block_ids.dtype, device=global_block_ids.device
  162. )
  163. global_segment_ids = torch.cumsum(torch.ones(batch_size, num_globals), dim=-1) - 1
  164. global_segment_ids = global_segment_ids.to(attention_mask.device)
  165. global_segment_ids = torch.where(global_segment_ids <= _sequence_block_ids_max, 1, 0)
  166. return global_block_ids.type(torch.int), global_segment_ids.type(torch.int)
  167. def _make_side_relative_position_ids(attention_mask: torch.Tensor, global_block_size: int) -> torch.Tensor:
  168. """Create the relative position tensor for local -> global attention."""
  169. block_ids, global_segment_ids = _make_global_fixed_block_ids(attention_mask, global_block_size)
  170. global_seq_len = global_segment_ids.shape[-1]
  171. global_positions = torch.arange(global_seq_len, device=block_ids.device)
  172. side_relative_position = global_positions - block_ids[..., None]
  173. return side_relative_position.type(torch.int64)
  174. def _create_global_aggregates(
  175. hidden_states: torch.Tensor, block_ids: torch.Tensor, global_seq_len: int
  176. ) -> torch.Tensor:
  177. """Compute individual block aggregates by summing over individual blocks."""
  178. # (batch..., seq_len, global_seq_len))
  179. block_ids = block_ids.where(
  180. block_ids >= 0, torch.tensor(global_seq_len, dtype=block_ids.dtype, device=block_ids.device)
  181. )
  182. one_hot_block_ids = nn.functional.one_hot(block_ids.type(torch.int64), global_seq_len + 1)[:, :, :-1]
  183. return torch.einsum("...nd,...ng->...gd", hidden_states, one_hot_block_ids.type(hidden_states.dtype))
  184. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->LongT5
  185. class LongT5LayerNorm(nn.Module):
  186. def __init__(self, hidden_size, eps=1e-6):
  187. """
  188. Construct a layernorm module in the LongT5 style. No bias and no subtraction of mean.
  189. """
  190. super().__init__()
  191. self.weight = nn.Parameter(torch.ones(hidden_size))
  192. self.variance_epsilon = eps
  193. def forward(self, hidden_states):
  194. # LongT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  195. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  196. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  197. # half-precision inputs is done in fp32
  198. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  199. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  200. # convert into half-precision if necessary
  201. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  202. hidden_states = hidden_states.to(self.weight.dtype)
  203. return self.weight * hidden_states
  204. try:
  205. from apex.normalization import FusedRMSNorm
  206. LongT5LayerNorm = FusedRMSNorm
  207. logger.info("Discovered apex.normalization.FusedRMSNorm - will use it instead of LongT5LayerNorm")
  208. except ImportError:
  209. # using the normal LongT5LayerNorm
  210. pass
  211. except Exception:
  212. logger.warning("discovered apex but it failed to load, falling back to LongT5LayerNorm")
  213. pass
  214. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->LongT5
  215. class LongT5DenseActDense(nn.Module):
  216. def __init__(self, config: LongT5Config):
  217. super().__init__()
  218. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  219. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  220. self.dropout = nn.Dropout(config.dropout_rate)
  221. self.act = ACT2FN[config.dense_act_fn]
  222. def forward(self, hidden_states):
  223. hidden_states = self.wi(hidden_states)
  224. hidden_states = self.act(hidden_states)
  225. hidden_states = self.dropout(hidden_states)
  226. if (
  227. isinstance(self.wo.weight, torch.Tensor)
  228. and hidden_states.dtype != self.wo.weight.dtype
  229. and self.wo.weight.dtype != torch.int8
  230. ):
  231. hidden_states = hidden_states.to(self.wo.weight.dtype)
  232. hidden_states = self.wo(hidden_states)
  233. return hidden_states
  234. class LongT5DenseGatedActDense(nn.Module):
  235. def __init__(self, config: LongT5Config):
  236. super().__init__()
  237. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  238. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  239. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  240. self.dropout = nn.Dropout(config.dropout_rate)
  241. self.act = ACT2FN[config.dense_act_fn]
  242. def forward(self, hidden_states):
  243. hidden_gelu = self.act(self.wi_0(hidden_states))
  244. hidden_linear = self.wi_1(hidden_states)
  245. hidden_states = hidden_gelu * hidden_linear
  246. hidden_states = self.dropout(hidden_states)
  247. hidden_states = self.wo(hidden_states)
  248. return hidden_states
  249. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->LongT5
  250. class LongT5LayerFF(nn.Module):
  251. def __init__(self, config: LongT5Config):
  252. super().__init__()
  253. if config.is_gated_act:
  254. self.DenseReluDense = LongT5DenseGatedActDense(config)
  255. else:
  256. self.DenseReluDense = LongT5DenseActDense(config)
  257. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  258. self.dropout = nn.Dropout(config.dropout_rate)
  259. def forward(self, hidden_states):
  260. forwarded_states = self.layer_norm(hidden_states)
  261. forwarded_states = self.DenseReluDense(forwarded_states)
  262. hidden_states = hidden_states + self.dropout(forwarded_states)
  263. return hidden_states
  264. # Copied from transformers.models.t5.modeling_t5.T5Attention with T5->LongT5
  265. class LongT5Attention(nn.Module):
  266. def __init__(
  267. self,
  268. config: LongT5Config,
  269. has_relative_attention_bias=False,
  270. layer_idx: Optional[int] = None,
  271. ):
  272. super().__init__()
  273. self.is_decoder = config.is_decoder
  274. self.has_relative_attention_bias = has_relative_attention_bias
  275. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  276. self.relative_attention_max_distance = config.relative_attention_max_distance
  277. self.d_model = config.d_model
  278. self.key_value_proj_dim = config.d_kv
  279. self.n_heads = config.num_heads
  280. self.dropout = config.dropout_rate
  281. self.inner_dim = self.n_heads * self.key_value_proj_dim
  282. self.layer_idx = layer_idx
  283. if layer_idx is None and self.is_decoder:
  284. logger.warning_once(
  285. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  286. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  287. "when creating this class."
  288. )
  289. # Mesh TensorFlow initialization to avoid scaling before softmax
  290. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  291. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  292. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  293. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  294. if self.has_relative_attention_bias:
  295. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  296. self.pruned_heads = set()
  297. self.gradient_checkpointing = False
  298. def prune_heads(self, heads):
  299. if len(heads) == 0:
  300. return
  301. heads, index = find_pruneable_heads_and_indices(
  302. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  303. )
  304. # Prune linear layers
  305. self.q = prune_linear_layer(self.q, index)
  306. self.k = prune_linear_layer(self.k, index)
  307. self.v = prune_linear_layer(self.v, index)
  308. self.o = prune_linear_layer(self.o, index, dim=1)
  309. # Update hyper params
  310. self.n_heads = self.n_heads - len(heads)
  311. self.inner_dim = self.key_value_proj_dim * self.n_heads
  312. self.pruned_heads = self.pruned_heads.union(heads)
  313. @staticmethod
  314. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  315. """
  316. Adapted from Mesh Tensorflow:
  317. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  318. Translate relative position to a bucket number for relative attention. The relative position is defined as
  319. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  320. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  321. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  322. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  323. This should allow for more graceful generalization to longer sequences than the model has been trained on
  324. Args:
  325. relative_position: an int32 Tensor
  326. bidirectional: a boolean - whether the attention is bidirectional
  327. num_buckets: an integer
  328. max_distance: an integer
  329. Returns:
  330. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  331. """
  332. relative_buckets = 0
  333. if bidirectional:
  334. num_buckets //= 2
  335. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  336. relative_position = torch.abs(relative_position)
  337. else:
  338. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  339. # now relative_position is in the range [0, inf)
  340. # half of the buckets are for exact increments in positions
  341. max_exact = num_buckets // 2
  342. is_small = relative_position < max_exact
  343. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  344. relative_position_if_large = max_exact + (
  345. torch.log(relative_position.float() / max_exact)
  346. / math.log(max_distance / max_exact)
  347. * (num_buckets - max_exact)
  348. ).to(torch.long)
  349. relative_position_if_large = torch.min(
  350. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  351. )
  352. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  353. return relative_buckets
  354. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  355. """Compute binned relative position bias"""
  356. if device is None:
  357. device = self.relative_attention_bias.weight.device
  358. if cache_position is None:
  359. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  360. else:
  361. context_position = cache_position[:, None].to(device)
  362. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  363. relative_position = memory_position - context_position # shape (query_length, key_length)
  364. relative_position_bucket = self._relative_position_bucket(
  365. relative_position, # shape (query_length, key_length)
  366. bidirectional=(not self.is_decoder),
  367. num_buckets=self.relative_attention_num_buckets,
  368. max_distance=self.relative_attention_max_distance,
  369. )
  370. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  371. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  372. return values
  373. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  374. def forward(
  375. self,
  376. hidden_states,
  377. mask=None,
  378. key_value_states=None,
  379. position_bias=None,
  380. past_key_values=None,
  381. layer_head_mask=None,
  382. query_length=None,
  383. use_cache=False,
  384. output_attentions=False,
  385. cache_position=None,
  386. ):
  387. """
  388. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  389. """
  390. # Input is (batch_size, seq_length, dim)
  391. # Mask is (batch_size, 1, 1, key_length) (non-causal encoder) or (batch_size, 1, seq_length, key_length) (causal decoder)
  392. batch_size, seq_length = hidden_states.shape[:2]
  393. # if key_value_states are provided this layer is used as a cross-attention layer for the decoder
  394. is_cross_attention = key_value_states is not None
  395. query_states = self.q(hidden_states)
  396. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  397. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  398. is_updated = False
  399. if isinstance(past_key_values, EncoderDecoderCache):
  400. is_updated = past_key_values.is_updated.get(self.layer_idx)
  401. if is_cross_attention:
  402. # after the first generated id, we can subsequently re-use all key/value_states from cache
  403. curr_past_key_value = past_key_values.cross_attention_cache
  404. else:
  405. curr_past_key_value = past_key_values.self_attention_cache
  406. else:
  407. curr_past_key_value = past_key_values
  408. current_states = key_value_states if is_cross_attention else hidden_states
  409. if is_cross_attention and past_key_values is not None and is_updated:
  410. # reuse k,v, cross_attentions
  411. key_states = curr_past_key_value.layers[self.layer_idx].keys
  412. value_states = curr_past_key_value.layers[self.layer_idx].values
  413. else:
  414. key_states = self.k(current_states)
  415. value_states = self.v(current_states)
  416. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  417. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  418. if past_key_values is not None:
  419. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  420. cache_position = cache_position if not is_cross_attention else None
  421. key_states, value_states = curr_past_key_value.update(
  422. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  423. )
  424. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  425. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  426. past_key_values.is_updated[self.layer_idx] = True
  427. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  428. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  429. if position_bias is None:
  430. key_length = key_states.shape[-2]
  431. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  432. real_seq_length = query_length if query_length is not None else cache_position[-1] + 1
  433. if not self.has_relative_attention_bias:
  434. position_bias = torch.zeros(
  435. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  436. )
  437. if self.gradient_checkpointing and self.training:
  438. position_bias.requires_grad = True
  439. else:
  440. position_bias = self.compute_bias(
  441. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  442. )
  443. position_bias = position_bias[:, :, -seq_length:, :]
  444. if mask is not None:
  445. causal_mask = mask[:, :, :, : key_states.shape[-2]]
  446. position_bias = position_bias + causal_mask
  447. if self.pruned_heads:
  448. mask = torch.ones(position_bias.shape[1])
  449. mask[list(self.pruned_heads)] = 0
  450. position_bias_masked = position_bias[:, mask.bool()]
  451. else:
  452. position_bias_masked = position_bias
  453. scores += position_bias_masked
  454. # (batch_size, n_heads, seq_length, key_length)
  455. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  456. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  457. # Mask heads if we want to
  458. if layer_head_mask is not None:
  459. attn_weights = attn_weights * layer_head_mask
  460. attn_output = torch.matmul(attn_weights, value_states)
  461. attn_output = attn_output.transpose(1, 2).contiguous()
  462. attn_output = attn_output.view(batch_size, -1, self.inner_dim)
  463. attn_output = self.o(attn_output)
  464. outputs = (attn_output, position_bias)
  465. if output_attentions:
  466. outputs = outputs + (attn_weights,)
  467. return outputs
  468. class LongT5LocalAttention(nn.Module):
  469. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  470. super().__init__()
  471. self.is_decoder = config.is_decoder
  472. self.has_relative_attention_bias = has_relative_attention_bias
  473. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  474. self.relative_attention_max_distance = config.relative_attention_max_distance
  475. self.d_model = config.d_model
  476. self.key_value_proj_dim = config.d_kv
  477. self.n_heads = config.num_heads
  478. self.local_radius = config.local_radius
  479. self.block_len = self.local_radius + 1
  480. self.dropout = config.dropout_rate
  481. self.inner_dim = self.n_heads * self.key_value_proj_dim
  482. # Mesh TensorFlow initialization to avoid scaling before softmax
  483. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  484. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  485. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  486. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  487. if self.has_relative_attention_bias:
  488. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  489. self.pruned_heads = set()
  490. self.gradient_checkpointing = False
  491. # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
  492. def prune_heads(self, heads):
  493. if len(heads) == 0:
  494. return
  495. heads, index = find_pruneable_heads_and_indices(
  496. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  497. )
  498. # Prune linear layers
  499. self.q = prune_linear_layer(self.q, index)
  500. self.k = prune_linear_layer(self.k, index)
  501. self.v = prune_linear_layer(self.v, index)
  502. self.o = prune_linear_layer(self.o, index, dim=1)
  503. # Update hyper params
  504. self.n_heads = self.n_heads - len(heads)
  505. self.inner_dim = self.key_value_proj_dim * self.n_heads
  506. self.pruned_heads = self.pruned_heads.union(heads)
  507. @staticmethod
  508. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  509. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  510. """
  511. Adapted from Mesh Tensorflow:
  512. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  513. Translate relative position to a bucket number for relative attention. The relative position is defined as
  514. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  515. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  516. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  517. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  518. This should allow for more graceful generalization to longer sequences than the model has been trained on
  519. Args:
  520. relative_position: an int32 Tensor
  521. bidirectional: a boolean - whether the attention is bidirectional
  522. num_buckets: an integer
  523. max_distance: an integer
  524. Returns:
  525. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  526. """
  527. relative_buckets = 0
  528. if bidirectional:
  529. num_buckets //= 2
  530. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  531. relative_position = torch.abs(relative_position)
  532. else:
  533. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  534. # now relative_position is in the range [0, inf)
  535. # half of the buckets are for exact increments in positions
  536. max_exact = num_buckets // 2
  537. is_small = relative_position < max_exact
  538. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  539. relative_position_if_large = max_exact + (
  540. torch.log(relative_position.float() / max_exact)
  541. / math.log(max_distance / max_exact)
  542. * (num_buckets - max_exact)
  543. ).to(torch.long)
  544. relative_position_if_large = torch.min(
  545. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  546. )
  547. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  548. return relative_buckets
  549. def compute_bias(self, block_length: int):
  550. """Compute binned relative position bias"""
  551. target_device = (
  552. self.relative_attention_bias.weight.device
  553. if self.relative_attention_bias.weight.device.type != "meta"
  554. else None
  555. )
  556. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  557. context_position = memory_position[block_length:-block_length]
  558. # (block_length, 3 * block_length)
  559. relative_position = memory_position[None, :] - context_position[:, None]
  560. relative_position_bucket = self._relative_position_bucket(
  561. relative_position, # (block_length, 3 * block_length)
  562. bidirectional=(not self.is_decoder),
  563. num_buckets=self.relative_attention_num_buckets,
  564. max_distance=self.relative_attention_max_distance,
  565. )
  566. # (block_length, 3 * block_length, num_heads)
  567. values = self.relative_attention_bias(relative_position_bucket)
  568. # (1, 1, num_heads, block_length, 3 * block_length)
  569. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  570. return values
  571. def forward(
  572. self,
  573. hidden_states,
  574. mask=None,
  575. position_bias=None,
  576. layer_head_mask=None,
  577. output_attentions=False,
  578. ):
  579. batch_size, seq_length = hidden_states.shape[:2]
  580. def shape(states):
  581. """projection"""
  582. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  583. def unshape(states):
  584. """reshape"""
  585. return states.contiguous().view(batch_size, -1, self.inner_dim)
  586. # get query/key/value states -> (batch_size, seq_length, n_heads, dim_per_head)
  587. query_states = shape(self.q(hidden_states))
  588. key_states = shape(self.k(hidden_states))
  589. value_states = shape(self.v(hidden_states))
  590. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  591. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  592. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  593. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  594. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  595. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  596. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  597. # Compute scores
  598. scores = torch.einsum(
  599. "...qhd,...khd->...hqk", query_states, key_states
  600. ) # (batch_size, num_block, n_heads, block_len, 3 * block_len)
  601. if position_bias is None:
  602. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  603. if not self.has_relative_attention_bias:
  604. position_bias = torch.zeros(
  605. (1, 1, self.n_heads, self.block_len, 3 * self.block_len), device=scores.device, dtype=scores.dtype
  606. )
  607. if self.gradient_checkpointing and self.training:
  608. position_bias.requires_grad = True
  609. else:
  610. position_bias = self.compute_bias(self.block_len)
  611. if mask is not None:
  612. # Replace masked positions with -1e10 (according to the original implementation)
  613. mask = torch.where(mask > 0, 0.0, -1e10)
  614. # We need to adjust position bias shape to be sum with mask
  615. position_bias = position_bias + mask.transpose(1, 2)
  616. scores += position_bias
  617. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  618. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  619. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len)
  620. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  621. # Mask heads if we want to
  622. if layer_head_mask is not None:
  623. attn_weights = attn_weights * layer_head_mask
  624. attn_weights = attn_weights.type(value_states.dtype)
  625. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  626. attn_output = attn_output[:, :seq_length, :]
  627. attn_output = self.o(attn_output)
  628. outputs = (
  629. attn_output,
  630. position_bias,
  631. )
  632. if output_attentions:
  633. outputs = outputs + (attn_weights,)
  634. return outputs
  635. class LongT5TransientGlobalAttention(nn.Module):
  636. def __init__(self, config: LongT5Config, has_relative_attention_bias: bool = False) -> None:
  637. super().__init__()
  638. self.is_decoder = config.is_decoder
  639. self.has_relative_attention_bias = has_relative_attention_bias
  640. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  641. self.relative_attention_max_distance = config.relative_attention_max_distance
  642. self.d_model = config.d_model
  643. self.key_value_proj_dim = config.d_kv
  644. self.n_heads = config.num_heads
  645. self.local_radius = config.local_radius
  646. self.block_len = self.local_radius + 1
  647. self.global_block_size = config.global_block_size
  648. self.dropout = config.dropout_rate
  649. self.inner_dim = self.n_heads * self.key_value_proj_dim
  650. # Mesh TensorFlow initialization to avoid scaling before softmax
  651. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  652. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  653. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  654. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  655. if self.has_relative_attention_bias:
  656. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  657. self.pruned_heads = set()
  658. # Relativen attention bias & Layer norm for global attention
  659. if self.has_relative_attention_bias:
  660. self.global_relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  661. self.global_input_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  662. # Copied from transformers.models.t5.modeling_t5.T5Attention.prune_heads
  663. def prune_heads(self, heads):
  664. if len(heads) == 0:
  665. return
  666. heads, index = find_pruneable_heads_and_indices(
  667. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads
  668. )
  669. # Prune linear layers
  670. self.q = prune_linear_layer(self.q, index)
  671. self.k = prune_linear_layer(self.k, index)
  672. self.v = prune_linear_layer(self.v, index)
  673. self.o = prune_linear_layer(self.o, index, dim=1)
  674. # Update hyper params
  675. self.n_heads = self.n_heads - len(heads)
  676. self.inner_dim = self.key_value_proj_dim * self.n_heads
  677. self.pruned_heads = self.pruned_heads.union(heads)
  678. @staticmethod
  679. # Copied from transformers.models.t5.modeling_t5.T5Attention._relative_position_bucket
  680. def _relative_position_bucket(relative_position, bidirectional=True, num_buckets=32, max_distance=128):
  681. """
  682. Adapted from Mesh Tensorflow:
  683. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  684. Translate relative position to a bucket number for relative attention. The relative position is defined as
  685. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  686. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  687. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  688. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  689. This should allow for more graceful generalization to longer sequences than the model has been trained on
  690. Args:
  691. relative_position: an int32 Tensor
  692. bidirectional: a boolean - whether the attention is bidirectional
  693. num_buckets: an integer
  694. max_distance: an integer
  695. Returns:
  696. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  697. """
  698. relative_buckets = 0
  699. if bidirectional:
  700. num_buckets //= 2
  701. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  702. relative_position = torch.abs(relative_position)
  703. else:
  704. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  705. # now relative_position is in the range [0, inf)
  706. # half of the buckets are for exact increments in positions
  707. max_exact = num_buckets // 2
  708. is_small = relative_position < max_exact
  709. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  710. relative_position_if_large = max_exact + (
  711. torch.log(relative_position.float() / max_exact)
  712. / math.log(max_distance / max_exact)
  713. * (num_buckets - max_exact)
  714. ).to(torch.long)
  715. relative_position_if_large = torch.min(
  716. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  717. )
  718. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  719. return relative_buckets
  720. def compute_bias(self, block_length: int):
  721. """Compute binned relative position bias"""
  722. target_device = (
  723. self.relative_attention_bias.weight.device
  724. if self.relative_attention_bias.weight.device.type != "meta"
  725. else None
  726. )
  727. memory_position = torch.arange(3 * block_length, dtype=torch.long, device=target_device)
  728. context_position = memory_position[block_length:-block_length]
  729. # (block_length, 3 * block_length)
  730. relative_position = memory_position[None, :] - context_position[:, None]
  731. relative_position_bucket = self._relative_position_bucket(
  732. relative_position, # (block_length, 3 * block_length)
  733. bidirectional=(not self.is_decoder),
  734. num_buckets=self.relative_attention_num_buckets,
  735. max_distance=self.relative_attention_max_distance,
  736. )
  737. # (block_length, 3 * block_length, num_heads)
  738. values = self.relative_attention_bias(relative_position_bucket)
  739. # (1, 1, num_heads, block_length, 3 * block_length)
  740. values = values.permute([2, 0, 1]).unsqueeze(0).unsqueeze(0)
  741. return values
  742. def compute_side_bias(self, mask: torch.Tensor, global_segment_ids: torch.Tensor) -> torch.Tensor:
  743. # (batch_size, 1, seq_len, global_seq_len)
  744. side_attention_mask = torch.eq(mask[..., None], global_segment_ids[:, None, :])[:, None, ...]
  745. attention_side_bias = torch.where(side_attention_mask > 0, 0.0, -1e10)
  746. # (batch_size, seq_len, global_seq_len)
  747. side_relative_position = _make_side_relative_position_ids(mask, self.global_block_size)
  748. side_relative_position_bucket = self._relative_position_bucket(
  749. side_relative_position,
  750. bidirectional=(not self.is_decoder),
  751. num_buckets=self.relative_attention_num_buckets,
  752. max_distance=self.relative_attention_max_distance,
  753. )
  754. # (batch_size, seq_len, global_seq_len, num_heads)
  755. side_bias = self.global_relative_attention_bias(side_relative_position_bucket)
  756. # (batch_size, num_heads, seq_len, global_seq_len)
  757. side_bias = side_bias.permute([0, 3, 1, 2])
  758. # (batch_size, num_heads, seq_len, global_seq_len)
  759. attention_side_bias = attention_side_bias + side_bias
  760. return attention_side_bias
  761. def forward(
  762. self,
  763. hidden_states,
  764. mask=None,
  765. position_bias=None,
  766. layer_head_mask=None,
  767. output_attentions=False,
  768. ):
  769. batch_size, seq_length = hidden_states.shape[:2]
  770. def shape(states):
  771. """projection"""
  772. return states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim)
  773. def unshape(states):
  774. """reshape"""
  775. return states.contiguous().view(batch_size, -1, self.inner_dim)
  776. # Prepare components for transient-global attention
  777. # Obtain block_ids and global_segment_ids
  778. # global_seq_len := seq_len // self.global_block_size
  779. # shapes: (batch_size, seq_len) & (batch_size, global_seq_len)
  780. block_ids, global_segment_ids = _make_global_fixed_block_ids(
  781. mask if mask is not None else torch.ones(hidden_states.shape[:-1]),
  782. self.global_block_size,
  783. )
  784. # Create global inputs
  785. _global_seq_len = global_segment_ids.shape[-1]
  786. global_inputs = _create_global_aggregates(hidden_states, block_ids, _global_seq_len)
  787. global_inputs = self.global_input_layer_norm(global_inputs)
  788. # get query states -> (batch_size, seq_length, n_heads, dim_per_head)
  789. query_states = shape(self.q(hidden_states))
  790. key_states = shape(self.k(hidden_states))
  791. value_states = shape(self.v(hidden_states))
  792. # Get global/side key/value states shape: (batch_size, global_seq_len, n_heads, dim_per_head)
  793. side_key_states = shape(self.k(global_inputs))
  794. side_value_states = shape(self.v(global_inputs))
  795. # Split into blocks -> (batch_size, num_blocks, block_len, n_heads, dim_per_head)
  796. query_states = _split_into_blocks(query_states, self.block_len, dim=1)
  797. key_states = _split_into_blocks(key_states, self.block_len, dim=1)
  798. value_states = _split_into_blocks(value_states, self.block_len, dim=1)
  799. # Concatenate 3 blocks for keys and values -> (batch_size, num_blocks, 3 * block_len, n_heads, dim_per_head)
  800. key_states = _concatenate_3_blocks(key_states, block_dim=1, sequence_dim=2)
  801. value_states = _concatenate_3_blocks(value_states, block_dim=1, sequence_dim=2)
  802. # Tile side inputs across local key/value blocks
  803. # New shape: (batch_size, num_blocks, global_seq_len, n_heads, dim_per_head)
  804. reps = [1] * (side_key_states.ndim + 1)
  805. reps[1] = key_states.shape[1]
  806. side_key_states = side_key_states.unsqueeze(1).repeat(reps)
  807. side_value_states = side_value_states.unsqueeze(1).repeat(reps)
  808. # Concatenate "local" and "side"/"global" key/value states to allow each token to attend global aggregated ones
  809. # New shape: (batch_size, num_blocks, 3 * block_len + global_seq_len, n_heads, dim_per_head)
  810. key_states = torch.cat([key_states, side_key_states], dim=2)
  811. value_states = torch.cat([value_states, side_value_states], dim=2)
  812. # Compute scores -> (batch_size, num_block, n_heads, block_len, 3 * block_len + global_seq_len)
  813. scores = torch.einsum("...qhd,...khd->...hqk", query_states, key_states)
  814. if mask is not None:
  815. # We need to adjust position bias shape to be sum with mask
  816. local_attention_mask = _get_local_attention_mask(mask, self.block_len, hidden_states.device)
  817. # Replace masked positions with -10_000 (according to the original implementation)
  818. local_attention_mask = torch.where(local_attention_mask > 0, 0.0, -1e10)
  819. else:
  820. local_attention_mask = None
  821. if position_bias is None:
  822. # position_bias shape: # (1, 1, n_heads, block_len, 3 * block_len)
  823. if not self.has_relative_attention_bias:
  824. position_bias = torch.zeros(
  825. (1, 1, self.n_heads, self.block_len, 3 * self.block_len),
  826. device=scores.device,
  827. dtype=scores.dtype,
  828. )
  829. if self.gradient_checkpointing and self.training:
  830. position_bias.requires_grad = True
  831. else:
  832. position_bias = self.compute_bias(self.block_len)
  833. if local_attention_mask is not None:
  834. # (batch_size, 1, n_heads, block_len, 3 * block_len)
  835. position_bias = position_bias + local_attention_mask.transpose(1, 2)
  836. position_bias = position_bias.type(scores.dtype)
  837. # Calculate global/side bias - shape: # (batch_size, num_heads, seq_len, global_seq_len)
  838. if mask is None:
  839. mask = torch.ones(batch_size, seq_length)
  840. # (batch_size, num_heads, seq_len, global_seq_len)
  841. side_position_bias = self.compute_side_bias(mask, global_segment_ids)
  842. # (batch_size, num_blocks, num_heads, block_len, global_seq_len)
  843. side_position_bias = _split_into_blocks(side_position_bias, self.block_len, dim=-2).transpose(1, 2)
  844. side_position_bias = side_position_bias.type(scores.dtype).to(scores.device)
  845. # (batch_size, num_blocks, num_heads, block_len, 3 * block_len + global_seq_len)
  846. position_bias = torch.cat([position_bias, side_position_bias], dim=-1)
  847. scores += position_bias
  848. # (batch_size, num_blocks, n_heads, block_len, 3 * block_len + global_seq_len)
  849. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  850. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  851. # Mask heads if we want to
  852. if layer_head_mask is not None:
  853. attn_weights = attn_weights * layer_head_mask
  854. attn_weights = attn_weights.type(value_states.dtype)
  855. attn_output = unshape(torch.einsum("...hqk,...khd->...qhd", attn_weights, value_states))
  856. attn_output = attn_output[:, :seq_length, :]
  857. attn_output = self.o(attn_output)
  858. outputs = (attn_output, position_bias)
  859. if output_attentions:
  860. outputs = outputs + (attn_weights,)
  861. return outputs
  862. # Copied from transformers.models.t5.modeling_t5.T5LayerSelfAttention with T5->LongT5
  863. class LongT5LayerSelfAttention(nn.Module):
  864. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  865. super().__init__()
  866. self.SelfAttention = LongT5Attention(
  867. config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx
  868. )
  869. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  870. self.dropout = nn.Dropout(config.dropout_rate)
  871. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  872. def forward(
  873. self,
  874. hidden_states,
  875. attention_mask=None,
  876. position_bias=None,
  877. layer_head_mask=None,
  878. past_key_values=None,
  879. use_cache=False,
  880. output_attentions=False,
  881. cache_position=None,
  882. ):
  883. normed_hidden_states = self.layer_norm(hidden_states)
  884. attention_output = self.SelfAttention(
  885. normed_hidden_states,
  886. mask=attention_mask,
  887. position_bias=position_bias,
  888. layer_head_mask=layer_head_mask,
  889. past_key_values=past_key_values,
  890. use_cache=use_cache,
  891. output_attentions=output_attentions,
  892. cache_position=cache_position,
  893. )
  894. hidden_states = hidden_states + self.dropout(attention_output[0])
  895. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  896. return outputs
  897. class LongT5LayerLocalSelfAttention(nn.Module):
  898. """Local self attention used in encoder"""
  899. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  900. super().__init__()
  901. self.LocalSelfAttention = LongT5LocalAttention(config, has_relative_attention_bias=has_relative_attention_bias)
  902. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  903. self.dropout = nn.Dropout(config.dropout_rate)
  904. def forward(
  905. self,
  906. hidden_states,
  907. attention_mask=None,
  908. position_bias=None,
  909. layer_head_mask=None,
  910. output_attentions=False,
  911. **kwargs: Any, # to accept past_key_values and use_cache kwargs
  912. ):
  913. normed_hidden_states = self.layer_norm(hidden_states)
  914. attention_output = self.LocalSelfAttention(
  915. normed_hidden_states,
  916. mask=attention_mask,
  917. position_bias=position_bias,
  918. layer_head_mask=layer_head_mask,
  919. output_attentions=output_attentions,
  920. )
  921. hidden_states = hidden_states + self.dropout(attention_output[0])
  922. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  923. return outputs
  924. class LongT5LayerTransientGlobalSelfAttention(nn.Module):
  925. """Transient-Global self attention used in encoder"""
  926. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  927. super().__init__()
  928. self.TransientGlobalSelfAttention = LongT5TransientGlobalAttention(
  929. config, has_relative_attention_bias=has_relative_attention_bias
  930. )
  931. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  932. self.dropout = nn.Dropout(config.dropout_rate)
  933. def forward(
  934. self,
  935. hidden_states,
  936. attention_mask=None,
  937. position_bias=None,
  938. layer_head_mask=None,
  939. output_attentions=False,
  940. **kwargs: Any, # to accept past_key_values and use_cache kwargs
  941. ):
  942. normed_hidden_states = self.layer_norm(hidden_states)
  943. attention_output = self.TransientGlobalSelfAttention(
  944. normed_hidden_states,
  945. mask=attention_mask,
  946. position_bias=position_bias,
  947. layer_head_mask=layer_head_mask,
  948. output_attentions=output_attentions,
  949. )
  950. hidden_states = hidden_states + self.dropout(attention_output[0])
  951. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  952. return outputs
  953. # Copied from transformers.models.t5.modeling_t5.T5LayerCrossAttention with T5->LongT5
  954. class LongT5LayerCrossAttention(nn.Module):
  955. def __init__(self, config, layer_idx: Optional[int] = None):
  956. super().__init__()
  957. self.EncDecAttention = LongT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  958. self.layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  959. self.dropout = nn.Dropout(config.dropout_rate)
  960. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  961. def forward(
  962. self,
  963. hidden_states,
  964. key_value_states,
  965. attention_mask=None,
  966. position_bias=None,
  967. layer_head_mask=None,
  968. past_key_values=None,
  969. use_cache=False,
  970. query_length=None,
  971. output_attentions=False,
  972. cache_position=None,
  973. ):
  974. normed_hidden_states = self.layer_norm(hidden_states)
  975. attention_output = self.EncDecAttention(
  976. normed_hidden_states,
  977. mask=attention_mask,
  978. key_value_states=key_value_states,
  979. position_bias=position_bias,
  980. layer_head_mask=layer_head_mask,
  981. past_key_values=past_key_values,
  982. use_cache=use_cache,
  983. query_length=query_length,
  984. output_attentions=output_attentions,
  985. cache_position=cache_position,
  986. )
  987. layer_output = hidden_states + self.dropout(attention_output[0])
  988. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  989. return outputs
  990. class LongT5Block(GradientCheckpointingLayer):
  991. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  992. super().__init__()
  993. self.is_decoder = config.is_decoder
  994. if config.is_decoder:
  995. attention_layer = LongT5LayerSelfAttention
  996. elif config.encoder_attention_type == "local":
  997. attention_layer = LongT5LayerLocalSelfAttention
  998. elif config.encoder_attention_type == "transient-global":
  999. attention_layer = LongT5LayerTransientGlobalSelfAttention
  1000. else:
  1001. raise ValueError(
  1002. "For encoder attention mechanism, either `local` or `transient-global` attention type is expected, "
  1003. f"but got {config.encoder_attention_type}."
  1004. )
  1005. self.layer = nn.ModuleList()
  1006. self.layer.append(
  1007. attention_layer(config, has_relative_attention_bias=has_relative_attention_bias, layer_idx=layer_idx)
  1008. )
  1009. if self.is_decoder:
  1010. self.layer.append(LongT5LayerCrossAttention(config, layer_idx=layer_idx))
  1011. self.layer.append(LongT5LayerFF(config))
  1012. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  1013. def forward(
  1014. self,
  1015. hidden_states,
  1016. attention_mask=None,
  1017. position_bias=None,
  1018. encoder_hidden_states=None,
  1019. encoder_attention_mask=None,
  1020. encoder_decoder_position_bias=None,
  1021. layer_head_mask=None,
  1022. cross_attn_layer_head_mask=None,
  1023. past_key_values=None,
  1024. use_cache=False,
  1025. output_attentions=False,
  1026. return_dict=True,
  1027. cache_position=None,
  1028. ):
  1029. self_attention_outputs = self.layer[0](
  1030. hidden_states,
  1031. attention_mask=attention_mask,
  1032. position_bias=position_bias,
  1033. layer_head_mask=layer_head_mask,
  1034. past_key_values=past_key_values,
  1035. use_cache=use_cache,
  1036. output_attentions=output_attentions,
  1037. cache_position=cache_position,
  1038. )
  1039. hidden_states = self_attention_outputs[0]
  1040. attention_outputs = self_attention_outputs[1:] # Keep self-attention outputs and relative position weights
  1041. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1042. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1043. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1044. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1045. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  1046. if do_cross_attention:
  1047. cross_attention_outputs = self.layer[1](
  1048. hidden_states,
  1049. key_value_states=encoder_hidden_states,
  1050. attention_mask=encoder_attention_mask,
  1051. position_bias=encoder_decoder_position_bias,
  1052. layer_head_mask=cross_attn_layer_head_mask,
  1053. past_key_values=past_key_values,
  1054. query_length=cache_position[-1] + 1,
  1055. use_cache=use_cache,
  1056. output_attentions=output_attentions,
  1057. cache_position=cache_position,
  1058. )
  1059. hidden_states = cross_attention_outputs[0]
  1060. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1061. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1062. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1063. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1064. # Keep cross-attention outputs and relative position weights
  1065. attention_outputs = attention_outputs + cross_attention_outputs[1:]
  1066. # Apply Feed Forward layer
  1067. hidden_states = self.layer[-1](hidden_states)
  1068. # clamp inf values to enable fp16 inference - check https://github.com/huggingface/transformers/pull/19229/
  1069. if hidden_states.dtype == torch.float16 and torch.isinf(hidden_states).any():
  1070. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  1071. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  1072. return (
  1073. (hidden_states,) + attention_outputs
  1074. ) # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1075. @auto_docstring
  1076. class LongT5PreTrainedModel(PreTrainedModel):
  1077. config: LongT5Config
  1078. base_model_prefix = "transformer"
  1079. supports_gradient_checkpointing = True
  1080. _no_split_modules = ["LongT5Block"]
  1081. _can_compile_fullgraph = False # TODO: @raushan more involved due to local/global attn
  1082. @property
  1083. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel.dummy_inputs
  1084. def dummy_inputs(self):
  1085. input_ids = torch.tensor(DUMMY_INPUTS)
  1086. input_mask = torch.tensor(DUMMY_MASK)
  1087. dummy_inputs = {
  1088. "decoder_input_ids": input_ids,
  1089. "input_ids": input_ids,
  1090. "decoder_attention_mask": input_mask,
  1091. }
  1092. return dummy_inputs
  1093. def _try_load_missing_tied_module(self, key):
  1094. module = self
  1095. key = key.removesuffix(".weight")
  1096. for sub_key in key.split("."):
  1097. if not hasattr(module, sub_key):
  1098. return
  1099. module = getattr(module, sub_key)
  1100. self._tie_or_clone_weights(module, self.shared)
  1101. @classmethod
  1102. def from_pretrained(self, *args, **kwargs):
  1103. requested_loading_info = kwargs.get("output_loading_info", False)
  1104. kwargs["output_loading_info"] = True
  1105. model, loading_info = super().from_pretrained(*args, **kwargs)
  1106. missing_keys = loading_info.get("missing_keys", [])
  1107. if hasattr(model, "shared") and hasattr(model, "_tied_weights_keys"):
  1108. for missing_key in missing_keys:
  1109. logger.warning(
  1110. f"Recovering a missing tied weight {missing_key} from a legacy LongT5 checkpoint. "
  1111. f"Consider saving {missing_key} in your checkpoint or updating the config (tie_word_embeddings=true)."
  1112. )
  1113. model._try_load_missing_tied_module(missing_key)
  1114. if requested_loading_info:
  1115. return model, loading_info
  1116. return model
  1117. def _init_weights(self, module):
  1118. """Initialize the weights"""
  1119. factor = self.config.initializer_factor # Used for testing weights initialization
  1120. if isinstance(module, LongT5LayerNorm):
  1121. module.weight.data.fill_(factor * 1.0)
  1122. elif isinstance(module, (LongT5Model, LongT5ForConditionalGeneration, LongT5EncoderModel)):
  1123. # Mesh TensorFlow embeddings initialization
  1124. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  1125. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  1126. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  1127. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  1128. elif isinstance(module, LongT5DenseActDense):
  1129. # Mesh TensorFlow FF initialization
  1130. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  1131. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  1132. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1133. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  1134. module.wi.bias.data.zero_()
  1135. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  1136. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  1137. module.wo.bias.data.zero_()
  1138. elif isinstance(module, LongT5DenseGatedActDense):
  1139. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1140. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  1141. module.wi_0.bias.data.zero_()
  1142. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  1143. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  1144. module.wi_1.bias.data.zero_()
  1145. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  1146. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  1147. module.wo.bias.data.zero_()
  1148. elif isinstance(module, (LongT5Attention, LongT5LocalAttention, LongT5TransientGlobalAttention)):
  1149. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  1150. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  1151. d_model = self.config.d_model
  1152. key_value_proj_dim = self.config.d_kv
  1153. n_heads = self.config.num_heads
  1154. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  1155. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  1156. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  1157. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  1158. if module.has_relative_attention_bias:
  1159. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  1160. if isinstance(module, LongT5TransientGlobalAttention):
  1161. module.global_relative_attention_bias.weight.data.normal_(
  1162. mean=0.0, std=factor * ((d_model) ** -0.5)
  1163. )
  1164. # Copied from transformers.models.t5.modeling_t5.T5PreTrainedModel._shift_right with T5->LongT5
  1165. def _shift_right(self, input_ids):
  1166. decoder_start_token_id = self.config.decoder_start_token_id
  1167. pad_token_id = self.config.pad_token_id
  1168. if decoder_start_token_id is None:
  1169. raise ValueError(
  1170. "self.model.config.decoder_start_token_id has to be defined. In LongT5 it is usually set to the pad_token_id. "
  1171. "See LongT5 docs for more information."
  1172. )
  1173. # shift inputs to the right
  1174. if is_torch_fx_proxy(input_ids):
  1175. # Item assignment is not supported natively for proxies.
  1176. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  1177. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  1178. else:
  1179. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  1180. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  1181. shifted_input_ids[..., 0] = decoder_start_token_id
  1182. if pad_token_id is None:
  1183. raise ValueError("self.model.config.pad_token_id has to be defined.")
  1184. # replace possible -100 values in labels by `pad_token_id`
  1185. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  1186. return shifted_input_ids
  1187. class LongT5Stack(LongT5PreTrainedModel):
  1188. def __init__(self, config, embed_tokens=None):
  1189. super().__init__(config)
  1190. self.embed_tokens = nn.Embedding(config.vocab_size, config.d_model)
  1191. if embed_tokens is not None:
  1192. self.embed_tokens.weight = embed_tokens.weight
  1193. self.is_decoder = config.is_decoder
  1194. self.local_radius = config.local_radius
  1195. self.block_len = self.local_radius + 1
  1196. self.block = nn.ModuleList(
  1197. [
  1198. LongT5Block(config, has_relative_attention_bias=bool(i == 0), layer_idx=i)
  1199. for i in range(config.num_layers)
  1200. ]
  1201. )
  1202. self.final_layer_norm = LongT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  1203. self.dropout = nn.Dropout(config.dropout_rate)
  1204. self.gradient_checkpointing = False
  1205. # Initialize weights and apply final processing
  1206. self.post_init()
  1207. # Copied from transformers.models.t5.modeling_t5.T5Stack.set_input_embeddings
  1208. def set_input_embeddings(self, new_embeddings):
  1209. self.embed_tokens = new_embeddings
  1210. def forward(
  1211. self,
  1212. input_ids=None,
  1213. attention_mask=None,
  1214. encoder_hidden_states=None,
  1215. encoder_attention_mask=None,
  1216. inputs_embeds=None,
  1217. head_mask=None,
  1218. cross_attn_head_mask=None,
  1219. past_key_values=None,
  1220. use_cache=None,
  1221. output_attentions=None,
  1222. output_hidden_states=None,
  1223. return_dict=None,
  1224. cache_position=None,
  1225. ):
  1226. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1227. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  1228. output_hidden_states = (
  1229. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1230. )
  1231. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1232. if input_ids is not None and inputs_embeds is not None:
  1233. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1234. raise ValueError(
  1235. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  1236. )
  1237. elif input_ids is not None:
  1238. input_shape = input_ids.size()
  1239. input_ids = input_ids.view(-1, input_shape[-1])
  1240. elif inputs_embeds is not None:
  1241. input_shape = inputs_embeds.size()[:-1]
  1242. else:
  1243. err_msg_prefix = "decoder_" if self.is_decoder else ""
  1244. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  1245. if self.gradient_checkpointing and self.training:
  1246. if use_cache:
  1247. logger.warning_once(
  1248. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  1249. )
  1250. use_cache = False
  1251. if inputs_embeds is None:
  1252. assert self.embed_tokens is not None, "You have to initialize the model with valid token embeddings"
  1253. inputs_embeds = self.embed_tokens(input_ids)
  1254. batch_size, seq_length = input_shape
  1255. if self.is_decoder:
  1256. if use_cache and past_key_values is None:
  1257. if self.config.is_encoder_decoder:
  1258. past_key_values = EncoderDecoderCache(
  1259. DynamicCache(config=self.config), DynamicCache(config=self.config)
  1260. )
  1261. else:
  1262. past_key_values = DynamicCache(config=self.config)
  1263. elif not self.is_decoder:
  1264. # do not pass cache object down the line for encoder stack
  1265. # it messes indexing later in decoder-stack because cache object is modified in-place
  1266. past_key_values = None
  1267. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  1268. if cache_position is None:
  1269. cache_position = torch.arange(
  1270. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  1271. )
  1272. if attention_mask is None and not is_torchdynamo_compiling():
  1273. # required mask seq length can be calculated via length of past
  1274. mask_seq_length = past_key_values_length + seq_length
  1275. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  1276. if self.is_decoder:
  1277. causal_mask = self._update_causal_mask(
  1278. attention_mask,
  1279. inputs_embeds,
  1280. cache_position,
  1281. past_key_values.self_attention_cache
  1282. if isinstance(past_key_values, EncoderDecoderCache)
  1283. else past_key_values,
  1284. output_attentions,
  1285. )
  1286. # We use local attention in encoder self-attention, otherwise standard self & cross attentions are used
  1287. elif self.config.encoder_attention_type == "local":
  1288. causal_mask = _get_local_attention_mask(attention_mask, self.block_len, inputs_embeds.device)
  1289. else: # we need to use both local attention mask and standard extended mask for transient-global attention
  1290. causal_mask = attention_mask
  1291. # If a 2D or 3D attention mask is provided for the cross-attention
  1292. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  1293. if self.is_decoder and encoder_hidden_states is not None:
  1294. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  1295. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  1296. if encoder_attention_mask is None:
  1297. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  1298. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  1299. else:
  1300. encoder_extended_attention_mask = None
  1301. # Prepare head mask if needed
  1302. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  1303. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  1304. all_hidden_states = () if output_hidden_states else None
  1305. all_attentions = () if output_attentions else None
  1306. all_cross_attentions = () if (output_attentions and self.is_decoder) else None
  1307. position_bias = None
  1308. encoder_decoder_position_bias = None
  1309. hidden_states = self.dropout(inputs_embeds)
  1310. for i, layer_module in enumerate(self.block):
  1311. layer_head_mask = head_mask[i]
  1312. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  1313. if output_hidden_states:
  1314. all_hidden_states = all_hidden_states + (hidden_states,)
  1315. layer_outputs = layer_module(
  1316. hidden_states,
  1317. causal_mask,
  1318. position_bias,
  1319. encoder_hidden_states,
  1320. encoder_extended_attention_mask,
  1321. encoder_decoder_position_bias, # as a positional argument for gradient checkpointing
  1322. layer_head_mask=layer_head_mask,
  1323. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  1324. past_key_values=past_key_values,
  1325. use_cache=use_cache,
  1326. output_attentions=output_attentions,
  1327. return_dict=return_dict,
  1328. cache_position=cache_position,
  1329. )
  1330. # layer_outputs is a tuple with:
  1331. # hidden-states, (self-attention position bias), (self-attention weights), (cross-attention position bias), (cross-attention weights)
  1332. hidden_states = layer_outputs[0]
  1333. # We share the position biases between the layers - the first layer store them
  1334. # layer_outputs = hidden-states, key-value-states (self-attention position bias), (self-attention weights),
  1335. # (cross-attention position bias), (cross-attention weights)
  1336. position_bias = layer_outputs[1]
  1337. if self.is_decoder and encoder_hidden_states is not None:
  1338. encoder_decoder_position_bias = layer_outputs[3 if output_attentions else 2]
  1339. if output_attentions:
  1340. all_attentions = all_attentions + (layer_outputs[2],)
  1341. if self.is_decoder:
  1342. all_cross_attentions = all_cross_attentions + (layer_outputs[4],)
  1343. hidden_states = self.final_layer_norm(hidden_states)
  1344. hidden_states = self.dropout(hidden_states)
  1345. # Add last layer
  1346. if output_hidden_states:
  1347. all_hidden_states = all_hidden_states + (hidden_states,)
  1348. if not return_dict:
  1349. return tuple(
  1350. v
  1351. for v in [
  1352. hidden_states,
  1353. past_key_values,
  1354. all_hidden_states,
  1355. all_attentions,
  1356. all_cross_attentions,
  1357. ]
  1358. if v is not None
  1359. )
  1360. return BaseModelOutputWithPastAndCrossAttentions(
  1361. last_hidden_state=hidden_states,
  1362. past_key_values=past_key_values,
  1363. hidden_states=all_hidden_states,
  1364. attentions=all_attentions,
  1365. cross_attentions=all_cross_attentions,
  1366. )
  1367. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  1368. def _update_causal_mask(
  1369. self,
  1370. attention_mask: Union[torch.Tensor, "BlockMask"],
  1371. input_tensor: torch.Tensor,
  1372. cache_position: torch.Tensor,
  1373. past_key_values: Cache,
  1374. output_attentions: bool = False,
  1375. ):
  1376. if self.config._attn_implementation == "flash_attention_2":
  1377. if attention_mask is not None and (attention_mask == 0.0).any():
  1378. return attention_mask
  1379. return None
  1380. if self.config._attn_implementation == "flex_attention":
  1381. if isinstance(attention_mask, torch.Tensor):
  1382. attention_mask = make_flex_block_causal_mask(attention_mask)
  1383. return attention_mask
  1384. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  1385. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  1386. # to infer the attention mask.
  1387. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  1388. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  1389. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  1390. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  1391. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  1392. attention_mask,
  1393. inputs_embeds=input_tensor,
  1394. past_key_values_length=past_seen_tokens,
  1395. is_training=self.training,
  1396. ):
  1397. return None
  1398. dtype = input_tensor.dtype
  1399. sequence_length = input_tensor.shape[1]
  1400. if using_compilable_cache:
  1401. target_length = past_key_values.get_max_cache_shape()
  1402. else:
  1403. target_length = (
  1404. attention_mask.shape[-1]
  1405. if isinstance(attention_mask, torch.Tensor)
  1406. else past_seen_tokens + sequence_length + 1
  1407. )
  1408. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  1409. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  1410. attention_mask,
  1411. sequence_length=sequence_length,
  1412. target_length=target_length,
  1413. dtype=dtype,
  1414. cache_position=cache_position,
  1415. batch_size=input_tensor.shape[0],
  1416. )
  1417. if (
  1418. self.config._attn_implementation == "sdpa"
  1419. and attention_mask is not None
  1420. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  1421. and not output_attentions
  1422. ):
  1423. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  1424. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  1425. # Details: https://github.com/pytorch/pytorch/issues/110213
  1426. min_dtype = torch.finfo(dtype).min
  1427. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  1428. return causal_mask
  1429. @staticmethod
  1430. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  1431. def _prepare_4d_causal_attention_mask_with_cache_position(
  1432. attention_mask: torch.Tensor,
  1433. sequence_length: int,
  1434. target_length: int,
  1435. dtype: torch.dtype,
  1436. cache_position: torch.Tensor,
  1437. batch_size: int,
  1438. **kwargs,
  1439. ):
  1440. """
  1441. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  1442. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  1443. Args:
  1444. attention_mask (`torch.Tensor`):
  1445. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  1446. `(batch_size, 1, query_length, key_value_length)`.
  1447. sequence_length (`int`):
  1448. The sequence length being processed.
  1449. target_length (`int`):
  1450. The target length: when generating with static cache, the mask should be as long as the static cache,
  1451. to account for the 0 padding, the part of the cache that is not filled yet.
  1452. dtype (`torch.dtype`):
  1453. The dtype to use for the 4D attention mask.
  1454. cache_position (`torch.Tensor`):
  1455. Indices depicting the position of the input sequence tokens in the sequence.
  1456. batch_size (`torch.Tensor`):
  1457. Batch size.
  1458. """
  1459. if attention_mask is not None and attention_mask.dim() == 4:
  1460. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  1461. causal_mask = attention_mask
  1462. else:
  1463. min_dtype = torch.finfo(dtype).min
  1464. causal_mask = torch.full(
  1465. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  1466. )
  1467. if sequence_length != 1:
  1468. causal_mask = torch.triu(causal_mask, diagonal=1)
  1469. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  1470. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  1471. if attention_mask is not None:
  1472. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  1473. mask_length = attention_mask.shape[-1]
  1474. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  1475. causal_mask.device
  1476. )
  1477. padding_mask = padding_mask == 0
  1478. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  1479. padding_mask, min_dtype
  1480. )
  1481. return causal_mask
  1482. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1483. __HEAD_MASK_WARNING_MSG = """
  1484. The input argument `head_mask` was split into two arguments `head_mask` and `decoder_head_mask`. Currently,
  1485. `decoder_head_mask` is set to copy `head_mask`, but this feature is deprecated and will be removed in future versions.
  1486. If you do not want to use any `decoder_head_mask` now, please set `decoder_head_mask = torch.ones(num_layers,
  1487. num_heads)`.
  1488. """
  1489. @auto_docstring
  1490. class LongT5Model(LongT5PreTrainedModel):
  1491. _keys_to_ignore_on_load_unexpected = [
  1492. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1493. ]
  1494. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1495. def __init__(self, config: LongT5Config):
  1496. super().__init__(config)
  1497. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1498. encoder_config = copy.deepcopy(config)
  1499. encoder_config.is_decoder = False
  1500. encoder_config.use_cache = False
  1501. encoder_config.tie_encoder_decoder = False
  1502. self.encoder = LongT5Stack(encoder_config, self.shared)
  1503. decoder_config = copy.deepcopy(config)
  1504. decoder_config.is_decoder = True
  1505. decoder_config.tie_encoder_decoder = False
  1506. decoder_config.num_layers = config.num_decoder_layers
  1507. self.decoder = LongT5Stack(decoder_config, self.shared)
  1508. # Initialize weights and apply final processing
  1509. self.post_init()
  1510. def get_input_embeddings(self):
  1511. return self.shared
  1512. def set_input_embeddings(self, new_embeddings):
  1513. self.shared = new_embeddings
  1514. self.encoder.set_input_embeddings(new_embeddings)
  1515. self.decoder.set_input_embeddings(new_embeddings)
  1516. def _tie_weights(self):
  1517. if self.config.tie_word_embeddings:
  1518. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1519. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1520. def get_encoder(self):
  1521. return self.encoder
  1522. def _prune_heads(self, heads_to_prune):
  1523. """
  1524. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1525. class PreTrainedModel
  1526. """
  1527. for layer, heads in heads_to_prune.items():
  1528. self.encoder.layer[layer].attention.prune_heads(heads)
  1529. @auto_docstring
  1530. def forward(
  1531. self,
  1532. input_ids: Optional[torch.LongTensor] = None,
  1533. attention_mask: Optional[torch.FloatTensor] = None,
  1534. decoder_input_ids: Optional[torch.LongTensor] = None,
  1535. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1536. head_mask: Optional[torch.FloatTensor] = None,
  1537. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1538. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1539. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  1540. past_key_values: Optional[Cache] = None,
  1541. inputs_embeds: Optional[torch.Tensor] = None,
  1542. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1543. use_cache: Optional[bool] = None,
  1544. output_attentions: Optional[bool] = None,
  1545. output_hidden_states: Optional[bool] = None,
  1546. return_dict: Optional[bool] = None,
  1547. cache_position: Optional[torch.LongTensor] = None,
  1548. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1549. r"""
  1550. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1551. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1552. you should be able to pad the inputs on both the right and the left.
  1553. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1554. [`PreTrainedTokenizer.__call__`] for detail.
  1555. [What are input IDs?](../glossary#input-ids)
  1556. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1557. Training](./longt5#training).
  1558. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1559. Indices of decoder input sequence tokens in the vocabulary.
  1560. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1561. [`PreTrainedTokenizer.__call__`] for details.
  1562. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1563. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1564. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1565. `past_key_values`).
  1566. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1567. Training](./longt5#training).
  1568. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1569. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1570. be used by default.
  1571. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1572. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1573. 1]`:
  1574. - 1 indicates the head is **not masked**,
  1575. - 0 indicates the head is **masked**.
  1576. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1577. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1578. `[0, 1]`:
  1579. - 1 indicates the head is **not masked**,
  1580. - 0 indicates the head is **masked**.
  1581. Example:
  1582. ```python
  1583. >>> from transformers import AutoTokenizer, LongT5Model
  1584. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1585. >>> model = LongT5Model.from_pretrained("google/long-t5-local-base")
  1586. >>> # Let's try a very long encoder input.
  1587. >>> input_ids = tokenizer(
  1588. ... 100 * "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1589. ... ).input_ids # Batch size 1
  1590. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1591. >>> # forward pass
  1592. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1593. >>> last_hidden_states = outputs.last_hidden_state
  1594. ```"""
  1595. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1596. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1597. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1598. if head_mask is not None and decoder_head_mask is None:
  1599. if self.config.num_layers == self.config.num_decoder_layers:
  1600. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1601. decoder_head_mask = head_mask
  1602. # Encode if needed (training, first prediction pass)
  1603. if encoder_outputs is None:
  1604. encoder_outputs = self.encoder(
  1605. input_ids=input_ids,
  1606. attention_mask=attention_mask,
  1607. inputs_embeds=inputs_embeds,
  1608. head_mask=head_mask,
  1609. output_attentions=output_attentions,
  1610. output_hidden_states=output_hidden_states,
  1611. return_dict=return_dict,
  1612. )
  1613. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1614. encoder_outputs = BaseModelOutput(
  1615. last_hidden_state=encoder_outputs[0],
  1616. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1617. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1618. )
  1619. hidden_states = encoder_outputs[0]
  1620. # Decode
  1621. decoder_outputs = self.decoder(
  1622. input_ids=decoder_input_ids,
  1623. attention_mask=decoder_attention_mask,
  1624. inputs_embeds=decoder_inputs_embeds,
  1625. past_key_values=past_key_values,
  1626. encoder_hidden_states=hidden_states,
  1627. encoder_attention_mask=attention_mask,
  1628. head_mask=decoder_head_mask,
  1629. cross_attn_head_mask=cross_attn_head_mask,
  1630. use_cache=use_cache,
  1631. output_attentions=output_attentions,
  1632. output_hidden_states=output_hidden_states,
  1633. return_dict=return_dict,
  1634. cache_position=cache_position,
  1635. )
  1636. if not return_dict:
  1637. return decoder_outputs + encoder_outputs
  1638. return Seq2SeqModelOutput(
  1639. last_hidden_state=decoder_outputs.last_hidden_state,
  1640. past_key_values=decoder_outputs.past_key_values,
  1641. decoder_hidden_states=decoder_outputs.hidden_states,
  1642. decoder_attentions=decoder_outputs.attentions,
  1643. cross_attentions=decoder_outputs.cross_attentions,
  1644. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1645. encoder_hidden_states=encoder_outputs.hidden_states,
  1646. encoder_attentions=encoder_outputs.attentions,
  1647. )
  1648. @auto_docstring(
  1649. custom_intro="""
  1650. LONGT5 Model with a `language modeling` head on top.
  1651. """
  1652. )
  1653. class LongT5ForConditionalGeneration(LongT5PreTrainedModel, GenerationMixin):
  1654. _keys_to_ignore_on_load_unexpected = [
  1655. r"decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight",
  1656. ]
  1657. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1658. def __init__(self, config: LongT5Config):
  1659. super().__init__(config)
  1660. self.model_dim = config.d_model
  1661. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1662. encoder_config = copy.deepcopy(config)
  1663. encoder_config.is_decoder = False
  1664. encoder_config.use_cache = False
  1665. encoder_config.tie_encoder_decoder = False
  1666. self.encoder = LongT5Stack(encoder_config, self.shared)
  1667. decoder_config = copy.deepcopy(config)
  1668. decoder_config.is_decoder = True
  1669. decoder_config.tie_encoder_decoder = False
  1670. decoder_config.num_layers = config.num_decoder_layers
  1671. self.decoder = LongT5Stack(decoder_config, self.shared)
  1672. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1673. # Initialize weights and apply final processing
  1674. self.post_init()
  1675. def get_input_embeddings(self):
  1676. return self.shared
  1677. def set_input_embeddings(self, new_embeddings):
  1678. self.shared = new_embeddings
  1679. self.encoder.set_input_embeddings(new_embeddings)
  1680. self.decoder.set_input_embeddings(new_embeddings)
  1681. def _tie_weights(self):
  1682. if self.config.tie_word_embeddings:
  1683. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1684. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1685. def get_encoder(self):
  1686. return self.encoder
  1687. @auto_docstring
  1688. def forward(
  1689. self,
  1690. input_ids: Optional[torch.LongTensor] = None,
  1691. attention_mask: Optional[torch.FloatTensor] = None,
  1692. decoder_input_ids: Optional[torch.LongTensor] = None,
  1693. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1694. head_mask: Optional[torch.FloatTensor] = None,
  1695. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1696. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1697. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1698. past_key_values: Optional[Cache] = None,
  1699. inputs_embeds: Optional[torch.FloatTensor] = None,
  1700. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1701. labels: Optional[torch.LongTensor] = None,
  1702. use_cache: Optional[bool] = None,
  1703. output_attentions: Optional[bool] = None,
  1704. output_hidden_states: Optional[bool] = None,
  1705. return_dict: Optional[bool] = None,
  1706. cache_position: Optional[torch.LongTensor] = None,
  1707. ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1708. r"""
  1709. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1710. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1711. you should be able to pad the inputs on both the right and the left.
  1712. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1713. [`PreTrainedTokenizer.__call__`] for detail.
  1714. [What are input IDs?](../glossary#input-ids)
  1715. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1716. Training](./longt5#training).
  1717. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1718. Indices of decoder input sequence tokens in the vocabulary.
  1719. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1720. [`PreTrainedTokenizer.__call__`] for details.
  1721. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1722. LONGT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If
  1723. `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
  1724. `past_key_values`).
  1725. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [LONGT5
  1726. Training](./longt5#training).
  1727. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1728. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1729. be used by default.
  1730. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1731. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1732. 1]`:
  1733. - 1 indicates the head is **not masked**,
  1734. - 0 indicates the head is **masked**.
  1735. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1736. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1737. `[0, 1]`:
  1738. - 1 indicates the head is **not masked**,
  1739. - 0 indicates the head is **masked**.
  1740. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1741. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1742. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1743. labels in `[0, ..., config.vocab_size]`
  1744. Examples:
  1745. ```python
  1746. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1747. >>> tokenizer = AutoTokenizer.from_pretrained("Stancld/longt5-tglobal-large-16384-pubmed-3k_steps")
  1748. >>> model = LongT5ForConditionalGeneration.from_pretrained(
  1749. ... "Stancld/longt5-tglobal-large-16384-pubmed-3k_steps"
  1750. ... )
  1751. >>> # Let's try a very long input.
  1752. >>> inputs = tokenizer(100 * "studies have shown that owning a dog is good for you ", return_tensors="pt")
  1753. >>> input_ids = inputs.input_ids
  1754. >>> outputs = model.generate(input_ids)
  1755. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  1756. abstractthe aim of this article is to provide an overview of the literature on the role of dog
  1757. ```"""
  1758. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1759. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1760. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1761. if head_mask is not None and decoder_head_mask is None:
  1762. if self.config.num_layers == self.config.num_decoder_layers:
  1763. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1764. decoder_head_mask = head_mask
  1765. # Encode if needed (training, first prediction pass)
  1766. if encoder_outputs is None:
  1767. # Convert encoder inputs in embeddings if needed
  1768. encoder_outputs = self.encoder(
  1769. input_ids=input_ids,
  1770. attention_mask=attention_mask,
  1771. inputs_embeds=inputs_embeds,
  1772. head_mask=head_mask,
  1773. output_attentions=output_attentions,
  1774. output_hidden_states=output_hidden_states,
  1775. return_dict=return_dict,
  1776. )
  1777. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1778. encoder_outputs = BaseModelOutput(
  1779. last_hidden_state=encoder_outputs[0],
  1780. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1781. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1782. )
  1783. hidden_states = encoder_outputs[0]
  1784. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1785. # get decoder inputs from shifting lm labels to the right
  1786. decoder_input_ids = self._shift_right(labels)
  1787. # Decode
  1788. decoder_outputs = self.decoder(
  1789. input_ids=decoder_input_ids,
  1790. attention_mask=decoder_attention_mask,
  1791. inputs_embeds=decoder_inputs_embeds,
  1792. past_key_values=past_key_values,
  1793. encoder_hidden_states=hidden_states,
  1794. encoder_attention_mask=attention_mask,
  1795. head_mask=decoder_head_mask,
  1796. cross_attn_head_mask=cross_attn_head_mask,
  1797. use_cache=use_cache,
  1798. output_attentions=output_attentions,
  1799. output_hidden_states=output_hidden_states,
  1800. return_dict=return_dict,
  1801. cache_position=cache_position,
  1802. )
  1803. sequence_output = decoder_outputs[0]
  1804. if self.config.tie_word_embeddings:
  1805. # Rescale output before projecting on vocab
  1806. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1807. sequence_output = sequence_output * (self.model_dim**-0.5)
  1808. lm_logits = self.lm_head(sequence_output)
  1809. loss = None
  1810. if labels is not None:
  1811. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1812. labels = labels.to(lm_logits.device)
  1813. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1814. # TODO(thom): Add z_loss https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  1815. if not return_dict:
  1816. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1817. return ((loss,) + output) if loss is not None else output
  1818. return Seq2SeqLMOutput(
  1819. loss=loss,
  1820. logits=lm_logits,
  1821. past_key_values=decoder_outputs.past_key_values,
  1822. decoder_hidden_states=decoder_outputs.hidden_states,
  1823. decoder_attentions=decoder_outputs.attentions,
  1824. cross_attentions=decoder_outputs.cross_attentions,
  1825. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1826. encoder_hidden_states=encoder_outputs.hidden_states,
  1827. encoder_attentions=encoder_outputs.attentions,
  1828. )
  1829. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1830. return self._shift_right(labels)
  1831. @auto_docstring
  1832. class LongT5EncoderModel(LongT5PreTrainedModel):
  1833. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1834. _keys_to_ignore_on_load_unexpected = [r"decoder"]
  1835. def __init__(self, config: LongT5Config):
  1836. super().__init__(config)
  1837. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1838. encoder_config = copy.deepcopy(config)
  1839. encoder_config.use_cache = False
  1840. encoder_config.tie_encoder_decoder = False
  1841. self.encoder = LongT5Stack(encoder_config, self.shared)
  1842. # Initialize weights and apply final processing
  1843. self.post_init()
  1844. def get_input_embeddings(self):
  1845. return self.shared
  1846. def set_input_embeddings(self, new_embeddings):
  1847. self.shared = new_embeddings
  1848. self.encoder.set_input_embeddings(new_embeddings)
  1849. def _tie_weights(self):
  1850. if self.config.tie_word_embeddings:
  1851. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1852. def get_encoder(self):
  1853. return self.encoder
  1854. def _prune_heads(self, heads_to_prune):
  1855. """
  1856. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1857. class PreTrainedModel
  1858. """
  1859. for layer, heads in heads_to_prune.items():
  1860. self.encoder.layer[layer].attention.prune_heads(heads)
  1861. @auto_docstring
  1862. def forward(
  1863. self,
  1864. input_ids: Optional[torch.LongTensor] = None,
  1865. attention_mask: Optional[torch.FloatTensor] = None,
  1866. head_mask: Optional[torch.FloatTensor] = None,
  1867. inputs_embeds: Optional[torch.FloatTensor] = None,
  1868. output_attentions: Optional[bool] = None,
  1869. output_hidden_states: Optional[bool] = None,
  1870. return_dict: Optional[bool] = None,
  1871. ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
  1872. r"""
  1873. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1874. Indices of input sequence tokens in the vocabulary. LongT5 is a model with relative position embeddings so
  1875. you should be able to pad the inputs on both the right and the left.
  1876. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1877. [`PreTrainedTokenizer.__call__`] for detail.
  1878. To know more on how to prepare `input_ids` for pretraining take a look a [LONGT5
  1879. Training](./longt5#training).
  1880. Example:
  1881. ```python
  1882. >>> from transformers import AutoTokenizer, LongT5ForConditionalGeneration
  1883. >>> tokenizer = AutoTokenizer.from_pretrained("google/long-t5-local-base")
  1884. >>> model = LongT5EncoderModel.from_pretrained("google/long-t5-local-base")
  1885. >>> input_ids = tokenizer(
  1886. ... 100 * "Studies have been shown that owning a dog is good for you ", return_tensors="pt"
  1887. ... ).input_ids # Batch size 1
  1888. >>> outputs = model(input_ids=input_ids)
  1889. >>> last_hidden_states = outputs.last_hidden_state
  1890. ```"""
  1891. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1892. encoder_outputs = self.encoder(
  1893. input_ids=input_ids,
  1894. attention_mask=attention_mask,
  1895. inputs_embeds=inputs_embeds,
  1896. head_mask=head_mask,
  1897. output_attentions=output_attentions,
  1898. output_hidden_states=output_hidden_states,
  1899. return_dict=return_dict,
  1900. )
  1901. return encoder_outputs
  1902. __all__ = ["LongT5EncoderModel", "LongT5ForConditionalGeneration", "LongT5Model", "LongT5PreTrainedModel"]