| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707 |
- import inspect
- import types
- import warnings
- from collections.abc import Iterable
- from typing import TYPE_CHECKING, Optional, Union
- import numpy as np
- from ..data import SquadExample, SquadFeatures, squad_convert_examples_to_features
- from ..modelcard import ModelCard
- from ..tokenization_utils import PreTrainedTokenizer
- from ..utils import (
- PaddingStrategy,
- add_end_docstrings,
- is_tf_available,
- is_tokenizers_available,
- is_torch_available,
- logging,
- )
- from .base import ArgumentHandler, ChunkPipeline, build_pipeline_init_args
- logger = logging.get_logger(__name__)
- if TYPE_CHECKING:
- from ..modeling_tf_utils import TFPreTrainedModel
- from ..modeling_utils import PreTrainedModel
- if is_tokenizers_available():
- import tokenizers
- if is_tf_available():
- import tensorflow as tf
- from ..models.auto.modeling_tf_auto import TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- Dataset = None
- if is_torch_available():
- import torch
- from torch.utils.data import Dataset
- from ..models.auto.modeling_auto import MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- def decode_spans(
- start: np.ndarray, end: np.ndarray, topk: int, max_answer_len: int, undesired_tokens: np.ndarray
- ) -> tuple:
- """
- Take the output of any `ModelForQuestionAnswering` and will generate probabilities for each span to be the actual
- answer.
- In addition, it filters out some unwanted/impossible cases like answer len being greater than max_answer_len or
- answer end position being before the starting position. The method supports output the k-best answer through the
- topk argument.
- Args:
- start (`np.ndarray`): Individual start probabilities for each token.
- end (`np.ndarray`): Individual end probabilities for each token.
- topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
- max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
- undesired_tokens (`np.ndarray`): Mask determining tokens that can be part of the answer
- """
- # Ensure we have batch axis
- if start.ndim == 1:
- start = start[None]
- if end.ndim == 1:
- end = end[None]
- # Compute the score of each tuple(start, end) to be the real answer
- outer = np.matmul(np.expand_dims(start, -1), np.expand_dims(end, 1))
- # Remove candidate with end < start and end - start > max_answer_len
- candidates = np.tril(np.triu(outer), max_answer_len - 1)
- # Inspired by Chen & al. (https://github.com/facebookresearch/DrQA)
- scores_flat = candidates.flatten()
- if topk == 1:
- idx_sort = [np.argmax(scores_flat)]
- elif len(scores_flat) < topk:
- idx_sort = np.argsort(-scores_flat)
- else:
- idx = np.argpartition(-scores_flat, topk)[0:topk]
- idx_sort = idx[np.argsort(-scores_flat[idx])]
- starts, ends = np.unravel_index(idx_sort, candidates.shape)[1:]
- desired_spans = np.isin(starts, undesired_tokens.nonzero()) & np.isin(ends, undesired_tokens.nonzero())
- starts = starts[desired_spans]
- ends = ends[desired_spans]
- scores = candidates[0, starts, ends]
- return starts, ends, scores
- def select_starts_ends(
- start,
- end,
- p_mask,
- attention_mask,
- min_null_score=1000000,
- top_k=1,
- handle_impossible_answer=False,
- max_answer_len=15,
- ):
- """
- Takes the raw output of any `ModelForQuestionAnswering` and first normalizes its outputs and then uses
- `decode_spans()` to generate probabilities for each span to be the actual answer.
- Args:
- start (`np.ndarray`): Individual start logits for each token.
- end (`np.ndarray`): Individual end logits for each token.
- p_mask (`np.ndarray`): A mask with 1 for values that cannot be in the answer
- attention_mask (`np.ndarray`): The attention mask generated by the tokenizer
- min_null_score(`float`): The minimum null (empty) answer score seen so far.
- topk (`int`): Indicates how many possible answer span(s) to extract from the model output.
- handle_impossible_answer(`bool`): Whether to allow null (empty) answers
- max_answer_len (`int`): Maximum size of the answer to extract from the model's output.
- """
- # Ensure padded tokens & question tokens cannot belong to the set of candidate answers.
- undesired_tokens = np.abs(np.array(p_mask) - 1)
- if attention_mask is not None:
- undesired_tokens = undesired_tokens & attention_mask
- # Generate mask
- undesired_tokens_mask = undesired_tokens == 0.0
- # Make sure non-context indexes in the tensor cannot contribute to the softmax
- start = np.where(undesired_tokens_mask, -10000.0, start)
- end = np.where(undesired_tokens_mask, -10000.0, end)
- # Normalize logits and spans to retrieve the answer
- start = np.exp(start - start.max(axis=-1, keepdims=True))
- start = start / start.sum()
- end = np.exp(end - end.max(axis=-1, keepdims=True))
- end = end / end.sum()
- if handle_impossible_answer:
- min_null_score = min(min_null_score, (start[0, 0] * end[0, 0]).item())
- # Mask CLS
- start[0, 0] = end[0, 0] = 0.0
- starts, ends, scores = decode_spans(start, end, top_k, max_answer_len, undesired_tokens)
- return starts, ends, scores, min_null_score
- class QuestionAnsweringArgumentHandler(ArgumentHandler):
- """
- QuestionAnsweringPipeline requires the user to provide multiple arguments (i.e. question & context) to be mapped to
- internal [`SquadExample`].
- QuestionAnsweringArgumentHandler manages all the possible to create a [`SquadExample`] from the command-line
- supplied arguments.
- """
- _load_processor = False
- _load_image_processor = False
- _load_feature_extractor = False
- _load_tokenizer = True
- def normalize(self, item):
- if isinstance(item, SquadExample):
- return item
- elif isinstance(item, dict):
- for k in ["question", "context"]:
- if k not in item:
- raise KeyError("You need to provide a dictionary with keys {question:..., context:...}")
- elif item[k] is None:
- raise ValueError(f"`{k}` cannot be None")
- elif isinstance(item[k], str) and len(item[k]) == 0:
- raise ValueError(f"`{k}` cannot be empty")
- return QuestionAnsweringPipeline.create_sample(**item)
- raise ValueError(f"{item} argument needs to be of type (SquadExample, dict)")
- def __call__(self, *args, **kwargs):
- # Detect where the actual inputs are
- if args is not None and len(args) > 0:
- if len(args) == 1:
- inputs = args[0]
- elif len(args) == 2 and {type(el) for el in args} == {str}:
- inputs = [{"question": args[0], "context": args[1]}]
- else:
- inputs = list(args)
- # Generic compatibility with sklearn and Keras
- # Batched data
- elif "X" in kwargs:
- warnings.warn(
- "Passing the `X` argument to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.",
- FutureWarning,
- )
- inputs = kwargs["X"]
- elif "data" in kwargs:
- warnings.warn(
- "Passing the `data` argument to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.",
- FutureWarning,
- )
- inputs = kwargs["data"]
- elif "question" in kwargs and "context" in kwargs:
- if isinstance(kwargs["question"], list) and isinstance(kwargs["context"], str):
- inputs = [{"question": Q, "context": kwargs["context"]} for Q in kwargs["question"]]
- elif isinstance(kwargs["question"], list) and isinstance(kwargs["context"], list):
- if len(kwargs["question"]) != len(kwargs["context"]):
- raise ValueError("Questions and contexts don't have the same lengths")
- inputs = [{"question": Q, "context": C} for Q, C in zip(kwargs["question"], kwargs["context"])]
- elif isinstance(kwargs["question"], str) and isinstance(kwargs["context"], str):
- inputs = [{"question": kwargs["question"], "context": kwargs["context"]}]
- else:
- raise ValueError("Arguments can't be understood")
- else:
- raise ValueError(f"Unknown arguments {kwargs}")
- # When user is sending a generator we need to trust it's a valid example
- generator_types = (types.GeneratorType, Dataset) if Dataset is not None else (types.GeneratorType,)
- if isinstance(inputs, generator_types):
- return inputs
- # Normalize inputs
- if isinstance(inputs, dict):
- inputs = [inputs]
- elif isinstance(inputs, Iterable):
- # Copy to avoid overriding arguments
- inputs = list(inputs)
- else:
- raise ValueError(f"Invalid arguments {kwargs}")
- for i, item in enumerate(inputs):
- inputs[i] = self.normalize(item)
- return inputs
- @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
- class QuestionAnsweringPipeline(ChunkPipeline):
- """
- Question Answering pipeline using any `ModelForQuestionAnswering`. See the [question answering
- examples](../task_summary#question-answering) for more information.
- Example:
- ```python
- >>> from transformers import pipeline
- >>> oracle = pipeline(model="deepset/roberta-base-squad2")
- >>> oracle(question="Where do I live?", context="My name is Wolfgang and I live in Berlin")
- {'score': 0.9191, 'start': 34, 'end': 40, 'answer': 'Berlin'}
- ```
- Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
- This question answering pipeline can currently be loaded from [`pipeline`] using the following task identifier:
- `"question-answering"`.
- The models that this pipeline can use are models that have been fine-tuned on a question answering task. See the
- up-to-date list of available models on
- [huggingface.co/models](https://huggingface.co/models?filter=question-answering).
- """
- default_input_names = "question,context"
- handle_impossible_answer = False
- def __init__(
- self,
- model: Union["PreTrainedModel", "TFPreTrainedModel"],
- tokenizer: PreTrainedTokenizer,
- modelcard: Optional[ModelCard] = None,
- framework: Optional[str] = None,
- task: str = "",
- **kwargs,
- ):
- super().__init__(
- model=model,
- tokenizer=tokenizer,
- modelcard=modelcard,
- framework=framework,
- task=task,
- **kwargs,
- )
- self._args_parser = QuestionAnsweringArgumentHandler()
- self.check_model_type(
- TF_MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- if self.framework == "tf"
- else MODEL_FOR_QUESTION_ANSWERING_MAPPING_NAMES
- )
- @staticmethod
- def create_sample(
- question: Union[str, list[str]], context: Union[str, list[str]]
- ) -> Union[SquadExample, list[SquadExample]]:
- """
- QuestionAnsweringPipeline leverages the [`SquadExample`] internally. This helper method encapsulate all the
- logic for converting question(s) and context(s) to [`SquadExample`].
- We currently support extractive question answering.
- Arguments:
- question (`str` or `list[str]`): The question(s) asked.
- context (`str` or `list[str]`): The context(s) in which we will look for the answer.
- Returns:
- One or a list of [`SquadExample`]: The corresponding [`SquadExample`] grouping question and context.
- """
- if isinstance(question, list):
- return [SquadExample(None, q, c, None, None, None) for q, c in zip(question, context)]
- else:
- return SquadExample(None, question, context, None, None, None)
- def _sanitize_parameters(
- self,
- padding=None,
- topk=None,
- top_k=None,
- doc_stride=None,
- max_answer_len=None,
- max_seq_len=None,
- max_question_len=None,
- handle_impossible_answer=None,
- align_to_words=None,
- **kwargs,
- ):
- # Set defaults values
- preprocess_params = {}
- if padding is not None:
- preprocess_params["padding"] = padding
- if doc_stride is not None:
- preprocess_params["doc_stride"] = doc_stride
- if max_question_len is not None:
- preprocess_params["max_question_len"] = max_question_len
- if max_seq_len is not None:
- preprocess_params["max_seq_len"] = max_seq_len
- postprocess_params = {}
- if topk is not None and top_k is None:
- warnings.warn("topk parameter is deprecated, use top_k instead", UserWarning)
- top_k = topk
- if top_k is not None:
- if top_k < 1:
- raise ValueError(f"top_k parameter should be >= 1 (got {top_k})")
- postprocess_params["top_k"] = top_k
- if max_answer_len is not None:
- if max_answer_len < 1:
- raise ValueError(f"max_answer_len parameter should be >= 1 (got {max_answer_len}")
- postprocess_params["max_answer_len"] = max_answer_len
- if handle_impossible_answer is not None:
- postprocess_params["handle_impossible_answer"] = handle_impossible_answer
- if align_to_words is not None:
- postprocess_params["align_to_words"] = align_to_words
- return preprocess_params, {}, postprocess_params
- def __call__(self, *args, **kwargs):
- """
- Answer the question(s) given as inputs by using the context(s).
- Args:
- question (`str` or `list[str]`):
- One or several question(s) (must be used in conjunction with the `context` argument).
- context (`str` or `list[str]`):
- One or several context(s) associated with the question(s) (must be used in conjunction with the
- `question` argument).
- top_k (`int`, *optional*, defaults to 1):
- The number of answers to return (will be chosen by order of likelihood). Note that we return less than
- top_k answers if there are not enough options available within the context.
- doc_stride (`int`, *optional*, defaults to 128):
- If the context is too long to fit with the question for the model, it will be split in several chunks
- with some overlap. This argument controls the size of that overlap.
- max_answer_len (`int`, *optional*, defaults to 15):
- The maximum length of predicted answers (e.g., only answers with a shorter length are considered).
- max_seq_len (`int`, *optional*, defaults to 384):
- The maximum length of the total sentence (context + question) in tokens of each chunk passed to the
- model. The context will be split in several chunks (using `doc_stride` as overlap) if needed.
- max_question_len (`int`, *optional*, defaults to 64):
- The maximum length of the question after tokenization. It will be truncated if needed.
- handle_impossible_answer (`bool`, *optional*, defaults to `False`):
- Whether or not we accept impossible as an answer.
- align_to_words (`bool`, *optional*, defaults to `True`):
- Attempts to align the answer to real words. Improves quality on space separated languages. Might hurt on
- non-space-separated languages (like Japanese or Chinese)
- Return:
- A `dict` or a list of `dict`: Each result comes as a dictionary with the following keys:
- - **score** (`float`) -- The probability associated to the answer.
- - **start** (`int`) -- The character start index of the answer (in the tokenized version of the input).
- - **end** (`int`) -- The character end index of the answer (in the tokenized version of the input).
- - **answer** (`str`) -- The answer to the question.
- """
- # Convert inputs to features
- if args:
- warnings.warn(
- "Passing a list of SQuAD examples to the pipeline is deprecated and will be removed in v5. Inputs should be passed using the `question` and `context` keyword arguments instead.",
- FutureWarning,
- )
- examples = self._args_parser(*args, **kwargs)
- if isinstance(examples, (list, tuple)) and len(examples) == 1:
- return super().__call__(examples[0], **kwargs)
- return super().__call__(examples, **kwargs)
- def preprocess(self, example, padding="do_not_pad", doc_stride=None, max_question_len=64, max_seq_len=None):
- # XXX: This is special, args_parser will not handle anything generator or dataset like
- # For those we expect user to send a simple valid example either directly as a SquadExample or simple dict.
- # So we still need a little sanitation here.
- if isinstance(example, dict):
- example = SquadExample(None, example["question"], example["context"], None, None, None)
- if max_seq_len is None:
- max_seq_len = min(self.tokenizer.model_max_length, 384)
- if doc_stride is None:
- doc_stride = min(max_seq_len // 2, 128)
- if doc_stride > max_seq_len:
- raise ValueError(f"`doc_stride` ({doc_stride}) is larger than `max_seq_len` ({max_seq_len})")
- if not self.tokenizer.is_fast:
- features = squad_convert_examples_to_features(
- examples=[example],
- tokenizer=self.tokenizer,
- max_seq_length=max_seq_len,
- doc_stride=doc_stride,
- max_query_length=max_question_len,
- padding_strategy=PaddingStrategy.MAX_LENGTH,
- is_training=False,
- tqdm_enabled=False,
- )
- else:
- # Define the side we want to truncate / pad and the text/pair sorting
- question_first = self.tokenizer.padding_side == "right"
- encoded_inputs = self.tokenizer(
- text=example.question_text if question_first else example.context_text,
- text_pair=example.context_text if question_first else example.question_text,
- padding=padding,
- truncation="only_second" if question_first else "only_first",
- max_length=max_seq_len,
- stride=doc_stride,
- return_token_type_ids=True,
- return_overflowing_tokens=True,
- return_offsets_mapping=True,
- return_special_tokens_mask=True,
- )
- # When the input is too long, it's converted in a batch of inputs with overflowing tokens
- # and a stride of overlap between the inputs. If a batch of inputs is given, a special output
- # "overflow_to_sample_mapping" indicate which member of the encoded batch belong to which original batch sample.
- # Here we tokenize examples one-by-one so we don't need to use "overflow_to_sample_mapping".
- # "num_span" is the number of output samples generated from the overflowing tokens.
- num_spans = len(encoded_inputs["input_ids"])
- # p_mask: mask with 1 for token than cannot be in the answer (0 for token which can be in an answer)
- # We put 0 on the tokens from the context and 1 everywhere else (question and special tokens)
- p_mask = [
- [tok != 1 if question_first else 0 for tok in encoded_inputs.sequence_ids(span_id)]
- for span_id in range(num_spans)
- ]
- features = []
- for span_idx in range(num_spans):
- input_ids_span_idx = encoded_inputs["input_ids"][span_idx]
- attention_mask_span_idx = (
- encoded_inputs["attention_mask"][span_idx] if "attention_mask" in encoded_inputs else None
- )
- token_type_ids_span_idx = (
- encoded_inputs["token_type_ids"][span_idx] if "token_type_ids" in encoded_inputs else None
- )
- # keep the cls_token unmasked (some models use it to indicate unanswerable questions)
- if self.tokenizer.cls_token_id is not None:
- cls_indices = np.nonzero(np.array(input_ids_span_idx) == self.tokenizer.cls_token_id)[0]
- for cls_index in cls_indices:
- p_mask[span_idx][cls_index] = 0
- submask = p_mask[span_idx]
- features.append(
- SquadFeatures(
- input_ids=input_ids_span_idx,
- attention_mask=attention_mask_span_idx,
- token_type_ids=token_type_ids_span_idx,
- p_mask=submask,
- encoding=encoded_inputs[span_idx],
- # We don't use the rest of the values - and actually
- # for Fast tokenizer we could totally avoid using SquadFeatures and SquadExample
- cls_index=None,
- token_to_orig_map={},
- example_index=0,
- unique_id=0,
- paragraph_len=0,
- token_is_max_context=0,
- tokens=[],
- start_position=0,
- end_position=0,
- is_impossible=False,
- qas_id=None,
- )
- )
- for i, feature in enumerate(features):
- fw_args = {}
- others = {}
- model_input_names = self.tokenizer.model_input_names + ["p_mask", "token_type_ids"]
- for k, v in feature.__dict__.items():
- if k in model_input_names:
- if self.framework == "tf":
- tensor = tf.constant(v)
- if tensor.dtype == tf.int64:
- tensor = tf.cast(tensor, tf.int32)
- fw_args[k] = tf.expand_dims(tensor, 0)
- elif self.framework == "pt":
- tensor = torch.tensor(v)
- if tensor.dtype == torch.int32:
- tensor = tensor.long()
- fw_args[k] = tensor.unsqueeze(0)
- else:
- others[k] = v
- is_last = i == len(features) - 1
- yield {"example": example, "is_last": is_last, **fw_args, **others}
- def _forward(self, inputs):
- example = inputs["example"]
- model_inputs = {k: inputs[k] for k in self.tokenizer.model_input_names}
- # `XXXForSequenceClassification` models should not use `use_cache=True` even if it's supported
- model_forward = self.model.forward if self.framework == "pt" else self.model.call
- if "use_cache" in inspect.signature(model_forward).parameters:
- model_inputs["use_cache"] = False
- output = self.model(**model_inputs)
- if isinstance(output, dict):
- return {"start": output["start_logits"], "end": output["end_logits"], "example": example, **inputs}
- else:
- start, end = output[:2]
- return {"start": start, "end": end, "example": example, **inputs}
- def postprocess(
- self,
- model_outputs,
- top_k=1,
- handle_impossible_answer=False,
- max_answer_len=15,
- align_to_words=True,
- ):
- min_null_score = 1000000 # large and positive
- answers = []
- for output in model_outputs:
- if self.framework == "pt" and output["start"].dtype == torch.bfloat16:
- start_ = output["start"].to(torch.float32)
- end_ = output["end"].to(torch.float32)
- else:
- start_ = output["start"]
- end_ = output["end"]
- example = output["example"]
- p_mask = output["p_mask"]
- attention_mask = (
- output["attention_mask"].numpy() if output.get("attention_mask", None) is not None else None
- )
- pre_topk = (
- top_k * 2 + 10 if align_to_words else top_k
- ) # Some candidates may be deleted if we align to words
- starts, ends, scores, min_null_score = select_starts_ends(
- start_,
- end_,
- p_mask,
- attention_mask,
- min_null_score,
- pre_topk,
- handle_impossible_answer,
- max_answer_len,
- )
- if not self.tokenizer.is_fast:
- char_to_word = np.array(example.char_to_word_offset)
- # Convert the answer (tokens) back to the original text
- # Score: score from the model
- # Start: Index of the first character of the answer in the context string
- # End: Index of the character following the last character of the answer in the context string
- # Answer: Plain text of the answer
- for s, e, score in zip(starts, ends, scores):
- token_to_orig_map = output["token_to_orig_map"]
- answers.append(
- {
- "score": score.item(),
- "start": np.where(char_to_word == token_to_orig_map[s])[0][0].item(),
- "end": np.where(char_to_word == token_to_orig_map[e])[0][-1].item(),
- "answer": " ".join(example.doc_tokens[token_to_orig_map[s] : token_to_orig_map[e] + 1]),
- }
- )
- else:
- # Convert the answer (tokens) back to the original text
- # Score: score from the model
- # Start: Index of the first character of the answer in the context string
- # End: Index of the character following the last character of the answer in the context string
- # Answer: Plain text of the answer
- question_first = self.tokenizer.padding_side == "right"
- enc = output["encoding"]
- # Encoding was *not* padded, input_ids *might*.
- # It doesn't make a difference unless we're padding on
- # the left hand side, since now we have different offsets
- # everywhere.
- if self.tokenizer.padding_side == "left":
- offset = (output["input_ids"] == self.tokenizer.pad_token_id).numpy().sum()
- else:
- offset = 0
- # Sometimes the max probability token is in the middle of a word so:
- # - we start by finding the right word containing the token with `token_to_word`
- # - then we convert this word in a character span with `word_to_chars`
- sequence_index = 1 if question_first else 0
- for s, e, score in zip(starts, ends, scores):
- s = s - offset
- e = e - offset
- start_index, end_index = self.get_indices(enc, s, e, sequence_index, align_to_words)
- target_answer = example.context_text[start_index:end_index]
- answer = self.get_answer(answers, target_answer)
- if answer:
- answer["score"] += score.item()
- else:
- answers.append(
- {
- "score": score.item(),
- "start": start_index,
- "end": end_index,
- "answer": example.context_text[start_index:end_index],
- }
- )
- if handle_impossible_answer:
- answers.append({"score": min_null_score, "start": 0, "end": 0, "answer": ""})
- answers = sorted(answers, key=lambda x: x["score"], reverse=True)[:top_k]
- if len(answers) == 1:
- return answers[0]
- return answers
- def get_answer(self, answers: list[dict], target: str) -> Optional[dict]:
- for answer in answers:
- if answer["answer"].lower() == target.lower():
- return answer
- return None
- def get_indices(
- self, enc: "tokenizers.Encoding", s: int, e: int, sequence_index: int, align_to_words: bool
- ) -> tuple[int, int]:
- if align_to_words:
- try:
- start_word = enc.token_to_word(s)
- end_word = enc.token_to_word(e)
- start_index = enc.word_to_chars(start_word, sequence_index=sequence_index)[0]
- end_index = enc.word_to_chars(end_word, sequence_index=sequence_index)[1]
- except Exception:
- # Some tokenizers don't really handle words. Keep to offsets then.
- start_index = enc.offsets[s][0]
- end_index = enc.offsets[e][1]
- else:
- start_index = enc.offsets[s][0]
- end_index = enc.offsets[e][1]
- return start_index, end_index
- def span_to_answer(self, text: str, start: int, end: int) -> dict[str, Union[str, int]]:
- """
- When decoding from token probabilities, this method maps token indexes to actual word in the initial context.
- Args:
- text (`str`): The actual context to extract the answer from.
- start (`int`): The answer starting token index.
- end (`int`): The answer end token index.
- Returns:
- Dictionary like `{'answer': str, 'start': int, 'end': int}`
- """
- words = []
- token_idx = char_start_idx = char_end_idx = chars_idx = 0
- for word in text.split(" "):
- token = self.tokenizer.tokenize(word)
- # Append words if they are in the span
- if start <= token_idx <= end:
- if token_idx == start:
- char_start_idx = chars_idx
- if token_idx == end:
- char_end_idx = chars_idx + len(word)
- words += [word]
- # Stop if we went over the end of the answer
- if token_idx > end:
- break
- # Append the subtokenization length to the running index
- token_idx += len(token)
- chars_idx += len(word) + 1
- # Join text with spaces
- return {
- "answer": " ".join(words),
- "start": max(0, char_start_idx),
- "end": min(len(text), char_end_idx),
- }
|