modeling_umt5.py 89 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950
  1. # coding=utf-8
  2. # Copyright 2023 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """PyTorch UMT5 model."""
  16. import copy
  17. import math
  18. from typing import Optional, Union
  19. import torch
  20. from torch import nn
  21. from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
  22. from ...activations import ACT2FN
  23. from ...cache_utils import Cache, DynamicCache, EncoderDecoderCache
  24. from ...generation import GenerationMixin
  25. from ...modeling_attn_mask_utils import AttentionMaskConverter
  26. from ...modeling_layers import GradientCheckpointingLayer
  27. from ...modeling_outputs import (
  28. BaseModelOutput,
  29. BaseModelOutputWithPastAndCrossAttentions,
  30. Seq2SeqLMOutput,
  31. Seq2SeqModelOutput,
  32. Seq2SeqQuestionAnsweringModelOutput,
  33. Seq2SeqSequenceClassifierOutput,
  34. TokenClassifierOutput,
  35. )
  36. from ...modeling_utils import PreTrainedModel
  37. from ...utils import (
  38. DUMMY_INPUTS,
  39. DUMMY_MASK,
  40. auto_docstring,
  41. is_torch_flex_attn_available,
  42. is_torch_fx_proxy,
  43. is_torchdynamo_compiling,
  44. logging,
  45. )
  46. from ...utils.deprecation import deprecate_kwarg
  47. from .configuration_umt5 import UMT5Config
  48. if is_torch_flex_attn_available():
  49. from torch.nn.attention.flex_attention import BlockMask
  50. from ...integrations.flex_attention import make_flex_block_causal_mask
  51. logger = logging.get_logger(__name__)
  52. # Copied from transformers.models.t5.modeling_t5.T5LayerNorm with T5->UMT5
  53. class UMT5LayerNorm(nn.Module):
  54. def __init__(self, hidden_size, eps=1e-6):
  55. """
  56. Construct a layernorm module in the UMT5 style. No bias and no subtraction of mean.
  57. """
  58. super().__init__()
  59. self.weight = nn.Parameter(torch.ones(hidden_size))
  60. self.variance_epsilon = eps
  61. def forward(self, hidden_states):
  62. # UMT5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  63. # Square Layer Normalization https://huggingface.co/papers/1910.07467 thus variance is calculated
  64. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  65. # half-precision inputs is done in fp32
  66. variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
  67. hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
  68. # convert into half-precision if necessary
  69. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  70. hidden_states = hidden_states.to(self.weight.dtype)
  71. return self.weight * hidden_states
  72. # Copied from transformers.models.t5.modeling_t5.T5DenseActDense with T5->UMT5
  73. class UMT5DenseActDense(nn.Module):
  74. def __init__(self, config: UMT5Config):
  75. super().__init__()
  76. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  77. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  78. self.dropout = nn.Dropout(config.dropout_rate)
  79. self.act = ACT2FN[config.dense_act_fn]
  80. def forward(self, hidden_states):
  81. hidden_states = self.wi(hidden_states)
  82. hidden_states = self.act(hidden_states)
  83. hidden_states = self.dropout(hidden_states)
  84. if (
  85. isinstance(self.wo.weight, torch.Tensor)
  86. and hidden_states.dtype != self.wo.weight.dtype
  87. and self.wo.weight.dtype != torch.int8
  88. ):
  89. hidden_states = hidden_states.to(self.wo.weight.dtype)
  90. hidden_states = self.wo(hidden_states)
  91. return hidden_states
  92. # Copied from transformers.models.t5.modeling_t5.T5DenseGatedActDense with T5->UMT5
  93. class UMT5DenseGatedActDense(nn.Module):
  94. def __init__(self, config: UMT5Config):
  95. super().__init__()
  96. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  97. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  98. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  99. self.dropout = nn.Dropout(config.dropout_rate)
  100. self.act = ACT2FN[config.dense_act_fn]
  101. def forward(self, hidden_states):
  102. hidden_gelu = self.act(self.wi_0(hidden_states))
  103. hidden_linear = self.wi_1(hidden_states)
  104. hidden_states = hidden_gelu * hidden_linear
  105. hidden_states = self.dropout(hidden_states)
  106. # To make 8bit quantization work for google/flan-t5-xxl, self.wo is kept in float32.
  107. # See https://github.com/huggingface/transformers/issues/20287
  108. # we also make sure the weights are not in `int8` in case users will force `_keep_in_fp32_modules` to be `None``
  109. if (
  110. isinstance(self.wo.weight, torch.Tensor)
  111. and hidden_states.dtype != self.wo.weight.dtype
  112. and self.wo.weight.dtype != torch.int8
  113. ):
  114. hidden_states = hidden_states.to(self.wo.weight.dtype)
  115. hidden_states = self.wo(hidden_states)
  116. return hidden_states
  117. # Copied from transformers.models.t5.modeling_t5.T5LayerFF with T5->UMT5
  118. class UMT5LayerFF(nn.Module):
  119. def __init__(self, config: UMT5Config):
  120. super().__init__()
  121. if config.is_gated_act:
  122. self.DenseReluDense = UMT5DenseGatedActDense(config)
  123. else:
  124. self.DenseReluDense = UMT5DenseActDense(config)
  125. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  126. self.dropout = nn.Dropout(config.dropout_rate)
  127. def forward(self, hidden_states):
  128. forwarded_states = self.layer_norm(hidden_states)
  129. forwarded_states = self.DenseReluDense(forwarded_states)
  130. hidden_states = hidden_states + self.dropout(forwarded_states)
  131. return hidden_states
  132. class UMT5Attention(nn.Module):
  133. """
  134. T5's attention using relative_attention_bias.
  135. """
  136. def __init__(self, config, has_relative_attention_bias=False, layer_idx: Optional[int] = None):
  137. super().__init__()
  138. self.is_decoder = config.is_decoder
  139. self.has_relative_attention_bias = has_relative_attention_bias
  140. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  141. self.relative_attention_max_distance = config.relative_attention_max_distance
  142. self.d_model = config.d_model
  143. self.key_value_proj_dim = config.d_kv
  144. self.n_heads = config.num_heads
  145. self.dropout = config.dropout_rate
  146. self.inner_dim = self.n_heads * self.key_value_proj_dim
  147. self.layer_idx = layer_idx
  148. if layer_idx is None and self.is_decoder:
  149. logger.warning_once(
  150. f"Instantiating a decoder {self.__class__.__name__} without passing `layer_idx` is not recommended and "
  151. "will to errors during the forward call, if caching is used. Please make sure to provide a `layer_idx` "
  152. "when creating this class."
  153. )
  154. # Mesh TensorFlow initialization to avoid scaling before softmax
  155. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  156. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  157. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  158. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  159. if self.has_relative_attention_bias:
  160. self.relative_attention_bias = nn.Embedding(self.relative_attention_num_buckets, self.n_heads)
  161. self.pruned_heads = set()
  162. def _shape(self, projection: torch.Tensor) -> torch.Tensor:
  163. new_projection_shape = projection.size()[:-1] + (self.n_heads, self.key_value_proj_dim)
  164. # move heads to 2nd position (B, T, H * D) -> (B, T, H, D) -> (B, H, T, D)
  165. new_projection = projection.view(new_projection_shape).permute(0, 2, 1, 3)
  166. return new_projection
  167. def _relative_position_bucket(self, relative_position):
  168. """
  169. Adapted from Mesh Tensorflow:
  170. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  171. Translate relative position to a bucket number for relative attention. The relative position is defined as
  172. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  173. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  174. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  175. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  176. This should allow for more graceful generalization to longer sequences than the model has been trained on
  177. Args:
  178. relative_position: an int32 Tensor
  179. bidirectional: a boolean - whether the attention is bidirectional
  180. num_buckets: an integer
  181. max_distance: an integer
  182. Returns:
  183. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  184. """
  185. relative_buckets = 0
  186. num_buckets = self.relative_attention_num_buckets
  187. max_distance = self.relative_attention_max_distance
  188. if not self.is_decoder:
  189. num_buckets //= 2
  190. relative_buckets += (relative_position > 0).to(torch.long) * num_buckets
  191. relative_position = torch.abs(relative_position)
  192. else:
  193. relative_position = -torch.min(relative_position, torch.zeros_like(relative_position))
  194. # now relative_position is in the range [0, inf)
  195. # half of the buckets are for exact increments in positions
  196. max_exact = num_buckets // 2
  197. is_small = relative_position < max_exact
  198. # The other half of the buckets are for logarithmically bigger bins in positions up to max_distance
  199. log_ratio = torch.log(relative_position.float() / max_exact) / math.log(max_distance / max_exact)
  200. log_ratio = log_ratio * (num_buckets - max_exact)
  201. relative_position_if_large = max_exact + log_ratio.to(torch.long)
  202. relative_position_if_large = torch.min(
  203. relative_position_if_large, torch.full_like(relative_position_if_large, num_buckets - 1)
  204. )
  205. relative_buckets += torch.where(is_small, relative_position, relative_position_if_large)
  206. return relative_buckets
  207. def compute_bias(self, query_length, key_length, device=None, cache_position=None):
  208. """Compute binned relative position bias"""
  209. if device is None:
  210. device = self.relative_attention_bias.weight.device
  211. if cache_position is None:
  212. context_position = torch.arange(query_length, dtype=torch.long, device=device)[:, None]
  213. else:
  214. context_position = cache_position[:, None]
  215. memory_position = torch.arange(key_length, dtype=torch.long, device=device)[None, :]
  216. relative_position = memory_position - context_position # shape (query_length, key_length)
  217. relative_position_bucket = self._relative_position_bucket(relative_position)
  218. values = self.relative_attention_bias(relative_position_bucket) # shape (query_length, key_length, num_heads)
  219. values = values.permute([2, 0, 1]).unsqueeze(0) # shape (1, num_heads, query_length, key_length)
  220. return values
  221. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  222. def forward(
  223. self,
  224. hidden_states: torch.Tensor,
  225. encoder_hidden_states: Optional[torch.Tensor] = None,
  226. past_key_values: Optional[Cache] = None,
  227. attention_mask: Optional[torch.Tensor] = None,
  228. layer_head_mask: Optional[torch.Tensor] = None,
  229. cache_position: Optional[torch.Tensor] = None,
  230. ):
  231. batch_size, seq_length = hidden_states.shape[:2]
  232. # if encoder_hidden_states are provided this layer is used as a cross-attention layer for the decoder
  233. is_cross_attention = encoder_hidden_states is not None
  234. query_states = self.q(hidden_states)
  235. query_states = query_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  236. # Check is encoder-decoder model is being used. Otherwise we'll get `DynamicCache`
  237. is_updated = False
  238. if past_key_values is not None and isinstance(past_key_values, EncoderDecoderCache):
  239. is_updated = past_key_values.is_updated.get(self.layer_idx)
  240. if is_cross_attention:
  241. # after the first generated id, we can subsequently re-use all key/value_states from cache
  242. curr_past_key_value = past_key_values.cross_attention_cache
  243. else:
  244. curr_past_key_value = past_key_values.self_attention_cache
  245. else:
  246. curr_past_key_value = past_key_values
  247. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  248. if is_cross_attention and past_key_values is not None and is_updated:
  249. # reuse k,v, cross_attentions
  250. key_states = curr_past_key_value.layers[self.layer_idx].keys
  251. value_states = curr_past_key_value.layers[self.layer_idx].values
  252. else:
  253. key_states = self.k(current_states)
  254. value_states = self.v(current_states)
  255. key_states = key_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  256. value_states = value_states.view(batch_size, -1, self.n_heads, self.key_value_proj_dim).transpose(1, 2)
  257. if past_key_values is not None:
  258. # save all key/value_states to cache to be re-used for fast auto-regressive generation
  259. cache_position = cache_position if not is_cross_attention else None
  260. key_states, value_states = curr_past_key_value.update(
  261. key_states, value_states, self.layer_idx, {"cache_position": cache_position}
  262. )
  263. # set flag that curr layer for cross-attn is already updated so we can re-use in subsequent calls
  264. if is_cross_attention and isinstance(past_key_values, EncoderDecoderCache):
  265. past_key_values.is_updated[self.layer_idx] = True
  266. # compute scores, equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  267. scores = torch.matmul(query_states, key_states.transpose(3, 2))
  268. # cache position is 0-indexed so we add 1 to get the real length of queries (aka with past)
  269. real_seq_length = seq_length + past_key_values.get_seq_length() if past_key_values is not None else seq_length
  270. key_length = key_states.shape[-2]
  271. if not self.has_relative_attention_bias:
  272. position_bias = torch.zeros(
  273. (1, self.n_heads, seq_length, key_length), device=scores.device, dtype=scores.dtype
  274. )
  275. else:
  276. position_bias = self.compute_bias(
  277. real_seq_length, key_length, device=scores.device, cache_position=cache_position
  278. )
  279. position_bias = position_bias[:, :, -seq_length:, :]
  280. if attention_mask is not None:
  281. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  282. position_bias = position_bias + causal_mask
  283. if self.pruned_heads:
  284. mask = torch.ones(position_bias.shape[1])
  285. mask[list(self.pruned_heads)] = 0
  286. position_bias_masked = position_bias[:, mask.bool()]
  287. else:
  288. position_bias_masked = position_bias
  289. scores += position_bias_masked
  290. # (batch_size, n_heads, seq_length, key_length)
  291. attn_weights = nn.functional.softmax(scores.float(), dim=-1).type_as(scores)
  292. attn_weights = nn.functional.dropout(attn_weights, p=self.dropout, training=self.training)
  293. # Mask heads if we want to
  294. if layer_head_mask is not None:
  295. attn_weights = attn_weights * layer_head_mask
  296. attn_output = torch.matmul(attn_weights, value_states)
  297. attn_output = attn_output.transpose(1, 2).contiguous()
  298. attn_output = attn_output.view(batch_size, seq_length, -1)
  299. attn_output = self.o(attn_output)
  300. return attn_output, attn_weights
  301. class UMT5LayerSelfAttention(nn.Module):
  302. def __init__(self, config, layer_idx: Optional[int] = None):
  303. super().__init__()
  304. self.SelfAttention = UMT5Attention(config, has_relative_attention_bias=True, layer_idx=layer_idx)
  305. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  306. self.dropout = nn.Dropout(config.dropout_rate)
  307. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  308. def forward(
  309. self,
  310. hidden_states,
  311. attention_mask=None,
  312. layer_head_mask=None,
  313. past_key_values=None,
  314. cache_position=None,
  315. ):
  316. normed_hidden_states = self.layer_norm(hidden_states)
  317. attention_output = self.SelfAttention(
  318. normed_hidden_states,
  319. attention_mask=attention_mask,
  320. layer_head_mask=layer_head_mask,
  321. past_key_values=past_key_values,
  322. cache_position=cache_position,
  323. )
  324. hidden_states = hidden_states + self.dropout(attention_output[0])
  325. outputs = (hidden_states,) + attention_output[1:] # add attentions if we output them
  326. return outputs
  327. class UMT5LayerCrossAttention(nn.Module):
  328. def __init__(self, config, layer_idx: Optional[int] = None):
  329. super().__init__()
  330. self.EncDecAttention = UMT5Attention(config, has_relative_attention_bias=False, layer_idx=layer_idx)
  331. self.layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  332. self.dropout = nn.Dropout(config.dropout_rate)
  333. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  334. def forward(
  335. self,
  336. hidden_states,
  337. encoder_hidden_states=None,
  338. attention_mask=None,
  339. layer_head_mask=None,
  340. past_key_values=None,
  341. cache_position=None,
  342. ):
  343. normed_hidden_states = self.layer_norm(hidden_states)
  344. attention_output = self.EncDecAttention(
  345. normed_hidden_states,
  346. encoder_hidden_states=encoder_hidden_states,
  347. attention_mask=attention_mask,
  348. layer_head_mask=layer_head_mask,
  349. past_key_values=past_key_values,
  350. cache_position=cache_position,
  351. )
  352. layer_output = hidden_states + self.dropout(attention_output[0])
  353. outputs = (layer_output,) + attention_output[1:] # add attentions if we output them
  354. return outputs
  355. class UMT5Block(GradientCheckpointingLayer):
  356. def __init__(self, config, layer_idx: Optional[int] = None):
  357. super().__init__()
  358. self.is_decoder = config.is_decoder
  359. self.layer = nn.ModuleList()
  360. self.layer.append(UMT5LayerSelfAttention(config, layer_idx=layer_idx))
  361. if self.is_decoder:
  362. self.layer.append(UMT5LayerCrossAttention(config, layer_idx=layer_idx))
  363. self.layer.append(UMT5LayerFF(config))
  364. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  365. def forward(
  366. self,
  367. hidden_states,
  368. attention_mask=None,
  369. encoder_hidden_states=None,
  370. encoder_attention_mask=None,
  371. layer_head_mask=None,
  372. cross_attn_layer_head_mask=None,
  373. past_key_values=None,
  374. use_cache=False,
  375. output_attentions=False,
  376. cache_position=None,
  377. ):
  378. hidden_states, self_attn_weights = self.layer[0](
  379. hidden_states,
  380. attention_mask=attention_mask,
  381. layer_head_mask=layer_head_mask,
  382. past_key_values=past_key_values,
  383. cache_position=cache_position,
  384. )
  385. # clamp inf values to enable fp16 training
  386. if hidden_states.dtype == torch.float16:
  387. max_dtype = torch.finfo(hidden_states.dtype).max
  388. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  389. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  390. # Cross-Attention Block
  391. cross_attn_weights = None
  392. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  393. if do_cross_attention:
  394. hidden_states, cross_attn_weights = self.layer[1](
  395. hidden_states,
  396. encoder_hidden_states=encoder_hidden_states,
  397. attention_mask=encoder_attention_mask,
  398. layer_head_mask=cross_attn_layer_head_mask,
  399. past_key_values=past_key_values,
  400. cache_position=cache_position,
  401. )
  402. # clamp inf values to enable fp16 training
  403. if hidden_states.dtype == torch.float16:
  404. max_dtype = torch.finfo(hidden_states.dtype).max
  405. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  406. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  407. # Apply Feed Forward layer
  408. hidden_states = self.layer[-1](hidden_states)
  409. # clamp inf values to enable fp16 training
  410. if hidden_states.dtype == torch.float16:
  411. max_dtype = torch.finfo(hidden_states.dtype).max
  412. clamp_value = torch.where(torch.isinf(hidden_states).any(), max_dtype - 1000, max_dtype)
  413. hidden_states = torch.clamp(hidden_states, min=-clamp_value, max=clamp_value)
  414. outputs = (hidden_states,)
  415. if output_attentions:
  416. outputs += (self_attn_weights, cross_attn_weights)
  417. return outputs
  418. # Copied from transformers.models.t5.modeling_t5.T5ClassificationHead with T5->UMT5
  419. class UMT5ClassificationHead(nn.Module):
  420. """Head for sentence-level classification tasks."""
  421. def __init__(self, config: UMT5Config):
  422. super().__init__()
  423. self.dense = nn.Linear(config.d_model, config.d_model)
  424. self.dropout = nn.Dropout(p=config.classifier_dropout)
  425. self.out_proj = nn.Linear(config.d_model, config.num_labels)
  426. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  427. hidden_states = self.dropout(hidden_states)
  428. hidden_states = self.dense(hidden_states)
  429. hidden_states = torch.tanh(hidden_states)
  430. hidden_states = self.dropout(hidden_states)
  431. hidden_states = self.out_proj(hidden_states)
  432. return hidden_states
  433. @auto_docstring
  434. class UMT5PreTrainedModel(PreTrainedModel):
  435. config: UMT5Config
  436. base_model_prefix = "transformer"
  437. supports_gradient_checkpointing = True
  438. _can_compile_fullgraph = True
  439. _no_split_modules = ["UMT5Block"]
  440. _keep_in_fp32_modules = ["wo"]
  441. @property
  442. def dummy_inputs(self):
  443. input_ids = torch.tensor(DUMMY_INPUTS)
  444. input_mask = torch.tensor(DUMMY_MASK)
  445. dummy_inputs = {
  446. "decoder_input_ids": input_ids,
  447. "input_ids": input_ids,
  448. "decoder_attention_mask": input_mask,
  449. }
  450. return dummy_inputs
  451. def _init_weights(self, module):
  452. """Initialize the weights"""
  453. factor = self.config.initializer_factor # Used for testing weights initialization
  454. if isinstance(module, UMT5LayerNorm):
  455. module.weight.data.fill_(factor * 1.0)
  456. elif isinstance(
  457. module,
  458. (
  459. UMT5Model,
  460. UMT5ForConditionalGeneration,
  461. UMT5EncoderModel,
  462. UMT5ForQuestionAnswering,
  463. ),
  464. ):
  465. # Mesh TensorFlow embeddings initialization
  466. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  467. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  468. if hasattr(module, "lm_head") and not self.config.tie_word_embeddings:
  469. module.lm_head.weight.data.normal_(mean=0.0, std=factor * 1.0)
  470. if hasattr(module, "qa_outputs"):
  471. module.qa_outputs.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  472. module.qa_outputs.bias.data.zero_()
  473. elif isinstance(module, UMT5ForTokenClassification):
  474. if hasattr(module, "classifier"):
  475. module.classifier.weight.data.normal_(mean=0.0, std=factor * 1.0)
  476. module.classifier.bias.data.zero_()
  477. elif isinstance(module, UMT5ClassificationHead):
  478. module.dense.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  479. if hasattr(module.dense, "bias") and module.dense.bias is not None:
  480. module.dense.bias.data.zero_()
  481. module.out_proj.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  482. if hasattr(module.out_proj, "bias") and module.out_proj.bias is not None:
  483. module.out_proj.bias.data.zero_()
  484. elif isinstance(module, UMT5DenseActDense):
  485. # Mesh TensorFlow FF initialization
  486. # See https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  487. # and https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  488. module.wi.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  489. if hasattr(module.wi, "bias") and module.wi.bias is not None:
  490. module.wi.bias.data.zero_()
  491. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  492. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  493. module.wo.bias.data.zero_()
  494. elif isinstance(module, UMT5DenseGatedActDense):
  495. module.wi_0.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  496. if hasattr(module.wi_0, "bias") and module.wi_0.bias is not None:
  497. module.wi_0.bias.data.zero_()
  498. module.wi_1.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_model) ** -0.5))
  499. if hasattr(module.wi_1, "bias") and module.wi_1.bias is not None:
  500. module.wi_1.bias.data.zero_()
  501. module.wo.weight.data.normal_(mean=0.0, std=factor * ((self.config.d_ff) ** -0.5))
  502. if hasattr(module.wo, "bias") and module.wo.bias is not None:
  503. module.wo.bias.data.zero_()
  504. elif isinstance(module, UMT5Attention):
  505. # Mesh TensorFlow attention initialization to avoid scaling before softmax
  506. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  507. d_model = self.config.d_model
  508. key_value_proj_dim = self.config.d_kv
  509. n_heads = self.config.num_heads
  510. module.q.weight.data.normal_(mean=0.0, std=factor * ((d_model * key_value_proj_dim) ** -0.5))
  511. module.k.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  512. module.v.weight.data.normal_(mean=0.0, std=factor * (d_model**-0.5))
  513. module.o.weight.data.normal_(mean=0.0, std=factor * ((n_heads * key_value_proj_dim) ** -0.5))
  514. if module.has_relative_attention_bias:
  515. module.relative_attention_bias.weight.data.normal_(mean=0.0, std=factor * ((d_model) ** -0.5))
  516. def _shift_right(self, input_ids):
  517. decoder_start_token_id = self.config.decoder_start_token_id
  518. pad_token_id = self.config.pad_token_id
  519. if decoder_start_token_id is None:
  520. raise ValueError(
  521. "self.model.config.decoder_start_token_id has to be defined. In UMT5 it is usually set to the pad_token_id. "
  522. "See UMT5 docs for more information."
  523. )
  524. # shift inputs to the right
  525. if is_torch_fx_proxy(input_ids):
  526. # Item assignment is not supported natively for proxies.
  527. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1,), decoder_start_token_id)
  528. shifted_input_ids = torch.cat([shifted_input_ids, input_ids[..., :-1]], dim=-1)
  529. else:
  530. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  531. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  532. shifted_input_ids[..., 0] = decoder_start_token_id
  533. if pad_token_id is None:
  534. raise ValueError("self.model.config.pad_token_id has to be defined.")
  535. # replace possible -100 values in labels by `pad_token_id`
  536. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  537. return shifted_input_ids
  538. class UMT5Stack(UMT5PreTrainedModel):
  539. def __init__(self, config, embed_tokens=None):
  540. super().__init__(config)
  541. self.embed_tokens = embed_tokens
  542. self.is_decoder = config.is_decoder
  543. self.block = nn.ModuleList([UMT5Block(config, layer_idx=i) for i in range(config.num_layers)])
  544. self.final_layer_norm = UMT5LayerNorm(config.d_model, eps=config.layer_norm_epsilon)
  545. self.dropout = nn.Dropout(config.dropout_rate)
  546. # Initialize weights and apply final processing
  547. self.gradient_checkpointing = False
  548. self.post_init()
  549. def set_input_embeddings(self, new_embeddings):
  550. self.embed_tokens = new_embeddings
  551. def forward(
  552. self,
  553. input_ids=None,
  554. attention_mask=None,
  555. encoder_hidden_states=None,
  556. encoder_attention_mask=None,
  557. inputs_embeds=None,
  558. head_mask=None,
  559. cross_attn_head_mask=None,
  560. past_key_values=None,
  561. use_cache=None,
  562. output_attentions=None,
  563. output_hidden_states=None,
  564. return_dict=None,
  565. cache_position=None,
  566. ):
  567. use_cache = use_cache if use_cache is not None else self.config.use_cache
  568. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  569. output_hidden_states = (
  570. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  571. )
  572. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  573. if input_ids is not None and inputs_embeds is not None:
  574. err_msg_prefix = "decoder_" if self.is_decoder else ""
  575. raise ValueError(
  576. f"You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time"
  577. )
  578. elif input_ids is not None:
  579. input_shape = input_ids.size()
  580. input_ids = input_ids.view(-1, input_shape[-1])
  581. elif inputs_embeds is not None:
  582. input_shape = inputs_embeds.size()[:-1]
  583. else:
  584. err_msg_prefix = "decoder_" if self.is_decoder else ""
  585. raise ValueError(f"You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds")
  586. if self.gradient_checkpointing and self.training:
  587. if use_cache:
  588. logger.warning_once(
  589. "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
  590. )
  591. use_cache = False
  592. if inputs_embeds is None:
  593. if self.embed_tokens is None:
  594. raise ValueError("You have to initialize the model with valid token embeddings")
  595. inputs_embeds = self.embed_tokens(input_ids)
  596. batch_size, seq_length = input_shape
  597. if use_cache is True:
  598. if not self.is_decoder:
  599. raise ValueError(f"`use_cache` can only be set to `True` if {self} is used as a decoder")
  600. # initialize past_key_values
  601. if self.is_decoder:
  602. if use_cache and past_key_values is None:
  603. if self.config.is_encoder_decoder:
  604. past_key_values = EncoderDecoderCache(
  605. DynamicCache(config=self.config), DynamicCache(config=self.config)
  606. )
  607. else:
  608. past_key_values = DynamicCache(config=self.config)
  609. elif not self.is_decoder:
  610. # do not pass cache object down the line for encoder stack
  611. # it messes indexing later in decoder-stack because cache object is modified in-place
  612. past_key_values = None
  613. past_key_values_length = past_key_values.get_seq_length() if past_key_values is not None else 0
  614. if cache_position is None:
  615. cache_position = torch.arange(
  616. past_key_values_length, past_key_values_length + seq_length, device=inputs_embeds.device
  617. )
  618. if attention_mask is None and not is_torchdynamo_compiling():
  619. # required mask seq length can be calculated via length of past cache
  620. mask_seq_length = past_key_values_length + seq_length
  621. attention_mask = torch.ones(batch_size, mask_seq_length, device=inputs_embeds.device)
  622. if self.is_decoder:
  623. causal_mask = self._update_causal_mask(
  624. attention_mask,
  625. inputs_embeds,
  626. cache_position,
  627. past_key_values.self_attention_cache
  628. if isinstance(past_key_values, EncoderDecoderCache)
  629. else past_key_values,
  630. output_attentions,
  631. )
  632. elif attention_mask is not None:
  633. causal_mask = attention_mask[:, None, None, :]
  634. causal_mask = causal_mask.to(dtype=inputs_embeds.dtype)
  635. causal_mask = (1.0 - causal_mask) * torch.finfo(inputs_embeds.dtype).min
  636. else:
  637. causal_mask = None
  638. # If a 2D or 3D attention mask is provided for the cross-attention
  639. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  640. if self.is_decoder and encoder_hidden_states is not None:
  641. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size()
  642. encoder_hidden_shape = (encoder_batch_size, encoder_sequence_length)
  643. if encoder_attention_mask is None:
  644. encoder_attention_mask = torch.ones(encoder_hidden_shape, device=inputs_embeds.device)
  645. encoder_extended_attention_mask = self.invert_attention_mask(encoder_attention_mask)
  646. else:
  647. encoder_extended_attention_mask = None
  648. # Prepare head mask if needed
  649. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  650. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask, self.config.num_layers)
  651. all_hidden_states = () if output_hidden_states else None
  652. all_attentions = () if output_attentions else None
  653. all_cross_attentions = () if output_attentions and self.is_decoder else None
  654. hidden_states = self.dropout(inputs_embeds)
  655. for i, layer_module in enumerate(self.block):
  656. layer_head_mask = head_mask[i]
  657. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  658. if output_hidden_states:
  659. all_hidden_states = all_hidden_states + (hidden_states,)
  660. layer_outputs = layer_module(
  661. hidden_states,
  662. causal_mask,
  663. encoder_hidden_states, # as a positional argument for gradient checkpointing
  664. encoder_attention_mask=encoder_extended_attention_mask,
  665. layer_head_mask=layer_head_mask,
  666. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  667. past_key_values=past_key_values,
  668. use_cache=use_cache,
  669. output_attentions=output_attentions,
  670. cache_position=cache_position,
  671. )
  672. hidden_states = layer_outputs[0]
  673. if output_attentions:
  674. all_attentions += (layer_outputs[1],)
  675. if self.is_decoder:
  676. all_cross_attentions += (layer_outputs[2],)
  677. hidden_states = self.final_layer_norm(hidden_states)
  678. hidden_states = self.dropout(hidden_states)
  679. # Add last layer
  680. if output_hidden_states:
  681. all_hidden_states = all_hidden_states + (hidden_states,)
  682. if not return_dict:
  683. return tuple(
  684. v
  685. for v in [
  686. hidden_states,
  687. past_key_values,
  688. all_hidden_states,
  689. all_attentions,
  690. all_cross_attentions,
  691. ]
  692. if v is not None
  693. )
  694. return BaseModelOutputWithPastAndCrossAttentions(
  695. last_hidden_state=hidden_states,
  696. past_key_values=past_key_values,
  697. hidden_states=all_hidden_states,
  698. attentions=all_attentions,
  699. cross_attentions=all_cross_attentions,
  700. )
  701. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._update_causal_mask
  702. def _update_causal_mask(
  703. self,
  704. attention_mask: Union[torch.Tensor, "BlockMask"],
  705. input_tensor: torch.Tensor,
  706. cache_position: torch.Tensor,
  707. past_key_values: Cache,
  708. output_attentions: bool = False,
  709. ):
  710. if self.config._attn_implementation == "flash_attention_2":
  711. if attention_mask is not None and (attention_mask == 0.0).any():
  712. return attention_mask
  713. return None
  714. if self.config._attn_implementation == "flex_attention":
  715. if isinstance(attention_mask, torch.Tensor):
  716. attention_mask = make_flex_block_causal_mask(attention_mask)
  717. return attention_mask
  718. # For SDPA, when possible, we will rely on its `is_causal` argument instead of its `attn_mask` argument, in
  719. # order to dispatch on Flash Attention 2. This feature is not compatible with static cache, as SDPA will fail
  720. # to infer the attention mask.
  721. past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
  722. using_compilable_cache = past_key_values.is_compileable if past_key_values is not None else False
  723. # When output attentions is True, sdpa implementation's forward method calls the eager implementation's forward
  724. if self.config._attn_implementation == "sdpa" and not using_compilable_cache and not output_attentions:
  725. if AttentionMaskConverter._ignore_causal_mask_sdpa(
  726. attention_mask,
  727. inputs_embeds=input_tensor,
  728. past_key_values_length=past_seen_tokens,
  729. is_training=self.training,
  730. ):
  731. return None
  732. dtype = input_tensor.dtype
  733. sequence_length = input_tensor.shape[1]
  734. if using_compilable_cache:
  735. target_length = past_key_values.get_max_cache_shape()
  736. else:
  737. target_length = (
  738. attention_mask.shape[-1]
  739. if isinstance(attention_mask, torch.Tensor)
  740. else past_seen_tokens + sequence_length + 1
  741. )
  742. # In case the provided `attention` mask is 2D, we generate a causal mask here (4D).
  743. causal_mask = self._prepare_4d_causal_attention_mask_with_cache_position(
  744. attention_mask,
  745. sequence_length=sequence_length,
  746. target_length=target_length,
  747. dtype=dtype,
  748. cache_position=cache_position,
  749. batch_size=input_tensor.shape[0],
  750. )
  751. if (
  752. self.config._attn_implementation == "sdpa"
  753. and attention_mask is not None
  754. and attention_mask.device.type in ["cuda", "xpu", "npu"]
  755. and not output_attentions
  756. ):
  757. # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when
  758. # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path.
  759. # Details: https://github.com/pytorch/pytorch/issues/110213
  760. min_dtype = torch.finfo(dtype).min
  761. causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype)
  762. return causal_mask
  763. @staticmethod
  764. # Copied from transformers.models.gptj.modeling_gptj.GPTJModel._prepare_4d_causal_attention_mask_with_cache_position
  765. def _prepare_4d_causal_attention_mask_with_cache_position(
  766. attention_mask: torch.Tensor,
  767. sequence_length: int,
  768. target_length: int,
  769. dtype: torch.dtype,
  770. cache_position: torch.Tensor,
  771. batch_size: int,
  772. **kwargs,
  773. ):
  774. """
  775. Creates a causal 4D mask of shape `(batch_size, 1, query_length, key_value_length)` from a 2D mask of shape
  776. `(batch_size, key_value_length)`, or if the input `attention_mask` is already 4D, do nothing.
  777. Args:
  778. attention_mask (`torch.Tensor`):
  779. A 2D attention mask of shape `(batch_size, key_value_length)` or a 4D attention mask of shape
  780. `(batch_size, 1, query_length, key_value_length)`.
  781. sequence_length (`int`):
  782. The sequence length being processed.
  783. target_length (`int`):
  784. The target length: when generating with static cache, the mask should be as long as the static cache,
  785. to account for the 0 padding, the part of the cache that is not filled yet.
  786. dtype (`torch.dtype`):
  787. The dtype to use for the 4D attention mask.
  788. cache_position (`torch.Tensor`):
  789. Indices depicting the position of the input sequence tokens in the sequence.
  790. batch_size (`torch.Tensor`):
  791. Batch size.
  792. """
  793. if attention_mask is not None and attention_mask.dim() == 4:
  794. # In this case we assume that the mask comes already in inverted form and requires no inversion or slicing.
  795. causal_mask = attention_mask
  796. else:
  797. min_dtype = torch.finfo(dtype).min
  798. causal_mask = torch.full(
  799. (sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=cache_position.device
  800. )
  801. if sequence_length != 1:
  802. causal_mask = torch.triu(causal_mask, diagonal=1)
  803. causal_mask *= torch.arange(target_length, device=cache_position.device) > cache_position.reshape(-1, 1)
  804. causal_mask = causal_mask[None, None, :, :].expand(batch_size, 1, -1, -1)
  805. if attention_mask is not None:
  806. causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit
  807. mask_length = attention_mask.shape[-1]
  808. padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :].to(
  809. causal_mask.device
  810. )
  811. padding_mask = padding_mask == 0
  812. causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(
  813. padding_mask, min_dtype
  814. )
  815. return causal_mask
  816. @auto_docstring
  817. class UMT5Model(UMT5PreTrainedModel):
  818. r"""
  819. Examples:
  820. ```python
  821. >>> from transformers import UMT5Model, AutoTokenizer
  822. >>> model = UMT5Model.from_pretrained("google/umt5-small")
  823. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  824. >>> noisy_text = "UN Offizier sagt, dass weiter <extra_id_0> werden muss in Syrien."
  825. >>> label = "<extra_id_0> verhandelt"
  826. >>> inputs = tokenizer(inputs, return_tensors="pt")
  827. >>> labels = tokenizer(label=label, return_tensors="pt")
  828. >>> outputs = model(input_ids=inputs["input_ids"], decoder_input_ids=labels["input_ids"])
  829. >>> hidden_states = outputs.last_hidden_state
  830. ```"""
  831. model_type = "umt5"
  832. config: UMT5Config
  833. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  834. def __init__(self, config):
  835. super().__init__(config)
  836. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  837. encoder_config = copy.deepcopy(config)
  838. encoder_config.is_decoder = False
  839. encoder_config.use_cache = False
  840. encoder_config.tie_encoder_decoder = False
  841. self.encoder = UMT5Stack(encoder_config, self.shared)
  842. decoder_config = copy.deepcopy(config)
  843. decoder_config.is_decoder = True
  844. decoder_config.tie_encoder_decoder = False
  845. decoder_config.num_layers = config.num_decoder_layers
  846. self.decoder = UMT5Stack(decoder_config, self.shared)
  847. # Initialize weights and apply final processing
  848. self.post_init()
  849. # Copied from transformers.models.t5.modeling_t5.T5Model.get_input_embeddings
  850. def get_input_embeddings(self):
  851. return self.shared
  852. # Copied from transformers.models.t5.modeling_t5.T5Model.set_input_embeddings
  853. def set_input_embeddings(self, new_embeddings):
  854. self.shared = new_embeddings
  855. self.encoder.set_input_embeddings(new_embeddings)
  856. self.decoder.set_input_embeddings(new_embeddings)
  857. # Copied from transformers.models.t5.modeling_t5.T5Model._tie_weights
  858. def _tie_weights(self):
  859. if self.config.tie_word_embeddings:
  860. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  861. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  862. # Copied from transformers.models.t5.modeling_t5.T5Model.get_encoder
  863. def get_encoder(self):
  864. return self.encoder
  865. # Copied from transformers.models.t5.modeling_t5.T5Model._prune_heads
  866. def _prune_heads(self, heads_to_prune):
  867. """
  868. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  869. class PreTrainedModel
  870. """
  871. for layer, heads in heads_to_prune.items():
  872. self.encoder.layer[layer].attention.prune_heads(heads)
  873. @auto_docstring
  874. def forward(
  875. self,
  876. input_ids: Optional[torch.LongTensor] = None,
  877. attention_mask: Optional[torch.FloatTensor] = None,
  878. decoder_input_ids: Optional[torch.LongTensor] = None,
  879. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  880. head_mask: Optional[torch.FloatTensor] = None,
  881. decoder_head_mask: Optional[torch.FloatTensor] = None,
  882. cross_attn_head_mask: Optional[torch.Tensor] = None,
  883. encoder_outputs: Optional[tuple[tuple[torch.FloatTensor]]] = None,
  884. past_key_values: Optional[Cache] = None,
  885. inputs_embeds: Optional[torch.Tensor] = None,
  886. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  887. use_cache: Optional[bool] = None,
  888. output_attentions: Optional[bool] = None,
  889. output_hidden_states: Optional[bool] = None,
  890. return_dict: Optional[bool] = None,
  891. cache_position: Optional[torch.LongTensor] = None,
  892. ) -> Union[tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  893. r"""
  894. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  895. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  896. you should be able to pad the inputs on both the right and the left.
  897. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  898. [`PreTrainedTokenizer.__call__`] for detail.
  899. [What are input IDs?](../glossary#input-ids)
  900. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  901. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  902. Indices of decoder input sequence tokens in the vocabulary.
  903. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  904. [`PreTrainedTokenizer.__call__`] for details.
  905. [What are decoder input IDs?](../glossary#decoder-input-ids)
  906. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  907. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  908. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  909. Training](./umt5#training).
  910. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  911. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  912. be used by default.
  913. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  914. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  915. 1]`:
  916. - 1 indicates the head is **not masked**,
  917. - 0 indicates the head is **masked**.
  918. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  919. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  920. `[0, 1]`:
  921. - 1 indicates the head is **not masked**,
  922. - 0 indicates the head is **masked**.
  923. Example:
  924. ```python
  925. >>> from transformers import AutoTokenizer, UMT5Model
  926. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  927. >>> model = UMT5Model.from_pretrained("google/umt5-small")
  928. >>> input_ids = tokenizer(
  929. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  930. ... ).input_ids # Batch size 1
  931. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  932. >>> # preprocess: Prepend decoder_input_ids with start token which is pad token for UMT5Model.
  933. >>> # This is not needed for torch's UMT5ForConditionalGeneration as it does this internally using labels arg.
  934. >>> decoder_input_ids = model._shift_right(decoder_input_ids)
  935. >>> # forward pass
  936. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  937. >>> last_hidden_states = outputs.last_hidden_state
  938. ```"""
  939. use_cache = use_cache if use_cache is not None else self.config.use_cache
  940. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  941. # Encode if needed (training, first prediction pass)
  942. if encoder_outputs is None:
  943. encoder_outputs = self.encoder(
  944. input_ids=input_ids,
  945. attention_mask=attention_mask,
  946. inputs_embeds=inputs_embeds,
  947. head_mask=head_mask,
  948. output_attentions=output_attentions,
  949. output_hidden_states=output_hidden_states,
  950. return_dict=return_dict,
  951. )
  952. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  953. encoder_outputs = BaseModelOutput(
  954. last_hidden_state=encoder_outputs[0],
  955. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  956. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  957. )
  958. hidden_states = encoder_outputs[0]
  959. # Decode
  960. decoder_outputs = self.decoder(
  961. input_ids=decoder_input_ids,
  962. attention_mask=decoder_attention_mask,
  963. inputs_embeds=decoder_inputs_embeds,
  964. past_key_values=past_key_values,
  965. encoder_hidden_states=hidden_states,
  966. encoder_attention_mask=attention_mask,
  967. head_mask=decoder_head_mask,
  968. cross_attn_head_mask=cross_attn_head_mask,
  969. use_cache=use_cache,
  970. output_attentions=output_attentions,
  971. output_hidden_states=output_hidden_states,
  972. return_dict=return_dict,
  973. cache_position=cache_position,
  974. )
  975. if not return_dict:
  976. return decoder_outputs + encoder_outputs
  977. return Seq2SeqModelOutput(
  978. last_hidden_state=decoder_outputs.last_hidden_state,
  979. past_key_values=decoder_outputs.past_key_values,
  980. decoder_hidden_states=decoder_outputs.hidden_states,
  981. decoder_attentions=decoder_outputs.attentions,
  982. cross_attentions=decoder_outputs.cross_attentions,
  983. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  984. encoder_hidden_states=encoder_outputs.hidden_states,
  985. encoder_attentions=encoder_outputs.attentions,
  986. )
  987. @auto_docstring(
  988. custom_intro="""
  989. UMT5 Model with a `language modeling` head on top.
  990. """
  991. )
  992. class UMT5ForConditionalGeneration(UMT5PreTrainedModel, GenerationMixin):
  993. r"""
  994. Examples:
  995. ```python
  996. >>> from transformers import UMT5ForConditionalGeneration, AutoTokenizer
  997. >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small")
  998. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  999. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1000. >>> summary = "Weiter Verhandlung in Syrien."
  1001. >>> inputs = tokenizer(article, text_target=summary, return_tensors="pt")
  1002. >>> outputs = model(**inputs)
  1003. >>> loss = outputs.loss
  1004. ```"""
  1005. model_type = "umt5"
  1006. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight", "lm_head.weight"]
  1007. def __init__(self, config):
  1008. super().__init__(config)
  1009. self.model_dim = config.d_model
  1010. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1011. encoder_config = copy.deepcopy(config)
  1012. encoder_config.is_decoder = False
  1013. encoder_config.use_cache = False
  1014. encoder_config.tie_encoder_decoder = False
  1015. self.encoder = UMT5Stack(encoder_config, self.shared)
  1016. decoder_config = copy.deepcopy(config)
  1017. decoder_config.is_decoder = True
  1018. decoder_config.tie_encoder_decoder = False
  1019. decoder_config.num_layers = config.num_decoder_layers
  1020. self.decoder = UMT5Stack(decoder_config, self.shared)
  1021. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  1022. # Initialize weights and apply final processing
  1023. self.post_init()
  1024. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_input_embeddings
  1025. def get_input_embeddings(self):
  1026. return self.shared
  1027. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.set_input_embeddings
  1028. def set_input_embeddings(self, new_embeddings):
  1029. self.shared = new_embeddings
  1030. self.encoder.set_input_embeddings(new_embeddings)
  1031. self.decoder.set_input_embeddings(new_embeddings)
  1032. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration._tie_weights
  1033. def _tie_weights(self):
  1034. if self.config.tie_word_embeddings:
  1035. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1036. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1037. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.get_encoder
  1038. def get_encoder(self):
  1039. return self.encoder
  1040. @auto_docstring
  1041. def forward(
  1042. self,
  1043. input_ids: Optional[torch.LongTensor] = None,
  1044. attention_mask: Optional[torch.FloatTensor] = None,
  1045. decoder_input_ids: Optional[torch.LongTensor] = None,
  1046. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1047. head_mask: Optional[torch.FloatTensor] = None,
  1048. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1049. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1050. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1051. past_key_values: Optional[Cache] = None,
  1052. inputs_embeds: Optional[torch.FloatTensor] = None,
  1053. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1054. labels: Optional[torch.LongTensor] = None,
  1055. use_cache: Optional[bool] = None,
  1056. output_attentions: Optional[bool] = None,
  1057. output_hidden_states: Optional[bool] = None,
  1058. return_dict: Optional[bool] = None,
  1059. cache_position: Optional[torch.LongTensor] = None,
  1060. ) -> Union[tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  1061. r"""
  1062. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1063. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  1064. you should be able to pad the inputs on both the right and the left.
  1065. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1066. [`PreTrainedTokenizer.__call__`] for detail.
  1067. [What are input IDs?](../glossary#input-ids)
  1068. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1069. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1070. Indices of decoder input sequence tokens in the vocabulary.
  1071. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1072. [`PreTrainedTokenizer.__call__`] for details.
  1073. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1074. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1075. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1076. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  1077. Training](./umt5#training).
  1078. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1079. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1080. be used by default.
  1081. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1082. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1083. 1]`:
  1084. - 1 indicates the head is **not masked**,
  1085. - 0 indicates the head is **masked**.
  1086. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1087. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1088. `[0, 1]`:
  1089. - 1 indicates the head is **not masked**,
  1090. - 0 indicates the head is **masked**.
  1091. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1092. Labels for computing the sequence classification/regression loss. Indices should be in `[-100, 0, ...,
  1093. config.vocab_size - 1]`. All labels set to `-100` are ignored (masked), the loss is only computed for
  1094. labels in `[0, ..., config.vocab_size]`
  1095. Examples:
  1096. ```python
  1097. >>> from transformers import AutoTokenizer, UMT5ForConditionalGeneration
  1098. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  1099. >>> model = UMT5ForConditionalGeneration.from_pretrained("google/umt5-small")
  1100. >>> # training
  1101. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  1102. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  1103. >>> outputs = model(input_ids=input_ids, labels=labels)
  1104. >>> loss = outputs.loss
  1105. >>> logits = outputs.logits
  1106. >>> # inference
  1107. >>> input_ids = tokenizer("Studies have shown that <extra_id_0> good for you", return_tensors="pt").input_ids
  1108. >>> outputs = model.generate(input_ids)
  1109. >>> tokenizer.decode(outputs[0], skip_special_tokens=True)
  1110. ```"""
  1111. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1112. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1113. # Encode if needed (training, first prediction pass)
  1114. if encoder_outputs is None:
  1115. # Convert encoder inputs in embeddings if needed
  1116. encoder_outputs = self.encoder(
  1117. input_ids=input_ids,
  1118. attention_mask=attention_mask,
  1119. inputs_embeds=inputs_embeds,
  1120. head_mask=head_mask,
  1121. output_attentions=output_attentions,
  1122. output_hidden_states=output_hidden_states,
  1123. return_dict=return_dict,
  1124. )
  1125. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1126. encoder_outputs = BaseModelOutput(
  1127. last_hidden_state=encoder_outputs[0],
  1128. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1129. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1130. )
  1131. hidden_states = encoder_outputs[0]
  1132. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  1133. # get decoder inputs from shifting lm labels to the right
  1134. decoder_input_ids = self._shift_right(labels)
  1135. # Decode
  1136. decoder_outputs = self.decoder(
  1137. input_ids=decoder_input_ids,
  1138. attention_mask=decoder_attention_mask,
  1139. inputs_embeds=decoder_inputs_embeds,
  1140. past_key_values=past_key_values,
  1141. encoder_hidden_states=hidden_states,
  1142. encoder_attention_mask=attention_mask,
  1143. head_mask=decoder_head_mask,
  1144. cross_attn_head_mask=cross_attn_head_mask,
  1145. use_cache=use_cache,
  1146. output_attentions=output_attentions,
  1147. output_hidden_states=output_hidden_states,
  1148. return_dict=return_dict,
  1149. cache_position=cache_position,
  1150. )
  1151. sequence_output = decoder_outputs[0]
  1152. if self.config.tie_word_embeddings:
  1153. # Rescale output before projecting on vocab
  1154. # See https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  1155. sequence_output = sequence_output * (self.model_dim**-0.5)
  1156. lm_logits = self.lm_head(sequence_output)
  1157. loss = None
  1158. if labels is not None:
  1159. loss_fct = CrossEntropyLoss(ignore_index=-100)
  1160. # move labels to correct device to enable PP
  1161. labels = labels.to(lm_logits.device)
  1162. loss = loss_fct(lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  1163. if not return_dict:
  1164. output = (lm_logits,) + decoder_outputs[1:] + encoder_outputs
  1165. return ((loss,) + output) if loss is not None else output
  1166. return Seq2SeqLMOutput(
  1167. loss=loss,
  1168. logits=lm_logits,
  1169. past_key_values=decoder_outputs.past_key_values,
  1170. decoder_hidden_states=decoder_outputs.hidden_states,
  1171. decoder_attentions=decoder_outputs.attentions,
  1172. cross_attentions=decoder_outputs.cross_attentions,
  1173. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1174. encoder_hidden_states=encoder_outputs.hidden_states,
  1175. encoder_attentions=encoder_outputs.attentions,
  1176. )
  1177. # Copied from transformers.models.t5.modeling_t5.T5ForConditionalGeneration.prepare_decoder_input_ids_from_labels
  1178. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  1179. return self._shift_right(labels)
  1180. @auto_docstring
  1181. class UMT5EncoderModel(UMT5PreTrainedModel):
  1182. r"""
  1183. Examples:
  1184. ```python
  1185. >>> from transformers import UMT5EncoderModel, AutoTokenizer
  1186. >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small")
  1187. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  1188. >>> article = "UN Offizier sagt, dass weiter verhandelt werden muss in Syrien."
  1189. >>> input_ids = tokenizer(article, return_tensors="pt").input_ids
  1190. >>> outputs = model(input_ids)
  1191. >>> hidden_state = outputs.last_hidden_state
  1192. ```"""
  1193. model_type = "umt5"
  1194. # config_class = UMT5Config
  1195. _tied_weights_keys = ["encoder.embed_tokens.weight"]
  1196. def __init__(self, config):
  1197. super().__init__(config)
  1198. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1199. encoder_config = copy.deepcopy(config)
  1200. encoder_config.use_cache = False
  1201. encoder_config.is_encoder_decoder = False
  1202. self.encoder = UMT5Stack(encoder_config, self.shared)
  1203. # Initialize weights and apply final processing
  1204. self.post_init()
  1205. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_input_embeddings
  1206. def get_input_embeddings(self):
  1207. return self.shared
  1208. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.set_input_embeddings
  1209. def set_input_embeddings(self, new_embeddings):
  1210. self.shared = new_embeddings
  1211. self.encoder.set_input_embeddings(new_embeddings)
  1212. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._tie_weights
  1213. def _tie_weights(self):
  1214. if self.config.tie_word_embeddings:
  1215. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1216. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.get_encoder
  1217. def get_encoder(self):
  1218. return self.encoder
  1219. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel._prune_heads
  1220. def _prune_heads(self, heads_to_prune):
  1221. """
  1222. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of heads to prune in this layer} See base
  1223. class PreTrainedModel
  1224. """
  1225. for layer, heads in heads_to_prune.items():
  1226. self.encoder.block[layer].layer[0].SelfAttention.prune_heads(heads)
  1227. @auto_docstring
  1228. # Copied from transformers.models.t5.modeling_t5.T5EncoderModel.forward with T5->UMT5, google-t5/t5-small->google/umt5-small, t5#training->umt5#training
  1229. def forward(
  1230. self,
  1231. input_ids: Optional[torch.LongTensor] = None,
  1232. attention_mask: Optional[torch.FloatTensor] = None,
  1233. head_mask: Optional[torch.FloatTensor] = None,
  1234. inputs_embeds: Optional[torch.FloatTensor] = None,
  1235. output_attentions: Optional[bool] = None,
  1236. output_hidden_states: Optional[bool] = None,
  1237. return_dict: Optional[bool] = None,
  1238. ) -> Union[tuple[torch.FloatTensor], BaseModelOutput]:
  1239. r"""
  1240. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1241. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you
  1242. should be able to pad the inputs on both the right and the left.
  1243. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1244. [`PreTrainedTokenizer.__call__`] for detail.
  1245. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1246. Example:
  1247. ```python
  1248. >>> from transformers import AutoTokenizer, UMT5EncoderModel
  1249. >>> tokenizer = AutoTokenizer.from_pretrained("google/umt5-small")
  1250. >>> model = UMT5EncoderModel.from_pretrained("google/umt5-small")
  1251. >>> input_ids = tokenizer(
  1252. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1253. ... ).input_ids # Batch size 1
  1254. >>> outputs = model(input_ids=input_ids)
  1255. >>> last_hidden_states = outputs.last_hidden_state
  1256. ```"""
  1257. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1258. encoder_outputs = self.encoder(
  1259. input_ids=input_ids,
  1260. attention_mask=attention_mask,
  1261. inputs_embeds=inputs_embeds,
  1262. head_mask=head_mask,
  1263. output_attentions=output_attentions,
  1264. output_hidden_states=output_hidden_states,
  1265. return_dict=return_dict,
  1266. )
  1267. return encoder_outputs
  1268. @auto_docstring(
  1269. custom_intro="""
  1270. UMT5 model with a sequence classification/head on top (a linear layer on top of the pooled output) e.g. for GLUE
  1271. tasks.
  1272. """
  1273. )
  1274. class UMT5ForSequenceClassification(UMT5PreTrainedModel):
  1275. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1276. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1277. # Copied from transformers.models.t5.modeling_t5.T5ForSequenceClassification.__init__ with T5->UMT5
  1278. def __init__(self, config: UMT5Config):
  1279. super().__init__(config)
  1280. self.transformer = UMT5Model(config)
  1281. self.classification_head = UMT5ClassificationHead(config)
  1282. # Initialize weights and apply final processing
  1283. self.post_init()
  1284. self.model_parallel = False
  1285. @auto_docstring
  1286. def forward(
  1287. self,
  1288. input_ids: Optional[torch.LongTensor] = None,
  1289. attention_mask: Optional[torch.Tensor] = None,
  1290. decoder_input_ids: Optional[torch.LongTensor] = None,
  1291. decoder_attention_mask: Optional[torch.LongTensor] = None,
  1292. head_mask: Optional[torch.Tensor] = None,
  1293. decoder_head_mask: Optional[torch.Tensor] = None,
  1294. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1295. encoder_outputs: Optional[list[torch.FloatTensor]] = None,
  1296. inputs_embeds: Optional[torch.FloatTensor] = None,
  1297. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1298. labels: Optional[torch.LongTensor] = None,
  1299. use_cache: Optional[bool] = None,
  1300. output_attentions: Optional[bool] = None,
  1301. output_hidden_states: Optional[bool] = None,
  1302. return_dict: Optional[bool] = None,
  1303. ) -> Union[tuple, Seq2SeqSequenceClassifierOutput]:
  1304. r"""
  1305. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1306. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  1307. you should be able to pad the inputs on both the right and the left.
  1308. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1309. [`PreTrainedTokenizer.__call__`] for detail.
  1310. [What are input IDs?](../glossary#input-ids)
  1311. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1312. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1313. Indices of decoder input sequence tokens in the vocabulary.
  1314. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1315. [`PreTrainedTokenizer.__call__`] for details.
  1316. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1317. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1318. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1319. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  1320. Training](./umt5#training).
  1321. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1322. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1323. be used by default.
  1324. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1325. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1326. 1]`:
  1327. - 1 indicates the head is **not masked**,
  1328. - 0 indicates the head is **masked**.
  1329. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1330. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1331. `[0, 1]`:
  1332. - 1 indicates the head is **not masked**,
  1333. - 0 indicates the head is **masked**.
  1334. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  1335. Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
  1336. config.num_labels - 1]`. If `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
  1337. """
  1338. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1339. if labels is not None:
  1340. use_cache = False
  1341. if input_ids is None and inputs_embeds is not None:
  1342. raise NotImplementedError(
  1343. f"Passing input embeddings is currently not supported for {self.__class__.__name__}"
  1344. )
  1345. # Copied from models.bart.modeling_bart.BartModel.forward different to other models, T5 automatically creates
  1346. # decoder_input_ids from input_ids if no decoder_input_ids are provided
  1347. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1348. if input_ids is None:
  1349. raise ValueError(
  1350. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1351. "passed, `input_ids` cannot be `None`. Please pass either "
  1352. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1353. )
  1354. decoder_input_ids = self._shift_right(input_ids)
  1355. outputs = self.transformer(
  1356. input_ids,
  1357. attention_mask=attention_mask,
  1358. decoder_input_ids=decoder_input_ids,
  1359. decoder_attention_mask=decoder_attention_mask,
  1360. head_mask=head_mask,
  1361. decoder_head_mask=decoder_head_mask,
  1362. cross_attn_head_mask=cross_attn_head_mask,
  1363. encoder_outputs=encoder_outputs,
  1364. inputs_embeds=inputs_embeds,
  1365. decoder_inputs_embeds=decoder_inputs_embeds,
  1366. use_cache=use_cache,
  1367. output_attentions=output_attentions,
  1368. output_hidden_states=output_hidden_states,
  1369. return_dict=return_dict,
  1370. )
  1371. sequence_output = outputs[0]
  1372. eos_mask = input_ids.eq(self.config.eos_token_id).to(sequence_output.device)
  1373. if len(torch.unique_consecutive(eos_mask.sum(1))) > 1:
  1374. raise ValueError("All examples must have the same number of <eos> tokens.")
  1375. batch_size, _, hidden_size = sequence_output.shape
  1376. sentence_representation = sequence_output[eos_mask, :].view(batch_size, -1, hidden_size)[:, -1, :]
  1377. logits = self.classification_head(sentence_representation)
  1378. loss = None
  1379. if labels is not None:
  1380. labels = labels.to(logits.device)
  1381. if self.config.problem_type is None:
  1382. if self.config.num_labels == 1:
  1383. self.config.problem_type = "regression"
  1384. elif self.config.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
  1385. self.config.problem_type = "single_label_classification"
  1386. else:
  1387. self.config.problem_type = "multi_label_classification"
  1388. if self.config.problem_type == "regression":
  1389. loss_fct = MSELoss()
  1390. if self.config.num_labels == 1:
  1391. loss = loss_fct(logits.squeeze(), labels.squeeze())
  1392. else:
  1393. loss = loss_fct(logits, labels)
  1394. elif self.config.problem_type == "single_label_classification":
  1395. loss_fct = CrossEntropyLoss()
  1396. loss = loss_fct(logits.view(-1, self.config.num_labels), labels.view(-1))
  1397. elif self.config.problem_type == "multi_label_classification":
  1398. loss_fct = BCEWithLogitsLoss()
  1399. loss = loss_fct(logits, labels)
  1400. if not return_dict:
  1401. output = (logits,) + outputs[1:]
  1402. return ((loss,) + output) if loss is not None else output
  1403. return Seq2SeqSequenceClassifierOutput(
  1404. loss=loss,
  1405. logits=logits,
  1406. past_key_values=outputs.past_key_values,
  1407. decoder_hidden_states=outputs.decoder_hidden_states,
  1408. decoder_attentions=outputs.decoder_attentions,
  1409. cross_attentions=outputs.cross_attentions,
  1410. encoder_last_hidden_state=outputs.encoder_last_hidden_state,
  1411. encoder_hidden_states=outputs.encoder_hidden_states,
  1412. encoder_attentions=outputs.encoder_attentions,
  1413. )
  1414. @auto_docstring
  1415. class UMT5ForTokenClassification(UMT5PreTrainedModel):
  1416. _keys_to_ignore_on_load_unexpected = ["decoder.block.0.layer.1.EncDecAttention.relative_attention_bias.weight"]
  1417. _tied_weights_keys = ["transformer.encoder.embed_tokens.weight"]
  1418. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.__init__ with T5->UMT5
  1419. def __init__(self, config: UMT5Config):
  1420. super().__init__(config)
  1421. self.num_labels = config.num_labels
  1422. self.transformer = UMT5EncoderModel(config)
  1423. self.dropout = nn.Dropout(config.classifier_dropout)
  1424. self.classifier = nn.Linear(config.hidden_size, config.num_labels)
  1425. # Initialize weights and apply final processing
  1426. self.post_init()
  1427. @auto_docstring
  1428. # Copied from transformers.models.t5.modeling_t5.T5ForTokenClassification.forward with T5->UMT5, t5->umt5
  1429. def forward(
  1430. self,
  1431. input_ids: Optional[torch.Tensor] = None,
  1432. attention_mask: Optional[torch.Tensor] = None,
  1433. head_mask: Optional[torch.Tensor] = None,
  1434. inputs_embeds: Optional[torch.Tensor] = None,
  1435. labels: Optional[torch.Tensor] = None,
  1436. output_attentions: Optional[bool] = None,
  1437. output_hidden_states: Optional[bool] = None,
  1438. return_dict: Optional[bool] = None,
  1439. ) -> Union[tuple[torch.Tensor], TokenClassifierOutput]:
  1440. r"""
  1441. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1442. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so you
  1443. should be able to pad the inputs on both the right and the left.
  1444. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1445. [`PreTrainedTokenizer.__call__`] for detail.
  1446. [What are input IDs?](../glossary#input-ids)
  1447. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1448. labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
  1449. Labels for computing the token classification loss. Indices should be in `[0, ..., config.num_labels - 1]`.
  1450. """
  1451. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1452. outputs = self.transformer(
  1453. input_ids,
  1454. attention_mask=attention_mask,
  1455. head_mask=head_mask,
  1456. inputs_embeds=inputs_embeds,
  1457. output_attentions=output_attentions,
  1458. output_hidden_states=output_hidden_states,
  1459. return_dict=return_dict,
  1460. )
  1461. hidden_states = outputs[0]
  1462. hidden_states = self.dropout(hidden_states)
  1463. logits = self.classifier(hidden_states)
  1464. loss = None
  1465. if labels is not None:
  1466. loss_fct = CrossEntropyLoss()
  1467. loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
  1468. if not return_dict:
  1469. output = (logits, outputs[2:-1])
  1470. return ((loss,) + output) if loss is not None else output
  1471. return TokenClassifierOutput(
  1472. loss=loss,
  1473. logits=logits,
  1474. hidden_states=outputs.hidden_states,
  1475. attentions=outputs.attentions,
  1476. )
  1477. @auto_docstring
  1478. class UMT5ForQuestionAnswering(UMT5PreTrainedModel):
  1479. _tied_weights_keys = ["encoder.embed_tokens.weight", "decoder.embed_tokens.weight"]
  1480. def __init__(self, config):
  1481. super().__init__(config)
  1482. self.model_dim = config.d_model
  1483. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1484. encoder_config = copy.deepcopy(config)
  1485. encoder_config.is_decoder = False
  1486. encoder_config.use_cache = False
  1487. encoder_config.tie_encoder_decoder = False
  1488. self.encoder = UMT5Stack(encoder_config, self.shared)
  1489. decoder_config = copy.deepcopy(config)
  1490. decoder_config.is_decoder = True
  1491. decoder_config.tie_encoder_decoder = False
  1492. decoder_config.num_layers = config.num_decoder_layers
  1493. self.decoder = UMT5Stack(decoder_config, self.shared)
  1494. self.num_labels = config.num_labels
  1495. self.qa_outputs = nn.Linear(config.d_model, config.num_labels)
  1496. # Initialize weights and apply final processing
  1497. self.post_init()
  1498. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_input_embeddings
  1499. def get_input_embeddings(self):
  1500. return self.shared
  1501. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.set_input_embeddings
  1502. def set_input_embeddings(self, new_embeddings):
  1503. self.shared = new_embeddings
  1504. self.encoder.set_input_embeddings(new_embeddings)
  1505. self.decoder.set_input_embeddings(new_embeddings)
  1506. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering._tie_weights
  1507. def _tie_weights(self):
  1508. if self.config.tie_word_embeddings:
  1509. self._tie_or_clone_weights(self.encoder.embed_tokens, self.shared)
  1510. self._tie_or_clone_weights(self.decoder.embed_tokens, self.shared)
  1511. # Copied from transformers.models.t5.modeling_t5.T5ForQuestionAnswering.get_encoder
  1512. def get_encoder(self):
  1513. return self.encoder
  1514. @auto_docstring
  1515. def forward(
  1516. self,
  1517. input_ids: Optional[torch.LongTensor] = None,
  1518. attention_mask: Optional[torch.FloatTensor] = None,
  1519. decoder_input_ids: Optional[torch.LongTensor] = None,
  1520. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1521. head_mask: Optional[torch.FloatTensor] = None,
  1522. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1523. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1524. encoder_outputs: Optional[tuple[tuple[torch.Tensor]]] = None,
  1525. start_positions: Optional[torch.LongTensor] = None,
  1526. end_positions: Optional[torch.LongTensor] = None,
  1527. inputs_embeds: Optional[torch.FloatTensor] = None,
  1528. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  1529. use_cache: Optional[bool] = None,
  1530. output_attentions: Optional[bool] = None,
  1531. output_hidden_states: Optional[bool] = None,
  1532. return_dict: Optional[bool] = None,
  1533. ) -> Union[tuple[torch.FloatTensor], Seq2SeqQuestionAnsweringModelOutput]:
  1534. r"""
  1535. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1536. Indices of input sequence tokens in the vocabulary. UMT5 is a model with relative position embeddings so
  1537. you should be able to pad the inputs on both the right and the left.
  1538. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1539. [`PreTrainedTokenizer.__call__`] for detail.
  1540. [What are input IDs?](../glossary#input-ids)
  1541. To know more on how to prepare `input_ids` for pretraining take a look a [UMT5 Training](./umt5#training).
  1542. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1543. Indices of decoder input sequence tokens in the vocabulary.
  1544. Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  1545. [`PreTrainedTokenizer.__call__`] for details.
  1546. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1547. UMT5 uses the `pad_token_id` as the starting token for `decoder_input_ids` generation. If `past_key_values`
  1548. is used, optionally only the last `decoder_input_ids` have to be input (see `past_key_values`).
  1549. To know more on how to prepare `decoder_input_ids` for pretraining take a look at [UMT5
  1550. Training](./umt5#training).
  1551. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  1552. Default behavior: generate a tensor that ignores pad tokens in `decoder_input_ids`. Causal mask will also
  1553. be used by default.
  1554. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1555. Mask to nullify selected heads of the self-attention modules in the decoder. Mask values selected in `[0,
  1556. 1]`:
  1557. - 1 indicates the head is **not masked**,
  1558. - 0 indicates the head is **masked**.
  1559. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  1560. Mask to nullify selected heads of the cross-attention modules in the decoder. Mask values selected in
  1561. `[0, 1]`:
  1562. - 1 indicates the head is **not masked**,
  1563. - 0 indicates the head is **masked**.
  1564. """
  1565. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1566. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1567. if start_positions is not None and end_positions is not None:
  1568. use_cache = False
  1569. # Copied from models.bart.modeling_bart.BartModel.forward
  1570. # different to other models, T5 automatically creates decoder_input_ids from
  1571. # input_ids if no decoder_input_ids are provided
  1572. if decoder_input_ids is None and decoder_inputs_embeds is None:
  1573. if input_ids is None:
  1574. raise ValueError(
  1575. "If no `decoder_input_ids` or `decoder_inputs_embeds` are "
  1576. "passed, `input_ids` cannot be `None`. Please pass either "
  1577. "`input_ids` or `decoder_input_ids` or `decoder_inputs_embeds`."
  1578. )
  1579. decoder_input_ids = self._shift_right(input_ids)
  1580. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1581. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1582. # Encode if needed (training, first prediction pass)
  1583. if encoder_outputs is None:
  1584. encoder_outputs = self.encoder(
  1585. input_ids=input_ids,
  1586. attention_mask=attention_mask,
  1587. inputs_embeds=inputs_embeds,
  1588. head_mask=head_mask,
  1589. output_attentions=output_attentions,
  1590. output_hidden_states=output_hidden_states,
  1591. return_dict=return_dict,
  1592. )
  1593. elif return_dict and not isinstance(encoder_outputs, BaseModelOutput):
  1594. encoder_outputs = BaseModelOutput(
  1595. last_hidden_state=encoder_outputs[0],
  1596. hidden_states=encoder_outputs[1] if len(encoder_outputs) > 1 else None,
  1597. attentions=encoder_outputs[2] if len(encoder_outputs) > 2 else None,
  1598. )
  1599. hidden_states = encoder_outputs[0]
  1600. # Decode
  1601. decoder_outputs = self.decoder(
  1602. input_ids=decoder_input_ids,
  1603. attention_mask=decoder_attention_mask,
  1604. inputs_embeds=decoder_inputs_embeds,
  1605. past_key_values=None,
  1606. encoder_hidden_states=hidden_states,
  1607. encoder_attention_mask=attention_mask,
  1608. head_mask=decoder_head_mask,
  1609. cross_attn_head_mask=cross_attn_head_mask,
  1610. use_cache=use_cache,
  1611. output_attentions=output_attentions,
  1612. output_hidden_states=output_hidden_states,
  1613. return_dict=return_dict,
  1614. )
  1615. sequence_output = decoder_outputs[0]
  1616. logits = self.qa_outputs(sequence_output)
  1617. start_logits, end_logits = logits.split(1, dim=-1)
  1618. start_logits = start_logits.squeeze(-1).contiguous()
  1619. end_logits = end_logits.squeeze(-1).contiguous()
  1620. total_loss = None
  1621. if start_positions is not None and end_positions is not None:
  1622. # If we are on multi-GPU, split add a dimension
  1623. if len(start_positions.size()) > 1:
  1624. start_positions = start_positions.squeeze(-1).to(start_logits.device)
  1625. if len(end_positions.size()) > 1:
  1626. end_positions = end_positions.squeeze(-1).to(end_logits.device)
  1627. # sometimes the start/end positions are outside our model inputs, we ignore these terms
  1628. ignored_index = start_logits.size(1)
  1629. start_positions = start_positions.clamp(0, ignored_index)
  1630. end_positions = end_positions.clamp(0, ignored_index)
  1631. loss_fct = CrossEntropyLoss(ignore_index=ignored_index)
  1632. start_loss = loss_fct(start_logits, start_positions)
  1633. end_loss = loss_fct(end_logits, end_positions)
  1634. total_loss = (start_loss + end_loss) / 2
  1635. if not return_dict:
  1636. output = (start_logits, end_logits) + decoder_outputs[1:] + encoder_outputs
  1637. return ((total_loss,) + output) if total_loss is not None else output
  1638. return Seq2SeqQuestionAnsweringModelOutput(
  1639. loss=total_loss,
  1640. start_logits=start_logits,
  1641. end_logits=end_logits,
  1642. past_key_values=decoder_outputs.past_key_values,
  1643. decoder_hidden_states=decoder_outputs.hidden_states,
  1644. decoder_attentions=decoder_outputs.attentions,
  1645. cross_attentions=decoder_outputs.cross_attentions,
  1646. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1647. encoder_hidden_states=encoder_outputs.hidden_states,
  1648. encoder_attentions=encoder_outputs.attentions,
  1649. )
  1650. __all__ = [
  1651. "UMT5EncoderModel",
  1652. "UMT5ForConditionalGeneration",
  1653. "UMT5ForQuestionAnswering",
  1654. "UMT5ForSequenceClassification",
  1655. "UMT5ForTokenClassification",
  1656. "UMT5Model",
  1657. "UMT5PreTrainedModel",
  1658. ]