metadata.py 2.2 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364
  1. # mypy: allow-untyped-defs
  2. from dataclasses import dataclass
  3. from functools import reduce
  4. from typing import Optional, Union
  5. from torch.distributed.remote_device import _remote_device
  6. @dataclass
  7. class ShardMetadata:
  8. """
  9. Represents a shard of the overall Tensor including its
  10. offsets, lengths and device placement.
  11. Args:
  12. shard_offsets(List[int]): Offsets in the original tensor indicating
  13. the start offsets for this shard. Should have the same rank as
  14. the original tensor.
  15. shard_sizes(List[int]): Integers indicating the size of each
  16. dimension for this shard. Should have the same rank as the
  17. original tensor.
  18. placement(:class:`torch.distributed._remote_device`):
  19. Specifies the placement of this shard.
  20. """
  21. __slots__ = ["shard_offsets", "shard_sizes", "placement"]
  22. shard_offsets: list[int]
  23. shard_sizes: list[int]
  24. placement: Optional[_remote_device]
  25. def __init__(
  26. self,
  27. shard_offsets: list[int],
  28. shard_sizes: list[int],
  29. placement: Optional[Union[str, _remote_device]] = None,
  30. ):
  31. self.shard_offsets = shard_offsets
  32. self.shard_sizes = shard_sizes
  33. if isinstance(placement, str):
  34. self.placement = _remote_device(placement)
  35. else:
  36. self.placement = placement
  37. if len(self.shard_offsets) != len(self.shard_sizes):
  38. raise ValueError(
  39. f"shard_offsets and shard_sizes should have "
  40. f"the same number of elements, found {len(self.shard_offsets)} "
  41. f"and {self.shard_sizes} respectively"
  42. )
  43. for i in range(len(self.shard_offsets)):
  44. if self.shard_offsets[i] < 0:
  45. raise ValueError("shard_offsets should be >=0")
  46. if self.shard_sizes[i] < 0:
  47. raise ValueError("shard_sizes should be >= 0")
  48. def __hash__(self):
  49. def _hash_reduce(a, b):
  50. return (a << 8) + hash(b)
  51. res = reduce(_hash_reduce, self.shard_offsets, 37)
  52. res = reduce(_hash_reduce, self.shard_sizes, res)
  53. res = _hash_reduce(res, self.placement)
  54. return res