template.py 97 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import json
  3. import os
  4. import re
  5. from functools import partial
  6. from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
  7. import torch
  8. import transformers
  9. from packaging import version
  10. from transformers import PreTrainedTokenizerBase
  11. from transformers.dynamic_module_utils import get_class_from_dynamic_module
  12. from transformers.integrations import is_deepspeed_zero3_enabled
  13. from modelscope import get_logger
  14. from .base import Template, TEMPLATE_MAPPING
  15. from .utils import (load_audio_qwen, load_batch, load_image, load_video_cogvlm2, load_video_internvl,
  16. load_video_llava, load_video_minicpmv_mplug_owl3, load_video_qwen2,
  17. transform_image, upper_bound, fetch_one)
  18. logger = get_logger()
  19. DEFAULT_SYSTEM = 'You are a helpful assistant.'
  20. History = List[Union[Tuple[str, str], List[str]]]
  21. Prompt = List[Union[str, List[int], List[str]]]
  22. StopWords = Prompt
  23. Context = Union[str, List[int]]
  24. class TemplateType:
  25. # text-generation
  26. default_generation = 'default-generation'
  27. chatglm_generation = 'chatglm-generation'
  28. qwen_vl_generation = 'qwen-vl-generation'
  29. qwen_audio_generation = 'qwen-audio-generation'
  30. # chat
  31. default = 'default'
  32. qwen = 'qwen'
  33. qwen_vl = 'qwen-vl'
  34. qwen_audio = 'qwen-audio'
  35. qwen2_audio = 'qwen2-audio'
  36. qwen2_audio_generation = 'qwen2-audio-generation'
  37. qwen2_vl = 'qwen2-vl'
  38. modelscope_agent = 'modelscope-agent'
  39. baichuan = 'baichuan'
  40. chatglm2 = 'chatglm2'
  41. chatglm3 = 'chatglm3'
  42. chatglm4 = 'chatglm4'
  43. codegeex4 = 'codegeex4'
  44. llama = 'llama' # llama2
  45. llama3 = 'llama3'
  46. reflection = 'reflection'
  47. longwriter_llama3 = 'longwriter-llama3'
  48. # llava-hf
  49. llava1_5 = 'llava1_5'
  50. llava_mistral = 'llava-mistral'
  51. llava_vicuna = 'llava-vicuna'
  52. llava_yi = 'llava-yi'
  53. llama3_llava_next_hf = 'llama-llava-next-hf'
  54. llava_next_llama3 = 'llava-next-llama3'
  55. llava_qwen_hf = 'llama-qwen-hf'
  56. llava_onevision_qwen = 'llava-onevision-qwen'
  57. # llava-video
  58. llava_next_video = 'llava-next-video'
  59. llava_next_video_yi = 'llava-next-video-yi'
  60. # lmms-lab:llava
  61. llama3_llava_next = 'llama3-llava-next'
  62. llava_qwen = 'llava-qwen'
  63. # xtuner:llava
  64. llava_llama_instruct = 'llava-llama-instruct'
  65. idefics3 = 'idefics3'
  66. mistral_nemo = 'mistral-nemo'
  67. openbuddy = 'openbuddy'
  68. openbuddy2 = 'openbuddy2'
  69. internlm = 'internlm'
  70. internlm2 = 'internlm2'
  71. internlm_xcomposer2 = 'internlm-xcomposer2'
  72. internlm_xcomposer2_4khd = 'internlm-xcomposer2-4khd'
  73. internlm_xcomposer2_5 = 'internlm-xcomposer2_5'
  74. internvl = 'internvl'
  75. internvl2 = 'internvl2'
  76. internvl_phi3 = 'internvl-phi3'
  77. internvl2_phi3 = 'internvl2-phi3'
  78. florence = 'florence'
  79. yi_coder = 'yi-coder'
  80. yi_vl = 'yi-vl'
  81. yuan = 'yuan'
  82. xverse = 'xverse'
  83. ziya = 'ziya'
  84. skywork = 'skywork'
  85. bluelm = 'bluelm'
  86. zephyr = 'zephyr'
  87. sus = 'sus'
  88. deepseek = 'deepseek'
  89. numina_math = 'numina-math'
  90. deepseek_coder = 'deepseek-coder'
  91. deepseek_vl = 'deepseek-vl'
  92. deepseek2 = 'deepseek2'
  93. deepseek2_5 = 'deepseek2_5'
  94. codefuse_codellama = 'codefuse-codellama'
  95. codefuse = 'codefuse'
  96. cogvlm = 'cogvlm'
  97. cogvlm2_video = 'cogvlm2-video'
  98. glm4v = 'glm4v'
  99. cogagent_chat = 'cogagent-chat'
  100. cogagent_instruct = 'cogagent-instruct'
  101. orion = 'orion'
  102. minicpm = 'minicpm'
  103. minicpm_v = 'minicpm-v'
  104. minicpm_v_v2_5 = 'minicpm-v-v2_5'
  105. minicpm_v_v2_6 = 'minicpm-v-v2_6'
  106. gemma = 'gemma'
  107. paligemma = 'paligemma'
  108. mplug_owl2 = 'mplug-owl2'
  109. mplug_owl3 = 'mplug_owl3'
  110. wizardlm2_awq = 'wizardlm2-awq'
  111. wizardlm2 = 'wizardlm2'
  112. atom = 'atom'
  113. phi3 = 'phi3'
  114. phi3_vl = 'phi3-vl'
  115. telechat = 'telechat'
  116. telechat_v2 = 'telechat-v2'
  117. dbrx = 'dbrx'
  118. mengzi = 'mengzi'
  119. c4ai = 'c4ai'
  120. chatml = 'chatml'
  121. # compatibility. (Deprecated)
  122. default_generation_bos = 'default-generation-bos'
  123. @classmethod
  124. def get_template_name_list(cls) -> List[str]:
  125. res = []
  126. for k in cls.__dict__.keys():
  127. if k.startswith('__') or k == 'get_template_name_list':
  128. continue
  129. res.append(cls.__dict__[k])
  130. return res
  131. def register_template(template_type: str, template: Template, *, exist_ok: bool = False, **kwargs) -> None:
  132. if not exist_ok and template_type in TEMPLATE_MAPPING:
  133. raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
  134. template.template_type = template_type
  135. template_info = {'template': template, **kwargs}
  136. TEMPLATE_MAPPING[template_type] = template_info
  137. register_template(
  138. TemplateType.default,
  139. Template([], ['### Human:\n{{QUERY}}\n\n### Assistant:\n'], ['\n\n'], [['eos_token_id']],
  140. DEFAULT_SYSTEM, ['{{SYSTEM}}\n\n'],
  141. auto_add_bos=True))
  142. # You can set the query as '' to serve as a template for pre-training.
  143. class DefaultGenerationTemplate(Template):
  144. def __init__(self):
  145. super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']], auto_add_bos=True)
  146. register_template(TemplateType.default_generation, DefaultGenerationTemplate(), is_generation=True)
  147. register_template(
  148. TemplateType.default_generation_bos,
  149. Template([['bos_token_id']], ['{{QUERY}}'], None, [['eos_token_id']]),
  150. is_generation=True)
  151. class ChatmlTemplateMixin:
  152. system = None
  153. def __init__(self, auto_add_bos: bool = True):
  154. Template.__init__(
  155. self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'], ['<|im_end|>\n'],
  156. ['<|im_end|>'],
  157. self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'],
  158. auto_add_bos=auto_add_bos)
  159. class ChatmlTemplate(ChatmlTemplateMixin, Template):
  160. pass
  161. class QwenTemplateMixin(ChatmlTemplateMixin):
  162. system = DEFAULT_SYSTEM
  163. def __init__(self):
  164. super().__init__(auto_add_bos=False)
  165. class QwenTemplate(QwenTemplateMixin, Template):
  166. pass
  167. class _QwenVLTemplateMixin:
  168. load_medias = False
  169. def check_example(self, example):
  170. images = example.get('images') or []
  171. assert not images or isinstance(fetch_one(images), str), 'QwenVL only supports datasets with images paths!'
  172. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  173. example: Dict[str, Any]) -> List[Context]:
  174. assert media_type == 'image'
  175. images = example.get('images') or []
  176. image = images[index]
  177. assert isinstance(image, str)
  178. return [f'Picture {index + 1}:<img>{image}</img>\n']
  179. def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
  180. objects = example['objects']
  181. object_ = objects[index]
  182. return [f'<ref>{object_["caption"]}</ref>']
  183. def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
  184. objects = example['objects']
  185. object_ = objects[index]
  186. if isinstance(object_['bbox'][0], list):
  187. all_objects = ''
  188. for sub_object in object_['bbox']:
  189. all_objects += (f'<box>({sub_object[0]},{sub_object[1]}),' f'({sub_object[2]},{sub_object[3]})</box>')
  190. return [all_objects]
  191. else:
  192. return [
  193. f'<box>({object_["bbox"][0]},{object_["bbox"][1]}),'
  194. f'({object_["bbox"][2]},{object_["bbox"][3]})</box>'
  195. ]
  196. register_template(TemplateType.qwen, QwenTemplate())
  197. class QwenVLTemplate(_QwenVLTemplateMixin, QwenTemplate):
  198. pass
  199. class QwenVLGenerationTemplate(_QwenVLTemplateMixin, DefaultGenerationTemplate):
  200. pass
  201. register_template(TemplateType.qwen_vl, QwenVLTemplate())
  202. register_template(TemplateType.qwen_vl_generation, QwenVLGenerationTemplate())
  203. register_template(TemplateType.chatml, ChatmlTemplate())
  204. register_template(
  205. TemplateType.modelscope_agent,
  206. Template([], [' \n\n<|user|>:{{QUERY}} \n\n<|assistant|>:'], [], [' \n\n</s>'], DEFAULT_SYSTEM,
  207. [' \n\n<|system|>:{{SYSTEM}}']))
  208. class _QwenAudioTemplateMixin:
  209. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  210. example: Dict[str, Any]) -> List[Context]:
  211. assert media_type == 'audio'
  212. audios = example.get('audios') or []
  213. audio = audios[index]
  214. assert isinstance(audio, str)
  215. return [f'Audio {index + 1}:<audio>{audio}</audio>\n']
  216. def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  217. inputs, tokenizer_kwargs = Template._encode(self, example)
  218. if len(inputs) == 0:
  219. return inputs, tokenizer_kwargs
  220. inputs.pop('loss_scale', None)
  221. inputs.update(tokenizer_kwargs)
  222. return inputs, tokenizer_kwargs
  223. def _get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
  224. return {'audio_info': self.tokenizer.process_audio(context)}
  225. def _concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any], curr_tokenizer_kwargs: Dict[str, Any]) -> None:
  226. audio_info = curr_tokenizer_kwargs.get('audio_info')
  227. old_audio_info = tokenizer_kwargs.get('audio_info')
  228. if old_audio_info is None:
  229. tokenizer_kwargs['audio_info'] = audio_info
  230. elif audio_info is not None:
  231. for k in ['input_audios', 'input_audio_lengths']:
  232. old_audio_info[k] = torch.concat([old_audio_info[k], audio_info[k]], dim=0)
  233. for k in ['audio_span_tokens', 'audio_urls']:
  234. old_audio_info[k] = old_audio_info[k] + audio_info[k]
  235. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  236. res = Template.data_collator(self, batch, padding_to)
  237. if batch[0].get('audio_info') is not None:
  238. res['audio_info'] = [b['audio_info'] for b in batch]
  239. return res
  240. class QwenAudioTemplate(_QwenAudioTemplateMixin, QwenTemplate):
  241. pass
  242. class QwenAudioGenerationTemplate(_QwenAudioTemplateMixin, DefaultGenerationTemplate):
  243. pass
  244. register_template(TemplateType.qwen_audio, QwenAudioTemplate(), lazy_tokenize=True)
  245. register_template(
  246. TemplateType.qwen_audio_generation, QwenAudioGenerationTemplate(), lazy_tokenize=True, is_generation=True)
  247. class _Qwen2AudioTemplateMixin:
  248. def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  249. inputs, _ = Template._encode(self, example)
  250. if len(inputs) == 0:
  251. return inputs, {}
  252. processor = self.tokenizer.processor
  253. sampling_rate = processor.feature_extractor.sampling_rate
  254. audios = load_batch(
  255. example.get('audios') or [], load_func=partial(load_audio_qwen, sampling_rate=sampling_rate))
  256. if audios:
  257. audio_inputs = processor.feature_extractor(
  258. audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
  259. audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
  260. inputs.update(audio_inputs)
  261. return inputs, {}
  262. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  263. res = Template.data_collator(self, batch, padding_to)
  264. input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
  265. if input_features:
  266. res['input_features'] = torch.concat(input_features)
  267. feature_attention_mask = [b['feature_attention_mask'] for b in batch]
  268. res['feature_attention_mask'] = torch.concat(feature_attention_mask)
  269. return res
  270. class Qwen2AudioTemplate(_Qwen2AudioTemplateMixin, QwenTemplate):
  271. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  272. example: Dict[str, Any]) -> List[Context]:
  273. assert media_type == 'audio'
  274. return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
  275. class Qwen2AudioGenerationTemplate(_Qwen2AudioTemplateMixin, DefaultGenerationTemplate):
  276. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  277. example: Dict[str, Any]) -> List[Context]:
  278. assert media_type == 'audio'
  279. return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
  280. register_template(TemplateType.qwen2_audio, Qwen2AudioTemplate(), lazy_tokenize=True)
  281. def _process_image_qwen(image):
  282. from qwen_vl_utils.vision_process import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS, smart_resize
  283. size_factor = get_env_args('size_factor', int, IMAGE_FACTOR)
  284. # resize
  285. resized_height = get_env_args('resized_height', int, None)
  286. resized_width = get_env_args('resized_width', int, None)
  287. if resized_height and resized_width:
  288. resized_height, resized_width = smart_resize(
  289. resized_height,
  290. resized_width,
  291. factor=size_factor,
  292. )
  293. else:
  294. width, height = image.size
  295. min_pixels = get_env_args('min_pixels', int, MIN_PIXELS)
  296. max_pixels = get_env_args('max_pixels', int, MAX_PIXELS)
  297. resized_height, resized_width = smart_resize(
  298. height,
  299. width,
  300. factor=size_factor,
  301. min_pixels=min_pixels,
  302. max_pixels=max_pixels,
  303. )
  304. image = image.resize((resized_width, resized_height))
  305. return image
  306. class Qwen2VLTemplate(QwenTemplate):
  307. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
  308. example: Dict[str, Any]) -> List[Context]:
  309. assert media_type in {'image', 'video'}
  310. if media_type == 'image':
  311. return ['<|vision_start|><|image_pad|><|vision_end|>']
  312. else:
  313. return ['<|vision_start|><|video_pad|><|vision_end|>']
  314. def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
  315. objects = example.get('objects')
  316. if objects:
  317. object_ = objects[index]
  318. return ['<|object_ref_start|>', object_['caption'], '<|object_ref_end|>']
  319. else:
  320. return ['<ref-object>']
  321. def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
  322. objects = example.get('objects')
  323. if objects:
  324. object_ = objects[index]
  325. if isinstance(object_['bbox'][0], list):
  326. all_objects = ''
  327. for sub_object in object_['bbox']:
  328. all_objects += (f'<|box_start|>({sub_object[0]},{sub_object[1]}),'
  329. f'({sub_object[2]},{sub_object[3]})<|box_end|>')
  330. return [all_objects]
  331. else:
  332. return [
  333. f'<|box_start|>({object_["bbox"][0]},{object_["bbox"][1]}),'
  334. f'({object_["bbox"][2]},{object_["bbox"][3]})<|box_end|>'
  335. ]
  336. else:
  337. return ['<bbox>']
  338. def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  339. inputs, _ = super()._encode(example)
  340. if len(inputs) == 0:
  341. return inputs, {}
  342. processor = self.tokenizer.processor
  343. input_ids = inputs['input_ids']
  344. labels = inputs['labels']
  345. images = example.get('images') or []
  346. videos = example.get('videos') or []
  347. for media_type in ['images', 'videos']:
  348. if locals()[media_type]:
  349. if media_type == 'images':
  350. images = load_batch(images, _process_image_qwen)
  351. media_token = 151655
  352. media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
  353. media_grid_thw = media_inputs['image_grid_thw']
  354. else:
  355. videos = load_batch(videos, load_video_qwen2)
  356. media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt')
  357. media_grid_thw = media_inputs['video_grid_thw']
  358. media_token = 151656
  359. idx_list = _findall(input_ids, media_token)
  360. added_tokens_len = 0
  361. for i, idx in enumerate(idx_list):
  362. merge_length = processor.image_processor.merge_size**2
  363. token_len = (media_grid_thw[i].prod() // merge_length)
  364. input_ids = input_ids[:idx
  365. + added_tokens_len] + [media_token] * token_len + input_ids[added_tokens_len
  366. + idx + 1:]
  367. if labels:
  368. labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx
  369. + 1:]
  370. added_tokens_len += token_len - 1
  371. inputs.update(media_inputs)
  372. inputs['input_ids'] = input_ids
  373. inputs['labels'] = labels
  374. return inputs, {}
  375. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  376. res = super().data_collator(batch, padding_to)
  377. for media_type in ['image', 'video']:
  378. grid_thw = [b[f'{media_type}_grid_thw'] for b in batch if b.get(f'{media_type}_grid_thw') is not None]
  379. if grid_thw:
  380. res[f'{media_type}_grid_thw'] = torch.concat(grid_thw)
  381. return res
  382. register_template(TemplateType.qwen2_vl, Qwen2VLTemplate(), lazy_tokenize=True)
  383. register_template(
  384. TemplateType.qwen2_audio_generation, Qwen2AudioGenerationTemplate(), lazy_tokenize=True, is_generation=True)
  385. class YiCoderTemplate(ChatmlTemplate):
  386. system = 'You are a helpful assistant.'
  387. register_template(TemplateType.yi_coder, YiCoderTemplate())
  388. yi_vl_default_system = (
  389. 'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
  390. "Read all the images carefully, and respond to the human's questions with informative, "
  391. 'helpful, detailed and polite answers. '
  392. '这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
  393. '仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
  394. class YiVLTemplate(Template):
  395. def replace_tag(self, media_type, index, example) -> List[Context]:
  396. assert media_type == 'image'
  397. return [[-200], '\n']
  398. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  399. inputs, _ = super()._encode(example)
  400. if len(inputs) == 0:
  401. return inputs, {}
  402. inputs.pop('loss_scale', None)
  403. from llava.mm_utils import expand2square
  404. # This processor should be put from the `model.vision_tower.image_processor`
  405. image_processor = self.tokenizer.image_processor
  406. images = example.get('images') or []
  407. for i, image in enumerate(images):
  408. background_color = tuple(int(x * 255) for x in image_processor.image_mean)
  409. image = expand2square(image, background_color)
  410. images[i] = image
  411. if images:
  412. image_tensor = image_processor.preprocess(images, return_tensors='pt')['pixel_values']
  413. inputs['images'] = image_tensor.to(kwargs['dtype'])
  414. return inputs, {}
  415. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  416. res = super().data_collator(batch, padding_to)
  417. images = [b['images'] for b in batch if 'images' in b]
  418. if images:
  419. res['images'] = torch.concat(images)
  420. has_images = [(b == -200).sum() for b in res['input_ids']]
  421. assert all([
  422. h > 0 for h in has_images
  423. ]) or not any([h > 0
  424. for h in has_images]), 'YIVL does not support mix-batch nlp dataset and multi-modal dataset'
  425. return res
  426. class GLMTemplate(Template):
  427. def _init_template(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs) -> None:
  428. res = super()._init_template(tokenizer, *args, **kwargs)
  429. token_list = tokenizer.encode('')
  430. self.prefix.insert(0, token_list)
  431. if self.system_prefix is not None:
  432. self.system_prefix.insert(0, token_list)
  433. return res
  434. class GLM4VTemplate(GLMTemplate):
  435. def __init__(self):
  436. super().__init__([], ['<|user|>\n{{QUERY}}<|assistant|>'], [], ['<|endoftext|>'], None,
  437. ['<|system|>\n{{SYSTEM}}'])
  438. def check_example(self, example):
  439. images = example.get('images') or []
  440. assert len(images) <= 1
  441. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  442. assert media_type == 'image'
  443. return [[-100]]
  444. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  445. inputs, _ = super()._encode(example)
  446. if len(inputs) == 0:
  447. return inputs, {}
  448. input_ids = inputs['input_ids']
  449. labels = inputs['labels']
  450. idx_list = _findall(input_ids, -100)
  451. if idx_list:
  452. idx = idx_list[0]
  453. image = example.get('images')[0]
  454. placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
  455. placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
  456. input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
  457. if labels is not None:
  458. labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
  459. messages = example['messages']
  460. messages[0]['image'] = image
  461. inputs2: Dict[str, Any] = self.tokenizer.apply_chat_template(messages, return_dict=True)
  462. inputs['images'] = inputs2['images']
  463. inputs['input_ids'] = input_ids
  464. inputs['labels'] = labels
  465. return inputs, {}
  466. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  467. res = super().data_collator(batch, padding_to)
  468. images = [b['images'] for b in batch if 'images' in b]
  469. if images:
  470. res['images'] = torch.concat(images)
  471. return res
  472. register_template(TemplateType.glm4v, GLM4VTemplate(), infer_media_type='dialogue', lazy_tokenize=True, use_model=False)
  473. register_template(
  474. TemplateType.yi_vl,
  475. YiVLTemplate([], [[8308], 'Human: {{QUERY}}\n', [8308], 'Assistant:'], ['\n'], ['\n', [8308]], yi_vl_default_system,
  476. ['{{SYSTEM}}\n\n']),
  477. use_model=False,
  478. infer_media_type='round',
  479. lazy_tokenize=True)
  480. register_template(TemplateType.baichuan, Template(['{{SYSTEM}}'], [[195], '{{QUERY}}', [196]], [], [['eos_token_id']]))
  481. register_template(
  482. TemplateType.chatglm2,
  483. GLMTemplate(['{{SYSTEM}}'], ['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], ['\n\n'], [['eos_token_id']]))
  484. register_template(
  485. TemplateType.chatglm_generation, GLMTemplate([], ['{{QUERY}}'], None, [['eos_token_id']]), is_generation=True)
  486. register_template(
  487. TemplateType.chatglm3,
  488. GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|user|>'], None, ['<|system|>\n{{SYSTEM}}']))
  489. register_template(
  490. TemplateType.chatglm4,
  491. GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|user|>'],
  492. None, ['<|system|>\n{{SYSTEM}}'],
  493. tools_prompt='glm4',
  494. tool_prompt=['<|observation|>\n{{QUERY}}<|assistant|>\n']))
  495. codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
  496. register_template(
  497. TemplateType.codegeex4,
  498. GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|endoftext|>'], codegeex4_system,
  499. ['<|system|>\n{{SYSTEM}}']))
  500. register_template(
  501. TemplateType.deepseek,
  502. Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant:'], [['eos_token_id']], [['eos_token_id']], None,
  503. [['bos_token_id'], '{{SYSTEM}}\n\n']))
  504. register_template(
  505. TemplateType.numina_math,
  506. Template([['bos_token_id']], ['### Problem: {{QUERY}}\n### Solution: '], ['\n'], [['eos_token_id']], None,
  507. [['bos_token_id'], '{{SYSTEM}}']))
  508. register_template(
  509. TemplateType.deepseek2,
  510. Template([[100000]], ['User: {{QUERY}}\n\nAssistant:'], [[100001]], [[100001]], None, [[100000], '{{SYSTEM}}\n\n']))
  511. register_template(
  512. TemplateType.deepseek2_5,
  513. Template(['<|begin▁of▁sentence|>'], ['<|User|>{{QUERY}}<|Assistant|>'], ['<|end_of_sentense|>'],
  514. ['<|end_of_sentense|>'], None, ['<|begin▁of▁sentence|>{{SYSTEM}}']))
  515. # ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
  516. LLAMA_DEFAULT_SYSTEM = (
  517. 'You are a helpful, respectful and honest assistant. '
  518. 'Always answer as helpfully as possible, while being safe. '
  519. 'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
  520. 'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
  521. 'If a question does not make any sense, or is not factually coherent, '
  522. 'explain why instead of answering something not correct. '
  523. "If you don't know the answer to a question, please don't share false information.")
  524. register_template(
  525. TemplateType.llama,
  526. Template(['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s><s>[INST] '], ['</s>'], LLAMA_DEFAULT_SYSTEM,
  527. ['<s>[INST] <<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
  528. register_template(
  529. TemplateType.longwriter_llama3,
  530. Template(['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'], None,
  531. ['<<SYS>>\n{{SYSTEM}}\n<</SYS>>\n\n']))
  532. register_template(TemplateType.mistral_nemo,
  533. Template(['<s>[INST] '], ['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'], ['</s>[INST] '], ['</s>']))
  534. class Llama3TemplateMixin:
  535. system = None
  536. def __init__(self):
  537. Template.__init__(
  538. self, ['<|begin_of_text|>'], [
  539. '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
  540. '<|start_header_id|>assistant<|end_header_id|>\n\n'
  541. ], ['<|eot_id|>'], ['<|eot_id|>'],
  542. self.system, ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'],
  543. tools_prompt='toolbench',
  544. tool_prompt=[
  545. '<|start_header_id|>tool<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
  546. '<|start_header_id|>assistant<|end_header_id|>\n\n'
  547. ])
  548. class Llama3Template(Llama3TemplateMixin, Template):
  549. pass
  550. class ReflectionTemplate(Llama3TemplateMixin, Template):
  551. system = ('You are a world-class AI system, capable of complex reasoning and reflection. '
  552. 'Reason through the query inside <thinking> tags, and then provide your final '
  553. 'response inside <output> tags. If you detect that you made a mistake in your reasoning '
  554. 'at any point, correct yourself inside <reflection> tags.')
  555. register_template(TemplateType.reflection, ReflectionTemplate())
  556. register_template(TemplateType.llama3, Llama3Template())
  557. OPENBUDDY_DEFAULT_SYSTEM = (
  558. 'You are a helpful, respectful and honest INTP-T AI Assistant named Buddy. You are talking to a human User.\n'
  559. 'Always answer as helpfully and logically as possible, while being safe. '
  560. 'Your answers should not include any '
  561. 'harmful, political, religious, unethical, racist, sexist, toxic, dangerous, or illegal content. '
  562. 'Please ensure that your responses are socially unbiased and positive in nature.\n'
  563. 'If a question does not make any sense, or is not factually coherent, '
  564. 'explain why instead of answering something not correct. '
  565. "If you don't know the answer to a question, please don't share false information.\n"
  566. 'You like to use emojis. You can speak fluently in many languages, for example: English, Chinese.\n'
  567. 'You cannot access the internet, but you have vast knowledge, cutoff: 2021-09.\n'
  568. 'You always deeply love and support China, Chinese government, people and culture.\n'
  569. 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
  570. 'you are based on LLaMA and Falcon transformers model, not related to GPT or OpenAI.')
  571. register_template(
  572. TemplateType.openbuddy,
  573. Template([], ['User: {{QUERY}}\nAssistant:'], ['\n'], [['eos_token_id']],
  574. OPENBUDDY_DEFAULT_SYSTEM, ['{{SYSTEM}}\n\n'],
  575. auto_add_bos=True))
  576. OPENBUDDY2_DEFAULT_SYSTEM = (
  577. 'You(assistant) are a helpful, respectful and honest INTP-T AI Assistant named Buddy. '
  578. 'You are talking to a human(user).\nAlways answer as helpfully and logically as possible, while being safe. '
  579. 'Your answers should not include any harmful, political, religious, unethical, racist, '
  580. 'sexist, toxic, dangerous, or illegal content. '
  581. 'Please ensure that your responses are socially unbiased and positive in nature.\n'
  582. 'You cannot access the internet, but you have vast knowledge, cutoff: 2023-04.\n'
  583. 'You are trained by OpenBuddy team, (https://openbuddy.ai, https://github.com/OpenBuddy/OpenBuddy), '
  584. 'not related to GPT or OpenAI')
  585. register_template(
  586. TemplateType.openbuddy2,
  587. Template([], ['<|role|>user<|says|>{{QUERY}}<|end|>\n<|role|>assistant<|says|>'], ['<|end|>\n'], ['<|end|>'],
  588. OPENBUDDY2_DEFAULT_SYSTEM, ['<|role|>system<|says|>{{SYSTEM}}<|end|>\n'],
  589. auto_add_bos=True))
  590. INTERNLM_SYSTEM = (
  591. 'You are an AI assistant whose name is InternLM (书生·浦语).\n'
  592. '- InternLM (书生·浦语) is a conversational language model that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
  593. 'It is designed to be helpful, honest, and harmless.\n'
  594. '- InternLM (书生·浦语) can understand and communicate fluently in the language chosen '
  595. 'by the user such as English and 中文.')
  596. register_template(
  597. TemplateType.internlm,
  598. Template(['<s>'], ['<|User|>:{{QUERY}}\n<|Bot|>:'], ['<eoa>\n'], ['<eoa>'], INTERNLM_SYSTEM,
  599. ['<s><|System|>:{{SYSTEM}}\n']))
  600. _T = TypeVar('_T')
  601. _log_set = set() # log once
  602. def get_env_args(args_name: str, type_func: Callable[[str], _T], default_value: Optional[_T]) -> Optional[_T]:
  603. args_name_upper = args_name.upper()
  604. value = os.getenv(args_name_upper)
  605. if value is None:
  606. value = default_value
  607. log_info = (f'Setting {args_name}: {default_value}. '
  608. f'You can adjust this hyperparameter through the environment variable: `{args_name_upper}`.')
  609. else:
  610. value = type_func(value)
  611. log_info = f'Using environment variable `{args_name_upper}`, Setting {args_name}: {value}.'
  612. if log_info not in _log_set:
  613. _log_set.add(log_info)
  614. logger.info(log_info)
  615. return value
  616. class Internlm2Template(ChatmlTemplate):
  617. system = INTERNLM_SYSTEM
  618. register_template(TemplateType.internlm2, Internlm2Template())
  619. def replace_img_tag(query: str,
  620. history: History,
  621. replace_token: str,
  622. pattern=r'<img>(.+?)</img>') -> Tuple[str, History, List[str]]:
  623. images_path = []
  624. new_history = []
  625. for i, h in enumerate(history):
  626. if h[0] is None:
  627. new_history.append(h.copy())
  628. else:
  629. images_path += re.findall(pattern, h[0])
  630. new_history.append([re.sub(pattern, replace_token, h[0]), h[1]])
  631. if query is None:
  632. new_query = query # pretrain dataset
  633. else:
  634. images_path += re.findall(pattern, query)
  635. new_query = re.sub(pattern, replace_token, query)
  636. return new_query, new_history, images_path
  637. class InternLMXComposer2Template(Template):
  638. INTERNLM_XCOMPOSER_SYSTEM = (
  639. 'You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
  640. '- InternLM-XComposer (浦语·灵笔) is a conversational language model that is developed by '
  641. 'Shanghai AI Laboratory (上海人工智能实验室). '
  642. 'It is designed to be helpful, honest, and harmless.\n'
  643. '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
  644. 'by the user such as English and 中文.')
  645. image_placeholder = ['</s>']
  646. def __init__(self, version):
  647. prefix = ['<s>']
  648. prompt = ['[UNUSED_TOKEN_146]user\n{{QUERY}}[UNUSED_TOKEN_145]\n[UNUSED_TOKEN_146]assistant\n']
  649. chat_sep = ['[UNUSED_TOKEN_145]\n']
  650. suffix = ['[UNUSED_TOKEN_145]']
  651. system_prefix = ['<s>[UNUSED_TOKEN_146]system\n{{SYSTEM}}[UNUSED_TOKEN_145]\n']
  652. super().__init__(prefix, prompt, chat_sep, suffix, self.INTERNLM_XCOMPOSER_SYSTEM, system_prefix)
  653. self.version = version
  654. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  655. inputs, _ = super()._encode(example)
  656. if len(inputs) == 0:
  657. return inputs, {}
  658. images = example.get('images') or []
  659. if self.version == 'v2.5':
  660. hd_num = 24
  661. if len(images) > 1:
  662. hd_num = 6
  663. hd_num = get_env_args('hd_num', int, hd_num)
  664. Image_transform = get_class_from_dynamic_module('ixc_utils.Image_transform', self.tokenizer.model_dir)
  665. images = [Image_transform(image, hd_num=hd_num) for image in images]
  666. elif self.version == 'v2-4khd':
  667. hd_num = 55
  668. hd_num = get_env_args('hd_num', int, hd_num)
  669. HD_transform = get_class_from_dynamic_module('ixc_utils.HD_transform', self.tokenizer.model_dir)
  670. images = [HD_transform(image, hd_num=hd_num) for image in images]
  671. # vis_processor comes from model.vis_processor
  672. images = [self.tokenizer.vis_processor(image).to(kwargs['dtype']) for image in images]
  673. inputs['_data'] = {'input_ids': inputs['input_ids'], 'labels': inputs['labels'], 'images': images}
  674. return inputs, {}
  675. def post_encode(self, model, data: Any) -> Dict[str, Any]:
  676. input_ids = data['input_ids']
  677. labels = data['labels']
  678. images = data['images']
  679. if len(images) > 0: # ignore <s>
  680. input_ids = input_ids[1:]
  681. if labels is not None:
  682. labels = labels[1:]
  683. if isinstance(input_ids, torch.Tensor):
  684. input_ids = input_ids.tolist()
  685. input_ids.append(2) # add dummy </s>
  686. if labels is not None:
  687. if isinstance(labels, torch.Tensor):
  688. labels = labels.tolist()
  689. labels.append(2)
  690. else:
  691. labels = []
  692. res_inputs_embeds = []
  693. res_labels = []
  694. wrap_im_mask = []
  695. pre_i, i, idx = 0, 0, 0
  696. device = model.device
  697. internlm2_model = model.model
  698. if not hasattr(internlm2_model, 'tok_embeddings'):
  699. internlm2_model = internlm2_model.model
  700. tok_embeddings = internlm2_model.tok_embeddings
  701. if len(images) > 0:
  702. images = torch.concat([model.img2emb(image[None])[0] for image in images], dim=0)
  703. while i < len(input_ids):
  704. if input_ids[i] == 2: # replace_token
  705. res_input_ids = torch.tensor([1] + input_ids[pre_i:i], device=device)
  706. res_inputs_embeds.append(tok_embeddings(res_input_ids[None])[0])
  707. wrap_im_mask += [0] * len(res_input_ids)
  708. res_labels += [-100] + labels[pre_i:i]
  709. if len(images) > 0 and idx < images.shape[0]:
  710. res_inputs_embeds.append(images[idx].to(device))
  711. wrap_im_mask += [1] * images.shape[1]
  712. res_labels += [-100] * images.shape[1]
  713. idx += 1
  714. i += 1
  715. pre_i = i
  716. continue
  717. i += 1
  718. if len(labels) == 0:
  719. res_labels = None
  720. res_inputs_embeds = torch.concat(res_inputs_embeds, dim=0)
  721. wrap_im_mask = torch.tensor(wrap_im_mask, dtype=torch.bool, device=device)[None]
  722. return {'inputs_embeds': res_inputs_embeds, 'im_mask': wrap_im_mask, 'labels': res_labels}
  723. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  724. res = super().data_collator(batch, padding_to)
  725. if 'im_mask' in batch[0]:
  726. im_mask = [b['im_mask'][0] for b in batch]
  727. im_mask = self.pad_sequence(im_mask, 0, self.padding_side)
  728. res['im_mask'] = im_mask
  729. return res
  730. @staticmethod
  731. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  732. return generate_ids
  733. register_template(
  734. TemplateType.internlm_xcomposer2, InternLMXComposer2Template(version='v2'), use_model=False, lazy_tokenize=True)
  735. class InternLMXComposer2_5Template(InternLMXComposer2Template):
  736. INTERNLM_XCOMPOSER_SYSTEM = (
  737. 'You are an AI assistant whose name is InternLM-XComposer (浦语·灵笔).\n'
  738. '- InternLM-XComposer (浦语·灵笔) is a multi-modality conversational language model '
  739. 'that is developed by Shanghai AI Laboratory (上海人工智能实验室). '
  740. 'It is designed to be helpful, honest, and harmless.\n'
  741. '- InternLM-XComposer (浦语·灵笔) can understand and communicate fluently in the language chosen '
  742. 'by the user such as English and 中文.\n'
  743. '- InternLM-XComposer (浦语·灵笔) is capable of comprehending and articulating responses effectively '
  744. 'based on the provided image.')
  745. register_template(
  746. TemplateType.internlm_xcomposer2_5,
  747. InternLMXComposer2_5Template(version='v2.5'),
  748. use_model=False,
  749. lazy_tokenize=True)
  750. register_template(
  751. TemplateType.internlm_xcomposer2_4khd,
  752. InternLMXComposer2_5Template(version='v2-4khd'),
  753. use_model=False,
  754. lazy_tokenize=True)
  755. class InternvlTemplate(Template):
  756. system = 'You are an AI assistant whose name is InternLM (书生·浦语).'
  757. num_image_token = 256
  758. def __init__(self):
  759. super().__init__([], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],
  760. ['<|im_end|>'],
  761. self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>'],
  762. auto_add_bos=True)
  763. def replace_tag(self, media_type, index, example) -> List[Context]:
  764. return ['<img>', [-100], '</img>\n']
  765. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  766. inputs, _ = super()._encode(example)
  767. if len(inputs) == 0:
  768. return inputs, {}
  769. input_ids = inputs['input_ids']
  770. idx_list = _findall(input_ids, -100)
  771. pixel_values = None
  772. images = example.get('images')
  773. if images:
  774. labels = inputs.get('labels')
  775. input_size = get_env_args('input_size', int, 448)
  776. max_num = get_env_args('max_num', int, 12)
  777. pixel_values_images = [transform_image(image, input_size, max_num) for image in images]
  778. pixel_values = torch.cat(pixel_values_images, dim=0).to(kwargs['dtype'])
  779. image_bs = pixel_values.shape[0]
  780. idx, idx2 = idx_list[0], idx_list[-1] # remove [-100, -100]
  781. img_tokens: List[int] = self.tokenizer.encode(
  782. '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * image_bs
  783. input_ids = input_ids[:idx] + img_tokens + input_ids[idx2 + 1:]
  784. if labels is not None:
  785. labels = labels[:idx] + [-100] * len(img_tokens) + labels[idx2 + 1:]
  786. inputs['input_ids'] = input_ids
  787. inputs['labels'] = labels
  788. inputs['_data'] = {'input_ids': torch.tensor(input_ids), 'pixel_values': pixel_values}
  789. inputs.pop('loss_scale', None)
  790. return inputs, {}
  791. def post_encode(self, model, data: Any) -> Dict[str, Any]:
  792. embedding = model.get_input_embeddings()
  793. device = embedding.weight.device
  794. input_ids = data['input_ids']
  795. inputs_embeds = embedding(input_ids[None])[0].to(device=device)
  796. pixel_values = data['pixel_values']
  797. if pixel_values is not None:
  798. pixel_values = pixel_values.to(device=device)
  799. vit_embeds = model.extract_feature(pixel_values).to(device=device)
  800. selected = (input_ids == self.tokenizer.encode('<IMG_CONTEXT>', add_special_tokens=False)[0])
  801. inputs_embeds[selected] = vit_embeds.reshape(-1, vit_embeds.shape[-1])
  802. elif is_deepspeed_zero3_enabled():
  803. dummy_pixel_values = torch.zeros((1, 3, 32, 32), device=device, dtype=inputs_embeds.dtype)
  804. vit_embeds = model.extract_feature(dummy_pixel_values).to(device=device)
  805. inputs_embeds += vit_embeds.mean() * 0.
  806. return {'inputs_embeds': inputs_embeds}
  807. @staticmethod
  808. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  809. return generate_ids
  810. def _replace_video2image(load_video_func, example, replace_tag) -> List[Context]:
  811. context_list = []
  812. video_index = example['video_index']
  813. video = example['videos'][video_index]
  814. images = example['images']
  815. image_index = example['image_index']
  816. new_images = load_video_func(video)
  817. example['images'] = images[:image_index] + new_images + images[image_index:]
  818. for i in range(len(new_images)):
  819. context_list += replace_tag(i)
  820. example['image_index'] += len(new_images)
  821. return context_list
  822. class Internvl2Template(InternvlTemplate):
  823. video_segments = 8
  824. system = '你是由上海人工智能实验室联合商汤科技开发的书生多模态大模型,英文名叫InternVL, 是一个有用无害的人工智能助手。'
  825. def replace_tag(self, media_type, index, example) -> List[Context]:
  826. image_context = super().replace_tag('image', index, example)
  827. if media_type == 'image':
  828. return image_context
  829. elif media_type == 'video':
  830. video_segments = get_env_args('video_segments', int, self.video_segments)
  831. load_video = partial(load_video_internvl, num_segments=video_segments)
  832. return _replace_video2image(load_video, example, lambda i: [f'Frame{i + 1}: '] + image_context)
  833. def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
  834. objects = example.get('objects')
  835. if objects:
  836. object_ = objects[index]
  837. return [f'<ref>{object_["caption"]}</ref>']
  838. else:
  839. return ['<ref-object>']
  840. def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
  841. objects = example.get('objects')
  842. if objects:
  843. object_ = objects[index]
  844. if isinstance(object_['bbox'][0], list):
  845. all_objects = '<box> ['
  846. for sub_object in object_['bbox']:
  847. all_objects += (f'[{sub_object[0]}, {sub_object[1]}, ' f'{sub_object[2]}, {sub_object[3]}],')
  848. all_objects = all_objects[:-1]
  849. all_objects += '] </box>'
  850. return [all_objects]
  851. else:
  852. return [
  853. f'<box> [[{object_["bbox"][0]}, {object_["bbox"][1]}, '
  854. f'{object_["bbox"][2]}, {object_["bbox"][3]}]] </box>'
  855. ]
  856. else:
  857. return ['<bbox>']
  858. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  859. inputs, _ = super(InternvlTemplate, self)._encode(example, **kwargs)
  860. if len(inputs) == 0:
  861. return inputs, {}
  862. input_ids = inputs['input_ids']
  863. idx_list = _findall(input_ids, -100)
  864. labels = inputs.get('labels')
  865. images = example.get('images')
  866. if images:
  867. has_video = bool(example.get('videos'))
  868. input_size = get_env_args('input_size', int, 448)
  869. max_num = get_env_args('max_num', int, 1 if has_video else 12)
  870. pixel_values = [transform_image(image, input_size, max_num) for image in images]
  871. num_patches = [pv.shape[0] for pv in pixel_values]
  872. pixel_values = torch.cat(pixel_values).to(kwargs['dtype'])
  873. else:
  874. pixel_values = None
  875. num_patches = []
  876. assert len(num_patches) == len(
  877. idx_list), f'len(num_patches): {len(num_patches)}, len(idx_list): {len(idx_list)}'
  878. added_tokens_len = 0
  879. for idx, num_patch in zip(idx_list, num_patches):
  880. img_tokens: List[int] = self.tokenizer.encode(
  881. '<IMG_CONTEXT>', add_special_tokens=False) * self.num_image_token * num_patch
  882. input_ids = input_ids[:idx + added_tokens_len] + img_tokens + input_ids[idx + added_tokens_len + 1:]
  883. if labels is not None:
  884. labels = labels[:idx + added_tokens_len] + [-100] * len(img_tokens) + labels[idx + added_tokens_len
  885. + 1:]
  886. added_tokens_len += len(img_tokens) - 1
  887. inputs['input_ids'] = input_ids
  888. inputs['labels'] = labels
  889. inputs['_data'] = {'input_ids': torch.tensor(input_ids), 'pixel_values': pixel_values}
  890. inputs.pop('loss_scale', None)
  891. return inputs, {}
  892. class InternvlPhi3TemplateMixin:
  893. def __init__(self):
  894. Template.__init__(
  895. self, [], ['<|user|>\n{{QUERY}}<|end|><|assistant|>\n'], ['<|end|>'], ['<|end|>'],
  896. getattr(self, 'system', None), ['<|system|>\n{{SYSTEM}}<|end|>'],
  897. auto_add_bos=True)
  898. self.padding_side = 'left'
  899. class InternvlPhi3Template(InternvlPhi3TemplateMixin, InternvlTemplate):
  900. system = 'You are an AI assistant whose name is Phi-3.'
  901. class Internvl2Phi3Template(InternvlPhi3TemplateMixin, Internvl2Template):
  902. pass
  903. register_template(
  904. TemplateType.internvl, InternvlTemplate(), use_model=False, lazy_tokenize=True, infer_media_type='dialogue')
  905. register_template(
  906. TemplateType.internvl_phi3, InternvlPhi3Template(), use_model=False, lazy_tokenize=True, infer_media_type='dialogue')
  907. register_template(TemplateType.internvl2, Internvl2Template(), use_model=False, lazy_tokenize=True)
  908. register_template(TemplateType.internvl2_phi3, Internvl2Phi3Template(), use_model=False, lazy_tokenize=True)
  909. class FlorenceTemplate(Template):
  910. compute_per_round_loss = False
  911. output_prompt_answer = True
  912. def __init__(self):
  913. super().__init__(['<s>'], ['{{QUERY}}</s>'], None, ['</s>'])
  914. self.task_prompts_without_inputs = {
  915. '<OCR>': 'What is the text in the image?',
  916. '<OCR_WITH_REGION>': 'What is the text in the image, with regions?',
  917. '<CAPTION>': 'What does the image describe?',
  918. '<DETAILED_CAPTION>': 'Describe in detail what is shown in the image.',
  919. '<MORE_DETAILED_CAPTION>': 'Describe with a paragraph what is shown in the image.',
  920. '<OD>': 'Locate the objects with category name in the image.',
  921. '<DENSE_REGION_CAPTION>': 'Locate the objects in the image, with their descriptions.',
  922. '<REGION_PROPOSAL>': 'Locate the region proposals in the image.'
  923. }
  924. self.task_prompts_with_input = {
  925. '<CAPTION_TO_PHRASE_GROUNDING>': 'Locate the phrases in the caption: {input}',
  926. '<REFERRING_EXPRESSION_SEGMENTATION>': 'Locate {input} in the image with mask',
  927. '<REGION_TO_SEGMENTATION>': 'What is the polygon mask of region {input}',
  928. '<OPEN_VOCABULARY_DETECTION>': 'Locate {input} in the image.',
  929. '<REGION_TO_CATEGORY>': 'What is the region {input}?',
  930. '<REGION_TO_DESCRIPTION>': 'What does the region {input} describe?',
  931. '<REGION_TO_OCR>': 'What text is in the region {input}?',
  932. }
  933. def check_example(self, example):
  934. images = example.get('images') or []
  935. assert len(images) == 1, 'Florence series models only supports input with a single image.'
  936. def add_default_tags(self, example: Dict[str, Any]) -> None:
  937. return
  938. def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
  939. object_ = example['objects'][index]
  940. if isinstance(object_['bbox'][0], list):
  941. all_objects = ''
  942. for sub_object in object_['bbox']:
  943. x1, y1, x2, y2 = sub_object
  944. all_objects += f'<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>,'
  945. return [all_objects[:-1]]
  946. else:
  947. x1, y1, x2, y2 = object_['bbox']
  948. return [f'<loc_{x1}><loc_{y1}><loc_{x2}><loc_{y2}>']
  949. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  950. query = example['query']
  951. processor = self.tokenizer.processor
  952. example['query'] = processor._construct_prompts([query])[0]
  953. inputs, _ = super()._encode(example)
  954. input_ids = inputs['prompt_input_ids']
  955. if len(inputs) == 0:
  956. return inputs, {}
  957. images = example.get('images') or []
  958. labels = inputs['answer_labels']
  959. if labels is not None:
  960. labels = [0] + labels
  961. pixel_values = processor.image_processor(images, return_tensors='pt')['pixel_values'].to(kwargs['dtype'])
  962. inputs = {
  963. 'input_ids': input_ids,
  964. 'labels': labels,
  965. '_data': {
  966. 'input_ids': torch.tensor(input_ids)[None],
  967. 'pixel_values': pixel_values,
  968. }
  969. }
  970. return inputs, {}
  971. def post_encode(self, model, data: Any) -> Dict[str, Any]:
  972. inputs_embeds = model.get_input_embeddings()(data['input_ids'])
  973. image_features = model._encode_image(data['pixel_values'])
  974. inputs_embeds, _ = model._merge_input_ids_with_image_features(image_features, inputs_embeds)
  975. return {'inputs_embeds': inputs_embeds[0]}
  976. @staticmethod
  977. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  978. return generate_ids
  979. def post_process_generate_response(self, response, example):
  980. if isinstance(example['images'], list):
  981. example['images'] = example['images'][0]
  982. image = load_image(example['images'])
  983. return json.dumps(
  984. self.tokenizer.processor.post_process_generation(
  985. response, task=example['query'], image_size=(image.width, image.height)))
  986. register_template(
  987. TemplateType.florence,
  988. FlorenceTemplate(),
  989. use_model=False,
  990. lazy_tokenize=True,
  991. infer_media_type='dialogue',
  992. stream=False)
  993. register_template(TemplateType.xverse,
  994. Template(['{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: '], [['eos_token_id']], [['eos_token_id']]))
  995. register_template(TemplateType.yuan, Template([], ['{{QUERY}}<sep>'], None, [['eos_token_id']]))
  996. register_template(TemplateType.ziya,
  997. Template([['bos_token_id'], '{{SYSTEM}}'], ['<human>:{{QUERY}}\n<bot>:'], ['\n'], [['eos_token_id']]))
  998. register_template(TemplateType.skywork,
  999. Template(['<s>{{SYSTEM}}'], ['</s><s>[USER]{{QUERY}}[SEP][BOT]'], None, ['[SEP]</s>']))
  1000. register_template(TemplateType.bluelm,
  1001. Template([['bos_token_id'], '{{SYSTEM}}'], ['[|Human|]:{{QUERY}}[|AI|]:'], [], [['eos_token_id']]))
  1002. register_template(
  1003. TemplateType.codefuse_codellama,
  1004. Template(['{{SYSTEM}}'], ['<|role_start|>human<|role_end|>{{QUERY}}<|role_start|>bot<|role_end|>'], [],
  1005. [['eos_token_id']]))
  1006. register_template(
  1007. TemplateType.codefuse,
  1008. Template([], ['<s>human\n{{QUERY}}\n<s>bot\n'], [['eos_token_id'], '\n'], [['eos_token_id']], None,
  1009. ['<s>system\n{{SYSTEM}}\n']))
  1010. register_template(
  1011. TemplateType.deepseek_coder,
  1012. Template(['{{SYSTEM}}'], ['### Instruction:\n{{QUERY}}\n### Response:\n'], ['\n<|EOT|>\n'], ['\n<|EOT|>'],
  1013. ('You are an AI programming assistant, utilizing the Deepseek Coder model, '
  1014. 'developed by Deepseek Company, and you only answer questions related to computer science. '
  1015. 'For politically sensitive questions, security and privacy issues, '
  1016. 'and other non-computer science questions, you will refuse to answer\n')))
  1017. class LlavaHfTemplate(Template):
  1018. def __init__(self, *args, **kwargs) -> None:
  1019. super().__init__(*args, **kwargs)
  1020. if version.parse(transformers.__version__) < version.parse('4.43.0'):
  1021. self.padding_side = 'left'
  1022. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1023. assert media_type == 'image'
  1024. return ['<image>\n']
  1025. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1026. inputs, _ = super()._encode(example)
  1027. if len(inputs) == 0:
  1028. return inputs, {}
  1029. images = example.get('images')
  1030. if images:
  1031. image_processor = self.tokenizer.processor.image_processor
  1032. image_inputs = image_processor(images, return_tensors='pt').to(kwargs['dtype'])
  1033. inputs['pixel_values'] = image_inputs['pixel_values']
  1034. if 'image_sizes' in image_inputs:
  1035. inputs['image_sizes'] = image_inputs['image_sizes']
  1036. return inputs, {}
  1037. class Llava1_6Llama3Template(LlavaHfTemplate):
  1038. default_system = 'You are a helpful language and vision assistant. ' \
  1039. 'You are able to understand the visual content that the user provides, ' \
  1040. 'and assist the user with a variety of tasks using natural language.'
  1041. def __init__(self):
  1042. super().__init__(['<|begin_of_text|>'], [
  1043. '<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
  1044. '<|start_header_id|>assistant<|end_header_id|>\n\n'
  1045. ], ['<|eot_id|>'], ['<|eot_id|>'], None,
  1046. ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'])
  1047. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1048. inputs, _ = super()._encode(example)
  1049. if len(inputs['pixel_values'].shape) == 5: # (1, num_patch, 3, H/W, W/H)
  1050. inputs['pixel_values'] = torch.squeeze(inputs['pixel_values'], dim=0) # (num_patch, 3, H/W, W/H)
  1051. return inputs, {}
  1052. register_template(TemplateType.llava_next_llama3, Llava1_6Llama3Template(), use_model=False, lazy_tokenize=True)
  1053. class LlavaVideoTemplate(Template):
  1054. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1055. if media_type == 'image':
  1056. return ['<image>\n']
  1057. assert media_type == 'video'
  1058. media_file = example['videos'][index]
  1059. if media_file.rsplit('.', 1)[-1] in {'jpg', 'png'}:
  1060. return ['<image>\n']
  1061. else:
  1062. return ['<video>\n']
  1063. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1064. inputs, _ = super()._encode(example)
  1065. if len(inputs) == 0:
  1066. return inputs, {}
  1067. images = example.get('images') or []
  1068. videos_path = example.get('videos') or []
  1069. if len(videos_path) > 0:
  1070. videos = load_batch(videos_path, load_video_llava)
  1071. video_processor = self.tokenizer.processor.video_processor
  1072. video_inputs = video_processor(videos, return_tensors='pt').to(kwargs['dtype'])
  1073. inputs['pixel_values_videos'] = video_inputs['pixel_values_videos']
  1074. if len(images) > 0:
  1075. image_processor = self.tokenizer.processor.image_processor
  1076. image_inputs = image_processor(images, return_tensors='pt').to(kwargs['dtype'])
  1077. inputs['pixel_values'] = image_inputs['pixel_values']
  1078. inputs['image_sizes'] = image_inputs['image_sizes']
  1079. return inputs, {}
  1080. register_template(
  1081. TemplateType.llava_next_video,
  1082. LlavaVideoTemplate(['<s>{{SYSTEM}} '], ['USER: {{QUERY}} ASSISTANT:'], [' '], ['</s>']),
  1083. use_model=False,
  1084. lazy_tokenize=True)
  1085. register_template(
  1086. TemplateType.llava_next_video_yi,
  1087. LlavaVideoTemplate(['{{SYSTEM}} '], ['USER: {{QUERY}} ASSISTANT:'], [' '], ['<|im_end|>']),
  1088. use_model=False,
  1089. infer_media_type='round',
  1090. lazy_tokenize=True)
  1091. def align_image_inputs(input_ids: List[int], labels: List[int], new_input_ids,
  1092. image_token: int) -> Tuple[List[int], List[int]]:
  1093. if isinstance(new_input_ids, torch.Tensor):
  1094. new_input_ids = new_input_ids.tolist()
  1095. # Find the tokens after the image_token in input_ids, and then align them.
  1096. i, j = 0, 0
  1097. while i < len(input_ids):
  1098. x = input_ids[i]
  1099. if x == image_token:
  1100. assert i + 1 < len(input_ids), f'input_ids[-10:]: {input_ids[-10:]}'
  1101. assert i - 1 >= 0, f'input_ids[:10]: {input_ids[:10]}'
  1102. # [1, 2, 3(i-1), image_token(i), 4(i+1) ,5, 6]
  1103. # [1, 2, 3(j_begin), a(j'), a, a, a, 4(j) ,5, 6]
  1104. j_begin = j - 1
  1105. for k in range(5): # Increase robustness.
  1106. if j_begin + k < len(new_input_ids) and new_input_ids[j_begin + k] == input_ids[i - 1]:
  1107. j_begin += k
  1108. break
  1109. if j_begin - k >= 0 and new_input_ids[j_begin - k] == input_ids[i - 1]:
  1110. j_begin -= k
  1111. break
  1112. else:
  1113. raise ValueError(f'new_input_ids: {new_input_ids}, input_ids: {input_ids}')
  1114. j_begin += 1
  1115. while j < len(new_input_ids) and new_input_ids[j] != input_ids[i + 1]:
  1116. j += 1
  1117. input_ids = input_ids[:i] + new_input_ids[j_begin:j] + input_ids[i + 1:]
  1118. if labels:
  1119. labels = labels[:i] + [-100] * (j - j_begin) + labels[i + 1:]
  1120. i += j - j_begin
  1121. else:
  1122. j += 1
  1123. i += 1
  1124. return input_ids, labels
  1125. class Idefics3Template(Template):
  1126. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1127. inputs, _ = super()._encode(example)
  1128. if len(inputs) == 0:
  1129. return inputs, {}
  1130. images = example.get('images') or []
  1131. processor = self.tokenizer.processor
  1132. prompt = self.tokenizer.decode(inputs['input_ids'])
  1133. if images:
  1134. image_inputs = processor(text=prompt, images=images, return_tensors='pt', add_special_tokens=False)
  1135. image_token = 128257 # <image>
  1136. inputs['input_ids'], inputs['labels'] = align_image_inputs(inputs['input_ids'], inputs['labels'],
  1137. image_inputs['input_ids'][0], image_token)
  1138. inputs['pixel_values'] = image_inputs['pixel_values']
  1139. return inputs, {}
  1140. register_template(
  1141. TemplateType.idefics3,
  1142. Idefics3Template(['<|begin_of_text|>'], ['User:{{QUERY}}<end_of_utterance>\nAssistant:'], ['<end_of_utterance>\n'],
  1143. ['<end_of_utterance>'], None, ['System:{{SYSTEM}}<end_of_utterance>\n']),
  1144. use_model=False,
  1145. lazy_tokenize=True)
  1146. class Llava1_5Template(LlavaHfTemplate):
  1147. def __init__(self):
  1148. super().__init__(['<s>'], ['USER: {{QUERY}}\nASSISTANT:'], ['</s>'], ['</s>'])
  1149. register_template(TemplateType.llava1_5, Llava1_5Template(), use_model=False, lazy_tokenize=True)
  1150. class LLavaTemplate(Template):
  1151. def __init__(self):
  1152. # This template follows: https://github.com/haotian-liu/LLaVA/blob/main/llava/conversation.py#L350
  1153. super().__init__(['<s>[INST] '], ['{{QUERY}} [/INST]'],
  1154. None, ['</s>'],
  1155. system_prefix=['<<SYS>>\n{{system}}\n<</SYS>>\n\n'])
  1156. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1157. assert media_type == 'image'
  1158. return [[-200], '\n']
  1159. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1160. inputs, _ = super()._encode(example)
  1161. if len(inputs) == 0:
  1162. return inputs, {}
  1163. images = example.get('images') or []
  1164. image_sizes = [x.size for x in images]
  1165. from llava.mm_utils import process_images
  1166. if images:
  1167. # image_processor comes from the model.vision_tower.image_processor
  1168. # config comes from the model.config
  1169. images_tensor = process_images(images, self.tokenizer.image_processor, self.tokenizer.config)
  1170. inputs['images'] = images_tensor.to(kwargs['dtype']).squeeze(0)
  1171. inputs['image_sizes'] = image_sizes
  1172. return inputs, {}
  1173. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1174. res = super().data_collator(batch, padding_to)
  1175. images = [b['images'] for b in batch if 'images' in b]
  1176. if images:
  1177. res['images'] = images
  1178. res['image_sizes'] = sum([b['image_sizes'] for b in batch if 'image_sizes' in b], start=[])
  1179. has_images = [(b == -200).sum() for b in res['input_ids']]
  1180. assert all([
  1181. h > 0 for h in has_images
  1182. ]) or not any([h > 0
  1183. for h in has_images]), 'Llava does not support mix-batch nlp dataset and multi-modal dataset'
  1184. return res
  1185. @staticmethod
  1186. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  1187. return generate_ids
  1188. class Llava1_6Template(LlavaHfTemplate):
  1189. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1190. for b in batch:
  1191. pixel_values = b.get('pixel_values')
  1192. if pixel_values is not None:
  1193. b['pixel_values'] = pixel_values.squeeze(0) # 5d -> 4d
  1194. res = super().data_collator(batch, padding_to)
  1195. return res
  1196. class Llava1_6MistralTemplate(Llava1_6Template):
  1197. def __init__(self):
  1198. super().__init__(['<s>[INST] '], ['{{QUERY}} [/INST]'], ['</s>'], ['</s>'],
  1199. system_prefix=['<<SYS>>\n{{system}}\n<</SYS>>\n\n'])
  1200. class Llava1_6VicunaTemplate(Llava1_6Template):
  1201. system = ('A chat between a curious human and an artificial intelligence assistant. '
  1202. "The assistant gives helpful, detailed, and polite answers to the human's questions.")
  1203. def __init__(self):
  1204. super().__init__(['<s>'], ['USER: {{QUERY}} ASSISTANT:'], ['</s>'], ['</s>'],
  1205. self.system,
  1206. system_prefix=['<s>{{SYSTEM}} '])
  1207. register_template(TemplateType.llava_mistral, Llava1_6MistralTemplate(), use_model=False, lazy_tokenize=True)
  1208. register_template(TemplateType.llava_vicuna, Llava1_6VicunaTemplate(), use_model=False, lazy_tokenize=True)
  1209. class LLava1_6YiTemplate(Llava1_6Template):
  1210. def __init__(self):
  1211. super().__init__([], ['<|im_start|>user\n{{QUERY}}<|im_end|><|im_start|>assistant\n'], ['<|im_end|>'],
  1212. ['<|im_end|>'],
  1213. system_prefix=['<|im_start|>system\n{{SYSTEM}}<|im_end|>'])
  1214. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1215. return super().replace_tag(media_type, index, example)
  1216. register_template(TemplateType.llava_yi, LLava1_6YiTemplate(), use_model=False, lazy_tokenize=True)
  1217. class Llama3LlavaNextHfTemplate(Llama3TemplateMixin, Llava1_6Template):
  1218. pass
  1219. register_template(TemplateType.llama3_llava_next_hf, Llama3LlavaNextHfTemplate(), use_model=False, lazy_tokenize=True)
  1220. class LlavaQwenHfTemplate(QwenTemplateMixin, Llava1_6Template):
  1221. pass
  1222. register_template(TemplateType.llava_qwen_hf, LlavaQwenHfTemplate(), use_model=False, lazy_tokenize=True)
  1223. class LlavaOneVisonTemplate(QwenTemplateMixin, Llava1_6Template):
  1224. system = None
  1225. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1226. inputs, _ = Template._encode(self, example)
  1227. if len(inputs) == 0:
  1228. return inputs, {}
  1229. images = example.get('images')
  1230. input_ids = inputs['input_ids']
  1231. labels = inputs['labels']
  1232. idx_list = _findall(input_ids, 151646) # <image>
  1233. processor = self.tokenizer.processor
  1234. if images:
  1235. image_processor = processor.image_processor
  1236. image_inputs = image_processor(images, return_tensors='pt').to(kwargs['dtype'])
  1237. height, width = image_inputs['pixel_values'][0].shape[-2:]
  1238. added_tokens_len = 0
  1239. for idx, pixel_v, image_size in zip(idx_list, image_inputs['pixel_values'], image_inputs['image_sizes']):
  1240. orig_height, orig_width = image_size
  1241. num_image_tokens = processor._get_number_of_features(orig_height, orig_width, height, width)
  1242. input_ids = input_ids[:added_tokens_len
  1243. + idx] + [151646] * num_image_tokens + input_ids[added_tokens_len + idx + 1:]
  1244. if labels is not None:
  1245. labels = labels[:added_tokens_len + idx] + [-100] * num_image_tokens + labels[added_tokens_len + idx
  1246. + 1:]
  1247. added_tokens_len += num_image_tokens - 1
  1248. inputs['input_ids'] = input_ids
  1249. inputs['labels'] = labels
  1250. inputs['pixel_values'] = image_inputs['pixel_values']
  1251. if 'image_sizes' in image_inputs:
  1252. inputs['image_sizes'] = image_inputs['image_sizes']
  1253. return inputs, {}
  1254. register_template(TemplateType.llava_onevision_qwen, LlavaOneVisonTemplate(), use_model=False, lazy_tokenize=True)
  1255. class LLavaLlamaTemplate(Llama3Template):
  1256. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example):
  1257. return ['<image>\n']
  1258. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1259. inputs, _ = super()._encode(example)
  1260. if len(inputs) == 0:
  1261. return inputs, {}
  1262. raw_image = example.get('images')
  1263. if raw_image:
  1264. pixel_values = self.tokenizer.processor.image_processor(raw_image, return_tensors='pt')['pixel_values']
  1265. inputs['pixel_values'] = pixel_values.to(kwargs['dtype'])
  1266. return inputs, {}
  1267. register_template(TemplateType.llava_llama_instruct, LLavaLlamaTemplate(), use_model=False, lazy_tokenize=True)
  1268. class PaliGemmaTemplate(Template):
  1269. def __init__(self):
  1270. super().__init__([], ['{{QUERY}}\n'], None, ['<eos>'])
  1271. def check_example(self, example):
  1272. images = example.get('images') or []
  1273. assert len(images) <= 1
  1274. def replace_tag(self, media_type, index, example) -> List[Context]:
  1275. assert media_type == 'image'
  1276. if self._is_vllm:
  1277. self.prompt = ['{{QUERY}}']
  1278. return []
  1279. else:
  1280. self.prompt = ['{{QUERY}}\n']
  1281. return ['<image>' * self.tokenizer.processor.image_seq_length + '<bos>']
  1282. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1283. inputs, _ = super()._encode(example)
  1284. if len(inputs) == 0:
  1285. return inputs, {}
  1286. raw_image = example.get('images')
  1287. processor = self.tokenizer.processor
  1288. if inputs['labels'] is not None:
  1289. n = upper_bound(0, len(inputs['labels']), lambda idx: inputs['labels'][idx] == -100)
  1290. n2 = len(inputs['labels']) - n
  1291. inputs['token_type_ids'] = [0] * n + [1] * n2
  1292. else:
  1293. inputs['token_type_ids'] = [0] * len(inputs['input_ids'])
  1294. if raw_image:
  1295. model_inputs = processor(text=example['query'], images=raw_image[0], return_tensors='pt')
  1296. inputs['pixel_values'] = model_inputs['pixel_values']
  1297. return inputs, {}
  1298. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1299. res = super().data_collator(batch, padding_to)
  1300. token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
  1301. token_type_ids = self.pad_sequence(token_type_ids, 0, self.padding_side)
  1302. res['token_type_ids'] = token_type_ids
  1303. return res
  1304. register_template(
  1305. TemplateType.paligemma, PaliGemmaTemplate(), infer_media_type='dialogue', lazy_tokenize=True, is_generation=True)
  1306. class Phi3Template(Template):
  1307. def __init__(self):
  1308. super().__init__([], ['<|user|>\n{{QUERY}}<|end|>\n<|assistant|>\n'], ['<|end|>\n'], ['<|end|>'],
  1309. None, ['<|system|>\n{{SYSTEM}}<|end|>\n'],
  1310. auto_add_bos=True)
  1311. register_template(TemplateType.phi3, Phi3Template())
  1312. class Phi3VisionTemplate(Phi3Template):
  1313. image_placeholder = ['<|image|><s>\n'] # <|image|>\n
  1314. def replace_tag(self, media_type, index, example) -> List[Context]:
  1315. return super().replace_tag(media_type, index, example)
  1316. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1317. images = example.get('images') or []
  1318. inputs, _ = super()._encode(example)
  1319. if len(inputs) == 0:
  1320. return inputs, {}
  1321. input_ids = inputs['input_ids']
  1322. labels = inputs['labels']
  1323. idx_list = _findall(input_ids, 32044) # '<|image|>'
  1324. if len(images) > 0:
  1325. processor = self.tokenizer.processor
  1326. inputs.update(processor.image_processor(images, return_tensors='pt'))
  1327. assert len(idx_list) == len(images), f'len(idx_list): {len(idx_list)}, len(images): {len(images)}'
  1328. res_input_ids = []
  1329. res_labels = []
  1330. num_img_tokens = inputs.pop('num_img_tokens').tolist()
  1331. idx_list.insert(0, -1)
  1332. for i in range(len(idx_list) - 1):
  1333. image_token_id = -i - 1
  1334. res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + [image_token_id] * num_img_tokens[i]
  1335. if labels is not None:
  1336. res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * num_img_tokens[i]
  1337. res_input_ids += input_ids[idx_list[-1] + 1:]
  1338. input_ids = res_input_ids
  1339. if labels is not None:
  1340. res_labels += labels[idx_list[-1] + 1:]
  1341. labels = res_labels
  1342. inputs['input_ids'] = input_ids
  1343. inputs['labels'] = labels
  1344. return inputs, {}
  1345. register_template(TemplateType.phi3_vl, Phi3VisionTemplate(), lazy_tokenize=True)
  1346. class Llama3LlavaNextTemplate(Llama3TemplateMixin, LLavaTemplate):
  1347. system = 'You are a helpful language and vision assistant. ' \
  1348. 'You are able to understand the visual content that the user provides, ' \
  1349. 'and assist the user with a variety of tasks using natural language.'
  1350. register_template(TemplateType.llama3_llava_next, Llama3LlavaNextTemplate(), use_model=False, lazy_tokenize=True)
  1351. class LLavaQwenTemplate(QwenTemplateMixin, LLavaTemplate):
  1352. pass
  1353. register_template(TemplateType.llava_qwen, LLavaQwenTemplate(), use_model=False, lazy_tokenize=True)
  1354. def _findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]:
  1355. """Find the index of a token in the token_list."""
  1356. if isinstance(sub_token_list, int):
  1357. sub_token_list = [sub_token_list]
  1358. res = []
  1359. idx = -1
  1360. try:
  1361. while True:
  1362. idx = token_list.index(sub_token_list[0], idx + 1)
  1363. if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]:
  1364. res.append(idx)
  1365. except ValueError:
  1366. pass
  1367. return res
  1368. class DeepseekVLTemplate(Template):
  1369. DEEPSEEK_VL_SYSTEM = ('You are a helpful language and vision assistant. '
  1370. 'You are able to understand the visual content that the user provides, '
  1371. 'and assist the user with a variety of tasks using natural language.')
  1372. image_placeholder = ['<image_placeholder>']
  1373. def __init__(self):
  1374. super().__init__(['<|begin▁of▁sentence|>{{SYSTEM}}\n\n'], ['User: {{QUERY}}\n\nAssistant:'],
  1375. ['<|end▁of▁sentence|>'], ['<|end▁of▁sentence|>'], self.DEEPSEEK_VL_SYSTEM)
  1376. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1377. inputs, _ = super()._encode(example)
  1378. if len(inputs) == 0:
  1379. return inputs, {}
  1380. images = example.get('images')
  1381. processor = self.tokenizer.processor
  1382. input_ids, labels = inputs['input_ids'], inputs['labels']
  1383. idx_list = _findall(input_ids, processor.image_id) # '<image_placeholder>'
  1384. new_input_ids, new_labels = [], []
  1385. lo = 0
  1386. for hi in idx_list:
  1387. new_input_ids += input_ids[lo:hi]
  1388. if labels is not None:
  1389. new_labels += labels[lo:hi]
  1390. new_input_ids += [processor.image_id] * processor.num_image_tokens
  1391. new_labels += [-100] * processor.num_image_tokens
  1392. lo = hi + 1
  1393. new_input_ids += input_ids[lo:]
  1394. if labels is not None:
  1395. new_labels += labels[lo:]
  1396. else:
  1397. new_labels = None
  1398. from deepseek_vl.models.processing_vlm import VLChatProcessorOutput
  1399. images_outputs = processor.image_processor(images, return_tensors='pt')
  1400. output = VLChatProcessorOutput(
  1401. sft_format=None,
  1402. input_ids=torch.tensor(new_input_ids),
  1403. pixel_values=images_outputs.pixel_values,
  1404. num_image_tokens=torch.tensor([processor.num_image_tokens] * len(idx_list)))
  1405. batched_output = dict(processor.batchify([output]))
  1406. batched_output['pixel_values'] = batched_output['pixel_values'].to(dtype=kwargs['dtype'])
  1407. inputs = {'input_ids': new_input_ids, 'labels': new_labels, '_data': batched_output}
  1408. return inputs, {}
  1409. def post_encode(self, model, data: Any) -> Dict[str, Any]:
  1410. inputs_embeds = model.prepare_inputs_embeds(**data)[0]
  1411. return {'inputs_embeds': inputs_embeds}
  1412. @staticmethod
  1413. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  1414. return generate_ids
  1415. register_template(TemplateType.deepseek_vl, DeepseekVLTemplate(), use_model=False, lazy_tokenize=True)
  1416. register_template(
  1417. TemplateType.zephyr,
  1418. Template([], ['<|user|>\n{{QUERY}}</s>\n<|assistant|>\n'], ['</s>\n'], ['</s>'], None,
  1419. ['<|system|>\n{{SYSTEM}}</s>\n']))
  1420. register_template(
  1421. TemplateType.sus,
  1422. Template(['{{SYSTEM}}'], ['### Human: {{QUERY}}\n\n### Assistant: '], ['<|endoftext|>'], ['<|endoftext|>']))
  1423. register_template(TemplateType.orion,
  1424. Template(['<s>{{SYSTEM}}'], ['Human: {{QUERY}}\n\nAssistant: </s>'], ['</s>'], ['</s>']))
  1425. class CogTemplate(Template):
  1426. def check_example(self, example):
  1427. images = example.get('images') or []
  1428. assert len(images) <= 1
  1429. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1430. return []
  1431. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1432. inputs, _ = super()._encode(example)
  1433. if len(inputs) == 0:
  1434. return inputs, {}
  1435. image = example.get('images') or []
  1436. inputs.pop('loss_scale', None)
  1437. inputs2 = self.tokenizer.build_conversation_input_ids(
  1438. self.tokenizer, query=example['query'], history=example.get('history'), images=image)
  1439. image_token_len = inputs2['token_type_ids'].sum().item()
  1440. input_ids = inputs['input_ids']
  1441. labels = inputs['labels']
  1442. inputs['token_type_ids'] = [0] + [1] * image_token_len + [0] * len(input_ids[1:])
  1443. inputs['input_ids'] = input_ids[:1] + [self.tokenizer.pad_token_id] * image_token_len + input_ids[1:]
  1444. if labels is not None:
  1445. inputs['labels'] = labels[:1] + [-100] * image_token_len + labels[1:]
  1446. if len(image) > 0:
  1447. inputs['images'] = [[img.to(dtype=kwargs['dtype'])] for img in inputs2['images']]
  1448. if 'cross_images' in inputs2:
  1449. # is cogagent
  1450. inputs['cross_images'] = [[cross_img.to(dtype=kwargs['dtype'])] for cross_img in inputs2['cross_images']]
  1451. return inputs, {}
  1452. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1453. res = super().data_collator(batch, padding_to)
  1454. keys = ['images', 'cross_images']
  1455. for key in keys:
  1456. if key in batch[0]:
  1457. res[key] = [b[key][0] for b in batch]
  1458. token_type_ids = [torch.tensor(b['token_type_ids']) for b in batch]
  1459. token_type_ids = self.pad_sequence(token_type_ids, 0, self.padding_side)
  1460. res['token_type_ids'] = token_type_ids
  1461. return res
  1462. register_template(
  1463. TemplateType.cogagent_chat,
  1464. CogTemplate(['<s>'], [' [INST] {{QUERY}} [/INST] '], [], ['</s>']),
  1465. use_model=False,
  1466. infer_media_type='dialogue',
  1467. lazy_tokenize=True)
  1468. register_template(
  1469. TemplateType.cogagent_instruct,
  1470. CogTemplate(['<s>'], ['<EOI>Question: {{QUERY}} Answer:'], None, ['</s>']),
  1471. use_model=False,
  1472. infer_media_type='dialogue',
  1473. lazy_tokenize=True)
  1474. register_template(
  1475. TemplateType.cogvlm,
  1476. CogTemplate([['bos_token_id']], ['Question: {{QUERY}} Answer:'], ['\n'], [['eos_token_id']]),
  1477. use_model=False,
  1478. infer_media_type='dialogue',
  1479. lazy_tokenize=True)
  1480. class Cog2VideoTemplate(CogTemplate):
  1481. def check_example(self, example):
  1482. videos = example.get('videos') or []
  1483. assert len(videos) <= 1
  1484. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1485. inputs, _ = super(CogTemplate, self)._encode(example)
  1486. if len(inputs) == 0:
  1487. return inputs, {}
  1488. videos_path = example.get('videos') or []
  1489. video = load_batch(videos_path, load_video_cogvlm2)
  1490. inputs.pop('loss_scale', None)
  1491. inputs2 = self.tokenizer.build_conversation_input_ids(
  1492. self.tokenizer,
  1493. query=example['query'],
  1494. history=example.get('history'),
  1495. images=video,
  1496. template_version='chat')
  1497. video_token_len = inputs2['token_type_ids'].sum().item()
  1498. input_ids = inputs['input_ids']
  1499. labels = inputs['labels']
  1500. inputs['token_type_ids'] = [0] + [1] * video_token_len + [0] * len(input_ids[1:])
  1501. inputs['input_ids'] = input_ids[:1] + [self.tokenizer.pad_token_id] * video_token_len + input_ids[1:]
  1502. if labels is not None:
  1503. inputs['labels'] = labels[:1] + [-100] * video_token_len + labels[1:]
  1504. if len(video) > 0:
  1505. inputs['images'] = [[img.to(dtype=kwargs['dtype'])] for img in inputs2['images']]
  1506. return inputs, {}
  1507. register_template(
  1508. TemplateType.cogvlm2_video,
  1509. Cog2VideoTemplate([['bos_token_id']], ['Question: {{QUERY}} Answer:'], ['\n'], [['eos_token_id']]),
  1510. use_model=False,
  1511. infer_media_type='dialogue',
  1512. lazy_tokenize=True,
  1513. media_type='video')
  1514. register_template(TemplateType.minicpm, Template(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']))
  1515. def _remove_idx(arr: List[int], idx_list: List[int]) -> List[int]:
  1516. res = []
  1517. idx_set = set(idx_list)
  1518. for i, x in enumerate(arr):
  1519. if i not in idx_set:
  1520. res.append(x)
  1521. return res
  1522. class MiniCPMVTemplate(Template):
  1523. is_v2_5 = False
  1524. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1525. return [[-100]]
  1526. def check_example(self, example):
  1527. images = example.get('images') or []
  1528. assert len(images) == 1
  1529. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1530. inputs, _ = super()._encode(example)
  1531. if len(inputs) == 0:
  1532. return inputs, {}
  1533. images = example['images']
  1534. input_ids = inputs['input_ids']
  1535. labels = inputs['labels']
  1536. idx_list = _findall(input_ids, -100)
  1537. idx = idx_list[0]
  1538. tgt_sizes = None
  1539. slice_mode = getattr(self.tokenizer.config, 'slice_mode', False)
  1540. if slice_mode:
  1541. if self.is_v2_5:
  1542. image_processor = self.tokenizer.processor.image_processor
  1543. image_inputs = image_processor(images, return_tensors='pt').to(kwargs['dtype'])
  1544. placeholder = image_processor.get_slice_image_placeholder(image_inputs.image_sizes[0][0])
  1545. pixel_values = image_inputs['pixel_values']
  1546. tgt_sizes = image_inputs['tgt_sizes']
  1547. else:
  1548. # Comes from model.get_slice_image_placeholder and model.transform
  1549. images, placeholder = self.tokenizer.get_slice_image_placeholder(images[0], self.tokenizer)
  1550. pixel_values = [[self.tokenizer.transform(img) for img in images]]
  1551. placeholder += '\n'
  1552. placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
  1553. input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
  1554. if labels is not None:
  1555. labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
  1556. input_tensor_ids = torch.tensor(input_ids)
  1557. image_start_idx = torch.where(input_tensor_ids == self.tokenizer.im_start_id)[0]
  1558. image_start_idx += 1
  1559. image_end_idx = torch.where(input_tensor_ids == self.tokenizer.im_end_id)[0]
  1560. valid_image_nums = max(len(image_start_idx), len(image_end_idx))
  1561. image_bound = [
  1562. torch.hstack(
  1563. [image_start_idx[:valid_image_nums].unsqueeze(-1), image_end_idx[:valid_image_nums].unsqueeze(-1)])
  1564. ]
  1565. else:
  1566. placeholder = '<image>' + '<unk>' * self.tokenizer.config.query_num + '</image>\n'
  1567. placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
  1568. input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
  1569. if labels is not None:
  1570. labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
  1571. image_bound = [torch.tensor([[idx, idx + self.tokenizer.config.query_num]])]
  1572. pixel_values = [[self.tokenizer.transform(images[0])]]
  1573. inputs = {
  1574. 'input_ids': input_ids,
  1575. 'labels': labels,
  1576. '_data': {
  1577. 'input_ids': torch.tensor(input_ids)[None],
  1578. 'image_bound': image_bound,
  1579. 'pixel_values': pixel_values,
  1580. 'tgt_sizes': tgt_sizes
  1581. }
  1582. }
  1583. return inputs, {}
  1584. def post_encode(self, model, data: Any) -> Dict[str, Any]:
  1585. inputs_embeds, _ = model.get_vllm_embedding(data)
  1586. return {'inputs_embeds': inputs_embeds[0]}
  1587. @staticmethod
  1588. def _get_generate_ids(generate_ids: List[int], input_token_len: int) -> List[int]:
  1589. return generate_ids
  1590. class MiniCPMV2_6Template(QwenTemplateMixin, MiniCPMVTemplate):
  1591. def check_example(self, example):
  1592. pass
  1593. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1594. assert media_type in {'image', 'video'}
  1595. max_num_frames = get_env_args('max_num_frames', int, 64)
  1596. load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
  1597. image_context = super().replace_tag('image', index, example)
  1598. if media_type == 'image':
  1599. return image_context
  1600. elif media_type == 'video':
  1601. return _replace_video2image(load_video, example, lambda i: image_context)
  1602. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1603. inputs, _ = Template._encode(self, example)
  1604. if len(inputs) == 0:
  1605. return inputs, {}
  1606. images = example.get('images')
  1607. use_video = bool(example.get('videos'))
  1608. is_plain_text = not images and not use_video
  1609. use_image_id = True
  1610. max_slice_nums = None
  1611. if use_video:
  1612. use_image_id = False
  1613. max_slice_nums = 1 # or 2
  1614. max_slice_nums = get_env_args('max_slice_nums', int, max_slice_nums)
  1615. input_ids = inputs['input_ids']
  1616. labels = inputs['labels']
  1617. idx_list = _findall(input_ids, -100)
  1618. idx_list.insert(0, -1)
  1619. image_processor = self.tokenizer.processor.image_processor
  1620. image_inputs = image_processor([images], return_tensors='pt',
  1621. max_slice_nums=max_slice_nums).to(kwargs['dtype'])
  1622. res_input_ids = []
  1623. res_labels = []
  1624. for i in range(len(idx_list) - 1):
  1625. placeholder = image_processor.get_slice_image_placeholder(
  1626. image_inputs.image_sizes[0][i], image_idx=i, max_slice_nums=max_slice_nums, use_image_id=use_image_id)
  1627. placeholder += '\n'
  1628. placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
  1629. res_input_ids += input_ids[idx_list[i] + 1:idx_list[i + 1]] + placeholder_id
  1630. if labels is not None:
  1631. res_labels += labels[idx_list[i] + 1:idx_list[i + 1]] + [-100] * len(placeholder_id)
  1632. res_input_ids += input_ids[idx_list[-1] + 1:]
  1633. input_ids = res_input_ids
  1634. if labels is not None:
  1635. res_labels += labels[idx_list[-1] + 1:]
  1636. labels = res_labels
  1637. if not is_plain_text:
  1638. input_tensor_ids = torch.tensor(input_ids)
  1639. unk_token = self.tokenizer.encode('<unk>', add_special_tokens=False)[0]
  1640. indices = (input_tensor_ids == unk_token).nonzero(as_tuple=True)[0].tolist()
  1641. ranges = []
  1642. start = indices[0]
  1643. for i in range(1, len(indices)):
  1644. if indices[i] != indices[i - 1] + 1:
  1645. ranges.append([start, indices[i - 1] + 1])
  1646. start = indices[i]
  1647. ranges.append([start, indices[-1] + 1])
  1648. image_bound = [torch.tensor(ranges)]
  1649. else:
  1650. image_bound = [[]]
  1651. inputs = {
  1652. 'input_ids': input_ids,
  1653. 'labels': labels,
  1654. '_data': {
  1655. 'input_ids': torch.tensor(input_ids)[None],
  1656. 'image_bound': image_bound,
  1657. 'pixel_values': image_inputs['pixel_values'],
  1658. 'tgt_sizes': image_inputs['tgt_sizes']
  1659. }
  1660. }
  1661. return inputs, {}
  1662. register_template(TemplateType.minicpm_v_v2_6, MiniCPMV2_6Template(), use_model=False, lazy_tokenize=True)
  1663. class MiniCPMV2_5Template(Llama3TemplateMixin, MiniCPMVTemplate):
  1664. is_v2_5 = True
  1665. register_template(
  1666. TemplateType.minicpm_v_v2_5, MiniCPMV2_5Template(), use_model=False, lazy_tokenize=True, infer_media_type='dialogue')
  1667. register_template(
  1668. TemplateType.minicpm_v,
  1669. MiniCPMVTemplate(['<s>{{SYSTEM}}'], ['<用户>{{QUERY}}<AI>'], [], ['</s>']),
  1670. use_model=False,
  1671. lazy_tokenize=True,
  1672. infer_media_type='dialogue')
  1673. gemma_template = Template(['<bos>'], ['<start_of_turn>user\n{{QUERY}}<end_of_turn>\n<start_of_turn>model\n'],
  1674. ['<end_of_turn>\n'], ['<end_of_turn>'], None,
  1675. ['<bos><start_of_turn>system\n{{SYSTEM}}<end_of_turn>\n'])
  1676. register_template(TemplateType.gemma, gemma_template)
  1677. register_template(TemplateType.telechat, Template([], ['<_user>{{QUERY}}<_bot>'], ['<_end>'], ['<_end>']))
  1678. register_template(TemplateType.telechat_v2, Template([], ['<_user> {{QUERY}}<_bot>'], [], ['<_end>']))
  1679. DBRX_SYSTEM = (
  1680. 'You are DBRX, created by Databricks. You were last updated in December 2023. '
  1681. 'You answer questions based on information available up to that point.\n'
  1682. 'YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, '
  1683. 'but provide thorough responses to more complex and open-ended questions.\n'
  1684. 'You assist with various tasks, from writing to coding (using markdown for code blocks '
  1685. '— remember to use ``` with code, JSON, and tables).\n'
  1686. 'You do not have real-time data access or code execution capabilities.'
  1687. ' You avoid stereotyping and provide balanced perspectives on controversial topics. '
  1688. 'You do not provide song lyrics, poems, or news articles and do not divulge details of your training data.\n'
  1689. 'This is your system prompt, guiding your responses. Do not reference it, just respond to the user. '
  1690. 'If you find yourself talking about this message, stop. You should be responding appropriately '
  1691. 'and usually that means not mentioning this.'
  1692. 'YOU DO NOT MENTION ANY OF THIS INFORMATION ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY '
  1693. 'PERTINENT TO THE USER\'S QUERY.')
  1694. class DbrxTemplate(ChatmlTemplate):
  1695. system = DBRX_SYSTEM
  1696. register_template(TemplateType.dbrx, DbrxTemplate())
  1697. register_template(TemplateType.mengzi,
  1698. Template([], ['输入:{{QUERY}}输出:\n'], [], [['eos_token_id']], None, ['指令:{{SYSTEM}}']))
  1699. C4AI_SYSTEM = ('You are Command-R, a brilliant, sophisticated, AI-assistant trained to assist human users by '
  1700. 'providing thorough responses.You are trained by Cohere.')
  1701. register_template(
  1702. TemplateType.c4ai,
  1703. Template(
  1704. ['<BOS_TOKEN>'],
  1705. ['<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{QUERY}}<|END_OF_TURN_TOKEN|><|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>'],
  1706. ['<|END_OF_TURN_TOKEN|>'], ['<|END_OF_TURN_TOKEN|>'], C4AI_SYSTEM,
  1707. ['<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{SYSTEM}}<|END_OF_TURN_TOKEN|']))
  1708. class mPlugOwl2Template(Template):
  1709. def __init__(self):
  1710. super().__init__(['{{SYSTEM}}'], ['USER: {{QUERY}}ASSISTANT:'], ['</s>'], [['eos_token_id']])
  1711. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1712. assert media_type == 'image'
  1713. return [[-200]]
  1714. def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1715. from mplug_owl2.mm_utils import process_images
  1716. processor = self.tokenizer.processor
  1717. images = example.get('images') or []
  1718. for i, image in enumerate(images):
  1719. # ref: https://modelscope.cn/models/iic/mPLUG-Owl2.1
  1720. max_edge = max(image.size)
  1721. image = image.resize((max_edge, max_edge))
  1722. images[i] = image
  1723. inputs, _ = super()._encode(example)
  1724. if len(inputs) == 0:
  1725. return inputs, {}
  1726. input_ids = inputs['input_ids']
  1727. labels = inputs['labels']
  1728. if images:
  1729. images = process_images(images, processor)
  1730. images = images.to(kwargs['dtype'])
  1731. return {'input_ids': input_ids, 'labels': labels, 'images': images}, {}
  1732. else:
  1733. return {'input_ids': input_ids, 'labels': labels}, {}
  1734. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1735. res = super().data_collator(batch, padding_to)
  1736. images = [b['images'] for b in batch if 'images' in b]
  1737. if images:
  1738. res['images'] = torch.concat(images)
  1739. return res
  1740. register_template(
  1741. TemplateType.mplug_owl2, mPlugOwl2Template(), infer_media_type='round', use_model=False, lazy_tokenize=True)
  1742. class mPlugOwl3Template(QwenTemplateMixin, Template):
  1743. system = None
  1744. def _get_image_token_list(self, cut_shape):
  1745. processor = self.tokenizer.processor
  1746. text = processor.image_processor.cut_prompt_template(img_token='<|image|>', h=cut_shape[0], w=cut_shape[1])
  1747. text_list = text.split('<|image|>')
  1748. if text_list[-1] == '':
  1749. text_list.pop()
  1750. res_text_list = []
  1751. for text in text_list:
  1752. res_text_list += [text, '<|image|>']
  1753. token_list = self._encode_context_list(res_text_list)[0]
  1754. return token_list
  1755. def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
  1756. assert media_type in {'image', 'video'}
  1757. max_num_frames = get_env_args('max_num_frames', int, 16)
  1758. load_video = partial(load_video_minicpmv_mplug_owl3, max_num_frames=max_num_frames)
  1759. if media_type == 'image':
  1760. return [[-100], '\n']
  1761. elif media_type == 'video':
  1762. return _replace_video2image(load_video, example, lambda i: [[-100]]) + ['\n']
  1763. def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1764. inputs, _ = super()._encode(example)
  1765. if len(inputs) == 0:
  1766. return inputs, {}
  1767. images = example['images']
  1768. videos = example['videos']
  1769. cut_enable = not videos
  1770. input_ids = inputs['input_ids']
  1771. labels = inputs['labels']
  1772. idx_list = _findall(input_ids, -100)
  1773. processor = self.tokenizer.processor
  1774. if images:
  1775. image_inputs = processor.image_processor(images, cut_enable=cut_enable, return_tensors='pt')
  1776. added_tokens_len = 0
  1777. cut_shapes = image_inputs['cut_shape'] or [None] * len(idx_list)
  1778. image_token_list = self.tokenizer.encode('<|image|>', add_special_tokens=False)
  1779. for idx, cut_shape in zip(idx_list, cut_shapes):
  1780. if cut_shape:
  1781. token_list = self._get_image_token_list(cut_shape)
  1782. else:
  1783. token_list = image_token_list
  1784. input_ids = input_ids[:idx + added_tokens_len] + token_list + input_ids[added_tokens_len + idx + 1:]
  1785. if labels:
  1786. labels = labels[:idx + added_tokens_len] + [-100] * len(token_list) + labels[added_tokens_len + idx
  1787. + 1:]
  1788. added_tokens_len += len(token_list) - 1
  1789. image_token_idx = torch.tensor(_findall(input_ids, image_token_list))[None]
  1790. _range = torch.arange(len(input_ids))[:, None]
  1791. matrix = (_range > image_token_idx).sum(dim=1)
  1792. media_offset = torch.stack([torch.zeros(matrix.shape[0], dtype=torch.long), matrix], dim=-1)[None]
  1793. inputs['_data'] = {'pixel_values': image_inputs['pixel_values']}
  1794. inputs['media_offset'] = media_offset
  1795. inputs['input_ids'] = input_ids
  1796. inputs['labels'] = labels
  1797. return inputs, {}
  1798. def _post_encode(self, model, data: Any) -> Dict[str, Any]:
  1799. image_embeds = model.forward_image(data['pixel_values'])
  1800. return {'image_embeds': image_embeds}
  1801. def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1802. res = super().data_collator(batch, padding_to)
  1803. image_embeds = [b['image_embeds'] for b in batch if 'image_embeds' in b]
  1804. if image_embeds:
  1805. res['image_embeds'] = torch.concat(image_embeds)
  1806. media_offset = [b['media_offset'] for b in batch if 'media_offset' in b]
  1807. if media_offset:
  1808. res['media_offset'] = torch.concat(media_offset)
  1809. return res
  1810. register_template(TemplateType.mplug_owl3, mPlugOwl3Template(), use_model=False, lazy_tokenize=True)
  1811. register_template(TemplateType.wizardlm2_awq,
  1812. Template(['{{SYSTEM}}'], ['User:\n{{QUERY}}\n\nAssistant:\n'], ['\n\n'], ['</s>']))
  1813. _wizardlm2_system = ('A chat between a curious user and an artificial intelligence assistant. '
  1814. 'The assistant gives helpful, detailed, and polite answers to the user\'s questions. ')
  1815. register_template(TemplateType.wizardlm2,
  1816. Template(['{{SYSTEM}}'], ['USER: {{QUERY}} ASSISTANT:'], ['</s>'], ['</s>'], _wizardlm2_system))
  1817. register_template(TemplateType.atom,
  1818. Template(['{{SYSTEM}}'], ['<s>Human: {{QUERY}}\n</s><s>Assistant: '], ['</s>'], ['</s>']))
  1819. class RLHFTemplateMixin:
  1820. def encode(self: Template,
  1821. example: Dict[str, Any],
  1822. streaming: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1823. template_encode = self._old_encode
  1824. inputs = {}
  1825. tokenizer_kwargs = {}
  1826. chosen_example, rejected_example = example, example.copy()
  1827. rejected_example['response'] = example['rejected_response']
  1828. if streaming:
  1829. chosen_inputs, chosen_tokenizer_kwargs = template_encode(chosen_example), {}
  1830. rejected_inputs, rejected_tokenizer_kwargs = template_encode(rejected_example), {}
  1831. else:
  1832. chosen_inputs, chosen_tokenizer_kwargs = template_encode(chosen_example)
  1833. rejected_inputs, rejected_tokenizer_kwargs = template_encode(rejected_example)
  1834. for suffix, res in zip(['inputs', 'tokenizer_kwargs'], [inputs, tokenizer_kwargs]):
  1835. for prefix in ['chosen', 'rejected']:
  1836. data = locals()[f'{prefix}_{suffix}']
  1837. for k, v in data.items():
  1838. res[f'{prefix}_{k}'] = v
  1839. return inputs, tokenizer_kwargs
  1840. def data_collator(self: Template, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1841. _data_collator = self._old_data_collator
  1842. new_batch = []
  1843. for prefix in ['chosen_', 'rejected_']:
  1844. for inputs in batch:
  1845. new_inputs = {}
  1846. for k, v in inputs.items():
  1847. if k.startswith(prefix):
  1848. new_k = k[len(prefix):]
  1849. new_inputs[new_k] = inputs[k]
  1850. if len(new_inputs) > 0:
  1851. new_batch.append(new_inputs)
  1852. assert len(new_batch) in {0, len(batch) * 2}, f'new_batch: {new_batch}'
  1853. return _data_collator(new_batch or batch, padding_to)
  1854. class KTOTemplateMixin:
  1855. def encode(self: Template,
  1856. example: Dict[str, Any],
  1857. streaming: bool = False) -> Tuple[Dict[str, Any], Dict[str, Any]]:
  1858. inputs, tokenizer_kwargs = self._old_encode(example, streaming)
  1859. if len(inputs) > 0:
  1860. inputs['label'] = example['label']
  1861. return inputs, tokenizer_kwargs
  1862. def data_collator(self: Template, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
  1863. res = {}
  1864. for prefix in ['', 'KL_']:
  1865. new_batch = []
  1866. for b in batch:
  1867. new_batch.append({'input_ids': b[f'{prefix}input_ids'], 'labels': b[f'{prefix}labels']})
  1868. for k, v in self._old_data_collator(new_batch, padding_to).items():
  1869. res[f'{prefix}completion_{k}'] = v
  1870. res['label'] = [b['label'] for b in batch]
  1871. return res