sgd.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538
  1. # mypy: allow-untyped-defs
  2. r"""Implementation for Stochastic Gradient Descent optimizer."""
  3. from typing import cast, Optional, Union
  4. import torch
  5. from torch import Tensor
  6. from .optimizer import (
  7. _default_to_fused_or_foreach,
  8. _device_dtype_check_for_fused,
  9. _differentiable_doc,
  10. _foreach_doc,
  11. _fused_doc,
  12. _maximize_doc,
  13. _params_doc,
  14. _to_scalar,
  15. _use_grad_for_differentiable,
  16. DeviceDict,
  17. Optimizer,
  18. ParamsT,
  19. )
  20. __all__ = ["SGD", "sgd"]
  21. class SGD(Optimizer): # noqa: D101
  22. def __init__(
  23. self,
  24. params: ParamsT,
  25. lr: Union[float, Tensor] = 1e-3,
  26. momentum: float = 0,
  27. dampening: float = 0,
  28. weight_decay: Union[float, Tensor] = 0,
  29. nesterov: bool = False,
  30. *,
  31. maximize: bool = False,
  32. foreach: Optional[bool] = None,
  33. differentiable: bool = False,
  34. fused: Optional[bool] = None,
  35. ): # noqa: D107
  36. if isinstance(lr, Tensor) and lr.numel() != 1:
  37. raise ValueError("Tensor lr must be 1-element")
  38. if lr < 0.0:
  39. raise ValueError(f"Invalid learning rate: {lr}")
  40. if momentum < 0.0:
  41. raise ValueError(f"Invalid momentum value: {momentum}")
  42. if weight_decay < 0.0:
  43. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  44. defaults = {
  45. "lr": lr,
  46. "momentum": momentum,
  47. "dampening": dampening,
  48. "weight_decay": weight_decay,
  49. "nesterov": nesterov,
  50. "maximize": maximize,
  51. "foreach": foreach,
  52. "differentiable": differentiable,
  53. "fused": fused,
  54. }
  55. if nesterov and (momentum <= 0 or dampening != 0):
  56. raise ValueError("Nesterov momentum requires a momentum and zero dampening")
  57. super().__init__(params, defaults)
  58. if fused:
  59. self._step_supports_amp_scaling = True
  60. self._need_device_dtype_check_for_fused = True
  61. if differentiable:
  62. raise RuntimeError("`fused` does not support `differentiable`")
  63. if foreach:
  64. raise RuntimeError("`fused` and `foreach` cannot be `True` together.")
  65. def __setstate__(self, state): # noqa: D105
  66. super().__setstate__(state)
  67. for group in self.param_groups:
  68. group.setdefault("nesterov", False)
  69. group.setdefault("maximize", False)
  70. group.setdefault("foreach", None)
  71. group.setdefault("differentiable", False)
  72. group.setdefault("fused", False)
  73. def _init_group(self, group, params, grads, momentum_buffer_list):
  74. has_sparse_grad = False
  75. for p in group["params"]:
  76. if p.grad is not None:
  77. if group["fused"] and getattr(
  78. self, "_need_device_dtype_check_for_fused", True
  79. ):
  80. _device_dtype_check_for_fused(p)
  81. self._need_device_dtype_check_for_fused = False
  82. params.append(p)
  83. grads.append(p.grad)
  84. if p.grad.is_sparse:
  85. has_sparse_grad = True
  86. if group["momentum"] != 0:
  87. state = self.state[p]
  88. momentum_buffer_list.append(state.get("momentum_buffer"))
  89. return has_sparse_grad
  90. @_use_grad_for_differentiable
  91. def step(self, closure=None):
  92. """Perform a single optimization step.
  93. Args:
  94. closure (Callable, optional): A closure that reevaluates the model
  95. and returns the loss.
  96. """
  97. loss = None
  98. if closure is not None:
  99. with torch.enable_grad():
  100. loss = closure()
  101. for group in self.param_groups:
  102. params: list[Tensor] = []
  103. grads: list[Tensor] = []
  104. momentum_buffer_list: list[Optional[Tensor]] = []
  105. has_sparse_grad = self._init_group(
  106. group, params, grads, momentum_buffer_list
  107. )
  108. sgd(
  109. params,
  110. grads,
  111. momentum_buffer_list,
  112. weight_decay=group["weight_decay"],
  113. momentum=group["momentum"],
  114. lr=group["lr"],
  115. dampening=group["dampening"],
  116. nesterov=group["nesterov"],
  117. maximize=group["maximize"],
  118. has_sparse_grad=has_sparse_grad,
  119. foreach=group["foreach"],
  120. fused=group["fused"],
  121. grad_scale=getattr(self, "grad_scale", None),
  122. found_inf=getattr(self, "found_inf", None),
  123. )
  124. if group["momentum"] != 0:
  125. # update momentum_buffers in state
  126. for p, momentum_buffer in zip(params, momentum_buffer_list):
  127. state = self.state[p]
  128. state["momentum_buffer"] = momentum_buffer
  129. return loss
  130. SGD.__doc__ = (
  131. r"""Implements stochastic gradient descent (optionally with momentum).
  132. .. math::
  133. \begin{aligned}
  134. &\rule{110mm}{0.4pt} \\
  135. &\textbf{input} : \gamma \text{ (lr)}, \: \theta_0 \text{ (params)}, \: f(\theta)
  136. \text{ (objective)}, \: \lambda \text{ (weight decay)}, \\
  137. &\hspace{13mm} \:\mu \text{ (momentum)}, \:\tau \text{ (dampening)},
  138. \:\textit{ nesterov,}\:\textit{ maximize} \\[-1.ex]
  139. &\rule{110mm}{0.4pt} \\
  140. &\textbf{for} \: t=1 \: \textbf{to} \: \ldots \: \textbf{do} \\
  141. &\hspace{5mm}\textbf{if} \: \textit{maximize}: \\
  142. &\hspace{10mm}g_t \leftarrow -\nabla_{\theta} f_t (\theta_{t-1}) \\
  143. &\hspace{5mm}\textbf{else} \\
  144. &\hspace{10mm}g_t \leftarrow \nabla_{\theta} f_t (\theta_{t-1}) \\
  145. &\hspace{5mm}\textbf{if} \: \lambda \neq 0 \\
  146. &\hspace{10mm} g_t \leftarrow g_t + \lambda \theta_{t-1} \\
  147. &\hspace{5mm}\textbf{if} \: \mu \neq 0 \\
  148. &\hspace{10mm}\textbf{if} \: t > 1 \\
  149. &\hspace{15mm} \textbf{b}_t \leftarrow \mu \textbf{b}_{t-1} + (1-\tau) g_t \\
  150. &\hspace{10mm}\textbf{else} \\
  151. &\hspace{15mm} \textbf{b}_t \leftarrow g_t \\
  152. &\hspace{10mm}\textbf{if} \: \textit{nesterov} \\
  153. &\hspace{15mm} g_t \leftarrow g_{t} + \mu \textbf{b}_t \\
  154. &\hspace{10mm}\textbf{else} \\[-1.ex]
  155. &\hspace{15mm} g_t \leftarrow \textbf{b}_t \\
  156. &\hspace{5mm}\theta_t \leftarrow \theta_{t-1} - \gamma g_t \\[-1.ex]
  157. &\rule{110mm}{0.4pt} \\[-1.ex]
  158. &\bf{return} \: \theta_t \\[-1.ex]
  159. &\rule{110mm}{0.4pt} \\[-1.ex]
  160. \end{aligned}
  161. Nesterov momentum is based on the formula from
  162. `On the importance of initialization and momentum in deep learning`__.
  163. """
  164. + rf"""
  165. Args:
  166. {_params_doc}
  167. lr (float, Tensor, optional): learning rate (default: 1e-3)
  168. momentum (float, optional): momentum factor (default: 0)
  169. dampening (float, optional): dampening for momentum (default: 0)
  170. weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
  171. nesterov (bool, optional): enables Nesterov momentum. Only applicable
  172. when momentum is non-zero. (default: False)
  173. {_maximize_doc}
  174. {_foreach_doc}
  175. {_differentiable_doc}
  176. {_fused_doc}
  177. """
  178. + r"""
  179. Example:
  180. >>> # xdoctest: +SKIP
  181. >>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
  182. >>> optimizer.zero_grad()
  183. >>> loss_fn(model(input), target).backward()
  184. >>> optimizer.step()
  185. __ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
  186. .. note::
  187. The implementation of SGD with Momentum/Nesterov subtly differs from
  188. Sutskever et al. and implementations in some other frameworks.
  189. Considering the specific case of Momentum, the update can be written as
  190. .. math::
  191. \begin{aligned}
  192. v_{t+1} & = \mu * v_{t} + g_{t+1}, \\
  193. p_{t+1} & = p_{t} - \text{lr} * v_{t+1},
  194. \end{aligned}
  195. where :math:`p`, :math:`g`, :math:`v` and :math:`\mu` denote the
  196. parameters, gradient, velocity, and momentum respectively.
  197. This is in contrast to Sutskever et al. and
  198. other frameworks which employ an update of the form
  199. .. math::
  200. \begin{aligned}
  201. v_{t+1} & = \mu * v_{t} + \text{lr} * g_{t+1}, \\
  202. p_{t+1} & = p_{t} - v_{t+1}.
  203. \end{aligned}
  204. The Nesterov version is analogously modified.
  205. Moreover, the initial value of the momentum buffer is set to the
  206. gradient value at the first step. This is in contrast to some other
  207. frameworks that initialize it to all zeros. One notable side effect
  208. of this decision is that the first momentum value will not be scaled
  209. by dampening. Dampening will be applied starting at the second step.
  210. """
  211. )
  212. def sgd(
  213. params: list[Tensor],
  214. d_p_list: list[Tensor],
  215. momentum_buffer_list: list[Optional[Tensor]],
  216. # kwonly args with defaults are not supported by functions compiled with torchscript issue #70627
  217. # setting this as kwarg for now as functional API is compiled by torch/distributed/optim
  218. has_sparse_grad: bool = False,
  219. foreach: Optional[bool] = None,
  220. fused: Optional[bool] = None,
  221. grad_scale: Optional[Tensor] = None,
  222. found_inf: Optional[Tensor] = None,
  223. *,
  224. weight_decay: float,
  225. momentum: float,
  226. lr: float,
  227. dampening: float,
  228. nesterov: bool,
  229. maximize: bool,
  230. ):
  231. r"""Functional API that performs SGD algorithm computation.
  232. See :class:`~torch.optim.SGD` for details.
  233. """
  234. # Respect when the user inputs False/True for foreach or fused. We only want to change
  235. # the default when neither have been user-specified. Note that we default to foreach
  236. # and pass False to use_fused. This is not a mistake--we want to give the fused impl
  237. # bake-in time before making it the default, even if it is typically faster.
  238. if foreach is None and fused is None:
  239. # why must we be explicit about an if statement for torch.jit.is_scripting here?
  240. # because JIT can't handle Optionals nor fancy conditionals when scripting
  241. if not torch.jit.is_scripting():
  242. fused, foreach = _default_to_fused_or_foreach(
  243. params, differentiable=False, use_fused=False
  244. )
  245. else:
  246. foreach = False
  247. fused = False
  248. if foreach is None:
  249. foreach = False
  250. if fused is None:
  251. fused = False
  252. if foreach and torch.jit.is_scripting():
  253. raise RuntimeError("torch.jit.script not supported with foreach optimizers")
  254. if fused and torch.jit.is_scripting():
  255. raise RuntimeError("torch.jit.script not supported with fused optimizers")
  256. if foreach and not torch.jit.is_scripting():
  257. func = _multi_tensor_sgd
  258. elif fused and not torch.jit.is_scripting():
  259. func = _fused_sgd
  260. else:
  261. func = _single_tensor_sgd
  262. func(
  263. params,
  264. d_p_list,
  265. momentum_buffer_list,
  266. weight_decay=weight_decay,
  267. momentum=momentum,
  268. lr=lr,
  269. dampening=dampening,
  270. nesterov=nesterov,
  271. has_sparse_grad=has_sparse_grad,
  272. maximize=maximize,
  273. grad_scale=grad_scale,
  274. found_inf=found_inf,
  275. )
  276. def _single_tensor_sgd(
  277. params: list[Tensor],
  278. grads: list[Tensor],
  279. momentum_buffer_list: list[Optional[Tensor]],
  280. grad_scale: Optional[Tensor],
  281. found_inf: Optional[Tensor],
  282. *,
  283. weight_decay: float,
  284. momentum: float,
  285. lr: float,
  286. dampening: float,
  287. nesterov: bool,
  288. maximize: bool,
  289. has_sparse_grad: bool,
  290. ):
  291. assert grad_scale is None and found_inf is None
  292. if not torch.jit.is_scripting():
  293. lr = _to_scalar(lr)
  294. for i, param in enumerate(params):
  295. grad = grads[i] if not maximize else -grads[i]
  296. if weight_decay != 0:
  297. # Nested if is necessary to bypass jitscript rules
  298. if isinstance(weight_decay, Tensor):
  299. if weight_decay.requires_grad:
  300. # usually this is the differentiable path, which is why the param.clone() is needed
  301. grad = grad.addcmul_(param.clone(), weight_decay)
  302. else:
  303. grad = grad.add(param, alpha=weight_decay)
  304. else:
  305. grad = grad.add(param, alpha=weight_decay)
  306. if momentum != 0:
  307. buf = momentum_buffer_list[i]
  308. if buf is None:
  309. buf = grad.detach().clone()
  310. momentum_buffer_list[i] = buf
  311. else:
  312. buf.mul_(momentum).add_(grad, alpha=1 - dampening)
  313. if nesterov:
  314. grad = grad.add(buf, alpha=momentum)
  315. else:
  316. grad = buf
  317. # Nested if is necessary to bypass jitscript rules
  318. if isinstance(lr, Tensor):
  319. if lr.requires_grad:
  320. param.addcmul_(grad, lr, value=-1)
  321. else:
  322. param.add_(grad, alpha=-lr)
  323. else:
  324. param.add_(grad, alpha=-lr)
  325. def _multi_tensor_sgd(
  326. params: list[Tensor],
  327. grads: list[Tensor],
  328. momentum_buffer_list: list[Optional[Tensor]],
  329. grad_scale: Optional[Tensor],
  330. found_inf: Optional[Tensor],
  331. *,
  332. weight_decay: float,
  333. momentum: float,
  334. lr: float,
  335. dampening: float,
  336. nesterov: bool,
  337. maximize: bool,
  338. has_sparse_grad: bool,
  339. ):
  340. assert grad_scale is None and found_inf is None
  341. if len(params) == 0:
  342. return
  343. lr = _to_scalar(lr)
  344. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  345. [params, grads, momentum_buffer_list], # type: ignore[list-item]
  346. with_indices=True,
  347. )
  348. for (
  349. device_params_,
  350. device_grads_,
  351. device_momentum_buffer_list,
  352. ), indices in grouped_tensors.values():
  353. device_params: list[Tensor] = cast(list[Tensor], device_params_)
  354. device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
  355. device_has_sparse_grad = has_sparse_grad and any(
  356. grad.is_sparse for grad in device_grads
  357. )
  358. if maximize:
  359. device_grads = torch._foreach_neg(device_grads) # type: ignore[assignment]
  360. if weight_decay != 0:
  361. # Reuse the intermediate memory (device_grads) already allocated for maximize
  362. if maximize:
  363. torch._foreach_add_(device_grads, device_params, alpha=weight_decay)
  364. else:
  365. device_grads = torch._foreach_add( # type: ignore[assignment]
  366. device_grads, device_params, alpha=weight_decay
  367. )
  368. if momentum != 0:
  369. bufs: list[Tensor] = []
  370. all_states_with_momentum_buffer = True
  371. for i in range(len(device_momentum_buffer_list)):
  372. if device_momentum_buffer_list[i] is None:
  373. all_states_with_momentum_buffer = False
  374. break
  375. else:
  376. bufs.append(cast(Tensor, device_momentum_buffer_list[i]))
  377. if all_states_with_momentum_buffer:
  378. torch._foreach_mul_(bufs, momentum)
  379. torch._foreach_add_(bufs, device_grads, alpha=1 - dampening)
  380. else:
  381. bufs = []
  382. for i in range(len(device_momentum_buffer_list)):
  383. if device_momentum_buffer_list[i] is None:
  384. buf = device_momentum_buffer_list[i] = momentum_buffer_list[
  385. indices[i]
  386. ] = device_grads[i].detach().clone()
  387. else:
  388. buf = cast(Tensor, device_momentum_buffer_list[i])
  389. buf.mul_(momentum).add_(device_grads[i], alpha=1 - dampening)
  390. bufs.append(buf)
  391. if nesterov:
  392. torch._foreach_add_(device_grads, bufs, alpha=momentum)
  393. else:
  394. device_grads = bufs
  395. if not device_has_sparse_grad:
  396. # handle internal item() call if lr is a tensor
  397. if isinstance(lr, torch.Tensor) and torch.compiler.is_compiling():
  398. grads_x_lr = torch._foreach_mul(device_grads, -lr)
  399. torch._foreach_add_(device_params, grads_x_lr)
  400. else:
  401. torch._foreach_add_(device_params, device_grads, alpha=-lr)
  402. else:
  403. # foreach APIs don't support sparse
  404. for i in range(len(device_params)):
  405. device_params[i].add_(device_grads[i], alpha=-lr)
  406. def _fused_sgd(
  407. params: list[Tensor],
  408. grads: list[Tensor],
  409. momentum_buffer_list: list[Optional[Tensor]],
  410. grad_scale: Optional[Tensor],
  411. found_inf: Optional[Tensor],
  412. *,
  413. weight_decay: float,
  414. momentum: float,
  415. lr: float,
  416. dampening: float,
  417. nesterov: bool,
  418. maximize: bool,
  419. has_sparse_grad: bool,
  420. ) -> None:
  421. if not params:
  422. return
  423. if has_sparse_grad:
  424. raise RuntimeError("`_fused_sgd` does not support sparse gradients")
  425. grad_scale_dict: DeviceDict = (
  426. {grad_scale.device: grad_scale} if grad_scale is not None else {}
  427. )
  428. found_inf_dict: DeviceDict = (
  429. {found_inf.device: found_inf} if found_inf is not None else {}
  430. )
  431. no_momentum_buffer = momentum == 0
  432. is_first_step = (
  433. all(t is None for t in momentum_buffer_list) and not no_momentum_buffer
  434. )
  435. if is_first_step:
  436. for i, g in enumerate(grads):
  437. momentum_buffer_list[i] = torch.empty_like(g)
  438. grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
  439. [params, grads, momentum_buffer_list], # type: ignore[list-item]
  440. with_indices=False,
  441. )
  442. for (device, _), (
  443. (device_params_, device_grads_, device_momentum_buffer_list),
  444. _,
  445. ) in grouped_tensors.items():
  446. device_params: list[Tensor] = cast(list[Tensor], device_params_)
  447. device_grads: list[Tensor] = cast(list[Tensor], device_grads_)
  448. device_grad_scale, device_found_inf = None, None
  449. if grad_scale is not None:
  450. device_grad_scale = grad_scale_dict.setdefault(
  451. device, grad_scale.to(device)
  452. )
  453. if found_inf_dict is not None and found_inf is not None:
  454. device_found_inf = found_inf_dict.setdefault(device, found_inf.to(device))
  455. torch._fused_sgd_(
  456. device_params,
  457. device_grads,
  458. []
  459. if no_momentum_buffer
  460. else cast(list[Tensor], device_momentum_buffer_list),
  461. weight_decay=weight_decay,
  462. momentum=momentum,
  463. lr=lr,
  464. dampening=dampening,
  465. nesterov=nesterov,
  466. maximize=maximize,
  467. is_first_step=is_first_step,
  468. grad_scale=device_grad_scale,
  469. found_inf=device_found_inf,
  470. )