candidate_generator.py 62 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265
  1. # coding=utf-8
  2. # Copyright 2023 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import copy
  16. import weakref
  17. from typing import TYPE_CHECKING, Any, Optional
  18. import numpy as np
  19. import torch
  20. import torch.nn as nn
  21. from ..pytorch_utils import prune_linear_layer
  22. from ..utils import is_sklearn_available
  23. if is_sklearn_available():
  24. from sklearn.metrics import roc_curve
  25. from ..pytorch_utils import isin_mps_friendly
  26. from .logits_process import LogitsProcessorList, MinLengthLogitsProcessor, SuppressTokensLogitsProcessor
  27. if TYPE_CHECKING:
  28. from ..modeling_utils import PreTrainedModel
  29. from ..tokenization_utils_base import PreTrainedTokenizerBase
  30. from .configuration_utils import GenerationConfig
  31. class CandidateGenerator:
  32. """Abstract base class for all candidate generators that can be applied during assisted generation."""
  33. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  34. """
  35. Fetches the candidates to be tried for the current input.
  36. Args:
  37. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  38. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  39. Return:
  40. `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
  41. assessed by the model and, optionally, a `torch.FloatTensor` of shape `(batch_size, candidate_length,
  42. vocabulary_size)` containing the logits associated to each candidate.
  43. """
  44. raise NotImplementedError(
  45. f"{self.__class__} is an abstract class. Only classes inheriting this class can call `get_candidates`."
  46. )
  47. def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
  48. """
  49. Updates the candidate generation strategy based on the outcomes.
  50. Args:
  51. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  52. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  53. scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
  54. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
  55. beam search or log softmax for each vocabulary token when using beam search
  56. num_matches (`int`):
  57. The number of matches between the candidate sequences and the model predictions.
  58. """
  59. raise NotImplementedError(
  60. f"{self.__class__} is an abstract class. Only classes inheriting this class can call "
  61. "`update_candidate_strategy`."
  62. )
  63. class AssistedCandidateGenerator(CandidateGenerator):
  64. """
  65. `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
  66. candidates through the use of a smaller model. Read the following blog post for more information:
  67. https://huggingface.co/blog/assisted-generation
  68. Args:
  69. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  70. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  71. assistant_model (`PreTrainedModel`):
  72. The model to be used for generating candidates. This model should be smaller than the main model.
  73. generation_config (`~generation.GenerationConfig`, *optional*):
  74. The generation configuration to be used as base parametrization for the generation call.
  75. logits_processor (`LogitsProcessorList`):
  76. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  77. used to modify the prediction scores of the language modeling head applied at each generation step.
  78. model_kwargs (`Dict`):
  79. The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
  80. model as well.
  81. inputs_tensor (`torch.Tensor`, *optional*):
  82. The model input tensor. In encoder-decoder models, this is the encoder input.
  83. """
  84. def __init__(
  85. self,
  86. input_ids: torch.LongTensor,
  87. assistant_model: "PreTrainedModel",
  88. generation_config: "GenerationConfig",
  89. model_kwargs: dict,
  90. inputs_tensor: Optional[torch.Tensor] = None,
  91. logits_processor: Optional["LogitsProcessorList"] = None,
  92. ):
  93. # Make sure all data at the same device as assistant model
  94. device = assistant_model.device
  95. input_ids = input_ids.to(device)
  96. if inputs_tensor is not None:
  97. inputs_tensor = inputs_tensor.to(device)
  98. # Prepare the assistant and the starting number of candidate tokens
  99. self.assistant_model = assistant_model
  100. self.num_assistant_tokens = assistant_model.generation_config.num_assistant_tokens
  101. self.assistant_confidence_threshold = assistant_model.generation_config.assistant_confidence_threshold
  102. # Set eos in assistant same as in target model
  103. self.assistant_model.generation_config.eos_token_id = generation_config.eos_token_id
  104. # Prepare the kwargs for the assistant model
  105. assistant_kwargs = {}
  106. for key, value in model_kwargs.items(): # deepcopy crashes if we attempt to copy encoder outputs with grads
  107. if key not in ("encoder_outputs", "past_key_values"):
  108. assistant_kwargs[key] = (
  109. value.detach().to(device) if isinstance(value, torch.Tensor) else copy.deepcopy(value)
  110. )
  111. # Remove potential default "logits_to_keep" key
  112. if "logits_to_keep" in assistant_kwargs and not assistant_model._supports_logits_to_keep():
  113. del assistant_kwargs["logits_to_keep"]
  114. # If the assistant is an encoder-decoder model, assume the encoder is different on the assistant.
  115. if assistant_model.config.is_encoder_decoder:
  116. inputs_tensor, model_input_name, assistant_kwargs = assistant_model._prepare_model_inputs(
  117. inputs_tensor, assistant_model.generation_config.bos_token_id, assistant_kwargs
  118. )
  119. assistant_kwargs = assistant_model._prepare_encoder_decoder_kwargs_for_generation(
  120. inputs_tensor, assistant_kwargs, model_input_name, assistant_model.generation_config
  121. )
  122. elif "encoder_outputs" in model_kwargs:
  123. assistant_kwargs["encoder_outputs"] = model_kwargs["encoder_outputs"]
  124. self.assistant_kwargs = assistant_kwargs
  125. # Prepare assistant model's keys of inputs
  126. if assistant_model.config.is_encoder_decoder:
  127. # both are encoder-decoder
  128. self.input_ids_key = "decoder_input_ids"
  129. elif "encoder_outputs" in assistant_kwargs:
  130. # special case for encoder-decoder with decoder-only assistant (like DistilWhisper)
  131. self.input_ids_key = "input_ids"
  132. self.assistant_kwargs["attention_mask"] = self.assistant_kwargs.get(
  133. "decoder_attention_mask",
  134. torch.ones((input_ids.shape[0], 1), device=input_ids.device, dtype=torch.long),
  135. )
  136. else:
  137. # both are decoder-only
  138. self.input_ids_key = "input_ids"
  139. # Prepare generation-related options.
  140. self.logits_processor = logits_processor if logits_processor is not None else LogitsProcessorList()
  141. self.generation_config = copy.deepcopy(generation_config)
  142. self.generation_config.return_dict_in_generate = True
  143. self.generation_config.output_scores = True
  144. self.generation_config.assistant_confidence_threshold = self.assistant_confidence_threshold
  145. # this flag allow us set the confidence stopping criteria for assistant model generation.
  146. self.generation_config.is_assistant = True
  147. # avoid unnecessary warnings that min_length is larger than max_new_tokens
  148. # remove the `MinLengthLogitsProcessor` if exists (NOTE: no need to check for `MinNewTokensLogitsProcessor`)
  149. self.main_model_min_length = self.generation_config.min_length
  150. self.generation_config.min_length = 0
  151. self.generation_config.min_new_tokens = None
  152. for processor in self.logits_processor:
  153. if isinstance(processor, MinLengthLogitsProcessor):
  154. raise ValueError(
  155. "Passing `MinLengthLogitsProcessor` when using `assisted_generation is disabled. "
  156. "Please pass in `min_length` into `.generate()` instead"
  157. )
  158. # We need to roll back the cache in assisted generation, only DynamicCache is supported
  159. self.generation_config.cache_implementation = "dynamic_full"
  160. if (
  161. is_sklearn_available()
  162. and self.assistant_model.generation_config.assistant_confidence_threshold
  163. and type(self) is AssistedCandidateGenerator
  164. ):
  165. self.probs = []
  166. self.matches = []
  167. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  168. """
  169. Fetches the candidates to be tried for the current input.
  170. Args:
  171. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  172. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  173. Return:
  174. `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
  175. assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
  176. vocabulary_size)` containing the logits associated to each candidate.
  177. """
  178. input_ids = input_ids.to(self.assistant_model.device)
  179. # Calculate new tokens to generate
  180. min_new_tokens, max_new_tokens = self._calculate_new_tokens(input_ids)
  181. if max_new_tokens == 0:
  182. return input_ids, None
  183. # Update past key values and masks
  184. self._update_past_and_masks(input_ids)
  185. # Generate candidates
  186. generation_args = self._prepare_generation_args(input_ids, min_new_tokens, max_new_tokens)
  187. candidate_ids, candidate_logits = self._generate_candidates(generation_args)
  188. return candidate_ids, candidate_logits
  189. def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
  190. """
  191. Updates the candidate generation strategy based on the outcomes.
  192. Args:
  193. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  194. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  195. scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
  196. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
  197. beam search or log softmax for each vocabulary token when using beam search
  198. num_matches (`int`):
  199. The number of matches between the candidate sequences and the model predictions.
  200. """
  201. # Adjust the max number of assistant tokens to use in the next iteration. This is a simple heuristic,
  202. # probably can be improved -- we want to balance the benefits of getting assistant tokens correct with the
  203. # cost of forecasting incorrect assistant tokens.
  204. if self.assistant_model.generation_config.num_assistant_tokens_schedule in {
  205. "heuristic",
  206. "heuristic_transient",
  207. }:
  208. # len(scores[0])-1 is the number of candidates according to the target tokenizer.
  209. if num_matches == len(scores[0]) - 1:
  210. self.num_assistant_tokens += 2.0
  211. else:
  212. self.num_assistant_tokens = max(1.0, self.num_assistant_tokens - 1.0)
  213. # The assistant's confidence threshold is adjusted throughout the speculative iterations to reduce the number of unnecessary draft and target forward passes. The costs are estimated based on the ROC curve, which considers the probability of the draft token and its match with the target. A cost of 25% is assigned to false positives and 75% to false negatives.
  214. # This adaptation is not compatible with UAG, as it relies on the number of matched tokens based on the draft vocabulary, which is unavailable in UAG.
  215. if (
  216. is_sklearn_available()
  217. and self.assistant_model.generation_config.assistant_confidence_threshold
  218. and type(self) is AssistedCandidateGenerator
  219. ):
  220. # update self.matches
  221. self.matches.extend([1] * num_matches)
  222. if len(self.probs) > len(self.matches):
  223. self.matches.append(0)
  224. # update self.probs
  225. excess_length = len(self.probs) - len(self.matches)
  226. if excess_length > 0:
  227. del self.probs[-excess_length:]
  228. if (
  229. len(self.probs) > 5 and {0, 1}.issubset(self.matches)
  230. ): # require at least 5 samples to calculate the ROC curve and at least one positive and one negative sample
  231. fpr, tpr, thresholds = roc_curve(self.matches, self.probs)
  232. fnr = 1 - tpr
  233. # Calculate the cost for each threshold
  234. costs = fpr + 3 * fnr
  235. # Find the threshold that minimizes the cost
  236. optimal_threshold_index = np.argmin(costs)
  237. best_threshold = thresholds[optimal_threshold_index]
  238. self.assistant_model.generation_config.assistant_confidence_threshold = best_threshold
  239. def _calculate_new_tokens(self, input_ids: torch.LongTensor) -> tuple[int, int]:
  240. """Calculate the minimum and maximum number of new tokens to generate."""
  241. new_cur_len = input_ids.shape[-1]
  242. max_new_tokens = min(int(self.num_assistant_tokens), self.generation_config.max_length - new_cur_len - 1)
  243. min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - new_cur_len), 0)
  244. return min_new_tokens, max_new_tokens
  245. def _update_past_and_masks(
  246. self, input_ids: torch.LongTensor, remove_from_pkv: int = 0, num_added_tokens: int = 1
  247. ) -> bool:
  248. """Update past key values and attention masks for subsequent generation rounds."""
  249. has_past_key_values = self.assistant_kwargs.get("past_key_values", None) is not None
  250. if has_past_key_values:
  251. new_cache_size = input_ids.shape[-1] - 1 - remove_from_pkv
  252. self.assistant_kwargs["past_key_values"].crop(new_cache_size - num_added_tokens)
  253. self.assistant_kwargs = _prepare_attention_mask(
  254. self.assistant_kwargs, input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
  255. )
  256. self.assistant_kwargs = _prepare_token_type_ids(self.assistant_kwargs, input_ids.shape[-1])
  257. # This unsets `dynamic_full`, needed to initialize a new cache for the assistant. After the first forward
  258. # pass on each generation, we reuse the cache instead.
  259. self.generation_config.cache_implementation = None
  260. return has_past_key_values
  261. def _prepare_generation_args(self, input_ids: torch.LongTensor, min_new_tokens: int, max_new_tokens: int) -> dict:
  262. """Prepare arguments for the generation call."""
  263. return {
  264. self.input_ids_key: input_ids,
  265. "min_new_tokens": min_new_tokens,
  266. "max_new_tokens": max_new_tokens,
  267. "generation_config": self.generation_config,
  268. "logits_processor": self.logits_processor,
  269. }
  270. def _generate_candidates(self, generation_args: dict) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  271. """Generate candidate sequences using the assistant model."""
  272. assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
  273. self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
  274. if (
  275. is_sklearn_available()
  276. and self.assistant_model.generation_config.assistant_confidence_threshold
  277. and type(self) is AssistedCandidateGenerator
  278. ):
  279. scores_tensor = torch.cat(assistant_output.scores, dim=0)
  280. scores_softmax = torch.softmax(scores_tensor, dim=-1)
  281. ids = assistant_output.sequences[-1, -len(assistant_output.scores) :]
  282. p = scores_softmax[range(len(ids)), ids]
  283. self.probs.extend(p.tolist())
  284. candidate_logits = torch.stack(assistant_output.scores, dim=1)
  285. candidate_ids = assistant_output.sequences
  286. return candidate_ids, candidate_logits
  287. class AssistedCandidateGeneratorDifferentTokenizers(AssistedCandidateGenerator):
  288. """
  289. `CandidateGenerator` class to be used for Universal Assisted Generation (UAD): assisted generation with different tokenizers
  290. for the assistant and main models. This class generates candidates through the use of a smaller
  291. model.
  292. The main model input tokens are re-encoded into assistant model tokens, then candidate tokens are generated in the assistant encoding, which are
  293. in turn re-encoded into main model candidate tokens. Validation then proceeds as explained above.
  294. The re-encoding steps involve decoding token ids into text and then encoding the text using a different tokenizer.
  295. Since re-encoding the tokens may result in tokenization discrepancies, UAD finds the longest common subsequence between the source and target encodings,
  296. to ensure the new tokens include the correct prompt suffix.
  297. Args:
  298. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  299. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  300. assistant_model (`PreTrainedModel`):
  301. The model to be used for generating candidates. This model should be smaller than the main model.
  302. target_tokenizer (`PreTrainedTokenizerBase`):
  303. The tokenizer used for the target model.
  304. assistant_tokenizer (`PreTrainedTokenizerBase`):
  305. The tokenizer used for the assistant model.
  306. generation_config (`~generation.GenerationConfig`, *optional*):
  307. The generation configuration to be used as base parametrization for the generation call.
  308. logits_processor (`LogitsProcessorList`):
  309. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  310. used to modify the prediction scores of the language modeling head applied at each generation step.
  311. model_kwargs (`Dict`):
  312. The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
  313. model as well.
  314. inputs_tensor (`torch.Tensor`, *optional*):
  315. The model input tensor. In encoder-decoder models, this is the encoder input.
  316. """
  317. def __init__(
  318. self,
  319. input_ids: torch.LongTensor,
  320. assistant_model: "PreTrainedModel",
  321. target_tokenizer: "PreTrainedTokenizerBase",
  322. assistant_tokenizer: "PreTrainedTokenizerBase",
  323. generation_config: "GenerationConfig",
  324. model_kwargs: dict,
  325. inputs_tensor: Optional[torch.Tensor] = None,
  326. logits_processor: Optional["LogitsProcessorList"] = None,
  327. ):
  328. super().__init__(input_ids, assistant_model, generation_config, model_kwargs, inputs_tensor, logits_processor)
  329. self.target_tokenizer = target_tokenizer
  330. self.assistant_tokenizer = assistant_tokenizer
  331. self.prev_target_ids_len: Optional[int] = None
  332. self.prev_assistant_ids = None
  333. self.target_lookbehind = assistant_model.generation_config.target_lookbehind
  334. self.assistant_lookbehind = assistant_model.generation_config.assistant_lookbehind
  335. @staticmethod
  336. def _get_longest_diag_dict(input_matrix, nonzero_idx):
  337. """
  338. Calculates the length of the longest diagonal sequence in a given matrix.
  339. Args:
  340. input_matrix (torch.Tensor): The input matrix.
  341. nonzero_idx (torch.Tensor): The indices of the non-zero elements in the matrix.
  342. Returns:
  343. dict: A dictionary where the keys are the indices of the non-zero elements and the values are the lengths of the longest diagonal sequences starting from those indices.
  344. """
  345. visited = set()
  346. diags = {}
  347. for idx in nonzero_idx:
  348. start_idx = torch.clone(idx)
  349. tuple_start_idx = tuple(start_idx.tolist())
  350. if tuple_start_idx in visited:
  351. continue
  352. visited.add(tuple_start_idx)
  353. cur_diag_len = 1
  354. start_idx += 1
  355. while start_idx[0] < input_matrix.shape[0] and start_idx[1] < input_matrix.shape[1]:
  356. tuple_start_idx = tuple(start_idx.tolist())
  357. visited.add(tuple_start_idx)
  358. if input_matrix[start_idx[0], start_idx[1]] == 1:
  359. cur_diag_len += 1
  360. start_idx += 1
  361. else:
  362. break
  363. diags[idx] = cur_diag_len
  364. return diags
  365. @staticmethod
  366. def _get_longest_diag_index(input_matrix):
  367. """
  368. Returns the start index and length of the longest diagonal in the given input.
  369. Args:
  370. input_matrix (numpy.ndarray): The input matrix.
  371. Returns:
  372. tuple: A tuple containing the start index and length of the longest diagonal.
  373. """
  374. diags = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_dict(
  375. input_matrix, input_matrix.nonzero()
  376. )
  377. diags_values = list(diags.values())
  378. diags_keys = list(diags.keys())
  379. best_diag = np.argmax(diags_values)
  380. diag_start_index = diags_keys[best_diag]
  381. diag_start_length = diags_values[best_diag]
  382. return diag_start_index, diag_start_length
  383. @staticmethod
  384. def _get_tokens_diag(prompt, prompt_plus_new_tokens):
  385. """
  386. Input:
  387. prompt: 2D array of shape (batch_size, prompt_length), represents the original prompt tokens
  388. prompt_plus_new_tokens: 2D array of shape (batch_size, prompt_length), represents the suffix of the original prompt, with additional new tokens.
  389. Output:
  390. discrepancy_length: int, represents the number of tokens that need to be replaced from prompt
  391. new_tokens_only: 2D array of shape (batch_size, new_token_length), represents the new tokens that are not in prompt
  392. discrepancy_only: 2D array of shape (batch_size, discrepancy_length), represents the new tokens that are in prompt but not in prompt_plus_new_tokens
  393. """
  394. compare_mat = prompt_plus_new_tokens.T == prompt
  395. if not torch.is_tensor(compare_mat):
  396. compare_mat = torch.tensor(compare_mat)
  397. compare_mat_int = compare_mat.to(int)
  398. if not compare_mat_int.any().item():
  399. # empty intersection between prompt and prompt_plus_new_tokens
  400. return None, None, None
  401. longest_location, longest_diag_length = AssistedCandidateGeneratorDifferentTokenizers._get_longest_diag_index(
  402. compare_mat_int
  403. )
  404. new_token_start_index = longest_location[0] + longest_diag_length
  405. discrepancy_with_old = longest_location[1] + longest_diag_length
  406. discrepancy_length = (prompt.shape[1] - discrepancy_with_old).item()
  407. new_tokens_only = prompt_plus_new_tokens[:, new_token_start_index + discrepancy_length :]
  408. discrepancy_only = prompt_plus_new_tokens[
  409. :, new_token_start_index : new_token_start_index + discrepancy_length
  410. ]
  411. return discrepancy_length, new_tokens_only, discrepancy_only
  412. def convert_source_tokens_to_target_tokens(
  413. self,
  414. input_ids,
  415. source_tokenizer,
  416. destination_tokenizer,
  417. ):
  418. """
  419. Convert token IDs from one tokenizer to another.
  420. Args:
  421. input_ids: The input token IDs.
  422. source_tokenizer: The source tokenizer.
  423. destination_tokenizer: The destination tokenizer.
  424. Returns:
  425. The converted token IDs.
  426. """
  427. text = source_tokenizer.batch_decode(input_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
  428. dest_ids = destination_tokenizer(text, add_special_tokens=True, return_tensors="pt")["input_ids"]
  429. return dest_ids.to(input_ids.device)
  430. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  431. """
  432. Fetches the candidates to be tried for the current input.
  433. Args:
  434. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  435. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  436. Return:
  437. `torch.LongTensor` of shape `(batch_size, candidate_length)` containing the candidate sequences to be
  438. assessed by the model and a `torch.FloatTensor` of shape `(batch_size, candidate_length,
  439. vocabulary_size)` containing the logits associated to each candidate.
  440. """
  441. max_new_tokens = int(self.num_assistant_tokens)
  442. if max_new_tokens == 0:
  443. return input_ids, None
  444. input_ids = input_ids.to(self.assistant_model.device)
  445. remove_from_pkv = 0
  446. assistant_input_ids, remove_from_pkv = self._prepare_assistant_input_ids(input_ids)
  447. self.prev_assistant_ids = assistant_input_ids
  448. min_new_tokens = max(min(max_new_tokens, self.main_model_min_length - assistant_input_ids.shape[-1]), 0)
  449. self._update_past_and_masks(assistant_input_ids, remove_from_pkv)
  450. generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
  451. self.assistant_kwargs.pop("attention_mask", None)
  452. assistant_output = self.assistant_model.generate(**generation_args, **self.assistant_kwargs)
  453. new_target_ids = self._process_assistant_outputs(input_ids, assistant_output.sequences)
  454. # Update state
  455. self.prev_target_ids_len = input_ids.shape[1]
  456. self.assistant_kwargs["past_key_values"] = assistant_output.past_key_values
  457. self.prev_assistant_ids = assistant_output.sequences
  458. if self.prev_target_ids_len >= new_target_ids.shape[1]:
  459. return input_ids, None
  460. return new_target_ids, None
  461. def _prepare_assistant_input_ids(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, int]:
  462. """Converts target input IDs to assistant input IDs, handling discrepancies."""
  463. convert_kwargs = {
  464. "source_tokenizer": self.target_tokenizer,
  465. "destination_tokenizer": self.assistant_tokenizer,
  466. }
  467. remove_from_pkv = 0
  468. if self.prev_assistant_ids is not None and self.prev_target_ids_len > self.target_lookbehind:
  469. # input_ids contains all target prompt input ids and some new target input ids
  470. start_index_in_target_window = self.prev_target_ids_len - self.target_lookbehind
  471. new_assistant_ids = self.convert_source_tokens_to_target_tokens(
  472. input_ids[:, start_index_in_target_window:], **convert_kwargs
  473. )
  474. prompt_use_length = new_assistant_ids.shape[1]
  475. prompt_use = self.prev_assistant_ids[:, -prompt_use_length:]
  476. discrepancy_length, new_tokens_only, discrepancy_only = self._get_tokens_diag(
  477. prompt_use, new_assistant_ids
  478. )
  479. assistant_input_ids = self.prev_assistant_ids
  480. if new_tokens_only is not None:
  481. if discrepancy_length > 0 and discrepancy_only.shape[1] > 0:
  482. if discrepancy_length == discrepancy_only.shape[1]:
  483. assistant_input_ids[:, -discrepancy_length:] = discrepancy_only
  484. elif discrepancy_length > discrepancy_only.shape[1]:
  485. discrepancy_length_diff = discrepancy_length - discrepancy_only.shape[1]
  486. assistant_input_ids = assistant_input_ids[:, :-discrepancy_length_diff]
  487. assistant_input_ids[:, -discrepancy_only.shape[1] :] = discrepancy_only
  488. remove_from_pkv = discrepancy_length
  489. if new_tokens_only.shape[1] > 0:
  490. assistant_input_ids = torch.cat([assistant_input_ids, new_tokens_only], dim=-1)
  491. else:
  492. # edge case: in case of no intersection between prompt and new_assistant_ids
  493. assistant_input_ids = torch.cat([assistant_input_ids, new_assistant_ids], dim=-1)
  494. else:
  495. assistant_input_ids = self.convert_source_tokens_to_target_tokens(input_ids, **convert_kwargs)
  496. self.prev_target_ids_len = input_ids.shape[1]
  497. return assistant_input_ids, remove_from_pkv
  498. def _process_assistant_outputs(
  499. self, input_ids: torch.LongTensor, assistant_sequences: torch.LongTensor
  500. ) -> torch.LongTensor:
  501. """Processes assistant outputs to obtain target input IDs."""
  502. num_prev_assistant = self.prev_assistant_ids.shape[1]
  503. start_assistant_look_index = num_prev_assistant - self.assistant_lookbehind
  504. new_target_ids_from_window = self.convert_source_tokens_to_target_tokens(
  505. assistant_sequences[:, start_assistant_look_index:],
  506. source_tokenizer=self.assistant_tokenizer,
  507. destination_tokenizer=self.target_tokenizer,
  508. )
  509. target_prompt_use_length = new_target_ids_from_window.shape[1]
  510. target_prompt_use = input_ids[:, -target_prompt_use_length:]
  511. _, target_new_tokens_only, _ = self._get_tokens_diag(target_prompt_use, new_target_ids_from_window)
  512. new_target_ids = input_ids
  513. if target_new_tokens_only is not None:
  514. if target_new_tokens_only.shape[1] > 0:
  515. new_target_ids = torch.cat([new_target_ids, target_new_tokens_only], dim=-1)
  516. else:
  517. # edge case: in case of no intersection between prompt and new_target_ids
  518. new_target_ids = torch.cat([new_target_ids, new_target_ids_from_window], dim=-1)
  519. if hasattr(self.generation_config, "max_length"):
  520. new_target_ids = new_target_ids[:, : self.generation_config.max_length]
  521. return new_target_ids
  522. class _PruneReindexingLMHead(nn.Module):
  523. """
  524. A class to prune and reindex the language model head.
  525. This class prunes the language model head to only include the specified token IDs and reindexes the logits
  526. to map back to the original vocabulary.
  527. Args:
  528. original_lm_head (nn.Module): The original language model head.
  529. token_ids (list[int]): The list of token IDs to keep.
  530. """
  531. def __init__(self, original_lm_head, assistant_overlap_token_ids):
  532. super().__init__()
  533. self.pruned_lm_head = prune_linear_layer(original_lm_head, assistant_overlap_token_ids).to(
  534. original_lm_head.weight.dtype
  535. )
  536. def forward(self, hidden_states):
  537. pruned_logits = self.pruned_lm_head(hidden_states)
  538. return pruned_logits
  539. class _MapInputEmbedding(nn.Module):
  540. def __init__(self, original_embedding: nn.Embedding, assistant_overlap_token_ids):
  541. """
  542. Wraps an existing embedding layer and remaps token IDs before lookup.
  543. Args:
  544. original_embedding (nn.Embedding): Pre-trained or existing embedding layer.
  545. assistant_overlap_token_ids (dict): Mapping from original token IDs to new token IDs.
  546. Example: {old_id: new_id}
  547. """
  548. super().__init__()
  549. self.original_embedding = original_embedding
  550. self.weight = original_embedding.weight
  551. self.assistant_overlap_token_ids = assistant_overlap_token_ids
  552. self.map = False
  553. def forward(self, input_ids: torch.LongTensor) -> torch.FloatTensor:
  554. """
  555. Args:
  556. input_ids (torch.LongTensor): Tensor of token IDs (batch_size, seq_len).
  557. Returns:
  558. torch.FloatTensor: Corresponding input embeddings.
  559. """
  560. if self.map:
  561. # Get the last item from input_ids
  562. my_input_ids = self.assistant_overlap_token_ids[input_ids[0, -1]].unsqueeze(0).unsqueeze(0)
  563. else:
  564. self.map = True
  565. my_input_ids = input_ids
  566. return self.original_embedding(my_input_ids)
  567. class AssistantToTargetTranslator:
  568. """
  569. Translates token ids and logits between assistant and target model vocabularies. This class is used to handle
  570. vocabulary mismatches when using different tokenizers for the assistant and target models in speculative decoding,
  571. as introduced in the paper "Lossless Speculative Decoding Algorithms for Heterogeneous Vocabularies"
  572. (https://huggingface.co/papers/2502.05202).
  573. It maintains mappings between the two vocabularies and handles token/logit conversion.
  574. Args:
  575. target_tokenizer (`PreTrainedTokenizerBase`):
  576. The tokenizer used by the target (main) model.
  577. assistant_tokenizer (`PreTrainedTokenizerBase`):
  578. The tokenizer used by the assistant model.
  579. target_vocab_size (`int`):
  580. The size of the target model's vocabulary. If not provided, will be inferred from the target tokenizer.
  581. assistant_model (Optional[PreTrainedModel], optional): The assistant model to be used. Defaults to None for backward compatibility.
  582. assistant_prune_lm_head (bool): Whether to prune the assistant model's language model
  583. head to match the target vocabulary. This is only applicable if `assistant_model` is provided.
  584. Defaults to False for backward compatibility.
  585. """
  586. FILTER_VALUE: float = -float("Inf") # The value used to filter out unmapped tokens in the logits.
  587. SUPPRESS_TOKEN_ID: int = -1 # The ID used to mark suppressed tokens in the mapping.
  588. def __init__(
  589. self,
  590. target_tokenizer: "PreTrainedTokenizerBase",
  591. assistant_tokenizer: "PreTrainedTokenizerBase",
  592. target_vocab_size: int, # required since target_vocab_size can be different from the length of target_tokenizer.get_vocab()
  593. assistant_model: Optional["PreTrainedModel"] = None,
  594. assistant_prune_lm_head: bool = False,
  595. ):
  596. self._target_tokenizer: PreTrainedTokenizerBase = target_tokenizer
  597. self._assistant_tokenizer: PreTrainedTokenizerBase = assistant_tokenizer
  598. self._assistant_model_device = assistant_model.device if assistant_model is not None else "cpu"
  599. self.target_vocab_size: int = target_vocab_size
  600. self._assistant_to_target_input_ids, self.target_to_assistant_input_ids = (
  601. self._get_assistant_to_target_input_ids()
  602. )
  603. self._suppress_input_ids: list[int] = self._get_suppress_input_ids()
  604. self.logits_processors: Optional[LogitsProcessorList] = None
  605. self.assistant_prune_lm_head = assistant_prune_lm_head and assistant_model is not None
  606. if len(self._suppress_input_ids) > 0:
  607. # the assistant vocab is not a subset of the target vocab
  608. if self.assistant_prune_lm_head:
  609. self.assistant_overlap_token_ids = torch.tensor(
  610. list(self.target_to_assistant_input_ids.values()),
  611. dtype=torch.long,
  612. device=self._assistant_model_device,
  613. )
  614. original_lm_head = assistant_model.get_output_embeddings()
  615. pruned_lm_head = _PruneReindexingLMHead(original_lm_head, self.assistant_overlap_token_ids)
  616. del original_lm_head
  617. assistant_model.set_output_embeddings(pruned_lm_head)
  618. original_input_embeddings = assistant_model.get_input_embeddings()
  619. map_input_embeddings = _MapInputEmbedding(original_input_embeddings, self.assistant_overlap_token_ids)
  620. del original_input_embeddings
  621. assistant_model.set_input_embeddings(map_input_embeddings)
  622. self.map_input_embeddings = map_input_embeddings
  623. else:
  624. self.logits_processors = LogitsProcessorList(
  625. [SuppressTokensLogitsProcessor(self._get_suppress_input_ids(), self._assistant_model_device)]
  626. )
  627. def unmap_input_ids(self):
  628. """
  629. Disables the mapping of input ids despite the assistant pruning for the language model head being enabled.
  630. This method is required for the first forward pass of `_MapInputEmbedding` where input ids are already in the assistant vocabulary space. By disabling the mapping, it ensures that the input ids are processed correctly without remapping.
  631. """
  632. if self.assistant_prune_lm_head:
  633. self.map_input_embeddings.map = False
  634. def _get_assistant_to_target_input_ids(self):
  635. target_vocab = self._target_tokenizer.get_vocab()
  636. assistant_vocab = self._assistant_tokenizer.get_vocab()
  637. space_str = " "
  638. target_space_ids = self._target_tokenizer(space_str, add_special_tokens=False)["input_ids"]
  639. if len(target_space_ids) > 0:
  640. target_space_sign = self._target_tokenizer.convert_ids_to_tokens(target_space_ids)[0][0]
  641. assistant_space_ids = self._assistant_tokenizer(space_str, add_special_tokens=False)["input_ids"]
  642. if len(assistant_space_ids) > 0:
  643. assistant_space_sign = self._assistant_tokenizer.convert_ids_to_tokens(assistant_space_ids)[0][0]
  644. if target_space_sign != assistant_space_sign:
  645. # If the assistant tokenizer has a different space sign than the target tokenizer,
  646. # we need to replace the assistant space sign with the target space sign in the assistant_vocab.
  647. assistant_vocab = {
  648. (
  649. tok.replace(assistant_space_sign, target_space_sign, 1)
  650. if tok.startswith(assistant_space_sign)
  651. else tok
  652. ): idx
  653. for tok, idx in assistant_vocab.items()
  654. }
  655. max_assistant_index = max(assistant_vocab.values())
  656. assistant_to_target_input_ids = torch.full((max_assistant_index + 1,), self.SUPPRESS_TOKEN_ID, dtype=int)
  657. target_to_assistant_input_ids: dict[int, int] = {}
  658. for tok, assistant_id in assistant_vocab.items():
  659. target_id = target_vocab.get(tok)
  660. if target_id is not None:
  661. assistant_to_target_input_ids[assistant_id] = target_id
  662. target_to_assistant_input_ids[target_id] = assistant_id
  663. return assistant_to_target_input_ids.to(self._assistant_model_device), target_to_assistant_input_ids
  664. def _get_suppress_input_ids(self) -> list[int]:
  665. """
  666. Get the input ids that are in the assistant vocab but not in the target vocab.
  667. """
  668. return torch.where(self._assistant_to_target_input_ids == self.SUPPRESS_TOKEN_ID)[0]
  669. def get_target_ids(
  670. self, assistant_input_ids, target_input_ids, assistant_candidate_ids: torch.LongTensor
  671. ) -> torch.LongTensor:
  672. """
  673. Return the target candidate ids that correspond to the assistant candidate ids.
  674. Note that we have already the target ids for the prompt and we only need to find the target ids for the new tokens.
  675. Moreover, assistant ids of the original prompt does not necessarily appear in _assistant_to_target_input_ids.
  676. """
  677. num_new_tokens = len(assistant_candidate_ids[0]) - assistant_input_ids.shape[1]
  678. if num_new_tokens == 0:
  679. return target_input_ids
  680. else:
  681. # Get last `num_new_tokens` candidate IDs
  682. last_candidate_ids = assistant_candidate_ids[0, -num_new_tokens:]
  683. if self.assistant_prune_lm_head:
  684. # Map assistant IDs -> target input IDs
  685. last_candidate_ids = self.assistant_overlap_token_ids[last_candidate_ids]
  686. transformed_slice = self._assistant_to_target_input_ids[last_candidate_ids]
  687. return torch.cat((target_input_ids, transformed_slice.unsqueeze(0)), dim=1)
  688. def get_target_logits(self, assistant_logits: torch.FloatTensor) -> torch.FloatTensor:
  689. """
  690. Return the target logits that correspond to the assistant logits.
  691. """
  692. target_shape: tuple[int, ...] = (*assistant_logits.shape[:-1], self.target_vocab_size)
  693. target_logits: torch.FloatTensor = torch.full(
  694. target_shape, self.FILTER_VALUE, device=self._assistant_model_device
  695. )
  696. # Mask for valid indices
  697. assistant_indices_mask = self._assistant_to_target_input_ids != self.SUPPRESS_TOKEN_ID
  698. # Exclude invalid indices
  699. target_logits_supported_indices = self._assistant_to_target_input_ids[assistant_indices_mask]
  700. if self.assistant_prune_lm_head:
  701. target_logits[..., target_logits_supported_indices] = assistant_logits
  702. else:
  703. valid_assistant_logits = assistant_logits[..., : self._assistant_to_target_input_ids.shape[0]]
  704. target_logits[..., target_logits_supported_indices] = valid_assistant_logits[..., assistant_indices_mask]
  705. return target_logits
  706. class AssistantVocabTranslatorCache:
  707. """
  708. Cache for `AssistantToTargetTranslator` instances. The instances are computed at
  709. pre-processing time, and this cache allows us to avoid recomputing them.
  710. """
  711. _cache = weakref.WeakKeyDictionary()
  712. @classmethod
  713. def get_translator(
  714. cls,
  715. target_tokenizer: "PreTrainedTokenizerBase",
  716. assistant_tokenizer: "PreTrainedTokenizerBase",
  717. target_vocab_size: int,
  718. assistant_model: Optional["PreTrainedModel"] = None,
  719. assistant_prune_lm_head: bool = False,
  720. ) -> AssistantToTargetTranslator:
  721. assistant_dict = cls._cache.get(target_tokenizer)
  722. if assistant_dict is None:
  723. assistant_dict = weakref.WeakKeyDictionary()
  724. cls._cache[target_tokenizer] = assistant_dict
  725. mapping = assistant_dict.get(assistant_tokenizer)
  726. if mapping is None:
  727. mapping = AssistantToTargetTranslator(
  728. target_tokenizer,
  729. assistant_tokenizer,
  730. target_vocab_size,
  731. assistant_model,
  732. assistant_prune_lm_head,
  733. )
  734. assistant_dict[assistant_tokenizer] = mapping
  735. return mapping
  736. @classmethod
  737. def cleanup(cls):
  738. """
  739. Clean up dead references in the cache.
  740. This removes entries where either the target_tokenizer or assistant_tokenizer
  741. has been garbage collected.
  742. """
  743. # Remove entries from the outer cache where the target_tokenizer is no longer alive
  744. dead_keys = [key for key in cls._cache if key is None]
  745. for key in dead_keys:
  746. del cls._cache[key]
  747. # For each assistant_dict, remove entries where assistant_tokenizer is no longer alive
  748. for assistant_dict in cls._cache.values():
  749. dead_keys = [key for key in assistant_dict if key is None]
  750. for key in dead_keys:
  751. del assistant_dict[key]
  752. class UniversalSpeculativeDecodingGenerator(AssistedCandidateGeneratorDifferentTokenizers):
  753. """
  754. `CandidateGenerator` class to be used for Universal Speculative Decoding (USD): speculative decoding with different tokenizers
  755. for the assistant and main models. This class generates candidates through the use of a smaller model.
  756. """
  757. def __init__(
  758. self,
  759. input_ids: torch.LongTensor,
  760. assistant_model: "PreTrainedModel",
  761. target_tokenizer: "PreTrainedTokenizerBase",
  762. assistant_tokenizer: "PreTrainedTokenizerBase",
  763. generation_config: "GenerationConfig",
  764. model_kwargs: dict,
  765. atm_translator: AssistantToTargetTranslator,
  766. inputs_tensor: Optional[torch.Tensor] = None,
  767. logits_processor: Optional["LogitsProcessorList"] = None,
  768. ):
  769. # Initialize translator before parent class
  770. self._atm_translator = atm_translator
  771. super().__init__(
  772. input_ids,
  773. assistant_model,
  774. target_tokenizer,
  775. assistant_tokenizer,
  776. generation_config,
  777. model_kwargs,
  778. inputs_tensor,
  779. logits_processor,
  780. )
  781. # Track sequence lengths and previous assistant IDs
  782. self._target_seq_len_with_candidates: int = 0
  783. self._prev_assistant_ids: Optional[torch.LongTensor] = None
  784. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  785. """
  786. Simplified version of get_candidates that uses the translator cache for token conversion.
  787. """
  788. target_input_ids = input_ids.to(self.assistant_model.device)
  789. assistant_input_ids, num_added_tokens = self._prepare_assistant_input_ids(target_input_ids)
  790. min_new_tokens, max_new_tokens = self._calculate_new_tokens(target_input_ids)
  791. if max_new_tokens == 0:
  792. return input_ids, None
  793. self._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
  794. generation_args = self._prepare_generation_args(assistant_input_ids, min_new_tokens, max_new_tokens)
  795. # Ensure scores are returned
  796. generation_args["generation_config"].output_scores = True
  797. generation_args["generation_config"].return_dict_in_generate = True
  798. # Generate and process outputs using translator
  799. if self._atm_translator.logits_processors is not None:
  800. generation_args["logits_processor"] = self._atm_translator.logits_processors
  801. self._prev_assistant_ids, assistant_candidate_logits = self._generate_candidates(generation_args)
  802. # Use translator to convert tokens and logits
  803. target_candidate_ids = self._atm_translator.get_target_ids(
  804. assistant_input_ids, target_input_ids, self._prev_assistant_ids
  805. )
  806. self._target_seq_len_with_candidates = target_candidate_ids.shape[-1]
  807. target_candidate_logits = self._atm_translator.get_target_logits(assistant_candidate_logits)
  808. return target_candidate_ids, target_candidate_logits
  809. def _update_past_and_masks(self, assistant_input_ids: torch.LongTensor, num_added_tokens: int = 1) -> bool:
  810. if self._prev_assistant_ids is None:
  811. # Prepare attention mask for the first generation.
  812. # For subsequent generations, the attention mask is updated in super()_update_past_and_masks.
  813. self.assistant_kwargs = _prepare_attention_mask(
  814. self.assistant_kwargs, assistant_input_ids.shape[-1], self.assistant_model.config.is_encoder_decoder
  815. )
  816. return super()._update_past_and_masks(assistant_input_ids, num_added_tokens=num_added_tokens)
  817. def _prepare_assistant_input_ids(self, target_input_ids: torch.LongTensor) -> torch.LongTensor:
  818. """
  819. Simplified token conversion that only processes new tokens.
  820. """
  821. # Calculate new tokens since last call
  822. target_seq_len = target_input_ids.shape[-1]
  823. if self._target_seq_len_with_candidates == 0:
  824. new_token_count = target_seq_len
  825. else:
  826. new_token_count = 1
  827. target_new_ids = target_input_ids[:, -new_token_count:]
  828. # Convert the new tokens
  829. assistant_new_ids = None
  830. if self._target_seq_len_with_candidates > 0:
  831. # we have only one new token and we can directly convert it
  832. assistant_new_ids = self._atm_translator.target_to_assistant_input_ids.get(target_new_ids[0].item())
  833. if assistant_new_ids is None:
  834. target_new_text = self.target_tokenizer.batch_decode(
  835. target_new_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True
  836. )
  837. assistant_new_ids = self.assistant_tokenizer(
  838. target_new_text, add_special_tokens=False, return_tensors="pt"
  839. )["input_ids"].to(self.assistant_model.device)
  840. else:
  841. assistant_new_ids = torch.tensor([[assistant_new_ids]], device=self.assistant_model.device)
  842. # Update or initialize assistant IDs
  843. if self._prev_assistant_ids is None:
  844. assistant_input_ids = assistant_new_ids
  845. else:
  846. tokens_to_remove = self._target_seq_len_with_candidates + 1 - target_seq_len
  847. # If the number of new tokens is greater than zero, truncate the previous assistant IDs
  848. if tokens_to_remove > 0:
  849. self._prev_assistant_ids = self._prev_assistant_ids[:, :-tokens_to_remove]
  850. assistant_input_ids = torch.cat([self._prev_assistant_ids, assistant_new_ids], dim=-1)
  851. assistant_input_ids = assistant_input_ids.to(dtype=torch.long)
  852. self._atm_translator.unmap_input_ids()
  853. return assistant_input_ids, len(assistant_new_ids[0])
  854. class PromptLookupCandidateGenerator(CandidateGenerator):
  855. """
  856. `CandidateGenerator` class to be used for prompt lookup generation. This class generates candidates by looking up
  857. likely continuations in the provided prompt (input_ids) itself.
  858. Read the following blog post for more information: https://github.com/apoorvumang/prompt-lookup-decoding
  859. Args:
  860. eos_token_id (`torch.Tensor`, *optional*):
  861. The token id of the end of sequence token.
  862. num_output_tokens (`int`, *optional*, defaults to 10):
  863. The number of tokens to be output as candidate tokens.
  864. max_matching_ngram_size (`int`, *optional*, defaults to 2):
  865. The maximum ngram size to be considered for matching in the prompt
  866. max_length (`int`, *optional*, defaults to 20):
  867. The number of total maximum tokens that can be generated. For decoder-only models that includes the
  868. prompt length. Defaults to 20, which is the max length used as default in generation config.
  869. logits_processor (`LogitsProcessorList`, *optional*):
  870. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  871. used to modify the prediction scores of the language modeling head applied at each generation step. In
  872. prompt lookup assisted generation, they are not used to manipulate probabilities, but rather to find
  873. forbidden tokens (p = -inf) and block them from being valid candidates.
  874. vocab_size (`int`, *optional*):
  875. The size of the vocabulary. Required if `logits_processor` is provided.
  876. """
  877. def __init__(
  878. self,
  879. eos_token_id: Optional[torch.Tensor] = None,
  880. num_output_tokens: int = 10,
  881. max_matching_ngram_size: int = 2,
  882. max_length: int = 20,
  883. logits_processor: Optional["LogitsProcessorList"] = None,
  884. vocab_size: Optional[int] = None,
  885. ):
  886. self.num_output_tokens = num_output_tokens
  887. self.max_matching_ngram_size = max_matching_ngram_size
  888. self.max_length = max_length
  889. self.eos_token_id = eos_token_id
  890. self.logits_processor = logits_processor
  891. self.vocab_size = vocab_size
  892. if self.max_matching_ngram_size <= 0 or self.num_output_tokens <= 0:
  893. raise ValueError("Invalid max_matching_ngram_size or num_output_tokens")
  894. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  895. """
  896. Fetches the candidates to be tried for the current input.
  897. Args:
  898. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  899. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  900. Return:
  901. `torch.LongTensor` of shape `(num_candidates, candidate_length)`: The candidate sequences to be tried.
  902. """
  903. bsz, input_length = input_ids.shape
  904. # Don't generate more than `max_length - 1` candidates since the target model generates one extra token.
  905. if self.max_length == input_length + 1:
  906. return input_ids, None
  907. chosen_ids = None
  908. match_found = False
  909. for ngram_size in range(min(self.max_matching_ngram_size, input_length - 1), 0, -1):
  910. # Create sliding windows of size ngram_size
  911. windows = input_ids.unfold(dimension=1, size=ngram_size, step=1)
  912. # Convert ngram to a tensor for comparison
  913. ngram_tensor = input_ids[0, -ngram_size:]
  914. # Find where the windows match the ngram
  915. matches = (windows == ngram_tensor).all(dim=2)
  916. # Get the indices of matches
  917. match_indices = matches.nonzero(as_tuple=True)[1]
  918. # Iterate through match indices to find a valid continuation
  919. # TODO (joao): this finds the first valid candidates (left to right), but perhaps we should find the
  920. # longest valid candidates?
  921. for idx in match_indices:
  922. start_idx = idx + ngram_size
  923. end_idx = start_idx + self.num_output_tokens
  924. end_idx = min(end_idx, input_length, self.max_length)
  925. if start_idx < end_idx:
  926. chosen_ids = input_ids[0, start_idx:end_idx]
  927. # Check if the each new candidate token is forbidden according to the logits processor. If all
  928. # tokens are allowed, we keep `chosen_ids` as is.
  929. # 1. create random logits.
  930. # 2. apply the logits processor to get output logits for the next token, using the arbitrary
  931. # logits as input.
  932. # 3. compare the output logits with the next candidate token. If they are -inf, then the next
  933. # candidate token is forbidden and we don't want to generate it.
  934. if self.logits_processor is not None:
  935. sequence_with_candidate = input_ids
  936. fake_input_logits = torch.ones(
  937. (bsz, self.vocab_size), device=input_ids.device, dtype=torch.float32
  938. )
  939. for candidate_idx, new_candidate_token in enumerate(chosen_ids):
  940. fake_output_logits = self.logits_processor(sequence_with_candidate, fake_input_logits)
  941. fake_candidate_logits = fake_output_logits[0, new_candidate_token]
  942. # next candidate token is forbidden -> crop chosen_ids accordingly
  943. if fake_candidate_logits in (-float("Inf"), torch.finfo(fake_candidate_logits.dtype).min):
  944. chosen_ids = chosen_ids[:candidate_idx]
  945. break
  946. else:
  947. sequence_with_candidate = torch.cat(
  948. (input_ids, chosen_ids[: candidate_idx + 1].unsqueeze(0)), dim=1
  949. )
  950. # no valid candidate tokens -> look for a different match
  951. if chosen_ids.shape[0] == 0:
  952. continue
  953. match_found = True
  954. # remove remaining candidate ids if an "eos" token is found, otherwise the target model may
  955. # accept eos and the rest as valid, thus not stopping generation after "eos"
  956. # NOTE: below code is written based on the fact that assisted decoding supports only bs=1
  957. mask = isin_mps_friendly(chosen_ids, self.eos_token_id)
  958. match_indices_eos = torch.nonzero(mask)
  959. if match_indices_eos.numel() > 0:
  960. first_eos_index = match_indices_eos[0].item()
  961. chosen_ids = chosen_ids[:first_eos_index]
  962. break
  963. if match_found:
  964. break
  965. # In case we didn't find a match return the input sequence unchanged, reverts back to autoregressive decoding
  966. if not match_found or len(chosen_ids) == 0:
  967. return input_ids, None
  968. # Now need extend input_ids with chosen_ids
  969. chosen_ids = chosen_ids.unsqueeze(0)
  970. candidate_input_ids = torch.cat((input_ids, chosen_ids), dim=1)
  971. # assisted_generation expects logits as well, but we don't have those here, so returning None
  972. return candidate_input_ids, None
  973. def update_candidate_strategy(self, input_ids: torch.LongTensor, scores: torch.FloatTensor, num_matches: int):
  974. """
  975. Updates the candidate generation strategy based on the outcomes.
  976. Args:
  977. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  978. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  979. scores (`torch.FloatTensor` of shape `(batch_size, candidate_length, config.vocab_size)`):
  980. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using
  981. beam search or log softmax for each vocabulary token when using beam search
  982. num_matches (`int`):
  983. The number of matches between the candidate sequences and the model predictions.
  984. """
  985. # Currently does nothing
  986. return
  987. class EarlyExitCandidateGenerator(AssistedCandidateGenerator):
  988. """
  989. `CandidateGenerator` class to be used for assisted generation and speculative decoding. This class generates
  990. candidates through the use of **the model itself**, exiting early. Can only be used with models that support early
  991. exit, e.g., `facebook/layerskip-llama3.2-1B`.
  992. Args:
  993. input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
  994. Indices of input sequence tokens in the vocabulary. [What are input IDs?](../glossary#input-ids)
  995. assistant_model (`PreTrainedModel`):
  996. The original model. This model must support early exit (i.e. is trained to compute logits in earlier
  997. layers).
  998. generation_config (`~generation.GenerationConfig`, *optional*):
  999. The generation configuration to be used as base parametrization for the generation call.
  1000. logits_processor (`LogitsProcessorList`):
  1001. An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
  1002. used to modify the prediction scores of the language modeling head applied at each generation step.
  1003. model_kwargs (`Dict`):
  1004. The keyword arguments that will be passed to the main model, and are used as base inputs for the assistant
  1005. model as well.
  1006. inputs_tensor (`torch.Tensor`, *optional*):
  1007. The model input tensor. In encoder-decoder models, this is the encoder input.
  1008. """
  1009. def __init__(
  1010. self,
  1011. input_ids: torch.LongTensor,
  1012. assistant_model: "PreTrainedModel",
  1013. generation_config: "GenerationConfig",
  1014. model_kwargs: dict,
  1015. inputs_tensor: Optional[torch.Tensor] = None,
  1016. logits_processor: Optional["LogitsProcessorList"] = None,
  1017. ):
  1018. super().__init__(
  1019. input_ids=input_ids,
  1020. assistant_model=assistant_model,
  1021. generation_config=generation_config,
  1022. model_kwargs=model_kwargs,
  1023. inputs_tensor=inputs_tensor,
  1024. logits_processor=logits_processor,
  1025. )
  1026. # We have to move early exit out of the generation config, otherwise the assistant will also call `generate`
  1027. # with early exit
  1028. self.assistant_early_exit = self.generation_config.assistant_early_exit
  1029. self.generation_config.assistant_early_exit = None
  1030. def get_candidates(self, input_ids: torch.LongTensor) -> tuple[torch.LongTensor, Optional[torch.FloatTensor]]:
  1031. # Temporarily sets the number of hidden layers to the early exit value
  1032. base_model = getattr(self.assistant_model, self.assistant_model.base_model_prefix)
  1033. original_num_hidden_layers = base_model.config.num_hidden_layers
  1034. base_model.config.num_hidden_layers = self.assistant_early_exit
  1035. candidate_ids, candidate_logits = super().get_candidates(input_ids)
  1036. base_model.config.num_hidden_layers = original_num_hidden_layers
  1037. return candidate_ids, candidate_logits
  1038. def _prepare_attention_mask(model_kwargs: dict[str, Any], new_length: int, is_encoder_decoder: bool) -> dict[str, Any]:
  1039. """Expands or crops the model's mask for decoding purposes, to the defined length"""
  1040. mask_key = "decoder_attention_mask" if is_encoder_decoder else "attention_mask"
  1041. if mask_key not in model_kwargs:
  1042. return model_kwargs
  1043. mask = model_kwargs[mask_key]
  1044. mask_length_diff = new_length - mask.shape[1]
  1045. if mask_length_diff < 0:
  1046. model_kwargs[mask_key] = mask[:, :mask_length_diff]
  1047. elif mask_length_diff > 0:
  1048. model_kwargs[mask_key] = torch.cat([mask, mask.new_ones((mask.shape[0], mask_length_diff))], dim=-1)
  1049. # Handle cross attention models
  1050. if "cross_attention_mask" in model_kwargs:
  1051. # Mllama case
  1052. cross_mask = model_kwargs["cross_attention_mask"]
  1053. if mask_length_diff < 0:
  1054. model_kwargs["cross_attention_mask"] = cross_mask[:, :mask_length_diff]
  1055. elif mask_length_diff > 0:
  1056. new_mask = cross_mask[:, -1:, :, :].repeat(1, mask_length_diff, 1, 1)
  1057. model_kwargs["cross_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
  1058. elif "image_attention_mask" in model_kwargs:
  1059. # IDEFICS case
  1060. cross_mask = model_kwargs["image_attention_mask"]
  1061. if mask_length_diff < 0:
  1062. model_kwargs["image_attention_mask"] = cross_mask[:, :mask_length_diff]
  1063. elif mask_length_diff > 0:
  1064. new_mask = cross_mask[:, -1:, :].repeat(1, mask_length_diff, 1)
  1065. model_kwargs["image_attention_mask"] = torch.cat([cross_mask, new_mask], dim=1)
  1066. return model_kwargs
  1067. def _prepare_token_type_ids(model_kwargs: dict[str, Any], new_length: int) -> dict[str, Any]:
  1068. """Expands or crops the model's token_type_ids for decoding purposes, to the defined length"""
  1069. if "token_type_ids" not in model_kwargs or model_kwargs["token_type_ids"] is None:
  1070. return model_kwargs
  1071. token_type_ids = model_kwargs["token_type_ids"]
  1072. final_token_type = token_type_ids[:, -1].unsqueeze(-1)
  1073. type_length_diff = new_length - token_type_ids.shape[1]
  1074. if type_length_diff < 0:
  1075. token_type_ids = token_type_ids[:, :type_length_diff]
  1076. elif type_length_diff > 0:
  1077. token_type_copies = final_token_type.repeat(1, type_length_diff)
  1078. model_kwargs["token_type_ids"] = torch.cat([model_kwargs["token_type_ids"], token_type_copies], dim=-1)
  1079. return model_kwargs