muon.py 26 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687
  1. """ Muon Optimizer
  2. Improved Muon optimizer implementation with flexible handling of high-dimensional tensors.
  3. Combines PyTorch-style structure with options for:
  4. - Batched spatial processing for convolutions in addition to flatten
  5. - Optional spatial normalization
  6. - Selectable coefficient presets
  7. - Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups
  8. Based on implementation by Keller Jordan, see
  9. - https://github.com/KellerJordan/Muon/blob/master/muon.py
  10. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  11. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py
  12. - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
  13. Hacked together by Ross Wightman
  14. """
  15. import logging
  16. import numbers
  17. from typing import List, Mapping, Optional, Sequence, Tuple, Union
  18. import torch
  19. from ._types import ParamsT
  20. from .adamw import adamw
  21. from .nadamw import nadamw
  22. _logger = logging.getLogger(__name__)
  23. # Constants from Keller Jordan's Muon
  24. MUON_EPS = 1e-7
  25. DEFAULT_NS_STEPS = 5
  26. _COEFFICIENTS = {
  27. "original": [
  28. # Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/
  29. (3.4445, -4.7750, 2.0315),
  30. ],
  31. "quintic": [
  32. # https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients
  33. # From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44
  34. (4.0848, -6.8946, 2.9270),
  35. (3.9505, -6.3029, 2.6377),
  36. (3.7418, -5.5913, 2.3037),
  37. (2.8769, -3.1427, 1.2046),
  38. (2.8366, -3.0525, 1.2012),
  39. ],
  40. "polar_express": [
  41. # Polar Express https://arxiv.org/abs/2505.16932
  42. # From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2
  43. (8.237312490495555, -23.157747414558198, 16.680568411445915),
  44. (4.082441999064835, -2.893047735332586, 0.5252849256975648),
  45. (3.9263479922546582, -2.8547468034765298, 0.5318022422894988),
  46. (3.2982187133085143, -2.424541981026706, 0.48632008358844075),
  47. (2.2970369434552573, -1.63662558125903, 0.4002628455953627),
  48. (1.8763805351440397, -1.2347896577722228, 0.35891887501668385),
  49. (1.8564423485617974, -1.2132449880935525, 0.3568003487825883),
  50. (1.8749994008682747, -1.2499988017229169, 0.3749994008546422),
  51. ],
  52. "polar_express_safer": [
  53. # from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  54. # w/ safety 2e-2
  55. (8.156554524902461, -22.48329292557795, 15.878769915207462),
  56. (4.0429299351667245, -2.808917465908704, 0.5000178451051299),
  57. (3.8916678022926563, -2.7724841532176825, 0.5060648178503389),
  58. (3.285753657755658, -2.3681294933425394, 0.46449024233003117),
  59. (2.3005307116270983, -1.6111665557258408, 0.3833374427545273),
  60. (1.8631210546382593, -1.2042160621002727, 0.3421879560523383),
  61. (1.8382572152247512, -1.1779263289537742, 0.3396513038637379),
  62. (1.8749999923301852, -1.2499999836060613, 0.374999991275876),
  63. ],
  64. }
  65. NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]]
  66. def zeropower_via_newtonschulz(
  67. G: torch.Tensor,
  68. steps: int,
  69. coefficients: List[Tuple[float, float, float]],
  70. eps: float = MUON_EPS,
  71. safety_factor: float = 1.0,
  72. dtype: torch.dtype = torch.bfloat16,
  73. ) -> torch.Tensor:
  74. """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient.
  75. Supports batched operation over leading dimensions.
  76. See
  77. - https://github.com/KellerJordan/Muon/blob/master/muon.py
  78. - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
  79. - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
  80. Args:
  81. G: Input gradient tensor of shape (m, n) or (batch, m, n)
  82. steps: Number of Newton-Schulz iterations
  83. coefficients: Coefficients (a, b, c) for the iteration
  84. eps: Numerical stability epsilon for norm
  85. safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants)
  86. dtype: Computation dtype
  87. Returns:
  88. Orthogonalized tensor of same shape as G
  89. """
  90. assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first."
  91. num_cs = len(coefficients)
  92. assert num_cs >= 1 and len(coefficients[0]) == 3
  93. # match coefficients with # of steps, truncate or repeat last
  94. coeff_sequence = coefficients[:steps] if steps <= num_cs else \
  95. coefficients + [coefficients[-1]] * (steps - num_cs)
  96. X = G.to(dtype=dtype, copy=True)
  97. # Transpose if needed (operate on dimension with fewer elements)
  98. transposed = X.size(-2) > X.size(-1)
  99. if transposed:
  100. X = X.mT
  101. # Normalize spectral norm to at most 1
  102. X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps))
  103. # Batched vs unbatched fused MM
  104. mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
  105. # Pre-allocate
  106. X = X.contiguous()
  107. A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
  108. B = torch.empty_like(A)
  109. C = torch.empty_like(X)
  110. # Perform Newton-Schulz iterations
  111. for a, b, c in coeff_sequence:
  112. mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
  113. mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
  114. mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
  115. X, C = C, X # swap refs to avoid copy
  116. if transposed:
  117. X = X.mT
  118. return X
  119. def get_lr_scale(
  120. param_shape: torch.Size,
  121. adjust_lr_fn: str = "match_rms_adamw"
  122. ) -> float:
  123. """Adjust learning rate based on parameter shape."""
  124. out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
  125. if adjust_lr_fn == "original":
  126. # Original Muon impl (https://kellerjordan.github.io/posts/muon/)
  127. return max(1, out_chs / in_chs) ** 0.5
  128. elif adjust_lr_fn == "match_rms_adamw":
  129. # Kimi (https://arxiv.org/abs/2502.16982)
  130. return 0.2 * max(out_chs, in_chs) ** 0.5
  131. elif adjust_lr_fn == "rms_to_rms":
  132. # Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion)
  133. # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
  134. return (out_chs / in_chs) ** 0.5
  135. else:
  136. assert False, f'Invalid scaling function "{adjust_lr_fn}"'
  137. def _is_suitable_for_muon(
  138. param: torch.Tensor,
  139. min_dim_size: int = 4,
  140. max_aspect_ratio: float = 128.,
  141. return_reason: bool = False,
  142. ) -> Union[bool, Tuple[bool, str]]:
  143. """Check if a parameter is suitable for Muon optimization.
  144. Args:
  145. param: Parameter tensor
  146. min_dim_size: Minimum size for non-unit dimensions
  147. max_aspect_ratio: Maximum allowed aspect ratio
  148. return_reason: If True, return (bool, reason_string), else just bool (faster)
  149. Returns:
  150. If return_reason=False: bool indicating suitability
  151. If return_reason=True: Tuple of (is_suitable, reason_string)
  152. Examples:
  153. (64, 128) -> True (or (True, "ok") if return_reason=True)
  154. (96, 3, 4, 4) -> True - will be flattened to (96, 48)
  155. (4, 2048) -> False - extreme aspect ratio
  156. (64,) -> False - insufficient dims
  157. (1, 196, 768) -> False - leading unit dims
  158. NOTE: these rules were created to balance complexity with covering common timm model cases
  159. Please let me know if there are non-optimal cases that you run into.
  160. """
  161. s = param.shape
  162. # Must have at least 2 non-unit dimensions
  163. if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2:
  164. return (False, "insufficient_dims") if return_reason else False
  165. # Unit dimension in first two positions indicates:
  166. # - Position embeddings (1, seq, dim)
  167. # - Depthwise convs (out, 1, h, w)
  168. # - Other degenerate cases possibly not caught by first rule
  169. if s[0] == 1 or s[1] == 1:
  170. return (False, "leading_unit_dims") if return_reason else False
  171. if param.ndim >= 3:
  172. # For 3D+ tensors, check what dimensions will be AFTER flattening
  173. # since that's what gets passed to Newton-Schulz iteration
  174. # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
  175. out_ch = s[0]
  176. in_ch_with_spatial = 1
  177. for d in s[1:]:
  178. in_ch_with_spatial *= d
  179. check_dims = (out_ch, in_ch_with_spatial)
  180. else:
  181. # For 2D tensors, check as-is
  182. check_dims = s
  183. # Both dims should be >= minimum size
  184. min_size = min(check_dims)
  185. if min_size < min_dim_size:
  186. if return_reason:
  187. return False, f"min_dim_too_small:{min_size}"
  188. return False
  189. # Aspect ratio shouldn't be too extreme
  190. max_size = max(check_dims)
  191. aspect_ratio = max_size / min_size
  192. if aspect_ratio > max_aspect_ratio:
  193. if return_reason:
  194. return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}"
  195. return False
  196. return (True, "ok") if return_reason else True
  197. def reshape_for_muon(
  198. tensor: torch.Tensor,
  199. mode: str = "flatten",
  200. ) -> Tuple[torch.Tensor, torch.Size]:
  201. """Reshape high-dimensional tensor for Muon processing.
  202. Args:
  203. tensor: Input tensor of shape (out, in, *spatial)
  204. mode: How to handle spatial dimensions
  205. - "flatten": Flatten spatial into output dimension (out, in*H*W)
  206. - "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization
  207. Returns:
  208. Reshaped tensor and original shape for restoration
  209. """
  210. original_shape = tensor.shape
  211. if tensor.ndim == 2:
  212. return tensor, original_shape
  213. if tensor.ndim < 2:
  214. raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}")
  215. out_ch, in_ch = tensor.shape[:2]
  216. if mode == "flatten":
  217. # Flatten: (out, in, *spatial) -> (out, in * spatial_prod)
  218. return tensor.reshape(out_ch, -1), original_shape
  219. elif mode == "batched":
  220. # Batched: (out, in, *spatial) -> (spatial_prod, out, in)
  221. # Move spatial dimension to front so zeropower_via_newtonschulz batches over it
  222. reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod)
  223. reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in)
  224. return reshaped, original_shape
  225. else:
  226. raise ValueError(f"Unknown mode: {mode}")
  227. def muon(
  228. params: List[torch.Tensor],
  229. grads: List[torch.Tensor],
  230. momentum_bufs: List[torch.Tensor],
  231. *,
  232. lr: float,
  233. weight_decay: float,
  234. momentum: float,
  235. nesterov: bool,
  236. ns_steps: int,
  237. ns_coefficients: NSCoeff,
  238. eps: float,
  239. safety_factor: float,
  240. adjust_lr_fn: Optional[str],
  241. conv_mode: str,
  242. normalize_spatial: bool,
  243. ) -> None:
  244. """Functional API that performs Muon algorithm computation."""
  245. _single_tensor_muon(
  246. params,
  247. grads,
  248. momentum_bufs,
  249. lr=lr,
  250. weight_decay=weight_decay,
  251. momentum=momentum,
  252. nesterov=nesterov,
  253. ns_steps=ns_steps,
  254. ns_coefficients=ns_coefficients,
  255. eps=eps,
  256. safety_factor=safety_factor,
  257. adjust_lr_fn=adjust_lr_fn,
  258. conv_mode=conv_mode,
  259. normalize_spatial=normalize_spatial,
  260. )
  261. def _single_tensor_muon(
  262. params: List[torch.Tensor],
  263. grads: List[torch.Tensor],
  264. momentum_bufs: List[torch.Tensor],
  265. *,
  266. lr: float,
  267. weight_decay: float,
  268. momentum: float,
  269. nesterov: bool,
  270. ns_steps: int,
  271. ns_coefficients: NSCoeff,
  272. eps: float,
  273. safety_factor: float,
  274. adjust_lr_fn: Optional[str],
  275. conv_mode: str,
  276. normalize_spatial: bool,
  277. ) -> None:
  278. """Single tensor Muon update."""
  279. ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
  280. for i, param in enumerate(params):
  281. grad = grads[i]
  282. momentum_buf = momentum_bufs[i]
  283. # Apply weight decay
  284. param.mul_(1 - lr * weight_decay)
  285. # Update momentum buffer
  286. momentum_buf.lerp_(grad, 1. - momentum)
  287. update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
  288. # Reshape for processing (handle 3D+ tensors like conv weights)
  289. if update.ndim >= 3:
  290. update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
  291. else:
  292. update_reshaped = update
  293. original_shape = update.shape
  294. # Apply Newton-Schulz orthogonalization
  295. update_ortho = zeropower_via_newtonschulz(
  296. update_reshaped,
  297. ns_steps,
  298. ns_coefficients,
  299. eps=eps,
  300. safety_factor=safety_factor,
  301. #dtype=torch.bfloat16, # wire to arg?
  302. )
  303. # Adjust learning rate based on parameter shape
  304. scale = get_lr_scale(update_ortho.shape, adjust_lr_fn)
  305. # Apply spatial normalization and permute back if in batched mode
  306. if conv_mode == "batched" and update_ortho.ndim >= 3:
  307. if normalize_spatial:
  308. scale *= update_ortho.shape[0] ** -0.5
  309. # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
  310. update_ortho = update_ortho.permute(1, 2, 0)
  311. # Reshape back to original shape
  312. update_ortho = update_ortho.reshape(original_shape)
  313. # Apply update
  314. param.add_(update_ortho, alpha=-lr * scale)
  315. class Muon(torch.optim.Optimizer):
  316. """Muon - MomentUm Orthogonalized by Newton-schulz
  317. Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
  318. parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
  319. """
  320. def __init__(
  321. self,
  322. params: ParamsT,
  323. lr: float = 0.02,
  324. weight_decay: float = 0,
  325. momentum: float = 0.95,
  326. nesterov: bool = False,
  327. ns_steps: int = DEFAULT_NS_STEPS,
  328. ns_coefficients: NSCoeff = "quintic",
  329. eps: float = MUON_EPS,
  330. safety_factor: float = 1.0,
  331. adjust_lr_fn: Optional[str] = "match_rms_adamw",
  332. conv_mode: str = "flatten",
  333. normalize_spatial: bool = True,
  334. adamw_lr: Optional[float] = None,
  335. betas: Tuple[float, float] = (0.9, 0.95),
  336. verbose: bool = False,
  337. ):
  338. """ Create Muon optimizer.
  339. Args:
  340. params: Iterable of parameters or dicts defining parameter groups
  341. lr: Learning rate (default: 0.02 for Muon parameters)
  342. weight_decay: Weight decay coefficient
  343. momentum: Momentum factor for Muon
  344. nesterov: Whether to use Nesterov momentum
  345. ns_steps: Number of Newton-Schulz iterations
  346. ns_coefficients: Coefficients for NS iteration
  347. eps: Numerical stability epsilon
  348. safety_factor: Multiplicative safety factor for NS norm
  349. adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw"
  350. conv_mode: How to handle convolutions - "flatten" or "batched"
  351. normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
  352. adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
  353. betas: AdamW beta coefficients
  354. verbose: Log parameter routing decisions (Muon vs AdamW)
  355. Example:
  356. ```python
  357. # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D
  358. optimizer = Muon(model.parameters(), lr=0.02)
  359. # Manual control over parameter groups
  360. optimizer = Muon([
  361. {'params': weight_matrices, 'lr': 0.02},
  362. {'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
  363. ])
  364. ```
  365. """
  366. if not 0.0 <= lr:
  367. raise ValueError(f"Invalid learning rate: {lr}")
  368. if not 0.0 <= weight_decay:
  369. raise ValueError(f"Invalid weight_decay value: {weight_decay}")
  370. if not 0.0 <= momentum < 1.0:
  371. raise ValueError(f"Invalid momentum value: {momentum}")
  372. if not 0.0 <= eps:
  373. raise ValueError(f"Invalid epsilon value: {eps}")
  374. if conv_mode not in ["flatten", "batched"]:
  375. raise ValueError(f"Invalid conv_mode: {conv_mode}")
  376. defaults = dict(
  377. lr=lr,
  378. weight_decay=weight_decay,
  379. momentum=momentum,
  380. nesterov=nesterov,
  381. ns_steps=ns_steps,
  382. ns_coefficients=ns_coefficients,
  383. eps=eps,
  384. safety_factor=safety_factor,
  385. adjust_lr_fn=adjust_lr_fn,
  386. conv_mode=conv_mode,
  387. normalize_spatial=normalize_spatial,
  388. adamw_lr=adamw_lr if adamw_lr is not None else lr,
  389. betas=betas,
  390. verbose=verbose,
  391. )
  392. super().__init__(params, defaults)
  393. @torch.no_grad()
  394. def step(self, closure=None):
  395. """Performs a single optimization step."""
  396. loss = None
  397. if closure is not None:
  398. with torch.enable_grad():
  399. loss = closure()
  400. verbose = self.defaults.get("verbose", False)
  401. # Tracking for logging (populated on first encounter of each param)
  402. muon_count = 0
  403. adamw_count = 0
  404. routing_reasons = {} if verbose else None
  405. for group in self.param_groups:
  406. # Separate params into Muon and AdamW groups
  407. muon_params = []
  408. muon_grads = []
  409. muon_momentum_bufs = []
  410. adamw_params = []
  411. adamw_grads = []
  412. adamw_exp_avgs = []
  413. adamw_exp_avg_sqs = []
  414. adamw_state_steps = []
  415. for p in group["params"]:
  416. if p.grad is None:
  417. continue
  418. if p.grad.is_sparse:
  419. raise RuntimeError("Muon does not support sparse gradients")
  420. state = self.state[p]
  421. # Determine routing on first encounter (cache in state)
  422. if "use_muon" not in state:
  423. # Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
  424. reason = None
  425. if group.get("use_fallback", False):
  426. # use_fallback=True means use AdamW (use_muon=False)
  427. state["use_muon"] = False
  428. if verbose:
  429. reason = "use_fallback_flag"
  430. elif "use_muon" in group:
  431. # Explicit use_muon flag for compatibility with other Muon implementations
  432. state["use_muon"] = group["use_muon"]
  433. if verbose:
  434. reason = "use_muon_flag"
  435. else:
  436. # Check shape suitability
  437. if verbose:
  438. suitable, reason = _is_suitable_for_muon(p, return_reason=True)
  439. else:
  440. suitable = _is_suitable_for_muon(p, return_reason=False)
  441. state["use_muon"] = suitable
  442. # Track routing decision for logging
  443. if routing_reasons is not None and reason is not None:
  444. shape_str = "x".join(str(s) for s in p.shape)
  445. if shape_str not in routing_reasons:
  446. routing_reasons[shape_str] = []
  447. routing_reasons[shape_str].append(reason)
  448. # Use cached routing decision
  449. use_muon = state["use_muon"]
  450. if use_muon:
  451. # Collect Muon params
  452. muon_params.append(p)
  453. muon_grads.append(p.grad)
  454. muon_count += 1
  455. # State initialization for Muon
  456. if "momentum_buffer" not in state:
  457. state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  458. muon_momentum_bufs.append(state["momentum_buffer"])
  459. else:
  460. # Collect AdamW/NAdamW params
  461. adamw_params.append(p)
  462. adamw_grads.append(p.grad)
  463. adamw_count += 1
  464. # State initialization for AdamW
  465. if "step" not in state:
  466. state["step"] = torch.tensor(0.)
  467. state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  468. state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
  469. adamw_exp_avgs.append(state["exp_avg"])
  470. adamw_exp_avg_sqs.append(state["exp_avg_sq"])
  471. adamw_state_steps.append(state["step"])
  472. # Apply Muon updates
  473. if muon_params:
  474. muon(
  475. muon_params,
  476. muon_grads,
  477. muon_momentum_bufs,
  478. lr=group["lr"],
  479. weight_decay=group["weight_decay"],
  480. momentum=group["momentum"],
  481. nesterov=group["nesterov"],
  482. ns_steps=group["ns_steps"],
  483. ns_coefficients=group["ns_coefficients"],
  484. eps=group["eps"],
  485. safety_factor=group["safety_factor"],
  486. adjust_lr_fn=group["adjust_lr_fn"],
  487. conv_mode=group["conv_mode"],
  488. normalize_spatial=group["normalize_spatial"],
  489. )
  490. # Apply AdamW updates
  491. if adamw_params:
  492. beta1, beta2 = group["betas"]
  493. if group["nesterov"]:
  494. # use nadamw for fallback optimizer if nesterov is enabled
  495. nadamw(
  496. adamw_params,
  497. adamw_grads,
  498. adamw_exp_avgs,
  499. adamw_exp_avg_sqs,
  500. adamw_state_steps,
  501. foreach=None,
  502. beta1=beta1,
  503. beta2=beta2,
  504. lr=group["adamw_lr"],
  505. weight_decay=group["weight_decay"],
  506. eps=group["eps"],
  507. caution=False,
  508. maximize=False,
  509. capturable=False,
  510. max_lr=None,
  511. )
  512. else:
  513. adamw(
  514. adamw_params,
  515. adamw_grads,
  516. adamw_exp_avgs,
  517. adamw_exp_avg_sqs,
  518. [], # max_exp_avg_sqs (not using amsgrad)
  519. adamw_state_steps,
  520. foreach=None,
  521. amsgrad=False,
  522. beta1=beta1,
  523. beta2=beta2,
  524. lr=group["adamw_lr"],
  525. weight_decay=group["weight_decay"],
  526. eps=group["eps"],
  527. caution=False,
  528. maximize=False,
  529. capturable=False,
  530. max_lr=None,
  531. )
  532. # Log routing summary when we have new routing decisions
  533. if routing_reasons and len(routing_reasons) > 0:
  534. # Concise summary
  535. _logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW")
  536. # Group by reason for detailed breakdown
  537. reason_groups = {}
  538. for shape_str, reasons in sorted(routing_reasons.items()):
  539. for reason in reasons:
  540. if reason not in reason_groups:
  541. reason_groups[reason] = []
  542. reason_groups[reason].append(shape_str)
  543. # Log summary counts per reason
  544. reason_summary = []
  545. for reason, shapes in sorted(reason_groups.items()):
  546. reason_summary.append(f"{reason}={len(shapes)}")
  547. _logger.info(f" Breakdown: {', '.join(reason_summary)}")
  548. # Detailed breakdown at INFO level
  549. if _logger.isEnabledFor(logging.INFO):
  550. for reason, shapes in sorted(reason_groups.items()):
  551. optimizer_name = "Muon" if reason == "ok" else "AdamW"
  552. _logger.info(f" {reason} -> {optimizer_name}:")
  553. for shape in shapes[:10]:
  554. _logger.info(f" {shape}")
  555. if len(shapes) > 10:
  556. _logger.info(f" ... and {len(shapes) - 10} more")
  557. return loss
  558. def resolve_ns_coefficients(
  559. value: Union[str, Sequence[float], Sequence[Sequence[float]]],
  560. presets: Mapping[str, Sequence[Sequence[float]]]
  561. ) -> List[Tuple[float, float, float]]:
  562. # tiny helpers (kept inline for succinctness)
  563. is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes))
  564. is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool)
  565. def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]:
  566. if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x):
  567. raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}")
  568. a, b, c = x # type: ignore[misc]
  569. return float(a), float(b), float(c)
  570. if isinstance(value, str):
  571. if value not in presets:
  572. valid = ", ".join(sorted(presets.keys()))
  573. raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}")
  574. seq = presets[value]
  575. if not is_seq(seq) or len(seq) == 0:
  576. raise ValueError(f"Preset '{value}' is empty or invalid")
  577. return [as_coeff(item) for item in seq] # validate & cast
  578. if not is_seq(value):
  579. raise TypeError(
  580. "Coefficients must be a preset name (str), a 3-sequence (a,b,c), "
  581. "or a sequence of 3-sequences."
  582. )
  583. # Decide single triple vs list-of-triples by structure
  584. if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index]
  585. return [as_coeff(value)] # single triple -> wrap
  586. # Otherwise treat as list/tuple of triples
  587. out = []
  588. for i, item in enumerate(value): # type: ignore[assignment]
  589. if not is_seq(item):
  590. raise TypeError(f"Item {i} is not a sequence: {item!r}")
  591. out.append(as_coeff(item))
  592. if not out:
  593. raise ValueError("Coefficient list cannot be empty")
  594. return out