asgd.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474
  1. # mypy: allow-untyped-defs
  2. from typing import cast, Optional, Union
  3. import torch
  4. from torch import Tensor
  5. from .optimizer import (
  6. _capturable_doc,
  7. _default_to_fused_or_foreach,
  8. _differentiable_doc,
  9. _disable_dynamo_if_unsupported,
  10. _foreach_doc,
  11. _get_capturable_supported_devices,
  12. _get_scalar_dtype,
  13. _get_value,
  14. _maximize_doc,
  15. _params_doc,
  16. _to_scalar,
  17. _use_grad_for_differentiable,
  18. _view_as_real,
  19. Optimizer,
  20. ParamsT,
  21. )
  22. __all__ = ["ASGD", "asgd"]
  23. class ASGD(Optimizer):
  24. def __init__(
  25. self,
  26. params: ParamsT,
  27. lr: Union[float, Tensor] = 1e-2,
  28. lambd: float = 1e-4,
  29. alpha: float = 0.75,
  30. t0: float = 1e6,
  31. weight_decay: float = 0,
  32. foreach: Optional[bool] = None,
  33. maximize: bool = False,
  34. differentiable: bool = False,
  35. capturable: bool = False,
  36. ):
  37. if isinstance(lr, Tensor) and lr.numel() != 1:
  38. raise ValueError("Tensor lr must be 1-element")
  39. if not 0.0 <= lr:
  40. raise ValueError(f"Invalid learning rate: {lr}")
  41. if not 0.0 <= weight_decay:
  42. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  43. defaults = {
  44. "lr": lr,
  45. "lambd": lambd,
  46. "alpha": alpha,
  47. "t0": t0,
  48. "weight_decay": weight_decay,
  49. "foreach": foreach,
  50. "maximize": maximize,
  51. "differentiable": differentiable,
  52. "capturable": capturable,
  53. }
  54. super().__init__(params, defaults)
  55. def __setstate__(self, state):
  56. super().__setstate__(state)
  57. for group in self.param_groups:
  58. group.setdefault("foreach", None)
  59. group.setdefault("maximize", False)
  60. group.setdefault("differentiable", False)
  61. group.setdefault("capturable", False)
  62. for p in group["params"]:
  63. p_state = self.state.get(p, [])
  64. if len(p_state) != 0:
  65. if not torch.is_tensor(p_state["step"]):
  66. step_val = float(p_state["step"])
  67. p_state["step"] = torch.tensor(
  68. step_val, dtype=_get_scalar_dtype(), device=p.device
  69. )
  70. if not torch.is_tensor(p_state["eta"]):
  71. p_state["eta"] = torch.tensor(
  72. p_state["eta"], dtype=_get_scalar_dtype(), device=p.device
  73. )
  74. if not torch.is_tensor(p_state["mu"]):
  75. p_state["mu"] = torch.tensor(
  76. p_state["mu"], dtype=_get_scalar_dtype(), device=p.device
  77. )
  78. def _init_group(self, group, params_with_grad, grads, mus, axs, etas, state_steps):
  79. has_complex = False
  80. for p in group["params"]:
  81. if p.grad is not None:
  82. has_complex |= torch.is_complex(p)
  83. params_with_grad.append(p)
  84. if p.grad.is_sparse:
  85. raise RuntimeError("ASGD does not support sparse gradients")
  86. grads.append(p.grad)
  87. state = self.state[p]
  88. # State initialization
  89. if len(state) == 0:
  90. state["step"] = torch.zeros(
  91. (), device=p.device, dtype=_get_scalar_dtype()
  92. )
  93. state["eta"] = (
  94. torch.as_tensor(
  95. _to_scalar(group["lr"]),
  96. device=p.device,
  97. dtype=_get_scalar_dtype(),
  98. )
  99. .clone()
  100. .detach()
  101. )
  102. state["mu"] = torch.ones(
  103. (), device=p.device, dtype=_get_scalar_dtype()
  104. )
  105. state["ax"] = torch.zeros_like(
  106. p, memory_format=torch.preserve_format
  107. )
  108. mus.append(state["mu"])
  109. axs.append(state["ax"])
  110. etas.append(state["eta"])
  111. state_steps.append(state["step"])
  112. return has_complex
  113. @_use_grad_for_differentiable
  114. def step(self, closure=None):
  115. """Perform a single optimization step.
  116. Args:
  117. closure (Callable, optional): A closure that reevaluates the model
  118. and returns the loss.
  119. """
  120. self._cuda_graph_capture_health_check()
  121. loss = None
  122. if closure is not None:
  123. with torch.enable_grad():
  124. loss = closure()
  125. for group in self.param_groups:
  126. params_with_grad: list[Tensor] = []
  127. grads: list[Tensor] = []
  128. mus: list[Tensor] = []
  129. axs: list[Tensor] = []
  130. etas: list[Tensor] = []
  131. state_steps: list[Tensor] = []
  132. has_complex = self._init_group(
  133. group, params_with_grad, grads, mus, axs, etas, state_steps
  134. )
  135. asgd(
  136. params_with_grad,
  137. grads,
  138. axs,
  139. mus,
  140. etas,
  141. state_steps,
  142. lambd=group["lambd"],
  143. lr=group["lr"],
  144. t0=group["t0"],
  145. alpha=group["alpha"],
  146. weight_decay=group["weight_decay"],
  147. foreach=group["foreach"],
  148. maximize=group["maximize"],
  149. differentiable=group["differentiable"],
  150. capturable=group["capturable"],
  151. has_complex=has_complex,
  152. )
  153. return loss
  154. ASGD.__doc__ = rf"""Implements Averaged Stochastic Gradient Descent.
  155. It has been proposed in `Acceleration of stochastic approximation by
  156. averaging`_.
  157. Args:
  158. {_params_doc}
  159. lr (float, Tensor, optional): learning rate (default: 1e-2)
  160. lambd (float, optional): decay term (default: 1e-4)
  161. alpha (float, optional): power for eta update (default: 0.75)
  162. t0 (float, optional): point at which to start averaging (default: 1e6)
  163. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  164. {_foreach_doc}
  165. {_maximize_doc}
  166. {_differentiable_doc}
  167. {_capturable_doc}
  168. .. _Acceleration of stochastic approximation by averaging:
  169. https://meyn.ece.ufl.edu/wp-content/uploads/sites/77/archive/spm_files/Courses/ECE555-2011/555media/poljud92.pdf
  170. """
  171. def _single_tensor_asgd(
  172. params: list[Tensor],
  173. grads: list[Tensor],
  174. axs: list[Tensor],
  175. mus: list[Tensor],
  176. etas: list[Tensor],
  177. state_steps: list[Tensor],
  178. *,
  179. lambd: float,
  180. lr: float,
  181. t0: float,
  182. alpha: float,
  183. weight_decay: float,
  184. maximize: bool,
  185. differentiable: bool,
  186. capturable: bool,
  187. has_complex: bool,
  188. ):
  189. if not torch.jit.is_scripting():
  190. lr = _to_scalar(lr)
  191. for i, param in enumerate(params):
  192. grad = grads[i]
  193. grad = grad if not maximize else -grad
  194. mu = mus[i]
  195. ax = axs[i]
  196. eta = etas[i]
  197. step_t = state_steps[i]
  198. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  199. if not torch.compiler.is_compiling() and capturable:
  200. capturable_supported_devices = _get_capturable_supported_devices()
  201. assert (
  202. param.device.type
  203. == mu.device.type
  204. == eta.device.type
  205. == step_t.device.type
  206. and param.device.type in capturable_supported_devices
  207. ), (
  208. f"If capturable=True, params, mus, etas, and state_steps must be "
  209. f"on supported devices: {capturable_supported_devices}."
  210. )
  211. if torch.is_complex(param):
  212. grad = torch.view_as_real(grad)
  213. param = torch.view_as_real(param)
  214. ax = torch.view_as_real(ax)
  215. # update step
  216. step_t += 1
  217. if weight_decay != 0:
  218. grad = grad.add(param, alpha=weight_decay)
  219. if capturable:
  220. param.mul_(1 - lambd * eta)
  221. param.addcmul_(grad, eta, value=-1) # update parameter
  222. else:
  223. eta_value = _get_value(eta)
  224. param.mul_(1 - lambd * eta_value) # decay term
  225. param.add_(grad, alpha=-eta_value) # update parameter
  226. # averaging
  227. if capturable or mu.item() != 1:
  228. ax.add_(param.sub(ax).mul_(mu))
  229. else:
  230. ax.copy_(param)
  231. if capturable:
  232. eta.copy_(lr / ((1 + lambd * lr * step_t) ** alpha))
  233. mu.copy_(1 / torch.maximum(step_t - t0, torch.ones_like(step_t)))
  234. else:
  235. step = _get_value(step_t)
  236. new_eta = torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha))
  237. eta.copy_(new_eta)
  238. new_mu = torch.as_tensor(1 / max(1, step - t0))
  239. mu.copy_(new_mu)
  240. def _multi_tensor_asgd(
  241. params: list[Tensor],
  242. grads: list[Tensor],
  243. axs: list[Tensor],
  244. mus: list[Tensor],
  245. etas: list[Tensor],
  246. state_steps: list[Tensor],
  247. *,
  248. lambd: float,
  249. lr: float,
  250. t0: float,
  251. alpha: float,
  252. weight_decay: float,
  253. maximize: bool,
  254. differentiable: bool,
  255. capturable: bool,
  256. has_complex: bool,
  257. ):
  258. if len(params) == 0:
  259. return
  260. assert not differentiable, "_foreach ops don't support autograd"
  261. # If compiling, the compiler will handle cudagraph checks, see note [torch.compile x capturable]
  262. if not torch.compiler.is_compiling() and capturable:
  263. capturable_supported_devices = _get_capturable_supported_devices(
  264. supports_xla=False
  265. )
  266. assert all(
  267. p.device.type == mu.device.type == eta.device.type == step.device.type
  268. and p.device.type in capturable_supported_devices
  269. for p, mu, eta, step in zip(params, mus, etas, state_steps)
  270. ), (
  271. f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
  272. )
  273. lr = _to_scalar(lr)
  274. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  275. [params, grads, axs, mus, etas, state_steps] # type: ignore[list-item]
  276. )
  277. for (device, _), (
  278. (
  279. grouped_params_,
  280. grouped_grads_,
  281. grouped_axs_,
  282. grouped_mus_,
  283. grouped_etas_,
  284. grouped_state_steps_,
  285. ),
  286. _,
  287. ) in grouped_tensors.items():
  288. grouped_params = cast(list[Tensor], grouped_params_)
  289. grouped_grads = cast(list[Tensor], grouped_grads_)
  290. grouped_axs = cast(list[Tensor], grouped_axs_)
  291. grouped_mus = cast(list[Tensor], grouped_mus_)
  292. grouped_etas = cast(list[Tensor], grouped_etas_)
  293. grouped_state_steps = cast(list[Tensor], grouped_state_steps_)
  294. if has_complex:
  295. _view_as_real(grouped_params, grouped_grads, grouped_axs)
  296. if maximize:
  297. grouped_grads = torch._foreach_neg(grouped_grads) # type: ignore[assignment]
  298. # Update steps
  299. # If steps are on CPU, foreach will fall back to the slow path, which is a for-loop calling t.add(1) over
  300. # and over. 1 will then be wrapped into a Tensor over and over again, which is slower than if we just
  301. # wrapped it once now. The alpha is required to assure we go to the right overload.
  302. if not torch.compiler.is_compiling() and grouped_state_steps[0].is_cpu:
  303. torch._foreach_add_(
  304. grouped_state_steps, torch.tensor(1.0, device="cpu"), alpha=1.0
  305. )
  306. else:
  307. torch._foreach_add_(grouped_state_steps, 1)
  308. # intermediate = grad + param * lambd
  309. intermediate: Union[tuple[Tensor, ...], list[Tensor]]
  310. if weight_decay != 0:
  311. if maximize:
  312. torch._foreach_add_(grouped_grads, grouped_params, alpha=weight_decay)
  313. intermediate = grouped_grads
  314. else:
  315. intermediate = torch._foreach_add(
  316. grouped_grads, grouped_params, alpha=weight_decay
  317. )
  318. torch._foreach_add_(intermediate, grouped_params, alpha=lambd)
  319. else:
  320. intermediate = torch._foreach_add(
  321. grouped_grads, grouped_params, alpha=lambd
  322. )
  323. # update param
  324. # param * (1 - lambd * eta) - eta * grad
  325. # => param - param * lambd * eta - eta * grad
  326. # => param - eta * intermediate
  327. torch._foreach_addcmul_(grouped_params, intermediate, grouped_etas, value=-1)
  328. del intermediate
  329. # update grouped_axs
  330. # averaging: ax = ax + mu * (param - ax)
  331. # Note (mlazos): We can't use lerp here since it requires weight to be float64
  332. # and our grouping code requires dtypes to match for all tensors in a group (and it should, since
  333. # we use the mus in other places)
  334. # all dtypes need to match, so we could introduce a cast in a loop
  335. # but since this only adds one additional kernel launch, this looks like the cleaner
  336. # and faster solution
  337. intermediate = torch._foreach_sub(grouped_params, grouped_axs)
  338. torch._foreach_addcmul_(grouped_axs, intermediate, grouped_mus)
  339. del intermediate
  340. new_etas: Union[tuple[Tensor, ...], list[Tensor]]
  341. new_mus: Union[tuple[Tensor, ...], list[Tensor]]
  342. if capturable:
  343. # update grouped_mus
  344. new_mus = torch._foreach_sub(grouped_state_steps, t0)
  345. torch._foreach_maximum_(new_mus, 1.0)
  346. torch._foreach_reciprocal_(new_mus)
  347. torch._foreach_copy_(grouped_mus, new_mus)
  348. del new_mus
  349. # update eta = lr / ((1 + lambd * lr * step)^alpha)
  350. new_etas = torch._foreach_mul(grouped_state_steps, lambd)
  351. torch._foreach_mul_(new_etas, lr)
  352. torch._foreach_add_(new_etas, 1)
  353. torch._foreach_pow_(new_etas, alpha)
  354. torch._foreach_reciprocal_(new_etas)
  355. torch._foreach_mul_(new_etas, lr)
  356. torch._foreach_copy_(grouped_etas, new_etas)
  357. else:
  358. new_etas = [
  359. torch.as_tensor(lr / ((1 + lambd * lr * step) ** alpha), device=device)
  360. for step in grouped_state_steps
  361. ]
  362. new_mus = [
  363. torch.as_tensor(1 / max(1, _get_value(step) - t0), device=device)
  364. for step in grouped_state_steps
  365. ]
  366. torch._foreach_copy_(grouped_etas, new_etas)
  367. torch._foreach_copy_(grouped_mus, new_mus)
  368. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_asgd)
  369. def asgd(
  370. params: list[Tensor],
  371. grads: list[Tensor],
  372. axs: list[Tensor],
  373. mus: list[Tensor],
  374. etas: list[Tensor],
  375. state_steps: list[Tensor],
  376. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  377. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  378. foreach: Optional[bool] = None,
  379. maximize: bool = False,
  380. differentiable: bool = False,
  381. capturable: bool = False,
  382. has_complex: bool = False,
  383. *,
  384. lambd: float,
  385. lr: float,
  386. t0: float,
  387. alpha: float,
  388. weight_decay: float,
  389. ):
  390. r"""Functional API that performs asgd algorithm computation.
  391. See :class:`~torch.optim.ASGD` for details.
  392. """
  393. if foreach is None:
  394. _, foreach = _default_to_fused_or_foreach(
  395. params, differentiable, use_fused=False
  396. )
  397. if foreach and torch.jit.is_scripting():
  398. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  399. if foreach and not torch.jit.is_scripting():
  400. func = _multi_tensor_asgd
  401. else:
  402. func = _single_tensor_asgd
  403. func(
  404. params,
  405. grads,
  406. axs,
  407. mus,
  408. etas,
  409. state_steps,
  410. lambd=lambd,
  411. lr=lr,
  412. t0=t0,
  413. alpha=alpha,
  414. weight_decay=weight_decay,
  415. maximize=maximize,
  416. differentiable=differentiable,
  417. capturable=capturable,
  418. has_complex=has_complex,
  419. )