chat_template_utils.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import inspect
  15. import json
  16. import re
  17. import types
  18. from contextlib import contextmanager
  19. from copy import deepcopy
  20. from datetime import datetime
  21. from functools import lru_cache
  22. from inspect import isfunction
  23. from typing import (
  24. Any,
  25. Callable,
  26. Literal,
  27. Optional,
  28. Union,
  29. get_args,
  30. get_origin,
  31. get_type_hints,
  32. )
  33. from packaging import version
  34. from . import logging
  35. from .import_utils import is_jinja_available, is_torch_available, is_vision_available
  36. logger = logging.get_logger(__name__)
  37. if is_jinja_available():
  38. import jinja2
  39. from jinja2.ext import Extension
  40. from jinja2.sandbox import ImmutableSandboxedEnvironment
  41. else:
  42. jinja2 = None
  43. if is_vision_available():
  44. from PIL.Image import Image
  45. if is_torch_available():
  46. from torch import Tensor
  47. BASIC_TYPES = (int, float, str, bool, Any, type(None), ...)
  48. # Extracts the initial segment of the docstring, containing the function description
  49. description_re = re.compile(r"^(.*?)[\n\s]*(Args:|Returns:|Raises:|\Z)", re.DOTALL)
  50. # Extracts the Args: block from the docstring
  51. args_re = re.compile(r"\n\s*Args:\n\s*(.*?)[\n\s]*(Returns:|Raises:|\Z)", re.DOTALL)
  52. # Splits the Args: block into individual arguments
  53. args_split_re = re.compile(
  54. r"""
  55. (?:^|\n) # Match the start of the args block, or a newline
  56. \s*(\w+):\s* # Capture the argument name and strip spacing
  57. (.*?)\s* # Capture the argument description, which can span multiple lines, and strip trailing spacing
  58. (?=\n\s*\w+:|\Z) # Stop when you hit the next argument or the end of the block
  59. """,
  60. re.DOTALL | re.VERBOSE,
  61. )
  62. # Extracts the Returns: block from the docstring, if present. Note that most chat templates ignore the return type/doc!
  63. returns_re = re.compile(r"\n\s*Returns:\n\s*(.*?)[\n\s]*(Raises:|\Z)", re.DOTALL)
  64. class TypeHintParsingException(Exception):
  65. """Exception raised for errors in parsing type hints to generate JSON schemas"""
  66. pass
  67. class DocstringParsingException(Exception):
  68. """Exception raised for errors in parsing docstrings to generate JSON schemas"""
  69. pass
  70. def _get_json_schema_type(param_type: type) -> dict[str, str]:
  71. type_mapping = {
  72. int: {"type": "integer"},
  73. float: {"type": "number"},
  74. str: {"type": "string"},
  75. bool: {"type": "boolean"},
  76. type(None): {"type": "null"},
  77. Any: {},
  78. }
  79. if is_vision_available():
  80. type_mapping[Image] = {"type": "image"}
  81. if is_torch_available():
  82. type_mapping[Tensor] = {"type": "audio"}
  83. return type_mapping.get(param_type, {"type": "object"})
  84. def _parse_type_hint(hint: str) -> dict:
  85. origin = get_origin(hint)
  86. args = get_args(hint)
  87. if origin is None:
  88. try:
  89. return _get_json_schema_type(hint)
  90. except KeyError:
  91. raise TypeHintParsingException(
  92. "Couldn't parse this type hint, likely due to a custom class or object: ", hint
  93. )
  94. elif origin is Union or (hasattr(types, "UnionType") and origin is types.UnionType):
  95. # Recurse into each of the subtypes in the Union, except None, which is handled separately at the end
  96. subtypes = [_parse_type_hint(t) for t in args if t is not type(None)]
  97. if len(subtypes) == 1:
  98. # A single non-null type can be expressed directly
  99. return_dict = subtypes[0]
  100. elif all(isinstance(subtype["type"], str) for subtype in subtypes):
  101. # A union of basic types can be expressed as a list in the schema
  102. return_dict = {"type": sorted([subtype["type"] for subtype in subtypes])}
  103. else:
  104. # A union of more complex types requires "anyOf"
  105. return_dict = {"anyOf": subtypes}
  106. if type(None) in args:
  107. return_dict["nullable"] = True
  108. return return_dict
  109. elif origin is Literal and len(args) > 0:
  110. LITERAL_TYPES = (int, float, str, bool, type(None))
  111. args_types = []
  112. for arg in args:
  113. if type(arg) not in LITERAL_TYPES:
  114. raise TypeHintParsingException("Only the valid python literals can be listed in typing.Literal.")
  115. arg_type = _get_json_schema_type(type(arg)).get("type")
  116. if arg_type is not None and arg_type not in args_types:
  117. args_types.append(arg_type)
  118. return {
  119. "type": args_types.pop() if len(args_types) == 1 else list(args_types),
  120. "enum": list(args),
  121. }
  122. elif origin is list:
  123. if not args:
  124. return {"type": "array"}
  125. else:
  126. # Lists can only have a single type argument, so recurse into it
  127. return {"type": "array", "items": _parse_type_hint(args[0])}
  128. elif origin is tuple:
  129. if not args:
  130. return {"type": "array"}
  131. if len(args) == 1:
  132. raise TypeHintParsingException(
  133. f"The type hint {str(hint).replace('typing.', '')} is a Tuple with a single element, which "
  134. "we do not automatically convert to JSON schema as it is rarely necessary. If this input can contain "
  135. "more than one element, we recommend "
  136. "using a list[] type instead, or if it really is a single element, remove the tuple[] wrapper and just "
  137. "pass the element directly."
  138. )
  139. if ... in args:
  140. raise TypeHintParsingException(
  141. "Conversion of '...' is not supported in Tuple type hints. "
  142. "Use list[] types for variable-length"
  143. " inputs instead."
  144. )
  145. return {"type": "array", "prefixItems": [_parse_type_hint(t) for t in args]}
  146. elif origin is dict:
  147. # The JSON equivalent to a dict is 'object', which mandates that all keys are strings
  148. # However, we can specify the type of the dict values with "additionalProperties"
  149. out = {"type": "object"}
  150. if len(args) == 2:
  151. out["additionalProperties"] = _parse_type_hint(args[1])
  152. return out
  153. raise TypeHintParsingException("Couldn't parse this type hint, likely due to a custom class or object: ", hint)
  154. def _convert_type_hints_to_json_schema(func: Callable) -> dict:
  155. type_hints = get_type_hints(func)
  156. signature = inspect.signature(func)
  157. required = []
  158. for param_name, param in signature.parameters.items():
  159. if param.annotation == inspect.Parameter.empty:
  160. raise TypeHintParsingException(f"Argument {param.name} is missing a type hint in function {func.__name__}")
  161. if param.default == inspect.Parameter.empty:
  162. required.append(param_name)
  163. properties = {}
  164. for param_name, param_type in type_hints.items():
  165. properties[param_name] = _parse_type_hint(param_type)
  166. schema = {"type": "object", "properties": properties}
  167. if required:
  168. schema["required"] = required
  169. return schema
  170. def parse_google_format_docstring(docstring: str) -> tuple[Optional[str], Optional[dict], Optional[str]]:
  171. """
  172. Parses a Google-style docstring to extract the function description,
  173. argument descriptions, and return description.
  174. Args:
  175. docstring (str): The docstring to parse.
  176. Returns:
  177. The function description, arguments, and return description.
  178. """
  179. # Extract the sections
  180. description_match = description_re.search(docstring)
  181. args_match = args_re.search(docstring)
  182. returns_match = returns_re.search(docstring)
  183. # Clean and store the sections
  184. description = description_match.group(1).strip() if description_match else None
  185. docstring_args = args_match.group(1).strip() if args_match else None
  186. returns = returns_match.group(1).strip() if returns_match else None
  187. # Parsing the arguments into a dictionary
  188. if docstring_args is not None:
  189. docstring_args = "\n".join([line for line in docstring_args.split("\n") if line.strip()]) # Remove blank lines
  190. matches = args_split_re.findall(docstring_args)
  191. args_dict = {match[0]: re.sub(r"\s*\n+\s*", " ", match[1].strip()) for match in matches}
  192. else:
  193. args_dict = {}
  194. return description, args_dict, returns
  195. def get_json_schema(func: Callable) -> dict:
  196. """
  197. This function generates a JSON schema for a given function, based on its docstring and type hints. This is
  198. mostly used for passing lists of tools to a chat template. The JSON schema contains the name and description of
  199. the function, as well as the names, types and descriptions for each of its arguments. `get_json_schema()` requires
  200. that the function has a docstring, and that each argument has a description in the docstring, in the standard
  201. Google docstring format shown below. It also requires that all the function arguments have a valid Python type hint.
  202. Although it is not required, a `Returns` block can also be added, which will be included in the schema. This is
  203. optional because most chat templates ignore the return value of the function.
  204. Args:
  205. func: The function to generate a JSON schema for.
  206. Returns:
  207. A dictionary containing the JSON schema for the function.
  208. Examples:
  209. ```python
  210. >>> def multiply(x: float, y: float):
  211. >>> '''
  212. >>> A function that multiplies two numbers
  213. >>>
  214. >>> Args:
  215. >>> x: The first number to multiply
  216. >>> y: The second number to multiply
  217. >>> '''
  218. >>> return x * y
  219. >>>
  220. >>> print(get_json_schema(multiply))
  221. {
  222. "name": "multiply",
  223. "description": "A function that multiplies two numbers",
  224. "parameters": {
  225. "type": "object",
  226. "properties": {
  227. "x": {"type": "number", "description": "The first number to multiply"},
  228. "y": {"type": "number", "description": "The second number to multiply"}
  229. },
  230. "required": ["x", "y"]
  231. }
  232. }
  233. ```
  234. The general use for these schemas is that they are used to generate tool descriptions for chat templates that
  235. support them, like so:
  236. ```python
  237. >>> from transformers import AutoTokenizer
  238. >>> from transformers.utils import get_json_schema
  239. >>>
  240. >>> def multiply(x: float, y: float):
  241. >>> '''
  242. >>> A function that multiplies two numbers
  243. >>>
  244. >>> Args:
  245. >>> x: The first number to multiply
  246. >>> y: The second number to multiply
  247. >>> return x * y
  248. >>> '''
  249. >>>
  250. >>> multiply_schema = get_json_schema(multiply)
  251. >>> tokenizer = AutoTokenizer.from_pretrained("CohereForAI/c4ai-command-r-v01")
  252. >>> messages = [{"role": "user", "content": "What is 179 x 4571?"}]
  253. >>> formatted_chat = tokenizer.apply_chat_template(
  254. >>> messages,
  255. >>> tools=[multiply_schema],
  256. >>> chat_template="tool_use",
  257. >>> return_dict=True,
  258. >>> return_tensors="pt",
  259. >>> add_generation_prompt=True
  260. >>> )
  261. >>> # The formatted chat can now be passed to model.generate()
  262. ```
  263. Each argument description can also have an optional `(choices: ...)` block at the end, such as
  264. `(choices: ["tea", "coffee"])`, which will be parsed into an `enum` field in the schema. Note that this will
  265. only be parsed correctly if it is at the end of the line:
  266. ```python
  267. >>> def drink_beverage(beverage: str):
  268. >>> '''
  269. >>> A function that drinks a beverage
  270. >>>
  271. >>> Args:
  272. >>> beverage: The beverage to drink (choices: ["tea", "coffee"])
  273. >>> '''
  274. >>> pass
  275. >>>
  276. >>> print(get_json_schema(drink_beverage))
  277. ```
  278. {
  279. 'name': 'drink_beverage',
  280. 'description': 'A function that drinks a beverage',
  281. 'parameters': {
  282. 'type': 'object',
  283. 'properties': {
  284. 'beverage': {
  285. 'type': 'string',
  286. 'enum': ['tea', 'coffee'],
  287. 'description': 'The beverage to drink'
  288. }
  289. },
  290. 'required': ['beverage']
  291. }
  292. }
  293. """
  294. doc = inspect.getdoc(func)
  295. if not doc:
  296. raise DocstringParsingException(
  297. f"Cannot generate JSON schema for {func.__name__} because it has no docstring!"
  298. )
  299. doc = doc.strip()
  300. main_doc, param_descriptions, return_doc = parse_google_format_docstring(doc)
  301. json_schema = _convert_type_hints_to_json_schema(func)
  302. if (return_dict := json_schema["properties"].pop("return", None)) is not None:
  303. if return_doc is not None: # We allow a missing return docstring since most templates ignore it
  304. return_dict["description"] = return_doc
  305. for arg, schema in json_schema["properties"].items():
  306. if arg not in param_descriptions:
  307. raise DocstringParsingException(
  308. f"Cannot generate JSON schema for {func.__name__} because the docstring has no description for the argument '{arg}'"
  309. )
  310. desc = param_descriptions[arg]
  311. enum_choices = re.search(r"\(choices:\s*(.*?)\)\s*$", desc, flags=re.IGNORECASE)
  312. if enum_choices:
  313. schema["enum"] = [c.strip() for c in json.loads(enum_choices.group(1))]
  314. desc = enum_choices.string[: enum_choices.start()].strip()
  315. schema["description"] = desc
  316. output = {"name": func.__name__, "description": main_doc, "parameters": json_schema}
  317. if return_dict is not None:
  318. output["return"] = return_dict
  319. return {"type": "function", "function": output}
  320. def _render_with_assistant_indices(
  321. compiled_template, messages, tools, documents, add_generation_prompt, **template_kwargs
  322. ):
  323. rendered_blocks = []
  324. generation_indices = []
  325. with compiled_template.environment.activate_tracker(rendered_blocks, generation_indices):
  326. for block in compiled_template.generate(
  327. messages=messages,
  328. tools=tools,
  329. documents=documents,
  330. add_generation_prompt=add_generation_prompt,
  331. **template_kwargs,
  332. ):
  333. rendered_blocks.append(block)
  334. rendered_chat = "".join(rendered_blocks)
  335. return rendered_chat, generation_indices
  336. @lru_cache
  337. def _compile_jinja_template(chat_template):
  338. if not is_jinja_available():
  339. raise ImportError(
  340. "apply_chat_template requires jinja2 to be installed. Please install it using `pip install jinja2`."
  341. )
  342. class AssistantTracker(Extension):
  343. # This extension is used to track the indices of assistant-generated tokens in the rendered chat
  344. tags = {"generation"}
  345. def __init__(self, environment: ImmutableSandboxedEnvironment):
  346. # The class is only initiated by jinja.
  347. super().__init__(environment)
  348. environment.extend(activate_tracker=self.activate_tracker)
  349. self._rendered_blocks = None
  350. self._generation_indices = None
  351. def parse(self, parser: jinja2.parser.Parser) -> jinja2.nodes.CallBlock:
  352. lineno = next(parser.stream).lineno
  353. body = parser.parse_statements(["name:endgeneration"], drop_needle=True)
  354. return jinja2.nodes.CallBlock(self.call_method("_generation_support"), [], [], body).set_lineno(lineno)
  355. @jinja2.pass_eval_context
  356. def _generation_support(self, context: jinja2.nodes.EvalContext, caller: jinja2.runtime.Macro) -> str:
  357. rv = caller()
  358. if self.is_active():
  359. # Only track generation indices if the tracker is active
  360. start_index = len("".join(self._rendered_blocks))
  361. end_index = start_index + len(rv)
  362. self._generation_indices.append((start_index, end_index))
  363. return rv
  364. def is_active(self) -> bool:
  365. return self._rendered_blocks or self._generation_indices
  366. @contextmanager
  367. def activate_tracker(self, rendered_blocks: list[int], generation_indices: list[int]):
  368. try:
  369. if self.is_active():
  370. raise ValueError("AssistantTracker should not be reused before closed")
  371. self._rendered_blocks = rendered_blocks
  372. self._generation_indices = generation_indices
  373. yield
  374. finally:
  375. self._rendered_blocks = None
  376. self._generation_indices = None
  377. if version.parse(jinja2.__version__) < version.parse("3.1.0"):
  378. raise ImportError(
  379. f"apply_chat_template requires jinja2>=3.1.0 to be installed. Your version is {jinja2.__version__}."
  380. )
  381. def raise_exception(message):
  382. raise jinja2.exceptions.TemplateError(message)
  383. def tojson(x, ensure_ascii=False, indent=None, separators=None, sort_keys=False):
  384. # We override the built-in tojson filter because Jinja's default filter escapes HTML characters
  385. # We also expose some options like custom indents and separators
  386. return json.dumps(x, ensure_ascii=ensure_ascii, indent=indent, separators=separators, sort_keys=sort_keys)
  387. def strftime_now(format):
  388. return datetime.now().strftime(format)
  389. jinja_env = ImmutableSandboxedEnvironment(
  390. trim_blocks=True, lstrip_blocks=True, extensions=[AssistantTracker, jinja2.ext.loopcontrols]
  391. )
  392. jinja_env.filters["tojson"] = tojson
  393. jinja_env.globals["raise_exception"] = raise_exception
  394. jinja_env.globals["strftime_now"] = strftime_now
  395. return jinja_env.from_string(chat_template)
  396. def render_jinja_template(
  397. conversations: list[list[dict[str, str]]],
  398. tools: Optional[list[Union[dict, Callable]]] = None,
  399. documents: Optional[list[dict[str, str]]] = None,
  400. chat_template: Optional[str] = None,
  401. return_assistant_tokens_mask: bool = False,
  402. continue_final_message: bool = False,
  403. add_generation_prompt: bool = False,
  404. **kwargs,
  405. ) -> str:
  406. if return_assistant_tokens_mask and not re.search(r"\{\%-?\s*generation\s*-?\%\}", chat_template):
  407. logger.warning_once(
  408. "return_assistant_tokens_mask==True but chat template does not contain `{% generation %}` keyword."
  409. )
  410. # Compilation function uses a cache to avoid recompiling the same template
  411. compiled_template = _compile_jinja_template(chat_template)
  412. # We accept either JSON schemas or functions for tools. If we get functions, we convert them to schemas
  413. if tools is not None:
  414. tool_schemas = []
  415. for tool in tools:
  416. if isinstance(tool, dict):
  417. tool_schemas.append(tool)
  418. elif isfunction(tool):
  419. tool_schemas.append(get_json_schema(tool))
  420. else:
  421. raise ValueError(
  422. "Tools should either be a JSON schema, or a callable function with type hints "
  423. "and a docstring suitable for auto-conversion to a schema."
  424. )
  425. else:
  426. tool_schemas = None
  427. if documents is not None:
  428. for document in documents:
  429. if not isinstance(document, dict):
  430. raise TypeError("Documents should be a list of dicts with 'title' and 'text' keys!")
  431. rendered = []
  432. all_generation_indices = []
  433. continue_final_message_tag = "CONTINUE_FINAL_MESSAGE_TAG "
  434. for chat in conversations:
  435. if hasattr(chat, "messages"):
  436. # Indicates it's a Conversation object
  437. chat = chat.messages
  438. if continue_final_message:
  439. chat = deepcopy(chat)
  440. final_message = chat[-1]["content"]
  441. if isinstance(final_message, (list, tuple)):
  442. for content_block in reversed(final_message):
  443. if "text" in content_block:
  444. # Pick the last text block in the message (the first one we hit while iterating in reverse)
  445. final_message = content_block["text"]
  446. content_block["text"] = content_block["text"] + continue_final_message_tag
  447. break
  448. else:
  449. raise ValueError(
  450. "continue_final_message is set but we could not find any text to continue in the final message!"
  451. )
  452. else:
  453. chat[-1]["content"] = chat[-1]["content"] + continue_final_message_tag
  454. if return_assistant_tokens_mask:
  455. rendered_chat, generation_indices = _render_with_assistant_indices(
  456. compiled_template=compiled_template,
  457. messages=chat,
  458. tools=tool_schemas,
  459. documents=documents,
  460. add_generation_prompt=add_generation_prompt,
  461. **kwargs,
  462. )
  463. all_generation_indices.append(generation_indices)
  464. else:
  465. rendered_chat = compiled_template.render(
  466. messages=chat,
  467. tools=tool_schemas,
  468. documents=documents,
  469. add_generation_prompt=add_generation_prompt,
  470. **kwargs,
  471. )
  472. if continue_final_message:
  473. if (final_message.strip() not in rendered_chat) or (
  474. continue_final_message_tag.strip() not in rendered_chat
  475. ):
  476. raise ValueError(
  477. "continue_final_message is set but the final message does not appear in the chat after "
  478. "applying the chat template! This can happen if the chat template deletes portions of "
  479. "the final message. Please verify the chat template and final message in your chat to "
  480. "ensure they are compatible."
  481. )
  482. tag_loc = rendered_chat.rindex(continue_final_message_tag.strip())
  483. if rendered_chat[tag_loc : tag_loc + len(continue_final_message_tag)] == continue_final_message_tag:
  484. # The template preserves spacing, so things are simple
  485. rendered_chat = rendered_chat[:tag_loc]
  486. else:
  487. # The message has trailing spacing that was trimmed, so we must be more cautious
  488. rendered_chat = rendered_chat[:tag_loc].rstrip()
  489. rendered.append(rendered_chat)
  490. return rendered, all_generation_indices