text_generation.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
  3. #
  4. # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
  5. # and OPT implementations in this library. It has been modified from its
  6. # original forms to accommodate minor architectural differences compared
  7. # to GPT-NeoX and OPT used by the Meta AI team that trained the model.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from typing import Dict, List, Optional, Tuple, Union
  21. import torch
  22. from transformers.models.llama import LlamaForCausalLM
  23. from modelscope.metainfo import Models
  24. from modelscope.models.base import TorchModel
  25. from modelscope.models.builder import MODELS
  26. from modelscope.outputs import OutputKeys
  27. from modelscope.utils.constant import Tasks
  28. from .backbone import MsModelMixin
  29. def get_chat_prompt(system: str, text: str, history: List[Tuple[str, str]],
  30. max_length: int, tokenizer):
  31. system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
  32. system_ids = tokenizer(
  33. system_prompt, add_special_tokens=False, return_tensors='pt').input_ids
  34. text_prompt = f'{text.strip()} [/INST]'
  35. text_ids = tokenizer(
  36. text_prompt, add_special_tokens=False, return_tensors='pt').input_ids
  37. prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
  38. if prompt_length > max_length:
  39. raise RuntimeError(
  40. f'prepend prompt length {prompt_length} is bigger than max_length {max_length}'
  41. )
  42. history_prompt = ''
  43. history_ids_list = []
  44. # traverse history in reverse order
  45. for user, bot in history[::-1]:
  46. assert isinstance(user, str)
  47. assert isinstance(bot, str)
  48. round_prompt = f'{user.strip()} [/INST] {bot.strip()} </s><s>[INST] '
  49. round_ids = tokenizer(
  50. round_prompt, add_special_tokens=False,
  51. return_tensors='pt').input_ids
  52. if prompt_length + round_ids.shape[-1] > max_length:
  53. # excess history should not be appended to the prompt
  54. break
  55. else:
  56. history_prompt = round_prompt + history_prompt
  57. history_ids_list = [round_ids] + history_ids_list
  58. prompt_length += round_ids.shape[-1]
  59. prompt_list = [system_prompt, history_prompt, text_prompt]
  60. prompt_ids_list = [system_ids] + history_ids_list + [text_ids]
  61. return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=1)
  62. # This file is mainly copied from the llama code of transformers
  63. @MODELS.register_module(Tasks.chat, module_name=Models.llama2)
  64. @MODELS.register_module(Tasks.chat, module_name=Models.llama)
  65. @MODELS.register_module(Tasks.text_generation, module_name=Models.llama2)
  66. @MODELS.register_module(Tasks.text_generation, module_name=Models.llama)
  67. class LlamaForTextGeneration(MsModelMixin, LlamaForCausalLM, TorchModel):
  68. def chat(self, input: Dict, tokenizer) -> Dict:
  69. import copy
  70. gen_kwargs = copy.copy(input)
  71. if 'text' not in input:
  72. text: str = ''
  73. else:
  74. text: str = input['text']
  75. gen_kwargs.pop('text')
  76. if 'system' not in input:
  77. system: str = ''
  78. else:
  79. system: str = input['system']
  80. gen_kwargs.pop('system')
  81. if 'history' not in input:
  82. history = []
  83. else:
  84. history: List[Tuple] = copy.copy(input['history'])
  85. gen_kwargs.pop('history')
  86. if 'max_length' not in gen_kwargs:
  87. gen_kwargs['max_length'] = 4096
  88. prompt, prompt_ids = get_chat_prompt(
  89. system=system,
  90. text=text,
  91. history=history,
  92. max_length=gen_kwargs['max_length'],
  93. tokenizer=tokenizer)
  94. input_ids = prompt_ids.to(self.device)
  95. generate_ids = self.generate(input_ids, **gen_kwargs)
  96. # remove input tokens
  97. generate_ids = generate_ids[:, input_ids.shape[1]:]
  98. response = tokenizer.batch_decode(
  99. generate_ids,
  100. skip_special_tokens=True,
  101. clean_up_tokenization_spaces=False)[0]
  102. response = response.strip()
  103. history.append((text, response))
  104. return {OutputKeys.RESPONSE: response, OutputKeys.HISTORY: history}