schedules.py 130 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. import copy
  4. import csv
  5. import itertools
  6. import logging
  7. import re
  8. from abc import ABC, abstractmethod
  9. from collections import Counter, defaultdict
  10. from enum import Enum
  11. from functools import lru_cache
  12. from typing import Any, Callable, NamedTuple, Optional, Union
  13. import torch
  14. import torch.distributed as dist
  15. from torch._dynamo import OptimizedModule
  16. from torch.distributed.fsdp import FSDPModule, UnshardHandle
  17. from torch.nn.modules.loss import _Loss
  18. from torch.profiler import record_function
  19. from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
  20. from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
  21. from .stage import _PipelineStageBase
  22. __all__ = [
  23. "get_schedule_class",
  24. "PipelineScheduleSingle",
  25. "PipelineScheduleMulti",
  26. "Schedule1F1B",
  27. "ScheduleGPipe",
  28. "ScheduleInterleaved1F1B",
  29. "ScheduleLoopedBFS",
  30. "ScheduleInterleavedZeroBubble",
  31. "ScheduleZBVZeroBubble",
  32. "ScheduleDualPipeV",
  33. ]
  34. logger = logging.getLogger(__name__)
  35. class _ComputationType(Enum):
  36. # TODO(whc) rename to _ActType?
  37. FORWARD = 1
  38. BACKWARD_INPUT = 2
  39. BACKWARD_WEIGHT = 3
  40. UNSHARD = 4
  41. RESHARD = 5
  42. SEND_F = 6
  43. RECV_F = 7
  44. SEND_B = 8
  45. RECV_B = 9
  46. FULL_BACKWARD = 10
  47. OVERLAP_F_B = 11
  48. def __str__(self):
  49. str_map = {
  50. _ComputationType.FORWARD: "F",
  51. _ComputationType.BACKWARD_INPUT: "I",
  52. _ComputationType.BACKWARD_WEIGHT: "W",
  53. _ComputationType.UNSHARD: "UNSHARD",
  54. _ComputationType.RESHARD: "RESHARD",
  55. _ComputationType.SEND_F: "SEND_F",
  56. _ComputationType.RECV_F: "RECV_F",
  57. _ComputationType.SEND_B: "SEND_B",
  58. _ComputationType.RECV_B: "RECV_B",
  59. _ComputationType.FULL_BACKWARD: "B",
  60. _ComputationType.OVERLAP_F_B: "OVERLAP_F_B",
  61. }
  62. return str_map[self]
  63. @staticmethod
  64. def from_str(action):
  65. if action == "F":
  66. return _ComputationType.FORWARD
  67. elif action == "I":
  68. return _ComputationType.BACKWARD_INPUT
  69. elif action == "W":
  70. return _ComputationType.BACKWARD_WEIGHT
  71. elif action == "UNSHARD":
  72. return _ComputationType.UNSHARD
  73. elif action == "RESHARD":
  74. return _ComputationType.RESHARD
  75. elif action == "SEND_F":
  76. return _ComputationType.SEND_F
  77. elif action == "RECV_F":
  78. return _ComputationType.RECV_F
  79. elif action == "SEND_B":
  80. return _ComputationType.SEND_B
  81. elif action == "RECV_B":
  82. return _ComputationType.RECV_B
  83. elif action == "B":
  84. return _ComputationType.FULL_BACKWARD
  85. elif action == "OVERLAP_F_B":
  86. return _ComputationType.OVERLAP_F_B
  87. else:
  88. raise RuntimeError(f"Invalid computation type {action}")
  89. FORWARD = _ComputationType.FORWARD
  90. BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT
  91. BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT
  92. UNSHARD = _ComputationType.UNSHARD
  93. RESHARD = _ComputationType.RESHARD
  94. SEND_F = _ComputationType.SEND_F
  95. RECV_F = _ComputationType.RECV_F
  96. SEND_B = _ComputationType.SEND_B
  97. RECV_B = _ComputationType.RECV_B
  98. FULL_BACKWARD = _ComputationType.FULL_BACKWARD
  99. OVERLAP_F_B = _ComputationType.OVERLAP_F_B
  100. # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
  101. F = FORWARD
  102. I = BACKWARD_INPUT
  103. W = BACKWARD_WEIGHT
  104. B = FULL_BACKWARD
  105. # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
  106. _action_regex = re.compile(
  107. r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)"
  108. )
  109. class _Action(NamedTuple):
  110. stage_index: int
  111. computation_type: _ComputationType
  112. microbatch_index: Optional[int] = None
  113. sub_actions: Optional[tuple["_Action", ...]] = None
  114. def __str__(self):
  115. return self.__repr__()
  116. def __repr__(self):
  117. if self.sub_actions is not None:
  118. # Use recursive repr for sub_actions
  119. sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions]
  120. return f"({';'.join(sub_action_reprs)}){self.computation_type}"
  121. else:
  122. repr_str = str(self.stage_index)
  123. repr_str += str(self.computation_type)
  124. if self.microbatch_index is not None:
  125. repr_str += str(self.microbatch_index)
  126. return repr_str
  127. @property
  128. def is_compute_op(self) -> bool:
  129. return self.computation_type in (
  130. FORWARD,
  131. FULL_BACKWARD,
  132. BACKWARD_INPUT,
  133. BACKWARD_WEIGHT,
  134. OVERLAP_F_B,
  135. )
  136. @staticmethod
  137. def from_str(action_string: str):
  138. """
  139. Reverse of __repr__
  140. String should be formatted as [stage][action type][(microbatch)]
  141. e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
  142. """
  143. action_string = action_string.strip()
  144. if action_string == "":
  145. return None
  146. # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType
  147. if action_string.startswith("(") and ")" in action_string:
  148. # Find the closing bracket to separate sub_actions from computation type
  149. bracket_end = action_string.find(")")
  150. sub_part = action_string[
  151. 1:bracket_end
  152. ] # Remove '[' and get content before ']'
  153. computation_type_part = action_string[
  154. bracket_end + 1 :
  155. ] # Get part after ']'
  156. # Parse sub_actions
  157. sub_actions = []
  158. if sub_part.strip():
  159. for sub_str in sub_part.split(";"):
  160. sub_action = _Action.from_str(sub_str.strip())
  161. if sub_action is not None:
  162. sub_actions.append(sub_action)
  163. # For sub_actions format, we create an action with just the computation type
  164. # The stage_index and microbatch_index are not meaningful for the container action
  165. return _Action(
  166. stage_index=-1, # Placeholder, not meaningful for sub_actions container
  167. computation_type=_ComputationType.from_str(computation_type_part),
  168. microbatch_index=None,
  169. sub_actions=tuple(sub_actions) if sub_actions else None,
  170. )
  171. # Handle regular single action format
  172. if match := _action_regex.match(action_string):
  173. stage_index, computation_type, microbatch_index = match.groups()
  174. return _Action(
  175. int(stage_index),
  176. _ComputationType.from_str(computation_type),
  177. int(microbatch_index) if len(microbatch_index) else None,
  178. )
  179. elif action_string == "":
  180. return None
  181. raise RuntimeError(
  182. f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
  183. )
  184. @lru_cache
  185. def _get_profiler_function_name(action: _Action) -> str:
  186. return f"PP:{str(action)}"
  187. def _format_pipeline_order(
  188. pipeline_order: dict[int, list[Optional[_Action]]],
  189. error_step_number: Optional[int] = None,
  190. ) -> str:
  191. """
  192. Formats the pipeline order in a timestep (row) x rank (column) grid of actions
  193. and returns the formatted string.
  194. If `error_step_number` is passed in, an additional label will be added to signify which step
  195. that it is erroring on.
  196. """
  197. # don't mutate the original
  198. pipeline_order = copy.deepcopy(pipeline_order)
  199. # Replace None with ""
  200. for rank in pipeline_order:
  201. for i in range(len(pipeline_order[rank])):
  202. if pipeline_order[rank][i] is None:
  203. # TODO make a real 'None action' that prints as empty string and make mypy happy
  204. pipeline_order[rank][i] = "" # type: ignore[call-overload]
  205. # Calculate the maximum number of steps across all ranks
  206. num_steps = max(len(actions) for actions in pipeline_order.values())
  207. step_labels = [
  208. "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
  209. ]
  210. # Sorting the dictionary by keys and retrieving values in that order
  211. rank_actions = [
  212. pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
  213. ]
  214. # Transpose the list of lists (rows to columns)
  215. transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
  216. # Generate column labels for ranks
  217. num_ranks = len(pipeline_order)
  218. rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
  219. # Calculate the maximum length of each column, considering labels
  220. max_lengths = [
  221. max(len(str(item)) if item is not None else 0 for item in col)
  222. for col in zip(step_labels, *transposed_actions)
  223. ]
  224. # Format the header row with rank labels
  225. header_row = " " * (len(step_labels[0]) + 2) + " ".join(
  226. f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
  227. )
  228. # Format each row with its corresponding label
  229. formatted_rows = [
  230. f"{label}: "
  231. + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
  232. + (
  233. " <-- ERROR HERE"
  234. if error_step_number is not None
  235. and int(label.split()[1]) == error_step_number
  236. else ""
  237. )
  238. for label, row in zip(step_labels, transposed_actions)
  239. ]
  240. # Join the rows into a single string
  241. formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
  242. return formatted_table
  243. class _PipelineSchedule(ABC):
  244. def __init__(
  245. self,
  246. n_microbatches: int,
  247. loss_fn: Optional[Callable[..., torch.Tensor]] = None,
  248. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  249. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  250. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  251. scale_grads: bool = True,
  252. ):
  253. # From arguments
  254. self._n_microbatches = n_microbatches
  255. self._loss_fn = loss_fn
  256. # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti`
  257. self.scale_grads = scale_grads
  258. # Chunking specification for positional inputs. (default: `None`)
  259. self._args_chunk_spec = args_chunk_spec
  260. # Chunking specification for keyword inputs. (default: `None`)
  261. self._kwargs_chunk_spec = kwargs_chunk_spec
  262. self._output_merge_spec = output_merge_spec
  263. """
  264. # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
  265. # They are used to convert batch to microbatches in `step(x)`. See
  266. # `TensorChunkSpec` for helper methods for creating them.
  267. """
  268. # Derived
  269. self._has_backward = self._loss_fn is not None
  270. # Holds the losses for each microbatch.
  271. self._internal_losses: list[torch.Tensor] = []
  272. logger.info("Using %s", self.__class__.__name__)
  273. def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
  274. if stage.is_last and self._loss_fn is not None:
  275. loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
  276. self._internal_losses.append(loss)
  277. def _maybe_get_loss(self, stage, mb_index):
  278. valid_index = 0 <= mb_index < len(self._internal_losses)
  279. if stage.is_last and self._loss_fn is not None and valid_index:
  280. return self._internal_losses[mb_index]
  281. elif len(self._internal_losses) != 0 and not valid_index:
  282. raise RuntimeError(
  283. f"Loss for microbatch {mb_index} is not available. "
  284. f"Available losses for microbatches: {self._internal_losses}"
  285. )
  286. else:
  287. return None
  288. def _update_losses(self, stages, losses):
  289. """
  290. Update the losses to those in the internal state
  291. """
  292. # if stages not a list turn into a list
  293. if not isinstance(stages, list):
  294. stages = [stages]
  295. contains_last_stage = any(stage.is_last for stage in stages)
  296. # Return losses if there is a container passed in
  297. if contains_last_stage and losses is not None:
  298. if len(self._internal_losses) != self._n_microbatches:
  299. raise RuntimeError(
  300. f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
  301. )
  302. # Clean external container first
  303. losses.clear()
  304. # Copy internal losses to external container
  305. losses.extend(self._internal_losses)
  306. self._internal_losses.clear()
  307. @abstractmethod
  308. def _step_microbatches(
  309. self,
  310. arg_mbs: Optional[list] = None,
  311. kwarg_mbs: Optional[list] = None,
  312. target_mbs: Optional[list] = None,
  313. losses: Optional[list] = None,
  314. ):
  315. """
  316. Run one iteration of the pipeline schedule with list of microbatches.
  317. Will go through all the microbatches according to the schedule
  318. implementation.
  319. Args:
  320. microbatches: list of microbatch args.
  321. """
  322. raise NotImplementedError
  323. @abstractmethod
  324. def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
  325. """
  326. Run one iteration of the pipeline schedule with *whole-batch* input.
  327. Will chunk the input into microbatches automatically, and go through the
  328. microbatches according to the schedule implementation.
  329. args: positional arguments to the model (as in non-pipeline case).
  330. kwargs: keyword arguments to the model (as in non-pipeline case).
  331. target: target for the loss function.
  332. losses: a list to store the losses for each microbatch.
  333. """
  334. raise NotImplementedError
  335. def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs):
  336. """
  337. Run one iteration of the pipeline schedule with *whole-batch* input.
  338. Will chunk the input into microbatches automatically, and go through the
  339. microbatches, calling forward only.
  340. args: positional arguments to the model (as in non-pipeline case).
  341. kwargs: keyword arguments to the model (as in non-pipeline case).
  342. target: target values for the loss function.
  343. losses: a list to store the losses for each microbatch.
  344. """
  345. # Save the original has_backward state
  346. original_has_backward = self._has_backward
  347. try:
  348. self._has_backward = False
  349. return self.step(*args, target=target, losses=losses, **kwargs)
  350. finally:
  351. # Restore the original state
  352. self._has_backward = original_has_backward
  353. def _check_inputs(
  354. self,
  355. arg_mbs: Optional[list] = None,
  356. kwarg_mbs: Optional[list] = None,
  357. target_mbs: Optional[list] = None,
  358. losses: Optional[list] = None,
  359. ):
  360. """
  361. Pre-process/check inputs
  362. """
  363. def check_type_and_len(mbs, name: str):
  364. if not isinstance(mbs, list):
  365. raise TypeError(f"{name} must be a list but got a {type(mbs)}")
  366. if len(mbs) != self._n_microbatches:
  367. raise ValueError(
  368. f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
  369. )
  370. if arg_mbs is not None:
  371. check_type_and_len(arg_mbs, "arg_mbs")
  372. else:
  373. arg_mbs = [()] * self._n_microbatches
  374. if kwarg_mbs is not None:
  375. check_type_and_len(kwarg_mbs, "kwarg_mbs")
  376. else:
  377. kwarg_mbs = [{}] * self._n_microbatches
  378. if target_mbs is not None:
  379. check_type_and_len(target_mbs, "target_mbs")
  380. if losses is not None:
  381. if not isinstance(losses, list):
  382. raise TypeError(f"losses must be a list but got a {type(losses)}")
  383. return arg_mbs, kwarg_mbs
  384. def _compute_loss(self, output, target):
  385. return self._loss_fn(output, target) # type: ignore[misc]
  386. def _split_inputs(
  387. self,
  388. args: tuple[Any, ...],
  389. kwargs: Optional[dict[str, Any]] = None,
  390. ):
  391. """
  392. Splits a full-batch input into chunks (i.e. microbatches) and returns
  393. the chunks
  394. """
  395. if args or kwargs:
  396. args_split, kwargs_split = split_args_kwargs_into_chunks(
  397. args,
  398. kwargs,
  399. self._n_microbatches,
  400. self._args_chunk_spec,
  401. self._kwargs_chunk_spec,
  402. )
  403. return args_split, kwargs_split
  404. else:
  405. # Empty inputs (e.g. when called on middle stages)
  406. # Return a list of empty tuples/dicts with matching length as chunks
  407. return [()] * self._n_microbatches, [{}] * self._n_microbatches
  408. def _merge_outputs(self, output_chunks: list[Any]) -> Any:
  409. """
  410. Merge output chunks back to a batch state.
  411. If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
  412. """
  413. return merge_chunks(
  414. output_chunks,
  415. self._output_merge_spec,
  416. )
  417. def _batch_p2p(
  418. p2p_ops: list[dist.P2POp], desc: Optional[str] = None
  419. ) -> list[dist.Work]:
  420. """
  421. Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
  422. """
  423. if len(p2p_ops) == 0:
  424. return []
  425. desc_str = f"{desc}, " if desc else ""
  426. logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
  427. return dist.batch_isend_irecv(p2p_ops)
  428. def _sorted_batch_p2p(
  429. p2p_ops: list[dist.P2POp], desc: Optional[str] = None
  430. ) -> dict[int, list[dist.Work]]:
  431. """
  432. Sorts the list of P2P ops by the peer rank, and then calls
  433. batch_isend_irecv. Return a dictionary of works by peer rank. This function
  434. helps us avoid hangs in case of skip connections.
  435. """
  436. # Arrange p2p_ops by peer rank:
  437. # int is the peer rank;
  438. # List is the list of ops towards the peer
  439. ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
  440. work_by_peer: dict[int, list[dist.Work]] = {}
  441. if len(p2p_ops) == 0:
  442. return work_by_peer
  443. # Classify the ops by peer rank
  444. for op in p2p_ops:
  445. ops_by_peer[op.peer].append(op)
  446. # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
  447. for peer, ops in sorted(ops_by_peer.items()):
  448. work_by_peer[peer] = _batch_p2p(ops, desc=desc)
  449. return work_by_peer
  450. def _wait_batch_p2p(work: list[dist.Work]):
  451. """
  452. Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
  453. """
  454. for w in work:
  455. w.wait()
  456. class PipelineScheduleSingle(_PipelineSchedule):
  457. """
  458. Base class for single-stage schedules.
  459. Implements the `step` method.
  460. Derived classes should implement `_step_microbatches`.
  461. Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
  462. should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
  463. or sum losses (scale_grads=False).
  464. """
  465. def __init__(
  466. self,
  467. stage: _PipelineStageBase,
  468. n_microbatches: int,
  469. loss_fn: Optional[Callable] = None,
  470. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  471. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  472. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  473. scale_grads: bool = True,
  474. ):
  475. # Init parent
  476. super().__init__(
  477. n_microbatches=n_microbatches,
  478. loss_fn=loss_fn,
  479. args_chunk_spec=args_chunk_spec,
  480. kwargs_chunk_spec=kwargs_chunk_spec,
  481. output_merge_spec=output_merge_spec,
  482. scale_grads=scale_grads,
  483. )
  484. # Self attributes
  485. self._stage = stage
  486. self._num_stages = stage.num_stages
  487. self._stage_initialized = False
  488. if n_microbatches < self._num_stages:
  489. raise ValueError(
  490. f"Number of microbatches ({n_microbatches}) must be greater than \
  491. or equal to the number of stages ({self._num_stages})."
  492. )
  493. self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
  494. self._get_pipeline_order()
  495. )
  496. def _initialize_stage(self, args, kwargs):
  497. # Prepare the communication needed for the pipeline schedule execution
  498. # This is needed because during execution we always perform a series of batch P2P ops
  499. # The first call of the batched P2P needs to involve the global group
  500. all_ops: list[dist.P2POp] = []
  501. all_ops.extend(self._stage._get_init_p2p_neighbors_ops())
  502. _wait_batch_p2p(_batch_p2p(all_ops))
  503. self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
  504. if self._has_backward:
  505. self._stage._prepare_backward_infra(self._n_microbatches)
  506. self._stage_initialized = True
  507. def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
  508. """
  509. Run one iteration of the pipeline schedule with *whole-batch* input.
  510. Will chunk the input into microbatches automatically, and go through the
  511. microbatches according to the schedule implementation.
  512. args: positional arguments to the model (as in non-pipeline case).
  513. kwargs: keyword arguments to the model (as in non-pipeline case).
  514. target: target for the loss function.
  515. losses: a list to store the losses for each microbatch.
  516. """
  517. if self._has_backward and not torch.is_grad_enabled():
  518. raise RuntimeError(
  519. "step() requires gradients to be enabled for backward computation; "
  520. "it should not be used under torch.no_grad() context. "
  521. "Please call eval() instead."
  522. )
  523. # Set the same has_backward flag for stage object
  524. self._stage.has_backward = self._has_backward
  525. # Clean per iteration
  526. self._stage.clear_runtime_states()
  527. # Split inputs into microbatches
  528. args_split, kwargs_split = self._split_inputs(args, kwargs)
  529. # Split target into microbatches
  530. if target is not None:
  531. targets_split = list(torch.tensor_split(target, self._n_microbatches))
  532. else:
  533. targets_split = None
  534. # Run microbatches
  535. self._step_microbatches(args_split, kwargs_split, targets_split, losses)
  536. # Return merged results per original format
  537. if self._stage.is_last:
  538. return self._merge_outputs(self._stage.output_chunks)
  539. else:
  540. return None
  541. def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
  542. """
  543. Returns the pipeline execution order as a schedule IR.
  544. The returned IR is a dictionary mapping rank IDs to lists of actions.
  545. Each action is either an _Action object representing computation to perform,
  546. or None representing a deliberate idle step.
  547. The None values are used to represent pipeline bubbles where a rank
  548. must wait for dependencies from other ranks before proceeding. However
  549. during execution, with the _PipelineScheduleRuntime, these Nones are
  550. skipped since the relevant communication (send/recv) will be scheduled and waited on.
  551. Returns:
  552. A dictionary mapping rank -> list of actions
  553. """
  554. return None
  555. class _ScheduleForwardOnly(PipelineScheduleSingle):
  556. """
  557. The forward-only schedule.
  558. Will go through all the microbatches and perform only the forward pass
  559. """
  560. def _step_microbatches(
  561. self,
  562. arg_mbs: Optional[list] = None,
  563. kwarg_mbs: Optional[list] = None,
  564. target_mbs: Optional[list] = None,
  565. losses: Optional[list] = None,
  566. ):
  567. """
  568. Run one iteration of the pipeline schedule
  569. """
  570. if target_mbs is not None or losses is not None:
  571. raise RuntimeError(
  572. "Forward-only schedule does not support loss computation"
  573. )
  574. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  575. if not self._stage_initialized:
  576. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  577. # Delay send waits
  578. fwd_sends_to_wait: list[list[dist.Work]] = []
  579. # Run microbatches
  580. for i in range(self._n_microbatches):
  581. with record_function(f"Forward {i}"):
  582. ops = self._stage.get_fwd_recv_ops(i)
  583. works = _sorted_batch_p2p(ops, desc="fwd_recv")
  584. for work in works.values():
  585. _wait_batch_p2p(work)
  586. self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
  587. ops = self._stage.get_fwd_send_ops(i)
  588. works = _sorted_batch_p2p(ops, desc="fwd_send")
  589. fwd_sends_to_wait.extend(works.values())
  590. logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
  591. # Wait for all forward sends to finish
  592. # This should not have performance impact because by the time the first
  593. # backward arrives all the forward sends should have been finished.
  594. for work in fwd_sends_to_wait:
  595. _wait_batch_p2p(work)
  596. class ScheduleGPipe(PipelineScheduleSingle):
  597. """
  598. The GPipe schedule.
  599. Will go through all the microbatches in a fill-drain manner.
  600. """
  601. def _step_microbatches(
  602. self,
  603. arg_mbs: Optional[list] = None,
  604. kwarg_mbs: Optional[list] = None,
  605. target_mbs: Optional[list] = None,
  606. losses: Optional[list] = None,
  607. ):
  608. """
  609. Run one iteration of the pipeline schedule with list of microbatches.
  610. Will go through all the microbatches according to the GPipe schedule.
  611. Args:
  612. microbatches: list of microbatch args.
  613. """
  614. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  615. if not self._stage_initialized:
  616. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  617. # Delay send waits
  618. fwd_sends_to_wait: list[list[dist.Work]] = []
  619. # Run microbatches
  620. for i in range(self._n_microbatches):
  621. with record_function(f"Forward {i}"):
  622. ops = self._stage.get_fwd_recv_ops(i)
  623. works = _sorted_batch_p2p(ops, desc="fwd_recv")
  624. for work in works.values():
  625. _wait_batch_p2p(work)
  626. output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
  627. ops = self._stage.get_fwd_send_ops(i)
  628. works = _sorted_batch_p2p(ops, desc="fwd_send")
  629. fwd_sends_to_wait.extend(works.values())
  630. logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
  631. self._maybe_compute_loss(self._stage, output, target_mbs, i)
  632. # Wait for all forward sends to finish
  633. # This should not have performance impact because by the time the first
  634. # backward arrives all the forward sends should have been finished.
  635. for work in fwd_sends_to_wait:
  636. _wait_batch_p2p(work)
  637. # Run backward
  638. # Delay send waits
  639. bwd_sends_to_wait: list[list[dist.Work]] = []
  640. for i in range(self._n_microbatches):
  641. with record_function(f"Backward {i}"):
  642. ops = self._stage.get_bwd_recv_ops(i)
  643. works = _sorted_batch_p2p(ops, desc="bwd_recv")
  644. for work in works.values():
  645. _wait_batch_p2p(work)
  646. loss = self._maybe_get_loss(self._stage, i)
  647. self._stage.backward_one_chunk(
  648. i,
  649. loss=loss,
  650. last_backward=i == self._n_microbatches - 1,
  651. )
  652. ops = self._stage.get_bwd_send_ops(i)
  653. works = _sorted_batch_p2p(ops, desc="bwd_send")
  654. bwd_sends_to_wait.extend(works.values())
  655. logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
  656. self._stage.scale_grads(
  657. grad_scale_factor=self._n_microbatches if self.scale_grads else 1
  658. )
  659. # Wait for all backward sends to finish
  660. for work in bwd_sends_to_wait:
  661. _wait_batch_p2p(work)
  662. # Update losses if there is a container passed in
  663. self._update_losses(self._stage, losses)
  664. def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
  665. """
  666. Returns the pipeline order for GPipe schedule.
  667. See base method in PipelineScheduleSingle for details on the schedule IR format.
  668. """
  669. pipeline_order = {}
  670. pp_group_size = self._num_stages
  671. for rank in range(pp_group_size):
  672. actions: list[Optional[_Action]] = []
  673. # 1. Initial delay based on rank position
  674. warmup_delay = rank
  675. actions.extend([None] * warmup_delay)
  676. # 2. Forward passes for all microbatches
  677. for mb_idx in range(self._n_microbatches):
  678. actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
  679. # 3. Wait period before backward passes can begin
  680. backward_delay = 3 * (pp_group_size - 1 - rank)
  681. actions.extend([None] * backward_delay)
  682. # 4. Backward passes for all microbatches
  683. for mb_idx in range(self._n_microbatches):
  684. actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
  685. pipeline_order[rank] = actions
  686. return pipeline_order
  687. class Schedule1F1B(PipelineScheduleSingle):
  688. """
  689. The 1F1B schedule.
  690. Will perform one forward and one backward on the microbatches in steady state.
  691. """
  692. def _step_microbatches(
  693. self,
  694. arg_mbs: Optional[list] = None,
  695. kwarg_mbs: Optional[list] = None,
  696. target_mbs: Optional[list] = None,
  697. losses: Optional[list] = None,
  698. ):
  699. """
  700. Run one iteration of the pipeline schedule with list of microbatches.
  701. Will go through all the microbatches according to the 1F1B schedule.
  702. Args:
  703. microbatches: list of microbatch args.
  704. """
  705. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  706. if not self._stage_initialized:
  707. self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
  708. # Last stage has 1 warmup, second-to-last 2 warmups, ...
  709. # first stage `num_stages` warmups
  710. warmup_chunks = min(
  711. self._n_microbatches,
  712. self._num_stages - self._stage.stage_index,
  713. )
  714. # Chunk counters
  715. fwd_mb_index = 0
  716. bwd_mb_index = 0
  717. # Warmup phase
  718. send_work: list[dist.Work] = []
  719. fwd_sends = []
  720. for _ in range(warmup_chunks):
  721. # Receive activations
  722. fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
  723. _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv"))
  724. # Compute
  725. output = self._stage.forward_one_chunk(
  726. fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
  727. ) # type: ignore[index]
  728. # Clear previous chunk's forward sends (hopefully they have well
  729. # finished, otherwise, we are heavily communication bound, in which
  730. # case it doesn't create a lot of benefit to compute next chunk
  731. # eagerly either)
  732. _wait_batch_p2p(send_work)
  733. # Send activations
  734. fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
  735. if fwd_mb_index != warmup_chunks - 1:
  736. # Safe to fire
  737. send_work = _batch_p2p(fwd_sends, desc="fwd_send")
  738. # otherwise:
  739. # The last forward send is left for fuse with first 1B in 1B1F below
  740. # Compute loss
  741. self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
  742. fwd_mb_index += 1
  743. # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
  744. # 1B1F phase
  745. while True: # Don't worry, we have a break inside
  746. # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
  747. bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
  748. # Now, we need to fire the fwd_sends and bwd_recvs together
  749. _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"))
  750. # Backward one chunk
  751. loss = self._maybe_get_loss(self._stage, bwd_mb_index)
  752. self._stage.backward_one_chunk(
  753. bwd_mb_index,
  754. loss=loss,
  755. last_backward=bwd_mb_index == self._n_microbatches - 1,
  756. )
  757. # Get the bwd send ops, but don't fire, to be fused with the 1F below
  758. bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
  759. bwd_mb_index += 1
  760. if fwd_mb_index == self._n_microbatches:
  761. # We are done with 1B1F, so break with some left-over bwd_sends
  762. break
  763. # We prepare 1F of the `1B1F`
  764. fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
  765. # Fuse it with bwd_sends above
  766. _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"))
  767. # Now do the fwd
  768. output = self._stage.forward_one_chunk(
  769. fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
  770. ) # type: ignore[index]
  771. # Compute loss
  772. self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
  773. # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
  774. fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
  775. fwd_mb_index += 1
  776. # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
  777. send_work = _batch_p2p(bwd_sends, desc="bwd_send")
  778. # Cooldown
  779. while bwd_mb_index < self._n_microbatches:
  780. # prepare bwd recv ops
  781. bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
  782. _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv"))
  783. # Backward one chunk
  784. loss = self._maybe_get_loss(self._stage, bwd_mb_index)
  785. self._stage.backward_one_chunk(
  786. bwd_mb_index,
  787. loss=loss,
  788. last_backward=bwd_mb_index == self._n_microbatches - 1,
  789. )
  790. # Clear previous chunk's backward sends (hopefully they have well finished)
  791. _wait_batch_p2p(send_work)
  792. # Get the bwd send ops, fire it
  793. bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
  794. send_work = _batch_p2p(bwd_sends, desc="bwd_send")
  795. bwd_mb_index += 1
  796. self._stage.scale_grads(
  797. grad_scale_factor=self._n_microbatches if self.scale_grads else 1
  798. )
  799. # Wait for the last backward send to finish
  800. _wait_batch_p2p(send_work)
  801. # Return losses if there is a container passed in
  802. self._update_losses(self._stage, losses)
  803. def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
  804. """
  805. Returns the pipeline order for 1F1B schedule.
  806. See base method in PipelineScheduleSingle for details on the schedule IR format.
  807. """
  808. pipeline_order = {}
  809. pp_group_size = self._num_stages
  810. for rank in range(pp_group_size):
  811. actions: list[Optional[_Action]] = []
  812. # 1. Warmup phase: initial delay based on rank
  813. actions.extend([None] * rank)
  814. # 2. Initial forward passes before 1F1B phase
  815. num_forward = (pp_group_size - 1) - rank
  816. forward_mb = 0
  817. for i in range(num_forward):
  818. actions.append(_Action(rank, _ComputationType.FORWARD, i))
  819. forward_mb = i
  820. # 3. Wait for backward to be ready
  821. wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
  822. actions.extend([None] * wait_for_1f1b)
  823. # 4. 1F1B steady state phase
  824. backward_mb = 0
  825. remaining_forward = self._n_microbatches - num_forward
  826. while remaining_forward > 0:
  827. # One forward
  828. forward_mb += 1
  829. actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
  830. remaining_forward -= 1
  831. # One backward
  832. actions.append(
  833. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  834. )
  835. backward_mb += 1
  836. # 5. Cooldown phase: remaining backward passes
  837. remaining_backward = self._n_microbatches - backward_mb
  838. while remaining_backward > 0:
  839. # Add None and backward actions in alternating pattern
  840. # based on distance from the last stage
  841. if (pp_group_size - rank) > 0:
  842. actions.append(None)
  843. # Decrement the wait counter only if we still have backward passes to do
  844. if remaining_backward > 0:
  845. actions.append(
  846. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  847. )
  848. backward_mb += 1
  849. remaining_backward -= 1
  850. else:
  851. # If we're at the last stage, just add backward actions without None
  852. actions.append(
  853. _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
  854. )
  855. backward_mb += 1
  856. remaining_backward -= 1
  857. pipeline_order[rank] = actions
  858. return pipeline_order
  859. def _add_unshard_reshard(
  860. compute_actions: list[Optional[_Action]],
  861. max_active_stages: int = 3,
  862. ) -> list[_Action]:
  863. """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
  864. UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
  865. RESHARD does the opposite, releasing memory (but doing no communication)
  866. We abandon the "timestep lock" during lowering
  867. max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
  868. 3 stages is probably the thing we want?
  869. (to account for having one f and one b active, and something else prefetching?)
  870. """
  871. def next_stage_indices(
  872. count: int, next_actions: list[Optional[_Action]]
  873. ) -> list[int]:
  874. """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
  875. seen: set[int] = set()
  876. ret: list[int] = []
  877. for a in next_actions:
  878. if a is not None:
  879. # Handle OVERLAP_F_B actions by checking their sub_actions
  880. if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
  881. for sub_action in a.sub_actions:
  882. if sub_action.stage_index not in seen:
  883. seen.add(sub_action.stage_index)
  884. ret.append(sub_action.stage_index)
  885. if len(ret) == count:
  886. break
  887. if len(ret) == count:
  888. break
  889. else:
  890. # Regular action
  891. if a.stage_index not in seen:
  892. seen.add(a.stage_index)
  893. ret.append(a.stage_index)
  894. if len(ret) == count:
  895. break
  896. return ret
  897. active_stages: set[int] = set()
  898. fsdp_aware_actions: list[_Action] = []
  899. def _unshard(stage_index: int):
  900. active_stages.add(stage_index)
  901. fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
  902. def _reshard(stage_index: int):
  903. active_stages.remove(stage_index)
  904. fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
  905. for i, action in enumerate(compute_actions):
  906. if action is None:
  907. continue
  908. # We prefetch the next N stages we'll see, dropping existing stages to make room
  909. next_n = next_stage_indices(max_active_stages, compute_actions[i:])
  910. # Fetch needs to be ordered correctly, so don't use a set
  911. fetch = list(filter(lambda s: s not in active_stages, next_n))
  912. # Unclear what the best policy is for eviction, but we can maintain order so we do
  913. evict = list(filter(lambda s: s not in next_n, active_stages))
  914. # logger.debug(
  915. # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
  916. # i,
  917. # active_stages,
  918. # fetch,
  919. # evict,
  920. # )
  921. for stage in evict:
  922. _reshard(stage)
  923. for stage in fetch:
  924. _unshard(stage)
  925. fsdp_aware_actions.append(action)
  926. return fsdp_aware_actions
  927. def _merge_bw(
  928. compute_actions: list[Optional[_Action]],
  929. ) -> list[_Action]:
  930. """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
  931. (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
  932. B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
  933. in some cases.
  934. """
  935. merged_actions = []
  936. while compute_actions:
  937. action = compute_actions.pop(0)
  938. if action is None:
  939. continue
  940. # Remove any None actions and find the next non-None action
  941. while len(compute_actions) and compute_actions[0] is None:
  942. compute_actions.pop(0)
  943. # Get the next action if it exists
  944. next_action = compute_actions[0] if len(compute_actions) > 0 else None
  945. if (
  946. action.computation_type == BACKWARD_INPUT
  947. and next_action is not None
  948. and next_action.computation_type == BACKWARD_WEIGHT
  949. and action.stage_index == next_action.stage_index
  950. and action.microbatch_index == next_action.microbatch_index
  951. ):
  952. merged_actions.append(
  953. _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index)
  954. )
  955. compute_actions.pop(0)
  956. else:
  957. merged_actions.append(action)
  958. return merged_actions
  959. def _add_send_recv(
  960. compute_actions: dict[int, list[_Action]],
  961. stage_to_rank: Callable[[int], int],
  962. num_stages: int,
  963. ) -> dict[int, list[_Action]]:
  964. """
  965. Transforms a compute-only schedule into a complete schedule with communication actions.
  966. """
  967. comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
  968. prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
  969. def _has_comms(action: _Action) -> bool:
  970. if action.computation_type == F:
  971. return action.stage_index != num_stages - 1 and stage_to_rank(
  972. action.stage_index + 1
  973. ) != stage_to_rank(action.stage_index)
  974. elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  975. return action.stage_index != 0 and stage_to_rank(
  976. action.stage_index - 1
  977. ) != stage_to_rank(action.stage_index)
  978. return False
  979. def _get_comms(action: _Action) -> tuple[_Action, _Action]:
  980. assert _has_comms(action), f"{action} is not a valid comm action"
  981. stage_idx = action.stage_index
  982. ctype = action.computation_type
  983. mb_idx = action.microbatch_index
  984. send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
  985. recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
  986. recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
  987. return send, recv
  988. def _ready_to_schedule(
  989. action: Optional[_Action], prev_actions: set[_Action]
  990. ) -> bool:
  991. """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
  992. This helps ensure a sane (non-hanging) ordering of sends and recvs.
  993. But it also means we might not be able to schedule our next compute action yet.
  994. """
  995. if action is None:
  996. return True
  997. elif action.computation_type == F and not action.stage_index == 0:
  998. if (
  999. _Action(action.stage_index, RECV_F, action.microbatch_index)
  1000. in prev_actions
  1001. ):
  1002. return True
  1003. elif (
  1004. _Action(action.stage_index - 1, F, action.microbatch_index)
  1005. in prev_actions
  1006. ):
  1007. return True
  1008. return False
  1009. elif (
  1010. action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
  1011. and not action.stage_index == num_stages - 1
  1012. ):
  1013. if (
  1014. _Action(action.stage_index, RECV_B, action.microbatch_index)
  1015. in prev_actions
  1016. ):
  1017. return True
  1018. elif (
  1019. _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
  1020. in prev_actions
  1021. ):
  1022. return True
  1023. elif (
  1024. _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
  1025. in prev_actions
  1026. ):
  1027. return True
  1028. return False
  1029. else:
  1030. return True
  1031. # TODO: For now we are splitting OVERLAP_F_B into replacing it to
  1032. # its forward and backward components
  1033. # We need to figure out how to do the communication
  1034. for rank in compute_actions:
  1035. new_actions: list[_Action] = []
  1036. for action in compute_actions[rank]:
  1037. if action is not None and action.sub_actions is not None:
  1038. # Replace OVERLAP_F_B action with its sub_actions
  1039. new_actions.extend(action.sub_actions)
  1040. else:
  1041. new_actions.append(action)
  1042. compute_actions[rank] = new_actions
  1043. while compute_actions:
  1044. progress = False
  1045. # go in order of ranks even if dict keys aren't ordered
  1046. for rank in sorted(compute_actions):
  1047. assert len(compute_actions[rank]) > 0, (
  1048. f"{rank=}, {len(compute_actions[rank])=}"
  1049. )
  1050. action = compute_actions[rank][0]
  1051. if not _ready_to_schedule(action, prev_actions[rank]):
  1052. continue
  1053. if action is not None:
  1054. comm_actions[rank].append(action)
  1055. prev_actions[rank].add(action)
  1056. if _has_comms(action):
  1057. send, recv = _get_comms(action)
  1058. # TODO we can avoid send/recv if the 2 stages are on the same rank.
  1059. # should we avoid that in the runtime or here?
  1060. comm_actions[rank].append(send)
  1061. prev_actions[rank].add(send)
  1062. comm_actions[stage_to_rank(recv.stage_index)].append(recv)
  1063. prev_actions[stage_to_rank(recv.stage_index)].add(recv)
  1064. compute_actions[rank].pop(0)
  1065. if len(compute_actions[rank]) == 0:
  1066. del compute_actions[rank]
  1067. progress = True
  1068. assert progress, "Malformed compute schedule, can't schedule sends/recvs"
  1069. return comm_actions
  1070. def _validate_schedule(
  1071. actions: dict[int, list[Optional[_Action]]],
  1072. pp_group_size: int,
  1073. num_stages: int,
  1074. num_microbatches: int,
  1075. ) -> dict[int, int]:
  1076. assert len(actions) == pp_group_size, (
  1077. f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
  1078. )
  1079. for rank in range(pp_group_size):
  1080. assert rank in actions, f"Schedule is missing actions for rank {rank}"
  1081. # We will count all the actions per stage and ensure they happen in a valid order
  1082. # (e.g. F before (B, I) before W for a given microbatch)
  1083. stage_actions: dict[int, dict[_ComputationType, set]] = {
  1084. stage_id: {
  1085. F: set(),
  1086. B: set(),
  1087. I: set(),
  1088. W: set(),
  1089. }
  1090. for stage_id in range(num_stages)
  1091. }
  1092. stage_index_to_rank_mapping = {}
  1093. def _process_action(action: _Action, rank: int, step: int):
  1094. """Process a single action and update stage_actions and stage_index_to_rank_mapping"""
  1095. s_id = action.stage_index
  1096. ctype = action.computation_type
  1097. mb_id = action.microbatch_index
  1098. if ctype == F:
  1099. stage_actions[s_id][F].add(mb_id)
  1100. elif ctype == B:
  1101. if mb_id not in stage_actions[s_id][F]:
  1102. error_msg = (
  1103. f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, "
  1104. f"microbatch {mb_id} without first running Forward"
  1105. )
  1106. formatted_schedule = _format_pipeline_order(
  1107. actions, error_step_number=step
  1108. )
  1109. full_error_msg = (
  1110. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1111. )
  1112. raise AssertionError(full_error_msg)
  1113. stage_actions[s_id][B].add(mb_id)
  1114. elif ctype == I:
  1115. if mb_id not in stage_actions[s_id][F]:
  1116. error_msg = (
  1117. f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, "
  1118. f"microbatch {mb_id} without first running Forward"
  1119. )
  1120. formatted_schedule = _format_pipeline_order(
  1121. actions, error_step_number=step
  1122. )
  1123. full_error_msg = (
  1124. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1125. )
  1126. raise AssertionError(full_error_msg)
  1127. stage_actions[s_id][I].add(mb_id)
  1128. elif ctype == W:
  1129. if mb_id not in stage_actions[s_id][I]:
  1130. error_msg = (
  1131. f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, "
  1132. f"microbatch {mb_id} without first running Backward Input"
  1133. )
  1134. formatted_schedule = _format_pipeline_order(
  1135. actions, error_step_number=step
  1136. )
  1137. full_error_msg = (
  1138. f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
  1139. )
  1140. raise AssertionError(full_error_msg)
  1141. stage_actions[s_id][W].add(mb_id)
  1142. if s_id not in stage_index_to_rank_mapping:
  1143. stage_index_to_rank_mapping[s_id] = rank
  1144. else:
  1145. existing_rank = stage_index_to_rank_mapping[s_id]
  1146. assert rank == existing_rank, (
  1147. f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
  1148. )
  1149. for rank in actions:
  1150. for step, action in enumerate(actions[rank]):
  1151. if action is None:
  1152. continue
  1153. assert isinstance(action, _Action), (
  1154. f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action"
  1155. )
  1156. # Check if action has sub_actions
  1157. if action.sub_actions is not None:
  1158. # Process each sub_action instead of the main action
  1159. for sub_action in action.sub_actions:
  1160. _process_action(sub_action, rank, step)
  1161. else:
  1162. # Process the main action normally
  1163. _process_action(action, rank, step)
  1164. for s_id in stage_actions:
  1165. f_mb = len(stage_actions[s_id][F])
  1166. b_mb = len(stage_actions[s_id][B])
  1167. i_mb = len(stage_actions[s_id][I])
  1168. w_mb = len(stage_actions[s_id][W])
  1169. assert f_mb == num_microbatches, (
  1170. f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
  1171. )
  1172. assert i_mb == w_mb, (
  1173. f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \
  1174. but got I={i_mb}, W={w_mb}"
  1175. )
  1176. assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
  1177. f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
  1178. but got B={b_mb}, I={i_mb}, W={w_mb}"
  1179. )
  1180. return stage_index_to_rank_mapping
  1181. class PipelineScheduleMulti(_PipelineSchedule):
  1182. """
  1183. Base class for multi-stage schedules.
  1184. Implements the `step` method.
  1185. Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
  1186. should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
  1187. or sum losses (scale_grads=False).
  1188. """
  1189. def __init__(
  1190. self,
  1191. stages: list[_PipelineStageBase],
  1192. n_microbatches: int,
  1193. loss_fn: Optional[Callable] = None,
  1194. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  1195. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  1196. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  1197. use_full_backward: Optional[bool] = None,
  1198. scale_grads: bool = True,
  1199. ):
  1200. # Init parent
  1201. super().__init__(
  1202. n_microbatches=n_microbatches,
  1203. loss_fn=loss_fn,
  1204. args_chunk_spec=args_chunk_spec,
  1205. kwargs_chunk_spec=kwargs_chunk_spec,
  1206. output_merge_spec=output_merge_spec,
  1207. scale_grads=scale_grads,
  1208. )
  1209. # Self attributes
  1210. self._stages = stages
  1211. self._num_stages = stages[0].num_stages
  1212. self.pp_group_size = stages[0].group_size
  1213. self.rank = stages[0].group_rank
  1214. # Set the pipeline stage states
  1215. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  1216. self.pp_group_size, self._num_stages
  1217. )
  1218. for stage in self._stages:
  1219. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  1220. self._stages_initialized = False
  1221. # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle
  1222. has_loss: bool = self._loss_fn is not None
  1223. self._should_compute_loss = lambda stage: stage.is_last and has_loss
  1224. # This will be set during init of derived schedules
  1225. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  1226. if use_full_backward is not None:
  1227. logger.warning(
  1228. "Deprecation warning: 'use_full_backward' is no longer supported. "
  1229. "Simply stop passing it, and everything should still work fine."
  1230. )
  1231. def _initialize_stages(self, args: tuple[Any, ...], kwargs):
  1232. # Prepare the communication needed for the pipeline schedule execution
  1233. # This is needed because during execution we always perform a series of batch P2P ops
  1234. # The first call of the batched P2P needs to involve the global group
  1235. all_ops: list[dist.P2POp] = []
  1236. for stage in self._stages:
  1237. all_ops.extend(stage._get_init_p2p_neighbors_ops())
  1238. _wait_batch_p2p(_batch_p2p(all_ops))
  1239. # may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
  1240. # or real value (if this stage and next stage are on the same device)
  1241. next_stage_args: tuple[Any, ...] = tuple()
  1242. for stage in self._stages:
  1243. if stage.is_first:
  1244. next_stage_args = stage._prepare_forward_infra(
  1245. self._n_microbatches, args, kwargs
  1246. )
  1247. else:
  1248. next_stage_args = stage._prepare_forward_infra(
  1249. self._n_microbatches, next_stage_args, kwargs
  1250. )
  1251. if self._has_backward:
  1252. stage._prepare_backward_infra(self._n_microbatches)
  1253. self._stages_initialized = True
  1254. def _validate_and_set_stage_mapping(
  1255. self, actions: dict[int, list[Optional[_Action]]]
  1256. ) -> None:
  1257. """
  1258. Allocates the stage index to rank mapping which is needed for communication
  1259. """
  1260. self.stage_index_to_group_rank = _validate_schedule(
  1261. actions,
  1262. self.pp_group_size,
  1263. self._num_stages,
  1264. self._n_microbatches,
  1265. )
  1266. for stage in self._stages:
  1267. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  1268. def _dump_csv(self, filename):
  1269. """Dump a CSV representation of the schedule into a file with the provided filename."""
  1270. with open(filename, "w", newline="") as csvfile:
  1271. writer = csv.writer(csvfile)
  1272. for rank in self.pipeline_order:
  1273. writer.writerow(self.pipeline_order[rank])
  1274. def _load_csv(self, filename, format="compute_only"):
  1275. """Load a CSV representation of the schedule from a file with the provided filename.
  1276. This API will most likely get renamed/refactored so is marked as internal for now.
  1277. format must be "compute_only" for PipelineScheduleMulti.
  1278. """
  1279. assert format == "compute_only"
  1280. with open(filename, newline="") as csvfile:
  1281. reader = csv.reader(csvfile)
  1282. for rank, row in enumerate(reader):
  1283. self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
  1284. # Validates the order of the pipeline actions and infers the stage_to_rank_mapping.
  1285. # This will overwrite the default stage_to_rank_mapping created in the constructor
  1286. self._validate_and_set_stage_mapping(self.pipeline_order)
  1287. def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
  1288. """
  1289. Run one iteration of the pipeline schedule with *whole-batch* input.
  1290. Will chunk the input into microbatches automatically, and go through the
  1291. microbatches according to the schedule implementation.
  1292. args: positional arguments to the model (as in non-pipeline case).
  1293. kwargs: keyword arguments to the model (as in non-pipeline case).
  1294. target: target for the loss function.
  1295. losses: a list to store the losses for each microbatch.
  1296. """
  1297. if self._has_backward and not torch.is_grad_enabled():
  1298. raise RuntimeError(
  1299. "step() requires gradients to be enabled for backward computation; "
  1300. "it should not be used under torch.no_grad() context. "
  1301. "Please call eval() instead."
  1302. )
  1303. # Set the same has_backward flag for stage object
  1304. for stage in self._stages:
  1305. stage.has_backward = self._has_backward
  1306. # Clean per iteration
  1307. for stage in self._stages:
  1308. stage.clear_runtime_states()
  1309. # Split inputs into microbatches
  1310. args_split, kwargs_split = self._split_inputs(args, kwargs)
  1311. # Split target into microbatches
  1312. if target is not None:
  1313. targets_split = list(torch.tensor_split(target, self._n_microbatches))
  1314. else:
  1315. targets_split = None
  1316. # Run microbatches
  1317. self._step_microbatches(args_split, kwargs_split, targets_split, losses)
  1318. # Return merged results per original format
  1319. for stage in self._stages:
  1320. if stage.is_last:
  1321. return self._merge_outputs(stage.output_chunks)
  1322. # Does not contain the last stage
  1323. return None
  1324. def _step_microbatches(
  1325. self,
  1326. arg_mbs: Optional[list] = None,
  1327. kwarg_mbs: Optional[list] = None,
  1328. target_mbs: Optional[list] = None,
  1329. losses: Optional[list] = None,
  1330. ):
  1331. """
  1332. Operate on the microbatches for looped schedules (multiple stages on each rank).
  1333. TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
  1334. not support models with skip connections.
  1335. """
  1336. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  1337. if not self._stages_initialized:
  1338. self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
  1339. # Based on the plan in Step 1 created in __init__:
  1340. # 2. Perform communication based on the pipeline_order
  1341. stage_index_to_stage: dict[int, _PipelineStageBase] = {
  1342. stage.stage_index: stage for stage in self._stages
  1343. }
  1344. # determine prev_rank and next_rank based on which ranks are next to
  1345. # the stages in the pipeline_order
  1346. all_prev_ranks: set[int] = set()
  1347. all_next_ranks: set[int] = set()
  1348. for stage_index in stage_index_to_stage.keys():
  1349. # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
  1350. if stage_index > 0:
  1351. all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
  1352. if stage_index < self._num_stages - 1:
  1353. all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
  1354. # count either full_backward or backward_weight together, to determine when to sync DP grads
  1355. backward_counter: Counter[int] = Counter()
  1356. for time_step, action in enumerate(self.pipeline_order[self.rank]):
  1357. try:
  1358. ops: list[dist.P2POp] = []
  1359. if action is not None:
  1360. computation_type = action.computation_type
  1361. mb_index = action.microbatch_index
  1362. stage_index = action.stage_index
  1363. assert mb_index is not None, (
  1364. "All currently supported action types require valid microbatch_index"
  1365. )
  1366. if computation_type == _ComputationType.FORWARD:
  1367. # perform forward computation
  1368. stage = stage_index_to_stage[stage_index]
  1369. output = stage.forward_one_chunk(
  1370. mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
  1371. )
  1372. self._maybe_compute_loss(stage, output, target_mbs, mb_index)
  1373. ops.extend(stage.get_fwd_send_ops(mb_index))
  1374. elif computation_type == _ComputationType.FULL_BACKWARD:
  1375. # perform backward computation
  1376. stage = stage_index_to_stage[stage_index]
  1377. loss = self._maybe_get_loss(stage, mb_index)
  1378. backward_counter[stage_index] += 1
  1379. last_backward = (
  1380. backward_counter[stage_index] == self._n_microbatches
  1381. )
  1382. grad_scale_factor = (
  1383. self._n_microbatches if self.scale_grads else 1
  1384. )
  1385. stage.backward_one_chunk(
  1386. mb_index,
  1387. loss=loss,
  1388. full_backward=True,
  1389. last_backward=last_backward,
  1390. )
  1391. if last_backward:
  1392. stage.scale_grads(grad_scale_factor)
  1393. ops.extend(stage.get_bwd_send_ops(mb_index))
  1394. elif computation_type == _ComputationType.BACKWARD_INPUT:
  1395. # perform backward computation
  1396. stage = stage_index_to_stage[stage_index]
  1397. loss = self._maybe_get_loss(stage, mb_index)
  1398. stage.backward_one_chunk(
  1399. mb_index,
  1400. loss=loss,
  1401. full_backward=False,
  1402. last_backward=False,
  1403. )
  1404. ops.extend(stage.get_bwd_send_ops(mb_index))
  1405. elif computation_type == _ComputationType.BACKWARD_WEIGHT:
  1406. # perform weight update
  1407. stage = stage_index_to_stage[stage_index]
  1408. backward_counter[stage_index] += 1
  1409. last_backward = (
  1410. backward_counter[stage_index] == self._n_microbatches
  1411. )
  1412. grad_scale_factor = (
  1413. self._n_microbatches if self.scale_grads else 1
  1414. )
  1415. stage.backward_weight_one_chunk(
  1416. mb_index,
  1417. last_backward=last_backward,
  1418. )
  1419. if last_backward:
  1420. stage.scale_grads(grad_scale_factor)
  1421. else:
  1422. raise ValueError(f"Unknown computation type {computation_type}")
  1423. # Look at the neighboring ranks for this current timestep and determine whether
  1424. # this current rank needs to do any recv communication
  1425. for prev_rank in all_prev_ranks:
  1426. prev_rank_ops = self.pipeline_order[prev_rank]
  1427. prev_rank_action = None
  1428. if time_step < len(prev_rank_ops):
  1429. prev_rank_action = prev_rank_ops[time_step]
  1430. if prev_rank_action is not None:
  1431. computation_type = prev_rank_action.computation_type
  1432. mb_index = prev_rank_action.microbatch_index
  1433. stage_index = prev_rank_action.stage_index
  1434. assert mb_index is not None, (
  1435. "All currently supported action types require valid microbatch_index"
  1436. )
  1437. # Only handle sends for the forward from a previous rank
  1438. if computation_type == _ComputationType.FORWARD:
  1439. # If not the last stage, then receive fwd activations
  1440. if stage_index + 1 in stage_index_to_stage:
  1441. # TODO: We are assuming that stage will always receive from stage-1
  1442. # however that is not necessarily true of get_fwd_recv_ops
  1443. stage = stage_index_to_stage[stage_index + 1]
  1444. ops.extend(stage.get_fwd_recv_ops(mb_index))
  1445. elif computation_type in (
  1446. FULL_BACKWARD,
  1447. BACKWARD_INPUT,
  1448. BACKWARD_WEIGHT,
  1449. ):
  1450. # Previous rank doing backward has no influence for the current rank forward recv
  1451. pass
  1452. else:
  1453. raise ValueError(
  1454. f"Unknown computation type {computation_type}"
  1455. )
  1456. for next_rank in all_next_ranks:
  1457. next_rank_ops = self.pipeline_order[next_rank]
  1458. next_rank_action = None
  1459. if time_step < len(next_rank_ops):
  1460. next_rank_action = next_rank_ops[time_step]
  1461. if next_rank_action is not None:
  1462. computation_type = next_rank_action.computation_type
  1463. mb_index = next_rank_action.microbatch_index
  1464. stage_index = next_rank_action.stage_index
  1465. assert mb_index is not None, (
  1466. "All currently supported action types require valid microbatch_index"
  1467. )
  1468. # Only handle receives for the backwards from a next rank
  1469. if computation_type in (FORWARD, BACKWARD_WEIGHT):
  1470. # Next rank doing forward or weight update has no influence for the current rank backward recv
  1471. pass
  1472. elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  1473. # If not the first stage, then receive bwd gradients
  1474. if stage_index - 1 in stage_index_to_stage:
  1475. # TODO: We are assuming that stage will always receive from stage+1
  1476. # however that is not necessarily true of get_bwd_recv_ops
  1477. stage = stage_index_to_stage[stage_index - 1]
  1478. ops.extend(stage.get_bwd_recv_ops(mb_index))
  1479. else:
  1480. raise ValueError(
  1481. f"Unknown computation type {computation_type}"
  1482. )
  1483. # do the communication
  1484. _wait_batch_p2p(_batch_p2p(ops))
  1485. except Exception as e:
  1486. logger.error(
  1487. "[Rank %s] pipeline schedule %s caught the following exception '%s' \
  1488. at time_step %s when running action %s",
  1489. self.rank,
  1490. self.__class__.__name__,
  1491. str(e),
  1492. time_step,
  1493. action,
  1494. )
  1495. logger.error(
  1496. "%s",
  1497. _format_pipeline_order(
  1498. self.pipeline_order, error_step_number=time_step
  1499. ),
  1500. )
  1501. raise e
  1502. # Return losses if there is a container passed in
  1503. self._update_losses(self._stages, losses)
  1504. class _PipelineScheduleRuntime(PipelineScheduleMulti):
  1505. """
  1506. Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
  1507. Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
  1508. subclassed and the subclass can be responsible for creating a schedule IR.
  1509. """
  1510. def _prepare_schedule_with_comms(
  1511. self,
  1512. actions: dict[int, list[Optional[_Action]]],
  1513. format: str = "compute_only",
  1514. ):
  1515. """
  1516. Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
  1517. communication actions. Stores the schedule in self, and must be called before running step_mo()
  1518. """
  1519. # validate the provided actions are valid and overrides the default stage_index_to_group_rank
  1520. super()._validate_and_set_stage_mapping(actions)
  1521. self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
  1522. if format == "compute_comms":
  1523. for rank in actions:
  1524. self.pipeline_order_with_comms[rank] = []
  1525. for action in actions[rank]:
  1526. assert action is not None
  1527. self.pipeline_order_with_comms[rank].append(action)
  1528. # TODO what level of validation should we offer for compute+comms schedule?
  1529. elif format == "compute_only":
  1530. # Validate that the schedule does not have comms already added to it
  1531. for rank, action_list in actions.items():
  1532. for i, action in enumerate(action_list):
  1533. if action is not None and not action.is_compute_op:
  1534. raise ValueError(
  1535. f"Expected compute-only schedule but found communication action "
  1536. f"'{action}' at rank {rank}, position {i}. "
  1537. f"Communication actions (e.g. SEND_F, RECV_F, etc.) "
  1538. f"should not be present when format='compute_only'."
  1539. )
  1540. # Perform schedule lowering
  1541. for rank in actions:
  1542. self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
  1543. actions[rank]
  1544. )
  1545. self.pipeline_order_with_comms = _add_send_recv(
  1546. self.pipeline_order_with_comms,
  1547. stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
  1548. num_stages=self._num_stages,
  1549. )
  1550. else:
  1551. raise NotImplementedError(f"{format=} is not implemented")
  1552. def _load_csv(self, filename: str, format: str = "compute_only"):
  1553. """Loads a csv in simple format and then lowers it to include communication actions
  1554. format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
  1555. will automatically be run to generate a compute_comms schedule.
  1556. """
  1557. if format == "compute_only":
  1558. # this will populate self.pipeline_order
  1559. super()._load_csv(filename)
  1560. # this will populate self.pipeline_order_with_comms
  1561. self._prepare_schedule_with_comms(self.pipeline_order)
  1562. elif format == "compute_comms":
  1563. actions = {}
  1564. with open(filename, newline="") as csvfile:
  1565. reader = csv.reader(csvfile)
  1566. for rank, row in enumerate(reader):
  1567. actions[rank] = [_Action.from_str(s) for s in row]
  1568. self._prepare_schedule_with_comms(actions, format=format)
  1569. else:
  1570. raise NotImplementedError(f"{format=} is not implemented")
  1571. def _dump_csv(self, filename: str, format: str = "compute_comms"):
  1572. """Dump a CSV representation of the schedule into a file with the provided filename."""
  1573. if format == "compute_only":
  1574. assert self.pipeline_order is not None, (
  1575. "Compute only schedule must be available"
  1576. )
  1577. with open(filename, "w", newline="") as csvfile:
  1578. writer = csv.writer(csvfile)
  1579. for rank in self.pipeline_order:
  1580. writer.writerow(self.pipeline_order[rank])
  1581. elif format == "compute_comms":
  1582. assert self.pipeline_order_with_comms is not None, (
  1583. "Must initialize compute_comms schedule before dump_csv"
  1584. )
  1585. with open(filename, "w", newline="") as csvfile:
  1586. writer = csv.writer(csvfile)
  1587. for rank in self.pipeline_order_with_comms:
  1588. writer.writerow(self.pipeline_order_with_comms[rank])
  1589. def _simulate(self):
  1590. return _simulate_comms_compute(
  1591. self.pipeline_order_with_comms,
  1592. lambda s: self.stage_index_to_group_rank[s],
  1593. self._num_stages,
  1594. )
  1595. def _step_microbatches(
  1596. self,
  1597. arg_mbs: Optional[list] = None,
  1598. kwarg_mbs: Optional[list] = None,
  1599. target_mbs: Optional[list] = None,
  1600. losses: Optional[list] = None,
  1601. ):
  1602. """
  1603. Operate on the microbatches for looped schedules (multiple stages on each rank).
  1604. TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
  1605. not support models with skip connections.
  1606. """
  1607. arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
  1608. if not self._stages_initialized:
  1609. self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
  1610. # Based on the plan in Step 1 created in __init__:
  1611. # 2. Perform communication based on the pipeline_order
  1612. stage_index_to_stage: dict[int, _PipelineStageBase] = {
  1613. stage.stage_index: stage for stage in self._stages
  1614. }
  1615. assert self.pipeline_order_with_comms is not None, (
  1616. "Must call _prepare_schedule_with_comms() before calling _step_microbatches()"
  1617. )
  1618. # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
  1619. bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
  1620. fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
  1621. # send ops should be waited on before step() exists, mainly for hygiene
  1622. send_ops: list[list[dist.Work]] = []
  1623. # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
  1624. unshard_ops: dict[int, UnshardHandle] = {}
  1625. unsharded_stages = set()
  1626. def _assert_unsharded(stage_idx: int):
  1627. """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
  1628. if stage_idx in unshard_ops:
  1629. unshard_ops[stage_idx].wait()
  1630. del unshard_ops[stage_idx]
  1631. unsharded_stages.add(stage_idx)
  1632. assert stage_idx in unsharded_stages, (
  1633. f"Attempted to compute on sharded {stage_idx=}"
  1634. )
  1635. # count either full_backward or backward_weight together, to determine when to sync DP grads
  1636. backward_counter: Counter[int] = Counter()
  1637. for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
  1638. try:
  1639. comp_type = action.computation_type
  1640. mb_index: int = (
  1641. action.microbatch_index
  1642. if action.microbatch_index is not None
  1643. else -1
  1644. )
  1645. assert mb_index >= 0 or comp_type in (
  1646. UNSHARD,
  1647. RESHARD,
  1648. ), f"{action=} missing mb_index"
  1649. stage_idx = action.stage_index
  1650. stage = stage_index_to_stage[stage_idx]
  1651. stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
  1652. # see [Note: V-schedule special case]
  1653. is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage
  1654. is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage
  1655. logger.debug(
  1656. "_PipelineScheduleRuntime running time_step %d, action %s",
  1657. time_step,
  1658. action,
  1659. )
  1660. with record_function(_get_profiler_function_name(action)):
  1661. # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
  1662. # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
  1663. # safe to use instead.
  1664. # However, I was wondering if I should avoid calling batched operators at all in the case that there is
  1665. # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
  1666. if comp_type == SEND_F:
  1667. send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
  1668. elif comp_type == SEND_B:
  1669. send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
  1670. elif comp_type == RECV_F:
  1671. assert (
  1672. stage_idx,
  1673. mb_index,
  1674. ) not in fwd_recv_ops, (
  1675. "Recv twice for {stage_idx=} {mb_index=} without executing forward"
  1676. )
  1677. fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
  1678. stage.get_fwd_recv_ops(mb_index)
  1679. )
  1680. elif comp_type == RECV_B:
  1681. assert (
  1682. stage_idx,
  1683. mb_index,
  1684. ) not in bwd_recv_ops, (
  1685. "Recv twice for {stage_idx=} {mb_index=} without executing backward"
  1686. )
  1687. bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
  1688. stage.get_bwd_recv_ops(mb_index)
  1689. )
  1690. elif comp_type == UNSHARD:
  1691. if stage_uses_fsdp:
  1692. assert (
  1693. stage_idx not in unsharded_stages
  1694. and stage_idx not in unshard_ops
  1695. ), f"Unsharding the same {stage_idx=} twice"
  1696. unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
  1697. elif comp_type == RESHARD:
  1698. if stage_uses_fsdp:
  1699. assert stage_idx in unsharded_stages, (
  1700. f"Resharding {stage_idx=} without unsharding"
  1701. )
  1702. assert stage_idx not in unshard_ops, (
  1703. f"Resharding {stage_idx=} before finishing unshard"
  1704. )
  1705. stage.submod.reshard() # type: ignore[operator]
  1706. elif comp_type == FORWARD:
  1707. if stage_uses_fsdp:
  1708. _assert_unsharded(stage_idx)
  1709. if (
  1710. not stage.is_first
  1711. # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
  1712. and not is_prev_stage_on_this_rank
  1713. ):
  1714. assert (
  1715. stage_idx,
  1716. mb_index,
  1717. ) in fwd_recv_ops, (
  1718. f"Computing {action=} before receiving input"
  1719. )
  1720. _wait_batch_p2p(fwd_recv_ops.pop((stage_idx, mb_index)))
  1721. output = stage.forward_one_chunk(
  1722. mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
  1723. )
  1724. self._maybe_compute_loss(stage, output, target_mbs, mb_index)
  1725. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1726. # see [Note: V-schedule special case]
  1727. if is_next_stage_on_this_rank:
  1728. stage_index_to_stage[stage_idx + 1].set_local_fwd_input(
  1729. output, mb_index
  1730. )
  1731. elif comp_type == FULL_BACKWARD:
  1732. if stage_uses_fsdp:
  1733. _assert_unsharded(stage_idx)
  1734. if (
  1735. not stage.is_last
  1736. # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
  1737. and not is_next_stage_on_this_rank
  1738. ):
  1739. assert (
  1740. stage_idx,
  1741. mb_index,
  1742. ) in bwd_recv_ops, (
  1743. f"Attempted to run compute {action=} before receiving input"
  1744. )
  1745. _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index)))
  1746. loss = self._maybe_get_loss(stage, mb_index)
  1747. backward_counter[stage_idx] += 1
  1748. last_backward = (
  1749. backward_counter[stage_idx] == self._n_microbatches
  1750. )
  1751. grad_scale_factor = (
  1752. self._n_microbatches if self.scale_grads else 1
  1753. )
  1754. stage.backward_one_chunk(
  1755. mb_index,
  1756. loss=loss,
  1757. full_backward=True,
  1758. last_backward=last_backward,
  1759. )
  1760. if last_backward:
  1761. stage.scale_grads(grad_scale_factor)
  1762. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1763. # see [Note: V-schedule special case]
  1764. if is_prev_stage_on_this_rank:
  1765. stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
  1766. stage.get_local_bwd_output(mb_index), mb_index
  1767. )
  1768. elif comp_type == BACKWARD_INPUT:
  1769. if stage_uses_fsdp:
  1770. _assert_unsharded(stage_idx)
  1771. if not stage.is_last and not is_next_stage_on_this_rank:
  1772. assert (
  1773. stage_idx,
  1774. mb_index,
  1775. ) in bwd_recv_ops, (
  1776. f"Attempted to run compute {action=} before receiving input"
  1777. )
  1778. _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index)))
  1779. loss = self._maybe_get_loss(stage, mb_index)
  1780. stage.backward_one_chunk(
  1781. mb_index,
  1782. loss=loss,
  1783. full_backward=False,
  1784. last_backward=False,
  1785. )
  1786. # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
  1787. # see [Note: V-schedule special case]
  1788. if is_prev_stage_on_this_rank:
  1789. stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
  1790. stage.get_local_bwd_output(mb_index), mb_index
  1791. )
  1792. elif comp_type == BACKWARD_WEIGHT:
  1793. if stage_uses_fsdp:
  1794. _assert_unsharded(stage_idx)
  1795. backward_counter[stage_idx] += 1
  1796. stage.backward_weight_one_chunk(
  1797. mb_index,
  1798. last_backward=backward_counter[stage_idx]
  1799. == self._n_microbatches,
  1800. )
  1801. else:
  1802. raise ValueError(f"{action=} is unknown or unsupported")
  1803. except Exception as e:
  1804. logger.error(
  1805. "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
  1806. time_step,
  1807. action,
  1808. )
  1809. # TODO(whc) what is the best practice for printing a multiline log?
  1810. # logger will split it into multiple log lines, but this makes it hard to read (too wide)
  1811. print(
  1812. _format_pipeline_order(
  1813. self.pipeline_order_with_comms, # type: ignore[arg-type]
  1814. error_step_number=time_step,
  1815. )
  1816. )
  1817. raise e
  1818. # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
  1819. while len(send_ops):
  1820. _wait_batch_p2p(send_ops.pop())
  1821. assert len(unshard_ops) == 0, "Unused unshard operations"
  1822. # Return losses if there is a container passed in
  1823. self._update_losses(self._stages, losses)
  1824. class ScheduleLoopedBFS(PipelineScheduleMulti):
  1825. """
  1826. Breadth-First Pipeline Parallelism.
  1827. See https://arxiv.org/abs/2211.05953 for details.
  1828. Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
  1829. What is different is that when microbatches are ready for multiple local
  1830. stages, Loops BFS will prioritizes the earlier stage, running all available
  1831. microbatches at once.
  1832. """
  1833. def __init__(
  1834. self,
  1835. stages: list[_PipelineStageBase],
  1836. n_microbatches: int,
  1837. loss_fn: Optional[Union[Callable, _Loss]] = None,
  1838. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  1839. scale_grads: bool = True,
  1840. ):
  1841. super().__init__(
  1842. stages=stages,
  1843. n_microbatches=n_microbatches,
  1844. loss_fn=loss_fn,
  1845. output_merge_spec=output_merge_spec,
  1846. scale_grads=scale_grads,
  1847. )
  1848. # 1. Create the pipeline_order (all ranks do this calculation)
  1849. # This will be used to keep track of the current state of the entire pipeline
  1850. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  1851. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  1852. # ========================================================================
  1853. for rank in range(self.pp_group_size):
  1854. rank_ops = self._calculate_single_rank_operations(rank)
  1855. self.pipeline_order[rank] = rank_ops
  1856. def _calculate_single_rank_operations(self, rank):
  1857. n_local_stages = len(self._stages)
  1858. stage_indices = range(
  1859. rank, self.pp_group_size * n_local_stages, self.pp_group_size
  1860. )
  1861. # Store the list of operations used for that rank
  1862. # Pre-padding, rank starts with no-ops based on the warmup.
  1863. rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
  1864. for stage_index in stage_indices:
  1865. rank_ops.extend(
  1866. _Action(stage_index, _ComputationType.FORWARD, mb_index)
  1867. for mb_index in range(self._n_microbatches)
  1868. )
  1869. # wait for the first backward to trickle up
  1870. # which is 2 for every hop away
  1871. post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
  1872. rank_ops.extend([None] * post_warmup_ops)
  1873. for stage_index in reversed(stage_indices):
  1874. rank_ops.extend(
  1875. _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
  1876. for mb_index in reversed(range(self._n_microbatches))
  1877. )
  1878. return rank_ops
  1879. def _get_1f1b_rank_ops(
  1880. n_local_stages,
  1881. pp_group_size,
  1882. warmup_ops,
  1883. fwd_bwd_ops,
  1884. cooldown_ops,
  1885. rank,
  1886. forward_stage_index,
  1887. backward_stage_index,
  1888. num_1f1b_microbatches=0,
  1889. enable_zero_bubble=False,
  1890. ):
  1891. # All stages start with handling microbatch 0
  1892. fwd_stage_mb_index: dict[int, int] = defaultdict(int)
  1893. bwd_stage_mb_index: dict[int, int] = defaultdict(int)
  1894. weight_stage_mb_index: dict[int, int] = defaultdict(int)
  1895. # Store the list of operations used for that rank
  1896. # Pre-padding, rank starts with no-ops based on the warmup.
  1897. rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
  1898. # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
  1899. # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
  1900. # Formula:
  1901. # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
  1902. # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
  1903. # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
  1904. # warmup_ops = calculated above
  1905. post_warmup_ops = (
  1906. n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
  1907. ) - (warmup_ops + rank)
  1908. if enable_zero_bubble:
  1909. post_warmup_ops = pp_group_size - rank - 1
  1910. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  1911. backward_op_ids = []
  1912. weight_op_count = 0
  1913. FULL_BACKWARD_OR_BACKWARD_INPUT = (
  1914. BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD
  1915. )
  1916. for op in range(total_ops):
  1917. # Warmup phase
  1918. if op < warmup_ops:
  1919. fwd_stage_index = forward_stage_index(op)
  1920. # This will assign the current microbatch index and update it as well
  1921. fwd_stage_mb_index[fwd_stage_index] = (
  1922. mb_index := fwd_stage_mb_index[fwd_stage_index]
  1923. ) + 1
  1924. rank_ops.append(
  1925. _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
  1926. )
  1927. if op == warmup_ops - 1:
  1928. # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
  1929. rank_ops.extend([None] * post_warmup_ops)
  1930. # 1F1B Phase (forward and backward)
  1931. elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
  1932. fwd_stage_index = forward_stage_index(op)
  1933. fwd_stage_mb_index[fwd_stage_index] = (
  1934. fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
  1935. ) + 1
  1936. rank_ops.append(
  1937. _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
  1938. )
  1939. bwd_stage_index = backward_stage_index(op)
  1940. bwd_stage_mb_index[bwd_stage_index] = (
  1941. bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
  1942. ) + 1
  1943. rank_ops.append(
  1944. _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
  1945. )
  1946. backward_op_ids.append(op)
  1947. if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
  1948. weight_stage_index = backward_stage_index(
  1949. backward_op_ids[weight_op_count]
  1950. )
  1951. weight_stage_mb_index[weight_stage_index] = (
  1952. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  1953. ) + 1
  1954. rank_ops.append(
  1955. _Action(
  1956. weight_stage_index,
  1957. _ComputationType.BACKWARD_WEIGHT,
  1958. weight_mb_index,
  1959. )
  1960. )
  1961. weight_op_count += 1
  1962. # Cooldown phase
  1963. else:
  1964. # During cooldown phase, we need steps to align with 1f1b happening in other ranks
  1965. # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
  1966. if not enable_zero_bubble:
  1967. rank_ops.append(None)
  1968. bwd_stage_index = backward_stage_index(op)
  1969. bwd_stage_mb_index[bwd_stage_index] = (
  1970. bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
  1971. ) + 1
  1972. rank_ops.append(
  1973. _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
  1974. )
  1975. backward_op_ids.append(op)
  1976. if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
  1977. weight_stage_index = backward_stage_index(
  1978. backward_op_ids[weight_op_count]
  1979. )
  1980. weight_stage_mb_index[weight_stage_index] = (
  1981. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  1982. ) + 1
  1983. rank_ops.append(
  1984. _Action(
  1985. weight_stage_index,
  1986. _ComputationType.BACKWARD_WEIGHT,
  1987. weight_mb_index,
  1988. )
  1989. )
  1990. weight_op_count += 1
  1991. while enable_zero_bubble and weight_op_count < len(backward_op_ids):
  1992. weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
  1993. weight_stage_mb_index[weight_stage_index] = (
  1994. weight_mb_index := weight_stage_mb_index[weight_stage_index]
  1995. ) + 1
  1996. rank_ops.append(
  1997. _Action(
  1998. weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index
  1999. )
  2000. )
  2001. weight_op_count += 1
  2002. return rank_ops
  2003. class ScheduleInterleaved1F1B(PipelineScheduleMulti):
  2004. """
  2005. The Interleaved 1F1B schedule.
  2006. See https://arxiv.org/pdf/2104.04473 for details.
  2007. Will perform one forward and one backward on the microbatches in steady
  2008. state and supports multiple stages per rank. When microbatches are ready for
  2009. multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
  2010. (also called "depth first").
  2011. This schedule is mostly similar to the original paper.
  2012. It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
  2013. Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
  2014. it works as long as n_microbatches % num_rounds is 0. As a few examples, support
  2015. 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
  2016. 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
  2017. """
  2018. def __init__(
  2019. self,
  2020. stages: list[_PipelineStageBase],
  2021. n_microbatches: int,
  2022. loss_fn: Optional[Callable] = None,
  2023. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  2024. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  2025. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  2026. scale_grads: bool = True,
  2027. ):
  2028. self.pp_group_size = stages[0].group_size
  2029. super().__init__(
  2030. stages=stages,
  2031. n_microbatches=n_microbatches,
  2032. loss_fn=loss_fn,
  2033. args_chunk_spec=args_chunk_spec,
  2034. kwargs_chunk_spec=kwargs_chunk_spec,
  2035. output_merge_spec=output_merge_spec,
  2036. scale_grads=scale_grads,
  2037. )
  2038. self.n_local_stages = len(stages)
  2039. self.rank = stages[0].group_rank
  2040. self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
  2041. self.microbatches_per_round = n_microbatches // self.number_of_rounds
  2042. if n_microbatches % self.number_of_rounds != 0:
  2043. raise ValueError(
  2044. "Interleaved 1F1B requires the number of microbatches to be a "
  2045. f"multiple of the number of rounds ({self.number_of_rounds}), "
  2046. f"but got {n_microbatches}."
  2047. )
  2048. # 1. Create the pipeline_order (all ranks do this calculation)
  2049. # This will be used to keep track of the current state of the entire pipeline
  2050. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2051. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  2052. for rank in range(self.pp_group_size):
  2053. rank_ops = self._calculate_single_rank_operations(rank)
  2054. self.pipeline_order[rank] = rank_ops
  2055. def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
  2056. def get_rank_warmup_ops(rank):
  2057. # Warms up operations for last stage
  2058. warmups_ops_last_stage = (
  2059. self.n_local_stages - 1
  2060. ) * self.microbatches_per_round
  2061. # Increment warmup operations by 2 for each hop away from the last stage
  2062. multiply_factor = 2
  2063. warmup_ops = warmups_ops_last_stage + multiply_factor * (
  2064. (self.pp_group_size - 1) - rank
  2065. )
  2066. # We cannot have more warmup operations than there are number of microbatches, so cap it there
  2067. return min(warmup_ops, self._n_microbatches * self.n_local_stages)
  2068. warmup_ops = get_rank_warmup_ops(rank)
  2069. microbatch_ops = self.n_local_stages * self._n_microbatches
  2070. # fwd_bwd_ops should encompass the remaining forwards
  2071. fwd_bwd_ops = microbatch_ops - warmup_ops
  2072. # cooldown_ops should encompass the remaining backwards
  2073. cooldown_ops = microbatch_ops - fwd_bwd_ops
  2074. # total ops encompass both forward and backward ops
  2075. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  2076. # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
  2077. logger.debug(
  2078. "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
  2079. rank,
  2080. warmup_ops,
  2081. fwd_bwd_ops,
  2082. cooldown_ops,
  2083. total_ops,
  2084. )
  2085. # Calculates the stage index based on step and pp_group_size
  2086. def forward_stage_index(step):
  2087. # Get the local index from 0 to n_local_stages-1
  2088. local_index = (step // self.microbatches_per_round) % self.n_local_stages
  2089. return (local_index * self.pp_group_size) + rank
  2090. def backward_stage_index(step):
  2091. local_index = (
  2092. self.n_local_stages
  2093. - 1
  2094. - ((step - warmup_ops) // self.microbatches_per_round)
  2095. % self.n_local_stages
  2096. )
  2097. return (local_index * self.pp_group_size) + rank
  2098. return _get_1f1b_rank_ops(
  2099. self.n_local_stages,
  2100. self.pp_group_size,
  2101. warmup_ops,
  2102. fwd_bwd_ops,
  2103. cooldown_ops,
  2104. rank,
  2105. forward_stage_index,
  2106. backward_stage_index,
  2107. )
  2108. class ScheduleInterleavedZeroBubble(PipelineScheduleMulti):
  2109. """
  2110. The Interleaved Zero Bubble schedule.
  2111. See https://arxiv.org/pdf/2401.10241 for details.
  2112. Will perform one forward and one backward on inputs for the microbatches in steady
  2113. state and supports multiple stages per rank. Uses the backward for weights to fill in
  2114. the pipeline bubble.
  2115. In particular this is implementing the ZB1P schedule in the paper.
  2116. """
  2117. def __init__(
  2118. self,
  2119. stages: list[_PipelineStageBase],
  2120. n_microbatches: int,
  2121. loss_fn: Optional[Callable] = None,
  2122. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  2123. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  2124. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  2125. scale_grads: bool = True,
  2126. ):
  2127. # TODO: we don't support Zero Bubble with torch.compile so we
  2128. # should disable it for now
  2129. for stage in stages:
  2130. if isinstance(stage.submod, OptimizedModule):
  2131. raise RuntimeError(
  2132. "The Zero Bubble schedule is not supported with \
  2133. stage modules that have used torch.compile"
  2134. )
  2135. self.pp_group_size = stages[0].group_size
  2136. super().__init__(
  2137. stages=stages,
  2138. n_microbatches=n_microbatches,
  2139. loss_fn=loss_fn,
  2140. args_chunk_spec=args_chunk_spec,
  2141. kwargs_chunk_spec=kwargs_chunk_spec,
  2142. output_merge_spec=output_merge_spec,
  2143. scale_grads=scale_grads,
  2144. )
  2145. self.n_local_stages = len(stages)
  2146. self.rank = stages[0].group_rank
  2147. self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
  2148. self.microbatches_per_round = n_microbatches // self.number_of_rounds
  2149. if n_microbatches % self.number_of_rounds != 0:
  2150. raise ValueError(
  2151. "Zero bubble requires the number of microbatches to be a "
  2152. f"multiple of the number of rounds ({self.number_of_rounds}), "
  2153. f"but got {n_microbatches}."
  2154. )
  2155. # 1. Create the pipeline_order (all ranks do this calculation)
  2156. # This will be used to keep track of the current state of the entire pipeline
  2157. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2158. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  2159. for rank in range(self.pp_group_size):
  2160. rank_ops = self._calculate_single_rank_operations(rank)
  2161. self.pipeline_order[rank] = rank_ops
  2162. # This function add bubbles to the generated schedule based on dependencies of actions
  2163. # Note that the ZB1P schedule will not require bubbles to be manually added and it is
  2164. # only useful when n_microbatches <= microbatches_per_round
  2165. self.pipeline_order = self._add_bubbles_to_actions(
  2166. self.n_local_stages * self.pp_group_size,
  2167. )
  2168. def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
  2169. def get_rank_warmup_ops(rank):
  2170. # Warms up operations for last stage
  2171. warmups_ops_last_stage = (
  2172. self.n_local_stages - 1
  2173. ) * self.microbatches_per_round
  2174. # Increment warmup operations by 2 for each hop away from the last stage
  2175. multiply_factor = 1
  2176. warmup_ops = warmups_ops_last_stage + multiply_factor * (
  2177. (self.pp_group_size - 1) - rank
  2178. )
  2179. # We cannot have more warmup operations than there are number of microbatches, so cap it there
  2180. return min(warmup_ops, self._n_microbatches * self.n_local_stages)
  2181. warmup_ops = get_rank_warmup_ops(rank)
  2182. microbatch_ops = self.n_local_stages * self._n_microbatches
  2183. # fwd_bwd_ops should encompass the remaining forwards
  2184. fwd_bwd_ops = microbatch_ops - warmup_ops
  2185. # cooldown_ops should encompass the remaining backwards
  2186. cooldown_ops = microbatch_ops - fwd_bwd_ops
  2187. # total ops encompass both forward and backward ops
  2188. total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
  2189. # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
  2190. logger.debug(
  2191. "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
  2192. rank,
  2193. warmup_ops,
  2194. fwd_bwd_ops,
  2195. cooldown_ops,
  2196. total_ops,
  2197. )
  2198. # Calculates the stage index based on step and pp_group_size
  2199. def forward_stage_index(step):
  2200. # Get the local index from 0 to n_local_stages-1
  2201. local_index = (step // self.microbatches_per_round) % self.n_local_stages
  2202. return (local_index * self.pp_group_size) + rank
  2203. def backward_stage_index(step):
  2204. local_index = (
  2205. self.n_local_stages
  2206. - 1
  2207. - ((step - warmup_ops) // self.microbatches_per_round)
  2208. % self.n_local_stages
  2209. )
  2210. return (local_index * self.pp_group_size) + rank
  2211. num_1f1b_microbatches = rank
  2212. return _get_1f1b_rank_ops(
  2213. self.n_local_stages,
  2214. self.pp_group_size,
  2215. warmup_ops,
  2216. fwd_bwd_ops,
  2217. cooldown_ops,
  2218. rank,
  2219. forward_stage_index,
  2220. backward_stage_index,
  2221. num_1f1b_microbatches,
  2222. enable_zero_bubble=True,
  2223. )
  2224. def _add_bubbles_to_actions(self, num_stages_global):
  2225. actions = self.pipeline_order
  2226. def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
  2227. if op == _ComputationType.FORWARD:
  2228. if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
  2229. return True
  2230. elif op == _ComputationType.FULL_BACKWARD:
  2231. if stage == num_stages_global - 1:
  2232. return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
  2233. return (stage + 1, op, microbatch) not in seen_ops
  2234. return False
  2235. seen_ops: set[tuple[int, _ComputationType, int]] = set()
  2236. result: dict[int, list[Optional[_Action]]] = {}
  2237. next_pointer: dict[int, int] = {}
  2238. bubbles_added: dict[int, int] = {}
  2239. total_bubbles_added = 0
  2240. for rank in range(self.pp_group_size):
  2241. result[rank] = []
  2242. next_pointer[rank] = 0
  2243. bubbles_added[rank] = 0
  2244. while True:
  2245. should_stop = True
  2246. temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
  2247. for rank in range(self.pp_group_size):
  2248. timestamp = next_pointer[rank]
  2249. if timestamp >= len(actions[rank]):
  2250. continue
  2251. should_stop = False
  2252. if actions[rank][timestamp] is not None:
  2253. temp_action = actions[rank][timestamp]
  2254. assert temp_action is not None
  2255. stage_index, op, microbatch, _ = temp_action
  2256. if not need_bubble(
  2257. stage_index, op, microbatch, num_stages_global, seen_ops
  2258. ):
  2259. result[rank].append(actions[rank][timestamp])
  2260. if microbatch is not None:
  2261. temp_seen_ops.add((stage_index, op, microbatch))
  2262. next_pointer[rank] += 1
  2263. else:
  2264. result[rank].append(None)
  2265. bubbles_added[rank] += 1
  2266. else:
  2267. next_pointer[rank] += 1
  2268. result[rank].append(None)
  2269. seen_ops.update(temp_seen_ops)
  2270. if should_stop:
  2271. break
  2272. if total_bubbles_added > 0:
  2273. logger.warning(
  2274. "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
  2275. total_bubbles_added,
  2276. bubbles_added,
  2277. )
  2278. return result
  2279. class ScheduleZBVZeroBubble(PipelineScheduleMulti):
  2280. """
  2281. The Zero Bubble schedule (ZBV variant).
  2282. See https://arxiv.org/pdf/2401.10241 Section 6 for details.
  2283. This schedules requires exactly two stages per rank.
  2284. This schedule will perform one forward and one backward on inputs for the microbatches in steady
  2285. state and supports multiple stages per rank. Uses backward with respect to weights to fill in
  2286. the pipeline bubble.
  2287. This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights.
  2288. In practice, this is not likely true for real models so alternatively
  2289. a greedy scheduler could be implemented for unequal/unbalanced time.
  2290. """
  2291. def __init__(
  2292. self,
  2293. stages: list[_PipelineStageBase],
  2294. n_microbatches: int,
  2295. loss_fn: Optional[Callable] = None,
  2296. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  2297. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  2298. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  2299. scale_grads: bool = True,
  2300. ):
  2301. self.pp_group_size = stages[0].group_size
  2302. super().__init__(
  2303. stages=stages,
  2304. n_microbatches=n_microbatches,
  2305. loss_fn=loss_fn,
  2306. args_chunk_spec=args_chunk_spec,
  2307. kwargs_chunk_spec=kwargs_chunk_spec,
  2308. output_merge_spec=output_merge_spec,
  2309. scale_grads=scale_grads,
  2310. )
  2311. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  2312. self.pp_group_size, self._num_stages, style="v"
  2313. )
  2314. for stage in self._stages:
  2315. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  2316. self.n_local_stages = len(stages)
  2317. if self.n_local_stages != 2:
  2318. raise ValueError(
  2319. "ZBV requires exactly 2 stages per rank, but got "
  2320. f"{self.n_local_stages}."
  2321. )
  2322. self.rank = stages[0].group_rank
  2323. self.num_stages = stages[0].num_stages
  2324. # 1. Create the pipeline_order (all ranks do this calculation)
  2325. # This will be used to keep track of the current state of the entire pipeline
  2326. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2327. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  2328. for rank in range(self.pp_group_size):
  2329. rank_ops = self._calculate_single_rank_operations(rank)
  2330. self.pipeline_order[rank] = rank_ops
  2331. def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
  2332. # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
  2333. # as large of the number of microbatches needed to fully utilize the pipeline
  2334. n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
  2335. rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
  2336. # Forward and backward action counts for stage chunk 0 and chunk 1
  2337. f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
  2338. # warm-up phase
  2339. warmup_n1 = 2 * (self.pp_group_size - rank) - 1
  2340. stage_id_chunk0 = rank
  2341. stage_id_chunk1 = self.num_stages - 1 - rank
  2342. for _ in range(warmup_n1):
  2343. rank_ops.append(
  2344. _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
  2345. )
  2346. f0_cnt += 1
  2347. warmup_n2 = rank
  2348. for _ in range(warmup_n2):
  2349. rank_ops.append(
  2350. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2351. )
  2352. f1_cnt += 1
  2353. rank_ops.append(
  2354. _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
  2355. )
  2356. f0_cnt += 1
  2357. warmup_n3 = self.pp_group_size - rank
  2358. for _ in range(warmup_n3):
  2359. rank_ops.append(
  2360. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2361. )
  2362. f1_cnt += 1
  2363. rank_ops.append(
  2364. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2365. )
  2366. rank_ops.append(
  2367. _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
  2368. )
  2369. b1_cnt += 1
  2370. # stable phase
  2371. while f1_cnt < f0_cnt or f0_cnt < n_micro:
  2372. if f0_cnt < n_micro:
  2373. rank_ops.append(
  2374. _Action(
  2375. stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt
  2376. )
  2377. )
  2378. f0_cnt += 1
  2379. rank_ops.append(
  2380. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2381. )
  2382. rank_ops.append(
  2383. _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt)
  2384. )
  2385. b0_cnt += 1
  2386. rank_ops.append(
  2387. _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
  2388. )
  2389. f1_cnt += 1
  2390. rank_ops.append(
  2391. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2392. )
  2393. rank_ops.append(
  2394. _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
  2395. )
  2396. b1_cnt += 1
  2397. # cool-down phase
  2398. w0_cnt, w1_cnt = b0_cnt, b1_cnt
  2399. cooldown_n1 = rank
  2400. for _ in range(cooldown_n1):
  2401. rank_ops.append(
  2402. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2403. )
  2404. b0_cnt += 1
  2405. rank_ops.append(
  2406. _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
  2407. )
  2408. b1_cnt += 1
  2409. cooldown_n2 = self.pp_group_size - rank
  2410. for _ in range(cooldown_n2):
  2411. rank_ops.append(
  2412. _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
  2413. )
  2414. b0_cnt += 1
  2415. rank_ops.append(
  2416. _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
  2417. )
  2418. w0_cnt += 1
  2419. while w1_cnt < b1_cnt:
  2420. rank_ops.append(
  2421. _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt)
  2422. )
  2423. w1_cnt += 1
  2424. while w0_cnt < b0_cnt:
  2425. rank_ops.append(
  2426. _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
  2427. )
  2428. w0_cnt += 1
  2429. assert w0_cnt == b0_cnt and b0_cnt == f0_cnt
  2430. assert w1_cnt == b1_cnt and b1_cnt == f1_cnt
  2431. # We use max() in the n_micro computation above, so we may need to
  2432. # remove redundant microbatches
  2433. rank_ops = [
  2434. (
  2435. action
  2436. if action is not None
  2437. and action.microbatch_index is not None
  2438. and action.microbatch_index < self._n_microbatches
  2439. else None
  2440. )
  2441. for action in rank_ops
  2442. ]
  2443. return rank_ops
  2444. class ScheduleDualPipeV(_PipelineScheduleRuntime):
  2445. """
  2446. The DualPipeV schedule. A more efficient schedule variant based on the
  2447. DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
  2448. Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
  2449. """
  2450. def __init__(
  2451. self,
  2452. stages: list[_PipelineStageBase],
  2453. n_microbatches: int,
  2454. loss_fn: Optional[Callable] = None,
  2455. args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
  2456. kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
  2457. output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
  2458. scale_grads: bool = True,
  2459. ):
  2460. self.pp_group_size = stages[0].group_size
  2461. super().__init__(
  2462. stages=stages,
  2463. n_microbatches=n_microbatches,
  2464. loss_fn=loss_fn,
  2465. args_chunk_spec=args_chunk_spec,
  2466. kwargs_chunk_spec=kwargs_chunk_spec,
  2467. output_merge_spec=output_merge_spec,
  2468. scale_grads=scale_grads,
  2469. )
  2470. self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
  2471. self.pp_group_size, self._num_stages, style="v"
  2472. )
  2473. for stage in self._stages:
  2474. stage.stage_index_to_group_rank = self.stage_index_to_group_rank
  2475. self.n_local_stages = len(stages)
  2476. if self.n_local_stages != 2:
  2477. raise ValueError(
  2478. "ZBV requires exactly 2 stages per rank, but got "
  2479. f"{self.n_local_stages}."
  2480. )
  2481. if n_microbatches < self._num_stages:
  2482. raise ValueError(
  2483. "DualPipeV requires at least as many microbatches as stages, but got "
  2484. f"{n_microbatches} microbatches and {self._num_stages} stages."
  2485. )
  2486. self.rank = stages[0].group_rank
  2487. self.num_stages = stages[0].num_stages
  2488. # 1. Create the pipeline_order (all ranks do this calculation)
  2489. # This will be used to keep track of the current state of the entire pipeline
  2490. # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
  2491. self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
  2492. for rank in range(self.pp_group_size):
  2493. rank_ops = self._calculate_single_rank_operations(rank)
  2494. self.pipeline_order[rank] = rank_ops
  2495. # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
  2496. self._prepare_schedule_with_comms(self.pipeline_order)
  2497. def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
  2498. actions: list[Optional[_Action]] = []
  2499. counters: dict[
  2500. tuple[int, _ComputationType], int
  2501. ] = {} # (stage_index, computation_type) -> mb_index
  2502. weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
  2503. num_ranks = self.pp_group_size
  2504. num_chunks = self._n_microbatches
  2505. rank_to_stages = generate_rank_to_stage_mapping(
  2506. num_ranks, num_ranks * 2, style="v"
  2507. )
  2508. stage0_index, stage1_index = rank_to_stages[rank]
  2509. def increment_backward_counts(stage_index: int):
  2510. """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
  2511. input_key = (stage_index, BACKWARD_INPUT)
  2512. weight_key = (stage_index, BACKWARD_WEIGHT)
  2513. counters[input_key] = counters.get(input_key, 0) + 1
  2514. counters[weight_key] = counters.get(weight_key, 0) + 1
  2515. def add_overlap_f_b(
  2516. actions: list,
  2517. forward_stage: int,
  2518. backward_stage: int,
  2519. ):
  2520. """Helper method to add an overlapped forward+backward action which tracks microbatch index."""
  2521. # Create new overlapped forward+backward action with sub_actions
  2522. forward_key = (forward_stage, FORWARD)
  2523. backward_key = (backward_stage, BACKWARD_INPUT)
  2524. forward_mb = counters.get(forward_key, 0)
  2525. backward_mb = counters.get(backward_key, 0)
  2526. sub_actions = (
  2527. _Action(forward_stage, FORWARD, forward_mb),
  2528. _Action(backward_stage, FULL_BACKWARD, backward_mb),
  2529. )
  2530. actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
  2531. # Update counters for sub_actions
  2532. counters[forward_key] = forward_mb + 1
  2533. increment_backward_counts(backward_stage)
  2534. def add_action(
  2535. actions: list,
  2536. stage_index: int,
  2537. computation_type: _ComputationType,
  2538. ):
  2539. # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
  2540. key = (
  2541. (stage_index, computation_type)
  2542. if computation_type != FULL_BACKWARD
  2543. else (stage_index, BACKWARD_INPUT)
  2544. )
  2545. mb_index = counters.get(key, 0)
  2546. actions.append(_Action(stage_index, computation_type, mb_index))
  2547. # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
  2548. if computation_type == FULL_BACKWARD:
  2549. increment_backward_counts(stage_index)
  2550. else:
  2551. # If BACKWARD_INPUT is updated, add corresponding weight action to queue
  2552. if computation_type == BACKWARD_INPUT:
  2553. # Add weight action to queue for later processing
  2554. weight_queue.append((stage_index, mb_index))
  2555. counters[key] = mb_index + 1
  2556. def add_weight_action_if_pending(actions: list):
  2557. """Helper method to add a weight action from the queue."""
  2558. if not weight_queue:
  2559. return # No pending weight actions, skip
  2560. # Pop the oldest weight action from the queue
  2561. actual_stage_index, weight_mb_index = weight_queue.pop(0)
  2562. actions.append(
  2563. _Action(
  2564. actual_stage_index,
  2565. BACKWARD_WEIGHT,
  2566. weight_mb_index,
  2567. )
  2568. )
  2569. # Update the counter for the actual stage that was processed
  2570. weight_key = (actual_stage_index, BACKWARD_WEIGHT)
  2571. counters[weight_key] = counters.get(weight_key, 0) + 1
  2572. # Step 1: F0
  2573. step_1 = (num_ranks - rank - 1) * 2
  2574. for _ in range(step_1):
  2575. add_action(actions, stage0_index, FORWARD)
  2576. # Step 2: F0F1
  2577. step_2 = rank + 1
  2578. for _ in range(step_2):
  2579. add_action(actions, stage0_index, FORWARD)
  2580. add_action(actions, stage1_index, FORWARD)
  2581. # Step 3: I1W1F1 (Use zero bubble)
  2582. step_3 = num_ranks - rank - 1
  2583. for _ in range(step_3):
  2584. add_action(actions, stage1_index, BACKWARD_INPUT)
  2585. add_weight_action_if_pending(actions)
  2586. add_action(actions, stage1_index, FORWARD)
  2587. # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
  2588. step_4 = num_chunks - num_ranks * 2 + rank + 1
  2589. for i in range(step_4):
  2590. if i == 0 and rank == num_ranks - 1:
  2591. # NOTE: We don't overlap these two chunks to further reduce bubble size.
  2592. add_action(actions, stage0_index, FORWARD)
  2593. add_action(actions, stage1_index, FULL_BACKWARD)
  2594. else:
  2595. add_overlap_f_b(
  2596. actions,
  2597. forward_stage=stage0_index,
  2598. backward_stage=stage1_index,
  2599. )
  2600. add_overlap_f_b(
  2601. actions,
  2602. forward_stage=stage1_index,
  2603. backward_stage=stage0_index,
  2604. )
  2605. # Step 5: B1-F1B0
  2606. step_5 = num_ranks - rank - 1
  2607. for _ in range(step_5):
  2608. add_action(actions, stage1_index, FULL_BACKWARD)
  2609. add_overlap_f_b(
  2610. actions,
  2611. forward_stage=stage1_index,
  2612. backward_stage=stage0_index,
  2613. )
  2614. # Step 6: B1B0 (The second half of the chunks use zero bubble)
  2615. step_6 = rank + 1
  2616. enable_zb = False
  2617. for i in range(step_6):
  2618. if i == step_6 // 2 and rank % 2 == 1:
  2619. enable_zb = True
  2620. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2621. add_action(actions, stage1_index, comp_type)
  2622. if i == step_6 // 2 and rank % 2 == 0:
  2623. enable_zb = True
  2624. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2625. add_action(actions, stage0_index, comp_type)
  2626. # Step 7: W0B0
  2627. step_7 = num_ranks - rank - 1
  2628. for _ in range(step_7):
  2629. add_weight_action_if_pending(actions)
  2630. comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
  2631. add_action(actions, stage0_index, comp_type)
  2632. # Step 8: W0
  2633. step_8 = rank + 1
  2634. for _ in range(step_8):
  2635. add_weight_action_if_pending(actions)
  2636. return actions
  2637. def get_schedule_class(schedule_name: str):
  2638. """
  2639. Maps a schedule name (case insensitive) to its corresponding class object.
  2640. Args:
  2641. schedule_name (str): The name of the schedule.
  2642. """
  2643. schedule_map = {
  2644. "1F1B": Schedule1F1B,
  2645. "Interleaved1F1B": ScheduleInterleaved1F1B,
  2646. "GPipe": ScheduleGPipe,
  2647. "LoopedBFS": ScheduleLoopedBFS,
  2648. "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
  2649. "PipelineScheduleSingle": PipelineScheduleSingle,
  2650. "PipelineScheduleMulti": PipelineScheduleMulti,
  2651. "ZBVZeroBubble": ScheduleZBVZeroBubble,
  2652. "DualPipeV": ScheduleDualPipeV,
  2653. }
  2654. lowercase_keys = {k.lower(): k for k in schedule_map.keys()}
  2655. lowercase_schedule_name = schedule_name.lower()
  2656. if lowercase_schedule_name not in lowercase_keys:
  2657. raise ValueError(
  2658. f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}"
  2659. )
  2660. return schedule_map[lowercase_keys[lowercase_schedule_name]]
  2661. def _simulate_comms_compute(
  2662. pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
  2663. ):
  2664. """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags
  2665. any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank
  2666. can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used
  2667. as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number
  2668. of simulated steps.
  2669. The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams.
  2670. Future work may be to enhance this and model the compute time, comms overlap, and even memory.
  2671. """
  2672. pipeline_order = {
  2673. rank: [a for a in pipeline_order[rank] if a is not None]
  2674. for rank in sorted(pipeline_order)
  2675. }
  2676. _schedule: dict[int, list[_Action | None]] = {
  2677. rank: [] for rank in sorted(pipeline_order)
  2678. }
  2679. _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
  2680. def add_to_schedule(rank: int, action: Optional[_Action]):
  2681. _schedule[rank].append(action)
  2682. if action is not None:
  2683. _prev_ops_rank[rank].add(action)
  2684. def _ready_to_schedule(action: Optional[_Action]) -> bool:
  2685. if action is None:
  2686. return True
  2687. stage_idx = action.stage_index
  2688. prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)]
  2689. if action.computation_type == F:
  2690. if action.stage_index == 0:
  2691. return True
  2692. elif (
  2693. _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops
  2694. ):
  2695. return True
  2696. elif (
  2697. _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops
  2698. ):
  2699. return True
  2700. return False
  2701. elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
  2702. if action.stage_index == num_stages - 1:
  2703. return True
  2704. if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops:
  2705. return True
  2706. if (
  2707. _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
  2708. in prev_ops
  2709. ):
  2710. return True
  2711. if (
  2712. _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
  2713. in prev_ops
  2714. ):
  2715. return True
  2716. return False
  2717. elif action.computation_type == BACKWARD_WEIGHT:
  2718. return True
  2719. elif action.computation_type == SEND_F:
  2720. expected_f = _Action(action.stage_index, F, action.microbatch_index)
  2721. return expected_f in prev_ops
  2722. elif action.computation_type == RECV_F:
  2723. peer_stage_idx = stage_idx - 1
  2724. expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
  2725. return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
  2726. elif action.computation_type == SEND_B:
  2727. expected_b = _Action(
  2728. action.stage_index, BACKWARD_INPUT, action.microbatch_index
  2729. )
  2730. expected_bw = _Action(
  2731. action.stage_index, FULL_BACKWARD, action.microbatch_index
  2732. )
  2733. return expected_b in prev_ops or expected_bw in prev_ops
  2734. elif action.computation_type == RECV_B:
  2735. peer_stage_idx = stage_idx + 1
  2736. expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
  2737. return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
  2738. else:
  2739. raise ValueError(f"Unsupported action type {action}")
  2740. while pipeline_order:
  2741. progress = False
  2742. for rank in sorted(pipeline_order):
  2743. if len(pipeline_order[rank]) == 0:
  2744. continue
  2745. action = pipeline_order[rank][0]
  2746. if _ready_to_schedule(action):
  2747. if action is not None:
  2748. add_to_schedule(rank, action)
  2749. pipeline_order[rank].pop(0)
  2750. progress = True
  2751. else:
  2752. add_to_schedule(rank, None)
  2753. for i in sorted(pipeline_order, reverse=True):
  2754. if len(pipeline_order[i]) == 0:
  2755. del pipeline_order[i]
  2756. # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
  2757. # by one of the later ranks
  2758. for rank in sorted(pipeline_order):
  2759. if len(pipeline_order[rank]) == 0:
  2760. continue
  2761. if _schedule[rank][-1] is not None:
  2762. continue
  2763. action = pipeline_order[rank][0]
  2764. if _ready_to_schedule(action):
  2765. if action is not None:
  2766. _schedule[rank][-1] = action
  2767. _prev_ops_rank[rank].add(action)
  2768. pipeline_order[rank].pop(0)
  2769. for i in sorted(pipeline_order, reverse=True):
  2770. if len(pipeline_order[i]) == 0:
  2771. del pipeline_order[i]
  2772. if not progress:
  2773. print("WIP comms schedule:\n", _format_pipeline_order(_schedule))
  2774. for rank in pipeline_order:
  2775. print(f"{rank=} next action= {pipeline_order[rank][0]}")
  2776. raise ValueError("Schedule is not progressing")
  2777. return _schedule
  2778. def _dump_chrometrace(schedule, filename):
  2779. """
  2780. This function dumps a schedule IR into a chrometrace format so it can be visualized.
  2781. It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text.
  2782. As future work we may extend this to include more accurate heuristics for durations, or let users input durations,
  2783. add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute
  2784. as separate streams on the chrometrace view.
  2785. """
  2786. events = []
  2787. for rank in sorted(schedule):
  2788. for timestep, action in enumerate(schedule[rank]):
  2789. if action is None:
  2790. continue
  2791. events.append(
  2792. {
  2793. "name": str(action),
  2794. "cat": (
  2795. "computation"
  2796. if action.computation_type in (F, B, W)
  2797. else "communication"
  2798. ),
  2799. "ph": "X",
  2800. "pid": rank,
  2801. "tid": rank,
  2802. "ts": timestep,
  2803. "dur": 1,
  2804. }
  2805. )
  2806. import json
  2807. with open(filename, "w") as f:
  2808. json.dump({"traceEvents": events}, f)