tf_logits_process.py 28 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600
  1. # coding=utf-8
  2. # Copyright 2022 The HuggingFace Inc. team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import inspect
  16. import numpy as np
  17. import tensorflow as tf
  18. from ..tf_utils import stable_softmax
  19. from ..utils import add_start_docstrings
  20. from ..utils.logging import get_logger
  21. logger = get_logger(__name__)
  22. TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING = r"""
  23. Args:
  24. input_ids (`tf.Tensor` of shape `(batch_size, sequence_length)`):
  25. Indices of input sequence tokens in the vocabulary.
  26. Indices can be obtained using [`PreTrainedTokenizer`]. See [`PreTrainedTokenizer.encode`] and
  27. [`PreTrainedTokenizer.__call__`] for details.
  28. [What are input IDs?](../glossary#input-ids)
  29. scores (`tf.Tensor` of shape `(batch_size, config.vocab_size)`):
  30. Prediction scores of a language modeling head. These can be logits for each vocabulary when not using beam
  31. search or log softmax for each vocabulary token when using beam search.
  32. cur_len (`int`):
  33. The current length of valid input sequence tokens. In the TF implementation, the input_ids' sequence length
  34. is the maximum length generate can produce, and we need to know which of its tokens are valid.
  35. kwargs (`dict[str, Any]`, *optional*):
  36. Additional logits processor specific kwargs.
  37. Return:
  38. `tf.Tensor` of shape `(batch_size, config.vocab_size)`: The processed prediction scores.
  39. """
  40. class TFLogitsProcessor:
  41. """Abstract base class for all logit processors that can be applied during generation."""
  42. @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  43. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  44. """TF method for processing logits."""
  45. raise NotImplementedError(
  46. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  47. )
  48. class TFLogitsWarper:
  49. """Abstract base class for all logit warpers that can be applied during generation with multinomial sampling."""
  50. @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  51. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  52. """TF method for warping logits."""
  53. raise NotImplementedError(
  54. f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
  55. )
  56. class TFLogitsProcessorList(list):
  57. """
  58. This class can be used to create a list of [`TFLogitsProcessor`] to subsequently process a `scores` input tensor.
  59. This class inherits from list and adds a specific *__call__* method to apply each [`TFLogitsProcessor`] to the
  60. inputs.
  61. """
  62. @add_start_docstrings(TF_LOGITS_PROCESSOR_INPUTS_DOCSTRING)
  63. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int, **kwargs) -> tf.Tensor:
  64. for processor in self:
  65. function_args = inspect.signature(processor.__call__).parameters
  66. if len(function_args) > 3:
  67. if not all(arg in kwargs for arg in list(function_args.keys())[2:]):
  68. raise ValueError(
  69. f"Make sure that all the required parameters: {list(function_args.keys())} for "
  70. f"{processor.__class__} are passed to the logits processor."
  71. )
  72. scores = processor(input_ids, scores, cur_len, **kwargs)
  73. else:
  74. scores = processor(input_ids, scores, cur_len)
  75. return scores
  76. class TFTemperatureLogitsWarper(TFLogitsWarper):
  77. r"""
  78. [`TFLogitsWarper`] for temperature (exponential scaling output probability distribution).
  79. Args:
  80. temperature (`float`):
  81. The value used to module the logits distribution.
  82. """
  83. def __init__(self, temperature: float):
  84. if not isinstance(temperature, float) or not (temperature > 0):
  85. raise ValueError(f"`temperature` has to be a strictly positive float, but is {temperature}")
  86. self.temperature = temperature
  87. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  88. scores = scores / self.temperature
  89. return scores
  90. class TFTopKLogitsWarper(TFLogitsWarper):
  91. r"""
  92. [`TFLogitsWarper`] that performs top-k, i.e. restricting to the k highest probability elements.
  93. Args:
  94. top_k (`int`):
  95. The number of highest probability vocabulary tokens to keep for top-k-filtering.
  96. filter_value (`float`, *optional*, defaults to -inf):
  97. All filtered values will be set to this float value.
  98. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  99. Minimum number of tokens that cannot be filtered.
  100. """
  101. def __init__(self, top_k: int, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  102. if not isinstance(top_k, int) or top_k <= 0:
  103. raise ValueError(f"`top_k` has to be a strictly positive integer, but is {top_k}")
  104. self.top_k = max(top_k, min_tokens_to_keep)
  105. self.filter_value = filter_value
  106. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  107. top_k = min(self.top_k, scores.shape[-1]) # Safety check
  108. # Boolean mask containing all tokens with a probability less than the last token of the top-k
  109. indices_to_remove = scores < tf.math.top_k(scores, k=top_k)[0][..., -1:]
  110. next_scores = tf.where(indices_to_remove, self.filter_value, scores)
  111. return next_scores
  112. class TFTopPLogitsWarper(TFLogitsWarper):
  113. """
  114. [`TFLogitsWarper`] that performs top-p, i.e. restricting to top tokens summing to <= prob_cut_off.
  115. Args:
  116. top_p (`float`):
  117. If set to < 1, only the smallest set of most probable tokens with probabilities that add up to `top_p` or
  118. higher are kept for generation.
  119. filter_value (`float`, *optional*, defaults to -inf):
  120. All filtered values will be set to this float value.
  121. min_tokens_to_keep (`int`, *optional*, defaults to 1):
  122. Minimum number of tokens that cannot be filtered.
  123. """
  124. def __init__(self, top_p: float, filter_value: float = -float("Inf"), min_tokens_to_keep: int = 1):
  125. if not isinstance(top_p, float) or (top_p < 0 or top_p > 1.0):
  126. raise ValueError(f"`top_p` has to be a float > 0 and < 1, but is {top_p}")
  127. if not isinstance(min_tokens_to_keep, int) or (min_tokens_to_keep < 1):
  128. raise ValueError(f"`min_tokens_to_keep` has to be a positive integer, but is {min_tokens_to_keep}")
  129. self.top_p = top_p
  130. self.filter_value = filter_value
  131. self.min_tokens_to_keep = min_tokens_to_keep
  132. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  133. topk_scores, topk_indices = tf.math.top_k(scores, scores.shape[-1])
  134. mask_scores = tf.fill(scores.shape, self.filter_value)
  135. cumulative_probs = tf.math.cumsum(stable_softmax(topk_scores, axis=-1), axis=-1)
  136. score_mask = cumulative_probs < self.top_p
  137. # Also include the token that is higher than top_p (the first false = shift and insert a True on the left)
  138. score_mask = tf.concat((tf.ones([score_mask.shape[0], 1], dtype=tf.bool), score_mask[:, :-1]), axis=-1)
  139. # Ensure min tokens to keep
  140. score_mask = tf.concat(
  141. (
  142. tf.ones([score_mask.shape[0], self.min_tokens_to_keep], dtype=tf.bool),
  143. score_mask[:, self.min_tokens_to_keep :],
  144. ),
  145. axis=-1,
  146. )
  147. # Mask the values that do not fit the criteria
  148. topk_next_scores = tf.where(score_mask, topk_scores, mask_scores)
  149. # Undo the topk sorting: converts the 2D matrix of per-row original indices of shape (batch_size, vocab_size)
  150. # to a 3D tensor of shape (batch_size, vocab_size, 2) containing the original score coordinate, from which we
  151. # can scatter (i.e. `scatter_indices[row, col, :]` is a tensor containing `[row, topk_indices[row, col]]`)
  152. scatter_rows = tf.tile(tf.expand_dims(tf.range(topk_indices.shape[0]), axis=-1), [1, topk_indices.shape[-1]])
  153. scatter_indices = tf.stack((scatter_rows, topk_indices), axis=-1)
  154. next_scores = tf.scatter_nd(scatter_indices, topk_next_scores, shape=topk_next_scores.shape)
  155. return next_scores
  156. class TFMinLengthLogitsProcessor(TFLogitsProcessor):
  157. r"""
  158. [`TFLogitsProcessor`] enforcing a min-length by setting EOS probability to 0.
  159. Args:
  160. min_length (`int`):
  161. The minimum length below which the score of `eos_token_id` is set to `-float("Inf")`.
  162. eos_token_id (`int`):
  163. The id of the *end-of-sequence* token.
  164. """
  165. def __init__(self, min_length: int, eos_token_id: int):
  166. if not isinstance(min_length, int) or min_length < 0:
  167. raise ValueError(f"`min_length` has to be a positive integer, but is {min_length}")
  168. if not isinstance(eos_token_id, int) or eos_token_id < 0:
  169. raise ValueError(f"`eos_token_id` has to be a positive integer, but is {eos_token_id}")
  170. self.min_length = min_length
  171. self.eos_token_id = eos_token_id
  172. def _apply_eos_token_mask(self, scores: tf.Tensor) -> tf.Tensor:
  173. eos_token_id_mask = tf.range(scores.shape[-1]) == self.eos_token_id
  174. scores = tf.where(eos_token_id_mask, float("-inf"), scores)
  175. return scores
  176. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  177. # applies eos token masking if the first argument is true
  178. scores = tf.cond(
  179. tf.less(cur_len, self.min_length),
  180. lambda: self._apply_eos_token_mask(scores),
  181. lambda: tf.identity(scores),
  182. )
  183. return scores
  184. class TFRepetitionPenaltyLogitsProcessor(TFLogitsProcessor):
  185. r"""
  186. [`TFLogitsProcessor`] enforcing an exponential penalty on repeated sequences.
  187. Args:
  188. repetition_penalty (`float`):
  189. The parameter for repetition penalty. 1.0 means no penalty. See [this
  190. paper](https://huggingface.co/papers/1909.05858) for more details.
  191. """
  192. def __init__(self, penalty: float):
  193. if not isinstance(penalty, float) or not (penalty > 0):
  194. raise ValueError(f"`penalty` has to be a strictly positive float, but is {penalty}")
  195. self.penalty = penalty
  196. def _create_score_penalties(self, input_ids: tf.Tensor, logits: tf.Tensor) -> tf.Tensor:
  197. # We want to populate the penalties in the positions of `input_ids`. Since XLA can't handle shapes unknown
  198. # before runtime, `tf.unique` can't be used. Therefore, we may have redundant updates, when a given row has
  199. # the same token multiple times.
  200. # Gathers the penalties to apply
  201. logit_penalties = tf.gather(logits, input_ids, axis=1, batch_dims=1)
  202. logit_penalties = tf.where(logit_penalties > 0, 1 / self.penalty, logit_penalties)
  203. logit_penalties = tf.where(logit_penalties < 0, self.penalty, logit_penalties)
  204. # Scatters the penalties
  205. token_penalties = tf.ones(logits.shape)
  206. batch_size = input_ids.shape[0]
  207. seq_len = tf.shape(input_ids)[1] # the sequence length has dynamic size, hence the dynamic shape
  208. indexable_prev_input_ids = tf.concat(
  209. (
  210. tf.expand_dims(tf.repeat(tf.range(batch_size), seq_len), axis=-1),
  211. tf.expand_dims(tf.reshape(input_ids, [-1]), axis=-1),
  212. ),
  213. axis=1,
  214. )
  215. token_penalties = tf.tensor_scatter_nd_update(
  216. token_penalties, indices=indexable_prev_input_ids, updates=tf.reshape(logit_penalties, [-1])
  217. )
  218. return token_penalties
  219. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  220. score_penalties = self._create_score_penalties(input_ids[:, :cur_len], scores)
  221. scores = tf.math.multiply(scores, score_penalties)
  222. return scores
  223. class TFNoBadWordsLogitsProcessor(TFLogitsProcessor):
  224. """
  225. [`TFLogitsProcessor`] that enforces that specified sequences will never be sampled.
  226. Args:
  227. bad_words_ids (`list[list[int]]`):
  228. List of list of token ids that are not allowed to be generated. In order to get the tokens of the words
  229. that should not appear in the generated text, make sure to set `add_prefix_space=True` when initializing
  230. the tokenizer, and use `tokenizer(bad_words, add_special_tokens=False).input_ids`. The `add_prefix_space`
  231. argument is only supported for some slow tokenizers, as fast tokenizers' prefixing behaviours come from
  232. `pre tokenizers`. Read more [here](https://huggingface.co/docs/tokenizers/api/pre-tokenizers).
  233. eos_token_id (`int`):
  234. The id of the *end-of-sequence* token.
  235. """
  236. def __init__(self, bad_words_ids: list[list[int]], eos_token_id: int):
  237. if not isinstance(bad_words_ids, list) or len(bad_words_ids) == 0:
  238. raise ValueError(f"`bad_words_ids` has to be a non-empty list, but is {bad_words_ids}.")
  239. if any(not isinstance(bad_word_ids, list) for bad_word_ids in bad_words_ids):
  240. raise ValueError(f"`bad_words_ids` has to be a list of lists, but is {bad_words_ids}.")
  241. if any(
  242. any((not isinstance(token_id, (int, np.integer)) or token_id < 0) for token_id in bad_word_ids)
  243. for bad_word_ids in bad_words_ids
  244. ):
  245. raise ValueError(
  246. f"Each list in `bad_words_ids` has to be a list of positive integers, but is {bad_words_ids}."
  247. )
  248. # stores the information about bad words in three tensors:
  249. # 1. a rectangular tensor with the forbidden sequences (padded with `-1`), for full data comparisons
  250. self.bad_word_seqs_ids = tf.ragged.constant(bad_words_ids).to_tensor(default_value=-1)
  251. # 2. a tensor with the unpadded length of each forbidden sequence, for quick length comparisons
  252. bad_word_seqs_len = [len(bad_words) for bad_words in bad_words_ids]
  253. if any(word_len == 0 for word_len in bad_word_seqs_len):
  254. raise ValueError(f"Banned words token sequences {bad_words_ids} cannot have an empty list")
  255. self.bad_word_seqs_len = tf.convert_to_tensor(bad_word_seqs_len, dtype=tf.int32)
  256. # 3. a tensor containing the last token for each sequence, for easy access to the tokens that may be banned
  257. self.seq_forbidden_tokens = tf.convert_to_tensor([bad_words[-1] for bad_words in bad_words_ids])
  258. def _calc_row_banned_bad_tokens(self, row_input_ids: tf.Tensor) -> tf.Tensor:
  259. def _tokens_match(bad_word_seq_number):
  260. def _len_one():
  261. # If the bad sequence only has one token, always mask it
  262. return tf.cond(
  263. tf.math.equal(self.bad_word_seqs_len[bad_word_seq_number], 1),
  264. lambda: tf.ones((), dtype=tf.bool),
  265. _len_greater_than_cur_len,
  266. )
  267. def _len_greater_than_cur_len():
  268. # Otherwise, if the bad sequence is longer than the current length they can't ever match
  269. return tf.cond(
  270. tf.math.greater(self.bad_word_seqs_len[bad_word_seq_number], tf.shape(row_input_ids)[0]),
  271. lambda: tf.zeros((), dtype=tf.bool),
  272. _match_found,
  273. )
  274. def _match_found():
  275. # Finally, runs the actual comparison. Can only be called if the previous comparisons do not yield
  276. # an answer (otherwise we get indexing exceptions)
  277. compare_len = self.bad_word_seqs_len[bad_word_seq_number] - 1
  278. return tf.cond(
  279. tf.math.reduce_all(
  280. tf.math.equal(
  281. row_input_ids[-compare_len:], self.bad_word_seqs_ids[bad_word_seq_number, :compare_len]
  282. )
  283. ),
  284. lambda: tf.ones((), dtype=tf.bool),
  285. lambda: tf.zeros((), dtype=tf.bool),
  286. )
  287. match = _len_one()
  288. return match
  289. # Compares the current row against all bad word sequences, obtaining a mask with the matches.
  290. match_mask = tf.map_fn(_tokens_match, tf.range(self.bad_word_seqs_ids.shape[0]), fn_output_signature=tf.bool)
  291. row_banned_tokens = self.seq_forbidden_tokens[match_mask]
  292. return row_banned_tokens
  293. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  294. # We want to mask some banned tokens, at a score level. Since the banned tokens depend on the previous
  295. # `input_ids`, they may have a different length for each row, and they may even be empty for some rows.
  296. # To remain simple and XLA-compatible, we work on a per-row fashion.
  297. # TODO (Joao): this function might trigger XLA retracing as `cur_len` increases. Fix it if it becomes
  298. # a frequent choke point. (make `cur_len` a tensor?)
  299. def _get_row_updated_score(row_inputs: tuple[tf.Tensor]) -> tf.Tensor:
  300. row_input_ids, row_score = row_inputs
  301. banned_tokens = self._calc_row_banned_bad_tokens(row_input_ids[:cur_len])
  302. banned_tokens_mask = tf.scatter_nd(
  303. indices=tf.expand_dims(banned_tokens, axis=-1),
  304. updates=tf.ones_like(banned_tokens, dtype=tf.bool),
  305. shape=row_score.shape,
  306. )
  307. row_score = tf.where(banned_tokens_mask, -float("inf"), row_score)
  308. return row_score
  309. scores = tf.map_fn(_get_row_updated_score, (input_ids, scores), fn_output_signature=tf.float32)
  310. return scores
  311. class TFNoRepeatNGramLogitsProcessor(TFLogitsProcessor):
  312. r"""
  313. [`TFLogitsProcessor`] that enforces no repetition of n-grams. See
  314. [Fairseq](https://github.com/pytorch/fairseq/blob/a07cb6f40480928c9e0548b737aadd36ee66ac76/fairseq/sequence_generator.py#L345).
  315. Args:
  316. ngram_size (`int`):
  317. All ngrams of size `ngram_size` can only occur once.
  318. """
  319. def __init__(self, ngram_size: int):
  320. if not isinstance(ngram_size, int) or ngram_size <= 0:
  321. raise ValueError(f"`ngram_size` has to be a strictly positive integer, but is {ngram_size}")
  322. self.ngram_size = ngram_size
  323. def calc_banned_ngram_tokens(self, input_ids, num_hypos, cur_len):
  324. # Copied from fairseq for no_repeat_ngram in beam_search
  325. if cur_len + 1 < self.ngram_size:
  326. # return no banned tokens if we haven't generated ngram_size tokens yet
  327. return [[] for _ in range(num_hypos)]
  328. generated_ngrams = [{} for _ in range(num_hypos)]
  329. prev_input_ids = input_ids[:, :cur_len]
  330. for idx in range(num_hypos):
  331. gen_tokens = prev_input_ids[idx].numpy().tolist()
  332. generated_ngram = generated_ngrams[idx]
  333. for ngram in zip(*[gen_tokens[i:] for i in range(self.ngram_size)]):
  334. prev_ngram_tuple = tuple(ngram[:-1])
  335. generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
  336. def _get_generated_ngrams(hypo_idx):
  337. # Before decoding the next token, prevent decoding of ngrams that have already appeared
  338. start_idx = cur_len + 1 - self.ngram_size
  339. ngram_idx = tuple(prev_input_ids[hypo_idx, start_idx:cur_len].numpy().tolist())
  340. return generated_ngrams[hypo_idx].get(ngram_idx, [])
  341. banned_tokens = [_get_generated_ngrams(hypo_idx) for hypo_idx in range(num_hypos)]
  342. return banned_tokens
  343. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  344. # TODO (joao): enable XLA on this logits processor. See discussion and attempts in
  345. # https://github.com/huggingface/transformers/pull/16974
  346. if not tf.executing_eagerly():
  347. raise NotImplementedError("TFNoRepeatNGramLogitsProcessor is only implemented for eager execution.")
  348. batch_size, vocab_size = scores.shape
  349. banned_tokens = self.calc_banned_ngram_tokens(input_ids, batch_size, cur_len)
  350. # create banned_tokens boolean mask
  351. banned_tokens_indices_mask = []
  352. for banned_tokens_slice in banned_tokens:
  353. banned_tokens_indices_mask.append([token in banned_tokens_slice for token in range(vocab_size)])
  354. scores = tf.where(tf.convert_to_tensor(banned_tokens_indices_mask, dtype=tf.bool), -float("inf"), scores)
  355. return scores
  356. class TFForcedBOSTokenLogitsProcessor(TFLogitsProcessor):
  357. r"""
  358. [`TFLogitsProcessor`] that enforces the specified token as the first generated token.
  359. Args:
  360. bos_token_id (`int`):
  361. The id of the token to force as the first generated token.
  362. """
  363. def __init__(self, bos_token_id: int):
  364. if bos_token_id < 0:
  365. raise ValueError(f"The forced bos token id must be a non-negative integer, got {bos_token_id}")
  366. self.bos_token_id = bos_token_id
  367. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  368. if cur_len == 1:
  369. batch_size, num_tokens = scores.shape
  370. # sets the score to 0 in the bos_token_id column
  371. scores = tf.zeros((batch_size, 1))
  372. # sets the score to -inf everywhere else
  373. if self.bos_token_id > 0:
  374. scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.bos_token_id)), scores), axis=-1)
  375. if self.bos_token_id < (num_tokens - 1):
  376. scores = tf.concat(
  377. (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.bos_token_id))),
  378. axis=-1,
  379. )
  380. return scores
  381. class TFForcedEOSTokenLogitsProcessor(TFLogitsProcessor):
  382. r"""
  383. [`TFLogitsProcessor`] that enforces the specified token as the last generated token when `max_length` is reached.
  384. Args:
  385. max_length (`int`):
  386. The maximum length of the sequence to be generated.
  387. eos_token_id (`int`):
  388. The id of the token to force as the last generated token when `max_length` is reached.
  389. """
  390. def __init__(self, max_length: int, eos_token_id: int):
  391. self.max_length = max_length
  392. if eos_token_id < 0:
  393. raise ValueError(f"The forced eos token id must be a non-negative integer, got {eos_token_id}")
  394. self.eos_token_id = eos_token_id
  395. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  396. if cur_len == self.max_length - 1:
  397. batch_size, num_tokens = scores.shape
  398. # sets the score to 0 in the eos_token_id column
  399. scores = tf.zeros((batch_size, 1))
  400. # sets the score to -inf everywhere else
  401. if self.eos_token_id > 0:
  402. scores = tf.concat((tf.broadcast_to(-float("inf"), (batch_size, self.eos_token_id)), scores), axis=-1)
  403. if self.eos_token_id < (num_tokens - 1):
  404. scores = tf.concat(
  405. (scores, tf.broadcast_to(-float("inf"), (batch_size, (num_tokens - 1) - self.eos_token_id))),
  406. axis=-1,
  407. )
  408. return scores
  409. class TFSuppressTokensAtBeginLogitsProcessor(TFLogitsProcessor):
  410. r"""
  411. [`TFSuppressTokensAtBeginLogitsProcessor`] suppresses a list of tokens as soon as the `generate` function starts
  412. generating using `begin_index` tokens. This should ensure that the tokens defined by `begin_suppress_tokens` at not
  413. sampled at the beginning of the generation.
  414. """
  415. def __init__(self, begin_suppress_tokens, begin_index):
  416. self.begin_suppress_tokens = list(begin_suppress_tokens)
  417. self.begin_index = begin_index
  418. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  419. suppressed_indices = []
  420. for token in self.begin_suppress_tokens:
  421. if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
  422. suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
  423. if len(suppressed_indices) > 0:
  424. scores = tf.cond(
  425. tf.equal(cur_len, self.begin_index),
  426. lambda: tf.tensor_scatter_nd_update(
  427. scores,
  428. indices=suppressed_indices,
  429. updates=[-float("inf") for _ in range(scores.shape[0] * len(self.begin_suppress_tokens))],
  430. ),
  431. lambda: scores,
  432. )
  433. return scores
  434. class TFSuppressTokensLogitsProcessor(TFLogitsProcessor):
  435. r"""This processor can be used to suppress a list of tokens. The processor will set their log probs to `-inf` so that they
  436. are not sampled."""
  437. def __init__(self, suppress_tokens):
  438. self.suppress_tokens = list(suppress_tokens)
  439. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  440. suppressed_indices = []
  441. for token in self.suppress_tokens:
  442. if token < scores.shape[-1]: # to ensure we don't go beyond the vocab size
  443. suppressed_indices.extend([[i, token] for i in range(scores.shape[0])])
  444. if len(suppressed_indices) > 0:
  445. scores = tf.tensor_scatter_nd_update(
  446. scores,
  447. indices=[[i, token] for i in range(scores.shape[0]) for token in self.suppress_tokens],
  448. updates=[-float("inf") for _ in range(scores.shape[0] * len(self.suppress_tokens))],
  449. )
  450. return scores
  451. class TFForceTokensLogitsProcessor(TFLogitsProcessor):
  452. r"""This processor takes a list of pairs of integers which indicates a mapping from generation indices to token
  453. indices that will be forced before sampling. The processor will set their log probs to `0` and all other tokens to
  454. `-inf` so that they are sampled at their corresponding index."""
  455. def __init__(self, force_token_map: list[list[int]]):
  456. force_token_map = dict(force_token_map)
  457. # Converts the dictionary of format {index: token} containing the tokens to be forced to an array, where the
  458. # index of the array corresponds to the index of the token to be forced, for XLA compatibility.
  459. # Indexes without forced tokens will have an negative value.
  460. force_token_array = np.ones((max(force_token_map.keys()) + 1), dtype=np.int32) * -1
  461. for index, token in force_token_map.items():
  462. if token is not None:
  463. force_token_array[index] = token
  464. self.force_token_array = tf.convert_to_tensor(force_token_array, dtype=tf.int32)
  465. def __call__(self, input_ids: tf.Tensor, scores: tf.Tensor, cur_len: int) -> tf.Tensor:
  466. def _force_token(generation_idx):
  467. batch_size = scores.shape[0]
  468. current_token = self.force_token_array[generation_idx]
  469. new_scores = tf.zeros_like(scores, dtype=scores.dtype) + tf.constant([scores.dtype.min])
  470. indices = tf.stack((tf.range(batch_size), tf.tile([current_token], [batch_size])), axis=1)
  471. updates = tf.zeros((batch_size,), dtype=scores.dtype)
  472. new_scores = tf.tensor_scatter_nd_update(new_scores, indices, updates)
  473. return new_scores
  474. scores = tf.cond(
  475. tf.greater_equal(cur_len, tf.shape(self.force_token_array)[0]),
  476. # If the current length is geq than the length of force_token_array, the processor does nothing.
  477. lambda: tf.identity(scores),
  478. # Otherwise, it may force a certain token.
  479. lambda: tf.cond(
  480. tf.greater_equal(self.force_token_array[cur_len], 0),
  481. # Only valid (positive) tokens are forced
  482. lambda: _force_token(cur_len),
  483. # Otherwise, the processor does nothing.
  484. lambda: scores,
  485. ),
  486. )
  487. return scores