configuration_mbart.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391
  1. # coding=utf-8
  2. # Copyright 2021, The Facebook AI Research Team and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """MBART model configuration"""
  16. from collections import OrderedDict
  17. from collections.abc import Mapping
  18. from typing import Any, Optional
  19. from ... import PreTrainedTokenizer
  20. from ...configuration_utils import PretrainedConfig
  21. from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
  22. from ...onnx.utils import compute_effective_axis_dimension
  23. from ...utils import TensorType, is_torch_available, logging
  24. logger = logging.get_logger(__name__)
  25. class MBartConfig(PretrainedConfig):
  26. r"""
  27. This is the configuration class to store the configuration of a [`MBartModel`]. It is used to instantiate an MBART
  28. model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
  29. defaults will yield a similar configuration to that of the MBART
  30. [facebook/mbart-large-cc25](https://huggingface.co/facebook/mbart-large-cc25) architecture.
  31. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  32. documentation from [`PretrainedConfig`] for more information.
  33. Args:
  34. vocab_size (`int`, *optional*, defaults to 50265):
  35. Vocabulary size of the MBART model. Defines the number of different tokens that can be represented by the
  36. `inputs_ids` passed when calling [`MBartModel`] or [`TFMBartModel`].
  37. d_model (`int`, *optional*, defaults to 1024):
  38. Dimensionality of the layers and the pooler layer.
  39. encoder_layers (`int`, *optional*, defaults to 12):
  40. Number of encoder layers.
  41. decoder_layers (`int`, *optional*, defaults to 12):
  42. Number of decoder layers.
  43. encoder_attention_heads (`int`, *optional*, defaults to 16):
  44. Number of attention heads for each attention layer in the Transformer encoder.
  45. decoder_attention_heads (`int`, *optional*, defaults to 16):
  46. Number of attention heads for each attention layer in the Transformer decoder.
  47. decoder_ffn_dim (`int`, *optional*, defaults to 4096):
  48. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  49. encoder_ffn_dim (`int`, *optional*, defaults to 4096):
  50. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  51. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
  52. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  53. `"relu"`, `"silu"` and `"gelu_new"` are supported.
  54. dropout (`float`, *optional*, defaults to 0.1):
  55. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  56. attention_dropout (`float`, *optional*, defaults to 0.0):
  57. The dropout ratio for the attention probabilities.
  58. activation_dropout (`float`, *optional*, defaults to 0.0):
  59. The dropout ratio for activations inside the fully connected layer.
  60. classifier_dropout (`float`, *optional*, defaults to 0.0):
  61. The dropout ratio for classifier.
  62. max_position_embeddings (`int`, *optional*, defaults to 1024):
  63. The maximum sequence length that this model might ever be used with. Typically set this to something large
  64. just in case (e.g., 512 or 1024 or 2048).
  65. init_std (`float`, *optional*, defaults to 0.02):
  66. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  67. encoder_layerdrop (`float`, *optional*, defaults to 0.0):
  68. The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  69. for more details.
  70. decoder_layerdrop (`float`, *optional*, defaults to 0.0):
  71. The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  72. for more details.
  73. scale_embedding (`bool`, *optional*, defaults to `False`):
  74. Scale embeddings by diving by sqrt(d_model).
  75. use_cache (`bool`, *optional*, defaults to `True`):
  76. Whether or not the model should return the last key/values attentions (not used by all models)
  77. forced_eos_token_id (`int`, *optional*, defaults to 2):
  78. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  79. `eos_token_id`.
  80. Example:
  81. ```python
  82. >>> from transformers import MBartConfig, MBartModel
  83. >>> # Initializing a MBART facebook/mbart-large-cc25 style configuration
  84. >>> configuration = MBartConfig()
  85. >>> # Initializing a model (with random weights) from the facebook/mbart-large-cc25 style configuration
  86. >>> model = MBartModel(configuration)
  87. >>> # Accessing the model configuration
  88. >>> configuration = model.config
  89. ```"""
  90. model_type = "mbart"
  91. keys_to_ignore_at_inference = ["past_key_values"]
  92. attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
  93. def __init__(
  94. self,
  95. vocab_size=50265,
  96. max_position_embeddings=1024,
  97. encoder_layers=12,
  98. encoder_ffn_dim=4096,
  99. encoder_attention_heads=16,
  100. decoder_layers=12,
  101. decoder_ffn_dim=4096,
  102. decoder_attention_heads=16,
  103. encoder_layerdrop=0.0,
  104. decoder_layerdrop=0.0,
  105. use_cache=True,
  106. is_encoder_decoder=True,
  107. activation_function="gelu",
  108. d_model=1024,
  109. dropout=0.1,
  110. attention_dropout=0.0,
  111. activation_dropout=0.0,
  112. init_std=0.02,
  113. classifier_dropout=0.0,
  114. scale_embedding=False,
  115. pad_token_id=1,
  116. bos_token_id=0,
  117. eos_token_id=2,
  118. forced_eos_token_id=2,
  119. **kwargs,
  120. ):
  121. self.vocab_size = vocab_size
  122. self.max_position_embeddings = max_position_embeddings
  123. self.d_model = d_model
  124. self.encoder_ffn_dim = encoder_ffn_dim
  125. self.encoder_layers = encoder_layers
  126. self.encoder_attention_heads = encoder_attention_heads
  127. self.decoder_ffn_dim = decoder_ffn_dim
  128. self.decoder_layers = decoder_layers
  129. self.decoder_attention_heads = decoder_attention_heads
  130. self.dropout = dropout
  131. self.attention_dropout = attention_dropout
  132. self.activation_dropout = activation_dropout
  133. self.activation_function = activation_function
  134. self.init_std = init_std
  135. self.encoder_layerdrop = encoder_layerdrop
  136. self.decoder_layerdrop = decoder_layerdrop
  137. self.classifier_dropout = classifier_dropout
  138. self.use_cache = use_cache
  139. self.num_hidden_layers = encoder_layers
  140. self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
  141. super().__init__(
  142. pad_token_id=pad_token_id,
  143. bos_token_id=bos_token_id,
  144. eos_token_id=eos_token_id,
  145. is_encoder_decoder=is_encoder_decoder,
  146. forced_eos_token_id=forced_eos_token_id,
  147. **kwargs,
  148. )
  149. # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig with Bart->MBart
  150. class MBartOnnxConfig(OnnxSeq2SeqConfigWithPast):
  151. @property
  152. def inputs(self) -> Mapping[str, Mapping[int, str]]:
  153. if self.task in ["default", "seq2seq-lm"]:
  154. common_inputs = OrderedDict(
  155. [
  156. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  157. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  158. ]
  159. )
  160. if self.use_past:
  161. common_inputs["decoder_input_ids"] = {0: "batch"}
  162. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
  163. else:
  164. common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
  165. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
  166. if self.use_past:
  167. self.fill_with_past_key_values_(common_inputs, direction="inputs")
  168. elif self.task == "causal-lm":
  169. # TODO: figure this case out.
  170. common_inputs = OrderedDict(
  171. [
  172. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  173. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  174. ]
  175. )
  176. if self.use_past:
  177. num_encoder_layers, _ = self.num_layers
  178. for i in range(num_encoder_layers):
  179. common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  180. common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  181. else:
  182. common_inputs = OrderedDict(
  183. [
  184. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  185. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  186. ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
  187. ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
  188. ]
  189. )
  190. return common_inputs
  191. @property
  192. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  193. if self.task in ["default", "seq2seq-lm"]:
  194. common_outputs = super().outputs
  195. else:
  196. common_outputs = super(OnnxConfigWithPast, self).outputs
  197. if self.use_past:
  198. num_encoder_layers, _ = self.num_layers
  199. for i in range(num_encoder_layers):
  200. common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  201. common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  202. return common_outputs
  203. def _generate_dummy_inputs_for_default_and_seq2seq_lm(
  204. self,
  205. tokenizer: PreTrainedTokenizer,
  206. batch_size: int = -1,
  207. seq_length: int = -1,
  208. is_pair: bool = False,
  209. framework: Optional[TensorType] = None,
  210. ) -> Mapping[str, Any]:
  211. encoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  212. tokenizer, batch_size, seq_length, is_pair, framework
  213. )
  214. # Generate decoder inputs
  215. decoder_seq_length = seq_length if not self.use_past else 1
  216. decoder_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  217. tokenizer, batch_size, decoder_seq_length, is_pair, framework
  218. )
  219. decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
  220. common_inputs = dict(**encoder_inputs, **decoder_inputs)
  221. if self.use_past:
  222. if not is_torch_available():
  223. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  224. else:
  225. import torch
  226. batch, encoder_seq_length = common_inputs["input_ids"].shape
  227. decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
  228. num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
  229. encoder_shape = (
  230. batch,
  231. num_encoder_attention_heads,
  232. encoder_seq_length,
  233. self._config.hidden_size // num_encoder_attention_heads,
  234. )
  235. decoder_past_length = decoder_seq_length + 3
  236. decoder_shape = (
  237. batch,
  238. num_decoder_attention_heads,
  239. decoder_past_length,
  240. self._config.hidden_size // num_decoder_attention_heads,
  241. )
  242. common_inputs["decoder_attention_mask"] = torch.cat(
  243. [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
  244. )
  245. common_inputs["past_key_values"] = []
  246. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  247. num_encoder_layers, num_decoder_layers = self.num_layers
  248. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  249. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  250. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  251. for _ in range(min_num_layers):
  252. common_inputs["past_key_values"].append(
  253. (
  254. torch.zeros(decoder_shape),
  255. torch.zeros(decoder_shape),
  256. torch.zeros(encoder_shape),
  257. torch.zeros(encoder_shape),
  258. )
  259. )
  260. # TODO: test this.
  261. shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
  262. for _ in range(min_num_layers, max_num_layers):
  263. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  264. return common_inputs
  265. def _generate_dummy_inputs_for_causal_lm(
  266. self,
  267. tokenizer: PreTrainedTokenizer,
  268. batch_size: int = -1,
  269. seq_length: int = -1,
  270. is_pair: bool = False,
  271. framework: Optional[TensorType] = None,
  272. ) -> Mapping[str, Any]:
  273. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  274. tokenizer, batch_size, seq_length, is_pair, framework
  275. )
  276. if self.use_past:
  277. if not is_torch_available():
  278. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  279. else:
  280. import torch
  281. batch, seqlen = common_inputs["input_ids"].shape
  282. # Not using the same length for past_key_values
  283. past_key_values_length = seqlen + 2
  284. num_encoder_layers, _ = self.num_layers
  285. num_encoder_attention_heads, _ = self.num_attention_heads
  286. past_shape = (
  287. batch,
  288. num_encoder_attention_heads,
  289. past_key_values_length,
  290. self._config.hidden_size // num_encoder_attention_heads,
  291. )
  292. mask_dtype = common_inputs["attention_mask"].dtype
  293. common_inputs["attention_mask"] = torch.cat(
  294. [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
  295. )
  296. common_inputs["past_key_values"] = [
  297. (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
  298. ]
  299. return common_inputs
  300. def _generate_dummy_inputs_for_sequence_classification_and_question_answering(
  301. self,
  302. tokenizer: PreTrainedTokenizer,
  303. batch_size: int = -1,
  304. seq_length: int = -1,
  305. is_pair: bool = False,
  306. framework: Optional[TensorType] = None,
  307. ) -> Mapping[str, Any]:
  308. # Copied from OnnxConfig.generate_dummy_inputs
  309. # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
  310. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  311. batch_size = compute_effective_axis_dimension(
  312. batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
  313. )
  314. # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
  315. token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
  316. seq_length = compute_effective_axis_dimension(
  317. seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
  318. )
  319. # Generate dummy inputs according to compute batch and sequence
  320. dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
  321. common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
  322. return common_inputs
  323. def generate_dummy_inputs(
  324. self,
  325. tokenizer: PreTrainedTokenizer,
  326. batch_size: int = -1,
  327. seq_length: int = -1,
  328. is_pair: bool = False,
  329. framework: Optional[TensorType] = None,
  330. ) -> Mapping[str, Any]:
  331. if self.task in ["default", "seq2seq-lm"]:
  332. common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
  333. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  334. )
  335. elif self.task == "causal-lm":
  336. common_inputs = self._generate_dummy_inputs_for_causal_lm(
  337. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  338. )
  339. else:
  340. common_inputs = self._generate_dummy_inputs_for_sequence_classification_and_question_answering(
  341. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  342. )
  343. return common_inputs
  344. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  345. if self.task in ["default", "seq2seq-lm"]:
  346. flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
  347. else:
  348. flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
  349. flattened_output, name, idx, t
  350. )
  351. __all__ = ["MBartConfig", "MBartOnnxConfig"]