base.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import json
  3. import re
  4. from copy import deepcopy
  5. from typing import Any, Dict, List, Literal, Optional, Tuple, Union
  6. import torch
  7. import torch.nn.functional as F
  8. from modelscope import get_logger
  9. from torch.nn import Module
  10. from torch.nn.utils.rnn import pad_sequence
  11. from transformers import PreTrainedTokenizerBase, StoppingCriteria
  12. from .loss_scale import loss_scale_map
  13. from .tools_prompt import get_tools_prompt
  14. from .utils import load_batch, load_image, rescale_image, fetch_one, to_device, decode_base64
  15. from .utils import History, Prompt, StopWords, Context, Messages
  16. logger = get_logger()
  17. DEFAULT_SYSTEM = 'You are a helpful assistant.'
  18. TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {}
  19. def get_template(
  20. template_type: str,
  21. tokenizer: PreTrainedTokenizerBase,
  22. default_system: Optional[str] = None,
  23. max_length: Optional[int] = None,
  24. truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
  25. **kwargs,
  26. ) -> 'Template':
  27. template_info = TEMPLATE_MAPPING[template_type]
  28. template = deepcopy(template_info['template'])
  29. template.init_template(tokenizer, default_system, max_length, truncation_strategy, **kwargs)
  30. return template
  31. def _findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]:
  32. """Find the index of a token in the token_list."""
  33. if isinstance(sub_token_list, int):
  34. sub_token_list = [sub_token_list]
  35. res = []
  36. idx = -1
  37. try:
  38. while True:
  39. idx = token_list.index(sub_token_list[0], idx + 1)
  40. if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]:
  41. res.append(idx)
  42. except ValueError:
  43. pass
  44. return res
  45. def replace_img_tag(messages: Messages,
  46. replace_token: str,
  47. pattern=r'<img>(.+?)</img>') -> Tuple[str, History, List[str]]:
  48. images_path = []
  49. new_messages = []
  50. for i, m in enumerate(messages):
  51. m = m.copy()
  52. if m['content'] is None or m['role'] in ('tool', 'system', 'assistant'):
  53. new_messages.append(m)
  54. else:
  55. images_path += re.findall(pattern, m['content'])
  56. m['content'] = re.sub(pattern, replace_token, m['content'])
  57. new_messages.append(m)
  58. return messages, images_path
  59. class StopWordsCriteria(StoppingCriteria):
  60. """Adding extra stop words in template to prevent unstoppable generation
  61. Like suffixes and chat seps in the template.
  62. """
  63. def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: StopWords, **tokenizer_kwargs) -> None:
  64. self.tokenizer = tokenizer
  65. self.stop_words = stop_words
  66. self.tokenizer_kwargs = tokenizer_kwargs
  67. self.start_idx = -1
  68. def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool:
  69. if self.start_idx == -1:
  70. self.start_idx = len(input_ids[0]) - 1
  71. tokenizer = self.tokenizer
  72. stop_words = self.stop_words
  73. # [-20:]: Assuming the end tokens do not exceed 20 tokens,
  74. # to avoid input_ids being too long and affecting efficiency.
  75. text = tokenizer.decode(input_ids[0, self.start_idx:][-20:], **self.tokenizer_kwargs)
  76. for stop_word in stop_words:
  77. if isinstance(stop_word, str):
  78. if stop_word in text:
  79. return True
  80. else: # list
  81. if len(stop_word) > 0 and input_ids[0].tolist()[-len(stop_word):] == stop_word:
  82. return True
  83. return False
  84. class Template:
  85. """A template class for all supported models.
  86. Args:
  87. prefix: Prefix tokens before the first turn's prompt
  88. prompt: A list of elements whose types are str and list of integers. The input query part of every turn.
  89. chat_sep: The chat separators between every turn.
  90. suffix: The end tokens after the chat finished.
  91. default_system: A default system instruction.
  92. system_prefix: The prefix if the `system` is not empty.
  93. auto_add_bos: By default, the bos_token is not added. The auto_add_bos option will determine
  94. whether to add it based on `tokenizer.encode('')`.
  95. tools_prompt: The tools prompt name
  96. tool_prompt: The tool prompt, usually useful when there is a tool role
  97. padding_side: The padding side
  98. infer_media_type: The media type supported by the multi-modals
  99. Examples:
  100. <start_of_output>system\nYou are a helpful assistant!<end_of_output>\n<bos><start_of_output>Who are you?<end_of_output>\n<start_of_output>assistant:I am a robot<end_of_output>\n<start_of_output>Who are you?<end_of_output>\n<start_of_output>assistant:I am a robot<end_of_output> # noqa
  101. ----------system------------ ---query---- --response- -----chatsep----- ---query--- --response- ----suffix-----
  102. ----------------------------system_prefix---------------------------- ---------------------------- prompt ------------------------------------- ---------------------------- prompt -------------------------------------
  103. """
  104. special_tokens = ['<image>', '<video>', '<audio>', '<bbox>', '<ref-object>']
  105. special_keys = ['images', 'videos', 'audios', 'objects']
  106. grounding_type = 'norm_1000'
  107. image_placeholder = ['<image>']
  108. load_medias = True
  109. compute_per_round_loss = True # for rlhf
  110. output_prompt_answer = False # for encoder-decoder & kto
  111. def __init__(self,
  112. prefix: Prompt,
  113. prompt: Prompt,
  114. chat_sep: Optional[Prompt],
  115. suffix: Prompt,
  116. default_system: Optional[str] = None,
  117. system_prefix: Optional[Prompt] = None,
  118. auto_add_bos: bool = False,
  119. tools_prompt: str = 'react_en',
  120. tool_prompt: Optional[Prompt] = None,
  121. padding_side: Literal['left', 'right'] = 'right',
  122. infer_media_type: Literal['interleave', 'dialogue', 'round'] = 'interleave') -> None:
  123. # check
  124. for x in [prefix, prompt, chat_sep, suffix, system_prefix]:
  125. assert x is None or isinstance(x, list)
  126. if default_system == '':
  127. default_system = None
  128. if self._has_system(prefix):
  129. assert system_prefix is None, 'The prefix already contains {{SYSTEM}}.'
  130. system_prefix = prefix
  131. prefix = self._replace_system(prefix)
  132. self.prefix = prefix
  133. self.system_prefix = system_prefix
  134. if self.system_prefix is None and not any(['{{SYSTEM}}' in context for context in prompt]):
  135. assert default_system is None, 'The template does not support `system`.'
  136. self.prompt = prompt
  137. self.chat_sep = chat_sep
  138. self.support_multi_round = self.chat_sep is not None
  139. self.suffix = suffix
  140. self.default_system = default_system
  141. self.use_default_system = True
  142. self.auto_add_bos = auto_add_bos
  143. self._is_init = False
  144. self.tools_prompt = tools_prompt
  145. self.tool_prompt = tool_prompt if tool_prompt is not None else self.prompt # default as user
  146. self.padding_side = padding_side
  147. self.infer_media_type = infer_media_type
  148. @staticmethod
  149. def _replace_system(prefix: Prompt) -> Prompt:
  150. """Replace system with the """
  151. return [p.replace('{{SYSTEM}}', '') for p in prefix if '{{SYSTEM}}' in p]
  152. @staticmethod
  153. def _has_system(prefix: Prompt) -> bool:
  154. return any(['{{SYSTEM}}' in p for p in prefix])
  155. @staticmethod
  156. def token_attr_to_id(tokenizer: PreTrainedTokenizerBase, value: Optional[Prompt]) -> Optional[Prompt]:
  157. """Turn `eos_token_id` to token id
  158. e.g. [['eos_token_id']] -> [[2]]
  159. """
  160. if value is None:
  161. return None
  162. res_value = []
  163. for v in value:
  164. if isinstance(v, list):
  165. res_v = []
  166. for sub_v in v:
  167. if isinstance(sub_v, str):
  168. sub_v = getattr(tokenizer, sub_v)
  169. res_v.append(sub_v)
  170. v = res_v
  171. res_value.append(v)
  172. return res_value
  173. def init_template(self,
  174. tokenizer: PreTrainedTokenizerBase,
  175. default_system: Optional[str] = None,
  176. max_length: Optional[int] = None,
  177. truncation_strategy: Literal['delete', 'truncation_left'] = 'delete',
  178. loss_scale: str = 'default',
  179. rescale_image: int = -1,
  180. **kwargs) -> None:
  181. """Init template by a tokenizer
  182. Args:
  183. tokenizer: The tokenizer to tokenize the sentence
  184. default_system: The default system to use if the dataset does not provide one
  185. max_length: Max length of the sequence
  186. truncation_strategy: The truncation strategy
  187. loss_scale: The loss scale function to use
  188. rescale_image: Rescale image to reduce memory usage, default `-1` means no limitation
  189. """
  190. assert self._is_init is False, 'The template has been initialized.'
  191. self._is_init = True
  192. self.tokenizer = tokenizer
  193. self.is_multimodal = getattr(tokenizer, 'is_multimodal', None)
  194. # if default_system is None. not change self.default_system
  195. if default_system == '':
  196. self.default_system = None
  197. elif default_system is not None:
  198. assert self.system_prefix is not None, (
  199. f'The template does not support `system`, template_type: {getattr(self, "template_type", None)}')
  200. self.default_system = default_system
  201. self.max_length = max_length
  202. self.truncation_strategy = truncation_strategy
  203. if isinstance(loss_scale, str):
  204. self.loss_scale = loss_scale_map.get(loss_scale, None)
  205. else:
  206. self.loss_scale = loss_scale
  207. self.rescale_image = rescale_image
  208. for key in ['prefix', 'prompt', 'chat_sep', 'suffix', 'system_prefix']:
  209. value = getattr(self, key)
  210. value = self.token_attr_to_id(tokenizer, value)
  211. setattr(self, key, value)
  212. def post_encode(self, model: Module, data: Any) -> Dict[str, Any]:
  213. """This method will be called after data_collator and before the forward
  214. Args:
  215. data: The `_data` field from the example batch, this field should be packed manually
  216. Returns:
  217. Any extra fields need to be passed into the model.forward
  218. """
  219. return {}
  220. def check_example(self, example: Dict[str, Any]) -> None:
  221. """Check example valid"""
  222. pass
  223. def add_default_tags(self, example: Dict[str, Any]) -> None:
  224. """Add default tags to example, this is for the multi-modal datasets
  225. 1. For the round infer_media_type, this method will check the tag equals with the chat round
  226. 2. Else, this method will try to add tags to the head of the messages
  227. Args:
  228. example: The input example
  229. """
  230. messages = example['messages']
  231. for media_key, media_tag in [('videos', '<video>'), ('images', '<image>'), ('audios', '<audio>')]:
  232. if example.get(media_key):
  233. _messages = [message for message in messages if message['role']!='system']
  234. n_round = len(_messages)
  235. assert n_round % 2 == 0
  236. history = [_messages[i:i+2] for i in range(n_round // 2)]
  237. if self.infer_media_type == 'round':
  238. for i, h, m in zip(range(n_round // 2), history, example[media_key]):
  239. num_media_tags = len(re.findall(media_tag, h[0]['content']))
  240. if m:
  241. assert num_media_tags <= 1, (
  242. 'The model includes at most one media per round. However, '
  243. f'this round contains {num_media_tags} media_tags. query: {h[0]}')
  244. if num_media_tags == 0:
  245. h[0]['content'] = media_tag + h[0]['content']
  246. else:
  247. assert num_media_tags == 0, f'Missing media. query: {h[0]}'
  248. example[media_key] = [m for m in example[media_key] if m]
  249. else:
  250. num_media_tags = len(re.findall(media_tag, '\n'.join([h[0]['content'] for h in history])))
  251. example[media_key] = [m for m in example[media_key] if m]
  252. num_media = len(example[media_key])
  253. num_new_tags = num_media - num_media_tags
  254. assert num_new_tags >= 0, f'Number of media: {num_media}, number of media_tags: {num_media_tags}'
  255. history[0][0]['content'] = media_tag * num_new_tags + history[0][0]['content']
  256. def replace_media_tags(self, example) -> None:
  257. """Replace the <img></img> with the images key and <image> tag
  258. Args:
  259. example: The input example
  260. """
  261. # Parse <img></img> format images and merged into images key
  262. if self.is_multimodal in {True, None}: # If False, do not perform replace_img_tag
  263. example['messages'], images_path = replace_img_tag(
  264. example.get('messages'), '<image>')
  265. if example.get('images') and images_path:
  266. raise ValueError('Do not mix use the <img></img> tag and <image> tag.')
  267. example['images'] = example.get('images') or [] + images_path
  268. # audio, video
  269. if self.is_multimodal in {True, None}:
  270. for k, tag, pattern in zip(['audios', 'videos'], ['<audio>', '<video>'],
  271. [r'<audio>(.+?)</audio>', r'<video>(.+?)</video>']):
  272. example['messages'], medias_path = replace_img_tag(
  273. example.get('messages'), tag, pattern)
  274. example[k] = example.get(k) or [] + medias_path
  275. def _preprocess_media(self, example):
  276. """Preprocess multi-modal media resources in one example
  277. 1. Wrap all values in media keys to list
  278. 2. Replace <img></img> tags
  279. 3. Add or check missing tags to examples
  280. 4. Parse the string field in the `objects` field to jsons
  281. 5. Load images if needed
  282. Args:
  283. example: The input example
  284. """
  285. multimodal_keys = {
  286. 'audio': 'audios',
  287. 'image': 'images',
  288. 'video': 'videos',
  289. }
  290. # Format media_keys to list
  291. for media_key in multimodal_keys.values():
  292. if example.get(media_key) and not isinstance(example[media_key], (tuple, list)):
  293. # change images field to list
  294. example[media_key] = [example[media_key]]
  295. self.replace_media_tags(example)
  296. # Add default tags to examples to note where to put the medias into the sequence
  297. self.add_default_tags(example)
  298. # Format objects(groundings/refs) to json
  299. if example.get('objects') and isinstance(example['objects'], str):
  300. # reload grounding from str
  301. example['objects'] = json.loads(example['objects'])
  302. objects = []
  303. for object in example['objects']:
  304. # Compatible with list format
  305. if isinstance(object, list):
  306. object = {
  307. 'caption': object[0],
  308. 'bbox': object[1],
  309. 'bbox_type': None,
  310. 'image': 0,
  311. }
  312. objects.append(object)
  313. example['objects'] = objects
  314. # Load image into PIL format
  315. images = example.get('images') or []
  316. if images:
  317. if example.get('objects') or self.load_medias:
  318. images = load_batch(images, load_image) # base64/local_path -> PIL.Image
  319. if example.get('objects'):
  320. # Normalize grounding bboxes
  321. self.normalize_bbox(example['objects'], images, to_type=self.grounding_type)
  322. if self.load_medias and self.grounding_type != 'real':
  323. images = [rescale_image(img, self.rescale_image) for img in images]
  324. if not self.load_medias: # fix pt & qwen-vl
  325. images = decode_base64(images=images)['images'] # PIL.Image/base64 -> local_path
  326. example['images'] = images
  327. def preprocess(self, example):
  328. # Duplicate example and create a new one to prepare in-place changes
  329. example = example.copy()
  330. template_type: Optional[str] = getattr(self, 'template_type', None)
  331. tools: Union[List[Any], str] = example.get('tools') or []
  332. # Template needs to be initialized
  333. if not self._is_init:
  334. raise ValueError(
  335. 'Template is not initialized, please use the `get_template` function to obtain the template.')
  336. messages = example['messages']
  337. system_round = [message for message in messages if message['role'] == 'system']
  338. messages = [message for message in messages if message['role'] != 'system']
  339. # Reset system (by default value and agent tools)
  340. system: Optional[str] = system_round[0]['content'] if system_round else ''
  341. if not system:
  342. if self.use_default_system:
  343. system = self.default_system
  344. else:
  345. assert self.system_prefix is not None, (
  346. f'The template does not support `system`, template_type: {template_type}')
  347. if tools:
  348. if isinstance(tools, str):
  349. tools = json.loads(tools)
  350. if system is None:
  351. system = ''
  352. system += get_tools_prompt(tools, self.tools_prompt)
  353. if system:
  354. if not system_round:
  355. system_round = [{'role': 'system', 'content': None}]
  356. system_round[0]['content'] = system
  357. if len(messages) > 1:
  358. assert self.support_multi_round, (
  359. f'The template does not support multi-round chat, template_type: {template_type}')
  360. example['messages'] = system_round + messages
  361. self._preprocess_media(example)
  362. # Check the example that whether matching the very template's rules
  363. self.check_example(example)
  364. return example
  365. def encode(self, example: Dict[str, Any], streaming: bool = False, is_training: bool = False, **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  366. """The entrance method of Template!
  367. Args:
  368. example: The input example
  369. streaming: If is streaming mode
  370. is_training: Use template in training
  371. **kwargs:
  372. model: The model instance, use only in `is_training=False`
  373. Returns:
  374. if not streaming mode, returns tuple of (example, tokenizer_kwargs), else return example only
  375. """
  376. example = self.preprocess(example)
  377. res = self._encode(example, **kwargs)
  378. inputs = res[0]
  379. if not is_training and '_data' in inputs:
  380. model = kwargs.get('model')
  381. assert model is not None
  382. data = inputs.pop('_data')
  383. data = to_device(data, model.device)
  384. inputs.update(self.post_encode(model, data))
  385. return res if not streaming else inputs
  386. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  387. """return: inputs, tokenizer_kwargs"""
  388. messages = example['messages']
  389. is_multi_modal: bool = any([example.get(key) for key in Template.special_keys])
  390. inputs, tokenizer_kwargs = self._concat_and_tokenize(
  391. messages,
  392. self.truncation_strategy,
  393. auto_add_bos=self.auto_add_bos,
  394. is_multi_modal=is_multi_modal,
  395. example=example)
  396. if inputs.get('labels') is None:
  397. inputs.pop('loss_scale', None)
  398. return inputs, tokenizer_kwargs
  399. def _concat_context_list(
  400. self,
  401. context_list: List[Context],
  402. res_context_list: List[Context], # inplace
  403. loss_scale_list: List[float], # inplace
  404. system: Optional[str] = None,
  405. query: Optional[str] = None,
  406. response: Optional[str] = None,
  407. round0: Optional[int] = None,
  408. compute_loss: bool = True) -> None:
  409. """Concat context list and replace placeholder"""
  410. round1 = None
  411. if round0 is not None:
  412. round1 = str(round0 + 1)
  413. round0 = str(round0)
  414. for context in context_list:
  415. if isinstance(context, str):
  416. if '{{RESPONSE}}' == context:
  417. assert response is not None
  418. if compute_loss:
  419. content_part, weight_part = self.loss_scale(query, response)
  420. else:
  421. content_part, weight_part = [response], [0.]
  422. res_context_list.extend(content_part)
  423. loss_scale_list.extend(weight_part)
  424. continue
  425. old_str_list = ['{{SYSTEM}}', '{{QUERY}}', '{{ROUND0}}', '{{ROUND1}}']
  426. new_str_list = [system, query, round0, round1]
  427. for (old_str, new_str) in zip(old_str_list, new_str_list):
  428. if new_str is not None and old_str in context:
  429. context = context.replace(old_str, new_str)
  430. if len(context) == 0:
  431. continue
  432. res_context_list.append(context)
  433. loss_scale_list.append(0.)
  434. def _simplify_context_list(self, context_list: List[Context], loss_scale_list: List[float],
  435. **kwargs) -> Tuple[List[Context], List[float]]:
  436. """Merge anything in the context to simplify the inputs"""
  437. is_multi_modal: bool = kwargs.pop('is_multi_modal', False)
  438. if is_multi_modal:
  439. context_list, loss_scale_list = self.split_special_tokens(context_list, loss_scale_list)
  440. context_list, loss_scale_list = self.pre_tokenize(context_list, loss_scale_list, **kwargs)
  441. res: List[Context] = [] # result of context_list
  442. res_loss_scale: List[float] = [] # result of loss_scale_list
  443. temp: List[str] = []
  444. temp_loss_scale = 0.
  445. for i, (context, loss_scale) in enumerate(zip(context_list, loss_scale_list)):
  446. if isinstance(context, str) and (loss_scale == temp_loss_scale):
  447. temp.append(context)
  448. else:
  449. if len(temp) > 0:
  450. res.append(''.join(temp))
  451. res_loss_scale.append(temp_loss_scale)
  452. temp.clear()
  453. if isinstance(context, str): # loss_scale diff
  454. temp.append(context)
  455. else:
  456. res.append(context)
  457. res_loss_scale.append(loss_scale)
  458. temp_loss_scale = loss_scale
  459. if len(temp) > 0:
  460. res.append(''.join(temp))
  461. res_loss_scale.append(temp_loss_scale)
  462. return res, res_loss_scale
  463. @staticmethod
  464. def split_special_tokens(context_list: List[Context],
  465. loss_scale_list: List[float]) -> Tuple[List[Context], List[float]]:
  466. """Split special tokens, for example `<image>`, `<video>`, this will help the replace_tag operation"""
  467. from .utils import split_str_parts_by
  468. res: List[Context] = []
  469. loss_scale_res: List[float] = []
  470. for context, loss_scale in zip(context_list, loss_scale_list):
  471. contexts = []
  472. if isinstance(fetch_one(context), str):
  473. for d in split_str_parts_by(context, Template.special_tokens):
  474. contexts.extend([d['key'], d['content']])
  475. contexts = [c for c in contexts if c]
  476. res.extend(contexts)
  477. loss_scale_res.extend([loss_scale] * len(contexts))
  478. else:
  479. res.append(context)
  480. loss_scale_res.append(loss_scale)
  481. return res, loss_scale_res
  482. def _tokenize(self, context, **tokenizer_kwargs):
  483. return self.tokenizer(
  484. context, return_attention_mask=False, add_special_tokens=False, **tokenizer_kwargs)['input_ids']
  485. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  486. example: Dict[str, Any]) -> List[Context]:
  487. """Override this function to do your own replace operation.
  488. This method is used to replace standard tags like `<image>` to some tokens that the model needs.
  489. Args:
  490. media_type: The modal.
  491. index: The index of the medias, for example 0 represents the first elements in `images`
  492. example: The input example
  493. Returns:
  494. The content or input_ids after replacement.
  495. """
  496. if media_type == 'image':
  497. return self.image_placeholder
  498. elif media_type == 'video':
  499. return ['<video>']
  500. elif media_type == 'audio':
  501. return ['<audio>']
  502. def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
  503. """Replace objects referenced by the bbox to contents or input_ids. This is useful in the grounding task.
  504. Override this function to do your own replace operation.
  505. Args:
  506. index: The index in the `objects` key
  507. example: The input example
  508. Returns:
  509. The contents or input_ids replaced
  510. """
  511. objects = example.get('objects')
  512. if objects:
  513. object_ = objects[index]
  514. return [object_['caption']]
  515. else:
  516. return ['<ref-object>']
  517. def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
  518. """Replace bbox pointing to the objects to contents or input_ids. This is useful in the grounding task.
  519. Override this function to do your own replace operation.
  520. Args:
  521. index: The index in the `objects` key
  522. example: The input example
  523. Returns:
  524. The contents or input_ids replaced
  525. """
  526. objects = example.get('objects')
  527. if objects:
  528. object_ = objects[index]
  529. if isinstance(object_['bbox'][0], list):
  530. all_objects = ''
  531. for sub_object in object_['bbox']:
  532. all_objects += f'[({sub_object[0]},{sub_object[1]}),' f'({sub_object[2]},{sub_object[3]})],'
  533. all_objects = all_objects[:-1]
  534. return [all_objects]
  535. else:
  536. return [f'[({object_["bbox"][0]},{object_["bbox"][1]}),({object_["bbox"][2]},{object_["bbox"][3]})]']
  537. else:
  538. return ['<bbox>']
  539. @classmethod
  540. def normalize_bbox(cls, objects: List[Dict[str, Any]], images: List[Any],
  541. to_type: Literal['real', 'norm_1000', 'norm_1']) -> None:
  542. """Normalize bbox to needed.
  543. to_type support real/norm_1000/norm_1, which literally means the coordinates in real, or normalized by 1000,
  544. or normalized by 1.
  545. Args:
  546. objects: The objects containing the bbox
  547. images: The images list
  548. to_type: The coordinate type needed by the model.
  549. """
  550. if not objects or not images:
  551. return
  552. for object in objects:
  553. bbox = object['bbox']
  554. bbox_type = object['bbox_type']
  555. idx = object['image']
  556. image = images[idx]
  557. if bbox_type == 'real':
  558. if to_type == 'real':
  559. continue
  560. width, height = image.width, image.height
  561. if isinstance(bbox[0], list):
  562. bboxes = []
  563. for _box in bbox:
  564. bboxes.append([
  565. int(coord / dim * 999) if to_type == 'norm_1000' else coord / dim
  566. for coord, dim in zip(_box, [width, height, width, height])
  567. ])
  568. object['bbox'] = bboxes
  569. else:
  570. object['bbox'] = [
  571. int(coord / dim * 999) if to_type == 'norm_1000' else coord / dim
  572. for coord, dim in zip(bbox, [width, height, width, height])
  573. ]
  574. object['bbox_type'] = to_type
  575. elif bbox_type == 'norm_1000':
  576. if to_type == 'norm_1000':
  577. continue
  578. if to_type == 'norm_1':
  579. object['bbox'] = [coord / 999. for coord in bbox]
  580. elif to_type == 'real':
  581. width, height = image.width, image.height
  582. object['bbox'] = [
  583. int(coord / 999. * dim) for coord, dim in zip(bbox, [width, height, width, height])
  584. ]
  585. object['bbox_type'] = to_type
  586. elif bbox_type == 'norm_1':
  587. if to_type == 'norm_1':
  588. continue
  589. if to_type == 'norm_1000':
  590. object['bbox'] = [int(coord * 999) for coord in bbox]
  591. elif to_type == 'real':
  592. width, height = image.width, image.height
  593. object['bbox'] = [int(coord * dim) for coord, dim in zip(bbox, [width, height, width, height])]
  594. object['bbox_type'] = to_type
  595. def pre_tokenize(self, context_list: List[Context], loss_scale_list: List[float],
  596. **kwargs) -> Tuple[List[Context], List[float]]:
  597. """This method happens before tokenization, replace standard tags to the contents or input_ids needed by
  598. the model.
  599. Args:
  600. context_list: The content list
  601. loss_scale_list: The loss scale list
  602. Returns:
  603. The context_list and loss_scale_list after replacement.
  604. """
  605. example = kwargs.get('example') # get x_index
  606. res: List[Context] = [] # result of context_list
  607. res_loss_scale: List[float] = [] # result of loss_scale_list
  608. for k in ['image', 'video', 'audio']:
  609. example[f'{k}_index'] = 0
  610. for context, loss_scale in zip(context_list, loss_scale_list):
  611. for k in ['image', 'video', 'audio']:
  612. if context == f'<{k}>':
  613. c_list = self.replace_tag(k, example[f'{k}_index'], example)
  614. example[f'{k}_index'] += 1
  615. break
  616. else:
  617. if context == '<ref-object>':
  618. c_list = self.replace_object(example.get('object_index', 0), example)
  619. example['object_index'] = example.get('object_index', 0) + 1
  620. elif context == '<bbox>':
  621. c_list = self.replace_box(example.get('box_index', 0), example)
  622. example['box_index'] = example.get('box_index', 0) + 1
  623. else:
  624. c_list = [context]
  625. res += c_list
  626. res_loss_scale += [loss_scale] * len(c_list)
  627. return res, res_loss_scale
  628. def _encode_context_list(
  629. self,
  630. context_list: List[Context],
  631. loss_scale_list: Optional[List[float]] = None) -> Tuple[List[int], List[int], List[float], Dict[str, Any]]:
  632. """return: input_ids, labels, tokenizer_kwargs"""
  633. input_ids: List[int] = []
  634. labels: List[int] = []
  635. loss_scale: List[float] = []
  636. tokenizer_kwargs = {}
  637. if loss_scale_list is None:
  638. loss_scale_list = [0.] * len(context_list)
  639. for i, (context, loss_weight) in enumerate(zip(context_list, loss_scale_list)):
  640. if isinstance(context, str):
  641. # tokenizer_kwargs is the returned tokenizer_kwargs,
  642. # while curr_tokenizer_kwargs is the tokenizer_kwargs for the current context.
  643. curr_tokenizer_kwargs = self._get_tokenizer_kwargs(context)
  644. self._concat_tokenizer_kwargs(tokenizer_kwargs, curr_tokenizer_kwargs)
  645. token_list = self._tokenize(context, **curr_tokenizer_kwargs)
  646. else:
  647. token_list = context
  648. input_ids += token_list
  649. if loss_scale_list[i] > 0.0:
  650. labels += token_list
  651. else:
  652. labels += [-100] * len(token_list)
  653. loss_scale.extend([loss_weight] * len(token_list))
  654. return input_ids, labels, loss_scale, tokenizer_kwargs
  655. @staticmethod
  656. def use_dynamic_eos(labels: List[int], suffix_tokens_id: List[int]) -> None:
  657. suffix_len = len(suffix_tokens_id)
  658. start = 0
  659. for i in range(1, len(labels)):
  660. if labels[i - 1] >= 0 and labels[i] == -100:
  661. start = i
  662. if start > 0 and labels[i - 1] == -100 and labels[i] >= 0:
  663. # [0, 1, 2, -100(start), -100, 3(i), 4]
  664. length = i - start
  665. if length >= suffix_len:
  666. labels[start:start + suffix_len] = suffix_tokens_id
  667. def _concat_and_tokenize(self,
  668. messages: List[Dict[str, str]],
  669. truncation_strategy: str,
  670. auto_add_bos: bool = False,
  671. **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  672. """
  673. return: inputs, tokenizer_kwargs
  674. """
  675. system = [message for message in messages if message['role'] == 'system']
  676. messages = [message for message in messages if message['role'] != 'system']
  677. if len(system) > 0:
  678. system = system[0]['content']
  679. else:
  680. system = None
  681. assert len(messages) >= 1
  682. if len(messages) == 1:
  683. if messages['role'] == 'response':
  684. history = [None, messages['content']]
  685. history_roles = [None, messages['role']]
  686. else:
  687. history = [messages['content'], None]
  688. history_roles = [messages['role'], None]
  689. else:
  690. assert len(messages) % 2 == 0
  691. history = [[messages[i]['content'], messages[i+1]['content']] for i in range(len(messages) // 2)]
  692. history_roles = [[messages[i]['role'], messages[i + 1]['role']] for i in range(len(messages) // 2)]
  693. res_context_list: List[Context] = []
  694. loss_scale_list: List[float] = []
  695. if auto_add_bos:
  696. bos_token_id = self.tokenizer.bos_token_id
  697. if isinstance(bos_token_id, int) and bos_token_id in self.tokenizer.encode(''):
  698. res_context_list.append([bos_token_id])
  699. loss_scale_list.append(0.)
  700. prompt = self.prompt.copy()
  701. if system is None:
  702. prompt = [context for context in prompt if '{{SYSTEM}}' not in context]
  703. if system is None or any(['{{SYSTEM}}' in context for context in prompt]):
  704. prefix = self.prefix
  705. else:
  706. prefix = self.system_prefix
  707. self._concat_context_list(prefix, res_context_list, loss_scale_list, system=system)
  708. for i, ((q, r), (qr, rr)) in enumerate(zip(history, history_roles)):
  709. context_list = self.tool_prompt.copy() if qr == 'tool' else prompt.copy()
  710. extra_context_list = []
  711. is_suffix = False
  712. if i < len(history) - 1:
  713. context_list = [context for context in context_list if '{{SYSTEM}}' not in context]
  714. context_list.append('{{RESPONSE}}')
  715. if history[i + 1][0]:
  716. extra_context_list = self.chat_sep
  717. elif r is not None:
  718. # last response
  719. context_list.append('{{RESPONSE}}')
  720. extra_context_list = self.suffix
  721. is_suffix = True
  722. if q or r:
  723. self._concat_context_list(
  724. context_list,
  725. res_context_list,
  726. loss_scale_list,
  727. query=q,
  728. response=r,
  729. system=system,
  730. round0=i,
  731. compute_loss=self.compute_per_round_loss or is_suffix)
  732. res_context_list += extra_context_list
  733. loss_scale_list += ([1.] if is_suffix else [0.]) * len(extra_context_list)
  734. inputs = {}
  735. if self.output_prompt_answer:
  736. # tokenizer_kwargs: use prompt
  737. answer_len = len(extra_context_list) + bool(history[-1][-1] is not None)
  738. total_len = len(res_context_list)
  739. for key, _slice in zip(['answer', 'prompt'],
  740. [slice(total_len - answer_len, total_len),
  741. slice(0, total_len - answer_len)]):
  742. _res_context_list, _loss_scale_list = self._simplify_context_list(res_context_list[_slice],
  743. loss_scale_list[_slice], **kwargs)
  744. input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
  745. _res_context_list, _loss_scale_list)
  746. inputs[f'{key}_input_ids'], inputs[f'{key}_labels'] = input_ids, labels
  747. if self.loss_scale:
  748. inputs[f'{key}_loss_scale'] = loss_scale
  749. input_ids = inputs['prompt_input_ids'] + inputs['answer_input_ids']
  750. labels = inputs['prompt_labels'] + inputs['answer_labels']
  751. if history[-1][-1] is None:
  752. assert len(inputs['answer_labels']) == 0
  753. inputs['answer_labels'] = None
  754. else:
  755. res_context_list, loss_scale_list = self._simplify_context_list(res_context_list, loss_scale_list, **kwargs)
  756. input_ids, labels, loss_scale, tokenizer_kwargs = self._encode_context_list(
  757. res_context_list, loss_scale_list)
  758. if labels is not None:
  759. self.use_dynamic_eos(labels, self._encode_context_list(self.suffix)[0])
  760. if history[-1][-1] is None:
  761. labels = None
  762. if self.max_length is not None:
  763. if truncation_strategy == 'delete' and len(input_ids) > self.max_length:
  764. logger.warn(f'Current length of row({len(input_ids)}) is larger'
  765. f' than the max_length({self.max_length}), deleted.')
  766. return {}, {}
  767. input_ids = input_ids[-self.max_length:]
  768. if labels is not None:
  769. labels = labels[-self.max_length:]
  770. if loss_scale is not None:
  771. loss_scale = loss_scale[-self.max_length:]
  772. inputs['input_ids'] = input_ids
  773. inputs['labels'] = labels
  774. if self.loss_scale:
  775. inputs['loss_scale'] = loss_scale
  776. return inputs, tokenizer_kwargs
  777. def _get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
  778. """return: curr_tokenizer_kwargs"""
  779. return {}
  780. def _concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any], curr_tokenizer_kwargs: Dict[str, Any]) -> None:
  781. assert len(tokenizer_kwargs) == 0
  782. @staticmethod
  783. def pad_sequence(sequences: List[torch.Tensor],
  784. padding_value: float = 0.,
  785. padding_side: Literal['right', 'left'] = 'right') -> torch.Tensor:
  786. """Pad sequence by some side
  787. Args:
  788. sequences: The input sequences in tensor.
  789. padding_value: The padding value
  790. padding_side: The padding side
  791. Returns:
  792. A tensor after padding
  793. """
  794. padding_right = padding_side == 'right'
  795. if padding_right:
  796. return pad_sequence(sequences, batch_first=True, padding_value=padding_value)
  797. max_len = max([s.size(0) for s in sequences])
  798. padded_sequences = []
  799. for seq in sequences:
  800. pad_length = max_len - seq.size(0)
  801. pad_tuple = [0] * ((seq.dim() - 1) * 2) + [pad_length, 0]
  802. padded_seq = F.pad(seq, tuple(pad_tuple), 'constant', padding_value)
  803. padded_sequences.append(padded_seq)
  804. return torch.stack(padded_sequences)
  805. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  806. """
  807. Args:
  808. batch(`List[Dict[str, Any]]`): The input data in batch
  809. padding_to(`int`, optional): Whether padding the batch to a fixed length, if none, the batch
  810. will be padded to the `longest`
  811. """
  812. tokenizer = self.tokenizer
  813. assert tokenizer.pad_token_id is not None
  814. padding_right = self.padding_side == 'right'
  815. res = {}
  816. if 'inputs_embeds' in batch[0]:
  817. inputs_embeds = [b['inputs_embeds'] for b in batch]
  818. res['inputs_embeds'] = inputs_embeds
  819. res['attention_mask'] = [
  820. torch.ones((inputs_embeds[i].shape[0]), dtype=torch.int64) for i in range(len(inputs_embeds))
  821. ]
  822. elif 'input_ids' in batch[0]:
  823. input_ids = [torch.tensor(b['input_ids']) for b in batch]
  824. res['input_ids'] = input_ids
  825. res['attention_mask'] = [torch.ones(len(input_ids[i]), dtype=torch.int64) for i in range(len(input_ids))]
  826. for key in ['labels', 'loss_scale', 'position_ids']:
  827. if key in batch[0]:
  828. res[key] = [torch.tensor(b[key]) for b in batch]
  829. if padding_to is not None:
  830. assert 'input_ids' in res
  831. padding_len = padding_to - res['input_ids'][0].shape[-1]
  832. if padding_len > 0:
  833. for key, value in zip(['input_ids', 'attention_mask', 'labels', 'loss_scale', 'position_ids'],
  834. [tokenizer.pad_token_id, 0, -100, 0., -1]):
  835. if key in res:
  836. res[key][0] = F.pad(res[key][0], (0, padding_len) if padding_right else (padding_len, 0),
  837. 'constant', value)
  838. for key, value in zip(['input_ids', 'inputs_embeds', 'attention_mask', 'labels', 'loss_scale', 'position_ids'],
  839. [tokenizer.pad_token_id, 0., 0, -100, 0., -1]):
  840. if key in res:
  841. res[key] = self.pad_sequence(res[key], value, self.padding_side)
  842. if '_data' in batch[0]:
  843. res['_data'] = [b['_data'] for b in batch]
  844. # multimodal
  845. pixel_values = [b['pixel_values'] for b in batch if b.get('pixel_values') is not None]
  846. if len(pixel_values) > 0:
  847. res['pixel_values'] = torch.concat(pixel_values)
  848. image_sizes = [b['image_sizes'] for b in batch if b.get('image_sizes') is not None]
  849. if len(image_sizes) > 0:
  850. res['image_sizes'] = torch.concat(image_sizes)
  851. pixel_values_videos = [b['pixel_values_videos'] for b in batch if b.get('pixel_values_videos') is not None]
  852. if len(pixel_values_videos) > 0:
  853. res['pixel_values_videos'] = torch.concat(pixel_values_videos)
  854. return res
  855. @classmethod
  856. def get_generate_ids(cls, generate_ids: torch.Tensor, input_token_len: int) -> List[int]:
  857. if isinstance(generate_ids, torch.Tensor):
  858. generate_ids = generate_ids.tolist()
  859. if len(generate_ids) >= 1 and isinstance(generate_ids[0], (list, tuple)):
  860. generate_ids = generate_ids[0]
  861. return cls._get_generate_ids(generate_ids, input_token_len)
  862. @staticmethod
  863. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  864. return generate_ids[input_token_len:]
  865. @staticmethod
  866. def _is_chinese_char(cp: int) -> bool:
  867. """Checks whether CP is the codepoint of a CJK character."""
  868. # copy from transformers.generation.streamers.TextStreamer
  869. if ((0x4E00 <= cp <= 0x9FFF) or (0x3400 <= cp <= 0x4DBF) or (0x20000 <= cp <= 0x2A6DF)
  870. or (0x2A700 <= cp <= 0x2B73F) or (0x2B740 <= cp <= 0x2B81F) or (0x2B820 <= cp <= 0x2CEAF)
  871. or (0xF900 <= cp <= 0xFAFF) or (0x2F800 <= cp <= 0x2FA1F)):
  872. return True
  873. return False
  874. @classmethod
  875. def _get_safe_print_idx(cls, response: str, print_idx: int, is_finished: bool = False) -> int:
  876. if is_finished:
  877. return len(response)
  878. if response.endswith('\n') or len(response) > 0 and cls._is_chinese_char(ord(response[-1])):
  879. print_idx = len(response)
  880. else:
  881. print_idx = max(response.rfind(' ') + 1, print_idx)
  882. return print_idx
  883. def generate_ids_to_response(
  884. self,
  885. generate_ids: List[int],
  886. is_finished: bool = True,
  887. *,
  888. tokenizer_kwargs: Optional[Dict[str, Any]] = None,
  889. # only stream=True
  890. return_delta: bool = False,
  891. print_idx: Optional[List[int]] = None,
  892. first_num_space: Optional[List[int]] = None,
  893. ):
  894. if tokenizer_kwargs is None:
  895. tokenizer_kwargs = {}
  896. tokenizer = self.tokenizer
  897. if hasattr(generate_ids, 'tolist'):
  898. generate_ids = generate_ids.tolist()
  899. # avoid printing template.suffix[-1])
  900. if isinstance(self.suffix[-1], list) and (not is_finished or is_finished
  901. and generate_ids[-len(self.suffix[-1]):] == self.suffix[-1]):
  902. generate_ids = generate_ids[:-len(self.suffix[-1])]
  903. if not is_finished or is_finished and generate_ids[-1:] == [self.tokenizer.eos_token_id]:
  904. generate_ids = generate_ids[:-1]
  905. response = tokenizer.decode(generate_ids, **tokenizer_kwargs)
  906. if first_num_space is not None:
  907. # Avoid the occurrence of repeated words in sentence.
  908. res_fns = first_num_space # res_first_num_space
  909. first_num_space = first_num_space[0]
  910. cur_num_space = len(response) - len(response.lstrip(' '))
  911. if not is_finished and first_num_space == -1:
  912. first_num_space = cur_num_space
  913. res_fns[0] = first_num_space
  914. if cur_num_space < first_num_space:
  915. response = ' ' * (first_num_space - cur_num_space) + response
  916. elif cur_num_space > first_num_space:
  917. response = response[cur_num_space - first_num_space:]
  918. if isinstance(self.suffix[-1],
  919. str) and (not is_finished or is_finished and response[-len(self.suffix[-1]):] == self.suffix[-1]):
  920. idx = max(len(response) - len(self.suffix[-1]), 0)
  921. # To avoid response length being shorter than previous response length during streaming.
  922. if print_idx is not None:
  923. idx = max(idx, print_idx[0])
  924. response = response[:idx]
  925. if print_idx is not None:
  926. old_print_idx = print_idx[0]
  927. if not is_finished:
  928. # avoid printing incomplete words
  929. print_idx[0] = self._get_safe_print_idx(response, print_idx[0])
  930. response = response[:print_idx[0]]
  931. if return_delta:
  932. response = response[old_print_idx:]
  933. else:
  934. assert is_finished and not return_delta
  935. return response
  936. def post_process_generate_response(self, response: str, example: dict) -> str:
  937. return response