adamax.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482
  1. # mypy: allow-untyped-defs
  2. from typing import cast, Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _get_value,
  14. _maximize_doc,
  15. _params_doc,
  16. _to_scalar,
  17. _use_grad_for_differentiable,
  18. _view_as_real,
  19. Optimizer,
  20. ParamsT,
  21. )
  22. __all__ = ["Adamax", "adamax"]
  23. class Adamax(Optimizer):
  24. def __init__(
  25. self,
  26. params: ParamsT,
  27. lr: Union[float, Tensor] = 2e-3,
  28. betas: tuple[float, float] = (0.9, 0.999),
  29. eps: float = 1e-8,
  30. weight_decay: float = 0,
  31. foreach: Optional[bool] = None,
  32. *,
  33. maximize: bool = False,
  34. differentiable: bool = False,
  35. capturable: bool = False,
  36. ):
  37. if isinstance(lr, Tensor) and lr.numel() != 1:
  38. raise ValueError("Tensor lr must be 1-element")
  39. if not 0.0 <= lr:
  40. raise ValueError(f"Invalid learning rate: {lr}")
  41. if not 0.0 <= eps:
  42. raise ValueError(f"Invalid epsilon value: {eps}")
  43. if not 0.0 <= betas[0] < 1.0:
  44. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  45. if not 0.0 <= betas[1] < 1.0:
  46. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  47. if not 0.0 <= weight_decay:
  48. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  49. defaults = {
  50. "lr": lr,
  51. "betas": betas,
  52. "eps": eps,
  53. "weight_decay": weight_decay,
  54. "foreach": foreach,
  55. "maximize": maximize,
  56. "differentiable": differentiable,
  57. "capturable": capturable,
  58. }
  59. super().__init__(params, defaults)
  60. def __setstate__(self, state):
  61. super().__setstate__(state)
  62. for group in self.param_groups:
  63. group.setdefault("foreach", None)
  64. group.setdefault("maximize", False)
  65. group.setdefault("differentiable", False)
  66. group.setdefault("capturable", False)
  67. for p in group["params"]:
  68. p_state = self.state.get(p, [])
  69. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  70. step_val = float(p_state["step"])
  71. p_state["step"] = (
  72. torch.tensor(
  73. step_val, dtype=_get_scalar_dtype(), device=p.device
  74. )
  75. if group["capturable"]
  76. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  77. )
  78. def _init_group(
  79. self, group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
  80. ):
  81. has_complex = False
  82. for p in group["params"]:
  83. if p.grad is None:
  84. continue
  85. has_complex |= torch.is_complex(p)
  86. params_with_grad.append(p)
  87. if p.grad.is_sparse:
  88. raise RuntimeError("Adamax does not support sparse gradients")
  89. grads.append(p.grad)
  90. state = self.state[p]
  91. # State initialization
  92. if len(state) == 0:
  93. state["step"] = (
  94. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  95. if group["capturable"]
  96. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  97. )
  98. state["exp_avg"] = torch.zeros_like(
  99. p, memory_format=torch.preserve_format
  100. )
  101. state["exp_inf"] = torch.zeros_like(
  102. p, memory_format=torch.preserve_format
  103. )
  104. exp_avgs.append(state["exp_avg"])
  105. exp_infs.append(state["exp_inf"])
  106. state_steps.append(state["step"])
  107. return has_complex
  108. @_use_grad_for_differentiable
  109. def step(self, closure=None):
  110. """Performs a single optimization step.
  111. Args:
  112. closure (Callable, optional): A closure that reevaluates the model
  113. and returns the loss.
  114. """
  115. self._cuda_graph_capture_health_check()
  116. loss = None
  117. if closure is not None:
  118. with torch.enable_grad():
  119. loss = closure()
  120. for group in self.param_groups:
  121. params_with_grad: list[Tensor] = []
  122. grads: list[Tensor] = []
  123. exp_avgs: list[Tensor] = []
  124. exp_infs: list[Tensor] = []
  125. state_steps: list[Tensor] = []
  126. beta1, beta2 = group["betas"]
  127. eps = group["eps"]
  128. lr = group["lr"]
  129. weight_decay = group["weight_decay"]
  130. foreach = group["foreach"]
  131. maximize = group["maximize"]
  132. differentiable = group["differentiable"]
  133. capturable = group["capturable"]
  134. has_complex = self._init_group(
  135. group, params_with_grad, grads, exp_avgs, exp_infs, state_steps
  136. )
  137. adamax(
  138. params_with_grad,
  139. grads,
  140. exp_avgs,
  141. exp_infs,
  142. state_steps,
  143. eps=eps,
  144. beta1=beta1,
  145. beta2=beta2,
  146. lr=lr,
  147. weight_decay=weight_decay,
  148. foreach=foreach,
  149. maximize=maximize,
  150. differentiable=differentiable,
  151. capturable=capturable,
  152. has_complex=has_complex,
  153. )
  154. return loss
  155. Adamax.__doc__ = (
  156. r"""Implements Adamax algorithm (a variant of Adam based on infinity norm).
  157. .. math::
  158. \begin{aligned}
  159. &\rule{110mm}{0.4pt} \\
  160. &\textbf{input} : \gamma \text{ (lr)}, \beta_1, \beta_2
  161. \text{ (betas)},\theta_0 \text{ (params)},f(\theta) \text{ (objective)},
  162. \: \lambda \text{ (weight decay)}, \\
  163. &\hspace{13mm} \epsilon \text{ (epsilon)} \\
  164. &\textbf{initialize} : m_0 \leftarrow 0 \text{ ( first moment)},
  165. u_0 \leftarrow 0 \text{ ( infinity norm)} \\[-1.ex]
  166. &\rule{110mm}{0.4pt} \\
  167. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  168. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  169. &\hspace{5mm}if \: \lambda \neq 0 \\
  170. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  171. &\hspace{5mm}m_t \leftarrow \beta_1 m_{t-1} + (1 - \beta_1) g_t \\
  172. &\hspace{5mm}u_t \leftarrow \mathrm{max}(\beta_2 u_{t-1}, |g_{t}|+\epsilon) \\
  173. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \frac{\gamma m_t}{(1-\beta^t_1) u_t} \\
  174. &\rule{110mm}{0.4pt} \\[-1.ex]
  175. &\bf{return} \: \theta_t \\[-1.ex]
  176. &\rule{110mm}{0.4pt} \\[-1.ex]
  177. \end{aligned}
  178. For further details regarding the algorithm we refer to `Adam: A Method for Stochastic Optimization`_.
  179. """
  180. + rf"""
  181. Args:
  182. {_params_doc}
  183. lr (float, Tensor, optional): learning rate (default: 2e-3)
  184. betas (Tuple[float, float], optional): coefficients used for computing
  185. running averages of gradient and its square
  186. eps (float, optional): term added to the denominator to improve
  187. numerical stability (default: 1e-8)
  188. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  189. {_foreach_doc}
  190. {_maximize_doc}
  191. {_differentiable_doc}
  192. {_capturable_doc}
  193. .. _Adam\: A Method for Stochastic Optimization:
  194. https://arxiv.org/abs/1412.6980
  195. """
  196. )
  197. def _single_tensor_adamax(
  198. params: list[Tensor],
  199. grads: list[Tensor],
  200. exp_avgs: list[Tensor],
  201. exp_infs: list[Tensor],
  202. state_steps: list[Tensor],
  203. *,
  204. eps: float,
  205. beta1: float,
  206. beta2: float,
  207. lr: float,
  208. weight_decay: float,
  209. maximize: bool,
  210. differentiable: bool,
  211. capturable: bool,
  212. has_complex: bool,
  213. ):
  214. if not torch.jit.is_scripting():
  215. lr = _to_scalar(lr)
  216. for i, param in enumerate(params):
  217. grad = grads[i]
  218. grad = grad if not maximize else -grad
  219. exp_avg = exp_avgs[i]
  220. exp_inf = exp_infs[i]
  221. step_t = state_steps[i]
  222. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  223. if not torch.compiler.is_compiling() and capturable:
  224. capturable_supported_devices = _get_capturable_supported_devices()
  225. assert (
  226. param.device.type == step_t.device.type
  227. and param.device.type in capturable_supported_devices
  228. ), (
  229. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  230. )
  231. # update step
  232. step_t += 1
  233. if weight_decay != 0:
  234. grad = grad.add(param, alpha=weight_decay)
  235. if torch.is_complex(param):
  236. param = torch.view_as_real(param)
  237. grad = torch.view_as_real(grad)
  238. exp_avg = torch.view_as_real(exp_avg)
  239. exp_inf = torch.view_as_real(exp_inf)
  240. # Update biased first moment estimate.
  241. exp_avg.lerp_(grad, 1 - beta1)
  242. # Update the exponentially weighted infinity norm.
  243. if not differentiable:
  244. torch.maximum(
  245. exp_inf.mul_(beta2),
  246. grad.abs().add_(eps),
  247. out=exp_inf,
  248. )
  249. else:
  250. norm_buf = torch.cat(
  251. [exp_inf.mul_(beta2).unsqueeze(0), grad.abs().add_(eps).unsqueeze_(0)],
  252. 0,
  253. )
  254. exp_inf.copy_(torch.amax(norm_buf, 0, keepdim=False))
  255. if capturable:
  256. # why jump through extra hoops and negate bias_correction? check out #121238
  257. # once fixed, we should use bias_correction with addcdiv value=-1 for readability
  258. neg_bias_correction = beta1**step_t - 1
  259. neg_bias_correction.div_(lr)
  260. denom = exp_inf * neg_bias_correction
  261. param.addcdiv_(exp_avg, denom)
  262. else:
  263. bias_correction = 1 - beta1 ** _get_value(step_t)
  264. clr = lr / bias_correction
  265. param.addcdiv_(exp_avg, exp_inf, value=-clr)
  266. def _multi_tensor_adamax(
  267. params: list[Tensor],
  268. grads: list[Tensor],
  269. exp_avgs: list[Tensor],
  270. exp_infs: list[Tensor],
  271. state_steps: list[Tensor],
  272. *,
  273. eps: float,
  274. beta1: float,
  275. beta2: float,
  276. lr: float,
  277. weight_decay: float,
  278. maximize: bool,
  279. differentiable: bool,
  280. capturable: bool,
  281. has_complex: bool,
  282. ):
  283. assert not differentiable, "_foreach ops don't support autograd"
  284. if len(params) == 0:
  285. return
  286. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  287. if not torch.compiler.is_compiling() and capturable:
  288. capturable_supported_devices = _get_capturable_supported_devices(
  289. supports_xla=False
  290. )
  291. assert all(
  292. p.device.type == step.device.type
  293. and p.device.type in capturable_supported_devices
  294. for p, step in zip(params, state_steps)
  295. ), (
  296. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  297. )
  298. lr = _to_scalar(lr)
  299. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  300. [params, grads, exp_avgs, exp_infs, state_steps] # type: ignore[list-item]
  301. )
  302. for (
  303. grouped_params_,
  304. grouped_grads_,
  305. grouped_exp_avgs_,
  306. grouped_exp_infs_,
  307. grouped_state_steps_,
  308. ), _ in grouped_tensors.values():
  309. grouped_params = cast(list[Tensor], grouped_params_)
  310. grouped_grads = cast(list[Tensor], grouped_grads_)
  311. grouped_exp_avgs = cast(list[Tensor], grouped_exp_avgs_)
  312. grouped_exp_infs = cast(list[Tensor], grouped_exp_infs_)
  313. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  314. if has_complex:
  315. _view_as_real(
  316. grouped_params, grouped_grads, grouped_exp_avgs, grouped_exp_infs
  317. )
  318. if maximize:
  319. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  320. # Update steps
  321. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  322. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  323. # wrapped it once now. The alpha is required to assure we go to the right overload.
  324. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  325. torch._foreach_add_(
  326. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  327. )
  328. else:
  329. torch._foreach_add_(grouped_state_steps, 1)
  330. if weight_decay != 0:
  331. if maximize:
  332. # Reuse the intermediate memory (grouped_grads) already allocated for maximize
  333. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  334. else:
  335. grouped_grads = torch._foreach_add( # type: ignore[assignment]
  336. grouped_grads, grouped_params, alpha=weight_decay
  337. )
  338. # Update biased first moment estimate.
  339. torch._foreach_lerp_(grouped_exp_avgs, grouped_grads, 1 - beta1)
  340. # Update the exponentially weighted infinity norm.
  341. torch._foreach_mul_(grouped_exp_infs, beta2)
  342. # in this case, we need to introduce a copy of the grads
  343. # since one has not been introduced previously
  344. if not maximize and weight_decay == 0:
  345. grouped_grads = torch._foreach_abs(grouped_grads) # type: ignore[assignment]
  346. else:
  347. torch._foreach_abs_(grouped_grads)
  348. torch._foreach_add_(grouped_grads, eps)
  349. torch._foreach_maximum_(grouped_exp_infs, grouped_grads)
  350. bias_corrections: Union[tuple[Tensor, ...], list[Tensor]]
  351. if capturable:
  352. bias_corrections = torch._foreach_pow(beta1, grouped_state_steps)
  353. # foreach_sub doesn't allow a scalar as the first arg
  354. torch._foreach_sub_(bias_corrections, 1)
  355. torch._foreach_div_(bias_corrections, lr)
  356. denom = torch._foreach_mul(grouped_exp_infs, bias_corrections)
  357. torch._foreach_addcdiv_(grouped_params, grouped_exp_avgs, denom)
  358. else:
  359. bias_corrections = [
  360. 1 - beta1 ** _get_value(step) for step in grouped_state_steps
  361. ]
  362. step_size = [(_get_value(lr) / bc) * -1 for bc in bias_corrections]
  363. torch._foreach_addcdiv_(
  364. grouped_params, grouped_exp_avgs, grouped_exp_infs, step_size
  365. )
  366. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_adamax)
  367. def adamax(
  368. params: list[Tensor],
  369. grads: list[Tensor],
  370. exp_avgs: list[Tensor],
  371. exp_infs: list[Tensor],
  372. state_steps: list[Tensor],
  373. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  374. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  375. foreach: Optional[bool] = None,
  376. maximize: bool = False,
  377. differentiable: bool = False,
  378. capturable: bool = False,
  379. has_complex: bool = False,
  380. *,
  381. eps: float,
  382. beta1: float,
  383. beta2: float,
  384. lr: float,
  385. weight_decay: float,
  386. ):
  387. r"""Functional API that performs adamax algorithm computation.
  388. See :class:`~torch.optim.Adamax` for details.
  389. """
  390. if not torch.compiler.is_compiling() and not all(
  391. isinstance(t, torch.Tensor) for t in state_steps
  392. ):
  393. raise RuntimeError(
  394. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  395. )
  396. if foreach is None:
  397. _, foreach = _default_to_fused_or_foreach(
  398. params, differentiable, use_fused=False
  399. )
  400. if foreach and torch.jit.is_scripting():
  401. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  402. if foreach and not torch.jit.is_scripting():
  403. func = _multi_tensor_adamax
  404. else:
  405. func = _single_tensor_adamax
  406. func(
  407. params,
  408. grads,
  409. exp_avgs,
  410. exp_infs,
  411. state_steps,
  412. eps=eps,
  413. beta1=beta1,
  414. beta2=beta2,
  415. lr=lr,
  416. weight_decay=weight_decay,
  417. maximize=maximize,
  418. differentiable=differentiable,
  419. has_complex=has_complex,
  420. capturable=capturable,
  421. )