| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002 |
- # coding=utf-8
- # Copyright 2020 The HuggingFace Inc. team
- #
- # Licensed under the Apache License, Version 2.0 (the "License");
- # you may not use this file except in compliance with the License.
- # You may obtain a copy of the License at
- #
- # http://www.apache.org/licenses/LICENSE-2.0
- #
- # Unless required by applicable law or agreed to in writing, software
- # distributed under the License is distributed on an "AS IS" BASIS,
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- # See the License for the specific language governing permissions and
- # limitations under the License.
- from abc import ABC, abstractmethod
- from collections import UserDict
- from typing import Optional, Union
- import numpy as np
- import torch
- from ..utils import add_start_docstrings, logging
- from .beam_constraints import Constraint, ConstraintListState
- logger = logging.get_logger(__name__)
- PROCESS_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
- Current scores of the top `2 * num_beams` non-finished beam hypotheses.
- next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
- `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
- next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
- Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`Union[int, list[int]]`, *optional*):
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- beam_indices (`torch.LongTensor`, *optional*):
- Beam indices indicating to which beam hypothesis each token correspond.
- Return:
- `UserDict`: A dictionary composed of the fields as defined above:
- - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of all
- non-finished beams.
- - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be added
- to the non-finished beam_hypotheses.
- - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
- indicating to which beam the next tokens shall be added.
- """
- FINALIZE_INPUTS_DOCSTRING = r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- final_beam_scores (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
- The final scores of all non-finished beams.
- final_beam_tokens (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
- The last tokens to be added to the non-finished beam_hypotheses.
- final_beam_indices (`torch.FloatTensor` of shape `(batch_size * num_beams)`):
- The beam indices indicating to which beam the `final_beam_tokens` shall be added.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`Union[int, list[int]]`, *optional*):
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- Return:
- `torch.LongTensor` of shape `(batch_size * num_return_sequences, sequence_length)`: The generated sequences.
- The second dimension (sequence_length) is either equal to `max_length` or shorter if all batches finished early
- due to the `eos_token_id`.
- """
- class BeamScorer(ABC):
- """
- Abstract base class for all beam scorers that are used for [`~PreTrainedModel.beam_search`] and
- [`~PreTrainedModel.beam_sample`].
- """
- @abstractmethod
- @add_start_docstrings(PROCESS_INPUTS_DOCSTRING)
- def process(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- **kwargs,
- ) -> tuple[torch.Tensor]:
- raise NotImplementedError("This is an abstract method.")
- @abstractmethod
- @add_start_docstrings(FINALIZE_INPUTS_DOCSTRING)
- def finalize(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- max_length: int,
- **kwargs,
- ) -> torch.LongTensor:
- raise NotImplementedError("This is an abstract method.")
- class BeamSearchScorer(BeamScorer):
- r"""
- [`BeamScorer`] implementing standard beam search decoding.
- Adapted in part from [Facebook's XLM beam search
- code](https://github.com/facebookresearch/XLM/blob/9e6f6814d17be4fe5b15f2e6c43eb2b2d76daeb4/src/model/transformer.py#L529).
- Reference for the diverse beam search algorithm and implementation [Ashwin Kalyan's DBS
- implementation](https://github.com/ashwinkalyan/dbs/blob/master/dbs/beam_utils.lua)
- Args:
- batch_size (`int`):
- Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
- num_beams (`int`):
- Number of beams for beam search.
- device (`torch.device`):
- Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
- allocated.
- length_penalty (`float`, *optional*, defaults to 1.0):
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
- the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
- likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
- `length_penalty` < 0.0 encourages shorter sequences.
- do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
- Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
- `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
- heuristic is applied and the generation stops when is it very unlikely to find better candidates;
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
- beam search algorithm).
- num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
- The number of beam hypotheses that shall be returned upon calling
- [`~transformers.BeamSearchScorer.finalize`].
- num_beam_groups (`int`, *optional*, defaults to 1):
- Number of groups to divide `num_beams` into in order to ensure diversity among different groups of beams.
- See [this paper](https://huggingface.co/papers/1610.02424) for more details.
- max_length (`int`, *optional*):
- The maximum length of the sequence to be generated.
- """
- def __init__(
- self,
- batch_size: int,
- num_beams: int,
- device: torch.device,
- length_penalty: float = 1.0,
- do_early_stopping: Union[bool, str] = False,
- num_beam_hyps_to_keep: int = 1,
- num_beam_groups: int = 1,
- max_length: Optional[int] = None,
- ):
- logger.warning_once(
- "`BeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
- )
- self.num_beams = num_beams
- self.device = device
- self.length_penalty = length_penalty
- self.do_early_stopping = do_early_stopping
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
- self.num_beam_groups = num_beam_groups
- self.group_size = self.num_beams // self.num_beam_groups
- self._is_init = False
- # self._beam_hyps[i*self.num_beam_groups+j] is the beam_hyps of the j-th group in the i-th mini-batch.
- # If group_beam_search is not used, the list consists of `batch_size` beam_hyps.
- self._beam_hyps = [
- BeamHypotheses(
- num_beams=self.group_size,
- length_penalty=self.length_penalty,
- early_stopping=self.do_early_stopping,
- max_length=max_length,
- )
- for _ in range(batch_size * self.num_beam_groups)
- ]
- # self._done[i*self.num_beam_groups+j] indicates whether the generation of the beam_hyps of the j-th group
- # in the i-th mini-batch is complete.
- self._done = torch.tensor(
- [False for _ in range(batch_size * self.num_beam_groups)], dtype=torch.bool, device=self.device
- )
- if not isinstance(num_beams, int) or num_beams <= 1:
- raise ValueError(
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
- " one should make use of `greedy_search` instead."
- )
- if not isinstance(num_beam_groups, int) or (num_beam_groups > num_beams) or (num_beams % num_beam_groups != 0):
- raise ValueError(
- "`num_beam_groups` has to be an integer smaller or equal than `num_beams` and `num_beams` has to be"
- f" divisible by `num_beam_groups`, but is {num_beam_groups} with `num_beams` being {num_beams}."
- )
- @property
- def is_done(self) -> bool:
- return self._done.all().item()
- def process(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
- eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
- beam_indices: Optional[torch.LongTensor] = None,
- group_index: int = 0,
- decoder_prompt_len: int = 0,
- ) -> dict[str, torch.Tensor]:
- # add up to the length which the next_scores is calculated on (including decoder prompt)
- cur_len = input_ids.shape[-1] + 1
- batch_size = len(self._beam_hyps) // self.num_beam_groups
- if batch_size != (input_ids.shape[0] // self.group_size):
- if self.num_beam_groups > 1:
- raise ValueError(
- f"A group beam size of {input_ids.shape[0]} is used as the input, but a group beam "
- f"size of {self.group_size} is expected by the beam scorer."
- )
- else:
- raise ValueError(
- f"A beam size of {input_ids.shape[0]} is used as the input, but a beam size of "
- f"{self.group_size} is expected by the beam scorer."
- )
- device = input_ids.device
- next_beam_scores = torch.zeros((batch_size, self.group_size), dtype=next_scores.dtype, device=device)
- next_beam_tokens = torch.zeros((batch_size, self.group_size), dtype=next_tokens.dtype, device=device)
- next_beam_indices = torch.zeros((batch_size, self.group_size), dtype=next_indices.dtype, device=device)
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id = torch.tensor(eos_token_id)
- for batch_idx in range(batch_size):
- batch_group_idx = batch_idx * self.num_beam_groups + group_index
- if self._done[batch_group_idx]:
- if self.num_beams < len(self._beam_hyps[batch_group_idx]):
- raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
- if eos_token_id is None or pad_token_id is None:
- raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
- # pad the batch
- next_beam_scores[batch_idx, :] = 0
- next_beam_tokens[batch_idx, :] = pad_token_id
- next_beam_indices[batch_idx, :] = 0
- continue
- # next tokens for this sentence
- beam_idx = 0
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
- zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
- ):
- batch_beam_idx = batch_idx * self.group_size + next_index
- # add to generated hypotheses if end of sentence
- if (eos_token_id is not None) and (next_token.item() in eos_token_id):
- # if beam_token does not belong to top num_beams tokens, it should not be added
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.group_size
- if is_beam_token_worse_than_top_num_beams:
- continue
- if beam_indices is not None:
- beam_index = beam_indices[batch_beam_idx]
- beam_index = beam_index + (batch_beam_idx,)
- else:
- beam_index = None
- self._beam_hyps[batch_group_idx].add(
- input_ids[batch_beam_idx].clone(),
- next_score.item(),
- beam_indices=beam_index,
- generated_len=cur_len - decoder_prompt_len,
- )
- else:
- # add next predicted token since it is not eos_token
- next_beam_scores[batch_idx, beam_idx] = next_score
- next_beam_tokens[batch_idx, beam_idx] = next_token
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
- beam_idx += 1
- # once the beam for next step is full, don't add more tokens to it.
- if beam_idx == self.group_size:
- break
- if beam_idx < self.group_size:
- raise ValueError(
- f"At most {self.group_size} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
- f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
- )
- # Check if we are done so that we can save a pad step if all(done)
- self._done[batch_group_idx] = self._done[batch_group_idx] or self._beam_hyps[batch_group_idx].is_done(
- next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
- )
- return UserDict(
- {
- "next_beam_scores": next_beam_scores.view(-1),
- "next_beam_tokens": next_beam_tokens.view(-1),
- "next_beam_indices": next_beam_indices.view(-1),
- }
- )
- def finalize(
- self,
- input_ids: torch.LongTensor,
- final_beam_scores: torch.FloatTensor,
- final_beam_tokens: torch.LongTensor,
- final_beam_indices: torch.LongTensor,
- max_length: int,
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
- eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
- beam_indices: Optional[torch.LongTensor] = None,
- decoder_prompt_len: int = 0,
- ) -> tuple[torch.LongTensor]:
- batch_size = len(self._beam_hyps) // self.num_beam_groups
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id = torch.tensor(eos_token_id)
- # finalize all open beam hypotheses and add to generated hypotheses
- for batch_group_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_group_idx]:
- continue
- # all open beam hypotheses are added to the beam hypothesis
- # beam hypothesis class automatically keeps the best beams
- for index_per_group in range(self.group_size):
- batch_beam_idx = batch_group_idx * self.group_size + index_per_group
- final_score = final_beam_scores[batch_beam_idx].item()
- final_tokens = input_ids[batch_beam_idx]
- beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
- beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
- # select the best hypotheses
- sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
- best = []
- best_indices = []
- best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
- # retrieve best hypotheses
- for i in range(batch_size):
- beam_hyps_in_batch = self._beam_hyps[i * self.num_beam_groups : (i + 1) * self.num_beam_groups]
- candidate_beams = [beam for beam_hyp in beam_hyps_in_batch for beam in beam_hyp.beams]
- sorted_hyps = sorted(candidate_beams, key=lambda x: x[0])
- for j in range(self.num_beam_hyps_to_keep):
- best_hyp_tuple = sorted_hyps.pop()
- best_score = best_hyp_tuple[0]
- best_hyp = best_hyp_tuple[1]
- best_index = best_hyp_tuple[2]
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
- # append hyp to lists
- best.append(best_hyp)
- # append indices to list
- best_indices.append(best_index)
- best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
- # prepare for adding eos
- sent_lengths_max = sent_lengths.max().item() + 1
- sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
- decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
- if len(best_indices) > 0 and best_indices[0] is not None:
- indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
- else:
- indices = None
- # shorter batches are padded if needed
- if sent_lengths.min().item() != sent_lengths.max().item():
- if pad_token_id is None:
- raise ValueError("`pad_token_id` has to be defined")
- decoded.fill_(pad_token_id)
- if indices is not None:
- indices.fill_(-1)
- # fill with hypotheses and eos_token_id if the latter fits in
- for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
- decoded[i, : sent_lengths[i]] = hypo
- if indices is not None:
- indices[i, : len(best_idx)] = torch.tensor(best_idx)
- if sent_lengths[i] < sent_max_len:
- # inserting only the first eos_token_id
- decoded[i, sent_lengths[i]] = eos_token_id[0]
- return UserDict(
- {
- "sequences": decoded,
- "sequence_scores": best_scores,
- "beam_indices": indices,
- }
- )
- class ConstrainedBeamSearchScorer(BeamScorer):
- r"""
- [`BeamScorer`] implementing constrained beam search decoding.
- Args:
- batch_size (`int`):
- Batch Size of `input_ids` for which standard beam search decoding is run in parallel.
- num_beams (`int`):
- Number of beams for beam search.
- constraints (`list[Constraint]`):
- A list of positive constraints represented as `Constraint` objects that must be fulfilled in the generation
- output. For more information, the documentation of [`Constraint`] should be read.
- device (`torch.device`):
- Defines the device type (*e.g.*, `"cpu"` or `"cuda"`) on which this instance of `BeamSearchScorer` will be
- allocated.
- length_penalty (`float`, *optional*, defaults to 1.0):
- Exponential penalty to the length that is used with beam-based generation. It is applied as an exponent to
- the sequence length, which in turn is used to divide the score of the sequence. Since the score is the log
- likelihood of the sequence (i.e. negative), `length_penalty` > 0.0 promotes longer sequences, while
- `length_penalty` < 0.0 encourages shorter sequences.
- do_early_stopping (`bool` or `str`, *optional*, defaults to `False`):
- Controls the stopping condition for beam-based methods, like beam-search. It accepts the following values:
- `True`, where the generation stops as soon as there are `num_beams` complete candidates; `False`, where an
- heuristic is applied and the generation stops when is it very unlikely to find better candidates;
- `"never"`, where the beam search procedure only stops when there cannot be better candidates (canonical
- beam search algorithm).
- num_beam_hyps_to_keep (`int`, *optional*, defaults to 1):
- The number of beam hypotheses that shall be returned upon calling
- [`~transformers.BeamSearchScorer.finalize`].
- max_length (`int`, *optional*):
- The maximum length of the sequence to be generated.
- """
- def __init__(
- self,
- batch_size: int,
- num_beams: int,
- constraints: list[Constraint],
- device: torch.device,
- length_penalty: float = 1.0,
- do_early_stopping: Union[bool, str] = False,
- num_beam_hyps_to_keep: int = 1,
- max_length: Optional[int] = None,
- ):
- logger.warning_once(
- "`ConstrainedBeamSearchScorer` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
- )
- self.num_beams = num_beams
- self.device = device
- self.length_penalty = length_penalty
- self.do_early_stopping = do_early_stopping
- self.num_beam_hyps_to_keep = num_beam_hyps_to_keep
- self.constraints = constraints
- self._is_init = False
- self._beam_hyps = [
- BeamHypotheses(
- num_beams=self.num_beams,
- length_penalty=self.length_penalty,
- early_stopping=self.do_early_stopping,
- max_length=max_length,
- )
- for _ in range(batch_size)
- ]
- self._done = torch.tensor([False for _ in range(batch_size)], dtype=torch.bool, device=self.device)
- if not isinstance(num_beams, int) or num_beams <= 1:
- raise ValueError(
- f"`num_beams` has to be an integer strictly greater than 1, but is {num_beams}. For `num_beams` == 1,"
- " one should make use of `greedy_search` instead."
- )
- @property
- def is_done(self) -> bool:
- return self._done.all().item()
- def make_constraint_states(self, n):
- return [ConstraintListState([constraint.copy() for constraint in self.constraints]) for _ in range(n)]
- def check_completes_constraints(self, sequence):
- new_state = self.make_constraint_states(1)[0]
- new_state.reset(sequence)
- return new_state.completed
- def process(
- self,
- input_ids: torch.LongTensor,
- next_scores: torch.FloatTensor,
- next_tokens: torch.LongTensor,
- next_indices: torch.LongTensor,
- scores_for_all_vocab: torch.FloatTensor,
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
- eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
- beam_indices: Optional[torch.LongTensor] = None,
- decoder_prompt_len: int = 0,
- ) -> tuple[torch.Tensor]:
- r"""
- Args:
- input_ids (`torch.LongTensor` of shape `(batch_size * num_beams, sequence_length)`):
- Indices of input sequence tokens in the vocabulary.
- Indices can be obtained using any class inheriting from [`PreTrainedTokenizer`]. See
- [`PreTrainedTokenizer.encode`] and [`PreTrainedTokenizer.__call__`] for details.
- [What are input IDs?](../glossary#input-ids)
- next_scores (`torch.FloatTensor` of shape `(batch_size, 2 * num_beams)`):
- Current scores of the top `2 * num_beams` non-finished beam hypotheses.
- next_tokens (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
- `input_ids` of the tokens corresponding to the top `2 * num_beams` non-finished beam hypotheses.
- next_indices (`torch.LongTensor` of shape `(batch_size, 2 * num_beams)`):
- Beam indices indicating to which beam hypothesis the `next_tokens` correspond.
- scores_for_all_vocab (`torch.FloatTensor` of shape `(batch_size * num_beams, sequence_length)`):
- The scores of all tokens in the vocabulary for each of the beam hypotheses.
- pad_token_id (`int`, *optional*):
- The id of the *padding* token.
- eos_token_id (`Union[int, list[int]]`, *optional*):
- The id of the *end-of-sequence* token. Optionally, use a list to set multiple *end-of-sequence* tokens.
- beam_indices (`torch.LongTensor`, *optional*):
- Beam indices indicating to which beam hypothesis each token correspond.
- decoder_prompt_len (`int`, *optional*):
- The length of prompt that is included in the input to decoder.
- Return:
- `UserDict`: A dictionary composed of the fields as defined above:
- - **next_beam_scores** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Updated scores of
- all
- non-finished beams.
- - **next_beam_tokens** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Next tokens to be
- added
- to the non-finished beam_hypotheses.
- - **next_beam_indices** (`torch.FloatTensor` of shape `(batch_size * num_beams)`) -- Beam indices
- indicating to which beam the next tokens shall be added.
- """
- # add up to the length which the next_scores is calculated on (including decoder prompt)
- cur_len = input_ids.shape[-1] + 1
- batch_size = len(self._beam_hyps)
- device = input_ids.device
- next_beam_scores = torch.zeros((batch_size, self.num_beams), dtype=next_scores.dtype, device=device)
- next_beam_tokens = torch.zeros((batch_size, self.num_beams), dtype=next_tokens.dtype, device=device)
- next_beam_indices = torch.zeros((batch_size, self.num_beams), dtype=next_indices.dtype, device=device)
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id = torch.tensor(eos_token_id)
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx]:
- if self.num_beams < len(beam_hyp):
- raise ValueError(f"Batch can only be done if at least {self.num_beams} beams have been generated")
- if eos_token_id is None or pad_token_id is None:
- raise ValueError("Generated beams >= num_beams -> eos_token_id and pad_token have to be defined")
- # pad the batch
- next_beam_scores[batch_idx, :] = 0
- next_beam_tokens[batch_idx, :] = pad_token_id
- next_beam_indices[batch_idx, :] = 0
- continue
- # next tokens for this sentence.
- beam_idx = 0
- for beam_token_rank, (next_token, next_score, next_index) in enumerate(
- zip(next_tokens[batch_idx], next_scores[batch_idx], next_indices[batch_idx])
- ):
- batch_beam_idx = batch_idx * self.num_beams + next_index
- # add to generated hypotheses if end of sentence
- if (eos_token_id is not None) and (next_token.item() in eos_token_id):
- # if beam_token does not belong to top num_beams tokens, it should not be added
- is_beam_token_worse_than_top_num_beams = beam_token_rank >= self.num_beams
- if is_beam_token_worse_than_top_num_beams:
- continue
- completes_constraint = self.check_completes_constraints(input_ids[batch_beam_idx].tolist())
- if completes_constraint:
- if beam_indices is not None:
- beam_index = beam_indices[batch_beam_idx]
- beam_index = beam_index + (batch_beam_idx,)
- else:
- beam_index = None
- beam_hyp.add(
- input_ids[batch_beam_idx].clone(),
- next_score.item(),
- beam_indices=beam_index,
- generated_len=cur_len - decoder_prompt_len,
- )
- else:
- # add next predicted token since it is not eos_token
- next_beam_scores[batch_idx, beam_idx] = next_score
- next_beam_tokens[batch_idx, beam_idx] = next_token
- next_beam_indices[batch_idx, beam_idx] = batch_beam_idx
- beam_idx += 1
- # once the beam for next step is full, don't add more tokens to it.
- if beam_idx == self.num_beams:
- break
- new_scores, new_tokens, new_indices = self.step_sentence_constraint(
- batch_idx,
- input_ids,
- scores_for_all_vocab,
- next_beam_scores[batch_idx],
- next_beam_tokens[batch_idx],
- next_beam_indices[batch_idx],
- )
- next_beam_scores[batch_idx] = new_scores
- next_beam_tokens[batch_idx] = new_tokens
- next_beam_indices[batch_idx] = new_indices
- if beam_idx < self.num_beams:
- raise ValueError(
- f"At most {self.num_beams} tokens in {next_tokens[batch_idx]} can be equal to `eos_token_id:"
- f" {eos_token_id}`. Make sure {next_tokens[batch_idx]} are corrected."
- )
- # Check if we are done so that we can save a pad step if all(done)
- self._done[batch_idx] = self._done[batch_idx] or beam_hyp.is_done(
- next_scores[batch_idx].max().item(), cur_len, decoder_prompt_len
- )
- return UserDict(
- {
- "next_beam_scores": next_beam_scores.view(-1),
- "next_beam_tokens": next_beam_tokens.view(-1),
- "next_beam_indices": next_beam_indices.view(-1),
- }
- )
- def step_sentence_constraint(
- self,
- batch_idx: int,
- input_ids: torch.LongTensor,
- vocab_scores: torch.FloatTensor,
- sent_beam_scores: torch.FloatTensor,
- sent_beam_tokens: torch.LongTensor,
- sent_beam_indices: torch.LongTensor,
- push_progress: bool = False,
- ):
- # sent_beam_tokens are the next {num_beams} number of tokens that are under consideration for this beam
- # (candidate next tokens)
- # 1. Adding "advance_tokens"
- # using ConstraintStateList.advance(), we propose new tokens to be added into this "candidate list" that will
- # advance us in fulfilling the constraints.
- # 2. Selecting best candidates such that we end up with highest probable candidates
- # that fulfill our constraints.
- orig_len = sent_beam_indices.size(0)
- device = sent_beam_indices.device
- # initialize states
- topk_constraint_states = self.make_constraint_states(orig_len)
- advance_constraint_states = self.make_constraint_states(orig_len)
- sidx, eidx = batch_idx * orig_len, (batch_idx + 1) * orig_len
- this_batch_input_ids = input_ids[sidx:eidx]
- this_batch_token_scores = vocab_scores[sidx:eidx]
- full_hypotheses = torch.cat((input_ids[sent_beam_indices], sent_beam_tokens.unsqueeze(-1)), dim=-1)
- # need to make new hypothesis that advance the constraints
- track_new = {
- "new_seqs": full_hypotheses.tolist(),
- "new_states": [],
- "new_indices": [],
- "new_tokens": [],
- "new_scores": [],
- }
- for seq_idx, pre_seq in enumerate(this_batch_input_ids):
- # pre_seq = ith sequence generated before this step.
- # input_ids -> (topk) generic beam search best model next tokens
- # -> (advance) constraints forcing the next token
- # either way, we need to sort them into "banks" later, so store a "ConstraintListState" for all types of
- # hypotheses.
- topk_state = topk_constraint_states[seq_idx]
- topk_state.reset(full_hypotheses[seq_idx].tolist())
- advance_state = advance_constraint_states[seq_idx]
- advance_state.reset(pre_seq.tolist())
- if not advance_state.completed:
- advance_tokens = torch.tensor(advance_state.advance(), dtype=torch.long, device=device)
- for advance_token in advance_tokens:
- # since adding each `advance_token` leads to a different hypothesis, create new state instance.
- new_state = advance_state.copy(stateful=True)
- new_state.add(advance_token.tolist())
- advance_seq = torch.cat((pre_seq, advance_token.unsqueeze(0)), -1).tolist()
- if advance_seq not in track_new["new_seqs"]:
- # prevent duplicates, which are basically bound to happen in this process.
- track_new["new_seqs"].append(advance_seq)
- track_new["new_indices"].append(sidx + seq_idx) # idx -> global idx across all the batches
- track_new["new_tokens"].append(advance_token)
- track_new["new_scores"].append(this_batch_token_scores[seq_idx].take(advance_token))
- track_new["new_states"].append(new_state)
- elif push_progress:
- # Basically, `sent_beam_indices` often chooses very little among `input_ids` the generated sequences that
- # actually fulfill our constraints. For example, let constraints == ["loves pies"] and
- # pre_seq_1 = "The child loves pies and" pre_seq_2 = "The child plays in the playground and"
- # Without this step, if `sent_beam_indices` is something like [1,1], then
- # 1. `pre_seq_1` won't be added to the list of (topk) hypothesis since it's not in the indices and
- # 2. it won't be added to the list of (advance) hypothesis since it's completed already. (this is
- # the else part of `if constraints_completed[seq_idx]`)
- # 3. it ends up simply getting removed from consideration.
- # #3 might be fine and actually desired, since it's likely that it's a low-probability output anyways,
- # especially if it's not in the list of `sent_beam_indices`. But this often leads to lengthened beam
- # search times, since completed sequences keep getting removed after all this effort for constrained
- # generation.
- # Here, we basically take `pre_seq_1` and to "push" it into the considered list of hypotheses, by simply
- # appending the next likely token in the vocabulary and adding it to the list of hypotheses.
- new_score, new_token = torch.max(this_batch_token_scores[seq_idx], 0) # some next probable token
- advance_seq = torch.cat((pre_seq, new_token.unsqueeze(0)), -1)
- advance_state = advance_constraint_states[seq_idx]
- advance_seq = advance_seq.tolist()
- advance_state.reset(advance_seq)
- if advance_seq not in track_new["new_seqs"]:
- # but still don't want to have duplicates
- track_new["new_seqs"].append(advance_seq)
- track_new["new_indices"].append(seq_idx)
- track_new["new_tokens"].append(new_token)
- track_new["new_scores"].append(new_score)
- track_new["new_states"].append(advance_state)
- if len(track_new["new_indices"]) > 0:
- new_indices = torch.tensor(track_new["new_indices"], device=device)
- new_tokens = torch.stack(track_new["new_tokens"]).to(device)
- new_scores = torch.stack(track_new["new_scores"]).to(device)
- all_states = topk_constraint_states + track_new["new_states"]
- all_tokens = torch.cat((sent_beam_tokens, new_tokens), -1)
- all_scores = torch.cat((sent_beam_scores, new_scores), -1)
- all_banks = torch.tensor([one.get_bank() for one in all_states], device=device)
- zipped = all_banks * 100 + all_scores
- indices = zipped.sort(descending=True).indices
- sorted_banks = all_banks[indices]
- # Then we end up with {sorted among bank C}, {sorted among bank C-1}, ..., {sorted among bank 0}
- counter = -1
- cur_bank = sorted_banks[0]
- increments = []
- for bank in sorted_banks:
- if bank == cur_bank:
- counter += 1
- else:
- counter = 0
- cur_bank = bank
- increments.append(counter)
- rearrangers = torch.tensor(np.argsort(increments, kind="mergesort"))
- indices = indices[rearrangers][:orig_len]
- sent_beam_scores = all_scores[indices]
- sent_beam_tokens = all_tokens[indices]
- sent_beam_indices = torch.cat((sent_beam_indices, new_indices))[indices]
- return sent_beam_scores, sent_beam_tokens, sent_beam_indices
- def finalize(
- self,
- input_ids: torch.LongTensor,
- final_beam_scores: torch.FloatTensor,
- final_beam_tokens: torch.LongTensor,
- final_beam_indices: torch.LongTensor,
- max_length: int,
- pad_token_id: Optional[Union[int, torch.Tensor]] = None,
- eos_token_id: Optional[Union[int, list[int], torch.Tensor]] = None,
- beam_indices: Optional[torch.LongTensor] = None,
- decoder_prompt_len: int = 0,
- ) -> tuple[torch.LongTensor]:
- batch_size = len(self._beam_hyps)
- if eos_token_id is not None and not isinstance(eos_token_id, torch.Tensor):
- if isinstance(eos_token_id, int):
- eos_token_id = [eos_token_id]
- eos_token_id = torch.tensor(eos_token_id)
- # finalize all open beam hypotheses and add to generated hypotheses
- for batch_idx, beam_hyp in enumerate(self._beam_hyps):
- if self._done[batch_idx]:
- continue
- # all open beam hypotheses are added to the beam hypothesis
- # beam hypothesis class automatically keeps the best beams
- ids_collect = []
- for beam_id in range(self.num_beams):
- batch_beam_idx = batch_idx * self.num_beams + beam_id
- final_score = final_beam_scores[batch_beam_idx].item()
- final_tokens = input_ids[batch_beam_idx]
- completes_constraint = self.check_completes_constraints(final_tokens.tolist())
- if completes_constraint:
- beam_index = beam_indices[batch_beam_idx] if beam_indices is not None else None
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
- beam_hyp.add(final_tokens, final_score, beam_indices=beam_index, generated_len=generated_len)
- ids_collect.append(beam_id)
- # due to overly complex constraints or other factors, sometimes we can't guarantee a successful
- # generation. In these cases we simply return the highest scoring outputs.
- if len(ids_collect) < self.num_beam_hyps_to_keep:
- for beam_id in range(self.num_beams):
- if beam_id not in ids_collect:
- batch_beam_idx = batch_idx * self.num_beams + beam_id
- final_score = final_beam_scores[batch_beam_idx].item()
- final_tokens = input_ids[batch_beam_idx]
- generated_len = final_tokens.shape[-1] - decoder_prompt_len
- beam_hyp.add(final_tokens, final_score, generated_len=generated_len)
- if len(ids_collect) >= self.num_beam_hyps_to_keep:
- break
- # select the best hypotheses
- sent_lengths = input_ids.new(batch_size * self.num_beam_hyps_to_keep)
- best = []
- best_indices = []
- best_scores = torch.zeros(batch_size * self.num_beam_hyps_to_keep, device=self.device, dtype=torch.float32)
- # retrieve best hypotheses
- for i, beam_hyp in enumerate(self._beam_hyps):
- sorted_hyps = sorted(beam_hyp.beams, key=lambda x: x[0])
- for j in range(self.num_beam_hyps_to_keep):
- best_hyp_tuple = sorted_hyps.pop()
- best_score = best_hyp_tuple[0]
- best_hyp = best_hyp_tuple[1]
- best_index = best_hyp_tuple[2]
- sent_lengths[self.num_beam_hyps_to_keep * i + j] = len(best_hyp)
- # append to lists
- best.append(best_hyp)
- # append indices to list
- best_indices.append(best_index)
- best_scores[i * self.num_beam_hyps_to_keep + j] = best_score
- # prepare for adding eos
- sent_lengths_max = sent_lengths.max().item() + 1
- sent_max_len = min(sent_lengths_max, max_length) if max_length is not None else sent_lengths_max
- decoded: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
- if len(best_indices) > 0 and best_indices[0] is not None:
- indices: torch.LongTensor = input_ids.new(batch_size * self.num_beam_hyps_to_keep, sent_max_len)
- else:
- indices = None
- # shorter batches are padded if needed
- if sent_lengths.min().item() != sent_lengths.max().item():
- if pad_token_id is None:
- raise ValueError("`pad_token_id` has to be defined")
- decoded.fill_(pad_token_id)
- if indices is not None:
- indices.fill_(-1)
- # fill with hypotheses and eos_token_id if the latter fits in
- for i, (hypo, best_idx) in enumerate(zip(best, best_indices)):
- decoded[i, : sent_lengths[i]] = hypo
- if indices is not None:
- indices[i, : len(best_idx)] = torch.tensor(best_idx)
- if sent_lengths[i] < sent_max_len:
- # inserting only the first eos_token_id
- decoded[i, sent_lengths[i]] = eos_token_id[0]
- return UserDict(
- {
- "sequences": decoded,
- "sequence_scores": best_scores,
- "beam_indices": indices,
- }
- )
- class BeamHypotheses:
- def __init__(
- self, num_beams: int, length_penalty: float, early_stopping: Union[bool, str], max_length: Optional[int] = None
- ):
- """
- Initialize n-best list of hypotheses.
- """
- logger.warning_once(
- "`BeamHypotheses` is deprecated and will be removed in v4.62.0, as constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search."
- )
- self.length_penalty = length_penalty
- self.early_stopping = early_stopping
- self.max_length = max_length
- self.num_beams = num_beams
- self.beams = []
- self.worst_score = 1e9
- if not isinstance(self.early_stopping, bool) and self.max_length is None:
- raise ValueError(
- "When `do_early_stopping` is set to a string, `max_length` must be defined. Ensure it is passed to the"
- " BeamScorer class instance at initialization time."
- )
- def __len__(self):
- """
- Number of hypotheses in the list.
- """
- return len(self.beams)
- def add(
- self,
- hyp: torch.LongTensor,
- sum_logprobs: float,
- beam_indices: Optional[torch.LongTensor] = None,
- generated_len: Optional[int] = None,
- ):
- """
- Add a new hypothesis to the list.
- """
- if generated_len is not None:
- score = sum_logprobs / (generated_len**self.length_penalty)
- # This 'else' case exists for retrocompatibility
- else:
- score = sum_logprobs / (hyp.shape[-1] ** self.length_penalty)
- if len(self) < self.num_beams or score > self.worst_score:
- self.beams.append((score, hyp, beam_indices))
- if len(self) > self.num_beams:
- sorted_next_scores = sorted([(s, idx) for idx, (s, _, _) in enumerate(self.beams)])
- del self.beams[sorted_next_scores[0][1]]
- self.worst_score = sorted_next_scores[1][0]
- else:
- self.worst_score = min(score, self.worst_score)
- def is_done(self, best_sum_logprobs: float, cur_len: int, decoder_prompt_len: int = 0) -> bool:
- """
- If there are enough hypotheses and that none of the hypotheses being generated can become better than the worst
- one in the heap, then we are done with this sentence.
- """
- if len(self) < self.num_beams:
- return False
- # `True`: stop as soon as at least `num_beams` hypotheses are finished
- if self.early_stopping is True:
- return True
- # `False`: heuristic -- compute best possible score from `cur_len`, even though it is not entirely accurate
- # when `length_penalty` is positive. See the discussion below for more details.
- # https://github.com/huggingface/transformers/pull/20901#issuecomment-1369845565
- elif self.early_stopping is False:
- highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
- ret = self.worst_score >= highest_attainable_score
- return ret
- # `"never"`: compute the best possible score, depending on the signal of `length_penalty`
- else:
- # `length_penalty` > 0.0 -> max denominator is obtaned from `max_length`, not from `cur_len` -> min
- # abs(`highest_attainable_score`) is obtained -> `highest_attainable_score` is negative, hence we obtain
- # its max this way
- if self.length_penalty > 0.0:
- if self.max_length <= decoder_prompt_len:
- raise ValueError("max_length is not larger than decoder prompt length")
- highest_attainable_score = (
- best_sum_logprobs / (self.max_length - decoder_prompt_len) ** self.length_penalty
- )
- # the opposite logic applies here (max `highest_attainable_score` from `cur_len`)
- else:
- highest_attainable_score = best_sum_logprobs / (cur_len - decoder_prompt_len) ** self.length_penalty
- ret = self.worst_score >= highest_attainable_score
- return ret
|