| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687 |
- """ Muon Optimizer
- Improved Muon optimizer implementation with flexible handling of high-dimensional tensors.
- Combines PyTorch-style structure with options for:
- - Batched spatial processing for convolutions in addition to flatten
- - Optional spatial normalization
- - Selectable coefficient presets
- - Automatic fallback to AdamW for 1D / scalar parameters (biases, norms, etc.) and optional fallback via param groups
- Based on implementation by Keller Jordan, see
- - https://github.com/KellerJordan/Muon/blob/master/muon.py
- - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
- - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py
- - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
- Hacked together by Ross Wightman
- """
- import logging
- import numbers
- from typing import List, Mapping, Optional, Sequence, Tuple, Union
- import torch
- from ._types import ParamsT
- from .adamw import adamw
- from .nadamw import nadamw
- _logger = logging.getLogger(__name__)
- # Constants from Keller Jordan's Muon
- MUON_EPS = 1e-7
- DEFAULT_NS_STEPS = 5
- _COEFFICIENTS = {
- "original": [
- # Keller Jordan's Muon https://kellerjordan.github.io/posts/muon/
- (3.4445, -4.7750, 2.0315),
- ],
- "quintic": [
- # https://leloykun.github.io/ponder/muon-opt-coeffs/#how-do-we-optimize-the-coefficients
- # From https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt_medium.py#L44
- (4.0848, -6.8946, 2.9270),
- (3.9505, -6.3029, 2.6377),
- (3.7418, -5.5913, 2.3037),
- (2.8769, -3.1427, 1.2046),
- (2.8366, -3.0525, 1.2012),
- ],
- "polar_express": [
- # Polar Express https://arxiv.org/abs/2505.16932
- # From https://github.com/NoahAmsel/PolarExpress/tree/main with safety 1e-2
- (8.237312490495555, -23.157747414558198, 16.680568411445915),
- (4.082441999064835, -2.893047735332586, 0.5252849256975648),
- (3.9263479922546582, -2.8547468034765298, 0.5318022422894988),
- (3.2982187133085143, -2.424541981026706, 0.48632008358844075),
- (2.2970369434552573, -1.63662558125903, 0.4002628455953627),
- (1.8763805351440397, -1.2347896577722228, 0.35891887501668385),
- (1.8564423485617974, -1.2132449880935525, 0.3568003487825883),
- (1.8749994008682747, -1.2499988017229169, 0.3749994008546422),
- ],
- "polar_express_safer": [
- # from https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
- # w/ safety 2e-2
- (8.156554524902461, -22.48329292557795, 15.878769915207462),
- (4.0429299351667245, -2.808917465908704, 0.5000178451051299),
- (3.8916678022926563, -2.7724841532176825, 0.5060648178503389),
- (3.285753657755658, -2.3681294933425394, 0.46449024233003117),
- (2.3005307116270983, -1.6111665557258408, 0.3833374427545273),
- (1.8631210546382593, -1.2042160621002727, 0.3421879560523383),
- (1.8382572152247512, -1.1779263289537742, 0.3396513038637379),
- (1.8749999923301852, -1.2499999836060613, 0.374999991275876),
- ],
- }
- NSCoeff = Union[str, Tuple[float, float, float], List[Tuple[float, float, float]]]
- def zeropower_via_newtonschulz(
- G: torch.Tensor,
- steps: int,
- coefficients: List[Tuple[float, float, float]],
- eps: float = MUON_EPS,
- safety_factor: float = 1.0,
- dtype: torch.dtype = torch.bfloat16,
- ) -> torch.Tensor:
- """Newton-Schulz quintic iteration to compute the zeroth power / orthogonalization of gradient.
- Supports batched operation over leading dimensions.
- See
- - https://github.com/KellerJordan/Muon/blob/master/muon.py
- - https://github.com/NoahAmsel/PolarExpress/blob/main/polar_express.py
- - https://github.com/KellerJordan/modded-nanogpt/blob/master/train_gpt.py
- Args:
- G: Input gradient tensor of shape (m, n) or (batch, m, n)
- steps: Number of Newton-Schulz iterations
- coefficients: Coefficients (a, b, c) for the iteration
- eps: Numerical stability epsilon for norm
- safety_factor: Multiplicative safety factor for norm (1.01 is common safety value in 'polar express' variants)
- dtype: Computation dtype
- Returns:
- Orthogonalized tensor of same shape as G
- """
- assert G.ndim in (2, 3), f"Input must be 2D or 3D, got {G.ndim}D. Flatten batch dims first."
- num_cs = len(coefficients)
- assert num_cs >= 1 and len(coefficients[0]) == 3
- # match coefficients with # of steps, truncate or repeat last
- coeff_sequence = coefficients[:steps] if steps <= num_cs else \
- coefficients + [coefficients[-1]] * (steps - num_cs)
- X = G.to(dtype=dtype, copy=True)
- # Transpose if needed (operate on dimension with fewer elements)
- transposed = X.size(-2) > X.size(-1)
- if transposed:
- X = X.mT
- # Normalize spectral norm to at most 1
- X.div_(X.norm(2, dim=(-2, -1), keepdim=True).mul(safety_factor).clamp_min(eps))
- # Batched vs unbatched fused MM
- mm_fn = torch.baddbmm if X.ndim > 2 else torch.addmm
- # Pre-allocate
- X = X.contiguous()
- A = torch.empty((*X.shape[:-1], X.size(-2)), device=X.device, dtype=X.dtype)
- B = torch.empty_like(A)
- C = torch.empty_like(X)
- # Perform Newton-Schulz iterations
- for a, b, c in coeff_sequence:
- mm_fn(A, X, X.mT, beta=0.0, alpha=1.0, out=A) # A = X @ X.mT
- mm_fn(A, A, A, beta=b, alpha=c, out=B) # B = b * A + c * A @ A
- mm_fn(X, B, X, beta=a, alpha=1.0, out=C) # C = a * X + B @ X
- X, C = C, X # swap refs to avoid copy
- if transposed:
- X = X.mT
- return X
- def get_lr_scale(
- param_shape: torch.Size,
- adjust_lr_fn: str = "match_rms_adamw"
- ) -> float:
- """Adjust learning rate based on parameter shape."""
- out_chs, in_chs = (param_shape[-2], param_shape[-1]) if len(param_shape) > 1 else (1., 1.)
- if adjust_lr_fn == "original":
- # Original Muon impl (https://kellerjordan.github.io/posts/muon/)
- return max(1, out_chs / in_chs) ** 0.5
- elif adjust_lr_fn == "match_rms_adamw":
- # Kimi (https://arxiv.org/abs/2502.16982)
- return 0.2 * max(out_chs, in_chs) ** 0.5
- elif adjust_lr_fn == "rms_to_rms":
- # Scion (https://arxiv.org/abs/2502.07529, https://github.com/LIONS-EPFL/scion)
- # Bernstein et al. (https://jeremybernste.in/writing/deriving-muon)
- return (out_chs / in_chs) ** 0.5
- else:
- assert False, f'Invalid scaling function "{adjust_lr_fn}"'
- def _is_suitable_for_muon(
- param: torch.Tensor,
- min_dim_size: int = 4,
- max_aspect_ratio: float = 128.,
- return_reason: bool = False,
- ) -> Union[bool, Tuple[bool, str]]:
- """Check if a parameter is suitable for Muon optimization.
- Args:
- param: Parameter tensor
- min_dim_size: Minimum size for non-unit dimensions
- max_aspect_ratio: Maximum allowed aspect ratio
- return_reason: If True, return (bool, reason_string), else just bool (faster)
- Returns:
- If return_reason=False: bool indicating suitability
- If return_reason=True: Tuple of (is_suitable, reason_string)
- Examples:
- (64, 128) -> True (or (True, "ok") if return_reason=True)
- (96, 3, 4, 4) -> True - will be flattened to (96, 48)
- (4, 2048) -> False - extreme aspect ratio
- (64,) -> False - insufficient dims
- (1, 196, 768) -> False - leading unit dims
- NOTE: these rules were created to balance complexity with covering common timm model cases
- Please let me know if there are non-optimal cases that you run into.
- """
- s = param.shape
- # Must have at least 2 non-unit dimensions
- if param.ndim < 2 or sum(1 for dim_size in s if dim_size > 1) < 2:
- return (False, "insufficient_dims") if return_reason else False
- # Unit dimension in first two positions indicates:
- # - Position embeddings (1, seq, dim)
- # - Depthwise convs (out, 1, h, w)
- # - Other degenerate cases possibly not caught by first rule
- if s[0] == 1 or s[1] == 1:
- return (False, "leading_unit_dims") if return_reason else False
- if param.ndim >= 3:
- # For 3D+ tensors, check what dimensions will be AFTER flattening
- # since that's what gets passed to Newton-Schulz iteration
- # Flatten mode: (out, in, *spatial) -> (out, in * spatial_prod)
- out_ch = s[0]
- in_ch_with_spatial = 1
- for d in s[1:]:
- in_ch_with_spatial *= d
- check_dims = (out_ch, in_ch_with_spatial)
- else:
- # For 2D tensors, check as-is
- check_dims = s
- # Both dims should be >= minimum size
- min_size = min(check_dims)
- if min_size < min_dim_size:
- if return_reason:
- return False, f"min_dim_too_small:{min_size}"
- return False
- # Aspect ratio shouldn't be too extreme
- max_size = max(check_dims)
- aspect_ratio = max_size / min_size
- if aspect_ratio > max_aspect_ratio:
- if return_reason:
- return False, f"extreme_aspect_ratio:{aspect_ratio:.1f}"
- return False
- return (True, "ok") if return_reason else True
- def reshape_for_muon(
- tensor: torch.Tensor,
- mode: str = "flatten",
- ) -> Tuple[torch.Tensor, torch.Size]:
- """Reshape high-dimensional tensor for Muon processing.
- Args:
- tensor: Input tensor of shape (out, in, *spatial)
- mode: How to handle spatial dimensions
- - "flatten": Flatten spatial into output dimension (out, in*H*W)
- - "batched": Batch over spatial positions (spatial_prod, out, in) for per-position orthogonalization
- Returns:
- Reshaped tensor and original shape for restoration
- """
- original_shape = tensor.shape
- if tensor.ndim == 2:
- return tensor, original_shape
- if tensor.ndim < 2:
- raise ValueError(f"Tensor must have at least 2 dimensions, got {tensor.ndim}")
- out_ch, in_ch = tensor.shape[:2]
- if mode == "flatten":
- # Flatten: (out, in, *spatial) -> (out, in * spatial_prod)
- return tensor.reshape(out_ch, -1), original_shape
- elif mode == "batched":
- # Batched: (out, in, *spatial) -> (spatial_prod, out, in)
- # Move spatial dimension to front so zeropower_via_newtonschulz batches over it
- reshaped = tensor.reshape(out_ch, in_ch, -1) # (out, in, spatial_prod)
- reshaped = reshaped.permute(2, 0, 1) # (spatial_prod, out, in)
- return reshaped, original_shape
- else:
- raise ValueError(f"Unknown mode: {mode}")
- def muon(
- params: List[torch.Tensor],
- grads: List[torch.Tensor],
- momentum_bufs: List[torch.Tensor],
- *,
- lr: float,
- weight_decay: float,
- momentum: float,
- nesterov: bool,
- ns_steps: int,
- ns_coefficients: NSCoeff,
- eps: float,
- safety_factor: float,
- adjust_lr_fn: Optional[str],
- conv_mode: str,
- normalize_spatial: bool,
- ) -> None:
- """Functional API that performs Muon algorithm computation."""
- _single_tensor_muon(
- params,
- grads,
- momentum_bufs,
- lr=lr,
- weight_decay=weight_decay,
- momentum=momentum,
- nesterov=nesterov,
- ns_steps=ns_steps,
- ns_coefficients=ns_coefficients,
- eps=eps,
- safety_factor=safety_factor,
- adjust_lr_fn=adjust_lr_fn,
- conv_mode=conv_mode,
- normalize_spatial=normalize_spatial,
- )
- def _single_tensor_muon(
- params: List[torch.Tensor],
- grads: List[torch.Tensor],
- momentum_bufs: List[torch.Tensor],
- *,
- lr: float,
- weight_decay: float,
- momentum: float,
- nesterov: bool,
- ns_steps: int,
- ns_coefficients: NSCoeff,
- eps: float,
- safety_factor: float,
- adjust_lr_fn: Optional[str],
- conv_mode: str,
- normalize_spatial: bool,
- ) -> None:
- """Single tensor Muon update."""
- ns_coefficients = resolve_ns_coefficients(ns_coefficients, _COEFFICIENTS)
- for i, param in enumerate(params):
- grad = grads[i]
- momentum_buf = momentum_bufs[i]
- # Apply weight decay
- param.mul_(1 - lr * weight_decay)
- # Update momentum buffer
- momentum_buf.lerp_(grad, 1. - momentum)
- update = grad.lerp_(momentum_buf, momentum) if nesterov else momentum_buf.clone()
- # Reshape for processing (handle 3D+ tensors like conv weights)
- if update.ndim >= 3:
- update_reshaped, original_shape = reshape_for_muon(update, mode=conv_mode)
- else:
- update_reshaped = update
- original_shape = update.shape
- # Apply Newton-Schulz orthogonalization
- update_ortho = zeropower_via_newtonschulz(
- update_reshaped,
- ns_steps,
- ns_coefficients,
- eps=eps,
- safety_factor=safety_factor,
- #dtype=torch.bfloat16, # wire to arg?
- )
- # Adjust learning rate based on parameter shape
- scale = get_lr_scale(update_ortho.shape, adjust_lr_fn)
- # Apply spatial normalization and permute back if in batched mode
- if conv_mode == "batched" and update_ortho.ndim >= 3:
- if normalize_spatial:
- scale *= update_ortho.shape[0] ** -0.5
- # Permute back: (spatial_prod, out, in) -> (out, in, spatial_prod)
- update_ortho = update_ortho.permute(1, 2, 0)
- # Reshape back to original shape
- update_ortho = update_ortho.reshape(original_shape)
- # Apply update
- param.add_(update_ortho, alpha=-lr * scale)
- class Muon(torch.optim.Optimizer):
- """Muon - MomentUm Orthogonalized by Newton-schulz
- Combines Muon for 2D+ parameters (weight matrices) with AdamW for 1D parameters (biases, norms) and
- parameter groups with 'use_fallback=True' set (or 'use_muon=False' for compatibility).
- """
- def __init__(
- self,
- params: ParamsT,
- lr: float = 0.02,
- weight_decay: float = 0,
- momentum: float = 0.95,
- nesterov: bool = False,
- ns_steps: int = DEFAULT_NS_STEPS,
- ns_coefficients: NSCoeff = "quintic",
- eps: float = MUON_EPS,
- safety_factor: float = 1.0,
- adjust_lr_fn: Optional[str] = "match_rms_adamw",
- conv_mode: str = "flatten",
- normalize_spatial: bool = True,
- adamw_lr: Optional[float] = None,
- betas: Tuple[float, float] = (0.9, 0.95),
- verbose: bool = False,
- ):
- """ Create Muon optimizer.
- Args:
- params: Iterable of parameters or dicts defining parameter groups
- lr: Learning rate (default: 0.02 for Muon parameters)
- weight_decay: Weight decay coefficient
- momentum: Momentum factor for Muon
- nesterov: Whether to use Nesterov momentum
- ns_steps: Number of Newton-Schulz iterations
- ns_coefficients: Coefficients for NS iteration
- eps: Numerical stability epsilon
- safety_factor: Multiplicative safety factor for NS norm
- adjust_lr_fn: LR adjustment function - "original" or "match_rms_adamw"
- conv_mode: How to handle convolutions - "flatten" or "batched"
- normalize_spatial: Whether to normalize by sqrt(spatial_size) in batched mode
- adamw_lr: Learning rate for AdamW (1D params), defaults to lr if not specified
- betas: AdamW beta coefficients
- verbose: Log parameter routing decisions (Muon vs AdamW)
- Example:
- ```python
- # Simple usage - automatically uses Muon for 2D+ params, AdamW for 1D
- optimizer = Muon(model.parameters(), lr=0.02)
- # Manual control over parameter groups
- optimizer = Muon([
- {'params': weight_matrices, 'lr': 0.02},
- {'params': biases, 'use_fallback': True, 'lr': 3e-4}, # use AdamW if use_fallback=True
- ])
- ```
- """
- if not 0.0 <= lr:
- raise ValueError(f"Invalid learning rate: {lr}")
- if not 0.0 <= weight_decay:
- raise ValueError(f"Invalid weight_decay value: {weight_decay}")
- if not 0.0 <= momentum < 1.0:
- raise ValueError(f"Invalid momentum value: {momentum}")
- if not 0.0 <= eps:
- raise ValueError(f"Invalid epsilon value: {eps}")
- if conv_mode not in ["flatten", "batched"]:
- raise ValueError(f"Invalid conv_mode: {conv_mode}")
- defaults = dict(
- lr=lr,
- weight_decay=weight_decay,
- momentum=momentum,
- nesterov=nesterov,
- ns_steps=ns_steps,
- ns_coefficients=ns_coefficients,
- eps=eps,
- safety_factor=safety_factor,
- adjust_lr_fn=adjust_lr_fn,
- conv_mode=conv_mode,
- normalize_spatial=normalize_spatial,
- adamw_lr=adamw_lr if adamw_lr is not None else lr,
- betas=betas,
- verbose=verbose,
- )
- super().__init__(params, defaults)
- @torch.no_grad()
- def step(self, closure=None):
- """Performs a single optimization step."""
- loss = None
- if closure is not None:
- with torch.enable_grad():
- loss = closure()
- verbose = self.defaults.get("verbose", False)
- # Tracking for logging (populated on first encounter of each param)
- muon_count = 0
- adamw_count = 0
- routing_reasons = {} if verbose else None
- for group in self.param_groups:
- # Separate params into Muon and AdamW groups
- muon_params = []
- muon_grads = []
- muon_momentum_bufs = []
- adamw_params = []
- adamw_grads = []
- adamw_exp_avgs = []
- adamw_exp_avg_sqs = []
- adamw_state_steps = []
- for p in group["params"]:
- if p.grad is None:
- continue
- if p.grad.is_sparse:
- raise RuntimeError("Muon does not support sparse gradients")
- state = self.state[p]
- # Determine routing on first encounter (cache in state)
- if "use_muon" not in state:
- # Check explicit flags first (support both 'use_fallback' and 'use_muon' for compatibility)
- reason = None
- if group.get("use_fallback", False):
- # use_fallback=True means use AdamW (use_muon=False)
- state["use_muon"] = False
- if verbose:
- reason = "use_fallback_flag"
- elif "use_muon" in group:
- # Explicit use_muon flag for compatibility with other Muon implementations
- state["use_muon"] = group["use_muon"]
- if verbose:
- reason = "use_muon_flag"
- else:
- # Check shape suitability
- if verbose:
- suitable, reason = _is_suitable_for_muon(p, return_reason=True)
- else:
- suitable = _is_suitable_for_muon(p, return_reason=False)
- state["use_muon"] = suitable
- # Track routing decision for logging
- if routing_reasons is not None and reason is not None:
- shape_str = "x".join(str(s) for s in p.shape)
- if shape_str not in routing_reasons:
- routing_reasons[shape_str] = []
- routing_reasons[shape_str].append(reason)
- # Use cached routing decision
- use_muon = state["use_muon"]
- if use_muon:
- # Collect Muon params
- muon_params.append(p)
- muon_grads.append(p.grad)
- muon_count += 1
- # State initialization for Muon
- if "momentum_buffer" not in state:
- state["momentum_buffer"] = torch.zeros_like(p, memory_format=torch.preserve_format)
- muon_momentum_bufs.append(state["momentum_buffer"])
- else:
- # Collect AdamW/NAdamW params
- adamw_params.append(p)
- adamw_grads.append(p.grad)
- adamw_count += 1
- # State initialization for AdamW
- if "step" not in state:
- state["step"] = torch.tensor(0.)
- state["exp_avg"] = torch.zeros_like(p, memory_format=torch.preserve_format)
- state["exp_avg_sq"] = torch.zeros_like(p, memory_format=torch.preserve_format)
- adamw_exp_avgs.append(state["exp_avg"])
- adamw_exp_avg_sqs.append(state["exp_avg_sq"])
- adamw_state_steps.append(state["step"])
- # Apply Muon updates
- if muon_params:
- muon(
- muon_params,
- muon_grads,
- muon_momentum_bufs,
- lr=group["lr"],
- weight_decay=group["weight_decay"],
- momentum=group["momentum"],
- nesterov=group["nesterov"],
- ns_steps=group["ns_steps"],
- ns_coefficients=group["ns_coefficients"],
- eps=group["eps"],
- safety_factor=group["safety_factor"],
- adjust_lr_fn=group["adjust_lr_fn"],
- conv_mode=group["conv_mode"],
- normalize_spatial=group["normalize_spatial"],
- )
- # Apply AdamW updates
- if adamw_params:
- beta1, beta2 = group["betas"]
- if group["nesterov"]:
- # use nadamw for fallback optimizer if nesterov is enabled
- nadamw(
- adamw_params,
- adamw_grads,
- adamw_exp_avgs,
- adamw_exp_avg_sqs,
- adamw_state_steps,
- foreach=None,
- beta1=beta1,
- beta2=beta2,
- lr=group["adamw_lr"],
- weight_decay=group["weight_decay"],
- eps=group["eps"],
- caution=False,
- maximize=False,
- capturable=False,
- max_lr=None,
- )
- else:
- adamw(
- adamw_params,
- adamw_grads,
- adamw_exp_avgs,
- adamw_exp_avg_sqs,
- [], # max_exp_avg_sqs (not using amsgrad)
- adamw_state_steps,
- foreach=None,
- amsgrad=False,
- beta1=beta1,
- beta2=beta2,
- lr=group["adamw_lr"],
- weight_decay=group["weight_decay"],
- eps=group["eps"],
- caution=False,
- maximize=False,
- capturable=False,
- max_lr=None,
- )
- # Log routing summary when we have new routing decisions
- if routing_reasons and len(routing_reasons) > 0:
- # Concise summary
- _logger.info(f"Muon parameter routing: {muon_count} Muon, {adamw_count} AdamW")
- # Group by reason for detailed breakdown
- reason_groups = {}
- for shape_str, reasons in sorted(routing_reasons.items()):
- for reason in reasons:
- if reason not in reason_groups:
- reason_groups[reason] = []
- reason_groups[reason].append(shape_str)
- # Log summary counts per reason
- reason_summary = []
- for reason, shapes in sorted(reason_groups.items()):
- reason_summary.append(f"{reason}={len(shapes)}")
- _logger.info(f" Breakdown: {', '.join(reason_summary)}")
- # Detailed breakdown at INFO level
- if _logger.isEnabledFor(logging.INFO):
- for reason, shapes in sorted(reason_groups.items()):
- optimizer_name = "Muon" if reason == "ok" else "AdamW"
- _logger.info(f" {reason} -> {optimizer_name}:")
- for shape in shapes[:10]:
- _logger.info(f" {shape}")
- if len(shapes) > 10:
- _logger.info(f" ... and {len(shapes) - 10} more")
- return loss
- def resolve_ns_coefficients(
- value: Union[str, Sequence[float], Sequence[Sequence[float]]],
- presets: Mapping[str, Sequence[Sequence[float]]]
- ) -> List[Tuple[float, float, float]]:
- # tiny helpers (kept inline for succinctness)
- is_seq = lambda x: isinstance(x, Sequence) and not isinstance(x, (str, bytes))
- is_real = lambda x: isinstance(x, numbers.Real) and not isinstance(x, bool)
- def as_coeff(x: Sequence[float]) -> Tuple[float, float, float]:
- if not is_seq(x) or len(x) != 3 or not all(is_real(v) for v in x):
- raise ValueError(f"Coefficient must be length-3 of real numbers, got: {x!r}")
- a, b, c = x # type: ignore[misc]
- return float(a), float(b), float(c)
- if isinstance(value, str):
- if value not in presets:
- valid = ", ".join(sorted(presets.keys()))
- raise ValueError(f"Unknown coefficients preset '{value}'. Valid options: {valid}")
- seq = presets[value]
- if not is_seq(seq) or len(seq) == 0:
- raise ValueError(f"Preset '{value}' is empty or invalid")
- return [as_coeff(item) for item in seq] # validate & cast
- if not is_seq(value):
- raise TypeError(
- "Coefficients must be a preset name (str), a 3-sequence (a,b,c), "
- "or a sequence of 3-sequences."
- )
- # Decide single triple vs list-of-triples by structure
- if len(value) == 3 and all(is_real(v) for v in value): # type: ignore[index]
- return [as_coeff(value)] # single triple -> wrap
- # Otherwise treat as list/tuple of triples
- out = []
- for i, item in enumerate(value): # type: ignore[assignment]
- if not is_seq(item):
- raise TypeError(f"Item {i} is not a sequence: {item!r}")
- out.append(as_coeff(item))
- if not out:
- raise ValueError("Coefficient list cannot be empty")
- return out
|