cache_utils.py 66 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493
  1. from abc import ABC, abstractmethod
  2. from collections.abc import Iterable
  3. from typing import Any, Optional
  4. import torch
  5. from .configuration_utils import PretrainedConfig
  6. from .utils import (
  7. is_hqq_available,
  8. is_quanto_greater,
  9. is_torch_greater_or_equal,
  10. is_torchdynamo_compiling,
  11. logging,
  12. )
  13. if is_hqq_available():
  14. from hqq.core.quantize import Quantizer as HQQQuantizer
  15. _is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
  16. logger = logging.get_logger(__name__)
  17. class CacheLayerMixin(ABC):
  18. """Base, abstract class for a single layer's cache."""
  19. is_compileable = False
  20. def __init__(self):
  21. self.keys: Optional[torch.Tensor] = None
  22. self.values: Optional[torch.Tensor] = None
  23. self.is_initialized = False
  24. def __repr__(self):
  25. return f"{self.__class__.__name__}"
  26. @abstractmethod
  27. def lazy_initialization(self, key_states: torch.Tensor): ...
  28. @abstractmethod
  29. def update(
  30. self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
  31. ) -> tuple[torch.Tensor, torch.Tensor]: ...
  32. @abstractmethod
  33. def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
  34. @abstractmethod
  35. def get_seq_length(self) -> int: ...
  36. @abstractmethod
  37. def get_max_cache_shape(self) -> int: ...
  38. def offload(self):
  39. """Offload this layer's data to CPU device."""
  40. if self.is_initialized:
  41. self.keys = self.keys.to("cpu", non_blocking=True)
  42. self.values = self.values.to("cpu", non_blocking=True)
  43. def prefetch(self):
  44. """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
  45. if self.is_initialized and self.keys.device != self.device:
  46. self.keys = self.keys.to(self.device, non_blocking=True)
  47. self.values = self.values.to(self.device, non_blocking=True)
  48. def reset(self) -> None:
  49. """Resets the cache values while preserving the objects"""
  50. if self.is_initialized:
  51. self.keys.zero_()
  52. self.values.zero_()
  53. # This attribute is set on several Layers
  54. if hasattr(self, "cumulative_length"):
  55. self.cumulative_length = 0
  56. def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
  57. """Reorders this layer's cache for beam search."""
  58. if self.get_seq_length() > 0:
  59. self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
  60. self.values = self.values.index_select(0, beam_idx.to(self.values.device))
  61. class DynamicLayer(CacheLayerMixin):
  62. """
  63. A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
  64. It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
  65. """
  66. is_sliding = False
  67. def lazy_initialization(self, key_states: torch.Tensor):
  68. self.dtype, self.device = key_states.dtype, key_states.device
  69. self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
  70. self.values = torch.tensor([], dtype=self.dtype, device=self.device)
  71. self.is_initialized = True
  72. def update(
  73. self,
  74. key_states: torch.Tensor,
  75. value_states: torch.Tensor,
  76. cache_kwargs: Optional[dict[str, Any]] = None,
  77. ) -> tuple[torch.Tensor, torch.Tensor]:
  78. """
  79. Update the key and value caches in-place, and return the necessary keys and value states.
  80. Args:
  81. key_states (`torch.Tensor`): The new key states to cache.
  82. value_states (`torch.Tensor`): The new value states to cache.
  83. cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
  84. Returns:
  85. tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
  86. """
  87. # Lazy initialization
  88. if not self.is_initialized:
  89. self.lazy_initialization(key_states)
  90. self.keys = torch.cat([self.keys, key_states], dim=-2)
  91. self.values = torch.cat([self.values, value_states], dim=-2)
  92. return self.keys, self.values
  93. def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
  94. """Return the length and offset of the cache, used to generate the mask"""
  95. kv_offset = 0
  96. query_length = cache_position.shape[0]
  97. kv_length = self.get_seq_length() + query_length
  98. return kv_length, kv_offset
  99. def get_seq_length(self) -> int:
  100. """Returns the sequence length of the cached states."""
  101. if not self.is_initialized or self.keys.numel() == 0:
  102. return 0
  103. return self.keys.shape[-2]
  104. def get_max_cache_shape(self) -> int:
  105. """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
  106. return -1
  107. def crop(self, max_length: int) -> None:
  108. """
  109. Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
  110. to remove `max_length` tokens.
  111. """
  112. if max_length < 0:
  113. max_length = self.get_seq_length() - abs(max_length)
  114. if self.get_seq_length() <= max_length:
  115. return
  116. self.keys = self.keys[..., :max_length, :]
  117. self.values = self.values[..., :max_length, :]
  118. def batch_repeat_interleave(self, repeats: int) -> None:
  119. """Repeat the cache `repeats` times in the batch dimension."""
  120. if self.get_seq_length() > 0:
  121. self.keys = self.keys.repeat_interleave(repeats, dim=0)
  122. self.values = self.values.repeat_interleave(repeats, dim=0)
  123. def batch_select_indices(self, indices: torch.Tensor) -> None:
  124. """Only keep the `indices` in the batch dimension of the cache."""
  125. if self.get_seq_length() > 0:
  126. self.keys = self.keys[indices, ...]
  127. self.values = self.values[indices, ...]
  128. class DynamicSlidingWindowLayer(DynamicLayer):
  129. """
  130. A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
  131. It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
  132. """
  133. is_sliding = True
  134. def __init__(self, sliding_window: int):
  135. super().__init__()
  136. self.sliding_window = sliding_window
  137. self.cumulative_length = 0
  138. def update(
  139. self,
  140. key_states: torch.Tensor,
  141. value_states: torch.Tensor,
  142. cache_kwargs: Optional[dict[str, Any]] = None,
  143. ) -> tuple[torch.Tensor, torch.Tensor]:
  144. """
  145. Update the key and value caches in-place, and return the necessary keys and value states.
  146. Args:
  147. key_states (`torch.Tensor`): The new key states to cache.
  148. value_states (`torch.Tensor`): The new value states to cache.
  149. cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
  150. Returns:
  151. tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
  152. """
  153. # Lazy initialization
  154. if not self.is_initialized:
  155. self.lazy_initialization(key_states)
  156. self.cumulative_length += key_states.shape[-2]
  157. # Compute the full states
  158. full_key_states = torch.cat([self.keys, key_states], dim=-2)
  159. full_value_states = torch.cat([self.values, value_states], dim=-2)
  160. # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
  161. self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
  162. self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
  163. # Return the full states
  164. return full_key_states, full_value_states
  165. def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
  166. """Return the length and offset of the cache, used to generate the attention mask"""
  167. query_length = cache_position.shape[0]
  168. is_full = self.cumulative_length >= self.sliding_window
  169. kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
  170. if is_full:
  171. kv_length = self.sliding_window - 1 + query_length
  172. else:
  173. kv_length = self.cumulative_length + query_length
  174. return kv_length, kv_offset
  175. def get_seq_length(self) -> int:
  176. """Returns the sequence length of the cached states."""
  177. return self.cumulative_length
  178. def get_max_cache_shape(self) -> int:
  179. """Return the maximum cache shape of the cache"""
  180. return self.sliding_window
  181. def crop(self, max_length: int) -> None:
  182. """
  183. Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
  184. negative to remove `max_length` tokens.
  185. """
  186. if self.get_seq_length() >= self.sliding_window:
  187. raise ValueError(
  188. "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
  189. "sliding window (otherwise some states are lost)"
  190. )
  191. super().crop(max_length)
  192. self.cumulative_length = self.keys.shape[-2]
  193. class StaticLayer(CacheLayerMixin):
  194. """
  195. A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
  196. It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
  197. Args:
  198. max_cache_len (`int`):
  199. Maximum number of tokens that can be stored, used for tensor preallocation.
  200. """
  201. is_compileable = True
  202. is_sliding = False
  203. def __init__(self, max_cache_len: int):
  204. super().__init__()
  205. self.max_cache_len = max_cache_len
  206. def lazy_initialization(self, key_states: torch.Tensor):
  207. """
  208. Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
  209. num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
  210. devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
  211. If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
  212. function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
  213. internally don't compile the prefill, this is guaranteed to have been called already when compiling.
  214. If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
  215. it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
  216. i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
  217. not be compiled anyway for performances!
  218. """
  219. self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
  220. self.dtype, self.device = key_states.dtype, key_states.device
  221. self.keys = torch.zeros(
  222. (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
  223. dtype=self.dtype,
  224. device=self.device,
  225. )
  226. self.values = torch.zeros(
  227. (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
  228. dtype=self.dtype,
  229. device=self.device,
  230. )
  231. # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
  232. # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
  233. # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
  234. # prefill explicitly, but this should be avoided!)
  235. if not is_torchdynamo_compiling():
  236. torch._dynamo.mark_static_address(self.keys)
  237. torch._dynamo.mark_static_address(self.values)
  238. self.is_initialized = True
  239. def update(
  240. self,
  241. key_states: torch.Tensor,
  242. value_states: torch.Tensor,
  243. cache_kwargs: Optional[dict[str, Any]] = None,
  244. ) -> tuple[torch.Tensor, torch.Tensor]:
  245. """
  246. Update the key and value caches in-place, and return the necessary keys and value states.
  247. Args:
  248. key_states (`torch.Tensor`): The new key states to cache.
  249. value_states (`torch.Tensor`): The new value states to cache.
  250. cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
  251. Returns:
  252. tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
  253. """
  254. # Lazy initialization
  255. if not self.is_initialized:
  256. self.lazy_initialization(key_states)
  257. # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
  258. # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
  259. cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
  260. cache_position = (
  261. cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
  262. )
  263. # Update the cache
  264. try:
  265. self.keys.index_copy_(2, cache_position, key_states)
  266. self.values.index_copy_(2, cache_position, value_states)
  267. except NotImplementedError:
  268. # Fallback for devices like MPS where index_copy_ might not be supported.
  269. self.keys[:, :, cache_position] = key_states
  270. self.values[:, :, cache_position] = value_states
  271. return self.keys, self.values
  272. def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
  273. """Return the length and offset of the cache, used to generate the attention mask"""
  274. kv_offset = 0
  275. kv_length = self.max_cache_len
  276. return kv_length, kv_offset
  277. def get_seq_length(self) -> int:
  278. """Returns the sequence length of the cached states."""
  279. # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
  280. # limit the check to the first batch member and head dimension.
  281. return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
  282. def get_max_cache_shape(self) -> int:
  283. """Return the maximum cache shape of the cache"""
  284. return self.max_cache_len
  285. class StaticSlidingWindowLayer(StaticLayer):
  286. """
  287. A static cache layer that stores the key and value states as static tensors of shape
  288. `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
  289. tensors, and then mutates them in-place. Built for `torch.compile` support.
  290. Args:
  291. max_cache_len (`int`):
  292. Maximum number of tokens that can be stored, used for tensor preallocation.
  293. sliding_window (`int`):
  294. The size of the sliding window.
  295. """
  296. is_sliding = True
  297. def __init__(self, max_cache_len: int, sliding_window: int):
  298. effective_max_cache_len = min(sliding_window, max_cache_len)
  299. super().__init__(max_cache_len=effective_max_cache_len)
  300. self.cumulative_length = 0
  301. def update(
  302. self,
  303. key_states: torch.Tensor,
  304. value_states: torch.Tensor,
  305. cache_kwargs: Optional[dict[str, Any]] = None,
  306. ) -> tuple[torch.Tensor, torch.Tensor]:
  307. """
  308. Update the key and value caches in-place, and return the necessary keys and value states.
  309. Args:
  310. key_states (`torch.Tensor`): The new key states to cache.
  311. value_states (`torch.Tensor`): The new value states to cache.
  312. cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
  313. Returns:
  314. tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
  315. """
  316. # Lazy initialization
  317. if not self.is_initialized:
  318. self.lazy_initialization(key_states)
  319. # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
  320. # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
  321. cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
  322. cache_position = (
  323. cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
  324. )
  325. cumulative_length = self.cumulative_length
  326. is_full = cumulative_length >= self.max_cache_len
  327. # Update it now that we saved the value above
  328. self.cumulative_length += key_states.shape[-2]
  329. if is_full:
  330. # In general, we should use a much simpler `cat` here as well, independently of the states size. However,
  331. # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
  332. if key_states.shape[-2] == 1:
  333. # Roll all values to the left by 1 position
  334. new_keys = self.keys.roll(-1, dims=-2)
  335. new_values = self.values.roll(-1, dims=-2)
  336. # Overwrite the last position with new states
  337. # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
  338. index = torch.tensor([-1], dtype=int, device=self.device)
  339. new_keys[:, :, index] = key_states
  340. new_values[:, :, index] = value_states
  341. # Copy back into `self` (do not just assign again) in order to keep the static dynamo address
  342. self.keys.copy_(new_keys)
  343. self.values.copy_(new_values)
  344. # Very important to return the `self` tensors here, as they have the static dynamo address
  345. return self.keys, self.values
  346. # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
  347. else:
  348. full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
  349. full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
  350. # Not yet full, but becoming full on this update
  351. elif cumulative_length + key_states.shape[2] > self.max_cache_len:
  352. # Fast prefill path, no need to cat() in this case, as the cache is currently empty
  353. if cumulative_length == 0:
  354. full_key_states = key_states
  355. full_value_states = value_states
  356. else:
  357. full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
  358. full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
  359. else:
  360. try:
  361. self.keys.index_copy_(2, cache_position, key_states)
  362. self.values.index_copy_(2, cache_position, value_states)
  363. except NotImplementedError:
  364. self.keys[:, :, cache_position] = key_states
  365. self.values[:, :, cache_position] = value_states
  366. # Very important to return the `self` tensors here, as they have the static dynamo address
  367. return self.keys, self.values
  368. # We only cache the last `sliding_window` tokens
  369. self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
  370. self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
  371. # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
  372. return full_key_states, full_value_states
  373. def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
  374. """Return the length and offset of the cache, used to generate the attention mask"""
  375. query_length = cache_position.shape[0]
  376. sliding_window = self.max_cache_len
  377. is_full = self.cumulative_length >= self.max_cache_len
  378. kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
  379. # The cache is already full
  380. if is_full:
  381. kv_length = sliding_window + query_length - 1
  382. # Not yet full, but becoming full on this update
  383. elif self.cumulative_length + query_length > sliding_window:
  384. kv_length = self.cumulative_length + query_length
  385. # Here the Cache is still smaller than the local size, but we return the local size as it's static
  386. else:
  387. kv_length = sliding_window
  388. return kv_length, kv_offset
  389. def get_seq_length(self) -> int:
  390. """Returns the sequence length of the cached states."""
  391. return self.cumulative_length
  392. class QuantizedLayer(DynamicLayer):
  393. """
  394. A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
  395. It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
  396. applying quantization.
  397. The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
  398. is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
  399. precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
  400. for both Keys and Values, in contrast to what was described in the paper.
  401. """
  402. def __init__(
  403. self,
  404. nbits: int = 4,
  405. axis_key: int = 0,
  406. axis_value: int = 0,
  407. q_group_size: int = 64,
  408. residual_length: int = 128,
  409. ):
  410. super().__init__()
  411. self.nbits = nbits
  412. self.axis_key = axis_key
  413. self.axis_value = axis_value
  414. self.q_group_size = q_group_size
  415. self.residual_length = residual_length
  416. self.cumulative_length = 0
  417. def update(
  418. self,
  419. key_states: torch.Tensor,
  420. value_states: torch.Tensor,
  421. cache_kwargs: Optional[dict[str, Any]] = None,
  422. ) -> tuple[torch.Tensor, torch.Tensor]:
  423. """
  424. Update the key and value caches in-place, and return the necessary keys and value states.
  425. Args:
  426. key_states (`torch.Tensor`): The new key states to cache.
  427. value_states (`torch.Tensor`): The new value states to cache.
  428. cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
  429. Returns:
  430. tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
  431. """
  432. self.cumulative_length += key_states.shape[-2]
  433. # Lazy initialization
  434. if not self.is_initialized:
  435. self.lazy_initialization(key_states)
  436. self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
  437. self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
  438. return key_states, value_states
  439. dequant_keys = self._dequantize(self._quantized_keys)
  440. dequant_values = self._dequantize(self._quantized_values)
  441. keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
  442. values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
  443. if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
  444. self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
  445. self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
  446. self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
  447. self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
  448. else:
  449. self.keys = torch.cat([self.keys, key_states], dim=-2)
  450. self.values = torch.cat([self.values, value_states], dim=-2)
  451. return keys_to_return, values_to_return
  452. @abstractmethod
  453. def _quantize(self, tensor, axis): ...
  454. @abstractmethod
  455. def _dequantize(self, q_tensor): ...
  456. def get_seq_length(self) -> int:
  457. """Returns the sequence length of the cached states."""
  458. return self.cumulative_length
  459. class QuantoQuantizedLayer(QuantizedLayer):
  460. def __init__(
  461. self,
  462. nbits: int = 4,
  463. axis_key: int = 0,
  464. axis_value: int = 0,
  465. q_group_size: int = 64,
  466. residual_length: int = 128,
  467. ):
  468. super().__init__(
  469. nbits=nbits,
  470. axis_key=axis_key,
  471. axis_value=axis_value,
  472. q_group_size=q_group_size,
  473. residual_length=residual_length,
  474. )
  475. # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
  476. if is_quanto_greater("0.2.5", accept_dev=True):
  477. from optimum.quanto import MaxOptimizer, qint2, qint4
  478. else:
  479. raise ImportError(
  480. "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
  481. )
  482. if self.nbits not in [2, 4]:
  483. raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
  484. if self.axis_key not in [0, -1]:
  485. raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
  486. if self.axis_value not in [0, -1]:
  487. raise ValueError(
  488. f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
  489. )
  490. self.qtype = qint4 if self.nbits == 4 else qint2
  491. self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
  492. def _quantize(self, tensor, axis):
  493. from optimum.quanto import quantize_weight
  494. scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
  495. qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
  496. return qtensor
  497. def _dequantize(self, qtensor):
  498. return qtensor.dequantize()
  499. class HQQQuantizedLayer(QuantizedLayer):
  500. def __init__(
  501. self,
  502. nbits: int = 4,
  503. axis_key: int = 0,
  504. axis_value: int = 0,
  505. q_group_size: int = 64,
  506. residual_length: int = 128,
  507. ):
  508. super().__init__(
  509. nbits=nbits,
  510. axis_key=axis_key,
  511. axis_value=axis_value,
  512. q_group_size=q_group_size,
  513. residual_length=residual_length,
  514. )
  515. if not is_hqq_available():
  516. raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`")
  517. if self.nbits not in [1, 2, 3, 4, 8]:
  518. raise ValueError(
  519. f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
  520. )
  521. if self.axis_key not in [0, 1]:
  522. raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
  523. if self.axis_value not in [0, 1]:
  524. raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
  525. self.quantizer = HQQQuantizer
  526. def _quantize(self, tensor, axis):
  527. qtensor, meta = self.quantizer.quantize(
  528. tensor,
  529. axis=axis,
  530. device=self.keys.device,
  531. compute_dtype=self.keys.dtype,
  532. nbits=self.nbits,
  533. group_size=self.q_group_size,
  534. )
  535. meta["compute_dtype"] = self.keys.dtype
  536. self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
  537. meta["scale"] = meta["scale"].to(qtensor.device)
  538. meta["zero"] = meta["zero"].to(qtensor.device)
  539. return qtensor, meta
  540. def _dequantize(self, qtensor):
  541. quant_tensor, meta = qtensor
  542. tensor = self.quantizer.dequantize(quant_tensor, meta)
  543. return tensor
  544. class Cache:
  545. """
  546. A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
  547. the Cache of each layer.
  548. Args:
  549. layers (`Optional`, *optional*):
  550. A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
  551. be used.
  552. layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
  553. Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
  554. and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
  555. list of layers.
  556. offloading (`bool`, *optional*, defaults to `False`):
  557. Whether to perform offloading of the layers to `cpu`, to save GPU memory.
  558. offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
  559. If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
  560. usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
  561. """
  562. def __init__(
  563. self,
  564. layers: Optional[list[CacheLayerMixin]] = None,
  565. layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
  566. offloading: bool = False,
  567. offload_only_non_sliding: bool = True,
  568. ):
  569. if layers is not None and layer_class_to_replicate is not None:
  570. raise ValueError(
  571. "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
  572. "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
  573. "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
  574. )
  575. if layers is None and layer_class_to_replicate is None:
  576. raise ValueError(
  577. "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
  578. )
  579. self.layers = layers if layers is not None else []
  580. self.layer_class_to_replicate = layer_class_to_replicate
  581. self.offloading = offloading
  582. if self.offloading:
  583. self.only_non_sliding = offload_only_non_sliding
  584. self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
  585. def __repr__(self):
  586. return f"{self.__class__.__name__}(layers={self.layers})"
  587. def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
  588. """
  589. Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
  590. which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
  591. Note that we use a non-default stream for this, to avoid blocking.
  592. """
  593. if only_non_sliding:
  594. # Try to find next non-sliding, starting at `layer_idx`
  595. try:
  596. layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
  597. # In this case, we need to circle back to the beginning
  598. except ValueError:
  599. layer_idx = self.is_sliding.index(False)
  600. else:
  601. layer_idx = layer_idx if layer_idx < len(self.layers) else 0
  602. # Prefetch
  603. with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
  604. self.layers[layer_idx].prefetch()
  605. def offload(self, layer_idx: int, only_non_sliding: bool = True):
  606. """
  607. Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
  608. non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
  609. computation in the layer's `update` methods are finished.
  610. """
  611. if not (only_non_sliding and self.is_sliding[layer_idx]):
  612. self.layers[layer_idx].offload()
  613. def update(
  614. self,
  615. key_states: torch.Tensor,
  616. value_states: torch.Tensor,
  617. layer_idx: int,
  618. cache_kwargs: Optional[dict[str, Any]] = None,
  619. ) -> tuple[torch.Tensor, torch.Tensor]:
  620. """
  621. Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
  622. Parameters:
  623. key_states (`torch.Tensor`):
  624. The new key states to cache.
  625. value_states (`torch.Tensor`):
  626. The new value states to cache.
  627. layer_idx (`int`):
  628. The index of the layer to cache the states for.
  629. cache_kwargs (`dict[str, Any]`, *optional*):
  630. Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
  631. cache to be created.
  632. Return:
  633. A tuple containing the updated key and value states.
  634. """
  635. # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
  636. if self.layer_class_to_replicate is not None:
  637. while len(self.layers) <= layer_idx:
  638. self.layers.append(self.layer_class_to_replicate())
  639. if self.offloading:
  640. # Wait for the stream to finish if needed, and start prefetching the next layer
  641. torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
  642. self.prefetch(layer_idx + 1, self.only_non_sliding)
  643. keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
  644. if self.offloading:
  645. self.offload(layer_idx, self.only_non_sliding)
  646. return keys, values
  647. def early_initialization(
  648. self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
  649. ):
  650. """
  651. Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
  652. This is useful for our `export` recipes, as `export` needs everything in advance.
  653. """
  654. # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
  655. # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
  656. # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
  657. fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
  658. # Init all layers
  659. for layer in self.layers:
  660. layer.lazy_initialization(fake_keys_tensor)
  661. def get_seq_length(self, layer_idx: int = 0) -> int:
  662. """Returns the sequence length of the cache for the given layer."""
  663. if layer_idx >= len(self.layers):
  664. return 0
  665. return self.layers[layer_idx].get_seq_length()
  666. def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
  667. """
  668. Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
  669. the given layer at `layer_idx`.
  670. The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
  671. """
  672. # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
  673. # simply the shape of `cache_position`
  674. if layer_idx >= len(self.layers):
  675. return cache_position.shape[0], 0
  676. return self.layers[layer_idx].get_mask_sizes(cache_position)
  677. def get_max_cache_shape(self, layer_idx: int = 0) -> int:
  678. """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
  679. # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
  680. # as DynamicLayer does
  681. if layer_idx >= len(self.layers):
  682. return -1
  683. return self.layers[layer_idx].get_max_cache_shape()
  684. def reset(self):
  685. """Recursively reset all layers tensors"""
  686. for layer_idx in range(len(self.layers)):
  687. self.layers[layer_idx].reset()
  688. def reorder_cache(self, beam_idx: torch.LongTensor):
  689. """Reorder the cache for beam search"""
  690. for layer_idx in range(len(self.layers)):
  691. self.layers[layer_idx].reorder_cache(beam_idx)
  692. def crop(self, max_length: int):
  693. """Crop the cache to the given length"""
  694. for layer_idx in range(len(self.layers)):
  695. self.layers[layer_idx].crop(max_length)
  696. def batch_repeat_interleave(self, repeats: int):
  697. """Repeat and interleave the cache"""
  698. for layer_idx in range(len(self.layers)):
  699. self.layers[layer_idx].batch_repeat_interleave(repeats)
  700. def batch_select_indices(self, indices: torch.Tensor):
  701. """Select indices from the cache"""
  702. for layer_idx in range(len(self.layers)):
  703. self.layers[layer_idx].batch_select_indices(indices)
  704. @property
  705. def max_batch_size(self) -> int:
  706. """Return the maximum batch size of the cache"""
  707. values = [layer.max_batch_size for layer in self.layers]
  708. if len(set(values)) > 1:
  709. raise ValueError(f"Max batch size is not consistent across layers: {values}")
  710. return values[0]
  711. @property
  712. def max_cache_len(self) -> int:
  713. """Return the maximum cache length of the cache"""
  714. values = [layer.max_cache_len for layer in self.layers]
  715. return max(values)
  716. @property
  717. def is_compileable(self) -> bool:
  718. """Return whether the cache is compileable"""
  719. # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
  720. if len(self.layers) == 0:
  721. return False
  722. return all(layer.is_compileable for layer in self.layers)
  723. @property
  724. def is_initialized(self) -> bool:
  725. """Return whether the cache data is initialized"""
  726. return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
  727. @property
  728. def is_sliding(self) -> list[bool]:
  729. """Return whether the layers of the cache are sliding window"""
  730. return [getattr(layer, "is_sliding", False) for layer in self.layers]
  731. def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
  732. """
  733. Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
  734. sequence length.
  735. """
  736. if layer_idx < len(self.layers):
  737. return self.layers[layer_idx].keys, self.layers[layer_idx].values
  738. else:
  739. raise KeyError(
  740. f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
  741. )
  742. def __iter__(self):
  743. """
  744. Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
  745. keys and values
  746. """
  747. for layer_idx in range(len(self)):
  748. yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
  749. def __len__(self):
  750. """
  751. This value corresponds to the number of layers in the model.
  752. """
  753. # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
  754. # forward through all the layers
  755. return len(self.layers)
  756. class DynamicCache(Cache):
  757. """
  758. A cache that grows dynamically as more tokens are generated. This is the default for generative models.
  759. It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
  760. in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
  761. If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
  762. memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
  763. See `Cache` for details on common methods that are implemented by all cache classes.
  764. Args:
  765. ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
  766. It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
  767. `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
  768. for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
  769. Note: it needs to be the 1st arg as well to work correctly
  770. config (`PretrainedConfig`, *optional*):
  771. The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
  772. or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
  773. `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
  774. offloading (`bool`, *optional*, defaults to `False`):
  775. Whether to perform offloading of the layers to `cpu`, to save GPU memory.
  776. offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
  777. If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
  778. usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
  779. Example:
  780. ```python
  781. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
  782. >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  783. >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
  784. >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
  785. >>> # Prepare a cache class and pass it to model's forward
  786. >>> past_key_values = DynamicCache(config=model.config)
  787. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  788. >>> outputs.past_key_values # access cache filled with key/values from generation
  789. ```
  790. """
  791. def __init__(
  792. self,
  793. ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
  794. config: Optional[PretrainedConfig] = None,
  795. offloading: bool = False,
  796. offload_only_non_sliding: bool = False,
  797. ):
  798. layers = []
  799. # If a config is passed, use it to infer the layer types and initialize accordingly
  800. if config is not None:
  801. decoder_config = config.get_text_config(decoder=True)
  802. sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
  803. decoder_config, "attention_chunk_size", None
  804. )
  805. layer_types = getattr(decoder_config, "layer_types", None)
  806. if layer_types is None:
  807. layer_types = [
  808. "sliding_attention" if sliding_window is not None else "full_attention"
  809. for _ in range(decoder_config.num_hidden_layers)
  810. ]
  811. # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
  812. if hasattr(decoder_config, "num_kv_shared_layers"):
  813. layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
  814. for layer_type in layer_types:
  815. # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
  816. # states they should return - only the mask changes to make them different at the end!
  817. if layer_type in ("sliding_attention", "chunked_attention"):
  818. layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
  819. else:
  820. layers.append(DynamicLayer())
  821. # In this case, use the passed data to already fill in the Cache
  822. if ddp_cache_data is not None:
  823. # Init all the layers with the data
  824. for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
  825. # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
  826. if config is None:
  827. layers.append(DynamicLayer())
  828. # Update the layer with the data
  829. _, _ = layers[layer_idx].update(key_states, value_states)
  830. # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
  831. if len(layers) == 0:
  832. super().__init__(
  833. layer_class_to_replicate=DynamicLayer,
  834. offloading=offloading,
  835. offload_only_non_sliding=offload_only_non_sliding,
  836. )
  837. else:
  838. super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
  839. def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
  840. """
  841. Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
  842. backward compatibility.
  843. """
  844. legacy_cache = ()
  845. for layer in self.layers:
  846. legacy_cache += ((layer.keys, layer.values),)
  847. return legacy_cache
  848. @classmethod
  849. def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
  850. """
  851. Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
  852. backward compatibility.
  853. """
  854. cache = cls()
  855. if past_key_values is None:
  856. logger.warning_once("past_key_values should not be None in from_legacy_cache()")
  857. if past_key_values is not None:
  858. for layer_idx in range(len(past_key_values)):
  859. key_states, value_states = past_key_values[layer_idx]
  860. cache.update(key_states, value_states, layer_idx)
  861. return cache
  862. class StaticCache(Cache):
  863. """
  864. Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
  865. for potential hybrid cache structure, and initialize each layer accordingly.
  866. See `Cache` for details on common methods that are implemented by all cache classes.
  867. Args:
  868. config (`PretrainedConfig`):
  869. The config of the model for which this Cache will be used. It will be used to check for sliding
  870. or hybrid layer structure, and initialize each layer accordingly.
  871. max_cache_len (`int`):
  872. The maximum number of tokens that this Cache should hold.
  873. offloading (`bool`, *optional*, defaults to `False`):
  874. Whether to perform offloading of the layers to `cpu`, to save GPU memory.
  875. offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
  876. If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
  877. usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
  878. Example:
  879. ```python
  880. >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
  881. >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
  882. >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
  883. >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
  884. >>> # Prepare a cache class and pass it to model's forward
  885. >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
  886. >>> max_generated_length = inputs.input_ids.shape[1] + 10
  887. >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
  888. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  889. >>> outputs.past_key_values # access cache filled with key/values from generation
  890. StaticCache()
  891. ```
  892. """
  893. # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
  894. def __init__(
  895. self,
  896. config: PretrainedConfig,
  897. max_cache_len: int,
  898. offloading: bool = False,
  899. offload_only_non_sliding: bool = True,
  900. **kwargs,
  901. ):
  902. config = config.get_text_config(decoder=True)
  903. layer_types = getattr(config, "layer_types", None)
  904. # If `layer_types` is not explicitly provided, infer if the model is fully sliding
  905. if layer_types is None:
  906. if getattr(config, "sliding_window", None) is not None:
  907. layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
  908. elif getattr(config, "attention_chunk_size", None) is not None:
  909. layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
  910. else:
  911. layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
  912. # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
  913. if hasattr(config, "num_kv_shared_layers"):
  914. layer_types = layer_types[: -config.num_kv_shared_layers]
  915. layers = []
  916. for layer_type in layer_types:
  917. if layer_type == "sliding_attention":
  918. layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
  919. elif layer_type == "chunked_attention":
  920. # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
  921. # states they should return - only the mask changes to make them different at the end!
  922. layer = StaticSlidingWindowLayer(
  923. max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
  924. )
  925. else:
  926. layer = StaticLayer(max_cache_len=max_cache_len)
  927. layers.append(layer)
  928. super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
  929. class QuantizedCache(Cache):
  930. """
  931. A quantizer cache similar to what is described in the
  932. [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
  933. It allows the model to generate longer sequence length without allocating too much memory for keys and values
  934. by applying quantization.
  935. The cache has two types of storage, one for original precision and one for the
  936. quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
  937. length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
  938. The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
  939. described in the paper.
  940. See `Cache` for details on common methods that are implemented by all cache classes.
  941. Args:
  942. backend (`str`):
  943. The quantization backend to use. One of `("quanto", "hqq").
  944. config (`PretrainedConfig`):
  945. The config of the model for which this Cache will be used.
  946. nbits (`int`, *optional*, defaults to 4):
  947. The number of bits for quantization.
  948. axis_key (`int`, *optional*, defaults to 0):
  949. The axis on which to quantize the keys.
  950. axis_value (`int`, *optional*, defaults to 0):
  951. The axis on which to quantize the values.
  952. q_group_size (`int`, *optional*, defaults to 64):
  953. Quantization is done per-channel according to a set `q_group_size` for both keys and values.
  954. residual_length (`int`, *optional*, defaults to 128):
  955. Maximum capacity for the original precision cache
  956. """
  957. def __init__(
  958. self,
  959. backend: str,
  960. config: PretrainedConfig,
  961. nbits: int = 4,
  962. axis_key: int = 0,
  963. axis_value: int = 0,
  964. q_group_size: int = 64,
  965. residual_length: int = 128,
  966. ):
  967. if backend == "quanto":
  968. layer_class = QuantoQuantizedLayer
  969. elif backend == "hqq":
  970. layer_class = HQQQuantizedLayer
  971. else:
  972. raise ValueError(f"Unknown quantization backend `{backend}`")
  973. config = config.get_text_config(decoder=True)
  974. layers = [
  975. layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
  976. for _ in range(config.num_hidden_layers)
  977. ]
  978. super().__init__(layers=layers)
  979. class EncoderDecoderCache(Cache):
  980. """
  981. Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
  982. cross-attention caches.
  983. See `Cache` for details on common methods that are implemented by all cache classes.
  984. Args:
  985. caches (`Iterable`):
  986. Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
  987. second one for cross-attention. Can optionally also be an iterable of length 1, containing a
  988. `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
  989. Example:
  990. ```python
  991. >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
  992. >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
  993. >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
  994. >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
  995. >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
  996. >>> self_attention_cache = DynamicCache(config=self.config)
  997. >>> cross_attention_cache = DynamicCache(config=self.config)
  998. >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
  999. >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
  1000. >>> outputs.past_key_values # access cache filled with key/values from generation
  1001. EncoderDecoderCache()
  1002. ```
  1003. """
  1004. def __init__(self, *caches) -> None:
  1005. # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
  1006. if len(caches) == 1:
  1007. self.self_attention_cache = DynamicCache()
  1008. self.cross_attention_cache = DynamicCache()
  1009. # Populate cache from the iterable
  1010. for layer_idx, key_value_states in enumerate(caches[0]):
  1011. key_states, value_states = key_value_states[:2]
  1012. self.self_attention_cache.update(key_states, value_states, layer_idx)
  1013. if len(key_value_states) > 2:
  1014. key_states, value_states = key_value_states[2:]
  1015. self.cross_attention_cache.update(key_states, value_states, layer_idx)
  1016. # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
  1017. elif len(caches) == 2:
  1018. if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
  1019. raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
  1020. self.self_attention_cache = caches[0]
  1021. self.cross_attention_cache = caches[1]
  1022. # Error case
  1023. else:
  1024. raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
  1025. self.is_updated = {}
  1026. for layer_idx in range(len(self.cross_attention_cache)):
  1027. self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
  1028. def __repr__(self) -> str:
  1029. return (
  1030. f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
  1031. f"{self.cross_attention_cache})"
  1032. )
  1033. def __iter__(self):
  1034. """
  1035. Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
  1036. keys and values
  1037. """
  1038. for layer_idx in range(len(self)):
  1039. yield (
  1040. self.self_attention_cache.layers[layer_idx].keys,
  1041. self.self_attention_cache.layers[layer_idx].values,
  1042. self.cross_attention_cache.layers[layer_idx].keys,
  1043. self.cross_attention_cache.layers[layer_idx].values,
  1044. )
  1045. def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
  1046. """
  1047. Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
  1048. sequence length.
  1049. """
  1050. if layer_idx < len(self):
  1051. return (
  1052. self.self_attention_cache.layers[layer_idx].keys,
  1053. self.self_attention_cache.layers[layer_idx].values,
  1054. self.cross_attention_cache.layers[layer_idx].keys,
  1055. self.cross_attention_cache.layers[layer_idx].values,
  1056. )
  1057. else:
  1058. raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
  1059. def __len__(self):
  1060. """
  1061. Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
  1062. to the number of layers in the model.
  1063. """
  1064. return len(self.self_attention_cache)
  1065. def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
  1066. """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
  1067. legacy_cache = ()
  1068. if len(self.cross_attention_cache) > 0:
  1069. for self_attn, cross_attn in zip(
  1070. self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
  1071. ):
  1072. legacy_cache += (self_attn + cross_attn,)
  1073. else:
  1074. legacy_cache = self.self_attention_cache.to_legacy_cache()
  1075. return legacy_cache
  1076. @classmethod
  1077. def from_legacy_cache(
  1078. cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
  1079. ) -> "EncoderDecoderCache":
  1080. """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
  1081. cache = cls(DynamicCache(), DynamicCache())
  1082. if past_key_values is None:
  1083. logger.warning_once("past_key_values should not be None in from_legacy_cache()")
  1084. else:
  1085. for layer_idx, key_value_states in enumerate(past_key_values):
  1086. key_states, value_states = key_value_states[:2]
  1087. cache.self_attention_cache.update(key_states, value_states, layer_idx)
  1088. if len(key_value_states) > 2:
  1089. key_states, value_states = key_value_states[2:]
  1090. cache.cross_attention_cache.update(key_states, value_states, layer_idx)
  1091. cache.is_updated[layer_idx] = True
  1092. return cache
  1093. def get_seq_length(self, layer_idx: int = 0) -> int:
  1094. """Returns the sequence length of the cached states. A layer index can be optionally passed."""
  1095. return self.self_attention_cache.get_seq_length(layer_idx)
  1096. def reset(self):
  1097. self.self_attention_cache.reset()
  1098. self.cross_attention_cache.reset()
  1099. for layer_idx in self.is_updated:
  1100. self.is_updated[layer_idx] = False
  1101. def reorder_cache(self, beam_idx: torch.LongTensor):
  1102. """Reorders the cache for beam search, given the selected beam indices."""
  1103. self.self_attention_cache.reorder_cache(beam_idx)
  1104. self.cross_attention_cache.reorder_cache(beam_idx)
  1105. def check_dynamic_cache(self, method: str):
  1106. if not (
  1107. isinstance(self.self_attention_cache, DynamicCache)
  1108. and isinstance(self.cross_attention_cache, DynamicCache)
  1109. ):
  1110. raise ValueError(
  1111. f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
  1112. f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
  1113. )
  1114. # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
  1115. def crop(self, maximum_length: int):
  1116. """
  1117. Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
  1118. negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
  1119. """
  1120. self.check_dynamic_cache(self.crop.__name__)
  1121. self.self_attention_cache.crop(maximum_length)
  1122. def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
  1123. """
  1124. Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
  1125. `_split_model_inputs()` in `generation.utils`
  1126. """
  1127. self.check_dynamic_cache(self.batch_split.__name__)
  1128. self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
  1129. cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
  1130. out = []
  1131. for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
  1132. out.append(EncoderDecoderCache(self_attn, cross_attn))
  1133. return out
  1134. def batch_repeat_interleave(self, repeats: int):
  1135. """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
  1136. self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
  1137. self.self_attention_cache.batch_repeat_interleave(repeats)
  1138. self.cross_attention_cache.batch_repeat_interleave(repeats)
  1139. def batch_select_indices(self, indices: torch.Tensor):
  1140. """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
  1141. self.check_dynamic_cache(self.batch_select_indices.__name__)
  1142. self.self_attention_cache.batch_select_indices(indices)
  1143. self.cross_attention_cache.batch_select_indices(indices)
  1144. def get_max_cache_shape(self) -> int:
  1145. """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
  1146. return self.self_attention_cache.get_max_cache_shape()
  1147. def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
  1148. return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
  1149. @property
  1150. def is_sliding(self):
  1151. return self.self_attention_cache.is_sliding
  1152. @property
  1153. def is_compileable(self) -> bool:
  1154. return self.self_attention_cache.is_compileable
  1155. ### Deprecated classes
  1156. class SlidingWindowLayer(StaticSlidingWindowLayer):
  1157. def __init__(self, max_cache_len: int, sliding_window: int):
  1158. logger.warning_once(
  1159. "`SlidingWindowLayer` is deprecated and will be removed in version v4.59 "
  1160. "Use `StaticSlidingWindowLayer` instead, which is a better name for it."
  1161. )
  1162. super().__init__(max_cache_len, sliding_window)
  1163. class ChunkedSlidingLayer(StaticSlidingWindowLayer):
  1164. def __init__(self, max_cache_len: int, sliding_window: int):
  1165. logger.warning_once(
  1166. "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 "
  1167. "Use `StaticSlidingWindowLayer` instead, which has the exact same functionalities."
  1168. )
  1169. super().__init__(max_cache_len, sliding_window)
  1170. class OffloadedCache(DynamicCache):
  1171. def __init__(self) -> None:
  1172. logger.warning_once(
  1173. "`OffloadedCache` is deprecated and will be removed in version v4.59 "
  1174. "Use `DynamicCache(offloading=True)` instead"
  1175. )
  1176. super().__init__(offloading=True)
  1177. class OffloadedStaticCache(StaticCache):
  1178. def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
  1179. logger.warning_once(
  1180. "`OffloadedStaticCache` is deprecated and will be removed in version v4.59 "
  1181. "Use `StaticCache(..., offloading=True)` instead"
  1182. )
  1183. super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
  1184. class SlidingWindowCache(StaticCache):
  1185. def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
  1186. logger.warning_once(
  1187. "`SlidingWindowCache` is deprecated and will be removed in version v4.59 "
  1188. "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
  1189. )
  1190. super().__init__(config=config, max_cache_len=max_cache_len)
  1191. class HybridCache(StaticCache):
  1192. def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
  1193. logger.warning_once(
  1194. "`HybridCache` is deprecated and will be removed in version v4.59 "
  1195. "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
  1196. )
  1197. super().__init__(config=config, max_cache_len=max_cache_len)
  1198. class HybridChunkedCache(StaticCache):
  1199. def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
  1200. logger.warning_once(
  1201. "`HybridChunkedCache` is deprecated and will be removed in version v4.59 "
  1202. "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
  1203. )
  1204. super().__init__(config=config, max_cache_len=max_cache_len)
  1205. class OffloadedHybridCache(StaticCache):
  1206. def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
  1207. logger.warning_once(
  1208. "`OffloadedHybridCache` is deprecated and will be removed in version v4.59 "
  1209. "Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer."
  1210. )
  1211. super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
  1212. class QuantoQuantizedCache(QuantizedCache):
  1213. def __init__(
  1214. self,
  1215. config: PretrainedConfig,
  1216. nbits: int = 4,
  1217. axis_key: int = 0,
  1218. axis_value: int = 0,
  1219. q_group_size: int = 64,
  1220. residual_length: int = 128,
  1221. ):
  1222. logger.warning_once(
  1223. "`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 "
  1224. "Use `QuantizedCache(backend='quanto', ...)` instead."
  1225. )
  1226. super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
  1227. class HQQQuantizedCache(QuantizedCache):
  1228. def __init__(
  1229. self,
  1230. config: PretrainedConfig,
  1231. nbits: int = 4,
  1232. axis_key: int = 0,
  1233. axis_value: int = 0,
  1234. q_group_size: int = 64,
  1235. residual_length: int = 128,
  1236. ):
  1237. logger.warning_once(
  1238. "`HQQQuantizedCache` is deprecated and will be removed in version v4.59 "
  1239. "Use `QuantizedCache(backend='hqq', ...)` instead."
  1240. )
  1241. super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
  1242. class SinkCache(Cache):
  1243. """
  1244. It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
  1245. See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
  1246. general `custom_generate`usage.
  1247. """
  1248. # TODO (joao, manuel): Remove this class in v4.59.0
  1249. def __init__(self, **kwargs) -> None:
  1250. raise NotImplementedError(
  1251. "`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
  1252. "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
  1253. )