_traverse.py 6.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198
  1. # Copyright (c) Meta Platforms, Inc. and affiliates
  2. from collections.abc import Collection, Mapping, MutableMapping
  3. from typing import Callable, cast, Optional, TypeVar, Union
  4. import torch
  5. from torch.distributed._shard.sharded_tensor.api import ShardedTensor
  6. from torch.distributed.checkpoint.metadata import STATE_DICT_TYPE
  7. from torch.distributed.tensor import DTensor
  8. PATH_ITEM = Union[str, int]
  9. OBJ_PATH = tuple[PATH_ITEM, ...]
  10. T = TypeVar("T")
  11. STATE_DICT_ITEM = object
  12. CONTAINER_TYPE = MutableMapping[PATH_ITEM, STATE_DICT_ITEM]
  13. __all__ = ["traverse_state_dict", "set_element", "get_element", "print_tensor"]
  14. def _keep_visiting_tensors(value: STATE_DICT_ITEM) -> bool:
  15. return isinstance(value, torch.Tensor)
  16. # TODO: update docstring for traverse.py
  17. def traverse_state_dict(
  18. state_dict: STATE_DICT_TYPE,
  19. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  20. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  21. ) -> None:
  22. """
  23. Invoke ``visitor`` for each value recursively in ``state_dict``.
  24. Mapping will be traversed and ``visitor`` will be applied to the leaf elements.
  25. ``visitor`` will only be applied to elements in a list or a tuple, if the
  26. container contains tensors or mappings.
  27. """
  28. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  29. values: Collection[STATE_DICT_ITEM]
  30. if isinstance(value, Mapping):
  31. return False
  32. elif isinstance(value, list):
  33. values = value
  34. else:
  35. return True
  36. for entry in values:
  37. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  38. return False
  39. if keep_traversing is not None and keep_traversing(entry):
  40. return False
  41. return True
  42. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  43. if isinstance(value, Mapping):
  44. for k, v in value.items():
  45. _traverse_obj(path + (str(k),), v)
  46. elif _is_terminal(value):
  47. visitor(path, value)
  48. elif isinstance(value, (list, tuple)):
  49. for i, v in enumerate(value):
  50. _traverse_obj(path + (i,), v)
  51. for key, value in state_dict.items():
  52. _traverse_obj((str(key),), value)
  53. def traverse_state_dict_v_2_3(
  54. state_dict: STATE_DICT_TYPE,
  55. visitor: Callable[[OBJ_PATH, STATE_DICT_ITEM], None],
  56. keep_traversing: Callable[[STATE_DICT_ITEM], bool] = _keep_visiting_tensors,
  57. ) -> None:
  58. """
  59. Traversal is short-circuited when if finds a collection for which ``keep_visiting_tensors`` evaluates
  60. to false for all elements.
  61. By default, all collections with at least one ``torch.Tensor`` element are traversed.
  62. Visitor takes a path argument that is a tuple of the keys used to reach it.
  63. """
  64. # a value is terminal if it has no other containers values inside it
  65. def _is_terminal(value: STATE_DICT_ITEM) -> bool:
  66. values: Collection[STATE_DICT_ITEM]
  67. if isinstance(value, Mapping):
  68. values = value.values()
  69. elif isinstance(value, list):
  70. values = value
  71. else:
  72. return True
  73. for entry in values:
  74. if isinstance(entry, (Mapping, list)) and not _is_terminal(entry):
  75. return False
  76. if keep_traversing is not None and keep_traversing(entry):
  77. return False
  78. return True
  79. def _traverse_obj(path: OBJ_PATH, value: STATE_DICT_ITEM) -> None:
  80. if _is_terminal(value):
  81. visitor(path, value)
  82. elif isinstance(value, Mapping):
  83. for k, v in value.items():
  84. _traverse_obj(path + (str(k),), v)
  85. elif isinstance(value, list):
  86. for i, v in enumerate(value):
  87. _traverse_obj(path + (i,), v)
  88. for key, value in state_dict.items():
  89. _traverse_obj((str(key),), value)
  90. def set_element(
  91. root_dict: STATE_DICT_TYPE, path: OBJ_PATH, value: STATE_DICT_ITEM
  92. ) -> None:
  93. """Set ``value`` in ``root_dict`` along the ``path`` object path."""
  94. cur_container = cast(CONTAINER_TYPE, root_dict)
  95. def extend_list(lst: list[STATE_DICT_ITEM], idx: int) -> None:
  96. while len(lst) <= idx:
  97. lst.append(None)
  98. for i in range(1, len(path)):
  99. prev_key = path[i - 1]
  100. key = path[i]
  101. def_val = cast(STATE_DICT_ITEM, {} if type(key) == str else [])
  102. if isinstance(cur_container, Mapping):
  103. cur_container = cast(
  104. CONTAINER_TYPE, cur_container.setdefault(prev_key, def_val)
  105. )
  106. else:
  107. extend_list(cur_container, prev_key)
  108. if cur_container[prev_key] is None:
  109. cur_container[prev_key] = def_val
  110. cur_container = cur_container[prev_key]
  111. key = path[-1]
  112. if type(key) == int:
  113. extend_list(cast(list[STATE_DICT_ITEM], cur_container), key)
  114. cur_container[key] = value
  115. def get_element(
  116. root_dict: STATE_DICT_TYPE,
  117. path: OBJ_PATH,
  118. default_value: Optional[T] = None,
  119. ) -> Optional[T]:
  120. """Retrieve the value at ``path``from ``root_dict``, returning ``default_value`` if not found."""
  121. cur_value = cast(CONTAINER_TYPE, root_dict)
  122. for part in path:
  123. if type(part) is int:
  124. if not isinstance(cur_value, list) or len(cur_value) < part:
  125. return default_value
  126. elif not isinstance(cur_value, Mapping) or part not in cur_value:
  127. return default_value
  128. cur_value = cast(CONTAINER_TYPE, cur_value[part])
  129. return cast(Optional[T], cur_value)
  130. def _print_nested(
  131. value: STATE_DICT_ITEM,
  132. prefix: str = "",
  133. print_fun: Callable[[str], None] = print,
  134. ) -> None:
  135. if type(value) is ShardedTensor:
  136. print_fun(f"{prefix} ShardedTensor size: {value.size()}")
  137. for shard in value.local_shards():
  138. _print_nested(
  139. shard.tensor,
  140. f"{shard.metadata.shard_offsets} ",
  141. print_fun=print_fun,
  142. )
  143. elif type(value) is (DTensor):
  144. print_fun(f"{prefix} DistributedTensor size: {value.size()}")
  145. # TODO: add local offset for _local_tensor in print_nested.
  146. _print_nested(
  147. value._local_tensor,
  148. print_fun=print_fun,
  149. )
  150. elif isinstance(value, torch.Tensor):
  151. print_fun(f"{prefix} Tensor size: {value.size()}")
  152. else:
  153. print_fun(f"{prefix} Type: {type(value)}")
  154. def print_tensor(
  155. path: OBJ_PATH,
  156. value: STATE_DICT_ITEM,
  157. print_fun: Callable[[str], None] = print,
  158. ) -> None:
  159. """
  160. Use this callback with traverse_state_dict to print its content.
  161. By default the content is printed using the builtin ``print`` but this can
  162. be change by passing a different ``print_fun` callable.
  163. """
  164. _print_nested(value, prefix=str(path), print_fun=print_fun)