configuration_marian.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395
  1. # coding=utf-8
  2. # Copyright 2021 The Marian Team Authors and The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """Marian model configuration"""
  16. from collections import OrderedDict
  17. from collections.abc import Mapping
  18. from typing import Any, Optional
  19. from ... import PreTrainedTokenizer
  20. from ...configuration_utils import PretrainedConfig
  21. from ...onnx import OnnxConfig, OnnxConfigWithPast, OnnxSeq2SeqConfigWithPast
  22. from ...onnx.utils import compute_effective_axis_dimension
  23. from ...utils import TensorType, is_torch_available, logging
  24. logger = logging.get_logger(__name__)
  25. class MarianConfig(PretrainedConfig):
  26. r"""
  27. This is the configuration class to store the configuration of a [`MarianModel`]. It is used to instantiate an
  28. Marian model according to the specified arguments, defining the model architecture. Instantiating a configuration
  29. with the defaults will yield a similar configuration to that of the Marian
  30. [Helsinki-NLP/opus-mt-en-de](https://huggingface.co/Helsinki-NLP/opus-mt-en-de) architecture.
  31. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  32. documentation from [`PretrainedConfig`] for more information.
  33. Args:
  34. vocab_size (`int`, *optional*, defaults to 58101):
  35. Vocabulary size of the Marian model. Defines the number of different tokens that can be represented by the
  36. `inputs_ids` passed when calling [`MarianModel`] or [`TFMarianModel`].
  37. d_model (`int`, *optional*, defaults to 1024):
  38. Dimensionality of the layers and the pooler layer.
  39. encoder_layers (`int`, *optional*, defaults to 12):
  40. Number of encoder layers.
  41. decoder_layers (`int`, *optional*, defaults to 12):
  42. Number of decoder layers.
  43. encoder_attention_heads (`int`, *optional*, defaults to 16):
  44. Number of attention heads for each attention layer in the Transformer encoder.
  45. decoder_attention_heads (`int`, *optional*, defaults to 16):
  46. Number of attention heads for each attention layer in the Transformer decoder.
  47. decoder_ffn_dim (`int`, *optional*, defaults to 4096):
  48. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  49. encoder_ffn_dim (`int`, *optional*, defaults to 4096):
  50. Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
  51. activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
  52. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  53. `"relu"`, `"silu"` and `"gelu_new"` are supported.
  54. dropout (`float`, *optional*, defaults to 0.1):
  55. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  56. attention_dropout (`float`, *optional*, defaults to 0.0):
  57. The dropout ratio for the attention probabilities.
  58. activation_dropout (`float`, *optional*, defaults to 0.0):
  59. The dropout ratio for activations inside the fully connected layer.
  60. max_position_embeddings (`int`, *optional*, defaults to 1024):
  61. The maximum sequence length that this model might ever be used with. Typically set this to something large
  62. just in case (e.g., 512 or 1024 or 2048).
  63. init_std (`float`, *optional*, defaults to 0.02):
  64. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  65. encoder_layerdrop (`float`, *optional*, defaults to 0.0):
  66. The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  67. for more details.
  68. decoder_layerdrop (`float`, *optional*, defaults to 0.0):
  69. The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://huggingface.co/papers/1909.11556)
  70. for more details.
  71. scale_embedding (`bool`, *optional*, defaults to `False`):
  72. Scale embeddings by diving by sqrt(d_model).
  73. use_cache (`bool`, *optional*, defaults to `True`):
  74. Whether or not the model should return the last key/values attentions (not used by all models)
  75. forced_eos_token_id (`int`, *optional*, defaults to 0):
  76. The id of the token to force as the last generated token when `max_length` is reached. Usually set to
  77. `eos_token_id`.
  78. Examples:
  79. ```python
  80. >>> from transformers import MarianModel, MarianConfig
  81. >>> # Initializing a Marian Helsinki-NLP/opus-mt-en-de style configuration
  82. >>> configuration = MarianConfig()
  83. >>> # Initializing a model from the Helsinki-NLP/opus-mt-en-de style configuration
  84. >>> model = MarianModel(configuration)
  85. >>> # Accessing the model configuration
  86. >>> configuration = model.config
  87. ```"""
  88. model_type = "marian"
  89. keys_to_ignore_at_inference = ["past_key_values"]
  90. attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
  91. def __init__(
  92. self,
  93. vocab_size=58101,
  94. decoder_vocab_size=None,
  95. max_position_embeddings=1024,
  96. encoder_layers=12,
  97. encoder_ffn_dim=4096,
  98. encoder_attention_heads=16,
  99. decoder_layers=12,
  100. decoder_ffn_dim=4096,
  101. decoder_attention_heads=16,
  102. encoder_layerdrop=0.0,
  103. decoder_layerdrop=0.0,
  104. use_cache=True,
  105. is_encoder_decoder=True,
  106. activation_function="gelu",
  107. d_model=1024,
  108. dropout=0.1,
  109. attention_dropout=0.0,
  110. activation_dropout=0.0,
  111. init_std=0.02,
  112. decoder_start_token_id=58100,
  113. scale_embedding=False,
  114. pad_token_id=58100,
  115. eos_token_id=0,
  116. forced_eos_token_id=0,
  117. share_encoder_decoder_embeddings=True,
  118. **kwargs,
  119. ):
  120. self.vocab_size = vocab_size
  121. self.decoder_vocab_size = decoder_vocab_size or vocab_size
  122. self.max_position_embeddings = max_position_embeddings
  123. self.d_model = d_model
  124. self.encoder_ffn_dim = encoder_ffn_dim
  125. self.encoder_layers = encoder_layers
  126. self.encoder_attention_heads = encoder_attention_heads
  127. self.decoder_ffn_dim = decoder_ffn_dim
  128. self.decoder_layers = decoder_layers
  129. self.decoder_attention_heads = decoder_attention_heads
  130. self.dropout = dropout
  131. self.attention_dropout = attention_dropout
  132. self.activation_dropout = activation_dropout
  133. self.activation_function = activation_function
  134. self.init_std = init_std
  135. self.encoder_layerdrop = encoder_layerdrop
  136. self.decoder_layerdrop = decoder_layerdrop
  137. self.use_cache = use_cache
  138. self.num_hidden_layers = encoder_layers
  139. self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
  140. self.share_encoder_decoder_embeddings = share_encoder_decoder_embeddings
  141. super().__init__(
  142. pad_token_id=pad_token_id,
  143. eos_token_id=eos_token_id,
  144. is_encoder_decoder=is_encoder_decoder,
  145. decoder_start_token_id=decoder_start_token_id,
  146. forced_eos_token_id=forced_eos_token_id,
  147. **kwargs,
  148. )
  149. class MarianOnnxConfig(OnnxSeq2SeqConfigWithPast):
  150. @property
  151. # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.inputs
  152. def inputs(self) -> Mapping[str, Mapping[int, str]]:
  153. if self.task in ["default", "seq2seq-lm"]:
  154. common_inputs = OrderedDict(
  155. [
  156. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  157. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  158. ]
  159. )
  160. if self.use_past:
  161. common_inputs["decoder_input_ids"] = {0: "batch"}
  162. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "past_decoder_sequence + sequence"}
  163. else:
  164. common_inputs["decoder_input_ids"] = {0: "batch", 1: "decoder_sequence"}
  165. common_inputs["decoder_attention_mask"] = {0: "batch", 1: "decoder_sequence"}
  166. if self.use_past:
  167. self.fill_with_past_key_values_(common_inputs, direction="inputs")
  168. elif self.task == "causal-lm":
  169. # TODO: figure this case out.
  170. common_inputs = OrderedDict(
  171. [
  172. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  173. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  174. ]
  175. )
  176. if self.use_past:
  177. num_encoder_layers, _ = self.num_layers
  178. for i in range(num_encoder_layers):
  179. common_inputs[f"past_key_values.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  180. common_inputs[f"past_key_values.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  181. else:
  182. common_inputs = OrderedDict(
  183. [
  184. ("input_ids", {0: "batch", 1: "encoder_sequence"}),
  185. ("attention_mask", {0: "batch", 1: "encoder_sequence"}),
  186. ("decoder_input_ids", {0: "batch", 1: "decoder_sequence"}),
  187. ("decoder_attention_mask", {0: "batch", 1: "decoder_sequence"}),
  188. ]
  189. )
  190. return common_inputs
  191. @property
  192. # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig.outputs
  193. def outputs(self) -> Mapping[str, Mapping[int, str]]:
  194. if self.task in ["default", "seq2seq-lm"]:
  195. common_outputs = super().outputs
  196. else:
  197. common_outputs = super(OnnxConfigWithPast, self).outputs
  198. if self.use_past:
  199. num_encoder_layers, _ = self.num_layers
  200. for i in range(num_encoder_layers):
  201. common_outputs[f"present.{i}.key"] = {0: "batch", 2: "past_sequence + sequence"}
  202. common_outputs[f"present.{i}.value"] = {0: "batch", 2: "past_sequence + sequence"}
  203. return common_outputs
  204. def _generate_dummy_inputs_for_default_and_seq2seq_lm(
  205. self,
  206. tokenizer: PreTrainedTokenizer,
  207. batch_size: int = -1,
  208. seq_length: int = -1,
  209. is_pair: bool = False,
  210. framework: Optional[TensorType] = None,
  211. ) -> Mapping[str, Any]:
  212. encoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
  213. tokenizer, batch_size, seq_length, is_pair, framework
  214. )
  215. # Generate decoder inputs
  216. decoder_seq_length = seq_length if not self.use_past else 1
  217. decoder_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
  218. tokenizer, batch_size, decoder_seq_length, is_pair, framework
  219. )
  220. decoder_inputs = {f"decoder_{name}": tensor for name, tensor in decoder_inputs.items()}
  221. common_inputs = dict(**encoder_inputs, **decoder_inputs)
  222. if self.use_past:
  223. if not is_torch_available():
  224. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  225. else:
  226. import torch
  227. batch, encoder_seq_length = common_inputs["input_ids"].shape
  228. decoder_seq_length = common_inputs["decoder_input_ids"].shape[1]
  229. num_encoder_attention_heads, num_decoder_attention_heads = self.num_attention_heads
  230. encoder_shape = (
  231. batch,
  232. num_encoder_attention_heads,
  233. encoder_seq_length,
  234. self._config.hidden_size // num_encoder_attention_heads,
  235. )
  236. decoder_past_length = decoder_seq_length + 3
  237. decoder_shape = (
  238. batch,
  239. num_decoder_attention_heads,
  240. decoder_past_length,
  241. self._config.hidden_size // num_decoder_attention_heads,
  242. )
  243. common_inputs["decoder_attention_mask"] = torch.cat(
  244. [common_inputs["decoder_attention_mask"], torch.ones(batch, decoder_past_length)], dim=1
  245. )
  246. common_inputs["past_key_values"] = []
  247. # If the number of encoder and decoder layers are present in the model configuration, both are considered
  248. num_encoder_layers, num_decoder_layers = self.num_layers
  249. min_num_layers = min(num_encoder_layers, num_decoder_layers)
  250. max_num_layers = max(num_encoder_layers, num_decoder_layers) - min_num_layers
  251. remaining_side_name = "encoder" if num_encoder_layers > num_decoder_layers else "decoder"
  252. for _ in range(min_num_layers):
  253. common_inputs["past_key_values"].append(
  254. (
  255. torch.zeros(decoder_shape),
  256. torch.zeros(decoder_shape),
  257. torch.zeros(encoder_shape),
  258. torch.zeros(encoder_shape),
  259. )
  260. )
  261. # TODO: test this.
  262. shape = encoder_shape if remaining_side_name == "encoder" else decoder_shape
  263. for _ in range(min_num_layers, max_num_layers):
  264. common_inputs["past_key_values"].append((torch.zeros(shape), torch.zeros(shape)))
  265. return common_inputs
  266. def _generate_dummy_inputs_for_causal_lm(
  267. self,
  268. tokenizer: PreTrainedTokenizer,
  269. batch_size: int = -1,
  270. seq_length: int = -1,
  271. is_pair: bool = False,
  272. framework: Optional[TensorType] = None,
  273. ) -> Mapping[str, Any]:
  274. common_inputs = self._generate_dummy_inputs_for_encoder_and_decoder(
  275. tokenizer, batch_size, seq_length, is_pair, framework
  276. )
  277. if self.use_past:
  278. if not is_torch_available():
  279. raise ValueError("Cannot generate dummy past_keys inputs without PyTorch installed.")
  280. else:
  281. import torch
  282. batch, seqlen = common_inputs["input_ids"].shape
  283. # Not using the same length for past_key_values
  284. past_key_values_length = seqlen + 2
  285. num_encoder_layers, _ = self.num_layers
  286. num_encoder_attention_heads, _ = self.num_attention_heads
  287. past_shape = (
  288. batch,
  289. num_encoder_attention_heads,
  290. past_key_values_length,
  291. self._config.hidden_size // num_encoder_attention_heads,
  292. )
  293. mask_dtype = common_inputs["attention_mask"].dtype
  294. common_inputs["attention_mask"] = torch.cat(
  295. [common_inputs["attention_mask"], torch.ones(batch, past_key_values_length, dtype=mask_dtype)], dim=1
  296. )
  297. common_inputs["past_key_values"] = [
  298. (torch.zeros(past_shape), torch.zeros(past_shape)) for _ in range(num_encoder_layers)
  299. ]
  300. return common_inputs
  301. # Copied from BartOnnxConfig._generate_dummy_inputs_for_sequence_classification_and_question_answering
  302. # We renamed this function because Marian models do not have a sequence classification or question answering head
  303. def _generate_dummy_inputs_for_encoder_and_decoder(
  304. self,
  305. tokenizer: PreTrainedTokenizer,
  306. batch_size: int = -1,
  307. seq_length: int = -1,
  308. is_pair: bool = False,
  309. framework: Optional[TensorType] = None,
  310. ) -> Mapping[str, Any]:
  311. # Copied from OnnxConfig.generate_dummy_inputs
  312. # Did not use super(OnnxConfigWithPast, self).generate_dummy_inputs for code clarity.
  313. # If dynamic axis (-1) we forward with a fixed dimension of 2 samples to avoid optimizations made by ONNX
  314. batch_size = compute_effective_axis_dimension(
  315. batch_size, fixed_dimension=OnnxConfig.default_fixed_batch, num_token_to_add=0
  316. )
  317. # If dynamic axis (-1) we forward with a fixed dimension of 8 tokens to avoid optimizations made by ONNX
  318. token_to_add = tokenizer.num_special_tokens_to_add(is_pair)
  319. seq_length = compute_effective_axis_dimension(
  320. seq_length, fixed_dimension=OnnxConfig.default_fixed_sequence, num_token_to_add=token_to_add
  321. )
  322. # Generate dummy inputs according to compute batch and sequence
  323. dummy_input = [" ".join([tokenizer.unk_token]) * seq_length] * batch_size
  324. common_inputs = dict(tokenizer(dummy_input, return_tensors=framework))
  325. return common_inputs
  326. def generate_dummy_inputs(
  327. self,
  328. tokenizer: PreTrainedTokenizer,
  329. batch_size: int = -1,
  330. seq_length: int = -1,
  331. is_pair: bool = False,
  332. framework: Optional[TensorType] = None,
  333. ) -> Mapping[str, Any]:
  334. if self.task in ["default", "seq2seq-lm"]:
  335. common_inputs = self._generate_dummy_inputs_for_default_and_seq2seq_lm(
  336. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  337. )
  338. else:
  339. common_inputs = self._generate_dummy_inputs_for_causal_lm(
  340. tokenizer, batch_size=batch_size, seq_length=seq_length, is_pair=is_pair, framework=framework
  341. )
  342. return common_inputs
  343. # Copied from transformers.models.bart.configuration_bart.BartOnnxConfig._flatten_past_key_values_
  344. def _flatten_past_key_values_(self, flattened_output, name, idx, t):
  345. if self.task in ["default", "seq2seq-lm"]:
  346. flattened_output = super()._flatten_past_key_values_(flattened_output, name, idx, t)
  347. else:
  348. flattened_output = super(OnnxSeq2SeqConfigWithPast, self)._flatten_past_key_values_(
  349. flattened_output, name, idx, t
  350. )
  351. @property
  352. def atol_for_validation(self) -> float:
  353. return 1e-4
  354. __all__ = ["MarianConfig", "MarianOnnxConfig"]