comms.py 70 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859
  1. # mypy: allow-untyped-defs
  2. # pyre-strict
  3. from __future__ import annotations
  4. import heapq
  5. import importlib
  6. import itertools
  7. import logging
  8. import operator
  9. import sys
  10. import time
  11. from collections import defaultdict
  12. from dataclasses import dataclass
  13. from typing import Any, Optional, TYPE_CHECKING, Union
  14. import torch
  15. from torch._logging import trace_structured
  16. from torch.multiprocessing.reductions import StorageWeakRef
  17. from torch.utils._ordered_set import OrderedSet
  18. from . import config, ir
  19. from .dependencies import WeakDep
  20. if TYPE_CHECKING:
  21. from .ir import IRNode, Operation
  22. from .scheduler import SchedulerBuffer
  23. from .memory import (
  24. estimate_peak_memory,
  25. estimate_peak_memory_allocfree,
  26. FreeableInputBuffer,
  27. get_freeable_input_buf,
  28. SNodeMemory,
  29. )
  30. from .utils import (
  31. contains_collective,
  32. contains_wait,
  33. find_recursive_deps_of_node,
  34. find_recursive_users_of_node,
  35. is_collective,
  36. is_fallback_op,
  37. is_wait,
  38. )
  39. from .virtualized import V
  40. log = logging.getLogger(__name__)
  41. overlap_log = torch._logging.getArtifactLogger(__name__, "overlap")
  42. if TYPE_CHECKING:
  43. from torch._inductor.scheduler import BaseSchedulerNode
  44. def align_runtime_estimations_across_all_distributed_ranks(
  45. snodes: list[BaseSchedulerNode],
  46. ):
  47. runtime_estimations = {}
  48. for snode in snodes:
  49. runtime_estimations[snode] = snode.get_estimated_runtime()
  50. import torch.distributed as dist
  51. from torch.distributed.distributed_c10d import _get_default_group
  52. world_size = dist.get_world_size()
  53. pg = _get_default_group()
  54. gathered_runtime_estimations: list[list[float]] = [[] for _ in range(world_size)]
  55. dist.all_gather_object(
  56. gathered_runtime_estimations, list(runtime_estimations.values()), pg
  57. )
  58. median_runtime_estimations = torch.median(
  59. torch.tensor(gathered_runtime_estimations), dim=0
  60. ).values.tolist()
  61. for i in range(len(snodes)):
  62. snodes[i].override_estimated_runtime = median_runtime_estimations[i]
  63. def sink_waits(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  64. """
  65. Greedily schedules waits as late as possible.
  66. """
  67. return _schedule_for_comm(
  68. snodes, raise_comms=False, sink_waits=True, reorder_for_overlap=False
  69. )
  70. def raise_comms(snodes: list[BaseSchedulerNode]) -> list[BaseSchedulerNode]:
  71. """
  72. Greedily schedules comms as early as possible.
  73. """
  74. return _schedule_for_comm(
  75. snodes, raise_comms=True, sink_waits=False, reorder_for_overlap=False
  76. )
  77. def reorder_compute_for_overlap(
  78. snodes: list[BaseSchedulerNode],
  79. ) -> list[BaseSchedulerNode]:
  80. """
  81. This achieves the following overall scheduling procedure:
  82. Step 1: Given that we've currently scheduled comm N, we now schedule all compute nodes
  83. that are required for comm N + 1 but do not depend on comm N, to run at the same time with comm N.
  84. Step 2: If all those compute nodes are sufficient to overlap comm N, we're done.
  85. Otherwise, we now need to look elsewhere to find compute that overlaps with comm N.
  86. We prioritize compute nodes that are needed sooner.
  87. Step 3: We schedule the compute nodes dependent on comm N and required for comm N + 1.
  88. Step 4: We schedule comm N + 1.
  89. Repeat this for subsequent comm nodes.
  90. """
  91. return _schedule_for_comm(
  92. snodes, raise_comms=True, sink_waits=True, reorder_for_overlap=True
  93. )
  94. def reorder_communication_preserving_peak_memory(
  95. snodes: list[BaseSchedulerNode],
  96. ) -> list[BaseSchedulerNode]:
  97. """
  98. Reorders communication ops relative to computation ops to improve communication-compute overlapping and hide comm
  99. latency. Stops moving a particular op if it reaches a point that would have increased the peak memory footprint.
  100. Currently, follows these heuristics (subject to change or tune):
  101. - never reorders collectives relative to one another, for SPMD safety
  102. - has an option for per-collective prefetch limit, but does not enable it by default
  103. - limits the total number of reorder steps to some factor of the graph size to prevent worst-case quadratic
  104. performance
  105. Prerequisite: sink_comms_and_waits - ensure comm and wait nodes are scheduled as late as possible, respecting data
  106. dependencies. That allows reorder_communication_preserving_peak_memory to take a best case peak-memory snapshot,
  107. and then monotonically improve latency by moving collectives backward in time.
  108. Peak memory impact is computed in an iterative fashion. First, memory use at each timestep is computed, and global
  109. peak memory is computed as a max over timesteps. Then, when swapping any two adjacent nodes, only the curr-memory
  110. for the earlier of the nodes after the swap is affected. This enables checking step by step whether a swap is
  111. peak-memory-safe, and bailing out if not. Example:
  112. 0 n0 C0
  113. 1 n1 C0 + Allocs(n1) - Frees(n1)
  114. 2 n2 C0 + Allocs(n1) - Frees(n1) + Allocs(n2) - Frees(n2)
  115. 0 n0 C0
  116. 1 n2 C0 + Allocs(n2) - Frees(n2) <-- After moving n2 to Time 1, only time1 memory changes
  117. 2 n1 C0 + Allocs(n2) - Frees(n2) + Allocs(n1) - Frees(n1)
  118. """
  119. reordered_snodes, node_stats = (
  120. _reorder_communication_preserving_peak_memory_internal(snodes)
  121. )
  122. return reordered_snodes
  123. @dataclass
  124. class ReorderInfo:
  125. """
  126. Debug info describing how an individual snode was reordered
  127. """
  128. initial_exposed: float = -1
  129. final_exposed: float = -1
  130. limiting_factor: str = "None"
  131. moves: int = 0
  132. grouped: int = 0
  133. grouped_info: str = ""
  134. @property
  135. def improvement(self):
  136. return self.initial_exposed - self.final_exposed
  137. def is_gemm_like(node: Optional[Union[IRNode, Operation]]) -> bool:
  138. if node is None:
  139. return False
  140. if is_fallback_op(
  141. node, # type: ignore[arg-type]
  142. torch.ops.aten._scaled_dot_product_flash_attention.default,
  143. ):
  144. return True
  145. if (
  146. python_kernel_name := getattr(node, "python_kernel_name", None)
  147. ) and "extern_kernels" in python_kernel_name:
  148. return True
  149. return False
  150. def contains_gemm_like(snode: BaseSchedulerNode) -> bool:
  151. from torch._inductor.scheduler import GroupedSchedulerNode
  152. if isinstance(snode, GroupedSchedulerNode):
  153. return any(contains_gemm_like(x) for x in snode.snodes)
  154. else:
  155. return is_gemm_like(snode.node)
  156. def _temp_group_visit_leaves(snode, fn):
  157. from torch._inductor.scheduler import GroupedSchedulerNode
  158. if isinstance(snode, GroupedSchedulerNode) and snode.temp_grouping:
  159. for _snode in snode.snodes:
  160. fn(_snode)
  161. else:
  162. fn(snode)
  163. def _group_name(snode, with_bufs=False) -> str:
  164. ret = ""
  165. for n in snode.snodes:
  166. if ret:
  167. ret += "_"
  168. ret += n.get_name()
  169. if with_bufs:
  170. ret += f"{list(snode.get_buffer_names())}"
  171. return ret
  172. def _is_fake_dep(d):
  173. return isinstance(d, WeakDep) and d.is_fake
  174. def _group_names(gns: list[BaseSchedulerNode]) -> str:
  175. return "~".join([gn.get_name() for gn in gns])
  176. def _initialize_memory_tracking(snodes, graph_inputs, graph_outputs):
  177. """Initialize memory tracking data structures"""
  178. name_to_freeable_input_buf = get_freeable_input_buf(snodes, graph_inputs)
  179. peak_memory, snodes_curr_memory, snodes_allocfree, buf_to_snode_last_use = (
  180. estimate_peak_memory_allocfree(
  181. snodes, name_to_freeable_input_buf, graph_outputs
  182. )
  183. )
  184. _curr_memory = dict(zip(snodes, snodes_curr_memory))
  185. _curr_memory[None] = (0, 0)
  186. return (
  187. peak_memory,
  188. _curr_memory,
  189. snodes_allocfree,
  190. buf_to_snode_last_use,
  191. name_to_freeable_input_buf,
  192. )
  193. def _initialize_double_linked_list(
  194. snodes: list[BaseSchedulerNode],
  195. ) -> tuple[
  196. dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  197. dict[BaseSchedulerNode, Optional[BaseSchedulerNode]],
  198. BaseSchedulerNode,
  199. ]:
  200. """Create double-linked list structure from snodes"""
  201. _prev = {}
  202. _next = {}
  203. for i, snode in enumerate(snodes):
  204. _prev[snode] = snodes[i - 1] if i > 0 else None
  205. _next[snode] = snodes[i + 1] if i < len(snodes) - 1 else None
  206. _head = snodes[0]
  207. return _prev, _next, _head
  208. def _reorder_communication_preserving_peak_memory_internal(
  209. snodes: list[BaseSchedulerNode],
  210. ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, ReorderInfo]]:
  211. """
  212. Internal testing helper that also returns debug info.
  213. Returns:
  214. - reordered snodes list
  215. - dict {snode: ReorderInfo}
  216. """
  217. has_collectives = False
  218. for snode in snodes:
  219. if contains_collective(snode):
  220. has_collectives = True
  221. break
  222. if not has_collectives:
  223. return snodes, {}
  224. from torch._inductor.scheduler import GroupedSchedulerNode
  225. original_snodes_num = len(snodes)
  226. # heuristic to avoid degenerating to quadratic time
  227. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  228. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  229. (
  230. peak_memory,
  231. _curr_memory,
  232. snodes_allocfree,
  233. buf_to_snode_last_use,
  234. name_to_freeable_input_buf,
  235. ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
  236. runtimes: dict[BaseSchedulerNode, float] = {
  237. snode: estimate_op_runtime(snode) for snode in snodes
  238. }
  239. # debug stats
  240. stats: dict[BaseSchedulerNode, ReorderInfo] = {}
  241. def exposed_communication_time(
  242. collective_snode: BaseSchedulerNode, remaining_snodes: list[BaseSchedulerNode]
  243. ) -> float:
  244. # assumes a linear schedule and computes the overlap of the collective with the remaining nodes
  245. comm_time = estimate_op_runtime(collective_snode)
  246. compute_time = 0.0
  247. for snode in remaining_snodes:
  248. if contains_collective(snode):
  249. continue
  250. if contains_wait(snode):
  251. # TODO - if the wait is for a collective that started before this collective or on another stream,
  252. # we can ignore it. Otherwise, it's the end of the road for overlap opportunities
  253. break
  254. def accumulate_time(_snode: BaseSchedulerNode) -> None:
  255. nonlocal compute_time
  256. compute_time += runtimes[_snode]
  257. _temp_group_visit_leaves(snode, accumulate_time)
  258. return max(0, comm_time - compute_time)
  259. total_moves = 0
  260. _prev, _next, _head = _initialize_double_linked_list(snodes)
  261. def _group_nodes(
  262. head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
  263. ) -> list[BaseSchedulerNode]:
  264. ret = []
  265. n = head
  266. while True:
  267. if n is not None:
  268. ret.append(n)
  269. if n == tail:
  270. break
  271. n = _next[n] # type: ignore[index]
  272. return ret
  273. def _perform_double_linked_list_swap(candidate, group_head, group_tail):
  274. # swap (candidate, group_head...group_tail)
  275. # Before:
  276. # candidate_prev -0-> candidate -1-> group_head...group_tail -2-> group_tail_next
  277. # After:
  278. # candidate_prev -0-> group_head...group_tail -1-> candidate -2-> group_tail_next
  279. # 0
  280. candidate_prev = _prev[candidate]
  281. if candidate_prev:
  282. _next[candidate_prev] = group_head
  283. _prev[group_head] = candidate_prev
  284. # 2
  285. group_tail_next = _next[group_tail]
  286. if group_tail_next:
  287. _prev[group_tail_next] = candidate
  288. _next[candidate] = group_tail_next
  289. # 1
  290. _prev[candidate] = group_tail
  291. _next[group_tail] = candidate
  292. nonlocal _head
  293. if _head == candidate:
  294. _head = group_head
  295. def _calculate_potential_peak_memory(
  296. candidate, group_ns, group_n_to_bufs_after_swap_dealloc_by_candidate
  297. ):
  298. # Caching calculations of memory for group nodes and candidate,
  299. # to apply without recalculation after swap.
  300. _post_alloc_update: dict[BaseSchedulerNode, int] = {}
  301. potential_peak: int = 0
  302. if not group_n_to_bufs_after_swap_dealloc_by_candidate:
  303. # Not accounting for buffers last use change
  304. potential_peak = max(
  305. group_peak_memory - candidate_delta_mem,
  306. _curr_memory[group_tail][1]
  307. - candidate_delta_mem
  308. + candidate_allocfree.size_alloc,
  309. )
  310. return potential_peak, _post_alloc_update
  311. # If candidate will be after group, the starting memory level of group nodes
  312. # changes to the -(candidate.size_alloc - candidate.size_free)
  313. mem_after_reorder_delta: int = -candidate_delta_mem
  314. for gn in gns:
  315. gn_post_alloc_mem = _curr_memory[gn][0] + mem_after_reorder_delta
  316. _post_alloc_update[gn] = gn_post_alloc_mem
  317. potential_peak = max(potential_peak, gn_post_alloc_mem)
  318. bufs = group_n_to_bufs_after_swap_dealloc_by_candidate.get(gn, None)
  319. if bufs is not None:
  320. for buf in bufs:
  321. # Candidate will deallocate those buffers
  322. mem_after_reorder_delta += buf.mpi_buffer.size_free
  323. candidate_mem_post_alloc = (
  324. _curr_memory[group_tail][1]
  325. + mem_after_reorder_delta
  326. + candidate_allocfree.size_alloc
  327. )
  328. _post_alloc_update[candidate] = candidate_mem_post_alloc
  329. potential_peak = max(potential_peak, candidate_mem_post_alloc)
  330. return potential_peak, _post_alloc_update
  331. def _update_memory_tracking_after_swap(
  332. candidate,
  333. gns,
  334. group_n_to_bufs_after_swap_dealloc_by_candidate,
  335. _post_alloc_update,
  336. ):
  337. if not group_n_to_bufs_after_swap_dealloc_by_candidate:
  338. for gn in gns:
  339. cm = _curr_memory[gn]
  340. _curr_memory[gn] = (
  341. cm[0] - candidate_delta_mem,
  342. cm[1] - candidate_delta_mem,
  343. )
  344. _candidate_post_alloc_mem = (
  345. _curr_memory[group_tail][1] + candidate_allocfree.size_alloc
  346. )
  347. _candidate_post_free_mem = (
  348. _candidate_post_alloc_mem - candidate_allocfree.size_free
  349. )
  350. _curr_memory[candidate] = (
  351. _candidate_post_alloc_mem,
  352. _candidate_post_free_mem,
  353. )
  354. return
  355. # Candidate becomes last use of some bufs
  356. for (
  357. gn,
  358. bufs,
  359. ) in group_n_to_bufs_after_swap_dealloc_by_candidate.items():
  360. for buf in bufs:
  361. buf_to_snode_last_use[buf] = candidate
  362. size_free_to_move_to_candidate_sum: int = 0
  363. for n in gns:
  364. _gn_post_alloc_mem: int = _post_alloc_update[n]
  365. size_free_to_move_to_candidate: int = sum(
  366. buf.mpi_buffer.size_free
  367. for buf in group_n_to_bufs_after_swap_dealloc_by_candidate[n]
  368. )
  369. size_free_to_move_to_candidate_sum += size_free_to_move_to_candidate
  370. # group node does not deallocate this after swap
  371. snodes_allocfree[n].size_free -= size_free_to_move_to_candidate
  372. gn_post_free_mem: int = _gn_post_alloc_mem - snodes_allocfree[n].size_free
  373. _curr_memory[n] = (_gn_post_alloc_mem, gn_post_free_mem)
  374. _candidate_post_alloc_mem = _post_alloc_update[candidate]
  375. snodes_allocfree[candidate].size_free += size_free_to_move_to_candidate_sum
  376. candidate_post_free_mem = (
  377. _candidate_post_alloc_mem - snodes_allocfree[candidate].size_free
  378. )
  379. _curr_memory[candidate] = (
  380. _candidate_post_alloc_mem,
  381. candidate_post_free_mem,
  382. )
  383. debug_num_collectives_to_reorder: Optional[int] = (
  384. config.reorder_iterative_debug_limit_to_reorder
  385. )
  386. num_processed_collectives: int = 0
  387. curr = _head
  388. debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
  389. iterative_recompute_error = False
  390. while _next[curr] is not None:
  391. if iterative_recompute_error:
  392. break
  393. if contains_collective(curr):
  394. if debug_num_collectives_to_reorder is not None and (
  395. num_processed_collectives >= debug_num_collectives_to_reorder
  396. ):
  397. break
  398. num_processed_collectives += 1
  399. info = stats[curr] = ReorderInfo()
  400. info.initial_exposed = info.final_exposed = exposed_communication_time(
  401. curr, _group_nodes(_next[curr], None)
  402. )
  403. candidate = _prev[curr]
  404. group_head = curr
  405. group_tail = curr
  406. group_peak_memory = _curr_memory[curr][0] # post_alloc memory
  407. while candidate is not None:
  408. if contains_collective(candidate):
  409. info.limiting_factor = "collective ordering"
  410. break
  411. gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
  412. group = GroupedSchedulerNode(
  413. curr.scheduler,
  414. gns,
  415. temp_grouping=True,
  416. )
  417. # We can have multiple deps with the same name.
  418. # As we ignore WeakDep(is_fake=True) =>
  419. # filter them out first to avoid overwriting of real dep.
  420. data_deps = {
  421. d.name: d for d in group.unmet_dependencies if not _is_fake_dep(d)
  422. }
  423. candidate_outs = candidate.get_outputs()
  424. data_dep = None
  425. for o in candidate_outs:
  426. if d := data_deps.get(o.get_name(), None):
  427. data_dep = d
  428. break
  429. if data_dep is not None:
  430. def is_groupable(
  431. candidate: BaseSchedulerNode,
  432. ) -> tuple[bool, Optional[str]]:
  433. # preserve ordering
  434. if contains_collective(candidate):
  435. return False, "contains_collective"
  436. if contains_gemm_like(candidate):
  437. return False, "contains_gemm_like"
  438. return True, None
  439. is_groupable_result, grouping_reason = is_groupable(candidate)
  440. if is_groupable_result:
  441. group_head = candidate
  442. group_peak_memory = max(
  443. group_peak_memory, _curr_memory[candidate][0]
  444. )
  445. info.grouped += 1
  446. info.grouped_info = _group_names(gns)
  447. candidate = _prev[candidate]
  448. continue
  449. else:
  450. msg = (
  451. f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
  452. f"\n candidate:{candidate.get_name()}(outs:{[candidate.get_buffer_names()]})"
  453. f"dep on {_group_names(gns)}"
  454. f"\n non_group_reason:{grouping_reason}"
  455. )
  456. info.limiting_factor = msg
  457. break
  458. candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
  459. candidate_delta_mem: int = (
  460. candidate_allocfree.size_alloc - candidate_allocfree.size_free
  461. )
  462. # candidate and one of group nodes are successors of the same buffer
  463. # and last use of the buffer happen in group nodes.
  464. # This last use deallocates it.
  465. # If we swap [candidate [group]] to [[group] candidate],
  466. # candidate becomes the last use
  467. # and deallocated this buffer instead of group node.
  468. # we need to update size_free accordingly to group_node and candidate,
  469. # and recalculate post_alloc, post_free for them.
  470. #
  471. # Buf that changes its last use snode,
  472. # after swap will be deallocated only by candidate,
  473. # while before it was deallocated by group node.
  474. group_n_to_bufs_after_swap_dealloc_by_candidate: dict[
  475. BaseSchedulerNode, list[Union[FreeableInputBuffer, Any]]
  476. ] = defaultdict(list)
  477. for (
  478. buf,
  479. snode_last_use,
  480. ) in buf_to_snode_last_use.items():
  481. succ_nodes = buf.mpi_buffer.succ_nodes
  482. if candidate not in succ_nodes:
  483. continue
  484. if not any(gn == snode_last_use for gn in gns):
  485. continue
  486. group_n_to_bufs_after_swap_dealloc_by_candidate[
  487. snode_last_use
  488. ].append(buf)
  489. potential_peak, _post_alloc_update = _calculate_potential_peak_memory(
  490. candidate, gns, group_n_to_bufs_after_swap_dealloc_by_candidate
  491. )
  492. if potential_peak > peak_memory:
  493. info.limiting_factor = (
  494. f"peak memory new:{potential_peak} vs base:{peak_memory}"
  495. )
  496. break
  497. info.moves += 1
  498. total_moves += 1
  499. _perform_double_linked_list_swap(candidate, group_head, group_tail)
  500. info.final_exposed = exposed_communication_time(
  501. curr, _group_nodes(_next[curr], None)
  502. )
  503. _update_memory_tracking_after_swap(
  504. candidate,
  505. gns,
  506. group_n_to_bufs_after_swap_dealloc_by_candidate,
  507. _post_alloc_update,
  508. )
  509. if debug_iterative_memory_recompute:
  510. # Compare iteratively recomputed memory data
  511. # with full run of estimate_peak_memory
  512. from .comms_debug import _debug_iterative_memory_recompute
  513. iterative_recompute_error = _debug_iterative_memory_recompute(
  514. candidate,
  515. gns,
  516. _group_names(gns),
  517. _group_nodes(_head, None),
  518. name_to_freeable_input_buf,
  519. graph_outputs,
  520. peak_memory,
  521. _curr_memory,
  522. snodes_allocfree,
  523. "reorder_communication_preserving_peak_memory",
  524. group_n_to_bufs_after_swap_dealloc_by_candidate,
  525. )
  526. if iterative_recompute_error:
  527. break
  528. candidate = _prev[group_head]
  529. curr = _next[curr] # type: ignore[assignment]
  530. node_stats = stats
  531. improvement = {snode: node_stats[snode].improvement for snode in node_stats}
  532. total_improvement = sum([improvement[snode] for snode in improvement])
  533. total_moves = sum([node_stats[snode].moves for snode in node_stats])
  534. reorder_log_str = (
  535. f"reorder_communication_preserving_peak_memory improved overlap by {total_improvement} ns"
  536. f" after {total_moves} reorders.\n"
  537. )
  538. headers = [
  539. "Collective node",
  540. "initial exposed",
  541. "final exposed",
  542. "improvement",
  543. "limiting factor",
  544. "moves",
  545. "grouped",
  546. "grouped_info",
  547. ]
  548. rows = [
  549. [
  550. node_summary(snode),
  551. node_info.initial_exposed,
  552. node_info.final_exposed,
  553. node_info.improvement,
  554. node_info.limiting_factor,
  555. node_info.moves,
  556. node_info.grouped,
  557. node_info.grouped_info,
  558. ]
  559. for snode, node_info in node_stats.items()
  560. ]
  561. if importlib.util.find_spec("tabulate"):
  562. from tabulate import tabulate
  563. reorder_log_str += tabulate(
  564. rows,
  565. headers=headers,
  566. )
  567. else:
  568. reorder_log_str += (
  569. "Please `pip install tabulate` to nicely render overlap stats.\n"
  570. )
  571. reorder_log_str += str(headers) + "\n"
  572. reorder_log_str += "\n".join(map(str, rows))
  573. new_snodes = _group_nodes(_head, None)
  574. assert len(new_snodes) == original_snodes_num
  575. new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
  576. new_snodes, name_to_freeable_input_buf, graph_outputs
  577. )
  578. reorder_log_str += f"\n peak_memory_before:{peak_memory}"
  579. reorder_log_str += f"\n peak_memory_after:{new_peak_memory}"
  580. overlap_log.info(reorder_log_str)
  581. trace_structured(
  582. "artifact",
  583. metadata_fn=lambda: {
  584. "name": "reorder_communication_preserving_peak_memory",
  585. "encoding": "string",
  586. },
  587. payload_fn=lambda: reorder_log_str,
  588. )
  589. return new_snodes, stats
  590. def _schedule_for_comm(
  591. snodes: list[BaseSchedulerNode],
  592. raise_comms: bool,
  593. sink_waits: bool,
  594. reorder_for_overlap: bool,
  595. ) -> list[BaseSchedulerNode]:
  596. """
  597. Schedule `snodes` for various comm optimization objectives.
  598. Args:
  599. snodes: the nodes to be scheduled.
  600. raise_comms: whether to greedily schedule collectives as early as possible
  601. sink_wait: whether to greedily schedule waits as late as possible
  602. reorder_compute_for_overlap: whether to reorder compute nodes to
  603. optimize for compute/communication overlapping.
  604. Returns:
  605. The new schedule order.
  606. Some notes on the synergy between different options:
  607. - `raise_comms` provides more overlapping oppurtunies for `reorder_compute_for_overlap`.
  608. - When both `raise_comms` and `sink_waits` is `True`, `raise_comms` is prioritized.
  609. """
  610. # We assign each node a tuple of scores (score_0, score_1, score_2),
  611. # decreasing in importance, with a lower value indicating a higher ranking:
  612. #
  613. # - score_0: the lowest comm_idx among the comm nodes that the node blocks.
  614. # If a node doesn't block any comm nodes, its score_0 is set to
  615. # sys.maxsize. This score ensures that comm nodes get scheduled as early as
  616. # possible.
  617. # - score_1: 1 if the node is a wait node, 0 otherwise. This score ensures
  618. # that wait nodes are deferred as late as possible.
  619. # - score_2: the index of the node in the original topological order. This
  620. # score provides stability in case of ties.
  621. #
  622. # When only raise_comms is True, only score_0 and score_2 are considered.
  623. # When only sink_waits is True, only score_1 and score_2 are considered.
  624. # When neither is True, the original order is yielded.
  625. buf_name_to_snode = {}
  626. name_to_fused_node = {}
  627. scores_0, scores_1, scores_2 = {}, {}, {}
  628. for idx, snode in enumerate(snodes):
  629. for buf_name in snode.get_buffer_names():
  630. buf_name_to_snode[buf_name] = snode
  631. for op_name in snode.get_operation_names():
  632. name_to_fused_node[op_name] = snode
  633. name_to_fused_node[snode.get_name()] = snode
  634. node_name = snode.get_name()
  635. scores_0[node_name] = sys.maxsize
  636. scores_1[node_name] = 0
  637. scores_2[node_name] = idx
  638. comm_idx = 0
  639. for snode in snodes:
  640. if raise_comms and contains_collective(snode):
  641. scores_0[snode.get_name()] = comm_idx
  642. for ancestor in snode.ancestors:
  643. anc_fused_name = name_to_fused_node[ancestor].get_name()
  644. scores_0[anc_fused_name] = min(scores_0[anc_fused_name], comm_idx)
  645. comm_idx += 1
  646. elif sink_waits and contains_wait(snode):
  647. scores_1[snode.get_name()] = 1
  648. class Runnable:
  649. def __init__(self, snode) -> None:
  650. self.snode = snode
  651. name = next(iter(snode.get_operation_names()))
  652. fused_name = name_to_fused_node[name].get_name()
  653. self.score = (
  654. scores_0[fused_name],
  655. scores_1[fused_name],
  656. scores_2[fused_name],
  657. )
  658. def __lt__(self, other):
  659. return self.score < other.score
  660. unmet_deps: dict[BaseSchedulerNode, OrderedSet[str]] = {
  661. snode: OrderedSet(dep.name for dep in snode.unmet_dependencies)
  662. for snode in snodes
  663. }
  664. ready: list[Runnable] = []
  665. buffer_users: dict[str, OrderedSet[BaseSchedulerNode]] = defaultdict(OrderedSet)
  666. snode_to_cost = {snode: estimate_op_runtime(snode) for snode in snodes}
  667. for snode, deps in unmet_deps.items():
  668. if len(deps) == 0:
  669. heapq.heappush(ready, Runnable(snode))
  670. for dep in deps:
  671. buffer_users[dep].add(snode)
  672. scheduled = []
  673. def schedule(snode):
  674. """
  675. Schedules `snode` and put all unblocked nodes onto the ready queue.
  676. """
  677. scheduled.append(snode)
  678. for buf_name in snode.get_buffer_names():
  679. for snode in buffer_users[buf_name]:
  680. unmet_deps[snode].remove(buf_name)
  681. if len(unmet_deps[snode]) == 0:
  682. heapq.heappush(ready, Runnable(snode))
  683. def get_overlapping_candidate():
  684. """
  685. Return the next node in the ready queue that's neither a collective or
  686. a wait.
  687. """
  688. candidates = [
  689. x
  690. for x in ready
  691. if not contains_collective(x.snode) and not contains_wait(x.snode)
  692. ]
  693. if len(candidates) == 0:
  694. return None
  695. return min(candidates, key=lambda x: x.score)
  696. def schedule_collective_for_overlap(snode):
  697. """
  698. Schedules collective node `snode`, along with one or more compute nodes
  699. to overlap with it. The strategy is described in the comment of
  700. `reorder_compute_for_overlap`.
  701. """
  702. assert contains_collective(snode)
  703. schedule(snode)
  704. collective_cost = snode_to_cost[snode]
  705. while (
  706. collective_cost > 0
  707. and (candidate := get_overlapping_candidate()) is not None
  708. ):
  709. ready.remove(candidate)
  710. schedule(candidate.snode)
  711. collective_cost -= snode_to_cost[candidate.snode]
  712. heapq.heapify(ready)
  713. while len(ready):
  714. snode = heapq.heappop(ready).snode
  715. if reorder_for_overlap and contains_collective(snode):
  716. schedule_collective_for_overlap(snode)
  717. else:
  718. schedule(snode)
  719. for snode, deps in unmet_deps.items():
  720. assert len(deps) == 0, (
  721. f"Detected unscheduled nodes. Nodes with unmet dependencies: {unmet_deps}"
  722. )
  723. return scheduled
  724. def decide_global_ordering_of_comms(
  725. nodes: list[BaseSchedulerNode], name_to_buf, name_to_fused_node
  726. ) -> list[BaseSchedulerNode]:
  727. """
  728. Decide global ordering of comms, by just enforcing the ordering that's in the input graph
  729. (might not be the same ordering as the eager mode program).
  730. TODO: Come up with a better approach
  731. """
  732. if not torch.distributed.is_available():
  733. return nodes
  734. comm_nodes = [n for n in nodes if contains_collective(n)]
  735. for i in range(1, len(comm_nodes)):
  736. # Enforce ordering by making previous comm a `WeakDep` dependency of the next comm
  737. mutating_buf = next(iter(comm_nodes[i].get_buffer_names()))
  738. for buf in comm_nodes[i - 1].get_buffer_names():
  739. comm_nodes[i].add_fake_dep(
  740. WeakDep(buf, mutating_buf=mutating_buf, is_fake=True)
  741. )
  742. return nodes
  743. @dataclass
  744. class SinkWaitInfo:
  745. grouped: int = 0
  746. grouped_info: str = ""
  747. moves: int = 0
  748. moves_info: str = ""
  749. limiting_factor: str = "None"
  750. def _sink_waits_iterative_internal(
  751. snodes: list[BaseSchedulerNode],
  752. ) -> tuple[list[BaseSchedulerNode], dict[BaseSchedulerNode, SinkWaitInfo]]:
  753. from torch._inductor.scheduler import GroupedSchedulerNode
  754. original_snodes_num = len(snodes)
  755. if original_snodes_num == 0:
  756. return snodes, {}
  757. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  758. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  759. (
  760. peak_memory,
  761. _curr_memory,
  762. snodes_allocfree,
  763. buf_to_snode_last_use,
  764. name_to_freeable_input_buf,
  765. ) = _initialize_memory_tracking(snodes, graph_inputs, graph_outputs)
  766. _prev, _next, _head = _initialize_double_linked_list(snodes)
  767. stats: dict[BaseSchedulerNode, SinkWaitInfo] = {}
  768. def _group_nodes(
  769. head: Optional[BaseSchedulerNode], tail: Optional[BaseSchedulerNode]
  770. ) -> list[BaseSchedulerNode]:
  771. ret = []
  772. n = head
  773. while True:
  774. if n is not None:
  775. ret.append(n)
  776. if n == tail:
  777. break
  778. n = _next[n] # type: ignore[index]
  779. return ret
  780. def _calculate_potential_peak_memory(
  781. candidate, group_ns, group_n_to_bufs_after_swap_dealloc_instead_of_candidate
  782. ):
  783. pre_group_mem = (
  784. _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
  785. )
  786. # Stash memory tracing updates to not recompute them after swap
  787. _post_alloc_update: dict[BaseSchedulerNode, int] = {}
  788. _size_free_delta_update: dict[BaseSchedulerNode, int] = {}
  789. potential_peak = 0
  790. if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  791. # Not accounting for buffers liveliness change
  792. potential_peak = max(
  793. group_peak_memory + candidate_delta_mem,
  794. pre_group_mem + candidate_allocfree.size_alloc,
  795. )
  796. return potential_peak, _post_alloc_update, _size_free_delta_update
  797. candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
  798. _post_alloc_update[candidate] = candidate_post_alloc
  799. potential_peak = candidate_post_alloc
  800. candidate_size_free_to_move = sum(
  801. buf.mpi_buffer.size_free # type: ignore[attr-defined]
  802. for buf in itertools.chain.from_iterable(
  803. group_n_to_bufs_after_swap_dealloc_instead_of_candidate.values()
  804. )
  805. )
  806. _size_free_delta_update[candidate] = -candidate_size_free_to_move
  807. delta_mem = candidate_delta_mem + candidate_size_free_to_move
  808. for gn in gns:
  809. gn_post_alloc = _curr_memory[gn][0] + delta_mem
  810. _post_alloc_update[gn] = gn_post_alloc
  811. potential_peak = max(potential_peak, gn_post_alloc)
  812. gn_size_free_to_add = 0
  813. if gn in group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  814. bufs = group_n_to_bufs_after_swap_dealloc_instead_of_candidate[gn]
  815. for buf in bufs:
  816. gn_size_free_to_add += buf.mpi_buffer.size_free
  817. _size_free_delta_update[gn] = gn_size_free_to_add
  818. delta_mem -= gn_size_free_to_add
  819. return potential_peak, _post_alloc_update, _size_free_delta_update
  820. def _perform_double_linked_list_swap(candidate, group_head, group_tail):
  821. # group_head_prev -0-> candidate -1-> group_head...group_tail -2-> candidate_next
  822. # 0:
  823. group_head_prev = _prev[group_head]
  824. if group_head_prev:
  825. _next[group_head_prev] = candidate
  826. _prev[candidate] = group_head_prev
  827. # 2:
  828. candidate_next = _next[candidate]
  829. if candidate_next:
  830. _prev[candidate_next] = group_tail
  831. _next[group_tail] = candidate_next
  832. # 1:
  833. _prev[group_head] = candidate
  834. _next[candidate] = group_head
  835. nonlocal _head
  836. if group_head == _head:
  837. _head = candidate
  838. def _update_memory_tracking_after_swap(
  839. candidate,
  840. gns,
  841. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  842. _post_alloc_update,
  843. _size_free_delta_update,
  844. ):
  845. group_head = gns[0]
  846. pre_group_mem = (
  847. _curr_memory[group_head][0] - snodes_allocfree[group_head].size_alloc
  848. )
  849. if not group_n_to_bufs_after_swap_dealloc_instead_of_candidate:
  850. candidate_post_alloc = pre_group_mem + candidate_allocfree.size_alloc
  851. _curr_memory[candidate] = (
  852. candidate_post_alloc,
  853. candidate_post_alloc - candidate_allocfree.size_free,
  854. )
  855. for gn in gns:
  856. cm = _curr_memory[gn]
  857. _curr_memory[gn] = (
  858. cm[0] + candidate_delta_mem,
  859. cm[1] + candidate_delta_mem,
  860. )
  861. return
  862. for n in [candidate, *gns]:
  863. post_alloc = _post_alloc_update[n]
  864. snodes_allocfree[n].size_free += _size_free_delta_update[n]
  865. _curr_memory[n] = (
  866. post_alloc,
  867. post_alloc - snodes_allocfree[n].size_free,
  868. )
  869. curr = snodes[-1]
  870. processed_waits = OrderedSet() # type: ignore[var-annotated]
  871. debug_iterative_memory_recompute = config.reorder_iterative_debug_memory_recompute
  872. debug_num_sink_waits_to_reorder: Optional[int] = (
  873. config.sink_waits_iterative_debug_limit_to_sink
  874. )
  875. iterative_recompute_error = False
  876. while _prev[curr] is not None:
  877. if iterative_recompute_error:
  878. break
  879. if (
  880. debug_num_sink_waits_to_reorder is not None
  881. and len(processed_waits) >= debug_num_sink_waits_to_reorder
  882. ):
  883. break
  884. if contains_wait(curr) and curr not in processed_waits:
  885. processed_waits.add(curr)
  886. info = stats[curr] = SinkWaitInfo()
  887. candidate = _next[curr]
  888. wait_snode = curr
  889. group_head = curr
  890. group_tail = curr
  891. group_peak_memory = _curr_memory[curr][0]
  892. while candidate is not None:
  893. if iterative_recompute_error:
  894. break
  895. gns: list[BaseSchedulerNode] = _group_nodes(group_head, group_tail)
  896. group = GroupedSchedulerNode(
  897. wait_snode.scheduler,
  898. gns,
  899. temp_grouping=True,
  900. )
  901. # We can have multiple deps with the same name.
  902. # As we ignore WeakDep(is_fake=True) =>
  903. # filter them out first to avoid overwriting of real dep.
  904. data_deps = {
  905. d.name: d
  906. for d in candidate.unmet_dependencies
  907. if not _is_fake_dep(d)
  908. }
  909. group_outs = group.get_outputs()
  910. data_dep = None
  911. for o in group_outs:
  912. if d := data_deps.get(o.get_name(), None):
  913. data_dep = d
  914. break
  915. # 1. If we have data_dep - we can not swap => trying to group
  916. # 2. If swap candidate and current node both contain collectives => trying to group
  917. if data_dep is not None or (
  918. both_contain_comms := (
  919. contains_collective(group) and contains_collective(candidate)
  920. )
  921. ):
  922. def is_groupable(snode):
  923. # We do not want to group with collectives to not reorder them forward.
  924. if contains_collective(snode):
  925. return (
  926. False,
  927. f"candidate contains collective {snode.get_name()}",
  928. )
  929. if contains_gemm_like(snode):
  930. return (
  931. False,
  932. f"candidate contains gemm_like {snode.get_name()}",
  933. )
  934. return True, None
  935. is_grp, grp_reason = is_groupable(candidate)
  936. if is_grp:
  937. group_tail = candidate
  938. group_peak_memory = max(
  939. group_peak_memory, _curr_memory[candidate][0]
  940. )
  941. info.grouped += 1
  942. info.grouped_info = _group_names(gns)
  943. candidate = _next[candidate]
  944. continue
  945. elif (data_dep is None) and both_contain_comms:
  946. info.limiting_factor = (
  947. f"collective ordering {_group_names(gns)}"
  948. f" with candidate:{candidate.get_name()}"
  949. )
  950. break
  951. else:
  952. info.limiting_factor = (
  953. f"data dependency {data_dep}(dep_names:{list(data_deps.keys())})"
  954. f"\n candidate:{candidate.get_name()}(os:{[candidate.get_buffer_names()]})"
  955. f"dep on {gns}"
  956. f"\n outs:{[o.get_name() for o in group_outs]}"
  957. f"\n non_group_reason:{grp_reason}"
  958. )
  959. break
  960. candidate_allocfree: SNodeMemory = snodes_allocfree[candidate]
  961. candidate_delta_mem = (
  962. candidate_allocfree.size_alloc - candidate_allocfree.size_free
  963. )
  964. # [group] candidate -> candidate [group]
  965. # Check for buffers with successors in group and candidate last successor
  966. #
  967. # Buf that changes its last use snode,
  968. # It was deallocated by candidate,
  969. # but after swap it will be deallocated by group node.
  970. group_n_to_bufs_after_swap_dealloc_instead_of_candidate: dict[
  971. BaseSchedulerNode, list[Union[FreeableInputBuffer, SchedulerBuffer]]
  972. ] = defaultdict(list)
  973. for (
  974. buf,
  975. snode_last_use,
  976. ) in buf_to_snode_last_use.items():
  977. succ_nodes = buf.mpi_buffer.succ_nodes
  978. if snode_last_use != candidate: # noqa: E711
  979. continue
  980. # candidate is last use of buf
  981. last_succ_gn = None
  982. for gn in gns:
  983. if gn in succ_nodes:
  984. last_succ_gn = gn
  985. if last_succ_gn is None:
  986. continue
  987. # gn has successors of buf that after potential swap will become
  988. # last use of buf and start deallocating buf instead of candidate
  989. group_n_to_bufs_after_swap_dealloc_instead_of_candidate[
  990. last_succ_gn
  991. ].append(buf)
  992. potential_peak, _post_alloc_update, _size_free_delta_update = (
  993. _calculate_potential_peak_memory(
  994. candidate,
  995. gns,
  996. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  997. )
  998. )
  999. if potential_peak > peak_memory:
  1000. info.limiting_factor = (
  1001. f"peak memory new:{potential_peak} vs base:{peak_memory}"
  1002. )
  1003. break
  1004. info.moves += 1
  1005. info.moves_info += f"+{candidate.get_name()}"
  1006. _perform_double_linked_list_swap(candidate, group_head, group_tail)
  1007. _update_memory_tracking_after_swap(
  1008. candidate,
  1009. gns,
  1010. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  1011. _post_alloc_update,
  1012. _size_free_delta_update,
  1013. )
  1014. if debug_iterative_memory_recompute:
  1015. from .comms_debug import _debug_iterative_memory_recompute
  1016. iterative_recompute_error = _debug_iterative_memory_recompute(
  1017. candidate,
  1018. gns,
  1019. _group_names(gns),
  1020. _group_nodes(_head, None),
  1021. name_to_freeable_input_buf,
  1022. graph_outputs,
  1023. peak_memory,
  1024. _curr_memory,
  1025. snodes_allocfree,
  1026. "sink_waits_iterative",
  1027. group_n_to_bufs_after_swap_dealloc_instead_of_candidate,
  1028. )
  1029. if iterative_recompute_error:
  1030. break
  1031. candidate = _next[group_tail]
  1032. curr = _prev[curr] # type: ignore[assignment]
  1033. headers = [
  1034. "Wait node",
  1035. "grouped",
  1036. "grouped_info",
  1037. "moves",
  1038. "moves_info",
  1039. "limiting factor",
  1040. ]
  1041. rows = [
  1042. [
  1043. node_summary(snode),
  1044. info.grouped,
  1045. info.grouped_info,
  1046. info.moves,
  1047. info.moves_info,
  1048. info.limiting_factor,
  1049. ]
  1050. for snode, info in stats.items()
  1051. ]
  1052. log_str = ""
  1053. if importlib.util.find_spec("tabulate"):
  1054. from tabulate import tabulate
  1055. log_str += tabulate(
  1056. rows,
  1057. headers=headers,
  1058. )
  1059. else:
  1060. log_str += "Please `pip install tabulate` to nicely render overlap stats.\n"
  1061. log_str += str(headers) + "\n"
  1062. log_str += "\n".join(map(str, rows))
  1063. overlap_log.info(log_str)
  1064. new_snodes = _group_nodes(_head, None)
  1065. assert len(new_snodes) == original_snodes_num
  1066. new_peak_memory, _, _, _ = estimate_peak_memory_allocfree(
  1067. new_snodes, name_to_freeable_input_buf, graph_outputs
  1068. )
  1069. log_str += f"\n sink_waits_iterative peak_memory_before:{peak_memory}"
  1070. log_str += f"\n sink_waits_iterative peak_memory_after:{new_peak_memory}"
  1071. trace_structured(
  1072. "artifact",
  1073. metadata_fn=lambda: {
  1074. "name": "sink_waits_iterative_info",
  1075. "encoding": "string",
  1076. },
  1077. payload_fn=lambda: log_str,
  1078. )
  1079. return new_snodes, stats
  1080. def sink_waits_iterative(
  1081. snodes: list[BaseSchedulerNode],
  1082. ) -> list[BaseSchedulerNode]:
  1083. return _sink_waits_iterative_internal(snodes)[0]
  1084. def estimate_op_runtime(snode: BaseSchedulerNode) -> float:
  1085. """
  1086. Returns estimated op runtime in nanoseconds (ns)
  1087. """
  1088. if config.estimate_op_runtime == "default":
  1089. runtime = snode.get_estimated_runtime()
  1090. else:
  1091. assert callable(config.estimate_op_runtime)
  1092. runtime = config.estimate_op_runtime(snode)
  1093. return runtime
  1094. def node_summary(snode):
  1095. snodes = snode.get_nodes()
  1096. if len(snodes) == 1:
  1097. detail = ""
  1098. if isinstance(snode.node, (ir.ExternKernelOut, ir._CollectiveKernel)):
  1099. outs_str = f"outs:{[o.get_name() for o in snode.get_outputs()]}"
  1100. ins_str = f"ins:{[d.name for d in snode.unmet_dependencies]}"
  1101. detail = f" {snode.get_name()} ({snode.node.python_kernel_name})\n {outs_str}\n ({ins_str})"
  1102. layouts = [child.node.get_output_spec() for child in snode.get_nodes()]
  1103. out_tensor_info = ",".join(
  1104. [
  1105. f" (size={layout.size}, stride={layout.stride})"
  1106. if isinstance(layout, ir.Layout)
  1107. else ""
  1108. for layout in layouts
  1109. ]
  1110. )
  1111. try:
  1112. node_name = snode.node.maybe_get_name()
  1113. except AttributeError:
  1114. # TODO: node_summary was written without FusedSchedulerNode in mind, generally needs to be hardened
  1115. node_name = ""
  1116. return f"{snode.node.__class__.__name__}{detail}{out_tensor_info} ({node_name} ({snode.get_estimated_runtime():.0f} ns)"
  1117. # Flatten the summaries for Fused/Foreach/Grouped nodes
  1118. summaries = []
  1119. for child_snode in snodes:
  1120. summaries.append(node_summary(child_snode))
  1121. return f"{snode.__class__.__name__}: {', '.join(summaries)}"
  1122. def visualize_overlap(order):
  1123. # TODO - this function probably doesn't do a very good job estimating the runtime because it doesn't carefully model
  1124. # streams and overlap. For now its mostly useful as a debug visualization.
  1125. total_est_runtime: float = 0.0
  1126. cur_comm_node = None
  1127. def step_log(step, msg):
  1128. overlap_log.debug(f"{step:>6}: {msg}") # noqa: G004
  1129. for step, snode in enumerate(order):
  1130. if cur_comm_node is None:
  1131. if contains_collective(snode):
  1132. total_est_runtime += estimate_op_runtime(snode)
  1133. cur_comm_node = snode.node
  1134. elif is_wait(snode.node):
  1135. # raise AssertionError(
  1136. # "Wait is not expected when there is no collective running"
  1137. # )
  1138. pass
  1139. else: # exposed compute op
  1140. total_est_runtime += estimate_op_runtime(snode)
  1141. step_log(step, f"{node_summary(snode)}")
  1142. else: # cur_comm_node is not None
  1143. if contains_collective(snode):
  1144. total_est_runtime += estimate_op_runtime(snode)
  1145. cur_comm_node = snode.node
  1146. step_log(step, f"{node_summary(snode)}") # noqa: G004
  1147. elif is_wait(snode.node): # end of this comm op
  1148. step_log(step, f"{node_summary(snode)}")
  1149. cur_comm_node = None
  1150. else: # overlapped compute op
  1151. step_log(step, f"| {node_summary(snode)}")
  1152. overlap_log.debug(
  1153. f"Est. runtime (ms): {total_est_runtime / 1000 / 1000}" # noqa: G004
  1154. )
  1155. def reorder_compute_and_comm_for_overlap(
  1156. snodes: list[BaseSchedulerNode],
  1157. ) -> list[BaseSchedulerNode]:
  1158. order = snodes
  1159. graph_inputs: OrderedSet[str] = OrderedSet(V.graph.graph_inputs.keys())
  1160. graph_outputs: OrderedSet[str] = OrderedSet(V.graph.get_output_names())
  1161. for p in config.reorder_for_compute_comm_overlap_passes:
  1162. if isinstance(p, str) and p in globals():
  1163. p = globals()[p] # it is a builtin pass
  1164. assert callable(p), (
  1165. f"Invalid reorder_compute_and_comm_for_overlap pass: {p} is not callable"
  1166. )
  1167. peak_memory, _ = estimate_peak_memory(
  1168. snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs
  1169. )
  1170. if torch.distributed.get_rank() == 0:
  1171. overlap_log.debug(
  1172. f"==== Visualize overlap before reordering pass {p}, {peak_memory=} ====" # noqa: G004
  1173. )
  1174. try:
  1175. visualize_overlap(order)
  1176. except Exception as e:
  1177. overlap_log.debug("", exc_info=e)
  1178. t0 = time.time()
  1179. order = p(order) # type: ignore[operator]
  1180. t = time.time() - t0
  1181. if torch.distributed.get_rank() == 0:
  1182. overlap_log.debug(
  1183. f"==== Visualize overlap after reordering pass {p} (ran in {t} sec)====" # noqa: G004
  1184. )
  1185. try:
  1186. visualize_overlap(order)
  1187. except Exception as e:
  1188. overlap_log.debug("", exc_info=e)
  1189. peak_memory, _ = estimate_peak_memory(
  1190. snodes, get_freeable_input_buf(snodes, graph_inputs), graph_outputs
  1191. )
  1192. print(f"final {peak_memory=}")
  1193. return order
  1194. def remove_fsdp2_unsharded_param_graph_input_usage(graph: torch.fx.Graph):
  1195. """
  1196. This FX graph pass replaces uses of FSDP2 unsharded params with their corresponding
  1197. graph intermediates that were fsdp.copy_ into the unsharded params in the original graph.
  1198. NOTE: Can only apply this pass to any of the FSDP2 unsharded params that have this pattern
  1199. (or repetition of): `resize_(full) -> copy_ -> resize_(0)`. Because of this, for partial-graph case
  1200. where `resize_(full) -> copy_` is in one graph and `resize_(0)` is in another graph, we can't
  1201. remove these resize and copy ops and thus we will have worse performance there.
  1202. In other words, "do we try to remove all the resize_(full) -> copy_ -> resize_(0) nodes for this unsharded param"
  1203. is actually a per-unsharded-param decision, since for each unsharded param, we look at its resize sequence pattern
  1204. (in `check_resize_pattern()`) to determine if its set of resize and copy nodes can be removed.
  1205. """
  1206. node_list = list(graph.nodes)
  1207. # Find all graph inputs and their resize counts
  1208. graph_input_to_resized_to_full_node_idxes = defaultdict(list)
  1209. graph_input_to_resized_to_0_node_idxes = defaultdict(list)
  1210. for idx, node in enumerate(node_list):
  1211. if (
  1212. node.op == "call_function"
  1213. and node.target == torch.ops.inductor.resize_storage_bytes_.default
  1214. ):
  1215. assert node.args[0].op == "placeholder", f"""\
  1216. Resize can only operate on graph inputs, but got {node} which is resizing non-graph-input {node.args[0]}
  1217. """
  1218. graph_input = node.args[0]
  1219. new_size = node.args[1]
  1220. if new_size > 0:
  1221. graph_input_to_resized_to_full_node_idxes[graph_input].append(idx)
  1222. else:
  1223. graph_input_to_resized_to_0_node_idxes[graph_input].append(idx)
  1224. def check_resize_pattern(graph_input):
  1225. # Check the number of resize-to-full and resize-to-0 nodes are equal,
  1226. # and that for each (resize-to-full, resize-to-0) pair, the resize-to-full node
  1227. # always happens before the resize-to-0 node.
  1228. # This is the precondition for being able to remove all the resize and copy nodes
  1229. # for this specific unsharded param.
  1230. resized_to_full_idxes = graph_input_to_resized_to_full_node_idxes.get(
  1231. graph_input, []
  1232. )
  1233. resized_to_0_idxes = graph_input_to_resized_to_0_node_idxes.get(graph_input, [])
  1234. if not len(resized_to_full_idxes) == len(resized_to_0_idxes):
  1235. log.warning(
  1236. f"""
  1237. Unequal number of resize-to-full and resize-to-0 nodes for graph input {graph_input}:
  1238. {len(resized_to_full_idxes)} vs. {len(resized_to_0_idxes)}.
  1239. Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass.
  1240. """ # noqa: G004
  1241. )
  1242. return False
  1243. # Check the sequence: (resize_to_full -> resize_to_0)+
  1244. for resize_to_full_idx, resize_to_0_idx in zip(
  1245. resized_to_full_idxes, resized_to_0_idxes
  1246. ):
  1247. if resize_to_full_idx >= resize_to_0_idx:
  1248. log.warning(
  1249. f"""
  1250. For graph input {graph_input}: resize-to-full node {node_list[resize_to_full_idx]} at index {resize_to_full_idx}
  1251. happens after resize-to-0 node {node_list[resize_to_0_idx]} at index {resize_to_0_idx}.
  1252. Skipping `remove_fsdp2_unsharded_param_graph_input_usage` FX graph pass for that unsharded param.
  1253. """ # noqa: G004
  1254. )
  1255. return False
  1256. return True
  1257. # Find all eligible unsharded params and their corresponding graph intermediates.
  1258. unsharded_param_to_fsdp_copy_node_idxes = defaultdict(list)
  1259. for idx, node in enumerate(node_list):
  1260. if node.op == "call_function" and node.target == torch.ops.fsdp.copy_.default:
  1261. fsdp_copy_node = node
  1262. unsharded_param = node.args[0]
  1263. assert unsharded_param.op == "placeholder", f"""
  1264. Assumed all FSDP2 `unsharded_param`s to be graph input, but it's not true!
  1265. Offending node: {unsharded_param}. Graph: {graph}
  1266. """
  1267. if check_resize_pattern(unsharded_param):
  1268. unsharded_param_to_fsdp_copy_node_idxes[unsharded_param].append(idx)
  1269. def is_allowed_mutation(node):
  1270. return (
  1271. node.target == torch.ops.fsdp.copy_.default
  1272. or node.target == torch.ops.inductor.resize_storage_bytes_.default
  1273. )
  1274. def is_node_mutating_unsharded_param_or_its_alias(node, unsharded_params):
  1275. # Check whether the node is mutating any of the unsharded params or their aliases.
  1276. mutated_arg_idxes = (
  1277. [
  1278. i
  1279. for i, x in enumerate(node.target._schema.arguments)
  1280. if x.alias_info is not None and x.alias_info.is_write
  1281. ]
  1282. if isinstance(node.target, torch._ops.OpOverload)
  1283. else []
  1284. )
  1285. mutated_node_arg_storages = OrderedSet(
  1286. [
  1287. StorageWeakRef(node.args[i].meta["val"].untyped_storage())
  1288. for i in mutated_arg_idxes
  1289. ]
  1290. )
  1291. storages_of_unsharded_params = OrderedSet(
  1292. [
  1293. StorageWeakRef(unsharded_param.meta["val"].untyped_storage())
  1294. for unsharded_param in unsharded_params
  1295. ]
  1296. )
  1297. return len(mutated_node_arg_storages & storages_of_unsharded_params) > 0
  1298. # Check no user mutation on any unsharded_param
  1299. for node in node_list:
  1300. if (
  1301. node.op == "call_function"
  1302. and isinstance(node.target, torch._ops.OpOverload)
  1303. and node.target._schema.is_mutable
  1304. and not is_allowed_mutation(node)
  1305. ):
  1306. assert not is_node_mutating_unsharded_param_or_its_alias(
  1307. node, unsharded_param_to_fsdp_copy_node_idxes.keys()
  1308. ), f"""\
  1309. User mutation on FSDP2 unsharded param is not allowed when Traceable FSDP2 is used. Violating node: {node}
  1310. """
  1311. # For each `fsdp.copy_(unsharded_param, Y)`, replace downstream usage of `unsharded_param` with `Y`.
  1312. #
  1313. # NOTE: Because of "layer reuse" use case, there could be multiple `fsdp.copy_` to the same `unsharded_param` graph input.
  1314. # e.g.
  1315. # ```
  1316. # fsdp_copy_1 = fsdp.copy_(unsharded_param_1, Y1)
  1317. # ... (use of unsharded_param_1) -> Subgraph 1
  1318. # fsdp_copy_2 = fsdp.copy_(unsharded_param_1, Y2)
  1319. # ... (use of unsharded_param_1) -> Subgraph 2
  1320. # fsdp_copy_3 = fsdp.copy_(unsharded_param_1, Y3)
  1321. # ... (use of unsharded_param_1) -> Subgraph 3
  1322. # ```
  1323. # We must do the replacement only within each subgraph.
  1324. for (
  1325. unsharded_param,
  1326. fsdp_copy_node_idxes,
  1327. ) in unsharded_param_to_fsdp_copy_node_idxes.items():
  1328. for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
  1329. fsdp_copy_node = node_list[fsdp_copy_node_idx]
  1330. assert fsdp_copy_node.args[0] is unsharded_param
  1331. _, replacement = fsdp_copy_node.args
  1332. # subgraph_start_idx is exclusive
  1333. subgraph_start_idx = fsdp_copy_node_idx + 1
  1334. # subgraph_end_idx is exclusive (also intentionally don't replace args in return op)
  1335. subgraph_end_idx = (
  1336. fsdp_copy_node_idxes[i + 1]
  1337. if i < len(fsdp_copy_node_idxes) - 1
  1338. else len(node_list) - 1
  1339. )
  1340. subgraph_nodes = node_list[subgraph_start_idx:subgraph_end_idx]
  1341. assert not any(
  1342. is_node_mutating_unsharded_param_or_its_alias(node, [unsharded_param])
  1343. for node in subgraph_nodes
  1344. ), f"""\
  1345. Assumed no ops mutating unsharded param {unsharded_param} in subgraph {subgraph_nodes}, but it's not true!
  1346. Graph: {graph}
  1347. """
  1348. for node in subgraph_nodes:
  1349. if (
  1350. node.op == "call_function"
  1351. and unsharded_param in node.args
  1352. and node.target != torch.ops.inductor.resize_storage_bytes_.default
  1353. ): # TODO(yf225): implement replacement in kwargs
  1354. new_args = tuple(
  1355. replacement if arg is unsharded_param else arg
  1356. for arg in node.args
  1357. )
  1358. node.args = new_args
  1359. # Delete `fsdp.copy_(unsharded_param, Y)` nodes
  1360. for (
  1361. unsharded_param,
  1362. fsdp_copy_node_idxes,
  1363. ) in unsharded_param_to_fsdp_copy_node_idxes.items():
  1364. for i, fsdp_copy_node_idx in enumerate(fsdp_copy_node_idxes):
  1365. fsdp_copy_node = node_list[fsdp_copy_node_idx]
  1366. graph.erase_node(fsdp_copy_node)
  1367. # Delete `resize_(unsharded_param, ...)` nodes
  1368. for node in node_list:
  1369. if (
  1370. node.op == "call_function"
  1371. and node.target == torch.ops.inductor.resize_storage_bytes_.default
  1372. and node.args[0] in unsharded_param_to_fsdp_copy_node_idxes
  1373. ):
  1374. graph.erase_node(node)
  1375. def reinplace_fsdp_all_gather(graph: torch.fx.Graph) -> None:
  1376. try:
  1377. import torch.distributed.fsdp._fully_shard._fsdp_collectives
  1378. assert torch.distributed.is_available()
  1379. # Assert existence of these ops
  1380. assert (
  1381. torch.ops._c10d_functional.all_gather_into_tensor
  1382. and torch.ops._c10d_functional.all_gather_into_tensor_out
  1383. )
  1384. except (ImportError, AttributeError, AssertionError):
  1385. return
  1386. from .pattern_matcher import (
  1387. CallFunction,
  1388. KeywordArg,
  1389. Match,
  1390. PatternMatcherPass,
  1391. register_graph_pattern,
  1392. )
  1393. """
  1394. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
  1395. getitem = all_gather_copy_in[0];
  1396. (getitem_1 = all_gather_copy_in[1];) # optional
  1397. all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor.default(getitem, ...);
  1398. ->
  1399. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(...);
  1400. getitem = all_gather_copy_in[0];
  1401. getitem_1 = all_gather_copy_in[1];
  1402. all_gather_into_tensor = torch.ops._c10d_functional.all_gather_into_tensor_out.default(getitem, ..., out=getitem_1);
  1403. """
  1404. def remove_unused_getitem(g):
  1405. # Remove `getitem_X = all_gather_copy_in[1]` which is never used.
  1406. node_list = list(g.nodes)
  1407. for n in node_list:
  1408. if (
  1409. n.target == operator.getitem
  1410. and n.args[0].target is torch.ops.fsdp.all_gather_copy_in.default
  1411. and n.args[1] == 1
  1412. ):
  1413. g.erase_node(n)
  1414. graph_pass = PatternMatcherPass()
  1415. @register_graph_pattern(
  1416. CallFunction(
  1417. torch.ops._c10d_functional.all_gather_into_tensor.default,
  1418. CallFunction(
  1419. operator.getitem,
  1420. CallFunction(
  1421. torch.ops.fsdp.all_gather_copy_in.default,
  1422. KeywordArg("all_gather_inputs"),
  1423. KeywordArg("all_gather_output"),
  1424. KeywordArg("inp_split_sizes"),
  1425. KeywordArg("all_gather_input_numel"),
  1426. KeywordArg("rank"),
  1427. ),
  1428. KeywordArg("item_idx"),
  1429. ),
  1430. KeywordArg("group_size"),
  1431. KeywordArg("group_name"),
  1432. ),
  1433. pass_dict=graph_pass,
  1434. extra_check=lambda match: match.kwargs["item_idx"] == 0,
  1435. )
  1436. def reinplace_all_gather(match: Match, *args, **kwargs):
  1437. def repl(
  1438. *args,
  1439. ):
  1440. copy_in_args = args[:-2]
  1441. group_size = args[-2]
  1442. group_name = args[-1]
  1443. all_gather_copy_in = torch.ops.fsdp.all_gather_copy_in.default(
  1444. *copy_in_args
  1445. )
  1446. getitem = all_gather_copy_in[0]
  1447. getitem_1 = all_gather_copy_in[1]
  1448. all_gather_into_tensor = (
  1449. torch.ops._c10d_functional.all_gather_into_tensor_out.default(
  1450. getitem, group_size, group_name, out=getitem_1
  1451. )
  1452. )
  1453. return all_gather_into_tensor
  1454. match.replace_by_example(
  1455. repl,
  1456. [
  1457. kwargs["all_gather_inputs"],
  1458. kwargs["all_gather_output"],
  1459. kwargs["inp_split_sizes"],
  1460. kwargs["all_gather_input_numel"],
  1461. kwargs["rank"],
  1462. kwargs["group_size"],
  1463. kwargs["group_name"],
  1464. ],
  1465. )
  1466. remove_unused_getitem(graph)
  1467. graph_pass.apply(graph) # type: ignore[arg-type]
  1468. def get_op_idx(snode):
  1469. assert not isinstance(
  1470. snode,
  1471. (
  1472. torch._inductor.scheduler.FusedSchedulerNode,
  1473. torch._inductor.scheduler.GroupedSchedulerNode,
  1474. ),
  1475. )
  1476. return int(snode.get_name()[2:])
  1477. def enforce_comm_ordering_for_fsdp(
  1478. snodes: list[torch._inductor.scheduler.BaseSchedulerNode],
  1479. name_to_buf: dict[str, torch._inductor.scheduler.SchedulerBuffer],
  1480. name_to_fused_node: dict[str, BaseSchedulerNode],
  1481. ) -> list[torch._inductor.scheduler.BaseSchedulerNode]:
  1482. from . import scheduler
  1483. new_order: list[BaseSchedulerNode] = []
  1484. scheduled = OrderedSet[Any]()
  1485. ag_exists = False
  1486. rs_exists = False
  1487. ag_grouped_node_to_wait_grouped_node = {}
  1488. rs_grouped_node_to_wait_grouped_node = {}
  1489. snode_name_to_final_snode = {}
  1490. def _create_group_node(snodes_to_group):
  1491. group_node = scheduler.GroupedSchedulerNode.create(snodes_to_group)
  1492. for snode in snodes_to_group:
  1493. snode_name_to_final_snode[snode.get_name()] = group_node
  1494. snode_name_to_final_snode[group_node.get_name()] = group_node
  1495. return group_node
  1496. # Create grouped nodes for specific sets of ops
  1497. for snode in snodes:
  1498. # Case 1: Handle AllGather
  1499. if is_collective(
  1500. snode.node, op=torch.ops._c10d_functional.all_gather_into_tensor_out.default
  1501. ) and any(
  1502. is_fallback_op(
  1503. name_to_fused_node[x].node, torch.ops.fsdp.all_gather_copy_in.default
  1504. )
  1505. for x in snode.ancestors
  1506. ):
  1507. ag_exists = True
  1508. ag_snode = snode
  1509. ag_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
  1510. # Find the "cast + copy_in + getitem + all_gather" code block
  1511. find_recursive_deps_of_node(
  1512. ag_snode,
  1513. ag_related_snode_set,
  1514. name_to_buf,
  1515. name_to_fused_node,
  1516. )
  1517. # Find the "all_gather + all_gather_wait_tensor + copy_out" code block
  1518. allowed_ops = OrderedSet(
  1519. [
  1520. torch.ops._c10d_functional.all_gather_into_tensor_out.default,
  1521. torch.ops._c10d_functional.wait_tensor.default,
  1522. torch.ops.fsdp.split_with_sizes_copy.default,
  1523. ]
  1524. )
  1525. find_recursive_users_of_node(
  1526. ag_snode,
  1527. ag_related_snode_set,
  1528. name_to_buf,
  1529. name_to_fused_node,
  1530. criteria_cb=lambda x: not (
  1531. isinstance(x, scheduler.NopKernelSchedulerNode)
  1532. or (
  1533. isinstance(x, scheduler.ExternKernelSchedulerNode)
  1534. and x.node.op_overload in allowed_ops # type: ignore[union-attr]
  1535. )
  1536. ),
  1537. )
  1538. # sort nodes by original operation order
  1539. ag_related_snodes = sorted(
  1540. ag_related_snode_set, key=lambda x: get_op_idx(x)
  1541. )
  1542. # In the "reuse layer" case, some ops in the 2nd all-gather code block could also
  1543. # depend on ops in the 1st all-gather code block, and we don't want to group them together.
  1544. end_idx_of_current_ag_block = len(ag_related_snodes)
  1545. copy_out_count = 0
  1546. for i in range(len(ag_related_snodes)):
  1547. cur_snode = ag_related_snodes[i]
  1548. if is_fallback_op(
  1549. cur_snode.node, torch.ops.fsdp.split_with_sizes_copy.default
  1550. ):
  1551. copy_out_count += 1
  1552. if copy_out_count > 1:
  1553. end_idx_of_current_ag_block = i
  1554. break
  1555. ag_related_snodes = ag_related_snodes[:end_idx_of_current_ag_block]
  1556. # Group "cast + copy_in + getitem + all_gather" into one GroupedSchedulerNode
  1557. wait_node_idx = None
  1558. for i in range(len(ag_related_snodes) - 1):
  1559. if isinstance(ag_related_snodes[i + 1].node, ir._WaitKernel):
  1560. wait_node_idx = i + 1
  1561. break
  1562. assert wait_node_idx is not None
  1563. ag_group_node = _create_group_node(ag_related_snodes[:wait_node_idx])
  1564. # Group "all_gather_wait_tensor + copy_out" into one GroupedSchedulerNode
  1565. ag_wait_group_node = _create_group_node(ag_related_snodes[wait_node_idx:])
  1566. ag_grouped_node_to_wait_grouped_node[ag_group_node] = ag_wait_group_node
  1567. # Case 2: Handle ReduceScatter
  1568. elif is_fallback_op(snode.node, torch.ops.fsdp.chunk_cat.default):
  1569. rs_exists = True
  1570. rs_snode = snode
  1571. # Find the "reduce_scatter copy-in + reduce_scatter comm + reduce_scatter wait" code block
  1572. rs_related_snode_set: OrderedSet[scheduler.BaseSchedulerNode] = OrderedSet()
  1573. find_recursive_users_of_node(
  1574. rs_snode,
  1575. rs_related_snode_set,
  1576. name_to_buf,
  1577. name_to_fused_node,
  1578. )
  1579. # sort nodes by original operation order
  1580. rs_related_snodes = sorted(
  1581. rs_related_snode_set, key=lambda x: get_op_idx(x)
  1582. )
  1583. # Group "reduce_scatter copy-in + reduce_scatter comm" into one GroupedSchedulerNode
  1584. wait_node_idx = None
  1585. for i in range(len(rs_related_snodes) - 1):
  1586. if isinstance(rs_related_snodes[i + 1].node, ir._WaitKernel):
  1587. wait_node_idx = i + 1
  1588. break
  1589. assert wait_node_idx is not None
  1590. rs_group_node = _create_group_node(rs_related_snodes[:wait_node_idx])
  1591. # Group "reduce_scatter wait + related output nodes" into one GroupedSchedulerNode
  1592. rs_wait_group_node = _create_group_node(rs_related_snodes[wait_node_idx:])
  1593. rs_grouped_node_to_wait_grouped_node[rs_group_node] = rs_wait_group_node
  1594. assert len(snode_name_to_final_snode) > 0
  1595. if ag_exists:
  1596. assert len(ag_grouped_node_to_wait_grouped_node) > 0
  1597. if rs_exists:
  1598. assert len(rs_grouped_node_to_wait_grouped_node) > 0
  1599. # Build the new node schedule, taking GroupedSchedulerNode into account
  1600. for snode in snodes:
  1601. if snode.get_name() in snode_name_to_final_snode:
  1602. snode = snode_name_to_final_snode[snode.get_name()]
  1603. if snode in scheduled:
  1604. continue
  1605. new_order.append(snode)
  1606. scheduled.add(snode)
  1607. # Enforce AllGather ordering: previous AllGather's "wait then copy_out" group node must run
  1608. # before next AllGather's "copy_in then AG" group node
  1609. prev_ag_wait = None
  1610. for ag_group_node, wait_group_node in ag_grouped_node_to_wait_grouped_node.items():
  1611. if prev_ag_wait is not None:
  1612. mutating_buf = next(iter(ag_group_node.get_buffer_names()))
  1613. for o in prev_ag_wait.get_outputs():
  1614. ag_group_node.add_fake_dep(
  1615. WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
  1616. )
  1617. prev_ag_wait = wait_group_node
  1618. # Enforce ReduceScatter ordering: previous ReduceScatter's "wait" group node must run
  1619. # before next ReduceScatter's "copy_in then RS" group node
  1620. prev_rs_wait = None
  1621. for rs_group_node, wait_group_node in rs_grouped_node_to_wait_grouped_node.items():
  1622. if prev_rs_wait is not None:
  1623. mutating_buf = next(iter(rs_group_node.get_buffer_names()))
  1624. for o in prev_rs_wait.get_outputs():
  1625. rs_group_node.add_fake_dep(
  1626. WeakDep(o.get_name(), mutating_buf=mutating_buf, is_fake=True)
  1627. )
  1628. prev_rs_wait = wait_group_node
  1629. return new_order # type: ignore[return-value]