| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204 |
- # coding=utf-8
- # Copyright 2025 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- import time
- from dataclasses import dataclass, field
- from enum import Enum
- from typing import Optional
- import torch
- from ...utils.logging import logging
- from ...utils.metrics import traced
- # We centralize the logger here to coordinate between logging and progress bar
- logger = logging.getLogger("ContinuousBatchingLogger")
- # logger.setLevel(logging.INFO)
- def get_device_and_memory_breakdown() -> tuple[torch.device, int, int, int]:
- if torch.cuda.is_available():
- device = torch.device("cuda")
- torch.cuda.empty_cache()
- torch.cuda.synchronize()
- total_memory = torch.cuda.get_device_properties(device).total_memory
- reserved_memory = torch.cuda.memory_reserved(device)
- allocated_memory = torch.cuda.memory_allocated(device)
- elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
- device = torch.device("mps")
- # MPS memory reporting (PyTorch 2.0+)
- total_memory = torch.mps.driver_allocated_memory()
- allocated_memory = total_memory - torch.mps.recommended_max_memory()
- reserved_memory = 0 # MPS does not track reserved separately
- else:
- device = torch.device("cpu")
- total_memory = None
- reserved_memory = 0
- allocated_memory = 0
- return device, total_memory, reserved_memory, allocated_memory
- class RequestStatus(Enum):
- """Status of a generation request through its lifecycle."""
- PENDING = "pending"
- PREFILLING = "prefilling"
- PREFILLING_SPLIT = "prefilling_split"
- SPLIT_PENDING_REMAINDER = "split_pending_remainder"
- DECODING = "decoding"
- FINISHED = "finished"
- FAILED = "failed"
- @dataclass
- class GenerationOutput:
- """Tracks the output of a generation request.
- Attributes:
- request_id (str): The ID of the generation request.
- prompt_ids (list[int]): The IDs of the prompt tokens.
- generated_tokens (list[int]): The generated tokens.
- logprobs (list[float]): The log probabilities of the generated tokens.
- error (Optional[str]): Any error message associated with the request. When None, the request was successful.
- status (RequestStatus): The status of the request.
- created_time (float): The time the request was created.
- """
- request_id: str
- prompt_ids: list[int] = field(default_factory=list)
- generated_tokens: list[int] = field(default_factory=list)
- logprobs: list[float] = field(default_factory=list)
- error: Optional[str] = None
- status: RequestStatus = RequestStatus.PENDING
- created_time: float = field(default_factory=time.time)
- @dataclass
- class RequestState:
- """Tracks the state of a generation request through its lifecycle.
- Attributes:
- request_id (str): The ID of the generation request.
- full_prompt_ids (list[int] | None): The tokens IDs of the full prompt.
- prompt_ids (list[int] | None): The tokens IDs currently being processed.
- remaining_prompt_ids (list[int]): The tokens IDs remaining to be processed (for split requests).
- static_outputs (list[int]): The generated tokens.
- allocated_blocks (int): The number of blocks allocated to the request.
- position_offset (int): The current position in the sequence for position_ids.
- status (RequestStatus): The status of the request: can be one of PENDING, PREFILLING, PREFILLING_SPLIT,
- SPLIT_PENDING_REMAINDER, DECODING, FINISHED, FAILED
- max_new_tokens (int): The maximum number of new tokens to generate.
- eos_token_id (int): The ID of the end-of-sequence token.
- created_time (float): The time the request was created.
- error (Optional[str]): Any error message associated with the request. When None, has had no error yet.
- """
- # Required fields
- request_id: str
- full_prompt_ids: Optional[list[int]] = None # Full initial prompt
- prompt_ids: Optional[list[int]] = None # Tokens IDs currently being processed (initial + generated)
- remaining_prompt_ids: list[int] = field(default_factory=list) # For split requests, prefill left to process
- static_outputs: list[int] = field(default_factory=list) # Generated tokens
- allocated_blocks: int = 0 # Number of blocks allocated to the request
- position_offset: int = 0 # Current position in the sequence for position_ids
- _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property
- max_new_tokens: int = 20 # Maximum number of new tokens to generate
- eos_token_id: int = -1 # ID of the end-of-sequence token
- created_time: float = field(default_factory=time.time) # Time the request was created
- error: Optional[str] = None # Error message if the request failed
- lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished)
- @property
- def status(self) -> RequestStatus:
- return self._status
- @status.setter
- def status(self, value: RequestStatus):
- if self._status == RequestStatus.PENDING:
- self.lifespan = (time.time(), -1)
- elif value == RequestStatus.FINISHED:
- self.lifespan = (self.lifespan[0], time.time())
- self.log_end_of_request()
- self._status = value
- def log_end_of_request(self):
- prefill_len = len(self.full_prompt_ids)
- decode_len = self.generated_len()
- start_time = self.lifespan[0] - self.created_time
- end_time = self.lifespan[1] - self.created_time
- logger.info(
- f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }"
- )
- def current_len(self) -> int:
- """Get the current length of the sequence (prompt + generated tokens)."""
- return self.position_offset
- def generated_len(self) -> int:
- """Get the number of tokens generated so far."""
- return len(self.static_outputs)
- # TODO: this logic seems one token off, check it out
- @traced
- def update_with_token(self, token_id: int) -> bool:
- """Update the request with a newly generated token and check for completion.
- Args:
- token_id: The token ID to add to the output sequence
- Returns:
- bool: True if the request is now complete, False otherwise
- """
- # Only update if we're in decoding state
- if self.status != RequestStatus.DECODING:
- return False
- is_eos = token_id == self.eos_token_id and self.eos_token_id != -1
- is_max_len = self.generated_len() >= self.max_new_tokens
- # Only add the token if we're not finishing due to max length
- # (EOS tokens should still be added to the output)
- if not (is_max_len and not is_eos):
- self.static_outputs.extend([token_id])
- if is_eos or is_max_len:
- self.status = RequestStatus.FINISHED
- return True
- return False
- def __repr__(self):
- msg = [
- f"request_id={self.request_id}",
- f"status={self._status}",
- f"out_tokens={self.generated_len()}",
- f"query_length={len(self.prompt_ids)}",
- f"remaining_tokens={len(self.remaining_prompt_ids)}",
- f"kv_length={self.position_offset}",
- f"full_prompt_length={len(self.full_prompt_ids)}",
- f"allocated_blocks={self.allocated_blocks}",
- f"generated_tokens={self.static_outputs}",
- ]
- return "RequestState(\n\t" + ",\n\t".join(msg) + "\n)"
- def to_generation_output(self):
- """Convert the request state to a GenerationOutput object."""
- return GenerationOutput(
- request_id=self.request_id,
- prompt_ids=self.full_prompt_ids,
- status=self.status,
- generated_tokens=self.static_outputs,
- logprobs=[],
- error=self.error,
- )
|