configuration_flava.py 33 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697
  1. # coding=utf-8
  2. # Copyright 2022 Meta Platforms authors and The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """FLAVA model configurations"""
  16. from typing import Any, Optional
  17. from ...configuration_utils import PretrainedConfig
  18. from ...utils import logging
  19. logger = logging.get_logger(__name__)
  20. class FlavaImageConfig(PretrainedConfig):
  21. r"""
  22. This is the configuration class to store the configuration of a [`FlavaImageModel`]. It is used to instantiate an
  23. FLAVA model according to the specified arguments, defining the model architecture.
  24. Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
  25. [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
  26. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  27. documentation from [`PretrainedConfig`] for more information.
  28. Args:
  29. hidden_size (`int`, *optional*, defaults to 768):
  30. Dimensionality of the encoder layers and the pooler layer.
  31. num_hidden_layers (`int`, *optional*, defaults to 12):
  32. Number of hidden layers in the Transformer encoder.
  33. num_attention_heads (`int`, *optional*, defaults to 12):
  34. Number of attention heads for each attention layer in the Transformer encoder.
  35. intermediate_size (`int`, *optional*, defaults to 3072):
  36. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  37. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  38. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  39. `"relu"`, `"selu"` and `"gelu_new"` are supported.
  40. hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
  41. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  42. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
  43. The dropout ratio for the attention probabilities.
  44. initializer_range (`float`, *optional*, defaults to 0.02):
  45. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  46. layer_norm_eps (`float`, *optional*, defaults to 1e-12):
  47. The epsilon used by the layer normalization layers.
  48. image_size (`int`, *optional*, defaults to 224):
  49. The size (resolution) of each image.
  50. patch_size (`int`, *optional*, defaults to 16):
  51. The size (resolution) of each patch.
  52. num_channels (`int`, *optional*, defaults to 3):
  53. The number of input channels.
  54. qkv_bias (`bool`, *optional*, defaults to `True`):
  55. Whether to add a bias to the queries, keys and values.
  56. mask_token (`bool`, *optional*, defaults to `True`):
  57. Whether to use a mask token or not. Used in MIM (Masked Image Modeling) loss for FLAVA.
  58. vocab_size (`int`, *optional*, defaults to 8192):
  59. Vocabulary size of the [`FlavaImageCodebook`] used in conjunction with [`FlavaImageModel`] for MIM (Masked
  60. Image Modeling) loss for FLAVA.
  61. Example:
  62. ```python
  63. >>> from transformers import FlavaImageConfig, FlavaImageModel
  64. >>> # Initializing a FlavaImageModel with style configuration
  65. >>> configuration = FlavaImageConfig()
  66. >>> # Initializing a FlavaImageModel model (with random weights) from the style configuration
  67. >>> model = FlavaImageModel(configuration)
  68. >>> # Accessing the model configuration
  69. >>> configuration = model.config
  70. ```"""
  71. model_type = "flava_image_model"
  72. base_config_key = "image_config"
  73. def __init__(
  74. self,
  75. hidden_size: int = 768,
  76. num_hidden_layers: int = 12,
  77. num_attention_heads: int = 12,
  78. intermediate_size: int = 3072,
  79. hidden_act: int = "gelu",
  80. hidden_dropout_prob: float = 0.0,
  81. attention_probs_dropout_prob: float = 0.0,
  82. initializer_range: float = 0.02,
  83. layer_norm_eps: float = 1e-12,
  84. image_size: int = 224,
  85. patch_size: int = 16,
  86. num_channels: int = 3,
  87. qkv_bias: bool = True,
  88. mask_token: bool = True,
  89. vocab_size: int = 8192,
  90. **kwargs,
  91. ):
  92. super().__init__(**kwargs)
  93. self.hidden_size = hidden_size
  94. self.num_hidden_layers = num_hidden_layers
  95. self.num_attention_heads = num_attention_heads
  96. self.intermediate_size = intermediate_size
  97. self.hidden_act = hidden_act
  98. self.hidden_dropout_prob = hidden_dropout_prob
  99. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  100. self.initializer_range = initializer_range
  101. self.layer_norm_eps = layer_norm_eps
  102. self.image_size = image_size
  103. self.patch_size = patch_size
  104. self.num_channels = num_channels
  105. self.qkv_bias = qkv_bias
  106. self.mask_token = mask_token
  107. self.vocab_size = vocab_size
  108. class FlavaTextConfig(PretrainedConfig):
  109. r"""
  110. This is the configuration class to store the configuration of a [`FlavaTextModel`]. It is used to instantiate an
  111. FLAVA model according to the specified arguments, defining the model architecture.
  112. Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
  113. [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
  114. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  115. documentation from [`PretrainedConfig`] for more information.
  116. Args:
  117. vocab_size (`int`, *optional*, defaults to 30522):
  118. Vocabulary size of the BERT model. Defines the number of different tokens that can be represented by the
  119. `inputs_ids` passed when calling [`FlavaTextModel`].
  120. type_vocab_size (`int`, *optional*, defaults to 2):
  121. The vocabulary size of the `token_type_ids` passed when calling [`FlavaTextModel`]. Note that even though
  122. text encoder allows `token_type_ids`'s value as 2, for text-only pretraining and fine-tuning, only 1 is
  123. used similar to RoBERTa.
  124. max_position_embeddings (`int`, *optional*, defaults to 512):
  125. The maximum sequence length that this model might ever be used with. Typically set this to something large
  126. just in case (e.g., 512 or 1024 or 2048). For VL, max_length passed to model is 77.
  127. position_embedding_type (`str`, *optional*, defaults to `"absolute"`):
  128. Type of position embedding. Choose one of `"absolute"`, `"relative_key"`, `"relative_key_query"`. For
  129. positional embeddings use `"absolute"`. For more information on `"relative_key"`, please refer to
  130. [Self-Attention with Relative Position Representations (Shaw et al.)](https://huggingface.co/papers/1803.02155).
  131. For more information on `"relative_key_query"`, please refer to *Method 4* in [Improve Transformer Models
  132. with Better Relative Position Embeddings (Huang et al.)](https://huggingface.co/papers/2009.13658).
  133. hidden_size (`int`, *optional*, defaults to 768):
  134. Dimensionality of the encoder layers and the pooler layer.
  135. num_hidden_layers (`int`, *optional*, defaults to 12):
  136. Number of hidden layers in the Transformer encoder.
  137. num_attention_heads (`int`, *optional*, defaults to 12):
  138. Number of attention heads for each attention layer in the Transformer encoder.
  139. intermediate_size (`int`, *optional*, defaults to 3072):
  140. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  141. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  142. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  143. `"relu"`, `"selu"` and `"gelu_new"` are supported.
  144. hidden_dropout_prob (`float`, *optional*, defaults to 0.1):
  145. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  146. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.1):
  147. The dropout ratio for the attention probabilities.
  148. initializer_range (`float`, *optional*, defaults to 0.02):
  149. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  150. layer_norm_eps (`float`, *optional*, defaults to 1e-12):
  151. The epsilon used by the layer normalization layers.
  152. image_size (`int`, *optional*, defaults to 224):
  153. The size (resolution) of each image.
  154. patch_size (`int`, *optional*, defaults to 16):
  155. The size (resolution) of each patch.
  156. num_channels (`int`, *optional*, defaults to 3):
  157. The number of input channels.
  158. qkv_bias (`bool`, *optional*, defaults to `True`):
  159. Whether to add a bias to the queries, keys and values.
  160. Example:
  161. ```python
  162. >>> from transformers import FlavaTextConfig, FlavaTextModel
  163. >>> # Initializing a FlavaTextModel with style configuration
  164. >>> configuration = FlavaTextConfig()
  165. >>> # Initializing a FlavaTextModel model (with random weights) from the style configuration
  166. >>> model = FlavaTextModel(configuration)
  167. >>> # Accessing the model configuration
  168. >>> configuration = model.config
  169. ```"""
  170. model_type = "flava_text_model"
  171. base_config_key = "text_config"
  172. def __init__(
  173. self,
  174. vocab_size: int = 30522,
  175. type_vocab_size: int = 2,
  176. max_position_embeddings: int = 512,
  177. position_embedding_type: str = "absolute",
  178. hidden_size: int = 768,
  179. num_hidden_layers: int = 12,
  180. num_attention_heads: int = 12,
  181. intermediate_size: int = 3072,
  182. hidden_act: str = "gelu",
  183. hidden_dropout_prob: float = 0.0,
  184. attention_probs_dropout_prob: float = 0.0,
  185. initializer_range: float = 0.02,
  186. layer_norm_eps: float = 1e-12,
  187. pad_token_id: int = 0,
  188. qkv_bias: bool = True,
  189. **kwargs,
  190. ):
  191. super().__init__(**kwargs)
  192. self.vocab_size = vocab_size
  193. self.type_vocab_size = type_vocab_size
  194. self.max_position_embeddings = max_position_embeddings
  195. self.position_embedding_type = position_embedding_type
  196. self.hidden_size = hidden_size
  197. self.num_hidden_layers = num_hidden_layers
  198. self.num_attention_heads = num_attention_heads
  199. self.intermediate_size = intermediate_size
  200. self.hidden_act = hidden_act
  201. self.hidden_dropout_prob = hidden_dropout_prob
  202. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  203. self.initializer_range = initializer_range
  204. self.layer_norm_eps = layer_norm_eps
  205. self.qkv_bias = qkv_bias
  206. self.pad_token_id = pad_token_id
  207. class FlavaMultimodalConfig(PretrainedConfig):
  208. r"""
  209. This is the configuration class to store the configuration of a [`FlavaMultimodalModel`]. It is used to instantiate
  210. an FLAVA model according to the specified arguments, defining the model architecture.
  211. Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
  212. [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
  213. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  214. documentation from [`PretrainedConfig`] for more information.
  215. Args:
  216. hidden_size (`int`, *optional*, defaults to 768):
  217. Dimensionality of the encoder layers and the pooler layer.
  218. num_hidden_layers (`int`, *optional*, defaults to 6):
  219. Number of hidden layers in the Transformer encoder.
  220. num_attention_heads (`int`, *optional*, defaults to 12):
  221. Number of attention heads for each attention layer in the Transformer encoder.
  222. intermediate_size (`int`, *optional*, defaults to 3072):
  223. Dimensionality of the "intermediate" (i.e., feed-forward) layer in the Transformer encoder.
  224. hidden_act (`str` or `function`, *optional*, defaults to `"gelu"`):
  225. The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
  226. `"relu"`, `"selu"` and `"gelu_new"` are supported.
  227. hidden_dropout_prob (`float`, *optional*, defaults to 0.0):
  228. The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
  229. attention_probs_dropout_prob (`float`, *optional*, defaults to 0.0):
  230. The dropout ratio for the attention probabilities.
  231. initializer_range (`float`, *optional*, defaults to 0.02):
  232. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  233. layer_norm_eps (`float`, *optional*, defaults to 1e-12):
  234. The epsilon used by the layer normalization layers.
  235. qkv_bias (`bool`, *optional*, defaults to `True`):
  236. Whether to add a bias to the queries, keys and values.
  237. use_cls_token (`bool`, *optional*, defaults to `True`):
  238. Whether to use an extra CLS token for multimodal settings. Usually needed by the FLAVA model.
  239. Example:
  240. ```python
  241. >>> from transformers import FlavaMultimodalConfig, FlavaMultimodalModel
  242. >>> # Initializing a FlavaMultimodalModel with style configuration
  243. >>> configuration = FlavaMultimodalConfig()
  244. >>> # Initializing a FlavaMultimodalModel model (with random weights) from the style configuration
  245. >>> model = FlavaMultimodalModel(configuration)
  246. >>> # Accessing the model configuration
  247. >>> configuration = model.config
  248. ```"""
  249. model_type = "flava_multimodal_model"
  250. base_config_key = "multimodal_config"
  251. def __init__(
  252. self,
  253. hidden_size: int = 768,
  254. num_hidden_layers: int = 6,
  255. num_attention_heads: int = 12,
  256. intermediate_size: int = 3072,
  257. hidden_act: int = "gelu",
  258. hidden_dropout_prob: int = 0.0,
  259. attention_probs_dropout_prob: int = 0.0,
  260. initializer_range: float = 0.02,
  261. layer_norm_eps: float = 1e-12,
  262. qkv_bias: bool = True,
  263. use_cls_token: bool = True,
  264. **kwargs,
  265. ):
  266. super().__init__(**kwargs)
  267. self.hidden_size = hidden_size
  268. self.num_hidden_layers = num_hidden_layers
  269. self.num_attention_heads = num_attention_heads
  270. self.intermediate_size = intermediate_size
  271. self.hidden_act = hidden_act
  272. self.hidden_dropout_prob = hidden_dropout_prob
  273. self.attention_probs_dropout_prob = attention_probs_dropout_prob
  274. self.initializer_range = initializer_range
  275. self.layer_norm_eps = layer_norm_eps
  276. self.qkv_bias = qkv_bias
  277. self.use_cls_token = use_cls_token
  278. class FlavaImageCodebookConfig(PretrainedConfig):
  279. model_type = "flava_image_codebook"
  280. base_config_key = "image_codebook_config"
  281. r"""
  282. [`FlavaImageCodebookConfig`] is the configuration class to store the configuration of a [`FlavaImageCodebook`]. It
  283. is used to instantiate an FLAVA model according to the specified arguments, defining the model architecture.
  284. Instantiating a configuration with the defaults will yield a similar configuration to that of the FLAVA
  285. [facebook/flava-image-codebook](https://huggingface.co/facebook/flava-image-codebook) architecture.
  286. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  287. documentation from [`PretrainedConfig`] for more information.
  288. Args:
  289. num_groups (`int`, *optional*, defaults to 4):
  290. Number of groups to be created. This parameter as of now doesn't affect the model and is used for some
  291. internal calculation and estimations.
  292. input_channels (`int`, *optional*, defaults to 3):
  293. Number of channels in the image to be passed.
  294. num_blocks_per_group (`int`, *optional*, defaults to 2):
  295. Number of conv-based blocks per group.
  296. hidden_size (`int`, *optional*, defaults to 256):
  297. Size of hidden dim for the blocks.
  298. vocab_size (`int`, *optional*, defaults to 8192):
  299. Size of the output vocabulary for the codebook.
  300. freeze (`bool`, defaults to `True`):
  301. Whether to freeze the weights of the model.
  302. initializer_range (`float`, *optional*, defaults to 0.02):
  303. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  304. kwargs (*optional*):
  305. Dictionary of keyword arguments.
  306. Example:
  307. ```python
  308. >>> from transformers import FlavaImageCodebookConfig, FlavaImageCodebook
  309. >>> # Initializing a FlavaImageCodebook with style configuration
  310. >>> configuration = FlavaImageCodebookConfig()
  311. >>> # Initializing a FlavaImageCodebook model (with random weights) from the style configuration
  312. >>> model = FlavaImageCodebook(configuration)
  313. >>> # Accessing the model configuration
  314. >>> configuration = model.config
  315. ```
  316. """
  317. def __init__(
  318. self,
  319. num_groups: int = 4,
  320. input_channels: int = 3,
  321. num_blocks_per_group: int = 2,
  322. hidden_size: int = 256,
  323. vocab_size: int = 8192,
  324. freeze: int = True,
  325. initializer_range: float = 0.02,
  326. **kwargs,
  327. ):
  328. super().__init__(**kwargs)
  329. self.num_groups = num_groups
  330. self.input_channels = input_channels
  331. self.num_blocks_per_group = num_blocks_per_group
  332. self.hidden_size = hidden_size
  333. self.vocab_size = vocab_size
  334. self.freeze = freeze
  335. self.initializer_range = initializer_range
  336. class FlavaConfig(PretrainedConfig):
  337. r"""
  338. [`FlavaConfig`] is the configuration class to store the configuration of a [`FlavaModel`]. It is used to
  339. instantiate FLAVA model according to the specified arguments, defining the text model, image model, image codebook
  340. and multimodal model configs. Instantiating a configuration with the defaults will yield a similar configuration to
  341. that of the FLAVA [facebook/flava-full](https://huggingface.co/facebook/flava-full) architecture.
  342. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  343. documentation from [`PretrainedConfig`] for more information.
  344. Args:
  345. text_config (`dict`, *optional*):
  346. Dictionary of configuration options used to initialize [`FlavaTextConfig`].
  347. image_config (`dict`, *optional*):
  348. Dictionary of configuration options used to initialize [`FlavaImageConfig`].
  349. multimodal_config (`dict`, *optional*):
  350. Dictionary of configuration options used to initialize [`FlavaMultimodalConfig`].
  351. hidden_size (`int`, *optional*, defaults to 768):
  352. Dimensionality of the encoder layers and the pooler layer.
  353. layer_norm_eps (`float`, *optional*, defaults to 1e-12):
  354. The epsilon used by the layer normalization layers.
  355. projection_dim (`int`, *optional*, defaults to 512):
  356. Dimensionality of text and image projection layers.
  357. logit_scale_init_value (`float`, *optional*, defaults to 2.6592):
  358. The initial value of the *logit_scale* parameter. Default is used as per the original FLAVA/CLIP
  359. implementation.
  360. initializer_range (`float`, *optional*, defaults to 0.02):
  361. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  362. ce_ignore_index (`int`, *optional*, defaults to -100):
  363. Cross entropy index to ignore.
  364. mim_weight (`float`, *optional*, defaults to 1.0):
  365. Weight to be assigned to MIM (Masked Image Modeling) unimodal loss
  366. mlm_weight (`float`, *optional*, defaults to 1.0):
  367. Weight to be assigned to MLM (Masked Language Modeling) unimodal loss
  368. global_contrastive_weight (`float`, *optional*, defaults to 1.0):
  369. Weight to be assigned to global contrastive cross-alignment loss.
  370. itm_weight (`float`, *optional*, defaults to 1.0):
  371. Weight to be assigned to image-text matching multimodal loss.
  372. mmm_image_weight (`float`, *optional*, defaults to 1.0):
  373. Weight to be assigned to MMM loss's image part.
  374. mmm_text_weight (`float`, *optional*, defaults to 1.0):
  375. Weight to be assigned to MMM loss's text part.
  376. global_backprop_contrastive (`bool`, *optional*, defaults to `True`):
  377. Whether to use global backpropgation through all workers in contrastive loss.
  378. skip_unmasked_multimodal_encoder (`bool`, *optional*, defaults to `True`):
  379. Whether to skip running unmasked multimodal encoder whose outputs are not used by FLAVA losses.
  380. return_loss (`bool`, *optional*, defaults to `True`):
  381. Whether to return loss or not
  382. kwargs (*optional*):
  383. Dictionary of keyword arguments.
  384. Example:
  385. ```python
  386. >>> from transformers import FlavaConfig, FlavaModel, FlavaForPreTraining
  387. >>> # Initializing a FlavaConfig with style configuration
  388. >>> configuration = FlavaConfig()
  389. >>> # Initializing a FlavaModel and FlavaForPreTraining model (with random weights) from the style configuration
  390. >>> model = FlavaModel(configuration)
  391. >>> model_pre = FlavaForPreTraining(configuration)
  392. >>> # Accessing the model configuration
  393. >>> configuration = model.config
  394. >>> configuration_pre = model_pre.config
  395. ```
  396. """
  397. model_type = "flava"
  398. sub_configs = {
  399. "text_config": FlavaTextConfig,
  400. "image_config": FlavaImageConfig,
  401. "multimodal_config": FlavaMultimodalConfig,
  402. "image_codebook_config": FlavaImageCodebookConfig,
  403. }
  404. def __init__(
  405. self,
  406. image_config: Optional[dict[str, Any]] = None,
  407. text_config: Optional[dict[str, Any]] = None,
  408. multimodal_config: Optional[dict[str, Any]] = None,
  409. image_codebook_config: Optional[dict[str, Any]] = None,
  410. hidden_size: int = 768,
  411. layer_norm_eps: float = 1e-12,
  412. projection_dim: int = 768,
  413. init_codebook: bool = True,
  414. logit_scale_init_value: float = 2.6592,
  415. initializer_range: float = 0.02,
  416. ce_ignore_index: int = -100,
  417. mim_weight: float = 1.0,
  418. mlm_weight: float = 1.0,
  419. global_contrastive_weight: float = 1.0,
  420. itm_weight: float = 1.0,
  421. mmm_image_weight: float = 1.0,
  422. mmm_text_weight: float = 1.0,
  423. global_backprop_contrastive: bool = True,
  424. skip_unmasked_multimodal_encoder: bool = True,
  425. return_loss: bool = True,
  426. **kwargs,
  427. ):
  428. # If `_config_dict` exist, we use them for the backward compatibility.
  429. # We pop out these 2 attributes before calling `super().__init__` to avoid them being saved (which causes a lot
  430. # of confusion!).
  431. text_config_dict = kwargs.pop("text_config_dict", None)
  432. image_config_dict = kwargs.pop("image_config_dict", None)
  433. multimodal_config_dict = kwargs.pop("multimodal_config_dict", None)
  434. image_codebook_config_dict = kwargs.pop("image_codebook_config_dict", None)
  435. super().__init__(**kwargs)
  436. # Instead of simply assigning `[text|vision]_config_dict` to `[text|vision]_config`, we use the values in
  437. # `[text|vision]_config_dict` to update the values in `[text|vision]_config`. The values should be same in most
  438. # cases, but we don't want to break anything regarding `_config_dict` that existed before commit `8827e1b2`.
  439. if text_config_dict is not None:
  440. if text_config is None:
  441. text_config = {}
  442. # This is the complete result when using `text_config_dict`.
  443. _text_config_dict = FlavaTextConfig(**text_config_dict).to_dict()
  444. # Give a warning if the values exist in both `_text_config_dict` and `text_config` but being different.
  445. for key, value in _text_config_dict.items():
  446. if key in text_config and value != text_config[key] and key != "transformers_version":
  447. # If specified in `text_config_dict`
  448. if key in text_config_dict:
  449. message = (
  450. f"`{key}` is found in both `text_config_dict` and `text_config` but with different values. "
  451. f'The value `text_config_dict["{key}"]` will be used instead.'
  452. )
  453. # If inferred from default argument values (just to be super careful)
  454. else:
  455. message = (
  456. f"`text_config_dict` is provided which will be used to initialize `FlavaTextConfig`. The "
  457. f'value `text_config["{key}"]` will be overridden.'
  458. )
  459. logger.info(message)
  460. # Update all values in `text_config` with the ones in `_text_config_dict`.
  461. text_config.update(_text_config_dict)
  462. if image_config_dict is not None:
  463. if image_config is None:
  464. image_config = {}
  465. # This is the complete result when using `image_config_dict`.
  466. _image_config_dict = FlavaImageConfig(**image_config_dict).to_dict()
  467. # convert keys to string instead of integer
  468. if "id2label" in _image_config_dict:
  469. _image_config_dict["id2label"] = {
  470. str(key): value for key, value in _image_config_dict["id2label"].items()
  471. }
  472. # Give a warning if the values exist in both `_image_config_dict` and `image_config` but being different.
  473. for key, value in _image_config_dict.items():
  474. if key in image_config and value != image_config[key] and key != "transformers_version":
  475. # If specified in `image_config_dict`
  476. if key in image_config_dict:
  477. message = (
  478. f"`{key}` is found in both `image_config_dict` and `image_config` but with different "
  479. f'values. The value `image_config_dict["{key}"]` will be used instead.'
  480. )
  481. # If inferred from default argument values (just to be super careful)
  482. else:
  483. message = (
  484. f"`image_config_dict` is provided which will be used to initialize `FlavaImageConfig`. "
  485. f'The value `image_config["{key}"]` will be overridden.'
  486. )
  487. logger.info(message)
  488. # Update all values in `image_config` with the ones in `_image_config_dict`.
  489. image_config.update(_image_config_dict)
  490. if multimodal_config_dict is not None:
  491. if multimodal_config is None:
  492. multimodal_config = {}
  493. # This is the complete result when using `multimodal_config_dict`.
  494. _multimodal_config_dict = FlavaMultimodalConfig(**multimodal_config_dict).to_dict()
  495. # Give a warning if the values exist in both `_multimodal_config_dict` and `multimodal_config` but being
  496. # different.
  497. for key, value in _multimodal_config_dict.items():
  498. if key in multimodal_config and value != multimodal_config[key] and key != "transformers_version":
  499. # If specified in `multimodal_config_dict`
  500. if key in multimodal_config_dict:
  501. message = (
  502. f"`{key}` is found in both `multimodal_config_dict` and `multimodal_config` but with "
  503. f'different values. The value `multimodal_config_dict["{key}"]` will be used instead.'
  504. )
  505. # If inferred from default argument values (just to be super careful)
  506. else:
  507. message = (
  508. f"`multimodal_config_dict` is provided which will be used to initialize "
  509. f'`FlavaMultimodalConfig`. The value `multimodal_config["{key}"]` will be overridden.'
  510. )
  511. logger.info(message)
  512. # Update all values in `multimodal_config` with the ones in `_multimodal_config_dict`.
  513. multimodal_config.update(_multimodal_config_dict)
  514. if image_codebook_config_dict is not None:
  515. if image_codebook_config is None:
  516. image_codebook_config = {}
  517. # This is the complete result when using `image_codebook_config_dict`.
  518. _image_codebook_config_dict = FlavaImageCodebookConfig(**image_codebook_config_dict).to_dict()
  519. # Give a warning if the values exist in both `_image_codebook_config_dict` and `image_codebook_config` but
  520. # being different.
  521. for key, value in _image_codebook_config_dict.items():
  522. if (
  523. key in image_codebook_config
  524. and value != image_codebook_config[key]
  525. and key != "transformers_version"
  526. ):
  527. # If specified in `image_codebook_config_dict`
  528. if key in image_codebook_config_dict:
  529. message = (
  530. f"`{key}` is found in both `image_codebook_config_dict` and `image_codebook_config` but "
  531. f'with different values. The value `image_codebook_config_dict["{key}"]` will be used '
  532. "instead."
  533. )
  534. # If inferred from default argument values (just to be super careful)
  535. else:
  536. message = (
  537. f"`image_codebook_config_dict` is provided which will be used to initialize "
  538. f'`FlavaImageCodebookConfig`. The value `image_codebook_config["{key}"]` will be overridden.'
  539. )
  540. logger.info(message)
  541. # Update all values in `image_codebook_config` with the ones in `_image_codebook_config_dict`.
  542. image_codebook_config.update(_image_codebook_config_dict)
  543. if image_config is None:
  544. image_config = {}
  545. logger.info("`image_config` is `None`. initializing the `FlavaImageConfig` with default values.")
  546. if text_config is None:
  547. text_config = {}
  548. logger.info("`text_config` is `None`. Initializing the `FlavaTextConfig` with default values.")
  549. if multimodal_config is None:
  550. multimodal_config = {}
  551. logger.info("`multimodal_config` is `None`. initializing the `FlavaMultimodalConfig` with default values.")
  552. if image_codebook_config is None:
  553. image_codebook_config = {}
  554. logger.info(
  555. "`image_codebook_config` is `None`. initializing the `FlavaImageCodebookConfig` with default values."
  556. )
  557. self.image_config = FlavaImageConfig(**image_config)
  558. self.text_config = FlavaTextConfig(**text_config)
  559. self.multimodal_config = FlavaMultimodalConfig(**multimodal_config)
  560. self.image_codebook_config = FlavaImageCodebookConfig(**image_codebook_config)
  561. self.projection_dim = projection_dim
  562. self.init_codebook = init_codebook
  563. self.hidden_size = hidden_size
  564. self.layer_norm_eps = layer_norm_eps
  565. self.initializer_range = initializer_range
  566. self.logit_scale_init_value = logit_scale_init_value
  567. self.initializer_factor = 1.0
  568. self.ce_ignore_index = ce_ignore_index
  569. self.mim_weight = mim_weight
  570. self.mlm_weight = mlm_weight
  571. self.global_contrastive_weight = global_contrastive_weight
  572. self.itm_weight = itm_weight
  573. self.mmm_image_weight = mmm_image_weight
  574. self.mmm_text_weight = mmm_text_weight
  575. self.global_backprop_contrastive = global_backprop_contrastive
  576. self.skip_unmasked_multimodal_encoder = skip_unmasked_multimodal_encoder
  577. self.return_loss = return_loss
  578. @classmethod
  579. def from_configs(
  580. cls,
  581. image_config: FlavaImageConfig,
  582. text_config: FlavaTextConfig,
  583. multimodal_config: FlavaMultimodalConfig,
  584. image_codebook_config: FlavaImageCodebookConfig,
  585. **kwargs,
  586. ):
  587. r"""
  588. Instantiate a [`FlavaConfig`] (or a derived class) from flava text model configuration, flava image model
  589. configuration, flava multimodal model and flava codebook model configuration.
  590. Returns:
  591. [`FlavaConfig`]: An instance of a configuration object
  592. """
  593. return cls(
  594. image_config=image_config.to_dict(),
  595. text_config=text_config.to_dict(),
  596. multimodal_config=multimodal_config.to_dict(),
  597. image_codebook_config=image_codebook_config.to_dict(),
  598. **kwargs,
  599. )
  600. __all__ = ["FlavaConfig", "FlavaImageCodebookConfig", "FlavaImageConfig", "FlavaMultimodalConfig", "FlavaTextConfig"]