model_debugging_utils.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456
  1. # Copyright 2025 The HuggingFace Inc. team.
  2. # All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import functools
  16. import json
  17. import os
  18. import re
  19. from contextlib import contextmanager, redirect_stdout
  20. from io import StringIO
  21. from typing import Optional
  22. from .utils import logging
  23. from .utils.import_utils import is_torch_available, requires
  24. if is_torch_available():
  25. import torch
  26. from safetensors.torch import save_file
  27. _torch_distributed_available = False
  28. # Note to code inspectors: this toolbox is intended for people who add models to `transformers`.
  29. if torch.distributed.is_available():
  30. import torch.distributed.tensor
  31. _torch_distributed_available = True
  32. else:
  33. _torch_distributed_available = False
  34. logger = logging.get_logger(__name__)
  35. def _is_rank_zero():
  36. """Return True if rank=0 or we aren't running distributed."""
  37. if not (_torch_distributed_available and torch.distributed.is_initialized()):
  38. return True
  39. return torch.distributed.get_rank() == 0
  40. MEMORY_ADDRESS_REGEX = re.compile(r"object at 0x[0-9A-Fa-f]+")
  41. def _sanitize_repr_for_diff(x_str: str) -> str:
  42. """
  43. Replace memory addresses in an object's repr with a stable placeholder
  44. so that beautiful JSON diffs won't be ruined by ephemeral addresses.
  45. """
  46. return MEMORY_ADDRESS_REGEX.sub("object at 0xXXXXXXXX", x_str)
  47. def _dtensor_repr(x):
  48. """Return a stable string representation for a DTensor-like object."""
  49. if _is_rank_zero():
  50. return f"DTensor (rank0) -> {repr(x._local_tensor)}"
  51. return "DTensor(non-rank0)"
  52. def _serialize_tensor_like_io(
  53. value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None
  54. ):
  55. """
  56. Converts Tensors and DTensors to a JSON-serializable dictionary representation.
  57. Args:
  58. value: Any Python object, often including torch Tensors, lists, dicts, etc.
  59. debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
  60. use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensor as the
  61. `value` property in the asscoiated FULL_TENSORS.json file, or to store the full tensors in separate
  62. SafeTensors file and store the relative path to that file in the `value` property in the dictionary.
  63. path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
  64. tensor value if `use_repr=False`.
  65. Returns:
  66. A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
  67. """
  68. torch.set_printoptions(sci_mode=True)
  69. if use_repr:
  70. value_out = _repr_to_list(value)
  71. elif path_to_value:
  72. if not path_to_value.endswith(".safetensors"):
  73. path_to_value += ".safetensors"
  74. filepath = os.path.join(debug_path, path_to_value) if debug_path else path_to_value
  75. save_file({"data": value.contiguous().detach().cpu()}, filepath)
  76. value_out = f"./{path_to_value}"
  77. else:
  78. raise ValueError(f"{use_repr=} and {path_to_value=} cannot both be falsy.")
  79. out = {
  80. "shape": repr(value.shape),
  81. "dtype": repr(value.dtype),
  82. "value": value_out,
  83. }
  84. if value.dtype in {torch.float16, torch.float32, torch.bfloat16}:
  85. out.update(
  86. {
  87. "mean": _sanitize_repr_for_diff(repr(value.mean())),
  88. "std": _sanitize_repr_for_diff(repr(value.std())),
  89. "min": _sanitize_repr_for_diff(repr(value.min())),
  90. "max": _sanitize_repr_for_diff(repr(value.max())),
  91. }
  92. )
  93. return out
  94. def _serialize_io(value, debug_path: Optional[str] = None, use_repr: bool = True, path_to_value: Optional[str] = None):
  95. """
  96. Recursively build a JSON-serializable Python structure from `value`.
  97. Tensors and DTensors become either sanitized repr strings, or are saved to disk as SafeTensors files and their
  98. relative paths are recorded in the returned Python structure.
  99. Lists/tuples/dicts are recursed into.
  100. All memory addresses are replaced with a stable placeholder.
  101. Args:
  102. value: Any Python object, often including torch Tensors, lists, dicts, etc.
  103. debug_path (`str`, *optional*, defaults to `None`): Directory to dump debug JSON and SafeTensors files.
  104. use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
  105. `value` property in the asscoiated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
  106. files and store the relative path to that file in the `value` property.
  107. path_to_value (`str`, *optional*, defaults to `None`): The file name for the SafeTensors file holding the full
  108. tensor value if `use_repr=False`.
  109. Returns:
  110. A nested Python structure (list, dict, or sanitized string) that is safe to json.dump.
  111. """
  112. if isinstance(value, (list, tuple)):
  113. return [
  114. _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{i}")
  115. for i, v in enumerate(value)
  116. ]
  117. if isinstance(value, dict):
  118. return {
  119. k: _serialize_io(v, debug_path=debug_path, use_repr=use_repr, path_to_value=f"{path_to_value}_{k}")
  120. for k, v in value.items()
  121. }
  122. if hasattr(value, "_local_tensor"):
  123. return _serialize_tensor_like_io(
  124. value._local_tensor, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value
  125. )
  126. if isinstance(value, torch.Tensor):
  127. return _serialize_tensor_like_io(value, debug_path=debug_path, use_repr=use_repr, path_to_value=path_to_value)
  128. return _sanitize_repr_for_diff(repr(value))
  129. def _repr_to_list(value: torch.Tensor):
  130. """
  131. Converts a tensor into a sanitized multi-line string representation.
  132. Args:
  133. value (`torch.Tensor`): The tensor to represent.
  134. Returns:
  135. `list[str]`: List of string lines representing the tensor.
  136. """
  137. torch.set_printoptions(sci_mode=True, linewidth=120)
  138. with StringIO() as buf, redirect_stdout(buf):
  139. print(value) # to redirected stdout to avoid line splits
  140. raw = buf.getvalue()
  141. return _sanitize_repr_for_diff(raw).splitlines()
  142. def prune_outputs_if_children(node):
  143. # if there are children, remove this node's "outputs"
  144. # so we only see outputs at the leaf level
  145. if node.get("children"):
  146. node.pop("outputs", None)
  147. for child in node["children"]:
  148. prune_outputs_if_children(child)
  149. LAYER_SUFFIX_RE = re.compile(r"(.*)\.(\d+)$") # should be generic enough, ends with a number
  150. def is_layer_block(node):
  151. """
  152. Checks whether a node represents a layer block with submodules.
  153. Args:
  154. node (`dict`): A node from the call tree.
  155. Returns:
  156. `bool`: Whether the node is a layer block.
  157. """
  158. match = LAYER_SUFFIX_RE.match(node.get("module_path", ""))
  159. if not match or not node.get("children"):
  160. return False
  161. number = match.group(2)
  162. return any(f".{number}." in child.get("module_path", "") for child in node["children"])
  163. def prune_intermediate_layers(node):
  164. """
  165. Recursively removes intermediate layers from the tree to improve readability.
  166. Keeps at least the first and last layers if many consecutive layers are present.
  167. Args:
  168. node (`dict`): The root or subnode to prune recursively.
  169. """
  170. if not node.get("children"):
  171. return
  172. layer_blocks = [(i, child) for i, child in enumerate(node["children"]) if is_layer_block(child)]
  173. if len(layer_blocks) > 2:
  174. to_remove = [i for i, _ in layer_blocks[1:-1]]
  175. node["children"] = [child for i, child in enumerate(node["children"]) if i not in to_remove]
  176. for child in node["children"]:
  177. prune_intermediate_layers(child)
  178. def log_model_debug_trace(debug_path: Optional[str], model):
  179. if debug_path:
  180. try:
  181. os.makedirs(debug_path, exist_ok=True)
  182. base = os.path.join(debug_path, model._debugger_module_dump_name + "_debug_tree")
  183. except Exception as e:
  184. raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
  185. else:
  186. base = model._debugger_module_dump_name + "_debug_tree"
  187. logger.info(f"Writing model trace at {base}.json")
  188. full_path = base + "_FULL_TENSORS.json"
  189. summary_path = base + "_SUMMARY.json"
  190. prune_outputs_if_children(model._call_tree)
  191. with open(full_path, "w") as f:
  192. json.dump(model._call_tree, f, indent=2)
  193. # summary-only version for readability - traversing the tree again #TODO optimize?
  194. def strip_values(node):
  195. def clean(val):
  196. if isinstance(val, dict):
  197. val.pop("value", None)
  198. for v in val.values():
  199. clean(v)
  200. elif isinstance(val, list):
  201. for item in val:
  202. clean(item)
  203. clean(node.get("inputs", {}))
  204. clean(node.get("outputs", {}))
  205. for child in node.get("children", []):
  206. strip_values(child)
  207. tree_copy = json.loads(json.dumps(model._call_tree)) # deep copy
  208. strip_values(tree_copy)
  209. with open(summary_path, "w") as f:
  210. json.dump(tree_copy, f, indent=2)
  211. def _attach_debugger_logic(
  212. model,
  213. debug_path: str = ".",
  214. do_prune_layers: bool = True,
  215. use_repr: bool = True,
  216. ):
  217. """
  218. Attaches a debugging wrapper to every module in the model.
  219. This records structured inputs and outputs during the forward pass into a call tree.
  220. Args:
  221. model (`PreTrainedModel`, `nn.Module`): Model to wrap.
  222. debug_path (`str`): Optional directory to dump debug JSON files.
  223. do_prune_layers (`bool`, *optional*, defaults to `True`): Whether to prune intermediate layers.
  224. use_repr (bool, *optional*, defaults to `True`): Whether to save a `repr()`-ized version of the tensors as the
  225. `value` property in the associated FULL_TENSORS.json file, or to store full tensors in separate SafeTensors
  226. files and store the relative path to that file in the `value` property.
  227. """
  228. class_name = model.__class__.__name__
  229. # Prepare data structures on the model object
  230. model._call_tree = {"module_path": class_name, "inputs": None, "outputs": None, "children": []}
  231. model._debugger_model_call_stack = []
  232. model._debugger_module_dump_name = class_name # used for final JSON filename
  233. if debug_path:
  234. try:
  235. os.makedirs(debug_path, exist_ok=True)
  236. except Exception as e:
  237. raise ValueError(f"Unexpected or existing debug_path={debug_path}.") from e
  238. def wrap_forward(module, full_path):
  239. orig_forward = module.forward
  240. @functools.wraps(orig_forward)
  241. def wrapped_forward(*inps, **kws):
  242. if _is_rank_zero():
  243. dict_inputs = {"args": inps, "kwargs": kws}
  244. dict_inputs = {k: dict_inputs[k] for k in dict_inputs if len(dict_inputs[k]) > 0}
  245. node = {
  246. "module_path": full_path,
  247. "inputs": _serialize_io(
  248. dict_inputs,
  249. debug_path=debug_path,
  250. use_repr=use_repr,
  251. path_to_value=f"{full_path}_inputs",
  252. ),
  253. "outputs": None,
  254. "children": [],
  255. }
  256. model._debugger_model_call_stack.append(node)
  257. with torch.no_grad():
  258. out = orig_forward(*inps, **kws)
  259. if _is_rank_zero():
  260. if sum(1 for _ in module.named_children()) > 0:
  261. node["outputs"] = None
  262. else:
  263. node["outputs"] = _serialize_io(
  264. out,
  265. debug_path=debug_path,
  266. use_repr=use_repr,
  267. path_to_value=f"{full_path}_outputs",
  268. )
  269. finished = model._debugger_model_call_stack.pop()
  270. # prune empty vertices here as well (mostly empty children nodes)
  271. if not finished["children"]:
  272. finished.pop("children")
  273. if model._debugger_model_call_stack:
  274. model._debugger_model_call_stack[-1]["children"].append(finished)
  275. return out
  276. module.forward = wrapped_forward
  277. # wrap all submodules
  278. for name, submodule in model.named_modules():
  279. if name == "":
  280. continue
  281. wrap_forward(submodule, f"{class_name}.{name}")
  282. # wrap top-level forward
  283. real_top_forward = model.forward
  284. @functools.wraps(real_top_forward)
  285. def top_wrapped_forward(*inps, **kws):
  286. if _is_rank_zero():
  287. top_node = {
  288. "module_path": f"{class_name} (top-level)",
  289. "inputs": _serialize_io(
  290. {"args": inps, "kwargs": kws},
  291. debug_path=debug_path,
  292. use_repr=use_repr,
  293. path_to_value=f"{class_name}_inputs",
  294. ),
  295. "outputs": None,
  296. "children": [],
  297. }
  298. model._debugger_model_call_stack.append(top_node)
  299. out = real_top_forward(*inps, **kws)
  300. if _is_rank_zero() and model._debugger_model_call_stack:
  301. top_node["outputs"] = _serialize_io(
  302. out,
  303. debug_path=debug_path,
  304. use_repr=use_repr,
  305. path_to_value=f"{class_name}_outputs",
  306. )
  307. finished = model._debugger_model_call_stack.pop()
  308. model._call_tree["inputs"] = finished["inputs"]
  309. model._call_tree["outputs"] = finished["outputs"]
  310. model._call_tree["children"] = finished["children"]
  311. # prune empty stuff for visibility
  312. [model._call_tree.pop(k, None) for k in list(model._call_tree.keys()) if not model._call_tree[k]]
  313. # prune layers that are not 0 or last
  314. if do_prune_layers:
  315. prune_intermediate_layers(model._call_tree)
  316. # Write final JSON trace here
  317. log_model_debug_trace(debug_path=debug_path, model=model)
  318. return out
  319. model.forward = top_wrapped_forward
  320. @requires(backends=("torch",))
  321. @contextmanager
  322. def model_addition_debugger_context(
  323. model,
  324. debug_path: Optional[str] = None,
  325. do_prune_layers: bool = True,
  326. use_repr: bool = True,
  327. ):
  328. """
  329. # Model addition debugger - context manager for model adders
  330. This context manager is a power user tool intended for model adders.
  331. It tracks all forward calls within a model forward and logs a slice of each input and output on a nested JSON file.
  332. If `use_repr=True` (the default), the JSON file will record a `repr()`-ized version of the tensors as a list of
  333. strings. If `use_repr=False`, the full tensors will be stored in separate SafeTensors files and the JSON file will
  334. provide a relative path to that file.
  335. To note, this context manager enforces `torch.no_grad()`.
  336. ## Usage
  337. add the context manager to a model to debug
  338. ```python
  339. import torch
  340. from PIL import Image
  341. from transformers import LlavaProcessor, LlavaForConditionalGeneration, model_addition_debugger_context
  342. torch.random.manual_seed(673)
  343. # load pretrained model and processor
  344. model_id = "llava-hf/llava-1.5-7b-hf"
  345. processor = LlavaProcessor.from_pretrained(model_id)
  346. model = LlavaForConditionalGeneration.from_pretrained(model_id)
  347. # create random image input
  348. random_image = Image.fromarray(torch.randint(0, 256, (224, 224, 3), dtype=torch.uint8).numpy())
  349. # prompt
  350. prompt = "<image>Describe this image."
  351. # process inputs
  352. inputs = processor(text=prompt, images=random_image, return_tensors="pt")
  353. # call forward method (not .generate!)
  354. with model_addition_debugger_context(model, debug_path="Your_debug_path", do_prune_layers=False):
  355. output = model.forward(**inputs)
  356. ```
  357. """
  358. orig_forwards = {m: m.forward for _, m in model.named_modules()}
  359. orig_forwards[model] = model.forward
  360. _attach_debugger_logic(model, debug_path, do_prune_layers, use_repr)
  361. try:
  362. yield model
  363. finally:
  364. for module_instance, forward_method in orig_forwards.items():
  365. module_instance.forward = forward_method