requests.py 8.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import time
  16. from dataclasses import dataclass, field
  17. from enum import Enum
  18. from typing import Optional
  19. import torch
  20. from ...utils.logging import logging
  21. from ...utils.metrics import traced
  22. # We centralize the logger here to coordinate between logging and progress bar
  23. logger = logging.getLogger("ContinuousBatchingLogger")
  24. # logger.setLevel(logging.INFO)
  25. def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
  26. if torch.cuda.is_available():
  27. device = torch.device("cuda")
  28. torch.cuda.empty_cache()
  29. torch.cuda.synchronize()
  30. total_memory = torch.cuda.get_device_properties(device).total_memory
  31. reserved_memory = torch.cuda.memory_reserved(device)
  32. allocated_memory = torch.cuda.memory_allocated(device)
  33. elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
  34. device = torch.device("mps")
  35. # MPS memory reporting (PyTorch 2.0+)
  36. total_memory = torch.mps.driver_allocated_memory()
  37. allocated_memory = total_memory - torch.mps.recommended_max_memory()
  38. reserved_memory = 0 # MPS does not track reserved separately
  39. else:
  40. device = torch.device("cpu")
  41. total_memory = None
  42. reserved_memory = 0
  43. allocated_memory = 0
  44. return device, total_memory, reserved_memory, allocated_memory
  45. class RequestStatus(Enum):
  46. """Status of a generation request through its lifecycle."""
  47. PENDING = "pending"
  48. PREFILLING = "prefilling"
  49. PREFILLING_SPLIT = "prefilling_split"
  50. SPLIT_PENDING_REMAINDER = "split_pending_remainder"
  51. DECODING = "decoding"
  52. FINISHED = "finished"
  53. FAILED = "failed"
  54. @dataclass
  55. class GenerationOutput:
  56. """Tracks the output of a generation request.
  57. Attributes:
  58. request_id (str): The ID of the generation request.
  59. prompt_ids (list[int]): The IDs of the prompt tokens.
  60. generated_tokens (list[int]): The generated tokens.
  61. logprobs (list[float]): The log probabilities of the generated tokens.
  62. error (Optional[str]): Any error message associated with the request. When None, the request was successful.
  63. status (RequestStatus): The status of the request.
  64. created_time (float): The time the request was created.
  65. """
  66. request_id: str
  67. prompt_ids: list[int] = field(default_factory=list)
  68. generated_tokens: list[int] = field(default_factory=list)
  69. logprobs: list[float] = field(default_factory=list)
  70. error: Optional[str] = None
  71. status: RequestStatus = RequestStatus.PENDING
  72. created_time: float = field(default_factory=time.time)
  73. @dataclass
  74. class RequestState:
  75. """Tracks the state of a generation request through its lifecycle.
  76. Attributes:
  77. request_id (str): The ID of the generation request.
  78. full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
  79. prompt_ids (list[int] | None): The tokens IDs currently being processed.
  80. remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
  81. static_outputs (list[int]): The generated tokens.
  82. allocated_blocks (int): The number of blocks allocated to the request.
  83. position_offset (int): The current position in the sequence for position_ids.
  84. status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
  85. SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
  86. max_new_tokens (int): The maximum number of new tokens to generate.
  87. eos_token_id (int): The ID of the end-of-sequence token.
  88. created_time (float): The time the request was created.
  89. error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
  90. """
  91. # Required fields
  92. request_id: str
  93. full_prompt_ids: Optional[list[int]] = None # Full initial prompt
  94. prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
  95. remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
  96. static_outputs: list[int] = field(default_factory=list) # Generated tokens
  97. allocated_blocks: int = 0 # Number of blocks allocated to the request
  98. position_offset: int = 0 # Current position in the sequence for position_ids
  99. _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
  100. max_new_tokens: int = 20 # Maximum number of new tokens to generate
  101. eos_token_id: int = -1 # ID of the end-of-sequence token
  102. created_time: float = field(default_factory=time.time) # Time the request was created
  103. error: Optional[str] = None # Error message if the request failed
  104. lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
  105. @property
  106. def status(self) -> RequestStatus:
  107. return self._status
  108. @status.setter
  109. def status(self, value: RequestStatus):
  110. if self._status == RequestStatus.PENDING:
  111. self.lifespan = (time.time(), -1)
  112. elif value == RequestStatus.FINISHED:
  113. self.lifespan = (self.lifespan[0], time.time())
  114. self.log_end_of_request()
  115. self._status = value
  116. def log_end_of_request(self):
  117. prefill_len = len(self.full_prompt_ids)
  118. decode_len = self.generated_len()
  119. start_time = self.lifespan[0] - self.created_time
  120. end_time = self.lifespan[1] - self.created_time
  121. logger.info(
  122. f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
  123. )
  124. def current_len(self) -> int:
  125. """Get the current length of the sequence (prompt + generated tokens)."""
  126. return self.position_offset
  127. def generated_len(self) -> int:
  128. """Get the number of tokens generated so far."""
  129. return len(self.static_outputs)
  130. # TODO: this logic seems one token off, check it out
  131. @traced
  132. def update_with_token(self, token_id: int) -> bool:
  133. """Update the request with a newly generated token and check for completion.
  134. Args:
  135. token_id: The token ID to add to the output sequence
  136. Returns:
  137. bool: True if the request is now complete, False otherwise
  138. """
  139. # Only update if we're in decoding state
  140. if self.status != RequestStatus.DECODING:
  141. return False
  142. is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
  143. is_max_len = self.generated_len() >= self.max_new_tokens
  144. # Only add the token if we're not finishing due to max length
  145. # (EOS tokens should still be added to the output)
  146. if not (is_max_len and not is_eos):
  147. self.static_outputs.extend([token_id])
  148. if is_eos or is_max_len:
  149. self.status = RequestStatus.FINISHED
  150. return True
  151. return False
  152. def __repr__(self):
  153. msg = [
  154. f"request_id={self.request_id}",
  155. f"status={self._status}",
  156. f"out_tokens={self.generated_len()}",
  157. f"query_length={len(self.prompt_ids)}",
  158. f"remaining_tokens={len(self.remaining_prompt_ids)}",
  159. f"kv_length={self.position_offset}",
  160. f"full_prompt_length={len(self.full_prompt_ids)}",
  161. f"allocated_blocks={self.allocated_blocks}",
  162. f"generated_tokens={self.static_outputs}",
  163. ]
  164. return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
  165. def to_generation_output(self):
  166. """Convert the request state to a GenerationOutput object."""
  167. return GenerationOutput(
  168. request_id=self.request_id,
  169. prompt_ids=self.full_prompt_ids,
  170. status=self.status,
  171. generated_tokens=self.static_outputs,
  172. logprobs=[],
  173. error=self.error,
  174. )