# Copyright (c) Alibaba, Inc. and its affiliates. import json import re from copy import deepcopy from typing import Any, Dict, List, Literal, Optional, Tuple, Union import torch import torch.nn.functional as F from modelscope import get_logger from torch.nn import Module from torch.nn.utils.rnn import pad_sequence from transformers import PreTrainedTokenizerBase, StoppingCriteria from .loss_scale import loss_scale_map from .tools_prompt import get_tools_prompt from .utils import load_batch, load_image, rescale_image, fetch_one, to_device, decode_base64 from .utils import History, Prompt, StopWords, Context, Messages logger = get_logger() DEFAULT_SYSTEM = 'You are a helpful assistant.' TEMPLATE_MAPPING: Dict[str, Dict[str, Any]] = {} def get_template( template_type: str, tokenizer: PreTrainedTokenizerBase, default_system: Optional[str] = None, max_length: Optional[int] = None, truncation_strategy: Literal['delete', 'truncation_left'] = 'delete', **kwargs, ) -> 'Template': template_info = TEMPLATE_MAPPING[template_type] template = deepcopy(template_info['template']) template.init_template(tokenizer, default_system, max_length, truncation_strategy, **kwargs) return template def _findall(token_list: List[int], sub_token_list: Union[int, List[int]]) -> List[int]: """Find the index of a token in the token_list.""" if isinstance(sub_token_list, int): sub_token_list = [sub_token_list] res = [] idx = -1 try: while True: idx = token_list.index(sub_token_list[0], idx + 1) if len(sub_token_list) == 1 or sub_token_list == token_list[idx:idx + len(sub_token_list)]: res.append(idx) except ValueError: pass return res def replace_img_tag(messages: Messages, replace_token: str, pattern=r'(.+?)') -> Tuple[str, History, List[str]]: images_path = [] new_messages = [] for i, m in enumerate(messages): m = m.copy() if m['content'] is None or m['role'] in ('tool', 'system', 'assistant'): new_messages.append(m) else: images_path += re.findall(pattern, m['content']) m['content'] = re.sub(pattern, replace_token, m['content']) new_messages.append(m) return messages, images_path class StopWordsCriteria(StoppingCriteria): """Adding extra stop words in template to prevent unstoppable generation Like suffixes and chat seps in the template. """ def __init__(self, tokenizer: PreTrainedTokenizerBase, stop_words: StopWords, **tokenizer_kwargs) -> None: self.tokenizer = tokenizer self.stop_words = stop_words self.tokenizer_kwargs = tokenizer_kwargs self.start_idx = -1 def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor, **kwargs) -> bool: if self.start_idx == -1: self.start_idx = len(input_ids[0]) - 1 tokenizer = self.tokenizer stop_words = self.stop_words # [-20:]: Assuming the end tokens do not exceed 20 tokens, # to avoid input_ids being too long and affecting efficiency. text = tokenizer.decode(input_ids[0, self.start_idx:][-20:], **self.tokenizer_kwargs) for stop_word in stop_words: if isinstance(stop_word, str): if stop_word in text: return True else: # list if len(stop_word) > 0 and input_ids[0].tolist()[-len(stop_word):] == stop_word: return True return False class Template: """A template class for all supported models. Args: prefix: Prefix tokens before the first turn's prompt prompt: A list of elements whose types are str and list of integers. The input query part of every turn. chat_sep: The chat separators between every turn. suffix: The end tokens after the chat finished. default_system: A default system instruction. system_prefix: The prefix if the `system` is not empty. auto_add_bos: By default, the bos_token is not added. The auto_add_bos option will determine whether to add it based on `tokenizer.encode('')`. tools_prompt: The tools prompt name tool_prompt: The tool prompt, usually useful when there is a tool role padding_side: The padding side infer_media_type: The media type supported by the multi-modals Examples: system\nYou are a helpful assistant!\nWho are you?\nassistant:I am a robot\nWho are you?\nassistant:I am a robot # noqa ----------system------------ ---query---- --response- -----chatsep----- ---query--- --response- ----suffix----- ----------------------------system_prefix---------------------------- ---------------------------- prompt ------------------------------------- ---------------------------- prompt ------------------------------------- """ special_tokens = ['', '