# Copyright (c) Alibaba, Inc. and its affiliates.
import json
import os
import re
from functools import partial
from typing import Any, Callable, Dict, List, Literal, Optional, Tuple, TypeVar, Union
import torch
import transformers
from packaging import version
from transformers import PreTrainedTokenizerBase
from transformers.dynamic_module_utils import get_class_from_dynamic_module
from transformers.integrations import is_deepspeed_zero3_enabled
from modelscope import get_logger
from .base import Template, TEMPLATE_MAPPING
from .utils import (load_audio_qwen, load_batch, load_image, load_video_cogvlm2, load_video_internvl,
load_video_llava, load_video_minicpmv_mplug_owl3, load_video_qwen2,
transform_image, upper_bound, fetch_one)
logger = get_logger()
DEFAULT_SYSTEM = 'You are a helpful assistant.'
History = List[Union[Tuple[str, str], List[str]]]
Prompt = List[Union[str, List[int], List[str]]]
StopWords = Prompt
Context = Union[str, List[int]]
class TemplateType:
# text-generation
default_generation = 'default-generation'
chatglm_generation = 'chatglm-generation'
qwen_vl_generation = 'qwen-vl-generation'
qwen_audio_generation = 'qwen-audio-generation'
# chat
default = 'default'
qwen = 'qwen'
qwen_vl = 'qwen-vl'
qwen_audio = 'qwen-audio'
qwen2_audio = 'qwen2-audio'
qwen2_audio_generation = 'qwen2-audio-generation'
qwen2_vl = 'qwen2-vl'
modelscope_agent = 'modelscope-agent'
baichuan = 'baichuan'
chatglm2 = 'chatglm2'
chatglm3 = 'chatglm3'
chatglm4 = 'chatglm4'
codegeex4 = 'codegeex4'
llama = 'llama' # llama2
llama3 = 'llama3'
reflection = 'reflection'
longwriter_llama3 = 'longwriter-llama3'
# llava-hf
llava1_5 = 'llava1_5'
llava_mistral = 'llava-mistral'
llava_vicuna = 'llava-vicuna'
llava_yi = 'llava-yi'
llama3_llava_next_hf = 'llama-llava-next-hf'
llava_next_llama3 = 'llava-next-llama3'
llava_qwen_hf = 'llama-qwen-hf'
llava_onevision_qwen = 'llava-onevision-qwen'
# llava-video
llava_next_video = 'llava-next-video'
llava_next_video_yi = 'llava-next-video-yi'
# lmms-lab:llava
llama3_llava_next = 'llama3-llava-next'
llava_qwen = 'llava-qwen'
# xtuner:llava
llava_llama_instruct = 'llava-llama-instruct'
idefics3 = 'idefics3'
mistral_nemo = 'mistral-nemo'
openbuddy = 'openbuddy'
openbuddy2 = 'openbuddy2'
internlm = 'internlm'
internlm2 = 'internlm2'
internlm_xcomposer2 = 'internlm-xcomposer2'
internlm_xcomposer2_4khd = 'internlm-xcomposer2-4khd'
internlm_xcomposer2_5 = 'internlm-xcomposer2_5'
internvl = 'internvl'
internvl2 = 'internvl2'
internvl_phi3 = 'internvl-phi3'
internvl2_phi3 = 'internvl2-phi3'
florence = 'florence'
yi_coder = 'yi-coder'
yi_vl = 'yi-vl'
yuan = 'yuan'
xverse = 'xverse'
ziya = 'ziya'
skywork = 'skywork'
bluelm = 'bluelm'
zephyr = 'zephyr'
sus = 'sus'
deepseek = 'deepseek'
numina_math = 'numina-math'
deepseek_coder = 'deepseek-coder'
deepseek_vl = 'deepseek-vl'
deepseek2 = 'deepseek2'
deepseek2_5 = 'deepseek2_5'
codefuse_codellama = 'codefuse-codellama'
codefuse = 'codefuse'
cogvlm = 'cogvlm'
cogvlm2_video = 'cogvlm2-video'
glm4v = 'glm4v'
cogagent_chat = 'cogagent-chat'
cogagent_instruct = 'cogagent-instruct'
orion = 'orion'
minicpm = 'minicpm'
minicpm_v = 'minicpm-v'
minicpm_v_v2_5 = 'minicpm-v-v2_5'
minicpm_v_v2_6 = 'minicpm-v-v2_6'
gemma = 'gemma'
paligemma = 'paligemma'
mplug_owl2 = 'mplug-owl2'
mplug_owl3 = 'mplug_owl3'
wizardlm2_awq = 'wizardlm2-awq'
wizardlm2 = 'wizardlm2'
atom = 'atom'
phi3 = 'phi3'
phi3_vl = 'phi3-vl'
telechat = 'telechat'
telechat_v2 = 'telechat-v2'
dbrx = 'dbrx'
mengzi = 'mengzi'
c4ai = 'c4ai'
chatml = 'chatml'
# compatibility. (Deprecated)
default_generation_bos = 'default-generation-bos'
@classmethod
def get_template_name_list(cls) -> List[str]:
res = []
for k in cls.__dict__.keys():
if k.startswith('__') or k == 'get_template_name_list':
continue
res.append(cls.__dict__[k])
return res
def register_template(template_type: str, template: Template, *, exist_ok: bool = False, **kwargs) -> None:
if not exist_ok and template_type in TEMPLATE_MAPPING:
raise ValueError(f'The `{template_type}` has already been registered in the TEMPLATE_MAPPING.')
template.template_type = template_type
template_info = {'template': template, **kwargs}
TEMPLATE_MAPPING[template_type] = template_info
register_template(
TemplateType.default,
Template([], ['### Human:\n{{QUERY}}\n\n### Assistant:\n'], ['\n\n'], [['eos_token_id']],
DEFAULT_SYSTEM, ['{{SYSTEM}}\n\n'],
auto_add_bos=True))
# You can set the query as '' to serve as a template for pre-training.
class DefaultGenerationTemplate(Template):
def __init__(self):
super().__init__([], ['{{QUERY}}'], None, [['eos_token_id']], auto_add_bos=True)
register_template(TemplateType.default_generation, DefaultGenerationTemplate(), is_generation=True)
register_template(
TemplateType.default_generation_bos,
Template([['bos_token_id']], ['{{QUERY}}'], None, [['eos_token_id']]),
is_generation=True)
class ChatmlTemplateMixin:
system = None
def __init__(self, auto_add_bos: bool = True):
Template.__init__(
self, [], ['<|im_start|>user\n{{QUERY}}<|im_end|>\n<|im_start|>assistant\n'], ['<|im_end|>\n'],
['<|im_end|>'],
self.system, ['<|im_start|>system\n{{SYSTEM}}<|im_end|>\n'],
auto_add_bos=auto_add_bos)
class ChatmlTemplate(ChatmlTemplateMixin, Template):
pass
class QwenTemplateMixin(ChatmlTemplateMixin):
system = DEFAULT_SYSTEM
def __init__(self):
super().__init__(auto_add_bos=False)
class QwenTemplate(QwenTemplateMixin, Template):
pass
class _QwenVLTemplateMixin:
load_medias = False
def check_example(self, example):
images = example.get('images') or []
assert not images or isinstance(fetch_one(images), str), 'QwenVL only supports datasets with images paths!'
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type == 'image'
images = example.get('images') or []
image = images[index]
assert isinstance(image, str)
return [f'Picture {index + 1}:
{image}\n']
def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example['objects']
object_ = objects[index]
return [f'[{object_["caption"]}]']
def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example['objects']
object_ = objects[index]
if isinstance(object_['bbox'][0], list):
all_objects = ''
for sub_object in object_['bbox']:
all_objects += (f'({sub_object[0]},{sub_object[1]}),' f'({sub_object[2]},{sub_object[3]})')
return [all_objects]
else:
return [
f'({object_["bbox"][0]},{object_["bbox"][1]}),'
f'({object_["bbox"][2]},{object_["bbox"][3]})'
]
register_template(TemplateType.qwen, QwenTemplate())
class QwenVLTemplate(_QwenVLTemplateMixin, QwenTemplate):
pass
class QwenVLGenerationTemplate(_QwenVLTemplateMixin, DefaultGenerationTemplate):
pass
register_template(TemplateType.qwen_vl, QwenVLTemplate())
register_template(TemplateType.qwen_vl_generation, QwenVLGenerationTemplate())
register_template(TemplateType.chatml, ChatmlTemplate())
register_template(
TemplateType.modelscope_agent,
Template([], [' \n\n<|user|>:{{QUERY}} \n\n<|assistant|>:'], [], [' \n\n'], DEFAULT_SYSTEM,
[' \n\n<|system|>:{{SYSTEM}}']))
class _QwenAudioTemplateMixin:
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type == 'audio'
audios = example.get('audios') or []
audio = audios[index]
assert isinstance(audio, str)
return [f'Audio {index + 1}:\n']
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, tokenizer_kwargs = Template._encode(self, example)
if len(inputs) == 0:
return inputs, tokenizer_kwargs
inputs.pop('loss_scale', None)
inputs.update(tokenizer_kwargs)
return inputs, tokenizer_kwargs
def _get_tokenizer_kwargs(self, context: str) -> Dict[str, Any]:
return {'audio_info': self.tokenizer.process_audio(context)}
def _concat_tokenizer_kwargs(self, tokenizer_kwargs: Dict[str, Any], curr_tokenizer_kwargs: Dict[str, Any]) -> None:
audio_info = curr_tokenizer_kwargs.get('audio_info')
old_audio_info = tokenizer_kwargs.get('audio_info')
if old_audio_info is None:
tokenizer_kwargs['audio_info'] = audio_info
elif audio_info is not None:
for k in ['input_audios', 'input_audio_lengths']:
old_audio_info[k] = torch.concat([old_audio_info[k], audio_info[k]], dim=0)
for k in ['audio_span_tokens', 'audio_urls']:
old_audio_info[k] = old_audio_info[k] + audio_info[k]
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = Template.data_collator(self, batch, padding_to)
if batch[0].get('audio_info') is not None:
res['audio_info'] = [b['audio_info'] for b in batch]
return res
class QwenAudioTemplate(_QwenAudioTemplateMixin, QwenTemplate):
pass
class QwenAudioGenerationTemplate(_QwenAudioTemplateMixin, DefaultGenerationTemplate):
pass
register_template(TemplateType.qwen_audio, QwenAudioTemplate(), lazy_tokenize=True)
register_template(
TemplateType.qwen_audio_generation, QwenAudioGenerationTemplate(), lazy_tokenize=True, is_generation=True)
class _Qwen2AudioTemplateMixin:
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = Template._encode(self, example)
if len(inputs) == 0:
return inputs, {}
processor = self.tokenizer.processor
sampling_rate = processor.feature_extractor.sampling_rate
audios = load_batch(
example.get('audios') or [], load_func=partial(load_audio_qwen, sampling_rate=sampling_rate))
if audios:
audio_inputs = processor.feature_extractor(
audios, sampling_rate=sampling_rate, return_attention_mask=True, return_tensors='pt')
audio_inputs['feature_attention_mask'] = audio_inputs.pop('attention_mask')
inputs.update(audio_inputs)
return inputs, {}
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = Template.data_collator(self, batch, padding_to)
input_features = [b['input_features'] for b in batch if b.get('input_features') is not None]
if input_features:
res['input_features'] = torch.concat(input_features)
feature_attention_mask = [b['feature_attention_mask'] for b in batch]
res['feature_attention_mask'] = torch.concat(feature_attention_mask)
return res
class Qwen2AudioTemplate(_Qwen2AudioTemplateMixin, QwenTemplate):
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type == 'audio'
return [f'Audio {index + 1}: <|audio_bos|><|AUDIO|><|audio_eos|>\n']
class Qwen2AudioGenerationTemplate(_Qwen2AudioTemplateMixin, DefaultGenerationTemplate):
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type == 'audio'
return ['<|audio_bos|><|AUDIO|><|audio_eos|>\n']
register_template(TemplateType.qwen2_audio, Qwen2AudioTemplate(), lazy_tokenize=True)
def _process_image_qwen(image):
from qwen_vl_utils.vision_process import IMAGE_FACTOR, MIN_PIXELS, MAX_PIXELS, smart_resize
size_factor = get_env_args('size_factor', int, IMAGE_FACTOR)
# resize
resized_height = get_env_args('resized_height', int, None)
resized_width = get_env_args('resized_width', int, None)
if resized_height and resized_width:
resized_height, resized_width = smart_resize(
resized_height,
resized_width,
factor=size_factor,
)
else:
width, height = image.size
min_pixels = get_env_args('min_pixels', int, MIN_PIXELS)
max_pixels = get_env_args('max_pixels', int, MAX_PIXELS)
resized_height, resized_width = smart_resize(
height,
width,
factor=size_factor,
min_pixels=min_pixels,
max_pixels=max_pixels,
)
image = image.resize((resized_width, resized_height))
return image
class Qwen2VLTemplate(QwenTemplate):
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index: int,
example: Dict[str, Any]) -> List[Context]:
assert media_type in {'image', 'video'}
if media_type == 'image':
return ['<|vision_start|><|image_pad|><|vision_end|>']
else:
return ['<|vision_start|><|video_pad|><|vision_end|>']
def replace_object(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example.get('objects')
if objects:
object_ = objects[index]
return ['<|object_ref_start|>', object_['caption'], '<|object_ref_end|>']
else:
return ['']
def replace_box(self, index: int, example: Dict[str, Any]) -> List[Context]:
objects = example.get('objects')
if objects:
object_ = objects[index]
if isinstance(object_['bbox'][0], list):
all_objects = ''
for sub_object in object_['bbox']:
all_objects += (f'<|box_start|>({sub_object[0]},{sub_object[1]}),'
f'({sub_object[2]},{sub_object[3]})<|box_end|>')
return [all_objects]
else:
return [
f'<|box_start|>({object_["bbox"][0]},{object_["bbox"][1]}),'
f'({object_["bbox"][2]},{object_["bbox"][3]})<|box_end|>'
]
else:
return ['']
def _encode(self, example: Dict[str, Any]) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super()._encode(example)
if len(inputs) == 0:
return inputs, {}
processor = self.tokenizer.processor
input_ids = inputs['input_ids']
labels = inputs['labels']
images = example.get('images') or []
videos = example.get('videos') or []
for media_type in ['images', 'videos']:
if locals()[media_type]:
if media_type == 'images':
images = load_batch(images, _process_image_qwen)
media_token = 151655
media_inputs = processor.image_processor(images=images, videos=None, return_tensors='pt')
media_grid_thw = media_inputs['image_grid_thw']
else:
videos = load_batch(videos, load_video_qwen2)
media_inputs = processor.image_processor(images=None, videos=videos, return_tensors='pt')
media_grid_thw = media_inputs['video_grid_thw']
media_token = 151656
idx_list = _findall(input_ids, media_token)
added_tokens_len = 0
for i, idx in enumerate(idx_list):
merge_length = processor.image_processor.merge_size**2
token_len = (media_grid_thw[i].prod() // merge_length)
input_ids = input_ids[:idx
+ added_tokens_len] + [media_token] * token_len + input_ids[added_tokens_len
+ idx + 1:]
if labels:
labels = labels[:idx + added_tokens_len] + [-100] * token_len + labels[added_tokens_len + idx
+ 1:]
added_tokens_len += token_len - 1
inputs.update(media_inputs)
inputs['input_ids'] = input_ids
inputs['labels'] = labels
return inputs, {}
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
for media_type in ['image', 'video']:
grid_thw = [b[f'{media_type}_grid_thw'] for b in batch if b.get(f'{media_type}_grid_thw') is not None]
if grid_thw:
res[f'{media_type}_grid_thw'] = torch.concat(grid_thw)
return res
register_template(TemplateType.qwen2_vl, Qwen2VLTemplate(), lazy_tokenize=True)
register_template(
TemplateType.qwen2_audio_generation, Qwen2AudioGenerationTemplate(), lazy_tokenize=True, is_generation=True)
class YiCoderTemplate(ChatmlTemplate):
system = 'You are a helpful assistant.'
register_template(TemplateType.yi_coder, YiCoderTemplate())
yi_vl_default_system = (
'This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. '
"Read all the images carefully, and respond to the human's questions with informative, "
'helpful, detailed and polite answers. '
'这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。'
'仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。')
class YiVLTemplate(Template):
def replace_tag(self, media_type, index, example) -> List[Context]:
assert media_type == 'image'
return [[-200], '\n']
def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super()._encode(example)
if len(inputs) == 0:
return inputs, {}
inputs.pop('loss_scale', None)
from llava.mm_utils import expand2square
# This processor should be put from the `model.vision_tower.image_processor`
image_processor = self.tokenizer.image_processor
images = example.get('images') or []
for i, image in enumerate(images):
background_color = tuple(int(x * 255) for x in image_processor.image_mean)
image = expand2square(image, background_color)
images[i] = image
if images:
image_tensor = image_processor.preprocess(images, return_tensors='pt')['pixel_values']
inputs['images'] = image_tensor.to(kwargs['dtype'])
return inputs, {}
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
images = [b['images'] for b in batch if 'images' in b]
if images:
res['images'] = torch.concat(images)
has_images = [(b == -200).sum() for b in res['input_ids']]
assert all([
h > 0 for h in has_images
]) or not any([h > 0
for h in has_images]), 'YIVL does not support mix-batch nlp dataset and multi-modal dataset'
return res
class GLMTemplate(Template):
def _init_template(self, tokenizer: PreTrainedTokenizerBase, *args, **kwargs) -> None:
res = super()._init_template(tokenizer, *args, **kwargs)
token_list = tokenizer.encode('')
self.prefix.insert(0, token_list)
if self.system_prefix is not None:
self.system_prefix.insert(0, token_list)
return res
class GLM4VTemplate(GLMTemplate):
def __init__(self):
super().__init__([], ['<|user|>\n{{QUERY}}<|assistant|>'], [], ['<|endoftext|>'], None,
['<|system|>\n{{SYSTEM}}'])
def check_example(self, example):
images = example.get('images') or []
assert len(images) <= 1
def replace_tag(self, media_type: Literal['image', 'video', 'audio'], index, example) -> List[Context]:
assert media_type == 'image'
return [[-100]]
def _encode(self, example: Dict[str, Any], **kwargs) -> Tuple[Dict[str, Any], Dict[str, Any]]:
inputs, _ = super()._encode(example)
if len(inputs) == 0:
return inputs, {}
input_ids = inputs['input_ids']
labels = inputs['labels']
idx_list = _findall(input_ids, -100)
if idx_list:
idx = idx_list[0]
image = example.get('images')[0]
placeholder = '<|begin_of_image|><|endoftext|><|end_of_image|>'
placeholder_id = self.tokenizer.encode(placeholder, add_special_tokens=False)
input_ids = (input_ids[:idx] + placeholder_id + input_ids[idx + 1:])
if labels is not None:
labels = (labels[:idx] + [-100] * len(placeholder_id) + labels[idx + 1:])
messages = example['messages']
messages[0]['image'] = image
inputs2: Dict[str, Any] = self.tokenizer.apply_chat_template(messages, return_dict=True)
inputs['images'] = inputs2['images']
inputs['input_ids'] = input_ids
inputs['labels'] = labels
return inputs, {}
def data_collator(self, batch: List[Dict[str, Any]], padding_to: Optional[int] = None) -> Dict[str, Any]:
res = super().data_collator(batch, padding_to)
images = [b['images'] for b in batch if 'images' in b]
if images:
res['images'] = torch.concat(images)
return res
register_template(TemplateType.glm4v, GLM4VTemplate(), infer_media_type='dialogue', lazy_tokenize=True, use_model=False)
register_template(
TemplateType.yi_vl,
YiVLTemplate([], [[8308], 'Human: {{QUERY}}\n', [8308], 'Assistant:'], ['\n'], ['\n', [8308]], yi_vl_default_system,
['{{SYSTEM}}\n\n']),
use_model=False,
infer_media_type='round',
lazy_tokenize=True)
register_template(TemplateType.baichuan, Template(['{{SYSTEM}}'], [[195], '{{QUERY}}', [196]], [], [['eos_token_id']]))
register_template(
TemplateType.chatglm2,
GLMTemplate(['{{SYSTEM}}'], ['[Round {{ROUND1}}]\n\n问:{{QUERY}}\n\n答:'], ['\n\n'], [['eos_token_id']]))
register_template(
TemplateType.chatglm_generation, GLMTemplate([], ['{{QUERY}}'], None, [['eos_token_id']]), is_generation=True)
register_template(
TemplateType.chatglm3,
GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|user|>'], None, ['<|system|>\n{{SYSTEM}}']))
register_template(
TemplateType.chatglm4,
GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|user|>'],
None, ['<|system|>\n{{SYSTEM}}'],
tools_prompt='glm4',
tool_prompt=['<|observation|>\n{{QUERY}}<|assistant|>\n']))
codegeex4_system = '你是一位智能编程助手,你叫CodeGeeX。你会为用户回答关于编程、代码、计算机方面的任何问题,并提供格式规范、可以执行、准确安全的代码,并在必要时提供详细的解释。'
register_template(
TemplateType.codegeex4,
GLMTemplate([], ['<|user|>\n{{QUERY}}<|assistant|>\n'], [], ['<|endoftext|>'], codegeex4_system,
['<|system|>\n{{SYSTEM}}']))
register_template(
TemplateType.deepseek,
Template([['bos_token_id']], ['User: {{QUERY}}\n\nAssistant:'], [['eos_token_id']], [['eos_token_id']], None,
[['bos_token_id'], '{{SYSTEM}}\n\n']))
register_template(
TemplateType.numina_math,
Template([['bos_token_id']], ['### Problem: {{QUERY}}\n### Solution: '], ['\n'], [['eos_token_id']], None,
[['bos_token_id'], '{{SYSTEM}}']))
register_template(
TemplateType.deepseek2,
Template([[100000]], ['User: {{QUERY}}\n\nAssistant:'], [[100001]], [[100001]], None, [[100000], '{{SYSTEM}}\n\n']))
register_template(
TemplateType.deepseek2_5,
Template(['<|begin▁of▁sentence|>'], ['<|User|>{{QUERY}}<|Assistant|>'], ['<|end_of_sentense|>'],
['<|end_of_sentense|>'], None, ['<|begin▁of▁sentence|>{{SYSTEM}}']))
# ref: https://github.com/facebookresearch/llama/blob/main/llama/generation.py
LLAMA_DEFAULT_SYSTEM = (
'You are a helpful, respectful and honest assistant. '
'Always answer as helpfully as possible, while being safe. '
'Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. '
'Please ensure that your responses are socially unbiased and positive in nature.\n\n'
'If a question does not make any sense, or is not factually coherent, '
'explain why instead of answering something not correct. '
"If you don't know the answer to a question, please don't share false information.")
register_template(
TemplateType.llama,
Template(['[INST] '], ['{{QUERY}} [/INST]'], ['[INST] '], [''], LLAMA_DEFAULT_SYSTEM,
['[INST] <>\n{{SYSTEM}}\n<>\n\n']))
register_template(
TemplateType.longwriter_llama3,
Template(['[INST]'], ['{{QUERY}}[/INST]'], ['[INST]'], ['<|end_of_text|>'], None,
['<>\n{{SYSTEM}}\n<>\n\n']))
register_template(TemplateType.mistral_nemo,
Template(['[INST] '], ['{{SYSTEM}}\n\n', '{{QUERY}}[/INST]'], ['[INST] '], ['']))
class Llama3TemplateMixin:
system = None
def __init__(self):
Template.__init__(
self, ['<|begin_of_text|>'], [
'<|start_header_id|>user<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
'<|start_header_id|>assistant<|end_header_id|>\n\n'
], ['<|eot_id|>'], ['<|eot_id|>'],
self.system, ['<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n{{SYSTEM}}<|eot_id|>'],
tools_prompt='toolbench',
tool_prompt=[
'<|start_header_id|>tool<|end_header_id|>\n\n{{QUERY}}<|eot_id|>'
'<|start_header_id|>assistant<|end_header_id|>\n\n'
])
class Llama3Template(Llama3TemplateMixin, Template):
pass
class ReflectionTemplate(Llama3TemplateMixin, Template):
system = ('You are a world-class AI system, capable of complex reasoning and reflection. '
'Reason through the query inside tags, and then provide your final '
'response inside