optimizer.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400
  1. # mypy: ignore-errors
  2. """
  3. This module implements variable tracking for PyTorch optimizers during Dynamo tracing.
  4. The OptimizerVariable class provides specialized handling for optimizer instances by:
  5. - Optimizing the tracing of expensive optimizer initialization
  6. - Managing optimizer state and parameter group tracking
  7. - Handling tensor sources and guards for optimizer state tensors
  8. - Supporting CUDA graph execution through static tensor address management
  9. - Providing special handling for parameter gradients and optimizer state tensors
  10. Key features include:
  11. - Efficient initialization tracing via _init_group optimization
  12. - Automatic marking of optimizer state tensors as static for CUDA graphs
  13. - Proper source tracking for parameter groups, gradients, and state tensors
  14. - Guard installation for optimizer state structure
  15. - Support for both CPU and GPU tensor handling
  16. - Cleanup of static tensor references via finalizers
  17. The module integrates with Dynamo's broader tracing system while providing
  18. optimizer-specific optimizations and safety guarantees.
  19. """
  20. import logging
  21. import weakref
  22. from typing import TYPE_CHECKING
  23. import torch
  24. from torch._logging import getArtifactLogger
  25. from torch.utils._pytree import tree_map_only
  26. from ..guards import GuardBuilder, install_guard
  27. from ..source import (
  28. AttrSource,
  29. ConstDictKeySource,
  30. DictGetItemSource,
  31. GetItemSource,
  32. GlobalWeakRefSource,
  33. GradSource,
  34. )
  35. from ..utils import GLOBAL_KEY_PREFIX
  36. from .base import VariableTracker
  37. from .constant import ConstantVariable
  38. from .dicts import ConstDictVariable
  39. from .lists import ListVariable
  40. from .misc import GetAttrVariable
  41. from .user_defined import UserDefinedObjectVariable
  42. if TYPE_CHECKING:
  43. from torch._dynamo.symbolic_convert import InstructionTranslator
  44. class ArgMappingException(Exception):
  45. pass
  46. class GuardInstallException(Exception):
  47. pass
  48. perf_hint_log = getArtifactLogger(__name__, "perf_hints")
  49. def _is_static_for_cudagraphs(x):
  50. from torch._inductor.cudagraph_trees import get_manager
  51. if x.is_cuda:
  52. manager = get_manager(x.device.index, False)
  53. is_static_address = torch._dynamo.utils.get_static_address_type(x) is not None
  54. if manager:
  55. return (
  56. is_static_address
  57. or manager.current_node._is_cuda_graph_recorded_tensor(x)
  58. )
  59. else:
  60. return is_static_address
  61. else:
  62. # Don't print a warning for non-cuda tensors
  63. return True
  64. class OptimizerVariable(UserDefinedObjectVariable):
  65. _nonvar_fields = {
  66. "grad_to_source",
  67. "tensor_to_source",
  68. "static_tensor_names",
  69. *UserDefinedObjectVariable._nonvar_fields,
  70. }
  71. def __init__(
  72. self,
  73. value,
  74. grad_to_source=None,
  75. static_tensor_names=None,
  76. tensor_to_source=None,
  77. **kwargs,
  78. ) -> None:
  79. super().__init__(value, **kwargs)
  80. self.grad_to_source = grad_to_source or {}
  81. self.tensor_to_source = tensor_to_source or {}
  82. self.static_tensor_names = static_tensor_names or set()
  83. def call_method(
  84. self,
  85. tx,
  86. name,
  87. args: "list[VariableTracker]",
  88. kwargs: "dict[str, VariableTracker]",
  89. ) -> "VariableTracker":
  90. """This is an optimization to avoid tracing the very slow initialization of the optimizer"""
  91. if name == "_init_group":
  92. try:
  93. self.graph_break_if_pending_mutation(tx)
  94. self.move_step_if_cpu()
  95. py_args, py_kwargs = self.get_python_args(*args, **kwargs)
  96. ret_val = self.value._init_group(*py_args, **py_kwargs)
  97. self.map_sources_and_install_guards(tx)
  98. self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
  99. # stash a weak_ptr to optimizer to invalidate code
  100. # if the optimizer object dies
  101. mangled_name = f"__optimizer_{id(self.value)}"
  102. tx.store_global_weakref_by_id(mangled_name, self.value)
  103. self.create_finalizer(tx)
  104. # This is currently safe only because the only actual `ret_val`s returned
  105. # by the `_init_group` of existing optimizers are properties that are invariant
  106. # to the input tensors (e.g. dtype, layout). Changing these would trigger a
  107. # recompilation and hence never result in the wrong specialization of `ret_val`.
  108. return ConstantVariable.create(ret_val)
  109. except (ArgMappingException, GuardInstallException) as _:
  110. # trace normally if we can't map args or install guards correctly
  111. pass
  112. return super().call_method(tx, name, args, kwargs)
  113. def var_getattr(self, tx: "InstructionTranslator", name):
  114. # Note: this allows us to intercept the call in call_method
  115. # in the typical case, we return a UserMethodVariable
  116. # which will directly inline
  117. if name in ("_init_group", "step"):
  118. return GetAttrVariable(self, name, source=AttrSource(self.source, name))
  119. if name == "param_groups":
  120. from ..decorators import mark_static_address
  121. for group in self.value.param_groups:
  122. for p in group["params"]:
  123. mark_static_address(p, guard=True)
  124. self._set_capturable(tx)
  125. return super().var_getattr(tx, name)
  126. def graph_break_if_pending_mutation(self, tx):
  127. # If there are pending mutations on a parameter (due to using closure)
  128. # then we need to graph break to allow the python version of the parameter
  129. # to update, so that running _init_group will initialize the states with
  130. # the correct values
  131. for g in self.value.param_groups:
  132. for p in g["params"]:
  133. side_effects = tx.output.side_effects
  134. variable = side_effects.id_to_variable.get(id(p), None)
  135. if variable and side_effects.has_pending_mutation(variable):
  136. from ..exc import Unsupported
  137. raise Unsupported("Pending mutation on parameter")
  138. def _set_capturable(self, tx):
  139. from . import LazyVariableTracker
  140. # We only set capturable if params are on cuda
  141. # and the state is not initialized
  142. def safe_to_set_capturable(group):
  143. all_uninitialized = True
  144. all_gpu = True
  145. for p in group.get("params", []):
  146. all_gpu &= p.is_cuda or p.is_xpu
  147. all_uninitialized &= p not in self.value.state
  148. return "capturable" in group and all_uninitialized and all_gpu
  149. # track indices to not set so we don't need to
  150. # in the variable tracker realize the whole state
  151. # we handle guarding the state specially
  152. for group in self.value.param_groups:
  153. if safe_to_set_capturable(group):
  154. group["capturable"] = True
  155. source = self.source and AttrSource(self.source, "param_groups")
  156. param_groups_vt = LazyVariableTracker.realize_all(
  157. VariableTracker.build(tx, self.value.param_groups, source)
  158. )
  159. for param_group_vt in param_groups_vt.items:
  160. key = ConstDictVariable._HashableTracker(
  161. ConstantVariable.create("capturable")
  162. )
  163. param_group_vt.items[key] = ConstantVariable.create(True)
  164. def get_python_args(self, *args, **kwargs):
  165. """Get python values equivalent to the variable tracker args"""
  166. def map_arg(arg):
  167. if isinstance(arg, ConstantVariable):
  168. return arg.as_python_constant()
  169. elif isinstance(arg, ListVariable) and not arg.items:
  170. return []
  171. elif (
  172. isinstance(arg, ConstDictVariable)
  173. and isinstance(arg.source, GetItemSource)
  174. and isinstance(arg.source.base, AttrSource)
  175. and arg.source.base.member == "param_groups"
  176. ):
  177. return self.value.param_groups[arg.source.index]
  178. raise ArgMappingException
  179. new_args = [map_arg(arg) for arg in args]
  180. new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
  181. return new_args, new_kwargs
  182. # If users load an old state dictionary,
  183. # it's possible that step could be on the cpu
  184. # if this is the case, move it to the GPU
  185. # corresponding to the parameter
  186. # in most cases this is a no-op because the state is empty
  187. def move_step_if_cpu(self):
  188. for p, state in self.value.state.items():
  189. if "step" in state and state["step"].is_cpu:
  190. state["step"] = state["step"].to(p.device)
  191. def map_sources_and_install_guards(self, tx):
  192. from ..decorators import mark_static_address
  193. from .lazy import LazyVariableTracker
  194. self.grad_to_source = {}
  195. self.tensor_to_source = {}
  196. def mark_static(x):
  197. mark_static_address(x, guard=True)
  198. tree_map_only(torch.Tensor, mark_static, self.value.state)
  199. # Recursively realize the variable trackers for optim.state and
  200. # optim.param_groups, which recursively install the necessary guards.
  201. params_groups_source = self.source and AttrSource(self.source, "param_groups")
  202. param_groups_vt = LazyVariableTracker.realize_all(
  203. VariableTracker.build(tx, self.value.param_groups, params_groups_source)
  204. )
  205. state_source = self.source and AttrSource(self.source, "state")
  206. state_vt = VariableTracker.build(tx, self.value.state, state_source)
  207. # We need to realize the top level state dict to populate
  208. # the guard locals
  209. state_vt.realize()
  210. tx.output.guard_on_key_order.add(state_source)
  211. # Populate self.grad_to_source and self.tensor_to_source so that we can
  212. # manually update_list_args
  213. for group, group_vt in zip(self.value.param_groups, param_groups_vt.items):
  214. # we assume here that all params within a param group
  215. # are initialized similarly
  216. if len(group["params"]) > 0:
  217. for param in group["params"]:
  218. if param.grad is not None:
  219. key_index = None
  220. for i, k in enumerate(self.value.state.keys()):
  221. if k is param:
  222. key_index = i
  223. break
  224. if key_index:
  225. LazyVariableTracker.realize_all(
  226. VariableTracker.build(
  227. tx,
  228. self.value.state[param],
  229. DictGetItemSource(
  230. state_source,
  231. ConstDictKeySource(state_source, key_index),
  232. ),
  233. )
  234. )
  235. break
  236. params_vt = group_vt.getitem_const(tx, ConstantVariable.create("params"))
  237. all_static = True
  238. non_static_grads = []
  239. for p_ind, (p, p_vt) in enumerate(
  240. zip(group["params"], params_vt.unpack_var_sequence(tx))
  241. ):
  242. param_source = p_vt.source
  243. self.tensor_to_source[p] = param_source
  244. grad_source = GradSource(
  245. param_source,
  246. "grad",
  247. )
  248. if p.grad is not None:
  249. self.grad_to_source[p.grad] = grad_source
  250. if not _is_static_for_cudagraphs(p.grad):
  251. all_static = False
  252. non_static_grads.append(grad_source)
  253. else:
  254. install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
  255. # Note: to avoid spam logs only warn if perf hint artifact is enabled
  256. # (NB: artifacts are only enabled at the debug or warning level)
  257. if not all_static and perf_hint_log.isEnabledFor(logging.DEBUG):
  258. non_static_grads = [src.name() for src in non_static_grads]
  259. perf_hint_log.warning(
  260. (
  261. "Grad tensors %s will be copied during cudagraphs execution."
  262. "If using cudagraphs and the grad tensor addresses will be the same across runs,"
  263. " use torch._dynamo.decorators.mark_static_address to elide this copy.",
  264. ),
  265. non_static_grads,
  266. )
  267. # We have to again iterate over the state dict to collect the
  268. # tensor_to_source dict. This is used for the finalizer.
  269. for idx, (p, value) in enumerate(self.value.state.items()):
  270. p_state_source = DictGetItemSource(
  271. state_source, ConstDictKeySource(state_source, idx)
  272. )
  273. tx.output.guard_on_key_order.add(p_state_source)
  274. for inner_idx, (k, v) in enumerate(value.items()):
  275. if (
  276. isinstance(v, torch.Tensor)
  277. and v not in self.grad_to_source
  278. and v not in self.tensor_to_source
  279. ):
  280. self.tensor_to_source[v] = DictGetItemSource(
  281. p_state_source, ConstDictKeySource(p_state_source, inner_idx)
  282. )
  283. def wrap_tensor(self, tx: "InstructionTranslator", tensor_value):
  284. """Wrap state tensor in a TensorVariable"""
  285. from ..decorators import mark_static_address
  286. # If we have a source for a tensor already use it,
  287. # if we have not seen a tensor before, stash and use a
  288. # global weak ref source, since it must be an optimizer tensor
  289. # that we have missed
  290. if tensor_value in self.tensor_to_source:
  291. # mark these tensors as static for cudagraphs
  292. mark_static_address(tensor_value, guard=True)
  293. source = self.tensor_to_source[tensor_value]
  294. self.static_tensor_names.add(tx.output.module_key_name(source.name()))
  295. elif tensor_value in self.grad_to_source:
  296. source = self.grad_to_source[tensor_value]
  297. else:
  298. # mark these tensors as static for cudagraphs
  299. mark_static_address(tensor_value, guard=True)
  300. global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
  301. source = GlobalWeakRefSource(global_name)
  302. self.static_tensor_names.add(tx.output.module_key_name(source.name()))
  303. return VariableTracker.build(tx, tensor_value, source)
  304. def update_list_args(
  305. self, tx: "InstructionTranslator", args, kwargs, py_args, py_kwargs
  306. ):
  307. """Update the args and kwargs to the traced optimizer call"""
  308. for arg, py_arg in zip(args, py_args):
  309. if isinstance(arg, ListVariable):
  310. assert isinstance(py_arg, list), (
  311. "py_arg should be a list in optimizer variable"
  312. )
  313. for i, val in enumerate(py_arg):
  314. tx.output.side_effects.mutation(arg)
  315. if isinstance(val, torch.Tensor):
  316. arg.items.append(self.wrap_tensor(tx, val))
  317. else:
  318. source = arg.source and GetItemSource(arg.source, i)
  319. arg.items.append(VariableTracker.build(tx, val, source))
  320. def create_finalizer(self, tx):
  321. names_to_delete = self.static_tensor_names
  322. value = self.value
  323. tc = tx.output.tracing_context
  324. def init_finalizer(gm):
  325. def clear_static_tensor_refs():
  326. for name in names_to_delete:
  327. gm._buffers.pop(name, None)
  328. gm._parameters.pop(name, None)
  329. if tc.params_flat:
  330. tc.params_flat.clear()
  331. if tc.params_flat_unwrap_subclasses:
  332. tc.params_flat_unwrap_subclasses.clear()
  333. weakref.finalize(value, clear_static_tensor_refs)
  334. tx.output.add_graph_finalizer(init_finalizer)