beam_search.py 47 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002
  1. # coding=utf-8
  2. # Copyright 2020 The HuggingFace Inc. team
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from abc import ABC, abstractmethod
  16. from collections import UserDict
  17. from typing import Optional, Union
  18. import numpy as np
  19. import torch
  20. from ..utils import add_start_docstrings, logging
  21. from .beam_constraints import Constraint, ConstraintListState
  22. logger = logging.get_logger(__name__)
  23. PROCESS_INPUTS_DOCSTRING = r"""
  24. Args:
  25. input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
  26. Indices of input sequence tokens in the vocabulary.
  27. Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
  28. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  29. [What are input IDs?](../glossary#input-ids)
  30. next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
  31. Current scores of the top `2 * num_beams` non-finished beam hypotheses.
  32. next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
  33. `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
  34. next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
  35. Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
  36. pad_token_id (`int`, *optional*):
  37. The id of the *padding* token.
  38. eos_token_id (`Union[int, list[int]]`, *optional*):
  39. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
  40. beam_indices (`torch.LongTensor`, *optional*):
  41. Beam indices indicating to which beam hypothesis each token correspond.
  42. Return:
  43. `UserDict`: A dictionary composed of the fields as defined above:
  44. - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
  45. non-finished beams.
  46. - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
  47. to the non-finished beam_hypotheses.
  48. - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
  49. indicating to which beam the next tokens shall be added.
  50. """
  51. FINALIZE_INPUTS_DOCSTRING = r"""
  52. Args:
  53. input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
  54. Indices of input sequence tokens in the vocabulary.
  55. Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
  56. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  57. [What are input IDs?](../glossary#input-ids)
  58. final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
  59. The final scores of all non-finished beams.
  60. final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
  61. The last tokens to be added to the non-finished beam_hypotheses.
  62. final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
  63. The beam indices indicating to which beam the `final_beam_tokens` shall be added.
  64. pad_token_id (`int`, *optional*):
  65. The id of the *padding* token.
  66. eos_token_id (`Union[int, list[int]]`, *optional*):
  67. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
  68. Return:
  69. `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
  70. The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
  71. due to the `eos_token_id`.
  72. """
  73. class BeamScorer(ABC):
  74. """
  75. Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
  76. [`~PreTrainedModel.beam_sample`].
  77. """
  78. @abstractmethod
  79. @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
  80. def process(
  81. self,
  82. input_ids: torch.LongTensor,
  83. next_scores: torch.FloatTensor,
  84. next_tokens: torch.LongTensor,
  85. next_indices: torch.LongTensor,
  86. **kwargs,
  87. ) -> tuple[torch.Tensor]:
  88. raise NotImplementedError("This is an abstract method.")
  89. @abstractmethod
  90. @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
  91. def finalize(
  92. self,
  93. input_ids: torch.LongTensor,
  94. next_scores: torch.FloatTensor,
  95. next_tokens: torch.LongTensor,
  96. next_indices: torch.LongTensor,
  97. max_length: int,
  98. **kwargs,
  99. ) -> torch.LongTensor:
  100. raise NotImplementedError("This is an abstract method.")
  101. class BeamSearchScorer(BeamScorer):
  102. r"""
  103. [`BeamScorer`] implementing standard beam search decoding.
  104. Adapted in part from [Facebook's XLM beam search
  105. code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
  106. Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
  107. implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
  108. Args:
  109. batch_size (`int`):
  110. Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
  111. num_beams (`int`):
  112. Number of beams for beam search.
  113. device (`torch.device`):
  114. Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
  115. allocated.
  116. length_penalty (`float`, *optional*, defaults to 1.0):
  117. Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
  118. the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
  119. likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
  120. `length_penalty` < 0.0 encourages shorter sequences.
  121. do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
  122. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
  123. `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
  124. heuristic is applied and the generation stops when is it very unlikely to find better candidates;
  125. `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
  126. beam search algorithm).
  127. num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
  128. The number of beam hypotheses that shall be returned upon calling
  129. [`~transformers.BeamSearchScorer.finalize`].
  130. num_beam_groups (`int`, *optional*, defaults to 1):
  131. Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
  132. See [this paper](https://huggingface.co/papers/1610.02424) for more details.
  133. max_length (`int`, *optional*):
  134. The maximum length of the sequence to be generated.
  135. """
  136. def __init__(
  137. self,
  138. batch_size: int,
  139. num_beams: int,
  140. device: torch.device,
  141. length_penalty: float = 1.0,
  142. do_early_stopping: Union[bool, str] = False,
  143. num_beam_hyps_to_keep: int = 1,
  144. num_beam_groups: int = 1,
  145. max_length: Optional[int] = None,
  146. ):
  147. logger.warning_once(
  148. "`BeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
  149. )
  150. self.num_beams = num_beams
  151. self.device = device
  152. self.length_penalty = length_penalty
  153. self.do_early_stopping = do_early_stopping
  154. self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
  155. self.num_beam_groups = num_beam_groups
  156. self.group_size = self.num_beams // self.num_beam_groups
  157. self._is_init = False
  158. # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
  159. # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
  160. self._beam_hyps = [
  161. BeamHypotheses(
  162. num_beams=self.group_size,
  163. length_penalty=self.length_penalty,
  164. early_stopping=self.do_early_stopping,
  165. max_length=max_length,
  166. )
  167. for _ in range(batch_size * self.num_beam_groups)
  168. ]
  169. # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
  170. # in the i-th mini-batch is complete.
  171. self._done = torch.tensor(
  172. [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
  173. )
  174. if not isinstance(num_beams, int) or num_beams <= 1:
  175. raise ValueError(
  176. f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
  177. " one should make use of `greedy_search` instead."
  178. )
  179. if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
  180. raise ValueError(
  181. "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
  182. f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
  183. )
  184. @property
  185. def is_done(self) -> bool:
  186. return self._done.all().item()
  187. def process(
  188. self,
  189. input_ids: torch.LongTensor,
  190. next_scores: torch.FloatTensor,
  191. next_tokens: torch.LongTensor,
  192. next_indices: torch.LongTensor,
  193. pad_token_id: Optional[Union[int, torch.Tensor]] = None,
  194. eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
  195. beam_indices: Optional[torch.LongTensor] = None,
  196. group_index: int = 0,
  197. decoder_prompt_len: int = 0,
  198. ) -> dict[str, torch.Tensor]:
  199. # add up to the length which the next_scores is calculated on (including decoder prompt)
  200. cur_len = input_ids.shape[-1] + 1
  201. batch_size = len(self._beam_hyps) // self.num_beam_groups
  202. if batch_size != (input_ids.shape[0] // self.group_size):
  203. if self.num_beam_groups > 1:
  204. raise ValueError(
  205. f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
  206. f"size of {self.group_size} is expected by the beam scorer."
  207. )
  208. else:
  209. raise ValueError(
  210. f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
  211. f"{self.group_size} is expected by the beam scorer."
  212. )
  213. device = input_ids.device
  214. next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
  215. next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
  216. next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
  217. if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
  218. if isinstance(eos_token_id, int):
  219. eos_token_id = [eos_token_id]
  220. eos_token_id = torch.tensor(eos_token_id)
  221. for batch_idx in range(batch_size):
  222. batch_group_idx = batch_idx * self.num_beam_groups + group_index
  223. if self._done[batch_group_idx]:
  224. if self.num_beams < len(self._beam_hyps[batch_group_idx]):
  225. raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
  226. if eos_token_id is None or pad_token_id is None:
  227. raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
  228. # pad the batch
  229. next_beam_scores[batch_idx, :] = 0
  230. next_beam_tokens[batch_idx, :] = pad_token_id
  231. next_beam_indices[batch_idx, :] = 0
  232. continue
  233. # next tokens for this sentence
  234. beam_idx = 0
  235. for beam_token_rank, (next_token, next_score, next_index) in enumerate(
  236. zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
  237. ):
  238. batch_beam_idx = batch_idx * self.group_size + next_index
  239. # add to generated hypotheses if end of sentence
  240. if (eos_token_id is not None) and (next_token.item() in eos_token_id):
  241. # if beam_token does not belong to top num_beams tokens, it should not be added
  242. is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
  243. if is_beam_token_worse_than_top_num_beams:
  244. continue
  245. if beam_indices is not None:
  246. beam_index = beam_indices[batch_beam_idx]
  247. beam_index = beam_index + (batch_beam_idx,)
  248. else:
  249. beam_index = None
  250. self._beam_hyps[batch_group_idx].add(
  251. input_ids[batch_beam_idx].clone(),
  252. next_score.item(),
  253. beam_indices=beam_index,
  254. generated_len=cur_len - decoder_prompt_len,
  255. )
  256. else:
  257. # add next predicted token since it is not eos_token
  258. next_beam_scores[batch_idx, beam_idx] = next_score
  259. next_beam_tokens[batch_idx, beam_idx] = next_token
  260. next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
  261. beam_idx += 1
  262. # once the beam for next step is full, don't add more tokens to it.
  263. if beam_idx == self.group_size:
  264. break
  265. if beam_idx < self.group_size:
  266. raise ValueError(
  267. f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
  268. f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
  269. )
  270. # Check if we are done so that we can save a pad step if all(done)
  271. self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
  272. next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
  273. )
  274. return UserDict(
  275. {
  276. "next_beam_scores": next_beam_scores.view(-1),
  277. "next_beam_tokens": next_beam_tokens.view(-1),
  278. "next_beam_indices": next_beam_indices.view(-1),
  279. }
  280. )
  281. def finalize(
  282. self,
  283. input_ids: torch.LongTensor,
  284. final_beam_scores: torch.FloatTensor,
  285. final_beam_tokens: torch.LongTensor,
  286. final_beam_indices: torch.LongTensor,
  287. max_length: int,
  288. pad_token_id: Optional[Union[int, torch.Tensor]] = None,
  289. eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
  290. beam_indices: Optional[torch.LongTensor] = None,
  291. decoder_prompt_len: int = 0,
  292. ) -> tuple[torch.LongTensor]:
  293. batch_size = len(self._beam_hyps) // self.num_beam_groups
  294. if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
  295. if isinstance(eos_token_id, int):
  296. eos_token_id = [eos_token_id]
  297. eos_token_id = torch.tensor(eos_token_id)
  298. # finalize all open beam hypotheses and add to generated hypotheses
  299. for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
  300. if self._done[batch_group_idx]:
  301. continue
  302. # all open beam hypotheses are added to the beam hypothesis
  303. # beam hypothesis class automatically keeps the best beams
  304. for index_per_group in range(self.group_size):
  305. batch_beam_idx = batch_group_idx * self.group_size + index_per_group
  306. final_score = final_beam_scores[batch_beam_idx].item()
  307. final_tokens = input_ids[batch_beam_idx]
  308. beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
  309. generated_len = final_tokens.shape[-1] - decoder_prompt_len
  310. beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
  311. # select the best hypotheses
  312. sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
  313. best = []
  314. best_indices = []
  315. best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
  316. # retrieve best hypotheses
  317. for i in range(batch_size):
  318. beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
  319. candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
  320. sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
  321. for j in range(self.num_beam_hyps_to_keep):
  322. best_hyp_tuple = sorted_hyps.pop()
  323. best_score = best_hyp_tuple[0]
  324. best_hyp = best_hyp_tuple[1]
  325. best_index = best_hyp_tuple[2]
  326. sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
  327. # append hyp to lists
  328. best.append(best_hyp)
  329. # append indices to list
  330. best_indices.append(best_index)
  331. best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
  332. # prepare for adding eos
  333. sent_lengths_max = sent_lengths.max().item() + 1
  334. sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
  335. decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
  336. if len(best_indices) > 0 and best_indices[0] is not None:
  337. indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
  338. else:
  339. indices = None
  340. # shorter batches are padded if needed
  341. if sent_lengths.min().item() != sent_lengths.max().item():
  342. if pad_token_id is None:
  343. raise ValueError("`pad_token_id` has to be defined")
  344. decoded.fill_(pad_token_id)
  345. if indices is not None:
  346. indices.fill_(-1)
  347. # fill with hypotheses and eos_token_id if the latter fits in
  348. for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
  349. decoded[i, : sent_lengths[i]] = hypo
  350. if indices is not None:
  351. indices[i, : len(best_idx)] = torch.tensor(best_idx)
  352. if sent_lengths[i] < sent_max_len:
  353. # inserting only the first eos_token_id
  354. decoded[i, sent_lengths[i]] = eos_token_id[0]
  355. return UserDict(
  356. {
  357. "sequences": decoded,
  358. "sequence_scores": best_scores,
  359. "beam_indices": indices,
  360. }
  361. )
  362. class ConstrainedBeamSearchScorer(BeamScorer):
  363. r"""
  364. [`BeamScorer`] implementing constrained beam search decoding.
  365. Args:
  366. batch_size (`int`):
  367. Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
  368. num_beams (`int`):
  369. Number of beams for beam search.
  370. constraints (`list[Constraint]`):
  371. A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
  372. output. For more information, the documentation of [`Constraint`] should be read.
  373. device (`torch.device`):
  374. Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
  375. allocated.
  376. length_penalty (`float`, *optional*, defaults to 1.0):
  377. Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
  378. the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
  379. likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
  380. `length_penalty` < 0.0 encourages shorter sequences.
  381. do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
  382. Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
  383. `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
  384. heuristic is applied and the generation stops when is it very unlikely to find better candidates;
  385. `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
  386. beam search algorithm).
  387. num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
  388. The number of beam hypotheses that shall be returned upon calling
  389. [`~transformers.BeamSearchScorer.finalize`].
  390. max_length (`int`, *optional*):
  391. The maximum length of the sequence to be generated.
  392. """
  393. def __init__(
  394. self,
  395. batch_size: int,
  396. num_beams: int,
  397. constraints: list[Constraint],
  398. device: torch.device,
  399. length_penalty: float = 1.0,
  400. do_early_stopping: Union[bool, str] = False,
  401. num_beam_hyps_to_keep: int = 1,
  402. max_length: Optional[int] = None,
  403. ):
  404. logger.warning_once(
  405. "`ConstrainedBeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
  406. )
  407. self.num_beams = num_beams
  408. self.device = device
  409. self.length_penalty = length_penalty
  410. self.do_early_stopping = do_early_stopping
  411. self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
  412. self.constraints = constraints
  413. self._is_init = False
  414. self._beam_hyps = [
  415. BeamHypotheses(
  416. num_beams=self.num_beams,
  417. length_penalty=self.length_penalty,
  418. early_stopping=self.do_early_stopping,
  419. max_length=max_length,
  420. )
  421. for _ in range(batch_size)
  422. ]
  423. self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
  424. if not isinstance(num_beams, int) or num_beams <= 1:
  425. raise ValueError(
  426. f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
  427. " one should make use of `greedy_search` instead."
  428. )
  429. @property
  430. def is_done(self) -> bool:
  431. return self._done.all().item()
  432. def make_constraint_states(self, n):
  433. return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
  434. def check_completes_constraints(self, sequence):
  435. new_state = self.make_constraint_states(1)[0]
  436. new_state.reset(sequence)
  437. return new_state.completed
  438. def process(
  439. self,
  440. input_ids: torch.LongTensor,
  441. next_scores: torch.FloatTensor,
  442. next_tokens: torch.LongTensor,
  443. next_indices: torch.LongTensor,
  444. scores_for_all_vocab: torch.FloatTensor,
  445. pad_token_id: Optional[Union[int, torch.Tensor]] = None,
  446. eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
  447. beam_indices: Optional[torch.LongTensor] = None,
  448. decoder_prompt_len: int = 0,
  449. ) -> tuple[torch.Tensor]:
  450. r"""
  451. Args:
  452. input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
  453. Indices of input sequence tokens in the vocabulary.
  454. Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
  455. [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
  456. [What are input IDs?](../glossary#input-ids)
  457. next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
  458. Current scores of the top `2 * num_beams` non-finished beam hypotheses.
  459. next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
  460. `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
  461. next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
  462. Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
  463. scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
  464. The scores of all tokens in the vocabulary for each of the beam hypotheses.
  465. pad_token_id (`int`, *optional*):
  466. The id of the *padding* token.
  467. eos_token_id (`Union[int, list[int]]`, *optional*):
  468. The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
  469. beam_indices (`torch.LongTensor`, *optional*):
  470. Beam indices indicating to which beam hypothesis each token correspond.
  471. decoder_prompt_len (`int`, *optional*):
  472. The length of prompt that is included in the input to decoder.
  473. Return:
  474. `UserDict`: A dictionary composed of the fields as defined above:
  475. - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
  476. all
  477. non-finished beams.
  478. - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
  479. added
  480. to the non-finished beam_hypotheses.
  481. - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
  482. indicating to which beam the next tokens shall be added.
  483. """
  484. # add up to the length which the next_scores is calculated on (including decoder prompt)
  485. cur_len = input_ids.shape[-1] + 1
  486. batch_size = len(self._beam_hyps)
  487. device = input_ids.device
  488. next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
  489. next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
  490. next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
  491. if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
  492. if isinstance(eos_token_id, int):
  493. eos_token_id = [eos_token_id]
  494. eos_token_id = torch.tensor(eos_token_id)
  495. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  496. if self._done[batch_idx]:
  497. if self.num_beams < len(beam_hyp):
  498. raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
  499. if eos_token_id is None or pad_token_id is None:
  500. raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
  501. # pad the batch
  502. next_beam_scores[batch_idx, :] = 0
  503. next_beam_tokens[batch_idx, :] = pad_token_id
  504. next_beam_indices[batch_idx, :] = 0
  505. continue
  506. # next tokens for this sentence.
  507. beam_idx = 0
  508. for beam_token_rank, (next_token, next_score, next_index) in enumerate(
  509. zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
  510. ):
  511. batch_beam_idx = batch_idx * self.num_beams + next_index
  512. # add to generated hypotheses if end of sentence
  513. if (eos_token_id is not None) and (next_token.item() in eos_token_id):
  514. # if beam_token does not belong to top num_beams tokens, it should not be added
  515. is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
  516. if is_beam_token_worse_than_top_num_beams:
  517. continue
  518. completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].tolist())
  519. if completes_constraint:
  520. if beam_indices is not None:
  521. beam_index = beam_indices[batch_beam_idx]
  522. beam_index = beam_index + (batch_beam_idx,)
  523. else:
  524. beam_index = None
  525. beam_hyp.add(
  526. input_ids[batch_beam_idx].clone(),
  527. next_score.item(),
  528. beam_indices=beam_index,
  529. generated_len=cur_len - decoder_prompt_len,
  530. )
  531. else:
  532. # add next predicted token since it is not eos_token
  533. next_beam_scores[batch_idx, beam_idx] = next_score
  534. next_beam_tokens[batch_idx, beam_idx] = next_token
  535. next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
  536. beam_idx += 1
  537. # once the beam for next step is full, don't add more tokens to it.
  538. if beam_idx == self.num_beams:
  539. break
  540. new_scores, new_tokens, new_indices = self.step_sentence_constraint(
  541. batch_idx,
  542. input_ids,
  543. scores_for_all_vocab,
  544. next_beam_scores[batch_idx],
  545. next_beam_tokens[batch_idx],
  546. next_beam_indices[batch_idx],
  547. )
  548. next_beam_scores[batch_idx] = new_scores
  549. next_beam_tokens[batch_idx] = new_tokens
  550. next_beam_indices[batch_idx] = new_indices
  551. if beam_idx < self.num_beams:
  552. raise ValueError(
  553. f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
  554. f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
  555. )
  556. # Check if we are done so that we can save a pad step if all(done)
  557. self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
  558. next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
  559. )
  560. return UserDict(
  561. {
  562. "next_beam_scores": next_beam_scores.view(-1),
  563. "next_beam_tokens": next_beam_tokens.view(-1),
  564. "next_beam_indices": next_beam_indices.view(-1),
  565. }
  566. )
  567. def step_sentence_constraint(
  568. self,
  569. batch_idx: int,
  570. input_ids: torch.LongTensor,
  571. vocab_scores: torch.FloatTensor,
  572. sent_beam_scores: torch.FloatTensor,
  573. sent_beam_tokens: torch.LongTensor,
  574. sent_beam_indices: torch.LongTensor,
  575. push_progress: bool = False,
  576. ):
  577. # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
  578. # (candidate next tokens)
  579. # 1. Adding "advance_tokens"
  580. # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
  581. # advance us in fulfilling the constraints.
  582. # 2. Selecting best candidates such that we end up with highest probable candidates
  583. # that fulfill our constraints.
  584. orig_len = sent_beam_indices.size(0)
  585. device = sent_beam_indices.device
  586. # initialize states
  587. topk_constraint_states = self.make_constraint_states(orig_len)
  588. advance_constraint_states = self.make_constraint_states(orig_len)
  589. sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
  590. this_batch_input_ids = input_ids[sidx:eidx]
  591. this_batch_token_scores = vocab_scores[sidx:eidx]
  592. full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
  593. # need to make new hypothesis that advance the constraints
  594. track_new = {
  595. "new_seqs": full_hypotheses.tolist(),
  596. "new_states": [],
  597. "new_indices": [],
  598. "new_tokens": [],
  599. "new_scores": [],
  600. }
  601. for seq_idx, pre_seq in enumerate(this_batch_input_ids):
  602. # pre_seq = ith sequence generated before this step.
  603. # input_ids -> (topk) generic beam search best model next tokens
  604. # -> (advance) constraints forcing the next token
  605. # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
  606. # hypotheses.
  607. topk_state = topk_constraint_states[seq_idx]
  608. topk_state.reset(full_hypotheses[seq_idx].tolist())
  609. advance_state = advance_constraint_states[seq_idx]
  610. advance_state.reset(pre_seq.tolist())
  611. if not advance_state.completed:
  612. advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
  613. for advance_token in advance_tokens:
  614. # since adding each `advance_token` leads to a different hypothesis, create new state instance.
  615. new_state = advance_state.copy(stateful=True)
  616. new_state.add(advance_token.tolist())
  617. advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).tolist()
  618. if advance_seq not in track_new["new_seqs"]:
  619. # prevent duplicates, which are basically bound to happen in this process.
  620. track_new["new_seqs"].append(advance_seq)
  621. track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
  622. track_new["new_tokens"].append(advance_token)
  623. track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
  624. track_new["new_states"].append(new_state)
  625. elif push_progress:
  626. # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
  627. # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
  628. # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
  629. # Without this step, if `sent_beam_indices` is something like [1,1], then
  630. # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
  631. # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
  632. # the else part of `if constraints_completed[seq_idx]`)
  633. # 3. it ends up simply getting removed from consideration.
  634. # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
  635. # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
  636. # search times, since completed sequences keep getting removed after all this effort for constrained
  637. # generation.
  638. # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
  639. # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
  640. new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
  641. advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
  642. advance_state = advance_constraint_states[seq_idx]
  643. advance_seq = advance_seq.tolist()
  644. advance_state.reset(advance_seq)
  645. if advance_seq not in track_new["new_seqs"]:
  646. # but still don't want to have duplicates
  647. track_new["new_seqs"].append(advance_seq)
  648. track_new["new_indices"].append(seq_idx)
  649. track_new["new_tokens"].append(new_token)
  650. track_new["new_scores"].append(new_score)
  651. track_new["new_states"].append(advance_state)
  652. if len(track_new["new_indices"]) > 0:
  653. new_indices = torch.tensor(track_new["new_indices"], device=device)
  654. new_tokens = torch.stack(track_new["new_tokens"]).to(device)
  655. new_scores = torch.stack(track_new["new_scores"]).to(device)
  656. all_states = topk_constraint_states + track_new["new_states"]
  657. all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
  658. all_scores = torch.cat((sent_beam_scores, new_scores), -1)
  659. all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
  660. zipped = all_banks * 100 + all_scores
  661. indices = zipped.sort(descending=True).indices
  662. sorted_banks = all_banks[indices]
  663. # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
  664. counter = -1
  665. cur_bank = sorted_banks[0]
  666. increments = []
  667. for bank in sorted_banks:
  668. if bank == cur_bank:
  669. counter += 1
  670. else:
  671. counter = 0
  672. cur_bank = bank
  673. increments.append(counter)
  674. rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
  675. indices = indices[rearrangers][:orig_len]
  676. sent_beam_scores = all_scores[indices]
  677. sent_beam_tokens = all_tokens[indices]
  678. sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
  679. return sent_beam_scores, sent_beam_tokens, sent_beam_indices
  680. def finalize(
  681. self,
  682. input_ids: torch.LongTensor,
  683. final_beam_scores: torch.FloatTensor,
  684. final_beam_tokens: torch.LongTensor,
  685. final_beam_indices: torch.LongTensor,
  686. max_length: int,
  687. pad_token_id: Optional[Union[int, torch.Tensor]] = None,
  688. eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
  689. beam_indices: Optional[torch.LongTensor] = None,
  690. decoder_prompt_len: int = 0,
  691. ) -> tuple[torch.LongTensor]:
  692. batch_size = len(self._beam_hyps)
  693. if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
  694. if isinstance(eos_token_id, int):
  695. eos_token_id = [eos_token_id]
  696. eos_token_id = torch.tensor(eos_token_id)
  697. # finalize all open beam hypotheses and add to generated hypotheses
  698. for batch_idx, beam_hyp in enumerate(self._beam_hyps):
  699. if self._done[batch_idx]:
  700. continue
  701. # all open beam hypotheses are added to the beam hypothesis
  702. # beam hypothesis class automatically keeps the best beams
  703. ids_collect = []
  704. for beam_id in range(self.num_beams):
  705. batch_beam_idx = batch_idx * self.num_beams + beam_id
  706. final_score = final_beam_scores[batch_beam_idx].item()
  707. final_tokens = input_ids[batch_beam_idx]
  708. completes_constraint = self.check_completes_constraints(final_tokens.tolist())
  709. if completes_constraint:
  710. beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
  711. generated_len = final_tokens.shape[-1] - decoder_prompt_len
  712. beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
  713. ids_collect.append(beam_id)
  714. # due to overly complex constraints or other factors, sometimes we can't guarantee a successful
  715. # generation. In these cases we simply return the highest scoring outputs.
  716. if len(ids_collect) < self.num_beam_hyps_to_keep:
  717. for beam_id in range(self.num_beams):
  718. if beam_id not in ids_collect:
  719. batch_beam_idx = batch_idx * self.num_beams + beam_id
  720. final_score = final_beam_scores[batch_beam_idx].item()
  721. final_tokens = input_ids[batch_beam_idx]
  722. generated_len = final_tokens.shape[-1] - decoder_prompt_len
  723. beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
  724. if len(ids_collect) >= self.num_beam_hyps_to_keep:
  725. break
  726. # select the best hypotheses
  727. sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
  728. best = []
  729. best_indices = []
  730. best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
  731. # retrieve best hypotheses
  732. for i, beam_hyp in enumerate(self._beam_hyps):
  733. sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
  734. for j in range(self.num_beam_hyps_to_keep):
  735. best_hyp_tuple = sorted_hyps.pop()
  736. best_score = best_hyp_tuple[0]
  737. best_hyp = best_hyp_tuple[1]
  738. best_index = best_hyp_tuple[2]
  739. sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
  740. # append to lists
  741. best.append(best_hyp)
  742. # append indices to list
  743. best_indices.append(best_index)
  744. best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
  745. # prepare for adding eos
  746. sent_lengths_max = sent_lengths.max().item() + 1
  747. sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
  748. decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
  749. if len(best_indices) > 0 and best_indices[0] is not None:
  750. indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
  751. else:
  752. indices = None
  753. # shorter batches are padded if needed
  754. if sent_lengths.min().item() != sent_lengths.max().item():
  755. if pad_token_id is None:
  756. raise ValueError("`pad_token_id` has to be defined")
  757. decoded.fill_(pad_token_id)
  758. if indices is not None:
  759. indices.fill_(-1)
  760. # fill with hypotheses and eos_token_id if the latter fits in
  761. for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
  762. decoded[i, : sent_lengths[i]] = hypo
  763. if indices is not None:
  764. indices[i, : len(best_idx)] = torch.tensor(best_idx)
  765. if sent_lengths[i] < sent_max_len:
  766. # inserting only the first eos_token_id
  767. decoded[i, sent_lengths[i]] = eos_token_id[0]
  768. return UserDict(
  769. {
  770. "sequences": decoded,
  771. "sequence_scores": best_scores,
  772. "beam_indices": indices,
  773. }
  774. )
  775. class BeamHypotheses:
  776. def __init__(
  777. self, num_beams: int, length_penalty: float, early_stopping: Union[bool, str], max_length: Optional[int] = None
  778. ):
  779. """
  780. Initialize n-best list of hypotheses.
  781. """
  782. logger.warning_once(
  783. "`BeamHypotheses` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
  784. )
  785. self.length_penalty = length_penalty
  786. self.early_stopping = early_stopping
  787. self.max_length = max_length
  788. self.num_beams = num_beams
  789. self.beams = []
  790. self.worst_score = 1e9
  791. if not isinstance(self.early_stopping, bool) and self.max_length is None:
  792. raise ValueError(
  793. "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
  794. " BeamScorer class instance at initialization time."
  795. )
  796. def __len__(self):
  797. """
  798. Number of hypotheses in the list.
  799. """
  800. return len(self.beams)
  801. def add(
  802. self,
  803. hyp: torch.LongTensor,
  804. sum_logprobs: float,
  805. beam_indices: Optional[torch.LongTensor] = None,
  806. generated_len: Optional[int] = None,
  807. ):
  808. """
  809. Add a new hypothesis to the list.
  810. """
  811. if generated_len is not None:
  812. score = sum_logprobs / (generated_len**self.length_penalty)
  813. # This 'else' case exists for retrocompatibility
  814. else:
  815. score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
  816. if len(self) < self.num_beams or score > self.worst_score:
  817. self.beams.append((score, hyp, beam_indices))
  818. if len(self) > self.num_beams:
  819. sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
  820. del self.beams[sorted_next_scores[0][1]]
  821. self.worst_score = sorted_next_scores[1][0]
  822. else:
  823. self.worst_score = min(score, self.worst_score)
  824. def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: int = 0) -> bool:
  825. """
  826. If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
  827. one in the heap, then we are done with this sentence.
  828. """
  829. if len(self) < self.num_beams:
  830. return False
  831. # `True`: stop as soon as at least `num_beams` hypotheses are finished
  832. if self.early_stopping is True:
  833. return True
  834. # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
  835. # when `length_penalty` is positive. See the discussion below for more details.
  836. # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
  837. elif self.early_stopping is False:
  838. highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
  839. ret = self.worst_score >= highest_attainable_score
  840. return ret
  841. # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
  842. else:
  843. # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
  844. # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
  845. # its max this way
  846. if self.length_penalty > 0.0:
  847. if self.max_length <= decoder_prompt_len:
  848. raise ValueError("max_length is not larger than decoder prompt length")
  849. highest_attainable_score = (
  850. best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
  851. )
  852. # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
  853. else:
  854. highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
  855. ret = self.worst_score >= highest_attainable_score
  856. return ret