scheduler.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300
  1. # coding=utf-8
  2. # Copyright 2025 The HuggingFace Inc. team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import threading
  16. from abc import ABC, abstractmethod
  17. from collections import deque
  18. from ...utils.metrics import attach_tracer, traced
  19. from .cache import PagedAttentionCache
  20. from .requests import RequestState, RequestStatus
  21. class Scheduler(ABC):
  22. """
  23. Abstract base class for scheduling requests in the continuous batch processor. Schedulers manage the lifecycle of
  24. requests from when they are added to the waiting queue to when they are scheduled for processing. Different
  25. schedulers implement different strategies for prioritizing and batching requests.
  26. """
  27. def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False):
  28. self.active_requests: dict[str, RequestState] = {}
  29. self.waiting_requests: dict[str, RequestState] = {}
  30. self.waiting_requests_order: deque[str] = deque()
  31. self.cache = cache
  32. self.retain_cache_on_finish = retain_cache_on_finish
  33. self._cancellation_lock = threading.Lock()
  34. self._requests_to_cancel: set[str] = set()
  35. @traced
  36. def add_waiting_request(self, state: RequestState):
  37. """Adds a request to the waiting list."""
  38. if self.retain_cache_on_finish and state.request_id in self.active_requests:
  39. old_state = self.active_requests.pop(state.request_id)
  40. state.prompt_ids = state.prompt_ids[len(old_state.full_prompt_ids) :] # XXX: check for indexing error?
  41. state.allocated_blocks = old_state.allocated_blocks
  42. state.position_offset = old_state.position_offset
  43. self.waiting_requests[state.request_id] = state
  44. self.waiting_requests_order.append(state.request_id)
  45. @abstractmethod
  46. def schedule_batch(self, token_budget: int) -> list[RequestState]:
  47. """Schedules requests for the next batch based on available token budget. This method selects which requests
  48. should be processed in the current batch, considering the token budget and the scheduler's prioritization rules.
  49. The token_budget is the maximum number of tokens that can be processed in this batch."""
  50. pass
  51. @traced
  52. def has_pending_requests(self) -> bool:
  53. """Checks if there are requests ready to be processed."""
  54. return len(self.active_requests) or len(self.waiting_requests)
  55. @traced
  56. def finish_request(self, request_id: str, evict_from_cache: bool = True):
  57. """Completes processing of a request and optionally frees its allocated cache blocks. This method is called
  58. when a request has finished generation or encountered an error.
  59. """
  60. if evict_from_cache:
  61. self.cache.free_blocks(request_id)
  62. if request_id in self.active_requests:
  63. del self.active_requests[request_id]
  64. @traced
  65. def get_active_request_static_outputs(self, request_id: str) -> list[int]:
  66. """Gets generated tokens for an active request."""
  67. if request_id in self.active_requests:
  68. return self.active_requests[request_id].static_outputs
  69. return []
  70. @traced
  71. def set_request_cancellation(self, request_id: str):
  72. """Marks a request for cancellation."""
  73. with self._cancellation_lock:
  74. self._requests_to_cancel.add(request_id)
  75. @traced
  76. def clear_cancelled_requests(self):
  77. """Remove all cancelled requests from active and waiting queues."""
  78. with self._cancellation_lock:
  79. for request_id in self._requests_to_cancel:
  80. if request_id in self.active_requests:
  81. del self.active_requests[request_id]
  82. if request_id in self.waiting_requests:
  83. del self.waiting_requests[request_id]
  84. if request_id in self.waiting_requests_order:
  85. self.waiting_requests_order.remove(request_id)
  86. self.cache.free_blocks(request_id)
  87. self._requests_to_cancel = set()
  88. @traced
  89. def request_is_cancelled(self, request_id: str) -> bool:
  90. """Checks if a request has been cancelled or removed."""
  91. return request_id in self._requests_to_cancel or (
  92. request_id not in self.active_requests and request_id not in self.waiting_requests
  93. )
  94. @traced
  95. def _allocate_blocks_if_needed(self, state: RequestState, len_next_tokens: int) -> bool:
  96. """Allocate additional cache blocks for a request if the currently allocated blocks are insufficient to
  97. accommodate the next tokens. It calculates how many blocks are needed based on the request's current
  98. cache occupancy and the number of tokens to be processed. The allocation itself is done by the CacheAllocator
  99. objects. Returns a boolean indicating if the allocation was successful or not.
  100. """
  101. # 1. we check that the occupancy is less than the requested length
  102. # 2. we allocate enough blocks to cover the requested length
  103. current_len = state.current_len()
  104. occupancy = state.allocated_blocks * self.cache.block_size - current_len
  105. if occupancy < len_next_tokens or state.allocated_blocks == 0:
  106. blocks_needed = ((len_next_tokens - occupancy + 1) // self.cache.block_size) + 1
  107. allocated = self.cache.allocate_blocks(blocks_needed, state.request_id)
  108. if allocated is None:
  109. return False
  110. state.allocated_blocks += allocated
  111. return True
  112. @traced(span_name="prepare_request")
  113. def _prepare_request_for_processing(
  114. self, state: RequestState, token_budget: int, request_ids_to_remove_from_waiting: set[str]
  115. ):
  116. """Prepares a request for processing in the current batch."""
  117. request_tokens = (
  118. state.remaining_prompt_ids if state.status == RequestStatus.SPLIT_PENDING_REMAINDER else state.prompt_ids
  119. )
  120. if len(request_tokens) < token_budget:
  121. # Can process the entire prompt/remainder
  122. if state.status == RequestStatus.PENDING:
  123. self.active_requests[state.request_id] = state
  124. state.status = RequestStatus.PREFILLING
  125. request_ids_to_remove_from_waiting.add(state.request_id)
  126. elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
  127. state.status = RequestStatus.PREFILLING
  128. state.prompt_ids = state.remaining_prompt_ids
  129. state.remaining_prompt_ids = []
  130. else:
  131. # Need to split the request
  132. if state.status == RequestStatus.PENDING:
  133. self.active_requests[state.request_id] = state
  134. state.status = RequestStatus.PREFILLING_SPLIT
  135. request_ids_to_remove_from_waiting.add(state.request_id)
  136. elif state.status == RequestStatus.SPLIT_PENDING_REMAINDER:
  137. state.status = RequestStatus.PREFILLING_SPLIT
  138. state.remaining_prompt_ids = request_tokens[token_budget:]
  139. state.prompt_ids = request_tokens[:token_budget]
  140. @attach_tracer()
  141. class FIFOScheduler(Scheduler):
  142. """This scheduler processes requests in the order they arrive, meaning decoding requests has priority over
  143. prefilling requests. Additionally, it includes a safety margin mechanism to prevent cache exhaustion. By default,
  144. when 80% of the cache is full, new requests will not be scheduled to prioritize decoding active requests."""
  145. def __init__(self, cache: PagedAttentionCache, retain_cache_on_finish: bool = False, safety_margin: float = 0.2):
  146. """Initializes the FIFO scheduler. The safety margin is the percentage of free blocks under which we stop
  147. scheduling new prefill requests, so safety_margin = 0.1 means that when there is less than 10% of free blocks,
  148. or equivalently when more than 90% of blocks are already allocated, we stop scheduling new prefill requests.
  149. """
  150. super().__init__(cache, retain_cache_on_finish)
  151. self.safety_margin = safety_margin
  152. @traced
  153. def schedule_batch(self, token_budget: int) -> list[RequestState]:
  154. priority_states: list[RequestState] = []
  155. second_priority_states: list[RequestState] = []
  156. scheduled_requests = []
  157. for state in self.active_requests.values():
  158. if state.status == RequestStatus.DECODING:
  159. priority_states.append(state)
  160. if state.status in [RequestStatus.SPLIT_PENDING_REMAINDER, RequestStatus.PREFILLING_SPLIT]:
  161. second_priority_states.append(state)
  162. # Add waiting requests to second priority
  163. for req_id in self.waiting_requests_order:
  164. second_priority_states.append(self.waiting_requests[req_id])
  165. candidates = priority_states + second_priority_states
  166. request_ids_to_remove_from_waiting = set()
  167. safety_margins = self.safety_margin * self.cache.num_blocks
  168. for state in candidates:
  169. # If we are out the safety margin, we only accept decoding requests or the first prefill request
  170. num_free_blocks = self.cache.get_num_free_blocks()
  171. outside_safety_margin = num_free_blocks < safety_margins
  172. if outside_safety_margin and scheduled_requests and state.status != RequestStatus.DECODING:
  173. break
  174. self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
  175. request_len = len(state.prompt_ids)
  176. if not self._allocate_blocks_if_needed(
  177. state, len(state.prompt_ids)
  178. ): # don't schedule if we can't allocate blocks
  179. if len(self.cache._free_blocks) == 0:
  180. break
  181. continue
  182. @traced
  183. def _add_to_scheduled_requests(state: RequestState):
  184. scheduled_requests.append(state)
  185. _add_to_scheduled_requests(state)
  186. token_budget -= request_len
  187. @traced
  188. def _remove_from_waiting_requests(state: RequestState):
  189. req_id = state.request_id
  190. if req_id in self.waiting_requests:
  191. del self.waiting_requests[req_id]
  192. request_ids_to_remove_from_waiting.add(req_id)
  193. _remove_from_waiting_requests(state)
  194. if token_budget == 0:
  195. break
  196. self.waiting_requests_order = deque(
  197. [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
  198. )
  199. return scheduled_requests
  200. # FIXME: prioritize adding from waiting reqs before scheduling `RequestStatus.DECODING` when cache space allows it
  201. @attach_tracer()
  202. class PrefillFirstScheduler(Scheduler):
  203. """Scheduler that prioritizes split prefill requests over decoding requests. This scheduler ensures that split
  204. prefill requests (which are continuations of partially processed prompts) are completed before processing new
  205. decoding requests."""
  206. @traced
  207. def schedule_batch(self, token_budget: int) -> list[RequestState]:
  208. priority_states: list[RequestState] = []
  209. second_priority_states: list[RequestState] = []
  210. scheduled_requests = []
  211. for state in self.active_requests.values():
  212. # XXX: when cache is full, state can stay on `PREFILLING_SPLIT` so we need to take those into account
  213. if state.status in [RequestStatus.PREFILLING_SPLIT, RequestStatus.SPLIT_PENDING_REMAINDER]:
  214. priority_states.append(state)
  215. elif state.status == RequestStatus.DECODING:
  216. second_priority_states.append(state)
  217. for req_id in self.waiting_requests_order:
  218. second_priority_states.append(self.waiting_requests[req_id])
  219. candidates = priority_states + second_priority_states
  220. request_ids_to_remove_from_waiting = set()
  221. for state in candidates:
  222. self._prepare_request_for_processing(state, token_budget, request_ids_to_remove_from_waiting)
  223. request_len = len(state.prompt_ids)
  224. if not self._allocate_blocks_if_needed(
  225. state, len(state.prompt_ids)
  226. ): # don't schedule if we can't allocate blocks
  227. if len(self.cache._free_blocks) == 0:
  228. break
  229. continue
  230. @traced
  231. def _add_to_scheduled_requests(state: RequestState):
  232. scheduled_requests.append(state)
  233. _add_to_scheduled_requests(state)
  234. token_budget -= request_len
  235. @traced
  236. def _remove_from_waiting_requests(state: RequestState):
  237. req_id = state.request_id
  238. if req_id in self.waiting_requests:
  239. del self.waiting_requests[req_id]
  240. request_ids_to_remove_from_waiting.add(req_id)
  241. _remove_from_waiting_requests(state)
  242. if token_budget == 0:
  243. break
  244. self.waiting_requests_order = deque(
  245. [req_id for req_id in self.waiting_requests_order if req_id not in request_ids_to_remove_from_waiting]
  246. )
  247. return scheduled_requests
  248. SCHEDULER_MAPPING = {
  249. "fifo": FIFOScheduler,
  250. "prefill_first": PrefillFirstScheduler,
  251. }