cache.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from collections import deque
  16. from math import floor, gcd, sqrt
  17. from typing import Optional, Union
  18. import torch
  19. from ...configuration_utils import PretrainedConfig
  20. from ...generation.configuration_utils import GenerationConfig
  21. from ...utils.metrics import attach_tracer, traced
  22. from .cache_manager import CacheAllocator, FullAttentionCacheAllocator, SlidingAttentionCacheAllocator
  23. from .requests import get_device_and_memory_breakdown, logger
  24. def group_layers_by_attn_type(config: PretrainedConfig) -> tuple[list[list[int]], list[str]]:
  25. """
  26. Group layers depending on the attention mix, according to VLLM's hybrid allocator rules:
  27. - Layers in each group need to have the same type of attention
  28. - All groups have the same number of layers
  29. For a model with the following layer types: ["sliding", "full", "full", "sliding", "full", "full", "full", "full"]
  30. We would get two groups: [0, 3] and [1, 2], [4,5], [6,7].
  31. """
  32. # If the config has no layer_type attribute, it means all layers are the same attention type
  33. layer_types = getattr(config, "layer_types", None)
  34. if layer_types is None:
  35. attn_type = "sliding_attention" if getattr(config, "sliding_window", None) is not None else "full_attention"
  36. layer_types = [attn_type for _ in range(config.num_hidden_layers)]
  37. # We then count the number of layers of each type
  38. layer_counts = {}
  39. for i, layer_type in enumerate(layer_types):
  40. layer_counts[layer_type] = layer_counts.get(layer_type, []) + [i]
  41. # The size of all groups is the greatest common divisor of the number of layers of each type
  42. group_size = gcd(*[len(indices) for indices in layer_counts.values()])
  43. # We then group the layers by type
  44. layer_groups = []
  45. for layer_type, indices in layer_counts.items():
  46. for i in range(0, len(indices), group_size):
  47. layer_groups.append(indices[i : i + group_size])
  48. # And note the layer types
  49. group_types = [layer_types[lg[0]] for lg in layer_groups]
  50. return layer_groups, group_types
  51. @attach_tracer()
  52. class PagedAttentionCache:
  53. """
  54. Manages the cache for a paged attention mechanism, inspired by VLLM's hybrid allocator. The cache relies on making
  55. groups of layers to reduce the complexity of cache management and fragmentation.
  56. The cache uses a three-level hierarchy:
  57. - Pages: The smallest unit of cache, a page has a size of [num_heads, head_size], which is the space needed to
  58. store the key or value states for one token and one layer. For a model with only full-attention layers, to store
  59. the KV cache of one token, we need `2 * num_layers` pages: key and values each take `num_layers` pages.
  60. Pages are grouped into blocks:
  61. - Blocks: A block is a collection of `block_size` pages, serving as the allocation unit to reduce management
  62. complexity and fragmentation. Cache is allocated and freed block by block, not page by page. One block is
  63. allocated to one layer group, which only has one attention type, like full-attention or sliding-attention.
  64. If all layers in the model have the same attention type, then all layers will be in the same group. There is
  65. more than one group if and only if the model has a mixed attention types, like layers with full-attention and
  66. layers with sliding-attention.
  67. - Cache tensors: The physical supports for the cache. There are as many cache tensors as there are layer in a
  68. layer group, and the shape of the cache tensor is `[num_blocks * block_size, num_heads, head_size]`.
  69. Grouping layers into groups is useful because when we allocate one block to a group N, the block allocated is the
  70. same for all layers in group N, equivalently it is allocated across all cache tensors. This allows us to
  71. efficiently allocate and free blocks, and to efficiently read and write key and value states.
  72. For instance, imagine we have 8 blocks of cache and a model with two layer groups: a full-attention group with 3
  73. layers and a sliding-attention group with 3 layers. At creation time, the physical cache tensors look like this:
  74. cache_tensor_0: □ □ □ □ □ □ □ □
  75. cache_tensor_1: □ □ □ □ □ □ □ □
  76. cache_tensor_2: □ □ □ □ □ □ □ □
  77. where □ means the blocks is not allocated to any layer group yet. We have 3 cache tensors because there are
  78. 3 layers per group.
  79. We allocate 1 block to each group, after allocation, the cache tensors look like this:
  80. cache_tensor_0: ✖ ◉ □ □ □ □ □ □
  81. cache_tensor_1: ✖ ◉ □ □ □ □ □ □
  82. cache_tensor_2: ✖ ◉ □ □ □ □ □ □
  83. where ✖ means the block is allocated to the full-attention group, and ◉ means the block is allocated to the
  84. sliding-attention group.
  85. Now, if we continue to generate, and the sliding window has been reached, we only need to allocate a new block
  86. for the full-attention group, and the cache tensors look like this:
  87. cache_tensor_0: ✖ ◉ ✖ □ □ □ □ □
  88. cache_tensor_1: ✖ ◉ ✖ □ □ □ □ □
  89. cache_tensor_2: ✖ ◉ ✖ □ □ □ □ □
  90. And after further generation, when we need a new block allocated:
  91. cache_tensor_0: ✖ ◉ ✖ ✖ □ □ □ □
  92. cache_tensor_1: ✖ ◉ ✖ ✖ □ □ □ □
  93. cache_tensor_2: ✖ ◉ ✖ ✖ □ □ □ □
  94. This would not have been possible if all layers were in the same group: we would have had to allocate a new block
  95. for the sliding-attention group, although it is not needed.
  96. """
  97. # TODO: this init is quite long, maybe a refactor is in order
  98. def __init__(
  99. self,
  100. config: PretrainedConfig,
  101. generation_config: GenerationConfig,
  102. device: torch.device,
  103. dtype: torch.dtype = torch.float16,
  104. layer_device_map: Optional[dict[int, Union[str, torch.device, int]]] = None,
  105. tp_size: Optional[int] = None,
  106. ) -> None:
  107. """Initialize a paged attention cache for efficient memory usage.
  108. Args:
  109. config: Model configuration
  110. generation_config: Generation configuration containing cache parameters
  111. device: Device for the cache tensors
  112. dtype: Data type of the cache
  113. layer_device_map: Optional mapping of layer indices to devices
  114. tp_size: Tensor parallelism size
  115. """
  116. self.config = config
  117. self.dtype = dtype
  118. self.device = device
  119. # Extract model dimensions
  120. kv_heads = getattr(config, "num_key_value_heads", None)
  121. self.num_key_value_heads: int = kv_heads if kv_heads is not None else config.num_attention_heads
  122. head_dim = getattr(config, "head_dim", None)
  123. self.head_dim: int = head_dim if head_dim is not None else config.hidden_size // config.num_attention_heads
  124. # Extract cache dimensions
  125. self.block_size = getattr(generation_config, "block_size", 32)
  126. # Group layers depending on the attention mix
  127. layer_groups, group_types = group_layers_by_attn_type(config)
  128. group_size = len(layer_groups[0])
  129. self.num_groups = len(layer_groups)
  130. self.sliding_windows = {}
  131. self.layer_index_to_group_indices = {}
  132. for i, group in enumerate(layer_groups):
  133. sliding_window = config.sliding_window if group_types[i] == "sliding_attention" else 1
  134. for j, layer in enumerate(group):
  135. self.layer_index_to_group_indices[layer] = (i, j)
  136. self.sliding_windows[layer] = sliding_window
  137. # Handle TP (or dont)
  138. if tp_size is not None and tp_size > 1:
  139. if self.num_key_value_heads % tp_size != 0:
  140. raise ValueError(
  141. f"Number of key value heads {self.num_key_value_heads} must be divisible by tensor parallel size {tp_size}."
  142. )
  143. # If the model is using tensor parallelism, we need to adjust the number of heads accordingly.
  144. # self.num_key_value_heads //= tp_size # TODO: why is this commented out?
  145. # Infer number of blocks and max batch tokens
  146. page_size = self.head_dim * self.num_key_value_heads
  147. if getattr(config, "attn_implementation", None) == "paged_attention":
  148. num_attention_masks = 0
  149. else:
  150. # TODO: when we generalize to allow for block-attn, we can use `num_attention_masks=sum(set(group_types))`
  151. num_attention_masks = 2 if "sliding_attention" in group_types else 1
  152. memory_handler = PagedAttentionMemoryHandler(
  153. block_size=self.block_size,
  154. page_size=page_size,
  155. num_groups=self.num_groups,
  156. group_size=group_size,
  157. peak_activation_per_token=(config.hidden_size + config.vocab_size),
  158. num_attention_masks=num_attention_masks,
  159. )
  160. num_blocks, max_batch_tokens = memory_handler.infer_num_blocks_and_max_batch_tokens(
  161. num_blocks=getattr(generation_config, "num_blocks", None),
  162. max_batch_tokens=getattr(generation_config, "max_batch_tokens", None),
  163. max_memory_percent=getattr(generation_config, "max_memory", 0.9),
  164. cache_dtype=self.dtype,
  165. )
  166. # Add the inferred attributes to the class
  167. self.num_blocks = num_blocks
  168. self.max_batch_tokens = max_batch_tokens
  169. logger.info(
  170. f"PagedAttentionCache initialized with {self.num_blocks = }, {self.block_size = }, {page_size = }, "
  171. f"{self.max_batch_tokens = } {num_attention_masks = }"
  172. )
  173. # Initialize the cache
  174. self.key_cache: list[torch.Tensor] = []
  175. self.value_cache: list[torch.Tensor] = []
  176. # We add one extra token to the cache to handle padding and generally discard unwanted tokens
  177. self.cache_shape = (num_blocks * self.block_size + 1, self.num_key_value_heads, self.head_dim)
  178. for _ in range(group_size):
  179. new_layer_key_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
  180. new_layer_value_cache = torch.empty(self.cache_shape, dtype=self.dtype, device=self.device)
  181. torch._dynamo.mark_static_address(new_layer_key_cache)
  182. torch._dynamo.mark_static_address(new_layer_value_cache)
  183. self.key_cache.append(new_layer_key_cache)
  184. self.value_cache.append(new_layer_value_cache)
  185. logger.info(f"{self.cache_shape = } {self.key_cache[0].shape = } {self.key_cache[0].numel() = }")
  186. # Block management data structures
  187. self._free_blocks = deque(range(num_blocks))
  188. self.group_cache_managers: list[CacheAllocator] = []
  189. for i, group_type in enumerate(group_types):
  190. if group_type == "full_attention":
  191. cm = FullAttentionCacheAllocator(i, self.block_size)
  192. elif group_type == "sliding_attention":
  193. cm = SlidingAttentionCacheAllocator(i, self.block_size, config.sliding_window)
  194. else:
  195. raise ValueError(f"Invalid group type: {group_type}")
  196. self.group_cache_managers.append(cm)
  197. @traced
  198. def allocate_blocks(self, n_blocks: int, request_id: str) -> int:
  199. """Allocate cache blocks across all layer groups for a given request. Actual allocation is done by the cache
  200. managers, and this method only returns the maximum number of blocks actually allocated across all managers."""
  201. max_allocated = 0
  202. for cm in self.group_cache_managers:
  203. allocated = cm.allocate_blocks(n_blocks, request_id, self._free_blocks)
  204. if allocated is None:
  205. return None
  206. max_allocated = max(max_allocated, allocated)
  207. return max_allocated
  208. @traced
  209. def free_blocks(self, request_id: str) -> None:
  210. """Free all allocated cache blocks for a given request across all layer groups. Actual deallocation is done
  211. by the cache managers."""
  212. for cm in self.group_cache_managers:
  213. cm.free_blocks(request_id, self._free_blocks)
  214. def get_num_free_blocks(self) -> int:
  215. """Get the current number of unallocated blocks available for new requests."""
  216. return len(self._free_blocks)
  217. @traced
  218. def extend_read_indices(
  219. self, request_id: str, past_length: int, query_length: int, read_index: list[list[int]]
  220. ) -> None:
  221. """Retrieve physical cache indices for reading KV states in the cache across all layer groups. This method
  222. coordinates with all cache managers to build the complete set of read indices needed for attention computation.
  223. """
  224. for cm, read_indices in zip(self.group_cache_managers, read_index):
  225. indices = cm.get_read_indices(request_id, past_length, query_length)
  226. read_indices.extend(indices)
  227. @traced
  228. def extend_write_indices(
  229. self, request_id: str, past_length: int, query_length: int, write_index: list[list[int]]
  230. ) -> None:
  231. """Retrieve physical cache indices for writing new KV states to the cache across all layer groups. This method
  232. coordinates with all cache managers to build the complete set of write indices needed to store computed KV
  233. states."""
  234. for cm, write_indices in zip(self.group_cache_managers, write_index):
  235. indices = cm.get_write_indices(request_id, past_length, query_length)
  236. write_indices.extend(indices)
  237. @traced
  238. def get_seqlens_k(self, request_id: str, past_length: int, query_length: int) -> dict[str, int]:
  239. """Retrieve the key sequence length for the given request_id across all layer types. Returns a dictionary of
  240. layer types to their corresponding key sequence lengths."""
  241. seqlens_k = {}
  242. for cm in self.group_cache_managers:
  243. attn_type, seqlen_k = cm.get_seqlens_k(request_id, past_length, query_length)
  244. seqlens_k[attn_type] = seqlen_k
  245. return seqlens_k
  246. @traced
  247. def update(
  248. self,
  249. key_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
  250. value_states: torch.Tensor, # shape [1, num_kv_heads, seqlen_kv, head_dim]
  251. layer_idx: int,
  252. read_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_kv + past_length]
  253. write_index: list[torch.Tensor], # shape [num_layer_groups, seqlen_q]
  254. **kwargs,
  255. ) -> tuple[torch.Tensor, torch.Tensor]: # shape [seqlen_kv + past_length, num_kv_heads, head_dim]
  256. """Update the cache with new key-value states for a specific layer. This method writes new KV states to the
  257. appropriate cache locations. The behavior differs based on the layer's attention type:
  258. - Full attention: New KV states are written to cache, then complete sequence is read from cache
  259. - Sliding window: Old KV is read from cache along with extra spaces for the new KV, then new KV is written to
  260. cache. This is because new KV might overwrite the old KV, so we need to read the old KV first.
  261. Returns the complete KV states (cached + new) for attention computation.
  262. """
  263. # Retrieve the layer read and write indices, and if there is a sliding window
  264. group_idx, layer_idx_in_group = self.layer_index_to_group_indices[layer_idx]
  265. layer_read_index = read_index[group_idx]
  266. layer_write_index = write_index[group_idx]
  267. # Select the correct cache
  268. k_cache = self.key_cache[layer_idx_in_group]
  269. v_cache = self.value_cache[layer_idx_in_group]
  270. # Transpose the key and value states to match the cache shape, after which shape is [seqlen_kv, num_kv_heads, head_dim]
  271. key_states = key_states.transpose(1, 2).squeeze(0)
  272. value_states = value_states.transpose(1, 2).squeeze(0)
  273. # Case: full attention
  274. sliding_window = self.sliding_windows[layer_idx]
  275. if sliding_window == 1:
  276. k_cache[layer_write_index, :, :] = key_states
  277. v_cache[layer_write_index, :, :] = value_states
  278. key_states_with_cache = k_cache[layer_read_index, :, :]
  279. value_states_with_cache = v_cache[layer_read_index, :, :]
  280. # Case: sliding window -- we need to be careful of read/write order because of chunked prefill, because it's
  281. # the only case where you may write over cache you need to use
  282. else:
  283. # Add the cache to the key and value states
  284. mask = layer_read_index == -1 # TODO: can this can be efficiently precomputed?
  285. key_states_with_cache = k_cache[layer_read_index, :, :]
  286. key_states_with_cache[mask] = key_states
  287. value_states_with_cache = v_cache[layer_read_index, :, :]
  288. value_states_with_cache[mask] = value_states
  289. # Write new KV values to the cache
  290. k_cache[layer_write_index, :, :] = key_states
  291. v_cache[layer_write_index, :, :] = value_states
  292. # Return the new KV values
  293. return key_states_with_cache, value_states_with_cache
  294. # TODO: rework computation with the groups and their sizes
  295. class PagedAttentionMemoryHandler:
  296. """A helper class to determine the best number of pages and maximum number of tokens per batch for the paged
  297. attention cache, providing automatic sizing based on available GPU memory.
  298. The helper works using the number of pages, which is tied to the number of blocks by:
  299. num_blocks = num_pages // block_size
  300. The memory footprint consists of three main components:
  301. - Cache memory: the space needed to store the cache tensors:
  302. 2 * layer_group_size * [num_pages, page_size] * cache_dtype
  303. - Activation memory: the space temporarily taken by the largest activation during the model forward pass:
  304. peak_activation_per_token * max_tokens_per_batch * activation_dtype_size
  305. - Static tensors: the space taken by the input/output buffers and metadata tensors for batch processing, sum of:
  306. - inputs_ids + outputs_ids + position_ids + logits_indices: 4 * max_tokens_per_batch * int32_size
  307. - attention_mask: num_attention_masks * num_pages * max_tokens_per_batch * activation_dtype_size
  308. - cumulative_seqlens_q + cumulative_seqlens_k: (1 + 2) * max_tokens_per_batch * int32_size
  309. - write_index_tensor: num_groups * max_tokens_per_batch * int32_size
  310. - read_index_tensor: num_groups * (num_pages + max_tokens_per_batch) * int32_size
  311. The handler can operate in three modes:
  312. 1. Auto-sizing: Determines both number of pages and maximum number of tokens per batch using quadratic optimization
  313. 2. Fixed cache: Calculates max batch tokens given a fixed number of pages
  314. 3. Fixed batch: Calculates number of pages given a fixed maximum batch size
  315. """
  316. _activation_dtype = torch.bfloat16
  317. _input_dtype = torch.int32
  318. _upper_bound_max_batch_tokens = 256
  319. _upper_bound_num_blocks = 4096
  320. def __init__(
  321. self,
  322. block_size: int,
  323. page_size: int,
  324. num_groups: int,
  325. group_size: int,
  326. peak_activation_per_token: int,
  327. num_attention_masks: int,
  328. ) -> None:
  329. """Initialize the memory handler with the parameters that cannot be automatically inferred.
  330. Args:
  331. block_size: Size of the cache blocks
  332. page_size: Size of the cache pages
  333. num_groups: Number of layer groups
  334. group_size: Number of layers per layer group
  335. peak_activation_per_token: Maximum size of activation tensor per token, = hidden_size + vocab_size
  336. num_attention_masks: Number of attention masks, 0 if no attention mask is used, 2 if hybrid model, else 1
  337. """
  338. self.block_size = block_size
  339. self.page_size = page_size
  340. self.num_groups = num_groups
  341. self.group_size = group_size
  342. self.peak_activation_per_token = peak_activation_per_token
  343. self.num_attention_masks = num_attention_masks
  344. @staticmethod
  345. def get_available_memory(max_memory_percent: float = 1.0) -> int:
  346. """Calculate available GPU memory for cache allocation, accounting for already allocated tensors.
  347. This method queries the current memory state and applies the specified percentage limit to determine
  348. how much memory can be safely used for the paged attention cache.
  349. Args:
  350. max_memory_percent: Fraction of available memory to use (0.0-1.0). 1.0 means use all available memory.
  351. Returns:
  352. int: Available memory in bytes for cache allocation
  353. """
  354. _, total, reserved, allocated = get_device_and_memory_breakdown()
  355. available_memory = total - max(allocated, reserved)
  356. available_memory = int(available_memory * max_memory_percent)
  357. return available_memory
  358. def infer_num_blocks_and_max_batch_tokens(
  359. self,
  360. num_blocks: Optional[int] = None,
  361. max_batch_tokens: Optional[int] = None,
  362. max_memory_percent: float = 0.9,
  363. cache_dtype: torch.dtype = torch.float16,
  364. ) -> tuple[int, int]:
  365. """Determine optimal number of blocks and maximum number of tokens per batch based on available memory and
  366. constraints. Check the class docstring for more details. Naming the number of pages as N and the maximum number
  367. of tokens per batch as M, the equation solved is:
  368. available_memory = sum([
  369. MN * num_attention_masks * activation_dtype_size,
  370. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  371. M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  372. ])
  373. where we already simplified int32_size = 4.
  374. """
  375. # If neither num_blocks nor max_batch_tokens are provided, we use a second-order polynomial
  376. if num_blocks is None and max_batch_tokens is None:
  377. num_blocks, max_batch_tokens = self.compute_num_blocks_and_max_batch_tokens(
  378. max_memory_percent, cache_dtype
  379. )
  380. # If only num_blocks is provided, we infer the max_batch_tokens
  381. elif num_blocks is not None and max_batch_tokens is None:
  382. max_batch_tokens = self.compute_max_batch_tokens(num_blocks, max_memory_percent, cache_dtype)
  383. # If only max_batch_tokens is provided, we infer the num_blocks
  384. elif max_batch_tokens is not None and num_blocks is None:
  385. num_blocks = self.compute_num_blocks(max_batch_tokens, max_memory_percent, cache_dtype)
  386. # We check if the memory footprint is too large in all cases
  387. available_memory = self.get_available_memory(max_memory_percent)
  388. memory_footprint = self.compute_memory_footprint(
  389. max_batch_tokens=max_batch_tokens,
  390. num_blocks=num_blocks,
  391. cache_dtype=cache_dtype,
  392. )
  393. if memory_footprint > available_memory:
  394. raise MemoryError(f"Memory footprint {memory_footprint} is more than available memory {available_memory}")
  395. return num_blocks, max_batch_tokens
  396. def compute_num_blocks_and_max_batch_tokens(
  397. self,
  398. max_memory_percent: float = 0.9,
  399. cache_dtype: torch.dtype = torch.float16,
  400. m: float = 0.01,
  401. ) -> tuple[int, int]:
  402. """Calculate optimal number of blocks and maximum number of tokens per batch using quadratic optimization when
  403. neither is fixed. This method assumes a relationship M = m * N where m is a small ratio below 1 and solves the
  404. resulting quadratic equation to find the optimal N that maximizes utilization within memory constraints. m is
  405. the amount of cache we can fill with one batch: m=0.01 means a batch fills at most 1% of the cache. The equation
  406. to solve is:
  407. available_memory = sum([
  408. m * N^2 * num_attention_masks * activation_dtype_size,
  409. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  410. m * N * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  411. ])
  412. """
  413. cache_memory = self.get_available_memory(max_memory_percent)
  414. logger.info(f"Cache memory: {cache_memory}")
  415. # Compute second-degree polynomial coefficients
  416. a = m * self.num_attention_masks * self._activation_dtype.itemsize
  417. b = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  418. b += m * (self.peak_activation_per_token * self._activation_dtype.itemsize + 28 + 4 * self.num_groups)
  419. c = -cache_memory
  420. logger.debug(f"Coefficients of 2nd degree polynomial: {a = }, {b = }, {c = }")
  421. # Compute discriminant and greatest solution
  422. discriminant = b**2 - 4 * a * c
  423. if discriminant < 0:
  424. raise ValueError(f"Discriminant is negative: {discriminant = }")
  425. greatest_solution = (-b + sqrt(discriminant)) / (2 * a)
  426. if greatest_solution < 0:
  427. raise ValueError(f"Greatest solution is negative: {greatest_solution = }")
  428. # Infer number of blocks and max batch tokens
  429. num_pages = floor(greatest_solution)
  430. num_blocks = num_pages // self.block_size
  431. if num_blocks > self._upper_bound_num_blocks:
  432. logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
  433. num_blocks = self._upper_bound_num_blocks
  434. max_batch_tokens = int(greatest_solution * m)
  435. if max_batch_tokens > self._upper_bound_max_batch_tokens:
  436. logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
  437. max_batch_tokens = self._upper_bound_max_batch_tokens
  438. return num_blocks, max_batch_tokens
  439. def compute_max_batch_tokens(
  440. self,
  441. num_blocks: int,
  442. max_memory_percent: float = 0.9,
  443. cache_dtype: torch.dtype = torch.float16,
  444. ) -> int:
  445. """Calculate maximum batch tokens M given a fixed number of cache blocks. The formula for M is given by:
  446. M = (available_memory - 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group))
  447. / (activation_dtype_size * (N * num_attention_masks + peak_activation_per_token) + 28 + 4 * num_group)
  448. """
  449. cache_memory = self.get_available_memory(max_memory_percent)
  450. num_pages = num_blocks * self.block_size
  451. # Compute numerator
  452. num = cache_memory
  453. num -= 2 * num_pages * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  454. # Compute denominator
  455. denum = self._activation_dtype.itemsize * (
  456. num_pages * self.num_attention_masks + self.peak_activation_per_token
  457. )
  458. denum += 28 + 4 * self.num_groups
  459. # Compute max batch tokens and return
  460. max_batch_tokens = floor(num / denum)
  461. if max_batch_tokens > self._upper_bound_max_batch_tokens:
  462. logger.info(f"{max_batch_tokens = } is too large, setting to {self._upper_bound_max_batch_tokens = }")
  463. max_batch_tokens = self._upper_bound_max_batch_tokens
  464. return max_batch_tokens
  465. def compute_num_blocks(
  466. self,
  467. max_batch_tokens: int,
  468. max_memory_percent: float = 0.9,
  469. cache_dtype: torch.dtype = torch.float16,
  470. ) -> int:
  471. """Calculate number of cache blocks N given a fixed maximum token per token M. The formula for N is given by:
  472. N = (available_memory - M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group))
  473. / (2 * (layer_group_size * page_size * cache_dtype + 2 * num_group) + M * (num_attention_masks * activation_dtype_size))
  474. """
  475. cache_memory = self.get_available_memory(max_memory_percent)
  476. # Compute numerator
  477. num = cache_memory
  478. num -= max_batch_tokens * self.peak_activation_per_token * self._activation_dtype.itemsize
  479. num -= max_batch_tokens * (28 + 4 * self.num_groups)
  480. # Compute denominator
  481. denum = 2 * (self.group_size * self.page_size * cache_dtype.itemsize + 2 * self.num_groups)
  482. denum += max_batch_tokens * (self.num_attention_masks * self._activation_dtype.itemsize)
  483. denum += max_batch_tokens * self._activation_dtype.itemsize
  484. # Compute cache size and return number of blocks
  485. num_pages = floor(num / denum)
  486. num_blocks = num_pages // self.block_size
  487. if num_blocks > self._upper_bound_num_blocks:
  488. logger.info(f"{num_blocks = } is too large, setting to {self._upper_bound_num_blocks = }")
  489. num_blocks = self._upper_bound_num_blocks
  490. return num_blocks
  491. def compute_memory_footprint(
  492. self,
  493. num_blocks: Optional[int] = None,
  494. max_batch_tokens: Optional[int] = None,
  495. cache_dtype: torch.dtype = torch.float16,
  496. ) -> tuple[int, int, int]:
  497. """Calculate the memory footprint breakdown for a given number of blocks and maximum batch tokens. The memory
  498. footprint is given by:
  499. available_memory = sum([
  500. MN * num_attention_masks * activation_dtype_size,
  501. 2N * (layer_group_size * page_size * cache_dtype + 2 * num_group),
  502. M * (peak_activation_per_token * activation_dtype + 28 + 4 * num_group),
  503. ])
  504. but is broken down below.
  505. """
  506. num_pages = num_blocks * self.block_size
  507. cache_memory_footprint = 2 * self.group_size * num_pages * self.page_size * cache_dtype.itemsize
  508. activation_memory_footprint = self.peak_activation_per_token * self._activation_dtype.itemsize
  509. activation_memory_footprint *= max_batch_tokens
  510. inputs_outputs_positions_and_logits_memory_footprint = 4 * max_batch_tokens * 4 # second 4 is for int32 size
  511. attention_memory_footprint = self.num_attention_masks * self._activation_dtype.itemsize
  512. attention_memory_footprint *= num_pages * max_batch_tokens
  513. cumulative_seqlens_memory_footprint = 3 * max_batch_tokens * 4 # 4 is for int32 size
  514. write_index_memory_footprint = self.num_groups * max_batch_tokens * 4 # 4 is for int32 size
  515. read_index_memory_footprint = self.num_groups * (num_pages + max_batch_tokens) * 4 # 4 is for int32 size
  516. total_memory_footprint = sum(
  517. [
  518. cache_memory_footprint,
  519. activation_memory_footprint,
  520. inputs_outputs_positions_and_logits_memory_footprint,
  521. attention_memory_footprint,
  522. cumulative_seqlens_memory_footprint,
  523. write_index_memory_footprint,
  524. read_index_memory_footprint,
  525. ]
  526. )
  527. return total_memory_footprint