_utils.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import logging
  4. from dataclasses import dataclass
  5. from typing import Union
  6. import torch
  7. from torch import fx
  8. logger = logging.getLogger(__name__)
  9. def flatten_args_detach(args):
  10. """
  11. Flatten the args into a list form and detach the tensors from computational graph.
  12. """
  13. flat_detached_args = []
  14. def extract_tensor_args(a):
  15. nonlocal flat_detached_args
  16. if isinstance(a, torch.Tensor):
  17. val = a.detach().requires_grad_(a.requires_grad)
  18. flat_detached_args.append(val)
  19. return val
  20. else:
  21. flat_detached_args.append(a)
  22. return a
  23. new_args = fx.node.map_aggregate(
  24. args,
  25. extract_tensor_args,
  26. )
  27. return new_args, flat_detached_args
  28. def flatten_args(args):
  29. """
  30. Flatten the args into a list form.
  31. """
  32. flat_args = []
  33. def extract_tensor_args(a):
  34. nonlocal flat_args
  35. flat_args.append(a)
  36. return a
  37. fx.node.map_aggregate(
  38. args,
  39. extract_tensor_args,
  40. )
  41. return flat_args
  42. class PipeliningShapeError(RuntimeError):
  43. """Shape mismatch between configured and runtime values."""
  44. def validate_tensor_metadata(desc, expected, given):
  45. if not expected.shape == given.shape:
  46. raise PipeliningShapeError(
  47. f"{desc} has a shape mismatch: expected {expected.shape} actual {given.shape}"
  48. )
  49. if not expected.dtype == given.dtype:
  50. raise PipeliningShapeError(
  51. f"{desc} has a dtype mismatch: expected {expected.dtype} actual {given.dtype}"
  52. )
  53. if not expected.stride() == given.stride():
  54. raise PipeliningShapeError(
  55. f"{desc} has a stride mismatch: expected {expected.stride()} actual {given.stride()}"
  56. )
  57. def validate_tensors_metadata(
  58. desc,
  59. expected_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
  60. actual_tensors: Union[list[torch.Tensor], tuple[torch.Tensor, ...]],
  61. ):
  62. if len(expected_tensors) != len(actual_tensors):
  63. raise PipeliningShapeError(
  64. f"{desc}: Number of values ({len(actual_tensors)}) does not match expected number ({len(expected_tensors)})"
  65. )
  66. for i in range(len(expected_tensors)):
  67. validate_tensor_metadata(
  68. f"{desc}: value {i}", expected_tensors[i], actual_tensors[i]
  69. )
  70. def generate_stage_to_rank_mapping(
  71. pp_size: int, num_stages: int, style: str = "loop"
  72. ) -> dict[int, int]:
  73. """
  74. Compute the stage id to rank mapping for either a looped or V-style schedule.
  75. Most commonly num_stages == pp_size * 2, but this function can be used to
  76. compute the mapping for any number of stages per rank.
  77. """
  78. mapping = {}
  79. if style == "loop":
  80. for stage_index in range(num_stages):
  81. mapping[stage_index] = stage_index % pp_size
  82. elif style == "v":
  83. if num_stages % pp_size != 0:
  84. raise ValueError(
  85. f"num_stages {num_stages} must be evenly divisible by pp_size {pp_size} for V schedules"
  86. )
  87. rank_index = 0
  88. for stage_index in range(num_stages):
  89. mapping[stage_index] = rank_index
  90. # dont change rank if we are on the border (to keep v shape)
  91. if (stage_index + 1) % pp_size == 0:
  92. continue
  93. if (stage_index // pp_size) % 2 == 0:
  94. rank_index += 1
  95. else:
  96. rank_index -= 1
  97. else:
  98. raise ValueError(f"Style {style} is not supported.")
  99. return mapping
  100. def generate_rank_to_stage_mapping(
  101. pp_size: int, num_stages: int, style: str = "loop"
  102. ) -> dict[int, list[int]]:
  103. """
  104. Compute the rank to stage id mapping for either a looped or V-style schedule.
  105. This function inverts the stage_to_rank_mapping to get which stages are assigned to each rank.
  106. Returns a dictionary mapping rank -> list of stage indices assigned to that rank.
  107. """
  108. stage_to_rank = generate_stage_to_rank_mapping(pp_size, num_stages, style)
  109. # Invert the mapping: rank -> list of stages
  110. rank_to_stages: dict[int, list[int]] = {}
  111. for stage_id, rank in stage_to_rank.items():
  112. if rank not in rank_to_stages:
  113. rank_to_stages[rank] = []
  114. rank_to_stages[rank].append(stage_id)
  115. # Sort the stage lists for each rank to ensure consistent ordering
  116. for stages in rank_to_stages.values():
  117. stages.sort()
  118. return rank_to_stages
  119. @dataclass
  120. class PipeInfo:
  121. """
  122. Captures information for a pipeline (`Pipe` object).
  123. """
  124. graph: fx.Graph
  125. num_stages: int
  126. has_loss_and_backward: bool