text2text_generation.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2018 Mesh TensorFlow authors, T5 Authors and HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import copy
  16. import warnings
  17. from typing import Optional, Tuple, Union
  18. import torch
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from transformers.utils.model_parallel_utils import (assert_device_map,
  22. get_device_map)
  23. from modelscope.metainfo import Models
  24. from modelscope.models.builder import MODELS
  25. from modelscope.outputs import (AttentionBackboneModelOutput, Seq2SeqLMOutput,
  26. TokenGeneratorOutput)
  27. from modelscope.utils.constant import Tasks
  28. from modelscope.utils.logger import get_logger
  29. from .backbone import T5PreTrainedModel, T5Stack
  30. from .configuration import T5Config
  31. logger = get_logger()
  32. # Warning message for FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  33. __HEAD_MASK_WARNING_MSG = """
  34. The input argument `head_mask` was split into two arguments `head_mask` and
  35. `decoder_head_mask`. Currently, `decoder_head_mask` is set to copy `head_mask`,
  36. but this feature is deprecated and will be removed in future versions. If you do
  37. not want to use any `decoder_head_mask` now, please set `decoder_head_mask =
  38. torch.ones(num_layers, num_heads)`.
  39. """
  40. @MODELS.register_module(
  41. group_key=Tasks.text2text_generation,
  42. module_name=Models.T5,
  43. )
  44. class T5ForConditionalGeneration(T5PreTrainedModel):
  45. _keys_to_ignore_on_load_missing = [
  46. r'encoder\.embed_tokens\.weight',
  47. r'decoder\.embed_tokens\.weight',
  48. r'lm_head\.weight',
  49. ]
  50. _keys_to_ignore_on_load_unexpected = [
  51. r'decoder\.block\.0\.layer\.1\.EncDecAttention\.relative_attention_bias\.weight',
  52. ]
  53. def __init__(self, config: T5Config, device_map=None, **kwargs):
  54. super().__init__(config)
  55. self.model_dim = config.d_model
  56. self.shared = nn.Embedding(config.vocab_size, config.d_model)
  57. encoder_config = copy.deepcopy(config)
  58. encoder_config.is_decoder = False
  59. encoder_config.use_cache = False
  60. encoder_config.is_encoder_decoder = False
  61. self.encoder = T5Stack(encoder_config, self.shared)
  62. decoder_config = copy.deepcopy(config)
  63. decoder_config.is_decoder = True
  64. decoder_config.is_encoder_decoder = False
  65. decoder_config.num_layers = config.num_decoder_layers
  66. self.decoder = T5Stack(decoder_config, self.shared)
  67. self.lm_head = nn.Linear(config.d_model, config.vocab_size, bias=False)
  68. # Initialize weights and apply final processing
  69. self.post_init()
  70. # Model parallel
  71. self.model_parallel = False
  72. if device_map == 'auto':
  73. self.parallelize()
  74. def parallelize(self, device_map=None):
  75. self.device_map = (
  76. get_device_map(
  77. len(self.encoder.block), range(torch.cuda.device_count()))
  78. if device_map is None else device_map)
  79. assert_device_map(self.device_map, len(self.encoder.block))
  80. self.encoder.parallelize(self.device_map)
  81. self.decoder.parallelize(self.device_map)
  82. self.lm_head = self.lm_head.to(self.decoder.first_device)
  83. self.model_parallel = True
  84. def deparallelize(self):
  85. self.encoder.deparallelize()
  86. self.decoder.deparallelize()
  87. self.encoder = self.encoder.to('cpu')
  88. self.decoder = self.decoder.to('cpu')
  89. self.lm_head = self.lm_head.to('cpu')
  90. self.model_parallel = False
  91. self.device_map = None
  92. torch.cuda.empty_cache()
  93. def get_input_embeddings(self):
  94. return self.shared
  95. def set_input_embeddings(self, new_embeddings):
  96. self.shared = new_embeddings
  97. self.encoder.set_input_embeddings(new_embeddings)
  98. self.decoder.set_input_embeddings(new_embeddings)
  99. def set_output_embeddings(self, new_embeddings):
  100. self.lm_head = new_embeddings
  101. def get_output_embeddings(self):
  102. return self.lm_head
  103. def get_encoder(self):
  104. return self.encoder
  105. def get_decoder(self):
  106. return self.decoder
  107. def forward(self,
  108. input_ids: Optional[torch.LongTensor] = None,
  109. attention_mask: Optional[torch.FloatTensor] = None,
  110. decoder_input_ids: Optional[torch.LongTensor] = None,
  111. decoder_attention_mask: Optional[torch.BoolTensor] = None,
  112. head_mask: Optional[torch.FloatTensor] = None,
  113. decoder_head_mask: Optional[torch.FloatTensor] = None,
  114. cross_attn_head_mask: Optional[torch.Tensor] = None,
  115. encoder_outputs: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  116. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  117. inputs_embeds: Optional[torch.FloatTensor] = None,
  118. decoder_inputs_embeds: Optional[torch.FloatTensor] = None,
  119. labels: Optional[torch.LongTensor] = None,
  120. use_cache: Optional[bool] = None,
  121. output_attentions: Optional[bool] = None,
  122. output_hidden_states: Optional[bool] = None,
  123. return_dict: Optional[bool] = None,
  124. **kwargs) -> Union[Tuple[torch.FloatTensor], Seq2SeqLMOutput]:
  125. r"""
  126. Args:
  127. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  128. Indices of input sequence tokens in the vocabulary. T5 is a model
  129. with relative position embeddings so you should be able to pad the
  130. inputs on both the right and the left.
  131. Indices can be obtained using [`T5Tokenizer`]. See
  132. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
  133. for detail.
  134. [What are input IDs?](../glossary#input-ids)
  135. To know more on how to prepare `input_ids` for pretraining take a
  136. look a [T5 Training](./t5#training).
  137. attention_mask (`torch.FloatTensor` of shape `(batch_size,sequence_length)`, *optional*):
  138. Mask to avoid performing attention on padding token indices. Mask
  139. values selected in `[0, 1]`:
  140. - 1 for tokens that are **not masked**,
  141. - 0 for tokens that are **masked**.
  142. [What are attention masks?](../glossary#attention-mask)
  143. decoder_input_ids (`torch.LongTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  144. Indices of decoder input sequence tokens in the vocabulary.
  145. Indices can be obtained using [`T5Tokenizer`]. See
  146. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`]
  147. for details.
  148. [What are decoder input IDs?](../glossary#decoder-input-ids)
  149. T5 uses the `pad_token_id` as the starting token for
  150. `decoder_input_ids` generation. If `past_key_values` is used,
  151. optionally only the last `decoder_input_ids` have to be input (see
  152. `past_key_values`).
  153. To know more on how to prepare `decoder_input_ids` for pretraining
  154. take a look at [T5 Training](./t5#training).
  155. decoder_attention_mask (`torch.BoolTensor` of shape `(batch_size, target_sequence_length)`, *optional*):
  156. Default behavior: generate a tensor that ignores pad tokens in
  157. `decoder_input_ids`. Causal mask will also be used by default.
  158. head_mask (`torch.FloatTensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  159. Mask to nullify selected heads of the self-attention modules in the
  160. encoder. Mask values selected in `[0, 1]`:
  161. - 1 indicates the head is **not masked**,
  162. - 0 indicates the head is **masked**.
  163. decoder_head_mask (`torch.FloatTensor` of shape `(num_heads,)` or
  164. `(num_layers, num_heads)`, *optional*):
  165. Mask to nullify selected heads of the self-attention modules in the
  166. decoder. Mask values selected in `[0, 1]`:
  167. - 1 indicates the head is **not masked**,
  168. - 0 indicates the head is **masked**.
  169. cross_attn_head_mask (`torch.Tensor` of shape `(num_heads,)` or `(num_layers, num_heads)`, *optional*):
  170. Mask to nullify selected heads of the cross-attention modules in
  171. the decoder. Mask values selected in `[0, 1]`:
  172. - 1 indicates the head is **not masked**,
  173. - 0 indicates the head is **masked**.
  174. encoder_outputs (`tuple(tuple(torch.FloatTensor)`, *optional*):
  175. Tuple consists of (`last_hidden_state`, `optional`: *hidden_states*,
  176. `optional`: *attentions*) `last_hidden_state` of shape `(batch_size,
  177. sequence_length, hidden_size)` is a sequence of hidden states at the
  178. output of the last layer of the encoder. Used in the cross-attention
  179. of the decoder.
  180. past_key_values (`tuple(tuple(torch.FloatTensor))` of length
  181. `config.n_layers` with each tuple having 4 tensors of shape
  182. `(batch_size, num_heads, sequence_length - 1, embed_size_per_head)`):
  183. Contains precomputed key and value hidden states of the attention
  184. blocks. Can be used to speed up decoding.
  185. If `past_key_values` are used, the user can optionally input only
  186. the last `decoder_input_ids` (those that don't have their past key
  187. value states given to this model) of shape `(batch_size, 1)` instead
  188. of all `decoder_input_ids` of shape `(batch_size, sequence_length)`.
  189. inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
  190. Optionally, instead of passing `input_ids` you can choose to
  191. directly pass an embedded representation. This is useful if you want
  192. more control over how to convert `input_ids` indices into associated
  193. vectors than the model's internal embedding lookup matrix.
  194. decoder_inputs_embeds (`torch.FloatTensor` of shape `(batch_size, target_sequence_length, hidden_size)`,
  195. *optional*):
  196. Optionally, instead of passing `decoder_input_ids` you can choose to
  197. directly pass an embedded representation. If `past_key_values` is
  198. used, optionally only the last `decoder_inputs_embeds` have to be
  199. input (see `past_key_values`). This is useful if you want more
  200. control over how to convert `decoder_input_ids` indices into
  201. associated vectors than the model's internal embedding lookup
  202. matrix.
  203. If `decoder_input_ids` and `decoder_inputs_embeds` are both unset,
  204. `decoder_inputs_embeds` takes the value of `inputs_embeds`.
  205. use_cache (`bool`, *optional*):
  206. If set to `True`, `past_key_values` key value states are returned
  207. and can be used to speed up decoding (see `past_key_values`).
  208. output_attentions (`bool`, *optional*):
  209. Whether or not to return the attentions tensors of all attention
  210. layers. See `attentions` under returned tensors for more detail.
  211. output_hidden_states (`bool`, *optional*):
  212. Whether or not to return the hidden states of all layers. See
  213. `hidden_states` under returned tensors for more detail.
  214. return_dict (`bool`, *optional*):
  215. Whether or not to return a [`~utils.ModelOutput`] instead of a plain
  216. tuple.
  217. labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
  218. Labels for computing the sequence classification/regression loss.
  219. Indices should be in `[-100, 0, ..., config.vocab_size - 1]`. All
  220. labels set to `-100` are ignored (masked), the loss is only computed
  221. for labels in `[0, ..., config.vocab_size]`
  222. Returns:
  223. Examples:
  224. >>> from transformers import T5Tokenizer, T5ForConditionalGeneration
  225. >>> tokenizer = T5Tokenizer.from_pretrained("t5-small")
  226. >>> model = T5ForConditionalGeneration.from_pretrained("t5-small")
  227. >>> # training
  228. >>> input_ids = tokenizer("The <extra_id_0> walks in <extra_id_1> park", return_tensors="pt").input_ids
  229. >>> labels = tokenizer("<extra_id_0> cute dog <extra_id_1> the <extra_id_2>", return_tensors="pt").input_ids
  230. >>> outputs = model(input_ids=input_ids, labels=labels)
  231. >>> loss = outputs.loss
  232. >>> logits = outputs.logits
  233. >>> # inference
  234. >>> input_ids = tokenizer(
  235. ... "summarize: studies have shown that owning a dog is good for you", return_tensors="pt"
  236. >>> ).input_ids # Batch size 1
  237. >>> outputs = model.generate(input_ids)
  238. >>> print(tokenizer.decode(outputs[0], skip_special_tokens=True))
  239. >>> # studies have shown that owning a dog is good for you.
  240. """
  241. use_cache = use_cache if use_cache is not None else self.config.use_cache
  242. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  243. # FutureWarning: head_mask was separated into two input args - head_mask, decoder_head_mask
  244. if head_mask is not None and decoder_head_mask is None:
  245. if self.config.num_layers == self.config.num_decoder_layers:
  246. warnings.warn(__HEAD_MASK_WARNING_MSG, FutureWarning)
  247. decoder_head_mask = head_mask
  248. # Encode if needed (training, first prediction pass)
  249. if encoder_outputs is None:
  250. # Convert encoder inputs in embeddings if needed
  251. encoder_outputs = self.encoder(
  252. input_ids=input_ids,
  253. attention_mask=attention_mask,
  254. inputs_embeds=inputs_embeds,
  255. head_mask=head_mask,
  256. output_attentions=output_attentions,
  257. output_hidden_states=output_hidden_states,
  258. return_dict=return_dict,
  259. )
  260. elif return_dict and not isinstance(encoder_outputs,
  261. AttentionBackboneModelOutput):
  262. encoder_outputs = AttentionBackboneModelOutput(
  263. last_hidden_state=encoder_outputs[0],
  264. hidden_states=encoder_outputs[1]
  265. if len(encoder_outputs) > 1 else None,
  266. attentions=encoder_outputs[2]
  267. if len(encoder_outputs) > 2 else None,
  268. )
  269. hidden_states = encoder_outputs[0]
  270. if self.model_parallel:
  271. torch.cuda.set_device(self.decoder.first_device)
  272. if labels is not None and decoder_input_ids is None and decoder_inputs_embeds is None:
  273. # get decoder inputs from shifting lm labels to the right
  274. decoder_input_ids = self._shift_right(labels)
  275. # Set device for model parallelism
  276. if self.model_parallel:
  277. torch.cuda.set_device(self.decoder.first_device)
  278. hidden_states = hidden_states.to(self.decoder.first_device)
  279. if decoder_input_ids is not None:
  280. decoder_input_ids = decoder_input_ids.to(
  281. self.decoder.first_device)
  282. if attention_mask is not None:
  283. attention_mask = attention_mask.to(self.decoder.first_device)
  284. if decoder_attention_mask is not None:
  285. decoder_attention_mask = decoder_attention_mask.to(
  286. self.decoder.first_device)
  287. # Decode
  288. decoder_outputs = self.decoder(
  289. input_ids=decoder_input_ids,
  290. attention_mask=decoder_attention_mask,
  291. inputs_embeds=decoder_inputs_embeds,
  292. past_key_values=past_key_values,
  293. encoder_hidden_states=hidden_states,
  294. encoder_attention_mask=attention_mask,
  295. head_mask=decoder_head_mask,
  296. cross_attn_head_mask=cross_attn_head_mask,
  297. use_cache=use_cache,
  298. output_attentions=output_attentions,
  299. output_hidden_states=output_hidden_states,
  300. return_dict=return_dict,
  301. )
  302. sequence_output = decoder_outputs[0]
  303. # Set device for model parallelism
  304. if self.model_parallel:
  305. torch.cuda.set_device(self.encoder.first_device)
  306. self.lm_head = self.lm_head.to(self.encoder.first_device)
  307. sequence_output = sequence_output.to(self.lm_head.weight.device)
  308. if self.config.tie_word_embeddings:
  309. # Rescale output before projecting on vocab See
  310. # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/transformer/transformer.py#L586
  311. sequence_output = sequence_output * (self.model_dim**-0.5)
  312. lm_logits = self.lm_head(sequence_output)
  313. loss = None
  314. if labels is not None:
  315. loss_fct = CrossEntropyLoss(ignore_index=-100)
  316. loss = loss_fct(
  317. lm_logits.view(-1, lm_logits.size(-1)), labels.view(-1))
  318. # TODO(thom): Add z_loss
  319. # https://github.com/tensorflow/mesh/blob/fa19d69eafc9a482aff0b59ddd96b025c0cb207d/mesh_tensorflow/layers.py#L666
  320. if not return_dict:
  321. output = (lm_logits, ) + decoder_outputs[1:] + encoder_outputs
  322. return ((loss, ) + output) if loss is not None else output
  323. return Seq2SeqLMOutput(
  324. loss=loss,
  325. logits=lm_logits,
  326. past_key_values=decoder_outputs.past_key_values,
  327. decoder_hidden_states=decoder_outputs.hidden_states,
  328. decoder_attentions=decoder_outputs.attentions,
  329. cross_attentions=decoder_outputs.cross_attentions,
  330. encoder_last_hidden_state=encoder_outputs.last_hidden_state,
  331. encoder_hidden_states=encoder_outputs.hidden_states,
  332. encoder_attentions=encoder_outputs.attentions,
  333. )
  334. def prepare_inputs_for_generation(self,
  335. input_ids,
  336. past=None,
  337. attention_mask=None,
  338. head_mask=None,
  339. decoder_head_mask=None,
  340. cross_attn_head_mask=None,
  341. use_cache=None,
  342. encoder_outputs=None,
  343. **kwargs):
  344. # cut decoder_input_ids if past is used
  345. if past is not None:
  346. input_ids = input_ids[:, -1:]
  347. return {
  348. 'decoder_input_ids': input_ids,
  349. 'past_key_values': past,
  350. 'encoder_outputs': encoder_outputs,
  351. 'attention_mask': attention_mask,
  352. 'head_mask': head_mask,
  353. 'decoder_head_mask': decoder_head_mask,
  354. 'cross_attn_head_mask': cross_attn_head_mask,
  355. 'use_cache': use_cache,
  356. }
  357. def prepare_decoder_input_ids_from_labels(self, labels: torch.Tensor):
  358. return self._shift_right(labels)
  359. def generate(
  360. self,
  361. *args,
  362. **kwargs,
  363. ):
  364. output = super().generate(*args, **kwargs)
  365. return TokenGeneratorOutput(
  366. sequences=output if isinstance(output, torch.Tensor) else output[0]
  367. )
  368. def _reorder_cache(self, past, beam_idx):
  369. # if decoder past is not included in output
  370. # speedy decoding is disabled and no need to reorder
  371. if past is None:
  372. logger.warning(
  373. 'You might want to consider setting `use_cache=True` to speed up decoding'
  374. )
  375. return past
  376. reordered_decoder_past = ()
  377. for layer_past_states in past:
  378. # get the correct batch idx from layer past batch dim
  379. # batch dim of `past` is at 2nd position
  380. reordered_layer_past_states = ()
  381. for layer_past_state in layer_past_states:
  382. # need to set correct `past` for each of the four key / value states
  383. reordered_layer_past_states = reordered_layer_past_states + (
  384. layer_past_state.index_select(
  385. 0, beam_idx.to(layer_past_state.device)), )
  386. assert reordered_layer_past_states[0].shape == layer_past_states[
  387. 0].shape
  388. assert len(reordered_layer_past_states) == len(layer_past_states)
  389. reordered_decoder_past = reordered_decoder_past + (
  390. reordered_layer_past_states, )
  391. return reordered_decoder_past