checkpoint.py 67 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599
  1. # mypy: allow-untyped-defs
  2. import contextlib
  3. import platform
  4. import uuid
  5. import warnings
  6. import weakref
  7. from collections import defaultdict
  8. from typing import * # noqa: F403
  9. import enum
  10. from weakref import ReferenceType
  11. import torch
  12. import torch.fx.traceback as fx_traceback
  13. from torch.utils._pytree import tree_map
  14. from torch.testing._internal.logging_tensor import capture_logs, LoggingTensorMode
  15. from torch.utils._python_dispatch import TorchDispatchMode
  16. __all__ = [
  17. "checkpoint",
  18. "checkpoint_sequential",
  19. "CheckpointError",
  20. "CheckpointFunction",
  21. "check_backward_validity",
  22. "detach_variable",
  23. "get_device_states",
  24. "set_device_states",
  25. "noop_context_fn",
  26. "set_checkpoint_early_stop",
  27. "DefaultDeviceType",
  28. "set_checkpoint_debug_enabled",
  29. "CheckpointPolicy",
  30. "SelectiveCheckpointContext",
  31. "create_selective_checkpoint_contexts",
  32. "SAC_IGNORED_OPS",
  33. ]
  34. _DEFAULT_DETERMINISM_MODE = "default"
  35. _checkpoint_debug_enabled: Optional[bool] = None
  36. @contextlib.contextmanager
  37. def set_checkpoint_debug_enabled(enabled: Optional[bool]):
  38. """
  39. Context manager that sets whether checkpoint should print additional debug
  40. information when running. See the ``debug`` flag for
  41. :func:`~torch.utils.checkpoint.checkpoint` for more information. Note that
  42. when set, this context manager overrides the value of ``debug`` passed to
  43. checkpoint. To defer to the local setting, pass ``None`` to this context.
  44. Args:
  45. enabled (bool): Whether checkpoint should print debug information.
  46. Default is 'None'.
  47. """
  48. global _checkpoint_debug_enabled
  49. try:
  50. prev = _checkpoint_debug_enabled
  51. _checkpoint_debug_enabled = enabled
  52. yield
  53. finally:
  54. _checkpoint_debug_enabled = prev
  55. def detach_variable(inputs: Tuple[Any, ...]) -> Tuple[torch.Tensor, ...]:
  56. if isinstance(inputs, tuple):
  57. out = []
  58. for inp in inputs:
  59. if not isinstance(inp, torch.Tensor):
  60. out.append(inp)
  61. continue
  62. x = inp.detach()
  63. x.requires_grad = inp.requires_grad
  64. out.append(x)
  65. return tuple(out)
  66. else:
  67. raise RuntimeError(
  68. "Only tuple of tensors is supported. Got Unsupported input type: ",
  69. type(inputs).__name__,
  70. )
  71. def check_backward_validity(inputs: Iterable[Any]) -> None:
  72. if not any(inp.requires_grad for inp in inputs if isinstance(inp, torch.Tensor)):
  73. warnings.warn(
  74. "None of the inputs have requires_grad=True. Gradients will be None"
  75. )
  76. def _get_device_module(device="cuda"):
  77. if device == "meta":
  78. return torch.device("meta")
  79. device_module = getattr(torch, device)
  80. return device_module
  81. class DefaultDeviceType:
  82. r"""
  83. A class that manages the default device type for checkpointing.
  84. If no non-CPU tensors are present, the default device type will
  85. be used. The default value is 'cuda'. The device type is used in
  86. the checkpointing process when determining which device states
  87. to save and restore for recomputation.
  88. """
  89. _default_device_type = "cuda"
  90. @staticmethod
  91. def set_device_type(device: str = "cuda"):
  92. """
  93. Set the default device type for checkpointing.
  94. Args:
  95. device (str): The device type to be set as default. Default is 'cuda'.
  96. """
  97. DefaultDeviceType._default_device_type = device
  98. @staticmethod
  99. def get_device_type() -> str:
  100. """
  101. Get the current default device type for checkpointing.
  102. Returns:
  103. str: The current default device type.
  104. """
  105. return DefaultDeviceType._default_device_type
  106. def _infer_device_type(*args):
  107. device_types = []
  108. def add_device_types(arg):
  109. nonlocal device_types
  110. if isinstance(arg, torch.Tensor) and not arg.device.type == "cpu":
  111. device_types.append(arg.device.type)
  112. tree_map(add_device_types, args)
  113. device_types_set = set(device_types)
  114. if len(device_types_set) > 1:
  115. warnings.warn(
  116. "Tensor arguments, excluding CPU tensors, are detected on at least two types of devices. "
  117. "Device state will only be saved for devices of a single device type, and the remaining "
  118. "devices will be ignored. Consequently, if any checkpointed functions involve randomness, "
  119. "this may result in incorrect gradients. (Note that if CUDA devices are among the devices "
  120. "detected, it will be prioritized; otherwise, the first device encountered will be selected.)"
  121. f"\nDevice types: {sorted(device_types_set)} first device type: {device_types[0]}"
  122. )
  123. if len(device_types) == 0:
  124. return DefaultDeviceType.get_device_type()
  125. elif "cuda" in device_types_set:
  126. return "cuda"
  127. else:
  128. return device_types[0]
  129. # We can't know if the run_fn will internally move some args to different devices,
  130. # which would require logic to preserve rng states for those devices as well.
  131. # We could paranoically stash and restore ALL the rng states for all visible devices,
  132. # but that seems very wasteful for most cases. Compromise: Stash the RNG state for
  133. # the device of all Tensor args.
  134. #
  135. # To consider: maybe get_device_states and set_device_states should reside in torch/random.py?
  136. def get_device_states(*args) -> Tuple[List[int], List[torch.Tensor]]:
  137. # This will not error out if "arg" is a CPU tensor or a non-tensor type because
  138. # the conditionals short-circuit.
  139. fwd_device_ids = []
  140. def add_device_ids(arg):
  141. nonlocal fwd_device_ids
  142. if isinstance(arg, torch.Tensor) and arg.device.type not in {"cpu", "meta"}:
  143. fwd_device_ids.append(arg.get_device())
  144. tree_map(add_device_ids, args)
  145. fwd_device_states = []
  146. device_module = _get_device_module(_infer_device_type(*args))
  147. for device_id in fwd_device_ids:
  148. with device_module.device(device_id):
  149. fwd_device_states.append(device_module.get_rng_state())
  150. return fwd_device_ids, fwd_device_states
  151. def set_device_states(devices, states, *, device_type=None) -> None:
  152. """Sets random number generator states for the specified devices.
  153. Args:
  154. devices: Device ids to set states for.
  155. states: States to set.
  156. device_type: ``device_type`` of the devices to set states for. Default
  157. is the device returned by a call to ``DefaultDeviceType.get_device_type()``,
  158. which is ``cuda`` if not changed by calling ``DefaultDeviceType::set_device_type()``.
  159. """
  160. if device_type is None:
  161. device_type = DefaultDeviceType.get_device_type()
  162. if device_type == "meta":
  163. return
  164. device_module = _get_device_module(device_type)
  165. for device, state in zip(devices, states):
  166. with device_module.device(device):
  167. device_module.set_rng_state(state)
  168. def _get_autocast_kwargs(device_type="cuda"):
  169. if torch.amp.is_autocast_available(device_type):
  170. device_autocast_kwargs = {
  171. "enabled": torch.is_autocast_enabled(device_type),
  172. "dtype": torch.get_autocast_dtype(device_type),
  173. "cache_enabled": torch.is_autocast_cache_enabled(),
  174. }
  175. else:
  176. device_autocast_kwargs = None
  177. cpu_autocast_kwargs = {
  178. "enabled": torch.is_autocast_enabled('cpu'),
  179. "dtype": torch.get_autocast_dtype('cpu'),
  180. "cache_enabled": torch.is_autocast_cache_enabled(),
  181. }
  182. return device_autocast_kwargs, cpu_autocast_kwargs
  183. class CheckpointFunction(torch.autograd.Function):
  184. @staticmethod
  185. def forward(ctx, run_function, preserve_rng_state, *args):
  186. check_backward_validity(args)
  187. ctx.run_function = run_function
  188. ctx.preserve_rng_state = preserve_rng_state
  189. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  190. ctx.device_type = _infer_device_type(*args)
  191. ctx.device_autocast_kwargs, ctx.cpu_autocast_kwargs = _get_autocast_kwargs(
  192. ctx.device_type
  193. )
  194. if preserve_rng_state:
  195. ctx.fwd_cpu_state = torch.get_rng_state()
  196. # Don't eagerly initialize the cuda context by accident.
  197. # (If the user intends that the context is initialized later, within their
  198. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  199. # we have no way to anticipate this will happen before we run the function.)
  200. ctx.had_device_in_fwd = False
  201. device_module = _get_device_module(ctx.device_type)
  202. if getattr(device_module, "_initialized", False):
  203. ctx.had_device_in_fwd = True
  204. ctx.fwd_devices, ctx.fwd_device_states = get_device_states(*args)
  205. # Save non-tensor inputs in ctx, keep a placeholder None for tensors
  206. # to be filled out during the backward.
  207. ctx.inputs = []
  208. ctx.tensor_indices = []
  209. tensor_inputs = []
  210. for i, arg in enumerate(args):
  211. if torch.is_tensor(arg):
  212. tensor_inputs.append(arg)
  213. ctx.tensor_indices.append(i)
  214. ctx.inputs.append(None)
  215. else:
  216. ctx.inputs.append(arg)
  217. ctx.save_for_backward(*tensor_inputs)
  218. with torch.no_grad():
  219. outputs = run_function(*args)
  220. return outputs
  221. @staticmethod
  222. def backward(ctx, *args):
  223. if not torch.autograd._is_checkpoint_valid():
  224. raise RuntimeError(
  225. "When use_reentrant=True, torch.utils.checkpoint is incompatible"
  226. " with .grad() or passing an `inputs` parameter to .backward()."
  227. " To resolve this error, you can either set use_reentrant=False,"
  228. " or call .backward() without passing the `inputs` argument."
  229. )
  230. # Copy the list to avoid modifying original list.
  231. inputs = list(ctx.inputs)
  232. tensor_indices = ctx.tensor_indices
  233. tensors = ctx.saved_tensors
  234. # Fill in inputs with appropriate saved tensors.
  235. for i, idx in enumerate(tensor_indices):
  236. inputs[idx] = tensors[i]
  237. # Stash the surrounding rng state, and mimic the state that was
  238. # present at this time during forward. Restore the surrounding state
  239. # when we're done.
  240. rng_devices = []
  241. if ctx.preserve_rng_state and ctx.had_device_in_fwd:
  242. rng_devices = ctx.fwd_devices
  243. with torch.random.fork_rng(
  244. devices=rng_devices, enabled=ctx.preserve_rng_state, device_type=ctx.device_type
  245. ):
  246. if ctx.preserve_rng_state:
  247. torch.set_rng_state(ctx.fwd_cpu_state)
  248. if ctx.had_device_in_fwd:
  249. set_device_states(ctx.fwd_devices, ctx.fwd_device_states, device_type=ctx.device_type)
  250. detached_inputs = detach_variable(tuple(inputs))
  251. device_autocast_ctx = torch.amp.autocast(
  252. device_type=ctx.device_type, **ctx.device_autocast_kwargs
  253. ) if torch.amp.is_autocast_available(ctx.device_type) else contextlib.nullcontext()
  254. with torch.enable_grad(), device_autocast_ctx, torch.amp.autocast("cpu", **ctx.cpu_autocast_kwargs): # type: ignore[attr-defined]
  255. outputs = ctx.run_function(*detached_inputs)
  256. if isinstance(outputs, torch.Tensor):
  257. outputs = (outputs,)
  258. # run backward() with only tensor that requires grad
  259. outputs_with_grad = []
  260. args_with_grad = []
  261. for i in range(len(outputs)):
  262. if torch.is_tensor(outputs[i]) and outputs[i].requires_grad:
  263. outputs_with_grad.append(outputs[i])
  264. args_with_grad.append(args[i])
  265. if len(outputs_with_grad) == 0:
  266. raise RuntimeError(
  267. "none of output has requires_grad=True,"
  268. " this checkpoint() is not necessary"
  269. )
  270. torch.autograd.backward(outputs_with_grad, args_with_grad)
  271. grads = tuple(
  272. inp.grad if isinstance(inp, torch.Tensor) else None
  273. for inp in detached_inputs
  274. )
  275. return (None, None) + grads
  276. def noop_context_fn():
  277. return contextlib.nullcontext(), contextlib.nullcontext()
  278. # Note: [torch.compile and checkpoint]
  279. # TorchDynamo does not step inside utils.checkpoint function. The flow
  280. # looks likes this
  281. # 1) TorchDynamo tries to wrap utils.checkpoint in a HigherOrderOp by
  282. # speculatively checking if the forward function is safe to trace.
  283. # 2) If yes, then Dynamo-generated Fx graph has the wrapped higher
  284. # order op. As a result, TorchDynamo does not look inside utils.checkpoint.
  285. # 3) If not, then TorchDynamo falls back to eager by performing a graph
  286. # break. And here, the following disable wrapper ensures that
  287. # TorchDynamo does not trigger again on the frames created by
  288. # utils.checkpoint innards.
  289. @torch._disable_dynamo
  290. def checkpoint(
  291. function,
  292. *args,
  293. use_reentrant: Optional[bool] = None,
  294. context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
  295. determinism_check: str = _DEFAULT_DETERMINISM_MODE,
  296. debug: bool = False,
  297. early_stop: bool = True,
  298. **kwargs
  299. ):
  300. r"""Checkpoint a model or part of the model.
  301. Activation checkpointing is a technique that trades compute for memory.
  302. Instead of keeping tensors needed for backward alive until they are used in
  303. gradient computation during backward, forward computation in checkpointed
  304. regions omits saving tensors for backward and recomputes them during the
  305. backward pass. Activation checkpointing can be applied to any part of a
  306. model.
  307. There are currently two checkpointing implementations available, determined
  308. by the :attr:`use_reentrant` parameter. It is recommended that you use
  309. ``use_reentrant=False``. Please refer the note below for a discussion of
  310. their differences.
  311. .. warning::
  312. If the :attr:`function` invocation during the backward pass differs
  313. from the forward pass, e.g., due to a global variable, the checkpointed
  314. version may not be equivalent, potentially causing an
  315. error being raised or leading to silently incorrect gradients.
  316. .. warning::
  317. The ``use_reentrant`` parameter should be passed explicitly. In version
  318. 2.9 we will raise an exception if ``use_reentrant`` is not passed.
  319. If you are using the ``use_reentrant=True`` variant, please refer to the
  320. note below for important considerations and potential limitations.
  321. .. note::
  322. The reentrant variant of checkpoint (``use_reentrant=True``) and
  323. the non-reentrant variant of checkpoint (``use_reentrant=False``)
  324. differ in the following ways:
  325. * Non-reentrant checkpoint stops recomputation as soon as all needed
  326. intermediate activations have been recomputed. This feature is enabled
  327. by default, but can be disabled with :func:`set_checkpoint_early_stop`.
  328. Reentrant checkpoint always recomputes :attr:`function` in its
  329. entirety during the backward pass.
  330. * The reentrant variant does not record the autograd graph during the
  331. forward pass, as it runs with the forward pass under
  332. :func:`torch.no_grad`. The non-reentrant version does record the
  333. autograd graph, allowing one to perform backward on the graph within
  334. checkpointed regions.
  335. * The reentrant checkpoint only supports the
  336. :func:`torch.autograd.backward` API for the backward pass without its
  337. `inputs` argument, while the non-reentrant version supports all ways
  338. of performing the backward pass.
  339. * At least one input and output must have ``requires_grad=True`` for the
  340. reentrant variant. If this condition is unmet, the checkpointed part
  341. of the model will not have gradients. The non-reentrant version does
  342. not have this requirement.
  343. * The reentrant version does not consider tensors in nested structures
  344. (e.g., custom objects, lists, dicts, etc) as participating in
  345. autograd, while the non-reentrant version does.
  346. * The reentrant checkpoint does not support checkpointed regions with
  347. detached tensors from the computational graph, whereas the
  348. non-reentrant version does. For the reentrant variant, if the
  349. checkpointed segment contains tensors detached using ``detach()`` or
  350. with :func:`torch.no_grad`, the backward pass will raise an error.
  351. This is because ``checkpoint`` makes all the outputs require gradients
  352. and this causes issues when a tensor is defined to have no gradient in
  353. the model. To avoid this, detach the tensors outside of the
  354. ``checkpoint`` function.
  355. Args:
  356. function: describes what to run in the forward pass of the model or
  357. part of the model. It should also know how to handle the inputs
  358. passed as the tuple. For example, in LSTM, if user passes
  359. ``(activation, hidden)``, :attr:`function` should correctly use the
  360. first input as ``activation`` and the second input as ``hidden``
  361. args: tuple containing inputs to the :attr:`function`
  362. Keyword args:
  363. preserve_rng_state(bool, optional): Omit stashing and restoring
  364. the RNG state during each checkpoint. Note that under torch.compile,
  365. this flag doesn't take effect and we always preserve RNG state.
  366. Default: ``True``
  367. use_reentrant(bool):
  368. specify whether to use the activation checkpoint variant that
  369. requires reentrant autograd. This parameter should be passed
  370. explicitly. In version 2.9 we will raise an exception if
  371. ``use_reentrant`` is not passed. If ``use_reentrant=False``,
  372. ``checkpoint`` will use an implementation that does not require
  373. reentrant autograd. This allows ``checkpoint`` to support additional
  374. functionality, such as working as expected with
  375. ``torch.autograd.grad`` and support for keyword arguments input into
  376. the checkpointed function.
  377. context_fn(Callable, optional): A callable returning a tuple of two
  378. context managers. The function and its recomputation will be run
  379. under the first and second context managers respectively.
  380. This argument is only supported if ``use_reentrant=False``.
  381. determinism_check(str, optional): A string specifying the determinism
  382. check to perform. By default it is set to ``"default"`` which
  383. compares the shapes, dtypes, and devices of the recomputed tensors
  384. against those the saved tensors. To turn off this check, specify
  385. ``"none"``. Currently these are the only two supported values.
  386. Please open an issue if you would like to see more determinism
  387. checks. This argument is only supported if ``use_reentrant=False``,
  388. if ``use_reentrant=True``, the determinism check is always disabled.
  389. debug(bool, optional): If ``True``, error messages will also include
  390. a trace of the operators ran during the original forward computation
  391. as well as the recomputation. This argument is only supported if
  392. ``use_reentrant=False``.
  393. early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
  394. recomputation as soon as it has computed all needed Tensors. This
  395. argument is ignored if ``use_reentrant=True``. Can be overridden
  396. globally using :func:`set_checkpoint_early_stop` context manager.
  397. Default: ``True``.
  398. Returns:
  399. Output of running :attr:`function` on :attr:`*args`
  400. """
  401. if use_reentrant is None:
  402. warnings.warn(
  403. "torch.utils.checkpoint: the use_reentrant parameter should be "
  404. "passed explicitly. Starting in PyTorch 2.9, calling checkpoint "
  405. "without use_reentrant will raise an exception. use_reentrant=False is "
  406. "recommended, but if you need to preserve the current default "
  407. "behavior, you can pass use_reentrant=True. Refer to docs for more "
  408. "details on the differences between the two variants.",
  409. stacklevel=2
  410. )
  411. use_reentrant = True
  412. # Hack to mix *args with **kwargs in a python 2.7-compliant way
  413. preserve = kwargs.pop("preserve_rng_state", True)
  414. if kwargs and use_reentrant:
  415. raise ValueError(
  416. "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
  417. )
  418. if use_reentrant:
  419. if context_fn is not noop_context_fn or debug is not False:
  420. raise ValueError(
  421. "Passing `context_fn` or `debug` is only supported when "
  422. "use_reentrant=False."
  423. )
  424. return CheckpointFunction.apply(function, preserve, *args)
  425. else:
  426. gen = _checkpoint_without_reentrant_generator(
  427. function, preserve, context_fn, determinism_check, debug, early_stop, *args, **kwargs
  428. )
  429. # Runs pre-forward logic
  430. next(gen)
  431. ret = function(*args, **kwargs)
  432. # Runs post-forward logic
  433. try:
  434. next(gen)
  435. except StopIteration:
  436. return ret
  437. def checkpoint_sequential(functions, segments, input, use_reentrant=None, **kwargs):
  438. r"""Checkpoint a sequential model to save memory.
  439. Sequential models execute a list of modules/functions in order
  440. (sequentially). Therefore, we can divide such a model in various segments
  441. and checkpoint each segment. All segments except the last will not store
  442. the intermediate activations. The inputs of each checkpointed segment will
  443. be saved for re-running the segment in the backward pass.
  444. .. warning::
  445. The ``use_reentrant`` parameter should be passed explicitly. In version
  446. 2.9 we will raise an exception if ``use_reentrant`` is not passed.
  447. If you are using the ``use_reentrant=True` variant, please see
  448. :func:`~torch.utils.checkpoint.checkpoint` for
  449. the important considerations and limitations of this variant. It is
  450. recommended that you use ``use_reentrant=False``.
  451. .. warning:
  452. Since PyTorch 1.4, it allows only one Tensor as the input and
  453. intermediate outputs, just like :class:`torch.nn.Sequential`.
  454. Args:
  455. functions: A :class:`torch.nn.Sequential` or the list of modules or
  456. functions (comprising the model) to run sequentially.
  457. segments: Number of chunks to create in the model
  458. input: A Tensor that is input to :attr:`functions`
  459. preserve_rng_state(bool, optional): Omit stashing and restoring
  460. the RNG state during each checkpoint.
  461. Default: ``True``
  462. use_reentrant(bool):
  463. specify whether to use the activation checkpoint variant that
  464. requires reentrant autograd. This parameter should be passed
  465. explicitly. In version 2.5 we will raise an exception if
  466. ``use_reentrant`` is not passed. If ``use_reentrant=False``,
  467. ``checkpoint`` will use an implementation that does not require
  468. reentrant autograd. This allows ``checkpoint`` to support additional
  469. functionality, such as working as expected with
  470. ``torch.autograd.grad`` and support for keyword arguments input into
  471. the checkpointed function.
  472. Returns:
  473. Output of running :attr:`functions` sequentially on :attr:`*inputs`
  474. Example:
  475. >>> # xdoctest: +SKIP("stub")
  476. >>> model = nn.Sequential(...)
  477. >>> input_var = checkpoint_sequential(model, chunks, input_var)
  478. """
  479. if use_reentrant is None:
  480. warnings.warn(
  481. "torch.utils.checkpoint.checkpoint_sequential: the use_reentrant "
  482. "parameter should be passed explicitly. "
  483. "In version 2.9 we will raise an exception if use_reentrant "
  484. "is not passed. use_reentrant=False is "
  485. "recommended, but if you need to preserve the current default "
  486. "behavior, you can pass use_reentrant=True. Refer to docs for more "
  487. "details on the differences between the two variants."
  488. )
  489. use_reentrant = True
  490. # Hack for keyword-only parameter in a python 2.7-compliant way
  491. preserve = kwargs.pop("preserve_rng_state", True)
  492. if kwargs:
  493. raise ValueError(
  494. "Unexpected keyword arguments: " + ",".join(arg for arg in kwargs)
  495. )
  496. def run_function(start, end, functions):
  497. def forward(input):
  498. for j in range(start, end + 1):
  499. input = functions[j](input)
  500. return input
  501. return forward
  502. if isinstance(functions, torch.nn.Sequential):
  503. functions = list(functions.children())
  504. segment_size = len(functions) // segments
  505. # the last chunk has to be non-volatile
  506. end = -1
  507. for start in range(0, segment_size * (segments - 1), segment_size):
  508. end = start + segment_size - 1
  509. input = checkpoint(
  510. run_function(start, end, functions),
  511. input,
  512. use_reentrant=use_reentrant,
  513. preserve_rng_state=preserve,
  514. )
  515. return run_function(end + 1, len(functions) - 1, functions)(input)
  516. def _internal_assert(cond):
  517. if not cond:
  518. raise AssertionError(
  519. "Something went unexpectedly wrong in activation checkpoint. "
  520. "Please report this bug by filing an issue to PyTorch."
  521. )
  522. # NOTE [ Nestable Checkpoint ]
  523. #
  524. # The semantics of nested checkpoint can be defined by two basic rules.
  525. # Following the two rules leads to an important implication that is central
  526. # to motivating the design.
  527. #
  528. # Rule 1. Saved tensors are managed by inner-most checkpoint only and hidden
  529. # from any outer layers of checkpoint.
  530. #
  531. # Rule 2. The inputs of inner checkpoints are treated as tensors saved to its
  532. # parent checkpoint.
  533. #
  534. # Implication: To recompute any given saved tensor, we need to recompute all of
  535. # the checkpoints wrapping it.
  536. #
  537. # Why is this implied? To unpack a saved tensor X during backward we need to
  538. # recompute the inner-most checkpoint (#1), and in order to recompute that
  539. # checkpoint I need to have its inputs, which are managed by that checkpoint's
  540. # parent (#2), which thus also needs to be recomputed first. Continue this line
  541. # of reasoning and we realize that in order to unpack X, all checkpoints that
  542. # were active at the time X was saved need to be recomputed. (unless we have
  543. # already done so in that backward for some other saved tensor).
  544. #
  545. # In practice, we use a noop autograd Function to save inputs as saved tensors.
  546. # During unpack calling ctx.saved_tensor triggers the parent checkpoint to
  547. # recompute.
  548. #
  549. # Rule 3. We should start recomputation as if there are no checkpoints currently
  550. # active. Checkpoints encountered during recomputation are still
  551. # respected.
  552. #
  553. # When we start recomputation, we push the saved variable hook meant for
  554. # recomputation on the stack. See examples in Rule 6 for more context.
  555. #
  556. # * * * *
  557. #
  558. # Beyond the basic semantics specific to nested checkpoint, we impose several
  559. # more constraints that may apply to checkpointing in general.
  560. #
  561. # Rule 4. Lifetime of recomputed tensors
  562. #
  563. # Recomputed tensors are considered specific to particular invocations
  564. # of backward and are always cleared immediately as they are unpacked
  565. # Particularly, we require this to happen even if retain_graph=True.
  566. #
  567. # [ Implementation details of Rule 4 ]
  568. #
  569. # If we were okay with recomputed tensors staying alive after backward is run
  570. # with retain_graph=True, we would store recomputed variables as the values of a
  571. # WeakKeyDictionary and pack strong references to the keys, so that as we
  572. # backward, those packed keys would be cleared as long as retain_graph=False.
  573. # Clearing the packed key clears the corresponding entry in the WKD.
  574. #
  575. # If we wish recomputed variables to be immediately cleared as we unpack them in
  576. # the retain_graph=True case, we cannot rely on the packed keys to be cleared by
  577. # backward automatically. Instead of packing the strong reference to the key
  578. # directly, we pack a container object, which we manually clear as we unpack.
  579. #
  580. # An important detail is that if a second backward happens, the second
  581. # recomputation needs to reset the container with a newly created key.
  582. #
  583. # Rule 5. Stop recomputation as soon as we've recomputed the saved tensors we
  584. # know we need.
  585. #
  586. # [ Implementation details of Rule 5 ]
  587. #
  588. # During recomputation, raise an exception if the number of recomputed tensors
  589. # matches the number of tensors that we expected to recompute. We wrap the
  590. # recomputation call with a try-catch to catch this specific exception. See
  591. # Rule #6 below for some examples.
  592. #
  593. # Rule 6. We support doing backward inside checkpoint context
  594. #
  595. # [ retain_graph is True]
  596. #
  597. # def fn(x):
  598. # y = x.sin()
  599. # z = y.cos()
  600. # gx, = torch.autograd.grad(z, x, retains_grad=True)
  601. # return gx, z
  602. #
  603. # out = checkpoint(fn)(inp)
  604. # out.backward()
  605. #
  606. # Because z is saved by cos while checkpoint is enabled, it would not be
  607. # actually saved, and so the .grad() call inside must trigger a recomputation.
  608. #
  609. # During recomputation the "inner pack hook" has two responsibilities:
  610. #
  611. # 1) As usual, populating the WeakKeyDictionary storing recomputed tensors
  612. # 2) Pack the actual tensor (detached) so that one may perform backward on the
  613. # recomputed graph. The tensors saved to this graph will live until the end
  614. # of recomputation, or die earlier if someone performs backward with
  615. # retain_graph=False.
  616. #
  617. # More generally performing backward on the recomputed graph occurs in the
  618. # following cases:
  619. # - If backward is performed inside forward,
  620. # - During the original forward IF early-stop is disabled
  621. # - During the original backward
  622. # - If there are multiple .grad()/.backward() calls, we would perform backward
  623. # on the recomputed graph even if early-stop is enabled (see the example below)
  624. #
  625. # [ retain_graph is False ]
  626. #
  627. # The example below shows what happens if during recomputation we find that some
  628. # of the tensors we are trying to recompute have already been cleared.
  629. #
  630. # Spoiler: we don't do anything special, we just skip over them!
  631. #
  632. # def fn(x):
  633. # y = x.sin() # (1)
  634. # z = y.cos() # (2)
  635. # gx, = torch.autograd.grad(z, x) # (3)
  636. # return x.cos() * gx # (4)
  637. #
  638. # out = checkpoint(fn)(inp)
  639. # out.backward() # (5)
  640. #
  641. # 1, 2. Don't save x and y since we are inside a checkpoint.
  642. # 3. Trigger a recompute of fn since x and y weren't saved.
  643. # And depending on whether early stop is enabled, either stop at (2) or
  644. # continue running the function.
  645. # Because we are running backward with retain_graph=False, we clear x and y's
  646. # holders.
  647. # 4. Don't save x since we are inside a checkpoint.
  648. # 5. Calling backward triggers another recompute of fn. During recompute, we see
  649. # that x and y have already been cleared in the original graph as indicated
  650. # by holder=None. We skip over them. We still save x at (4) (since its holder
  651. # is still alive.)
  652. _enable_checkpoint_early_stop: Optional[bool] = None
  653. @contextlib.contextmanager
  654. def set_checkpoint_early_stop(enable: bool):
  655. """Context manager that sets whether checkpoint should stop recomputation early.
  656. By default, non-reentrant checkpoint stops recomputation as soon as it
  657. has computed all needed Tensors. This context manager can be used to disable
  658. that feature if it is problematic for your specific application.
  659. This context manager only needs to be active when forward is run. It does
  660. not need to be active during backward.
  661. Example::
  662. >>> # xdoctest: +SKIP(failing)
  663. >>> message = "saved tensors default hooks are disabled"
  664. >>> with set_checkpoint_early_stop(False):
  665. ... # Any checkpoint under this context manager will respect this
  666. ... # context manager, even if its backward is performed outside.
  667. ... out = checkpoint(fn, inputs)
  668. ...
  669. >>> out.backward()
  670. """
  671. global _enable_checkpoint_early_stop
  672. try:
  673. prev = _enable_checkpoint_early_stop
  674. _enable_checkpoint_early_stop = enable
  675. yield
  676. finally:
  677. _enable_checkpoint_early_stop = prev
  678. class _Handle:
  679. pass
  680. class _Holder:
  681. def __init__(self):
  682. self.handles: Dict[int, Optional[_Handle]] = {}
  683. class _NoopSaveInputs(torch.autograd.Function):
  684. @staticmethod
  685. def forward(*args):
  686. return torch.empty((0,))
  687. @staticmethod
  688. def setup_context(ctx: Any, inputs: Tuple[Any, ...], output: Any) -> None:
  689. # Only tensors can be saved with ctx.save_for_backward, everything else
  690. # is captured by get_args, which is saved directly on ctx
  691. tensor_indices, tensors = zip(
  692. *[(i, o) for i, o in enumerate(inputs) if isinstance(o, torch.Tensor)]
  693. )
  694. idx2saved_idx = {b: a for a, b in enumerate(tensor_indices)}
  695. # args but with tensors replaced with None as placeholders
  696. args = [None if isinstance(o, torch.Tensor) else o for o in inputs]
  697. def get_args(saved_tensors):
  698. # restore the placeholders with the original tensors grabbed from
  699. # ctx.saved_tensors (which may be saved on a parent checkpoint if
  700. # this checkpoint is nested, and that would trigger a recursive
  701. # unpack!)
  702. ret = [
  703. saved_tensors[idx2saved_idx[i]] if i in tensor_indices else o
  704. for i, o in enumerate(args)
  705. ]
  706. # grab the tail since we also saved the dummy to avoid having to explicitly
  707. # handle the case where there are no tensor inputs
  708. return ret[1:]
  709. ctx.get_args = get_args
  710. ctx.save_for_backward(*tensors)
  711. @staticmethod
  712. def backward(ctx, *grad_outputs):
  713. raise AssertionError("Did not expect to backward on this graph")
  714. class _CheckpointFrame:
  715. def __init__(self, recompute_fn, early_stop, unpack_error_cb, metadata_fn):
  716. self.recompute_fn = recompute_fn
  717. self.input_saver = None
  718. self.weak_holders: List[ReferenceType] = []
  719. # We store this as a weakkeydictionary so that in the case of a partial
  720. # backward, the entries in the dict are cleared alongside the Holder
  721. # which will be removed when the SavedVariable is cleared.
  722. self.recomputed: DefaultDict[
  723. int, weakref.WeakKeyDictionary[_Handle, torch.Tensor]
  724. ] = defaultdict(weakref.WeakKeyDictionary)
  725. # We need both recomp_counter and recomputed since they can diverge
  726. # https://github.com/pytorch/pytorch/pull/90105#discussion_r1135889885
  727. self.recomp_counter: DefaultDict[int, int] = defaultdict(int)
  728. self.is_recomputed: DefaultDict[int, bool] = defaultdict(bool)
  729. # See Rule 5
  730. self.early_stop = early_stop
  731. # Debugging
  732. self.metadata_fn = metadata_fn
  733. self.unpack_error_cb = unpack_error_cb
  734. self.x_metadatas = []
  735. self.forward_completed = False
  736. self.ignore_saved_mismatch = False
  737. def check_recomputed_tensors_match(self, gid):
  738. if self.ignore_saved_mismatch:
  739. # TODO: we can probably make this check stricter by checking that
  740. # the metadata of the first tensors still match.
  741. return
  742. # NOTE [ Error handling for checkpoint ]
  743. #
  744. # At a high level, we need to check that the tensors saved
  745. # during original forward matches tensors saved during recompute
  746. # This means handling 3 cases:
  747. #
  748. # 1. During recompute, more tensors were saved.
  749. #
  750. # Usually this is hidden due to the StopRecomputationError
  751. # but if early stop is not enabled, or we would have errored
  752. # anyway because there aren't enough weak_holders. But we
  753. # do want to have a nice error. See the _recomputation_hook
  754. # for details.
  755. if not len(self.weak_holders) == self.recomp_counter[gid]:
  756. # 2. During recompute, fewer tensors were saved
  757. #
  758. # We know that every time we save something do original forward
  759. # we append to weak_holder, and every time we save a tensor
  760. # during recompute we increment recompute_counter.
  761. raise CheckpointError(
  762. "torch.utils.checkpoint: A different number of tensors was saved "
  763. "during the original forward and recomputation.\n"
  764. f"Number of tensors saved during forward: {len(self.weak_holders)}\n"
  765. f"Number of tensors saved during recomputation: {self.recomp_counter[gid]}.\n"
  766. f"{_debug_tip_msg}"
  767. )
  768. # 3. During recompute, the same tensors were saved, but they
  769. # have different metadata
  770. nb_meta_different = []
  771. for idx, weak_holder in enumerate(self.weak_holders):
  772. holder = weak_holder()
  773. if holder is None:
  774. continue
  775. # We've seen all holders since we iterate over them in order
  776. # For every holder that is still alive now, it must've been
  777. # alive when we saw it during recompute, therefore, the
  778. # gid must be set.
  779. _internal_assert(gid in holder.handles)
  780. # We know this is the first unpack, so it couldn't have been set
  781. # to None yet.
  782. _internal_assert(holder.handles[gid] is not None)
  783. # We always set these together in the recomputation hook
  784. _internal_assert(holder.handles[gid] in self.recomputed[gid])
  785. # see pack hook, x_metadata is 1:1 with weak_holders.
  786. x_meta = self.x_metadatas[idx]
  787. recomputed_x = self.recomputed[gid][holder.handles[gid]]
  788. if x_meta != self.metadata_fn(recomputed_x):
  789. nb_meta_different.append((idx, x_meta, self.metadata_fn(recomputed_x)))
  790. if len(nb_meta_different) > 0:
  791. mismatched_tensors = ""
  792. for idx, x_meta, recomputed_meta in nb_meta_different:
  793. mismatched_tensors += (
  794. f"tensor at position {idx}:\n"
  795. f"saved metadata: {x_meta}\n"
  796. f"recomputed metadata: {recomputed_meta}\n"
  797. )
  798. raise CheckpointError(
  799. "torch.utils.checkpoint: Recomputed values for the following tensors "
  800. "have different metadata than during the forward pass.\n"
  801. f"{mismatched_tensors}.\n"
  802. f"{_debug_tip_msg}"
  803. )
  804. _debug_tip_msg = """
  805. Tip: To see a more detailed error message, either pass `debug=True` to
  806. `torch.utils.checkpoint.checkpoint(...)` or wrap the code block
  807. with `with torch.utils.checkpoint.set_checkpoint_debug_enabled(True):` to
  808. enable checkpoint‑debug mode globally.
  809. """
  810. _checkpoint_error_template = """ \
  811. An error happened while unpacking tensors; dumping logs of latest computation
  812. because you passed `debug=True` to `torch.utils.checkpoint.checkpoint()`.
  813. Scroll all the way down for guidance on how to navigate these logs.
  814. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  815. | 1. Stack traces of the operators that ran in the original forward |
  816. +------------------------------------------------------------------------------+
  817. {forward_traces}
  818. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  819. | 2. Stack traces of the operators that ran during recomputation |
  820. +------------------------------------------------------------------------------+
  821. {recompute_traces}
  822. +~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~+
  823. | 3. Log of operators in the original forward and recomputation |
  824. +------------------------------------------------------------------------------+
  825. (Scroll up to correlate stack traces with each operation listed below. This
  826. helps identify their source in the code.)
  827. IMPORTANT: Differences in "detach" calls between the original forward and the
  828. recomputation are expected. They are introduced by the checkpointing
  829. mechanism and can be ignored.
  830. Operations executed during the original forward:
  831. {forward_ops}
  832. Operations executed during recomputation:
  833. {recompute_ops}
  834. +------------------------------------------------------------------------------+
  835. ERROR: Detected non-determinism while running activation checkpointing
  836. You are seeing this error because you passed `debug=True` to checkpoint and
  837. tensors to be saved during the original forward and differ between those saved
  838. during recomputation. This can happen if different operators were ran in the
  839. original forward and in the recomputation.
  840. To identify where the mismatch may be coming from, you can do the following:
  841. 1) Compare the operators ran during original forward and recomputation to
  842. see where they differ. These operators are printed above in the order they
  843. were executed.
  844. 2) Review the stack trace for each operator to locate its invocation source.
  845. Each operator's stack trace is printed in their execution order.
  846. Note that the logs can be quite long. Here's how they are structured:
  847. (Tip: you can Ctrl-f for these headers)
  848. 1. Stack traces of the operators that ran in the original forward
  849. 2. Stack traces of the operators that ran during recomputation
  850. 3. Log of operators in the original forward and recomputation
  851. 4. Error message <--- You are here
  852. --------------------------------------------------------------------------------
  853. """
  854. class CheckpointError(RuntimeError):
  855. pass
  856. def _get_debug_context_and_cb() -> Tuple[Callable[[], Any], Callable[[CheckpointError], None]]:
  857. # This function returns the context_fn and error_cb to be used by the
  858. # checkpointing mechanism. error_cb is invoked when an error is detected
  859. # during unpack.
  860. # record_context_cpp is not support on non-linux non-x86_64 platforms
  861. cpp_tb = platform.machine() == 'x86_64' and platform.system() == 'Linux'
  862. class CaptureLogs:
  863. def __init__(self):
  864. self.logs = None
  865. self.tbs = None
  866. def get_context_manager(self):
  867. @contextlib.contextmanager
  868. def logging_mode():
  869. with LoggingTensorMode(), \
  870. capture_logs(True, python_tb=True, script_tb=True, cpp_tb=cpp_tb) as logs_and_tb:
  871. self.logs, self.tbs = logs_and_tb
  872. yield logs_and_tb
  873. return logging_mode()
  874. capture_logs_fwd = CaptureLogs()
  875. capture_logs_recompute = CaptureLogs()
  876. def unpack_error_cb(e: CheckpointError):
  877. def get_str_tb(label, capture_logs):
  878. out = ""
  879. total_len = len(capture_logs.logs)
  880. for i, (log, tb) in enumerate(zip(capture_logs.logs, capture_logs.tbs)):
  881. out += f"{log} ({i + 1} of {total_len} in {label})\n\n"
  882. found_torch_dispatch = False
  883. for line in tb:
  884. # Start printing stack trace only after __torch_dispatch__ is found
  885. is_torch_dispatch = line['name'] == '__torch_dispatch__'
  886. if not found_torch_dispatch and not is_torch_dispatch:
  887. continue
  888. elif is_torch_dispatch:
  889. found_torch_dispatch = True
  890. continue
  891. out += f"{line['filename']}:{line['line']}:{line['name']}\n"
  892. out += "\n\n"
  893. return out
  894. assert capture_logs_fwd.logs is not None
  895. assert capture_logs_recompute.logs is not None
  896. raise CheckpointError(
  897. _checkpoint_error_template.format(
  898. forward_traces=get_str_tb("original", capture_logs_fwd),
  899. recompute_traces=get_str_tb("recompute", capture_logs_recompute),
  900. forward_ops="\n".join(capture_logs_fwd.logs),
  901. recompute_ops="\n".join(capture_logs_recompute.logs)
  902. )
  903. ) from e
  904. def context_fn():
  905. return capture_logs_fwd.get_context_manager(), capture_logs_recompute.get_context_manager()
  906. return context_fn, unpack_error_cb
  907. def _default_meta_extractor(x: torch.Tensor) -> Dict[str, Any]:
  908. # These properties are fast to check, easy to understand
  909. return {
  910. "shape": x.shape,
  911. "dtype": x.dtype,
  912. "device": x.device
  913. }
  914. _allowed_determinism_checks_to_fns: Dict[str, Callable[[torch.Tensor], Any]] = {
  915. _DEFAULT_DETERMINISM_MODE: _default_meta_extractor,
  916. "none": lambda _: None,
  917. }
  918. # See Rule 5
  919. class _StopRecomputationError(Exception):
  920. pass
  921. class _recomputation_hook(torch.autograd.graph.saved_tensors_hooks):
  922. def __init__(self, target_frame_ref: ReferenceType, gid: int):
  923. def pack_hook(x):
  924. x = x.detach() if x.requires_grad else x
  925. target_frame = target_frame_ref()
  926. assert target_frame is not None # appease mypy
  927. recomp_idx = target_frame.recomp_counter[gid]
  928. target_frame.recomp_counter[gid] += 1
  929. if recomp_idx >= len(target_frame.weak_holders):
  930. assert not target_frame.early_stop
  931. if not target_frame.forward_completed:
  932. # We run into this case when early stop is not enabled and do
  933. # grad within checkpoint.
  934. # We need to set this flag, so we don't error out later when
  935. # we check if the number of tensors saved during forward and
  936. # recomputation match.
  937. target_frame.ignore_saved_mismatch = True
  938. return x
  939. raise CheckpointError(
  940. "torch.utils.checkpoint: trying to save more tensors during "
  941. "recomputation than during the original forward pass.\n"
  942. f"{_debug_tip_msg}"
  943. )
  944. holder = target_frame.weak_holders[recomp_idx]()
  945. # This holder may have been cleared because someone may have called
  946. # backward within forward. If so, we don't need to save.
  947. if holder is not None:
  948. _internal_assert(holder.handles.get(gid, None) is None)
  949. holder.handles[gid] = _Handle()
  950. target_frame.recomputed[gid][holder.handles[gid]] = x
  951. if target_frame.early_stop and target_frame.recomp_counter[gid] == len(
  952. target_frame.weak_holders
  953. ):
  954. raise _StopRecomputationError
  955. # See Rule 6: [ retain_graph is True ] above
  956. return x
  957. def unpack_hook(x):
  958. # See Rule 6: [ retain_graph is True ] above for an example of when
  959. # the graph created during recomputation could be backwarded.
  960. return x
  961. super().__init__(pack_hook, unpack_hook)
  962. # torch._disable_dynamo creates a reference cycle with decorated function
  963. # This function is used to ensure that the decorated function does not have
  964. # a closure, so that other objects aren't also kept alive.
  965. # https://github.com/pytorch/pytorch/issues/154642
  966. # Note: does not work when fn is compiled
  967. @torch._disable_dynamo
  968. def _run_fn_with_dynamo_disabled(fn, *args, **kwargs):
  969. return fn(*args, **kwargs)
  970. class _checkpoint_hook(torch.autograd.graph.saved_tensors_hooks):
  971. def __init__(self, frame):
  972. def pack_hook(x):
  973. # See Rule 4 above
  974. holder = _Holder()
  975. frame.weak_holders.append(weakref.ref(holder))
  976. # Save metadata to detect non-determinism
  977. if frame.metadata_fn is not None:
  978. with torch.no_grad():
  979. frame.x_metadatas.append(frame.metadata_fn(x))
  980. return holder
  981. def unpack_hook(holder):
  982. gid = torch._C._current_graph_task_id()
  983. if gid == -1:
  984. # generate a temporary id if we trigger unpack outside of a backward call
  985. gid = int(uuid.uuid4())
  986. if not frame.is_recomputed[gid]:
  987. ctx = frame.input_saver.grad_fn
  988. args = ctx.get_args(ctx.saved_tensors)
  989. try:
  990. with _recomputation_hook(
  991. weakref.ref(frame), gid
  992. ), torch.autograd.enable_grad():
  993. # See Note: [compiled autograd and checkpoint unpack hook]
  994. _run_fn_with_dynamo_disabled(frame.recompute_fn, *args)
  995. except _StopRecomputationError:
  996. pass
  997. frame.is_recomputed[gid] = True
  998. frame.check_recomputed_tensors_match(gid)
  999. _internal_assert(gid in holder.handles)
  1000. if holder.handles[gid] is None:
  1001. raise CheckpointError(
  1002. "torch.utils.checkpoint: Unpack is being triggered for a tensor that was already "
  1003. "unpacked once. If you are calling ctx.saved_tensors in backward, make sure to do "
  1004. "so only once. Otherwise please open an issue with details on your use case."
  1005. )
  1006. _internal_assert(holder.handles[gid] in frame.recomputed[gid])
  1007. ret = frame.recomputed[gid][holder.handles[gid]]
  1008. holder.handles[gid] = None
  1009. return ret
  1010. if frame.unpack_error_cb is not None:
  1011. def unpack_hook_with_error_cb(holder):
  1012. try:
  1013. return unpack_hook(holder)
  1014. except CheckpointError as e:
  1015. frame.unpack_error_cb(e)
  1016. super().__init__(pack_hook, unpack_hook_with_error_cb)
  1017. else:
  1018. super().__init__(pack_hook, unpack_hook)
  1019. def _is_compiling(func, args, kwargs):
  1020. # Check if we are under AOTAutograd tracing
  1021. # Checking that a functional mode is active should always do what we want
  1022. return torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.FUNCTIONAL) is not None
  1023. class _VersionWrapper:
  1024. # Check that cached tensors are not mutated.
  1025. def __init__(self, val):
  1026. self.val: Union[torch.Tensor, Any] = val
  1027. self.version: Optional[int] = val._version if isinstance(val, torch.Tensor) else None
  1028. def get_val(self, allow_cache_entry_mutation):
  1029. if self.version is not None and not allow_cache_entry_mutation:
  1030. if self.val._version != self.version:
  1031. # Can we give user a stack trace of where the mutation happened?
  1032. raise RuntimeError(
  1033. "Tensor cached during selective activation checkpoint has been mutated"
  1034. )
  1035. return self.val
  1036. def _maybe_detach(x, any_ret_has_alias_info):
  1037. # We detach for two separate reasons:
  1038. # - For view ops, we need to ensure that when the tensor is returned from
  1039. # CachedDispatchMode, as_view sees that the AutogradMeta is nullptr
  1040. # - Avoid reference cycles
  1041. # For case 1, it is not enough to check whether x has differentiable dtype
  1042. # because non-differentiable dtype can have non-nullptr AutogradMeta, e.g.
  1043. # when the tensor is a view.
  1044. if isinstance(x, torch.Tensor) and (x.is_floating_point() or x.is_complex() or any_ret_has_alias_info):
  1045. with torch._C._SetExcludeDispatchKeyGuard(torch._C.DispatchKey.ADInplaceOrView, False):
  1046. # Ensure that view performed beneath autograd properly propagates
  1047. # version counter. TODO: Use reentrant_dispatch instead of
  1048. # manually manipulating dispatch keys. Using reentrant_dispatch
  1049. # would respect inference_mode, though that is not relevant for
  1050. # this case.
  1051. x = x.detach()
  1052. return x
  1053. class SelectiveCheckpointContext:
  1054. """
  1055. Context passed to policy function during selective checkpointing.
  1056. This class is used to pass relevant metadata to the policy function during
  1057. selective checkpointing. The metadata includes whether the current invocation
  1058. of the policy function is during recomputation or not.
  1059. Example:
  1060. >>> # xdoctest: +SKIP(stub)
  1061. >>>
  1062. >>> def policy_fn(ctx, op, *args, **kwargs):
  1063. >>> print(ctx.is_recompute)
  1064. >>>
  1065. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
  1066. >>>
  1067. >>> out = torch.utils.checkpoint.checkpoint(
  1068. >>> fn, x, y,
  1069. >>> use_reentrant=False,
  1070. >>> context_fn=context_fn,
  1071. >>> )
  1072. """
  1073. def __init__(self, *, is_recompute):
  1074. self.is_recompute = is_recompute
  1075. class CheckpointPolicy(enum.Enum):
  1076. """
  1077. Enum for specifying the policy for checkpointing during backpropagation.
  1078. The following policies are supported:
  1079. - ``{MUST,PREFER}_SAVE``: The operation's output will be saved during the forward
  1080. pass and will not be recomputed during the backward pass
  1081. - ``{MUST,PREFER}_RECOMPUTE``: The operation's output will not be saved during the
  1082. forward pass and will be recomputed during the backward pass
  1083. Use ``MUST_*`` over ``PREFER_*`` to indicate that the policy should not be overridden
  1084. by other subsystems like `torch.compile`.
  1085. .. note::
  1086. A policy function that always returns ``PREFER_RECOMPUTE`` is
  1087. equivalent to vanilla checkpointing.
  1088. A policy function that returns ``PREFER_SAVE`` every op is
  1089. NOT equivalent to not using checkpointing. Using such a policy would
  1090. save additional tensors not limited to ones that are actually needed for
  1091. gradient computation.
  1092. """
  1093. MUST_SAVE = 0
  1094. PREFER_SAVE = 1
  1095. MUST_RECOMPUTE = 2
  1096. PREFER_RECOMPUTE = 3
  1097. def _policy_from_bool(b):
  1098. # For backward compatibility
  1099. return CheckpointPolicy.MUST_SAVE if b else CheckpointPolicy.PREFER_RECOMPUTE
  1100. SAC_IGNORED_OPS = {
  1101. # AC inserts different number of detach during forward and recompute.
  1102. torch.ops.aten.detach.default,
  1103. # AC's determinism check invokes additional metadata ops during forward.
  1104. # With subclasses involved, these metadata ops become dispatchable, this
  1105. # can result in incorrectness if these ops are selected cached.
  1106. torch.ops.prim.device.default,
  1107. } | set(torch._subclasses.functional_tensor.FunctionalTensor.metadata_fns)
  1108. class _CachingTorchDispatchMode(TorchDispatchMode):
  1109. # Used together with _CachedTorchDispatchMode to implement SAC.
  1110. def __init__(self, policy_fn, storage):
  1111. self.policy_fn = policy_fn
  1112. self.storage = storage
  1113. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1114. if func in SAC_IGNORED_OPS:
  1115. return func(*args, **kwargs)
  1116. kwargs = {} if kwargs is None else kwargs
  1117. policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=False),
  1118. func, *args, **kwargs)
  1119. if isinstance(policy, bool):
  1120. policy = _policy_from_bool(policy)
  1121. is_compiling = _is_compiling(func, args, kwargs)
  1122. if is_compiling:
  1123. # Overwrite each node's "recompute" tag to add in the user annotation.
  1124. fx_traceback.current_meta["recompute"] = policy
  1125. out = func(*args, **kwargs)
  1126. # HOPs don't support func._schema
  1127. # HOPs don't alias -> this is always true today and will be always true for a long time
  1128. # TODO HOPs don't mutate -> this is always true today but will not be true forever
  1129. if isinstance(func, torch._ops.HigherOrderOperator):
  1130. any_ret_has_alias_info = False
  1131. else:
  1132. any_ret_has_alias_info = any(ret.alias_info is not None for ret in func._schema.returns)
  1133. if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
  1134. self.storage[func].append(tree_map(lambda x: _VersionWrapper(_maybe_detach(x, any_ret_has_alias_info)), out))
  1135. return out
  1136. class _CachedTorchDispatchMode(TorchDispatchMode):
  1137. # Used together with _CachedTorchDispatchMode to implement SAC.
  1138. def __init__(self, policy_fn, storage, allow_cache_entry_mutation):
  1139. self.policy_fn = policy_fn
  1140. self.storage = storage
  1141. self.allow_cache_entry_mutation = allow_cache_entry_mutation
  1142. def __torch_dispatch__(self, func, types, args=(), kwargs=None):
  1143. if func in SAC_IGNORED_OPS:
  1144. return func(*args, **kwargs)
  1145. kwargs = {} if kwargs is None else kwargs
  1146. policy = self.policy_fn(SelectiveCheckpointContext(is_recompute=True),
  1147. func, *args, **kwargs)
  1148. if isinstance(policy, bool):
  1149. policy = _policy_from_bool(policy)
  1150. is_compiling = _is_compiling(func, args, kwargs)
  1151. if policy in (CheckpointPolicy.MUST_SAVE, CheckpointPolicy.PREFER_SAVE) or is_compiling:
  1152. storage = self.storage.get(func)
  1153. if storage is None:
  1154. raise RuntimeError(f"{func} encountered during backward, but not found in storage")
  1155. if len(storage) == 0:
  1156. raise RuntimeError(
  1157. "Trying to backward an extra time. You are only allowed to backward once "
  1158. "on any region computed under selective activation checkpoint."
  1159. )
  1160. out = tree_map(lambda x: x.get_val(self.allow_cache_entry_mutation), storage.pop(0))
  1161. else:
  1162. out = func(*args, **kwargs)
  1163. return out
  1164. def create_selective_checkpoint_contexts(policy_fn_or_list, allow_cache_entry_mutation=False):
  1165. """
  1166. Helper to avoid recomputing certain ops during activation checkpointing.
  1167. Use this with `torch.utils.checkpoint.checkpoint` to control which
  1168. operations are recomputed during the backward pass.
  1169. Args:
  1170. policy_fn_or_list (Callable or List):
  1171. - If a policy function is provided, it should accept a
  1172. :class:`SelectiveCheckpointContext`, the :class:`OpOverload`, args and
  1173. kwargs to the op, and return a :class:`CheckpointPolicy` enum value
  1174. indicating whether the execution of the op should be recomputed or not.
  1175. - If a list of operations is provided, it is equivalent to a policy
  1176. returning `CheckpointPolicy.MUST_SAVE` for the specified
  1177. operations and `CheckpointPolicy.PREFER_RECOMPUTE` for all other
  1178. operations.
  1179. allow_cache_entry_mutation (bool, optional): By default, an error is
  1180. raised if any tensors cached by selective activation checkpoint are
  1181. mutated in order to ensure correctness. If set to `True`, this check
  1182. is disabled.
  1183. Returns:
  1184. A tuple of two context managers.
  1185. Example:
  1186. >>> # xdoctest: +REQUIRES(LINUX)
  1187. >>> import functools
  1188. >>>
  1189. >>> x = torch.rand(10, 10, requires_grad=True)
  1190. >>> y = torch.rand(10, 10, requires_grad=True)
  1191. >>>
  1192. >>> ops_to_save = [
  1193. >>> torch.ops.aten.mm.default,
  1194. >>> ]
  1195. >>>
  1196. >>> def policy_fn(ctx, op, *args, **kwargs):
  1197. >>> if op in ops_to_save:
  1198. >>> return CheckpointPolicy.MUST_SAVE
  1199. >>> else:
  1200. >>> return CheckpointPolicy.PREFER_RECOMPUTE
  1201. >>>
  1202. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, policy_fn)
  1203. >>>
  1204. >>> # or equivalently
  1205. >>> context_fn = functools.partial(create_selective_checkpoint_contexts, ops_to_save)
  1206. >>>
  1207. >>> def fn(x, y):
  1208. >>> return torch.sigmoid(torch.matmul(torch.matmul(x, y), y)) * y
  1209. >>>
  1210. >>> out = torch.utils.checkpoint.checkpoint(
  1211. >>> fn, x, y,
  1212. >>> use_reentrant=False,
  1213. >>> context_fn=context_fn,
  1214. >>> )
  1215. """
  1216. # NB: If grad_mode is disabled, checkpoint would not run forward under
  1217. # context_fn anyway, so proceed as usual.
  1218. if isinstance(policy_fn_or_list, list):
  1219. for op in policy_fn_or_list:
  1220. if not isinstance(op, (torch._ops.OpOverload, torch._ops.HigherOrderOperator)):
  1221. _extra_msg = (
  1222. "Please update the OpOverloadPacket to a specific OpOverload."
  1223. "For example, if you have `torch.ops.aten.mm`, change it to `torch.ops.aten.mm.default`."
  1224. ) if isinstance(op, torch._ops.OpOverloadPacket) else ""
  1225. raise ValueError(
  1226. f"Expected op in `op_list` to be an OpOverload but got: {op} "
  1227. f"of type {type(op)}. {_extra_msg}"
  1228. )
  1229. def policy_fn(ctx, op, *args, **kwargs):
  1230. if op in policy_fn_or_list:
  1231. return CheckpointPolicy.MUST_SAVE
  1232. else:
  1233. return CheckpointPolicy.PREFER_RECOMPUTE
  1234. elif callable(policy_fn_or_list):
  1235. policy_fn = policy_fn_or_list
  1236. else:
  1237. raise TypeError("policy_fn_or_list must be either a function or a list of ops.")
  1238. storage: Dict[Any, List[Any]] = defaultdict(list)
  1239. return (
  1240. _CachingTorchDispatchMode(policy_fn, storage),
  1241. _CachedTorchDispatchMode(policy_fn, storage, allow_cache_entry_mutation),
  1242. )
  1243. # NB: this helper wraps fn before calling checkpoint_impl. kwargs and
  1244. # saving/restoring of global state is handled here.
  1245. def _checkpoint_without_reentrant_generator(
  1246. fn,
  1247. preserve_rng_state=True,
  1248. context_fn: Callable[[], Tuple[ContextManager, ContextManager]] = noop_context_fn,
  1249. determinism_check: str = _DEFAULT_DETERMINISM_MODE,
  1250. debug: bool = False,
  1251. early_stop: bool = True,
  1252. *args,
  1253. **kwargs
  1254. ):
  1255. """Checkpointing without reentrant autograd.
  1256. Args:
  1257. fn: describes what to run in the forward pass of the model or
  1258. part of the model. It should also know how to handle the inputs
  1259. passed as the tuple. For example, in LSTM, if user passes
  1260. ``(activation, hidden)``, :attr:`function` should correctly use the
  1261. first input as ``activation`` and the second input as ``hidden``
  1262. preserve_rng_state(bool, optional): Omit stashing and restoring
  1263. the RNG state during each checkpoint.
  1264. Default: ``True``
  1265. context_fn(Callable, optional): A callable returning a tuple of two
  1266. context managers. The function and its recomputation will be run
  1267. under the first and second context managers respectively.
  1268. determinism_check(str, optional): A string specifying the determinism
  1269. check to perform. By default it is set to ``"default"`` which
  1270. compares the shapes, dtypes, and devices of the recomputed tensors
  1271. against those the saved tensors. To turn off this check, specify
  1272. ``"none"``. Currently these are the only two supported values.
  1273. Please open an issue if you would like to see more determinism
  1274. checks.
  1275. debug(bool, optional): If ``True``, error messages will also include
  1276. a trace of the operators ran during the original forward computation
  1277. as well as the recomputation.
  1278. early_stop(bool, optional): If ``True``, non-reentrant checkpoint stops
  1279. recomputation as soon as it has computed all needed Tensors. Can be
  1280. overridden globally using :func:`set_checkpoint_early_stop` context
  1281. manager. Default: ``True``.
  1282. *args: Arguments to pass in to the given ``function``.
  1283. **kwargs: Keyword arguments to pass into the given ``function``.
  1284. """
  1285. unpack_error_cb = None
  1286. if _checkpoint_debug_enabled if _checkpoint_debug_enabled is not None else debug:
  1287. if context_fn != noop_context_fn:
  1288. raise ValueError(
  1289. "debug=True is incompatible with non-default context_fn"
  1290. )
  1291. context_fn, unpack_error_cb = _get_debug_context_and_cb()
  1292. if determinism_check in _allowed_determinism_checks_to_fns:
  1293. metadata_fn = _allowed_determinism_checks_to_fns[determinism_check]
  1294. else:
  1295. raise ValueError(
  1296. f"determinism_check should be one of {list(_allowed_determinism_checks_to_fns.keys())}, "
  1297. f"but got {determinism_check}"
  1298. )
  1299. device_type = _infer_device_type(*args)
  1300. device_module = _get_device_module(device_type)
  1301. forward_context, recompute_context = context_fn()
  1302. if _is_compiling(fn, args, kwargs) and context_fn != noop_context_fn:
  1303. assert (
  1304. isinstance(forward_context, TorchDispatchMode) and
  1305. isinstance(recompute_context, TorchDispatchMode)
  1306. ), \
  1307. "In torch.compile mode, `context_fn` arg passed to `torch.utils.checkpoint` " + \
  1308. "must generate a tuple of two `TorchDispatchMode`s."
  1309. # Accommodates the (remote) possibility that autocast is enabled for cpu AND gpu.
  1310. device_autocast_kwargs, cpu_autocast_kwargs = _get_autocast_kwargs(device_type=device_type)
  1311. if preserve_rng_state:
  1312. fwd_cpu_state = torch.get_rng_state()
  1313. # Don't eagerly initialize the cuda context by accident.
  1314. # (If the user intends that the context is initialized later, within their
  1315. # run_function, we SHOULD actually stash the cuda state here. Unfortunately,
  1316. # we have no way to anticipate this will happen before we run the function.
  1317. # If they do so, we raise an error.)
  1318. had_device_in_fwd = False
  1319. if getattr(device_module, "_initialized", False):
  1320. had_device_in_fwd = True
  1321. fwd_devices, fwd_device_states = get_device_states(*args)
  1322. def recompute_fn(*inputs):
  1323. kwargs, *args = inputs
  1324. # This will be called later during recomputation. This wrapping enables
  1325. # the necessary global state to be captured.
  1326. rng_devices = []
  1327. if preserve_rng_state and had_device_in_fwd:
  1328. rng_devices = fwd_devices
  1329. with torch.random.fork_rng(
  1330. devices=rng_devices, enabled=preserve_rng_state, device_type=device_type
  1331. ):
  1332. if preserve_rng_state:
  1333. torch.set_rng_state(fwd_cpu_state)
  1334. if had_device_in_fwd:
  1335. set_device_states(fwd_devices, fwd_device_states, device_type=device_type)
  1336. device_autocast_ctx = torch.amp.autocast(
  1337. device_type=device_type, **device_autocast_kwargs
  1338. ) if torch.amp.is_autocast_available(device_type) else contextlib.nullcontext()
  1339. with device_autocast_ctx, torch.amp.autocast("cpu", **cpu_autocast_kwargs), recompute_context: # type: ignore[attr-defined]
  1340. fn(*args, **kwargs)
  1341. new_frame = _CheckpointFrame(
  1342. recompute_fn,
  1343. _enable_checkpoint_early_stop if _enable_checkpoint_early_stop is not None else early_stop,
  1344. unpack_error_cb,
  1345. metadata_fn
  1346. )
  1347. dummy = torch.empty((0,), requires_grad=True)
  1348. new_frame.input_saver = _NoopSaveInputs.apply(dummy, kwargs, *args)
  1349. # When ambient grad_mode is False
  1350. if new_frame.input_saver.grad_fn is None:
  1351. yield
  1352. return
  1353. with _checkpoint_hook(new_frame), forward_context:
  1354. yield
  1355. new_frame.forward_completed = True
  1356. if getattr(device_module, "_initialized", False) and \
  1357. preserve_rng_state and not had_device_in_fwd: # type: ignore[possibly-undefined]
  1358. # Device was not initialized before running the forward, so we didn't
  1359. # stash the device state.
  1360. raise RuntimeError(
  1361. "PyTorch's device state was initialized in the forward pass "
  1362. "of a Checkpoint, which is not allowed. Please open an issue "
  1363. "if you need this feature."
  1364. )
  1365. return
  1366. # Note: [compiled autograd and checkpoint unpack hook]
  1367. # When tracing via compiled autograd, this hook will be visible to the
  1368. # compiler if the forward of this checkpointed region ran in eager.
  1369. # If the forward had ran under compile, it would have been wrapped in a
  1370. # higher order op. See Note: [torch.compile and checkpoint].
  1371. #
  1372. # Since we run the recomputation hook under a enable_grad context,
  1373. # AOTDispatch will trace a joint graph for this hook, and may
  1374. # save different activations than in eager. This conflicts with the
  1375. # strict activation count checks in `frame.check_recomputed_tensors_match`.
  1376. # So, we disable this hook to force it to recompute eager checkpointed regions
  1377. # in eager. This could be removed if we can disable the partitioner for this
  1378. # graph segment.