planner_helpers.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490
  1. # mypy: allow-untyped-defs
  2. import io
  3. from typing import Any, Callable, cast
  4. import torch
  5. import torch.distributed as dist
  6. from torch._utils import _get_device_module
  7. from torch.distributed._shard.metadata import ShardMetadata
  8. from torch.distributed._shard.sharded_tensor import ShardedTensor
  9. from torch.distributed.tensor import DTensor
  10. from torch.distributed.tensor._utils import compute_local_shape_and_global_offset
  11. from .metadata import (
  12. BytesStorageMetadata,
  13. ChunkStorageMetadata,
  14. MetadataIndex,
  15. STATE_DICT_TYPE,
  16. STORAGE_TYPES,
  17. TensorProperties,
  18. TensorStorageMetadata,
  19. )
  20. from .planner import (
  21. LoadItemType,
  22. ReadItem,
  23. SavePlan,
  24. TensorWriteData,
  25. WriteItem,
  26. WriteItemType,
  27. )
  28. from .resharding import (
  29. _check_shard_metadata_pair_overlap,
  30. _shards_get_overlap_region_wrt_saved_tensor,
  31. )
  32. __all__: list[str] = ["create_read_items_for_chunk_list"]
  33. def _compare_save_plans(plan: SavePlan, other_plan: SavePlan) -> bool:
  34. """
  35. Compare the two Save plans and return True if they are equal.
  36. Args:
  37. plan (SavePlan): First SavePlan to compare.
  38. other_plan (SavePlan): Second SavePlan to compare.
  39. Returns:
  40. True if the two plans are equal, False otherwise.
  41. """
  42. if plan.usable != other_plan.usable:
  43. return False
  44. # Both the plans should have the same number of items
  45. if len(plan.items) != len(other_plan.items):
  46. return False
  47. # Both the plans should have the same write items.
  48. for plan_item, other_plan_item in zip(plan.items, other_plan.items):
  49. # Write item type should be same
  50. if plan_item.type != other_plan_item.type:
  51. return False
  52. plan_metadata_index = plan_item.index
  53. other_plan_metadata_index = other_plan_item.index
  54. # Write item metadata_index should be same
  55. if (
  56. plan_metadata_index.fqn != other_plan_metadata_index.fqn
  57. or plan_metadata_index.offset != other_plan_metadata_index.offset
  58. or plan_metadata_index.index != other_plan_metadata_index.index
  59. ):
  60. return False
  61. # Write item tensor_data should be present in both the write items plans, if it exists in either of them.
  62. tensor_data = plan_item.tensor_data
  63. other_tensor_data = other_plan_item.tensor_data
  64. if (tensor_data and not other_tensor_data) or (
  65. not tensor_data and other_tensor_data
  66. ):
  67. return False
  68. if tensor_data and other_tensor_data:
  69. # Write item tensor_data size should be same
  70. if tensor_data.size != other_tensor_data.size:
  71. return False
  72. # Write item tensor_data chunk should be present in both the write items, if it exists in either of them.
  73. chunk = tensor_data.chunk
  74. other_chunk = other_tensor_data.chunk
  75. if (chunk and not other_chunk) or (not chunk and other_chunk):
  76. return False
  77. # Write item tensor_data chunk offsets and sizes should be same
  78. if chunk and other_chunk:
  79. if (
  80. chunk.offsets != other_chunk.offsets
  81. or chunk.sizes != other_chunk.sizes
  82. ):
  83. return False
  84. return True
  85. def _contains_usable_plan(delta_plans: list[SavePlan]) -> bool:
  86. """
  87. Check if any delta plan is usable, indicating the plan has changed.
  88. Args:
  89. delta_plans (List[SavePlan]): A list of delta plans to check.
  90. Returns:
  91. True if any delta plan is usable, False otherwise.
  92. """
  93. return any(delta_plan and delta_plan.usable for delta_plan in delta_plans)
  94. def _merge_delta_local_plans(
  95. cached_plans: list[SavePlan],
  96. delta_plans: list[SavePlan],
  97. ) -> list[SavePlan]:
  98. """
  99. Merge a list of delta plans into a single plan.
  100. Args:
  101. cached_plans (List[SavePlan]): A list of cached plans.
  102. delta_plans (List[SavePlan]): A list of delta plans to merge. It can contain empty plans
  103. Returns:
  104. A single merged plan. If a delta plan is not usable, use the cached plan. Otherwise, use the delta plan.
  105. """
  106. merged_plans = []
  107. for cached_plan, delta_plan in zip(cached_plans, delta_plans):
  108. if delta_plan and not delta_plan.usable:
  109. merged_plans.append(cached_plan)
  110. else:
  111. merged_plans.append(delta_plan)
  112. return merged_plans
  113. def _create_chunk_from_tensor(tensor: torch.Tensor) -> ChunkStorageMetadata:
  114. return ChunkStorageMetadata(
  115. offsets=torch.Size([0] * len(tensor.size())), sizes=tensor.size()
  116. )
  117. def _chunk_for_shard(shard_md: ShardMetadata) -> ChunkStorageMetadata:
  118. return ChunkStorageMetadata(
  119. offsets=torch.Size(shard_md.shard_offsets),
  120. sizes=torch.Size(shard_md.shard_sizes),
  121. )
  122. def _sharded_tensor_metadata(
  123. sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  124. ) -> TensorWriteData:
  125. shard_properties = sharded_tensor.metadata().tensor_properties
  126. properties = TensorProperties(
  127. dtype=shard_properties.dtype,
  128. layout=shard_properties.layout,
  129. requires_grad=shard_properties.requires_grad,
  130. memory_format=shard_properties.memory_format,
  131. pin_memory=shard_properties.pin_memory,
  132. )
  133. return TensorWriteData(
  134. chunk=_chunk_for_shard(shard_md),
  135. properties=properties,
  136. size=sharded_tensor.metadata().size,
  137. )
  138. def _create_write_items_for_dtensor(fqn: str, tensor: DTensor) -> WriteItem:
  139. sizes, offsets = compute_local_shape_and_global_offset(
  140. tensor.shape, tensor.device_mesh, tensor.placements
  141. )
  142. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  143. return WriteItem(
  144. index=MetadataIndex(fqn, offsets),
  145. type=WriteItemType.SHARD,
  146. tensor_data=TensorWriteData(
  147. chunk=ChunkStorageMetadata(
  148. offsets=offsets,
  149. sizes=sizes,
  150. ),
  151. properties=TensorProperties.create_from_tensor(tensor.to_local()),
  152. size=tensor.size(),
  153. ),
  154. )
  155. def _create_write_item_for_shard(
  156. fqn: str, sharded_tensor: ShardedTensor, shard_md: ShardMetadata
  157. ) -> WriteItem:
  158. offsets = torch.Size(shard_md.shard_offsets)
  159. return WriteItem(
  160. index=MetadataIndex(fqn, offsets),
  161. type=WriteItemType.SHARD,
  162. tensor_data=_sharded_tensor_metadata(sharded_tensor, shard_md),
  163. )
  164. def _create_write_item_for_tensor(fqn: str, tensor: torch.Tensor) -> WriteItem:
  165. offsets = torch.Size([0] * len(tensor.size()))
  166. return WriteItem(
  167. index=MetadataIndex(fqn, offsets),
  168. type=WriteItemType.TENSOR,
  169. tensor_data=TensorWriteData(
  170. chunk=ChunkStorageMetadata(offsets=offsets, sizes=tensor.size()),
  171. properties=TensorProperties.create_from_tensor(tensor),
  172. size=tensor.size(),
  173. ),
  174. )
  175. def _create_write_item_for_bytesio(fqn: str, bytes: Any):
  176. return WriteItem(
  177. index=MetadataIndex(fqn),
  178. type=WriteItemType.BYTE_IO,
  179. )
  180. def _create_read_item_for_byteio(
  181. dest_index, dest_offset, storage_index, storage_offset, length
  182. ):
  183. return ReadItem(
  184. type=LoadItemType.BYTE_IO,
  185. dest_index=dest_index,
  186. dest_offsets=torch.Size((dest_offset,)),
  187. storage_index=storage_index,
  188. storage_offsets=torch.Size((storage_offset,)),
  189. lengths=torch.Size((length,)),
  190. )
  191. def _create_read_item_for_tensor(
  192. dest_index, dest_offsets, storage_index, storage_offsets, lengths
  193. ):
  194. return ReadItem(
  195. type=LoadItemType.TENSOR,
  196. dest_index=dest_index,
  197. dest_offsets=torch.Size(dest_offsets),
  198. storage_index=storage_index,
  199. storage_offsets=torch.Size(storage_offsets),
  200. lengths=torch.Size(lengths),
  201. )
  202. def create_read_items_for_chunk_list(
  203. fqn: str,
  204. checkpoint_md: TensorStorageMetadata,
  205. local_chunks: list[ChunkStorageMetadata],
  206. ) -> list[ReadItem]:
  207. """
  208. Create a list of ``ReadItem`` based on the checkpoint and local chunks.
  209. This applies the resharding algorithm and computes the reads needed
  210. to satisfy ``local_chunks`` with a checkpoint described by ``checkpoint_md``.
  211. Args:
  212. fqn (str) : The state_dict FQN to pass to ``ReadItem``.
  213. checkpoint_md (TensorStorageMetadata): metadata for a given tensor
  214. from a checkpoint.
  215. local_chunks (List[ChunkStorageMetadata]): Local chunks that needs to be
  216. loaded.
  217. Returns:
  218. A list of ``ReadItem`` that will satisfy all input chunks.
  219. """
  220. read_items = []
  221. # this is a naive quadratic algo that can be optimized later
  222. for idx, shard in enumerate(local_chunks):
  223. for storage_idx, storage_md in enumerate(checkpoint_md.chunks):
  224. if not _check_shard_metadata_pair_overlap(shard, storage_md):
  225. continue
  226. storage_offsets = []
  227. dest_offsets = []
  228. lengths = []
  229. for (
  230. _dim,
  231. offset_for_saved_tensor,
  232. offset_for_current_tensor,
  233. length,
  234. ) in _shards_get_overlap_region_wrt_saved_tensor(
  235. saved_shard=storage_md, current_shard=shard
  236. ):
  237. storage_offsets.append(offset_for_saved_tensor)
  238. dest_offsets.append(offset_for_current_tensor)
  239. lengths.append(length)
  240. read_items.append(
  241. _create_read_item_for_tensor(
  242. dest_index=MetadataIndex(fqn, shard.offsets, idx),
  243. dest_offsets=dest_offsets,
  244. storage_index=MetadataIndex(fqn, storage_md.offsets, storage_idx),
  245. storage_offsets=storage_offsets,
  246. lengths=lengths,
  247. )
  248. )
  249. return read_items
  250. def _create_default_metadata_only_plan(state_dict: STATE_DICT_TYPE) -> SavePlan:
  251. requests = []
  252. for fqn, obj in state_dict.items():
  253. if isinstance(obj, DTensor):
  254. requests.append(_create_write_items_for_dtensor(fqn, obj))
  255. elif isinstance(obj, ShardedTensor):
  256. requests.extend(
  257. _create_write_item_for_shard(fqn, obj, shard_md)
  258. for shard_md in obj.metadata().shards_metadata
  259. )
  260. elif isinstance(obj, torch.Tensor):
  261. requests.append(_create_write_item_for_tensor(fqn, obj))
  262. else:
  263. requests.append(_create_write_item_for_bytesio(fqn, obj))
  264. return SavePlan(requests)
  265. def _create_write_items(fqn: str, object: Any) -> list[WriteItem]:
  266. if hasattr(object, "__create_write_items__"):
  267. # DTensor implements _Checkpointable
  268. return object.__create_write_items__(fqn, object)
  269. elif isinstance(object, ShardedTensor):
  270. return [
  271. _create_write_item_for_shard(fqn, object, shard.metadata)
  272. for shard in object.local_shards()
  273. ]
  274. elif isinstance(object, torch.Tensor):
  275. return [_create_write_item_for_tensor(fqn, object)]
  276. else:
  277. return [_create_write_item_for_bytesio(fqn, object)]
  278. def _create_chunk_from_dtensor(tensor: DTensor) -> ChunkStorageMetadata:
  279. sizes, offsets = compute_local_shape_and_global_offset(
  280. tensor.shape, tensor.device_mesh, tensor.placements
  281. )
  282. sizes, offsets = torch.Size(sizes), torch.Size(offsets)
  283. return ChunkStorageMetadata(
  284. offsets=offsets,
  285. sizes=sizes,
  286. )
  287. def _create_chunk_list(tensor: torch.Tensor) -> list[ChunkStorageMetadata]:
  288. if hasattr(tensor, "__create_chunk_list__"):
  289. # DTensor implements _Checkpointable
  290. local_chunks = tensor.__create_chunk_list__() # type: ignore[attr-defined]
  291. elif isinstance(tensor, ShardedTensor):
  292. local_chunks = [
  293. _chunk_for_shard(shard.metadata) for shard in tensor.local_shards()
  294. ]
  295. elif isinstance(tensor, torch.Tensor):
  296. local_chunks = [_create_chunk_from_tensor(tensor)]
  297. else:
  298. raise ValueError(
  299. "Unsupported Type, expecting one of [Tensor, DTensor, ShardedTensor] "
  300. f",but got {type(tensor)}"
  301. )
  302. return local_chunks
  303. def _create_read_items(fqn: str, md: STORAGE_TYPES, obj: Any) -> list[ReadItem]:
  304. if not isinstance(md, BytesStorageMetadata):
  305. try:
  306. local_chunks = _create_chunk_list(obj)
  307. except ValueError as ex:
  308. raise ValueError(
  309. f"Invalid checkpoint metadata for {fqn}, "
  310. + f"expected BytesStorageMetadata but found {type(md)}",
  311. ) from ex
  312. return create_read_items_for_chunk_list(fqn, md, local_chunks)
  313. else:
  314. return [
  315. _create_read_item_for_byteio(
  316. dest_index=MetadataIndex(fqn),
  317. dest_offset=0,
  318. storage_index=MetadataIndex(fqn),
  319. storage_offset=0,
  320. length=0,
  321. )
  322. ]
  323. def _init_state_dict(state_dict: dict[str, Any]) -> Any:
  324. """
  325. Initializes meta tensor if the meta tensor is DTensor or torch.Tensor.
  326. """
  327. def dtensor_func(value: DTensor):
  328. device = getattr(value, "device", None)
  329. if device == torch.device("meta"):
  330. device_type = dist.distributed_c10d._get_pg_default_device().type
  331. device = cast(
  332. torch.device, _get_device_module(device_type).current_device()
  333. )
  334. new_local_tensor = torch.empty_like(value.to_local(), device=device)
  335. # We need to pass shape and stride explicitly, since DTensor might be
  336. # sharded unevenly.
  337. dtensor = DTensor.from_local(
  338. new_local_tensor,
  339. device_mesh=value.device_mesh,
  340. placements=value.placements,
  341. shape=value.size(),
  342. stride=value.stride(),
  343. )
  344. return dtensor
  345. else:
  346. return value
  347. def sharded_tensor_func(value: Any):
  348. device = getattr(value, "device", None)
  349. if device == torch.device("meta"):
  350. raise RuntimeError(
  351. f"Found unsupported type {type(value)} for meta device loading."
  352. )
  353. else:
  354. return value
  355. def tensor_func(value: torch.Tensor):
  356. device = getattr(value, "device", None)
  357. if device == torch.device("meta"):
  358. device_type = dist.distributed_c10d._get_pg_default_device().type
  359. device = cast(
  360. torch.device, _get_device_module(device_type).current_device()
  361. )
  362. tensor = torch.empty_like(value, device=device)
  363. return tensor
  364. else:
  365. return value
  366. _iterate_state_dict(
  367. state_dict,
  368. dtensor_func,
  369. sharded_tensor_func,
  370. tensor_func,
  371. )
  372. def _iterate_state_dict(
  373. iter_object: Any,
  374. dtensor_func: Callable,
  375. sharded_tensor_func: Callable,
  376. tensor_func: Callable,
  377. ):
  378. """
  379. Iterate through the state dict, applying the given functions to each tensor type
  380. and update the state dict in place.
  381. Args:
  382. iter_object (Any): the target state_dict.
  383. sharded_tensor_func (Callable): the function to apply to ShardedTensor
  384. dtensor_func (Callable): the function to apply to DTensor
  385. tensor_func (Callable): the function to apply to Tensor
  386. # TODO: let state_dict_util._iterate_state_dict() to support in place option
  387. so we don't need to have two versions of _iterate_state_dict.
  388. """
  389. if isinstance(iter_object, DTensor):
  390. return dtensor_func(iter_object)
  391. elif isinstance(iter_object, ShardedTensor):
  392. return sharded_tensor_func(iter_object)
  393. elif isinstance(iter_object, torch.Tensor):
  394. return tensor_func(iter_object)
  395. elif (
  396. isinstance(iter_object, (int, float, str, bytes, io.BytesIO))
  397. or iter_object is None
  398. ):
  399. return iter_object
  400. elif isinstance(iter_object, dict):
  401. for key, value in iter_object.items():
  402. iter_object[key] = _iterate_state_dict(
  403. value, dtensor_func, sharded_tensor_func, tensor_func
  404. )
  405. return iter_object
  406. elif isinstance(iter_object, (list, tuple)):
  407. ret = [
  408. _iterate_state_dict(v, dtensor_func, sharded_tensor_func, tensor_func)
  409. for v in iter_object
  410. ]
  411. if isinstance(iter_object, tuple):
  412. ret = tuple(ret) # type: ignore[assignment]
  413. return ret