llm_pipeline.py 30 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. from contextlib import contextmanager
  4. from threading import Lock
  5. from typing import (Any, Callable, Dict, Generator, Iterator, List, Optional,
  6. Tuple, Union)
  7. import json
  8. import numpy as np
  9. import torch
  10. from transformers import AutoConfig, PreTrainedModel, PreTrainedTokenizer
  11. from modelscope import (AutoModelForCausalLM, AutoTokenizer, Pipeline,
  12. snapshot_download)
  13. from modelscope.hub.file_download import model_file_download
  14. from modelscope.models.base import Model
  15. from modelscope.models.nlp import ChatGLM2Tokenizer, Llama2Tokenizer
  16. from modelscope.outputs import OutputKeys
  17. from modelscope.pipelines.base import Input
  18. from modelscope.pipelines.builder import PIPELINES
  19. from modelscope.pipelines.util import is_model, is_official_hub_path
  20. from modelscope.utils.config import Config
  21. from modelscope.utils.constant import Frameworks, Invoke, ModelFile, Tasks
  22. from modelscope.utils.device import create_device, device_placement
  23. from modelscope.utils.logger import get_logger
  24. from modelscope.utils.model_type_helper import ModelTypeHelper
  25. from modelscope.utils.streaming_output import (PipelineStreamingOutputMixin,
  26. StreamingOutputMixin,
  27. add_stream_generate)
  28. logger = get_logger()
  29. SWIFT_FRAMEWORK = 'swift'
  30. class LLMAdapterRegistry:
  31. llm_format_map = {'qwen': [None, None, None]}
  32. @classmethod
  33. def _add_to_map(cls, model_type: str, value_index: int = 0, member=None):
  34. assert model_type or ModelTypeHelper.current_model_type
  35. if model_type is None:
  36. model_type = ModelTypeHelper.current_model_type
  37. if model_type not in cls.llm_format_map:
  38. cls.llm_format_map[model_type] = [None, None, None]
  39. assert cls.llm_format_map[model_type][value_index] is None
  40. cls.llm_format_map[model_type][value_index] = member
  41. return member
  42. @classmethod
  43. def _wrapper(cls, model_type: str, value_index: int = 0, member=None):
  44. if member is not None:
  45. return cls._add_to_map(model_type, value_index, member)
  46. def _register(member):
  47. return cls._add_to_map(model_type, value_index, member)
  48. return _register
  49. @classmethod
  50. def register_format_messages(cls, model_type: str = None, function=None):
  51. return cls._wrapper(model_type, 0, function)
  52. @classmethod
  53. def register_format_output(cls, model_type: str = None, function=None):
  54. return cls._wrapper(model_type, 1, function)
  55. @classmethod
  56. def register_tokenizer(cls, model_type: str = None, tokenizer_class=None):
  57. return cls._wrapper(model_type, 2, tokenizer_class)
  58. @classmethod
  59. def contains(cls, model_name: str) -> bool:
  60. return model_name in cls.llm_format_map
  61. @classmethod
  62. def get(cls, model_name: str) -> bool:
  63. return cls.llm_format_map[model_name]
  64. @PIPELINES.register_module(Tasks.chat, module_name='llm')
  65. @PIPELINES.register_module(Tasks.text_generation, module_name='llm')
  66. class LLMPipeline(Pipeline, PipelineStreamingOutputMixin):
  67. def initiate_single_model(self, model):
  68. from swift import Swift
  69. if isinstance(model, str):
  70. logger.info(f'initiate model from {model}')
  71. if self._is_swift_model(model):
  72. if self.llm_framework is not None:
  73. logger.warning(
  74. f'Cannot use swift with llm_framework, ignoring {self.llm_framework}.'
  75. )
  76. base_model = self.cfg.safe_get('adapter_cfg.model_id_or_path')
  77. assert base_model is not None, 'Cannot get adapter_cfg.model_id_or_path from configuration.json file.'
  78. revision = self.cfg.safe_get('adapter_cfg.model_revision',
  79. 'master')
  80. base_model = Model.from_pretrained(
  81. base_model,
  82. revision,
  83. invoked_by=Invoke.PIPELINE,
  84. device_map=self.device_map,
  85. torch_dtype=self.torch_dtype,
  86. trust_remote_code=self.trust_remote_code)
  87. swift_model = Swift.from_pretrained(base_model, model_id=model)
  88. return swift_model
  89. if isinstance(model, str) and is_official_hub_path(model):
  90. logger.info(f'initiate model from location {model}.')
  91. if self.llm_framework:
  92. model_dir = model if os.path.exists(
  93. model) else snapshot_download(model)
  94. try:
  95. model = self._wrap_infer_framework(model_dir,
  96. self.llm_framework)
  97. logger.info(f'initiate model with {self.llm_framework}.')
  98. return model
  99. except Exception as e:
  100. logger.warning(
  101. f'Cannot using llm_framework with {model}, '
  102. f'ignoring llm_framework={self.llm_framework} : {e}')
  103. self.llm_framework = None
  104. if is_model(model):
  105. return Model.from_pretrained(
  106. model,
  107. invoked_by=Invoke.PIPELINE,
  108. device_map=self.device_map,
  109. torch_dtype=self.torch_dtype,
  110. ignore_file_pattern=self.ignore_file_pattern)
  111. else:
  112. model_dir = model if os.path.exists(
  113. model) else snapshot_download(model)
  114. # TODO: Temporary use of AutoModelForCausalLM
  115. # Need to be updated into a universal solution
  116. model = AutoModelForCausalLM.from_pretrained(
  117. model_dir,
  118. device_map=self.device_map,
  119. trust_remote_code=self.trust_remote_code)
  120. model.model_dir = model_dir
  121. return model
  122. else:
  123. return model
  124. def _is_swift_model(self, model: Union[str, Any]) -> bool:
  125. if not isinstance(model, str):
  126. return False
  127. if os.path.exists(model):
  128. cfg_file = os.path.join(model, ModelFile.CONFIGURATION)
  129. else:
  130. try:
  131. cfg_file = model_file_download(model, ModelFile.CONFIGURATION)
  132. except Exception:
  133. return False
  134. self.cfg = Config.from_file(cfg_file)
  135. return self.cfg.safe_get(
  136. 'adapter_cfg.tuner_backend') == SWIFT_FRAMEWORK
  137. def _wrap_infer_framework(self, model_dir, framework='vllm'):
  138. from modelscope.pipelines.accelerate.base import InferFramework
  139. return InferFramework.from_pretrained(model_dir, framework)
  140. def __init__(self,
  141. format_messages: Union[Callable, str] = None,
  142. format_output: Callable = None,
  143. tokenizer: PreTrainedTokenizer = None,
  144. llm_framework: str = None,
  145. trust_remote_code: Optional[bool] = None,
  146. *args,
  147. **kwargs):
  148. self.device_map = kwargs.pop('device_map', None)
  149. self.trust_remote_code = trust_remote_code
  150. self.llm_framework = llm_framework
  151. if os.path.exists(kwargs['model']):
  152. config = AutoConfig.from_pretrained(
  153. kwargs['model'], trust_remote_code=self.trust_remote_code)
  154. q_config = config.__dict__.get('quantization_config', None)
  155. if q_config:
  156. if q_config.get(
  157. 'quant_method',
  158. 'gptq') == 'gptq' and torch.cuda.device_count():
  159. self.device_map = 'cuda'
  160. self.torch_dtype = kwargs.pop('torch_dtype', None)
  161. self.ignore_file_pattern = kwargs.pop('ignore_file_pattern', None)
  162. if llm_framework == SWIFT_FRAMEWORK:
  163. self._init_swift(kwargs['model'], kwargs.get('device', 'gpu'))
  164. return
  165. with self._temp_configuration_file(kwargs):
  166. super().__init__(*args, **kwargs)
  167. if isinstance(self.model, PreTrainedModel):
  168. self.model = add_stream_generate(self.model)
  169. tokenizer_class = None
  170. if isinstance(format_messages, str):
  171. assert LLMAdapterRegistry.contains(format_messages), \
  172. f'Can not find function for `{format_messages}`!'
  173. format_messages, format_output, tokenizer_class = \
  174. LLMAdapterRegistry.get(format_messages)
  175. if format_messages is None:
  176. model_type = ModelTypeHelper.get(self.model.model_dir, split='-')
  177. if LLMAdapterRegistry.contains(model_type):
  178. format_messages, format_output, tokenizer_class = \
  179. LLMAdapterRegistry.get(model_type)
  180. if format_messages is not None:
  181. self.format_messages = format_messages
  182. if format_output is not None:
  183. self.format_output = format_output
  184. self.tokenizer = self._get_tokenizer(
  185. tokenizer_class) if tokenizer is None else tokenizer
  186. def _init_swift(self, model_id, device) -> None:
  187. from swift.llm import prepare_model_template
  188. from swift.llm import InferArguments, get_model_info_meta
  189. def format_messages(messages: Dict[str, List[Dict[str, str]]],
  190. tokenizer: PreTrainedTokenizer,
  191. **kwargs) -> Dict[str, torch.Tensor]:
  192. inputs = self.template.encode(messages)
  193. inputs.pop('labels', None)
  194. if 'input_ids' in inputs:
  195. input_ids = torch.tensor(inputs['input_ids'])[None]
  196. inputs['input_ids'] = input_ids
  197. token_len = input_ids.shape[1]
  198. if 'inputs_embeds' in inputs:
  199. inputs_embeds = inputs['inputs_embeds'][None]
  200. inputs['inputs_embeds'] = inputs_embeds
  201. token_len = inputs_embeds.shape[1]
  202. inputs['attention_mask'] = torch.ones(token_len)[None]
  203. if 'token_type_ids' in inputs:
  204. inputs['token_type_ids'] = torch.tensor(
  205. inputs['token_type_ids'])[None]
  206. return inputs
  207. def get_example(
  208. messages: Dict[str, List[Dict[str, str]]]) -> Dict[str, str]:
  209. messages = messages['messages']
  210. assert len(messages) > 0, 'messages cannot be empty!'
  211. system = None
  212. if messages[0]['role'] == 'system':
  213. system = messages[0]['content']
  214. messages = messages[1:]
  215. assert len(messages) % 2 == 1, 'Unsupported messages format!'
  216. contents = [message['content'] for message in messages]
  217. prompt = contents[-1]
  218. history = list(zip(contents[::2], contents[1::2]))
  219. if self.llm_framework == SWIFT_FRAMEWORK:
  220. return dict(system=system, query=prompt, history=history)
  221. else:
  222. return dict(system=system, prompt=prompt, history=history)
  223. args = InferArguments(model=model_id)
  224. model, template = prepare_model_template(
  225. args, device_map=self.device_map)
  226. self.model = add_stream_generate(model)
  227. template.model = self.model
  228. self.template = template
  229. self.tokenizer = template.tokenizer
  230. self.format_messages = format_messages
  231. self.has_multiple_models = False
  232. self.framework = Frameworks.torch
  233. self.device_name = device
  234. self.device = create_device(device)
  235. self._model_prepare = False
  236. self._model_prepare_lock = Lock()
  237. self._auto_collate = True
  238. self._compile = False
  239. @contextmanager
  240. def _temp_configuration_file(self, kwargs: Dict[str, Any]):
  241. kwargs['model'] = model = self.initiate_single_model(kwargs['model'])
  242. model_dir = model if isinstance(model, str) else model.model_dir
  243. configuration_path = os.path.join(model_dir, ModelFile.CONFIGURATION)
  244. if os.path.exists(configuration_path):
  245. yield
  246. else:
  247. with open(configuration_path, 'w') as f:
  248. json.dump({'framework': 'pytorch', 'task': 'chat'}, f)
  249. yield
  250. os.remove(configuration_path)
  251. def _process_single(self, inputs, *args, **kwargs) -> Dict[str, Any]:
  252. preprocess_params = kwargs.get('preprocess_params', {})
  253. forward_params = kwargs.get('forward_params', {})
  254. postprocess_params = kwargs.get('postprocess_params', {})
  255. preprocess_params['is_messages'] = postprocess_params['is_messages'] \
  256. = isinstance(inputs, dict) and 'messages' in inputs
  257. tokens = self.preprocess(inputs, **preprocess_params)
  258. if self.llm_framework in (None, SWIFT_FRAMEWORK):
  259. # pytorch model
  260. if hasattr(self.model, 'generate'):
  261. outputs = self.model.generate(**tokens, **forward_params)
  262. elif hasattr(self.model, 'model') and hasattr(
  263. self.model.model, 'generate'):
  264. outputs = self.model.model.generate(**tokens, **forward_params)
  265. else:
  266. raise ValueError('model does not support `generate`!')
  267. else:
  268. tokens = [list(tokens['inputs'].flatten().numpy())]
  269. outputs = self.model(tokens, **forward_params)[0]
  270. if self.llm_framework in (None, SWIFT_FRAMEWORK):
  271. # pytorch model
  272. outputs = outputs.tolist()[0][len(tokens['inputs'][0]):]
  273. response = self.postprocess(outputs, **postprocess_params)
  274. return response
  275. def stream_generate(self, inputs: Union[Input, List[Input]], *args,
  276. **kwargs) -> Generator:
  277. assert isinstance(self.model, StreamingOutputMixin
  278. ), 'pipeline.model must be StreamingOutputMixin!'
  279. if (self.model or (self.has_multiple_models and self.models[0])):
  280. if not self._model_prepare:
  281. self.prepare_model()
  282. preprocess_params, forward_params, postprocess_params = self._sanitize_parameters(
  283. **kwargs)
  284. preprocess_params['is_messages'] = postprocess_params['is_messages'] \
  285. = isinstance(inputs, dict) and 'messages' in inputs
  286. if isinstance(inputs, list):
  287. model_input_list = [
  288. self._preprocess_with_check(i, preprocess_params)
  289. for i in inputs
  290. ]
  291. output = []
  292. for ele in model_input_list:
  293. output.append(
  294. self._stream_single(ele, forward_params,
  295. postprocess_params))
  296. else:
  297. model_input = self._preprocess_with_check(inputs,
  298. preprocess_params)
  299. output = self._stream_single(model_input, forward_params,
  300. postprocess_params)
  301. return output
  302. def _stream_single(self, model_input: Dict[str, Any],
  303. forward_params: Dict[str, Any],
  304. postprocess_params: Dict[str, Any]) -> Generator:
  305. with device_placement(self.framework, self.device_name):
  306. if self.framework == Frameworks.torch:
  307. with torch.no_grad():
  308. if self._auto_collate:
  309. model_input = self._collate_fn(model_input)
  310. stream = self.model.stream_generate(
  311. **model_input, **forward_params)
  312. else:
  313. stream = self.model.stream_generate(**model_input,
  314. **forward_params)
  315. for out in stream:
  316. out = out.tolist()[0][len(model_input['inputs'][0]):]
  317. out = self.postprocess(out, **postprocess_params)
  318. self._check_output(out)
  319. yield out
  320. def preprocess(self, inputs: Union[str, Dict], **kwargs):
  321. is_messages = kwargs.pop('is_messages')
  322. if is_messages:
  323. tokens = self.format_messages(inputs, self.tokenizer, **kwargs)
  324. else:
  325. tokens = self.tokenizer(inputs, return_tensors='pt', **kwargs)
  326. tokens['inputs'] = tokens.pop('input_ids')
  327. if hasattr(self.model, 'device'):
  328. device = self.model.device
  329. elif hasattr(self.model, 'model') and hasattr(self.model.model,
  330. 'device'):
  331. device = self.model.model.device
  332. elif hasattr(self.model, 'llm_framework'):
  333. device = 'cpu'
  334. else:
  335. raise ValueError('model does not have `device` attribute!')
  336. return {
  337. k: (v.to(device) if torch.is_tensor(v) else v)
  338. for k, v in tokens.items()
  339. }
  340. def postprocess(self, outputs, **kwargs):
  341. is_messages = kwargs.pop('is_messages')
  342. if not isinstance(outputs, str):
  343. shape_type = (torch.Tensor, np.ndarray)
  344. if isinstance(outputs, shape_type) and len(outputs.shape) > 1:
  345. outputs = outputs[0]
  346. response = self.tokenizer.decode(
  347. outputs, skip_special_tokens=True, **kwargs)
  348. else:
  349. response = outputs
  350. if is_messages:
  351. response = self.format_output(response, **kwargs)
  352. else:
  353. response = {OutputKeys.TEXT: response}
  354. return response
  355. def _sanitize_parameters(self, **generate_parameter):
  356. """
  357. this method should sanitize the keyword args to preprocessor params,
  358. forward params and postprocess params on '__call__' or '_process_single' method
  359. considered to be a normal classmethod with default implementation / output
  360. Default Returns:
  361. Dict[str, str]: preprocess_params = {}
  362. Dict[str, str]: forward_params = {}
  363. Dict[str, str]: postprocess_params = pipeline_parameters
  364. """
  365. return {}, generate_parameter, {}
  366. def _get_tokenizer(self, tokenizer_class=None):
  367. if isinstance(self.model, str):
  368. model_dir = self.model
  369. else:
  370. model_dir = self.model.model_dir
  371. if tokenizer_class is None:
  372. tokenizer_class = AutoTokenizer
  373. return tokenizer_class.from_pretrained(
  374. model_dir, trust_remote_code=self.trust_remote_code)
  375. @staticmethod
  376. def format_messages(messages: Dict[str, List[Dict[str, str]]],
  377. tokenizer: PreTrainedTokenizer,
  378. **kwargs) -> Dict[str, torch.Tensor]:
  379. # {"messages":[{"role": "system", "content": "You are a helpful assistant."}...]}
  380. tokens = []
  381. for role, content in LLMPipeline._message_iter(messages):
  382. tokens = LLMPipeline._concat_with_special_tokens(
  383. tokens, role, content, tokenizer)
  384. return {'input_ids': torch.tensor([tokens], dtype=torch.int64)}
  385. @staticmethod
  386. def format_output(response: str, **kwargs):
  387. response = response.strip()
  388. message = {'message': {'role': 'assistant', 'content': response}}
  389. return message
  390. @staticmethod
  391. def _message_iter(
  392. data: Dict[str, List[Dict[str,
  393. str]]]) -> Iterator[Tuple[str, str]]:
  394. for pair in data['messages']:
  395. yield pair['role'], pair['content']
  396. @staticmethod
  397. def _concat_with_special_tokens(
  398. ids: List[int], role: str, content: Union[str, List[Dict[str,
  399. str]]],
  400. tokenizer: PreTrainedTokenizer) -> List[int]:
  401. im_start = tokenizer.im_start_id
  402. im_end = tokenizer.im_end_id
  403. nl_token = tokenizer.encode('\n')
  404. role = tokenizer.encode(role.strip())
  405. content = LLMPipeline._encode(tokenizer, content)
  406. return LLMPipeline._concat(ids, im_start, role, nl_token, content,
  407. im_end, nl_token)
  408. @staticmethod
  409. def _encode(tokenizer: PreTrainedTokenizer,
  410. content: Union[str, List[Dict[str, str]]]):
  411. if isinstance(content, str):
  412. return tokenizer.encode(content.rstrip())
  413. encoded = []
  414. for pair in content:
  415. (modal, value), = pair.items()
  416. if modal == 'image':
  417. img_token_span = getattr(tokenizer, 'img_token_span', 256)
  418. img_start_id = tokenizer.img_start_id
  419. img_end_id = img_start_id + 1
  420. img_pad_id = img_start_id + 2
  421. list_int_url = list(bytes(value, encoding='utf-8'))
  422. assert len(
  423. list_int_url) <= img_token_span, 'Image url is too long.'
  424. pad_ids = [img_pad_id] * (img_token_span - len(list_int_url))
  425. encoded = LLMPipeline._concat(encoded, img_start_id,
  426. list_int_url, pad_ids,
  427. img_end_id)
  428. else: # text
  429. encoded.extend(tokenizer.encode(value))
  430. return encoded
  431. @staticmethod
  432. def _concat(ids: List[int], *args: Union[int, List[int]]) -> List[int]:
  433. for item in args:
  434. if isinstance(item, list):
  435. ids.extend(item)
  436. else:
  437. ids.append(item)
  438. return ids
  439. @LLMAdapterRegistry.register_format_messages('chatglm2')
  440. def chatglm2_format_messages(messages, tokenizer, **kwargs):
  441. def build_chatglm2_prompt(messages, **kwargs):
  442. prompt = ''
  443. messages = messages['messages']
  444. # chatglm2 does not have system messages
  445. assert messages[0][
  446. 'role'] == 'user', 'chatglm2 does not have system messages'
  447. for i in range(0, len(messages) - 1, 2):
  448. prompt += '[Round {}]\n\n问:{}\n\n答:{}\n\n'.format(
  449. i // 2 + 1, messages[i]['content'], messages[i + 1]['content'])
  450. prompt += '[Round {}]\n\n问:{}\n\n答:'.format(
  451. len(messages) // 2 + 1, messages[-1]['content'])
  452. return prompt
  453. prompt = build_chatglm2_prompt(messages, **kwargs)
  454. return tokenizer(prompt, return_token_type_ids=False, return_tensors='pt')
  455. @LLMAdapterRegistry.register_format_output('chatglm')
  456. @LLMAdapterRegistry.register_format_output('chatglm2')
  457. def chatglm2_format_output(response, **kwargs):
  458. response = response.strip()
  459. response = response.replace('[[训练时间]]', '2023年')
  460. messages = {'role': 'assistant', 'content': response}
  461. outputs = {
  462. 'message': messages,
  463. }
  464. return outputs
  465. @LLMAdapterRegistry.register_format_messages('llama')
  466. @LLMAdapterRegistry.register_format_messages('llama2')
  467. def llama2_format_messages(messages, tokenizer, **kwargs):
  468. from transformers import BatchEncoding
  469. def build_llama2_prompt(messages, tokenizer, **kwargs):
  470. max_length = kwargs.get('max_length', 2048)
  471. default_system_message = 'you are a helpful assistant!'
  472. messages = messages['messages']
  473. # llama2 have system messages
  474. if messages[0]['role'] != 'system':
  475. messages = [{
  476. 'role': 'system',
  477. 'content': default_system_message
  478. }] + messages
  479. system = messages[0]['content']
  480. system_prompt = f'<s>[INST] <<SYS>>\n{system}\n<</SYS>>\n\n'
  481. system_ids = tokenizer(system_prompt, return_tensors='pt').input_ids
  482. text = messages[-1]['content']
  483. text_prompt = f'{text.strip()} [/INST]'
  484. text_ids = tokenizer(text_prompt, return_tensors='pt').input_ids
  485. prompt_length = system_ids.shape[-1] + text_ids.shape[-1]
  486. if prompt_length > max_length:
  487. raise RuntimeError(
  488. f'prepend prompt length {prompt_length} is bigger than max_length {max_length}'
  489. )
  490. # history items
  491. history_prompt = ''
  492. history_ids_list = []
  493. for i in range(len(messages) - 2, 0, -2):
  494. user, assistant = messages[i]['content'], messages[i
  495. + 1]['content']
  496. round_prompt = f'{user.strip()} [/INST] {assistant.strip()} </s><s>[INST] '
  497. round_ids = tokenizer(round_prompt, return_tensors='pt').input_ids
  498. if prompt_length + round_ids.shape[-1] > max_length:
  499. # excess history should not be appended to the prompt
  500. break
  501. else:
  502. history_prompt = round_prompt + history_prompt
  503. history_ids_list = [round_ids] + history_ids_list
  504. prompt_length += round_ids.shape[-1]
  505. prompt_list = [system_prompt, history_prompt, text_prompt]
  506. prompt_ids_list = [system_ids] + history_ids_list + [text_ids]
  507. return ''.join(prompt_list), torch.cat(prompt_ids_list, dim=-1)
  508. prompt, tokens = build_llama2_prompt(messages, tokenizer, **kwargs)
  509. return BatchEncoding({'input_ids': tokens})
  510. @LLMAdapterRegistry.register_format_messages('baichuan')
  511. @LLMAdapterRegistry.register_format_messages('baichuan2')
  512. def baichuan_format_messages(messages, tokenizer, **kwargs):
  513. from transformers import BatchEncoding
  514. def _parse_messages(messages, split_role='user'):
  515. system, rounds = '', []
  516. round = []
  517. for i, message in enumerate(messages):
  518. if message['role'] == 'system':
  519. assert i == 0, 'first message should be system message.'
  520. system = message['content']
  521. continue
  522. if message['role'] == split_role and round:
  523. rounds.append(round)
  524. round = []
  525. round.append(message)
  526. if round:
  527. rounds.append(round)
  528. return system, rounds
  529. messages = messages['messages']
  530. assistant_token_id = 196
  531. user_token_id = 195
  532. max_new_tokens = kwargs.get('max_new_tokens', None) or 2048
  533. model_max_length = 4096
  534. max_input_tokens = model_max_length - max_new_tokens
  535. system, rounds = _parse_messages(messages, split_role='user')
  536. system_tokens = tokenizer.encode(system)
  537. max_history_tokens = max_input_tokens - len(system_tokens)
  538. history_tokens = []
  539. for round in rounds[::-1]:
  540. round_tokens = []
  541. for message in round:
  542. if message['role'] == 'user':
  543. round_tokens.append(user_token_id)
  544. else:
  545. round_tokens.append(assistant_token_id)
  546. round_tokens.extend(tokenizer.encode(message['content']))
  547. if len(history_tokens) == 0 or len(history_tokens) + len(
  548. round_tokens) <= max_history_tokens:
  549. history_tokens = round_tokens + history_tokens # concat left
  550. if len(history_tokens) < max_history_tokens:
  551. continue
  552. break
  553. input_tokens = system_tokens + history_tokens
  554. if messages[-1]['role'] != 'assistant':
  555. input_tokens.append(assistant_token_id)
  556. input_tokens = input_tokens[-max_input_tokens:] # truncate left
  557. input_tokens = torch.LongTensor([input_tokens])
  558. return BatchEncoding({'input_ids': input_tokens})
  559. @LLMAdapterRegistry.register_format_messages('wizardlm')
  560. def wizardlm_format_messages(messages, tokenizer, **kwargs):
  561. def build_wizardlm_prompt(messages, tokenizer, **kwargs):
  562. default_system_message = 'A chat between a curious user and an artificial intelligence assistant.'
  563. 'The assistant gives helpful, detailed, and polite answers to the user\'s questions.'
  564. messages = messages['messages']
  565. # llama2 have system messages
  566. if messages[0]['role'] != 'system':
  567. messages = [{
  568. 'role': 'system',
  569. 'content': default_system_message
  570. }] + messages
  571. system_prompt = messages[0]['content']
  572. prompt_list = [system_prompt]
  573. for i, message in enumerate(messages[1:]):
  574. if message['role'] == 'user':
  575. user_prompt = message['content']
  576. prompt_list.append(f'USER: {user_prompt}')
  577. elif message['role'] == 'assistant':
  578. user_prompt = message['content']
  579. prompt_list.append(f'ASSISTANT: {user_prompt}</s>')
  580. prompts = ' '.join(prompt_list)
  581. return prompts
  582. prompts = build_wizardlm_prompt(messages, tokenizer, **kwargs)
  583. return tokenizer(prompts, return_token_type_ids=False, return_tensors='pt')
  584. @LLMAdapterRegistry.register_format_messages('wizardcode')
  585. def wizardcode_format_messages(messages, tokenizer, **kwargs):
  586. messages = messages['messages']
  587. assert len(messages) == 2, 'wizard code only support two messages.'
  588. system, user = '', ''
  589. for i, message in enumerate(messages):
  590. if message['role'] == 'system':
  591. assert i == 0, 'first message should be system message.'
  592. system = message['content']
  593. if message['role'] == 'user':
  594. assert i == 1, 'second message should be user message.'
  595. user = message['content']
  596. prompt = system + '\n\n### Instruction:\n' + user + '\n\n### Response:'
  597. inputs = tokenizer(
  598. prompt,
  599. return_token_type_ids=False,
  600. padding=False,
  601. add_special_tokens=False,
  602. return_tensors='pt')
  603. return inputs
  604. @LLMAdapterRegistry.register_format_messages('chatglm')
  605. def chatglm3_format_messages(messages, tokenizer, **kwargs):
  606. messages = messages['messages']
  607. query, history = messages[-1]['content'], messages[:-1]
  608. inputs = tokenizer.build_chat_input(query, history=history)
  609. eos_token_id = [
  610. tokenizer.eos_token_id,
  611. tokenizer.get_command('<|user|>'),
  612. tokenizer.get_command('<|observation|>')
  613. ]
  614. inputs['eos_token_id'] = eos_token_id
  615. return inputs
  616. @LLMAdapterRegistry.register_format_messages('qwen2')
  617. def qwen2_format_messages(messages, tokenizer, **kwargs):
  618. messages = messages['messages']
  619. text = tokenizer.apply_chat_template(
  620. messages, tokenize=False, add_generation_prompt=True)
  621. return tokenizer([text], return_tensors='pt')
  622. LLMAdapterRegistry.register_tokenizer('chatglm2', ChatGLM2Tokenizer)
  623. LLMAdapterRegistry.register_tokenizer('llama', Llama2Tokenizer)
  624. LLMAdapterRegistry.register_tokenizer('llama2', Llama2Tokenizer)