| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629 |
- # Copyright 2025 NXAI GmbH. All rights reserved.
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- """PyTorch xLSTM Model."""
- from dataclasses import dataclass
- from typing import Optional, Union
- import torch
- import torch.nn.functional as F
- from torch import nn
- from torch.nn import CrossEntropyLoss
- from ...generation import GenerationMixin
- from ...modeling_utils import PreTrainedModel
- from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available
- from .configuration_xlstm import xLSTMConfig
- if is_xlstm_available():
- from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
- from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock
- from xlstm.xlstm_large.model import mLSTMStateType, soft_cap
- external_xlstm = True
- else:
- from functools import partial
- from typing import Callable, Literal
- from .configuration_xlstm import round_up_to_next_multiple_of
- mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
- mLSTMStateType = dict[int, mLSTMLayerStateType]
- external_xlstm = False
- def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
- """
- Soft caps a tensor to a value.
- Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
- and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
- https://huggingface.co/papers/2408.00118
- Args:
- values: The tensor to cap.
- cap_value: The value to cap the values to. If None, no cap is applied.
- Returns:
- The capped values.
- """
- if cap_value is None:
- return values
- return cap_value * torch.tanh(values / cap_value)
- def mlstm_chunkwise_recurrent_fw_C(
- matK: torch.Tensor,
- matV: torch.Tensor,
- vecB: torch.Tensor,
- vecI: torch.Tensor,
- matC_states: Optional[torch.Tensor] = None,
- vecN_states: Optional[torch.Tensor] = None,
- scaMinter_states: Optional[torch.Tensor] = None,
- matC_initial: Optional[torch.Tensor] = None,
- vecN_initial: Optional[torch.Tensor] = None,
- scaMinter_initial: Optional[torch.Tensor] = None,
- qk_scale: Optional[float] = None,
- chunk_size: int = 64,
- num_chunks: int = 1,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
- nc = num_chunks
- _dtype, _device = matK.dtype, matK.device
- if qk_scale is None:
- qk_scale = dhqk**-0.5
- # initialize the states tensors
- if matC_states is None:
- matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
- if vecN_states is None:
- vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
- if scaMinter_states is None:
- scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
- # assign the initial states to the running states
- matC_k = (
- torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
- if matC_initial is None
- else matC_initial
- )
- vecN_k = (
- torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
- )
- scaM_inter_k = (
- torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
- if scaMinter_initial is None
- else scaMinter_initial
- )
- vecA = vecB[..., -1, None] - vecB + vecI
- scaG = vecB[..., -1]
- scaA_max = vecA.max(-1).values
- scaM_inter_k = scaM_inter_k.squeeze(-1)
- for key in range(0, num_chunks):
- # store the states from the previous iteration before updating them
- # in the first iteration, these are the initial states
- matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
- vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
- scaMinter_states[:, :, key] = scaM_inter_k
- # m_k update
- scaA_max_k = scaA_max[:, :, key]
- scaG_k = scaG[:, :, key]
- scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
- # C_k update
- matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
- matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
- vecA_k = vecA[:, :, key, :]
- vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
- matK_chunk_gated = matK_chunk * vecAbar_k
- scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
- # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
- matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
- # n_k update
- vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
- # move to the next iteration
- scaM_inter_k = scaM_inter_k_next
- matC_k = matC_k_next
- vecN_k = vecN_k_next
- # store the states from the last iteration
- matC_states[:, :, -dhqk:, :] = matC_k
- vecN_states[:, :, -dhqk:] = vecN_k
- scaMinter_states[:, :, -1] = scaM_inter_k
- return matC_states, vecN_states, scaMinter_states
- def mlstm_chunkwise_parallel_fw_H(
- matQ: torch.Tensor,
- matK: torch.Tensor,
- matV: torch.Tensor,
- # these states must be all states up to the last chunk, i.e. :-1
- matC_states: torch.Tensor,
- vecN_states: torch.Tensor,
- scaMinter_states: torch.Tensor,
- vecI: torch.Tensor,
- vecB: torch.Tensor,
- qk_scale: float,
- chunk_size: int = 64,
- num_chunks: int = 1,
- eps: float = 1e-6,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- _device = matQ.device
- nc = num_chunks
- batch_size, nh, dqk, dhv = matC_states.shape
- matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv)
- vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc)
- scaMinter_k_states = scaMinter_states
- matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk)
- matK = matK.view(batch_size, nh, nc, chunk_size, dqk)
- matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
- ltr = torch.tril(
- torch.ones(
- (chunk_size, chunk_size),
- dtype=torch.bool,
- device=_device,
- )
- )
- # Compute intra chunk contribution: H_intra
- matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
- matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
- matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
- # max_state intra
- vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
- # max_state combined
- vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
- vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
- vecM_k_combine = vecM_k_combine[:, :, :, :, None]
- vecM_b_inter = vecM_b_inter[:, :, :, :, None]
- matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
- matD_chunk = torch.exp(matLogD_stabilized_chunk)
- matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
- matM_chunk = matS_chunk * matD_chunk
- # ? Combine H_intra with H_inter
- vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
- matQ_chunk_gated = matQ * vecBbar * qk_scale
- matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
- vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
- vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
- matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
- matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
- # we need the denominator and the overall max state for the backward pass
- vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
- vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size)
- return matH_out, vecN_out, vecM_out
- def mlstm_chunkwise_fw(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- cstate: Optional[torch.Tensor] = None,
- nstate: Optional[torch.Tensor] = None,
- mstate: Optional[torch.Tensor] = None,
- qk_scale: Optional[float] = None,
- return_last_states: bool = False,
- return_all_states: bool = False,
- chunk_size: int = 64,
- eps: float = 1e-6,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
- Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
- ]:
- batch_size, nh, sequence_length, dhqk = query.shape
- if sequence_length % chunk_size != 0:
- raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
- nc = sequence_length // chunk_size
- vecI = igate.view(batch_size, nh, nc, chunk_size)
- vecF = fgate.view(batch_size, nh, nc, chunk_size)
- # compute the gates, the g and the a and b vectors
- vecF_logsig = fgate.logsigmoid(vecF)
- vecB = vecF_logsig.cumsum(-1)
- if qk_scale is None:
- qk_scale = dhqk**-0.5
- #! materialize the C_k, n_k, m_k states for each chunk
- matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
- matK=key,
- matV=value,
- vecB=vecB,
- vecI=vecI,
- matC_initial=cstate,
- vecN_initial=nstate,
- scaMinter_initial=mstate,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- )
- #! compute the outputs within each chunk
- matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
- matQ=query,
- matK=key,
- matV=value,
- matC_states=matC_k_states[:, :, :-dhqk, :],
- vecN_states=vecN_k_states[:, :, :-dhqk],
- scaMinter_states=scaMinter_k_states[:, :, :-1],
- vecI=vecI,
- vecB=vecB,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- eps=eps,
- )
- ret_tuple = (matH_out, vecN_out, vecM_out)
- if return_last_states:
- ret_tuple += (
- (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
- )
- else:
- ret_tuple += (None,)
- if return_all_states:
- ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
- else:
- ret_tuple += (None,)
- return ret_tuple
- def mlstm_chunkwise_native_autograd(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: Optional[torch.Tensor] = None,
- n_initial: Optional[torch.Tensor] = None,
- m_initial: Optional[torch.Tensor] = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- chunk_size: int = 64,
- **kwargs,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
- batch_size, nh, sequence_length, dhqk = query.shape
- if sequence_length % chunk_size != 0:
- raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
- nc = sequence_length // chunk_size
- vecI = igate.view(batch_size, nh, nc, chunk_size)
- vecF = fgate.view(batch_size, nh, nc, chunk_size)
- # compute the gates, the g and the a and b vectors
- vecF_logsig = F.logsigmoid(vecF)
- vecB = vecF_logsig.cumsum(-1)
- qk_scale = dhqk**-0.5
- #! materialize the C_k, n_k, m_k states for each chunk
- matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
- matK=key,
- matV=value,
- vecB=vecB,
- vecI=vecI,
- matC_initial=c_initial,
- vecN_initial=n_initial,
- scaMinter_initial=m_initial,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- )
- #! compute the outputs within each chunk
- matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
- matQ=query,
- matK=key,
- matV=value,
- matC_states=matC_k_states[:, :, :-dhqk, :],
- vecN_states=vecN_k_states[:, :, :-dhqk],
- scaMinter_states=scaMinter_k_states[:, :, :-1],
- vecI=vecI,
- vecB=vecB,
- qk_scale=qk_scale,
- chunk_size=chunk_size,
- num_chunks=nc,
- eps=eps,
- )
- last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
- if return_last_states:
- return matH_out, last_states
- else:
- return matH_out
- def mlstm_recurrent_step_native(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- cstate: torch.Tensor,
- nstate: torch.Tensor,
- mstate: torch.Tensor,
- eps: float = 1e-6,
- dtype_state: torch.dtype = torch.float32,
- **kwargs,
- ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
- """This is a single step of the mLSTM operation in recurrent form."""
- dtype_qkv = query.dtype
- matC_old = cstate.to(dtype=dtype_state)
- vecN_old = nstate.to(dtype=dtype_state)
- scaM_old = mstate.to(dtype=dtype_state)
- batch_size, nh, dhqk = query.shape
- _, _, dhhv = value.shape
- if query.shape != key.shape:
- raise ValueError("query and key must have the same shape")
- if matC_old.shape != (batch_size, nh, dhqk, dhhv):
- raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
- if vecN_old.shape != (batch_size, nh, dhqk):
- raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
- if scaM_old.shape != (batch_size, nh, 1):
- raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
- if igate.shape != (batch_size, nh, 1):
- raise ValueError(f"scaI has wrong shape, got {igate.shape}")
- if fgate.shape != (batch_size, nh, 1):
- raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
- # gates
- scaF_log = torch.nn.functional.logsigmoid(fgate)
- # update rule
- scaM_state_new = torch.max(scaF_log + scaM_old, igate)
- scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
- scaI_act = torch.exp(igate - scaM_state_new)
- vecQ_scaled = query * (dhqk ** (-0.5))
- matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * (
- key[:, :, :, None] @ value[:, :, None, :]
- )
- vecN_state_new = scaF_act * vecN_old + scaI_act * key
- h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
- h_num = h_num.squeeze(2).to(dtype=dtype_state)
- qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
- qn_dotproduct = qn_dotproduct.squeeze(2)
- max_val = torch.exp(-scaM_state_new)
- h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
- h = h_num / h_denom
- h = h.to(dtype=dtype_qkv)
- matC_state_new = matC_state_new.to(dtype=dtype_state)
- vecN_state_new = vecN_state_new.to(dtype=dtype_state)
- scaM_state_new = scaM_state_new.to(dtype=dtype_state)
- return h, (matC_state_new, vecN_state_new, scaM_state_new)
- def mlstm_recurrent_sequence_native(
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: Optional[torch.Tensor] = None,
- n_initial: Optional[torch.Tensor] = None,
- m_initial: Optional[torch.Tensor] = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- dtype_state: torch.dtype = torch.float32,
- **kwargs,
- ) -> tuple[
- torch.Tensor,
- torch.Tensor,
- torch.Tensor,
- Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
- Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
- ]:
- batch_size, nh, sequence_length, dhqk = query.shape
- dhv = value.shape[-1]
- device = query.device
- if c_initial is not None:
- if n_initial is None or m_initial is None:
- raise ValueError("Initial states must be provided together.")
- if n_initial is None or m_initial is None:
- raise ValueError("Initial states must be provided together.")
- matC_state, vecN_state, vecM_state = (
- c_initial.to(dtype=dtype_state),
- n_initial.to(dtype=dtype_state),
- m_initial.to(dtype=dtype_state),
- )
- else:
- # memory state
- matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
- # normalizer state
- vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
- # max state
- vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
- vecH_list = []
- for t in range(sequence_length):
- # gates
- vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
- # projections
- vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
- # step
- vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
- cstate=matC_state,
- nstate=vecN_state,
- mstate=vecM_state,
- query=vecQ_t,
- key=vecK_t,
- value=vecV_t,
- igate=vecI_t,
- fgate=vecF_t,
- eps=eps,
- dtype_state=dtype_state,
- **kwargs,
- )
- vecH_list.append(vecH)
- matH = torch.stack(vecH_list, dim=-2)
- if return_last_states:
- return matH, (matC_state, vecN_state, vecM_state)
- else:
- return matH
- def wrap_chunkwise_pad_zeros(
- mlstm_chunkwise_kernel: Callable,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- fgate: torch.Tensor,
- igate: torch.Tensor,
- c_initial: Optional[torch.Tensor] = None,
- n_initial: Optional[torch.Tensor] = None,
- m_initial: Optional[torch.Tensor] = None,
- return_last_states: bool = False,
- eps: float = 1e-6,
- autocast_kernel_dtype: torch.dtype = torch.bfloat16,
- chunk_size: int = 64,
- **kwargs,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
- if return_last_states:
- raise ValueError(
- "We are padding zeros, so we cannot return last states,",
- "as they would be not the true last states.",
- )
- batch_size, nh, sequence_length, dhqk = query.shape
- S_unpadded = sequence_length
- # padding to chunk size for kernels
- if sequence_length % chunk_size != 0:
- S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
- q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
- k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
- v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
- i_pad = igate.new_zeros(batch_size, nh, S_padded)
- f_pad = fgate.new_zeros(batch_size, nh, S_padded)
- q_pad[:, :, :S_unpadded, :] = query
- k_pad[:, :, :S_unpadded, :] = key
- v_pad[:, :, :S_unpadded, :] = value
- i_pad[:, :, :S_unpadded] = igate
- f_pad[:, :, :S_unpadded] = fgate
- else:
- q_pad = query
- k_pad = key
- v_pad = value
- i_pad = igate
- f_pad = fgate
- matH = mlstm_chunkwise_kernel(
- query=q_pad,
- key=k_pad,
- value=v_pad,
- igate=i_pad,
- fgate=f_pad,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- return_last_states=return_last_states,
- eps=eps,
- autocast_kernel_dtype=autocast_kernel_dtype,
- chunk_size=chunk_size,
- **kwargs,
- )
- matH = matH[:, :, :S_unpadded, :]
- return matH
- def wrap_chunkwise_arbitrary_sequence_length(
- mlstm_chunkwise_kernel: Callable,
- mlstm_sequence_kernel: Callable,
- mlstm_step_kernel: Callable,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- fgate: torch.Tensor,
- igate: torch.Tensor,
- c_initial: Optional[torch.Tensor] = None,
- n_initial: Optional[torch.Tensor] = None,
- m_initial: Optional[torch.Tensor] = None,
- return_last_states: bool = True,
- eps: float = 1e-6,
- autocast_kernel_dtype: torch.dtype = torch.bfloat16,
- chunk_size: int = 64,
- enable_logging: bool = False,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
- """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
- For this it uses three kernels:
- - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
- - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
- - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
- It tries to maximize the chunksizes to improve performance.
- It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
- At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
- E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
- For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
- in a single step and loop over this in pytorch.
- Args:
- mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
- mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
- query: The query tensor (batch_size, nh, sequence_length, dhqk)
- key: The key tensor (batch_size, nh, sequence_length, dhqk)
- value: The value tensor (batch_size, nh, sequence_length, dhhv)
- fgate: The forget gate tensor (batch_size, nh, sequence_length)
- igate: The input gate tensor (batch_size, nh, sequence_length)
- c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
- n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
- m_initial: The initial memory state tensor (batch_size, nh, 1)
- return_last_states: If True, the function will return the last states of the mLSTM
- eps: The epsilon value used for numerical stability
- autocast_kernel_dtype: The dtype used for the kernel computation
- chunk_size: The chunk size used for the chunkwise kernel
- enable_logging: If True, the function will log debug information. Default is False.
- Returns:
- The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
- Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
- """
- batch_size, nh, sequence_length, dhqk = key.shape
- dhhv = value.shape[-1]
- c_state = (
- c_initial
- if c_initial is not None
- else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
- )
- n_state = (
- n_initial
- if n_initial is not None
- else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
- )
- m_state = (
- m_initial
- if m_initial is not None
- else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
- )
- if sequence_length > 1:
- # process the sequence length in chunks
- h_outs = []
- seq_len_start_idx = 0
- remaining_seq_len = sequence_length - seq_len_start_idx
- num_chunks = remaining_seq_len // chunk_size
- if num_chunks > 0:
- iter_seq_len = chunk_size * num_chunks
- seq_len_idx = seq_len_start_idx + iter_seq_len
- h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
- query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
- fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
- igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
- c_initial=c_state,
- n_initial=n_state,
- m_initial=m_state,
- chunk_size=chunk_size,
- return_last_states=True,
- autocast_kernel_dtype=autocast_kernel_dtype,
- eps=eps,
- )
- seq_len_start_idx += iter_seq_len
- h_outs.append(h_out)
- remaining_seq_len = sequence_length - seq_len_start_idx
- if remaining_seq_len > 0:
- # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
- h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
- query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
- key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
- value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
- igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
- fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
- c_initial=c_state,
- n_initial=n_state,
- m_initial=m_state,
- return_last_states=True,
- eps=eps,
- )
- h_outs.append(h_out)
- h_out = torch.concatenate(h_outs, dim=2)
- else:
- if sequence_length != 1:
- raise ValueError(
- f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
- )
- # process the sequence length in a single step
- # while this case is also captured by the regular mode above,
- # it avoids the overhead of the loop and calls the step kernel directly
- # The step function does not want a sequence dimension
- # qkv shape is (batch_size, nh, dhqk/dhv)
- # igate, fgate shape is (batch_size, nh, 1)
- h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
- query=query.squeeze(2),
- key=key.squeeze(2),
- value=value.squeeze(2),
- igate=igate,
- fgate=fgate,
- cstate=c_state,
- nstate=n_state,
- mstate=m_state,
- eps=eps,
- )
- h_out = h_out[:, :, None, :]
- if return_last_states:
- return h_out, (c_state, n_state, m_state)
- else:
- return h_out
- class xLSTMBackend(nn.Module):
- """xLSTM Backend Module for PyTorch.
- This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
- """
- config_class = xLSTMConfig
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
- self.sequence_kernel_fn = mlstm_recurrent_sequence_native
- self.step_kernel_fn = mlstm_recurrent_step_native
- self._inference_fn = partial(
- wrap_chunkwise_arbitrary_sequence_length,
- mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
- mlstm_sequence_kernel=partial(
- self.sequence_kernel_fn,
- dtype_state=getattr(torch, config.inference_state_dtype),
- ),
- mlstm_step_kernel=partial(
- self.step_kernel_fn,
- dtype_state=getattr(torch, config.inference_state_dtype),
- ),
- chunk_size=config.chunk_size,
- eps=config.eps,
- autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
- return_last_states=True,
- )
- train_kernel_fn = partial(
- self.chunkwise_kernel_fn,
- autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
- eps=config.eps,
- chunk_size=config.chunk_size,
- )
- if "with_padding" in config.mode:
- train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
- self._train_fn = train_kernel_fn
- def forward(
- self,
- query: torch.Tensor,
- key: torch.Tensor,
- value: torch.Tensor,
- igate: torch.Tensor,
- fgate: torch.Tensor,
- c_initial: Optional[torch.Tensor] = None,
- n_initial: Optional[torch.Tensor] = None,
- m_initial: Optional[torch.Tensor] = None,
- return_last_states: bool = False,
- mode: Optional[Literal["train", "inference"]] = None,
- ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
- """Forward pass of the mLSTM backend.
- Depending on the configured mode, this method will call the appropriate kernel function.
- Args:
- query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
- key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
- value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
- igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
- fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
- c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
- Defaults to None.
- n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
- m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
- return_last_states: Whether to return the last states of the sequence. Defaults to None.
- If None, the value from the config is used.
- Returns:
- hidden states of shape (batch_size, nh, sequence_length, dhhv)
- hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
- the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
- """
- if mode is None:
- mode = self.config.mode
- if "train" in mode:
- if return_last_states is None:
- return_last_states = self.config.return_last_states
- if self.config.mode == "train_with_padding":
- if return_last_states:
- raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
- return self._train_fn(
- query=query,
- key=key,
- value=value,
- igate=igate,
- fgate=fgate,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- return_last_states=return_last_states,
- )
- elif "inference" in mode:
- # inference mode always returns the last states
- return self._inference_fn(
- query=query,
- key=key,
- value=value,
- igate=igate,
- fgate=fgate,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- )
- else:
- raise ValueError(f"Unknown mode: {self.config.mode}")
- def extra_repr(self) -> str:
- return f"{self.config}"
- class xLSTMRMSNorm(nn.Module):
- """Root mean square normalization layer implementation similar
- to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
- It normalizes the input tensor by the root mean square of the last dimension.
- Args:
- num_features: The number of features in the input tensor.
- eps: A small value to avoid division by zero.
- use_weight: Whether to use a learnable weight.
- use_bias: Whether to use a learnable bias.
- force_float32_reductions: Whether to force float32 reductions.
- """
- def __init__(
- self,
- num_features: int,
- eps: float = 1e-6,
- use_weight: bool = True,
- use_bias: bool = False,
- force_float32_reductions: bool = True,
- ):
- super().__init__()
- self.num_features = num_features
- self.eps = eps
- self.force_float32_reductions = force_float32_reductions
- if use_weight:
- self.weight = nn.Parameter(torch.ones(num_features))
- else:
- self.weight = None
- if use_bias:
- self.bias = nn.Parameter(torch.zeros(num_features))
- else:
- self.bias = None
- def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight is not None:
- x = x * self.weight
- if self.bias is not None:
- x = x + self.bias
- return x
- def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
- # apply rms norm over the last dimension, i.e. HD dimension
- in_dtype = x.dtype
- if self.force_float32_reductions:
- x = x.float()
- x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
- return x.to(in_dtype)
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- x = self._rms_normalize(x)
- x = self._apply_weight_bias(x)
- return x
- class xLSTMMultiHeadLayerNorm(nn.Module):
- """Multi-head version of the LayerNorm layer.
- It normalizes the last dimension of the input tensor.
- The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
- batch_size: batch size
- sequence_length: sequence length
- nh: number of heads
- DH: head dimension
- The normalization is applied over the last dimension (DH) of the input tensor.
- Args:
- num_heads: The number of heads.
- head_dim: The head dimension.
- eps: A small value to avoid division by zero.
- use_weight: Whether to use a learnable weight.
- use_bias: Whether to use a learnable bias.
- force_float32_reductions: Whether to force float32 reductions
- Returns:
- The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
- """
- def __init__(
- self,
- num_heads: int,
- head_dim: int,
- eps: float = 1e-6,
- use_weight: bool = True,
- use_bias: bool = False,
- force_float32_reductions: bool = True,
- ):
- super().__init__()
- self.num_features = num_heads * head_dim
- self.eps = eps
- self.force_float32_reductions = force_float32_reductions
- if use_weight:
- self.weight = nn.Parameter(torch.ones(self.num_features))
- else:
- self.weight = None
- if use_bias:
- self.bias = nn.Parameter(torch.zeros(self.num_features))
- else:
- self.bias = None
- self.num_heads = num_heads
- self.head_dim = head_dim
- def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
- if self.weight is not None:
- x = x * self.weight
- if self.bias is not None:
- x = x + self.bias
- return x
- def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
- # apply layer norm over the last dimension, i.e. HD dimension
- in_dtype = x.dtype
- if self.force_float32_reductions:
- x = x.float()
- x_centered = x - x.mean(dim=-1, keepdim=True)
- y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
- return y.to(in_dtype)
- def forward(
- self,
- x: torch.Tensor,
- ) -> torch.Tensor:
- batch_size, sequence_length, nh, DH = x.shape
- if nh != self.num_heads:
- raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
- if DH != self.head_dim:
- raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
- x = self._layer_normalize(x)
- x = x.reshape(batch_size, sequence_length, -1)
- x = self._apply_weight_bias(x)
- return x
- class xLSTMFeedForward(nn.Module):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.up_proj_dim = round_up_to_next_multiple_of(
- config.hidden_size * config.ffn_proj_factor,
- config.ffn_round_up_to_multiple_of,
- )
- if self.config.weight_mode == "single":
- self.proj_up_gate = nn.Linear(
- in_features=config.hidden_size,
- out_features=self.up_proj_dim,
- bias=self.config.use_bias,
- )
- self.proj_up = nn.Linear(
- in_features=config.hidden_size,
- out_features=self.up_proj_dim,
- bias=self.config.use_bias,
- )
- elif self.config.weight_mode == "fused":
- self.proj_up_gate_z = nn.Linear(
- in_features=config.hidden_size,
- out_features=2 * self.up_proj_dim,
- bias=self.config.use_bias,
- )
- self.proj_down = nn.Linear(
- in_features=self.up_proj_dim,
- out_features=config.hidden_size,
- bias=self.config.use_bias,
- )
- self.act_fn = nn.SiLU()
- def forward(self, x: torch.Tensor) -> torch.Tensor:
- if self.config.weight_mode == "single":
- x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
- elif self.config.weight_mode == "fused":
- x = self.proj_up_gate_z(x)
- gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
- x = self.act_fn(gate) * z
- y = self.proj_down(x)
- return y
- class xLSTMLayer(nn.Module):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.v_dim = int(config.hidden_size * config.v_dim_factor)
- self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
- if self.config.weight_mode == "single":
- self.q = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.qk_dim,
- bias=self.config.use_bias,
- )
- self.k = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.qk_dim,
- bias=self.config.use_bias,
- )
- self.v = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.v_dim,
- bias=self.config.use_bias,
- )
- self.ogate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.v_dim,
- bias=self.config.use_bias,
- )
- self.igate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.config.num_heads,
- bias=True,
- )
- self.fgate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=self.config.num_heads,
- bias=True,
- )
- elif self.config.weight_mode == "fused":
- self.qkv_opreact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=2 * self.qk_dim + 2 * self.v_dim,
- bias=self.config.use_bias,
- )
- self.ifgate_preact = nn.Linear(
- in_features=self.config.hidden_size,
- out_features=2 * self.config.num_heads,
- bias=True,
- )
- self.ogate_act_fn = nn.Sigmoid()
- self.mlstm_backend = xLSTMBackend(config=self.config)
- self.multihead_norm = xLSTMMultiHeadLayerNorm(
- num_heads=self.config.num_heads,
- head_dim=self.v_dim // self.config.num_heads,
- eps=self.config.norm_eps,
- use_weight=True,
- use_bias=self.config.use_bias,
- force_float32_reductions=self.config.norm_reduction_force_float32,
- )
- self.out_proj = nn.Linear(
- in_features=self.v_dim,
- out_features=self.config.hidden_size,
- bias=self.config.use_bias,
- )
- def forward(
- self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
- ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
- if x.ndim != 3:
- raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
- batch_size, sequence_length, _ = x.shape
- if self.config.weight_mode == "single":
- query = self.q(x)
- key = self.k(x)
- value = self.v(x)
- o_preact = self.ogate_preact(x)
- i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
- f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
- elif self.config.weight_mode == "fused":
- qkv_opreact = self.qkv_opreact(x)
- query, key, value, o_preact = torch.tensor_split(
- qkv_opreact,
- (
- self.qk_dim,
- 2 * self.qk_dim,
- 2 * self.qk_dim + self.v_dim,
- ),
- dim=-1,
- )
- if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
- i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
- query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
- i_preact = i_preact.transpose(1, 2)
- f_preact = f_preact.transpose(1, 2)
- if state is None:
- c_initial, n_initial, m_initial = None, None, None
- else:
- c_initial, n_initial, m_initial = state
- h, state = self.mlstm_backend(
- query=query,
- key=key,
- value=value,
- igate=i_preact,
- fgate=f_preact,
- c_initial=c_initial,
- n_initial=n_initial,
- m_initial=m_initial,
- )
- expected_h_shape = (
- batch_size,
- self.config.num_heads,
- sequence_length,
- self.v_dim // self.config.num_heads,
- )
- if h.shape != expected_h_shape:
- raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
- h = h.transpose(1, 2)
- h_norm = self.multihead_norm(h)
- h_norm = h_norm.reshape(batch_size, sequence_length, -1)
- h_out = self.ogate_act_fn(o_preact) * h_norm
- y = self.out_proj(h_out)
- return y, state
- class xLSTMBlock(nn.Module):
- def __init__(self, config: xLSTMConfig):
- super().__init__()
- self.config = config
- self.norm_mlstm = xLSTMRMSNorm(
- num_features=config.hidden_size,
- eps=config.norm_eps,
- use_weight=True,
- use_bias=config.use_bias,
- force_float32_reductions=config.norm_reduction_force_float32,
- )
- self.mlstm_layer = xLSTMLayer(config)
- self.norm_ffn = xLSTMRMSNorm(
- num_features=config.hidden_size,
- eps=config.norm_eps,
- use_weight=True,
- use_bias=config.use_bias,
- force_float32_reductions=config.norm_reduction_force_float32,
- )
- self.ffn = xLSTMFeedForward(config)
- def forward(
- self, x: torch.Tensor, state: Optional[mLSTMStateType] = None
- ) -> tuple[torch.Tensor, mLSTMStateType]:
- x_mlstm = self.norm_mlstm(x)
- x_mlstm, state = self.mlstm_layer(x_mlstm, state)
- x = x + x_mlstm
- x_ffn = self.norm_ffn(x)
- x_ffn = self.ffn(x_ffn)
- x = x + x_ffn
- return x, state
- def small_init_method(dim):
- """
- Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
- Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
- the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
- std = (2 / (5 * dim)) ** (1 / 2)
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=std)
- return init_
- def wang_init_method(n_layers, dim):
- """
- Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
- """
- std = 2 / n_layers / dim ** (1 / 2)
- def init_(tensor):
- return torch.nn.init.normal_(tensor, mean=0.0, std=std)
- return init_
- class xLSTMPreTrainedModel(PreTrainedModel):
- """
- An abstract class for an interface to loading a pre-trained xLSTM model.
- """
- config_class = xLSTMConfig
- base_model_prefix = "backbone"
- _no_split_modules = ["xLSTMBlock"]
- supports_gradient_checkpointing = True
- _is_stateful = True
- def _module_name_map(self, module):
- for name, mod in self.named_modules():
- if mod is module:
- return name
- return ""
- def _init_weights(self, module):
- if isinstance(module, nn.Embedding):
- small_init_method(self.config.hidden_size)(self.embeddings.weight)
- elif isinstance(module, nn.Linear):
- if module.bias is not None:
- torch.nn.init.zeros_(module.bias)
- if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
- torch.nn.init.zeros_(module.weight)
- with torch.no_grad():
- if "igate" in self._module_name_map(module):
- module.bias.copy_(-10.0 * torch.ones_like(module.bias))
- elif "fgate" in self._module_name_map(module):
- module.bias.copy_(
- torch.linspace(
- 3.0,
- 6.0,
- module.bias.shape[-1],
- ).to(
- device=module.bias.device,
- dtype=module.bias.dtype,
- )
- )
- elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
- torch.nn.init.zeros_(module.weight)
- with torch.no_grad():
- module.bias[: self.config.num_heads] += -module.bias[
- : self.config.num_heads
- ] - 10.0 * torch.ones_like(module.bias)
- module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace(
- 3.0,
- 6.0,
- module.bias.shape[-1],
- ).to(
- device=module.bias.device,
- dtype=module.bias.dtype,
- )
- elif "proj_down" in self._module_name_map(module):
- wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
- elif "out_proj" in self._module_name_map(module):
- wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
- elif module.weight is not None:
- small_init_method(self.config.hidden_size)(module.weight)
- elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
- torch.nn.init.ones_(module.weight)
- if hasattr(module, "bias") and module.bias is not None:
- torch.nn.init.zeros_(module.bias)
- class xLSTMCache:
- """
- Cache for xLSTM model which does not have attention mechanism and key value states.
- Arguments:
- config (`PretrainedConfig):
- The configuration file defining the shape-related attributes required to initialize the static cache.
- max_batch_size (`int`):
- The batch size with which the model will be used.
- dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
- The default `dtype` to use when initializing the layer.
- device (`torch.device` or `str`, *optional*):
- The device on which the cache should be initialized. Should be the same as the layer.
- Attributes:
- seqlen_offset: int
- dtype: torch.dtype
- Example:
- ```python
- >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
- >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
- >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
- >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
- >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
- >>> outputs.cache_params
- xLSTMCache()
- """
- def __init__(
- self,
- config: xLSTMConfig,
- max_batch_size: int,
- dtype: torch.dtype = torch.bfloat16,
- device: Optional[str] = None,
- **kwargs,
- ):
- self.seqlen_offset = 0
- self.dtype = dtype
- self.config = config
- self.rnn_state = {
- layer: (
- torch.zeros(
- [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
- dtype=dtype,
- device=device,
- ),
- torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
- torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
- )
- for layer in range(config.num_hidden_layers)
- }
- def reset(self):
- self.rnn_state = {
- layer: (
- torch.zeros_like(self.rnn_state[layer][0]),
- torch.zeros_like(self.rnn_state[layer][1]),
- torch.zeros_like(self.rnn_state[layer][2]),
- )
- for layer in self.rnn_state
- }
- @dataclass
- @auto_docstring
- class xLSTMOutput(ModelOutput):
- r"""
- cache_params (`xLSTMCache`):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- last_hidden_state: Optional[torch.FloatTensor]
- cache_params: Optional[xLSTMCache] = None
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
- @auto_docstring
- class xLSTMModel(xLSTMPreTrainedModel):
- def __init__(self, config):
- super().__init__(config)
- # use embbeding_dim and num_blocks once here to make use of them
- self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
- self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
- self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
- self.gradient_checkpointing = False
- # Initialize weights and apply final processing
- self.post_init()
- def get_input_embeddings(self):
- return self.embeddings
- def set_input_embeddings(self, new_embedding):
- self.embeddings = new_embedding
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.LongTensor] = None,
- cache_params: Optional[xLSTMCache] = None,
- use_cache: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple, xLSTMOutput]:
- r"""
- cache_params (`xLSTMCache`, *optional*):
- The xLSTMCache that carries the RNN states.
- """
- output_hidden_states = (
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
- )
- use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
- if self.gradient_checkpointing and self.training and use_cache:
- use_cache = False
- if (input_ids is None) ^ (inputs_embeds is not None):
- raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
- if inputs_embeds is None:
- inputs_embeds = self.embeddings(input_ids)
- if use_cache and cache_params is None:
- cache_params = xLSTMCache(
- self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
- )
- hidden_states = inputs_embeds
- if (
- not self.training
- and self.config.max_inference_chunksize < hidden_states.shape[1]
- and not output_hidden_states
- ):
- offset = 0
- with torch.no_grad():
- if cache_params is None:
- cache_params = xLSTMCache(config=self.config, max_batch_size=hidden_states.shape[0])
- final_state = torch.zeros_like(hidden_states)
- while offset < hidden_states.shape[1]:
- hidden_states_chunk = hidden_states[
- :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
- ]
- for layer_idx, xlstm_block in enumerate(self.blocks):
- hidden_states_chunk, rnn_state = xlstm_block(
- hidden_states_chunk,
- state=cache_params.rnn_state[layer_idx],
- )
- for state_idx in range(len(cache_params.rnn_state[layer_idx])):
- local_rnn_state = rnn_state[state_idx]
- cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
- cache_params.rnn_state_initial = False
- final_state[
- :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
- ] = hidden_states_chunk
- offset += self.config.max_inference_chunksize
- hidden_states = final_state
- else:
- all_hidden_states = () if output_hidden_states else None
- for layer_idx, xlstm_block in enumerate(self.blocks):
- if self.gradient_checkpointing and self.training:
- hidden_states, rnn_state = self._gradient_checkpointing_func(
- xlstm_block.__call__,
- hidden_states,
- cache_params.rnn_state[layer_idx] if cache_params is not None else None,
- )
- else:
- hidden_states, rnn_state = xlstm_block(
- hidden_states,
- state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
- )
- if cache_params:
- for state_idx in range(len(cache_params.rnn_state[layer_idx])):
- local_rnn_state = rnn_state[state_idx]
- cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
- cache_params.rnn_state_initial = False
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- if use_cache:
- cache_params.seqlen_offset += inputs_embeds.shape[1]
- hidden_states = self.out_norm(hidden_states)
- if output_hidden_states:
- all_hidden_states = all_hidden_states + (hidden_states,)
- return xLSTMOutput(
- last_hidden_state=hidden_states,
- cache_params=cache_params,
- hidden_states=all_hidden_states,
- )
- @dataclass
- @auto_docstring
- class xLSTMCausalLMOutput(ModelOutput):
- r"""
- loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
- Language modeling loss (for next-token prediction).
- logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
- Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
- cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
- The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
- avoid providing the old `input_ids`.
- """
- loss: Optional[torch.FloatTensor] = None
- logits: Optional[torch.FloatTensor] = None
- cache_params: Optional[xLSTMCache] = None
- hidden_states: Optional[tuple[torch.FloatTensor]] = None
- @auto_docstring
- class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
- def __init__(self, config):
- super().__init__(config)
- self.backbone = xLSTMModel(config)
- self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
- # Initialize weights and apply final processing
- self.post_init()
- def get_output_embeddings(self):
- return self.lm_head
- def set_output_embeddings(self, new_embeddings):
- self.lm_head = new_embeddings
- def get_input_embeddings(self):
- return self.backbone.get_input_embeddings()
- def set_input_embeddings(self, new_embeddings):
- return self.backbone.set_input_embeddings(new_embeddings)
- def prepare_inputs_for_generation(
- self,
- input_ids,
- attention_mask=None, # not used but needed, otherwise generate complains when passing tokenizer inputs
- inputs_embeds=None,
- use_cache=None,
- cache_params: Optional[xLSTMCache] = None,
- **kwargs,
- ):
- if use_cache and cache_params is not None:
- # If the first cache position is non-zero, we assume we are in generation mode.
- # Thus, the cache_params state is assumed to be the state before the last token
- # (lastly generated token), and all previous tokens are already ingested.
- # This should as well support generation from scratch with the [BOS] token inserted first.
- input_ids = input_ids[:, -1:]
- if inputs_embeds is not None:
- inputs_embeds = inputs_embeds[:, -1:]
- if inputs_embeds is not None and cache_params is None:
- model_inputs = {"inputs_embeds": inputs_embeds}
- else:
- model_inputs = {"input_ids": input_ids}
- model_inputs.update({"cache_params": cache_params, "use_cache": use_cache})
- # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
- for key, value in kwargs.items():
- if key not in model_inputs:
- model_inputs[key] = value
- return model_inputs
- @can_return_tuple
- @auto_docstring
- def forward(
- self,
- input_ids: Optional[torch.LongTensor] = None,
- inputs_embeds: Optional[torch.FloatTensor] = None,
- cache_params: Optional[xLSTMCache] = None,
- labels: Optional[torch.LongTensor] = None,
- use_cache: Optional[bool] = None,
- output_hidden_states: Optional[bool] = None,
- **kwargs,
- ) -> Union[tuple, xLSTMCausalLMOutput]:
- r"""
- cache_params (`xLSTMCache`, *optional*):
- The xLSTMCache that carries the RNN states.
- """
- xlstm_outputs = self.backbone(
- input_ids,
- cache_params=cache_params,
- inputs_embeds=inputs_embeds,
- use_cache=use_cache,
- output_hidden_states=output_hidden_states,
- **kwargs,
- )
- hidden_states = xlstm_outputs[0]
- logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
- if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
- offset = 0
- with torch.no_grad():
- while offset < logits.shape[1]:
- logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
- logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
- self.config.output_logit_soft_cap,
- )
- offset += self.config.max_inference_chunksize
- else:
- logits = soft_cap(logits, self.config.output_logit_soft_cap)
- loss = None
- if labels is not None:
- # move labels to correct device to enable model parallelism
- labels = labels.to(logits.device)
- # Shift so that tokens < nstate predict nstate
- shift_logits = logits[..., :-1, :].contiguous()
- shift_labels = labels[..., 1:].contiguous()
- # Flatten the tokens
- loss_fct = CrossEntropyLoss()
- loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
- return xLSTMCausalLMOutput(
- loss=loss,
- logits=logits,
- cache_params=xlstm_outputs.cache_params,
- hidden_states=xlstm_outputs.hidden_states,
- )
- __all__ = [
- "xLSTMForCausalLM",
- "xLSTMModel",
- "xLSTMPreTrainedModel",
- ]
|