configuration_superglue.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING, Optional
  15. from ...configuration_utils import PretrainedConfig
  16. from ...utils import logging
  17. from ..auto import CONFIG_MAPPING
  18. if TYPE_CHECKING:
  19. from ..superpoint import SuperPointConfig
  20. logger = logging.get_logger(__name__)
  21. class SuperGlueConfig(PretrainedConfig):
  22. r"""
  23. This is the configuration class to store the configuration of a [`SuperGlueModel`]. It is used to instantiate a
  24. SuperGlue model according to the specified arguments, defining the model architecture. Instantiating a
  25. configuration with the defaults will yield a similar configuration to that of the SuperGlue
  26. [magic-leap-community/superglue_indoor](https://huggingface.co/magic-leap-community/superglue_indoor) architecture.
  27. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  28. documentation from [`PretrainedConfig`] for more information.
  29. Args:
  30. keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
  31. The config object or dictionary of the keypoint detector.
  32. hidden_size (`int`, *optional*, defaults to 256):
  33. The dimension of the descriptors.
  34. keypoint_encoder_sizes (`list[int]`, *optional*, defaults to `[32, 64, 128, 256]`):
  35. The sizes of the keypoint encoder layers.
  36. gnn_layers_types (`list[str]`, *optional*, defaults to `['self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross', 'self', 'cross']`):
  37. The types of the GNN layers. Must be either 'self' or 'cross'.
  38. num_attention_heads (`int`, *optional*, defaults to 4):
  39. The number of heads in the GNN layers.
  40. sinkhorn_iterations (`int`, *optional*, defaults to 100):
  41. The number of Sinkhorn iterations.
  42. matching_threshold (`float`, *optional*, defaults to 0.0):
  43. The matching threshold.
  44. initializer_range (`float`, *optional*, defaults to 0.02):
  45. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  46. Examples:
  47. ```python
  48. >>> from transformers import SuperGlueConfig, SuperGlueModel
  49. >>> # Initializing a SuperGlue superglue style configuration
  50. >>> configuration = SuperGlueConfig()
  51. >>> # Initializing a model from the superglue style configuration
  52. >>> model = SuperGlueModel(configuration)
  53. >>> # Accessing the model configuration
  54. >>> configuration = model.config
  55. ```
  56. """
  57. model_type = "superglue"
  58. def __init__(
  59. self,
  60. keypoint_detector_config: "SuperPointConfig" = None,
  61. hidden_size: int = 256,
  62. keypoint_encoder_sizes: Optional[list[int]] = None,
  63. gnn_layers_types: Optional[list[str]] = None,
  64. num_attention_heads: int = 4,
  65. sinkhorn_iterations: int = 100,
  66. matching_threshold: float = 0.0,
  67. initializer_range: float = 0.02,
  68. **kwargs,
  69. ):
  70. self.gnn_layers_types = gnn_layers_types if gnn_layers_types is not None else ["self", "cross"] * 9
  71. # Check whether all gnn_layers_types are either 'self' or 'cross'
  72. if not all(layer_type in ["self", "cross"] for layer_type in self.gnn_layers_types):
  73. raise ValueError("All gnn_layers_types must be either 'self' or 'cross'")
  74. if hidden_size % num_attention_heads != 0:
  75. raise ValueError("hidden_size % num_attention_heads is different from zero")
  76. self.keypoint_encoder_sizes = (
  77. keypoint_encoder_sizes if keypoint_encoder_sizes is not None else [32, 64, 128, 256]
  78. )
  79. self.hidden_size = hidden_size
  80. self.keypoint_encoder_sizes = keypoint_encoder_sizes
  81. self.gnn_layers_types = gnn_layers_types
  82. self.num_attention_heads = num_attention_heads
  83. self.sinkhorn_iterations = sinkhorn_iterations
  84. self.matching_threshold = matching_threshold
  85. if isinstance(keypoint_detector_config, dict):
  86. keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint")
  87. keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
  88. **keypoint_detector_config
  89. )
  90. if keypoint_detector_config is None:
  91. keypoint_detector_config = CONFIG_MAPPING["superpoint"]()
  92. self.keypoint_detector_config = keypoint_detector_config
  93. self.initializer_range = initializer_range
  94. self.attention_probs_dropout_prob = 0
  95. self.is_decoder = False
  96. super().__init__(**kwargs)
  97. @property
  98. def sub_configs(self):
  99. return {"keypoint_detector_config": type(self.keypoint_detector_config)}
  100. __all__ = ["SuperGlueConfig"]