utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477
  1. # mypy: allow-untyped-defs
  2. import cProfile
  3. import inspect
  4. import io
  5. import itertools
  6. import os
  7. import warnings
  8. from collections.abc import Sequence
  9. from contextlib import contextmanager
  10. from functools import wraps
  11. from pstats import Stats
  12. from typing import Any, Callable, cast, Optional, TypeVar, Union
  13. import torch
  14. import torch.distributed as dist
  15. from torch.distributed._shard.sharded_tensor import ShardedTensor
  16. from torch.distributed._shard.sharded_tensor.shard import Shard
  17. from .api import (
  18. _is_wrapped_exception,
  19. _wrap_exception,
  20. CheckpointException,
  21. WRAPPED_EXCEPTION,
  22. )
  23. from .metadata import MetadataIndex, STATE_DICT_TYPE
  24. __all__ = ["find_tensor_shard", "find_state_dict_object"]
  25. T = TypeVar("T")
  26. R = TypeVar("R")
  27. def _get_failure_dict(
  28. results: list[Union[T, WRAPPED_EXCEPTION]],
  29. ) -> dict[int, WRAPPED_EXCEPTION]:
  30. return cast(
  31. dict[int, WRAPPED_EXCEPTION],
  32. {i: err for i, err in enumerate(results) if _is_wrapped_exception(err)},
  33. )
  34. def _all_gather_keys(
  35. local_dict: dict[str, Any], group: Optional[dist.ProcessGroup] = None
  36. ) -> set[str]:
  37. """Gathers all keys, and returns them sorted."""
  38. keys = list(local_dict.keys())
  39. gathered_keys: list[list[str]] = [None] * dist.get_world_size(group) # type: ignore[list-item]
  40. dist.all_gather_object(gathered_keys, keys, group=group)
  41. return set(itertools.chain.from_iterable(gathered_keys))
  42. def _assert_same_keys(
  43. state_dict: dict[str, Any], process_group: Optional[dist.ProcessGroup] = None
  44. ) -> None:
  45. """
  46. Asserts that all ranks have the same keys in their state dict.
  47. This is a collective call which requires all ranks in ``process_group`` to
  48. join. It will also induce cross-rank communication and block CPU.
  49. """
  50. if dist.get_world_size(process_group) == 1:
  51. return
  52. all_keys = _all_gather_keys(state_dict, process_group)
  53. my_keys = set(state_dict.keys())
  54. diff = all_keys - my_keys
  55. if len(diff) > 0:
  56. raise AssertionError(
  57. f"Key(s) present in other ranks but not this one, difference: {diff}"
  58. )
  59. class _DistWrapper:
  60. """
  61. This is a wrapper around PG that provides a series of features around object collectives.
  62. It works without distributed initialized, where most collectives turns into nops.
  63. All variants that take functions are exception robust, meaning that if one or more
  64. ranks raise errors, all ranks will observe those.
  65. """
  66. def __init__(
  67. self,
  68. group: Optional[dist.ProcessGroup],
  69. use_dist: bool,
  70. coordinator_rank: int,
  71. ):
  72. self.group = group
  73. self.use_dist = use_dist
  74. self.coordinator_rank = coordinator_rank
  75. if self.use_dist:
  76. self.global_coordinator_rank = (
  77. dist.get_global_rank(group, coordinator_rank)
  78. if group is not None
  79. else coordinator_rank
  80. )
  81. self.rank = dist.get_rank(group)
  82. self.is_coordinator = self.rank == coordinator_rank
  83. else:
  84. self.global_coordinator_rank = 0
  85. self.rank = 0
  86. self.is_coordinator = True
  87. def get_rank(self) -> int:
  88. return self.rank
  89. def get_world_size(self) -> int:
  90. if self.use_dist:
  91. return dist.get_world_size(self.group)
  92. return 1
  93. def broadcast_object(self, object: Optional[T]) -> T:
  94. """Implement functionality similar to c10d::broadcast_object_list but without distributed enabled."""
  95. object_list = [object]
  96. if self.use_dist:
  97. dist.broadcast_object_list(
  98. object_list=object_list,
  99. group=self.group,
  100. src=self.global_coordinator_rank,
  101. )
  102. return cast(T, object_list[0])
  103. def gather_object(self, object: T) -> Optional[list[T]]:
  104. """Implement functionality similar to c10d::gather_object but without distributed enabled."""
  105. if self.use_dist:
  106. gather_objs = (
  107. cast(list[T], [None] * dist.get_world_size(self.group))
  108. if self.is_coordinator
  109. else None
  110. )
  111. dist.gather_object(
  112. obj=object,
  113. object_gather_list=gather_objs if self.is_coordinator else None,
  114. dst=self.global_coordinator_rank,
  115. group=self.group,
  116. )
  117. result = gather_objs
  118. else:
  119. result = [object]
  120. return result
  121. def all_gather_object(self, object: T) -> list[T]:
  122. """Implement functionality similar to c10d::all_gather_object but without distributed enabled."""
  123. if self.use_dist:
  124. gather_objs = cast(list[T], [None] * dist.get_world_size(self.group))
  125. dist.all_gather_object(
  126. object_list=gather_objs, obj=object, group=self.group
  127. )
  128. else:
  129. gather_objs = [object]
  130. return gather_objs
  131. def scatter_object(self, object_list: Optional[list[T]]) -> T:
  132. """Implement functionality similar to c10d::scatter_object but without distributed enabled."""
  133. if self.use_dist:
  134. gather_result = cast(list[T], [None])
  135. dist.scatter_object_list(
  136. scatter_object_output_list=gather_result,
  137. scatter_object_input_list=object_list if self.is_coordinator else None,
  138. src=self.global_coordinator_rank,
  139. group=self.group,
  140. )
  141. local_reply = gather_result[0]
  142. else:
  143. assert object_list is not None
  144. local_reply = object_list[0]
  145. return local_reply
  146. def reduce_scatter(
  147. self,
  148. step: str,
  149. map_fun: Callable[[], T],
  150. reduce_fun: Callable[[list[T]], list[R]],
  151. ) -> R:
  152. """
  153. Compute a value on each rank, then do centralized reduce on a single rank, followed by a scatter.
  154. This method operates in the following way:
  155. Run ``map_fun`` on all ranks
  156. Gather results on rank 0
  157. Call ``reduce_fun`` on all those values
  158. Scatter to each rank part of the result.
  159. """
  160. local_data: Union[WRAPPED_EXCEPTION, T]
  161. try:
  162. local_data = map_fun()
  163. except BaseException as e: # noqa: B036
  164. local_data = _wrap_exception(e)
  165. all_data = self.gather_object(local_data)
  166. all_results: Optional[list[Union[R, CheckpointException]]] = None
  167. if self.is_coordinator:
  168. assert all_data is not None
  169. node_failures = _get_failure_dict(all_data)
  170. if len(node_failures) == 0:
  171. try:
  172. # N.B. why can't mypy cast List[R] to List[Union[R, WRAPPED_EXCEPTION]]?
  173. all_results = cast(
  174. list[Union[R, CheckpointException]],
  175. reduce_fun(cast(list[T], all_data)),
  176. )
  177. except BaseException as e: # noqa: B036
  178. node_failures[self.rank] = _wrap_exception(e)
  179. if len(node_failures) > 0:
  180. all_results = [
  181. CheckpointException(step, node_failures)
  182. ] * self.get_world_size()
  183. result = self.scatter_object(all_results)
  184. if isinstance(result, CheckpointException):
  185. raise result
  186. return result
  187. def all_reduce(
  188. self,
  189. step: str,
  190. map_fun: Callable[[], T],
  191. reduce_fun: Callable[[list[T]], R],
  192. ) -> R:
  193. """
  194. Compute a value on each rank, then do centralized reduce on a single rank, followed by a broadcast.
  195. This method operates in the following way:
  196. Run ``map_fun`` on all ranks
  197. Gather results on rank 0
  198. Call ``reduce_fun`` on all those values
  199. Broadcast the reduced value to all ranks.
  200. """
  201. local_data: Union[T, WRAPPED_EXCEPTION]
  202. try:
  203. local_data = map_fun()
  204. except BaseException as e: # noqa: B036
  205. local_data = _wrap_exception(e)
  206. all_data = self.gather_object(local_data)
  207. result: Optional[Union[R, CheckpointException]] = None
  208. if self.is_coordinator:
  209. assert all_data is not None
  210. node_failures = _get_failure_dict(all_data)
  211. if len(node_failures) == 0:
  212. try:
  213. result = reduce_fun(cast(list[T], all_data))
  214. except BaseException as e: # noqa: B036
  215. node_failures[self.rank] = _wrap_exception(e)
  216. if len(node_failures) > 0:
  217. result = CheckpointException(step, node_failures)
  218. final_result = self.broadcast_object(result)
  219. if isinstance(final_result, CheckpointException):
  220. raise final_result
  221. return cast(R, final_result)
  222. def all_gather(
  223. self,
  224. step: str,
  225. map_fun: Callable[[], T],
  226. ) -> list[T]:
  227. """
  228. Compute a value on each rank, then all_gather them.
  229. This method operates in the following way:
  230. Run ``map_cp`` on all ranks
  231. all_gather the values to all ranks
  232. """
  233. result: Union[T, WRAPPED_EXCEPTION]
  234. try:
  235. result = map_fun()
  236. except BaseException as e: # noqa: B036
  237. result = _wrap_exception(e)
  238. all_results = self.all_gather_object(result)
  239. node_failures = _get_failure_dict(all_results)
  240. if len(node_failures) > 0:
  241. raise CheckpointException(step, node_failures)
  242. return cast(list[T], all_results)
  243. def broadcast(
  244. self,
  245. step: str,
  246. map_fun: Callable[[], T],
  247. ) -> T:
  248. """
  249. Compute a value on rank 0 and broadcast it.
  250. This method operates in the following way:
  251. Run ``map_cp`` on rank 0
  252. broadcast the value
  253. """
  254. result: Optional[Union[T, CheckpointException]] = None
  255. if self.is_coordinator:
  256. try:
  257. result = map_fun()
  258. except BaseException as e: # noqa: B036
  259. result = CheckpointException(step, {self.rank: _wrap_exception(e)})
  260. final_result = self.broadcast_object(result)
  261. if isinstance(final_result, CheckpointException):
  262. raise final_result
  263. return cast(T, final_result)
  264. def barrier(self) -> None:
  265. """
  266. Add a synchronization point across all processes when using distributed.
  267. If torch.distributed is initialized, this function will invoke a barrier across the global process group.
  268. If torch.distributed is not initialized, this function is a no-op.
  269. """
  270. if not self.use_dist:
  271. return
  272. dist.barrier(group=self.group)
  273. def _find_shard(tensor: ShardedTensor, index: MetadataIndex) -> Shard:
  274. if index.offset is None:
  275. raise ValueError(
  276. f"Cannot lookup {index.fqn} since its a ShardedTensor and no offset was provided"
  277. )
  278. shards = tensor.local_shards()
  279. # index fast path
  280. if index.index is not None:
  281. if (
  282. len(shards) > index.index
  283. and torch.Size(shards[index.index].metadata.shard_offsets) == index.offset
  284. ):
  285. return shards[index.index]
  286. for shard in shards:
  287. if torch.Size(shard.metadata.shard_offsets) == index.offset:
  288. return shard
  289. raise ValueError(f"Could not find shard at '{index.offset}' for FQN: '{index.fqn}'")
  290. def find_tensor_shard(tensor: torch.Tensor, index: MetadataIndex) -> torch.Tensor:
  291. if hasattr(tensor, "__get_tensor_shard__"):
  292. # DTensor implements _Checkpointable
  293. return tensor.__get_tensor_shard__(index) # type: ignore[attr-defined]
  294. if isinstance(tensor, ShardedTensor):
  295. return _find_shard(tensor, index).tensor
  296. if index.offset is not None:
  297. # special case looking up a tensor by origin
  298. if index.offset == torch.Size([0] * len(tensor.size())):
  299. return tensor
  300. raise ValueError(
  301. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  302. )
  303. return tensor
  304. def find_state_dict_object(state_dict: STATE_DICT_TYPE, index: MetadataIndex) -> Any:
  305. if index.fqn not in state_dict:
  306. raise ValueError(f"Could not find FQN: '{index.fqn}'")
  307. obj = state_dict[index.fqn]
  308. if isinstance(obj, torch.Tensor):
  309. return find_tensor_shard(obj, index)
  310. elif index.offset is not None:
  311. raise ValueError(
  312. f"FQN: '{index.fqn}' is not a ShardedTensor, can't find by offset: '{index.offset}'"
  313. )
  314. return obj
  315. def _element_wise_add(a: Sequence[int], b: Sequence[int]) -> list[int]:
  316. return [i_a + i_b for i_a, i_b in zip(a, b)]
  317. def _element_wise_sub(a: Sequence[int], b: Sequence[int]) -> list[int]:
  318. return [i_a - i_b for i_a, i_b in zip(a, b)]
  319. class _ReaderView(io.IOBase):
  320. def __init__(self, base_stream: io.IOBase, offset: int, len: int):
  321. super().__init__()
  322. self.offset = offset
  323. self.len = len
  324. self.base_stream = base_stream
  325. self.seek(0)
  326. def seek(self, offset: int, whence: int = os.SEEK_SET, /) -> int:
  327. if whence == os.SEEK_SET:
  328. offset = self.offset + offset
  329. elif whence == os.SEEK_END:
  330. whence = os.SEEK_SET
  331. offset = (self.offset + self.len) - offset
  332. return self.base_stream.seek(offset, whence)
  333. def tell(self) -> int:
  334. return self.base_stream.tell() - self.offset
  335. def readable(self) -> bool:
  336. return self.base_stream.readable()
  337. def seekable(self) -> bool:
  338. return self.base_stream.seekable()
  339. def readinto(self, b):
  340. max_size = self.len - self.tell()
  341. if max_size == 0:
  342. return 0
  343. if len(b) > max_size:
  344. b = memoryview(b)[:max_size]
  345. return self.base_stream.readinto(b) # type: ignore[attr-defined]
  346. def read(self, size=-1):
  347. max_size = self.len - self.tell()
  348. if size == -1 or size > max_size:
  349. size = max_size
  350. return self.base_stream.read(size)
  351. def _create_file_view(file: io.IOBase, offset: int, length: int) -> io.IOBase:
  352. # FIXME (kumpera) torch.load fails if we wrap with io.BufferedReader
  353. return _ReaderView(file, offset, length)
  354. def _normalize_device_info(device_type: str, device_id: int) -> str:
  355. """Device info normalization."""
  356. if device_type == "cpu":
  357. return "cpu"
  358. return f"{device_type}:{device_id}"
  359. # TODO: integrate with distributed logging flag
  360. ENABLE_PROFILE = False
  361. @contextmanager
  362. def _profile():
  363. # Only log the profiling when it is enable and is on rank0 or dist is not
  364. # available.
  365. if ENABLE_PROFILE and (not dist.is_available() or dist.get_rank() == 0):
  366. profiler = cProfile.Profile()
  367. profiler.enable()
  368. try:
  369. yield
  370. finally:
  371. profiler.disable()
  372. stats = Stats(profiler)
  373. stats.sort_stats("time").print_stats(10)
  374. else:
  375. yield
  376. def _api_bc_check(func):
  377. @wraps(func)
  378. def inner_func(*args, **kwargs) -> Any:
  379. if len(args) == 2:
  380. warnings.warn(
  381. f"The argument order of {func.__name__} has been changed. "
  382. "Please check the document to avoid future breakages."
  383. )
  384. sig = inspect.signature(func)
  385. kwonlyargs = [
  386. p.name for p in sig.parameters.values() if p.kind == p.KEYWORD_ONLY
  387. ]
  388. if "storage_writer" in kwonlyargs:
  389. assert "storage_writer" not in kwargs, (args, kwargs)
  390. kwargs["storage_writer"] = args[1]
  391. elif "storage_reader" in kwonlyargs:
  392. assert "storage_reader" not in kwargs, (args, kwargs)
  393. kwargs["storage_reader"] = args[1]
  394. else:
  395. raise RuntimeError(f"Unexpected kwonlyargs = {kwonlyargs}")
  396. return func(args[0], **kwargs)
  397. else:
  398. return func(*args, **kwargs)
  399. return inner_func