| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493 |
- from abc import ABC, abstractmethod
- from collections.abc import Iterable
- from typing import Any, Optional
- import torch
- from .configuration_utils import PretrainedConfig
- from .utils import (
- is_hqq_available,
- is_quanto_greater,
- is_torch_greater_or_equal,
- is_torchdynamo_compiling,
- logging,
- )
- if is_hqq_available():
- from hqq.core.quantize import Quantizer as HQQQuantizer
- _is_torch_greater_or_equal_than_2_7 = is_torch_greater_or_equal("2.7", accept_dev=True)
- logger = logging.get_logger(__name__)
- class CacheLayerMixin(ABC):
- """Base, abstract class for a single layer's cache."""
- is_compileable = False
- def __init__(self):
- self.keys: Optional[torch.Tensor] = None
- self.values: Optional[torch.Tensor] = None
- self.is_initialized = False
- def __repr__(self):
- return f"{self.__class__.__name__}"
- @abstractmethod
- def lazy_initialization(self, key_states: torch.Tensor): ...
- @abstractmethod
- def update(
- self, key_states: torch.Tensor, value_states: torch.Tensor, cache_kwargs: Optional[dict[str, Any]] = None
- ) -> tuple[torch.Tensor, torch.Tensor]: ...
- @abstractmethod
- def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]: ...
- @abstractmethod
- def get_seq_length(self) -> int: ...
- @abstractmethod
- def get_max_cache_shape(self) -> int: ...
- def offload(self):
- """Offload this layer's data to CPU device."""
- if self.is_initialized:
- self.keys = self.keys.to("cpu", non_blocking=True)
- self.values = self.values.to("cpu", non_blocking=True)
- def prefetch(self):
- """In case of layer offloading, this allows to move the data back to the layer's device ahead of time."""
- if self.is_initialized and self.keys.device != self.device:
- self.keys = self.keys.to(self.device, non_blocking=True)
- self.values = self.values.to(self.device, non_blocking=True)
- def reset(self) -> None:
- """Resets the cache values while preserving the objects"""
- if self.is_initialized:
- self.keys.zero_()
- self.values.zero_()
- # This attribute is set on several Layers
- if hasattr(self, "cumulative_length"):
- self.cumulative_length = 0
- def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
- """Reorders this layer's cache for beam search."""
- if self.get_seq_length() > 0:
- self.keys = self.keys.index_select(0, beam_idx.to(self.keys.device))
- self.values = self.values.index_select(0, beam_idx.to(self.values.device))
- class DynamicLayer(CacheLayerMixin):
- """
- A cache layer that grows dynamically as more tokens are generated. This is the default for generative models.
- It stores the key and value states as tensors of shape `[batch_size, num_heads, seq_len, head_dim]`.
- """
- is_sliding = False
- def lazy_initialization(self, key_states: torch.Tensor):
- self.dtype, self.device = key_states.dtype, key_states.device
- self.keys = torch.tensor([], dtype=self.dtype, device=self.device)
- self.values = torch.tensor([], dtype=self.dtype, device=self.device)
- self.is_initialized = True
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Update the key and value caches in-place, and return the necessary keys and value states.
- Args:
- key_states (`torch.Tensor`): The new key states to cache.
- value_states (`torch.Tensor`): The new value states to cache.
- cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Returns:
- tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
- """
- # Lazy initialization
- if not self.is_initialized:
- self.lazy_initialization(key_states)
- self.keys = torch.cat([self.keys, key_states], dim=-2)
- self.values = torch.cat([self.values, value_states], dim=-2)
- return self.keys, self.values
- def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
- """Return the length and offset of the cache, used to generate the mask"""
- kv_offset = 0
- query_length = cache_position.shape[0]
- kv_length = self.get_seq_length() + query_length
- return kv_length, kv_offset
- def get_seq_length(self) -> int:
- """Returns the sequence length of the cached states."""
- if not self.is_initialized or self.keys.numel() == 0:
- return 0
- return self.keys.shape[-2]
- def get_max_cache_shape(self) -> int:
- """Returns the maximum sequence length of the cache object. DynamicLayer does not have a maximum length."""
- return -1
- def crop(self, max_length: int) -> None:
- """
- Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be negative
- to remove `max_length` tokens.
- """
- if max_length < 0:
- max_length = self.get_seq_length() - abs(max_length)
- if self.get_seq_length() <= max_length:
- return
- self.keys = self.keys[..., :max_length, :]
- self.values = self.values[..., :max_length, :]
- def batch_repeat_interleave(self, repeats: int) -> None:
- """Repeat the cache `repeats` times in the batch dimension."""
- if self.get_seq_length() > 0:
- self.keys = self.keys.repeat_interleave(repeats, dim=0)
- self.values = self.values.repeat_interleave(repeats, dim=0)
- def batch_select_indices(self, indices: torch.Tensor) -> None:
- """Only keep the `indices` in the batch dimension of the cache."""
- if self.get_seq_length() > 0:
- self.keys = self.keys[indices, ...]
- self.values = self.values[indices, ...]
- class DynamicSlidingWindowLayer(DynamicLayer):
- """
- A cache layer that grows dynamically as more tokens are generated, up until the sliding window size.
- It stores the key and value states as tensors of shape `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
- """
- is_sliding = True
- def __init__(self, sliding_window: int):
- super().__init__()
- self.sliding_window = sliding_window
- self.cumulative_length = 0
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Update the key and value caches in-place, and return the necessary keys and value states.
- Args:
- key_states (`torch.Tensor`): The new key states to cache.
- value_states (`torch.Tensor`): The new value states to cache.
- cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Returns:
- tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
- """
- # Lazy initialization
- if not self.is_initialized:
- self.lazy_initialization(key_states)
- self.cumulative_length += key_states.shape[-2]
- # Compute the full states
- full_key_states = torch.cat([self.keys, key_states], dim=-2)
- full_value_states = torch.cat([self.values, value_states], dim=-2)
- # Only cache the last `self.sliding_window - 1` tokens (or all of them if lower than that)
- self.keys = full_key_states[:, :, -self.sliding_window + 1 :, :]
- self.values = full_value_states[:, :, -self.sliding_window + 1 :, :]
- # Return the full states
- return full_key_states, full_value_states
- def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
- """Return the length and offset of the cache, used to generate the attention mask"""
- query_length = cache_position.shape[0]
- is_full = self.cumulative_length >= self.sliding_window
- kv_offset = max(self.cumulative_length - self.sliding_window + 1, 0)
- if is_full:
- kv_length = self.sliding_window - 1 + query_length
- else:
- kv_length = self.cumulative_length + query_length
- return kv_length, kv_offset
- def get_seq_length(self) -> int:
- """Returns the sequence length of the cached states."""
- return self.cumulative_length
- def get_max_cache_shape(self) -> int:
- """Return the maximum cache shape of the cache"""
- return self.sliding_window
- def crop(self, max_length: int) -> None:
- """
- Crop the past key values up to a new `max_length` in terms of tokens. `max_length` can also be
- negative to remove `max_length` tokens.
- """
- if self.get_seq_length() >= self.sliding_window:
- raise ValueError(
- "Cannot `crop` a `DynamicSlidingWindowLayer` after it has seen more tokens than its"
- "sliding window (otherwise some states are lost)"
- )
- super().crop(max_length)
- self.cumulative_length = self.keys.shape[-2]
- class StaticLayer(CacheLayerMixin):
- """
- A static cache layer that stores the key and value states as static tensors of shape `[batch_size, num_heads, max_cache_len), head_dim]`.
- It lazily allocates its full backing tensors, and then mutates them in-place. Built for `torch.compile` support.
- Args:
- max_cache_len (`int`):
- Maximum number of tokens that can be stored, used for tensor preallocation.
- """
- is_compileable = True
- is_sliding = False
- def __init__(self, max_cache_len: int):
- super().__init__()
- self.max_cache_len = max_cache_len
- def lazy_initialization(self, key_states: torch.Tensor):
- """
- Lazy initialization of the keys and values tensors. This allows to get all properties (dtype, device,
- num_heads in case of TP etc...) at runtime directly, which is extremely practical as it avoids moving
- devices, dtypes etc later on for each `update` (which could break the static dynamo addresses as well).
- If this is unwanted, one can call `early_initialization(...)` on the Cache directly, which will call this
- function ahead-of-time (this is required for `torch.export` for example). Note that for `compile`, as we
- internally don't compile the prefill, this is guaranteed to have been called already when compiling.
- If compiling the prefill as well, e.g. calling `model.compile(...)` before `generate` with a static cache,
- it is still supported in general, but without guarantees depending on the compilation options (e.g. cuda graphs,
- i.e. `mode="reduce-overhead"` is known to fail). But it will in general work correctly, and prefill should
- not be compiled anyway for performances!
- """
- self.max_batch_size, self.num_heads, _, self.head_dim = key_states.shape
- self.dtype, self.device = key_states.dtype, key_states.device
- self.keys = torch.zeros(
- (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
- dtype=self.dtype,
- device=self.device,
- )
- self.values = torch.zeros(
- (self.max_batch_size, self.num_heads, self.max_cache_len, self.head_dim),
- dtype=self.dtype,
- device=self.device,
- )
- # Note: `mark_static_address` is used to tag the cache as a fixed data pointer, preventing compiled graph
- # breaks when updating the cache. However, it is not supported when tracing the graph, so we skip it in this case.
- # As prefill should never be compiled, this is not an issue and it will still be run (except when users compile
- # prefill explicitly, but this should be avoided!)
- if not is_torchdynamo_compiling():
- torch._dynamo.mark_static_address(self.keys)
- torch._dynamo.mark_static_address(self.values)
- self.is_initialized = True
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Update the key and value caches in-place, and return the necessary keys and value states.
- Args:
- key_states (`torch.Tensor`): The new key states to cache.
- value_states (`torch.Tensor`): The new value states to cache.
- cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Returns:
- tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
- """
- # Lazy initialization
- if not self.is_initialized:
- self.lazy_initialization(key_states)
- # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
- # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
- cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
- cache_position = (
- cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
- )
- # Update the cache
- try:
- self.keys.index_copy_(2, cache_position, key_states)
- self.values.index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- # Fallback for devices like MPS where index_copy_ might not be supported.
- self.keys[:, :, cache_position] = key_states
- self.values[:, :, cache_position] = value_states
- return self.keys, self.values
- def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
- """Return the length and offset of the cache, used to generate the attention mask"""
- kv_offset = 0
- kv_length = self.max_cache_len
- return kv_length, kv_offset
- def get_seq_length(self) -> int:
- """Returns the sequence length of the cached states."""
- # Occupied cache == any slot in the 3rd dim (sequence length) holds a non-zero value. To save on compute, let's
- # limit the check to the first batch member and head dimension.
- return (self.keys[0, 0].any(dim=-1)).sum() if self.is_initialized else 0
- def get_max_cache_shape(self) -> int:
- """Return the maximum cache shape of the cache"""
- return self.max_cache_len
- class StaticSlidingWindowLayer(StaticLayer):
- """
- A static cache layer that stores the key and value states as static tensors of shape
- `[batch_size, num_heads, min(max_cache_len, sliding_window), head_dim]`. It lazily allocates its full backing
- tensors, and then mutates them in-place. Built for `torch.compile` support.
- Args:
- max_cache_len (`int`):
- Maximum number of tokens that can be stored, used for tensor preallocation.
- sliding_window (`int`):
- The size of the sliding window.
- """
- is_sliding = True
- def __init__(self, max_cache_len: int, sliding_window: int):
- effective_max_cache_len = min(sliding_window, max_cache_len)
- super().__init__(max_cache_len=effective_max_cache_len)
- self.cumulative_length = 0
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Update the key and value caches in-place, and return the necessary keys and value states.
- Args:
- key_states (`torch.Tensor`): The new key states to cache.
- value_states (`torch.Tensor`): The new value states to cache.
- cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Returns:
- tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
- """
- # Lazy initialization
- if not self.is_initialized:
- self.lazy_initialization(key_states)
- # Some old models give None for `cache_position` or even omit passing `cache_kwargs` when used as cross-attention,
- # in which case we should copy the whole Layer (key_states.shape[-2] == self.max_cache_len)
- cache_position = cache_kwargs.get("cache_position") if cache_kwargs is not None else None
- cache_position = (
- cache_position if cache_position is not None else torch.arange(key_states.shape[-2], device=self.device)
- )
- cumulative_length = self.cumulative_length
- is_full = cumulative_length >= self.max_cache_len
- # Update it now that we saved the value above
- self.cumulative_length += key_states.shape[-2]
- if is_full:
- # In general, we should use a much simpler `cat` here as well, independently of the states size. However,
- # dynamo is currently bugged when doing it - see https://github.com/pytorch/pytorch/issues/159855 for more details
- if key_states.shape[-2] == 1:
- # Roll all values to the left by 1 position
- new_keys = self.keys.roll(-1, dims=-2)
- new_values = self.values.roll(-1, dims=-2)
- # Overwrite the last position with new states
- # (note: very important to use a tensor to index here, see https://github.com/pytorch/pytorch/issues/159855)
- index = torch.tensor([-1], dtype=int, device=self.device)
- new_keys[:, :, index] = key_states
- new_values[:, :, index] = value_states
- # Copy back into `self` (do not just assign again) in order to keep the static dynamo address
- self.keys.copy_(new_keys)
- self.values.copy_(new_values)
- # Very important to return the `self` tensors here, as they have the static dynamo address
- return self.keys, self.values
- # Already full but using more than 1 new token (e.g. prefill caching, chat continuation, etc...)
- else:
- full_key_states = torch.cat((self.keys[:, :, 1:, :], key_states), dim=-2)
- full_value_states = torch.cat((self.values[:, :, 1:, :], value_states), dim=-2)
- # Not yet full, but becoming full on this update
- elif cumulative_length + key_states.shape[2] > self.max_cache_len:
- # Fast prefill path, no need to cat() in this case, as the cache is currently empty
- if cumulative_length == 0:
- full_key_states = key_states
- full_value_states = value_states
- else:
- full_key_states = torch.cat((self.keys[:, :, :cumulative_length, :], key_states), dim=-2)
- full_value_states = torch.cat((self.values[:, :, :cumulative_length, :], value_states), dim=-2)
- else:
- try:
- self.keys.index_copy_(2, cache_position, key_states)
- self.values.index_copy_(2, cache_position, value_states)
- except NotImplementedError:
- self.keys[:, :, cache_position] = key_states
- self.values[:, :, cache_position] = value_states
- # Very important to return the `self` tensors here, as they have the static dynamo address
- return self.keys, self.values
- # We only cache the last `sliding_window` tokens
- self.keys.copy_(full_key_states[:, :, -self.max_cache_len :, :])
- self.values.copy_(full_value_states[:, :, -self.max_cache_len :, :])
- # we should return the whole states instead of `self.keys/values` here, as otherwise we lose some context
- return full_key_states, full_value_states
- def get_mask_sizes(self, cache_position: torch.Tensor) -> tuple[int, int]:
- """Return the length and offset of the cache, used to generate the attention mask"""
- query_length = cache_position.shape[0]
- sliding_window = self.max_cache_len
- is_full = self.cumulative_length >= self.max_cache_len
- kv_offset = max(self.cumulative_length - sliding_window + 1, 0)
- # The cache is already full
- if is_full:
- kv_length = sliding_window + query_length - 1
- # Not yet full, but becoming full on this update
- elif self.cumulative_length + query_length > sliding_window:
- kv_length = self.cumulative_length + query_length
- # Here the Cache is still smaller than the local size, but we return the local size as it's static
- else:
- kv_length = sliding_window
- return kv_length, kv_offset
- def get_seq_length(self) -> int:
- """Returns the sequence length of the cached states."""
- return self.cumulative_length
- class QuantizedLayer(DynamicLayer):
- """
- A quantized layer similar to what is described in the [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
- It allows the model to generate longer sequence length without allocating too much memory for the key and value caches by
- applying quantization.
- The cache has two types of storage, one for original precision and one for the quantized cache. A `residual length`
- is set as a maximum capacity for the original precision cache. When the length goes beyond maximum capacity, the original
- precision cache is discarded and moved into the quantized cache. The quantization is done per-channel with a set `q_group_size`
- for both Keys and Values, in contrast to what was described in the paper.
- """
- def __init__(
- self,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- super().__init__()
- self.nbits = nbits
- self.axis_key = axis_key
- self.axis_value = axis_value
- self.q_group_size = q_group_size
- self.residual_length = residual_length
- self.cumulative_length = 0
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Update the key and value caches in-place, and return the necessary keys and value states.
- Args:
- key_states (`torch.Tensor`): The new key states to cache.
- value_states (`torch.Tensor`): The new value states to cache.
- cache_kwargs (`dict[str, Any]`, *optional*): Additional arguments for the cache.
- Returns:
- tuple[`torch.Tensor`, `torch.Tensor`]: The key and value states.
- """
- self.cumulative_length += key_states.shape[-2]
- # Lazy initialization
- if not self.is_initialized:
- self.lazy_initialization(key_states)
- self._quantized_keys = self._quantize(key_states.contiguous(), axis=self.axis_key)
- self._quantized_values = self._quantize(value_states.contiguous(), axis=self.axis_value)
- return key_states, value_states
- dequant_keys = self._dequantize(self._quantized_keys)
- dequant_values = self._dequantize(self._quantized_values)
- keys_to_return = torch.cat([dequant_keys, self.keys, key_states], dim=-2)
- values_to_return = torch.cat([dequant_values, self.values, value_states], dim=-2)
- if self.keys.dim() == 4 and self.keys.shape[-2] + 1 >= self.residual_length:
- self._quantized_keys = self._quantize(keys_to_return.contiguous(), axis=self.axis_key)
- self._quantized_values = self._quantize(values_to_return.contiguous(), axis=self.axis_value)
- self.keys = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
- self.values = torch.tensor([], dtype=key_states.dtype, device=key_states.device)
- else:
- self.keys = torch.cat([self.keys, key_states], dim=-2)
- self.values = torch.cat([self.values, value_states], dim=-2)
- return keys_to_return, values_to_return
- @abstractmethod
- def _quantize(self, tensor, axis): ...
- @abstractmethod
- def _dequantize(self, q_tensor): ...
- def get_seq_length(self) -> int:
- """Returns the sequence length of the cached states."""
- return self.cumulative_length
- class QuantoQuantizedLayer(QuantizedLayer):
- def __init__(
- self,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- super().__init__(
- nbits=nbits,
- axis_key=axis_key,
- axis_value=axis_value,
- q_group_size=q_group_size,
- residual_length=residual_length,
- )
- # We need to import quanto here to avoid circular imports due to optimum/quanto/models/transformers_models.py
- if is_quanto_greater("0.2.5", accept_dev=True):
- from optimum.quanto import MaxOptimizer, qint2, qint4
- else:
- raise ImportError(
- "You need optimum-quanto package version to be greater or equal than 0.2.5 to use `QuantoQuantizedCache`. "
- )
- if self.nbits not in [2, 4]:
- raise ValueError(f"`nbits` for `quanto` backend has to be one of [`2`, `4`] but got {self.nbits}")
- if self.axis_key not in [0, -1]:
- raise ValueError(f"`axis_key` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_key}")
- if self.axis_value not in [0, -1]:
- raise ValueError(
- f"`axis_value` for `quanto` backend has to be one of [`0`, `-1`] but got {self.axis_value}"
- )
- self.qtype = qint4 if self.nbits == 4 else qint2
- self.optimizer = MaxOptimizer() # hardcode as it's the only one for per-channel quantization
- def _quantize(self, tensor, axis):
- from optimum.quanto import quantize_weight
- scale, zeropoint = self.optimizer(tensor, self.qtype, axis, self.q_group_size)
- qtensor = quantize_weight(tensor, self.qtype, axis, scale, zeropoint, self.q_group_size)
- return qtensor
- def _dequantize(self, qtensor):
- return qtensor.dequantize()
- class HQQQuantizedLayer(QuantizedLayer):
- def __init__(
- self,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- super().__init__(
- nbits=nbits,
- axis_key=axis_key,
- axis_value=axis_value,
- q_group_size=q_group_size,
- residual_length=residual_length,
- )
- if not is_hqq_available():
- raise ImportError("You need to install `hqq` to use `HQQQuantizedLayer`")
- if self.nbits not in [1, 2, 3, 4, 8]:
- raise ValueError(
- f"`nbits` for `HQQ` backend has to be one of [`1`, `2`, `3`, `4`, `8`] but got {self.nbits}"
- )
- if self.axis_key not in [0, 1]:
- raise ValueError(f"`axis_key` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_key}")
- if self.axis_value not in [0, 1]:
- raise ValueError(f"`axis_value` for `HQQ` backend has to be one of [`0`, `1`] but got {self.axis_value}")
- self.quantizer = HQQQuantizer
- def _quantize(self, tensor, axis):
- qtensor, meta = self.quantizer.quantize(
- tensor,
- axis=axis,
- device=self.keys.device,
- compute_dtype=self.keys.dtype,
- nbits=self.nbits,
- group_size=self.q_group_size,
- )
- meta["compute_dtype"] = self.keys.dtype
- self.quantizer.cuda(qtensor, meta=meta, device=self.keys.device) # Move to device and cast to dtype
- meta["scale"] = meta["scale"].to(qtensor.device)
- meta["zero"] = meta["zero"].to(qtensor.device)
- return qtensor, meta
- def _dequantize(self, qtensor):
- quant_tensor, meta = qtensor
- tensor = self.quantizer.dequantize(quant_tensor, meta)
- return tensor
- class Cache:
- """
- A `Cache` is mostly a list of `CacheLayerMixin` objects, one per model layer. It serves as a container for
- the Cache of each layer.
- Args:
- layers (`Optional`, *optional*):
- A list of pre-created `CacheLayerMixin`. If omitted (`None`), then `layer_class_to_replicate` will
- be used.
- layer_class_to_replicate (`type[CacheLayerMixin]`, *optional*):
- Only used if `layers` is omitted (`None`), in which case it will be used as the base class for each layer,
- and the layers will be added lazily as soon as `update` is called with a `layer_idx` greater than the current
- list of layers.
- offloading (`bool`, *optional*, defaults to `False`):
- Whether to perform offloading of the layers to `cpu`, to save GPU memory.
- offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
- If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
- usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
- """
- def __init__(
- self,
- layers: Optional[list[CacheLayerMixin]] = None,
- layer_class_to_replicate: Optional[type[CacheLayerMixin]] = None,
- offloading: bool = False,
- offload_only_non_sliding: bool = True,
- ):
- if layers is not None and layer_class_to_replicate is not None:
- raise ValueError(
- "You can construct a Cache either from a list `layers` of all the predefined `CacheLayer`, or from a "
- "`layer_class_to_replicate`, in which case the Cache will append a new layer corresponding to "
- "`layer_class_to_replicate` for each new call to `update` with an idx not already in the Cache."
- )
- if layers is None and layer_class_to_replicate is None:
- raise ValueError(
- "You should provide exactly one of `layers` or `layer_class_to_replicate` to initialize a Cache."
- )
- self.layers = layers if layers is not None else []
- self.layer_class_to_replicate = layer_class_to_replicate
- self.offloading = offloading
- if self.offloading:
- self.only_non_sliding = offload_only_non_sliding
- self.prefetch_stream = torch.Stream() if _is_torch_greater_or_equal_than_2_7 else torch.cuda.Stream()
- def __repr__(self):
- return f"{self.__class__.__name__}(layers={self.layers})"
- def prefetch(self, layer_idx: int, only_non_sliding: bool = True):
- """
- Prefetch a given layer on its device. If `only_non_sliding` is True, it will try to prefetch only the layers
- which are non-sliding. If the `layer_idx` is outside the range, this will circle back to the first layers.
- Note that we use a non-default stream for this, to avoid blocking.
- """
- if only_non_sliding:
- # Try to find next non-sliding, starting at `layer_idx`
- try:
- layer_idx = layer_idx + self.is_sliding[layer_idx:].index(False)
- # In this case, we need to circle back to the beginning
- except ValueError:
- layer_idx = self.is_sliding.index(False)
- else:
- layer_idx = layer_idx if layer_idx < len(self.layers) else 0
- # Prefetch
- with self.prefetch_stream if _is_torch_greater_or_equal_than_2_7 else torch.cuda.stream(self.prefetch_stream):
- self.layers[layer_idx].prefetch()
- def offload(self, layer_idx: int, only_non_sliding: bool = True):
- """
- Offload a given `layer_idx`. If `only_non_sliding` is True, it will offload `layer_idx` only if it is a
- non-sliding layer. Note that we do it on the default stream, so that we ensure all earlier
- computation in the layer's `update` methods are finished.
- """
- if not (only_non_sliding and self.is_sliding[layer_idx]):
- self.layers[layer_idx].offload()
- def update(
- self,
- key_states: torch.Tensor,
- value_states: torch.Tensor,
- layer_idx: int,
- cache_kwargs: Optional[dict[str, Any]] = None,
- ) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Updates the cache with the new `key_states` and `value_states` for the layer `layer_idx`.
- Parameters:
- key_states (`torch.Tensor`):
- The new key states to cache.
- value_states (`torch.Tensor`):
- The new value states to cache.
- layer_idx (`int`):
- The index of the layer to cache the states for.
- cache_kwargs (`dict[str, Any]`, *optional*):
- Additional arguments for the cache subclass. These are specific to each subclass and allow new types of
- cache to be created.
- Return:
- A tuple containing the updated key and value states.
- """
- # In this case, the `layers` were not provided, and we must append as much as `layer_idx`
- if self.layer_class_to_replicate is not None:
- while len(self.layers) <= layer_idx:
- self.layers.append(self.layer_class_to_replicate())
- if self.offloading:
- # Wait for the stream to finish if needed, and start prefetching the next layer
- torch.cuda.default_stream(key_states.device).wait_stream(self.prefetch_stream)
- self.prefetch(layer_idx + 1, self.only_non_sliding)
- keys, values = self.layers[layer_idx].update(key_states, value_states, cache_kwargs)
- if self.offloading:
- self.offload(layer_idx, self.only_non_sliding)
- return keys, values
- def early_initialization(
- self, batch_size: int, num_heads: int, head_dim: int, dtype: torch.dtype, device: torch.device
- ):
- """
- Initialize all the layers in advance (it's otherwise lazily initialized on the first `update` call).
- This is useful for our `export` recipes, as `export` needs everything in advance.
- """
- # Note that the initialization needs all dimensions (except -2), as well as device and dtype, so we use
- # this fake tensor approach. It has size 0 on the -2 dimension, so it does not allocate any data (it only
- # creates an empty tensor with correct shape, dtype and device), which is very efficient and practical
- fake_keys_tensor = torch.zeros((batch_size, num_heads, 0, head_dim), dtype=dtype, device=device)
- # Init all layers
- for layer in self.layers:
- layer.lazy_initialization(fake_keys_tensor)
- def get_seq_length(self, layer_idx: int = 0) -> int:
- """Returns the sequence length of the cache for the given layer."""
- if layer_idx >= len(self.layers):
- return 0
- return self.layers[layer_idx].get_seq_length()
- def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
- """
- Return a tuple (kv_length, kv_offset) corresponding to the length and offset that will be returned for
- the given layer at `layer_idx`.
- The masks are then prepared according to the given lengths (kv_length, kv_offset) and patterns for each layer.
- """
- # For DynamicCache, where the layers are created at runtime -> if it was not yet created, the size is
- # simply the shape of `cache_position`
- if layer_idx >= len(self.layers):
- return cache_position.shape[0], 0
- return self.layers[layer_idx].get_mask_sizes(cache_position)
- def get_max_cache_shape(self, layer_idx: int = 0) -> int:
- """Returns maximum sequence length of the cache object. Dynamic caches do not have a maximum length."""
- # For DynamicCache, where the layers are created at runtime -> if it was not yet created, return -1
- # as DynamicLayer does
- if layer_idx >= len(self.layers):
- return -1
- return self.layers[layer_idx].get_max_cache_shape()
- def reset(self):
- """Recursively reset all layers tensors"""
- for layer_idx in range(len(self.layers)):
- self.layers[layer_idx].reset()
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorder the cache for beam search"""
- for layer_idx in range(len(self.layers)):
- self.layers[layer_idx].reorder_cache(beam_idx)
- def crop(self, max_length: int):
- """Crop the cache to the given length"""
- for layer_idx in range(len(self.layers)):
- self.layers[layer_idx].crop(max_length)
- def batch_repeat_interleave(self, repeats: int):
- """Repeat and interleave the cache"""
- for layer_idx in range(len(self.layers)):
- self.layers[layer_idx].batch_repeat_interleave(repeats)
- def batch_select_indices(self, indices: torch.Tensor):
- """Select indices from the cache"""
- for layer_idx in range(len(self.layers)):
- self.layers[layer_idx].batch_select_indices(indices)
- @property
- def max_batch_size(self) -> int:
- """Return the maximum batch size of the cache"""
- values = [layer.max_batch_size for layer in self.layers]
- if len(set(values)) > 1:
- raise ValueError(f"Max batch size is not consistent across layers: {values}")
- return values[0]
- @property
- def max_cache_len(self) -> int:
- """Return the maximum cache length of the cache"""
- values = [layer.max_cache_len for layer in self.layers]
- return max(values)
- @property
- def is_compileable(self) -> bool:
- """Return whether the cache is compileable"""
- # For DynamicCache dispatching the layers lazily (otherwise, all([]) is True)
- if len(self.layers) == 0:
- return False
- return all(layer.is_compileable for layer in self.layers)
- @property
- def is_initialized(self) -> bool:
- """Return whether the cache data is initialized"""
- return len(self.layers) > 0 and all(layer.is_initialized for layer in self.layers)
- @property
- def is_sliding(self) -> list[bool]:
- """Return whether the layers of the cache are sliding window"""
- return [getattr(layer, "is_sliding", False) for layer in self.layers]
- def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor]:
- """
- Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self.layers):
- return self.layers[layer_idx].keys, self.layers[layer_idx].values
- else:
- raise KeyError(
- f"Cache only has {len(self.layers)} layers, attempted to access layer with index {layer_idx}"
- )
- def __iter__(self):
- """
- Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
- keys and values
- """
- for layer_idx in range(len(self)):
- yield (self.layers[layer_idx].keys, self.layers[layer_idx].values)
- def __len__(self):
- """
- This value corresponds to the number of layers in the model.
- """
- # Note: for DynamicCache, layers are initialized lazily, so this will not be accurate before the first
- # forward through all the layers
- return len(self.layers)
- class DynamicCache(Cache):
- """
- A cache that grows dynamically as more tokens are generated. This is the default for generative models.
- It stores the key and value states as a list of `CacheLayer`, one for each layer. The expected shape for each tensor
- in the `CacheLayer`s is `[batch_size, num_heads, seq_len, head_dim]`.
- If a config is passed, it will additionally check for sliding or hybrid cache structure, greatly reducing the
- memory requirement of the cached tensors to `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
- See `Cache` for details on common methods that are implemented by all cache classes.
- Args:
- ddp_cache_data (`Iterable[tuple[torch.Tensor, torch.Tensor]]`, *optional*):
- It was originally added for compatibility with `torch.distributed` (DDP). In a nutshell, it is
- `map(gather_map, zip(*caches))`, i.e. each item in the iterable contains the key and value states
- for a layer gathered across replicas by torch.distributed (shape=[global batch size, num_heads, seq_len, head_dim]).
- Note: it needs to be the 1st arg as well to work correctly
- config (`PretrainedConfig`, *optional*):
- The config of the model for which this Cache will be used. If passed, it will be used to check for sliding
- or hybrid layer structure, greatly reducing the memory requirement of the cached tensors to
- `[batch_size, num_heads, min(seq_len, sliding_window), head_dim]`.
- offloading (`bool`, *optional*, defaults to `False`):
- Whether to perform offloading of the layers to `cpu`, to save GPU memory.
- offload_only_non_sliding (`bool`, *optional*, defaults to `False`):
- If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
- usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, DynamicCache
- >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct")
- >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> past_key_values = DynamicCache(config=model.config)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- ```
- """
- def __init__(
- self,
- ddp_cache_data: Optional[Iterable[tuple[torch.Tensor, torch.Tensor]]] = None,
- config: Optional[PretrainedConfig] = None,
- offloading: bool = False,
- offload_only_non_sliding: bool = False,
- ):
- layers = []
- # If a config is passed, use it to infer the layer types and initialize accordingly
- if config is not None:
- decoder_config = config.get_text_config(decoder=True)
- sliding_window = getattr(decoder_config, "sliding_window", None) or getattr(
- decoder_config, "attention_chunk_size", None
- )
- layer_types = getattr(decoder_config, "layer_types", None)
- if layer_types is None:
- layer_types = [
- "sliding_attention" if sliding_window is not None else "full_attention"
- for _ in range(decoder_config.num_hidden_layers)
- ]
- # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
- if hasattr(decoder_config, "num_kv_shared_layers"):
- layer_types = layer_types[: -decoder_config.num_kv_shared_layers]
- for layer_type in layer_types:
- # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
- # states they should return - only the mask changes to make them different at the end!
- if layer_type in ("sliding_attention", "chunked_attention"):
- layers.append(DynamicSlidingWindowLayer(sliding_window=sliding_window))
- else:
- layers.append(DynamicLayer())
- # In this case, use the passed data to already fill in the Cache
- if ddp_cache_data is not None:
- # Init all the layers with the data
- for layer_idx, (key_states, value_states) in enumerate(ddp_cache_data):
- # If the config was not passed above, initialize a DynamicLayer for each entry of the ddp_data
- if config is None:
- layers.append(DynamicLayer())
- # Update the layer with the data
- _, _ = layers[layer_idx].update(key_states, value_states)
- # If neither of config nor ddp_data was passed, then simply lazy init a full cache of DynamicLayer
- if len(layers) == 0:
- super().__init__(
- layer_class_to_replicate=DynamicLayer,
- offloading=offloading,
- offload_only_non_sliding=offload_only_non_sliding,
- )
- else:
- super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
- def to_legacy_cache(self) -> tuple[tuple[torch.Tensor, torch.Tensor]]:
- """
- Converts the `Cache` instance into the its equivalent in the legacy cache format. Used for
- backward compatibility.
- """
- legacy_cache = ()
- for layer in self.layers:
- legacy_cache += ((layer.keys, layer.values),)
- return legacy_cache
- @classmethod
- def from_legacy_cache(cls, past_key_values: tuple[tuple[torch.Tensor, torch.Tensor]]) -> "DynamicCache":
- """
- Converts a cache in the legacy cache format into an equivalent `Cache`. Used for
- backward compatibility.
- """
- cache = cls()
- if past_key_values is None:
- logger.warning_once("past_key_values should not be None in from_legacy_cache()")
- if past_key_values is not None:
- for layer_idx in range(len(past_key_values)):
- key_states, value_states = past_key_values[layer_idx]
- cache.update(key_states, value_states, layer_idx)
- return cache
- class StaticCache(Cache):
- """
- Static Cache class to be used with `torch.compile(model)` and `torch.export()`. It will check the `config`
- for potential hybrid cache structure, and initialize each layer accordingly.
- See `Cache` for details on common methods that are implemented by all cache classes.
- Args:
- config (`PretrainedConfig`):
- The config of the model for which this Cache will be used. It will be used to check for sliding
- or hybrid layer structure, and initialize each layer accordingly.
- max_cache_len (`int`):
- The maximum number of tokens that this Cache should hold.
- offloading (`bool`, *optional*, defaults to `False`):
- Whether to perform offloading of the layers to `cpu`, to save GPU memory.
- offload_only_non_sliding (`bool`, *optional*, defaults to `True`):
- If `offloading` is `True`, this further decides if only the non-sliding layers will be offloaded (because
- usually the sliding layers are small in size, so there is no need to offload them, and skipping it is faster).
- Example:
- ```python
- >>> from transformers import AutoTokenizer, AutoModelForCausalLM, StaticCache
- >>> model = AutoModelForCausalLM.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- >>> tokenizer = AutoTokenizer.from_pretrained("meta-llama/Llama-2-7b-chat-hf")
- >>> inputs = tokenizer(text="My name is Llama", return_tensors="pt")
- >>> # Prepare a cache class and pass it to model's forward
- >>> # Leave empty space for 10 new tokens, which can be used when calling forward iteratively 10 times to generate
- >>> max_generated_length = inputs.input_ids.shape[1] + 10
- >>> past_key_values = StaticCache(config=model.config, max_cache_len=max_generated_length)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- StaticCache()
- ```
- """
- # Pass-in kwargs as well to avoid crashing for BC (it used more arguments before)
- def __init__(
- self,
- config: PretrainedConfig,
- max_cache_len: int,
- offloading: bool = False,
- offload_only_non_sliding: bool = True,
- **kwargs,
- ):
- config = config.get_text_config(decoder=True)
- layer_types = getattr(config, "layer_types", None)
- # If `layer_types` is not explicitly provided, infer if the model is fully sliding
- if layer_types is None:
- if getattr(config, "sliding_window", None) is not None:
- layer_types = ["sliding_attention" for _ in range(config.num_hidden_layers)]
- elif getattr(config, "attention_chunk_size", None) is not None:
- layer_types = ["chunked_attention" for _ in range(config.num_hidden_layers)]
- else:
- layer_types = ["full_attention" for _ in range(config.num_hidden_layers)]
- # Some models have shared layers thus no cache is needed for them (e.g. Gemma3n)
- if hasattr(config, "num_kv_shared_layers"):
- layer_types = layer_types[: -config.num_kv_shared_layers]
- layers = []
- for layer_type in layer_types:
- if layer_type == "sliding_attention":
- layer = StaticSlidingWindowLayer(max_cache_len=max_cache_len, sliding_window=config.sliding_window)
- elif layer_type == "chunked_attention":
- # From a cache point of view, both sliding and chunked are the same in how they should behave and how many
- # states they should return - only the mask changes to make them different at the end!
- layer = StaticSlidingWindowLayer(
- max_cache_len=max_cache_len, sliding_window=config.attention_chunk_size
- )
- else:
- layer = StaticLayer(max_cache_len=max_cache_len)
- layers.append(layer)
- super().__init__(layers=layers, offloading=offloading, offload_only_non_sliding=offload_only_non_sliding)
- class QuantizedCache(Cache):
- """
- A quantizer cache similar to what is described in the
- [KIVI: A Tuning-Free Asymmetric 2bit Quantization for KV Cache paper](https://huggingface.co/papers/2402.02750).
- It allows the model to generate longer sequence length without allocating too much memory for keys and values
- by applying quantization.
- The cache has two types of storage, one for original precision and one for the
- quantized cache. A `residual length` is set as a maximum capacity for the original precision cache. When the
- length goes beyond maximum capacity, the original precision cache is discarded and moved into the quantized cache.
- The quantization is done per-channel with a set `q_group_size` for both keys and values, in contrast to what was
- described in the paper.
- See `Cache` for details on common methods that are implemented by all cache classes.
- Args:
- backend (`str`):
- The quantization backend to use. One of `("quanto", "hqq").
- config (`PretrainedConfig`):
- The config of the model for which this Cache will be used.
- nbits (`int`, *optional*, defaults to 4):
- The number of bits for quantization.
- axis_key (`int`, *optional*, defaults to 0):
- The axis on which to quantize the keys.
- axis_value (`int`, *optional*, defaults to 0):
- The axis on which to quantize the values.
- q_group_size (`int`, *optional*, defaults to 64):
- Quantization is done per-channel according to a set `q_group_size` for both keys and values.
- residual_length (`int`, *optional*, defaults to 128):
- Maximum capacity for the original precision cache
- """
- def __init__(
- self,
- backend: str,
- config: PretrainedConfig,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- if backend == "quanto":
- layer_class = QuantoQuantizedLayer
- elif backend == "hqq":
- layer_class = HQQQuantizedLayer
- else:
- raise ValueError(f"Unknown quantization backend `{backend}`")
- config = config.get_text_config(decoder=True)
- layers = [
- layer_class(nbits, axis_key, axis_value, q_group_size, residual_length)
- for _ in range(config.num_hidden_layers)
- ]
- super().__init__(layers=layers)
- class EncoderDecoderCache(Cache):
- """
- Base, abstract class for all encoder-decoder caches. Can be used to hold combinations of self-attention and
- cross-attention caches.
- See `Cache` for details on common methods that are implemented by all cache classes.
- Args:
- caches (`Iterable`):
- Usually an iterable of length 2, containing 2 `Cache` objects, the first one for self-attention, the
- second one for cross-attention. Can optionally also be an iterable of length 1, containing a
- `tuple[tuple[torch.Tensor]]` (usually used for compatibility with torch dp and ddp).
- Example:
- ```python
- >>> from transformers import AutoProcessor, AutoModelForCausalLM, DynamicCache, EncoderDecoderCache
- >>> model = AutoModelForCausalLM.from_pretrained("openai/whisper-small")
- >>> processor = AutoProcessor.from_pretrained("openai/whisper-small")
- >>> inputs = processor(audio=YOUR-AUDIO, return_tensors="pt")
- >>> # Prepare cache classes for encoder and decoder and pass it to model's forward
- >>> self_attention_cache = DynamicCache(config=self.config)
- >>> cross_attention_cache = DynamicCache(config=self.config)
- >>> past_key_values = EncoderDecoderCache(self_attention_cache, cross_attention_cache)
- >>> outputs = model(**inputs, past_key_values=past_key_values, use_cache=True)
- >>> outputs.past_key_values # access cache filled with key/values from generation
- EncoderDecoderCache()
- ```
- """
- def __init__(self, *caches) -> None:
- # For dp and ddp support, if only one argument is passed, it should be an iterable of tuples of tensors
- if len(caches) == 1:
- self.self_attention_cache = DynamicCache()
- self.cross_attention_cache = DynamicCache()
- # Populate cache from the iterable
- for layer_idx, key_value_states in enumerate(caches[0]):
- key_states, value_states = key_value_states[:2]
- self.self_attention_cache.update(key_states, value_states, layer_idx)
- if len(key_value_states) > 2:
- key_states, value_states = key_value_states[2:]
- self.cross_attention_cache.update(key_states, value_states, layer_idx)
- # Otherwise, we should get two arguments, a self-attention cache and a cross-attention cache
- elif len(caches) == 2:
- if not isinstance(caches[0], Cache) or not isinstance(caches[1], Cache):
- raise TypeError(f"One of the two arguments is not a Cache: {type(caches[0]) = }, {type(caches[1]) = }")
- self.self_attention_cache = caches[0]
- self.cross_attention_cache = caches[1]
- # Error case
- else:
- raise ValueError(f"Expected 1 or 2 arguments, got {len(caches)}")
- self.is_updated = {}
- for layer_idx in range(len(self.cross_attention_cache)):
- self.is_updated[layer_idx] = bool(self.cross_attention_cache.get_seq_length(layer_idx) > 0)
- def __repr__(self) -> str:
- return (
- f"{self.__class__.__name__}(self_attention_cache={self.self_attention_cache}, cross_attention_cache="
- f"{self.cross_attention_cache})"
- )
- def __iter__(self):
- """
- Support for backwards-compatible `past_key_values` iteration, e.g. `for x in past_key_values:` to iterate over
- keys and values
- """
- for layer_idx in range(len(self)):
- yield (
- self.self_attention_cache.layers[layer_idx].keys,
- self.self_attention_cache.layers[layer_idx].values,
- self.cross_attention_cache.layers[layer_idx].keys,
- self.cross_attention_cache.layers[layer_idx].values,
- )
- def __getitem__(self, layer_idx: int) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Support for backwards-compatible `past_key_values` indexing, e.g. `past_key_values[0][0].shape[2]` to get the
- sequence length.
- """
- if layer_idx < len(self):
- return (
- self.self_attention_cache.layers[layer_idx].keys,
- self.self_attention_cache.layers[layer_idx].values,
- self.cross_attention_cache.layers[layer_idx].keys,
- self.cross_attention_cache.layers[layer_idx].values,
- )
- else:
- raise KeyError(f"Cache only has {len(self)} layers, attempted to access layer with index {layer_idx}")
- def __len__(self):
- """
- Support for backwards-compatible `past_key_values` length, e.g. `len(past_key_values)`. This value corresponds
- to the number of layers in the model.
- """
- return len(self.self_attention_cache)
- def to_legacy_cache(self) -> tuple[tuple[torch.Tensor]]:
- """Converts the `EncoderDecoderCache` instance into its equivalent in the legacy cache format."""
- legacy_cache = ()
- if len(self.cross_attention_cache) > 0:
- for self_attn, cross_attn in zip(
- self.self_attention_cache.to_legacy_cache(), self.cross_attention_cache.to_legacy_cache()
- ):
- legacy_cache += (self_attn + cross_attn,)
- else:
- legacy_cache = self.self_attention_cache.to_legacy_cache()
- return legacy_cache
- @classmethod
- def from_legacy_cache(
- cls, past_key_values: Optional[Iterable[tuple[torch.FloatTensor, ...]]]
- ) -> "EncoderDecoderCache":
- """Converts a cache in the legacy cache format into an equivalent `EncoderDecoderCache`."""
- cache = cls(DynamicCache(), DynamicCache())
- if past_key_values is None:
- logger.warning_once("past_key_values should not be None in from_legacy_cache()")
- else:
- for layer_idx, key_value_states in enumerate(past_key_values):
- key_states, value_states = key_value_states[:2]
- cache.self_attention_cache.update(key_states, value_states, layer_idx)
- if len(key_value_states) > 2:
- key_states, value_states = key_value_states[2:]
- cache.cross_attention_cache.update(key_states, value_states, layer_idx)
- cache.is_updated[layer_idx] = True
- return cache
- def get_seq_length(self, layer_idx: int = 0) -> int:
- """Returns the sequence length of the cached states. A layer index can be optionally passed."""
- return self.self_attention_cache.get_seq_length(layer_idx)
- def reset(self):
- self.self_attention_cache.reset()
- self.cross_attention_cache.reset()
- for layer_idx in self.is_updated:
- self.is_updated[layer_idx] = False
- def reorder_cache(self, beam_idx: torch.LongTensor):
- """Reorders the cache for beam search, given the selected beam indices."""
- self.self_attention_cache.reorder_cache(beam_idx)
- self.cross_attention_cache.reorder_cache(beam_idx)
- def check_dynamic_cache(self, method: str):
- if not (
- isinstance(self.self_attention_cache, DynamicCache)
- and isinstance(self.cross_attention_cache, DynamicCache)
- ):
- raise ValueError(
- f"`{method}` is only defined for dynamic cache, got {self.self_attention_cache.__str__()} for the self "
- f"attention cache and {self.cross_attention_cache.__str__()} for the cross attention cache."
- )
- # TODO(gante, sanchit-gandhi): move following functionality into `.generate`
- def crop(self, maximum_length: int):
- """
- Crop the past key values up to a new `maximum_length` in terms of tokens. `maximum_length` can also be
- negative to remove `maximum_length` tokens. This is used in assisted decoding and contrastive search (on the Hub).
- """
- self.check_dynamic_cache(self.crop.__name__)
- self.self_attention_cache.crop(maximum_length)
- def batch_split(self, full_batch_size: int, split_size: int) -> "list[EncoderDecoderCache]":
- """
- Split the current instance into a list of `DynamicCache` by the batch size. This will be used by
- `_split_model_inputs()` in `generation.utils`
- """
- self.check_dynamic_cache(self.batch_split.__name__)
- self_attention_cache = self.self_attention_cache.batch_split(full_batch_size, split_size)
- cross_attention_cache = self.cross_attention_cache.batch_split(full_batch_size, split_size)
- out = []
- for self_attn, cross_attn in zip(self_attention_cache, cross_attention_cache):
- out.append(EncoderDecoderCache(self_attn, cross_attn))
- return out
- def batch_repeat_interleave(self, repeats: int):
- """Repeat the cache `repeats` times in the batch dimension. Used in contrastive search (on the Hub)."""
- self.check_dynamic_cache(self.batch_repeat_interleave.__name__)
- self.self_attention_cache.batch_repeat_interleave(repeats)
- self.cross_attention_cache.batch_repeat_interleave(repeats)
- def batch_select_indices(self, indices: torch.Tensor):
- """Only keep the `indices` in the batch dimension of the cache. Used in contrastive search (on the Hub)."""
- self.check_dynamic_cache(self.batch_select_indices.__name__)
- self.self_attention_cache.batch_select_indices(indices)
- self.cross_attention_cache.batch_select_indices(indices)
- def get_max_cache_shape(self) -> int:
- """Returns the maximum sequence length (i.e. max capacity) of the cache object"""
- return self.self_attention_cache.get_max_cache_shape()
- def get_mask_sizes(self, cache_position: torch.Tensor, layer_idx: int) -> tuple[int, int]:
- return self.self_attention_cache.get_mask_sizes(cache_position, layer_idx)
- @property
- def is_sliding(self):
- return self.self_attention_cache.is_sliding
- @property
- def is_compileable(self) -> bool:
- return self.self_attention_cache.is_compileable
- ### Deprecated classes
- class SlidingWindowLayer(StaticSlidingWindowLayer):
- def __init__(self, max_cache_len: int, sliding_window: int):
- logger.warning_once(
- "`SlidingWindowLayer` is deprecated and will be removed in version v4.59 "
- "Use `StaticSlidingWindowLayer` instead, which is a better name for it."
- )
- super().__init__(max_cache_len, sliding_window)
- class ChunkedSlidingLayer(StaticSlidingWindowLayer):
- def __init__(self, max_cache_len: int, sliding_window: int):
- logger.warning_once(
- "`ChunkedSlidingLayer` is deprecated and will be removed in version v4.59 "
- "Use `StaticSlidingWindowLayer` instead, which has the exact same functionalities."
- )
- super().__init__(max_cache_len, sliding_window)
- class OffloadedCache(DynamicCache):
- def __init__(self) -> None:
- logger.warning_once(
- "`OffloadedCache` is deprecated and will be removed in version v4.59 "
- "Use `DynamicCache(offloading=True)` instead"
- )
- super().__init__(offloading=True)
- class OffloadedStaticCache(StaticCache):
- def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
- logger.warning_once(
- "`OffloadedStaticCache` is deprecated and will be removed in version v4.59 "
- "Use `StaticCache(..., offloading=True)` instead"
- )
- super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
- class SlidingWindowCache(StaticCache):
- def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
- logger.warning_once(
- "`SlidingWindowCache` is deprecated and will be removed in version v4.59 "
- "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
- )
- super().__init__(config=config, max_cache_len=max_cache_len)
- class HybridCache(StaticCache):
- def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
- logger.warning_once(
- "`HybridCache` is deprecated and will be removed in version v4.59 "
- "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
- )
- super().__init__(config=config, max_cache_len=max_cache_len)
- class HybridChunkedCache(StaticCache):
- def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
- logger.warning_once(
- "`HybridChunkedCache` is deprecated and will be removed in version v4.59 "
- "Use `StaticCache(...)` instead which will correctly infer the type of each layer."
- )
- super().__init__(config=config, max_cache_len=max_cache_len)
- class OffloadedHybridCache(StaticCache):
- def __init__(self, config: PretrainedConfig, max_cache_len: int, *args, **kwargs):
- logger.warning_once(
- "`OffloadedHybridCache` is deprecated and will be removed in version v4.59 "
- "Use `StaticCache(..., offload=True)` instead which will correctly infer the type of each layer."
- )
- super().__init__(config=config, max_cache_len=max_cache_len, offloading=True)
- class QuantoQuantizedCache(QuantizedCache):
- def __init__(
- self,
- config: PretrainedConfig,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- logger.warning_once(
- "`QuantoQuantizedCache` is deprecated and will be removed in version v4.59 "
- "Use `QuantizedCache(backend='quanto', ...)` instead."
- )
- super().__init__("quanto", config, nbits, axis_key, axis_value, q_group_size, residual_length)
- class HQQQuantizedCache(QuantizedCache):
- def __init__(
- self,
- config: PretrainedConfig,
- nbits: int = 4,
- axis_key: int = 0,
- axis_value: int = 0,
- q_group_size: int = 64,
- residual_length: int = 128,
- ):
- logger.warning_once(
- "`HQQQuantizedCache` is deprecated and will be removed in version v4.59 "
- "Use `QuantizedCache(backend='hqq', ...)` instead."
- )
- super().__init__("hqq", config, nbits, axis_key, axis_value, q_group_size, residual_length)
- class SinkCache(Cache):
- """
- It is now a `custom_generate` repository on the Hub: https://huggingface.co/transformers-community/sink_cache.
- See [these docs](https://huggingface.co/docs/transformers/generation_strategies#custom-decoding-methods) for
- general `custom_generate`usage.
- """
- # TODO (joao, manuel): Remove this class in v4.59.0
- def __init__(self, **kwargs) -> None:
- raise NotImplementedError(
- "`SinkCache` has been moved as a `custom_generate` repository on the Hub: "
- "https://huggingface.co/transformers-community/sink_cache. See the repository for usage examples."
- )
|