| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303230423052306230723082309231023112312231323142315231623172318231923202321232223232324232523262327232823292330233123322333233423352336233723382339234023412342234323442345234623472348234923502351235223532354235523562357235823592360236123622363236423652366236723682369237023712372237323742375237623772378237923802381238223832384238523862387238823892390239123922393239423952396239723982399240024012402240324042405240624072408240924102411241224132414241524162417241824192420242124222423242424252426242724282429243024312432243324342435243624372438243924402441244224432444244524462447244824492450245124522453245424552456245724582459246024612462246324642465246624672468246924702471247224732474247524762477247824792480248124822483248424852486248724882489249024912492249324942495249624972498249925002501250225032504250525062507250825092510251125122513251425152516251725182519252025212522252325242525252625272528252925302531253225332534253525362537253825392540254125422543254425452546254725482549255025512552255325542555255625572558255925602561256225632564256525662567256825692570257125722573257425752576257725782579258025812582258325842585258625872588258925902591259225932594259525962597259825992600260126022603260426052606260726082609261026112612261326142615261626172618261926202621262226232624262526262627262826292630263126322633263426352636263726382639264026412642264326442645264626472648264926502651265226532654265526562657265826592660266126622663266426652666266726682669267026712672267326742675267626772678267926802681268226832684268526862687268826892690269126922693269426952696269726982699270027012702270327042705270627072708270927102711271227132714271527162717271827192720272127222723272427252726272727282729273027312732273327342735273627372738273927402741274227432744274527462747274827492750275127522753275427552756275727582759276027612762276327642765276627672768276927702771277227732774277527762777277827792780278127822783278427852786278727882789279027912792279327942795279627972798279928002801280228032804280528062807280828092810281128122813281428152816281728182819282028212822282328242825282628272828282928302831283228332834283528362837283828392840284128422843284428452846284728482849285028512852285328542855285628572858285928602861286228632864286528662867286828692870287128722873287428752876287728782879288028812882288328842885288628872888288928902891289228932894289528962897289828992900290129022903290429052906290729082909291029112912291329142915291629172918291929202921292229232924292529262927292829292930293129322933293429352936293729382939294029412942294329442945294629472948294929502951295229532954295529562957295829592960296129622963296429652966296729682969297029712972297329742975297629772978297929802981298229832984298529862987298829892990299129922993299429952996299729982999300030013002300330043005300630073008300930103011301230133014301530163017301830193020302130223023302430253026302730283029303030313032303330343035303630373038303930403041304230433044304530463047304830493050305130523053305430553056305730583059306030613062306330643065306630673068306930703071307230733074307530763077307830793080308130823083308430853086308730883089309030913092309330943095309630973098309931003101310231033104310531063107310831093110311131123113311431153116311731183119312031213122312331243125312631273128312931303131313231333134313531363137313831393140314131423143314431453146314731483149315031513152315331543155315631573158315931603161316231633164316531663167316831693170317131723173317431753176317731783179318031813182318331843185318631873188318931903191319231933194319531963197319831993200320132023203320432053206320732083209 |
- # mypy: allow-untyped-defs
- # Copyright (c) Meta Platforms, Inc. and affiliates
- import copy
- import csv
- import itertools
- import logging
- import re
- from abc import ABC, abstractmethod
- from collections import Counter, defaultdict
- from enum import Enum
- from functools import lru_cache
- from typing import Any, Callable, NamedTuple, Optional, Union
- import torch
- import torch.distributed as dist
- from torch._dynamo import OptimizedModule
- from torch.distributed.fsdp import FSDPModule, UnshardHandle
- from torch.nn.modules.loss import _Loss
- from torch.profiler import record_function
- from ._utils import generate_rank_to_stage_mapping, generate_stage_to_rank_mapping
- from .microbatch import merge_chunks, split_args_kwargs_into_chunks, TensorChunkSpec
- from .stage import _PipelineStageBase
- __all__ = [
- "get_schedule_class",
- "PipelineScheduleSingle",
- "PipelineScheduleMulti",
- "Schedule1F1B",
- "ScheduleGPipe",
- "ScheduleInterleaved1F1B",
- "ScheduleLoopedBFS",
- "ScheduleInterleavedZeroBubble",
- "ScheduleZBVZeroBubble",
- "ScheduleDualPipeV",
- ]
- logger = logging.getLogger(__name__)
- class _ComputationType(Enum):
- # TODO(whc) rename to _ActType?
- FORWARD = 1
- BACKWARD_INPUT = 2
- BACKWARD_WEIGHT = 3
- UNSHARD = 4
- RESHARD = 5
- SEND_F = 6
- RECV_F = 7
- SEND_B = 8
- RECV_B = 9
- FULL_BACKWARD = 10
- OVERLAP_F_B = 11
- def __str__(self):
- str_map = {
- _ComputationType.FORWARD: "F",
- _ComputationType.BACKWARD_INPUT: "I",
- _ComputationType.BACKWARD_WEIGHT: "W",
- _ComputationType.UNSHARD: "UNSHARD",
- _ComputationType.RESHARD: "RESHARD",
- _ComputationType.SEND_F: "SEND_F",
- _ComputationType.RECV_F: "RECV_F",
- _ComputationType.SEND_B: "SEND_B",
- _ComputationType.RECV_B: "RECV_B",
- _ComputationType.FULL_BACKWARD: "B",
- _ComputationType.OVERLAP_F_B: "OVERLAP_F_B",
- }
- return str_map[self]
- @staticmethod
- def from_str(action):
- if action == "F":
- return _ComputationType.FORWARD
- elif action == "I":
- return _ComputationType.BACKWARD_INPUT
- elif action == "W":
- return _ComputationType.BACKWARD_WEIGHT
- elif action == "UNSHARD":
- return _ComputationType.UNSHARD
- elif action == "RESHARD":
- return _ComputationType.RESHARD
- elif action == "SEND_F":
- return _ComputationType.SEND_F
- elif action == "RECV_F":
- return _ComputationType.RECV_F
- elif action == "SEND_B":
- return _ComputationType.SEND_B
- elif action == "RECV_B":
- return _ComputationType.RECV_B
- elif action == "B":
- return _ComputationType.FULL_BACKWARD
- elif action == "OVERLAP_F_B":
- return _ComputationType.OVERLAP_F_B
- else:
- raise RuntimeError(f"Invalid computation type {action}")
- FORWARD = _ComputationType.FORWARD
- BACKWARD_INPUT = _ComputationType.BACKWARD_INPUT
- BACKWARD_WEIGHT = _ComputationType.BACKWARD_WEIGHT
- UNSHARD = _ComputationType.UNSHARD
- RESHARD = _ComputationType.RESHARD
- SEND_F = _ComputationType.SEND_F
- RECV_F = _ComputationType.RECV_F
- SEND_B = _ComputationType.SEND_B
- RECV_B = _ComputationType.RECV_B
- FULL_BACKWARD = _ComputationType.FULL_BACKWARD
- OVERLAP_F_B = _ComputationType.OVERLAP_F_B
- # Convenience shorthand for compute actions only since they are used in 'simple schedule format'
- F = FORWARD
- I = BACKWARD_INPUT
- W = BACKWARD_WEIGHT
- B = FULL_BACKWARD
- # Helper to parse an action string like 1F0 into a tuple of (stage_index, computation_type, microbatch_index)
- _action_regex = re.compile(
- r"(\d+)(F|I|B|W|UNSHARD|RESHARD|SEND_F|RECV_F|SEND_B|RECV_B)(\d*)"
- )
- class _Action(NamedTuple):
- stage_index: int
- computation_type: _ComputationType
- microbatch_index: Optional[int] = None
- sub_actions: Optional[tuple["_Action", ...]] = None
- def __str__(self):
- return self.__repr__()
- def __repr__(self):
- if self.sub_actions is not None:
- # Use recursive repr for sub_actions
- sub_action_reprs = [repr(sub_action) for sub_action in self.sub_actions]
- return f"({';'.join(sub_action_reprs)}){self.computation_type}"
- else:
- repr_str = str(self.stage_index)
- repr_str += str(self.computation_type)
- if self.microbatch_index is not None:
- repr_str += str(self.microbatch_index)
- return repr_str
- @property
- def is_compute_op(self) -> bool:
- return self.computation_type in (
- FORWARD,
- FULL_BACKWARD,
- BACKWARD_INPUT,
- BACKWARD_WEIGHT,
- OVERLAP_F_B,
- )
- @staticmethod
- def from_str(action_string: str):
- """
- Reverse of __repr__
- String should be formatted as [stage][action type][(microbatch)]
- e.g. `2F0`, `1UNSHARD`, `3SEND_F1`
- """
- action_string = action_string.strip()
- if action_string == "":
- return None
- # Check for sub_actions format: [sub_action1;sub_action2;...]ComputationType
- if action_string.startswith("(") and ")" in action_string:
- # Find the closing bracket to separate sub_actions from computation type
- bracket_end = action_string.find(")")
- sub_part = action_string[
- 1:bracket_end
- ] # Remove '[' and get content before ']'
- computation_type_part = action_string[
- bracket_end + 1 :
- ] # Get part after ']'
- # Parse sub_actions
- sub_actions = []
- if sub_part.strip():
- for sub_str in sub_part.split(";"):
- sub_action = _Action.from_str(sub_str.strip())
- if sub_action is not None:
- sub_actions.append(sub_action)
- # For sub_actions format, we create an action with just the computation type
- # The stage_index and microbatch_index are not meaningful for the container action
- return _Action(
- stage_index=-1, # Placeholder, not meaningful for sub_actions container
- computation_type=_ComputationType.from_str(computation_type_part),
- microbatch_index=None,
- sub_actions=tuple(sub_actions) if sub_actions else None,
- )
- # Handle regular single action format
- if match := _action_regex.match(action_string):
- stage_index, computation_type, microbatch_index = match.groups()
- return _Action(
- int(stage_index),
- _ComputationType.from_str(computation_type),
- int(microbatch_index) if len(microbatch_index) else None,
- )
- elif action_string == "":
- return None
- raise RuntimeError(
- f"Invalid action string: {action_string}, should be formatted as [stage][action type][(microbatch)] e.g. 2F0"
- )
- @lru_cache
- def _get_profiler_function_name(action: _Action) -> str:
- return f"PP:{str(action)}"
- def _format_pipeline_order(
- pipeline_order: dict[int, list[Optional[_Action]]],
- error_step_number: Optional[int] = None,
- ) -> str:
- """
- Formats the pipeline order in a timestep (row) x rank (column) grid of actions
- and returns the formatted string.
- If `error_step_number` is passed in, an additional label will be added to signify which step
- that it is erroring on.
- """
- # don't mutate the original
- pipeline_order = copy.deepcopy(pipeline_order)
- # Replace None with ""
- for rank in pipeline_order:
- for i in range(len(pipeline_order[rank])):
- if pipeline_order[rank][i] is None:
- # TODO make a real 'None action' that prints as empty string and make mypy happy
- pipeline_order[rank][i] = "" # type: ignore[call-overload]
- # Calculate the maximum number of steps across all ranks
- num_steps = max(len(actions) for actions in pipeline_order.values())
- step_labels = [
- "Step " + str(i).zfill(len(str(num_steps - 1))) for i in range(num_steps)
- ]
- # Sorting the dictionary by keys and retrieving values in that order
- rank_actions = [
- pipeline_order.get(key, [""] * num_steps) for key in sorted(pipeline_order)
- ]
- # Transpose the list of lists (rows to columns)
- transposed_actions = list(itertools.zip_longest(*rank_actions, fillvalue=""))
- # Generate column labels for ranks
- num_ranks = len(pipeline_order)
- rank_labels = ["Rank " + str(i) for i in range(num_ranks)]
- # Calculate the maximum length of each column, considering labels
- max_lengths = [
- max(len(str(item)) if item is not None else 0 for item in col)
- for col in zip(step_labels, *transposed_actions)
- ]
- # Format the header row with rank labels
- header_row = " " * (len(step_labels[0]) + 2) + " ".join(
- f"{label:<{max_lengths[i]}}" for i, label in enumerate(rank_labels)
- )
- # Format each row with its corresponding label
- formatted_rows = [
- f"{label}: "
- + " ".join(f"{str(item):<{max_lengths[i]}}" for i, item in enumerate(row))
- + (
- " <-- ERROR HERE"
- if error_step_number is not None
- and int(label.split()[1]) == error_step_number
- else ""
- )
- for label, row in zip(step_labels, transposed_actions)
- ]
- # Join the rows into a single string
- formatted_table = header_row + "\n" + "\n".join(formatted_rows) + "\n"
- return formatted_table
- class _PipelineSchedule(ABC):
- def __init__(
- self,
- n_microbatches: int,
- loss_fn: Optional[Callable[..., torch.Tensor]] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- # From arguments
- self._n_microbatches = n_microbatches
- self._loss_fn = loss_fn
- # See documentation in `PipelineScheduleSingle` / `PipelineScheduleMulti`
- self.scale_grads = scale_grads
- # Chunking specification for positional inputs. (default: `None`)
- self._args_chunk_spec = args_chunk_spec
- # Chunking specification for keyword inputs. (default: `None`)
- self._kwargs_chunk_spec = kwargs_chunk_spec
- self._output_merge_spec = output_merge_spec
- """
- # args_chunk_spec and kwargs_chunk_spec specify how to chunk inputs.
- # They are used to convert batch to microbatches in `step(x)`. See
- # `TensorChunkSpec` for helper methods for creating them.
- """
- # Derived
- self._has_backward = self._loss_fn is not None
- # Holds the losses for each microbatch.
- self._internal_losses: list[torch.Tensor] = []
- logger.info("Using %s", self.__class__.__name__)
- def _maybe_compute_loss(self, stage, output, target_mbs, mb_index):
- if stage.is_last and self._loss_fn is not None:
- loss = self._compute_loss(output, target_mbs[mb_index]) # type: ignore[index]
- self._internal_losses.append(loss)
- def _maybe_get_loss(self, stage, mb_index):
- valid_index = 0 <= mb_index < len(self._internal_losses)
- if stage.is_last and self._loss_fn is not None and valid_index:
- return self._internal_losses[mb_index]
- elif len(self._internal_losses) != 0 and not valid_index:
- raise RuntimeError(
- f"Loss for microbatch {mb_index} is not available. "
- f"Available losses for microbatches: {self._internal_losses}"
- )
- else:
- return None
- def _update_losses(self, stages, losses):
- """
- Update the losses to those in the internal state
- """
- # if stages not a list turn into a list
- if not isinstance(stages, list):
- stages = [stages]
- contains_last_stage = any(stage.is_last for stage in stages)
- # Return losses if there is a container passed in
- if contains_last_stage and losses is not None:
- if len(self._internal_losses) != self._n_microbatches:
- raise RuntimeError(
- f"Expecting {self._n_microbatches} losses but got {len(self._internal_losses)}"
- )
- # Clean external container first
- losses.clear()
- # Copy internal losses to external container
- losses.extend(self._internal_losses)
- self._internal_losses.clear()
- @abstractmethod
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the schedule
- implementation.
- Args:
- microbatches: list of microbatch args.
- """
- raise NotImplementedError
- @abstractmethod
- def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- """
- raise NotImplementedError
- def eval(self, *args, target=None, losses: Optional[list] = None, **kwargs):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches, calling forward only.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target values for the loss function.
- losses: a list to store the losses for each microbatch.
- """
- # Save the original has_backward state
- original_has_backward = self._has_backward
- try:
- self._has_backward = False
- return self.step(*args, target=target, losses=losses, **kwargs)
- finally:
- # Restore the original state
- self._has_backward = original_has_backward
- def _check_inputs(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Pre-process/check inputs
- """
- def check_type_and_len(mbs, name: str):
- if not isinstance(mbs, list):
- raise TypeError(f"{name} must be a list but got a {type(mbs)}")
- if len(mbs) != self._n_microbatches:
- raise ValueError(
- f"Expecting {self._n_microbatches} {name} but got {len(mbs)}"
- )
- if arg_mbs is not None:
- check_type_and_len(arg_mbs, "arg_mbs")
- else:
- arg_mbs = [()] * self._n_microbatches
- if kwarg_mbs is not None:
- check_type_and_len(kwarg_mbs, "kwarg_mbs")
- else:
- kwarg_mbs = [{}] * self._n_microbatches
- if target_mbs is not None:
- check_type_and_len(target_mbs, "target_mbs")
- if losses is not None:
- if not isinstance(losses, list):
- raise TypeError(f"losses must be a list but got a {type(losses)}")
- return arg_mbs, kwarg_mbs
- def _compute_loss(self, output, target):
- return self._loss_fn(output, target) # type: ignore[misc]
- def _split_inputs(
- self,
- args: tuple[Any, ...],
- kwargs: Optional[dict[str, Any]] = None,
- ):
- """
- Splits a full-batch input into chunks (i.e. microbatches) and returns
- the chunks
- """
- if args or kwargs:
- args_split, kwargs_split = split_args_kwargs_into_chunks(
- args,
- kwargs,
- self._n_microbatches,
- self._args_chunk_spec,
- self._kwargs_chunk_spec,
- )
- return args_split, kwargs_split
- else:
- # Empty inputs (e.g. when called on middle stages)
- # Return a list of empty tuples/dicts with matching length as chunks
- return [()] * self._n_microbatches, [{}] * self._n_microbatches
- def _merge_outputs(self, output_chunks: list[Any]) -> Any:
- """
- Merge output chunks back to a batch state.
- If output_merge_spec is None, the utility will merge output chunks by dimension 0 (batch dim).
- """
- return merge_chunks(
- output_chunks,
- self._output_merge_spec,
- )
- def _batch_p2p(
- p2p_ops: list[dist.P2POp], desc: Optional[str] = None
- ) -> list[dist.Work]:
- """
- Simple wrapper over batch_isend_irecv from torch.distributed, which just adds a descriptive logger on top.
- """
- if len(p2p_ops) == 0:
- return []
- desc_str = f"{desc}, " if desc else ""
- logger.debug("batch_p2p %s%s", desc_str, p2p_ops)
- return dist.batch_isend_irecv(p2p_ops)
- def _sorted_batch_p2p(
- p2p_ops: list[dist.P2POp], desc: Optional[str] = None
- ) -> dict[int, list[dist.Work]]:
- """
- Sorts the list of P2P ops by the peer rank, and then calls
- batch_isend_irecv. Return a dictionary of works by peer rank. This function
- helps us avoid hangs in case of skip connections.
- """
- # Arrange p2p_ops by peer rank:
- # int is the peer rank;
- # List is the list of ops towards the peer
- ops_by_peer: dict[int, list[dist.P2POp]] = defaultdict(list)
- work_by_peer: dict[int, list[dist.Work]] = {}
- if len(p2p_ops) == 0:
- return work_by_peer
- # Classify the ops by peer rank
- for op in p2p_ops:
- ops_by_peer[op.peer].append(op)
- # Call batch_isend_irecv per peer, in sorted order of the peers (to avoid hangs)
- for peer, ops in sorted(ops_by_peer.items()):
- work_by_peer[peer] = _batch_p2p(ops, desc=desc)
- return work_by_peer
- def _wait_batch_p2p(work: list[dist.Work]):
- """
- Waits for a list of dist.Work (typically from _batch_p2p / _sorted_batch_p2p).
- """
- for w in work:
- w.wait()
- class PipelineScheduleSingle(_PipelineSchedule):
- """
- Base class for single-stage schedules.
- Implements the `step` method.
- Derived classes should implement `_step_microbatches`.
- Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
- should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
- or sum losses (scale_grads=False).
- """
- def __init__(
- self,
- stage: _PipelineStageBase,
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- # Init parent
- super().__init__(
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- # Self attributes
- self._stage = stage
- self._num_stages = stage.num_stages
- self._stage_initialized = False
- if n_microbatches < self._num_stages:
- raise ValueError(
- f"Number of microbatches ({n_microbatches}) must be greater than \
- or equal to the number of stages ({self._num_stages})."
- )
- self.pipeline_order: Optional[dict[int, list[Optional[_Action]]]] = (
- self._get_pipeline_order()
- )
- def _initialize_stage(self, args, kwargs):
- # Prepare the communication needed for the pipeline schedule execution
- # This is needed because during execution we always perform a series of batch P2P ops
- # The first call of the batched P2P needs to involve the global group
- all_ops: list[dist.P2POp] = []
- all_ops.extend(self._stage._get_init_p2p_neighbors_ops())
- _wait_batch_p2p(_batch_p2p(all_ops))
- self._stage._prepare_forward_infra(self._n_microbatches, args, kwargs)
- if self._has_backward:
- self._stage._prepare_backward_infra(self._n_microbatches)
- self._stage_initialized = True
- def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- """
- if self._has_backward and not torch.is_grad_enabled():
- raise RuntimeError(
- "step() requires gradients to be enabled for backward computation; "
- "it should not be used under torch.no_grad() context. "
- "Please call eval() instead."
- )
- # Set the same has_backward flag for stage object
- self._stage.has_backward = self._has_backward
- # Clean per iteration
- self._stage.clear_runtime_states()
- # Split inputs into microbatches
- args_split, kwargs_split = self._split_inputs(args, kwargs)
- # Split target into microbatches
- if target is not None:
- targets_split = list(torch.tensor_split(target, self._n_microbatches))
- else:
- targets_split = None
- # Run microbatches
- self._step_microbatches(args_split, kwargs_split, targets_split, losses)
- # Return merged results per original format
- if self._stage.is_last:
- return self._merge_outputs(self._stage.output_chunks)
- else:
- return None
- def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
- """
- Returns the pipeline execution order as a schedule IR.
- The returned IR is a dictionary mapping rank IDs to lists of actions.
- Each action is either an _Action object representing computation to perform,
- or None representing a deliberate idle step.
- The None values are used to represent pipeline bubbles where a rank
- must wait for dependencies from other ranks before proceeding. However
- during execution, with the _PipelineScheduleRuntime, these Nones are
- skipped since the relevant communication (send/recv) will be scheduled and waited on.
- Returns:
- A dictionary mapping rank -> list of actions
- """
- return None
- class _ScheduleForwardOnly(PipelineScheduleSingle):
- """
- The forward-only schedule.
- Will go through all the microbatches and perform only the forward pass
- """
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Run one iteration of the pipeline schedule
- """
- if target_mbs is not None or losses is not None:
- raise RuntimeError(
- "Forward-only schedule does not support loss computation"
- )
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- if not self._stage_initialized:
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Delay send waits
- fwd_sends_to_wait: list[list[dist.Work]] = []
- # Run microbatches
- for i in range(self._n_microbatches):
- with record_function(f"Forward {i}"):
- ops = self._stage.get_fwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
- ops = self._stage.get_fwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_send")
- fwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
- # Wait for all forward sends to finish
- # This should not have performance impact because by the time the first
- # backward arrives all the forward sends should have been finished.
- for work in fwd_sends_to_wait:
- _wait_batch_p2p(work)
- class ScheduleGPipe(PipelineScheduleSingle):
- """
- The GPipe schedule.
- Will go through all the microbatches in a fill-drain manner.
- """
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the GPipe schedule.
- Args:
- microbatches: list of microbatch args.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- if not self._stage_initialized:
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Delay send waits
- fwd_sends_to_wait: list[list[dist.Work]] = []
- # Run microbatches
- for i in range(self._n_microbatches):
- with record_function(f"Forward {i}"):
- ops = self._stage.get_fwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- output = self._stage.forward_one_chunk(i, arg_mbs[i], kwarg_mbs[i]) # type: ignore[index]
- ops = self._stage.get_fwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="fwd_send")
- fwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Forwarded microbatch %s", self._stage.stage_index, i)
- self._maybe_compute_loss(self._stage, output, target_mbs, i)
- # Wait for all forward sends to finish
- # This should not have performance impact because by the time the first
- # backward arrives all the forward sends should have been finished.
- for work in fwd_sends_to_wait:
- _wait_batch_p2p(work)
- # Run backward
- # Delay send waits
- bwd_sends_to_wait: list[list[dist.Work]] = []
- for i in range(self._n_microbatches):
- with record_function(f"Backward {i}"):
- ops = self._stage.get_bwd_recv_ops(i)
- works = _sorted_batch_p2p(ops, desc="bwd_recv")
- for work in works.values():
- _wait_batch_p2p(work)
- loss = self._maybe_get_loss(self._stage, i)
- self._stage.backward_one_chunk(
- i,
- loss=loss,
- last_backward=i == self._n_microbatches - 1,
- )
- ops = self._stage.get_bwd_send_ops(i)
- works = _sorted_batch_p2p(ops, desc="bwd_send")
- bwd_sends_to_wait.extend(works.values())
- logger.debug("[%s] Backwarded microbatch %s", self._stage.stage_index, i)
- self._stage.scale_grads(
- grad_scale_factor=self._n_microbatches if self.scale_grads else 1
- )
- # Wait for all backward sends to finish
- for work in bwd_sends_to_wait:
- _wait_batch_p2p(work)
- # Update losses if there is a container passed in
- self._update_losses(self._stage, losses)
- def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
- """
- Returns the pipeline order for GPipe schedule.
- See base method in PipelineScheduleSingle for details on the schedule IR format.
- """
- pipeline_order = {}
- pp_group_size = self._num_stages
- for rank in range(pp_group_size):
- actions: list[Optional[_Action]] = []
- # 1. Initial delay based on rank position
- warmup_delay = rank
- actions.extend([None] * warmup_delay)
- # 2. Forward passes for all microbatches
- for mb_idx in range(self._n_microbatches):
- actions.append(_Action(rank, _ComputationType.FORWARD, mb_idx))
- # 3. Wait period before backward passes can begin
- backward_delay = 3 * (pp_group_size - 1 - rank)
- actions.extend([None] * backward_delay)
- # 4. Backward passes for all microbatches
- for mb_idx in range(self._n_microbatches):
- actions.append(_Action(rank, _ComputationType.FULL_BACKWARD, mb_idx))
- pipeline_order[rank] = actions
- return pipeline_order
- class Schedule1F1B(PipelineScheduleSingle):
- """
- The 1F1B schedule.
- Will perform one forward and one backward on the microbatches in steady state.
- """
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Run one iteration of the pipeline schedule with list of microbatches.
- Will go through all the microbatches according to the 1F1B schedule.
- Args:
- microbatches: list of microbatch args.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- if not self._stage_initialized:
- self._initialize_stage(arg_mbs[0], kwarg_mbs[0])
- # Last stage has 1 warmup, second-to-last 2 warmups, ...
- # first stage `num_stages` warmups
- warmup_chunks = min(
- self._n_microbatches,
- self._num_stages - self._stage.stage_index,
- )
- # Chunk counters
- fwd_mb_index = 0
- bwd_mb_index = 0
- # Warmup phase
- send_work: list[dist.Work] = []
- fwd_sends = []
- for _ in range(warmup_chunks):
- # Receive activations
- fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
- _wait_batch_p2p(_batch_p2p(fwd_recvs, desc="fwd_recv"))
- # Compute
- output = self._stage.forward_one_chunk(
- fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
- ) # type: ignore[index]
- # Clear previous chunk's forward sends (hopefully they have well
- # finished, otherwise, we are heavily communication bound, in which
- # case it doesn't create a lot of benefit to compute next chunk
- # eagerly either)
- _wait_batch_p2p(send_work)
- # Send activations
- fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
- if fwd_mb_index != warmup_chunks - 1:
- # Safe to fire
- send_work = _batch_p2p(fwd_sends, desc="fwd_send")
- # otherwise:
- # The last forward send is left for fuse with first 1B in 1B1F below
- # Compute loss
- self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
- fwd_mb_index += 1
- # Now we should have send ops left over, to be fused with first 1B of 1B1F phase below.
- # 1B1F phase
- while True: # Don't worry, we have a break inside
- # We actually do 1B first as the `1B1F` name indicates, so prepare its recv ops
- bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
- # Now, we need to fire the fwd_sends and bwd_recvs together
- _wait_batch_p2p(_batch_p2p(fwd_sends + bwd_recvs, desc="fwd_send_bwd_recv"))
- # Backward one chunk
- loss = self._maybe_get_loss(self._stage, bwd_mb_index)
- self._stage.backward_one_chunk(
- bwd_mb_index,
- loss=loss,
- last_backward=bwd_mb_index == self._n_microbatches - 1,
- )
- # Get the bwd send ops, but don't fire, to be fused with the 1F below
- bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
- bwd_mb_index += 1
- if fwd_mb_index == self._n_microbatches:
- # We are done with 1B1F, so break with some left-over bwd_sends
- break
- # We prepare 1F of the `1B1F`
- fwd_recvs = self._stage.get_fwd_recv_ops(fwd_mb_index)
- # Fuse it with bwd_sends above
- _wait_batch_p2p(_batch_p2p(bwd_sends + fwd_recvs, desc="bwd_send_fwd_recv"))
- # Now do the fwd
- output = self._stage.forward_one_chunk(
- fwd_mb_index, arg_mbs[fwd_mb_index], kwarg_mbs[fwd_mb_index]
- ) # type: ignore[index]
- # Compute loss
- self._maybe_compute_loss(self._stage, output, target_mbs, fwd_mb_index)
- # Get the fwd send ops, but don't fire, leave it for the next iter (wrap-around)
- fwd_sends = self._stage.get_fwd_send_ops(fwd_mb_index)
- fwd_mb_index += 1
- # Remember we still have some bwd_sends left over after the break? Now it is time to fire it
- send_work = _batch_p2p(bwd_sends, desc="bwd_send")
- # Cooldown
- while bwd_mb_index < self._n_microbatches:
- # prepare bwd recv ops
- bwd_recvs = self._stage.get_bwd_recv_ops(bwd_mb_index)
- _wait_batch_p2p(_batch_p2p(bwd_recvs, desc="bwd_recv"))
- # Backward one chunk
- loss = self._maybe_get_loss(self._stage, bwd_mb_index)
- self._stage.backward_one_chunk(
- bwd_mb_index,
- loss=loss,
- last_backward=bwd_mb_index == self._n_microbatches - 1,
- )
- # Clear previous chunk's backward sends (hopefully they have well finished)
- _wait_batch_p2p(send_work)
- # Get the bwd send ops, fire it
- bwd_sends = self._stage.get_bwd_send_ops(bwd_mb_index)
- send_work = _batch_p2p(bwd_sends, desc="bwd_send")
- bwd_mb_index += 1
- self._stage.scale_grads(
- grad_scale_factor=self._n_microbatches if self.scale_grads else 1
- )
- # Wait for the last backward send to finish
- _wait_batch_p2p(send_work)
- # Return losses if there is a container passed in
- self._update_losses(self._stage, losses)
- def _get_pipeline_order(self) -> Optional[dict[int, list[Optional[_Action]]]]:
- """
- Returns the pipeline order for 1F1B schedule.
- See base method in PipelineScheduleSingle for details on the schedule IR format.
- """
- pipeline_order = {}
- pp_group_size = self._num_stages
- for rank in range(pp_group_size):
- actions: list[Optional[_Action]] = []
- # 1. Warmup phase: initial delay based on rank
- actions.extend([None] * rank)
- # 2. Initial forward passes before 1F1B phase
- num_forward = (pp_group_size - 1) - rank
- forward_mb = 0
- for i in range(num_forward):
- actions.append(_Action(rank, _ComputationType.FORWARD, i))
- forward_mb = i
- # 3. Wait for backward to be ready
- wait_for_1f1b = max(0, 2 * (pp_group_size - 1 - rank))
- actions.extend([None] * wait_for_1f1b)
- # 4. 1F1B steady state phase
- backward_mb = 0
- remaining_forward = self._n_microbatches - num_forward
- while remaining_forward > 0:
- # One forward
- forward_mb += 1
- actions.append(_Action(rank, _ComputationType.FORWARD, forward_mb))
- remaining_forward -= 1
- # One backward
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- # 5. Cooldown phase: remaining backward passes
- remaining_backward = self._n_microbatches - backward_mb
- while remaining_backward > 0:
- # Add None and backward actions in alternating pattern
- # based on distance from the last stage
- if (pp_group_size - rank) > 0:
- actions.append(None)
- # Decrement the wait counter only if we still have backward passes to do
- if remaining_backward > 0:
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- remaining_backward -= 1
- else:
- # If we're at the last stage, just add backward actions without None
- actions.append(
- _Action(rank, _ComputationType.FULL_BACKWARD, backward_mb)
- )
- backward_mb += 1
- remaining_backward -= 1
- pipeline_order[rank] = actions
- return pipeline_order
- def _add_unshard_reshard(
- compute_actions: list[Optional[_Action]],
- max_active_stages: int = 3,
- ) -> list[_Action]:
- """Given a basic schedule involving only compute actions (F,B,W,OVERLAP_F_B), add UNSHARD/RESHARD actions for FSDP.
- UNSHARD refers to fetching the full contents of an FSDP-sharded layer, requiring an all-gather operation.
- RESHARD does the opposite, releasing memory (but doing no communication)
- We abandon the "timestep lock" during lowering
- max_active_stages controls how many prefetches we allow. It should be measured in mb and tuneable but in practice
- 3 stages is probably the thing we want?
- (to account for having one f and one b active, and something else prefetching?)
- """
- def next_stage_indices(
- count: int, next_actions: list[Optional[_Action]]
- ) -> list[int]:
- """Remove duplicates (same stage, different microbatch), find next 'count' stages that will do compute."""
- seen: set[int] = set()
- ret: list[int] = []
- for a in next_actions:
- if a is not None:
- # Handle OVERLAP_F_B actions by checking their sub_actions
- if a.computation_type == OVERLAP_F_B and a.sub_actions is not None:
- for sub_action in a.sub_actions:
- if sub_action.stage_index not in seen:
- seen.add(sub_action.stage_index)
- ret.append(sub_action.stage_index)
- if len(ret) == count:
- break
- if len(ret) == count:
- break
- else:
- # Regular action
- if a.stage_index not in seen:
- seen.add(a.stage_index)
- ret.append(a.stage_index)
- if len(ret) == count:
- break
- return ret
- active_stages: set[int] = set()
- fsdp_aware_actions: list[_Action] = []
- def _unshard(stage_index: int):
- active_stages.add(stage_index)
- fsdp_aware_actions.append(_Action(stage_index, UNSHARD, None))
- def _reshard(stage_index: int):
- active_stages.remove(stage_index)
- fsdp_aware_actions.append(_Action(stage_index, RESHARD, None))
- for i, action in enumerate(compute_actions):
- if action is None:
- continue
- # We prefetch the next N stages we'll see, dropping existing stages to make room
- next_n = next_stage_indices(max_active_stages, compute_actions[i:])
- # Fetch needs to be ordered correctly, so don't use a set
- fetch = list(filter(lambda s: s not in active_stages, next_n))
- # Unclear what the best policy is for eviction, but we can maintain order so we do
- evict = list(filter(lambda s: s not in next_n, active_stages))
- # logger.debug(
- # "_add_unshard_reshard Step %d active: %s fetch %s, evict %s",
- # i,
- # active_stages,
- # fetch,
- # evict,
- # )
- for stage in evict:
- _reshard(stage)
- for stage in fetch:
- _unshard(stage)
- fsdp_aware_actions.append(action)
- return fsdp_aware_actions
- def _merge_bw(
- compute_actions: list[Optional[_Action]],
- ) -> list[_Action]:
- """Given a basic schedule involving only compute actions (F,I,W), merge adjacent I and W ops into B ops.
- (note: I = BACKWARD_INPUT, W = BACKWARD_WEIGHT, B = FULL_BACKWARD)
- B refers to running the whole backward (not separating grad_input and grad_weight), which can be more efficient
- in some cases.
- """
- merged_actions = []
- while compute_actions:
- action = compute_actions.pop(0)
- if action is None:
- continue
- # Remove any None actions and find the next non-None action
- while len(compute_actions) and compute_actions[0] is None:
- compute_actions.pop(0)
- # Get the next action if it exists
- next_action = compute_actions[0] if len(compute_actions) > 0 else None
- if (
- action.computation_type == BACKWARD_INPUT
- and next_action is not None
- and next_action.computation_type == BACKWARD_WEIGHT
- and action.stage_index == next_action.stage_index
- and action.microbatch_index == next_action.microbatch_index
- ):
- merged_actions.append(
- _Action(action.stage_index, FULL_BACKWARD, action.microbatch_index)
- )
- compute_actions.pop(0)
- else:
- merged_actions.append(action)
- return merged_actions
- def _add_send_recv(
- compute_actions: dict[int, list[_Action]],
- stage_to_rank: Callable[[int], int],
- num_stages: int,
- ) -> dict[int, list[_Action]]:
- """
- Transforms a compute-only schedule into a complete schedule with communication actions.
- """
- comm_actions: dict[int, list[_Action]] = {rank: [] for rank in compute_actions}
- prev_actions: dict[int, set[_Action]] = {rank: set() for rank in compute_actions}
- def _has_comms(action: _Action) -> bool:
- if action.computation_type == F:
- return action.stage_index != num_stages - 1 and stage_to_rank(
- action.stage_index + 1
- ) != stage_to_rank(action.stage_index)
- elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- return action.stage_index != 0 and stage_to_rank(
- action.stage_index - 1
- ) != stage_to_rank(action.stage_index)
- return False
- def _get_comms(action: _Action) -> tuple[_Action, _Action]:
- assert _has_comms(action), f"{action} is not a valid comm action"
- stage_idx = action.stage_index
- ctype = action.computation_type
- mb_idx = action.microbatch_index
- send = _Action(stage_idx, SEND_F if ctype == F else SEND_B, mb_idx)
- recv_stage_idx = stage_idx + 1 if ctype == F else stage_idx - 1
- recv = _Action(recv_stage_idx, RECV_F if ctype == F else RECV_B, mb_idx)
- return send, recv
- def _ready_to_schedule(
- action: Optional[_Action], prev_actions: set[_Action]
- ) -> bool:
- """We don't put our own recv ops in the schedule, we let a sender on another rank put our recv ops in place.
- This helps ensure a sane (non-hanging) ordering of sends and recvs.
- But it also means we might not be able to schedule our next compute action yet.
- """
- if action is None:
- return True
- elif action.computation_type == F and not action.stage_index == 0:
- if (
- _Action(action.stage_index, RECV_F, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index - 1, F, action.microbatch_index)
- in prev_actions
- ):
- return True
- return False
- elif (
- action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD)
- and not action.stage_index == num_stages - 1
- ):
- if (
- _Action(action.stage_index, RECV_B, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
- in prev_actions
- ):
- return True
- elif (
- _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
- in prev_actions
- ):
- return True
- return False
- else:
- return True
- # TODO: For now we are splitting OVERLAP_F_B into replacing it to
- # its forward and backward components
- # We need to figure out how to do the communication
- for rank in compute_actions:
- new_actions: list[_Action] = []
- for action in compute_actions[rank]:
- if action is not None and action.sub_actions is not None:
- # Replace OVERLAP_F_B action with its sub_actions
- new_actions.extend(action.sub_actions)
- else:
- new_actions.append(action)
- compute_actions[rank] = new_actions
- while compute_actions:
- progress = False
- # go in order of ranks even if dict keys aren't ordered
- for rank in sorted(compute_actions):
- assert len(compute_actions[rank]) > 0, (
- f"{rank=}, {len(compute_actions[rank])=}"
- )
- action = compute_actions[rank][0]
- if not _ready_to_schedule(action, prev_actions[rank]):
- continue
- if action is not None:
- comm_actions[rank].append(action)
- prev_actions[rank].add(action)
- if _has_comms(action):
- send, recv = _get_comms(action)
- # TODO we can avoid send/recv if the 2 stages are on the same rank.
- # should we avoid that in the runtime or here?
- comm_actions[rank].append(send)
- prev_actions[rank].add(send)
- comm_actions[stage_to_rank(recv.stage_index)].append(recv)
- prev_actions[stage_to_rank(recv.stage_index)].add(recv)
- compute_actions[rank].pop(0)
- if len(compute_actions[rank]) == 0:
- del compute_actions[rank]
- progress = True
- assert progress, "Malformed compute schedule, can't schedule sends/recvs"
- return comm_actions
- def _validate_schedule(
- actions: dict[int, list[Optional[_Action]]],
- pp_group_size: int,
- num_stages: int,
- num_microbatches: int,
- ) -> dict[int, int]:
- assert len(actions) == pp_group_size, (
- f"Schedule has incorrect number of ranks - expected {pp_group_size}, actual {len(actions)}"
- )
- for rank in range(pp_group_size):
- assert rank in actions, f"Schedule is missing actions for rank {rank}"
- # We will count all the actions per stage and ensure they happen in a valid order
- # (e.g. F before (B, I) before W for a given microbatch)
- stage_actions: dict[int, dict[_ComputationType, set]] = {
- stage_id: {
- F: set(),
- B: set(),
- I: set(),
- W: set(),
- }
- for stage_id in range(num_stages)
- }
- stage_index_to_rank_mapping = {}
- def _process_action(action: _Action, rank: int, step: int):
- """Process a single action and update stage_actions and stage_index_to_rank_mapping"""
- s_id = action.stage_index
- ctype = action.computation_type
- mb_id = action.microbatch_index
- if ctype == F:
- stage_actions[s_id][F].add(mb_id)
- elif ctype == B:
- if mb_id not in stage_actions[s_id][F]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Full Backward for stage {s_id}, "
- f"microbatch {mb_id} without first running Forward"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][B].add(mb_id)
- elif ctype == I:
- if mb_id not in stage_actions[s_id][F]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Backward Input for stage {s_id}, "
- f"microbatch {mb_id} without first running Forward"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][I].add(mb_id)
- elif ctype == W:
- if mb_id not in stage_actions[s_id][I]:
- error_msg = (
- f"Rank {rank}, step {step}: Running Backward Weight for stage {s_id}, "
- f"microbatch {mb_id} without first running Backward Input"
- )
- formatted_schedule = _format_pipeline_order(
- actions, error_step_number=step
- )
- full_error_msg = (
- f"{error_msg}\n\nFull pipeline schedule:\n{formatted_schedule}"
- )
- raise AssertionError(full_error_msg)
- stage_actions[s_id][W].add(mb_id)
- if s_id not in stage_index_to_rank_mapping:
- stage_index_to_rank_mapping[s_id] = rank
- else:
- existing_rank = stage_index_to_rank_mapping[s_id]
- assert rank == existing_rank, (
- f"Rank {rank}, step {step}: Stage {s_id} is assigned to both rank {rank} and rank {existing_rank}"
- )
- for rank in actions:
- for step, action in enumerate(actions[rank]):
- if action is None:
- continue
- assert isinstance(action, _Action), (
- f"Rank {rank}, step {step}: Got an invalid action: {action}, expected instance of _Action"
- )
- # Check if action has sub_actions
- if action.sub_actions is not None:
- # Process each sub_action instead of the main action
- for sub_action in action.sub_actions:
- _process_action(sub_action, rank, step)
- else:
- # Process the main action normally
- _process_action(action, rank, step)
- for s_id in stage_actions:
- f_mb = len(stage_actions[s_id][F])
- b_mb = len(stage_actions[s_id][B])
- i_mb = len(stage_actions[s_id][I])
- w_mb = len(stage_actions[s_id][W])
- assert f_mb == num_microbatches, (
- f"Got {f_mb} {F} microbatches for stage {s_id}, expected {num_microbatches}"
- )
- assert i_mb == w_mb, (
- f"Invalid backward microbatches for stage {s_id}: I and W must have equal counts, \
- but got I={i_mb}, W={w_mb}"
- )
- assert b_mb + (i_mb + w_mb) // 2 == num_microbatches, (
- f"Invalid backward microbatches for stage {s_id}: expected {num_microbatches} total backwards, \
- but got B={b_mb}, I={i_mb}, W={w_mb}"
- )
- return stage_index_to_rank_mapping
- class PipelineScheduleMulti(_PipelineSchedule):
- """
- Base class for multi-stage schedules.
- Implements the `step` method.
- Gradients are scaled by num_microbatches depending on the `scale_grads` argument, defaulting to True. This setting
- should match the configuration of your loss_fn, which may either average losses (scale_grads=True)
- or sum losses (scale_grads=False).
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- use_full_backward: Optional[bool] = None,
- scale_grads: bool = True,
- ):
- # Init parent
- super().__init__(
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- # Self attributes
- self._stages = stages
- self._num_stages = stages[0].num_stages
- self.pp_group_size = stages[0].group_size
- self.rank = stages[0].group_rank
- # Set the pipeline stage states
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self._stages_initialized = False
- # avoid putting a reference to 'self' inside the lambda, it creates a ref cycle
- has_loss: bool = self._loss_fn is not None
- self._should_compute_loss = lambda stage: stage.is_last and has_loss
- # This will be set during init of derived schedules
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- if use_full_backward is not None:
- logger.warning(
- "Deprecation warning: 'use_full_backward' is no longer supported. "
- "Simply stop passing it, and everything should still work fine."
- )
- def _initialize_stages(self, args: tuple[Any, ...], kwargs):
- # Prepare the communication needed for the pipeline schedule execution
- # This is needed because during execution we always perform a series of batch P2P ops
- # The first call of the batched P2P needs to involve the global group
- all_ops: list[dist.P2POp] = []
- for stage in self._stages:
- all_ops.extend(stage._get_init_p2p_neighbors_ops())
- _wait_batch_p2p(_batch_p2p(all_ops))
- # may be 'none' value (if this stage sends its output shapes to the next stage via P2P)
- # or real value (if this stage and next stage are on the same device)
- next_stage_args: tuple[Any, ...] = tuple()
- for stage in self._stages:
- if stage.is_first:
- next_stage_args = stage._prepare_forward_infra(
- self._n_microbatches, args, kwargs
- )
- else:
- next_stage_args = stage._prepare_forward_infra(
- self._n_microbatches, next_stage_args, kwargs
- )
- if self._has_backward:
- stage._prepare_backward_infra(self._n_microbatches)
- self._stages_initialized = True
- def _validate_and_set_stage_mapping(
- self, actions: dict[int, list[Optional[_Action]]]
- ) -> None:
- """
- Allocates the stage index to rank mapping which is needed for communication
- """
- self.stage_index_to_group_rank = _validate_schedule(
- actions,
- self.pp_group_size,
- self._num_stages,
- self._n_microbatches,
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- def _dump_csv(self, filename):
- """Dump a CSV representation of the schedule into a file with the provided filename."""
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order:
- writer.writerow(self.pipeline_order[rank])
- def _load_csv(self, filename, format="compute_only"):
- """Load a CSV representation of the schedule from a file with the provided filename.
- This API will most likely get renamed/refactored so is marked as internal for now.
- format must be "compute_only" for PipelineScheduleMulti.
- """
- assert format == "compute_only"
- with open(filename, newline="") as csvfile:
- reader = csv.reader(csvfile)
- for rank, row in enumerate(reader):
- self.pipeline_order[rank] = [_Action.from_str(s) for s in row]
- # Validates the order of the pipeline actions and infers the stage_to_rank_mapping.
- # This will overwrite the default stage_to_rank_mapping created in the constructor
- self._validate_and_set_stage_mapping(self.pipeline_order)
- def step(self, *args, target=None, losses: Optional[list] = None, **kwargs):
- """
- Run one iteration of the pipeline schedule with *whole-batch* input.
- Will chunk the input into microbatches automatically, and go through the
- microbatches according to the schedule implementation.
- args: positional arguments to the model (as in non-pipeline case).
- kwargs: keyword arguments to the model (as in non-pipeline case).
- target: target for the loss function.
- losses: a list to store the losses for each microbatch.
- """
- if self._has_backward and not torch.is_grad_enabled():
- raise RuntimeError(
- "step() requires gradients to be enabled for backward computation; "
- "it should not be used under torch.no_grad() context. "
- "Please call eval() instead."
- )
- # Set the same has_backward flag for stage object
- for stage in self._stages:
- stage.has_backward = self._has_backward
- # Clean per iteration
- for stage in self._stages:
- stage.clear_runtime_states()
- # Split inputs into microbatches
- args_split, kwargs_split = self._split_inputs(args, kwargs)
- # Split target into microbatches
- if target is not None:
- targets_split = list(torch.tensor_split(target, self._n_microbatches))
- else:
- targets_split = None
- # Run microbatches
- self._step_microbatches(args_split, kwargs_split, targets_split, losses)
- # Return merged results per original format
- for stage in self._stages:
- if stage.is_last:
- return self._merge_outputs(stage.output_chunks)
- # Does not contain the last stage
- return None
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Operate on the microbatches for looped schedules (multiple stages on each rank).
- TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
- not support models with skip connections.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- if not self._stages_initialized:
- self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
- # Based on the plan in Step 1 created in __init__:
- # 2. Perform communication based on the pipeline_order
- stage_index_to_stage: dict[int, _PipelineStageBase] = {
- stage.stage_index: stage for stage in self._stages
- }
- # determine prev_rank and next_rank based on which ranks are next to
- # the stages in the pipeline_order
- all_prev_ranks: set[int] = set()
- all_next_ranks: set[int] = set()
- for stage_index in stage_index_to_stage.keys():
- # TODO: assumption that stages only communicate from distances of +1/-1 (no skip connections)
- if stage_index > 0:
- all_prev_ranks.add(self.stage_index_to_group_rank[stage_index - 1])
- if stage_index < self._num_stages - 1:
- all_next_ranks.add(self.stage_index_to_group_rank[stage_index + 1])
- # count either full_backward or backward_weight together, to determine when to sync DP grads
- backward_counter: Counter[int] = Counter()
- for time_step, action in enumerate(self.pipeline_order[self.rank]):
- try:
- ops: list[dist.P2POp] = []
- if action is not None:
- computation_type = action.computation_type
- mb_index = action.microbatch_index
- stage_index = action.stage_index
- assert mb_index is not None, (
- "All currently supported action types require valid microbatch_index"
- )
- if computation_type == _ComputationType.FORWARD:
- # perform forward computation
- stage = stage_index_to_stage[stage_index]
- output = stage.forward_one_chunk(
- mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
- )
- self._maybe_compute_loss(stage, output, target_mbs, mb_index)
- ops.extend(stage.get_fwd_send_ops(mb_index))
- elif computation_type == _ComputationType.FULL_BACKWARD:
- # perform backward computation
- stage = stage_index_to_stage[stage_index]
- loss = self._maybe_get_loss(stage, mb_index)
- backward_counter[stage_index] += 1
- last_backward = (
- backward_counter[stage_index] == self._n_microbatches
- )
- grad_scale_factor = (
- self._n_microbatches if self.scale_grads else 1
- )
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=True,
- last_backward=last_backward,
- )
- if last_backward:
- stage.scale_grads(grad_scale_factor)
- ops.extend(stage.get_bwd_send_ops(mb_index))
- elif computation_type == _ComputationType.BACKWARD_INPUT:
- # perform backward computation
- stage = stage_index_to_stage[stage_index]
- loss = self._maybe_get_loss(stage, mb_index)
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=False,
- last_backward=False,
- )
- ops.extend(stage.get_bwd_send_ops(mb_index))
- elif computation_type == _ComputationType.BACKWARD_WEIGHT:
- # perform weight update
- stage = stage_index_to_stage[stage_index]
- backward_counter[stage_index] += 1
- last_backward = (
- backward_counter[stage_index] == self._n_microbatches
- )
- grad_scale_factor = (
- self._n_microbatches if self.scale_grads else 1
- )
- stage.backward_weight_one_chunk(
- mb_index,
- last_backward=last_backward,
- )
- if last_backward:
- stage.scale_grads(grad_scale_factor)
- else:
- raise ValueError(f"Unknown computation type {computation_type}")
- # Look at the neighboring ranks for this current timestep and determine whether
- # this current rank needs to do any recv communication
- for prev_rank in all_prev_ranks:
- prev_rank_ops = self.pipeline_order[prev_rank]
- prev_rank_action = None
- if time_step < len(prev_rank_ops):
- prev_rank_action = prev_rank_ops[time_step]
- if prev_rank_action is not None:
- computation_type = prev_rank_action.computation_type
- mb_index = prev_rank_action.microbatch_index
- stage_index = prev_rank_action.stage_index
- assert mb_index is not None, (
- "All currently supported action types require valid microbatch_index"
- )
- # Only handle sends for the forward from a previous rank
- if computation_type == _ComputationType.FORWARD:
- # If not the last stage, then receive fwd activations
- if stage_index + 1 in stage_index_to_stage:
- # TODO: We are assuming that stage will always receive from stage-1
- # however that is not necessarily true of get_fwd_recv_ops
- stage = stage_index_to_stage[stage_index + 1]
- ops.extend(stage.get_fwd_recv_ops(mb_index))
- elif computation_type in (
- FULL_BACKWARD,
- BACKWARD_INPUT,
- BACKWARD_WEIGHT,
- ):
- # Previous rank doing backward has no influence for the current rank forward recv
- pass
- else:
- raise ValueError(
- f"Unknown computation type {computation_type}"
- )
- for next_rank in all_next_ranks:
- next_rank_ops = self.pipeline_order[next_rank]
- next_rank_action = None
- if time_step < len(next_rank_ops):
- next_rank_action = next_rank_ops[time_step]
- if next_rank_action is not None:
- computation_type = next_rank_action.computation_type
- mb_index = next_rank_action.microbatch_index
- stage_index = next_rank_action.stage_index
- assert mb_index is not None, (
- "All currently supported action types require valid microbatch_index"
- )
- # Only handle receives for the backwards from a next rank
- if computation_type in (FORWARD, BACKWARD_WEIGHT):
- # Next rank doing forward or weight update has no influence for the current rank backward recv
- pass
- elif computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- # If not the first stage, then receive bwd gradients
- if stage_index - 1 in stage_index_to_stage:
- # TODO: We are assuming that stage will always receive from stage+1
- # however that is not necessarily true of get_bwd_recv_ops
- stage = stage_index_to_stage[stage_index - 1]
- ops.extend(stage.get_bwd_recv_ops(mb_index))
- else:
- raise ValueError(
- f"Unknown computation type {computation_type}"
- )
- # do the communication
- _wait_batch_p2p(_batch_p2p(ops))
- except Exception as e:
- logger.error(
- "[Rank %s] pipeline schedule %s caught the following exception '%s' \
- at time_step %s when running action %s",
- self.rank,
- self.__class__.__name__,
- str(e),
- time_step,
- action,
- )
- logger.error(
- "%s",
- _format_pipeline_order(
- self.pipeline_order, error_step_number=time_step
- ),
- )
- raise e
- # Return losses if there is a container passed in
- self._update_losses(self._stages, losses)
- class _PipelineScheduleRuntime(PipelineScheduleMulti):
- """
- Provides a simple runtime that requires a 'schedule IR' including specified communication operations.
- Can be instantiated directly by creating _PipelineScheduleRuntime and calling load_csv, or can be
- subclassed and the subclass can be responsible for creating a schedule IR.
- """
- def _prepare_schedule_with_comms(
- self,
- actions: dict[int, list[Optional[_Action]]],
- format: str = "compute_only",
- ):
- """
- Given an in-memory representation for a simple compute-only schedule, lower it to a complex schedule including
- communication actions. Stores the schedule in self, and must be called before running step_mo()
- """
- # validate the provided actions are valid and overrides the default stage_index_to_group_rank
- super()._validate_and_set_stage_mapping(actions)
- self.pipeline_order_with_comms: dict[int, list[_Action]] = {}
- if format == "compute_comms":
- for rank in actions:
- self.pipeline_order_with_comms[rank] = []
- for action in actions[rank]:
- assert action is not None
- self.pipeline_order_with_comms[rank].append(action)
- # TODO what level of validation should we offer for compute+comms schedule?
- elif format == "compute_only":
- # Validate that the schedule does not have comms already added to it
- for rank, action_list in actions.items():
- for i, action in enumerate(action_list):
- if action is not None and not action.is_compute_op:
- raise ValueError(
- f"Expected compute-only schedule but found communication action "
- f"'{action}' at rank {rank}, position {i}. "
- f"Communication actions (e.g. SEND_F, RECV_F, etc.) "
- f"should not be present when format='compute_only'."
- )
- # Perform schedule lowering
- for rank in actions:
- self.pipeline_order_with_comms[rank] = _add_unshard_reshard(
- actions[rank]
- )
- self.pipeline_order_with_comms = _add_send_recv(
- self.pipeline_order_with_comms,
- stage_to_rank=lambda s: self.stage_index_to_group_rank[s],
- num_stages=self._num_stages,
- )
- else:
- raise NotImplementedError(f"{format=} is not implemented")
- def _load_csv(self, filename: str, format: str = "compute_only"):
- """Loads a csv in simple format and then lowers it to include communication actions
- format must be either "compute_only" or "compute_comms". If compute_only, the lowering passes
- will automatically be run to generate a compute_comms schedule.
- """
- if format == "compute_only":
- # this will populate self.pipeline_order
- super()._load_csv(filename)
- # this will populate self.pipeline_order_with_comms
- self._prepare_schedule_with_comms(self.pipeline_order)
- elif format == "compute_comms":
- actions = {}
- with open(filename, newline="") as csvfile:
- reader = csv.reader(csvfile)
- for rank, row in enumerate(reader):
- actions[rank] = [_Action.from_str(s) for s in row]
- self._prepare_schedule_with_comms(actions, format=format)
- else:
- raise NotImplementedError(f"{format=} is not implemented")
- def _dump_csv(self, filename: str, format: str = "compute_comms"):
- """Dump a CSV representation of the schedule into a file with the provided filename."""
- if format == "compute_only":
- assert self.pipeline_order is not None, (
- "Compute only schedule must be available"
- )
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order:
- writer.writerow(self.pipeline_order[rank])
- elif format == "compute_comms":
- assert self.pipeline_order_with_comms is not None, (
- "Must initialize compute_comms schedule before dump_csv"
- )
- with open(filename, "w", newline="") as csvfile:
- writer = csv.writer(csvfile)
- for rank in self.pipeline_order_with_comms:
- writer.writerow(self.pipeline_order_with_comms[rank])
- def _simulate(self):
- return _simulate_comms_compute(
- self.pipeline_order_with_comms,
- lambda s: self.stage_index_to_group_rank[s],
- self._num_stages,
- )
- def _step_microbatches(
- self,
- arg_mbs: Optional[list] = None,
- kwarg_mbs: Optional[list] = None,
- target_mbs: Optional[list] = None,
- losses: Optional[list] = None,
- ):
- """
- Operate on the microbatches for looped schedules (multiple stages on each rank).
- TODO: Does not use sorted_batch_isend_irecv(). As a result, this schedule does
- not support models with skip connections.
- """
- arg_mbs, kwarg_mbs = self._check_inputs(arg_mbs, kwarg_mbs, target_mbs, losses)
- if not self._stages_initialized:
- self._initialize_stages(arg_mbs[0], kwarg_mbs[0])
- # Based on the plan in Step 1 created in __init__:
- # 2. Perform communication based on the pipeline_order
- stage_index_to_stage: dict[int, _PipelineStageBase] = {
- stage.stage_index: stage for stage in self._stages
- }
- assert self.pipeline_order_with_comms is not None, (
- "Must call _prepare_schedule_with_comms() before calling _step_microbatches()"
- )
- # recv ops indexed by (stage_idx, mb_idx) need to be waited on before use
- bwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
- fwd_recv_ops: dict[tuple[int, int], list[dist.Work]] = {}
- # send ops should be waited on before step() exists, mainly for hygiene
- send_ops: list[list[dist.Work]] = []
- # we track which stages are 'active' when used with FSDP, and wait on unshard ops before computing on stages
- unshard_ops: dict[int, UnshardHandle] = {}
- unsharded_stages = set()
- def _assert_unsharded(stage_idx: int):
- """If an unshard is active for `stage_idx`, wait() it and mark `stage_idx` unshared."""
- if stage_idx in unshard_ops:
- unshard_ops[stage_idx].wait()
- del unshard_ops[stage_idx]
- unsharded_stages.add(stage_idx)
- assert stage_idx in unsharded_stages, (
- f"Attempted to compute on sharded {stage_idx=}"
- )
- # count either full_backward or backward_weight together, to determine when to sync DP grads
- backward_counter: Counter[int] = Counter()
- for time_step, action in enumerate(self.pipeline_order_with_comms[self.rank]):
- try:
- comp_type = action.computation_type
- mb_index: int = (
- action.microbatch_index
- if action.microbatch_index is not None
- else -1
- )
- assert mb_index >= 0 or comp_type in (
- UNSHARD,
- RESHARD,
- ), f"{action=} missing mb_index"
- stage_idx = action.stage_index
- stage = stage_index_to_stage[stage_idx]
- stage_uses_fsdp = isinstance(stage.submod, FSDPModule)
- # see [Note: V-schedule special case]
- is_next_stage_on_this_rank = stage_idx + 1 in stage_index_to_stage
- is_prev_stage_on_this_rank = stage_idx - 1 in stage_index_to_stage
- logger.debug(
- "_PipelineScheduleRuntime running time_step %d, action %s",
- time_step,
- action,
- )
- with record_function(_get_profiler_function_name(action)):
- # TODO(whc) it's not actually safe to use _batch_p2p here in the uncommon case the model has skip-connections,
- # since we do not want to batch up ops between more than a pair of ranks. _sorted_batch_p2p would be
- # safe to use instead.
- # However, I was wondering if I should avoid calling batched operators at all in the case that there is
- # only one operator per batch. I could iterate through the 'fwd_send_ops' one by one and run them.
- if comp_type == SEND_F:
- send_ops.append(_batch_p2p(stage.get_fwd_send_ops(mb_index)))
- elif comp_type == SEND_B:
- send_ops.append(_batch_p2p(stage.get_bwd_send_ops(mb_index)))
- elif comp_type == RECV_F:
- assert (
- stage_idx,
- mb_index,
- ) not in fwd_recv_ops, (
- "Recv twice for {stage_idx=} {mb_index=} without executing forward"
- )
- fwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
- stage.get_fwd_recv_ops(mb_index)
- )
- elif comp_type == RECV_B:
- assert (
- stage_idx,
- mb_index,
- ) not in bwd_recv_ops, (
- "Recv twice for {stage_idx=} {mb_index=} without executing backward"
- )
- bwd_recv_ops[(stage_idx, mb_index)] = _batch_p2p(
- stage.get_bwd_recv_ops(mb_index)
- )
- elif comp_type == UNSHARD:
- if stage_uses_fsdp:
- assert (
- stage_idx not in unsharded_stages
- and stage_idx not in unshard_ops
- ), f"Unsharding the same {stage_idx=} twice"
- unshard_ops[stage_idx] = stage.submod.unshard(async_op=True) # type: ignore[operator]
- elif comp_type == RESHARD:
- if stage_uses_fsdp:
- assert stage_idx in unsharded_stages, (
- f"Resharding {stage_idx=} without unsharding"
- )
- assert stage_idx not in unshard_ops, (
- f"Resharding {stage_idx=} before finishing unshard"
- )
- stage.submod.reshard() # type: ignore[operator]
- elif comp_type == FORWARD:
- if stage_uses_fsdp:
- _assert_unsharded(stage_idx)
- if (
- not stage.is_first
- # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
- and not is_prev_stage_on_this_rank
- ):
- assert (
- stage_idx,
- mb_index,
- ) in fwd_recv_ops, (
- f"Computing {action=} before receiving input"
- )
- _wait_batch_p2p(fwd_recv_ops.pop((stage_idx, mb_index)))
- output = stage.forward_one_chunk(
- mb_index, arg_mbs[mb_index], kwarg_mbs[mb_index]
- )
- self._maybe_compute_loss(stage, output, target_mbs, mb_index)
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_next_stage_on_this_rank:
- stage_index_to_stage[stage_idx + 1].set_local_fwd_input(
- output, mb_index
- )
- elif comp_type == FULL_BACKWARD:
- if stage_uses_fsdp:
- _assert_unsharded(stage_idx)
- if (
- not stage.is_last
- # no recv op expected for V-schedule special case (see [Note: V-schedule special case])
- and not is_next_stage_on_this_rank
- ):
- assert (
- stage_idx,
- mb_index,
- ) in bwd_recv_ops, (
- f"Attempted to run compute {action=} before receiving input"
- )
- _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index)))
- loss = self._maybe_get_loss(stage, mb_index)
- backward_counter[stage_idx] += 1
- last_backward = (
- backward_counter[stage_idx] == self._n_microbatches
- )
- grad_scale_factor = (
- self._n_microbatches if self.scale_grads else 1
- )
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=True,
- last_backward=last_backward,
- )
- if last_backward:
- stage.scale_grads(grad_scale_factor)
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_prev_stage_on_this_rank:
- stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
- stage.get_local_bwd_output(mb_index), mb_index
- )
- elif comp_type == BACKWARD_INPUT:
- if stage_uses_fsdp:
- _assert_unsharded(stage_idx)
- if not stage.is_last and not is_next_stage_on_this_rank:
- assert (
- stage_idx,
- mb_index,
- ) in bwd_recv_ops, (
- f"Attempted to run compute {action=} before receiving input"
- )
- _wait_batch_p2p(bwd_recv_ops.pop((stage_idx, mb_index)))
- loss = self._maybe_get_loss(stage, mb_index)
- stage.backward_one_chunk(
- mb_index,
- loss=loss,
- full_backward=False,
- last_backward=False,
- )
- # SEND/RECV op are avoided for special case with 2 adjacent stages on same rank
- # see [Note: V-schedule special case]
- if is_prev_stage_on_this_rank:
- stage_index_to_stage[stage_idx - 1].set_local_bwd_input(
- stage.get_local_bwd_output(mb_index), mb_index
- )
- elif comp_type == BACKWARD_WEIGHT:
- if stage_uses_fsdp:
- _assert_unsharded(stage_idx)
- backward_counter[stage_idx] += 1
- stage.backward_weight_one_chunk(
- mb_index,
- last_backward=backward_counter[stage_idx]
- == self._n_microbatches,
- )
- else:
- raise ValueError(f"{action=} is unknown or unsupported")
- except Exception as e:
- logger.error(
- "_PipelineScheduleRuntime caught exception at step %s when running action %s. Full Schedule:",
- time_step,
- action,
- )
- # TODO(whc) what is the best practice for printing a multiline log?
- # logger will split it into multiple log lines, but this makes it hard to read (too wide)
- print(
- _format_pipeline_order(
- self.pipeline_order_with_comms, # type: ignore[arg-type]
- error_step_number=time_step,
- )
- )
- raise e
- # Mostly these operations should have finished long ago, but there isn't an obvious time when to wait for them
- while len(send_ops):
- _wait_batch_p2p(send_ops.pop())
- assert len(unshard_ops) == 0, "Unused unshard operations"
- # Return losses if there is a container passed in
- self._update_losses(self._stages, losses)
- class ScheduleLoopedBFS(PipelineScheduleMulti):
- """
- Breadth-First Pipeline Parallelism.
- See https://arxiv.org/abs/2211.05953 for details.
- Similar to Interleaved 1F1B, Looped BFS supports multiple stages per rank.
- What is different is that when microbatches are ready for multiple local
- stages, Loops BFS will prioritizes the earlier stage, running all available
- microbatches at once.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Union[Callable, _Loss]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- # ========================================================================
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- def _calculate_single_rank_operations(self, rank):
- n_local_stages = len(self._stages)
- stage_indices = range(
- rank, self.pp_group_size * n_local_stages, self.pp_group_size
- )
- # Store the list of operations used for that rank
- # Pre-padding, rank starts with no-ops based on the warmup.
- rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
- for stage_index in stage_indices:
- rank_ops.extend(
- _Action(stage_index, _ComputationType.FORWARD, mb_index)
- for mb_index in range(self._n_microbatches)
- )
- # wait for the first backward to trickle up
- # which is 2 for every hop away
- post_warmup_ops = 2 * (self.pp_group_size - 1 - rank)
- rank_ops.extend([None] * post_warmup_ops)
- for stage_index in reversed(stage_indices):
- rank_ops.extend(
- _Action(stage_index, _ComputationType.FULL_BACKWARD, mb_index)
- for mb_index in reversed(range(self._n_microbatches))
- )
- return rank_ops
- def _get_1f1b_rank_ops(
- n_local_stages,
- pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- num_1f1b_microbatches=0,
- enable_zero_bubble=False,
- ):
- # All stages start with handling microbatch 0
- fwd_stage_mb_index: dict[int, int] = defaultdict(int)
- bwd_stage_mb_index: dict[int, int] = defaultdict(int)
- weight_stage_mb_index: dict[int, int] = defaultdict(int)
- # Store the list of operations used for that rank
- # Pre-padding, rank starts with no-ops based on the warmup.
- rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
- # These are used to calculate the number of slots to fill with no-ops, to account for the delay in warmup
- # when we want to wait for the backward to trickle back up and start 1f1b to align all ranks.
- # Formula:
- # pre-padding + warmup_ops + post_warmup_ops = earliest time step of first backward
- # post_warmup_ops = [earliest time step of first backward] - (warmup_ops + pre-padding)
- # earliest time step of first backward = [local_stages * group_size + 2 * (group_size - 1 - rank)]
- # warmup_ops = calculated above
- post_warmup_ops = (
- n_local_stages * pp_group_size + 2 * (pp_group_size - 1 - rank)
- ) - (warmup_ops + rank)
- if enable_zero_bubble:
- post_warmup_ops = pp_group_size - rank - 1
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- backward_op_ids = []
- weight_op_count = 0
- FULL_BACKWARD_OR_BACKWARD_INPUT = (
- BACKWARD_INPUT if enable_zero_bubble else FULL_BACKWARD
- )
- for op in range(total_ops):
- # Warmup phase
- if op < warmup_ops:
- fwd_stage_index = forward_stage_index(op)
- # This will assign the current microbatch index and update it as well
- fwd_stage_mb_index[fwd_stage_index] = (
- mb_index := fwd_stage_mb_index[fwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(fwd_stage_index, _ComputationType.FORWARD, mb_index)
- )
- if op == warmup_ops - 1:
- # This is the last step in the warmup phase, so we need to wait for the backward to trickle back up
- rank_ops.extend([None] * post_warmup_ops)
- # 1F1B Phase (forward and backward)
- elif warmup_ops <= op < warmup_ops + fwd_bwd_ops:
- fwd_stage_index = forward_stage_index(op)
- fwd_stage_mb_index[fwd_stage_index] = (
- fwd_mb_index := fwd_stage_mb_index[fwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(fwd_stage_index, _ComputationType.FORWARD, fwd_mb_index)
- )
- bwd_stage_index = backward_stage_index(op)
- bwd_stage_mb_index[bwd_stage_index] = (
- bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
- )
- backward_op_ids.append(op)
- if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
- weight_stage_index = backward_stage_index(
- backward_op_ids[weight_op_count]
- )
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index,
- _ComputationType.BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- weight_op_count += 1
- # Cooldown phase
- else:
- # During cooldown phase, we need steps to align with 1f1b happening in other ranks
- # TODO: we don't need to always append, after all 1f1b are finished we can stop appending None
- if not enable_zero_bubble:
- rank_ops.append(None)
- bwd_stage_index = backward_stage_index(op)
- bwd_stage_mb_index[bwd_stage_index] = (
- bwd_mb_index := bwd_stage_mb_index[bwd_stage_index]
- ) + 1
- rank_ops.append(
- _Action(bwd_stage_index, FULL_BACKWARD_OR_BACKWARD_INPUT, bwd_mb_index)
- )
- backward_op_ids.append(op)
- if enable_zero_bubble and op - warmup_ops >= num_1f1b_microbatches:
- weight_stage_index = backward_stage_index(
- backward_op_ids[weight_op_count]
- )
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index,
- _ComputationType.BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- weight_op_count += 1
- while enable_zero_bubble and weight_op_count < len(backward_op_ids):
- weight_stage_index = backward_stage_index(backward_op_ids[weight_op_count])
- weight_stage_mb_index[weight_stage_index] = (
- weight_mb_index := weight_stage_mb_index[weight_stage_index]
- ) + 1
- rank_ops.append(
- _Action(
- weight_stage_index, _ComputationType.BACKWARD_WEIGHT, weight_mb_index
- )
- )
- weight_op_count += 1
- return rank_ops
- class ScheduleInterleaved1F1B(PipelineScheduleMulti):
- """
- The Interleaved 1F1B schedule.
- See https://arxiv.org/pdf/2104.04473 for details.
- Will perform one forward and one backward on the microbatches in steady
- state and supports multiple stages per rank. When microbatches are ready for
- multiple local stages, Interleaved 1F1B prioritizes the earlier microbatch
- (also called "depth first").
- This schedule is mostly similar to the original paper.
- It differs by being relaxing the requirement of num_microbatch % pp_size == 0.
- Using the flex_pp schedule, we will have num_rounds = max(1, n_microbatches // pp_group_size) and
- it works as long as n_microbatches % num_rounds is 0. As a few examples, support
- 1. pp_group_size = 4, n_microbatches = 10. We will have num_rounds = 2 and n_microbatches % 2 is 0.
- 2. pp_group_size = 4, n_microbatches = 3. We will have num_rounds = 1 and n_microbatches % 1 is 0.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- self.n_local_stages = len(stages)
- self.rank = stages[0].group_rank
- self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
- self.microbatches_per_round = n_microbatches // self.number_of_rounds
- if n_microbatches % self.number_of_rounds != 0:
- raise ValueError(
- "Interleaved 1F1B requires the number of microbatches to be a "
- f"multiple of the number of rounds ({self.number_of_rounds}), "
- f"but got {n_microbatches}."
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
- def get_rank_warmup_ops(rank):
- # Warms up operations for last stage
- warmups_ops_last_stage = (
- self.n_local_stages - 1
- ) * self.microbatches_per_round
- # Increment warmup operations by 2 for each hop away from the last stage
- multiply_factor = 2
- warmup_ops = warmups_ops_last_stage + multiply_factor * (
- (self.pp_group_size - 1) - rank
- )
- # We cannot have more warmup operations than there are number of microbatches, so cap it there
- return min(warmup_ops, self._n_microbatches * self.n_local_stages)
- warmup_ops = get_rank_warmup_ops(rank)
- microbatch_ops = self.n_local_stages * self._n_microbatches
- # fwd_bwd_ops should encompass the remaining forwards
- fwd_bwd_ops = microbatch_ops - warmup_ops
- # cooldown_ops should encompass the remaining backwards
- cooldown_ops = microbatch_ops - fwd_bwd_ops
- # total ops encompass both forward and backward ops
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
- logger.debug(
- "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
- rank,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- total_ops,
- )
- # Calculates the stage index based on step and pp_group_size
- def forward_stage_index(step):
- # Get the local index from 0 to n_local_stages-1
- local_index = (step // self.microbatches_per_round) % self.n_local_stages
- return (local_index * self.pp_group_size) + rank
- def backward_stage_index(step):
- local_index = (
- self.n_local_stages
- - 1
- - ((step - warmup_ops) // self.microbatches_per_round)
- % self.n_local_stages
- )
- return (local_index * self.pp_group_size) + rank
- return _get_1f1b_rank_ops(
- self.n_local_stages,
- self.pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- )
- class ScheduleInterleavedZeroBubble(PipelineScheduleMulti):
- """
- The Interleaved Zero Bubble schedule.
- See https://arxiv.org/pdf/2401.10241 for details.
- Will perform one forward and one backward on inputs for the microbatches in steady
- state and supports multiple stages per rank. Uses the backward for weights to fill in
- the pipeline bubble.
- In particular this is implementing the ZB1P schedule in the paper.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- # TODO: we don't support Zero Bubble with torch.compile so we
- # should disable it for now
- for stage in stages:
- if isinstance(stage.submod, OptimizedModule):
- raise RuntimeError(
- "The Zero Bubble schedule is not supported with \
- stage modules that have used torch.compile"
- )
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- self.n_local_stages = len(stages)
- self.rank = stages[0].group_rank
- self.number_of_rounds = max(1, n_microbatches // self.pp_group_size)
- self.microbatches_per_round = n_microbatches // self.number_of_rounds
- if n_microbatches % self.number_of_rounds != 0:
- raise ValueError(
- "Zero bubble requires the number of microbatches to be a "
- f"multiple of the number of rounds ({self.number_of_rounds}), "
- f"but got {n_microbatches}."
- )
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # This function add bubbles to the generated schedule based on dependencies of actions
- # Note that the ZB1P schedule will not require bubbles to be manually added and it is
- # only useful when n_microbatches <= microbatches_per_round
- self.pipeline_order = self._add_bubbles_to_actions(
- self.n_local_stages * self.pp_group_size,
- )
- def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
- def get_rank_warmup_ops(rank):
- # Warms up operations for last stage
- warmups_ops_last_stage = (
- self.n_local_stages - 1
- ) * self.microbatches_per_round
- # Increment warmup operations by 2 for each hop away from the last stage
- multiply_factor = 1
- warmup_ops = warmups_ops_last_stage + multiply_factor * (
- (self.pp_group_size - 1) - rank
- )
- # We cannot have more warmup operations than there are number of microbatches, so cap it there
- return min(warmup_ops, self._n_microbatches * self.n_local_stages)
- warmup_ops = get_rank_warmup_ops(rank)
- microbatch_ops = self.n_local_stages * self._n_microbatches
- # fwd_bwd_ops should encompass the remaining forwards
- fwd_bwd_ops = microbatch_ops - warmup_ops
- # cooldown_ops should encompass the remaining backwards
- cooldown_ops = microbatch_ops - fwd_bwd_ops
- # total ops encompass both forward and backward ops
- total_ops = warmup_ops + fwd_bwd_ops + cooldown_ops
- # warmup_ops + fwd_bwd_ops * 2 + cooldown_ops == microbatch_ops * 2
- logger.debug(
- "rank %s, warmup_ops %s, 1f1b %s, cooldown_ops %s total_ops %s",
- rank,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- total_ops,
- )
- # Calculates the stage index based on step and pp_group_size
- def forward_stage_index(step):
- # Get the local index from 0 to n_local_stages-1
- local_index = (step // self.microbatches_per_round) % self.n_local_stages
- return (local_index * self.pp_group_size) + rank
- def backward_stage_index(step):
- local_index = (
- self.n_local_stages
- - 1
- - ((step - warmup_ops) // self.microbatches_per_round)
- % self.n_local_stages
- )
- return (local_index * self.pp_group_size) + rank
- num_1f1b_microbatches = rank
- return _get_1f1b_rank_ops(
- self.n_local_stages,
- self.pp_group_size,
- warmup_ops,
- fwd_bwd_ops,
- cooldown_ops,
- rank,
- forward_stage_index,
- backward_stage_index,
- num_1f1b_microbatches,
- enable_zero_bubble=True,
- )
- def _add_bubbles_to_actions(self, num_stages_global):
- actions = self.pipeline_order
- def need_bubble(stage, op, microbatch, num_stages_global, seen_ops):
- if op == _ComputationType.FORWARD:
- if stage != 0 and (stage - 1, op, microbatch) not in seen_ops:
- return True
- elif op == _ComputationType.FULL_BACKWARD:
- if stage == num_stages_global - 1:
- return (stage, _ComputationType.FORWARD, microbatch) not in seen_ops
- return (stage + 1, op, microbatch) not in seen_ops
- return False
- seen_ops: set[tuple[int, _ComputationType, int]] = set()
- result: dict[int, list[Optional[_Action]]] = {}
- next_pointer: dict[int, int] = {}
- bubbles_added: dict[int, int] = {}
- total_bubbles_added = 0
- for rank in range(self.pp_group_size):
- result[rank] = []
- next_pointer[rank] = 0
- bubbles_added[rank] = 0
- while True:
- should_stop = True
- temp_seen_ops: set[tuple[int, _ComputationType, int]] = set()
- for rank in range(self.pp_group_size):
- timestamp = next_pointer[rank]
- if timestamp >= len(actions[rank]):
- continue
- should_stop = False
- if actions[rank][timestamp] is not None:
- temp_action = actions[rank][timestamp]
- assert temp_action is not None
- stage_index, op, microbatch, _ = temp_action
- if not need_bubble(
- stage_index, op, microbatch, num_stages_global, seen_ops
- ):
- result[rank].append(actions[rank][timestamp])
- if microbatch is not None:
- temp_seen_ops.add((stage_index, op, microbatch))
- next_pointer[rank] += 1
- else:
- result[rank].append(None)
- bubbles_added[rank] += 1
- else:
- next_pointer[rank] += 1
- result[rank].append(None)
- seen_ops.update(temp_seen_ops)
- if should_stop:
- break
- if total_bubbles_added > 0:
- logger.warning(
- "Non zero bubbles added: total_bubbles_added=%s bubbles_added=%s",
- total_bubbles_added,
- bubbles_added,
- )
- return result
- class ScheduleZBVZeroBubble(PipelineScheduleMulti):
- """
- The Zero Bubble schedule (ZBV variant).
- See https://arxiv.org/pdf/2401.10241 Section 6 for details.
- This schedules requires exactly two stages per rank.
- This schedule will perform one forward and one backward on inputs for the microbatches in steady
- state and supports multiple stages per rank. Uses backward with respect to weights to fill in
- the pipeline bubble.
- This ZB-V schedule would have the "zero bubble" property only if time forward == time backward input == time backward weights.
- In practice, this is not likely true for real models so alternatively
- a greedy scheduler could be implemented for unequal/unbalanced time.
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages, style="v"
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self.n_local_stages = len(stages)
- if self.n_local_stages != 2:
- raise ValueError(
- "ZBV requires exactly 2 stages per rank, but got "
- f"{self.n_local_stages}."
- )
- self.rank = stages[0].group_rank
- self.num_stages = stages[0].num_stages
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
- # max(2 * self.pp_group_size - 1, ...) ensure the number of microbatches is at least
- # as large of the number of microbatches needed to fully utilize the pipeline
- n_micro = max(2 * self.pp_group_size - 1, self._n_microbatches)
- rank_ops: list[Optional[_Action]] = [None for _ in range(rank)]
- # Forward and backward action counts for stage chunk 0 and chunk 1
- f0_cnt, f1_cnt, b0_cnt, b1_cnt = 0, 0, 0, 0
- # warm-up phase
- warmup_n1 = 2 * (self.pp_group_size - rank) - 1
- stage_id_chunk0 = rank
- stage_id_chunk1 = self.num_stages - 1 - rank
- for _ in range(warmup_n1):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
- )
- f0_cnt += 1
- warmup_n2 = rank
- for _ in range(warmup_n2):
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt)
- )
- f0_cnt += 1
- warmup_n3 = self.pp_group_size - rank
- for _ in range(warmup_n3):
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- # stable phase
- while f1_cnt < f0_cnt or f0_cnt < n_micro:
- if f0_cnt < n_micro:
- rank_ops.append(
- _Action(
- stage_id_chunk0, computation_type=F, microbatch_index=f0_cnt
- )
- )
- f0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=F, microbatch_index=f1_cnt)
- )
- f1_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- # cool-down phase
- w0_cnt, w1_cnt = b0_cnt, b1_cnt
- cooldown_n1 = rank
- for _ in range(cooldown_n1):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=I, microbatch_index=b1_cnt)
- )
- b1_cnt += 1
- cooldown_n2 = self.pp_group_size - rank
- for _ in range(cooldown_n2):
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=I, microbatch_index=b0_cnt)
- )
- b0_cnt += 1
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
- )
- w0_cnt += 1
- while w1_cnt < b1_cnt:
- rank_ops.append(
- _Action(stage_id_chunk1, computation_type=W, microbatch_index=w1_cnt)
- )
- w1_cnt += 1
- while w0_cnt < b0_cnt:
- rank_ops.append(
- _Action(stage_id_chunk0, computation_type=W, microbatch_index=w0_cnt)
- )
- w0_cnt += 1
- assert w0_cnt == b0_cnt and b0_cnt == f0_cnt
- assert w1_cnt == b1_cnt and b1_cnt == f1_cnt
- # We use max() in the n_micro computation above, so we may need to
- # remove redundant microbatches
- rank_ops = [
- (
- action
- if action is not None
- and action.microbatch_index is not None
- and action.microbatch_index < self._n_microbatches
- else None
- )
- for action in rank_ops
- ]
- return rank_ops
- class ScheduleDualPipeV(_PipelineScheduleRuntime):
- """
- The DualPipeV schedule. A more efficient schedule variant based on the
- DualPipe schedule introduced by DeepSeek in https://arxiv.org/pdf/2412.19437
- Based on the open sourced code from https://github.com/deepseek-ai/DualPipe
- """
- def __init__(
- self,
- stages: list[_PipelineStageBase],
- n_microbatches: int,
- loss_fn: Optional[Callable] = None,
- args_chunk_spec: Optional[tuple[TensorChunkSpec, ...]] = None,
- kwargs_chunk_spec: Optional[dict[str, TensorChunkSpec]] = None,
- output_merge_spec: Optional[Union[dict[str, Any], tuple[Any]]] = None,
- scale_grads: bool = True,
- ):
- self.pp_group_size = stages[0].group_size
- super().__init__(
- stages=stages,
- n_microbatches=n_microbatches,
- loss_fn=loss_fn,
- args_chunk_spec=args_chunk_spec,
- kwargs_chunk_spec=kwargs_chunk_spec,
- output_merge_spec=output_merge_spec,
- scale_grads=scale_grads,
- )
- self.stage_index_to_group_rank = generate_stage_to_rank_mapping(
- self.pp_group_size, self._num_stages, style="v"
- )
- for stage in self._stages:
- stage.stage_index_to_group_rank = self.stage_index_to_group_rank
- self.n_local_stages = len(stages)
- if self.n_local_stages != 2:
- raise ValueError(
- "ZBV requires exactly 2 stages per rank, but got "
- f"{self.n_local_stages}."
- )
- if n_microbatches < self._num_stages:
- raise ValueError(
- "DualPipeV requires at least as many microbatches as stages, but got "
- f"{n_microbatches} microbatches and {self._num_stages} stages."
- )
- self.rank = stages[0].group_rank
- self.num_stages = stages[0].num_stages
- # 1. Create the pipeline_order (all ranks do this calculation)
- # This will be used to keep track of the current state of the entire pipeline
- # pipeline_order[rank] = [Action(computation_type, microbatch_index, stage_index), ...]
- self.pipeline_order: dict[int, list[Optional[_Action]]] = {}
- for rank in range(self.pp_group_size):
- rank_ops = self._calculate_single_rank_operations(rank)
- self.pipeline_order[rank] = rank_ops
- # Initialize the pipeline order with communication necessary to run with _PipelineScheduleRuntime
- self._prepare_schedule_with_comms(self.pipeline_order)
- def _calculate_single_rank_operations(self, rank) -> list[Optional[_Action]]:
- actions: list[Optional[_Action]] = []
- counters: dict[
- tuple[int, _ComputationType], int
- ] = {} # (stage_index, computation_type) -> mb_index
- weight_queue = [] # Queue of (stage_index, mb_index) for pending weight actions
- num_ranks = self.pp_group_size
- num_chunks = self._n_microbatches
- rank_to_stages = generate_rank_to_stage_mapping(
- num_ranks, num_ranks * 2, style="v"
- )
- stage0_index, stage1_index = rank_to_stages[rank]
- def increment_backward_counts(stage_index: int):
- """Helper method to increment BACKWARD_INPUT and BACKWARD_WEIGHT counters when FULL_BACKWARD is used."""
- input_key = (stage_index, BACKWARD_INPUT)
- weight_key = (stage_index, BACKWARD_WEIGHT)
- counters[input_key] = counters.get(input_key, 0) + 1
- counters[weight_key] = counters.get(weight_key, 0) + 1
- def add_overlap_f_b(
- actions: list,
- forward_stage: int,
- backward_stage: int,
- ):
- """Helper method to add an overlapped forward+backward action which tracks microbatch index."""
- # Create new overlapped forward+backward action with sub_actions
- forward_key = (forward_stage, FORWARD)
- backward_key = (backward_stage, BACKWARD_INPUT)
- forward_mb = counters.get(forward_key, 0)
- backward_mb = counters.get(backward_key, 0)
- sub_actions = (
- _Action(forward_stage, FORWARD, forward_mb),
- _Action(backward_stage, FULL_BACKWARD, backward_mb),
- )
- actions.append(_Action(-1, OVERLAP_F_B, None, sub_actions))
- # Update counters for sub_actions
- counters[forward_key] = forward_mb + 1
- increment_backward_counts(backward_stage)
- def add_action(
- actions: list,
- stage_index: int,
- computation_type: _ComputationType,
- ):
- # Regular single action, for FULL_BACKWARD we only use the BACKWARD_INPUT counter
- key = (
- (stage_index, computation_type)
- if computation_type != FULL_BACKWARD
- else (stage_index, BACKWARD_INPUT)
- )
- mb_index = counters.get(key, 0)
- actions.append(_Action(stage_index, computation_type, mb_index))
- # If FULL_BACKWARD is used, just increment the separate BACKWARD_INPUT and BACKWARD_WEIGHT counters
- if computation_type == FULL_BACKWARD:
- increment_backward_counts(stage_index)
- else:
- # If BACKWARD_INPUT is updated, add corresponding weight action to queue
- if computation_type == BACKWARD_INPUT:
- # Add weight action to queue for later processing
- weight_queue.append((stage_index, mb_index))
- counters[key] = mb_index + 1
- def add_weight_action_if_pending(actions: list):
- """Helper method to add a weight action from the queue."""
- if not weight_queue:
- return # No pending weight actions, skip
- # Pop the oldest weight action from the queue
- actual_stage_index, weight_mb_index = weight_queue.pop(0)
- actions.append(
- _Action(
- actual_stage_index,
- BACKWARD_WEIGHT,
- weight_mb_index,
- )
- )
- # Update the counter for the actual stage that was processed
- weight_key = (actual_stage_index, BACKWARD_WEIGHT)
- counters[weight_key] = counters.get(weight_key, 0) + 1
- # Step 1: F0
- step_1 = (num_ranks - rank - 1) * 2
- for _ in range(step_1):
- add_action(actions, stage0_index, FORWARD)
- # Step 2: F0F1
- step_2 = rank + 1
- for _ in range(step_2):
- add_action(actions, stage0_index, FORWARD)
- add_action(actions, stage1_index, FORWARD)
- # Step 3: I1W1F1 (Use zero bubble)
- step_3 = num_ranks - rank - 1
- for _ in range(step_3):
- add_action(actions, stage1_index, BACKWARD_INPUT)
- add_weight_action_if_pending(actions)
- add_action(actions, stage1_index, FORWARD)
- # Step 4 (Main step): F0B1-F1B0 (combined, overlapped forward+backward)
- step_4 = num_chunks - num_ranks * 2 + rank + 1
- for i in range(step_4):
- if i == 0 and rank == num_ranks - 1:
- # NOTE: We don't overlap these two chunks to further reduce bubble size.
- add_action(actions, stage0_index, FORWARD)
- add_action(actions, stage1_index, FULL_BACKWARD)
- else:
- add_overlap_f_b(
- actions,
- forward_stage=stage0_index,
- backward_stage=stage1_index,
- )
- add_overlap_f_b(
- actions,
- forward_stage=stage1_index,
- backward_stage=stage0_index,
- )
- # Step 5: B1-F1B0
- step_5 = num_ranks - rank - 1
- for _ in range(step_5):
- add_action(actions, stage1_index, FULL_BACKWARD)
- add_overlap_f_b(
- actions,
- forward_stage=stage1_index,
- backward_stage=stage0_index,
- )
- # Step 6: B1B0 (The second half of the chunks use zero bubble)
- step_6 = rank + 1
- enable_zb = False
- for i in range(step_6):
- if i == step_6 // 2 and rank % 2 == 1:
- enable_zb = True
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage1_index, comp_type)
- if i == step_6 // 2 and rank % 2 == 0:
- enable_zb = True
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage0_index, comp_type)
- # Step 7: W0B0
- step_7 = num_ranks - rank - 1
- for _ in range(step_7):
- add_weight_action_if_pending(actions)
- comp_type = BACKWARD_INPUT if enable_zb else FULL_BACKWARD
- add_action(actions, stage0_index, comp_type)
- # Step 8: W0
- step_8 = rank + 1
- for _ in range(step_8):
- add_weight_action_if_pending(actions)
- return actions
- def get_schedule_class(schedule_name: str):
- """
- Maps a schedule name (case insensitive) to its corresponding class object.
- Args:
- schedule_name (str): The name of the schedule.
- """
- schedule_map = {
- "1F1B": Schedule1F1B,
- "Interleaved1F1B": ScheduleInterleaved1F1B,
- "GPipe": ScheduleGPipe,
- "LoopedBFS": ScheduleLoopedBFS,
- "InterleavedZeroBubble": ScheduleInterleavedZeroBubble,
- "PipelineScheduleSingle": PipelineScheduleSingle,
- "PipelineScheduleMulti": PipelineScheduleMulti,
- "ZBVZeroBubble": ScheduleZBVZeroBubble,
- "DualPipeV": ScheduleDualPipeV,
- }
- lowercase_keys = {k.lower(): k for k in schedule_map.keys()}
- lowercase_schedule_name = schedule_name.lower()
- if lowercase_schedule_name not in lowercase_keys:
- raise ValueError(
- f"Unknown schedule name '{schedule_name}'. The valid options are {list(schedule_map.keys())}"
- )
- return schedule_map[lowercase_keys[lowercase_schedule_name]]
- def _simulate_comms_compute(
- pipeline_order, stage_to_rank: Callable[[int], int], num_stages: int
- ):
- """This function dry-run simulates the actions in the schedule from the perspective of all ranks, and flags
- any deadlocks caused by missing or misordered communications. It also simulates any bubbles in time where a rank
- can not execute any action due to waiting for unmet dependencies. The total number of simulator steps can be used
- as a metric for unit tests involving IR optimization passes as reordering and merging of IR can reduce the number
- of simulated steps.
- The simulation is not high-fidelity and does not model overlapping of compute and communication, or cuda streams.
- Future work may be to enhance this and model the compute time, comms overlap, and even memory.
- """
- pipeline_order = {
- rank: [a for a in pipeline_order[rank] if a is not None]
- for rank in sorted(pipeline_order)
- }
- _schedule: dict[int, list[_Action | None]] = {
- rank: [] for rank in sorted(pipeline_order)
- }
- _prev_ops_rank: dict[int, set[_Action]] = {rank: set() for rank in _schedule}
- def add_to_schedule(rank: int, action: Optional[_Action]):
- _schedule[rank].append(action)
- if action is not None:
- _prev_ops_rank[rank].add(action)
- def _ready_to_schedule(action: Optional[_Action]) -> bool:
- if action is None:
- return True
- stage_idx = action.stage_index
- prev_ops = _prev_ops_rank[stage_to_rank(stage_idx)]
- if action.computation_type == F:
- if action.stage_index == 0:
- return True
- elif (
- _Action(action.stage_index, RECV_F, action.microbatch_index) in prev_ops
- ):
- return True
- elif (
- _Action(action.stage_index - 1, F, action.microbatch_index) in prev_ops
- ):
- return True
- return False
- elif action.computation_type in (BACKWARD_INPUT, FULL_BACKWARD):
- if action.stage_index == num_stages - 1:
- return True
- if _Action(action.stage_index, RECV_B, action.microbatch_index) in prev_ops:
- return True
- if (
- _Action(action.stage_index + 1, BACKWARD_INPUT, action.microbatch_index)
- in prev_ops
- ):
- return True
- if (
- _Action(action.stage_index + 1, FULL_BACKWARD, action.microbatch_index)
- in prev_ops
- ):
- return True
- return False
- elif action.computation_type == BACKWARD_WEIGHT:
- return True
- elif action.computation_type == SEND_F:
- expected_f = _Action(action.stage_index, F, action.microbatch_index)
- return expected_f in prev_ops
- elif action.computation_type == RECV_F:
- peer_stage_idx = stage_idx - 1
- expected_send = _Action(peer_stage_idx, SEND_F, action.microbatch_index)
- return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
- elif action.computation_type == SEND_B:
- expected_b = _Action(
- action.stage_index, BACKWARD_INPUT, action.microbatch_index
- )
- expected_bw = _Action(
- action.stage_index, FULL_BACKWARD, action.microbatch_index
- )
- return expected_b in prev_ops or expected_bw in prev_ops
- elif action.computation_type == RECV_B:
- peer_stage_idx = stage_idx + 1
- expected_send = _Action(peer_stage_idx, SEND_B, action.microbatch_index)
- return expected_send in _prev_ops_rank[stage_to_rank(peer_stage_idx)]
- else:
- raise ValueError(f"Unsupported action type {action}")
- while pipeline_order:
- progress = False
- for rank in sorted(pipeline_order):
- if len(pipeline_order[rank]) == 0:
- continue
- action = pipeline_order[rank][0]
- if _ready_to_schedule(action):
- if action is not None:
- add_to_schedule(rank, action)
- pipeline_order[rank].pop(0)
- progress = True
- else:
- add_to_schedule(rank, None)
- for i in sorted(pipeline_order, reverse=True):
- if len(pipeline_order[i]) == 0:
- del pipeline_order[i]
- # hacky, but do a second pass to replace any 'none' at this timestep with a real action, if it got unblocked
- # by one of the later ranks
- for rank in sorted(pipeline_order):
- if len(pipeline_order[rank]) == 0:
- continue
- if _schedule[rank][-1] is not None:
- continue
- action = pipeline_order[rank][0]
- if _ready_to_schedule(action):
- if action is not None:
- _schedule[rank][-1] = action
- _prev_ops_rank[rank].add(action)
- pipeline_order[rank].pop(0)
- for i in sorted(pipeline_order, reverse=True):
- if len(pipeline_order[i]) == 0:
- del pipeline_order[i]
- if not progress:
- print("WIP comms schedule:\n", _format_pipeline_order(_schedule))
- for rank in pipeline_order:
- print(f"{rank=} next action= {pipeline_order[rank][0]}")
- raise ValueError("Schedule is not progressing")
- return _schedule
- def _dump_chrometrace(schedule, filename):
- """
- This function dumps a schedule IR into a chrometrace format so it can be visualized.
- It is currently very basic and only serves as a graphical alternative to dumping the schedule IR as text.
- As future work we may extend this to include more accurate heuristics for durations, or let users input durations,
- add 'flow events' to let the UI show the connection between sends and recvs, and model cuda streams for comm/compute
- as separate streams on the chrometrace view.
- """
- events = []
- for rank in sorted(schedule):
- for timestep, action in enumerate(schedule[rank]):
- if action is None:
- continue
- events.append(
- {
- "name": str(action),
- "cat": (
- "computation"
- if action.computation_type in (F, B, W)
- else "communication"
- ),
- "ph": "X",
- "pid": rank,
- "tid": rank,
- "ts": timestep,
- "dur": 1,
- }
- )
- import json
- with open(filename, "w") as f:
- json.dump({"traceEvents": events}, f)
|