configuration_bart.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406
  1. # coding=utf-8
  2. # Copyright 2021 The Fairseq Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """BART model configuration"""
  16. import warnings
  17. from collections import OrderedDict
  18. from collections.abc import Mapping
  19. from typing import Any, Optional
  20. from ... import PreTrainedTokenizer
  21. from ...configuration_utils import PretrainedConfig
  22. from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
  23. from ...onnx.utils import compute_effective_axis_dimension
  24. from ...utils import TensorType, is_torch_available, logging
  25. logger = logging.get_logger(__name__)
  26. class BartConfig(PretrainedConfig):
  27. r"""
  28. This is the configuration class to store the configuration of a [`BartModel`]. It is used to instantiate a BART
  29. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  30. defaults will yield a similar configuration to that of the BART
  31. [facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
  32. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  33. documentation from [`PretrainedConfig`] for more information.
  34. Args:
  35. vocab_size (`int`, *optional*, defaults to 50265):
  36. Vocabulary size of the BART model. Defines the number of different tokens that can be represented by the
  37. `inputs_ids` passed when calling [`BartModel`] or [`TFBartModel`].
  38. d_model (`int`, *optional*, defaults to 1024):
  39. Dimensionality of the layers and the pooler layer.
  40. encoder_layers (`int`, *optional*, defaults to 12):
  41. Number of encoder layers.
  42. decoder_layers (`int`, *optional*, defaults to 12):
  43. Number of decoder layers.
  44. encoder_attention_heads (`int`, *optional*, defaults to 16):
  45. Number of attention heads for each attention layer in the Transformer encoder.
  46. decoder_attention_heads (`int`, *optional*, defaults to 16):
  47. Number of attention heads for each attention layer in the Transformer decoder.
  48. decoder_ffn_dim (`int`, *optional*, defaults to 4096):
  49. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  50. encoder_ffn_dim (`int`, *optional*, defaults to 4096):
  51. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  52. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
  53. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  54. `"relu"`, `"silu"` and `"gelu_new"` are supported.
  55. dropout (`float`, *optional*, defaults to 0.1):
  56. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  57. attention_dropout (`float`, *optional*, defaults to 0.0):
  58. The dropout ratio for the attention probabilities.
  59. activation_dropout (`float`, *optional*, defaults to 0.0):
  60. The dropout ratio for activations inside the fully connected layer.
  61. classifier_dropout (`float`, *optional*, defaults to 0.0):
  62. The dropout ratio for classifier.
  63. max_position_embeddings (`int`, *optional*, defaults to 1024):
  64. The maximum sequence length that this model might ever be used with. Typically set this to something large
  65. just in case (e.g., 512 or 1024 or 2048).
  66. init_std (`float`, *optional*, defaults to 0.02):
  67. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  68. encoder_layerdrop (`float`, *optional*, defaults to 0.0):
  69. The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  70. for more details.
  71. decoder_layerdrop (`float`, *optional*, defaults to 0.0):
  72. The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  73. for more details.
  74. scale_embedding (`bool`, *optional*, defaults to `False`):
  75. Scale embeddings by diving by sqrt(d_model).
  76. use_cache (`bool`, *optional*, defaults to `True`):
  77. Whether or not the model should return the last key/values attentions (not used by all models).
  78. num_labels (`int`, *optional*, defaults to 3):
  79. The number of labels to use in [`BartForSequenceClassification`].
  80. forced_eos_token_id (`int`, *optional*, defaults to 2):
  81. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  82. `eos_token_id`.
  83. Example:
  84. ```python
  85. >>> from transformers import BartConfig, BartModel
  86. >>> # Initializing a BART facebook/bart-large style configuration
  87. >>> configuration = BartConfig()
  88. >>> # Initializing a model (with random weights) from the facebook/bart-large style configuration
  89. >>> model = BartModel(configuration)
  90. >>> # Accessing the model configuration
  91. >>> configuration = model.config
  92. ```"""
  93. model_type = "bart"
  94. keys_to_ignore_at_inference = ["past_key_values"]
  95. attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
  96. def __init__(
  97. self,
  98. vocab_size=50265,
  99. max_position_embeddings=1024,
  100. encoder_layers=12,
  101. encoder_ffn_dim=4096,
  102. encoder_attention_heads=16,
  103. decoder_layers=12,
  104. decoder_ffn_dim=4096,
  105. decoder_attention_heads=16,
  106. encoder_layerdrop=0.0,
  107. decoder_layerdrop=0.0,
  108. activation_function="gelu",
  109. d_model=1024,
  110. dropout=0.1,
  111. attention_dropout=0.0,
  112. activation_dropout=0.0,
  113. init_std=0.02,
  114. classifier_dropout=0.0,
  115. scale_embedding=False,
  116. use_cache=True,
  117. num_labels=3,
  118. pad_token_id=1,
  119. bos_token_id=0,
  120. eos_token_id=2,
  121. is_encoder_decoder=True,
  122. decoder_start_token_id=2,
  123. forced_eos_token_id=2,
  124. **kwargs,
  125. ):
  126. self.vocab_size = vocab_size
  127. self.max_position_embeddings = max_position_embeddings
  128. self.d_model = d_model
  129. self.encoder_ffn_dim = encoder_ffn_dim
  130. self.encoder_layers = encoder_layers
  131. self.encoder_attention_heads = encoder_attention_heads
  132. self.decoder_ffn_dim = decoder_ffn_dim
  133. self.decoder_layers = decoder_layers
  134. self.decoder_attention_heads = decoder_attention_heads
  135. self.dropout = dropout
  136. self.attention_dropout = attention_dropout
  137. self.activation_dropout = activation_dropout
  138. self.activation_function = activation_function
  139. self.init_std = init_std
  140. self.encoder_layerdrop = encoder_layerdrop
  141. self.decoder_layerdrop = decoder_layerdrop
  142. self.classifier_dropout = classifier_dropout
  143. self.use_cache = use_cache
  144. self.num_hidden_layers = encoder_layers
  145. self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
  146. super().__init__(
  147. num_labels=num_labels,
  148. pad_token_id=pad_token_id,
  149. bos_token_id=bos_token_id,
  150. eos_token_id=eos_token_id,
  151. is_encoder_decoder=is_encoder_decoder,
  152. decoder_start_token_id=decoder_start_token_id,
  153. forced_eos_token_id=forced_eos_token_id,
  154. **kwargs,
  155. )
  156. # ensure backward compatibility for BART CNN models
  157. if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
  158. self.forced_bos_token_id = self.bos_token_id
  159. warnings.warn(
  160. f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
  161. "The config can simply be saved and uploaded again to be fixed."
  162. )
  163. class BartOnnxConfig(OnnxSeq2SeqConfigWithPast):
  164. @property
  165. def inputs(self) -> Mapping[str, Mapping[int, str]]:
  166. if self.task in ["default", "seq2seq-lm"]:
  167. common_inputs = OrderedDict(
  168. [
  169. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  170. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  171. ]
  172. )
  173. if self.use_past:
  174. common_inputs["decoder_input_ids"] = {0: "batch"}
  175. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
  176. else:
  177. common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
  178. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
  179. if self.use_past:
  180. self.fill_with_past_key_values_(common_inputs, direction="inputs")
  181. elif self.task == "causal-lm":
  182. # TODO: figure this case out.
  183. common_inputs = OrderedDict(
  184. [
  185. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  186. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  187. ]
  188. )
  189. if self.use_past:
  190. num_encoder_layers, _ = self.num_layers
  191. for i in range(num_encoder_layers):
  192. common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  193. common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  194. else:
  195. common_inputs = OrderedDict(
  196. [
  197. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  198. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  199. ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
  200. ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
  201. ]
  202. )
  203. return common_inputs
  204. @property
  205. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  206. if self.task in ["default", "seq2seq-lm"]:
  207. common_outputs = super().outputs
  208. else:
  209. common_outputs = super(OnnxConfigWithPast, self).outputs
  210. if self.use_past:
  211. num_encoder_layers, _ = self.num_layers
  212. for i in range(num_encoder_layers):
  213. common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  214. common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  215. return common_outputs
  216. def _generate_dummy_inputs_for_default_and_seq2seq_lm(
  217. self,
  218. tokenizer: PreTrainedTokenizer,
  219. batch_size: int = -1,
  220. seq_length: int = -1,
  221. is_pair: bool = False,
  222. framework: Optional[TensorType] = None,
  223. ) -> Mapping[str, Any]:
  224. encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  225. tokenizer, batch_size, seq_length, is_pair, framework
  226. )
  227. # Generate decoder inputs
  228. decoder_seq_length = seq_length if not self.use_past else 1
  229. decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  230. tokenizer, batch_size, decoder_seq_length, is_pair, framework
  231. )
  232. decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
  233. common_inputs = dict(**encoder_inputs, **decoder_inputs)
  234. if self.use_past:
  235. if not is_torch_available():
  236. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  237. else:
  238. import torch
  239. batch, encoder_seq_length = common_inputs["input_ids"].shape
  240. decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
  241. num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
  242. encoder_shape = (
  243. batch,
  244. num_encoder_attention_heads,
  245. encoder_seq_length,
  246. self._config.hidden_size // num_encoder_attention_heads,
  247. )
  248. decoder_past_length = decoder_seq_length + 3
  249. decoder_shape = (
  250. batch,
  251. num_decoder_attention_heads,
  252. decoder_past_length,
  253. self._config.hidden_size // num_decoder_attention_heads,
  254. )
  255. common_inputs["decoder_attention_mask"] = torch.cat(
  256. [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
  257. )
  258. common_inputs["past_key_values"] = []
  259. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  260. num_encoder_layers, num_decoder_layers = self.num_layers
  261. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  262. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  263. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  264. for _ in range(min_num_layers):
  265. common_inputs["past_key_values"].append(
  266. (
  267. torch.zeros(decoder_shape),
  268. torch.zeros(decoder_shape),
  269. torch.zeros(encoder_shape),
  270. torch.zeros(encoder_shape),
  271. )
  272. )
  273. # TODO: test this.
  274. shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
  275. for _ in range(min_num_layers, max_num_layers):
  276. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  277. return common_inputs
  278. def _generate_dummy_inputs_for_causal_lm(
  279. self,
  280. tokenizer: PreTrainedTokenizer,
  281. batch_size: int = -1,
  282. seq_length: int = -1,
  283. is_pair: bool = False,
  284. framework: Optional[TensorType] = None,
  285. ) -> Mapping[str, Any]:
  286. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  287. tokenizer, batch_size, seq_length, is_pair, framework
  288. )
  289. if self.use_past:
  290. if not is_torch_available():
  291. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  292. else:
  293. import torch
  294. batch, seqlen = common_inputs["input_ids"].shape
  295. # Not using the same length for past_key_values
  296. past_key_values_length = seqlen + 2
  297. num_encoder_layers, _ = self.num_layers
  298. num_encoder_attention_heads, _ = self.num_attention_heads
  299. past_shape = (
  300. batch,
  301. num_encoder_attention_heads,
  302. past_key_values_length,
  303. self._config.hidden_size // num_encoder_attention_heads,
  304. )
  305. mask_dtype = common_inputs["attention_mask"].dtype
  306. common_inputs["attention_mask"] = torch.cat(
  307. [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
  308. )
  309. common_inputs["past_key_values"] = [
  310. (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
  311. ]
  312. return common_inputs
  313. def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
  314. self,
  315. tokenizer: PreTrainedTokenizer,
  316. batch_size: int = -1,
  317. seq_length: int = -1,
  318. is_pair: bool = False,
  319. framework: Optional[TensorType] = None,
  320. ) -> Mapping[str, Any]:
  321. # Copied from OnnxConfig.generate_dummy_inputs
  322. # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
  323. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  324. batch_size = compute_effective_axis_dimension(
  325. batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
  326. )
  327. # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
  328. token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
  329. seq_length = compute_effective_axis_dimension(
  330. seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
  331. )
  332. # Generate dummy inputs according to compute batch and sequence
  333. dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
  334. common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
  335. return common_inputs
  336. def generate_dummy_inputs(
  337. self,
  338. tokenizer: PreTrainedTokenizer,
  339. batch_size: int = -1,
  340. seq_length: int = -1,
  341. is_pair: bool = False,
  342. framework: Optional[TensorType] = None,
  343. ) -> Mapping[str, Any]:
  344. if self.task in ["default", "seq2seq-lm"]:
  345. common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
  346. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  347. )
  348. elif self.task == "causal-lm":
  349. common_inputs = self._generate_dummy_inputs_for_causal_lm(
  350. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  351. )
  352. else:
  353. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  354. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  355. )
  356. return common_inputs
  357. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  358. if self.task in ["default", "seq2seq-lm"]:
  359. flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
  360. else:
  361. flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
  362. flattened_output, name, idx, t
  363. )
  364. __all__ = ["BartConfig", "BartOnnxConfig"]