backbone.py 66 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ PyTorch T5 model."""
  16. import copy
  17. import math
  18. import os
  19. import warnings
  20. from typing import Optional, Tuple, Union
  21. import torch
  22. from torch import nn
  23. from torch.utils.checkpoint import checkpoint
  24. from transformers.activations import ACT2FN
  25. from transformers.modeling_outputs import \
  26. BaseModelOutputWithPastAndCrossAttentions
  27. from transformers.modeling_utils import (PreTrainedModel,
  28. find_pruneable_heads_and_indices,
  29. prune_linear_layer)
  30. from transformers.utils import (DUMMY_INPUTS, DUMMY_MASK, add_start_docstrings,
  31. add_start_docstrings_to_model_forward,
  32. is_torch_fx_proxy, replace_return_docstrings)
  33. from transformers.utils.model_parallel_utils import (assert_device_map,
  34. get_device_map)
  35. from modelscope.metainfo import Models
  36. from modelscope.models.base import Model, Tensor, TorchModel
  37. from modelscope.models.builder import MODELS
  38. from modelscope.outputs import AttentionBackboneModelOutput, Seq2SeqModelOutput
  39. from modelscope.utils.constant import Tasks
  40. from modelscope.utils.logger import get_logger
  41. from .configuration import T5Config
  42. logger = get_logger()
  43. ###################################################
  44. # This is a conversion method from TF 1.0 to PyTorch
  45. # More details: https://medium.com/huggingface/from-tensorflow-to-pytorch-265f40ef2a28
  46. ####################################################
  47. def load_tf_weights_in_t5(model, config, tf_checkpoint_path):
  48. """Load tf checkpoints in a pytorch model."""
  49. try:
  50. import re
  51. import numpy as np
  52. import tensorflow as tf
  53. except ImportError:
  54. logger.error(
  55. 'Loading a TensorFlow model in PyTorch, requires TensorFlow to be installed. Please see '
  56. 'https://www.tensorflow.org/install/ for installation instructions.'
  57. )
  58. raise
  59. tf_path = os.path.abspath(tf_checkpoint_path)
  60. logger.info(f'Converting TensorFlow checkpoint from {tf_path}')
  61. # Load weights from TF model
  62. init_vars = tf.train.list_variables(tf_path)
  63. names = []
  64. tf_weights = {}
  65. for name, shape in init_vars:
  66. logger.info(f'Loading TF weight {name} with shape {shape}')
  67. array = tf.train.load_variable(tf_path, name)
  68. names.append(name)
  69. tf_weights[name] = array
  70. for txt_name in names:
  71. name = txt_name.split('/')
  72. # adam_v and adam_m are variables used in AdamWeightDecayOptimizer to calculated m and v
  73. # which are not required for using pretrained model
  74. if any(n in [
  75. 'adam_v', 'adam_m', 'AdamWeightDecayOptimizer',
  76. 'AdamWeightDecayOptimizer_1', 'global_step'
  77. ] for n in name):
  78. logger.info(f"Skipping {'/'.join(name)}")
  79. tf_weights.pop(txt_name, None)
  80. continue
  81. if '_slot_' in name[-1]:
  82. logger.info(f"Skipping {'/'.join(name)}")
  83. tf_weights.pop(txt_name, None)
  84. continue
  85. pointer = model
  86. array = tf_weights[txt_name]
  87. for m_name in name:
  88. if re.fullmatch(r'[A-Za-z]+_\d+', m_name):
  89. scope_names = re.split(r'_(\d+)', m_name)
  90. else:
  91. scope_names = [m_name]
  92. if scope_names[0] in ['kernel', 'scale', 'embedding']:
  93. pointer = getattr(pointer, 'weight')
  94. elif scope_names[0] == 'self_attention':
  95. pointer = getattr(pointer, 'layer')
  96. pointer = pointer[0]
  97. elif scope_names[0] == 'enc_dec_attention':
  98. pointer = getattr(pointer, 'layer')
  99. pointer = pointer[1]
  100. elif scope_names[0] == 'dense_relu_dense':
  101. pointer = getattr(pointer, 'layer')
  102. pointer = pointer[2]
  103. elif scope_names[0] == 'rms_norm':
  104. if hasattr(pointer, 'layer_norm'):
  105. pointer = getattr(pointer, 'layer_norm')
  106. elif hasattr(pointer, 'final_layer_norm'):
  107. pointer = getattr(pointer, 'final_layer_norm')
  108. elif scope_names[0] == 'scale':
  109. pointer = getattr(pointer, 'weight')
  110. elif scope_names[0] == 'output_bias' or scope_names[0] == 'beta':
  111. pointer = getattr(pointer, 'bias')
  112. elif scope_names[0] == 'squad':
  113. pointer = getattr(pointer, 'classifier')
  114. elif scope_names[0] == 'decoder' and name[1] == 'logits':
  115. continue
  116. elif scope_names[0] == 'logits':
  117. pointer = getattr(pointer, 'lm_head')
  118. elif scope_names[0] == 'wi' and len(
  119. scope_names) > 1 and scope_names[1].isdigit():
  120. pointer = getattr(pointer, f'wi_{scope_names[1]}')
  121. continue
  122. else:
  123. try:
  124. pointer = getattr(pointer, scope_names[0])
  125. except AttributeError:
  126. logger.info(f"Skipping {'/'.join(name)}")
  127. continue
  128. if len(scope_names) >= 2:
  129. num = int(scope_names[1])
  130. pointer = pointer[num]
  131. if scope_names[0] not in ['kernel', 'scale', 'embedding']:
  132. pointer = getattr(pointer, 'weight')
  133. if scope_names[0] != 'embedding':
  134. logger.info(
  135. f'Transposing numpy weight of shape {array.shape} for {name}')
  136. array = np.transpose(array)
  137. try:
  138. assert (
  139. pointer.shape == array.shape
  140. ), f'Pointer shape {pointer.shape} and array shape {array.shape} mismatched'
  141. except AssertionError as e:
  142. e.args += (pointer.shape, array.shape)
  143. raise
  144. logger.info(f'Initialize PyTorch weight {name}')
  145. pointer.data = torch.from_numpy(array.astype(np.float32))
  146. tf_weights.pop(txt_name, None)
  147. logger.info(
  148. f"Weights not copied to PyTorch model: {', '.join(tf_weights.keys())}."
  149. )
  150. return model
  151. class T5LayerNorm(nn.Module):
  152. def __init__(self, hidden_size, eps=1e-6):
  153. """
  154. Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
  155. """
  156. super().__init__()
  157. self.weight = nn.Parameter(torch.ones(hidden_size))
  158. self.variance_epsilon = eps
  159. def forward(self, hidden_states):
  160. # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
  161. # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
  162. # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
  163. # half-precision inputs is done in fp32
  164. variance = hidden_states.to(torch.float32).pow(2).mean(
  165. -1, keepdim=True)
  166. hidden_states = hidden_states * torch.rsqrt(variance
  167. + self.variance_epsilon)
  168. # convert into half-precision if necessary
  169. if self.weight.dtype in [torch.float16, torch.bfloat16]:
  170. hidden_states = hidden_states.to(self.weight.dtype)
  171. return self.weight * hidden_states
  172. class T5DenseReluDense(nn.Module):
  173. def __init__(self, config: T5Config):
  174. super().__init__()
  175. self.wi = nn.Linear(config.d_model, config.d_ff, bias=False)
  176. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  177. self.dropout = nn.Dropout(config.dropout_rate)
  178. def forward(self, hidden_states):
  179. hidden_states = self.wi(hidden_states)
  180. hidden_states = nn.functional.relu(hidden_states)
  181. hidden_states = self.dropout(hidden_states)
  182. hidden_states = self.wo(hidden_states)
  183. return hidden_states
  184. class T5DenseGatedGeluDense(nn.Module):
  185. def __init__(self, config: T5Config):
  186. super().__init__()
  187. self.wi_0 = nn.Linear(config.d_model, config.d_ff, bias=False)
  188. self.wi_1 = nn.Linear(config.d_model, config.d_ff, bias=False)
  189. self.wo = nn.Linear(config.d_ff, config.d_model, bias=False)
  190. self.dropout = nn.Dropout(config.dropout_rate)
  191. self.gelu_act = ACT2FN['gelu_new']
  192. def forward(self, hidden_states):
  193. hidden_gelu = self.gelu_act(self.wi_0(hidden_states))
  194. hidden_linear = self.wi_1(hidden_states)
  195. hidden_states = hidden_gelu * hidden_linear
  196. hidden_states = self.dropout(hidden_states)
  197. hidden_states = self.wo(hidden_states)
  198. return hidden_states
  199. class T5LayerFF(nn.Module):
  200. def __init__(self, config: T5Config):
  201. super().__init__()
  202. if config.feed_forward_proj == 'relu':
  203. self.DenseReluDense = T5DenseReluDense(config)
  204. elif config.feed_forward_proj == 'gated-gelu':
  205. self.DenseReluDense = T5DenseGatedGeluDense(config)
  206. else:
  207. raise ValueError(
  208. f'{self.config.feed_forward_proj} is not supported. Choose between `relu` and `gated-gelu`'
  209. )
  210. self.layer_norm = T5LayerNorm(
  211. config.d_model, eps=config.layer_norm_epsilon)
  212. self.dropout = nn.Dropout(config.dropout_rate)
  213. def forward(self, hidden_states):
  214. forwarded_states = self.layer_norm(hidden_states)
  215. forwarded_states = self.DenseReluDense(forwarded_states)
  216. hidden_states = hidden_states + self.dropout(forwarded_states)
  217. return hidden_states
  218. class T5Attention(nn.Module):
  219. def __init__(self, config: T5Config, has_relative_attention_bias=False):
  220. super().__init__()
  221. self.is_decoder = config.is_decoder
  222. self.has_relative_attention_bias = has_relative_attention_bias
  223. self.relative_attention_num_buckets = config.relative_attention_num_buckets
  224. self.relative_attention_max_distance = config.relative_attention_max_distance
  225. self.d_model = config.d_model
  226. self.key_value_proj_dim = config.d_kv
  227. self.n_heads = config.num_heads
  228. self.dropout = config.dropout_rate
  229. self.inner_dim = self.n_heads * self.key_value_proj_dim
  230. # Mesh TensorFlow initialization to avoid scaling before softmax
  231. self.q = nn.Linear(self.d_model, self.inner_dim, bias=False)
  232. self.k = nn.Linear(self.d_model, self.inner_dim, bias=False)
  233. self.v = nn.Linear(self.d_model, self.inner_dim, bias=False)
  234. self.o = nn.Linear(self.inner_dim, self.d_model, bias=False)
  235. if self.has_relative_attention_bias:
  236. self.relative_attention_bias = nn.Embedding(
  237. self.relative_attention_num_buckets, self.n_heads)
  238. self.pruned_heads = set()
  239. self.gradient_checkpointing = False
  240. def prune_heads(self, heads):
  241. if len(heads) == 0:
  242. return
  243. heads, index = find_pruneable_heads_and_indices(
  244. heads, self.n_heads, self.key_value_proj_dim, self.pruned_heads)
  245. # Prune linear layers
  246. self.q = prune_linear_layer(self.q, index)
  247. self.k = prune_linear_layer(self.k, index)
  248. self.v = prune_linear_layer(self.v, index)
  249. self.o = prune_linear_layer(self.o, index, dim=1)
  250. # Update hyper params
  251. self.n_heads = self.n_heads - len(heads)
  252. self.inner_dim = self.key_value_proj_dim * self.n_heads
  253. self.pruned_heads = self.pruned_heads.union(heads)
  254. @staticmethod
  255. def _relative_position_bucket(relative_position,
  256. bidirectional=True,
  257. num_buckets=32,
  258. max_distance=128):
  259. """
  260. Adapted from Mesh Tensorflow:
  261. https://github.com/tensorflow/mesh/blob/0cb87fe07da627bf0b7e60475d59f95ed6b5be3d/mesh_tensorflow/transformer/transformer_layers.py#L593
  262. Translate relative position to a bucket number for relative attention. The relative position is defined as
  263. memory_position - query_position, i.e. the distance in tokens from the attending position to the attended-to
  264. position. If bidirectional=False, then positive relative positions are invalid. We use smaller buckets for
  265. small absolute relative_position and larger buckets for larger absolute relative_positions. All relative
  266. positions >=max_distance map to the same bucket. All relative positions <=-max_distance map to the same bucket.
  267. This should allow for more graceful generalization to longer sequences than the model has been trained on
  268. Args:
  269. relative_position: an int32 Tensor
  270. bidirectional: a boolean - whether the attention is bidirectional
  271. num_buckets: an integer
  272. max_distance: an integer
  273. Returns:
  274. a Tensor with the same shape as relative_position, containing int32 values in the range [0, num_buckets)
  275. """
  276. relative_buckets = 0
  277. if bidirectional:
  278. num_buckets //= 2
  279. relative_buckets += (relative_position > 0).to(
  280. torch.long) * num_buckets
  281. relative_position = torch.abs(relative_position)
  282. else:
  283. relative_position = -torch.min(relative_position,
  284. torch.zeros_like(relative_position))
  285. # now relative_position is in the range [0, inf)
  286. # half of the buckets are for exact increments in positions
  287. max_exact = num_buckets // 2
  288. is_small = relative_position < max_exact
  289. # The other half of the buckets are for logarithmically bigger bins in
  290. # positions up to max_distance
  291. relateive_pos_log = torch.log(relative_position.float() / max_exact)
  292. max_dis_log = math.log(max_distance / max_exact)
  293. origin_relative_position = relateive_pos_log / max_dis_log * (
  294. num_buckets - max_exact)
  295. relative_postion_if_large = max_exact + origin_relative_position.to(
  296. torch.long)
  297. relative_postion_if_large = torch.min(
  298. relative_postion_if_large,
  299. torch.full_like(relative_postion_if_large, num_buckets - 1))
  300. relative_buckets += torch.where(is_small, relative_position,
  301. relative_postion_if_large)
  302. return relative_buckets
  303. def compute_bias(self, query_length, key_length):
  304. """Compute binned relative position bias"""
  305. context_position = torch.arange(
  306. query_length,
  307. dtype=torch.long,
  308. device=self.relative_attention_bias.weight.device)[:, None]
  309. memory_position = torch.arange(
  310. key_length,
  311. dtype=torch.long,
  312. device=self.relative_attention_bias.weight.device)[None, :]
  313. relative_position = memory_position - context_position # shape (query_length, key_length)
  314. relative_position_bucket = self._relative_position_bucket(
  315. relative_position, # shape (query_length, key_length)
  316. bidirectional=(not self.is_decoder),
  317. num_buckets=self.relative_attention_num_buckets,
  318. max_distance=self.relative_attention_max_distance,
  319. )
  320. values = self.relative_attention_bias(
  321. relative_position_bucket
  322. ) # shape (query_length, key_length, num_heads)
  323. values = values.permute([2, 0, 1]).unsqueeze(
  324. 0) # shape (1, num_heads, query_length, key_length)
  325. return values
  326. def forward(
  327. self,
  328. hidden_states,
  329. mask=None,
  330. key_value_states=None,
  331. position_bias=None,
  332. past_key_value=None,
  333. layer_head_mask=None,
  334. query_length=None,
  335. use_cache=False,
  336. output_attentions=False,
  337. ):
  338. """
  339. Self-attention (if key_value_states is None) or attention over source sentence (provided by key_value_states).
  340. """
  341. # Input is (batch_size, seq_length, dim)
  342. # Mask is (batch_size, key_length) (non-causal) or (batch_size, key_length, key_length)
  343. # past_key_value[0] is (batch_size, n_heads, q_len - 1, dim_per_head)
  344. batch_size, seq_length = hidden_states.shape[:2]
  345. real_seq_length = seq_length
  346. if past_key_value is not None:
  347. assert (
  348. len(past_key_value) == 2
  349. ), f'past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states'
  350. real_seq_length += past_key_value[0].shape[
  351. 2] if query_length is None else query_length
  352. key_length = real_seq_length if key_value_states is None else key_value_states.shape[
  353. 1]
  354. def shape(states):
  355. """projection"""
  356. return states.view(batch_size, -1, self.n_heads,
  357. self.key_value_proj_dim).transpose(1, 2)
  358. def unshape(states):
  359. """reshape"""
  360. return states.transpose(1, 2).contiguous().view(
  361. batch_size, -1, self.inner_dim)
  362. def project(hidden_states, proj_layer, key_value_states,
  363. past_key_value):
  364. """projects hidden states correctly to key/query states"""
  365. if key_value_states is None:
  366. # self-attn
  367. # (batch_size, n_heads, seq_length, dim_per_head)
  368. hidden_states = shape(proj_layer(hidden_states))
  369. elif past_key_value is None:
  370. # cross-attn
  371. # (batch_size, n_heads, seq_length, dim_per_head)
  372. hidden_states = shape(proj_layer(key_value_states))
  373. if past_key_value is not None:
  374. if key_value_states is None:
  375. # self-attn
  376. # (batch_size, n_heads, key_length, dim_per_head)
  377. hidden_states = torch.cat([past_key_value, hidden_states],
  378. dim=2)
  379. else:
  380. # cross-attn
  381. hidden_states = past_key_value
  382. return hidden_states
  383. # get query states
  384. query_states = shape(self.q(
  385. hidden_states)) # (batch_size, n_heads, seq_length, dim_per_head)
  386. # get key/value states
  387. key_states = project(
  388. hidden_states, self.k, key_value_states,
  389. past_key_value[0] if past_key_value is not None else None)
  390. value_states = project(
  391. hidden_states, self.v, key_value_states,
  392. past_key_value[1] if past_key_value is not None else None)
  393. # compute scores
  394. scores = torch.matmul(
  395. query_states, key_states.transpose(3, 2)
  396. ) # equivalent of torch.einsum("bnqd,bnkd->bnqk", query_states, key_states), compatible with onnx op>9
  397. if position_bias is None:
  398. if not self.has_relative_attention_bias:
  399. position_bias = torch.zeros(
  400. (1, self.n_heads, real_seq_length, key_length),
  401. device=scores.device,
  402. dtype=scores.dtype)
  403. if self.gradient_checkpointing and self.training:
  404. position_bias.requires_grad = True
  405. else:
  406. position_bias = self.compute_bias(real_seq_length, key_length)
  407. # if key and values are already calculated
  408. # we want only the last query position bias
  409. if past_key_value is not None:
  410. position_bias = position_bias[:, :, -hidden_states.size(1):, :]
  411. if mask is not None:
  412. position_bias = position_bias + mask # (batch_size, n_heads, seq_length, key_length)
  413. scores += position_bias
  414. attn_weights = nn.functional.softmax(
  415. scores.float(), dim=-1).type_as(
  416. scores) # (batch_size, n_heads, seq_length, key_length)
  417. attn_weights = nn.functional.dropout(
  418. attn_weights, p=self.dropout, training=self.training
  419. ) # (batch_size, n_heads, seq_length, key_length)
  420. # Mask heads if we want to
  421. if layer_head_mask is not None:
  422. attn_weights = attn_weights * layer_head_mask
  423. attn_output = unshape(torch.matmul(
  424. attn_weights, value_states)) # (batch_size, seq_length, dim)
  425. attn_output = self.o(attn_output)
  426. present_key_value_state = (key_states,
  427. value_states) if (self.is_decoder
  428. and use_cache) else None
  429. outputs = (attn_output, ) + (present_key_value_state, ) + (
  430. position_bias, )
  431. if output_attentions:
  432. outputs = outputs + (attn_weights, )
  433. return outputs
  434. class T5LayerSelfAttention(nn.Module):
  435. def __init__(self, config, has_relative_attention_bias=False):
  436. super().__init__()
  437. self.SelfAttention = T5Attention(
  438. config, has_relative_attention_bias=has_relative_attention_bias)
  439. self.layer_norm = T5LayerNorm(
  440. config.d_model, eps=config.layer_norm_epsilon)
  441. self.dropout = nn.Dropout(config.dropout_rate)
  442. def forward(
  443. self,
  444. hidden_states,
  445. attention_mask=None,
  446. position_bias=None,
  447. layer_head_mask=None,
  448. past_key_value=None,
  449. use_cache=False,
  450. output_attentions=False,
  451. ):
  452. normed_hidden_states = self.layer_norm(hidden_states)
  453. attention_output = self.SelfAttention(
  454. normed_hidden_states,
  455. mask=attention_mask,
  456. position_bias=position_bias,
  457. layer_head_mask=layer_head_mask,
  458. past_key_value=past_key_value,
  459. use_cache=use_cache,
  460. output_attentions=output_attentions,
  461. )
  462. hidden_states = hidden_states + self.dropout(attention_output[0])
  463. outputs = (hidden_states,
  464. ) + attention_output[1:] # add attentions if we output them
  465. return outputs
  466. class T5LayerCrossAttention(nn.Module):
  467. def __init__(self, config):
  468. super().__init__()
  469. self.EncDecAttention = T5Attention(
  470. config, has_relative_attention_bias=False)
  471. self.layer_norm = T5LayerNorm(
  472. config.d_model, eps=config.layer_norm_epsilon)
  473. self.dropout = nn.Dropout(config.dropout_rate)
  474. def forward(
  475. self,
  476. hidden_states,
  477. key_value_states,
  478. attention_mask=None,
  479. position_bias=None,
  480. layer_head_mask=None,
  481. past_key_value=None,
  482. use_cache=False,
  483. query_length=None,
  484. output_attentions=False,
  485. ):
  486. normed_hidden_states = self.layer_norm(hidden_states)
  487. attention_output = self.EncDecAttention(
  488. normed_hidden_states,
  489. mask=attention_mask,
  490. key_value_states=key_value_states,
  491. position_bias=position_bias,
  492. layer_head_mask=layer_head_mask,
  493. past_key_value=past_key_value,
  494. use_cache=use_cache,
  495. query_length=query_length,
  496. output_attentions=output_attentions,
  497. )
  498. layer_output = hidden_states + self.dropout(attention_output[0])
  499. outputs = (layer_output,
  500. ) + attention_output[1:] # add attentions if we output them
  501. return outputs
  502. class T5Block(nn.Module):
  503. def __init__(self, config, has_relative_attention_bias=False):
  504. super().__init__()
  505. self.is_decoder = config.is_decoder
  506. self.layer = nn.ModuleList()
  507. self.layer.append(
  508. T5LayerSelfAttention(
  509. config,
  510. has_relative_attention_bias=has_relative_attention_bias))
  511. if self.is_decoder:
  512. self.layer.append(T5LayerCrossAttention(config))
  513. self.layer.append(T5LayerFF(config))
  514. def forward(
  515. self,
  516. hidden_states,
  517. attention_mask=None,
  518. position_bias=None,
  519. encoder_hidden_states=None,
  520. encoder_attention_mask=None,
  521. encoder_decoder_position_bias=None,
  522. layer_head_mask=None,
  523. cross_attn_layer_head_mask=None,
  524. past_key_value=None,
  525. use_cache=False,
  526. output_attentions=False,
  527. return_dict=True,
  528. ):
  529. if past_key_value is not None:
  530. if not self.is_decoder:
  531. logger.warning(
  532. '`past_key_values` is passed to the encoder. Please make sure this is intended.'
  533. )
  534. expected_num_past_key_values = 2 if encoder_hidden_states is None else 4
  535. if len(past_key_value) != expected_num_past_key_values:
  536. raise ValueError(
  537. f'There should be {expected_num_past_key_values} past states. '
  538. f"{'2 (past / key) for cross attention. ' if expected_num_past_key_values == 4 else ''}"
  539. f'Got {len(past_key_value)} past key / value states')
  540. self_attn_past_key_value = past_key_value[:2]
  541. cross_attn_past_key_value = past_key_value[2:]
  542. else:
  543. self_attn_past_key_value, cross_attn_past_key_value = None, None
  544. self_attention_outputs = self.layer[0](
  545. hidden_states,
  546. attention_mask=attention_mask,
  547. position_bias=position_bias,
  548. layer_head_mask=layer_head_mask,
  549. past_key_value=self_attn_past_key_value,
  550. use_cache=use_cache,
  551. output_attentions=output_attentions,
  552. )
  553. hidden_states, present_key_value_state = self_attention_outputs[:2]
  554. attention_outputs = self_attention_outputs[
  555. 2:] # Keep self-attention outputs and relative position weights
  556. # clamp inf values to enable fp16 training
  557. if hidden_states.dtype == torch.float16 and torch.isinf(
  558. hidden_states).any():
  559. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  560. hidden_states = torch.clamp(
  561. hidden_states, min=-clamp_value, max=clamp_value)
  562. do_cross_attention = self.is_decoder and encoder_hidden_states is not None
  563. if do_cross_attention:
  564. # the actual query length is unknown for cross attention
  565. # if using past key value states. Need to inject it here
  566. if present_key_value_state is not None:
  567. query_length = present_key_value_state[0].shape[2]
  568. else:
  569. query_length = None
  570. cross_attention_outputs = self.layer[1](
  571. hidden_states,
  572. key_value_states=encoder_hidden_states,
  573. attention_mask=encoder_attention_mask,
  574. position_bias=encoder_decoder_position_bias,
  575. layer_head_mask=cross_attn_layer_head_mask,
  576. past_key_value=cross_attn_past_key_value,
  577. query_length=query_length,
  578. use_cache=use_cache,
  579. output_attentions=output_attentions,
  580. )
  581. hidden_states = cross_attention_outputs[0]
  582. # clamp inf values to enable fp16 training
  583. if hidden_states.dtype == torch.float16 and torch.isinf(
  584. hidden_states).any():
  585. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  586. hidden_states = torch.clamp(
  587. hidden_states, min=-clamp_value, max=clamp_value)
  588. # Combine self attn and cross attn key value states
  589. if present_key_value_state is not None:
  590. present_key_value_state = present_key_value_state + cross_attention_outputs[
  591. 1]
  592. # Keep cross-attention outputs and relative position weights
  593. attention_outputs = attention_outputs + cross_attention_outputs[2:]
  594. # Apply Feed Forward layer
  595. hidden_states = self.layer[-1](hidden_states)
  596. # clamp inf values to enable fp16 training
  597. if hidden_states.dtype == torch.float16 and torch.isinf(
  598. hidden_states).any():
  599. clamp_value = torch.finfo(hidden_states.dtype).max - 1000
  600. hidden_states = torch.clamp(
  601. hidden_states, min=-clamp_value, max=clamp_value)
  602. outputs = (hidden_states, )
  603. if use_cache:
  604. outputs = outputs + (present_key_value_state, ) + attention_outputs
  605. else:
  606. outputs = outputs + attention_outputs
  607. # hidden-states, present_key_value_states, (self-attention position
  608. # bias), (self-attention weights), (cross-attention position bias),
  609. # (cross-attention weights)
  610. return outputs
  611. class T5PreTrainedModel(TorchModel, PreTrainedModel):
  612. """
  613. An abstract class to handle weights initialization and a simple interface
  614. for downloading and loading pretrained models.
  615. """
  616. config_class = T5Config
  617. load_tf_weights = load_tf_weights_in_t5
  618. base_model_prefix = 'transformer'
  619. is_parallelizable = True
  620. supports_gradient_checkpointing = True
  621. def __init__(self, config, **kwargs):
  622. super().__init__(config.name_or_path, **kwargs)
  623. super(Model, self).__init__(config)
  624. @property
  625. def dummy_inputs(self):
  626. input_ids = torch.tensor(DUMMY_INPUTS)
  627. input_mask = torch.tensor(DUMMY_MASK)
  628. dummy_inputs = {
  629. 'decoder_input_ids': input_ids,
  630. 'input_ids': input_ids,
  631. 'decoder_attention_mask': input_mask,
  632. }
  633. return dummy_inputs
  634. def _init_weights(self, module):
  635. """Initialize the weights"""
  636. factor = self.config.initializer_factor # Used for testing weights initialization
  637. if isinstance(module, T5LayerNorm):
  638. module.weight.data.fill_(factor * 1.0)
  639. elif isinstance(module, T5Model):
  640. # Mesh TensorFlow embeddings initialization See
  641. # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L1624
  642. module.shared.weight.data.normal_(mean=0.0, std=factor * 1.0)
  643. elif isinstance(module, T5DenseReluDense):
  644. # Mesh TensorFlow FF initialization See
  645. # https://github.com/tensorflow/mesh/blob/master/mesh_tensorflow/transformer/transformer_layers.py#L56
  646. # and
  647. # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L89
  648. module.wi.weight.data.normal_(
  649. mean=0.0, std=factor * ((self.config.d_model)**-0.5))
  650. if hasattr(module.wi, 'bias') and module.wi.bias is not None:
  651. module.wi.bias.data.zero_()
  652. module.wo.weight.data.normal_(
  653. mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
  654. if hasattr(module.wo, 'bias') and module.wo.bias is not None:
  655. module.wo.bias.data.zero_()
  656. elif isinstance(module, T5DenseGatedGeluDense):
  657. module.wi_0.weight.data.normal_(
  658. mean=0.0, std=factor * ((self.config.d_model)**-0.5))
  659. if hasattr(module.wi_0, 'bias') and module.wi_0.bias is not None:
  660. module.wi_0.bias.data.zero_()
  661. module.wi_1.weight.data.normal_(
  662. mean=0.0, std=factor * ((self.config.d_model)**-0.5))
  663. if hasattr(module.wi_1, 'bias') and module.wi_1.bias is not None:
  664. module.wi_1.bias.data.zero_()
  665. module.wo.weight.data.normal_(
  666. mean=0.0, std=factor * ((self.config.d_ff)**-0.5))
  667. if hasattr(module.wo, 'bias') and module.wo.bias is not None:
  668. module.wo.bias.data.zero_()
  669. elif isinstance(module, T5Attention):
  670. # Mesh TensorFlow attention initialization to avoid scaling before
  671. # softmax See
  672. # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/attention.py#L136
  673. d_model = self.config.d_model
  674. key_value_proj_dim = self.config.d_kv
  675. n_heads = self.config.num_heads
  676. module.q.weight.data.normal_(
  677. mean=0.0, std=factor * ((d_model * key_value_proj_dim)**-0.5))
  678. module.k.weight.data.normal_(
  679. mean=0.0, std=factor * (d_model**-0.5))
  680. module.v.weight.data.normal_(
  681. mean=0.0, std=factor * (d_model**-0.5))
  682. module.o.weight.data.normal_(
  683. mean=0.0, std=factor * ((n_heads * key_value_proj_dim)**-0.5))
  684. if module.has_relative_attention_bias:
  685. module.relative_attention_bias.weight.data.normal_(
  686. mean=0.0, std=factor * ((d_model)**-0.5))
  687. def _set_gradient_checkpointing(self, module, value=False):
  688. if isinstance(module, (T5Attention, T5Stack)):
  689. module.gradient_checkpointing = value
  690. def _shift_right(self, input_ids):
  691. decoder_start_token_id = self.config.decoder_start_token_id
  692. pad_token_id = self.config.pad_token_id
  693. assert (
  694. decoder_start_token_id is not None
  695. ), 'self.model.config.decoder_start_token_id has to be defined.'
  696. # shift inputs to the right
  697. if is_torch_fx_proxy(input_ids):
  698. # Item assignment is not supported natively for proxies.
  699. shifted_input_ids = torch.full(input_ids.shape[:-1] + (1, ),
  700. decoder_start_token_id)
  701. shifted_input_ids = torch.cat(
  702. [shifted_input_ids, input_ids[..., :-1]], dim=-1)
  703. else:
  704. shifted_input_ids = input_ids.new_zeros(input_ids.shape)
  705. shifted_input_ids[..., 1:] = input_ids[..., :-1].clone()
  706. shifted_input_ids[..., 0] = decoder_start_token_id
  707. assert pad_token_id is not None, 'self.model.config.pad_token_id has to be defined.'
  708. # replace possible -100 values in labels by `pad_token_id`
  709. shifted_input_ids.masked_fill_(shifted_input_ids == -100, pad_token_id)
  710. assert torch.all(shifted_input_ids >= 0).item(
  711. ), 'Verify that `shifted_input_ids` has only positive values'
  712. return shifted_input_ids
  713. @classmethod
  714. def _instantiate(cls, **kwargs):
  715. """Instantiate the model.
  716. Args:
  717. kwargs: Input args.
  718. model_dir: The model dir used to load the checkpoint and the
  719. label information. num_labels: An optional arg to tell the
  720. model how many classes to initialize.
  721. Method will call utils.parse_label_mapping
  722. if num_labels not supplied. If num_labels is
  723. not found, the model will use the default
  724. setting (2 classes).
  725. Returns:
  726. The loaded model, which is initialized by
  727. transformers.PreTrainedModel.from_pretrained
  728. """
  729. model_dir = kwargs.get('model_dir', None)
  730. if model_dir is None:
  731. config = T5Config(**kwargs)
  732. model = cls(config)
  733. else:
  734. model_kwargs = {}
  735. model = super(Model, cls).from_pretrained(
  736. pretrained_model_name_or_path=model_dir, **model_kwargs)
  737. model.model_dir = model_dir
  738. return model
  739. class T5Stack(T5PreTrainedModel):
  740. def __init__(self, config, embed_tokens=None):
  741. super().__init__(config)
  742. self.embed_tokens = embed_tokens
  743. self.is_decoder = config.is_decoder
  744. self.block = nn.ModuleList([
  745. T5Block(config, has_relative_attention_bias=bool(i == 0))
  746. for i in range(config.num_layers)
  747. ])
  748. self.final_layer_norm = T5LayerNorm(
  749. config.d_model, eps=config.layer_norm_epsilon)
  750. self.dropout = nn.Dropout(config.dropout_rate)
  751. # Initialize weights and apply final processing
  752. self.post_init()
  753. # Model parallel
  754. self.model_parallel = False
  755. self.device_map = None
  756. self.gradient_checkpointing = False
  757. def parallelize(self, device_map=None):
  758. r"""
  759. This is an experimental feature and is a subject to change at a
  760. moment's notice.
  761. Uses a device map to distribute attention modules of the model
  762. across several devices. If no device map is given, it will evenly
  763. distribute blocks across all devices.
  764. Args:
  765. device_map (`Dict[int, list]`, optional, defaults to None):
  766. A dictionary that maps attention modules to devices. Note
  767. that the embedding module and LMHead are always
  768. automatically mapped to the first device (for esoteric
  769. reasons). That means that the first device should have fewer
  770. attention modules mapped to it than other devices. For
  771. reference, the t5 models have the following number of
  772. attention modules:
  773. - t5-small: 6
  774. - t5-base: 12
  775. - t5-large: 24
  776. - t5-3b: 24
  777. - t5-11b: 24
  778. Example:
  779. >>> # Here is an example of a device map on a machine with 4 GPUs
  780. >>> # using t5-3b, which has a total of 24 attention modules:
  781. >>> model = T5ForConditionalGeneration.from_pretrained("t5-3b")
  782. >>> device_map = {
  783. >>> 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14,
  784. >>> 15, 16], 3: [17, 18, 19, 20, 21, 22, 23],
  785. >>> }
  786. >>> model.parallelize(device_map)
  787. >>> # all of the parallelize methods in this file are the same
  788. """
  789. # Check validity of device_map
  790. self.device_map = (
  791. get_device_map(len(self.block), range(torch.cuda.device_count()))
  792. if device_map is None else device_map)
  793. assert_device_map(self.device_map, len(self.block))
  794. self.model_parallel = True
  795. self.first_device = 'cpu' if 'cpu' in self.device_map.keys(
  796. ) else 'cuda:' + str(min(self.device_map.keys()))
  797. self.last_device = 'cuda:' + str(max(self.device_map.keys()))
  798. # Load onto devices
  799. for k, v in self.device_map.items():
  800. for layer in v:
  801. cuda_device = 'cuda:' + str(k)
  802. self.block[layer] = self.block[layer].to(cuda_device)
  803. # Set embed_tokens to first layer
  804. self.embed_tokens = self.embed_tokens.to(self.first_device)
  805. # Set final layer norm to last device
  806. self.final_layer_norm = self.final_layer_norm.to(self.last_device)
  807. def deparallelize(self):
  808. r"""
  809. Moves the model to cpu from a model parallel state.
  810. Example:
  811. >>> # On a 4 GPU machine with t5-3b:
  812. >>> model = T5ForConditionalGeneration.from_pretrained("t5-3b")
  813. >>> device_map = {
  814. >>> 0: [0, 1, 2], 1: [3, 4, 5, 6, 7, 8, 9], 2: [10, 11, 12, 13, 14,
  815. >>> 15, 16], 3: [17, 18, 19, 20, 21, 22, 23],
  816. >>> }
  817. >>> model.parallelize(device_map)
  818. >>> # Splits the model across several devices model.deparallelize()
  819. >>> # Put the model back on cpu and
  820. >>> # cleans memory by calling torch.cuda.empty_cache()
  821. >>> # all of the deparallelize methods in this file are the same
  822. """
  823. self.model_parallel = False
  824. self.device_map = None
  825. self.first_device = 'cpu'
  826. self.last_device = 'cpu'
  827. for i in range(len(self.block)):
  828. self.block[i] = self.block[i].to('cpu')
  829. self.embed_tokens = self.embed_tokens.to('cpu')
  830. self.final_layer_norm = self.final_layer_norm.to('cpu')
  831. torch.cuda.empty_cache()
  832. def get_input_embeddings(self):
  833. return self.embed_tokens
  834. def set_input_embeddings(self, new_embeddings):
  835. self.embed_tokens = new_embeddings
  836. def forward(
  837. self,
  838. input_ids=None,
  839. attention_mask=None,
  840. encoder_hidden_states=None,
  841. encoder_attention_mask=None,
  842. inputs_embeds=None,
  843. head_mask=None,
  844. cross_attn_head_mask=None,
  845. past_key_values=None,
  846. use_cache=None,
  847. output_attentions=None,
  848. output_hidden_states=None,
  849. return_dict=None,
  850. ):
  851. # Model parallel
  852. if self.model_parallel:
  853. torch.cuda.set_device(self.first_device)
  854. self.embed_tokens = self.embed_tokens.to(self.first_device)
  855. use_cache = use_cache if use_cache is not None else self.config.use_cache
  856. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  857. output_hidden_states = (
  858. output_hidden_states if output_hidden_states is not None else
  859. self.config.output_hidden_states)
  860. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  861. if input_ids is not None and inputs_embeds is not None:
  862. err_msg_prefix = 'decoder_' if self.is_decoder else ''
  863. raise ValueError(
  864. f'You cannot specify both {err_msg_prefix}input_ids and {err_msg_prefix}inputs_embeds at the same time'
  865. )
  866. elif input_ids is not None:
  867. input_shape = input_ids.size()
  868. input_ids = input_ids.view(-1, input_shape[-1])
  869. elif inputs_embeds is not None:
  870. input_shape = inputs_embeds.size()[:-1]
  871. else:
  872. err_msg_prefix = 'decoder_' if self.is_decoder else ''
  873. raise ValueError(
  874. f'You have to specify either {err_msg_prefix}input_ids or {err_msg_prefix}inputs_embeds'
  875. )
  876. if inputs_embeds is None:
  877. assert self.embed_tokens is not None, 'You have to initialize the model with valid token embeddings'
  878. inputs_embeds = self.embed_tokens(input_ids)
  879. batch_size, seq_length = input_shape
  880. # required mask seq length can be calculated via length of past
  881. mask_seq_length = past_key_values[0][0].shape[
  882. 2] + seq_length if past_key_values is not None else seq_length
  883. if use_cache is True:
  884. assert self.is_decoder, f'`use_cache` can only be set to `True` if {self} is used as a decoder'
  885. if attention_mask is None:
  886. attention_mask = torch.ones(batch_size, mask_seq_length).to(
  887. inputs_embeds.device)
  888. if self.is_decoder and encoder_attention_mask is None and encoder_hidden_states is not None:
  889. encoder_seq_length = encoder_hidden_states.shape[1]
  890. encoder_attention_mask = torch.ones(
  891. batch_size,
  892. encoder_seq_length,
  893. device=inputs_embeds.device,
  894. dtype=torch.long)
  895. # initialize past_key_values with `None` if past does not exist
  896. if past_key_values is None:
  897. past_key_values = [None] * len(self.block)
  898. # We can provide a self-attention mask of dimensions [batch_size, from_seq_length, to_seq_length]
  899. # ourselves in which case we just need to make it broadcastable to all heads.
  900. extended_attention_mask = self.get_extended_attention_mask(
  901. attention_mask, input_shape, inputs_embeds.device)
  902. # If a 2D or 3D attention mask is provided for the cross-attention
  903. # we need to make broadcastable to [batch_size, num_heads, seq_length, seq_length]
  904. if self.is_decoder and encoder_hidden_states is not None:
  905. encoder_batch_size, encoder_sequence_length, _ = encoder_hidden_states.size(
  906. )
  907. encoder_hidden_shape = (encoder_batch_size,
  908. encoder_sequence_length)
  909. if encoder_attention_mask is None:
  910. encoder_attention_mask = torch.ones(
  911. encoder_hidden_shape, device=inputs_embeds.device)
  912. encoder_extended_attention_mask = self.invert_attention_mask(
  913. encoder_attention_mask)
  914. else:
  915. encoder_extended_attention_mask = None
  916. # Prepare head mask if needed
  917. head_mask = self.get_head_mask(head_mask, self.config.num_layers)
  918. cross_attn_head_mask = self.get_head_mask(cross_attn_head_mask,
  919. self.config.num_layers)
  920. present_key_value_states = () if use_cache else None
  921. all_hidden_states = () if output_hidden_states else None
  922. all_attentions = () if output_attentions else None
  923. all_cross_attentions = () if (output_attentions
  924. and self.is_decoder) else None
  925. position_bias = None
  926. encoder_decoder_position_bias = None
  927. hidden_states = self.dropout(inputs_embeds)
  928. for i, (layer_module,
  929. past_key_value) in enumerate(zip(self.block, past_key_values)):
  930. layer_head_mask = head_mask[i]
  931. cross_attn_layer_head_mask = cross_attn_head_mask[i]
  932. # Model parallel
  933. if self.model_parallel:
  934. torch.cuda.set_device(hidden_states.device)
  935. # Ensure that attention_mask is always on the same device as hidden_states
  936. if attention_mask is not None:
  937. attention_mask = attention_mask.to(hidden_states.device)
  938. if position_bias is not None:
  939. position_bias = position_bias.to(hidden_states.device)
  940. if encoder_hidden_states is not None:
  941. encoder_hidden_states = encoder_hidden_states.to(
  942. hidden_states.device)
  943. if encoder_extended_attention_mask is not None:
  944. encoder_extended_attention_mask = encoder_extended_attention_mask.to(
  945. hidden_states.device)
  946. if encoder_decoder_position_bias is not None:
  947. encoder_decoder_position_bias = encoder_decoder_position_bias.to(
  948. hidden_states.device)
  949. if layer_head_mask is not None:
  950. layer_head_mask = layer_head_mask.to(hidden_states.device)
  951. if cross_attn_layer_head_mask is not None:
  952. cross_attn_layer_head_mask = cross_attn_layer_head_mask.to(
  953. hidden_states.device)
  954. if output_hidden_states:
  955. all_hidden_states = all_hidden_states + (hidden_states, )
  956. if self.gradient_checkpointing and self.training:
  957. if use_cache:
  958. logger.warning(
  959. '`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...'
  960. )
  961. use_cache = False
  962. def create_custom_forward(module):
  963. def custom_forward(*inputs):
  964. return tuple(
  965. module(*inputs, use_cache, output_attentions))
  966. return custom_forward
  967. layer_outputs = checkpoint(
  968. create_custom_forward(layer_module),
  969. hidden_states,
  970. extended_attention_mask,
  971. position_bias,
  972. encoder_hidden_states,
  973. encoder_extended_attention_mask,
  974. encoder_decoder_position_bias,
  975. layer_head_mask,
  976. cross_attn_layer_head_mask,
  977. None, # past_key_value is always None with gradient checkpointing
  978. )
  979. else:
  980. layer_outputs = layer_module(
  981. hidden_states,
  982. attention_mask=extended_attention_mask,
  983. position_bias=position_bias,
  984. encoder_hidden_states=encoder_hidden_states,
  985. encoder_attention_mask=encoder_extended_attention_mask,
  986. encoder_decoder_position_bias=encoder_decoder_position_bias,
  987. layer_head_mask=layer_head_mask,
  988. cross_attn_layer_head_mask=cross_attn_layer_head_mask,
  989. past_key_value=past_key_value,
  990. use_cache=use_cache,
  991. output_attentions=output_attentions,
  992. )
  993. # layer_outputs is a tuple with: hidden-states, key-value-states,
  994. # (self-attention position bias), (self-attention weights),
  995. # (cross-attention position bias), (cross-attention weights)
  996. if use_cache is False:
  997. layer_outputs = layer_outputs[:1] + (
  998. None, ) + layer_outputs[1:]
  999. hidden_states, present_key_value_state = layer_outputs[:2]
  1000. # We share the position biases between the layers - the first layer
  1001. # store them layer_outputs = hidden-states, key-value-states
  1002. # (self-attention position bias), (self-attention weights),
  1003. # (cross-attention position bias), (cross-attention weights)
  1004. position_bias = layer_outputs[2]
  1005. if self.is_decoder and encoder_hidden_states is not None:
  1006. encoder_decoder_position_bias = layer_outputs[
  1007. 4 if output_attentions else 3]
  1008. # append next layer key value states
  1009. if use_cache:
  1010. present_key_value_states = present_key_value_states + (
  1011. present_key_value_state, )
  1012. if output_attentions:
  1013. all_attentions = all_attentions + (layer_outputs[3], )
  1014. if self.is_decoder:
  1015. all_cross_attentions = all_cross_attentions + (
  1016. layer_outputs[5], )
  1017. # Model Parallel: If it's the last layer for that device, put things on the next device
  1018. if self.model_parallel:
  1019. for k, v in self.device_map.items():
  1020. if i == v[-1] and 'cuda:' + str(k) != self.last_device:
  1021. hidden_states = hidden_states.to('cuda:' + str(k + 1))
  1022. hidden_states = self.final_layer_norm(hidden_states)
  1023. hidden_states = self.dropout(hidden_states)
  1024. # Add last layer
  1025. if output_hidden_states:
  1026. all_hidden_states = all_hidden_states + (hidden_states, )
  1027. if not return_dict:
  1028. return tuple(v for v in [
  1029. hidden_states,
  1030. present_key_value_states,
  1031. all_hidden_states,
  1032. all_attentions,
  1033. all_cross_attentions,
  1034. ] if v is not None)
  1035. return BaseModelOutputWithPastAndCrossAttentions(
  1036. last_hidden_state=hidden_states,
  1037. past_key_values=present_key_value_states,
  1038. hidden_states=all_hidden_states,
  1039. attentions=all_attentions,
  1040. cross_attentions=all_cross_attentions,
  1041. )
  1042. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1043. __HEAD_MASK_WARNING_MSG = """
  1044. The input argument `head_mask` was split into two arguments `head_mask` and
  1045. `decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`,
  1046. but this feature is deprecated and will be removed in future versions. If you do
  1047. not want to use any `decoder_head_mask` now, please set `decoder_head_mask =
  1048. torch.ones(num_layers, num_heads)`.
  1049. """
  1050. @MODELS.register_module(group_key=Tasks.backbone, module_name=Models.T5)
  1051. class T5Model(T5PreTrainedModel):
  1052. """The bare T5 Model transformer outputting raw hidden-states without any
  1053. specific head on top.
  1054. The T5 model was proposed in [Exploring the Limits of Transfer Learning with
  1055. a Unified Text-to-Text Transformer](https://arxiv.org/abs/1910.10683) by
  1056. Colin Raffel, Noam Shazeer, Adam Roberts, Katherine Lee, Sharan Narang,
  1057. Michael Matena, Yanqi Zhou, Wei Li, Peter J. Liu. It's an encoder decoder
  1058. transformer pre-trained in a text-to-text denoising generative setting.
  1059. This model inherits from [`PreTrainedModel`]. Check the superclass
  1060. documentation for the generic methods the library implements for all its
  1061. model (such as downloading or saving, resizing the input embeddings, pruning
  1062. heads etc.)
  1063. This model is also a PyTorch
  1064. [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module)
  1065. subclass. Use it as a regular PyTorch Module and refer to the PyTorch
  1066. documentation for all matter related to general usage and behavior.
  1067. Parameters:
  1068. config ([`T5Config`]): Model configuration class with all the parameters
  1069. of the model.
  1070. Initializing with a config file does not load the weights associated
  1071. with the model, only the configuration. Check out the
  1072. [`~PreTrainedModel.from_pretrained`] method to load the model
  1073. weights.
  1074. """
  1075. _keys_to_ignore_on_load_missing = [
  1076. r'encoder\.embed_tokens\.weight',
  1077. r'decoder\.embed_tokens\.weight',
  1078. ]
  1079. _keys_to_ignore_on_load_unexpected = [
  1080. r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight',
  1081. ]
  1082. def __init__(self, config: T5Config):
  1083. super().__init__(config)
  1084. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  1085. encoder_config = copy.deepcopy(config)
  1086. encoder_config.is_decoder = False
  1087. encoder_config.use_cache = False
  1088. encoder_config.is_encoder_decoder = False
  1089. self.encoder = T5Stack(encoder_config, self.shared)
  1090. decoder_config = copy.deepcopy(config)
  1091. decoder_config.is_decoder = True
  1092. decoder_config.is_encoder_decoder = False
  1093. decoder_config.num_layers = config.num_decoder_layers
  1094. self.decoder = T5Stack(decoder_config, self.shared)
  1095. # Initialize weights and apply final processing
  1096. self.post_init()
  1097. # Model parallel
  1098. self.model_parallel = False
  1099. self.device_map = None
  1100. def parallelize(self, device_map=None):
  1101. self.device_map = (
  1102. get_device_map(
  1103. len(self.encoder.block), range(torch.cuda.device_count()))
  1104. if device_map is None else device_map)
  1105. assert_device_map(self.device_map, len(self.encoder.block))
  1106. self.encoder.parallelize(self.device_map)
  1107. self.decoder.parallelize(self.device_map)
  1108. self.model_parallel = True
  1109. def deparallelize(self):
  1110. self.encoder.deparallelize()
  1111. self.decoder.deparallelize()
  1112. self.encoder = self.encoder.to('cpu')
  1113. self.decoder = self.decoder.to('cpu')
  1114. self.model_parallel = False
  1115. self.device_map = None
  1116. torch.cuda.empty_cache()
  1117. def get_input_embeddings(self):
  1118. return self.shared
  1119. def set_input_embeddings(self, new_embeddings):
  1120. self.shared = new_embeddings
  1121. self.encoder.set_input_embeddings(new_embeddings)
  1122. self.decoder.set_input_embeddings(new_embeddings)
  1123. def get_encoder(self):
  1124. return self.encoder
  1125. def get_decoder(self):
  1126. return self.decoder
  1127. def _prune_heads(self, heads_to_prune):
  1128. """
  1129. Prunes heads of the model. heads_to_prune: dict of {layer_num: list of
  1130. heads to prune in this layer} See base class PreTrainedModel
  1131. """
  1132. for layer, heads in heads_to_prune.items():
  1133. self.encoder.layer[layer].attention.prune_heads(heads)
  1134. def forward(
  1135. self,
  1136. input_ids: Optional[torch.LongTensor] = None,
  1137. attention_mask: Optional[torch.FloatTensor] = None,
  1138. decoder_input_ids: Optional[torch.LongTensor] = None,
  1139. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  1140. head_mask: Optional[torch.FloatTensor] = None,
  1141. decoder_head_mask: Optional[torch.FloatTensor] = None,
  1142. cross_attn_head_mask: Optional[torch.Tensor] = None,
  1143. encoder_outputs: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1144. past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None,
  1145. inputs_embeds: Optional[torch.Tensor] = None,
  1146. decoder_inputs_embeds: Optional[torch.Tensor] = None,
  1147. use_cache: Optional[bool] = None,
  1148. output_attentions: Optional[bool] = None,
  1149. output_hidden_states: Optional[bool] = None,
  1150. return_dict: Optional[bool] = None,
  1151. ) -> Union[Tuple[torch.FloatTensor], Seq2SeqModelOutput]:
  1152. r"""
  1153. Args:
  1154. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  1155. Indices of input sequence tokens in the vocabulary. T5 is a model
  1156. with relative position embeddings so you should be able to pad the
  1157. inputs on both the right and the left.
  1158. Indices can be obtained using [`T5Tokenizer`]. See
  1159. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
  1160. for detail.
  1161. [What are input IDs?](../glossary#input-ids)
  1162. To know more on how to prepare `input_ids` for pretraining take a
  1163. look a [T5 Training](./t5#training).
  1164. attention_mask (`torch.FloatTensor` of shape `(batch_size,
  1165. sequence_length)`, *optional*):
  1166. Mask to avoid performing attention on padding token indices. Mask
  1167. values selected in `[0, 1]`:
  1168. - 1 for tokens that are **not masked**,
  1169. - 0 for tokens that are **masked**.
  1170. [What are attention masks?](../glossary#attention-mask)
  1171. decoder_input_ids (`torch.LongTensor` of shape `(batch_size,
  1172. target_sequence_length)`, *optional*):
  1173. Indices of decoder input sequence tokens in the vocabulary.
  1174. Indices can be obtained using [`T5Tokenizer`]. See
  1175. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
  1176. for details.
  1177. [What are decoder input IDs?](../glossary#decoder-input-ids)
  1178. T5 uses the `pad_token_id` as the starting token for
  1179. `decoder_input_ids` generation. If `past_key_values` is used,
  1180. optionally only the last `decoder_input_ids` have to be input (see
  1181. `past_key_values`).
  1182. To know more on how to prepare `decoder_input_ids` for pretraining
  1183. take a look at [T5 Training](./t5#training).
  1184. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size,
  1185. target_sequence_length)`, *optional*):
  1186. Default behavior: generate a tensor that ignores pad tokens in
  1187. `decoder_input_ids`. Causal mask will also be used by default.
  1188. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers,
  1189. num_heads)`, *optional*):
  1190. Mask to nullify selected heads of the self-attention modules in the
  1191. encoder. Mask values selected in `[0, 1]`:
  1192. - 1 indicates the head is **not masked**,
  1193. - 0 indicates the head is **masked**.
  1194. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or
  1195. `(num_layers, num_heads)`, *optional*):
  1196. Mask to nullify selected heads of the self-attention modules in the
  1197. decoder. Mask values selected in `[0, 1]`:
  1198. - 1 indicates the head is **not masked**,
  1199. - 0 indicates the head is **masked**.
  1200. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or
  1201. `(num_layers, num_heads)`, *optional*):
  1202. Mask to nullify selected heads of the cross-attention modules in
  1203. the decoder. Mask values selected in `[0, 1]`:
  1204. - 1 indicates the head is **not masked**,
  1205. - 0 indicates the head is **masked**.
  1206. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  1207. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*,
  1208. `optional`: *attentions*) `last_hidden_state` of shape `(batch_size,
  1209. sequence_length, hidden_size)` is a sequence of hidden states at the
  1210. output of the last layer of the encoder. Used in the cross-attention
  1211. of the decoder.
  1212. past_key_values (`tuple(tuple(torch.FloatTensor))` of length
  1213. `config.n_layers` with each tuple having 4 tensors of shape
  1214. `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  1215. Contains precomputed key and value hidden states of the attention
  1216. blocks. Can be used to speed up decoding.
  1217. If `past_key_values` are used, the user can optionally input only
  1218. the last `decoder_input_ids` (those that don't have their past key
  1219. value states given to this model) of shape `(batch_size, 1)` instead
  1220. of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  1221. inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
  1222. sequence_length, hidden_size)`, *optional*):
  1223. Optionally, instead of passing `input_ids` you can choose to
  1224. directly pass an embedded representation. This is useful if you want
  1225. more control over how to convert `input_ids` indices into associated
  1226. vectors than the model's internal embedding lookup matrix.
  1227. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
  1228. target_sequence_length, hidden_size)`, *optional*):
  1229. Optionally, instead of passing `decoder_input_ids` you can choose to
  1230. directly pass an embedded representation. If `past_key_values` is
  1231. used, optionally only the last `decoder_inputs_embeds` have to be
  1232. input (see `past_key_values`). This is useful if you want more
  1233. control over how to convert `decoder_input_ids` indices into
  1234. associated vectors than the model's internal embedding lookup
  1235. matrix.
  1236. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset,
  1237. `decoder_inputs_embeds` takes the value of `inputs_embeds`.
  1238. use_cache (`bool`, *optional*):
  1239. If set to `True`, `past_key_values` key value states are returned
  1240. and can be used to speed up decoding (see `past_key_values`).
  1241. output_attentions (`bool`, *optional*):
  1242. Whether or not to return the attentions tensors of all attention
  1243. layers. See `attentions` under returned tensors for more detail.
  1244. output_hidden_states (`bool`, *optional*):
  1245. Whether or not to return the hidden states of all layers. See
  1246. `hidden_states` under returned tensors for more detail.
  1247. return_dict (`bool`, *optional*):
  1248. Whether or not to return a [`~utils.ModelOutput`] instead of a plain
  1249. tuple.
  1250. Returns:
  1251. Example:
  1252. >>> from transformers import T5Tokenizer, T5Model
  1253. >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
  1254. >>> model = T5Model.from_pretrained("t5-small")
  1255. >>> input_ids = tokenizer(
  1256. ... "Studies have been shown that owning a dog is good for you", return_tensors="pt"
  1257. >>> ).input_ids # Batch size 1
  1258. >>> decoder_input_ids = tokenizer("Studies show that", return_tensors="pt").input_ids # Batch size 1
  1259. >>> # forward pass
  1260. >>> outputs = model(input_ids=input_ids, decoder_input_ids=decoder_input_ids)
  1261. >>> last_hidden_states = outputs.last_hidden_state
  1262. """
  1263. use_cache = use_cache if use_cache is not None else self.config.use_cache
  1264. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  1265. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  1266. if head_mask is not None and decoder_head_mask is None:
  1267. if self.config.num_layers == self.config.num_decoder_layers:
  1268. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  1269. decoder_head_mask = head_mask
  1270. # Encode if needed (training, first prediction pass)
  1271. if encoder_outputs is None:
  1272. encoder_outputs = self.encoder(
  1273. input_ids=input_ids,
  1274. attention_mask=attention_mask,
  1275. inputs_embeds=inputs_embeds,
  1276. head_mask=head_mask,
  1277. output_attentions=output_attentions,
  1278. output_hidden_states=output_hidden_states,
  1279. return_dict=return_dict,
  1280. )
  1281. elif return_dict and not isinstance(encoder_outputs,
  1282. AttentionBackboneModelOutput):
  1283. encoder_outputs = AttentionBackboneModelOutput(
  1284. last_hidden_state=encoder_outputs[0],
  1285. hidden_states=encoder_outputs[1]
  1286. if len(encoder_outputs) > 1 else None,
  1287. attentions=encoder_outputs[2]
  1288. if len(encoder_outputs) > 2 else None,
  1289. )
  1290. hidden_states = encoder_outputs[0]
  1291. if self.model_parallel:
  1292. torch.cuda.set_device(self.decoder.first_device)
  1293. # Set device for model parallelism
  1294. if self.model_parallel:
  1295. torch.cuda.set_device(self.decoder.first_device)
  1296. hidden_states = hidden_states.to(self.decoder.first_device)
  1297. if decoder_input_ids is not None:
  1298. decoder_input_ids = decoder_input_ids.to(
  1299. self.decoder.first_device)
  1300. if attention_mask is not None:
  1301. attention_mask = attention_mask.to(self.decoder.first_device)
  1302. if decoder_attention_mask is not None:
  1303. decoder_attention_mask = decoder_attention_mask.to(
  1304. self.decoder.first_device)
  1305. # Decode
  1306. decoder_outputs = self.decoder(
  1307. input_ids=decoder_input_ids,
  1308. attention_mask=decoder_attention_mask,
  1309. inputs_embeds=decoder_inputs_embeds,
  1310. past_key_values=past_key_values,
  1311. encoder_hidden_states=hidden_states,
  1312. encoder_attention_mask=attention_mask,
  1313. head_mask=decoder_head_mask,
  1314. cross_attn_head_mask=cross_attn_head_mask,
  1315. use_cache=use_cache,
  1316. output_attentions=output_attentions,
  1317. output_hidden_states=output_hidden_states,
  1318. return_dict=return_dict,
  1319. )
  1320. if not return_dict:
  1321. return decoder_outputs + encoder_outputs
  1322. return Seq2SeqModelOutput(
  1323. last_hidden_state=decoder_outputs.last_hidden_state,
  1324. past_key_values=decoder_outputs.past_key_values,
  1325. decoder_hidden_states=decoder_outputs.hidden_states,
  1326. decoder_attentions=decoder_outputs.attentions,
  1327. cross_attentions=decoder_outputs.cross_attentions,
  1328. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  1329. encoder_hidden_states=encoder_outputs.hidden_states,
  1330. encoder_attentions=encoder_outputs.attentions,
  1331. )