text_generation.py 9.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254
  1. # Copyright (c) Alibaba Cloud.
  2. #
  3. # This source code is licensed under the license found in the
  4. # LICENSE file in the root directory of this source tree.
  5. import warnings
  6. from typing import TYPE_CHECKING, Callable, List, Optional, Tuple, Union
  7. import torch
  8. import torch.nn as nn
  9. import torch.nn.functional as F
  10. import torch.utils.checkpoint
  11. from torch.nn import CrossEntropyLoss
  12. from transformers import (GenerationConfig, PreTrainedTokenizer,
  13. StoppingCriteriaList)
  14. from transformers.generation.logits_process import LogitsProcessorList
  15. from transformers.generation.utils import GenerateOutput
  16. from transformers.modeling_outputs import CausalLMOutputWithPast
  17. from modelscope.metainfo import Models
  18. from modelscope.outputs import OutputKeys
  19. from modelscope.utils.constant import Tasks
  20. from modelscope.utils.logger import get_logger
  21. from ... import MODELS
  22. from .backbone import QWenModel, QWenPreTrainedModel
  23. from .qwen_generation_utils import (BatchTokensType, HistoryType,
  24. StopWordsLogitsProcessor, decode_tokens,
  25. get_batch, get_stop_words_ids,
  26. make_context, pad_batch, switch,
  27. top_k_logits)
  28. if TYPE_CHECKING:
  29. from transformers.generation.streamers import BaseStreamer
  30. logger = get_logger()
  31. @MODELS.register_module(Tasks.text_generation, module_name=Models.qwen_7b)
  32. @MODELS.register_module(Tasks.chat, module_name=Models.qwen_7b)
  33. class QWenForTextGeneration(QWenPreTrainedModel):
  34. _keys_to_ignore_on_load_missing = [r'h\.\d+\.attn\.rotary_emb\.inv_freq']
  35. _keys_to_ignore_on_load_unexpected = [r'h\.\d+\.attn\.masked_bias']
  36. def __init__(self, config):
  37. super().__init__(config)
  38. self.transformer = QWenModel(config)
  39. self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)
  40. assert not (config.bf16 and config.fp16
  41. ), 'In config, bf16 and fp16 cannot both be true'
  42. if config.bf16:
  43. self.transformer.bfloat16()
  44. self.lm_head.bfloat16()
  45. if config.fp16:
  46. self.transformer.half()
  47. self.lm_head.half()
  48. self.post_init()
  49. def get_output_embeddings(self):
  50. return self.lm_head
  51. def set_output_embeddings(self, new_embeddings):
  52. self.lm_head = new_embeddings
  53. def prepare_inputs_for_generation(self,
  54. input_ids,
  55. past_key_values=None,
  56. inputs_embeds=None,
  57. **kwargs):
  58. token_type_ids = kwargs.get('token_type_ids', None)
  59. if past_key_values:
  60. input_ids = input_ids[:, -1].unsqueeze(-1)
  61. if token_type_ids is not None:
  62. token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
  63. attention_mask = kwargs.get('attention_mask', None)
  64. position_ids = kwargs.get('position_ids', None)
  65. if attention_mask is not None and position_ids is None:
  66. position_ids = attention_mask.long().cumsum(-1) - 1
  67. position_ids.masked_fill_(attention_mask == 0, 1)
  68. if past_key_values:
  69. position_ids = position_ids[:, -1].unsqueeze(-1)
  70. else:
  71. position_ids = None
  72. if inputs_embeds is not None and past_key_values is None:
  73. model_inputs = {'inputs_embeds': inputs_embeds}
  74. else:
  75. model_inputs = {'input_ids': input_ids}
  76. model_inputs.update({
  77. 'past_key_values': past_key_values,
  78. 'use_cache': kwargs.get('use_cache'),
  79. 'position_ids': position_ids,
  80. 'attention_mask': attention_mask,
  81. 'token_type_ids': token_type_ids,
  82. })
  83. return model_inputs
  84. def forward(
  85. self,
  86. input_ids: Optional[torch.LongTensor] = None,
  87. past_key_values: Optional[Tuple[Tuple[torch.Tensor]]] = None,
  88. attention_mask: Optional[torch.FloatTensor] = None,
  89. token_type_ids: Optional[torch.LongTensor] = None,
  90. position_ids: Optional[torch.LongTensor] = None,
  91. head_mask: Optional[torch.FloatTensor] = None,
  92. inputs_embeds: Optional[torch.FloatTensor] = None,
  93. encoder_hidden_states: Optional[torch.Tensor] = None,
  94. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  95. labels: Optional[torch.LongTensor] = None,
  96. use_cache: Optional[bool] = None,
  97. output_attentions: Optional[bool] = None,
  98. output_hidden_states: Optional[bool] = None,
  99. return_dict: Optional[bool] = None,
  100. ) -> Union[Tuple, CausalLMOutputWithPast]:
  101. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  102. transformer_outputs = self.transformer(
  103. input_ids,
  104. past_key_values=past_key_values,
  105. attention_mask=attention_mask,
  106. token_type_ids=token_type_ids,
  107. position_ids=position_ids,
  108. head_mask=head_mask,
  109. inputs_embeds=inputs_embeds,
  110. encoder_hidden_states=encoder_hidden_states,
  111. encoder_attention_mask=encoder_attention_mask,
  112. use_cache=use_cache,
  113. output_attentions=output_attentions,
  114. output_hidden_states=output_hidden_states,
  115. return_dict=return_dict,
  116. )
  117. hidden_states = transformer_outputs[0]
  118. lm_logits = self.lm_head(hidden_states)
  119. loss = None
  120. if labels is not None:
  121. labels = labels.to(lm_logits.device)
  122. shift_logits = lm_logits[..., :-1, :].contiguous()
  123. shift_labels = labels[..., 1:].contiguous()
  124. loss_fct = CrossEntropyLoss()
  125. loss = loss_fct(
  126. shift_logits.view(-1, shift_logits.size(-1)),
  127. shift_labels.view(-1))
  128. if not return_dict:
  129. output = (lm_logits, ) + transformer_outputs[1:]
  130. return ((loss, ) + output) if loss is not None else output
  131. return CausalLMOutputWithPast(
  132. loss=loss,
  133. logits=lm_logits,
  134. past_key_values=transformer_outputs.past_key_values,
  135. hidden_states=transformer_outputs.hidden_states,
  136. attentions=transformer_outputs.attentions,
  137. )
  138. @staticmethod
  139. def _reorder_cache(past_key_values: Tuple[Tuple[torch.Tensor]],
  140. beam_idx: torch.Tensor) -> Tuple[Tuple[torch.Tensor]]:
  141. return tuple(
  142. tuple(
  143. past_state.index_select(0, beam_idx.to(past_state.device))
  144. for past_state in layer_past)
  145. for layer_past in past_key_values)
  146. def chat(
  147. self,
  148. tokenizer: PreTrainedTokenizer,
  149. query: str,
  150. history: Optional[HistoryType],
  151. system: str = 'You are a helpful assistant.',
  152. append_history: bool = True,
  153. ) -> Tuple[str, HistoryType]:
  154. if history is None:
  155. history = []
  156. raw_text, context_tokens = make_context(
  157. tokenizer,
  158. query,
  159. history=history,
  160. system=system,
  161. max_window_size=6144,
  162. chat_format=self.generation_config.chat_format)
  163. stop_words_ids = get_stop_words_ids(self.generation_config.chat_format,
  164. tokenizer)
  165. input_ids = torch.tensor([context_tokens]).to(self.device)
  166. outputs = self.generate(
  167. input_ids,
  168. stop_words_ids=stop_words_ids,
  169. return_dict_in_generate=False,
  170. )
  171. response = decode_tokens(
  172. outputs[0],
  173. tokenizer,
  174. raw_text_len=len(raw_text),
  175. context_length=len(context_tokens),
  176. chat_format=self.generation_config.chat_format,
  177. verbose=False,
  178. )
  179. if append_history:
  180. history.append((query, response))
  181. return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}
  182. def generate(
  183. self,
  184. inputs: Optional[torch.Tensor] = None,
  185. generation_config: Optional[GenerationConfig] = None,
  186. logits_processor: Optional[LogitsProcessorList] = None,
  187. stopping_criteria: Optional[StoppingCriteriaList] = None,
  188. prefix_allowed_tokens_fn: Optional[Callable[[int, torch.Tensor],
  189. List[int]]] = None,
  190. synced_gpus: Optional[bool] = None,
  191. streamer: Optional['BaseStreamer'] = None,
  192. **kwargs,
  193. ) -> Union[GenerateOutput, torch.LongTensor]:
  194. # Process stop_words_ids
  195. stop_words_ids = kwargs.pop('stop_words_ids', None)
  196. if stop_words_ids is None and generation_config is not None:
  197. stop_words_ids = getattr(generation_config, 'stop_words_ids', None)
  198. if stop_words_ids is None:
  199. stop_words_ids = getattr(self.generation_config, 'stop_words_ids',
  200. None)
  201. if stop_words_ids is not None:
  202. stop_words_logits_processor = StopWordsLogitsProcessor(
  203. stop_words_ids=stop_words_ids,
  204. eos_token_id=self.generation_config.eos_token_id)
  205. if logits_processor is None:
  206. logits_processor = LogitsProcessorList(
  207. [stop_words_logits_processor])
  208. else:
  209. logits_processor.append(stop_words_logits_processor)
  210. return super().generate(
  211. inputs,
  212. generation_config,
  213. logits_processor,
  214. stopping_criteria,
  215. prefix_allowed_tokens_fn,
  216. synced_gpus,
  217. streamer=streamer,
  218. **kwargs,
  219. )