configuration_colpali.py 4.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102
  1. # coding=utf-8
  2. # Copyright 2024 The HuggingFace Inc. team.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. """ColPali model configuration"""
  16. import logging
  17. from copy import deepcopy
  18. from ...configuration_utils import PretrainedConfig
  19. from ..auto import CONFIG_MAPPING, AutoConfig
  20. logger = logging.getLogger(__name__)
  21. class ColPaliConfig(PretrainedConfig):
  22. r"""
  23. Configuration class to store the configuration of a [`ColPaliForRetrieval`]. It is used to instantiate an instance
  24. of `ColPaliForRetrieval` according to the specified arguments, defining the model architecture following the methodology
  25. from the "ColPali: Efficient Document Retrieval with Vision Language Models" paper.
  26. Creating a configuration with the default settings will result in a configuration where the VLM backbone is set to the
  27. default PaliGemma configuration, i.e the one from [vidore/colpali-v1.2](https://huggingface.co/vidore/colpali-v1.2).
  28. Note that contrarily to what the class name suggests (actually the name refers to the ColPali **methodology**), you can
  29. use a different VLM backbone model than PaliGemma by passing the corresponding VLM configuration to the class constructor.
  30. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  31. documentation from [`PretrainedConfig`] for more information.
  32. Args:
  33. vlm_config (`PretrainedConfig`, *optional*):
  34. Configuration of the VLM backbone model.
  35. text_config (`PretrainedConfig`, *optional*):
  36. Configuration of the text backbone model. Overrides the `text_config` attribute of the `vlm_config` if provided.
  37. embedding_dim (`int`, *optional*, defaults to 128):
  38. Dimension of the multi-vector embeddings produced by the model.
  39. Example:
  40. ```python
  41. from transformers.models.colpali import ColPaliConfig, ColPaliForRetrieval
  42. config = ColPaliConfig()
  43. model = ColPaliForRetrieval(config)
  44. ```
  45. """
  46. model_type = "colpali"
  47. sub_configs = {"vlm_config": PretrainedConfig, "text_config": AutoConfig}
  48. def __init__(
  49. self,
  50. vlm_config=None,
  51. text_config=None,
  52. embedding_dim: int = 128,
  53. **kwargs,
  54. ):
  55. if vlm_config is None:
  56. vlm_config = CONFIG_MAPPING["paligemma"]()
  57. logger.info(
  58. "`vlm_config` is `None`. Initializing `vlm_config` with the `PaliGemmaConfig` with default values."
  59. )
  60. elif isinstance(vlm_config, dict):
  61. vlm_config = deepcopy(vlm_config)
  62. if "model_type" not in vlm_config:
  63. raise KeyError(
  64. "The `model_type` key is missing in the `vlm_config` dictionary. Please provide the model type."
  65. )
  66. elif vlm_config["model_type"] not in CONFIG_MAPPING:
  67. raise ValueError(
  68. f"The model type `{vlm_config['model_type']}` is not supported. Please provide a valid model type."
  69. )
  70. vlm_config = CONFIG_MAPPING[vlm_config["model_type"]](**vlm_config)
  71. elif not isinstance(vlm_config, PretrainedConfig):
  72. raise TypeError(
  73. f"Invalid type for `vlm_config`. Expected `PretrainedConfig`, `dict`, or `None`, but got {type(vlm_config)}."
  74. )
  75. self.vlm_config = vlm_config
  76. self.text_config = text_config if text_config is not None else vlm_config.text_config
  77. if isinstance(self.text_config, dict):
  78. text_config["model_type"] = text_config.get("model_type", "gemma")
  79. self.text_config = CONFIG_MAPPING[text_config["model_type"]](**text_config)
  80. self.embedding_dim = embedding_dim
  81. super().__init__(**kwargs)
  82. __all__ = ["ColPaliConfig"]