partitioners.py 111 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377237823792380238123822383238423852386238723882389239023912392239323942395239623972398239924002401240224032404240524062407240824092410241124122413241424152416241724182419242024212422242324242425242624272428242924302431243224332434243524362437243824392440244124422443244424452446244724482449245024512452245324542455245624572458245924602461246224632464246524662467246824692470247124722473247424752476247724782479248024812482248324842485248624872488248924902491249224932494249524962497249824992500250125022503250425052506250725082509251025112512251325142515251625172518251925202521252225232524252525262527252825292530253125322533253425352536253725382539254025412542254325442545254625472548254925502551255225532554255525562557255825592560256125622563256425652566256725682569257025712572257325742575257625772578257925802581258225832584258525862587258825892590259125922593259425952596259725982599260026012602260326042605260626072608260926102611261226132614261526162617261826192620262126222623262426252626262726282629263026312632263326342635263626372638263926402641264226432644264526462647264826492650265126522653265426552656265726582659266026612662266326642665266626672668266926702671267226732674267526762677267826792680268126822683268426852686268726882689269026912692269326942695269626972698269927002701270227032704270527062707270827092710271127122713271427152716271727182719272027212722272327242725272627272728272927302731273227332734273527362737273827392740274127422743274427452746274727482749275027512752275327542755275627572758275927602761276227632764276527662767276827692770277127722773277427752776277727782779278027812782278327842785278627872788278927902791279227932794279527962797279827992800280128022803280428052806280728082809281028112812281328142815281628172818281928202821282228232824282528262827282828292830283128322833283428352836283728382839284028412842284328442845284628472848284928502851285228532854285528562857285828592860286128622863286428652866286728682869287028712872287328742875287628772878287928802881288228832884288528862887288828892890289128922893289428952896289728982899290029012902290329042905290629072908290929102911291229132914291529162917291829192920
  1. # mypy: allow-untyped-defs
  2. import copy
  3. import functools
  4. import hashlib
  5. import heapq
  6. import itertools
  7. import logging
  8. import math
  9. import operator
  10. import os
  11. import os.path
  12. from collections import defaultdict
  13. from dataclasses import dataclass, replace
  14. from typing import Any, Callable, Optional, TYPE_CHECKING, Union
  15. import torch
  16. import torch._inductor.inductor_prims
  17. import torch.distributed
  18. import torch.fx as fx
  19. import torch.utils._pytree as pytree
  20. from torch._dynamo.utils import counters, is_node_meta_valid
  21. from torch._functorch._activation_checkpointing.ac_logging_utils import (
  22. create_structured_trace_for_min_cut_info,
  23. )
  24. from torch._inductor import config as inductor_config
  25. from torch._logging import trace_structured
  26. from torch._subclasses.fake_tensor import extract_tensor_metadata
  27. from torch.fx.experimental._backward_state import BackwardState
  28. from torch.fx.experimental.proxy_tensor import is_sym_node, py_sym_types
  29. from torch.fx.experimental.sym_node import magic_methods, method_to_operator
  30. from torch.fx.experimental.symbolic_shapes import (
  31. find_symbol_binding_fx_nodes,
  32. free_symbols,
  33. hint_int,
  34. is_symbol_binding_fx_node,
  35. statically_known_false,
  36. statically_known_true,
  37. )
  38. from torch.fx.passes import graph_drawer
  39. from torch.utils._ordered_set import OrderedSet
  40. from torch.utils.checkpoint import CheckpointPolicy
  41. from . import config
  42. from ._activation_checkpointing.graph_info_provider import GraphInfoProvider
  43. from ._activation_checkpointing.knapsack import (
  44. dp_knapsack,
  45. greedy_knapsack,
  46. ilp_knapsack,
  47. )
  48. from ._activation_checkpointing.knapsack_evaluator import KnapsackEvaluator
  49. from ._aot_autograd.descriptors import AOTOutput, SavedForBackwardsAOTOutput
  50. from ._aot_autograd.logging_utils import get_aot_graph_name
  51. from ._aot_autograd.utils import get_cuda_generator_meta_val, is_with_effects
  52. from .compile_utils import fx_graph_cse, get_aten_target, raise_getitems
  53. if TYPE_CHECKING:
  54. import sympy
  55. AOT_PARTITIONER_DEBUG: bool = config.debug_partitioner
  56. log: logging.Logger = logging.getLogger(__name__)
  57. aten = torch.ops.aten
  58. prims = torch.ops.prims
  59. @dataclass
  60. class OpTypes:
  61. """Class for keeping track of different operator categories"""
  62. fusible_ops: OrderedSet[Callable]
  63. compute_intensive_ops: OrderedSet[Callable]
  64. random_ops: OrderedSet[Callable]
  65. view_ops: OrderedSet[Callable]
  66. recomputable_ops: OrderedSet[Callable]
  67. def is_fusible(self, node: fx.Node):
  68. return get_aten_target(node) in self.fusible_ops
  69. def is_compute_intensive(self, node: fx.Node):
  70. return get_aten_target(node) in self.compute_intensive_ops
  71. def is_random(self, node: fx.Node):
  72. return get_aten_target(node) in self.random_ops
  73. def is_view(self, node: fx.Node):
  74. return get_aten_target(node) in self.view_ops
  75. def is_recomputable(self, node: fx.Node):
  76. return get_aten_target(node) in self.recomputable_ops
  77. @dataclass
  78. class NodeInfo:
  79. # Be careful about iterating over these explicitly, as their order may not
  80. # be deterministic
  81. inputs: list[fx.Node]
  82. _required_fw_nodes: OrderedSet[fx.Node]
  83. required_bw_nodes: OrderedSet[fx.Node]
  84. unclaimed_nodes: OrderedSet[fx.Node]
  85. fw_order: dict[fx.Node, int]
  86. # Effectively maps to which of our primals are parameters
  87. static_lifetime_input_nodes: OrderedSet[fx.Node]
  88. @functools.cached_property
  89. def required_fw_nodes(self) -> list[fx.Node]:
  90. return sorted(
  91. (n for n in self._required_fw_nodes), key=lambda n: self.fw_order[n]
  92. )
  93. def is_required_fw(self, n: fx.Node) -> bool:
  94. return n in self._required_fw_nodes
  95. def is_required_bw(self, n: fx.Node) -> bool:
  96. return n in self.required_bw_nodes
  97. def is_unclaimed(self, n: fx.Node) -> bool:
  98. return n in self.unclaimed_nodes
  99. def get_fw_order(self, n: fx.Node) -> int:
  100. assert n in self._required_fw_nodes, f"Node {n} not in fw nodes!"
  101. return self.fw_order[n]
  102. @dataclass
  103. class MinCutOptions:
  104. ban_if_used_far_apart: bool
  105. ban_if_long_fusible_chains: bool
  106. ban_if_materialized_backward: bool
  107. ban_if_not_in_allowlist: bool
  108. ban_if_reduction: bool
  109. def must_recompute(node: fx.Node) -> bool:
  110. return node.meta.get("recompute", None) in [
  111. CheckpointPolicy.MUST_RECOMPUTE,
  112. CheckpointPolicy.PREFER_RECOMPUTE,
  113. ]
  114. def has_recomputable_ops(fx_g: fx.GraphModule) -> bool:
  115. for node in fx_g.graph.nodes:
  116. if must_recompute(node):
  117. return True
  118. return False
  119. def has_recomputable_rng_ops(fx_g: fx.GraphModule) -> bool:
  120. for node in fx_g.graph.nodes:
  121. if (
  122. must_recompute(node)
  123. and hasattr(node.target, "tags")
  124. and torch.Tag.nondeterministic_seeded in node.target.tags
  125. ):
  126. return True
  127. return False
  128. def sym_node_size(node: fx.Node) -> int:
  129. if isinstance(node.meta["val"], (torch.SymInt, torch.SymBool)):
  130. return 1
  131. assert isinstance(node.meta["val"], torch.SymFloat)
  132. return 4
  133. class InvalidNodeBase:
  134. def __repr__(self):
  135. return "Invalid Node"
  136. InvalidNode = InvalidNodeBase()
  137. def _extract_graph_with_inputs_outputs(
  138. joint_graph: fx.Graph,
  139. inputs: list[fx.Node],
  140. outputs: list[fx.Node],
  141. outputs_descs: list[AOTOutput],
  142. subgraph: Optional[str] = None,
  143. ) -> fx.Graph:
  144. """
  145. Given a graph, extracts out a subgraph that takes the specified nodes as
  146. inputs and returns the specified outputs.
  147. This includes specifying non-placeholder nodes as inputs.
  148. The general strategy is to initialize all inputs with proxies as we
  149. encounter them, and trace through the graph, only keeping values which take
  150. in valid proxies. Then, all dead code is eliminated.
  151. """
  152. new_graph = fx.Graph()
  153. env = {}
  154. # Add new placeholder nodes in the order specified by the inputs
  155. for node in inputs:
  156. new_node = new_graph.placeholder(node.name)
  157. # Can't use node_copy here as we may be turning previous call_function into placeholders
  158. new_node.meta = node.meta
  159. env[node] = new_node
  160. for node in joint_graph.nodes:
  161. if _must_be_in_backward(node) and subgraph != "backward":
  162. env[node] = InvalidNode # type: ignore[assignment]
  163. continue
  164. if _must_be_in_forward(node) and subgraph != "forward":
  165. env[node] = InvalidNode # type: ignore[assignment]
  166. continue
  167. if node in env:
  168. # Node must be one of our inputs. (Any member of env which wasn't an
  169. # input to start must have been created by this loop and won't be in
  170. # joint_graph.nodes).
  171. continue
  172. elif node.op == "placeholder":
  173. env[node] = InvalidNode # type: ignore[assignment]
  174. elif node.op == "call_function":
  175. all_args = pytree.arg_tree_leaves(*node.args, **node.kwargs)
  176. all_args = [
  177. isinstance(env[x], InvalidNodeBase)
  178. for x in all_args
  179. if isinstance(x, fx.Node)
  180. ]
  181. if any(all_args):
  182. env[node] = InvalidNode # type: ignore[assignment]
  183. continue
  184. env[node] = new_graph.node_copy(node, lambda x: env[x])
  185. elif node.op == "get_attr":
  186. env[node] = new_graph.node_copy(node, lambda x: env[x])
  187. elif node.op == "output":
  188. pass
  189. output_values = []
  190. for x in outputs:
  191. if isinstance(x, fx.Node):
  192. if x not in env:
  193. raise RuntimeError(f"Node {x} couldn't be found in env")
  194. assert not isinstance(env[x], InvalidNodeBase), (
  195. f"Node {x} was invalid, but is output"
  196. )
  197. output_values.append(env[x])
  198. else:
  199. output_values.append(x)
  200. out = new_graph.output(tuple(output_values))
  201. out.meta["desc"] = outputs_descs
  202. new_graph.eliminate_dead_code()
  203. new_graph.lint()
  204. return new_graph
  205. def _is_primal(node: fx.Node) -> bool:
  206. return (
  207. node.op == "placeholder"
  208. and "tangents" not in str(node.target)
  209. and not _is_bwd_seed_offset(node)
  210. and not _is_fwd_seed_offset(node)
  211. )
  212. def _is_tangent(node: fx.Node) -> bool:
  213. return node.op == "placeholder" and "tangents" in str(node.target)
  214. def _is_bwd_seed_offset(node: fx.Node) -> bool:
  215. return node.op == "placeholder" and (
  216. "bwd_seed" in str(node.target) or "bwd_base_offset" in str(node.target)
  217. )
  218. def _is_fwd_seed_offset(node: fx.Node) -> bool:
  219. return node.op == "placeholder" and (
  220. "fwd_seed" in str(node.target) or "fwd_base_offset" in str(node.target)
  221. )
  222. def _is_backward_state(node: fx.Node) -> bool:
  223. return node.op == "placeholder" and isinstance(node.meta.get("val"), BackwardState)
  224. def _has_tag_is_backward(node: fx.Node) -> bool:
  225. return node.meta.get("partitioner_tag", None) == "is_backward"
  226. def _has_tag_must_be_in_forward(node: fx.Node) -> bool:
  227. return node.meta.get("partitioner_tag", None) == "must_be_in_forward"
  228. def _has_tag_must_be_in_backward(node: fx.Node) -> bool:
  229. return node.meta.get("partitioner_tag", None) == "must_be_in_backward"
  230. def _must_be_in_forward(node: fx.Node) -> bool:
  231. return _has_tag_must_be_in_forward(node)
  232. def _must_be_in_backward(node: fx.Node) -> bool:
  233. return _has_tag_must_be_in_backward(node) or (
  234. _has_tag_is_backward(node) and is_with_effects(node)
  235. )
  236. def _extract_fwd_bwd_outputs(
  237. joint_module: fx.GraphModule, *, num_fwd_outputs
  238. ) -> tuple[list[fx.Node], list[fx.Node], list[AOTOutput], list[AOTOutput]]:
  239. outputs = pytree.arg_tree_leaves(
  240. *(node.args for node in joint_module.graph.find_nodes(op="output"))
  241. )
  242. outputs_descs = pytree.arg_tree_leaves(
  243. next(iter(joint_module.graph.find_nodes(op="output"))).meta.get(
  244. "desc", [None] * len(outputs)
  245. )
  246. )
  247. fwd_outputs = outputs[:num_fwd_outputs]
  248. bwd_outputs = outputs[num_fwd_outputs:]
  249. fwd_outputs_descs = outputs_descs[:num_fwd_outputs]
  250. bwd_outputs_descs = outputs_descs[num_fwd_outputs:]
  251. return fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs
  252. def _remove_by_name(saved_values: list[fx.Node], name: str):
  253. for saved_value in saved_values:
  254. if saved_value.name == name:
  255. saved_values.remove(saved_value)
  256. break
  257. def find_first_sym_node(
  258. fwd_module_outputs: Union[list[fx.Node], tuple[fx.Node]],
  259. ) -> int:
  260. idx = len(fwd_module_outputs)
  261. for i in range(len(fwd_module_outputs) - 1, -1, -1):
  262. if not is_sym_node(fwd_module_outputs[i]):
  263. idx = i + 1
  264. break
  265. return idx
  266. def calculate_quantization_scaling(
  267. graph: torch.fx.Graph,
  268. node: torch.fx.Node,
  269. max: float = 57344.0,
  270. min: float = 1e-12,
  271. ):
  272. with graph.inserting_after(node):
  273. abs_node = graph.call_function(
  274. torch.ops.aten.abs.default,
  275. args=(node,),
  276. )
  277. abs_node.meta["val"] = torch.ops.aten.abs.default(node.meta["val"])
  278. abs_node.meta["tensor_meta"] = extract_tensor_metadata(abs_node.meta["val"])
  279. with graph.inserting_after(abs_node):
  280. amax_node = graph.call_function(
  281. torch.ops.aten.amax.default,
  282. args=(abs_node, [-1], True),
  283. )
  284. amax_node.meta["val"] = torch.ops.aten.amax.default(
  285. abs_node.meta["val"], [-1], True
  286. )
  287. amax_node.meta["tensor_meta"] = extract_tensor_metadata(amax_node.meta["val"])
  288. with graph.inserting_after(amax_node):
  289. amax_64_node = graph.call_function(
  290. torch.ops.prims.convert_element_type.default,
  291. args=(amax_node, torch.float64),
  292. )
  293. amax_64_node.meta["val"] = torch.ops.prims.convert_element_type.default(
  294. amax_node.meta["val"], torch.float64
  295. )
  296. amax_64_node.meta["tensor_meta"] = extract_tensor_metadata(
  297. amax_64_node.meta["val"]
  298. )
  299. with graph.inserting_after(amax_64_node):
  300. clamp_min_node = graph.call_function(
  301. torch.ops.aten.clamp_min.default,
  302. args=(amax_64_node, min),
  303. )
  304. clamp_min_node.meta["val"] = torch.ops.aten.clamp_min.default(
  305. amax_64_node.meta["val"], min
  306. )
  307. clamp_min_node.meta["tensor_meta"] = extract_tensor_metadata(
  308. clamp_min_node.meta["val"]
  309. )
  310. with graph.inserting_after(clamp_min_node):
  311. reciprocal_node = graph.call_function(
  312. torch.ops.aten.reciprocal.default,
  313. args=(clamp_min_node,),
  314. )
  315. reciprocal_node.meta["val"] = torch.ops.aten.reciprocal.default(
  316. clamp_min_node.meta["val"]
  317. )
  318. reciprocal_node.meta["tensor_meta"] = extract_tensor_metadata(
  319. reciprocal_node.meta["val"]
  320. )
  321. with graph.inserting_after(reciprocal_node):
  322. mul_node = graph.call_function(
  323. torch.ops.aten.mul.Tensor,
  324. args=(reciprocal_node, max),
  325. )
  326. mul_node.meta["val"] = torch.ops.aten.mul.Tensor(
  327. reciprocal_node.meta["val"], max
  328. )
  329. mul_node.meta["tensor_meta"] = extract_tensor_metadata(mul_node.meta["val"])
  330. with graph.inserting_after(mul_node):
  331. scale_node = graph.call_function(
  332. torch.ops.prims.convert_element_type.default,
  333. args=(mul_node, torch.float32),
  334. name="fp8_scale_" + str(node.name),
  335. )
  336. scale_node.meta["val"] = torch.ops.prims.convert_element_type.default(
  337. mul_node.meta["val"], torch.float32
  338. )
  339. scale_node.meta["tensor_meta"] = extract_tensor_metadata(scale_node.meta["val"])
  340. return scale_node
  341. def perform_quantization(
  342. graph: torch.fx.Graph,
  343. node: torch.fx.Node,
  344. scale_node: torch.fx.Node,
  345. quant_type: torch.dtype,
  346. clamp_min: float,
  347. clamp_max: float,
  348. ) -> torch.fx.Node:
  349. with graph.inserting_after(scale_node):
  350. target_node_32 = graph.call_function(
  351. torch.ops.prims.convert_element_type.default,
  352. args=(node, torch.float32),
  353. )
  354. target_node_32.meta["val"] = torch.ops.prims.convert_element_type.default(
  355. node.meta["val"], torch.float32
  356. )
  357. target_node_32.meta["tensor_meta"] = extract_tensor_metadata(
  358. target_node_32.meta["val"]
  359. )
  360. with graph.inserting_after(target_node_32):
  361. scaled_target_node = graph.call_function(
  362. torch.ops.aten.mul.Tensor,
  363. args=(target_node_32, scale_node),
  364. )
  365. scaled_target_node.meta["val"] = torch.ops.aten.mul.Tensor(
  366. target_node_32.meta["val"], scale_node.meta["val"]
  367. )
  368. scaled_target_node.meta["tensor_meta"] = extract_tensor_metadata(
  369. scaled_target_node.meta["val"]
  370. )
  371. with graph.inserting_after(scaled_target_node):
  372. clamp_min_scaled_node = graph.call_function(
  373. torch.ops.aten.clamp_min.default,
  374. args=(scaled_target_node, clamp_min),
  375. )
  376. clamp_min_scaled_node.meta["val"] = torch.ops.aten.clamp_min.default(
  377. scaled_target_node.meta["val"], clamp_min
  378. )
  379. clamp_min_scaled_node.meta["tensor_meta"] = extract_tensor_metadata(
  380. clamp_min_scaled_node.meta["val"]
  381. )
  382. with graph.inserting_after(clamp_min_scaled_node):
  383. clamp_max_scaled_node = graph.call_function(
  384. torch.ops.aten.clamp_max.default,
  385. args=(clamp_min_scaled_node, clamp_max),
  386. )
  387. clamp_max_scaled_node.meta["val"] = torch.ops.aten.clamp_max.default(
  388. clamp_min_scaled_node.meta["val"], clamp_max
  389. )
  390. clamp_max_scaled_node.meta["tensor_meta"] = extract_tensor_metadata(
  391. clamp_max_scaled_node.meta["val"]
  392. )
  393. with graph.inserting_after(clamp_max_scaled_node):
  394. quant_activation_node = graph.call_function(
  395. torch.ops.prims.convert_element_type.default,
  396. args=(clamp_max_scaled_node, quant_type),
  397. name="fp8_quant_" + str(node.name),
  398. )
  399. quant_activation_node.meta["val"] = (
  400. torch.ops.prims.convert_element_type.default(
  401. clamp_max_scaled_node.meta["val"], quant_type
  402. )
  403. )
  404. quant_activation_node.meta["tensor_meta"] = extract_tensor_metadata(
  405. quant_activation_node.meta["val"]
  406. )
  407. return quant_activation_node
  408. def calculate_tensor_size(tensor: torch.Tensor) -> float:
  409. """
  410. Calculate the size of a PyTorch tensor in megabytes (MB).
  411. Args:
  412. tensor (torch.Tensor): Input tensor
  413. Returns:
  414. float: Memory size in MB
  415. """
  416. # Get number of elements and size per element
  417. num_elements = tensor.numel()
  418. element_size = tensor.element_size()
  419. return (num_elements * element_size) / (1024 * 1024)
  420. def get_allowed_dtypes() -> list[torch.dtype]:
  421. allowed_dtypes = torch._inductor.config.post_grad_fusion_options[
  422. "activation_quantization_aten_pass"
  423. ].get("allowed_dtypes", "torch.bfloat16")
  424. allowed_dtypes = [
  425. getattr(torch, dtype.split(".")[-1]) for dtype in allowed_dtypes.split(";")
  426. ]
  427. return allowed_dtypes
  428. def should_quantize(node: torch.fx.Node) -> bool:
  429. allowed_dtypes = get_allowed_dtypes()
  430. if not is_node_meta_valid(node) or node.meta["val"].dtype not in allowed_dtypes:
  431. return False
  432. size_threshold = torch._inductor.config.post_grad_fusion_options[
  433. "activation_quantization_aten_pass"
  434. ].get("size_in_mb", 100)
  435. # calculate the size of the node
  436. size_in_mb = calculate_tensor_size(node.meta["val"])
  437. if not torch._inductor.config.post_grad_fusion_options[
  438. "activation_quantization_aten_pass"
  439. ].get("skip_dynamo_guards", False):
  440. return size_in_mb >= size_threshold
  441. else:
  442. # case 1: we always quantize tensors with dynamic shapes
  443. if torch._inductor.config.post_grad_fusion_options[
  444. "activation_quantization_aten_pass"
  445. ].get("quantize_dynamic_shape", False):
  446. return statically_known_true(
  447. size_in_mb >= size_threshold
  448. ) or not statically_known_false(size_in_mb >= size_threshold)
  449. else:
  450. # case 2: we always not quantize tensors with dynamic shapes
  451. return statically_known_true(size_in_mb >= size_threshold)
  452. def get_quant_type() -> torch.dtype:
  453. quant_type = torch._inductor.config.post_grad_fusion_options[
  454. "activation_quantization_aten_pass"
  455. ].get("quant_type", "torch.float8_e5m2")
  456. return getattr(torch, quant_type.split(".")[-1])
  457. def calculate_range(dtype: torch.dtype) -> tuple:
  458. """
  459. Calculate the range of values for a given torch.dtype.
  460. Args:
  461. dtype (torch.dtype): The input dtype.
  462. Returns:
  463. tuple: A tuple containing the minimum and maximum values.
  464. """
  465. info = torch.finfo(dtype)
  466. return info.min, info.max
  467. def quantize_activation_fw(graph: torch.fx.Graph) -> None:
  468. output = graph.find_nodes(op="output")[0]
  469. fwd_outputs = output.args[0]
  470. quant_type = get_quant_type()
  471. clamp_min, clamp_max = calculate_range(quant_type)
  472. node_to_quant = dict()
  473. tensor_scale_nodes, sym_scale_nodes = [], []
  474. for node in fwd_outputs:
  475. # check if the activation node is the node saved for quantization
  476. if node.meta.get("saved_for_quantization", False):
  477. # case: use scaling
  478. if torch._inductor.config.post_grad_fusion_options[
  479. "activation_quantization_aten_pass"
  480. ].get("use_scaling", True):
  481. # calculating the scale
  482. scale_node = calculate_quantization_scaling(
  483. graph, node, clamp_max, 1e-12
  484. )
  485. # converting to fp8
  486. quant_node = perform_quantization(
  487. graph, node, scale_node, quant_type, clamp_min, clamp_max
  488. )
  489. if not is_sym_node(scale_node):
  490. tensor_scale_nodes.append(scale_node)
  491. else:
  492. sym_scale_nodes.append(scale_node)
  493. else:
  494. # case: do not use scaling
  495. with graph.inserting_after(node):
  496. quant_node = graph.call_function(
  497. torch.ops.prims.convert_element_type.default,
  498. args=(node, quant_type),
  499. name="fp8_quant_" + str(node.name),
  500. )
  501. quant_node.meta["val"] = (
  502. torch.ops.prims.convert_element_type.default(
  503. node.meta["val"], quant_type
  504. )
  505. )
  506. quant_node.meta["tensor_meta"] = extract_tensor_metadata(
  507. quant_node.meta["val"]
  508. )
  509. node_to_quant[node] = quant_node
  510. # only update the return node args, and remain all other users unchanged
  511. output_updated_args = [
  512. node_to_quant[node] if node in node_to_quant else node for node in fwd_outputs
  513. ]
  514. # add the scale nodes to the output find the first sym_node in the output
  515. idx = find_first_sym_node(output_updated_args)
  516. scale_nodes = tensor_scale_nodes + sym_scale_nodes
  517. if scale_nodes:
  518. output_updated_args = (
  519. output_updated_args[:idx] + scale_nodes + output_updated_args[idx:]
  520. )
  521. output.update_arg(0, tuple(output_updated_args))
  522. counters["inductor"]["activation_quantization_fwd_aten_pass"] += 1
  523. def quantize_activation_bw(graph: torch.fx.Graph) -> None:
  524. bw_inputs = [node for node in graph.nodes if node.op == "placeholder"]
  525. activation_node = None
  526. for node in bw_inputs:
  527. if node.meta.get("saved_for_quantization", False):
  528. node.meta.pop("saved_for_quantization")
  529. dequant_type = node.meta.pop("dequant_type")
  530. # dequantize the node
  531. if torch._inductor.config.post_grad_fusion_options[
  532. "activation_quantization_aten_pass"
  533. ].get("use_scaling", False):
  534. # case: use scaling
  535. with graph.inserting_after(node):
  536. # find corresponding scale node
  537. scale_name = "fp8_scale_" + node.name.replace("fp8_quant_", "")
  538. scale_node = next(
  539. bwd_input
  540. for bwd_input in bw_inputs
  541. if bwd_input.name == scale_name
  542. )
  543. with graph.inserting_after(scale_node):
  544. activation_node = graph.call_function(
  545. torch.ops.prims.convert_element_type.default,
  546. args=(node, dequant_type),
  547. )
  548. activation_node.meta["val"] = (
  549. torch.ops.prims.convert_element_type.default(
  550. node.meta["val"], dequant_type
  551. )
  552. )
  553. activation_node.meta["tensor_meta"] = extract_tensor_metadata(
  554. activation_node.meta["val"]
  555. )
  556. with graph.inserting_after(activation_node):
  557. divided_target_node_32 = graph.call_function(
  558. torch.ops.aten.div.Tensor,
  559. args=(activation_node, scale_node),
  560. )
  561. divided_target_node_32.meta["val"] = torch.ops.aten.div.Tensor(
  562. activation_node.meta["val"], scale_node.meta["val"]
  563. )
  564. divided_target_node_32.meta["tensor_meta"] = (
  565. extract_tensor_metadata(divided_target_node_32.meta["val"])
  566. )
  567. with graph.inserting_after(divided_target_node_32):
  568. dequant_node = graph.call_function(
  569. torch.ops.prims.convert_element_type.default,
  570. args=(divided_target_node_32, dequant_type),
  571. )
  572. dequant_node.meta["val"] = (
  573. torch.ops.prims.convert_element_type.default(
  574. divided_target_node_32.meta["val"], dequant_type
  575. )
  576. )
  577. dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
  578. dequant_node.meta["val"]
  579. )
  580. else:
  581. with graph.inserting_after(node):
  582. dequant_node = graph.call_function(
  583. torch.ops.prims.convert_element_type.default,
  584. args=(node, dequant_type),
  585. name="dequant_" + str(node.name),
  586. )
  587. dequant_node.meta["val"] = (
  588. torch.ops.prims.convert_element_type.default(
  589. node.meta["val"], dequant_type
  590. )
  591. )
  592. dequant_node.meta["tensor_meta"] = extract_tensor_metadata(
  593. dequant_node.meta["val"]
  594. )
  595. # find the users of the node and replace them with the new node except the dequant_node
  596. for user in list(node.users.keys()):
  597. if user != dequant_node and user != activation_node:
  598. user.replace_input_with(node, dequant_node)
  599. counters["inductor"]["activation_quantization_bwd_aten_pass"] += 1
  600. def perform_fp8_activation_quantization(
  601. fwd_module: fx.GraphModule,
  602. bwd_module: fx.GraphModule,
  603. bwd_module_inputs: dict[str, fx.Node],
  604. ) -> None:
  605. trace_structured(
  606. "artifact",
  607. metadata_fn=lambda: {
  608. "name": "before_activation_quantization_fwd_aten_pass",
  609. "encoding": "string",
  610. },
  611. payload_fn=lambda: fwd_module.print_readable(
  612. print_output=False, include_stride=True, include_device=True
  613. ),
  614. )
  615. quantize_activation_fw(fwd_module.graph)
  616. trace_structured(
  617. "artifact",
  618. metadata_fn=lambda: {
  619. "name": "after_activation_quantization_fwd_aten_pass",
  620. "encoding": "string",
  621. },
  622. payload_fn=lambda: fwd_module.print_readable(
  623. print_output=False, include_stride=True, include_device=True
  624. ),
  625. )
  626. trace_structured(
  627. "artifact",
  628. metadata_fn=lambda: {
  629. "name": "before_activation_quantization_bwd_aten_pass",
  630. "encoding": "string",
  631. },
  632. payload_fn=lambda: bwd_module.print_readable(
  633. print_output=False, include_stride=True, include_device=True
  634. ),
  635. )
  636. quant_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  637. # update the corresponding bwd_inputs due to the fwd_outputs quantization
  638. for fwd_node in quant_fwd_module_outputs:
  639. if "fp8_quant_" in fwd_node.name:
  640. bwd_input = bwd_module_inputs[fwd_node.name.replace("fp8_quant_", "")]
  641. with bwd_module.graph.inserting_after(bwd_input):
  642. quant_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
  643. dequant_type = bwd_input.meta["dequant_type"]
  644. quant_bwd_input.meta.update(fwd_node.meta)
  645. quant_bwd_input.meta["saved_for_quantization"] = True
  646. quant_bwd_input.meta["dequant_type"] = dequant_type
  647. bwd_input.replace_all_uses_with(quant_bwd_input)
  648. bwd_module.graph.erase_node(bwd_input)
  649. # update the bwd_inputs if quantization with scaling is used
  650. if torch._inductor.config.post_grad_fusion_options[
  651. "activation_quantization_aten_pass"
  652. ].get("use_scaling", True):
  653. quant_bwd_module_inputs = list(bwd_module.graph.find_nodes(op="placeholder"))
  654. # update the corresponding bwd input nodes find the last non-tangent node
  655. bwd_input_loc = quant_bwd_module_inputs[-1]
  656. for bw_input in reversed(quant_bwd_module_inputs):
  657. if not _is_tangent(bw_input):
  658. bwd_input_loc = bw_input
  659. break
  660. scaled_fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  661. for fwd_node in scaled_fwd_module_outputs:
  662. if "fp8_scale_" in fwd_node.name:
  663. # fwd node is a scale node
  664. with bwd_module.graph.inserting_after(bwd_input_loc):
  665. scale_bwd_input = bwd_module.graph.placeholder(name=fwd_node.name)
  666. scale_bwd_input.meta.update(fwd_node.meta)
  667. bwd_input_loc = scale_bwd_input
  668. quantize_activation_bw(bwd_module.graph)
  669. trace_structured(
  670. "artifact",
  671. metadata_fn=lambda: {
  672. "name": "after_activation_quantization_bwd_aten_pass",
  673. "encoding": "string",
  674. },
  675. payload_fn=lambda: bwd_module.print_readable(
  676. print_output=False, include_stride=True, include_device=True
  677. ),
  678. )
  679. def enable_activation_quantization(
  680. saved_values: list[fx.Node],
  681. fwd_module: fx.GraphModule,
  682. bwd_module: fx.GraphModule,
  683. static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
  684. ) -> None:
  685. if (
  686. inductor_config.post_grad_fusion_options.get(
  687. "activation_quantization_aten_pass", None
  688. )
  689. is None
  690. ):
  691. return
  692. static_input_names = (
  693. [node.name for node in static_lifetime_input_nodes]
  694. if static_lifetime_input_nodes
  695. else []
  696. )
  697. saved_values_names = {node.name: node for node in saved_values}
  698. if torch._inductor.config.post_grad_fusion_options[
  699. "activation_quantization_aten_pass"
  700. ].get("exclude_primals", False):
  701. saved_values_names = {
  702. node.name: node for node in saved_values if "primals" not in node.name
  703. }
  704. fwd_module_outputs = fwd_module.graph.find_nodes(op="output")[0].args[0]
  705. bwd_module_inputs = {
  706. node.name: node for node in bwd_module.graph.find_nodes(op="placeholder")
  707. }
  708. should_perform_fp8_quant = False
  709. for node in fwd_module_outputs:
  710. if node.name in saved_values_names and should_quantize(node):
  711. if node.name in static_input_names:
  712. log.debug("Skipping quantization of static input %s: ", node.name)
  713. continue
  714. node.meta["saved_for_quantization"] = True
  715. node.meta["dequant_type"] = node.meta["val"].dtype
  716. # some of the fwd outputs and bwd inputs are not share the same object
  717. bwd_module_inputs[node.name].meta["saved_for_quantization"] = True
  718. bwd_module_inputs[node.name].meta["dequant_type"] = node.meta["val"].dtype
  719. should_perform_fp8_quant = True
  720. if should_perform_fp8_quant:
  721. perform_fp8_activation_quantization(fwd_module, bwd_module, bwd_module_inputs)
  722. def _extract_fwd_bwd_modules(
  723. joint_module: fx.GraphModule,
  724. saved_values: list[fx.Node],
  725. saved_sym_nodes: list[fx.Node],
  726. *,
  727. num_fwd_outputs: int,
  728. static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
  729. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  730. fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
  731. _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  732. )
  733. placeholders = joint_module.graph.find_nodes(op="placeholder")
  734. primal_inputs = [*filter(_is_primal, placeholders)]
  735. tangent_inputs = [*filter(_is_tangent, placeholders)]
  736. fwd_seed_offset_inputs = [*filter(_is_fwd_seed_offset, placeholders)]
  737. bwd_seed_offset_inputs = [*filter(_is_bwd_seed_offset, placeholders)]
  738. backward_state_inputs = [*filter(_is_backward_state, placeholders)]
  739. bwd_graph = _extract_graph_with_inputs_outputs(
  740. joint_module.graph,
  741. saved_sym_nodes + saved_values + tangent_inputs + bwd_seed_offset_inputs,
  742. bwd_outputs,
  743. bwd_outputs_descs,
  744. "backward",
  745. )
  746. distributed_enabled = torch.distributed.is_available()
  747. for node in bwd_graph.find_nodes(op="placeholder"):
  748. # This is to filter out saved values that don't actually end up being used by the backwards pass
  749. if not node.users:
  750. _remove_by_name(saved_values, node.name)
  751. _remove_by_name(saved_sym_nodes, node.name)
  752. # wait_tensor is a bit special: if we have a "dead activation" that is not used in the bw,
  753. # but this dead activation is actually a collective,
  754. # then the collective will generally by followed by a wait_tensor() call.
  755. # we need to peak one node further to see if this wait_tensor is dead as well.
  756. elif distributed_enabled and all(
  757. n.target is torch.ops._c10d_functional.wait_tensor.default
  758. and len(n.users) == 0
  759. for n in node.users
  760. ):
  761. _remove_by_name(saved_values, node.name)
  762. _remove_by_name(saved_sym_nodes, node.name)
  763. elif _is_backward_state(node):
  764. # BackwardState is saved directly
  765. _remove_by_name(saved_values, node.name)
  766. assert backward_state_inputs
  767. # Now that we have the finalized list of saved values, we need to ensure
  768. # we propagate all symbols which are referenced by backwards inputs.
  769. # These are not directly used in the graph but are required for downstream
  770. # sizevar assignment
  771. saved_symbols: OrderedSet[sympy.Symbol] = OrderedSet()
  772. saved_sym_nodes_binding = []
  773. saved_sym_nodes_derived = []
  774. # Some symbols may already be bound in the directly saved_sym_nodes,
  775. # keep track of them so we don't re-bind them
  776. for node in saved_sym_nodes:
  777. symbol = is_symbol_binding_fx_node(node)
  778. if symbol:
  779. saved_symbols.add(symbol)
  780. saved_sym_nodes_binding.append(node)
  781. else:
  782. saved_sym_nodes_derived.append(node)
  783. # Now go through all of the prospective backward inputs and track any
  784. # other symbols we need to bind
  785. symbol_bindings = find_symbol_binding_fx_nodes(joint_module.graph)
  786. for node in itertools.chain(saved_sym_nodes_derived, saved_values, tangent_inputs):
  787. if "val" not in node.meta:
  788. continue
  789. new_symbols = free_symbols(node.meta["val"]) - saved_symbols
  790. # NB: Deterministic order please!
  791. for s in sorted(new_symbols, key=lambda s: s.name):
  792. # NB: For well formed graphs, the symbol should always be present,
  793. # but we also have ways to produce ill-formed graphs, e.g., direct
  794. # make_fx usages, so don't choke in this case
  795. if s not in symbol_bindings:
  796. continue
  797. saved_sym_nodes_binding.append(symbol_bindings[s])
  798. saved_symbols |= new_symbols
  799. # Update saved_sym_nodes that are now reordered to have all bindings at
  800. # front. This can also be used later on to figure out the position of saved
  801. # sym nodes in the output of fwd graph.
  802. saved_sym_nodes.clear()
  803. saved_sym_nodes.extend(saved_sym_nodes_binding + saved_sym_nodes_derived)
  804. # Now, we re-generate the fwd/bwd graphs.
  805. # NB: This might increase compilation time, but I doubt it matters
  806. fwd_graph = _extract_graph_with_inputs_outputs(
  807. joint_module.graph,
  808. primal_inputs + fwd_seed_offset_inputs,
  809. fwd_outputs + saved_values + saved_sym_nodes,
  810. fwd_outputs_descs
  811. + [
  812. SavedForBackwardsAOTOutput(i)
  813. for i in range(len(saved_values) + len(saved_sym_nodes))
  814. ],
  815. "forward",
  816. )
  817. bwd_graph = _extract_graph_with_inputs_outputs(
  818. joint_module.graph,
  819. saved_sym_nodes
  820. + saved_values
  821. + tangent_inputs
  822. + bwd_seed_offset_inputs
  823. + backward_state_inputs,
  824. bwd_outputs,
  825. bwd_outputs_descs,
  826. "backward",
  827. )
  828. fwd_module = fx._lazy_graph_module._make_graph_module(joint_module, fwd_graph)
  829. bwd_module = fx._lazy_graph_module._make_graph_module(joint_module, bwd_graph)
  830. enable_activation_quantization(
  831. saved_values, fwd_module, bwd_module, static_lifetime_input_nodes
  832. )
  833. return fwd_module, bwd_module
  834. def default_partition(
  835. joint_module: fx.GraphModule,
  836. _joint_inputs,
  837. *,
  838. num_fwd_outputs,
  839. static_lifetime_input_indices: Optional[list[int]] = None,
  840. static_lifetime_input_nodes: Optional[OrderedSet[fx.Node]] = None,
  841. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  842. """
  843. Partitions the :attr:`joint_module` in a manner that closely resembles the
  844. behavior observed in the original ``.forward()`` and ``.backward()`` of the
  845. callable, i.e., the resulting forward graph contains those operators that
  846. are executed in the original ``.forward()`` callable passed to
  847. :func:`aot_function`.
  848. The default partitioner collects the operators that are between the forward
  849. inputs and the forward outputs. This helps in finding the tensors which have
  850. to be stashed for the backward pass. These stashed tensors become the output
  851. of the generated forward graph. The remaining operators are then placed in
  852. the backward graph.
  853. .. warning::
  854. This API is experimental and likely to change.
  855. Args:
  856. joint_module(fx.GraphModule): The joint forward and backward graph. This
  857. is the result of AOT Autograd tracing.
  858. Returns:
  859. Returns the generated forward and backward Fx graph modules.
  860. """
  861. if has_recomputable_ops(joint_module):
  862. return min_cut_rematerialization_partition(
  863. joint_module,
  864. _joint_inputs,
  865. num_fwd_outputs=num_fwd_outputs,
  866. static_lifetime_input_indices=static_lifetime_input_indices,
  867. )
  868. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  869. fwd_seed_offset_inputs = list(filter(_is_fwd_seed_offset, joint_module.graph.nodes))
  870. inputs = primal_inputs + fwd_seed_offset_inputs
  871. fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
  872. _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  873. )
  874. forward_only_graph = _extract_graph_with_inputs_outputs(
  875. joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
  876. )
  877. forward_node_names = OrderedSet(
  878. node.name for node in forward_only_graph.nodes if node.op != "output"
  879. )
  880. saved_values = []
  881. saved_sym_nodes = []
  882. for node in joint_module.graph.nodes:
  883. if node.name not in forward_node_names:
  884. continue
  885. if is_sym_node(node):
  886. # Symints must be kept separate from tensors so that PythonFunction only calls
  887. # save_for_backward on tensors and stashes symints in autograd .ctx
  888. saved_sym_nodes.append(node)
  889. elif "tensor_meta" not in node.meta and node.op == "call_function":
  890. # Since we can't save tuple of tensor values, we need to flatten out what we're saving
  891. users = node.users
  892. assert all(user.target == operator.getitem for user in users)
  893. saved_values.extend(users)
  894. else:
  895. backward_usages = [
  896. n for n in node.users if n.name not in forward_node_names
  897. ]
  898. if "tensor_meta" in node.meta and all(
  899. is_sym_node(n) for n in backward_usages
  900. ):
  901. # If we have a tensor in the forward, where only its sizes/strides are needed in the backward,
  902. # and not the actual tensor data,
  903. # then it will be a lot cheaper to save only the sizes/strides, and not the actual tensor.
  904. #
  905. # Note that saving the tensor could also cause compilation problems:
  906. # If the user mutated an input in the forward and uses its sizes/strides in the backward,
  907. # then we would be obligated to clone the input before saving it to appease autograd.
  908. # (This is how we originally found this bug).
  909. saved_sym_nodes.extend(backward_usages)
  910. else:
  911. saved_values.append(node)
  912. saved_values = list(dict.fromkeys(saved_values).keys())
  913. saved_sym_nodes = list(dict.fromkeys(saved_sym_nodes).keys())
  914. return _extract_fwd_bwd_modules(
  915. joint_module,
  916. saved_values,
  917. saved_sym_nodes=saved_sym_nodes,
  918. num_fwd_outputs=num_fwd_outputs,
  919. static_lifetime_input_nodes=static_lifetime_input_nodes,
  920. )
  921. INT_INF = int(1e6)
  922. def _tensor_nbytes(numel: int, dtype) -> int:
  923. return numel * dtype.itemsize
  924. def _size_of(node: fx.Node) -> int:
  925. def object_nbytes(x) -> int:
  926. if not isinstance(x, torch.Tensor):
  927. return 0
  928. return _tensor_nbytes(hint_int(x.numel(), fallback=4096), x.dtype)
  929. if "val" in node.meta:
  930. val = node.meta["val"]
  931. if isinstance(val, py_sym_types):
  932. return 1
  933. # NB: The fallback values here are meaningless, maybe we should respect
  934. # torch._inductor.config.unbacked_symint_fallback (but this is a
  935. # layering violation)
  936. elif isinstance(val, (list, tuple)):
  937. return sum(object_nbytes(n) for n in val)
  938. elif isinstance(val, dict):
  939. return sum(object_nbytes(n) for _, n in val.items())
  940. elif isinstance(val, torch.Tensor):
  941. return object_nbytes(val)
  942. raise RuntimeError(f"Unknown metadata type {type(val)} on node {node}")
  943. if node.op == "get_attr" or node.target is torch.ops.aten._assert_scalar.default:
  944. return 0
  945. raise RuntimeError(
  946. f"Node {node} didn't have `val` metadata; we should always have `val` metadata on the nodes."
  947. )
  948. # Used for some investigative purposes
  949. def _count_ops(graph: fx.Graph):
  950. from collections import defaultdict
  951. cnt: dict[str, int] = defaultdict(int)
  952. for node in graph.nodes:
  953. if node.op == "call_function":
  954. cnt[node.target.__name__] += 1
  955. log.info("%s", sorted(cnt.items(), key=operator.itemgetter(1), reverse=True))
  956. @functools.cache
  957. def pointwise_ops():
  958. ops = []
  959. for attr_name in dir(torch.ops.aten):
  960. opoverloadpacket = getattr(torch.ops.aten, attr_name)
  961. if not isinstance(opoverloadpacket, torch._ops.OpOverloadPacket):
  962. continue
  963. for overload in opoverloadpacket.overloads():
  964. op_overload = getattr(opoverloadpacket, overload)
  965. if torch.Tag.pointwise in op_overload.tags:
  966. # currently aot autograd uses packet not overload
  967. ops.append(opoverloadpacket)
  968. break
  969. return ops
  970. def sort_depths(args, depth_map: dict[fx.Node, int]) -> list[tuple[fx.Node, int]]:
  971. arg_depths = {
  972. arg: depth_map[arg] for arg in args if isinstance(arg, torch.fx.node.Node)
  973. }
  974. return sorted(arg_depths.items(), key=operator.itemgetter(1), reverse=True)
  975. def reordering_to_mimic_autograd_engine(gm: fx.GraphModule) -> fx.GraphModule:
  976. """
  977. This pass finds the first bwd node in the graph (by looking at users of
  978. tangents) and then reorders the graph by walking from this node to all the
  979. way to the end of the graph. At each op in this traversal, we insert this op
  980. in a new graph and try to bring only the relevant subgraph from the other
  981. non-bwd edges relevant for this op. This closely mimics the behavior of
  982. autograd engine.
  983. Why is this pass required in the first place?
  984. This is an artifact of how partitioners work today. The starting point of
  985. partitioner is a joint graph, which is fwd and then bwd graph. In the case
  986. of checkpointing, we keep portions of fwd graph in their original place in
  987. the joint graph, while obtaining a bwd graph. As a result, the resulting bwd
  988. graph has copies of recomputed fwd subgraphs followed by the original bwd
  989. graph. If we run this naively, this leads to bad memory footprint, because
  990. the fwd subgraphs are live for way longer duration than necessary. This pass
  991. reorders the operations such that we prioritize the ops for the original bwd
  992. graph while only realizing those ops from the fwd graph that are necessary
  993. at any given point in the graph.
  994. """
  995. new_graph = fx.Graph()
  996. env: dict[fx.Node, fx.Node] = {}
  997. # Add new placeholder nodes in the order specified by the inputs
  998. for node in gm.graph.find_nodes(op="placeholder"):
  999. env[node] = new_graph.node_copy(node, lambda x: env[x])
  1000. order = {node: idx for idx, node in enumerate(gm.graph.nodes)}
  1001. def insert_node_in_graph(node):
  1002. cur_nodes = [node]
  1003. insertable_nodes: OrderedSet[fx.Node] = OrderedSet()
  1004. while len(cur_nodes) > 0:
  1005. node = cur_nodes.pop()
  1006. if node in insertable_nodes or node in env:
  1007. continue
  1008. insertable_nodes.add(node)
  1009. # Bias traversal towards the nodes that have higher depth - prioritizes
  1010. # critical path first.
  1011. cur_nodes += node.all_input_nodes
  1012. insertable_nodes = sorted(insertable_nodes, key=lambda n: order[n])
  1013. for node in insertable_nodes:
  1014. env[node] = new_graph.node_copy(node, lambda x: env[x])
  1015. # Find first bwd node in the graph
  1016. tangent_inputs = list(filter(_is_tangent, gm.graph.nodes))
  1017. first_node_in_bwd = None
  1018. minimum_order = math.inf
  1019. for tangent in tangent_inputs:
  1020. for user in tangent.users:
  1021. if order[user] < minimum_order:
  1022. minimum_order = order[user]
  1023. first_node_in_bwd = user
  1024. # If gradInp does not depend upon gradOut, we may not find any nodes in the "backwards pass"
  1025. if first_node_in_bwd is None:
  1026. return gm
  1027. # Build the graph op-by-op by starting from the node all the way to the end
  1028. # copy_ can be not using tangents at all, we must copy it.
  1029. for node in list(gm.graph.nodes)[: order[first_node_in_bwd]]:
  1030. if node.op == "call_function" and node.target == torch.ops.aten.copy_.default:
  1031. insert_node_in_graph(node)
  1032. for node in list(gm.graph.nodes)[order[first_node_in_bwd] :]:
  1033. insert_node_in_graph(node)
  1034. # The output node is already built by the traversal.
  1035. new_gm = torch.fx.GraphModule(gm, new_graph)
  1036. return new_gm
  1037. def apply_graphsafe_rng_functionalization(
  1038. fw_module: torch.fx.GraphModule,
  1039. bw_module: torch.fx.GraphModule,
  1040. fw_node: torch.fx.Node,
  1041. bw_node: torch.fx.Node,
  1042. device: torch.device,
  1043. rng_count: int,
  1044. last_fwd_input: torch.fx.Node,
  1045. last_bwd_input: torch.fx.Node,
  1046. ):
  1047. """
  1048. Note [CUDA Graph Safe RNG Functionalization]
  1049. CUDA Graph capture doesn't work with get_rng_state and set_rng_state because these functions operate on CPU values,
  1050. while CUDA Graph RNG capture uses on-device CUDA tensors. To solve this, we use graphsafe_set_state with a
  1051. CUDA Generator registered to the CUDA Graph before capture begins. graphsafe_set_state updates the generator's pointer
  1052. to reference a different GeneratorImpl, ensuring subsequent calls are correctly forwarded to the desired generator
  1053. (and its cuda-tensor RNG state during graph capture).
  1054. For each RNG operation's forward/backward pair:
  1055. - We create two generators initialized with identical values
  1056. - Each forward and backward call advances its respective generator equally
  1057. - This keeps generators synchronized so forward and backward operations use matching RNG values
  1058. When forward is called multiple times before backward (causing desynchronization):
  1059. - We save the forward RNG state
  1060. - We update the backward Generator's state before executing backward
  1061. Before each CUDA Graph replay, replay_prologue updates captured RNG pointers with current states, ensuring backward Generator
  1062. changes are reflected during replay.
  1063. This function modifies both forward and backward computation graphs by:
  1064. Creating RNG state placeholders for both passes
  1065. Updating the forward node to use graph-safe RNG state
  1066. Updating the backward node to use graph-safe RNG state
  1067. For more details: https://github.com/pytorch/pytorch/issues/113541
  1068. """
  1069. device_idx = device.index
  1070. assert device_idx is not None
  1071. fw_graph = fw_module.graph
  1072. bw_graph = bw_module.graph
  1073. graphsafe_run_with_rng_state = torch._prims.rng_prims.graphsafe_run_with_rng_state
  1074. # Handle forward pass
  1075. # Note: [Generator arguments in AOTDispatcher]
  1076. # Generator arguments in AOTDispatcher are added to support graphsafe rng
  1077. # functionalization. See note above [CUDA Graph Safe RNG Functionalization]
  1078. with fw_module.graph.inserting_after(last_fwd_input):
  1079. fwd_rng_state = fw_module.graph.placeholder(f"fwd_rng_state_{rng_count}")
  1080. fwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
  1081. last_fwd_input = fwd_rng_state
  1082. # Handle backward pass
  1083. with bw_module.graph.inserting_after(last_bwd_input):
  1084. bwd_rng_state = bw_module.graph.placeholder(f"bwd_rng_state_{rng_count}")
  1085. # as above, clone so that meta val generator will not contain tensors
  1086. bwd_rng_state.meta["val"] = get_cuda_generator_meta_val(device_idx)
  1087. last_bwd_input = bwd_rng_state
  1088. # Update forward node
  1089. fw_kwargs = dict(fw_node.kwargs)
  1090. fw_kwargs["rng_state"] = fwd_rng_state
  1091. with fw_module.graph.inserting_after(fw_node):
  1092. functional_fw_node = fw_graph.create_node(
  1093. "call_function",
  1094. graphsafe_run_with_rng_state,
  1095. args=(fw_node.target, *fw_node.args), # type: ignore[arg-type]
  1096. kwargs=fw_kwargs,
  1097. )
  1098. fw_node.replace_all_uses_with(functional_fw_node)
  1099. fw_graph.erase_node(fw_node)
  1100. # Update backward node
  1101. bwd_kwargs = dict(bw_node.kwargs)
  1102. bwd_kwargs["rng_state"] = bwd_rng_state
  1103. with bw_graph.inserting_before(bw_node):
  1104. rng_output = bw_graph.create_node(
  1105. "call_function",
  1106. graphsafe_run_with_rng_state,
  1107. args=(bw_node.target, *bw_node.args), # type: ignore[arg-type]
  1108. kwargs=bwd_kwargs,
  1109. )
  1110. bw_node.replace_all_uses_with(rng_output)
  1111. bw_graph.erase_node(bw_node)
  1112. return last_fwd_input, last_bwd_input
  1113. def functionalize_rng_ops(
  1114. joint_module: fx.GraphModule,
  1115. fw_module: fx.GraphModule,
  1116. bw_module: fx.GraphModule,
  1117. num_sym_nodes: int,
  1118. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  1119. # During user-driven activation checkpointing, we have to ensure that a rng
  1120. # op in fwd yields the same output as the recomputed rng op in the bwd. To
  1121. # do this, we use functionalize wrappers to wrap the random ops and share
  1122. # rng state between the fwd and bwd graphs.
  1123. # There are 3 main steps to do this
  1124. # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
  1125. # Step 2 - Modify the fwd pass such that
  1126. # 1) Replace rand with run_and_save_rng_state wrapper
  1127. # 2) Replace the users of the original op with the output[1] of this op.
  1128. # 3) Collect all the rng_state - output[0] of each op, and make them
  1129. # output nodes. Special care needs to be taken here because fwd outputs
  1130. # has symints at the very end.
  1131. # Step 3 - Modify the bwd pass such that
  1132. # 1) Add the input nodes just before the tangents for the stashed rng states
  1133. # 2) Replace rand with run_with_save_rng_state wrappers
  1134. # 3) Use the stashed states as inputs to these ops
  1135. # Unique id to generate name
  1136. uid = itertools.count()
  1137. def get_rng_ops(gmod):
  1138. random_nodes = {}
  1139. for node in gmod.graph.nodes:
  1140. if (
  1141. node.op == "call_function"
  1142. and hasattr(node.target, "tags")
  1143. and torch.Tag.nondeterministic_seeded in node.target.tags
  1144. ):
  1145. random_nodes[node.name] = node
  1146. return random_nodes
  1147. def get_device(node) -> Optional[torch.device]:
  1148. """
  1149. Check the example value of the node outputs to find the device type.
  1150. """
  1151. if "val" not in node.meta:
  1152. return None
  1153. candidates = node.meta["val"]
  1154. if not isinstance(candidates, tuple):
  1155. candidates = (candidates,)
  1156. for candidate in candidates:
  1157. if isinstance(candidate, torch.Tensor):
  1158. if candidate.device.type == "cuda":
  1159. return candidate.device
  1160. return torch.device("cpu")
  1161. def get_sample_rng_state(device: Optional[torch.device]):
  1162. from torch._guards import detect_fake_mode # noqa: F401
  1163. fake_mode = detect_fake_mode()
  1164. assert fake_mode is not None
  1165. with fake_mode:
  1166. if device is not None and device.type == "cuda":
  1167. return fake_mode.from_tensor(torch.cuda.get_rng_state())
  1168. return fake_mode.from_tensor(torch.get_rng_state())
  1169. # Step 1 - Construct a mapping of rng node between the fwd and its counterpart in bwd.
  1170. joint_graph_rng_ops = get_rng_ops(joint_module)
  1171. fw_graph_rng_ops = get_rng_ops(fw_module)
  1172. bw_graph_rng_ops = get_rng_ops(bw_module)
  1173. recomputable_rng_ops_map = {}
  1174. for node in joint_module.graph.nodes:
  1175. if (
  1176. must_recompute(node)
  1177. and hasattr(node.target, "tags")
  1178. and torch.Tag.nondeterministic_seeded in node.target.tags
  1179. ):
  1180. base_node = joint_graph_rng_ops[node.name]
  1181. fw_node = fw_graph_rng_ops[node.name]
  1182. bw_node = bw_graph_rng_ops[node.name]
  1183. recomputable_rng_ops_map[base_node] = {"fwd": fw_node, "bwd": bw_node}
  1184. run_and_save_rng = torch._prims.rng_prims.run_and_save_rng_state
  1185. run_with_rng_state = torch._prims.rng_prims.run_with_rng_state
  1186. bw_tangent_start_node = None
  1187. for node in bw_module.graph.find_nodes(op="placeholder"):
  1188. if "tangent" in node.name:
  1189. bw_tangent_start_node = node
  1190. break
  1191. if bw_tangent_start_node is None:
  1192. raise RuntimeError(
  1193. "Couldn't find tangent node in graph inputs. This is unexpected, please file a bug if you see this"
  1194. )
  1195. fw_rng_state_outputs = []
  1196. last_fwd_input = next(reversed(fw_module.graph.find_nodes(op="placeholder")))
  1197. last_bwd_input = next(reversed(bw_module.graph.find_nodes(op="placeholder")))
  1198. devices = OrderedSet(
  1199. get_device(node_pair["fwd"]) for node_pair in recomputable_rng_ops_map.values()
  1200. )
  1201. devices.discard(torch.device("cpu"))
  1202. # multiple cuda devices won't work with cudagraphs anyway,
  1203. # fallback to non graphsafe rng checkpointing
  1204. multi_cuda_devices = len(devices) > 1
  1205. # this changes numerics, so if fallback_random is set we will not use it
  1206. ind_config = torch._inductor.config
  1207. use_rng_graphsafe_rng_functionalization = (
  1208. config.graphsafe_rng_functionalization
  1209. and not multi_cuda_devices
  1210. and (
  1211. not ind_config.fallback_random
  1212. or ind_config.test_configs.graphsafe_rng_func_ignores_fallback_random
  1213. )
  1214. )
  1215. for rng_count, (base_node, node_pair) in enumerate(
  1216. recomputable_rng_ops_map.items()
  1217. ):
  1218. # Step 2 - Modify the fwd pass such that
  1219. fw_node = node_pair["fwd"]
  1220. bw_node = node_pair["bwd"]
  1221. device = get_device(fw_node)
  1222. fw_graph = fw_module.graph
  1223. bw_graph = bw_module.graph
  1224. if (
  1225. use_rng_graphsafe_rng_functionalization
  1226. and device is not None
  1227. and device.type == "cuda"
  1228. ):
  1229. last_fwd_input, last_bwd_input = apply_graphsafe_rng_functionalization(
  1230. fw_module,
  1231. bw_module,
  1232. fw_node,
  1233. bw_node,
  1234. device,
  1235. rng_count,
  1236. last_fwd_input,
  1237. last_bwd_input,
  1238. )
  1239. else:
  1240. with fw_graph.inserting_before(fw_node):
  1241. functional_fw_node = fw_graph.create_node(
  1242. "call_function",
  1243. run_and_save_rng,
  1244. args=(fw_node.target, *fw_node.args),
  1245. kwargs=fw_node.kwargs,
  1246. )
  1247. state = fw_graph.create_node(
  1248. "call_function",
  1249. operator.getitem,
  1250. args=(functional_fw_node, 0),
  1251. kwargs={},
  1252. )
  1253. state.meta["val"] = get_sample_rng_state(device)
  1254. rng_output = fw_graph.create_node(
  1255. "call_function",
  1256. operator.getitem,
  1257. args=(
  1258. functional_fw_node,
  1259. 1,
  1260. ),
  1261. kwargs={},
  1262. )
  1263. # Copy the meta data from the original node
  1264. rng_output.meta = copy.copy(fw_node.meta)
  1265. fw_node.replace_all_uses_with(rng_output)
  1266. fw_graph.erase_node(fw_node)
  1267. fw_rng_state_outputs.append(state)
  1268. # Step 3 - Modify the bwd pass such that
  1269. with bw_graph.inserting_before(bw_tangent_start_node):
  1270. state_name = f"rng_state_output_{next(uid)}"
  1271. bw_rng_state_node = bw_graph.placeholder(state_name)
  1272. bw_rng_state_node.meta["val"] = get_sample_rng_state(device)
  1273. with bw_graph.inserting_before(bw_node):
  1274. rng_output = bw_graph.create_node(
  1275. "call_function",
  1276. run_with_rng_state,
  1277. args=(bw_rng_state_node, bw_node.target, *bw_node.args),
  1278. kwargs=bw_node.kwargs,
  1279. )
  1280. bw_node.replace_all_uses_with(rng_output)
  1281. bw_graph.erase_node(bw_node)
  1282. # Add the rng states in the output of the fwd graph. AOT Autograd assumes
  1283. # that symints are at the end of forward graph outputs. So, insert the new
  1284. # rng states accordingly.
  1285. if fw_rng_state_outputs:
  1286. fw_output_node = next(iter(fw_module.graph.find_nodes(op="output")))
  1287. fw_outputs = fw_output_node.args[0]
  1288. sym_node_start_idx = len(fw_outputs) - num_sym_nodes
  1289. outputs = (
  1290. fw_outputs[:sym_node_start_idx]
  1291. + tuple(fw_rng_state_outputs)
  1292. + fw_outputs[sym_node_start_idx:]
  1293. )
  1294. fw_module.graph.output(outputs)
  1295. fw_module.graph.erase_node(fw_output_node)
  1296. fw_module.recompile()
  1297. bw_module.recompile()
  1298. return fw_module, bw_module
  1299. def force_save_collectives(joint_module: fx.GraphModule) -> None:
  1300. """
  1301. By default, the partitioner is not allowed to recompute collectives
  1302. unless they come from a user-annotated AC region.
  1303. See Note [Recomputing collectives in the partitioner]
  1304. """
  1305. for node in joint_module.graph.nodes:
  1306. if (
  1307. isinstance(node.target, torch._ops.OpOverload)
  1308. and node.target.namespace == "_c10d_functional"
  1309. and not must_recompute(node)
  1310. ):
  1311. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1312. def force_save_bw_mutation_src(joint_module: fx.GraphModule) -> None:
  1313. # If we have mutations of the same primal in forward and backward,
  1314. # We must not recompute the source of mutation to not apply twice.
  1315. has_mutation_in_bw: OrderedSet[torch.fx.Node] = OrderedSet()
  1316. for node in reversed(joint_module.graph.nodes):
  1317. if node.op == "output":
  1318. continue
  1319. is_copy_ = node.target == torch.ops.aten.copy_.default
  1320. if is_copy_:
  1321. if _has_tag_must_be_in_backward(node):
  1322. has_mutation_in_bw.add(node.args[0])
  1323. if _has_tag_must_be_in_forward(node) and node.args[0] in has_mutation_in_bw:
  1324. node.args[1].meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1325. else:
  1326. # We use invariant of aotdispatch joint graph,
  1327. # That we emit copy_ only in the end of it.
  1328. # We do not want to iterate through all the joint graph,
  1329. # so break at the first non-output, non-copy_ node.
  1330. break
  1331. def cleanup_recompute_tags(joint_module: fx.GraphModule) -> fx.GraphModule:
  1332. """
  1333. If there are two consecutive checkpointed blocks with no operator in
  1334. between, we would still want to stash the tensor at the boundary of
  1335. checkpointed blocks. The following pass makes the last output node
  1336. non-recomputable to allow for that.
  1337. """
  1338. for node in joint_module.graph.nodes:
  1339. if must_recompute(node):
  1340. for user in node.users:
  1341. if (
  1342. must_recompute(user)
  1343. and user.meta["ac_graph_id"] > node.meta["ac_graph_id"]
  1344. ):
  1345. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1346. if node.meta.get("has_backward_hook", False) and not any(
  1347. must_recompute(user) for user in node.users
  1348. ):
  1349. # If node is AC region output and has a backward hook on it, we intentionally choose to save it.
  1350. # This is to work around circular dependencies in Traceable FSDP2+AC.
  1351. # Example:
  1352. # ```
  1353. # out = fully_shard(utils.checkpoint(module))(x)
  1354. # norm_out = layer_norm(out)
  1355. # ```
  1356. # Here there is a circular dependency:
  1357. # 1. In backward, grad_input of layer_norm aka. `out_grad` is actually dependent on `out`.
  1358. # 2. `out` depends on `out`'s backward hook created by FSDP2 (which does all-gather for `module` weights)
  1359. # in order to be recomputed.
  1360. # 3. `out`'s backward hook, as is the case for all eager backward hooks, depends on `out_grad`
  1361. # -> circular dependency with (1)!
  1362. #
  1363. # Solution: check whether `out` has a backward hook, and if so, intentionally save `out`
  1364. # in forward graph outputs. With this, we can break the above circular dependency.
  1365. node.meta["recompute"] = CheckpointPolicy.MUST_SAVE
  1366. return joint_module
  1367. def solve_min_cut(
  1368. joint_graph: fx.Graph,
  1369. node_info: NodeInfo,
  1370. min_cut_options: MinCutOptions,
  1371. dont_ban: Optional[OrderedSet[fx.Node]] = None,
  1372. ):
  1373. if dont_ban is None:
  1374. dont_ban = OrderedSet()
  1375. op_types = get_default_op_list()
  1376. if AOT_PARTITIONER_DEBUG:
  1377. joint_module_ops = OrderedSet(
  1378. str(node.target._overloadpacket)
  1379. for node in joint_graph.nodes
  1380. if node.op == "call_function" and hasattr(node.target, "_overloadpacket")
  1381. )
  1382. ops_ignored = joint_module_ops - OrderedSet(
  1383. str(i) for i in op_types.recomputable_ops
  1384. )
  1385. log.info("Ops banned from re-materialization: %s", ops_ignored)
  1386. def can_fuse_into_auto_functionalized(a, b):
  1387. if b.target != torch.ops.higher_order.auto_functionalized:
  1388. return False
  1389. mutable_op = b.args[0]
  1390. (
  1391. mutable_arg_names,
  1392. _,
  1393. ) = torch._higher_order_ops.auto_functionalize.get_mutable_args(mutable_op)
  1394. for name in mutable_arg_names:
  1395. arg = b.kwargs[name]
  1396. if a is arg:
  1397. return True
  1398. if isinstance(arg, list):
  1399. if a in arg:
  1400. return True
  1401. return False
  1402. def can_fuse_into_triton_kernel_wrapper_functional(a, b):
  1403. if b.target != torch.ops.higher_order.triton_kernel_wrapper_functional:
  1404. return False
  1405. mutable_arg_names = b.kwargs["tensors_to_clone"]
  1406. for name in mutable_arg_names:
  1407. arg = b.kwargs["kwargs"][name]
  1408. if a is arg:
  1409. return True
  1410. return False
  1411. def is_fusible(a, b):
  1412. # We can perform "memory fusion" into a cat, but cat cannot be a
  1413. # producer to a fusion
  1414. if get_aten_target(b) == aten.cat:
  1415. return True
  1416. if can_fuse_into_auto_functionalized(a, b):
  1417. return True
  1418. if can_fuse_into_triton_kernel_wrapper_functional(a, b):
  1419. return True
  1420. if (
  1421. a.target is operator.getitem
  1422. and a.args[0].target
  1423. is torch.ops.higher_order.triton_kernel_wrapper_functional
  1424. ):
  1425. # if a is the output of a user triton kernel,
  1426. # then (by default) we will not be able to fuse b into it
  1427. return False
  1428. return op_types.is_fusible(a) and op_types.is_fusible(b)
  1429. try:
  1430. import networkx as nx
  1431. except ImportError as e:
  1432. raise RuntimeError(
  1433. "Need networkx installed to perform smart recomputation heuristics"
  1434. ) from e
  1435. def is_materialized_backwards(node):
  1436. if op_types.is_view(node):
  1437. return False
  1438. cur_nodes = OrderedSet([node])
  1439. while len(cur_nodes) > 0:
  1440. cur = cur_nodes.pop()
  1441. for user in cur.users:
  1442. if not node_info.is_required_fw(user) and not is_fusible(cur, user):
  1443. return True
  1444. if op_types.is_view(user):
  1445. cur_nodes.add(user)
  1446. return False
  1447. def should_ban_recomputation(node):
  1448. if node.op != "call_function":
  1449. return False
  1450. if node.target == operator.getitem:
  1451. return False
  1452. if node.meta.get("recompute", None) == CheckpointPolicy.MUST_SAVE:
  1453. return True
  1454. if config.recompute_views and op_types.is_view(node):
  1455. return False
  1456. if node.target in [aten.lift_fresh_copy.default, aten.lift_fresh.default]:
  1457. return False
  1458. if min_cut_options.ban_if_not_in_allowlist:
  1459. if not op_types.is_recomputable(node):
  1460. return True
  1461. else:
  1462. if op_types.is_random(node) or op_types.is_compute_intensive(node):
  1463. return True
  1464. # If a node *must* be materialized in the backwards pass, then we
  1465. # should never recompute it. This is a pretty subtle point. In
  1466. # general, the assumption we make is that recomputing a node in the
  1467. # backwards pass is "free". However, if a node must be materialized
  1468. # in the backwards pass, then recomputing it is never free.
  1469. if min_cut_options.ban_if_materialized_backward and is_materialized_backwards(
  1470. node
  1471. ):
  1472. log.debug("materialized backwards: %s %s", node, tuple(node.users))
  1473. return True
  1474. # Arbitrary hack that sometimes seems to help things. The above
  1475. # modification appears to have made this heuristic a lot less critical
  1476. # for performance.
  1477. # NB: As of PR #121692, this hack no longer seems necessary.
  1478. if node.dist_from_bw < 1000 and node.dist_from_bw > config.max_dist_from_bw:
  1479. return True
  1480. # If the output of an op is 4x smaller (arbitrary choice),
  1481. # then we don't allow recomputation. The idea here is that for
  1482. # things like reductions, saving the output of the reduction is very
  1483. # cheap/small, and it makes sure we don't do things like recompute
  1484. # normalizations in the backwards.
  1485. if min_cut_options.ban_if_reduction:
  1486. input_tensors_size = sum(
  1487. _size_of(i) for i in node.args if isinstance(i, fx.Node)
  1488. )
  1489. output_size = _size_of(node)
  1490. return output_size * 4 < input_tensors_size
  1491. return False
  1492. def is_materialized(node):
  1493. if node.op == "placeholder":
  1494. return True
  1495. return not all(is_fusible(node, user) for user in node.users)
  1496. def get_node_weight(node, static_lifetime_input_nodes) -> float:
  1497. if (
  1498. config.treat_parameters_as_free_to_save
  1499. and node in static_lifetime_input_nodes
  1500. ):
  1501. return 0
  1502. mem_sz = _size_of(node)
  1503. if config.recompute_views and op_types.is_view(node):
  1504. # If `config.recompute_views=True`, we don't save views. This is generally
  1505. # a good idea since views are free to recompute, and it makes it a bit simpler
  1506. # to analyze.
  1507. # NB: If they're not free to recompute (e.g. nested tensors)... I
  1508. # think we should modify checks for view_ops to `is_view` and check
  1509. # that. Basically, with nested tensors, `aten.view` is not a "view
  1510. # op".
  1511. return math.inf
  1512. if isinstance(node.meta["val"], py_sym_types):
  1513. # We never want to save symfloats
  1514. if not isinstance(node.meta["val"], torch.SymInt):
  1515. return INT_INF
  1516. # Heuristic to bias towards nodes closer to the backwards pass
  1517. # Complete guess about current value
  1518. mem_sz = int(mem_sz * (1.1 ** max(min(node.dist_from_bw, 100), 1)))
  1519. if is_materialized(node):
  1520. return mem_sz
  1521. else:
  1522. return mem_sz * 2
  1523. nx_graph = nx.DiGraph()
  1524. banned_nodes: OrderedSet[fx.Node] = OrderedSet()
  1525. def ban_recomputation_if_allowed(node):
  1526. if op_types.is_view(node):
  1527. return False
  1528. if node in dont_ban:
  1529. # collectives are *always* banned from recompute, overriding `dont_ban`
  1530. # (in particular, the activation memory budget logic is not allowed to recompute collectives)
  1531. is_collective = (
  1532. isinstance(node.target, torch._ops.OpOverload)
  1533. and node.target.namespace == "_c10d_functional"
  1534. )
  1535. if config.unsafe_allow_optimization_of_collectives or not is_collective:
  1536. return False
  1537. # This bans recomputation of the node unless we've been forced not to by
  1538. # user annotation
  1539. if must_recompute(node):
  1540. return False
  1541. if "val" in node.meta and isinstance(node.meta["val"], torch.SymFloat):
  1542. return False
  1543. banned_nodes.add(node)
  1544. # A node will only ever be recomputed if there is a path from an
  1545. # ancestor of this node to the backwards path through this node that
  1546. # doesn't go through any saved value. If this node is saved, then that
  1547. # condition is not possible.
  1548. nx_graph.add_edge("source", node.name + "_in", capacity=math.inf)
  1549. return True
  1550. for node in joint_graph.nodes:
  1551. if node.op == "output":
  1552. continue
  1553. if node in node_info.required_bw_nodes:
  1554. if node not in node_info.inputs:
  1555. nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
  1556. continue
  1557. # If someone saves a input for backward as-is and backward
  1558. # returns that tensor as-is as a grad input, then the node x would
  1559. # be both a required_bw_node and an input. In this case we
  1560. # (1) connect x_in to to the source, (2) x_out to the sink, and
  1561. # (3) assign the proper weight to the x_in-x_out edge, so that
  1562. # x would be part of cut nodes. A case where this happens is if
  1563. # NestedTensor saves a offset tensor as part of the singleton int
  1564. # in sizes.
  1565. nx_graph.add_edge(node.name + "_out", "sink", capacity=math.inf)
  1566. if must_recompute(node):
  1567. # If user explicitly says they want to recompute a node, we honor it
  1568. # by adding an inf-capacity edge from X_in to the sink.
  1569. # This way, X_in node is guaranteed to be part of the subgraph that contains "sink"
  1570. # after the cut, thus guaranteeing that X op will be recomputed.
  1571. nx_graph.add_edge(node.name + "_in", "sink", capacity=math.inf)
  1572. continue
  1573. if _is_primal(node) or _is_fwd_seed_offset(node):
  1574. ban_recomputation_if_allowed(node)
  1575. # If a node can't be recomputed (too expensive or involves randomness),
  1576. # we prevent it from being recomputed by adding an inf edge to the source
  1577. # We only need to ban nodes in the fw pass, as those are the only ones that would be recomputed.
  1578. if node_info.is_required_fw(node) and should_ban_recomputation(node):
  1579. ban_recomputation_if_allowed(node)
  1580. # Checks if a node is actually a tuple. Can be simplified to just an isinstance check if we always use faketensors.
  1581. is_non_tensor_node = (
  1582. "val" not in node.meta and "tensor_meta" not in node.meta
  1583. ) or ("val" in node.meta and not isinstance(node.meta["val"], torch.Tensor))
  1584. if is_sym_node(node):
  1585. weight = float(sym_node_size(node))
  1586. elif is_non_tensor_node:
  1587. weight = (
  1588. 0.0 if isinstance(node.meta.get("val"), BackwardState) else math.inf
  1589. )
  1590. else:
  1591. weight = get_node_weight(node, node_info.static_lifetime_input_nodes)
  1592. # Creates the weights on the "node" edge
  1593. nx_graph.add_edge(node.name + "_in", node.name + "_out", capacity=weight)
  1594. for user in node.users:
  1595. nx_graph.add_edge(node.name + "_out", user.name + "_in", capacity=math.inf)
  1596. # todo(chilli): This is the most questionable of the 3 heuristics for banning recompute.
  1597. # Some example models to look at where this helps perf: poolformer_m36,
  1598. # mixer_b16_224, cait_m36_384
  1599. # The "rough" idea here is that if you have some node that is used by both a
  1600. # node nearby downstream as well as a node far downstream, if we recompute
  1601. # both of the downstream nodes, we're unlikely to be able to fuse both
  1602. # downstream nodes together.
  1603. # Thus, we shouldn't aim to recompute far downstream nodes that depend on
  1604. # this node. That intuition of "far downstream" is captured by whether
  1605. # there's an unfusible op along the chain somewhere
  1606. # It could probably be improved by properly analyzing what's going on in the
  1607. # backwards pass instead of only relying on whether it's unfusible in the
  1608. # forwards.
  1609. def find_first_unfusible(start_nodes: list[fx.Node], max_range: int) -> int:
  1610. """
  1611. Finds the first unfusible node in the chain of nodes starting from
  1612. `start_nodes` and returns its position.
  1613. """
  1614. sorted_nodes: list[tuple[int, fx.Node, bool]] = []
  1615. for n in start_nodes:
  1616. heapq.heappush(sorted_nodes, (node_info.get_fw_order(n), n, True))
  1617. while len(sorted_nodes) > 0:
  1618. _, node, node_is_fusible = heapq.heappop(sorted_nodes)
  1619. if not node_is_fusible:
  1620. return node_info.get_fw_order(node)
  1621. for user in node.users:
  1622. if node_info.is_required_fw(user):
  1623. if node_info.get_fw_order(user) > max_range:
  1624. continue
  1625. val: tuple[int, fx.Node, bool] = (
  1626. node_info.get_fw_order(user),
  1627. user,
  1628. is_fusible(node, user),
  1629. )
  1630. if val not in sorted_nodes:
  1631. heapq.heappush(sorted_nodes, val)
  1632. return max_range
  1633. if min_cut_options.ban_if_used_far_apart:
  1634. for used_node in node_info.required_fw_nodes:
  1635. orders = [
  1636. node_info.get_fw_order(user)
  1637. for user in used_node.users
  1638. if node_info.is_required_fw(user)
  1639. ]
  1640. fw_users = [
  1641. user for user in used_node.users if node_info.is_required_fw(user)
  1642. ]
  1643. if len(orders) > 0:
  1644. first_unfusible_use = find_first_unfusible(fw_users, max(orders))
  1645. for user in tuple(used_node.users):
  1646. if (
  1647. node_info.is_required_fw(user)
  1648. and node_info.get_fw_order(user) > first_unfusible_use
  1649. and is_fusible(used_node, user)
  1650. ):
  1651. if user in banned_nodes:
  1652. continue
  1653. log.info(
  1654. "used above/below fusible %s:(%s) -> %s -> %s:(%s)",
  1655. used_node,
  1656. node_info.get_fw_order(used_node),
  1657. first_unfusible_use,
  1658. user,
  1659. node_info.get_fw_order(user),
  1660. )
  1661. ban_recomputation_if_allowed(user)
  1662. # This heuristic is fairly straightforward. The idea is that although it is
  1663. # cheap to recompute bandwidth-bound ops, we don't want to end up in a situation
  1664. # where we have a long chain of pointwise ops from the beginning to the end
  1665. # of the model (like say, residual connections)
  1666. # todo: I'm not totally sure why this heuristic matters. It's possible that this is
  1667. # working around Inductor fusion decisions, or that it's a patch over
  1668. # suboptimal partitioning decisions
  1669. # Some models it improves perf on are cait_m36_384, mixer_b16_224, poolformer_m36
  1670. if min_cut_options.ban_if_long_fusible_chains:
  1671. visited: OrderedSet[fx.Node] = OrderedSet()
  1672. for start_node in joint_graph.nodes:
  1673. if not node_info.is_required_fw(start_node):
  1674. continue
  1675. fusible: list[tuple[int, fx.Node]] = [
  1676. (node_info.get_fw_order(start_node), start_node)
  1677. ]
  1678. start_order = node_info.get_fw_order(start_node)
  1679. while len(fusible) > 0:
  1680. _, cur = heapq.heappop(fusible)
  1681. if cur in visited:
  1682. continue
  1683. visited.add(cur)
  1684. # 100 is arbitrary choice to try and prevent degenerate cases
  1685. if (
  1686. node_info.get_fw_order(cur) > start_order + 100
  1687. and len(fusible) == 0
  1688. ):
  1689. log.info(
  1690. "too long %s %s %s %s",
  1691. cur,
  1692. start_node,
  1693. node_info.get_fw_order(cur),
  1694. node_info.get_fw_order(start_node),
  1695. )
  1696. ban_recomputation_if_allowed(cur)
  1697. break
  1698. for user in cur.users:
  1699. if (
  1700. node_info.is_required_fw(user)
  1701. and is_fusible(cur, user)
  1702. and user not in banned_nodes
  1703. ):
  1704. heapq.heappush(fusible, (node_info.get_fw_order(user), user))
  1705. try:
  1706. cut_value, partition = nx.minimum_cut(nx_graph, "source", "sink")
  1707. except Exception:
  1708. log.info("Failed to compute min-cut on following graph:")
  1709. log.info("\n".join(nx.readwrite.edgelist.generate_edgelist(nx_graph)))
  1710. visualize_min_cut_graph(nx_graph)
  1711. raise
  1712. reachable, non_reachable = partition
  1713. cutset: OrderedSet[tuple[str, str]] = OrderedSet()
  1714. for u, nbrs in ((n, nx_graph[n]) for n in reachable):
  1715. cutset.update((u, v) for v in nbrs if v in non_reachable)
  1716. cut_nodes: OrderedSet[str] = OrderedSet()
  1717. for node_in, node_out in cutset:
  1718. assert node_in[:-3] == node_out[:-4]
  1719. node_name = node_in[:-3]
  1720. cut_nodes.add(node_name)
  1721. name_to_node = get_name_to_node(joint_graph)
  1722. # To make this stuff deterministic
  1723. node_idx = {node: idx for idx, node in enumerate(joint_graph.nodes)}
  1724. saved_values = sorted(
  1725. (name_to_node[node] for node in cut_nodes), key=lambda x: node_idx[x]
  1726. )
  1727. return saved_values, banned_nodes
  1728. def visualize_min_cut_graph(nx_graph):
  1729. import networkx as nx
  1730. import pydot
  1731. dot_format = nx.nx_pydot.to_pydot(nx_graph).to_string()
  1732. dot_graph = pydot.graph_from_dot_data(dot_format)[0] # type: ignore[index]
  1733. for edge in dot_graph.get_edges():
  1734. weight = nx_graph[edge.get_source()][edge.get_destination()]["capacity"]
  1735. # Set edge label to weight
  1736. edge.set_label(str(weight)) # type: ignore[union-attr]
  1737. # Color edges with weight 'inf' as red
  1738. if weight == float("inf"):
  1739. edge.set_color("red") # type: ignore[union-attr]
  1740. log.info("Visualizing the failed graph to min_cut_failed.svg")
  1741. dot_graph.write_svg("min_cut_failed.svg") # type: ignore[union-attr]
  1742. def get_default_op_list() -> OpTypes:
  1743. default_recomputable_ops: list[Callable] = [
  1744. aten.add,
  1745. aten.sub,
  1746. aten.div,
  1747. aten.atan2,
  1748. aten.mul,
  1749. aten.max,
  1750. aten.min,
  1751. aten.pow,
  1752. aten.remainder,
  1753. aten.fmod,
  1754. aten.__and__,
  1755. aten.__or__,
  1756. aten.__xor__,
  1757. aten.__lshift__,
  1758. aten.__rshift__,
  1759. aten.eq,
  1760. aten.ne,
  1761. aten.ge,
  1762. aten.gt,
  1763. aten.le,
  1764. aten.lt,
  1765. aten.abs,
  1766. aten.bitwise_not,
  1767. aten.ceil,
  1768. aten.floor,
  1769. aten.frac,
  1770. aten.neg,
  1771. aten.relu,
  1772. aten.round,
  1773. aten.silu,
  1774. aten.trunc,
  1775. aten.log,
  1776. aten.log10,
  1777. aten.log1p,
  1778. aten.log2,
  1779. aten.lgamma,
  1780. aten.exp,
  1781. aten.expm1,
  1782. aten.erf,
  1783. aten.erfc,
  1784. aten.cos,
  1785. aten.acos,
  1786. aten.cosh,
  1787. aten.sin,
  1788. aten.asin,
  1789. aten.sinh,
  1790. aten.tan,
  1791. aten.atan,
  1792. aten.tanh,
  1793. aten.atanh,
  1794. aten.sqrt,
  1795. aten.rsqrt,
  1796. aten.reciprocal,
  1797. aten.sigmoid,
  1798. aten.softplus,
  1799. aten.threshold,
  1800. aten.threshold_backward,
  1801. aten.clamp,
  1802. aten.where,
  1803. aten.lerp,
  1804. aten.addcmul,
  1805. aten.gelu,
  1806. aten.gelu_backward,
  1807. aten.sum,
  1808. aten.mean,
  1809. aten._grad_sum_to_size,
  1810. aten.sum_to_size,
  1811. aten.amax,
  1812. aten.to,
  1813. aten.type_as,
  1814. operator.getitem,
  1815. aten.squeeze,
  1816. aten.unsqueeze,
  1817. aten.rsub,
  1818. aten._to_copy,
  1819. ] # noqa: E501,B950
  1820. recomputable_view_ops = [aten.squeeze, aten.unsqueeze, aten.alias]
  1821. recomputable_view_ops += [
  1822. aten.view,
  1823. aten.slice,
  1824. aten.t,
  1825. prims.broadcast_in_dim,
  1826. aten.expand,
  1827. aten.as_strided,
  1828. aten.permute,
  1829. aten.select,
  1830. aten.split,
  1831. ]
  1832. view_ops = recomputable_view_ops
  1833. default_recomputable_ops += [
  1834. prims.div,
  1835. prims.convert_element_type,
  1836. aten.clone,
  1837. aten._to_copy,
  1838. aten.full_like,
  1839. prims.var,
  1840. prims.sum,
  1841. aten.var,
  1842. aten.std,
  1843. prims.broadcast_in_dim,
  1844. aten.select,
  1845. aten._unsafe_view,
  1846. aten.view,
  1847. aten.expand,
  1848. aten.slice,
  1849. aten.reshape,
  1850. aten.broadcast_tensors,
  1851. aten.scalar_tensor,
  1852. aten.ones,
  1853. aten.new_zeros,
  1854. aten.lift_fresh_copy,
  1855. aten.arange,
  1856. aten.triu,
  1857. aten.var_mean,
  1858. aten.isinf,
  1859. aten.any,
  1860. aten.full,
  1861. aten.as_strided,
  1862. aten.zeros,
  1863. aten.empty,
  1864. aten.empty_like,
  1865. aten.argmax,
  1866. aten.maximum,
  1867. prims.iota,
  1868. prims._low_memory_max_pool_offsets_to_indices,
  1869. ] # noqa: E501,B950
  1870. # Natalia said that we should allow recomputing indexing :)
  1871. default_recomputable_ops += [aten.index, aten.gather]
  1872. default_recomputable_ops += view_ops
  1873. default_recomputable_ops += pointwise_ops()
  1874. default_recomputable_ops += [
  1875. aten.zeros_like,
  1876. ]
  1877. default_recomputable_ops += [method_to_operator(m) for m in magic_methods]
  1878. recomputable_ops = OrderedSet(default_recomputable_ops)
  1879. random_ops = OrderedSet[Callable[..., Any]](
  1880. [aten.native_dropout, aten.rand_like, aten.randn_like]
  1881. )
  1882. compute_intensive_ops = [
  1883. aten.mm,
  1884. aten.convolution,
  1885. aten.convolution_backward,
  1886. aten.bmm,
  1887. aten.addmm,
  1888. aten._scaled_dot_product_flash_attention,
  1889. aten._scaled_dot_product_efficient_attention,
  1890. aten._flash_attention_forward,
  1891. aten._efficient_attention_forward,
  1892. aten.upsample_bilinear2d,
  1893. aten._scaled_mm,
  1894. ] # noqa: E501,B950
  1895. fusible_ops = recomputable_ops | random_ops
  1896. return OpTypes(
  1897. fusible_ops,
  1898. OrderedSet(compute_intensive_ops),
  1899. random_ops,
  1900. OrderedSet(view_ops),
  1901. recomputable_ops,
  1902. )
  1903. def get_name_to_node(graph: fx.Graph):
  1904. name_to_node = {}
  1905. for node in graph.nodes:
  1906. name_to_node[node.name] = node
  1907. return name_to_node
  1908. def _optimize_runtime_with_given_memory(
  1909. joint_graph: fx.Graph,
  1910. memory: list[float],
  1911. runtimes: list[float],
  1912. max_memory: float,
  1913. node_info: NodeInfo,
  1914. all_recomputable_banned_nodes: list[fx.Node],
  1915. ) -> tuple[float, list[int], list[int]]:
  1916. SOLVER = config.activation_memory_budget_solver
  1917. if SOLVER == "greedy":
  1918. return greedy_knapsack(memory, runtimes, max_memory)
  1919. elif SOLVER == "ilp":
  1920. return ilp_knapsack(memory, runtimes, max_memory)
  1921. elif SOLVER == "dp":
  1922. return dp_knapsack(memory, runtimes, max_memory)
  1923. elif SOLVER == "dynamic_memory_budget_dp":
  1924. log.warning(
  1925. "dynamic_memory_budget_dp is an experimental solver. "
  1926. "It does not guarantee performance improvements. "
  1927. "Additionally, it is not guaranteed to be stable."
  1928. )
  1929. graph_info_provider = GraphInfoProvider.inialize_from_graph(
  1930. joint_graph=joint_graph,
  1931. all_recomputable_banned_nodes=all_recomputable_banned_nodes,
  1932. recorded_knapsack_input_memories=memory,
  1933. recorded_knapsack_input_runtimes=runtimes,
  1934. )
  1935. return dp_knapsack(
  1936. memory,
  1937. runtimes,
  1938. KnapsackEvaluator(
  1939. graph_info_provider=graph_info_provider,
  1940. ).get_knee_point_memory_budget(
  1941. knapsack_algo=dp_knapsack,
  1942. max_mem_budget=max_memory,
  1943. ),
  1944. )
  1945. elif callable(SOLVER):
  1946. saved_node_idx, recomp_node_idx = SOLVER(
  1947. memory, joint_graph, max_memory, node_info, all_recomputable_banned_nodes
  1948. )
  1949. return (0.0, saved_node_idx, recomp_node_idx)
  1950. else:
  1951. raise RuntimeError(f"Not aware of memory budget knapsack solver: {SOLVER}")
  1952. from torch.utils._mode_utils import no_dispatch
  1953. # replace symbols in size and strides with their hints without guarding.
  1954. def _remove_symbols_without_guarding(x: torch.Tensor, fallback: int) -> torch.Tensor:
  1955. shape = list(x.shape)
  1956. def realize_symbol(d):
  1957. return hint_int(d, fallback=fallback)
  1958. shape = [realize_symbol(s) for s in shape]
  1959. stride = [realize_symbol(s) for s in x.stride()]
  1960. return x.new_empty_strided(shape, stride=stride)
  1961. def estimate_runtime(node):
  1962. RUNTIME_MODE = config.activation_memory_budget_runtime_estimator
  1963. def materialize_arg(x):
  1964. if isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.Tensor):
  1965. return _remove_symbols_without_guarding(x.meta["val"], fallback=4096)
  1966. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymInt):
  1967. return hint_int(x.meta["val"], fallback=4096)
  1968. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymFloat):
  1969. return 1.0
  1970. elif isinstance(x, fx.Node) and isinstance(x.meta["val"], torch.SymBool):
  1971. return True
  1972. else:
  1973. return x
  1974. if RUNTIME_MODE == "testing":
  1975. return 1
  1976. elif RUNTIME_MODE == "profile":
  1977. with no_dispatch():
  1978. from torch._inductor.runtime.benchmarking import benchmarker
  1979. args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
  1980. ms = benchmarker.benchmark_gpu(lambda: node.target(*args, **kwargs))
  1981. return ms
  1982. elif RUNTIME_MODE == "flops":
  1983. # todo(chilli): Normalize this to also return ms
  1984. from torch.utils.flop_counter import FlopCounterMode
  1985. args, kwargs = pytree.tree_map(materialize_arg, (node.args, node.kwargs))
  1986. with FlopCounterMode(display=False) as mode:
  1987. node.target(*args, **kwargs)
  1988. counted_flops = mode.get_total_flops()
  1989. return max(counted_flops, 1)
  1990. else:
  1991. raise RuntimeError(f"Not aware of runtime estimator: {RUNTIME_MODE}")
  1992. def choose_saved_values_set(
  1993. joint_graph: fx.Graph,
  1994. node_info: NodeInfo,
  1995. memory_budget=1,
  1996. ) -> list[fx.Node]:
  1997. if memory_budget > 1 or memory_budget < 0:
  1998. raise RuntimeError(
  1999. f"The valid ranges for memory budget are 0 <= m <= 1. The provided value is {memory_budget}"
  2000. )
  2001. min_cut_options = MinCutOptions(
  2002. ban_if_used_far_apart=config.ban_recompute_used_far_apart,
  2003. ban_if_long_fusible_chains=config.ban_recompute_long_fusible_chains,
  2004. ban_if_materialized_backward=config.ban_recompute_materialized_backward,
  2005. ban_if_not_in_allowlist=config.ban_recompute_not_in_allowlist,
  2006. ban_if_reduction=config.ban_recompute_reductions,
  2007. )
  2008. if config.aggressive_recomputation:
  2009. min_cut_options = replace(
  2010. min_cut_options,
  2011. ban_if_used_far_apart=False,
  2012. ban_if_long_fusible_chains=False,
  2013. ban_if_materialized_backward=False,
  2014. ban_if_not_in_allowlist=False,
  2015. )
  2016. if memory_budget == 0:
  2017. return node_info.inputs
  2018. runtime_optimized_saved_values, _ = solve_min_cut(
  2019. joint_graph,
  2020. node_info,
  2021. min_cut_options,
  2022. )
  2023. # return runtime_optimized_saved_values
  2024. if memory_budget == 1:
  2025. return runtime_optimized_saved_values
  2026. def estimate_activations_size(saved_values: list[fx.Node]) -> float:
  2027. return sum(map(_size_of, saved_values)) / 1e9
  2028. min_act_size = estimate_activations_size(node_info.inputs)
  2029. max_act_size = estimate_activations_size(runtime_optimized_saved_values)
  2030. # The optimized choice is smaller than the inputs anyways
  2031. if max_act_size <= min_act_size:
  2032. return runtime_optimized_saved_values
  2033. def get_normalized_size(sz):
  2034. return (sz / 1e9) / (max_act_size - min_act_size)
  2035. def get_mem_ratio(activations: list[fx.Node]):
  2036. return (estimate_activations_size(activations) - min_act_size) / (
  2037. max_act_size - min_act_size
  2038. )
  2039. more_aggressive_options = replace(
  2040. min_cut_options,
  2041. ban_if_used_far_apart=False,
  2042. ban_if_long_fusible_chains=False,
  2043. ban_if_materialized_backward=False,
  2044. )
  2045. more_aggressive_saved_values, _ = solve_min_cut(
  2046. joint_graph, node_info, more_aggressive_options
  2047. )
  2048. if get_mem_ratio(more_aggressive_saved_values) < memory_budget:
  2049. return more_aggressive_saved_values
  2050. aggressive_options = replace(
  2051. more_aggressive_options,
  2052. ban_if_not_in_allowlist=False,
  2053. )
  2054. aggressive_recomputation_saved_values, banned_nodes = solve_min_cut(
  2055. joint_graph, node_info, aggressive_options
  2056. )
  2057. if get_mem_ratio(aggressive_recomputation_saved_values) < memory_budget:
  2058. return aggressive_recomputation_saved_values
  2059. from torch._inductor.fx_utils import get_node_storage
  2060. input_storages = OrderedSet(get_node_storage(node) for node in node_info.inputs)
  2061. def get_recomputable_banned_nodes(
  2062. banned_nodes: OrderedSet[fx.Node],
  2063. ) -> list[fx.Node]:
  2064. return [
  2065. i
  2066. for i in banned_nodes
  2067. if (
  2068. # Only allow recomputing nodes that are actually required for BW
  2069. i.dist_from_bw < int(1e9) # type: ignore[attr-defined]
  2070. and get_node_storage(i) not in input_storages
  2071. )
  2072. ]
  2073. recomputable_banned_nodes = get_recomputable_banned_nodes(banned_nodes)
  2074. must_save_nodes = [
  2075. i
  2076. for i in recomputable_banned_nodes
  2077. if i.meta.get("recompute", False) == CheckpointPolicy.MUST_SAVE
  2078. ]
  2079. recomputable_banned_nodes = [
  2080. i for i in recomputable_banned_nodes if i not in must_save_nodes
  2081. ]
  2082. # default: runtime_optimized_saved_values
  2083. # more aggressive: more_aggressive_saved_values
  2084. # full aggressive: aggressive_recomputation_saved_values
  2085. all_recomputable_banned_nodes = sorted(
  2086. recomputable_banned_nodes, key=_size_of, reverse=True
  2087. )
  2088. if len(all_recomputable_banned_nodes) == 0:
  2089. return node_info.inputs + must_save_nodes
  2090. memories_banned_nodes = [
  2091. get_normalized_size(_size_of(i)) for i in all_recomputable_banned_nodes
  2092. ]
  2093. runtimes_banned_nodes = [
  2094. estimate_runtime(node) for node in all_recomputable_banned_nodes
  2095. ]
  2096. from torch.utils._mode_utils import no_dispatch
  2097. def get_saved_values_knapsack(memory_budget, node_info, joint_graph):
  2098. with no_dispatch():
  2099. (
  2100. expected_runtime,
  2101. saved_node_idxs,
  2102. recomputable_node_idxs,
  2103. ) = _optimize_runtime_with_given_memory(
  2104. joint_graph,
  2105. memories_banned_nodes,
  2106. runtimes_banned_nodes,
  2107. max(memory_budget, 0),
  2108. node_info,
  2109. all_recomputable_banned_nodes,
  2110. )
  2111. dont_ban: OrderedSet[fx.Node] = OrderedSet()
  2112. for idx in recomputable_node_idxs:
  2113. # if idx in all_recomputable_banned_nodes:
  2114. try:
  2115. dont_ban.add(all_recomputable_banned_nodes[idx])
  2116. except BaseException: # noqa: B036
  2117. pass
  2118. assert dont_ban.issubset(all_recomputable_banned_nodes)
  2119. saved_values, _ = solve_min_cut(
  2120. joint_graph,
  2121. node_info,
  2122. aggressive_options,
  2123. dont_ban,
  2124. )
  2125. if AOT_PARTITIONER_DEBUG:
  2126. create_structured_trace_for_min_cut_info(
  2127. joint_graph=joint_graph,
  2128. all_recomputable_banned_nodes=all_recomputable_banned_nodes,
  2129. saved_node_idxs=saved_node_idxs,
  2130. recomputable_node_idxs=recomputable_node_idxs,
  2131. expected_runtime=expected_runtime,
  2132. memories_banned_nodes=memories_banned_nodes,
  2133. runtimes_banned_nodes=runtimes_banned_nodes,
  2134. min_cut_saved_values=saved_values,
  2135. )
  2136. return saved_values, expected_runtime
  2137. if config.visualize_memory_budget_pareto:
  2138. def estimate_for_budget(b):
  2139. saved_values, expected_runtime = get_saved_values_knapsack(
  2140. b, node_info=node_info, joint_graph=joint_graph
  2141. )
  2142. return (
  2143. b,
  2144. sum(runtimes_banned_nodes) - expected_runtime,
  2145. get_mem_ratio(saved_values),
  2146. )
  2147. options = [estimate_for_budget(0.0), estimate_for_budget(1.0)]
  2148. if options[0][1:] != options[1][1:]:
  2149. bisects = [(options[0], options[1])]
  2150. while bisects:
  2151. lhs, rhs = bisects.pop()
  2152. if rhs[0] - lhs[0] < 1e-3:
  2153. options.append(lhs)
  2154. options.append(rhs)
  2155. continue
  2156. mid = estimate_for_budget((lhs[0] + rhs[0]) / 2)
  2157. if mid[1:] != lhs[1:]:
  2158. bisects.append((lhs, mid))
  2159. if mid[1:] != rhs[1:]:
  2160. bisects.append((mid, rhs))
  2161. options.sort()
  2162. import matplotlib.pyplot as plt
  2163. x_values = [item[2] for item in options]
  2164. y_values = [item[1] for item in options]
  2165. # Plotting the values with updated axis labels and chart title
  2166. plt.figure(figsize=(10, 6))
  2167. plt.plot(x_values, y_values, marker="o")
  2168. # Adding labels for each point
  2169. for i, txt in enumerate(x_values):
  2170. plt.annotate(
  2171. f"{txt:.4f}",
  2172. (txt, y_values[i]),
  2173. textcoords="offset points",
  2174. xytext=(0, 10),
  2175. ha="center",
  2176. )
  2177. plt.xlabel("Memory Budget")
  2178. plt.ylabel("Runtime of Recomputed Components")
  2179. plt.title("Pareto Frontier of Memory Budget vs. Recomputation Runtime")
  2180. plt.grid(True)
  2181. fig = plt.gcf()
  2182. plt.show()
  2183. fig_dir = os.getcwd()
  2184. if config.memory_budget_pareto_dir is not None:
  2185. fig_dir = config.memory_budget_pareto_dir
  2186. os.makedirs(fig_dir, exist_ok=True)
  2187. rank_suffix = ""
  2188. if torch.distributed.is_available() and torch.distributed.is_initialized():
  2189. rank_suffix = f"_rank_{torch.distributed.get_rank()}"
  2190. fig_name = os.path.join(
  2191. fig_dir, f"memory_budget_pareto{rank_suffix}_{get_aot_graph_name()}.svg"
  2192. )
  2193. fig.savefig(fig_name)
  2194. log.warning("Generated Pareto frontier curve at %s", fig_name)
  2195. # todo(chilli): Estimated doesn't align exactly with actual - actual is
  2196. # usually less memory than estimated. i'm guessing (actually quite
  2197. # unsure about this) that's because estimated is just only including
  2198. # tensors we actually banned from recompute, but there may be other
  2199. # tensors that we choose to save.
  2200. return get_saved_values_knapsack(
  2201. memory_budget=memory_budget, node_info=node_info, joint_graph=joint_graph
  2202. )[0]
  2203. def _sync_decision_cross_ranks(
  2204. joint_graph: torch.fx.Graph, saved_values: list[torch.fx.Node]
  2205. ):
  2206. # use the same policy across different GPUs
  2207. from torch._subclasses.fake_tensor import unset_fake_temporarily
  2208. def has_collectives(joint_graph):
  2209. for node in joint_graph.nodes:
  2210. if isinstance(
  2211. node.target, torch._ops.OpOverload
  2212. ) and node.target.namespace in {"_c10d_functional", "c10d_functional"}:
  2213. return True
  2214. return False
  2215. def has_same_nodes(joint_graph):
  2216. # proxy to check if the graph is the same across different GPUs.
  2217. # We only consider the name and order of nodes. A more robust way
  2218. # would be to check the hash of the whole graph (disregarding input shapes),
  2219. # this is is a reasonable first-order approximation.
  2220. node_str = "/".join(x.name for x in joint_graph.nodes)
  2221. inputs = hashlib.sha256(node_str.encode("utf-8")).hexdigest()
  2222. all_inputs = [None for _ in range(torch.distributed.get_world_size())]
  2223. with no_dispatch(), unset_fake_temporarily():
  2224. # TODO: maybe use a different process group?
  2225. torch.distributed.all_gather_object(all_inputs, inputs)
  2226. return all(all_inputs[0] == x for x in all_inputs)
  2227. if (
  2228. torch.distributed.is_available()
  2229. and torch.distributed.is_initialized()
  2230. and torch.distributed.get_world_size() > 1
  2231. and has_collectives(joint_graph)
  2232. and has_same_nodes(joint_graph)
  2233. ):
  2234. with no_dispatch(), unset_fake_temporarily():
  2235. objects = [[x.name for x in saved_values]]
  2236. saved_ops_names_all_ranks: list[list[str]] = [
  2237. [] for _ in range(torch.distributed.get_world_size())
  2238. ]
  2239. torch.distributed.all_gather_object(saved_ops_names_all_ranks, objects[0])
  2240. name_to_node = get_name_to_node(joint_graph)
  2241. saved_sizes: list[int] = []
  2242. saved_ops_with_sizes: dict[str, int] = {}
  2243. for idx, saved_ops_names in enumerate(saved_ops_names_all_ranks):
  2244. saved_nodes = [name_to_node[op_name] for op_name in saved_ops_names]
  2245. saved_size = 0
  2246. for node in saved_nodes:
  2247. size_of_node = _size_of(node)
  2248. saved_size += size_of_node
  2249. if idx == torch.distributed.get_rank():
  2250. saved_ops_with_sizes[node.name] = size_of_node
  2251. saved_ops_with_sizes["total size"] = saved_size
  2252. saved_sizes.append(saved_size)
  2253. saved_sizes_tensor = torch.tensor(
  2254. saved_sizes,
  2255. device=torch.distributed.distributed_c10d._get_object_coll_device(),
  2256. )
  2257. torch.distributed.all_reduce(
  2258. saved_sizes_tensor, op=torch.distributed.distributed_c10d.ReduceOp.MAX
  2259. )
  2260. picked_rank_idx = int(torch.argmin(saved_sizes_tensor).item())
  2261. sync_decision_cross_ranks_str = f"picked_rank_idx={picked_rank_idx}, saved_nodes of current rank={saved_ops_with_sizes}"
  2262. trace_structured(
  2263. "artifact",
  2264. metadata_fn=lambda: {
  2265. "name": "aot_joint_graph_sync_decision_cross_ranks",
  2266. "encoding": "string",
  2267. },
  2268. payload_fn=lambda: sync_decision_cross_ranks_str,
  2269. )
  2270. saved_values = [
  2271. name_to_node[n] for n in saved_ops_names_all_ranks[picked_rank_idx]
  2272. ]
  2273. return saved_values
  2274. def thread_graphsafe_rng_from_hops(module, is_backward):
  2275. """
  2276. Graph-safe RNG lets torch.compile use CUDA Graphs for graphs with RNG ops.
  2277. For graphs without HOPs, the partitioner adds placeholder nodes
  2278. fwd_rng_state_* and bw_rng_state_* to the forward and backward graphs. At
  2279. runtime, the AOTDispatcher retrieves these RNG states and passes them to the
  2280. compiled graphs.
  2281. This works well for no-HOP graphs. With HOPs, the partitioner runs
  2282. recursively: it first partitions the HOP (producing forward/backward HOP
  2283. subgraphs) and then stitches them back into the outer joint graph. For HOPs
  2284. that contain RNG ops, the outer joint graph now includes HOP subgraph
  2285. modules with extra RNG placeholders. We must thread these placeholders
  2286. through the outer module partitioned forward and backward graphs—this
  2287. function does exactly that. It collects the RNG placeholder nodes from the
  2288. HOPs and creates corresponding placeholders in the outer forward and
  2289. backward graphs.
  2290. There is a catch: for a short period, the joint graph is in a “bad” state.
  2291. The HOP subgraphs expect additional inputs (because of the new
  2292. placeholders), but the outer graph call sites don't yet provide them. We
  2293. can't fix this in the joint graph because the joint graph's input signature
  2294. is fixed (primals, tangents). As a compromise, we keep the joint graph in
  2295. somewhat of a bad state for some time and, once the outer forward and
  2296. backward graphs are partitioned, insert the corresponding RNG placeholders
  2297. and wire up the calls.
  2298. """
  2299. rng_count = 0
  2300. rng_string = "bwd_rng_state" if is_backward else "fwd_rng_state"
  2301. last_input = next(reversed(module.graph.find_nodes(op="placeholder")))
  2302. for hop_node in module.graph.find_nodes(
  2303. op="call_function", target=torch.ops.higher_order.invoke_subgraph
  2304. ):
  2305. subgraph = getattr(module, hop_node.args[0].target)
  2306. if isinstance(subgraph, fx.GraphModule):
  2307. new_rng_inputs = []
  2308. for idx, placeholder_node in enumerate(
  2309. subgraph.graph.find_nodes(op="placeholder")
  2310. ):
  2311. if rng_string in placeholder_node.name:
  2312. # Found a rng state placeholder in the hop graph, lets add
  2313. # the corresponding node in the outer graph
  2314. with module.graph.inserting_after(last_input):
  2315. rng_state = module.graph.placeholder(
  2316. f"{rng_string}_{rng_count}"
  2317. )
  2318. rng_count += 1
  2319. rng_state.meta["val"] = placeholder_node.meta["val"]
  2320. last_input = rng_state
  2321. new_rng_inputs.append(rng_state)
  2322. if new_rng_inputs:
  2323. # Pass on the new args that include the new_rng_inputs
  2324. with module.graph.inserting_after(hop_node):
  2325. new_hop_node_with_fixed_args = module.graph.create_node(
  2326. "call_function",
  2327. torch.ops.higher_order.invoke_subgraph,
  2328. (*hop_node.args, *new_rng_inputs), # type: ignore[arg-type]
  2329. {},
  2330. )
  2331. hop_node.replace_all_uses_with(
  2332. new_hop_node_with_fixed_args, propagate_meta=True
  2333. )
  2334. # Setup the eager_input_vals
  2335. eager_vals = hop_node.meta.get("eager_input_vals")
  2336. if eager_vals:
  2337. eager_args, eager_kwargs = eager_vals
  2338. new_eager_args = (
  2339. *eager_args,
  2340. *[inp.meta["val"] for inp in new_rng_inputs],
  2341. )
  2342. new_hop_node_with_fixed_args.meta["eager_input_vals"] = (
  2343. new_eager_args,
  2344. eager_kwargs,
  2345. )
  2346. module.graph.erase_node(hop_node)
  2347. return module
  2348. def min_cut_rematerialization_partition(
  2349. joint_module: fx.GraphModule,
  2350. _joint_inputs,
  2351. compiler="inductor",
  2352. *,
  2353. num_fwd_outputs,
  2354. static_lifetime_input_indices: Optional[list[int]] = None,
  2355. ) -> tuple[fx.GraphModule, fx.GraphModule]:
  2356. """
  2357. Partitions the joint graph such that the backward recomputes the forward.
  2358. Recomputing helps in trading off memory bandwidth with computation.
  2359. To create the fwd and bwd graph, we copy the joint graph, manually set the
  2360. outputs to just original forward or backward outputs. And then we run the
  2361. resulting graphs through dead code elimination.
  2362. .. warning::
  2363. This API is experimental and likely to change.
  2364. Args:
  2365. joint_module(fx.GraphModule): The joint forward and backward graph. This
  2366. is the result of AOT Autograd tracing.
  2367. _joint_inputs: The inputs to the joint graph. This is unused.
  2368. compiler: This option determines the default set of recomputable ops.
  2369. Currently, there are two options: ``nvfuser`` and ``inductor``.
  2370. recomputable_ops: This is an optional set of recomputable ops. If this
  2371. is not None, then this set of ops will be used instead of the
  2372. default set of ops.
  2373. num_fwd_outputs: The number of outputs from the forward graph.
  2374. Returns:
  2375. Returns the generated forward and backward Fx graph modules.
  2376. """
  2377. joint_module.graph.eliminate_dead_code()
  2378. joint_module.recompile()
  2379. fx_g = joint_module.graph
  2380. # add the CSE pass
  2381. if config.cse:
  2382. cse_graph = fx_graph_cse(fx_g)
  2383. joint_module.graph = cse_graph
  2384. joint_graph = joint_module.graph
  2385. graph_has_recomputable_ops = has_recomputable_ops(joint_module)
  2386. graph_has_recomputable_rng_ops = has_recomputable_rng_ops(joint_module)
  2387. if graph_has_recomputable_ops:
  2388. joint_module = cleanup_recompute_tags(joint_module)
  2389. if not config.unsafe_allow_optimization_of_collectives:
  2390. force_save_collectives(joint_module)
  2391. force_save_bw_mutation_src(joint_module)
  2392. def classify_nodes(joint_module, static_lifetime_input_indices):
  2393. name_to_node = get_name_to_node(joint_module.graph)
  2394. required_bw_nodes: OrderedSet[fx.Node] = OrderedSet()
  2395. for node in joint_module.graph.nodes:
  2396. if node.op == "placeholder" and "tangents" in node.target:
  2397. required_bw_nodes.add(node)
  2398. elif _must_be_in_backward(node):
  2399. required_bw_nodes.add(node)
  2400. if node in required_bw_nodes:
  2401. required_bw_nodes.update(node.users)
  2402. primal_inputs = list(filter(_is_primal, joint_module.graph.nodes))
  2403. fwd_seed_offset_inputs = list(
  2404. filter(_is_fwd_seed_offset, joint_module.graph.nodes)
  2405. )
  2406. inputs = primal_inputs + fwd_seed_offset_inputs
  2407. fwd_outputs, bwd_outputs, fwd_outputs_descs, bwd_outputs_descs = (
  2408. _extract_fwd_bwd_outputs(joint_module, num_fwd_outputs=num_fwd_outputs)
  2409. )
  2410. required_bw_nodes.update(
  2411. o for o in bwd_outputs if o is not None and o.op != "output"
  2412. )
  2413. forward_only_graph = _extract_graph_with_inputs_outputs(
  2414. joint_module.graph, inputs, fwd_outputs, fwd_outputs_descs, "forward"
  2415. )
  2416. required_fw_nodes: OrderedSet[fx.Node] = OrderedSet(
  2417. name_to_node[node.name]
  2418. for node in forward_only_graph.nodes
  2419. if node.op != "output"
  2420. )
  2421. unclaimed_nodes: OrderedSet[fx.Node] = OrderedSet(
  2422. node
  2423. for node in joint_module.graph.nodes
  2424. if node not in required_fw_nodes and node not in required_bw_nodes
  2425. )
  2426. static_lifetime_input_nodes = OrderedSet(
  2427. p for i, p in enumerate(primal_inputs) if i in static_lifetime_input_indices
  2428. )
  2429. fw_cnt = 0
  2430. fw_order = {}
  2431. for node in joint_module.graph.nodes:
  2432. if node in required_fw_nodes:
  2433. fw_order[node] = fw_cnt
  2434. fw_cnt += 1
  2435. return NodeInfo(
  2436. inputs,
  2437. required_fw_nodes,
  2438. required_bw_nodes,
  2439. unclaimed_nodes,
  2440. fw_order,
  2441. static_lifetime_input_nodes,
  2442. )
  2443. if static_lifetime_input_indices is None:
  2444. static_lifetime_input_indices = []
  2445. node_info = classify_nodes(joint_module, static_lifetime_input_indices)
  2446. # networkx blows up on graphs with no required backward nodes
  2447. # Since there's nothing to partition anyway, and the default partitioner can "handle"
  2448. # this case, send our graph over to the default partitioner.
  2449. if len(node_info.required_bw_nodes) == 0:
  2450. return default_partition(
  2451. joint_module,
  2452. _joint_inputs,
  2453. num_fwd_outputs=num_fwd_outputs,
  2454. static_lifetime_input_indices=static_lifetime_input_indices,
  2455. static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
  2456. )
  2457. for node in reversed(joint_module.graph.nodes):
  2458. if node.op == "output":
  2459. node.dist_from_bw = int(1e9)
  2460. elif not node_info.is_required_fw(node):
  2461. node.dist_from_bw = 0
  2462. else:
  2463. node.dist_from_bw = int(1e9)
  2464. for user in node.users:
  2465. node.dist_from_bw = min(node.dist_from_bw, user.dist_from_bw + 1)
  2466. memory_budget = config.activation_memory_budget
  2467. for node in joint_graph.nodes:
  2468. if isinstance(node.meta.get("memory_budget", None), float):
  2469. memory_budget = node.meta["memory_budget"]
  2470. break
  2471. saved_values = choose_saved_values_set(
  2472. joint_graph,
  2473. node_info,
  2474. memory_budget=memory_budget,
  2475. )
  2476. if config._sync_decision_cross_ranks:
  2477. saved_values = _sync_decision_cross_ranks(joint_graph, saved_values)
  2478. # save_for_backward on tensors and stashes symints in autograd .ctx
  2479. saved_sym_nodes = list(filter(is_sym_node, saved_values))
  2480. saved_values = list(filter(lambda n: not is_sym_node(n), saved_values))
  2481. # NB: saved_sym_nodes will be mutated to reflect the actual saved symbols
  2482. fw_module, bw_module = _extract_fwd_bwd_modules(
  2483. joint_module,
  2484. saved_values,
  2485. saved_sym_nodes=saved_sym_nodes,
  2486. num_fwd_outputs=num_fwd_outputs,
  2487. static_lifetime_input_nodes=node_info.static_lifetime_input_nodes,
  2488. )
  2489. if graph_has_recomputable_ops:
  2490. if graph_has_recomputable_rng_ops:
  2491. fw_module, bw_module = functionalize_rng_ops(
  2492. joint_module, fw_module, bw_module, len(saved_sym_nodes)
  2493. )
  2494. bw_module = reordering_to_mimic_autograd_engine(bw_module)
  2495. # raise all getitem ops to as early as possible
  2496. # this is helpful for memory, especially in the case of aot_eager backend
  2497. fw_module = raise_getitems(fw_module)
  2498. bw_module = raise_getitems(bw_module)
  2499. fw_module = thread_graphsafe_rng_from_hops(fw_module, is_backward=False)
  2500. bw_module = thread_graphsafe_rng_from_hops(bw_module, is_backward=True)
  2501. if AOT_PARTITIONER_DEBUG:
  2502. # Calculate sorted sizes of saved values
  2503. sorted_sizes = sorted([(_size_of(i), str(i)) for i in saved_values])
  2504. # Log total theoretical activations stored
  2505. total_activations_size_gb = sum(_size_of(i) for i in saved_values) / 1e9
  2506. log.info("Theoretical Activations Stored: %.2f GB", total_activations_size_gb)
  2507. # Log theoretical per activation storage sizes
  2508. log.info("Theoretical Per Activation Storage Sizes: %s", sorted_sizes)
  2509. fw_module_nodes = OrderedSet(
  2510. node.name for node in fw_module.graph.nodes if node.op == "call_function"
  2511. )
  2512. bw_module_nodes = OrderedSet(
  2513. node.name for node in bw_module.graph.nodes if node.op == "call_function"
  2514. )
  2515. remat_nodes = fw_module_nodes & bw_module_nodes
  2516. counts: dict[str, int] = defaultdict(int)
  2517. for node in fw_module.graph.nodes:
  2518. if node.name in remat_nodes and hasattr(node.target, "_overloadpacket"):
  2519. counts[str(node.target._overloadpacket)] += 1
  2520. log.info(
  2521. "# remat/fw/bw: %d/%d/%d",
  2522. len(remat_nodes),
  2523. len(fw_module_nodes),
  2524. len(bw_module_nodes),
  2525. )
  2526. rematerialized_ops = sorted(
  2527. counts.items(), key=operator.itemgetter(1), reverse=True
  2528. )
  2529. log.info("Count of Ops Rematerialized: %s", rematerialized_ops)
  2530. return fw_module, bw_module
  2531. def draw_graph(
  2532. traced: torch.fx.GraphModule,
  2533. fname: str,
  2534. figname: str = "fx_graph",
  2535. clear_meta: bool = True,
  2536. prog: Optional[Union[str, list[str]]] = None,
  2537. parse_stack_trace: bool = False,
  2538. dot_graph_shape: Optional[str] = None,
  2539. ) -> None:
  2540. if clear_meta:
  2541. new_graph = copy.deepcopy(traced.graph)
  2542. traced = fx.GraphModule(traced, new_graph)
  2543. for node in traced.graph.nodes:
  2544. node.meta = {}
  2545. base, ext = os.path.splitext(fname)
  2546. if not ext:
  2547. ext = "." + config.torch_compile_graph_format
  2548. log.info("Writing FX graph to file: %s%s", base, ext)
  2549. g = graph_drawer.FxGraphDrawer(
  2550. traced,
  2551. figname,
  2552. parse_stack_trace=parse_stack_trace,
  2553. dot_graph_shape=dot_graph_shape,
  2554. )
  2555. x = g.get_main_dot_graph()
  2556. write_method = getattr(x, "write_" + ext.lstrip("."))
  2557. fname = f"{base}{ext}"
  2558. if prog is None:
  2559. write_method(fname)
  2560. else:
  2561. write_method(fname, prog=prog)