adam.py 38 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973
  1. # mypy: allow-untyped-defs
  2. from typing import cast, Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _device_dtype_check_for_fused,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _foreach_doc,
  12. _fused_doc,
  13. _get_capturable_supported_devices,
  14. _get_scalar_dtype,
  15. _get_value,
  16. _maximize_doc,
  17. _params_doc,
  18. _stack_if_compiling,
  19. _to_scalar,
  20. _use_grad_for_differentiable,
  21. _view_as_real,
  22. DeviceDict,
  23. DeviceDtypeDict,
  24. Optimizer,
  25. ParamsT,
  26. )
  27. __all__ = ["Adam", "adam"]
  28. class Adam(Optimizer):
  29. def __init__(
  30. self,
  31. params: ParamsT,
  32. lr: Union[float, Tensor] = 1e-3,
  33. betas: tuple[Union[float, Tensor], Union[float, Tensor]] = (0.9, 0.999),
  34. eps: float = 1e-8,
  35. weight_decay: float = 0,
  36. amsgrad: bool = False,
  37. *,
  38. foreach: Optional[bool] = None,
  39. maximize: bool = False,
  40. capturable: bool = False,
  41. differentiable: bool = False,
  42. fused: Optional[bool] = None,
  43. decoupled_weight_decay: bool = False,
  44. ):
  45. if isinstance(lr, Tensor):
  46. if foreach and not capturable:
  47. raise ValueError(
  48. "lr as a Tensor is not supported for capturable=False and foreach=True"
  49. )
  50. if lr.numel() != 1:
  51. raise ValueError("Tensor lr must be 1-element")
  52. if not 0.0 <= lr:
  53. raise ValueError(f"Invalid learning rate: {lr}")
  54. if not 0.0 <= eps:
  55. raise ValueError(f"Invalid epsilon value: {eps}")
  56. if not 0.0 <= betas[0] < 1.0:
  57. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  58. if not 0.0 <= betas[1] < 1.0:
  59. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  60. if not 0.0 <= weight_decay:
  61. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  62. if not (
  63. (isinstance(betas[0], float) and isinstance(betas[1], float))
  64. or (isinstance(betas[0], Tensor) and isinstance(betas[1], Tensor))
  65. ):
  66. raise ValueError("betas must be either both floats or both Tensors")
  67. if isinstance(betas[0], Tensor):
  68. if not capturable and foreach:
  69. raise ValueError(
  70. "betas[0] as a Tensor is not supported for capturable=False and foreach=True"
  71. )
  72. if betas[0].numel() != 1:
  73. raise ValueError("Tensor betas[0] must be 1-element")
  74. if isinstance(betas[1], Tensor):
  75. if not capturable and foreach:
  76. raise ValueError(
  77. "betas[1] as a Tensor is not supported for capturable=False and foreach=True"
  78. )
  79. if betas[1].numel() != 1:
  80. raise ValueError("Tensor betas[1] must be 1-element")
  81. defaults = {
  82. "lr": lr,
  83. "betas": betas,
  84. "eps": eps,
  85. "weight_decay": weight_decay,
  86. "amsgrad": amsgrad,
  87. "maximize": maximize,
  88. "foreach": foreach,
  89. "capturable": capturable,
  90. "differentiable": differentiable,
  91. "fused": fused,
  92. "decoupled_weight_decay": decoupled_weight_decay,
  93. }
  94. super().__init__(params, defaults)
  95. if fused:
  96. if differentiable:
  97. raise RuntimeError("`fused` does not support `differentiable`")
  98. self._step_supports_amp_scaling = True
  99. # TODO(crcrpar): [low prec params & their higher prec copy]
  100. # Support AMP with FP16/BF16 model params which would need
  101. # higher prec copy of params to do update math in higher prec to
  102. # alleviate the loss of information.
  103. if foreach:
  104. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  105. def __setstate__(self, state):
  106. super().__setstate__(state)
  107. for group in self.param_groups:
  108. group.setdefault("amsgrad", False)
  109. group.setdefault("maximize", False)
  110. group.setdefault("foreach", None)
  111. group.setdefault("capturable", False)
  112. group.setdefault("differentiable", False)
  113. group.setdefault("decoupled_weight_decay", False)
  114. fused = group.setdefault("fused", None)
  115. for p in group["params"]:
  116. p_state = self.state.get(p, [])
  117. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  118. step_val = float(p_state["step"])
  119. p_state["step"] = (
  120. torch.tensor(
  121. step_val,
  122. dtype=_get_scalar_dtype(is_fused=fused),
  123. device=p.device,
  124. )
  125. if group["capturable"] or group["fused"]
  126. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  127. )
  128. def _init_group(
  129. self,
  130. group,
  131. params_with_grad,
  132. grads,
  133. exp_avgs,
  134. exp_avg_sqs,
  135. max_exp_avg_sqs,
  136. state_steps,
  137. ):
  138. has_complex = False
  139. for p in group["params"]:
  140. if p.grad is not None:
  141. has_complex |= torch.is_complex(p)
  142. params_with_grad.append(p)
  143. if p.grad.is_sparse:
  144. raise RuntimeError(
  145. "Adam does not support sparse gradients, please consider SparseAdam instead"
  146. )
  147. grads.append(p.grad)
  148. state = self.state[p]
  149. # Lazy state initialization
  150. if len(state) == 0:
  151. if group["fused"]:
  152. _device_dtype_check_for_fused(p)
  153. # note(crcrpar): [special device hosting for step]
  154. # Deliberately host `step` on CPU if both capturable and fused are off.
  155. # This is because kernel launches are costly on CUDA and XLA.
  156. state["step"] = (
  157. torch.zeros(
  158. (),
  159. dtype=_get_scalar_dtype(is_fused=group["fused"]),
  160. device=p.device,
  161. )
  162. if group["capturable"] or group["fused"]
  163. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  164. )
  165. # Exponential moving average of gradient values
  166. state["exp_avg"] = torch.zeros_like(
  167. p, memory_format=torch.preserve_format
  168. )
  169. # Exponential moving average of squared gradient values
  170. state["exp_avg_sq"] = torch.zeros_like(
  171. p, memory_format=torch.preserve_format
  172. )
  173. if group["amsgrad"]:
  174. # Maintains max of all exp. moving avg. of sq. grad. values
  175. state["max_exp_avg_sq"] = torch.zeros_like(
  176. p, memory_format=torch.preserve_format
  177. )
  178. exp_avgs.append(state["exp_avg"])
  179. exp_avg_sqs.append(state["exp_avg_sq"])
  180. if group["amsgrad"]:
  181. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  182. if group["differentiable"] and state["step"].requires_grad:
  183. raise RuntimeError(
  184. "`requires_grad` is not supported for `step` in differentiable mode"
  185. )
  186. # Foreach without capturable does not support a tensor lr
  187. if (
  188. group["foreach"]
  189. and torch.is_tensor(group["lr"])
  190. and not group["capturable"]
  191. ):
  192. raise RuntimeError(
  193. "lr as a Tensor is not supported for capturable=False and foreach=True"
  194. )
  195. state_steps.append(state["step"])
  196. return has_complex
  197. @_use_grad_for_differentiable
  198. def step(self, closure=None):
  199. """Perform a single optimization step.
  200. Args:
  201. closure (Callable, optional): A closure that reevaluates the model
  202. and returns the loss.
  203. """
  204. self._cuda_graph_capture_health_check()
  205. loss = None
  206. if closure is not None:
  207. with torch.enable_grad():
  208. loss = closure()
  209. for group in self.param_groups:
  210. params_with_grad: list[Tensor] = []
  211. grads: list[Tensor] = []
  212. exp_avgs: list[Tensor] = []
  213. exp_avg_sqs: list[Tensor] = []
  214. max_exp_avg_sqs: list[Tensor] = []
  215. state_steps: list[Tensor] = []
  216. beta1, beta2 = group["betas"]
  217. has_complex = self._init_group(
  218. group,
  219. params_with_grad,
  220. grads,
  221. exp_avgs,
  222. exp_avg_sqs,
  223. max_exp_avg_sqs,
  224. state_steps,
  225. )
  226. adam(
  227. params_with_grad,
  228. grads,
  229. exp_avgs,
  230. exp_avg_sqs,
  231. max_exp_avg_sqs,
  232. state_steps,
  233. amsgrad=group["amsgrad"],
  234. has_complex=has_complex,
  235. beta1=beta1,
  236. beta2=beta2,
  237. lr=group["lr"],
  238. weight_decay=group["weight_decay"],
  239. eps=group["eps"],
  240. maximize=group["maximize"],
  241. foreach=group["foreach"],
  242. capturable=group["capturable"],
  243. differentiable=group["differentiable"],
  244. fused=group["fused"],
  245. grad_scale=getattr(self, "grad_scale", None),
  246. found_inf=getattr(self, "found_inf", None),
  247. decoupled_weight_decay=group["decoupled_weight_decay"],
  248. )
  249. return loss
  250. Adam.__doc__ = (
  251. r"""Implements Adam algorithm.
  252. .. math::
  253. \begin{aligned}
  254. &\rule{110mm}{0.4pt} \\
  255. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  256. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)} \\
  257. &\hspace{13mm} \lambda \text{ (weight decay)}, \: \textit{amsgrad},
  258. \:\textit{maximize}, \: \epsilon \text{ (epsilon)} \\
  259. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  260. v_0\leftarrow 0 \text{ (second moment)},\: v_0^{max}\leftarrow 0 \\[-1.ex]
  261. &\rule{110mm}{0.4pt} \\
  262. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  263. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  264. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  265. &\hspace{5mm}\textbf{else} \\
  266. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  267. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  268. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  269. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  270. &\hspace{5mm}v_t \leftarrow \beta_2 v_{t-1} + (1-\beta_2) g^2_t \\
  271. &\hspace{5mm}\widehat{m_t} \leftarrow m_t/\big(1-\beta_1^t \big) \\
  272. &\hspace{5mm}\textbf{if} \: amsgrad \\
  273. &\hspace{10mm} v_t^{max} \leftarrow \mathrm{max}(v_{t-1}^{max},v_t) \\
  274. &\hspace{10mm}\widehat{v_t} \leftarrow v_t^{max}/\big(1-\beta_2^t \big) \\
  275. &\hspace{5mm}\textbf{else} \\
  276. &\hspace{10mm}\widehat{v_t} \leftarrow v_t/\big(1-\beta_2^t \big) \\
  277. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma \widehat{m_t}/
  278. \big(\sqrt{\widehat{v_t}} + \epsilon \big) \\
  279. &\rule{110mm}{0.4pt} \\[-1.ex]
  280. &\bf{return} \: \theta_t \\[-1.ex]
  281. &\rule{110mm}{0.4pt} \\[-1.ex]
  282. \end{aligned}
  283. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  284. """
  285. + rf"""
  286. Args:
  287. {_params_doc}
  288. lr (float, Tensor, optional): learning rate (default: 1e-3). A tensor LR
  289. is not yet supported for all our implementations. Please use a float
  290. LR if you are not also specifying fused=True or capturable=True.
  291. betas (Tuple[float, float], optional): coefficients used for computing
  292. running averages of gradient and its square (default: (0.9, 0.999))
  293. eps (float, optional): term added to the denominator to improve
  294. numerical stability (default: 1e-8)
  295. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  296. decoupled_weight_decay (bool, optional): if True, this optimizer is
  297. equivalent to AdamW and the algorithm will not accumulate weight
  298. decay in the momentum nor variance. (default: False)
  299. amsgrad (bool, optional): whether to use the AMSGrad variant of this
  300. algorithm from the paper `On the Convergence of Adam and Beyond`_
  301. (default: False)
  302. {_foreach_doc}
  303. {_maximize_doc}
  304. {_capturable_doc}
  305. {_differentiable_doc}
  306. {_fused_doc}
  307. .. Note::
  308. A prototype implementation of Adam and AdamW for MPS supports `torch.float32` and `torch.float16`.
  309. .. _Adam\: A Method for Stochastic Optimization:
  310. https://arxiv.org/abs/1412.6980
  311. .. _On the Convergence of Adam and Beyond:
  312. https://openreview.net/forum?id=ryQu7f-RZ
  313. """
  314. )
  315. def _single_tensor_adam(
  316. params: list[Tensor],
  317. grads: list[Tensor],
  318. exp_avgs: list[Tensor],
  319. exp_avg_sqs: list[Tensor],
  320. max_exp_avg_sqs: list[Tensor],
  321. state_steps: list[Tensor],
  322. grad_scale: Optional[Tensor],
  323. found_inf: Optional[Tensor],
  324. *,
  325. amsgrad: bool,
  326. has_complex: bool,
  327. beta1: Union[float, Tensor],
  328. beta2: Union[float, Tensor],
  329. lr: Union[float, Tensor],
  330. weight_decay: float,
  331. eps: float,
  332. maximize: bool,
  333. capturable: bool,
  334. differentiable: bool,
  335. decoupled_weight_decay: bool,
  336. ):
  337. assert grad_scale is None and found_inf is None
  338. if torch.jit.is_scripting():
  339. # this assert is due to JIT being dumb and not realizing that the ops below
  340. # have overloads to handle both float and Tensor lrs, so we just assert it's
  341. # a float since most people using JIT are using floats
  342. assert isinstance(lr, float)
  343. assert isinstance(beta1, float)
  344. assert isinstance(beta2, float)
  345. else:
  346. lr = _to_scalar(lr)
  347. # TODO: Support nonzero-dim Tensor betas, see #147921
  348. # We only shuffle around the beta when it is a Tensor, otherwise, we prefer
  349. # treating it as a scalar.
  350. # Note: ensure type declaration is under conditional check for isinstance
  351. # or else torchscript will get cranky about the DeviceDict type.
  352. if isinstance(beta1, Tensor):
  353. beta1_dict: Optional[DeviceDtypeDict] = {(beta1.device, beta1.dtype): beta1}
  354. else:
  355. beta1_dict = None
  356. for i, param in enumerate(params):
  357. grad = grads[i] if not maximize else -grads[i]
  358. exp_avg = exp_avgs[i]
  359. exp_avg_sq = exp_avg_sqs[i]
  360. step_t = state_steps[i]
  361. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  362. if not torch.compiler.is_compiling() and capturable:
  363. capturable_supported_devices = _get_capturable_supported_devices()
  364. assert (
  365. param.device.type == step_t.device.type
  366. and param.device.type in capturable_supported_devices
  367. ), (
  368. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  369. )
  370. # update step
  371. step_t += 1
  372. if weight_decay != 0:
  373. if decoupled_weight_decay:
  374. # Perform stepweight decay
  375. param.mul_(1 - lr * weight_decay)
  376. else:
  377. # Nested if is necessary to bypass jitscript rules
  378. if differentiable and isinstance(weight_decay, Tensor):
  379. if weight_decay.requires_grad:
  380. grad = grad.addcmul_(param.clone(), weight_decay)
  381. else:
  382. grad = grad.add(param, alpha=weight_decay)
  383. else:
  384. grad = grad.add(param, alpha=weight_decay)
  385. if torch.is_complex(param):
  386. grad = torch.view_as_real(grad)
  387. exp_avg = torch.view_as_real(exp_avg)
  388. exp_avg_sq = torch.view_as_real(exp_avg_sq)
  389. if amsgrad:
  390. max_exp_avg_sqs[i] = torch.view_as_real(max_exp_avg_sqs[i])
  391. param = torch.view_as_real(param)
  392. device = param.device
  393. if beta1_dict is not None:
  394. dtype = param.dtype # type: ignore[union-attr]
  395. # cast to workaround https://github.com/pytorch/pytorch/issues/140601
  396. key = (device, dtype)
  397. if key not in beta1_dict:
  398. beta1_dict[key] = beta1.to( # type: ignore[union-attr]
  399. device=device, dtype=dtype, non_blocking=True
  400. )
  401. device_beta1: Union[float, Tensor] = beta1_dict[key]
  402. else:
  403. device_beta1 = beta1
  404. # Decay the first and second moment running average coefficient
  405. exp_avg.lerp_(grad, 1 - device_beta1)
  406. # Nested if is necessary to bypass jitscript rules
  407. if differentiable and isinstance(beta2, Tensor):
  408. if beta2.requires_grad:
  409. # Using lerp to only use 2 operations bc addcmul's value cannot be a tensor
  410. # Showing equivalence of differentiable path and nondifferentiable path
  411. # expavg * b2 + grad^2 * (1-b2)
  412. # add expavg * (1-b2) - expavg * (1-b2) = 0
  413. # expavg * b2 + expavg * (1-b2) - expavg * (1-b2) + grad^2 * (1-b2)
  414. # expavg - expavg * (1-b2) + grad^2 * (1-b2)
  415. # expavg + (grad^2 - expavg) * (1-b2)
  416. # expavg.lerp(grad^2, 1-beta2)
  417. exp_avg_sq.lerp_(torch.square(grad), weight=1 - beta2)
  418. else:
  419. exp_avg_sq.mul_(beta2).addcmul_(
  420. grad, grad, value=cast(float, 1 - beta2)
  421. )
  422. else:
  423. exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) # type: ignore[arg-type]
  424. if capturable or differentiable:
  425. step = step_t
  426. # Nested if is necessary to bypass jitscript rules
  427. if differentiable and isinstance(beta1, Tensor):
  428. if beta1.requires_grad:
  429. bias_correction1 = 1 - beta1 ** step.clone()
  430. else:
  431. bias_correction1 = 1 - beta1**step
  432. else:
  433. bias_correction1 = 1 - beta1**step
  434. # Nested if is necessary to bypass jitscript rules
  435. if differentiable and isinstance(beta2, Tensor):
  436. if beta2.requires_grad:
  437. bias_correction2 = 1 - beta2 ** step.clone()
  438. else:
  439. bias_correction2 = 1 - beta2**step
  440. else:
  441. bias_correction2 = 1 - beta2**step
  442. step_size = lr / bias_correction1
  443. step_size_neg = step_size.neg()
  444. bias_correction2_sqrt = bias_correction2.sqrt()
  445. if amsgrad:
  446. # Maintains the maximum of all 2nd moment running avg. till now
  447. if differentiable:
  448. max_exp_avg_sq = max_exp_avg_sqs[i].clone()
  449. else:
  450. max_exp_avg_sq = max_exp_avg_sqs[i]
  451. max_exp_avg_sqs[i].copy_(torch.maximum(max_exp_avg_sq, exp_avg_sq))
  452. # Uses the max. for normalizing running avg. of gradient
  453. # Folds in (admittedly ugly) 1-elem step_size math here to avoid extra param-set-sized read+write
  454. # (can't fold it into addcdiv_ below because addcdiv_ requires value is a Number, not a Tensor)
  455. denom = (
  456. max_exp_avg_sqs[i].sqrt() / (bias_correction2_sqrt * step_size_neg)
  457. ).add_(eps / step_size_neg)
  458. else:
  459. denom = (
  460. exp_avg_sq.sqrt() / (bias_correction2_sqrt * step_size_neg)
  461. ).add_(eps / step_size_neg)
  462. if differentiable:
  463. param.addcdiv_(exp_avg.clone(), denom)
  464. else:
  465. param.addcdiv_(exp_avg, denom)
  466. else:
  467. step = _get_value(step_t)
  468. bias_correction1 = 1 - beta1**step
  469. bias_correction2 = 1 - beta2**step
  470. step_size = lr / bias_correction1
  471. bias_correction2_sqrt = bias_correction2**0.5
  472. if amsgrad:
  473. # Maintains the maximum of all 2nd moment running avg. till now
  474. torch.maximum(max_exp_avg_sqs[i], exp_avg_sq, out=max_exp_avg_sqs[i])
  475. # Use the max. for normalizing running avg. of gradient
  476. denom = (max_exp_avg_sqs[i].sqrt() / bias_correction2_sqrt).add_(eps)
  477. else:
  478. denom = (exp_avg_sq.sqrt() / bias_correction2_sqrt).add_(eps)
  479. param.addcdiv_(exp_avg, denom, value=-step_size) # type: ignore[arg-type]
  480. # Lastly, switch back to complex view
  481. if amsgrad and torch.is_complex(params[i]):
  482. max_exp_avg_sqs[i] = torch.view_as_complex(max_exp_avg_sqs[i])
  483. def _multi_tensor_adam(
  484. params: list[Tensor],
  485. grads: list[Tensor],
  486. exp_avgs: list[Tensor],
  487. exp_avg_sqs: list[Tensor],
  488. max_exp_avg_sqs: list[Tensor],
  489. state_steps: list[Tensor],
  490. grad_scale: Optional[Tensor],
  491. found_inf: Optional[Tensor],
  492. *,
  493. amsgrad: bool,
  494. has_complex: bool,
  495. beta1: Union[float, Tensor],
  496. beta2: Union[float, Tensor],
  497. lr: Union[float, Tensor],
  498. weight_decay: float,
  499. eps: float,
  500. maximize: bool,
  501. capturable: bool,
  502. differentiable: bool,
  503. decoupled_weight_decay: bool,
  504. ):
  505. if len(params) == 0:
  506. return
  507. if isinstance(lr, Tensor):
  508. if not capturable:
  509. raise RuntimeError(
  510. "lr as a Tensor is not supported for capturable=False and foreach=True"
  511. )
  512. if lr.numel() != 1:
  513. raise ValueError("Tensor lr must be 1-element")
  514. if isinstance(beta1, Tensor):
  515. if not capturable:
  516. raise ValueError(
  517. "beta1 as a Tensor is not supported for capturable=False and foreach=True"
  518. )
  519. if beta1.numel() != 1:
  520. raise ValueError("Tensor beta1 must be 1-element")
  521. if isinstance(beta2, Tensor):
  522. if not capturable:
  523. raise ValueError(
  524. "beta2 as a Tensor is not supported for capturable=False and foreach=True"
  525. )
  526. if beta2.numel() != 1:
  527. raise ValueError("Tensor beta2 must be 1-element")
  528. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  529. if not torch.compiler.is_compiling() and capturable:
  530. capturable_supported_devices = _get_capturable_supported_devices(
  531. supports_xla=False
  532. )
  533. assert all(
  534. p.device.type == step.device.type
  535. and p.device.type in capturable_supported_devices
  536. for p, step in zip(params, state_steps)
  537. ), (
  538. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  539. )
  540. assert grad_scale is None and found_inf is None
  541. assert not differentiable, "_foreach ops don't support autograd"
  542. lr = _to_scalar(lr)
  543. # TODO: Support nonzero-dim Tensor betas, see #147921
  544. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  545. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
  546. )
  547. # We only shuffle around the beta when it is a Tensor and on CUDA, otherwise, we prefer
  548. # treating it as a scalar.
  549. beta1_dict: Optional[DeviceDict] = ( # type: ignore[attr-defined]
  550. {beta1.device: beta1}
  551. if isinstance(beta1, Tensor) and str(beta1.device) != "cpu"
  552. else None
  553. )
  554. for (
  555. device_params_,
  556. device_grads_,
  557. device_exp_avgs_,
  558. device_exp_avg_sqs_,
  559. device_max_exp_avg_sqs_,
  560. device_state_steps_,
  561. ), _ in grouped_tensors.values():
  562. device_params = cast(list[Tensor], device_params_)
  563. device_grads = cast(list[Tensor], device_grads_)
  564. device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
  565. device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
  566. device_state_steps = cast(list[Tensor], device_state_steps_)
  567. device = device_params[0].device
  568. if beta1_dict is not None and device not in beta1_dict:
  569. beta1_dict[device] = beta1.to(device=device, non_blocking=True) # type: ignore[union-attr, attr-defined]
  570. device_beta1 = beta1_dict[device] if beta1_dict else beta1
  571. # Handle complex parameters
  572. if has_complex:
  573. if amsgrad:
  574. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  575. _view_as_real(
  576. device_params,
  577. device_grads,
  578. device_exp_avgs,
  579. device_exp_avg_sqs,
  580. device_max_exp_avg_sqs,
  581. )
  582. else:
  583. _view_as_real(
  584. device_params, device_grads, device_exp_avgs, device_exp_avg_sqs
  585. )
  586. if maximize:
  587. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  588. # Update steps
  589. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  590. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  591. # wrapped it once now. The alpha is required to assure we go to the right overload.
  592. if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
  593. torch._foreach_add_(
  594. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  595. )
  596. else:
  597. torch._foreach_add_(device_state_steps, 1)
  598. if weight_decay != 0:
  599. if decoupled_weight_decay:
  600. # Perform stepweight decay
  601. torch._foreach_mul_(device_params, 1 - lr * weight_decay)
  602. else:
  603. # Reuse the intermediate memory (device_grads) already allocated for maximize
  604. if maximize:
  605. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  606. else:
  607. device_grads = torch._foreach_add( # type: ignore[assignment]
  608. device_grads, device_params, alpha=weight_decay
  609. )
  610. # Decay the first and second moment running average coefficient
  611. # Use device beta1 if beta1 is a tensor to ensure all
  612. # tensors are on the same device
  613. torch._foreach_lerp_(
  614. device_exp_avgs, device_grads, cast(float, 1 - device_beta1)
  615. )
  616. torch._foreach_mul_(device_exp_avg_sqs, beta2)
  617. # Due to the strictness of the _foreach_addcmul API, we can't have a single
  618. # tensor scalar as the scalar arg (only python number is supported there)
  619. # as a result, separate out the value mul
  620. # Filed https://github.com/pytorch/pytorch/issues/139795
  621. if isinstance(beta2, torch.Tensor):
  622. scaled_device_grads = torch._foreach_mul(device_grads, 1 - beta2) # type: ignore[assignment]
  623. value = 1.0
  624. else:
  625. scaled_device_grads = device_grads # type: ignore[assignment]
  626. value = 1 - beta2
  627. torch._foreach_addcmul_(
  628. device_exp_avg_sqs, scaled_device_grads, device_grads, value
  629. )
  630. # Delete the local intermediate(s) since they won't be used anymore to save on peak memory
  631. del device_grads
  632. del scaled_device_grads
  633. bias_correction1: Union[tuple[Tensor, ...], list[Tensor]]
  634. bias_correction2: Union[tuple[Tensor, ...], list[Tensor]]
  635. bias_correction2_sqrt: Union[tuple[Tensor, ...], list[Tensor]]
  636. if capturable:
  637. bias_correction1 = torch._foreach_pow(beta1, device_state_steps) # type: ignore[arg-type]
  638. bias_correction2 = torch._foreach_pow(beta2, device_state_steps) # type: ignore[arg-type]
  639. # foreach_sub doesn't allow a scalar as the first arg
  640. torch._foreach_sub_(bias_correction1, 1)
  641. torch._foreach_sub_(bias_correction2, 1)
  642. # we do not negate bias_correction1 as it'll need to be negated later anyway
  643. torch._foreach_neg_(bias_correction2)
  644. # foreach_div doesn't allow a scalar as the first arg
  645. torch._foreach_div_(bias_correction1, lr)
  646. torch._foreach_reciprocal_(bias_correction1)
  647. torch._foreach_sqrt_(bias_correction2)
  648. # Re-assign for clarity as we maintain minimal intermediates: we'll have
  649. # step_size = - lr / (1 - beta1 ^ t) where t = num_steps
  650. # bias_correction2_sqrt = sqrt(1 - beta2 ^ t)
  651. step_size = bias_correction1
  652. bias_correction2_sqrt = bias_correction2
  653. if amsgrad:
  654. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  655. # Maintains the maximum of all 2nd moment running avg. till now
  656. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs) # type: ignore[assignment]
  657. # Set intermediate to the max. for normalizing running avg. of gradient when amsgrad
  658. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  659. else:
  660. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  661. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  662. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  663. torch._foreach_div_(exp_avg_sq_sqrt, step_size)
  664. # at this point, exp_avg_sq_sqrt = - (1 - beta^t) * [sqrt(exp_avg_sq / (1 - beta2^t)) + eps] / lr
  665. torch._foreach_addcdiv_(device_params, device_exp_avgs, exp_avg_sq_sqrt)
  666. else:
  667. bias_correction1 = [
  668. 1 - beta1 ** _get_value(step) for step in device_state_steps
  669. ]
  670. bias_correction2 = [
  671. 1 - beta2 ** _get_value(step) for step in device_state_steps
  672. ]
  673. step_size = _stack_if_compiling([(lr / bc) * -1 for bc in bias_correction1])
  674. bias_correction2_sqrt = [bc**0.5 for bc in bias_correction2] # type: ignore[arg-type]
  675. if amsgrad:
  676. device_max_exp_avg_sqs = cast(list[Tensor], device_max_exp_avg_sqs_)
  677. # Maintains the maximum of all 2nd moment running avg. till now
  678. torch._foreach_maximum_(device_max_exp_avg_sqs, device_exp_avg_sqs)
  679. # Use the max. for normalizing running avg. of gradient
  680. exp_avg_sq_sqrt = torch._foreach_sqrt(device_max_exp_avg_sqs)
  681. else:
  682. exp_avg_sq_sqrt = torch._foreach_sqrt(device_exp_avg_sqs)
  683. torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
  684. torch._foreach_add_(exp_avg_sq_sqrt, eps)
  685. torch._foreach_addcdiv_(
  686. device_params,
  687. device_exp_avgs,
  688. exp_avg_sq_sqrt,
  689. step_size, # type: ignore[arg-type]
  690. )
  691. def _fused_adam(
  692. params: list[Tensor],
  693. grads: list[Tensor],
  694. exp_avgs: list[Tensor],
  695. exp_avg_sqs: list[Tensor],
  696. max_exp_avg_sqs: list[Tensor],
  697. state_steps: list[Tensor],
  698. grad_scale: Optional[Tensor],
  699. found_inf: Optional[Tensor],
  700. *,
  701. amsgrad: bool,
  702. has_complex: bool, # Needed for consistency.
  703. beta1: float,
  704. beta2: float,
  705. lr: Union[float, Tensor],
  706. weight_decay: float,
  707. eps: float,
  708. maximize: bool,
  709. capturable: bool, # Needed for consistency.
  710. differentiable: bool,
  711. decoupled_weight_decay: bool,
  712. ) -> None:
  713. if not params:
  714. return
  715. if differentiable:
  716. raise RuntimeError("Adam with fused=True does not support differentiable=True")
  717. grad_scale_dict: DeviceDict = (
  718. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  719. )
  720. found_inf_dict: DeviceDict = (
  721. {found_inf.device: found_inf} if found_inf is not None else {}
  722. )
  723. # We only shuffle around the lr when it is a Tensor and on CUDA, otherwise, we prefer
  724. # treating it as a scalar.
  725. lr_dict: Optional[DeviceDict] = (
  726. {lr.device: lr} if isinstance(lr, Tensor) and str(lr.device) != "cpu" else None
  727. )
  728. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  729. [params, grads, exp_avgs, exp_avg_sqs, max_exp_avg_sqs, state_steps] # type: ignore[list-item]
  730. )
  731. for (device, _), (
  732. (
  733. device_params_,
  734. device_grads_,
  735. device_exp_avgs_,
  736. device_exp_avg_sqs_,
  737. device_max_exp_avg_sqs,
  738. device_state_steps_,
  739. ),
  740. _,
  741. ) in grouped_tensors.items():
  742. device_params = cast(list[Tensor], device_params_)
  743. device_grads = cast(list[Tensor], device_grads_)
  744. device_exp_avgs = cast(list[Tensor], device_exp_avgs_)
  745. device_exp_avg_sqs = cast(list[Tensor], device_exp_avg_sqs_)
  746. device_state_steps = cast(list[Tensor], device_state_steps_)
  747. device_grad_scale, device_found_inf = None, None
  748. if grad_scale is not None:
  749. device_grad_scale = grad_scale_dict.setdefault(
  750. device, grad_scale.to(device, non_blocking=True)
  751. )
  752. if found_inf is not None:
  753. device_found_inf = found_inf_dict.setdefault(
  754. device, found_inf.to(device, non_blocking=True)
  755. )
  756. if lr_dict is not None and device not in lr_dict:
  757. lr_dict[device] = lr.to(device=device, non_blocking=True) # type: ignore[union-attr]
  758. lr = lr_dict[device]
  759. torch._foreach_add_(device_state_steps, 1)
  760. func = torch._fused_adam_ if not decoupled_weight_decay else torch._fused_adamw_
  761. func(
  762. device_params,
  763. device_grads,
  764. device_exp_avgs,
  765. device_exp_avg_sqs,
  766. device_max_exp_avg_sqs, # type: ignore[arg-type]
  767. device_state_steps,
  768. amsgrad=amsgrad,
  769. lr=lr, # type: ignore[arg-type]
  770. beta1=beta1,
  771. beta2=beta2,
  772. weight_decay=weight_decay,
  773. eps=eps,
  774. maximize=maximize,
  775. grad_scale=device_grad_scale,
  776. found_inf=device_found_inf,
  777. )
  778. if device_found_inf is not None:
  779. torch._foreach_sub_(
  780. device_state_steps, [device_found_inf] * len(device_state_steps)
  781. )
  782. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adam)
  783. def adam(
  784. params: list[Tensor],
  785. grads: list[Tensor],
  786. exp_avgs: list[Tensor],
  787. exp_avg_sqs: list[Tensor],
  788. max_exp_avg_sqs: list[Tensor],
  789. state_steps: list[Tensor],
  790. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  791. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  792. foreach: Optional[bool] = None,
  793. capturable: bool = False,
  794. differentiable: bool = False,
  795. fused: Optional[bool] = None,
  796. grad_scale: Optional[Tensor] = None,
  797. found_inf: Optional[Tensor] = None,
  798. has_complex: bool = False,
  799. decoupled_weight_decay: bool = False,
  800. *,
  801. amsgrad: bool,
  802. beta1: float,
  803. beta2: float,
  804. lr: Union[float, Tensor],
  805. weight_decay: float,
  806. eps: float,
  807. maximize: bool,
  808. ):
  809. r"""Functional API that performs Adam algorithm computation.
  810. See :class:`~torch.optim.Adam` for details.
  811. """
  812. # Respect when the user inputs False/True for foreach or fused. We only want to change
  813. # the default when neither have been user-specified. Note that we default to foreach
  814. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  815. # bake-in time before making it the default, even if it is typically faster.
  816. if fused is None and foreach is None:
  817. _, foreach = _default_to_fused_or_foreach(
  818. params, differentiable, use_fused=False
  819. )
  820. # Do not flip on foreach for the unsupported case where lr is a Tensor and capturable=False.
  821. if foreach and isinstance(lr, Tensor) and not capturable:
  822. foreach = False
  823. if fused is None:
  824. fused = False
  825. if foreach is None:
  826. foreach = False
  827. # this check is slow during compilation, so we skip it
  828. # if it's strictly needed we can add this check back in dynamo
  829. if not torch.compiler.is_compiling() and not all(
  830. isinstance(t, torch.Tensor) for t in state_steps
  831. ):
  832. raise RuntimeError(
  833. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  834. )
  835. if foreach and torch.jit.is_scripting():
  836. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  837. if fused and torch.jit.is_scripting():
  838. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  839. if fused and not torch.jit.is_scripting():
  840. func = _fused_adam
  841. elif foreach and not torch.jit.is_scripting():
  842. func = _multi_tensor_adam
  843. else:
  844. func = _single_tensor_adam
  845. func(
  846. params,
  847. grads,
  848. exp_avgs,
  849. exp_avg_sqs,
  850. max_exp_avg_sqs,
  851. state_steps,
  852. amsgrad=amsgrad,
  853. has_complex=has_complex,
  854. beta1=beta1,
  855. beta2=beta2,
  856. lr=lr,
  857. weight_decay=weight_decay,
  858. eps=eps,
  859. maximize=maximize,
  860. capturable=capturable,
  861. differentiable=differentiable,
  862. grad_scale=grad_scale,
  863. found_inf=found_inf,
  864. decoupled_weight_decay=decoupled_weight_decay,
  865. )