functional_sgd.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. from torch.distributed.optim._deprecation_warning import (
  7. _scripted_functional_optimizer_deprecation_warning,
  8. )
  9. __all__: list[str] = []
  10. # Define a TorchScript compatible Functional SGD Optimizer
  11. # where we use these optimizer in a functional way.
  12. # Instead of using the `param.grad` when updating parameters,
  13. # we explicitly allow the distributed optimizer pass gradients to
  14. # the `step` function. In this way, we could separate the gradients
  15. # and parameters and allow multithreaded trainer to update the
  16. # parameters without data traces on accumulating to the same .grad.
  17. # NOTE: This should be only used by distributed optimizer internals
  18. # and not meant to expose to the user.
  19. @torch.jit.script
  20. class _FunctionalSGD:
  21. def __init__(
  22. self,
  23. params: list[Tensor],
  24. lr: float = 1e-2,
  25. momentum: float = 0.0,
  26. dampening: float = 0.0,
  27. weight_decay: float = 0.0,
  28. nesterov: bool = False,
  29. maximize: bool = False,
  30. foreach: bool = False,
  31. fused: bool = False,
  32. _allow_empty_param_list: bool = False,
  33. ):
  34. _scripted_functional_optimizer_deprecation_warning(stacklevel=2)
  35. self.defaults = {
  36. "lr": lr,
  37. "momentum": momentum,
  38. "dampening": dampening,
  39. "weight_decay": weight_decay,
  40. }
  41. self.nesterov = nesterov
  42. self.maximize = maximize
  43. self.foreach = foreach
  44. self.fused = fused
  45. self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
  46. if len(params) == 0 and not _allow_empty_param_list:
  47. raise ValueError("optimizer got an empty parameter list")
  48. # NOTE: we only have one param_group and don't allow user to add additional
  49. # param group as it's not a common use case.
  50. self.param_group = {"params": params}
  51. def step_param(self, param: Tensor, grad: Optional[Tensor]):
  52. """Similar to self.step, but operates on a single parameter and
  53. its gradient.
  54. """
  55. # TODO: Once step_param interface is robust, refactor step to call
  56. # step param on each param.
  57. weight_decay = self.defaults["weight_decay"]
  58. momentum = self.defaults["momentum"]
  59. dampening = self.defaults["dampening"]
  60. lr = self.defaults["lr"]
  61. params = [param]
  62. momentum_buffer_list: list[Optional[Tensor]] = []
  63. grads = []
  64. has_sparse_grad = False
  65. if grad is not None:
  66. grads.append(grad)
  67. if grad.is_sparse:
  68. has_sparse_grad = True
  69. if param not in self.state:
  70. self.state[param] = {}
  71. state = self.state[param]
  72. if "momentum_buffer" not in state:
  73. momentum_buffer_list.append(None)
  74. else:
  75. momentum_buffer_list.append(state["momentum_buffer"])
  76. with torch.no_grad():
  77. F.sgd(
  78. params,
  79. grads,
  80. momentum_buffer_list,
  81. weight_decay=weight_decay,
  82. momentum=momentum,
  83. lr=lr,
  84. dampening=dampening,
  85. nesterov=self.nesterov,
  86. maximize=self.maximize,
  87. has_sparse_grad=has_sparse_grad,
  88. foreach=self.foreach,
  89. fused=self.fused,
  90. grad_scale=None,
  91. found_inf=None,
  92. )
  93. # update momentum_buffer in state
  94. state = self.state[param]
  95. momentum_buffer = momentum_buffer_list[0]
  96. if momentum_buffer is not None:
  97. state["momentum_buffer"] = momentum_buffer
  98. def step(self, gradients: list[Optional[Tensor]]):
  99. params = self.param_group["params"]
  100. params_with_grad = []
  101. grads = []
  102. momentum_buffer_list: list[Optional[Tensor]] = []
  103. lr = self.defaults["lr"]
  104. weight_decay = self.defaults["weight_decay"]
  105. momentum = self.defaults["momentum"]
  106. dampening = self.defaults["dampening"]
  107. if len(params) != len(gradients):
  108. raise ValueError(
  109. "the gradients passed in does not equal to the size of the parameters!"
  110. + f"Params length: {len(params)}. "
  111. + f"Gradients length: {len(gradients)}"
  112. )
  113. has_sparse_grad = False
  114. for param, gradient in zip(params, gradients):
  115. if gradient is not None:
  116. params_with_grad.append(param)
  117. grads.append(gradient)
  118. if gradient.is_sparse:
  119. has_sparse_grad = True
  120. if param not in self.state:
  121. self.state[param] = {}
  122. state = self.state[param]
  123. if "momentum_buffer" not in state:
  124. momentum_buffer_list.append(None)
  125. else:
  126. momentum_buffer_list.append(state["momentum_buffer"])
  127. with torch.no_grad():
  128. F.sgd(
  129. params_with_grad,
  130. grads,
  131. momentum_buffer_list,
  132. weight_decay=weight_decay,
  133. momentum=momentum,
  134. lr=lr,
  135. dampening=dampening,
  136. nesterov=self.nesterov,
  137. maximize=self.maximize,
  138. has_sparse_grad=has_sparse_grad,
  139. foreach=self.foreach,
  140. fused=self.fused,
  141. grad_scale=None,
  142. found_inf=None,
  143. )
  144. # update momentum_buffers in state
  145. for i, p in enumerate(params_with_grad):
  146. state = self.state[p]
  147. momentum_buffer = momentum_buffer_list[i]
  148. if momentum_buffer is not None:
  149. state["momentum_buffer"] = momentum_buffer