modeling_xlstm.py 65 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629
  1. # Copyright 2025 NXAI GmbH. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch xLSTM Model."""
  15. from dataclasses import dataclass
  16. from typing import Optional, Union
  17. import torch
  18. import torch.nn.functional as F
  19. from torch import nn
  20. from torch.nn import CrossEntropyLoss
  21. from ...generation import GenerationMixin
  22. from ...modeling_utils import PreTrainedModel
  23. from ...utils import ModelOutput, auto_docstring, can_return_tuple, is_xlstm_available
  24. from .configuration_xlstm import xLSTMConfig
  25. if is_xlstm_available():
  26. from xlstm.xlstm_large.model import RMSNorm as xLSTMRMSNorm
  27. from xlstm.xlstm_large.model import mLSTMBlock as xLSTMBlock
  28. from xlstm.xlstm_large.model import mLSTMStateType, soft_cap
  29. external_xlstm = True
  30. else:
  31. from functools import partial
  32. from typing import Callable, Literal
  33. from .configuration_xlstm import round_up_to_next_multiple_of
  34. mLSTMLayerStateType = tuple[torch.Tensor, torch.Tensor, torch.Tensor]
  35. mLSTMStateType = dict[int, mLSTMLayerStateType]
  36. external_xlstm = False
  37. def soft_cap(values: torch.Tensor, cap_value: Optional[Union[float, torch.Tensor]] = None) -> torch.Tensor:
  38. """
  39. Soft caps a tensor to a value.
  40. Performs a tanh operation on the logits and scales the result to the cap value. Common technique in attention
  41. and output language heads to prevent large logits from dominating the softmax. See for example Gemma2:
  42. https://huggingface.co/papers/2408.00118
  43. Args:
  44. values: The tensor to cap.
  45. cap_value: The value to cap the values to. If None, no cap is applied.
  46. Returns:
  47. The capped values.
  48. """
  49. if cap_value is None:
  50. return values
  51. return cap_value * torch.tanh(values / cap_value)
  52. def mlstm_chunkwise_recurrent_fw_C(
  53. matK: torch.Tensor,
  54. matV: torch.Tensor,
  55. vecB: torch.Tensor,
  56. vecI: torch.Tensor,
  57. matC_states: Optional[torch.Tensor] = None,
  58. vecN_states: Optional[torch.Tensor] = None,
  59. scaMinter_states: Optional[torch.Tensor] = None,
  60. matC_initial: Optional[torch.Tensor] = None,
  61. vecN_initial: Optional[torch.Tensor] = None,
  62. scaMinter_initial: Optional[torch.Tensor] = None,
  63. qk_scale: Optional[float] = None,
  64. chunk_size: int = 64,
  65. num_chunks: int = 1,
  66. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  67. batch_size, nh, _, dhqk, dhhv = *matK.shape, matV.shape[-1]
  68. nc = num_chunks
  69. _dtype, _device = matK.dtype, matK.device
  70. if qk_scale is None:
  71. qk_scale = dhqk**-0.5
  72. # initialize the states tensors
  73. if matC_states is None:
  74. matC_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk, dhhv), dtype=_dtype, device=_device)
  75. if vecN_states is None:
  76. vecN_states = torch.zeros((batch_size, nh, (nc + 1) * dhqk), dtype=_dtype, device=_device)
  77. if scaMinter_states is None:
  78. scaMinter_states = torch.zeros((batch_size, nh, (nc + 1)), dtype=_dtype, device=_device)
  79. # assign the initial states to the running states
  80. matC_k = (
  81. torch.zeros((batch_size, nh, dhqk, dhhv), dtype=_dtype, device=_device)
  82. if matC_initial is None
  83. else matC_initial
  84. )
  85. vecN_k = (
  86. torch.zeros((batch_size, nh, dhqk), dtype=_dtype, device=_device) if vecN_initial is None else vecN_initial
  87. )
  88. scaM_inter_k = (
  89. torch.zeros((batch_size, nh, 1), dtype=_dtype, device=_device)
  90. if scaMinter_initial is None
  91. else scaMinter_initial
  92. )
  93. vecA = vecB[..., -1, None] - vecB + vecI
  94. scaG = vecB[..., -1]
  95. scaA_max = vecA.max(-1).values
  96. scaM_inter_k = scaM_inter_k.squeeze(-1)
  97. for key in range(0, num_chunks):
  98. # store the states from the previous iteration before updating them
  99. # in the first iteration, these are the initial states
  100. matC_states[:, :, key * dhqk : (key + 1) * dhqk, :] = matC_k
  101. vecN_states[:, :, key * dhqk : (key + 1) * dhqk] = vecN_k
  102. scaMinter_states[:, :, key] = scaM_inter_k
  103. # m_k update
  104. scaA_max_k = scaA_max[:, :, key]
  105. scaG_k = scaG[:, :, key]
  106. scaM_inter_k_next = torch.max(scaG_k + scaM_inter_k, scaA_max_k)
  107. # C_k update
  108. matK_chunk = matK[:, :, key * chunk_size : (key + 1) * chunk_size, :] # * qk_scale
  109. matV_chunk = matV[:, :, key * chunk_size : (key + 1) * chunk_size, :]
  110. vecA_k = vecA[:, :, key, :]
  111. vecAbar_k = torch.exp(vecA_k - scaM_inter_k_next[..., None])[:, :, :, None]
  112. matK_chunk_gated = matK_chunk * vecAbar_k
  113. scaGbar_k = torch.exp(scaG_k + scaM_inter_k - scaM_inter_k_next)[:, :, None]
  114. # NOTE: no update in-place (i.e. +=) as this gives error for autograd backward
  115. matC_k_next = scaGbar_k[..., None] * matC_k + matK_chunk_gated.transpose(-2, -1) @ (matV_chunk)
  116. # n_k update
  117. vecN_k_next = scaGbar_k * vecN_k + matK_chunk_gated.transpose(-2, -1).sum(-1)
  118. # move to the next iteration
  119. scaM_inter_k = scaM_inter_k_next
  120. matC_k = matC_k_next
  121. vecN_k = vecN_k_next
  122. # store the states from the last iteration
  123. matC_states[:, :, -dhqk:, :] = matC_k
  124. vecN_states[:, :, -dhqk:] = vecN_k
  125. scaMinter_states[:, :, -1] = scaM_inter_k
  126. return matC_states, vecN_states, scaMinter_states
  127. def mlstm_chunkwise_parallel_fw_H(
  128. matQ: torch.Tensor,
  129. matK: torch.Tensor,
  130. matV: torch.Tensor,
  131. # these states must be all states up to the last chunk, i.e. :-1
  132. matC_states: torch.Tensor,
  133. vecN_states: torch.Tensor,
  134. scaMinter_states: torch.Tensor,
  135. vecI: torch.Tensor,
  136. vecB: torch.Tensor,
  137. qk_scale: float,
  138. chunk_size: int = 64,
  139. num_chunks: int = 1,
  140. eps: float = 1e-6,
  141. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  142. _device = matQ.device
  143. nc = num_chunks
  144. batch_size, nh, dqk, dhv = matC_states.shape
  145. matC_k_states = matC_states.view(batch_size, nh, nc, dqk // nc, dhv)
  146. vecN_k_states = vecN_states.view(batch_size, nh, nc, dqk // nc)
  147. scaMinter_k_states = scaMinter_states
  148. matQ = matQ.view(batch_size, nh, nc, chunk_size, dqk)
  149. matK = matK.view(batch_size, nh, nc, chunk_size, dqk)
  150. matV = matV.view(batch_size, nh, nc, chunk_size, dhv)
  151. ltr = torch.tril(
  152. torch.ones(
  153. (chunk_size, chunk_size),
  154. dtype=torch.bool,
  155. device=_device,
  156. )
  157. )
  158. # Compute intra chunk contribution: H_intra
  159. matF_logsig_chunk = vecB[:, :, :, :, None] - vecB[:, :, :, None, :]
  160. matF_logsig_mask_chunk = torch.where(ltr, matF_logsig_chunk, -float("inf"))
  161. matLogD_chunk = matF_logsig_mask_chunk + vecI[:, :, :, None, :]
  162. # max_state intra
  163. vecMintra_k = torch.max(matLogD_chunk, dim=-1, keepdim=False).values
  164. # max_state combined
  165. vecM_b_inter = vecB + scaMinter_k_states[:, :, :, None]
  166. vecM_k_combine = torch.maximum(vecM_b_inter, vecMintra_k)
  167. vecM_k_combine = vecM_k_combine[:, :, :, :, None]
  168. vecM_b_inter = vecM_b_inter[:, :, :, :, None]
  169. matLogD_stabilized_chunk = matLogD_chunk - vecM_k_combine
  170. matD_chunk = torch.exp(matLogD_stabilized_chunk)
  171. matS_chunk = (matQ @ matK.transpose(-2, -1)) * qk_scale
  172. matM_chunk = matS_chunk * matD_chunk
  173. # ? Combine H_intra with H_inter
  174. vecBbar = torch.exp(vecM_b_inter - vecM_k_combine)
  175. matQ_chunk_gated = matQ * vecBbar * qk_scale
  176. matNumerator_common = matQ_chunk_gated @ matC_k_states + matM_chunk @ matV
  177. vecDenom_l_common = matQ_chunk_gated @ vecN_k_states.unsqueeze(-1) + matM_chunk.sum(dim=-1, keepdim=True)
  178. vecDenom_max_common = torch.maximum(torch.abs(vecDenom_l_common), torch.exp(-vecM_k_combine))
  179. matH_k_chunk = matNumerator_common / (vecDenom_max_common + eps)
  180. matH_out = matH_k_chunk.view(batch_size, nh, nc * chunk_size, dhv)
  181. # we need the denominator and the overall max state for the backward pass
  182. vecN_out = vecDenom_max_common.reshape(batch_size, nh, nc * chunk_size)
  183. vecM_out = vecM_k_combine(batch_size, nh, nc * chunk_size)
  184. return matH_out, vecN_out, vecM_out
  185. def mlstm_chunkwise_fw(
  186. query: torch.Tensor,
  187. key: torch.Tensor,
  188. value: torch.Tensor,
  189. igate: torch.Tensor,
  190. fgate: torch.Tensor,
  191. cstate: Optional[torch.Tensor] = None,
  192. nstate: Optional[torch.Tensor] = None,
  193. mstate: Optional[torch.Tensor] = None,
  194. qk_scale: Optional[float] = None,
  195. return_last_states: bool = False,
  196. return_all_states: bool = False,
  197. chunk_size: int = 64,
  198. eps: float = 1e-6,
  199. ) -> tuple[
  200. torch.Tensor,
  201. torch.Tensor,
  202. torch.Tensor,
  203. Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
  204. Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
  205. ]:
  206. batch_size, nh, sequence_length, dhqk = query.shape
  207. if sequence_length % chunk_size != 0:
  208. raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
  209. nc = sequence_length // chunk_size
  210. vecI = igate.view(batch_size, nh, nc, chunk_size)
  211. vecF = fgate.view(batch_size, nh, nc, chunk_size)
  212. # compute the gates, the g and the a and b vectors
  213. vecF_logsig = fgate.logsigmoid(vecF)
  214. vecB = vecF_logsig.cumsum(-1)
  215. if qk_scale is None:
  216. qk_scale = dhqk**-0.5
  217. #! materialize the C_k, n_k, m_k states for each chunk
  218. matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
  219. matK=key,
  220. matV=value,
  221. vecB=vecB,
  222. vecI=vecI,
  223. matC_initial=cstate,
  224. vecN_initial=nstate,
  225. scaMinter_initial=mstate,
  226. qk_scale=qk_scale,
  227. chunk_size=chunk_size,
  228. num_chunks=nc,
  229. )
  230. #! compute the outputs within each chunk
  231. matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
  232. matQ=query,
  233. matK=key,
  234. matV=value,
  235. matC_states=matC_k_states[:, :, :-dhqk, :],
  236. vecN_states=vecN_k_states[:, :, :-dhqk],
  237. scaMinter_states=scaMinter_k_states[:, :, :-1],
  238. vecI=vecI,
  239. vecB=vecB,
  240. qk_scale=qk_scale,
  241. chunk_size=chunk_size,
  242. num_chunks=nc,
  243. eps=eps,
  244. )
  245. ret_tuple = (matH_out, vecN_out, vecM_out)
  246. if return_last_states:
  247. ret_tuple += (
  248. (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:]),
  249. )
  250. else:
  251. ret_tuple += (None,)
  252. if return_all_states:
  253. ret_tuple += ((matC_k_states, vecN_k_states, scaMinter_k_states),)
  254. else:
  255. ret_tuple += (None,)
  256. return ret_tuple
  257. def mlstm_chunkwise_native_autograd(
  258. query: torch.Tensor,
  259. key: torch.Tensor,
  260. value: torch.Tensor,
  261. igate: torch.Tensor,
  262. fgate: torch.Tensor,
  263. c_initial: Optional[torch.Tensor] = None,
  264. n_initial: Optional[torch.Tensor] = None,
  265. m_initial: Optional[torch.Tensor] = None,
  266. return_last_states: bool = False,
  267. eps: float = 1e-6,
  268. chunk_size: int = 64,
  269. **kwargs,
  270. ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
  271. batch_size, nh, sequence_length, dhqk = query.shape
  272. if sequence_length % chunk_size != 0:
  273. raise ValueError(f"Sequence length {sequence_length} is not divisible by chunk size {chunk_size}.")
  274. nc = sequence_length // chunk_size
  275. vecI = igate.view(batch_size, nh, nc, chunk_size)
  276. vecF = fgate.view(batch_size, nh, nc, chunk_size)
  277. # compute the gates, the g and the a and b vectors
  278. vecF_logsig = F.logsigmoid(vecF)
  279. vecB = vecF_logsig.cumsum(-1)
  280. qk_scale = dhqk**-0.5
  281. #! materialize the C_k, n_k, m_k states for each chunk
  282. matC_k_states, vecN_k_states, scaMinter_k_states = mlstm_chunkwise_recurrent_fw_C(
  283. matK=key,
  284. matV=value,
  285. vecB=vecB,
  286. vecI=vecI,
  287. matC_initial=c_initial,
  288. vecN_initial=n_initial,
  289. scaMinter_initial=m_initial,
  290. qk_scale=qk_scale,
  291. chunk_size=chunk_size,
  292. num_chunks=nc,
  293. )
  294. #! compute the outputs within each chunk
  295. matH_out, vecN_out, vecM_out = mlstm_chunkwise_parallel_fw_H(
  296. matQ=query,
  297. matK=key,
  298. matV=value,
  299. matC_states=matC_k_states[:, :, :-dhqk, :],
  300. vecN_states=vecN_k_states[:, :, :-dhqk],
  301. scaMinter_states=scaMinter_k_states[:, :, :-1],
  302. vecI=vecI,
  303. vecB=vecB,
  304. qk_scale=qk_scale,
  305. chunk_size=chunk_size,
  306. num_chunks=nc,
  307. eps=eps,
  308. )
  309. last_states = (matC_k_states[:, :, -dhqk:, :], vecN_k_states[:, :, -dhqk:], scaMinter_k_states[:, :, -1:])
  310. if return_last_states:
  311. return matH_out, last_states
  312. else:
  313. return matH_out
  314. def mlstm_recurrent_step_native(
  315. query: torch.Tensor,
  316. key: torch.Tensor,
  317. value: torch.Tensor,
  318. igate: torch.Tensor,
  319. fgate: torch.Tensor,
  320. cstate: torch.Tensor,
  321. nstate: torch.Tensor,
  322. mstate: torch.Tensor,
  323. eps: float = 1e-6,
  324. dtype_state: torch.dtype = torch.float32,
  325. **kwargs,
  326. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]:
  327. """This is a single step of the mLSTM operation in recurrent form."""
  328. dtype_qkv = query.dtype
  329. matC_old = cstate.to(dtype=dtype_state)
  330. vecN_old = nstate.to(dtype=dtype_state)
  331. scaM_old = mstate.to(dtype=dtype_state)
  332. batch_size, nh, dhqk = query.shape
  333. _, _, dhhv = value.shape
  334. if query.shape != key.shape:
  335. raise ValueError("query and key must have the same shape")
  336. if matC_old.shape != (batch_size, nh, dhqk, dhhv):
  337. raise ValueError(f"matC_old has wrong shape, got {matC_old.shape}")
  338. if vecN_old.shape != (batch_size, nh, dhqk):
  339. raise ValueError(f"vecN_old has wrong shape, got {vecN_old.shape}")
  340. if scaM_old.shape != (batch_size, nh, 1):
  341. raise ValueError(f"scaM_old has wrong shape, got {scaM_old.shape}")
  342. if igate.shape != (batch_size, nh, 1):
  343. raise ValueError(f"scaI has wrong shape, got {igate.shape}")
  344. if fgate.shape != (batch_size, nh, 1):
  345. raise ValueError(f"scaF has wrong shape, got {fgate.shape}")
  346. # gates
  347. scaF_log = torch.nn.functional.logsigmoid(fgate)
  348. # update rule
  349. scaM_state_new = torch.max(scaF_log + scaM_old, igate)
  350. scaF_act = torch.exp(scaF_log + scaM_old - scaM_state_new)
  351. scaI_act = torch.exp(igate - scaM_state_new)
  352. vecQ_scaled = query * (dhqk ** (-0.5))
  353. matC_state_new = scaF_act[:, :, :, None] * matC_old + scaI_act[:, :, :, None] * (
  354. key[:, :, :, None] @ value[:, :, None, :]
  355. )
  356. vecN_state_new = scaF_act * vecN_old + scaI_act * key
  357. h_num = vecQ_scaled[:, :, None, :] @ matC_state_new.to(dtype=dtype_qkv)
  358. h_num = h_num.squeeze(2).to(dtype=dtype_state)
  359. qn_dotproduct = vecQ_scaled[:, :, None, :] @ vecN_state_new[:, :, :, None].to(dtype=dtype_qkv)
  360. qn_dotproduct = qn_dotproduct.squeeze(2)
  361. max_val = torch.exp(-scaM_state_new)
  362. h_denom = (torch.maximum(qn_dotproduct.abs(), max_val) + eps).to(dtype=dtype_state)
  363. h = h_num / h_denom
  364. h = h.to(dtype=dtype_qkv)
  365. matC_state_new = matC_state_new.to(dtype=dtype_state)
  366. vecN_state_new = vecN_state_new.to(dtype=dtype_state)
  367. scaM_state_new = scaM_state_new.to(dtype=dtype_state)
  368. return h, (matC_state_new, vecN_state_new, scaM_state_new)
  369. def mlstm_recurrent_sequence_native(
  370. query: torch.Tensor,
  371. key: torch.Tensor,
  372. value: torch.Tensor,
  373. igate: torch.Tensor,
  374. fgate: torch.Tensor,
  375. c_initial: Optional[torch.Tensor] = None,
  376. n_initial: Optional[torch.Tensor] = None,
  377. m_initial: Optional[torch.Tensor] = None,
  378. return_last_states: bool = False,
  379. eps: float = 1e-6,
  380. dtype_state: torch.dtype = torch.float32,
  381. **kwargs,
  382. ) -> tuple[
  383. torch.Tensor,
  384. torch.Tensor,
  385. torch.Tensor,
  386. Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
  387. Optional[tuple[torch.Tensor, torch.Tensor, torch.Tensor]],
  388. ]:
  389. batch_size, nh, sequence_length, dhqk = query.shape
  390. dhv = value.shape[-1]
  391. device = query.device
  392. if c_initial is not None:
  393. if n_initial is None or m_initial is None:
  394. raise ValueError("Initial states must be provided together.")
  395. if n_initial is None or m_initial is None:
  396. raise ValueError("Initial states must be provided together.")
  397. matC_state, vecN_state, vecM_state = (
  398. c_initial.to(dtype=dtype_state),
  399. n_initial.to(dtype=dtype_state),
  400. m_initial.to(dtype=dtype_state),
  401. )
  402. else:
  403. # memory state
  404. matC_state = torch.zeros((batch_size, nh, dhqk, dhv), dtype=dtype_state, device=device)
  405. # normalizer state
  406. vecN_state = torch.zeros((batch_size, nh, dhqk), dtype=dtype_state, device=device)
  407. # max state
  408. vecM_state = torch.zeros((batch_size, nh, 1), dtype=dtype_state, device=device)
  409. vecH_list = []
  410. for t in range(sequence_length):
  411. # gates
  412. vecF_t, vecI_t = fgate[:, :, t, None], igate[:, :, t, None]
  413. # projections
  414. vecQ_t, vecK_t, vecV_t = query[:, :, t, :], key[:, :, t, :], value[:, :, t, :]
  415. # step
  416. vecH, (matC_state, vecN_state, vecM_state) = mlstm_recurrent_step_native(
  417. cstate=matC_state,
  418. nstate=vecN_state,
  419. mstate=vecM_state,
  420. query=vecQ_t,
  421. key=vecK_t,
  422. value=vecV_t,
  423. igate=vecI_t,
  424. fgate=vecF_t,
  425. eps=eps,
  426. dtype_state=dtype_state,
  427. **kwargs,
  428. )
  429. vecH_list.append(vecH)
  430. matH = torch.stack(vecH_list, dim=-2)
  431. if return_last_states:
  432. return matH, (matC_state, vecN_state, vecM_state)
  433. else:
  434. return matH
  435. def wrap_chunkwise_pad_zeros(
  436. mlstm_chunkwise_kernel: Callable,
  437. query: torch.Tensor,
  438. key: torch.Tensor,
  439. value: torch.Tensor,
  440. fgate: torch.Tensor,
  441. igate: torch.Tensor,
  442. c_initial: Optional[torch.Tensor] = None,
  443. n_initial: Optional[torch.Tensor] = None,
  444. m_initial: Optional[torch.Tensor] = None,
  445. return_last_states: bool = False,
  446. eps: float = 1e-6,
  447. autocast_kernel_dtype: torch.dtype = torch.bfloat16,
  448. chunk_size: int = 64,
  449. **kwargs,
  450. ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
  451. if return_last_states:
  452. raise ValueError(
  453. "We are padding zeros, so we cannot return last states,",
  454. "as they would be not the true last states.",
  455. )
  456. batch_size, nh, sequence_length, dhqk = query.shape
  457. S_unpadded = sequence_length
  458. # padding to chunk size for kernels
  459. if sequence_length % chunk_size != 0:
  460. S_padded = ((sequence_length + chunk_size - 1) // chunk_size) * chunk_size
  461. q_pad = query.new_zeros(batch_size, nh, S_padded, query.shape[3])
  462. k_pad = key.new_zeros(batch_size, nh, S_padded, key.shape[3])
  463. v_pad = value.new_zeros(batch_size, nh, S_padded, value.shape[3])
  464. i_pad = igate.new_zeros(batch_size, nh, S_padded)
  465. f_pad = fgate.new_zeros(batch_size, nh, S_padded)
  466. q_pad[:, :, :S_unpadded, :] = query
  467. k_pad[:, :, :S_unpadded, :] = key
  468. v_pad[:, :, :S_unpadded, :] = value
  469. i_pad[:, :, :S_unpadded] = igate
  470. f_pad[:, :, :S_unpadded] = fgate
  471. else:
  472. q_pad = query
  473. k_pad = key
  474. v_pad = value
  475. i_pad = igate
  476. f_pad = fgate
  477. matH = mlstm_chunkwise_kernel(
  478. query=q_pad,
  479. key=k_pad,
  480. value=v_pad,
  481. igate=i_pad,
  482. fgate=f_pad,
  483. c_initial=c_initial,
  484. n_initial=n_initial,
  485. m_initial=m_initial,
  486. return_last_states=return_last_states,
  487. eps=eps,
  488. autocast_kernel_dtype=autocast_kernel_dtype,
  489. chunk_size=chunk_size,
  490. **kwargs,
  491. )
  492. matH = matH[:, :, :S_unpadded, :]
  493. return matH
  494. def wrap_chunkwise_arbitrary_sequence_length(
  495. mlstm_chunkwise_kernel: Callable,
  496. mlstm_sequence_kernel: Callable,
  497. mlstm_step_kernel: Callable,
  498. query: torch.Tensor,
  499. key: torch.Tensor,
  500. value: torch.Tensor,
  501. fgate: torch.Tensor,
  502. igate: torch.Tensor,
  503. c_initial: Optional[torch.Tensor] = None,
  504. n_initial: Optional[torch.Tensor] = None,
  505. m_initial: Optional[torch.Tensor] = None,
  506. return_last_states: bool = True,
  507. eps: float = 1e-6,
  508. autocast_kernel_dtype: torch.dtype = torch.bfloat16,
  509. chunk_size: int = 64,
  510. enable_logging: bool = False,
  511. ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
  512. """This function computes the last hidden state and matH outputs of the mLSTM, independently of the sequence length.
  513. For this it uses three kernels:
  514. - mlstm_chunkwise_kernel: mlstm chunkwise kernels that processes chunks of a given chunk size in parallel.
  515. - mlstm_sequence_kernel: mlstm kernel that processes the remaining sequence length in a single step recurrence.
  516. - mlstm_step_kernel: mlstm kernel that processes a sequence length of 1 in a single step.
  517. It tries to maximize the chunksizes to improve performance.
  518. It will start with the given chunk size and then divides the chunksize by 2 until the chunk size is smaller than 16.
  519. At every chunksize it will process the maximal number of chunks that fit into the remaining sequence length.
  520. E.g. for chunk_size = 64, this function will try the chunksizes [64, 32, 16] if necessary.
  521. For the remaining sequence length, which is smaller than 16, we use a different kernel that computes the mLSTM
  522. in a single step and loop over this in pytorch.
  523. Args:
  524. mlstm_chunkwise_kernel: The mLSTM chunkwise kernel that processes chunks of a given chunk size in parallel
  525. mlstm_sequence_kernel: The mLSTM kernel that processes the remaining sequence length in a single step recurrence
  526. query: The query tensor (batch_size, nh, sequence_length, dhqk)
  527. key: The key tensor (batch_size, nh, sequence_length, dhqk)
  528. value: The value tensor (batch_size, nh, sequence_length, dhhv)
  529. fgate: The forget gate tensor (batch_size, nh, sequence_length)
  530. igate: The input gate tensor (batch_size, nh, sequence_length)
  531. c_initial: The initial cell state tensor (batch_size, nh, dhqk, dhhv)
  532. n_initial: The initial hidden state tensor (batch_size, nh, dhqk)
  533. m_initial: The initial memory state tensor (batch_size, nh, 1)
  534. return_last_states: If True, the function will return the last states of the mLSTM
  535. eps: The epsilon value used for numerical stability
  536. autocast_kernel_dtype: The dtype used for the kernel computation
  537. chunk_size: The chunk size used for the chunkwise kernel
  538. enable_logging: If True, the function will log debug information. Default is False.
  539. Returns:
  540. The last hidden state tensor (batch_size, nh, sequence_length, dhhv) or a tuple containing the last hidden state tensor and the last states of the mLSTM
  541. Last states are (cstate (batch_size, nh, dhqk, dhhv), nstate (batch_size, nh, dhqk), mstate (batch_size, nh, 1)).
  542. """
  543. batch_size, nh, sequence_length, dhqk = key.shape
  544. dhhv = value.shape[-1]
  545. c_state = (
  546. c_initial
  547. if c_initial is not None
  548. else torch.zeros(batch_size, nh, dhqk, dhhv, device=key.device, dtype=torch.float32)
  549. )
  550. n_state = (
  551. n_initial
  552. if n_initial is not None
  553. else torch.zeros(batch_size, nh, dhqk, device=key.device, dtype=torch.float32)
  554. )
  555. m_state = (
  556. m_initial
  557. if m_initial is not None
  558. else torch.zeros(batch_size, nh, 1, device=key.device, dtype=torch.float32)
  559. )
  560. if sequence_length > 1:
  561. # process the sequence length in chunks
  562. h_outs = []
  563. seq_len_start_idx = 0
  564. remaining_seq_len = sequence_length - seq_len_start_idx
  565. num_chunks = remaining_seq_len // chunk_size
  566. if num_chunks > 0:
  567. iter_seq_len = chunk_size * num_chunks
  568. seq_len_idx = seq_len_start_idx + iter_seq_len
  569. h_out, (c_state, n_state, m_state) = mlstm_chunkwise_kernel(
  570. query=query[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  571. key=key[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  572. value=value[..., seq_len_start_idx:seq_len_idx, :].contiguous(),
  573. fgate=fgate[..., seq_len_start_idx:seq_len_idx].contiguous(),
  574. igate=igate[..., seq_len_start_idx:seq_len_idx].contiguous(),
  575. c_initial=c_state,
  576. n_initial=n_state,
  577. m_initial=m_state,
  578. chunk_size=chunk_size,
  579. return_last_states=True,
  580. autocast_kernel_dtype=autocast_kernel_dtype,
  581. eps=eps,
  582. )
  583. seq_len_start_idx += iter_seq_len
  584. h_outs.append(h_out)
  585. remaining_seq_len = sequence_length - seq_len_start_idx
  586. if remaining_seq_len > 0:
  587. # we use here matK as query as this kernel does not need a query, since we do not care about the outputs only about the last state
  588. h_out, (c_state, n_state, m_state) = mlstm_sequence_kernel(
  589. query=query[..., seq_len_start_idx:sequence_length, :].contiguous(),
  590. key=key[..., seq_len_start_idx:sequence_length, :].contiguous(),
  591. value=value[..., seq_len_start_idx:sequence_length, :].contiguous(),
  592. igate=igate[..., seq_len_start_idx:sequence_length].contiguous(),
  593. fgate=fgate[..., seq_len_start_idx:sequence_length].contiguous(),
  594. c_initial=c_state,
  595. n_initial=n_state,
  596. m_initial=m_state,
  597. return_last_states=True,
  598. eps=eps,
  599. )
  600. h_outs.append(h_out)
  601. h_out = torch.concatenate(h_outs, dim=2)
  602. else:
  603. if sequence_length != 1:
  604. raise ValueError(
  605. f"Received empty sequence (sequence_length={sequence_length}), require at least single element in the sequence."
  606. )
  607. # process the sequence length in a single step
  608. # while this case is also captured by the regular mode above,
  609. # it avoids the overhead of the loop and calls the step kernel directly
  610. # The step function does not want a sequence dimension
  611. # qkv shape is (batch_size, nh, dhqk/dhv)
  612. # igate, fgate shape is (batch_size, nh, 1)
  613. h_out, (c_state, n_state, m_state) = mlstm_step_kernel(
  614. query=query.squeeze(2),
  615. key=key.squeeze(2),
  616. value=value.squeeze(2),
  617. igate=igate,
  618. fgate=fgate,
  619. cstate=c_state,
  620. nstate=n_state,
  621. mstate=m_state,
  622. eps=eps,
  623. )
  624. h_out = h_out[:, :, None, :]
  625. if return_last_states:
  626. return h_out, (c_state, n_state, m_state)
  627. else:
  628. return h_out
  629. class xLSTMBackend(nn.Module):
  630. """xLSTM Backend Module for PyTorch.
  631. This module wraps the xLSTM kernels and provides a high-level interface for training and inference.
  632. """
  633. config_class = xLSTMConfig
  634. def __init__(self, config: xLSTMConfig):
  635. super().__init__()
  636. self.config = config
  637. self.chunkwise_kernel_fn = mlstm_chunkwise_native_autograd
  638. self.sequence_kernel_fn = mlstm_recurrent_sequence_native
  639. self.step_kernel_fn = mlstm_recurrent_step_native
  640. self._inference_fn = partial(
  641. wrap_chunkwise_arbitrary_sequence_length,
  642. mlstm_chunkwise_kernel=self.chunkwise_kernel_fn,
  643. mlstm_sequence_kernel=partial(
  644. self.sequence_kernel_fn,
  645. dtype_state=getattr(torch, config.inference_state_dtype),
  646. ),
  647. mlstm_step_kernel=partial(
  648. self.step_kernel_fn,
  649. dtype_state=getattr(torch, config.inference_state_dtype),
  650. ),
  651. chunk_size=config.chunk_size,
  652. eps=config.eps,
  653. autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
  654. return_last_states=True,
  655. )
  656. train_kernel_fn = partial(
  657. self.chunkwise_kernel_fn,
  658. autocast_kernel_dtype=getattr(torch, config.autocast_kernel_dtype),
  659. eps=config.eps,
  660. chunk_size=config.chunk_size,
  661. )
  662. if "with_padding" in config.mode:
  663. train_kernel_fn = partial(wrap_chunkwise_pad_zeros, mlstm_chunkwise_kernel=train_kernel_fn)
  664. self._train_fn = train_kernel_fn
  665. def forward(
  666. self,
  667. query: torch.Tensor,
  668. key: torch.Tensor,
  669. value: torch.Tensor,
  670. igate: torch.Tensor,
  671. fgate: torch.Tensor,
  672. c_initial: Optional[torch.Tensor] = None,
  673. n_initial: Optional[torch.Tensor] = None,
  674. m_initial: Optional[torch.Tensor] = None,
  675. return_last_states: bool = False,
  676. mode: Optional[Literal["train", "inference"]] = None,
  677. ) -> Union[torch.Tensor, tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor, torch.Tensor]]]:
  678. """Forward pass of the mLSTM backend.
  679. Depending on the configured mode, this method will call the appropriate kernel function.
  680. Args:
  681. query: The query tensor of shape (batch_size, nh, sequence_length, dhqk).
  682. key: The key tensor of shape (batch_size, nh, sequence_length, dhqk).
  683. value: The value tensor of shape (batch_size, nh, sequence_length, dhhv).
  684. igate: The input gate preactivation tensor of shape (batch_size, nh, sequence_length).
  685. fgate: The forget gate preactivation tensor of shape (batch_size, nh, sequence_length).
  686. c_initial: The initial cell state tensor of shape (batch_size, nh, dhqk, dhhv).
  687. Defaults to None.
  688. n_initial: The initial hidden state tensor of shape (batch_size, nh, dhqk). Defaults to None.
  689. m_initial: The initial memory tensor of shape (batch_size, nh, 1). Defaults to None.
  690. return_last_states: Whether to return the last states of the sequence. Defaults to None.
  691. If None, the value from the config is used.
  692. Returns:
  693. hidden states of shape (batch_size, nh, sequence_length, dhhv)
  694. hidden states and last states the last states are the cell state cstate (batch_size, nh, dhqk, dhhv),
  695. the normalizer state nstate (batch_size, nh, dhqk), and the max state mstate (batch_size, nh, 1)
  696. """
  697. if mode is None:
  698. mode = self.config.mode
  699. if "train" in mode:
  700. if return_last_states is None:
  701. return_last_states = self.config.return_last_states
  702. if self.config.mode == "train_with_padding":
  703. if return_last_states:
  704. raise ValueError("return_last_states=True is not supported with train_with_padding mode.")
  705. return self._train_fn(
  706. query=query,
  707. key=key,
  708. value=value,
  709. igate=igate,
  710. fgate=fgate,
  711. c_initial=c_initial,
  712. n_initial=n_initial,
  713. m_initial=m_initial,
  714. return_last_states=return_last_states,
  715. )
  716. elif "inference" in mode:
  717. # inference mode always returns the last states
  718. return self._inference_fn(
  719. query=query,
  720. key=key,
  721. value=value,
  722. igate=igate,
  723. fgate=fgate,
  724. c_initial=c_initial,
  725. n_initial=n_initial,
  726. m_initial=m_initial,
  727. )
  728. else:
  729. raise ValueError(f"Unknown mode: {self.config.mode}")
  730. def extra_repr(self) -> str:
  731. return f"{self.config}"
  732. class xLSTMRMSNorm(nn.Module):
  733. """Root mean square normalization layer implementation similar
  734. to https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html.
  735. It normalizes the input tensor by the root mean square of the last dimension.
  736. Args:
  737. num_features: The number of features in the input tensor.
  738. eps: A small value to avoid division by zero.
  739. use_weight: Whether to use a learnable weight.
  740. use_bias: Whether to use a learnable bias.
  741. force_float32_reductions: Whether to force float32 reductions.
  742. """
  743. def __init__(
  744. self,
  745. num_features: int,
  746. eps: float = 1e-6,
  747. use_weight: bool = True,
  748. use_bias: bool = False,
  749. force_float32_reductions: bool = True,
  750. ):
  751. super().__init__()
  752. self.num_features = num_features
  753. self.eps = eps
  754. self.force_float32_reductions = force_float32_reductions
  755. if use_weight:
  756. self.weight = nn.Parameter(torch.ones(num_features))
  757. else:
  758. self.weight = None
  759. if use_bias:
  760. self.bias = nn.Parameter(torch.zeros(num_features))
  761. else:
  762. self.bias = None
  763. def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
  764. if self.weight is not None:
  765. x = x * self.weight
  766. if self.bias is not None:
  767. x = x + self.bias
  768. return x
  769. def _rms_normalize(self, x: torch.Tensor) -> torch.Tensor:
  770. # apply rms norm over the last dimension, i.e. HD dimension
  771. in_dtype = x.dtype
  772. if self.force_float32_reductions:
  773. x = x.float()
  774. x = x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
  775. return x.to(in_dtype)
  776. def forward(self, x: torch.Tensor) -> torch.Tensor:
  777. x = self._rms_normalize(x)
  778. x = self._apply_weight_bias(x)
  779. return x
  780. class xLSTMMultiHeadLayerNorm(nn.Module):
  781. """Multi-head version of the LayerNorm layer.
  782. It normalizes the last dimension of the input tensor.
  783. The input is assumed to have the shape (batch_size, sequence_length, nh, DH), where:
  784. batch_size: batch size
  785. sequence_length: sequence length
  786. nh: number of heads
  787. DH: head dimension
  788. The normalization is applied over the last dimension (DH) of the input tensor.
  789. Args:
  790. num_heads: The number of heads.
  791. head_dim: The head dimension.
  792. eps: A small value to avoid division by zero.
  793. use_weight: Whether to use a learnable weight.
  794. use_bias: Whether to use a learnable bias.
  795. force_float32_reductions: Whether to force float32 reductions
  796. Returns:
  797. The normalized tensor with the shape (batch_size, sequence_length, nh * DH).
  798. """
  799. def __init__(
  800. self,
  801. num_heads: int,
  802. head_dim: int,
  803. eps: float = 1e-6,
  804. use_weight: bool = True,
  805. use_bias: bool = False,
  806. force_float32_reductions: bool = True,
  807. ):
  808. super().__init__()
  809. self.num_features = num_heads * head_dim
  810. self.eps = eps
  811. self.force_float32_reductions = force_float32_reductions
  812. if use_weight:
  813. self.weight = nn.Parameter(torch.ones(self.num_features))
  814. else:
  815. self.weight = None
  816. if use_bias:
  817. self.bias = nn.Parameter(torch.zeros(self.num_features))
  818. else:
  819. self.bias = None
  820. self.num_heads = num_heads
  821. self.head_dim = head_dim
  822. def _apply_weight_bias(self, x: torch.Tensor) -> torch.Tensor:
  823. if self.weight is not None:
  824. x = x * self.weight
  825. if self.bias is not None:
  826. x = x + self.bias
  827. return x
  828. def _layer_normalize(self, x: torch.Tensor) -> torch.Tensor:
  829. # apply layer norm over the last dimension, i.e. HD dimension
  830. in_dtype = x.dtype
  831. if self.force_float32_reductions:
  832. x = x.float()
  833. x_centered = x - x.mean(dim=-1, keepdim=True)
  834. y = x_centered * torch.rsqrt(x.var(dim=-1, keepdim=True, unbiased=False) + self.eps)
  835. return y.to(in_dtype)
  836. def forward(
  837. self,
  838. x: torch.Tensor,
  839. ) -> torch.Tensor:
  840. batch_size, sequence_length, nh, DH = x.shape
  841. if nh != self.num_heads:
  842. raise ValueError(f"Expected {self.num_heads} heads, got {nh}, input shape: {x.shape}")
  843. if DH != self.head_dim:
  844. raise ValueError(f"Expected {self.head_dim} head dimension, got {DH}, input shape: {x.shape}")
  845. x = self._layer_normalize(x)
  846. x = x.reshape(batch_size, sequence_length, -1)
  847. x = self._apply_weight_bias(x)
  848. return x
  849. class xLSTMFeedForward(nn.Module):
  850. def __init__(self, config: xLSTMConfig):
  851. super().__init__()
  852. self.config = config
  853. self.up_proj_dim = round_up_to_next_multiple_of(
  854. config.hidden_size * config.ffn_proj_factor,
  855. config.ffn_round_up_to_multiple_of,
  856. )
  857. if self.config.weight_mode == "single":
  858. self.proj_up_gate = nn.Linear(
  859. in_features=config.hidden_size,
  860. out_features=self.up_proj_dim,
  861. bias=self.config.use_bias,
  862. )
  863. self.proj_up = nn.Linear(
  864. in_features=config.hidden_size,
  865. out_features=self.up_proj_dim,
  866. bias=self.config.use_bias,
  867. )
  868. elif self.config.weight_mode == "fused":
  869. self.proj_up_gate_z = nn.Linear(
  870. in_features=config.hidden_size,
  871. out_features=2 * self.up_proj_dim,
  872. bias=self.config.use_bias,
  873. )
  874. self.proj_down = nn.Linear(
  875. in_features=self.up_proj_dim,
  876. out_features=config.hidden_size,
  877. bias=self.config.use_bias,
  878. )
  879. self.act_fn = nn.SiLU()
  880. def forward(self, x: torch.Tensor) -> torch.Tensor:
  881. if self.config.weight_mode == "single":
  882. x = self.act_fn(self.proj_up_gate(x)) * self.proj_up(x)
  883. elif self.config.weight_mode == "fused":
  884. x = self.proj_up_gate_z(x)
  885. gate, z = torch.tensor_split(x, (self.up_proj_dim,), dim=-1)
  886. x = self.act_fn(gate) * z
  887. y = self.proj_down(x)
  888. return y
  889. class xLSTMLayer(nn.Module):
  890. def __init__(self, config: xLSTMConfig):
  891. super().__init__()
  892. self.config = config
  893. self.v_dim = int(config.hidden_size * config.v_dim_factor)
  894. self.qk_dim = int(config.hidden_size * config.qk_dim_factor)
  895. if self.config.weight_mode == "single":
  896. self.q = nn.Linear(
  897. in_features=self.config.hidden_size,
  898. out_features=self.qk_dim,
  899. bias=self.config.use_bias,
  900. )
  901. self.k = nn.Linear(
  902. in_features=self.config.hidden_size,
  903. out_features=self.qk_dim,
  904. bias=self.config.use_bias,
  905. )
  906. self.v = nn.Linear(
  907. in_features=self.config.hidden_size,
  908. out_features=self.v_dim,
  909. bias=self.config.use_bias,
  910. )
  911. self.ogate_preact = nn.Linear(
  912. in_features=self.config.hidden_size,
  913. out_features=self.v_dim,
  914. bias=self.config.use_bias,
  915. )
  916. self.igate_preact = nn.Linear(
  917. in_features=self.config.hidden_size,
  918. out_features=self.config.num_heads,
  919. bias=True,
  920. )
  921. self.fgate_preact = nn.Linear(
  922. in_features=self.config.hidden_size,
  923. out_features=self.config.num_heads,
  924. bias=True,
  925. )
  926. elif self.config.weight_mode == "fused":
  927. self.qkv_opreact = nn.Linear(
  928. in_features=self.config.hidden_size,
  929. out_features=2 * self.qk_dim + 2 * self.v_dim,
  930. bias=self.config.use_bias,
  931. )
  932. self.ifgate_preact = nn.Linear(
  933. in_features=self.config.hidden_size,
  934. out_features=2 * self.config.num_heads,
  935. bias=True,
  936. )
  937. self.ogate_act_fn = nn.Sigmoid()
  938. self.mlstm_backend = xLSTMBackend(config=self.config)
  939. self.multihead_norm = xLSTMMultiHeadLayerNorm(
  940. num_heads=self.config.num_heads,
  941. head_dim=self.v_dim // self.config.num_heads,
  942. eps=self.config.norm_eps,
  943. use_weight=True,
  944. use_bias=self.config.use_bias,
  945. force_float32_reductions=self.config.norm_reduction_force_float32,
  946. )
  947. self.out_proj = nn.Linear(
  948. in_features=self.v_dim,
  949. out_features=self.config.hidden_size,
  950. bias=self.config.use_bias,
  951. )
  952. def forward(
  953. self, x: torch.Tensor, state: Optional[mLSTMLayerStateType] = None
  954. ) -> tuple[torch.Tensor, Optional[mLSTMLayerStateType]]:
  955. if x.ndim != 3:
  956. raise ValueError(f"Input must have shape [batch_size, sequence_length, HD], got {x.shape}")
  957. batch_size, sequence_length, _ = x.shape
  958. if self.config.weight_mode == "single":
  959. query = self.q(x)
  960. key = self.k(x)
  961. value = self.v(x)
  962. o_preact = self.ogate_preact(x)
  963. i_preact = soft_cap(self.igate_preact(x), cap_value=self.config.gate_soft_cap)
  964. f_preact = soft_cap(self.fgate_preact(x), cap_value=self.config.gate_soft_cap)
  965. elif self.config.weight_mode == "fused":
  966. qkv_opreact = self.qkv_opreact(x)
  967. query, key, value, o_preact = torch.tensor_split(
  968. qkv_opreact,
  969. (
  970. self.qk_dim,
  971. 2 * self.qk_dim,
  972. 2 * self.qk_dim + self.v_dim,
  973. ),
  974. dim=-1,
  975. )
  976. if_preact = soft_cap(self.ifgate_preact(x), cap_value=self.config.gate_soft_cap)
  977. i_preact, f_preact = torch.tensor_split(if_preact, (self.config.num_heads,), dim=-1)
  978. query = query.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  979. key = key.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  980. value = value.reshape(batch_size, sequence_length, self.config.num_heads, -1).transpose(1, 2)
  981. i_preact = i_preact.transpose(1, 2)
  982. f_preact = f_preact.transpose(1, 2)
  983. if state is None:
  984. c_initial, n_initial, m_initial = None, None, None
  985. else:
  986. c_initial, n_initial, m_initial = state
  987. h, state = self.mlstm_backend(
  988. query=query,
  989. key=key,
  990. value=value,
  991. igate=i_preact,
  992. fgate=f_preact,
  993. c_initial=c_initial,
  994. n_initial=n_initial,
  995. m_initial=m_initial,
  996. )
  997. expected_h_shape = (
  998. batch_size,
  999. self.config.num_heads,
  1000. sequence_length,
  1001. self.v_dim // self.config.num_heads,
  1002. )
  1003. if h.shape != expected_h_shape:
  1004. raise ValueError(f"Got {h.shape}, expected {expected_h_shape}")
  1005. h = h.transpose(1, 2)
  1006. h_norm = self.multihead_norm(h)
  1007. h_norm = h_norm.reshape(batch_size, sequence_length, -1)
  1008. h_out = self.ogate_act_fn(o_preact) * h_norm
  1009. y = self.out_proj(h_out)
  1010. return y, state
  1011. class xLSTMBlock(nn.Module):
  1012. def __init__(self, config: xLSTMConfig):
  1013. super().__init__()
  1014. self.config = config
  1015. self.norm_mlstm = xLSTMRMSNorm(
  1016. num_features=config.hidden_size,
  1017. eps=config.norm_eps,
  1018. use_weight=True,
  1019. use_bias=config.use_bias,
  1020. force_float32_reductions=config.norm_reduction_force_float32,
  1021. )
  1022. self.mlstm_layer = xLSTMLayer(config)
  1023. self.norm_ffn = xLSTMRMSNorm(
  1024. num_features=config.hidden_size,
  1025. eps=config.norm_eps,
  1026. use_weight=True,
  1027. use_bias=config.use_bias,
  1028. force_float32_reductions=config.norm_reduction_force_float32,
  1029. )
  1030. self.ffn = xLSTMFeedForward(config)
  1031. def forward(
  1032. self, x: torch.Tensor, state: Optional[mLSTMStateType] = None
  1033. ) -> tuple[torch.Tensor, mLSTMStateType]:
  1034. x_mlstm = self.norm_mlstm(x)
  1035. x_mlstm, state = self.mlstm_layer(x_mlstm, state)
  1036. x = x + x_mlstm
  1037. x_ffn = self.norm_ffn(x)
  1038. x_ffn = self.ffn(x_ffn)
  1039. x = x + x_ffn
  1040. return x, state
  1041. def small_init_method(dim):
  1042. """
  1043. Adapted from: https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
  1044. Fills the input Tensor with values according to the method described in Transformers without Tears: Improving
  1045. the Normalization of Self-Attention - Nguyen, T. & Salazar, J. (2019), using a normal distribution."""
  1046. std = (2 / (5 * dim)) ** (1 / 2)
  1047. def init_(tensor):
  1048. return torch.nn.init.normal_(tensor, mean=0.0, std=std)
  1049. return init_
  1050. def wang_init_method(n_layers, dim):
  1051. """
  1052. Adapted from https://github.com/EleutherAI/gpt-neox/blob/main/megatron/model/init_functions.py
  1053. """
  1054. std = 2 / n_layers / dim ** (1 / 2)
  1055. def init_(tensor):
  1056. return torch.nn.init.normal_(tensor, mean=0.0, std=std)
  1057. return init_
  1058. class xLSTMPreTrainedModel(PreTrainedModel):
  1059. """
  1060. An abstract class for an interface to loading a pre-trained xLSTM model.
  1061. """
  1062. config_class = xLSTMConfig
  1063. base_model_prefix = "backbone"
  1064. _no_split_modules = ["xLSTMBlock"]
  1065. supports_gradient_checkpointing = True
  1066. _is_stateful = True
  1067. def _module_name_map(self, module):
  1068. for name, mod in self.named_modules():
  1069. if mod is module:
  1070. return name
  1071. return ""
  1072. def _init_weights(self, module):
  1073. if isinstance(module, nn.Embedding):
  1074. small_init_method(self.config.hidden_size)(self.embeddings.weight)
  1075. elif isinstance(module, nn.Linear):
  1076. if module.bias is not None:
  1077. torch.nn.init.zeros_(module.bias)
  1078. if self.config.weight_mode == "single" and "gate" in self._module_name_map(module):
  1079. torch.nn.init.zeros_(module.weight)
  1080. with torch.no_grad():
  1081. if "igate" in self._module_name_map(module):
  1082. module.bias.copy_(-10.0 * torch.ones_like(module.bias))
  1083. elif "fgate" in self._module_name_map(module):
  1084. module.bias.copy_(
  1085. torch.linspace(
  1086. 3.0,
  1087. 6.0,
  1088. module.bias.shape[-1],
  1089. ).to(
  1090. device=module.bias.device,
  1091. dtype=module.bias.dtype,
  1092. )
  1093. )
  1094. elif self.config.weight_mode == "fused" and "gate" in self._module_name_map(module):
  1095. torch.nn.init.zeros_(module.weight)
  1096. with torch.no_grad():
  1097. module.bias[: self.config.num_heads] += -module.bias[
  1098. : self.config.num_heads
  1099. ] - 10.0 * torch.ones_like(module.bias)
  1100. module.bias[: self.config.num_heads] += -module.bias[self.config.num_heads :] + torch.linspace(
  1101. 3.0,
  1102. 6.0,
  1103. module.bias.shape[-1],
  1104. ).to(
  1105. device=module.bias.device,
  1106. dtype=module.bias.dtype,
  1107. )
  1108. elif "proj_down" in self._module_name_map(module):
  1109. wang_init_method(dim=module.weight.shape[1], n_layers=self.config.num_hidden_layers)(module.weight)
  1110. elif "out_proj" in self._module_name_map(module):
  1111. wang_init_method(dim=self.config.hidden_size, n_layers=self.config.num_hidden_layers)(module.weight)
  1112. elif module.weight is not None:
  1113. small_init_method(self.config.hidden_size)(module.weight)
  1114. elif isinstance(module, xLSTMRMSNorm) or hasattr(module, "_layer_normalize"):
  1115. torch.nn.init.ones_(module.weight)
  1116. if hasattr(module, "bias") and module.bias is not None:
  1117. torch.nn.init.zeros_(module.bias)
  1118. class xLSTMCache:
  1119. """
  1120. Cache for xLSTM model which does not have attention mechanism and key value states.
  1121. Arguments:
  1122. config (`PretrainedConfig):
  1123. The configuration file defining the shape-related attributes required to initialize the static cache.
  1124. max_batch_size (`int`):
  1125. The batch size with which the model will be used.
  1126. dtype (`torch.dtype`, *optional*, defaults to `torch.bfloat16`):
  1127. The default `dtype` to use when initializing the layer.
  1128. device (`torch.device` or `str`, *optional*):
  1129. The device on which the cache should be initialized. Should be the same as the layer.
  1130. Attributes:
  1131. seqlen_offset: int
  1132. dtype: torch.dtype
  1133. Example:
  1134. ```python
  1135. >>> from transformers import AutoTokenizer, xLSTMForCausalLM, xLSTMCache
  1136. >>> model = xLSTMForCausalLM.from_pretrained("NX-AI/xLSTM-7b")
  1137. >>> tokenizer = xLSTMTokenizer.from_pretrained("NX-AI/xLSTM-7b")
  1138. >>> inputs = tokenizer(text="I am an xLSTM", return_tensors="pt")
  1139. >>> # Prepare a cache class and pass it to model's forward
  1140. >>> cache_params = xLSTMCache(config=model.config, max_batch_size=1, device=model.device, dtype=model.dtype)
  1141. >>> outputs = model(**inputs, cache_params=cache_params, use_cache=True)
  1142. >>> outputs.cache_params
  1143. xLSTMCache()
  1144. """
  1145. def __init__(
  1146. self,
  1147. config: xLSTMConfig,
  1148. max_batch_size: int,
  1149. dtype: torch.dtype = torch.bfloat16,
  1150. device: Optional[str] = None,
  1151. **kwargs,
  1152. ):
  1153. self.seqlen_offset = 0
  1154. self.dtype = dtype
  1155. self.config = config
  1156. self.rnn_state = {
  1157. layer: (
  1158. torch.zeros(
  1159. [max_batch_size, config.num_heads, config.qk_head_dim, config.v_head_dim],
  1160. dtype=dtype,
  1161. device=device,
  1162. ),
  1163. torch.zeros([max_batch_size, config.num_heads, config.qk_head_dim], dtype=dtype, device=device),
  1164. torch.zeros([max_batch_size, config.num_heads, 1], dtype=dtype, device=device),
  1165. )
  1166. for layer in range(config.num_hidden_layers)
  1167. }
  1168. def reset(self):
  1169. self.rnn_state = {
  1170. layer: (
  1171. torch.zeros_like(self.rnn_state[layer][0]),
  1172. torch.zeros_like(self.rnn_state[layer][1]),
  1173. torch.zeros_like(self.rnn_state[layer][2]),
  1174. )
  1175. for layer in self.rnn_state
  1176. }
  1177. @dataclass
  1178. @auto_docstring
  1179. class xLSTMOutput(ModelOutput):
  1180. r"""
  1181. cache_params (`xLSTMCache`):
  1182. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  1183. avoid providing the old `input_ids`.
  1184. """
  1185. last_hidden_state: Optional[torch.FloatTensor]
  1186. cache_params: Optional[xLSTMCache] = None
  1187. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1188. @auto_docstring
  1189. class xLSTMModel(xLSTMPreTrainedModel):
  1190. def __init__(self, config):
  1191. super().__init__(config)
  1192. # use embbeding_dim and num_blocks once here to make use of them
  1193. self.embeddings = nn.Embedding(config.vocab_size, config.embedding_dim)
  1194. self.blocks = nn.ModuleList([xLSTMBlock(config) for _ in range(config.num_blocks)])
  1195. self.out_norm = xLSTMRMSNorm(config.hidden_size, eps=config.norm_eps)
  1196. self.gradient_checkpointing = False
  1197. # Initialize weights and apply final processing
  1198. self.post_init()
  1199. def get_input_embeddings(self):
  1200. return self.embeddings
  1201. def set_input_embeddings(self, new_embedding):
  1202. self.embeddings = new_embedding
  1203. @can_return_tuple
  1204. @auto_docstring
  1205. def forward(
  1206. self,
  1207. input_ids: Optional[torch.LongTensor] = None,
  1208. inputs_embeds: Optional[torch.LongTensor] = None,
  1209. cache_params: Optional[xLSTMCache] = None,
  1210. use_cache: Optional[bool] = None,
  1211. output_hidden_states: Optional[bool] = None,
  1212. **kwargs,
  1213. ) -> Union[tuple, xLSTMOutput]:
  1214. r"""
  1215. cache_params (`xLSTMCache`, *optional*):
  1216. The xLSTMCache that carries the RNN states.
  1217. """
  1218. output_hidden_states = (
  1219. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  1220. )
  1221. use_cache = use_cache if use_cache is not None else (self.config.use_cache if not self.training else False)
  1222. if self.gradient_checkpointing and self.training and use_cache:
  1223. use_cache = False
  1224. if (input_ids is None) ^ (inputs_embeds is not None):
  1225. raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
  1226. if inputs_embeds is None:
  1227. inputs_embeds = self.embeddings(input_ids)
  1228. if use_cache and cache_params is None:
  1229. cache_params = xLSTMCache(
  1230. self.config, inputs_embeds.size(0), device=inputs_embeds.device, dtype=inputs_embeds.dtype
  1231. )
  1232. hidden_states = inputs_embeds
  1233. if (
  1234. not self.training
  1235. and self.config.max_inference_chunksize < hidden_states.shape[1]
  1236. and not output_hidden_states
  1237. ):
  1238. offset = 0
  1239. with torch.no_grad():
  1240. if cache_params is None:
  1241. cache_params = xLSTMCache(config=self.config, max_batch_size=hidden_states.shape[0])
  1242. final_state = torch.zeros_like(hidden_states)
  1243. while offset < hidden_states.shape[1]:
  1244. hidden_states_chunk = hidden_states[
  1245. :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
  1246. ]
  1247. for layer_idx, xlstm_block in enumerate(self.blocks):
  1248. hidden_states_chunk, rnn_state = xlstm_block(
  1249. hidden_states_chunk,
  1250. state=cache_params.rnn_state[layer_idx],
  1251. )
  1252. for state_idx in range(len(cache_params.rnn_state[layer_idx])):
  1253. local_rnn_state = rnn_state[state_idx]
  1254. cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
  1255. cache_params.rnn_state_initial = False
  1256. final_state[
  1257. :, offset : min(offset + self.config.max_inference_chunksize, hidden_states.shape[1])
  1258. ] = hidden_states_chunk
  1259. offset += self.config.max_inference_chunksize
  1260. hidden_states = final_state
  1261. else:
  1262. all_hidden_states = () if output_hidden_states else None
  1263. for layer_idx, xlstm_block in enumerate(self.blocks):
  1264. if self.gradient_checkpointing and self.training:
  1265. hidden_states, rnn_state = self._gradient_checkpointing_func(
  1266. xlstm_block.__call__,
  1267. hidden_states,
  1268. cache_params.rnn_state[layer_idx] if cache_params is not None else None,
  1269. )
  1270. else:
  1271. hidden_states, rnn_state = xlstm_block(
  1272. hidden_states,
  1273. state=cache_params.rnn_state[layer_idx] if cache_params is not None else None,
  1274. )
  1275. if cache_params:
  1276. for state_idx in range(len(cache_params.rnn_state[layer_idx])):
  1277. local_rnn_state = rnn_state[state_idx]
  1278. cache_params.rnn_state[layer_idx][state_idx].copy_(local_rnn_state)
  1279. cache_params.rnn_state_initial = False
  1280. if output_hidden_states:
  1281. all_hidden_states = all_hidden_states + (hidden_states,)
  1282. if use_cache:
  1283. cache_params.seqlen_offset += inputs_embeds.shape[1]
  1284. hidden_states = self.out_norm(hidden_states)
  1285. if output_hidden_states:
  1286. all_hidden_states = all_hidden_states + (hidden_states,)
  1287. return xLSTMOutput(
  1288. last_hidden_state=hidden_states,
  1289. cache_params=cache_params,
  1290. hidden_states=all_hidden_states,
  1291. )
  1292. @dataclass
  1293. @auto_docstring
  1294. class xLSTMCausalLMOutput(ModelOutput):
  1295. r"""
  1296. loss (`torch.FloatTensor` of shape `(1,)`, *optional*, returned when `labels` is provided):
  1297. Language modeling loss (for next-token prediction).
  1298. logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
  1299. Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
  1300. cache_params (`xLSTMCache`, *optional*, carrying the RNN states):
  1301. The state of the model at the last time step. Can be used in a forward method with the next `input_ids` to
  1302. avoid providing the old `input_ids`.
  1303. """
  1304. loss: Optional[torch.FloatTensor] = None
  1305. logits: Optional[torch.FloatTensor] = None
  1306. cache_params: Optional[xLSTMCache] = None
  1307. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  1308. @auto_docstring
  1309. class xLSTMForCausalLM(xLSTMPreTrainedModel, GenerationMixin):
  1310. def __init__(self, config):
  1311. super().__init__(config)
  1312. self.backbone = xLSTMModel(config)
  1313. self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
  1314. # Initialize weights and apply final processing
  1315. self.post_init()
  1316. def get_output_embeddings(self):
  1317. return self.lm_head
  1318. def set_output_embeddings(self, new_embeddings):
  1319. self.lm_head = new_embeddings
  1320. def get_input_embeddings(self):
  1321. return self.backbone.get_input_embeddings()
  1322. def set_input_embeddings(self, new_embeddings):
  1323. return self.backbone.set_input_embeddings(new_embeddings)
  1324. def prepare_inputs_for_generation(
  1325. self,
  1326. input_ids,
  1327. attention_mask=None, # not used but needed, otherwise generate complains when passing tokenizer inputs
  1328. inputs_embeds=None,
  1329. use_cache=None,
  1330. cache_params: Optional[xLSTMCache] = None,
  1331. **kwargs,
  1332. ):
  1333. if use_cache and cache_params is not None:
  1334. # If the first cache position is non-zero, we assume we are in generation mode.
  1335. # Thus, the cache_params state is assumed to be the state before the last token
  1336. # (lastly generated token), and all previous tokens are already ingested.
  1337. # This should as well support generation from scratch with the [BOS] token inserted first.
  1338. input_ids = input_ids[:, -1:]
  1339. if inputs_embeds is not None:
  1340. inputs_embeds = inputs_embeds[:, -1:]
  1341. if inputs_embeds is not None and cache_params is None:
  1342. model_inputs = {"inputs_embeds": inputs_embeds}
  1343. else:
  1344. model_inputs = {"input_ids": input_ids}
  1345. model_inputs.update({"cache_params": cache_params, "use_cache": use_cache})
  1346. # Forward ALL kwargs that are uninitialized (e.g. `use_cache`).
  1347. for key, value in kwargs.items():
  1348. if key not in model_inputs:
  1349. model_inputs[key] = value
  1350. return model_inputs
  1351. @can_return_tuple
  1352. @auto_docstring
  1353. def forward(
  1354. self,
  1355. input_ids: Optional[torch.LongTensor] = None,
  1356. inputs_embeds: Optional[torch.FloatTensor] = None,
  1357. cache_params: Optional[xLSTMCache] = None,
  1358. labels: Optional[torch.LongTensor] = None,
  1359. use_cache: Optional[bool] = None,
  1360. output_hidden_states: Optional[bool] = None,
  1361. **kwargs,
  1362. ) -> Union[tuple, xLSTMCausalLMOutput]:
  1363. r"""
  1364. cache_params (`xLSTMCache`, *optional*):
  1365. The xLSTMCache that carries the RNN states.
  1366. """
  1367. xlstm_outputs = self.backbone(
  1368. input_ids,
  1369. cache_params=cache_params,
  1370. inputs_embeds=inputs_embeds,
  1371. use_cache=use_cache,
  1372. output_hidden_states=output_hidden_states,
  1373. **kwargs,
  1374. )
  1375. hidden_states = xlstm_outputs[0]
  1376. logits = self.lm_head(hidden_states.to(self.lm_head.weight.dtype)).float()
  1377. if not self.training and self.config.max_inference_chunksize < logits.shape[1]:
  1378. offset = 0
  1379. with torch.no_grad():
  1380. while offset < logits.shape[1]:
  1381. logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])] = soft_cap(
  1382. logits[:, offset : min(offset + self.config.max_inference_chunksize, logits.shape[1])],
  1383. self.config.output_logit_soft_cap,
  1384. )
  1385. offset += self.config.max_inference_chunksize
  1386. else:
  1387. logits = soft_cap(logits, self.config.output_logit_soft_cap)
  1388. loss = None
  1389. if labels is not None:
  1390. # move labels to correct device to enable model parallelism
  1391. labels = labels.to(logits.device)
  1392. # Shift so that tokens < nstate predict nstate
  1393. shift_logits = logits[..., :-1, :].contiguous()
  1394. shift_labels = labels[..., 1:].contiguous()
  1395. # Flatten the tokens
  1396. loss_fct = CrossEntropyLoss()
  1397. loss = loss_fct(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
  1398. return xLSTMCausalLMOutput(
  1399. loss=loss,
  1400. logits=logits,
  1401. cache_params=xlstm_outputs.cache_params,
  1402. hidden_states=xlstm_outputs.hidden_states,
  1403. )
  1404. __all__ = [
  1405. "xLSTMForCausalLM",
  1406. "xLSTMModel",
  1407. "xLSTMPreTrainedModel",
  1408. ]