zero_shot_classification.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271
  1. import inspect
  2. from typing import Union
  3. import numpy as np
  4. from ..tokenization_utils import TruncationStrategy
  5. from ..utils import add_end_docstrings, logging
  6. from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args
  7. logger = logging.get_logger(__name__)
  8. class ZeroShotClassificationArgumentHandler(ArgumentHandler):
  9. """
  10. Handles arguments for zero-shot for text classification by turning each possible label into an NLI
  11. premise/hypothesis pair.
  12. """
  13. def _parse_labels(self, labels):
  14. if isinstance(labels, str):
  15. labels = [label.strip() for label in labels.split(",") if label.strip()]
  16. return labels
  17. def __call__(self, sequences, labels, hypothesis_template):
  18. if len(labels) == 0 or len(sequences) == 0:
  19. raise ValueError("You must include at least one label and at least one sequence.")
  20. if hypothesis_template.format(labels[0]) == hypothesis_template:
  21. raise ValueError(
  22. f'The provided hypothesis_template "{hypothesis_template}" was not able to be formatted with the target labels. '
  23. "Make sure the passed template includes formatting syntax such as {} where the label should go."
  24. )
  25. if isinstance(sequences, str):
  26. sequences = [sequences]
  27. sequence_pairs = []
  28. for sequence in sequences:
  29. sequence_pairs.extend([[sequence, hypothesis_template.format(label)] for label in labels])
  30. return sequence_pairs, sequences
  31. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  32. class ZeroShotClassificationPipeline(ChunkPipeline):
  33. """
  34. NLI-based zero-shot classification pipeline using a `ModelForSequenceClassification` trained on NLI (natural
  35. language inference) tasks. Equivalent of `text-classification` pipelines, but these models don't require a
  36. hardcoded number of potential classes, they can be chosen at runtime. It usually means it's slower but it is
  37. **much** more flexible.
  38. Any combination of sequences and labels can be passed and each combination will be posed as a premise/hypothesis
  39. pair and passed to the pretrained model. Then, the logit for *entailment* is taken as the logit for the candidate
  40. label being valid. Any NLI model can be used, but the id of the *entailment* label must be included in the model
  41. config's :attr:*~transformers.PretrainedConfig.label2id*.
  42. Example:
  43. ```python
  44. >>> from transformers import pipeline
  45. >>> oracle = pipeline(model="facebook/bart-large-mnli")
  46. >>> oracle(
  47. ... "I have a problem with my iphone that needs to be resolved asap!!",
  48. ... candidate_labels=["urgent", "not urgent", "phone", "tablet", "computer"],
  49. ... )
  50. {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['urgent', 'phone', 'computer', 'not urgent', 'tablet'], 'scores': [0.504, 0.479, 0.013, 0.003, 0.002]}
  51. >>> oracle(
  52. ... "I have a problem with my iphone that needs to be resolved asap!!",
  53. ... candidate_labels=["english", "german"],
  54. ... )
  55. {'sequence': 'I have a problem with my iphone that needs to be resolved asap!!', 'labels': ['english', 'german'], 'scores': [0.814, 0.186]}
  56. ```
  57. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  58. This NLI pipeline can currently be loaded from [`pipeline`] using the following task identifier:
  59. `"zero-shot-classification"`.
  60. The models that this pipeline can use are models that have been fine-tuned on an NLI task. See the up-to-date list
  61. of available models on [huggingface.co/models](https://huggingface.co/models?search=nli).
  62. """
  63. _load_processor = False
  64. _load_image_processor = False
  65. _load_feature_extractor = False
  66. _load_tokenizer = True
  67. def __init__(self, args_parser=ZeroShotClassificationArgumentHandler(), **kwargs):
  68. self._args_parser = args_parser
  69. super().__init__(**kwargs)
  70. if self.entailment_id == -1:
  71. logger.warning(
  72. "Failed to determine 'entailment' label id from the label2id mapping in the model config. Setting to "
  73. "-1. Define a descriptive label2id mapping in the model config to ensure correct outputs."
  74. )
  75. @property
  76. def entailment_id(self):
  77. for label, ind in self.model.config.label2id.items():
  78. if label.lower().startswith("entail"):
  79. return ind
  80. return -1
  81. def _parse_and_tokenize(
  82. self, sequence_pairs, padding=True, add_special_tokens=True, truncation=TruncationStrategy.ONLY_FIRST, **kwargs
  83. ):
  84. """
  85. Parse arguments and tokenize only_first so that hypothesis (label) is not truncated
  86. """
  87. return_tensors = self.framework
  88. if self.tokenizer.pad_token is None:
  89. # Override for tokenizers not supporting padding
  90. logger.error(
  91. "Tokenizer was not supporting padding necessary for zero-shot, attempting to use "
  92. " `pad_token=eos_token`"
  93. )
  94. self.tokenizer.pad_token = self.tokenizer.eos_token
  95. try:
  96. inputs = self.tokenizer(
  97. sequence_pairs,
  98. add_special_tokens=add_special_tokens,
  99. return_tensors=return_tensors,
  100. padding=padding,
  101. truncation=truncation,
  102. )
  103. except Exception as e:
  104. if "too short" in str(e):
  105. # tokenizers might yell that we want to truncate
  106. # to a value that is not even reached by the input.
  107. # In that case we don't want to truncate.
  108. # It seems there's not a really better way to catch that
  109. # exception.
  110. inputs = self.tokenizer(
  111. sequence_pairs,
  112. add_special_tokens=add_special_tokens,
  113. return_tensors=return_tensors,
  114. padding=padding,
  115. truncation=TruncationStrategy.DO_NOT_TRUNCATE,
  116. )
  117. else:
  118. raise e
  119. return inputs
  120. def _sanitize_parameters(self, **kwargs):
  121. if kwargs.get("multi_class") is not None:
  122. kwargs["multi_label"] = kwargs["multi_class"]
  123. logger.warning(
  124. "The `multi_class` argument has been deprecated and renamed to `multi_label`. "
  125. "`multi_class` will be removed in a future version of Transformers."
  126. )
  127. preprocess_params = {}
  128. if "candidate_labels" in kwargs:
  129. preprocess_params["candidate_labels"] = self._args_parser._parse_labels(kwargs["candidate_labels"])
  130. if "hypothesis_template" in kwargs:
  131. preprocess_params["hypothesis_template"] = kwargs["hypothesis_template"]
  132. postprocess_params = {}
  133. if "multi_label" in kwargs:
  134. postprocess_params["multi_label"] = kwargs["multi_label"]
  135. return preprocess_params, {}, postprocess_params
  136. def __call__(
  137. self,
  138. sequences: Union[str, list[str]],
  139. *args,
  140. **kwargs,
  141. ):
  142. """
  143. Classify the sequence(s) given as inputs. See the [`ZeroShotClassificationPipeline`] documentation for more
  144. information.
  145. Args:
  146. sequences (`str` or `list[str]`):
  147. The sequence(s) to classify, will be truncated if the model input is too large.
  148. candidate_labels (`str` or `list[str]`):
  149. The set of possible class labels to classify each sequence into. Can be a single label, a string of
  150. comma-separated labels, or a list of labels.
  151. hypothesis_template (`str`, *optional*, defaults to `"This example is {}."`):
  152. The template used to turn each label into an NLI-style hypothesis. This template must include a {} or
  153. similar syntax for the candidate label to be inserted into the template. For example, the default
  154. template is `"This example is {}."` With the candidate label `"sports"`, this would be fed into the
  155. model like `"<cls> sequence to classify <sep> This example is sports . <sep>"`. The default template
  156. works well in many cases, but it may be worthwhile to experiment with different templates depending on
  157. the task setting.
  158. multi_label (`bool`, *optional*, defaults to `False`):
  159. Whether or not multiple candidate labels can be true. If `False`, the scores are normalized such that
  160. the sum of the label likelihoods for each sequence is 1. If `True`, the labels are considered
  161. independent and probabilities are normalized for each candidate by doing a softmax of the entailment
  162. score vs. the contradiction score.
  163. Return:
  164. A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
  165. - **sequence** (`str`) -- The sequence for which this is the output.
  166. - **labels** (`list[str]`) -- The labels sorted by order of likelihood.
  167. - **scores** (`list[float]`) -- The probabilities for each of the labels.
  168. """
  169. if len(args) == 0:
  170. pass
  171. elif len(args) == 1 and "candidate_labels" not in kwargs:
  172. kwargs["candidate_labels"] = args[0]
  173. else:
  174. raise ValueError(f"Unable to understand extra arguments {args}")
  175. return super().__call__(sequences, **kwargs)
  176. def preprocess(self, inputs, candidate_labels=None, hypothesis_template="This example is {}."):
  177. sequence_pairs, sequences = self._args_parser(inputs, candidate_labels, hypothesis_template)
  178. for i, (candidate_label, sequence_pair) in enumerate(zip(candidate_labels, sequence_pairs)):
  179. model_input = self._parse_and_tokenize([sequence_pair])
  180. yield {
  181. "candidate_label": candidate_label,
  182. "sequence": sequences[0],
  183. "is_last": i == len(candidate_labels) - 1,
  184. **model_input,
  185. }
  186. def _forward(self, inputs):
  187. candidate_label = inputs["candidate_label"]
  188. sequence = inputs["sequence"]
  189. model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
  190. # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
  191. model_forward = self.model.forward if self.framework == "pt" else self.model.call
  192. if "use_cache" in inspect.signature(model_forward).parameters:
  193. model_inputs["use_cache"] = False
  194. outputs = self.model(**model_inputs)
  195. model_outputs = {
  196. "candidate_label": candidate_label,
  197. "sequence": sequence,
  198. "is_last": inputs["is_last"],
  199. **outputs,
  200. }
  201. return model_outputs
  202. def postprocess(self, model_outputs, multi_label=False):
  203. candidate_labels = [outputs["candidate_label"] for outputs in model_outputs]
  204. sequences = [outputs["sequence"] for outputs in model_outputs]
  205. if self.framework == "pt":
  206. logits = np.concatenate([output["logits"].float().numpy() for output in model_outputs])
  207. else:
  208. logits = np.concatenate([output["logits"].numpy() for output in model_outputs])
  209. N = logits.shape[0]
  210. n = len(candidate_labels)
  211. num_sequences = N // n
  212. reshaped_outputs = logits.reshape((num_sequences, n, -1))
  213. if multi_label or len(candidate_labels) == 1:
  214. # softmax over the entailment vs. contradiction dim for each label independently
  215. entailment_id = self.entailment_id
  216. contradiction_id = -1 if entailment_id == 0 else 0
  217. entail_contr_logits = reshaped_outputs[..., [contradiction_id, entailment_id]]
  218. scores = np.exp(entail_contr_logits) / np.exp(entail_contr_logits).sum(-1, keepdims=True)
  219. scores = scores[..., 1]
  220. else:
  221. # softmax the "entailment" logits over all candidate labels
  222. entail_logits = reshaped_outputs[..., self.entailment_id]
  223. scores = np.exp(entail_logits) / np.exp(entail_logits).sum(-1, keepdims=True)
  224. top_inds = list(reversed(scores[0].argsort()))
  225. return {
  226. "sequence": sequences[0],
  227. "labels": [candidate_labels[i] for i in top_inds],
  228. "scores": scores[0, top_inds].tolist(),
  229. }