functional_adamax.py 4.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123
  1. # mypy: allow-untyped-defs
  2. from typing import Optional
  3. import torch
  4. import torch.optim._functional as F
  5. from torch import Tensor
  6. from torch.distributed.optim._deprecation_warning import (
  7. _scripted_functional_optimizer_deprecation_warning,
  8. )
  9. __all__: list[str] = []
  10. # Define a TorchScript compatible Functional Adamax Optimizer
  11. # where we use these optimizer in a functional way.
  12. # Instead of using the `param.grad` when updating parameters,
  13. # we explicitly allow the distributed optimizer pass gradients to
  14. # the `step` function. In this way, we could separate the gradients
  15. # and parameters and allow multithreaded trainer to update the
  16. # parameters without data traces on accumulating to the same .grad.
  17. # NOTE: This should be only used by distributed optimizer internals
  18. # and not meant to expose to the user.
  19. @torch.jit.script
  20. class _FunctionalAdamax:
  21. def __init__(
  22. self,
  23. params: list[Tensor],
  24. lr: float = 1e-3,
  25. betas: tuple[float, float] = (0.9, 0.999),
  26. eps: float = 1e-8,
  27. weight_decay: float = 0.0,
  28. foreach: bool = False,
  29. maximize: bool = False,
  30. _allow_empty_param_list: bool = False,
  31. ):
  32. _scripted_functional_optimizer_deprecation_warning(stacklevel=2)
  33. if not 0.0 <= lr:
  34. raise ValueError(f"Invalid learning rate: {lr}")
  35. if not 0.0 <= eps:
  36. raise ValueError(f"Invalid epsilon value: {eps}")
  37. if not 0.0 <= betas[0] < 1.0:
  38. raise ValueError(f"Invalid beta parameter at index 0: {betas[0]}")
  39. if not 0.0 <= betas[1] < 1.0:
  40. raise ValueError(f"Invalid beta parameter at index 1: {betas[1]}")
  41. if not 0.0 <= weight_decay:
  42. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  43. self.defaults = {
  44. "lr": lr,
  45. "eps": eps,
  46. "beta1": betas[0],
  47. "beta2": betas[1],
  48. "weight_decay": weight_decay,
  49. }
  50. self.foreach = foreach
  51. self.maximize = maximize
  52. self.state = torch.jit.annotate(dict[torch.Tensor, dict[str, torch.Tensor]], {})
  53. if len(params) == 0 and not _allow_empty_param_list:
  54. raise ValueError("optimizer got an empty parameter list")
  55. # NOTE: we only have one param_group and don't allow user to add additional
  56. # param group as it's not a common use case.
  57. self.param_group = {"params": params}
  58. def step(self, gradients: list[Optional[Tensor]]):
  59. params = self.param_group["params"]
  60. params_with_grad = []
  61. grads = []
  62. exp_avgs = []
  63. exp_infs = []
  64. state_steps: list[Tensor] = []
  65. if len(params) != len(gradients):
  66. raise ValueError(
  67. "the gradients passed in does not equal to the size of the parameters!"
  68. + f"Params length: {len(params)}. "
  69. + f"Gradients length: {len(gradients)}"
  70. )
  71. has_complex = False
  72. for param, gradient in zip(self.param_group["params"], gradients):
  73. if gradient is not None:
  74. has_complex |= torch.is_complex(param)
  75. params_with_grad.append(param)
  76. grads.append(gradient)
  77. # Lazy state initialization
  78. if param not in self.state:
  79. self.state[param] = {}
  80. state = self.state[param]
  81. state["step"] = torch.tensor(0.0)
  82. # Exponential moving average of gradient values
  83. state["exp_avg"] = torch.zeros_like(
  84. param, memory_format=torch.preserve_format
  85. )
  86. # Exponential moving average of squared gradient values
  87. state["exp_inf"] = torch.zeros_like(
  88. param, memory_format=torch.preserve_format
  89. )
  90. state = self.state[param]
  91. exp_avgs.append(state["exp_avg"])
  92. exp_infs.append(state["exp_inf"])
  93. state_steps.append(state["step"])
  94. with torch.no_grad():
  95. F.adamax(
  96. params_with_grad,
  97. grads,
  98. exp_avgs,
  99. exp_infs,
  100. state_steps,
  101. eps=self.defaults["eps"],
  102. beta1=self.defaults["beta1"],
  103. beta2=self.defaults["beta2"],
  104. lr=self.defaults["lr"],
  105. weight_decay=self.defaults["weight_decay"],
  106. foreach=self.foreach,
  107. maximize=self.maximize,
  108. has_complex=has_complex,
  109. )