_muon.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. # mypy: allow-untyped-defs
  2. # mypy: disable-error-code=arg-type
  3. """Implementation of the Muon optimizer."""
  4. import math
  5. from collections.abc import MutableMapping
  6. from typing import Optional
  7. import torch
  8. from torch import Tensor
  9. from .optimizer import (
  10. _disable_dynamo_if_unsupported,
  11. _params_doc,
  12. _to_scalar,
  13. Optimizer,
  14. ParamsT,
  15. )
  16. __all__ = ["Muon"]
  17. # Constants from Keller Jordan's Muon post: https://kellerjordan.github.io/posts/muon/
  18. # github permlink: https://github.com/KellerJordan/Muon/blob/f90a42b28e00b8d9d2d05865fe90d9f39abcbcbd/muon.py#L16
  19. EPS = 1e-7
  20. DEFAULT_A = 3.4445
  21. DEFAULT_B = -4.7750
  22. DEFAULT_C = 2.0315
  23. DEFAULT_NS_STEPS = 5
  24. def _zeropower_via_newtonschulz(
  25. grad: Tensor, ns_coefficients: tuple[float, float, float], ns_steps: int, eps: float
  26. ) -> Tensor:
  27. """
  28. Newton-Schulz iteration to compute the zeroth power / orthogonalization of G. We opt to use a
  29. quintic iteration whose coefficients are selected to maximize the slope at zero. For the purpose
  30. of minimizing steps, it turns out to be empirically effective to keep increasing the slope at
  31. zero even beyond the point where the iteration no longer converges all the way to one everywhere
  32. on the interval. This iteration therefore does not produce UV^T but rather something like US'V^T
  33. where S' is diagonal with S_{ii}' ~ Uniform(0.5, 1.5), which turns out not to hurt model
  34. performance at all relative to UV^T, where USV^T = G is the SVD.
  35. Implementation reference: https://github.com/KellerJordan/Muon/blob/master/muon.py
  36. with suggestions by @jxbz, @leloykun, and @YouJiacheng.
  37. """
  38. if ns_steps >= 100:
  39. raise ValueError(
  40. "Number of steps must be less than 100 for computational efficiency"
  41. )
  42. if len(grad.shape) != 2:
  43. raise ValueError("Input tensor gradient must be a 2D matrix")
  44. if len(ns_coefficients) != 3:
  45. raise ValueError("Coefficients must be a tuple of exactly 3 values")
  46. a, b, c = ns_coefficients
  47. ortho_grad = grad.bfloat16()
  48. if grad.size(0) > grad.size(1):
  49. ortho_grad = ortho_grad.T
  50. # Ensure spectral norm is at most 1
  51. ortho_grad.div_(ortho_grad.norm().clamp(min=eps))
  52. # Perform the NS iterations
  53. for _ in range(ns_steps):
  54. gram_matrix = ortho_grad @ ortho_grad.T
  55. gram_update = torch.addmm(
  56. gram_matrix, gram_matrix, gram_matrix, beta=b, alpha=c
  57. )
  58. ortho_grad = torch.addmm(ortho_grad, gram_update, ortho_grad, beta=a)
  59. if grad.size(0) > grad.size(1):
  60. ortho_grad = ortho_grad.T
  61. return ortho_grad
  62. def _adjust_lr(
  63. lr: float, adjust_lr_fn: Optional[str], param_shape: torch.Size
  64. ) -> float:
  65. """Default learning rate adjustment used by Muon."""
  66. A, B = param_shape[:2]
  67. if adjust_lr_fn is None or adjust_lr_fn == "original":
  68. adjusted_ratio = math.sqrt(max(1, A / B))
  69. elif adjust_lr_fn == "match_rms_adamw":
  70. adjusted_ratio = 0.2 * math.sqrt(max(A, B))
  71. else:
  72. adjusted_ratio = 1.0
  73. return lr * adjusted_ratio
  74. class Muon(Optimizer):
  75. def __init__(
  76. self,
  77. params: ParamsT,
  78. lr: float = 1e-3,
  79. weight_decay: float = 0.1,
  80. momentum: float = 0.95,
  81. nesterov: bool = True,
  82. ns_coefficients: tuple[float, float, float] = (DEFAULT_A, DEFAULT_B, DEFAULT_C),
  83. eps: float = EPS,
  84. ns_steps: int = DEFAULT_NS_STEPS,
  85. adjust_lr_fn: Optional[str] = None,
  86. ) -> None:
  87. if isinstance(lr, Tensor) and lr.numel() != 1:
  88. raise ValueError("Tensor lr must be 1-element")
  89. if not 0.0 <= lr:
  90. raise ValueError(f"Learning rate should be >= 0 but is: {lr}")
  91. if not 0.0 <= momentum:
  92. raise ValueError(f"momentum should be >= 0 but is: {momentum}")
  93. if not 0.0 <= weight_decay:
  94. raise ValueError(f"weight decay should be >= 0 but is: {weight_decay}")
  95. if adjust_lr_fn is not None and adjust_lr_fn not in [
  96. "original",
  97. "match_rms_adamw",
  98. ]:
  99. raise ValueError(
  100. f"Adjust learning rate function {adjust_lr_fn} is not supported"
  101. )
  102. defaults = {
  103. "lr": lr,
  104. "weight_decay": weight_decay,
  105. "momentum": momentum,
  106. "nesterov": nesterov,
  107. "ns_coefficients": ns_coefficients,
  108. "eps": eps,
  109. "ns_steps": ns_steps,
  110. "adjust_lr_fn": adjust_lr_fn,
  111. }
  112. super().__init__(params, defaults)
  113. for group in self.param_groups:
  114. for p in group["params"]:
  115. if p.ndim != 2:
  116. raise ValueError(
  117. f"Muon only supports 2D parameters whereas we found a parameter with size: {p.size()}"
  118. )
  119. def _init_group(
  120. self,
  121. group: MutableMapping,
  122. params_with_grad: list[Tensor],
  123. grads: list[Tensor],
  124. muon_momentum_bufs: list[Tensor],
  125. ):
  126. for p in group["params"]:
  127. if p.grad is None:
  128. continue
  129. if torch.is_complex(p):
  130. raise RuntimeError("Muon does not support complex parameters")
  131. if p.grad.is_sparse:
  132. raise RuntimeError("Muon does not support sparse gradients")
  133. params_with_grad.append(p)
  134. grads.append(p.grad)
  135. state = self.state[p]
  136. if "momentum_buffer" not in state:
  137. state["momentum_buffer"] = torch.zeros_like(
  138. p.grad, memory_format=torch.preserve_format
  139. )
  140. muon_momentum_bufs.append(state["momentum_buffer"])
  141. return False # has_complex
  142. @torch.no_grad()
  143. def step(self, closure=None):
  144. """Performs a single optimization step."""
  145. loss = None
  146. if closure is not None:
  147. with torch.enable_grad():
  148. loss = closure()
  149. for group in self.param_groups:
  150. lr = group["lr"]
  151. weight_decay = group["weight_decay"]
  152. momentum = group["momentum"]
  153. params_with_grad: list[Tensor] = []
  154. grads: list[Tensor] = []
  155. muon_momentum_bufs: list[Tensor] = []
  156. has_complex = self._init_group(
  157. group,
  158. params_with_grad,
  159. grads,
  160. muon_momentum_bufs,
  161. )
  162. muon(
  163. params_with_grad,
  164. grads,
  165. muon_momentum_bufs,
  166. lr=lr,
  167. weight_decay=weight_decay,
  168. momentum=momentum,
  169. nesterov=group["nesterov"],
  170. ns_coefficients=group["ns_coefficients"],
  171. eps=group["eps"],
  172. ns_steps=group["ns_steps"],
  173. adjust_lr_fn=group["adjust_lr_fn"],
  174. has_complex=has_complex,
  175. )
  176. return loss
  177. Muon.__doc__ = (
  178. r"""Implements Muon algorithm.
  179. .. math::
  180. \begin{aligned}
  181. &\rule{110mm}{0.4pt} \\
  182. &\textbf{input} : \gamma \text{ (lr)},\ \lambda \text{ (weight decay)},\
  183. \mu \text{ (momentum)},\ \textit{nesterov}\in\{True,False\},\\
  184. &\hspace{13mm}(a,b,c)\ \text{ (NS coefficients)},\
  185. \varepsilon \text{ (epsilon)},\ k \text{ (NS steps)},\
  186. \theta_0 \text{ (params)},\ f(\theta) \text{ (objective)} \\
  187. &\textbf{initialize} : B_0 \leftarrow 0 \text{ (momentum buffer)} \\[-1.ex]
  188. &\rule{110mm}{0.4pt} \\
  189. &\textbf{for}\ t=1\ \textbf{to}\ \ldots\ \textbf{do} \\[0.25ex]
  190. &\hspace{5mm} g_t \leftarrow \nabla_{\theta} f_t(\theta_{t-1}) \\[0.25ex]
  191. &\hspace{5mm} B_t \leftarrow \mu B_{t-1} + g_t \\[0.25ex]
  192. &\hspace{5mm} \widetilde{B}_t \leftarrow
  193. \begin{cases}
  194. g_t + \mu B_t, & \text{if nesterov}=True \\
  195. B_t, & \text{if nesterov}=False
  196. \end{cases} \\[1.0ex]
  197. &\hspace{5mm} O_t \leftarrow \mathrm{NS}^{(a,b,c)}_{k}\!\big(\widetilde{B}_t;\ \varepsilon\big) \\[0.5ex]
  198. &\hspace{5mm} \theta_t \leftarrow \theta_{t-1} - \gamma\,\lambda\,\theta_{t-1}
  199. \quad\text{(decoupled weight decay)} \\[0.25ex]
  200. &\hspace{5mm} \gamma \leftarrow \mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big) \\[0.25ex]
  201. &\hspace{5mm} \theta_t \leftarrow \theta_t - \gamma\, O_t \\
  202. &\rule{110mm}{0.4pt} \\[-1.ex]
  203. &\mathbf{return}\ \theta_t \\[-1.ex]
  204. &\rule{110mm}{0.4pt}s
  205. \end{aligned}
  206. Here, :math:`\mathrm{NS}^{(a,b,c)}_{k}(\cdot;\varepsilon)` denotes :math:`k` iterations of the
  207. Newton–Schulz orthogonalization operator parameterized by coefficients :math:`(a,b,c)`
  208. with numerical stabilization :math:`\varepsilon`.
  209. The purpose for :math:`\mathrm{AdjustLR}\!\big(\gamma;\ \mathrm{shape}\!\big(\theta_t \big) \big)`
  210. is to make the orthogonalized update have a consistent :math:`RMS` across rectangular matrices.
  211. Keller's original implementation scales the update by :math:`\sqrt{\max\!\left(1, \frac{A}{B}\right)}`,
  212. where :math:`A` and :math:`B` are dimension of the matrix being optimized.
  213. Moonshot's implementation also focuses on matching :math:`RMS` of AdamW. The adjustment is computed as:
  214. :math:`\gamma \leftarrow {0.2}\gamma\,\sqrt{\max\!\left({A}, {B}\right)}`
  215. The method is adopted from `Muon is Scalable for LLM Training`_. Research
  216. results show that with this adjustment Muon can directly reuse the learning rate
  217. and weight decay tuned for AdamW.
  218. We provide two options for the learning rate adjustment: "original", which follows Keller's
  219. implementation, and "match_rms_adamw", which refers to Moonshot's implementation. This gives users the
  220. flexibility to choose between the two. If `adjust_lr_fn` is not specified, the default is "original".
  221. For further details regarding the algorithm we refer to `Muon: An optimizer for hidden layers in neural networks`_
  222. and `Muon is Scalable for LLM Training`_.
  223. """
  224. + rf"""
  225. Args:
  226. {_params_doc}. Note that Muon is an optimizer for 2D parameters of neural network hidden layers. Other
  227. parameters, such as bias, and embedding, should be optimized by a standard method such as AdamW.
  228. lr (float, Tensor, optional): learning rate (default: 1e-3).
  229. weight_decay (float, optional): weight decay (L2 penalty). (default: 0.1)
  230. momentum (float, optional): momentum factor (default: 0.95)
  231. nesterov (bool, optional): enables Nesterov momentum. Only applicable
  232. when momentum is non-zero
  233. ns_coefficients (tuple of three floats, optional): coefficients \(a,b,c\) for the
  234. Newton–Schulz orthogonalization polynomial (default: ({DEFAULT_A}, {DEFAULT_B}, {DEFAULT_C}))
  235. eps (float, optional): term added to the denominator for numerical stability. (default: {EPS})
  236. ns_steps (int, optional): number of Newton–Schulz iteration steps. (default: {DEFAULT_NS_STEPS})
  237. adjust_lr_fn (str, optional): function to adjust learning rate. One of "original" and "match_rms_adamw".
  238. If not specified, we will default to use "original". (default: None)
  239. .. _Muon\: An optimizer for hidden layers in neural networks:
  240. https://kellerjordan.github.io/posts/muon/
  241. .. _Muon is Scalable for LLM Training:
  242. https://arxiv.org/pdf/2502.16982
  243. """
  244. )
  245. def _single_tensor_muon(
  246. params: list[Tensor],
  247. grads: list[Tensor],
  248. muon_momentum_bufs: list[Tensor],
  249. *,
  250. lr: float,
  251. weight_decay: float,
  252. momentum: float,
  253. nesterov: bool,
  254. ns_coefficients: tuple[float, float, float],
  255. ns_steps: int,
  256. eps: float,
  257. adjust_lr_fn: Optional[str],
  258. has_complex: bool,
  259. ) -> None:
  260. lr = _to_scalar(lr)
  261. if has_complex:
  262. raise ValueError("Complex parameters are not supported")
  263. for i, param in enumerate(params):
  264. grad = grads[i]
  265. if grad.ndim != 2:
  266. raise ValueError("Param gradient must be a 2D matrix")
  267. buf = muon_momentum_bufs[i]
  268. buf.lerp_(grad, 1 - momentum)
  269. update = grad.lerp(buf, momentum) if nesterov else buf
  270. update = _zeropower_via_newtonschulz(update, ns_coefficients, ns_steps, eps)
  271. adjusted_lr = _adjust_lr(lr, adjust_lr_fn, param.shape)
  272. param.mul_(1 - lr * weight_decay)
  273. param.add_(update, alpha=-adjusted_lr)
  274. @_disable_dynamo_if_unsupported(single_tensor_fn=_single_tensor_muon)
  275. def muon(
  276. params: list[Tensor],
  277. grads: list[Tensor],
  278. muon_momentum_bufs: list[Tensor],
  279. *,
  280. foreach: Optional[bool] = None,
  281. lr: float,
  282. weight_decay: float,
  283. momentum: float,
  284. nesterov: bool,
  285. ns_coefficients: tuple[float, float, float],
  286. ns_steps: int,
  287. eps: float,
  288. adjust_lr_fn: Optional[str],
  289. has_complex: bool,
  290. ):
  291. r"""Functional API that performs Muon algorithm computation.
  292. See :class:`~torch.optim.Muon` for details.
  293. """
  294. if foreach is not None and foreach:
  295. raise RuntimeError("Foreach is not supported for Muon yet")
  296. func = _single_tensor_muon
  297. func(
  298. params,
  299. grads,
  300. muon_momentum_bufs,
  301. lr=lr,
  302. weight_decay=weight_decay,
  303. momentum=momentum,
  304. nesterov=nesterov,
  305. ns_coefficients=ns_coefficients,
  306. ns_steps=ns_steps,
  307. eps=eps,
  308. adjust_lr_fn=adjust_lr_fn,
  309. has_complex=has_complex,
  310. )