_dynamism.py 4.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118
  1. import re
  2. from typing import Any, Callable, Union
  3. import torch
  4. from torch.utils._pytree import tree_flatten_with_path, tree_map
  5. KeyPath = tuple[Any, ...]
  6. NonTensorShapeFn = Callable[[Union[int, float]], tuple[Any, ...]]
  7. __all__ = [
  8. "normalize_source_name",
  9. "module_to_nested_dict",
  10. "track_dynamism_across_examples",
  11. "clone_and_convert_to_meta",
  12. ]
  13. def normalize_source_name(name: str) -> str:
  14. # Match attribute access like .x and replace with ['x']
  15. return re.sub(r"\.([a-zA-Z_][a-zA-Z0-9_]*)", r"['\1']", name)
  16. def module_to_nested_dict(module: torch.nn.Module) -> dict[str, Any]:
  17. """Recursively converts an nn.Module into a nested dictionary with explicit 'parameters' and 'modules' keys."""
  18. self_dict: dict[str, Any] = {}
  19. self_dict["_parameters"] = {}
  20. self_dict["_modules"] = {}
  21. for attr_name in dir(module):
  22. try:
  23. if not attr_name.startswith("_") and not callable(
  24. getattr(module, attr_name)
  25. ):
  26. attr_value = getattr(module, attr_name)
  27. if (
  28. not isinstance(attr_value, torch.nn.Module)
  29. and isinstance(attr_value, (int, float, torch.Tensor))
  30. and type(attr_value) is not bool
  31. ):
  32. self_dict[attr_name] = attr_value
  33. except NotImplementedError:
  34. # Skip attributes that raise NotImplementedError since they won't
  35. # contain any dynamism anyways.
  36. continue
  37. for name, param in module.named_parameters(recurse=False):
  38. self_dict["_parameters"][name] = param
  39. for name, buffer in module.named_buffers(recurse=False):
  40. self_dict["_parameters"][name] = buffer
  41. for name, submodule in module.named_children():
  42. self_dict["_modules"][name] = module_to_nested_dict(submodule)
  43. return self_dict
  44. def track_dynamism_across_examples(
  45. example_inputs: list[Any],
  46. ) -> dict[Any, Any]:
  47. """
  48. This function analyzes a list of example inputs to determine the dynamism of their shapes.
  49. It tracks whether the dimensions of tensors or non-tensor values change across
  50. different examples. The function returns a dictionary where each key represents
  51. a path to a value in the input examples, and the corresponding value is a tuple
  52. indicating which dimensions are dynamic (i.e., change across examples). This
  53. helps in understanding how the structure of data varies across different instances.
  54. """
  55. tracking: dict[KeyPath, tuple[list[set[Any]], bool]] = {}
  56. for ex in example_inputs:
  57. if "self" in ex and isinstance(ex["self"], torch.nn.Module):
  58. ex["self"] = module_to_nested_dict(ex["self"])
  59. leaves_with_paths, _ = tree_flatten_with_path(ex)
  60. for key_path, value in leaves_with_paths:
  61. if not isinstance(value, (int, float, torch.Tensor)):
  62. continue
  63. if isinstance(value, torch.Tensor):
  64. shape: tuple[int | float, ...] = tuple(value.shape)
  65. is_tensor = True
  66. else:
  67. shape = (value,)
  68. is_tensor = False
  69. if key_path not in tracking:
  70. tracking[key_path] = ([set() for _ in range(len(shape))], is_tensor)
  71. else:
  72. dim_sets, flag = tracking[key_path]
  73. if flag != is_tensor:
  74. pass
  75. while len(dim_sets) < len(shape):
  76. dim_sets.append(set())
  77. for i, dim in enumerate(shape):
  78. tracking[key_path][0][i].add(dim)
  79. output: dict[Any, Any] = {}
  80. for key_path, (dim_sets, _is_tensor) in tracking.items():
  81. final_dyn = tuple(len(s) > 1 for s in dim_sets)
  82. key_str = "L" + "".join(f"{str(k)}" for k in key_path)
  83. key = key_path[0].key # type: ignore[attr-defined]
  84. if key not in output:
  85. output[key] = {}
  86. output[key][key_str] = final_dyn
  87. return output
  88. def clone_and_convert_to_meta(example_input: Any) -> Any:
  89. """
  90. This function takes a list of example inputs and for each tensor, clones it and converts it to device=meta.
  91. For non-tensor values, it keeps the reference. It uses pytree to handle nested structures recursively.
  92. """
  93. def transform_fn(value: Any) -> Any:
  94. if isinstance(value, torch.Tensor):
  95. return value.clone().to(device="meta")
  96. return value
  97. return tree_map(transform_fn, example_input)