adagrad.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574
  1. # mypy: allow-untyped-defs
  2. from typing import cast, Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _default_to_fused_or_foreach,
  7. _device_dtype_check_for_fused,
  8. _differentiable_doc,
  9. _foreach_doc,
  10. _get_scalar_dtype,
  11. _get_value,
  12. _maximize_doc,
  13. _params_doc,
  14. _to_scalar,
  15. _use_grad_for_differentiable,
  16. _view_as_real,
  17. Optimizer,
  18. ParamsT,
  19. )
  20. __all__ = ["Adagrad", "adagrad"]
  21. class Adagrad(Optimizer):
  22. def __init__(
  23. self,
  24. params: ParamsT,
  25. lr: Union[float, Tensor] = 1e-2,
  26. lr_decay: float = 0,
  27. weight_decay: float = 0,
  28. initial_accumulator_value: float = 0,
  29. eps: float = 1e-10,
  30. foreach: Optional[bool] = None,
  31. *,
  32. maximize: bool = False,
  33. differentiable: bool = False,
  34. fused: Optional[bool] = None,
  35. ):
  36. if isinstance(lr, Tensor) and lr.numel() != 1:
  37. raise ValueError("Tensor lr must be 1-element")
  38. if not 0.0 <= lr:
  39. raise ValueError(f"Invalid learning rate: {lr}")
  40. if not 0.0 <= lr_decay:
  41. raise ValueError(f"Invalid lr_decay value: {lr_decay}")
  42. if not 0.0 <= weight_decay:
  43. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  44. if not 0.0 <= initial_accumulator_value:
  45. raise ValueError(
  46. f"Invalid initial_accumulator_value value: {initial_accumulator_value}"
  47. )
  48. if not 0.0 <= eps:
  49. raise ValueError(f"Invalid epsilon value: {eps}")
  50. defaults = {
  51. "lr": lr,
  52. "lr_decay": lr_decay,
  53. "eps": eps,
  54. "weight_decay": weight_decay,
  55. "initial_accumulator_value": initial_accumulator_value,
  56. "foreach": foreach,
  57. "maximize": maximize,
  58. "differentiable": differentiable,
  59. "fused": fused,
  60. }
  61. super().__init__(params, defaults)
  62. if fused:
  63. if differentiable:
  64. raise RuntimeError("`fused` does not support `differentiable`")
  65. if foreach:
  66. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  67. self._need_device_dtype_check_for_fused = True
  68. for group in self.param_groups:
  69. for p in group["params"]:
  70. state = self.state[p]
  71. state["step"] = (
  72. torch.zeros(
  73. (),
  74. dtype=_get_scalar_dtype(is_fused=group["fused"]),
  75. device=p.device,
  76. )
  77. if group["fused"]
  78. else torch.tensor(0.0, dtype=_get_scalar_dtype())
  79. )
  80. init_value = (
  81. complex(initial_accumulator_value, initial_accumulator_value)
  82. if torch.is_complex(p)
  83. else initial_accumulator_value
  84. )
  85. state["sum"] = torch.full_like(
  86. p, init_value, memory_format=torch.preserve_format
  87. )
  88. def __setstate__(self, state):
  89. super().__setstate__(state)
  90. # define "fused" for
  91. # MYPY error: Name "fused" may be undefined
  92. fused = None
  93. for group in self.param_groups:
  94. group.setdefault("foreach", None)
  95. group.setdefault("maximize", False)
  96. group.setdefault("differentiable", False)
  97. fused = group.setdefault("fused", None)
  98. state_values = list(self.state.values())
  99. step_is_tensor = (len(state_values) != 0) and torch.is_tensor(
  100. state_values[0]["step"]
  101. )
  102. if not step_is_tensor:
  103. for s in state_values:
  104. s["step"] = torch.tensor(
  105. float(s["step"]), dtype=_get_scalar_dtype(is_fused=fused)
  106. )
  107. def share_memory(self):
  108. """Calls tensor.share_memory_() on the state sum tensors."""
  109. for group in self.param_groups:
  110. for p in group["params"]:
  111. state = self.state[p]
  112. state["sum"].share_memory_()
  113. def _init_group(self, group, params_with_grad, grads, state_sums, state_steps):
  114. has_sparse_grad, has_complex = False, False
  115. for p in group["params"]:
  116. if p.grad is not None:
  117. if group["fused"] and getattr(
  118. self,
  119. "_need_device_dtype_check_for_fused",
  120. True,
  121. ):
  122. _device_dtype_check_for_fused(p, cuda_unsupported=True)
  123. self._need_device_dtype_check_for_fused = False
  124. has_sparse_grad |= p.grad.is_sparse
  125. has_complex |= torch.is_complex(p)
  126. params_with_grad.append(p)
  127. grads.append(p.grad)
  128. state = self.state[p]
  129. state_sums.append(state["sum"])
  130. state_steps.append(state["step"])
  131. return has_sparse_grad, has_complex
  132. @_use_grad_for_differentiable
  133. def step(self, closure=None):
  134. """Perform a single optimization step.
  135. Args:
  136. closure (Callable, optional): A closure that reevaluates the model
  137. and returns the loss.
  138. """
  139. loss = None
  140. if closure is not None:
  141. with torch.enable_grad():
  142. loss = closure()
  143. for group in self.param_groups:
  144. params_with_grad: list[Tensor] = []
  145. grads: list[Tensor] = []
  146. state_sums: list[Tensor] = []
  147. state_steps: list[Tensor] = []
  148. has_sparse_grad, has_complex = self._init_group(
  149. group, params_with_grad, grads, state_sums, state_steps
  150. )
  151. adagrad(
  152. params_with_grad,
  153. grads,
  154. state_sums,
  155. state_steps,
  156. lr=group["lr"],
  157. weight_decay=group["weight_decay"],
  158. lr_decay=group["lr_decay"],
  159. eps=group["eps"],
  160. has_sparse_grad=has_sparse_grad,
  161. foreach=group["foreach"],
  162. maximize=group["maximize"],
  163. differentiable=group["differentiable"],
  164. has_complex=has_complex,
  165. fused=group["fused"],
  166. grad_scale=getattr(self, "grad_scale", None),
  167. found_inf=getattr(self, "found_inf", None),
  168. )
  169. return loss
  170. Adagrad.__doc__ = (
  171. r"""Implements Adagrad algorithm.
  172. .. math::
  173. \begin{aligned}
  174. &\rule{110mm}{0.4pt} \\
  175. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  176. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  177. &\hspace{12mm} \tau \text{ (initial accumulator value)}, \: \eta\text{ (lr decay)}\\
  178. &\textbf{initialize} : state\_sum_0 \leftarrow \tau \\[-1.ex]
  179. &\rule{110mm}{0.4pt} \\
  180. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  181. &\hspace{5mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  182. &\hspace{5mm} \tilde{\gamma} \leftarrow \gamma / (1 +(t-1) \eta) \\
  183. &\hspace{5mm} \textbf{if} \: \lambda \neq 0 \\
  184. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  185. &\hspace{5mm}state\_sum_t \leftarrow state\_sum_{t-1} + g^2_t \\
  186. &\hspace{5mm}\theta_t \leftarrow
  187. \theta_{t-1}- \tilde{\gamma} \frac{g_t}{\sqrt{state\_sum_t}+\epsilon} \\
  188. &\rule{110mm}{0.4pt} \\[-1.ex]
  189. &\bf{return} \: \theta_t \\[-1.ex]
  190. &\rule{110mm}{0.4pt} \\[-1.ex]
  191. \end{aligned}
  192. For further details regarding the algorithm we refer to `Adaptive Subgradient Methods for Online Learning
  193. and Stochastic Optimization`_.
  194. """
  195. + rf"""
  196. Args:
  197. {_params_doc}
  198. lr (float, Tensor, optional): learning rate (default: 1e-2)
  199. lr_decay (float, optional): learning rate decay (default: 0)
  200. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  201. initial_accumulator_value (float, optional): initial value of the
  202. sum of squares of gradients (default: 0)
  203. eps (float, optional): term added to the denominator to improve
  204. numerical stability (default: 1e-10)
  205. {_foreach_doc}
  206. {_maximize_doc}
  207. {_differentiable_doc}
  208. fused (bool, optional): whether the fused implementation (CPU only) is used.
  209. Currently, `torch.float64`, `torch.float32`, `torch.float16`, and `torch.bfloat16`
  210. are supported. (default: None). Please note that the fused implementations does not
  211. support sparse or complex gradients.
  212. .. _Adaptive Subgradient Methods for Online Learning and Stochastic
  213. Optimization: http://jmlr.org/papers/v12/duchi11a.html
  214. """
  215. )
  216. def adagrad(
  217. params: list[Tensor],
  218. grads: list[Tensor],
  219. state_sums: list[Tensor],
  220. state_steps: list[Tensor],
  221. fused: Optional[bool] = None,
  222. grad_scale: Optional[Tensor] = None,
  223. found_inf: Optional[Tensor] = None,
  224. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  225. # setting these as kwargs for now as functional API is compiled by torch/distributed/optim
  226. has_sparse_grad: bool = False,
  227. foreach: Optional[bool] = None,
  228. differentiable: bool = False,
  229. has_complex: bool = False,
  230. *,
  231. lr: float,
  232. weight_decay: float,
  233. lr_decay: float,
  234. eps: float,
  235. maximize: bool,
  236. ):
  237. r"""Functional API that performs Adagrad algorithm computation.
  238. See :class:`~torch.optim.Adagrad` for details.
  239. """
  240. if not all(isinstance(t, torch.Tensor) for t in state_steps):
  241. raise RuntimeError(
  242. "API has changed, `state_steps` argument must contain a list of singleton tensors"
  243. )
  244. # Respect when the user inputs False/True for foreach or fused. We only want to change
  245. # the default when neither have been user-specified. Note that we default to foreach
  246. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  247. # bake-in time before making it the default, even if it is typically faster.
  248. if fused is None and foreach is None:
  249. _, foreach = _default_to_fused_or_foreach(
  250. params, differentiable, use_fused=False
  251. )
  252. if fused is None:
  253. fused = False
  254. if foreach is None:
  255. foreach = False
  256. if foreach and torch.jit.is_scripting():
  257. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  258. if fused and torch.jit.is_scripting():
  259. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  260. if fused and not torch.jit.is_scripting():
  261. func = _fused_adagrad
  262. elif foreach and not torch.jit.is_scripting():
  263. func = _multi_tensor_adagrad
  264. else:
  265. func = _single_tensor_adagrad
  266. func(
  267. params,
  268. grads,
  269. state_sums,
  270. state_steps,
  271. lr=lr,
  272. weight_decay=weight_decay,
  273. lr_decay=lr_decay,
  274. eps=eps,
  275. has_sparse_grad=has_sparse_grad,
  276. maximize=maximize,
  277. differentiable=differentiable,
  278. has_complex=has_complex,
  279. grad_scale=grad_scale,
  280. found_inf=found_inf,
  281. )
  282. def _make_sparse(grad, grad_indices, values):
  283. size = grad.size()
  284. return torch.sparse_coo_tensor(grad_indices, values, size)
  285. def _single_tensor_adagrad(
  286. params: list[Tensor],
  287. grads: list[Tensor],
  288. state_sums: list[Tensor],
  289. state_steps: list[Tensor],
  290. grad_scale: Optional[Tensor],
  291. found_inf: Optional[Tensor],
  292. *,
  293. lr: float,
  294. weight_decay: float,
  295. lr_decay: float,
  296. eps: float,
  297. has_sparse_grad: bool,
  298. maximize: bool,
  299. differentiable: bool,
  300. has_complex: bool,
  301. ):
  302. assert grad_scale is None and found_inf is None
  303. if not torch.jit.is_scripting():
  304. lr = _to_scalar(lr)
  305. for param, grad, state_sum, step_t in zip(params, grads, state_sums, state_steps):
  306. # update step
  307. step_t += 1
  308. step = _get_value(step_t)
  309. grad = grad if not maximize else -grad
  310. if weight_decay != 0:
  311. if grad.is_sparse:
  312. raise RuntimeError(
  313. "weight_decay option is not compatible with sparse gradients"
  314. )
  315. grad = grad.add(param, alpha=weight_decay)
  316. clr = lr / (1 + (step - 1) * lr_decay)
  317. if grad.is_sparse:
  318. grad = grad.coalesce() # the update is non-linear so indices must be unique
  319. grad_indices = grad._indices()
  320. grad_values = grad._values()
  321. state_sum.add_(_make_sparse(grad, grad_indices, grad_values.pow(2)))
  322. std = state_sum.sparse_mask(grad)
  323. std_values = std._values().sqrt_().add_(eps)
  324. param.add_(
  325. _make_sparse(grad, grad_indices, grad_values / std_values), alpha=-clr
  326. )
  327. else:
  328. is_complex = torch.is_complex(param)
  329. if is_complex:
  330. grad = torch.view_as_real(grad)
  331. state_sum = torch.view_as_real(state_sum)
  332. param = torch.view_as_real(param)
  333. state_sum.addcmul_(grad, grad, value=1)
  334. if differentiable:
  335. std = state_sum.sqrt() + eps
  336. else:
  337. std = state_sum.sqrt().add_(eps)
  338. param.addcdiv_(grad, std, value=-clr)
  339. if is_complex:
  340. param = torch.view_as_complex(param)
  341. state_sum = torch.view_as_complex(state_sum)
  342. def _multi_tensor_adagrad(
  343. params: list[Tensor],
  344. grads: list[Tensor],
  345. state_sums: list[Tensor],
  346. state_steps: list[Tensor],
  347. grad_scale: Optional[Tensor],
  348. found_inf: Optional[Tensor],
  349. *,
  350. lr: float,
  351. weight_decay: float,
  352. lr_decay: float,
  353. eps: float,
  354. has_sparse_grad: bool,
  355. maximize: bool,
  356. differentiable: bool,
  357. has_complex: bool,
  358. ):
  359. assert not differentiable, "_foreach ops don't support autograd"
  360. assert grad_scale is None and found_inf is None
  361. # Foreach functions will throw errors if given empty lists
  362. if len(params) == 0:
  363. return
  364. lr = _to_scalar(lr)
  365. grouped_tensorlists = Optimizer._group_tensors_by_device_and_dtype(
  366. [params, grads, state_sums, state_steps] # type: ignore[list-item]
  367. )
  368. for (
  369. device_params_,
  370. device_grads_,
  371. device_state_sums_,
  372. device_state_steps_,
  373. ), _ in grouped_tensorlists.values():
  374. device_params = cast(list[Tensor], device_params_)
  375. device_grads = cast(list[Tensor], device_grads_)
  376. device_state_sums = cast(list[Tensor], device_state_sums_)
  377. device_state_steps = cast(list[Tensor], device_state_steps_)
  378. device_has_sparse_grad = has_sparse_grad and any(
  379. grad.is_sparse for grad in device_grads
  380. )
  381. if device_has_sparse_grad:
  382. _single_tensor_adagrad(
  383. device_params,
  384. device_grads,
  385. device_state_sums,
  386. device_state_steps,
  387. lr=lr,
  388. weight_decay=weight_decay,
  389. lr_decay=lr_decay,
  390. eps=eps,
  391. has_sparse_grad=True,
  392. maximize=maximize,
  393. differentiable=differentiable,
  394. has_complex=has_complex,
  395. grad_scale=grad_scale,
  396. found_inf=found_inf,
  397. )
  398. continue
  399. # Handle complex parameters
  400. if has_complex:
  401. _view_as_real(device_params, device_grads, device_state_sums)
  402. if maximize:
  403. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  404. # Update steps
  405. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  406. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  407. # wrapped it once now. The alpha is required to assure we go to the right overload.
  408. if not torch.compiler.is_compiling() and device_state_steps[0].is_cpu:
  409. torch._foreach_add_(
  410. device_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  411. )
  412. else:
  413. torch._foreach_add_(device_state_steps, 1)
  414. if weight_decay != 0:
  415. # Reuse the intermediate memory (device_grads) already allocated for maximize
  416. if maximize:
  417. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  418. else:
  419. device_grads = torch._foreach_add( # type: ignore[assignment]
  420. device_grads, device_params, alpha=weight_decay
  421. )
  422. minus_clr = [
  423. -lr / (1 + (_get_value(step) - 1) * lr_decay) for step in device_state_steps
  424. ]
  425. torch._foreach_addcmul_(device_state_sums, device_grads, device_grads, value=1)
  426. std = torch._foreach_sqrt(device_state_sums)
  427. torch._foreach_add_(std, eps)
  428. if weight_decay != 0 or maximize:
  429. # Again, reuse the intermediate memory (device_grads) already allocated
  430. torch._foreach_mul_(device_grads, minus_clr)
  431. numerator = device_grads
  432. else:
  433. numerator = torch._foreach_mul(device_grads, minus_clr) # type: ignore[assignment]
  434. torch._foreach_addcdiv_(device_params, numerator, std)
  435. def _fused_adagrad(
  436. params: list[Tensor],
  437. grads: list[Tensor],
  438. state_sums: list[Tensor],
  439. state_steps: list[Tensor],
  440. grad_scale: Optional[Tensor],
  441. found_inf: Optional[Tensor],
  442. *,
  443. lr: float,
  444. weight_decay: float,
  445. lr_decay: float,
  446. eps: float,
  447. has_sparse_grad: bool,
  448. maximize: bool,
  449. differentiable: bool,
  450. has_complex: bool,
  451. ) -> None:
  452. if not params:
  453. return
  454. if has_sparse_grad or has_complex:
  455. raise RuntimeError("`fused` does not support sparse grad or complex param")
  456. if differentiable:
  457. raise RuntimeError(
  458. "adagrad with fused=True does not support differentiable=True"
  459. )
  460. lr = _to_scalar(lr)
  461. grad_scale_dict = (
  462. {grad_scale.device: grad_scale} if grad_scale is not None else None
  463. )
  464. found_inf_dict = {found_inf.device: found_inf} if found_inf is not None else None
  465. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  466. [params, grads, state_sums, state_steps] # type: ignore[list-item]
  467. )
  468. for (device, _), (
  469. (
  470. device_params_,
  471. device_grads_,
  472. device_state_sums_,
  473. device_state_steps_,
  474. ),
  475. _,
  476. ) in grouped_tensors.items():
  477. device_params = cast(list[Tensor], device_params_)
  478. device_grads = cast(list[Tensor], device_grads_)
  479. device_state_sums = cast(list[Tensor], device_state_sums_)
  480. device_state_steps = cast(list[Tensor], device_state_steps_)
  481. device_grad_scale, device_found_inf = None, None
  482. if grad_scale is not None and grad_scale_dict is not None:
  483. if device not in grad_scale_dict:
  484. grad_scale_dict[device] = grad_scale.to(device, non_blocking=True) # type: ignore[index]
  485. device_grad_scale = grad_scale_dict[device] # type: ignore[index]
  486. if found_inf is not None and found_inf_dict is not None:
  487. if found_inf not in found_inf_dict:
  488. found_inf_dict[device] = found_inf.to(device, non_blocking=True) # type: ignore[index]
  489. device_found_inf = found_inf_dict[device] # type: ignore[index]
  490. torch._foreach_add_(device_state_steps, 1)
  491. torch._fused_adagrad_(
  492. device_params,
  493. device_grads,
  494. device_state_sums,
  495. device_state_steps,
  496. lr=lr,
  497. lr_decay=lr_decay,
  498. weight_decay=weight_decay,
  499. eps=eps,
  500. maximize=maximize,
  501. grad_scale=device_grad_scale,
  502. found_inf=device_found_inf,
  503. )
  504. if device_found_inf is not None:
  505. torch._foreach_sub_(
  506. device_state_steps, [device_found_inf] * len(device_state_steps)
  507. )