| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533 |
- from abc import ABC, abstractmethod
- from typing import Optional
- from ..utils import logging
- logger = logging.get_logger(__name__)
- # TODO joao, manuel: remove in v4.58.0
- class Constraint(ABC):
- r"""Abstract base class for all constraints that can be applied during generation.
- It must define how the constraint can be satisfied.
- All classes that inherit Constraint must follow the requirement that
- ```py
- completed = False
- while not completed:
- _, completed = constraint.update(constraint.advance())
- ```
- will always terminate (halt).
- """
- def __init__(self):
- logger.warning_once(
- "Importing `Constraint` classes is deprecated and will be removed in v4.58.0. Constrained beam search has been moved to the Hub: https://hf.co/transformers-community/constrained-beam-search. Please import using `from transformers.generation import Constraint` instead."
- )
- # test for the above condition
- self.test()
- def test(self):
- """
- Tests whether this constraint has been properly defined.
- """
- counter = 0
- completed = False
- while not completed:
- if counter == 1:
- self.reset()
- advance = self.advance()
- if not self.does_advance(advance):
- raise Exception(
- "Custom Constraint is not defined correctly. self.does_advance(self.advance()) must be true."
- )
- stepped, completed, reset = self.update(advance)
- counter += 1
- if counter > 10000:
- raise Exception("update() does not fulfill the constraint.")
- if self.remaining() != 0:
- raise Exception("Custom Constraint is not defined correctly.")
- @abstractmethod
- def advance(self):
- """
- When called, returns the token(s) that would take this constraint one step closer to being fulfilled.
- Return:
- token_ids (Union[int, list[int], None]):
- - A single token ID (int) that advances the constraint, or
- - A list of token IDs that could advance the constraint
- - None if the constraint is completed or cannot be advanced
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- @abstractmethod
- def does_advance(self, token_id: int):
- """
- Reads in a token and returns whether it creates progress.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- @abstractmethod
- def update(self, token_id: int):
- """
- Reads in a token and returns booleans that indicate the progress made by it. This function will update the
- state of this object unlikes `does_advance(self, token_id: int)`.
- This isn't to test whether a certain token will advance the progress; it's to update its state as if it has
- been generated. This becomes important if token_id != desired token (refer to else statement in
- PhrasalConstraint)
- Args:
- token_id(`int`):
- The id of a newly generated token in the beam search.
- Return:
- stepped(`bool`):
- Whether this constraint has become one step closer to being fulfuilled.
- completed(`bool`):
- Whether this constraint has been completely fulfilled by this token being generated.
- reset (`bool`):
- Whether this constraint has reset its progress by this token being generated.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- @abstractmethod
- def reset(self):
- """
- Resets the state of this constraint to its initialization. We would call this in cases where the fulfillment of
- a constraint is abrupted by an unwanted token.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- @abstractmethod
- def remaining(self):
- """
- Returns the number of remaining steps of `advance()` in order to complete this constraint.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- @abstractmethod
- def copy(self, stateful=False):
- """
- Creates a new instance of this constraint.
- Args:
- stateful(`bool`): Whether to not only copy the constraint for new instance, but also its state.
- Return:
- constraint(`Constraint`): The same constraint as the one being called from.
- """
- raise NotImplementedError(
- f"{self.__class__} is an abstract class. Only classes inheriting this class can be called."
- )
- class PhrasalConstraint(Constraint):
- r"""
- [`Constraint`] enforcing that an ordered sequence of tokens is included in the output.
- Args:
- token_ids (`list[int]`):
- The id of the token that must be generated by the output.
- """
- def __init__(self, token_ids: list[int]):
- super(Constraint, self).__init__()
- if not isinstance(token_ids, list) or len(token_ids) == 0:
- raise ValueError(f"`token_ids` has to be a non-empty list, but is {token_ids}.")
- if any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids):
- raise ValueError(f"Each list in `token_ids` has to be a list of positive integers, but is {token_ids}.")
- self.token_ids = token_ids
- self.seqlen = len(self.token_ids)
- self.fulfilled_idx = -1 # the index of the currently fulfilled step
- self.completed = False
- def advance(self):
- if self.completed:
- return None
- return self.token_ids[self.fulfilled_idx + 1]
- def does_advance(self, token_id: int):
- if not isinstance(token_id, int):
- raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
- if self.completed:
- return False
- return token_id == self.token_ids[self.fulfilled_idx + 1]
- def update(self, token_id: int):
- if not isinstance(token_id, int):
- raise TypeError(f"`token_id` has to be an `int`, but is {token_id} of type {type(token_id)}")
- stepped = False
- completed = False
- reset = False
- if self.does_advance(token_id):
- self.fulfilled_idx += 1
- stepped = True
- if self.fulfilled_idx == (self.seqlen - 1):
- completed = True
- self.completed = completed
- else:
- # failed to make progress.
- reset = True
- self.reset()
- return stepped, completed, reset
- def reset(self):
- self.completed = False
- self.fulfilled_idx = 0
- def remaining(self):
- return self.seqlen - (self.fulfilled_idx + 1)
- def copy(self, stateful=False):
- new_constraint = PhrasalConstraint(self.token_ids)
- if stateful:
- new_constraint.seq_len = self.seqlen
- new_constraint.fulfilled_idx = self.fulfilled_idx
- new_constraint.completed = self.completed
- return new_constraint
- class DisjunctiveTrie:
- def __init__(self, nested_token_ids: list[list[int]], no_subsets=True):
- r"""
- A helper class that builds a trie with the words represented in `nested_token_ids`.
- """
- self.max_height = max([len(one) for one in nested_token_ids])
- root = {}
- for token_ids in nested_token_ids:
- level = root
- for tidx, token_id in enumerate(token_ids):
- if token_id not in level:
- level[token_id] = {}
- level = level[token_id]
- if no_subsets and self.has_subsets(root, nested_token_ids):
- raise ValueError(
- "Each list in `nested_token_ids` can't be a complete subset of another list, but is"
- f" {nested_token_ids}."
- )
- self.trie = root
- def next_tokens(self, current_seq):
- """
- The next possible tokens that will progress the trie, given the current sequence of tokens in `current_seq`.
- """
- start = self.trie
- for current_token in current_seq:
- start = start[current_token]
- next_tokens = list(start.keys())
- return next_tokens
- def reached_leaf(self, current_seq):
- next_tokens = self.next_tokens(current_seq)
- return len(next_tokens) == 0
- def count_leaves(self, root):
- next_nodes = list(root.values())
- if len(next_nodes) == 0:
- return 1
- else:
- return sum([self.count_leaves(nn) for nn in next_nodes])
- def has_subsets(self, trie, nested_token_ids):
- """
- Returns whether # of leaves == # of words. Otherwise some word is a subset of another.
- """
- leaf_count = self.count_leaves(trie)
- return len(nested_token_ids) != leaf_count
- class DisjunctiveConstraint(Constraint):
- r"""
- A special [`Constraint`] that is fulfilled by fulfilling just one of several constraints.
- Args:
- nested_token_ids (`list[list[int]]`):
- A list of words, where each word is a list of ids. This constraint is fulfilled by generating just one from
- the list of words.
- """
- def __init__(self, nested_token_ids: list[list[int]]):
- super(Constraint, self).__init__()
- if not isinstance(nested_token_ids, list) or len(nested_token_ids) == 0:
- raise ValueError(f"`nested_token_ids` has to be a non-empty list, but is {nested_token_ids}.")
- if any(not isinstance(token_ids, list) for token_ids in nested_token_ids):
- raise ValueError(f"`nested_token_ids` has to be a list of lists, but is {nested_token_ids}.")
- if any(
- any((not isinstance(token_id, int) or token_id < 0) for token_id in token_ids)
- for token_ids in nested_token_ids
- ):
- raise ValueError(
- f"Each list in `nested_token_ids` has to be a list of positive integers, but is {nested_token_ids}."
- )
- self.trie = DisjunctiveTrie(nested_token_ids)
- self.token_ids = nested_token_ids
- self.seqlen = self.trie.max_height
- self.current_seq = []
- self.completed = False
- def advance(self):
- token_list = self.trie.next_tokens(self.current_seq)
- if len(token_list) == 0:
- return None
- else:
- return token_list
- def does_advance(self, token_id: int):
- if not isinstance(token_id, int):
- raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
- next_tokens = self.trie.next_tokens(self.current_seq)
- return token_id in next_tokens
- def update(self, token_id: int):
- if not isinstance(token_id, int):
- raise TypeError(f"`token_id` is supposed to be type `int`, but is {token_id} of type {type(token_id)}")
- stepped = False
- completed = False
- reset = False
- if self.does_advance(token_id):
- self.current_seq.append(token_id)
- stepped = True
- else:
- reset = True
- self.reset()
- completed = self.trie.reached_leaf(self.current_seq)
- self.completed = completed
- return stepped, completed, reset
- def reset(self):
- self.completed = False
- self.current_seq = []
- def remaining(self):
- if self.completed:
- # since this can be completed without reaching max height
- return 0
- else:
- return self.seqlen - len(self.current_seq)
- def copy(self, stateful=False):
- new_constraint = DisjunctiveConstraint(self.token_ids)
- if stateful:
- new_constraint.seq_len = self.seqlen
- new_constraint.current_seq = self.current_seq
- new_constraint.completed = self.completed
- return new_constraint
- class ConstraintListState:
- r"""
- A class for beam scorers to track its progress through a list of constraints.
- Args:
- constraints (`list[Constraint]`):
- A list of [`Constraint`] objects that must be fulfilled by the beam scorer.
- """
- def __init__(self, constraints: list[Constraint]):
- self.constraints = constraints
- # max # of steps required to fulfill a given constraint
- self.max_seqlen = max([c.seqlen for c in constraints])
- self.n_constraints = len(constraints)
- self.completed = False
- self.init_state()
- def init_state(self):
- self.complete_constraints = []
- self.inprogress_constraint = None
- self.pending_constraints = [constraint.copy(stateful=False) for constraint in self.constraints]
- def get_bank(self):
- add = 0
- if self.inprogress_constraint:
- # extra points for having a constraint mid-fulfilled
- add += self.max_seqlen - self.inprogress_constraint.remaining()
- return (len(self.complete_constraints) * self.max_seqlen) + add
- def advance(self):
- """The list of tokens to generate such that we can make progress.
- By "list" we don't mean the list of token that will fully fulfill a constraint.
- Given constraints `c_i = {t_ij | j == # of tokens}`, If we're not in the middle of progressing through a
- specific constraint `c_i`, we return:
- `[t_k1 for k in indices of unfulfilled constraints]`
- If we are in the middle of a constraint, then we return:
- `[t_ij]`, where `i` is the index of the inprogress constraint, `j` is the next step for the constraint.
- Though we don't care which constraint is fulfilled first, if we are in the progress of fulfilling a constraint,
- that's the only one we'll return.
- """
- token_list = []
- if self.inprogress_constraint is None:
- for constraint in self.pending_constraints: # "pending" == "unfulfilled yet"
- advance = constraint.advance()
- if isinstance(advance, int):
- token_list.append(advance)
- elif isinstance(advance, list):
- token_list.extend(advance)
- else:
- advance = self.inprogress_constraint.advance()
- if isinstance(advance, int):
- token_list.append(advance)
- elif isinstance(advance, list):
- token_list.extend(advance)
- if len(token_list) == 0:
- return None
- else:
- return token_list
- def reset(self, token_ids: Optional[list[int]]):
- """
- token_ids: the tokens generated thus far to reset the state of the progress through constraints.
- """
- self.init_state()
- if token_ids is not None:
- for token in token_ids:
- # completes or steps **one** constraint
- complete, stepped = self.add(token)
- # the entire list of constraints are fulfilled
- if self.completed:
- break
- def add(self, token_id: int):
- if not isinstance(token_id, int):
- raise TypeError(f"`token_id` should be an `int`, but is `{token_id}`.")
- complete, stepped = False, False
- if self.completed:
- complete = True
- stepped = False
- return complete, stepped
- if self.inprogress_constraint is not None:
- # In the middle of fulfilling a constraint. If the `token_id` *does* makes an incremental progress to current
- # job, simply update the state
- stepped, complete, reset = self.inprogress_constraint.update(token_id)
- if reset:
- # 1. If the next token breaks the progress, then we must restart.
- # e.g. constraint = "I love pies" and sequence so far is "I love" but `token_id` == "books".
- # But that doesn't mean we self.init_state(), since we only reset the state for this particular
- # constraint, not the full list of constraints.
- self.pending_constraints.append(self.inprogress_constraint.copy(stateful=False))
- self.inprogress_constraint = None
- if complete:
- # 2. If the next token completes the constraint, move it to completed list, set
- # inprogress to None. If there are no pending constraints either, then this full list of constraints
- # is complete.
- self.complete_constraints.append(self.inprogress_constraint)
- self.inprogress_constraint = None
- if len(self.pending_constraints) == 0:
- # we're done!
- self.completed = True
- else:
- # Not in the middle of fulfilling a constraint. So does this `token_id` helps us step towards any of our list
- # of constraints?
- for cidx, pending_constraint in enumerate(self.pending_constraints):
- if pending_constraint.does_advance(token_id):
- stepped, complete, reset = pending_constraint.update(token_id)
- if not stepped:
- raise Exception(
- "`constraint.update(token_id)` is not yielding incremental progress, "
- "even though `constraint.does_advance(token_id)` is true."
- )
- if complete:
- self.complete_constraints.append(pending_constraint)
- self.inprogress_constraint = None
- if not complete and stepped:
- self.inprogress_constraint = pending_constraint
- if complete or stepped:
- # If we made any progress at all, then it's at least not a "pending constraint".
- self.pending_constraints = (
- self.pending_constraints[:cidx] + self.pending_constraints[cidx + 1 :]
- )
- if len(self.pending_constraints) == 0 and self.inprogress_constraint is None:
- # If there's no longer any pending after this and no inprogress either, then we must be
- # complete.
- self.completed = True
- break # prevent accidentally stepping through multiple constraints with just one token.
- return complete, stepped
- def copy(self, stateful=True):
- new_state = ConstraintListState(self.constraints) # we actually never though self.constraints objects
- # throughout this process. So it's at initialization state.
- if stateful:
- new_state.complete_constraints = [
- constraint.copy(stateful=True) for constraint in self.complete_constraints
- ]
- if self.inprogress_constraint is not None:
- new_state.inprogress_constraint = self.inprogress_constraint.copy(stateful=True)
- new_state.pending_constraints = [constraint.copy() for constraint in self.pending_constraints]
- return new_state
|