rmsprop.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for the RMSprop algorithm."""
  3. from typing import cast, Optional, Union
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _capturable_doc,
  8. _default_to_fused_or_foreach,
  9. _differentiable_doc,
  10. _disable_dynamo_if_unsupported,
  11. _foreach_doc,
  12. _get_capturable_supported_devices,
  13. _get_scalar_dtype,
  14. _maximize_doc,
  15. _params_doc,
  16. _to_scalar,
  17. _use_grad_for_differentiable,
  18. _view_as_real,
  19. Optimizer,
  20. ParamsT,
  21. )
  22. __all__ = ["RMSprop", "rmsprop"]
  23. class RMSprop(Optimizer): # noqa: D101
  24. def __init__(
  25. self,
  26. params: ParamsT,
  27. lr: Union[float, Tensor] = 1e-2,
  28. alpha: float = 0.99,
  29. eps: float = 1e-8,
  30. weight_decay: float = 0,
  31. momentum: float = 0,
  32. centered: bool = False,
  33. capturable: bool = False,
  34. foreach: Optional[bool] = None,
  35. maximize: bool = False,
  36. differentiable: bool = False,
  37. ): # noqa: D107
  38. if isinstance(lr, Tensor) and lr.numel() != 1:
  39. raise ValueError("Tensor lr must be 1-element")
  40. if not 0.0 <= lr:
  41. raise ValueError(f"Invalid learning rate: {lr}")
  42. if not 0.0 <= eps:
  43. raise ValueError(f"Invalid epsilon value: {eps}")
  44. if not 0.0 <= momentum:
  45. raise ValueError(f"Invalid momentum value: {momentum}")
  46. if not 0.0 <= weight_decay:
  47. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  48. if not 0.0 <= alpha:
  49. raise ValueError(f"Invalid alpha value: {alpha}")
  50. defaults = {
  51. "lr": lr,
  52. "momentum": momentum,
  53. "alpha": alpha,
  54. "eps": eps,
  55. "centered": centered,
  56. "weight_decay": weight_decay,
  57. "capturable": capturable,
  58. "foreach": foreach,
  59. "maximize": maximize,
  60. "differentiable": differentiable,
  61. }
  62. super().__init__(params, defaults)
  63. def __setstate__(self, state): # noqa: D105
  64. super().__setstate__(state)
  65. for group in self.param_groups:
  66. group.setdefault("momentum", 0)
  67. group.setdefault("centered", False)
  68. group.setdefault("foreach", None)
  69. group.setdefault("maximize", False)
  70. group.setdefault("differentiable", False)
  71. group.setdefault("capturable", False)
  72. for p in group["params"]:
  73. p_state = self.state.get(p, [])
  74. if len(p_state) != 0 and not torch.is_tensor(p_state["step"]):
  75. step_val = float(p_state["step"])
  76. p_state["step"] = (
  77. torch.tensor(
  78. step_val, dtype=_get_scalar_dtype(), device=p.device
  79. )
  80. if group["capturable"]
  81. else torch.tensor(step_val, dtype=_get_scalar_dtype())
  82. )
  83. def _init_group(
  84. self,
  85. group,
  86. params_with_grad,
  87. grads,
  88. square_avgs,
  89. momentum_buffer_list,
  90. grad_avgs,
  91. state_steps,
  92. ):
  93. has_complex = False
  94. for p in group["params"]:
  95. if p.grad is None:
  96. continue
  97. has_complex |= torch.is_complex(p)
  98. params_with_grad.append(p)
  99. if p.grad.is_sparse:
  100. raise RuntimeError("RMSprop does not support sparse gradients")
  101. grads.append(p.grad)
  102. state = self.state[p]
  103. # State initialization
  104. if len(state) == 0:
  105. state["step"] = (
  106. torch.zeros((), dtype=_get_scalar_dtype(), device=p.device)
  107. if group["capturable"]
  108. else torch.zeros((), dtype=_get_scalar_dtype())
  109. )
  110. state["square_avg"] = torch.zeros_like(
  111. p, memory_format=torch.preserve_format
  112. )
  113. if group["momentum"] > 0:
  114. state["momentum_buffer"] = torch.zeros_like(
  115. p, memory_format=torch.preserve_format
  116. )
  117. if group["centered"]:
  118. state["grad_avg"] = torch.zeros_like(
  119. p, memory_format=torch.preserve_format
  120. )
  121. square_avgs.append(state["square_avg"])
  122. state_steps.append(state["step"])
  123. if group["momentum"] > 0:
  124. momentum_buffer_list.append(state["momentum_buffer"])
  125. if group["centered"]:
  126. grad_avgs.append(state["grad_avg"])
  127. return has_complex
  128. @_use_grad_for_differentiable
  129. def step(self, closure=None):
  130. """Perform a single optimization step.
  131. Args:
  132. closure (Callable, optional): A closure that reevaluates the model
  133. and returns the loss.
  134. """
  135. self._cuda_graph_capture_health_check()
  136. loss = None
  137. if closure is not None:
  138. with torch.enable_grad():
  139. loss = closure()
  140. for group in self.param_groups:
  141. params_with_grad: list[Tensor] = []
  142. grads: list[Tensor] = []
  143. square_avgs: list[Tensor] = []
  144. grad_avgs: list[Tensor] = []
  145. momentum_buffer_list: list[Tensor] = []
  146. state_steps: list[Tensor] = []
  147. has_complex = self._init_group(
  148. group,
  149. params_with_grad,
  150. grads,
  151. square_avgs,
  152. momentum_buffer_list,
  153. grad_avgs,
  154. state_steps,
  155. )
  156. rmsprop(
  157. params_with_grad,
  158. grads,
  159. square_avgs,
  160. grad_avgs,
  161. momentum_buffer_list,
  162. state_steps,
  163. lr=group["lr"],
  164. alpha=group["alpha"],
  165. eps=group["eps"],
  166. weight_decay=group["weight_decay"],
  167. momentum=group["momentum"],
  168. centered=group["centered"],
  169. foreach=group["foreach"],
  170. maximize=group["maximize"],
  171. differentiable=group["differentiable"],
  172. capturable=group["capturable"],
  173. has_complex=has_complex,
  174. )
  175. return loss
  176. RMSprop.__doc__ = (
  177. r"""Implements RMSprop algorithm.
  178. .. math::
  179. \begin{aligned}
  180. &\rule{110mm}{0.4pt} \\
  181. &\textbf{input} : \alpha \text{ (alpha)}, \: \gamma \text{ (lr)},
  182. \: \theta_0 \text{ (params)}, \: f(\theta) \text{ (objective)} \\
  183. &\hspace{13mm} \lambda \text{ (weight decay)},\: \mu \text{ (momentum)},
  184. \: centered, \: \epsilon \text{ (epsilon)} \\
  185. &\textbf{initialize} : v_0 \leftarrow 0 \text{ (square average)}, \:
  186. \textbf{b}_0 \leftarrow 0 \text{ (buffer)}, \: g^{ave}_0 \leftarrow 0 \\[-1.ex]
  187. &\rule{110mm}{0.4pt} \\
  188. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  189. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  190. &\hspace{5mm}if \: \lambda \neq 0 \\
  191. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  192. &\hspace{5mm}v_t \leftarrow \alpha v_{t-1} + (1 - \alpha) g^2_t
  193. \hspace{8mm} \\
  194. &\hspace{5mm} \tilde{v_t} \leftarrow v_t \\
  195. &\hspace{5mm}if \: centered \\
  196. &\hspace{10mm} g^{ave}_t \leftarrow g^{ave}_{t-1} \alpha + (1-\alpha) g_t \\
  197. &\hspace{10mm} \tilde{v_t} \leftarrow \tilde{v_t} - \big(g^{ave}_{t} \big)^2 \\
  198. &\hspace{5mm}if \: \mu > 0 \\
  199. &\hspace{10mm} \textbf{b}_t\leftarrow \mu \textbf{b}_{t-1} +
  200. g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \\
  201. &\hspace{10mm} \theta_t \leftarrow \theta_{t-1} - \gamma \textbf{b}_t \\
  202. &\hspace{5mm} else \\
  203. &\hspace{10mm}\theta_t \leftarrow \theta_{t-1} -
  204. \gamma g_t/ \big(\sqrt{\tilde{v_t}} + \epsilon \big) \hspace{3mm} \\
  205. &\rule{110mm}{0.4pt} \\[-1.ex]
  206. &\bf{return} \: \theta_t \\[-1.ex]
  207. &\rule{110mm}{0.4pt} \\[-1.ex]
  208. \end{aligned}
  209. For further details regarding the algorithm we refer to
  210. `lecture notes <https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf>`_ by G. Hinton.
  211. and centered version `Generating Sequences
  212. With Recurrent Neural Networks <https://arxiv.org/pdf/1308.0850v5.pdf>`_.
  213. The implementation here takes the square root of the gradient average before
  214. adding epsilon (note that TensorFlow interchanges these two operations). The effective
  215. learning rate is thus :math:`\gamma/(\sqrt{v} + \epsilon)` where :math:`\gamma`
  216. is the scheduled learning rate and :math:`v` is the weighted moving average
  217. of the squared gradient.
  218. """
  219. + rf"""
  220. Args:
  221. {_params_doc}
  222. lr (float, Tensor, optional): learning rate (default: 1e-2)
  223. alpha (float, optional): smoothing constant (default: 0.99)
  224. eps (float, optional): term added to the denominator to improve
  225. numerical stability (default: 1e-8)
  226. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  227. momentum (float, optional): momentum factor (default: 0)
  228. centered (bool, optional) : if ``True``, compute the centered RMSProp,
  229. the gradient is normalized by an estimation of its variance
  230. {_capturable_doc}
  231. {_foreach_doc}
  232. {_maximize_doc}
  233. {_differentiable_doc}
  234. """
  235. )
  236. def _single_tensor_rmsprop(
  237. params: list[Tensor],
  238. grads: list[Tensor],
  239. square_avgs: list[Tensor],
  240. grad_avgs: list[Tensor],
  241. momentum_buffer_list: list[Tensor],
  242. state_steps: list[Tensor],
  243. *,
  244. lr: float,
  245. alpha: float,
  246. eps: float,
  247. weight_decay: float,
  248. momentum: float,
  249. centered: bool,
  250. maximize: bool,
  251. differentiable: bool,
  252. capturable: bool,
  253. has_complex: bool,
  254. ):
  255. if not torch.jit.is_scripting():
  256. lr = _to_scalar(lr)
  257. for i, param in enumerate(params):
  258. step = state_steps[i]
  259. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  260. if not torch.compiler.is_compiling() and capturable:
  261. capturable_supported_devices = _get_capturable_supported_devices()
  262. assert (
  263. param.device.type == step.device.type
  264. and param.device.type in capturable_supported_devices
  265. ), (
  266. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  267. )
  268. grad = grads[i]
  269. grad = grad if not maximize else -grad
  270. square_avg = square_avgs[i]
  271. step += 1
  272. if weight_decay != 0:
  273. grad = grad.add(param, alpha=weight_decay)
  274. is_complex_param = torch.is_complex(param)
  275. if is_complex_param:
  276. param = torch.view_as_real(param)
  277. grad = torch.view_as_real(grad)
  278. square_avg = torch.view_as_real(square_avg)
  279. square_avg.mul_(alpha).addcmul_(grad, grad, value=1 - alpha)
  280. if centered:
  281. grad_avg = grad_avgs[i]
  282. if is_complex_param:
  283. grad_avg = torch.view_as_real(grad_avg)
  284. grad_avg.lerp_(grad, 1 - alpha)
  285. avg = square_avg.addcmul(grad_avg, grad_avg, value=-1).sqrt_()
  286. else:
  287. avg = square_avg.sqrt()
  288. if differentiable:
  289. avg = avg.add(eps)
  290. else:
  291. avg = avg.add_(eps)
  292. if momentum > 0:
  293. buf = momentum_buffer_list[i]
  294. if is_complex_param:
  295. buf = torch.view_as_real(buf)
  296. buf.mul_(momentum).addcdiv_(grad, avg)
  297. param.add_(buf, alpha=-lr)
  298. else:
  299. param.addcdiv_(grad, avg, value=-lr)
  300. def _multi_tensor_rmsprop(
  301. params: list[Tensor],
  302. grads: list[Tensor],
  303. square_avgs: list[Tensor],
  304. grad_avgs: list[Tensor],
  305. momentum_buffer_list: list[Tensor],
  306. state_steps: list[Tensor],
  307. *,
  308. lr: float,
  309. alpha: float,
  310. eps: float,
  311. weight_decay: float,
  312. momentum: float,
  313. centered: bool,
  314. maximize: bool,
  315. differentiable: bool,
  316. capturable: bool,
  317. has_complex: bool,
  318. ):
  319. if len(params) == 0:
  320. return
  321. assert not differentiable, "_foreach ops don't support autograd"
  322. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  323. if not torch.compiler.is_compiling() and capturable:
  324. capturable_supported_devices = _get_capturable_supported_devices()
  325. assert all(
  326. p.device.type == step.device.type
  327. and p.device.type in capturable_supported_devices
  328. for p, step in zip(params, state_steps)
  329. ), (
  330. f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
  331. )
  332. lr = _to_scalar(lr)
  333. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  334. [params, grads, square_avgs, grad_avgs, momentum_buffer_list, state_steps] # type: ignore[list-item]
  335. )
  336. for (
  337. (
  338. grouped_params_,
  339. grouped_grads_,
  340. grouped_square_avgs_,
  341. grouped_grad_avgs_,
  342. grouped_momentum_buffer_list_,
  343. grouped_state_steps_,
  344. )
  345. ), _ in grouped_tensors.values():
  346. grouped_params = cast(list[Tensor], grouped_params_)
  347. grouped_grads = cast(list[Tensor], grouped_grads_)
  348. grouped_square_avgs = cast(list[Tensor], grouped_square_avgs_)
  349. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  350. if has_complex:
  351. state_and_grads = [grouped_grads, grouped_square_avgs]
  352. if momentum > 0:
  353. grouped_momentum_buffer_list = cast(
  354. list[Tensor], grouped_momentum_buffer_list_
  355. )
  356. state_and_grads.append(grouped_momentum_buffer_list)
  357. if centered:
  358. grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
  359. state_and_grads.append(grouped_grad_avgs)
  360. _view_as_real(grouped_params, *state_and_grads)
  361. if maximize:
  362. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  363. # Update steps
  364. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  365. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  366. # wrapped it once now. The alpha is required to assure we go to the right overload.
  367. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  368. torch._foreach_add_(
  369. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  370. )
  371. else:
  372. torch._foreach_add_(grouped_state_steps, 1)
  373. if weight_decay != 0:
  374. # Reuse the intermediate memory (grouped_grads) already allocated for maximize
  375. if maximize:
  376. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  377. else:
  378. grouped_grads = torch._foreach_add( # type: ignore[assignment]
  379. grouped_grads, grouped_params, alpha=weight_decay
  380. )
  381. torch._foreach_mul_(grouped_square_avgs, alpha)
  382. torch._foreach_addcmul_(
  383. grouped_square_avgs, grouped_grads, grouped_grads, value=1 - alpha
  384. )
  385. if centered:
  386. grouped_grad_avgs = cast(list[Tensor], grouped_grad_avgs_)
  387. torch._foreach_lerp_(grouped_grad_avgs, grouped_grads, 1 - alpha)
  388. avg = torch._foreach_addcmul(
  389. grouped_square_avgs, grouped_grad_avgs, grouped_grad_avgs, value=-1
  390. )
  391. torch._foreach_sqrt_(avg)
  392. torch._foreach_add_(avg, eps)
  393. else:
  394. avg = torch._foreach_sqrt(grouped_square_avgs)
  395. torch._foreach_add_(avg, eps)
  396. if momentum > 0:
  397. grouped_momentum_buffer_list = cast(
  398. list[Tensor], grouped_momentum_buffer_list_
  399. )
  400. torch._foreach_mul_(grouped_momentum_buffer_list, momentum)
  401. torch._foreach_addcdiv_(grouped_momentum_buffer_list, grouped_grads, avg)
  402. # If LR is a tensor, the else branch will internally call item()
  403. # which will cause silent incorrectness if we are capturing
  404. if capturable and isinstance(lr, torch.Tensor):
  405. momentum_lr = torch._foreach_mul(grouped_momentum_buffer_list, -lr)
  406. torch._foreach_add_(grouped_params, momentum_lr)
  407. else:
  408. torch._foreach_add_(
  409. grouped_params, grouped_momentum_buffer_list, alpha=-lr
  410. )
  411. else:
  412. # If LR is a tensor, the else branch will internally call item()
  413. # which will cause silent incorrectness if we are capturing
  414. if capturable and isinstance(lr, torch.Tensor):
  415. torch._foreach_div_(avg, -lr)
  416. torch._foreach_addcdiv_(grouped_params, grouped_grads, avg)
  417. else:
  418. torch._foreach_addcdiv_(grouped_params, grouped_grads, avg, value=-lr)
  419. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_rmsprop)
  420. def rmsprop(
  421. params: list[Tensor],
  422. grads: list[Tensor],
  423. square_avgs: list[Tensor],
  424. grad_avgs: list[Tensor],
  425. momentum_buffer_list: list[Tensor],
  426. state_steps: list[Tensor],
  427. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  428. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  429. foreach: Optional[bool] = None,
  430. maximize: bool = False,
  431. differentiable: bool = False,
  432. capturable: bool = False,
  433. has_complex: bool = False,
  434. *,
  435. lr: float,
  436. alpha: float,
  437. eps: float,
  438. weight_decay: float,
  439. momentum: float,
  440. centered: bool,
  441. ):
  442. r"""Functional API that performs rmsprop algorithm computation.
  443. See :class:`~torch.optim.RMSProp` for details.
  444. """
  445. # this check is slow during compilation, so we skip it
  446. # if it's strictly needed we can add this check back in dynamo
  447. if not torch.compiler.is_compiling() and not all(
  448. isinstance(t, torch.Tensor) for t in state_steps
  449. ):
  450. raise RuntimeError(
  451. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  452. )
  453. if foreach is None:
  454. _, foreach = _default_to_fused_or_foreach(
  455. params, differentiable, use_fused=False
  456. )
  457. if foreach and torch.jit.is_scripting():
  458. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  459. if foreach and not torch.jit.is_scripting():
  460. func = _multi_tensor_rmsprop
  461. else:
  462. func = _single_tensor_rmsprop
  463. func(
  464. params,
  465. grads,
  466. square_avgs,
  467. grad_avgs,
  468. momentum_buffer_list,
  469. state_steps,
  470. lr=lr,
  471. alpha=alpha,
  472. eps=eps,
  473. weight_decay=weight_decay,
  474. momentum=momentum,
  475. centered=centered,
  476. maximize=maximize,
  477. capturable=capturable,
  478. differentiable=differentiable,
  479. has_complex=has_complex,
  480. )