config.py 32 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744
  1. # Copyright 2021 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import copy
  15. import dataclasses
  16. import warnings
  17. from abc import ABC, abstractmethod
  18. from collections import OrderedDict
  19. from collections.abc import Iterable, Mapping
  20. from typing import TYPE_CHECKING, Any, Callable, Optional, Union
  21. import numpy as np
  22. from packaging import version
  23. from ..utils import TensorType, is_torch_available, is_vision_available, logging
  24. from .utils import ParameterFormat, compute_effective_axis_dimension, compute_serialized_parameters_size
  25. if TYPE_CHECKING:
  26. from ..configuration_utils import PretrainedConfig
  27. from ..feature_extraction_utils import FeatureExtractionMixin
  28. from ..image_processing_utils import ImageProcessingMixin
  29. from ..tokenization_utils_base import PreTrainedTokenizerBase
  30. if is_vision_available():
  31. from PIL import Image
  32. logger = logging.get_logger(__name__)
  33. DEFAULT_ONNX_OPSET = 11
  34. # 2 Gb
  35. EXTERNAL_DATA_FORMAT_SIZE_LIMIT = 2 * 1024 * 1024 * 1024
  36. @dataclasses.dataclass
  37. class PatchingSpec:
  38. """
  39. Data class that holds patching specifications.
  40. Args:
  41. o: Module / object where the op to patch is located
  42. name: Name of the op to monkey patch
  43. custom_op: Custom op that patches the original op
  44. orig_op: Original op that is being patched
  45. op_wrapper: Wrapper (optional) that wraps both the original and custom ops.
  46. It is useful for ops that are class or static methods for instance.
  47. """
  48. o: Any
  49. name: str
  50. custom_op: Callable
  51. orig_op: Optional[Callable] = None
  52. op_wrapper: Optional[Callable] = None
  53. class OnnxConfig(ABC):
  54. """
  55. Base class for ONNX exportable model describing metadata on how to export the model through the ONNX format.
  56. """
  57. default_fixed_batch = 2
  58. default_fixed_sequence = 8
  59. default_fixed_num_choices = 4
  60. torch_onnx_minimum_version = version.parse("1.8")
  61. _tasks_to_common_outputs = {
  62. "causal-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  63. "default": OrderedDict({"last_hidden_state": {0: "batch", 1: "sequence"}}),
  64. "image-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  65. "image-segmentation": OrderedDict(
  66. {
  67. "logits": {0: "batch", 1: "sequence"},
  68. "pred_boxes": {0: "batch", 1: "sequence"},
  69. "pred_masks": {0: "batch", 1: "sequence"},
  70. }
  71. ),
  72. "masked-im": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  73. "masked-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  74. "multiple-choice": OrderedDict({"logits": {0: "batch"}}),
  75. "object-detection": OrderedDict(
  76. {
  77. "logits": {0: "batch", 1: "sequence"},
  78. "pred_boxes": {0: "batch", 1: "sequence"},
  79. }
  80. ),
  81. "question-answering": OrderedDict(
  82. {
  83. "start_logits": {0: "batch", 1: "sequence"},
  84. "end_logits": {0: "batch", 1: "sequence"},
  85. }
  86. ),
  87. "semantic-segmentation": OrderedDict({"logits": {0: "batch", 1: "num_labels", 2: "height", 3: "width"}}),
  88. "seq2seq-lm": OrderedDict({"logits": {0: "batch", 1: "decoder_sequence"}}),
  89. "sequence-classification": OrderedDict({"logits": {0: "batch"}}),
  90. "token-classification": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  91. "vision2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  92. "speech2seq-lm": OrderedDict({"logits": {0: "batch", 1: "sequence"}}),
  93. }
  94. def __init__(
  95. self, config: "PretrainedConfig", task: str = "default", patching_specs: Optional[list[PatchingSpec]] = None
  96. ):
  97. self._config = config
  98. if task not in self._tasks_to_common_outputs:
  99. raise ValueError(
  100. f"{task} is not a supported task, supported tasks: {self._tasks_to_common_outputs.keys()}"
  101. )
  102. self.task = task
  103. self._patching_specs = []
  104. for spec in patching_specs if patching_specs is not None else []:
  105. final_spec = spec
  106. if spec.orig_op is None:
  107. final_spec = dataclasses.replace(spec, orig_op=getattr(spec.o, spec.name))
  108. self._patching_specs.append(final_spec)
  109. @classmethod
  110. def from_model_config(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfig":
  111. """
  112. Instantiate a OnnxConfig for a specific model
  113. Args:
  114. config: The model's configuration to use when exporting to ONNX
  115. Returns:
  116. OnnxConfig for this model
  117. """
  118. return cls(config, task=task)
  119. @property
  120. @abstractmethod
  121. def inputs(self) -> Mapping[str, Mapping[int, str]]:
  122. """
  123. Mapping containing the axis definition of the input tensors to provide to the model
  124. Returns:
  125. For each input: its name associated to the axes symbolic name and the axis position within the tensor
  126. """
  127. raise NotImplementedError()
  128. @property
  129. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  130. """
  131. Mapping containing the axis definition of the output tensors to provide to the model
  132. Returns:
  133. For each output: its name associated to the axes symbolic name and the axis position within the tensor
  134. """
  135. common_outputs = self._tasks_to_common_outputs[self.task]
  136. return copy.deepcopy(common_outputs)
  137. @property
  138. def values_override(self) -> Optional[Mapping[str, Any]]:
  139. """
  140. Dictionary of keys to override in the model's config before exporting
  141. Returns:
  142. Dictionary with the keys (and their corresponding values) to override
  143. """
  144. if hasattr(self._config, "use_cache"):
  145. return {"use_cache": False}
  146. return None
  147. @property
  148. def default_batch_size(self) -> int:
  149. """
  150. The default batch size to use if no other indication
  151. Returns:
  152. Integer > 0
  153. """
  154. # Using 2 avoid ONNX making assumption about single sample batch
  155. return OnnxConfig.default_fixed_batch
  156. @property
  157. def default_sequence_length(self) -> int:
  158. """
  159. The default sequence length to use if no other indication
  160. Returns:
  161. Integer > 0
  162. """
  163. return OnnxConfig.default_fixed_sequence
  164. @property
  165. def default_num_choices(self) -> int:
  166. """
  167. The default number of choices to use if no other indication
  168. Returns:
  169. Integer > 0
  170. """
  171. return OnnxConfig.default_fixed_num_choices
  172. @property
  173. def default_onnx_opset(self) -> int:
  174. """
  175. Which onnx opset to use when exporting the model
  176. Returns:
  177. Integer ONNX Opset version
  178. """
  179. return DEFAULT_ONNX_OPSET
  180. @property
  181. def atol_for_validation(self) -> float:
  182. """
  183. What absolute tolerance value to use during model conversion validation.
  184. Returns:
  185. Float absolute tolerance value.
  186. """
  187. return 1e-5
  188. @property
  189. def is_torch_support_available(self) -> bool:
  190. """
  191. The minimum PyTorch version required to export the model.
  192. Returns:
  193. `bool`: Whether the installed version of PyTorch is compatible with the model.
  194. """
  195. if is_torch_available():
  196. from transformers.utils import get_torch_version
  197. return version.parse(get_torch_version()) >= self.torch_onnx_minimum_version
  198. else:
  199. return False
  200. @staticmethod
  201. def use_external_data_format(num_parameters: int) -> bool:
  202. """
  203. Flag indicating if the model requires using external data format
  204. Args:
  205. num_parameters: Number of parameter on the model
  206. Returns:
  207. True if model.num_parameters() * size_of(float32) >= 2Gb False otherwise
  208. """
  209. return (
  210. compute_serialized_parameters_size(num_parameters, ParameterFormat.Float)
  211. >= EXTERNAL_DATA_FORMAT_SIZE_LIMIT
  212. )
  213. def _generate_dummy_images(
  214. self, batch_size: int = 2, num_channels: int = 3, image_height: int = 40, image_width: int = 40
  215. ):
  216. images = []
  217. for _ in range(batch_size):
  218. data = np.random.rand(image_height, image_width, num_channels) * 255
  219. images.append(Image.fromarray(data.astype("uint8")).convert("RGB"))
  220. return images
  221. def _generate_dummy_audio(
  222. self, batch_size: int = 2, sampling_rate: int = 22050, time_duration: float = 5.0, frequency: int = 220
  223. ):
  224. audio_data = []
  225. for _ in range(batch_size):
  226. # time variable
  227. t = np.linspace(0, time_duration, int(time_duration * sampling_rate), endpoint=False)
  228. # generate pure sine wave at `frequency` Hz
  229. audio_data.append(0.5 * np.sin(2 * np.pi * frequency * t))
  230. return audio_data
  231. def generate_dummy_inputs(
  232. self,
  233. preprocessor: Union["PreTrainedTokenizerBase", "FeatureExtractionMixin", "ImageProcessingMixin"],
  234. batch_size: int = -1,
  235. seq_length: int = -1,
  236. num_choices: int = -1,
  237. is_pair: bool = False,
  238. framework: Optional[TensorType] = None,
  239. num_channels: int = 3,
  240. image_width: int = 40,
  241. image_height: int = 40,
  242. sampling_rate: int = 22050,
  243. time_duration: float = 5.0,
  244. frequency: int = 220,
  245. tokenizer: Optional["PreTrainedTokenizerBase"] = None,
  246. ) -> Mapping[str, Any]:
  247. """
  248. Generate inputs to provide to the ONNX exporter for the specific framework
  249. Args:
  250. preprocessor: ([`PreTrainedTokenizerBase`], [`FeatureExtractionMixin`], or [`ImageProcessingMixin`]):
  251. The preprocessor associated with this model configuration.
  252. batch_size (`int`, *optional*, defaults to -1):
  253. The batch size to export the model for (-1 means dynamic axis).
  254. num_choices (`int`, *optional*, defaults to -1):
  255. The number of candidate answers provided for multiple choice task (-1 means dynamic axis).
  256. seq_length (`int`, *optional*, defaults to -1):
  257. The sequence length to export the model for (-1 means dynamic axis).
  258. is_pair (`bool`, *optional*, defaults to `False`):
  259. Indicate if the input is a pair (sentence 1, sentence 2)
  260. framework (`TensorType`, *optional*, defaults to `None`):
  261. The framework (PyTorch or TensorFlow) that the tokenizer will generate tensors for.
  262. num_channels (`int`, *optional*, defaults to 3):
  263. The number of channels of the generated images.
  264. image_width (`int`, *optional*, defaults to 40):
  265. The width of the generated images.
  266. image_height (`int`, *optional*, defaults to 40):
  267. The height of the generated images.
  268. sampling_rate (`int`, *optional* defaults to 22050)
  269. The sampling rate for audio data generation.
  270. time_duration (`float`, *optional* defaults to 5.0)
  271. Total seconds of sampling for audio data generation.
  272. frequency (`int`, *optional* defaults to 220)
  273. The desired natural frequency of generated audio.
  274. Returns:
  275. Mapping[str, Tensor] holding the kwargs to provide to the model's forward function
  276. """
  277. from ..feature_extraction_utils import FeatureExtractionMixin
  278. from ..image_processing_utils import ImageProcessingMixin
  279. from ..tokenization_utils_base import PreTrainedTokenizerBase
  280. if isinstance(preprocessor, PreTrainedTokenizerBase) and tokenizer is not None:
  281. raise ValueError("You cannot provide both a tokenizer and a preprocessor to generate dummy inputs.")
  282. if tokenizer is not None:
  283. warnings.warn(
  284. "The `tokenizer` argument is deprecated and will be removed in version 5 of Transformers. Use"
  285. " `preprocessor` instead.",
  286. FutureWarning,
  287. )
  288. logger.warning("Overwriting the `preprocessor` argument with `tokenizer` to generate dummy inputs.")
  289. preprocessor = tokenizer
  290. if isinstance(preprocessor, PreTrainedTokenizerBase):
  291. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  292. batch_size = compute_effective_axis_dimension(
  293. batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
  294. )
  295. # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
  296. token_to_add = preprocessor.num_special_tokens_to_add(is_pair)
  297. seq_length = compute_effective_axis_dimension(
  298. seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
  299. )
  300. # Generate dummy inputs according to compute batch and sequence
  301. input_token = (
  302. preprocessor.unk_token
  303. if (preprocessor.unk_token is not None and len(preprocessor.unk_token) > 0)
  304. else "0"
  305. )
  306. dummy_input = [" ".join([input_token]) * seq_length] * batch_size
  307. if self.task == "multiple-choice":
  308. # If dynamic axis (-1) we forward with a fixed dimension of 4 candidate answers to avoid optimizations
  309. # made by ONNX
  310. num_choices = compute_effective_axis_dimension(
  311. num_choices, fixed_dimension=OnnxConfig.default_fixed_num_choices, num_token_to_add=0
  312. )
  313. dummy_input = dummy_input * num_choices
  314. # The shape of the tokenized inputs values is [batch_size * num_choices, seq_length]
  315. tokenized_input = preprocessor(dummy_input, text_pair=dummy_input)
  316. # Unflatten the tokenized inputs values expanding it to the shape [batch_size, num_choices, seq_length]
  317. for k, v in tokenized_input.items():
  318. tokenized_input[k] = [v[i : i + num_choices] for i in range(0, len(v), num_choices)]
  319. return dict(tokenized_input.convert_to_tensors(tensor_type=framework))
  320. return dict(preprocessor(dummy_input, return_tensors=framework))
  321. elif isinstance(preprocessor, ImageProcessingMixin):
  322. if preprocessor.model_input_names[0] != "pixel_values":
  323. raise ValueError(
  324. f"The `preprocessor` is an image processor ({preprocessor.__class__.__name__}) and expects"
  325. f' `model_input_names[0]` to be "pixel_values", but got {preprocessor.model_input_names[0]}'
  326. )
  327. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  328. batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
  329. dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
  330. return dict(preprocessor(images=dummy_input, return_tensors=framework))
  331. elif isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "pixel_values":
  332. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  333. batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
  334. dummy_input = self._generate_dummy_images(batch_size, num_channels, image_height, image_width)
  335. return dict(preprocessor(images=dummy_input, return_tensors=framework))
  336. elif (
  337. isinstance(preprocessor, FeatureExtractionMixin) and preprocessor.model_input_names[0] == "input_features"
  338. ):
  339. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  340. batch_size = compute_effective_axis_dimension(batch_size, fixed_dimension=OnnxConfig.default_fixed_batch)
  341. dummy_input = self._generate_dummy_audio(batch_size, sampling_rate, time_duration, frequency)
  342. return dict(preprocessor(dummy_input, return_tensors=framework))
  343. else:
  344. raise ValueError(
  345. "Unable to generate dummy inputs for the model. Please provide a tokenizer or a preprocessor."
  346. )
  347. def generate_dummy_inputs_onnxruntime(self, reference_model_inputs: Mapping[str, Any]) -> Mapping[str, Any]:
  348. """
  349. Generate inputs for ONNX Runtime using the reference model inputs. Override this to run inference with seq2seq
  350. models which have the encoder and decoder exported as separate ONNX files.
  351. Args:
  352. reference_model_inputs ([`Mapping[str, Tensor]`):
  353. Reference inputs for the model.
  354. Returns:
  355. `Mapping[str, Tensor]`: The mapping holding the kwargs to provide to the model's forward function
  356. """
  357. return reference_model_inputs
  358. def patch_ops(self):
  359. for spec in self._patching_specs:
  360. custom_op = spec.custom_op if spec.op_wrapper is None else spec.op_wrapper(spec.custom_op)
  361. setattr(spec.o, spec.name, custom_op)
  362. def restore_ops(self):
  363. for spec in self._patching_specs:
  364. orig_op = spec.orig_op if spec.op_wrapper is None else spec.op_wrapper(spec.orig_op)
  365. setattr(spec.o, spec.name, orig_op)
  366. @classmethod
  367. def flatten_output_collection_property(cls, name: str, field: Iterable[Any]) -> dict[str, Any]:
  368. """
  369. Flatten any potential nested structure expanding the name of the field with the index of the element within the
  370. structure.
  371. Args:
  372. name: The name of the nested structure
  373. field: The structure to, potentially, be flattened
  374. Returns:
  375. (dict[str, Any]): Outputs with flattened structure and key mapping this new structure.
  376. """
  377. from itertools import chain
  378. return {f"{name}.{idx}": item for idx, item in enumerate(chain.from_iterable(field))}
  379. class OnnxConfigWithPast(OnnxConfig, ABC):
  380. def __init__(
  381. self,
  382. config: "PretrainedConfig",
  383. task: str = "default",
  384. patching_specs: Optional[list[PatchingSpec]] = None,
  385. use_past: bool = False,
  386. ):
  387. super().__init__(config, task=task, patching_specs=patching_specs)
  388. self.use_past = use_past
  389. @classmethod
  390. def with_past(cls, config: "PretrainedConfig", task: str = "default") -> "OnnxConfigWithPast":
  391. """
  392. Instantiate a OnnxConfig with `use_past` attribute set to True
  393. Args:
  394. config: The underlying model's config to use when exporting to ONNX
  395. Returns:
  396. OnnxConfig with `.use_past = True`
  397. """
  398. return cls(config, task=task, use_past=True)
  399. @property
  400. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  401. common_outputs = super().outputs
  402. if self.use_past:
  403. self.fill_with_past_key_values_(common_outputs, direction="outputs")
  404. return common_outputs
  405. @property
  406. def values_override(self) -> Optional[Mapping[str, Any]]:
  407. if hasattr(self._config, "use_cache"):
  408. return {"use_cache": self.use_past}
  409. return None
  410. @property
  411. def num_layers(self) -> int:
  412. """
  413. The number of layers attribute retrieved from the model config. Override this for model configs where the
  414. number of layers attribute is not called `num_layers`.
  415. """
  416. if not hasattr(self._config, "num_layers"):
  417. raise AttributeError(
  418. "could not find the number of layers attribute in the model configuration, override the num_layers"
  419. " property of the model OnnxConfig to solve this"
  420. )
  421. return self._config.num_layers
  422. @property
  423. def num_attention_heads(self) -> int:
  424. """
  425. The number of attention heads attribute retrieved from the model config. Override this for model configs where
  426. the number of attention heads attribute is not called `num_attention_heads`.
  427. """
  428. if not hasattr(self._config, "num_attention_heads"):
  429. raise AttributeError(
  430. "could not find the number of attention heads attribute in the model configuration, override the"
  431. " num_attention_heads property of the model OnnxConfig to solve this"
  432. )
  433. return self._config.num_attention_heads
  434. def generate_dummy_inputs(
  435. self,
  436. tokenizer: "PreTrainedTokenizerBase",
  437. batch_size: int = -1,
  438. seq_length: int = -1,
  439. is_pair: bool = False,
  440. framework: Optional[TensorType] = None,
  441. ) -> Mapping[str, Any]:
  442. # TODO: should we set seq_length = 1 when self.use_past = True?
  443. common_inputs = super().generate_dummy_inputs(
  444. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  445. )
  446. if self.use_past:
  447. if not is_torch_available():
  448. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  449. else:
  450. import torch
  451. batch, seqlen = common_inputs["input_ids"].shape
  452. # Not using the same length for past_key_values
  453. past_key_values_length = seqlen + 2
  454. shape = (
  455. batch,
  456. self.num_attention_heads,
  457. past_key_values_length,
  458. self._config.hidden_size // self.num_attention_heads,
  459. )
  460. if "attention_mask" in common_inputs:
  461. mask_dtype = common_inputs["attention_mask"].dtype
  462. common_inputs["attention_mask"] = torch.cat(
  463. [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)],
  464. dim=1,
  465. )
  466. common_inputs["past_key_values"] = []
  467. for _ in range(self.num_layers):
  468. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  469. return common_inputs
  470. def fill_with_past_key_values_(
  471. self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str, inverted_values_shape: bool = False
  472. ):
  473. """
  474. Fill the input_or_outputs mapping with past_key_values dynamic axes considering.
  475. Args:
  476. inputs_or_outputs: The mapping to fill.
  477. direction: either "inputs" or "outputs", it specifies whether input_or_outputs is the input mapping or the
  478. output mapping, this is important for axes naming.
  479. inverted_values_shape:
  480. If `True`, store values on dynamic axis 1, else on axis 2.
  481. """
  482. if direction not in ["inputs", "outputs"]:
  483. raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
  484. name = "past_key_values" if direction == "inputs" else "present"
  485. for i in range(self.num_layers):
  486. inputs_or_outputs[f"{name}.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  487. if inverted_values_shape:
  488. inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 1: "past_sequence + sequence"}
  489. else:
  490. inputs_or_outputs[f"{name}.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  491. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  492. flattened_output[f"{name}.{idx}.key"] = t[0]
  493. flattened_output[f"{name}.{idx}.value"] = t[1]
  494. def flatten_output_collection_property(self, name: str, field: Iterable[Any]) -> dict[str, Any]:
  495. flattened_output = {}
  496. if name in ["present", "past_key_values"]:
  497. for idx, t in enumerate(field):
  498. self._flatten_past_key_values_(flattened_output, name, idx, t)
  499. else:
  500. flattened_output = super().flatten_output_collection_property(name, field)
  501. return flattened_output
  502. class OnnxSeq2SeqConfigWithPast(OnnxConfigWithPast):
  503. @property
  504. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  505. common_outputs = super(OnnxConfigWithPast, self).outputs
  506. # Renaming the outputs axes properly.
  507. for name, axes_names in common_outputs.items():
  508. sequence_name = "encoder_sequence" if "encoder" in name else "decoder_sequence"
  509. for axis_idx, name in axes_names.items():
  510. if "sequence" in name:
  511. axes_names[axis_idx] = sequence_name
  512. # We reset the value as the order in common_outputs (OrderedDict) is lost otherwise
  513. else:
  514. axes_names[axis_idx] = name
  515. if self.use_past:
  516. self.fill_with_past_key_values_(common_outputs, direction="outputs")
  517. return common_outputs
  518. @property
  519. def num_layers(self) -> tuple[int, ...]:
  520. try:
  521. num_layers = super().num_layers
  522. num_layers = (num_layers, num_layers)
  523. except AttributeError:
  524. if hasattr(self._config, "encoder_layers") and hasattr(self._config, "decoder_layers"):
  525. num_layers = (self._config.encoder_layers, self._config.decoder_layers)
  526. else:
  527. raise AttributeError(
  528. "could not find the number of encoder and decoder layers attributes in the model configuration,"
  529. " override the num_layers property of the model OnnxConfig to solve this"
  530. )
  531. return num_layers
  532. @property
  533. def num_attention_heads(self) -> tuple[int, ...]:
  534. try:
  535. num_attention_heads = super().num_attention_heads
  536. num_attention_heads = (num_attention_heads, num_attention_heads)
  537. except AttributeError:
  538. if hasattr(self._config, "encoder_attention_heads") and hasattr(self._config, "decoder_attention_heads"):
  539. num_attention_heads = (self._config.encoder_attention_heads, self._config.decoder_attention_heads)
  540. else:
  541. raise AttributeError(
  542. "could not find the number of attention heads for the encoder and the decoder attributes in the"
  543. " model configuration, override the num_attention_heads property of the model OnnxConfig to solve"
  544. " this"
  545. )
  546. return num_attention_heads
  547. def generate_dummy_inputs(
  548. self,
  549. tokenizer: Optional["PreTrainedTokenizerBase"],
  550. batch_size: int = -1,
  551. seq_length: int = -1,
  552. is_pair: bool = False,
  553. framework: Optional[TensorType] = None,
  554. ) -> Mapping[str, Any]:
  555. encoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
  556. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  557. )
  558. # Generate decoder inputs
  559. decoder_seq_length = seq_length if not self.use_past else 1
  560. decoder_inputs = super(OnnxConfigWithPast, self).generate_dummy_inputs(
  561. tokenizer, batch_size=batch_size, seq_length=decoder_seq_length, is_pair=is_pair, framework=framework
  562. )
  563. decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
  564. common_inputs = dict(**encoder_inputs, **decoder_inputs)
  565. if self.use_past:
  566. if not is_torch_available():
  567. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  568. else:
  569. import torch
  570. batch = common_inputs["input_ids"].shape[0]
  571. encoder_seq_length = common_inputs["input_ids"].shape[1]
  572. decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
  573. num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
  574. encoder_shape = (
  575. batch,
  576. num_encoder_attention_heads,
  577. encoder_seq_length,
  578. self._config.hidden_size // num_encoder_attention_heads,
  579. )
  580. decoder_shape = (
  581. batch,
  582. num_decoder_attention_heads,
  583. # Not using the same length for past_key_values
  584. decoder_seq_length + 3,
  585. self._config.hidden_size // num_decoder_attention_heads,
  586. )
  587. common_inputs["past_key_values"] = []
  588. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  589. num_encoder_layers, num_decoder_layers = self.num_layers
  590. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  591. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  592. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  593. for _ in range(min_num_layers):
  594. # For encoder-decoder models, past_key_values contains pre-computed values for both the encoder and the
  595. # decoder layers, hence a tuple of 4 tensors instead of 2
  596. common_inputs["past_key_values"].append(
  597. (
  598. torch.zeros(decoder_shape),
  599. torch.zeros(decoder_shape),
  600. torch.zeros(encoder_shape),
  601. torch.zeros(encoder_shape),
  602. )
  603. )
  604. # TODO: test this.
  605. shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
  606. for _ in range(min_num_layers, max_num_layers):
  607. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  608. return common_inputs
  609. def fill_with_past_key_values_(self, inputs_or_outputs: Mapping[str, Mapping[int, str]], direction: str):
  610. if direction not in ["inputs", "outputs"]:
  611. raise ValueError(f'direction must either be "inputs" or "outputs", but {direction} was given')
  612. name = "past_key_values" if direction == "inputs" else "present"
  613. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  614. num_encoder_layers, num_decoder_layers = self.num_layers
  615. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  616. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  617. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  618. encoder_sequence = "past_encoder_sequence"
  619. decoder_sequence = "past_decoder_sequence" if direction == "inputs" else "past_decoder_sequence + sequence"
  620. for i in range(min_num_layers):
  621. inputs_or_outputs[f"{name}.{i}.decoder.key"] = {0: "batch", 2: decoder_sequence}
  622. inputs_or_outputs[f"{name}.{i}.decoder.value"] = {0: "batch", 2: decoder_sequence}
  623. inputs_or_outputs[f"{name}.{i}.encoder.key"] = {0: "batch", 2: encoder_sequence}
  624. inputs_or_outputs[f"{name}.{i}.encoder.value"] = {0: "batch", 2: encoder_sequence}
  625. for i in range(min_num_layers, max_num_layers):
  626. if remaining_side_name == "encoder":
  627. axes_info = {0: "batch", 2: encoder_sequence}
  628. else:
  629. axes_info = {0: "batch", 2: decoder_sequence}
  630. inputs_or_outputs[f"{name}.{i}.{remaining_side_name}.key"] = axes_info
  631. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  632. flattened_output[f"{name}.{idx}.decoder.key"] = t[0]
  633. flattened_output[f"{name}.{idx}.decoder.value"] = t[1]
  634. flattened_output[f"{name}.{idx}.encoder.key"] = t[2]
  635. flattened_output[f"{name}.{idx}.encoder.value"] = t[3]