functional_adam.py 7.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. from torch.distributed.optim._deprecation_warning import (
  7. _scripted_functional_optimizer_deprecation_warning,
  8. )
  9. __all__: list[str] = []
  10. # Define a TorchScript compatible Functional Adam Optimizer
  11. # where we use these optimizer in a functional way.
  12. # Instead of using the `param.grad` when updating parameters,
  13. # we explicitly allow the distributed optimizer pass gradients to
  14. # the `step` function. In this way, we could separate the gradients
  15. # and parameters and allow multithreaded trainer to update the
  16. # parameters without data traces on accumulating to the same .grad.
  17. # NOTE: This should be only used by distributed optimizer internals
  18. # and not meant to expose to the user.
  19. @torch.jit.script
  20. class _FunctionalAdam:
  21. def __init__(
  22. self,
  23. params: list[Tensor],
  24. lr: float = 1e-3,
  25. betas: tuple[float, float] = (0.9, 0.999),
  26. eps: float = 1e-8,
  27. weight_decay: float = 0.0,
  28. amsgrad: bool = False,
  29. maximize: bool = False,
  30. foreach: bool = False,
  31. fused: bool = False,
  32. _allow_empty_param_list: bool = False,
  33. ):
  34. _scripted_functional_optimizer_deprecation_warning(stacklevel=2)
  35. if not 0.0 <= lr:
  36. raise ValueError(f"Invalid learning rate: {lr}")
  37. if not 0.0 <= eps:
  38. raise ValueError(f"Invalid epsilon value: {eps}")
  39. if not 0.0 <= betas[0] < 1.0:
  40. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  41. if not 0.0 <= betas[1] < 1.0:
  42. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  43. if not 0.0 <= weight_decay:
  44. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  45. self.defaults = {
  46. "lr": lr,
  47. "eps": eps,
  48. "beta1": betas[0],
  49. "beta2": betas[1],
  50. "weight_decay": weight_decay,
  51. }
  52. self.amsgrad = amsgrad
  53. self.maximize = maximize
  54. self.foreach = foreach
  55. self.fused = fused
  56. self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
  57. if len(params) == 0 and not _allow_empty_param_list:
  58. raise ValueError("optimizer got an empty parameter list")
  59. # NOTE: we only have one param_group and don't allow user to add additional
  60. # param group as it's not a common use case.
  61. self.param_group = {"params": params}
  62. def step_param(self, param: Tensor, grad: Optional[Tensor]):
  63. """
  64. Similar to step, but operates on a single parameter and optionally a
  65. gradient tensor.
  66. """
  67. params_with_grad = []
  68. grads = []
  69. exp_avgs = []
  70. exp_avg_sqs = []
  71. max_exp_avg_sqs = []
  72. state_steps: list[Tensor] = []
  73. has_complex = torch.is_complex(param)
  74. if grad is not None:
  75. params_with_grad.append(param)
  76. grads.append(grad)
  77. if param not in self.state:
  78. self.state[param] = {}
  79. state = self.state[param]
  80. state["step"] = torch.tensor(0.0)
  81. state["exp_avg"] = torch.zeros_like(
  82. param, memory_format=torch.preserve_format
  83. )
  84. state["exp_avg_sq"] = torch.zeros_like(
  85. param, memory_format=torch.preserve_format
  86. )
  87. if self.amsgrad:
  88. state["max_exp_avg_sq"] = torch.zeros_like(
  89. param, memory_format=torch.preserve_format
  90. )
  91. state = self.state[param]
  92. exp_avgs.append(state["exp_avg"])
  93. exp_avg_sqs.append(state["exp_avg_sq"])
  94. if self.amsgrad:
  95. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  96. state_steps.append(state["step"])
  97. with torch.no_grad():
  98. F.adam(
  99. params_with_grad,
  100. grads,
  101. exp_avgs,
  102. exp_avg_sqs,
  103. max_exp_avg_sqs,
  104. state_steps,
  105. amsgrad=self.amsgrad,
  106. has_complex=has_complex,
  107. maximize=self.maximize,
  108. beta1=self.defaults["beta1"],
  109. beta2=self.defaults["beta2"],
  110. lr=self.defaults["lr"],
  111. weight_decay=self.defaults["weight_decay"],
  112. eps=self.defaults["eps"],
  113. foreach=self.foreach,
  114. fused=self.fused,
  115. grad_scale=None,
  116. found_inf=None,
  117. )
  118. def step(self, gradients: list[Optional[Tensor]]):
  119. params = self.param_group["params"]
  120. params_with_grad = []
  121. grads = []
  122. exp_avgs = []
  123. exp_avg_sqs = []
  124. max_exp_avg_sqs = []
  125. state_steps: list[Tensor] = []
  126. has_complex = False
  127. if len(params) != len(gradients):
  128. raise ValueError(
  129. "the gradients passed in does not equal to the size of the parameters!"
  130. + f"Params length: {len(params)}. "
  131. + f"Gradients length: {len(gradients)}"
  132. )
  133. for param, gradient in zip(self.param_group["params"], gradients):
  134. if gradient is not None:
  135. has_complex |= torch.is_complex(param)
  136. params_with_grad.append(param)
  137. grads.append(gradient)
  138. # Lazy state initialization
  139. if param not in self.state:
  140. self.state[param] = {}
  141. state = self.state[param]
  142. state["step"] = torch.tensor(0.0)
  143. # Exponential moving average of gradient values
  144. state["exp_avg"] = torch.zeros_like(
  145. param, memory_format=torch.preserve_format
  146. )
  147. # Exponential moving average of squared gradient values
  148. state["exp_avg_sq"] = torch.zeros_like(
  149. param, memory_format=torch.preserve_format
  150. )
  151. if self.amsgrad:
  152. # Maintains max of all exp. moving avg. of sq. grad. values
  153. state["max_exp_avg_sq"] = torch.zeros_like(
  154. param, memory_format=torch.preserve_format
  155. )
  156. state = self.state[param]
  157. exp_avgs.append(state["exp_avg"])
  158. exp_avg_sqs.append(state["exp_avg_sq"])
  159. if self.amsgrad:
  160. max_exp_avg_sqs.append(state["max_exp_avg_sq"])
  161. state_steps.append(state["step"])
  162. with torch.no_grad():
  163. F.adam(
  164. params_with_grad,
  165. grads,
  166. exp_avgs,
  167. exp_avg_sqs,
  168. max_exp_avg_sqs,
  169. state_steps,
  170. amsgrad=self.amsgrad,
  171. has_complex=has_complex,
  172. maximize=self.maximize,
  173. beta1=self.defaults["beta1"],
  174. beta2=self.defaults["beta2"],
  175. lr=self.defaults["lr"],
  176. weight_decay=self.defaults["weight_decay"],
  177. eps=self.defaults["eps"],
  178. foreach=self.foreach,
  179. fused=self.fused,
  180. grad_scale=None,
  181. found_inf=None,
  182. )