table_question_answering.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457
  1. import collections
  2. import types
  3. import numpy as np
  4. from ..generation import GenerationConfig
  5. from ..utils import (
  6. add_end_docstrings,
  7. is_tf_available,
  8. is_torch_available,
  9. requires_backends,
  10. )
  11. from .base import ArgumentHandler, Dataset, Pipeline, PipelineException, build_pipeline_init_args
  12. if is_torch_available():
  13. import torch
  14. from ..models.auto.modeling_auto import (
  15. MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  16. MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
  17. )
  18. if is_tf_available():
  19. import tensorflow as tf
  20. from ..models.auto.modeling_tf_auto import (
  21. TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES,
  22. TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES,
  23. )
  24. class TableQuestionAnsweringArgumentHandler(ArgumentHandler):
  25. """
  26. Handles arguments for the TableQuestionAnsweringPipeline
  27. """
  28. def __call__(self, table=None, query=None, **kwargs):
  29. # Returns tqa_pipeline_inputs of shape:
  30. # [
  31. # {"table": pd.DataFrame, "query": list[str]},
  32. # ...,
  33. # {"table": pd.DataFrame, "query" : list[str]}
  34. # ]
  35. requires_backends(self, "pandas")
  36. import pandas as pd
  37. if table is None:
  38. raise ValueError("Keyword argument `table` cannot be None.")
  39. elif query is None:
  40. if isinstance(table, dict) and table.get("query") is not None and table.get("table") is not None:
  41. tqa_pipeline_inputs = [table]
  42. elif isinstance(table, list) and len(table) > 0:
  43. if not all(isinstance(d, dict) for d in table):
  44. raise ValueError(
  45. f"Keyword argument `table` should be a list of dict, but is {(type(d) for d in table)}"
  46. )
  47. if table[0].get("query") is not None and table[0].get("table") is not None:
  48. tqa_pipeline_inputs = table
  49. else:
  50. raise ValueError(
  51. "If keyword argument `table` is a list of dictionaries, each dictionary should have a `table`"
  52. f" and `query` key, but only dictionary has keys {table[0].keys()} `table` and `query` keys."
  53. )
  54. elif Dataset is not None and isinstance(table, Dataset) or isinstance(table, types.GeneratorType):
  55. return table
  56. else:
  57. raise ValueError(
  58. "Invalid input. Keyword argument `table` should be either of type `dict` or `list`, but "
  59. f"is {type(table)})"
  60. )
  61. else:
  62. tqa_pipeline_inputs = [{"table": table, "query": query}]
  63. for tqa_pipeline_input in tqa_pipeline_inputs:
  64. if not isinstance(tqa_pipeline_input["table"], pd.DataFrame):
  65. if tqa_pipeline_input["table"] is None:
  66. raise ValueError("Table cannot be None.")
  67. tqa_pipeline_input["table"] = pd.DataFrame(tqa_pipeline_input["table"])
  68. return tqa_pipeline_inputs
  69. @add_end_docstrings(build_pipeline_init_args(has_tokenizer=True))
  70. class TableQuestionAnsweringPipeline(Pipeline):
  71. """
  72. Table Question Answering pipeline using a `ModelForTableQuestionAnswering`. This pipeline is only available in
  73. PyTorch.
  74. Unless the model you're using explicitly sets these generation parameters in its configuration files
  75. (`generation_config.json`), the following default values will be used:
  76. - max_new_tokens: 256
  77. Example:
  78. ```python
  79. >>> from transformers import pipeline
  80. >>> oracle = pipeline(model="google/tapas-base-finetuned-wtq")
  81. >>> table = {
  82. ... "Repository": ["Transformers", "Datasets", "Tokenizers"],
  83. ... "Stars": ["36542", "4512", "3934"],
  84. ... "Contributors": ["651", "77", "34"],
  85. ... "Programming language": ["Python", "Python", "Rust, Python and NodeJS"],
  86. ... }
  87. >>> oracle(query="How many stars does the transformers repository have?", table=table)
  88. {'answer': 'AVERAGE > 36542', 'coordinates': [(0, 1)], 'cells': ['36542'], 'aggregator': 'AVERAGE'}
  89. ```
  90. Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
  91. This tabular question answering pipeline can currently be loaded from [`pipeline`] using the following task
  92. identifier: `"table-question-answering"`.
  93. The models that this pipeline can use are models that have been fine-tuned on a tabular question answering task.
  94. See the up-to-date list of available models on
  95. [huggingface.co/models](https://huggingface.co/models?filter=table-question-answering).
  96. """
  97. default_input_names = "table,query"
  98. _pipeline_calls_generate = True
  99. _load_processor = False
  100. _load_image_processor = False
  101. _load_feature_extractor = False
  102. _load_tokenizer = True
  103. # Make sure the docstring is updated when the default generation config is changed
  104. _default_generation_config = GenerationConfig(
  105. max_new_tokens=256,
  106. )
  107. def __init__(self, args_parser=TableQuestionAnsweringArgumentHandler(), **kwargs):
  108. super().__init__(**kwargs)
  109. self._args_parser = args_parser
  110. if self.framework == "tf":
  111. mapping = TF_MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
  112. mapping.update(TF_MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
  113. else:
  114. mapping = MODEL_FOR_TABLE_QUESTION_ANSWERING_MAPPING_NAMES.copy()
  115. mapping.update(MODEL_FOR_SEQ_TO_SEQ_CAUSAL_LM_MAPPING_NAMES)
  116. self.check_model_type(mapping)
  117. self.aggregate = getattr(self.model.config, "aggregation_labels", None) and getattr(
  118. self.model.config, "num_aggregation_labels", None
  119. )
  120. self.type = "tapas" if hasattr(self.model.config, "aggregation_labels") else None
  121. def batch_inference(self, **inputs):
  122. return self.model(**inputs)
  123. def sequential_inference(self, **inputs):
  124. """
  125. Inference used for models that need to process sequences in a sequential fashion, like the SQA models which
  126. handle conversational query related to a table.
  127. """
  128. if self.framework == "pt":
  129. all_logits = []
  130. all_aggregations = []
  131. prev_answers = None
  132. batch_size = inputs["input_ids"].shape[0]
  133. input_ids = inputs["input_ids"].to(self.device)
  134. attention_mask = inputs["attention_mask"].to(self.device)
  135. token_type_ids = inputs["token_type_ids"].to(self.device)
  136. token_type_ids_example = None
  137. for index in range(batch_size):
  138. # If sequences have already been processed, the token type IDs will be created according to the previous
  139. # answer.
  140. if prev_answers is not None:
  141. prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
  142. model_labels = np.zeros_like(prev_labels_example.cpu().numpy()) # shape (seq_len,)
  143. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  144. for i in range(model_labels.shape[0]):
  145. segment_id = token_type_ids_example[:, 0].tolist()[i]
  146. col_id = token_type_ids_example[:, 1].tolist()[i] - 1
  147. row_id = token_type_ids_example[:, 2].tolist()[i] - 1
  148. if row_id >= 0 and col_id >= 0 and segment_id == 1:
  149. model_labels[i] = int(prev_answers[(col_id, row_id)])
  150. token_type_ids_example[:, 3] = torch.from_numpy(model_labels).type(torch.long).to(self.device)
  151. input_ids_example = input_ids[index]
  152. attention_mask_example = attention_mask[index] # shape (seq_len,)
  153. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  154. outputs = self.model(
  155. input_ids=input_ids_example.unsqueeze(0),
  156. attention_mask=attention_mask_example.unsqueeze(0),
  157. token_type_ids=token_type_ids_example.unsqueeze(0),
  158. )
  159. logits = outputs.logits
  160. if self.aggregate:
  161. all_aggregations.append(outputs.logits_aggregation)
  162. all_logits.append(logits)
  163. dist_per_token = torch.distributions.Bernoulli(logits=logits)
  164. probabilities = dist_per_token.probs * attention_mask_example.type(torch.float32).to(
  165. dist_per_token.probs.device
  166. )
  167. coords_to_probs = collections.defaultdict(list)
  168. for i, p in enumerate(probabilities.squeeze().tolist()):
  169. segment_id = token_type_ids_example[:, 0].tolist()[i]
  170. col = token_type_ids_example[:, 1].tolist()[i] - 1
  171. row = token_type_ids_example[:, 2].tolist()[i] - 1
  172. if col >= 0 and row >= 0 and segment_id == 1:
  173. coords_to_probs[(col, row)].append(p)
  174. prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
  175. logits_batch = torch.cat(tuple(all_logits), 0)
  176. return (logits_batch,) if not self.aggregate else (logits_batch, torch.cat(tuple(all_aggregations), 0))
  177. else:
  178. all_logits = []
  179. all_aggregations = []
  180. prev_answers = None
  181. batch_size = inputs["input_ids"].shape[0]
  182. input_ids = inputs["input_ids"]
  183. attention_mask = inputs["attention_mask"]
  184. token_type_ids = inputs["token_type_ids"].numpy()
  185. token_type_ids_example = None
  186. for index in range(batch_size):
  187. # If sequences have already been processed, the token type IDs will be created according to the previous
  188. # answer.
  189. if prev_answers is not None:
  190. prev_labels_example = token_type_ids_example[:, 3] # shape (seq_len,)
  191. model_labels = np.zeros_like(prev_labels_example, dtype=np.int32) # shape (seq_len,)
  192. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  193. for i in range(model_labels.shape[0]):
  194. segment_id = token_type_ids_example[:, 0].tolist()[i]
  195. col_id = token_type_ids_example[:, 1].tolist()[i] - 1
  196. row_id = token_type_ids_example[:, 2].tolist()[i] - 1
  197. if row_id >= 0 and col_id >= 0 and segment_id == 1:
  198. model_labels[i] = int(prev_answers[(col_id, row_id)])
  199. token_type_ids_example[:, 3] = model_labels
  200. input_ids_example = input_ids[index]
  201. attention_mask_example = attention_mask[index] # shape (seq_len,)
  202. token_type_ids_example = token_type_ids[index] # shape (seq_len, 7)
  203. outputs = self.model(
  204. input_ids=np.expand_dims(input_ids_example, axis=0),
  205. attention_mask=np.expand_dims(attention_mask_example, axis=0),
  206. token_type_ids=np.expand_dims(token_type_ids_example, axis=0),
  207. )
  208. logits = outputs.logits
  209. if self.aggregate:
  210. all_aggregations.append(outputs.logits_aggregation)
  211. all_logits.append(logits)
  212. probabilities = tf.math.sigmoid(tf.cast(logits, tf.float32)) * tf.cast(
  213. attention_mask_example, tf.float32
  214. )
  215. coords_to_probs = collections.defaultdict(list)
  216. for i, p in enumerate(tf.squeeze(probabilities).numpy().tolist()):
  217. segment_id = token_type_ids_example[:, 0].tolist()[i]
  218. col = token_type_ids_example[:, 1].tolist()[i] - 1
  219. row = token_type_ids_example[:, 2].tolist()[i] - 1
  220. if col >= 0 and row >= 0 and segment_id == 1:
  221. coords_to_probs[(col, row)].append(p)
  222. prev_answers = {key: np.array(coords_to_probs[key]).mean() > 0.5 for key in coords_to_probs}
  223. logits_batch = tf.concat(tuple(all_logits), 0)
  224. return (logits_batch,) if not self.aggregate else (logits_batch, tf.concat(tuple(all_aggregations), 0))
  225. def __call__(self, *args, **kwargs):
  226. r"""
  227. Answers queries according to a table. The pipeline accepts several types of inputs which are detailed below:
  228. - `pipeline(table, query)`
  229. - `pipeline(table, [query])`
  230. - `pipeline(table=table, query=query)`
  231. - `pipeline(table=table, query=[query])`
  232. - `pipeline({"table": table, "query": query})`
  233. - `pipeline({"table": table, "query": [query]})`
  234. - `pipeline([{"table": table, "query": query}, {"table": table, "query": query}])`
  235. The `table` argument should be a dict or a DataFrame built from that dict, containing the whole table:
  236. Example:
  237. ```python
  238. data = {
  239. "actors": ["brad pitt", "leonardo di caprio", "george clooney"],
  240. "age": ["56", "45", "59"],
  241. "number of movies": ["87", "53", "69"],
  242. "date of birth": ["7 february 1967", "10 june 1996", "28 november 1967"],
  243. }
  244. ```
  245. This dictionary can be passed in as such, or can be converted to a pandas DataFrame:
  246. Example:
  247. ```python
  248. import pandas as pd
  249. table = pd.DataFrame.from_dict(data)
  250. ```
  251. Args:
  252. table (`pd.DataFrame` or `Dict`):
  253. Pandas DataFrame or dictionary that will be converted to a DataFrame containing all the table values.
  254. See above for an example of dictionary.
  255. query (`str` or `list[str]`):
  256. Query or list of queries that will be sent to the model alongside the table.
  257. sequential (`bool`, *optional*, defaults to `False`):
  258. Whether to do inference sequentially or as a batch. Batching is faster, but models like SQA require the
  259. inference to be done sequentially to extract relations within sequences, given their conversational
  260. nature.
  261. padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
  262. Activates and controls padding. Accepts the following values:
  263. - `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
  264. sequence if provided).
  265. - `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
  266. acceptable input length for the model if that argument is not provided.
  267. - `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
  268. lengths).
  269. truncation (`bool`, `str` or [`TapasTruncationStrategy`], *optional*, defaults to `False`):
  270. Activates and controls truncation. Accepts the following values:
  271. - `True` or `'drop_rows_to_fit'`: Truncate to a maximum length specified with the argument `max_length`
  272. or to the maximum acceptable input length for the model if that argument is not provided. This will
  273. truncate row by row, removing rows from the table.
  274. - `False` or `'do_not_truncate'` (default): No truncation (i.e., can output batch with sequence lengths
  275. greater than the model maximum admissible input size).
  276. Return:
  277. A dictionary or a list of dictionaries containing results: Each result is a dictionary with the following
  278. keys:
  279. - **answer** (`str`) -- The answer of the query given the table. If there is an aggregator, the answer will
  280. be preceded by `AGGREGATOR >`.
  281. - **coordinates** (`list[tuple[int, int]]`) -- Coordinates of the cells of the answers.
  282. - **cells** (`list[str]`) -- List of strings made up of the answer cell values.
  283. - **aggregator** (`str`) -- If the model has an aggregator, this returns the aggregator.
  284. """
  285. pipeline_inputs = self._args_parser(*args, **kwargs)
  286. results = super().__call__(pipeline_inputs, **kwargs)
  287. if len(results) == 1:
  288. return results[0]
  289. return results
  290. def _sanitize_parameters(self, sequential=None, padding=None, truncation=None, **kwargs):
  291. preprocess_params = {}
  292. if padding is not None:
  293. preprocess_params["padding"] = padding
  294. if truncation is not None:
  295. preprocess_params["truncation"] = truncation
  296. forward_params = {}
  297. if sequential is not None:
  298. forward_params["sequential"] = sequential
  299. if getattr(self, "assistant_model", None) is not None:
  300. forward_params["assistant_model"] = self.assistant_model
  301. if getattr(self, "assistant_tokenizer", None) is not None:
  302. forward_params["tokenizer"] = self.tokenizer
  303. forward_params["assistant_tokenizer"] = self.assistant_tokenizer
  304. return preprocess_params, forward_params, {}
  305. def preprocess(self, pipeline_input, padding=True, truncation=None):
  306. if truncation is None:
  307. if self.type == "tapas":
  308. truncation = "drop_rows_to_fit"
  309. else:
  310. truncation = "do_not_truncate"
  311. table, query = pipeline_input["table"], pipeline_input["query"]
  312. if table.empty:
  313. raise ValueError("table is empty")
  314. if query is None or query == "":
  315. raise ValueError("query is empty")
  316. inputs = self.tokenizer(table, query, return_tensors=self.framework, truncation=truncation, padding=padding)
  317. inputs["table"] = table
  318. return inputs
  319. def _forward(self, model_inputs, sequential=False, **generate_kwargs):
  320. table = model_inputs.pop("table")
  321. if self.type == "tapas":
  322. if sequential:
  323. outputs = self.sequential_inference(**model_inputs)
  324. else:
  325. outputs = self.batch_inference(**model_inputs)
  326. else:
  327. # User-defined `generation_config` passed to the pipeline call take precedence
  328. if "generation_config" not in generate_kwargs:
  329. generate_kwargs["generation_config"] = self.generation_config
  330. outputs = self.model.generate(**model_inputs, **generate_kwargs)
  331. model_outputs = {"model_inputs": model_inputs, "table": table, "outputs": outputs}
  332. return model_outputs
  333. def postprocess(self, model_outputs):
  334. inputs = model_outputs["model_inputs"]
  335. table = model_outputs["table"]
  336. outputs = model_outputs["outputs"]
  337. if self.type == "tapas":
  338. if self.aggregate:
  339. logits, logits_agg = outputs[:2]
  340. predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits, logits_agg)
  341. answer_coordinates_batch, agg_predictions = predictions
  342. aggregators = {i: self.model.config.aggregation_labels[pred] for i, pred in enumerate(agg_predictions)}
  343. no_agg_label_index = self.model.config.no_aggregation_label_index
  344. aggregators_prefix = {
  345. i: aggregators[i] + " > " for i, pred in enumerate(agg_predictions) if pred != no_agg_label_index
  346. }
  347. else:
  348. logits = outputs[0]
  349. predictions = self.tokenizer.convert_logits_to_predictions(inputs, logits)
  350. answer_coordinates_batch = predictions[0]
  351. aggregators = {}
  352. aggregators_prefix = {}
  353. answers = []
  354. for index, coordinates in enumerate(answer_coordinates_batch):
  355. cells = [table.iat[coordinate] for coordinate in coordinates]
  356. aggregator = aggregators.get(index, "")
  357. aggregator_prefix = aggregators_prefix.get(index, "")
  358. answer = {
  359. "answer": aggregator_prefix + ", ".join(cells),
  360. "coordinates": coordinates,
  361. "cells": [table.iat[coordinate] for coordinate in coordinates],
  362. }
  363. if aggregator:
  364. answer["aggregator"] = aggregator
  365. answers.append(answer)
  366. if len(answer) == 0:
  367. raise PipelineException("Table question answering", self.model.name_or_path, "Empty answer")
  368. else:
  369. answers = [{"answer": answer} for answer in self.tokenizer.batch_decode(outputs, skip_special_tokens=True)]
  370. return answers if len(answers) > 1 else answers[0]