configuration_efficientloftr.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import Optional
  15. from ...configuration_utils import PretrainedConfig
  16. from ...modeling_rope_utils import rope_config_validation
  17. class EfficientLoFTRConfig(PretrainedConfig):
  18. r"""
  19. This is the configuration class to store the configuration of a [`EfficientLoFTRFromKeypointMatching`].
  20. It is used to instantiate a EfficientLoFTR model according to the specified arguments, defining the model
  21. architecture. Instantiating a configuration with the defaults will yield a similar configuration to that of the
  22. EfficientLoFTR [zju-community/efficientloftr](https://huggingface.co/zju-community/efficientloftr) architecture.
  23. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  24. documentation from [`PretrainedConfig`] for more information.
  25. Args:
  26. stage_num_blocks (`List`, *optional*, defaults to [1, 2, 4, 14]):
  27. The number of blocks in each stages
  28. out_features (`List`, *optional*, defaults to [64, 64, 128, 256]):
  29. The number of channels in each stage
  30. stage_stride (`List`, *optional*, defaults to [2, 1, 2, 2]):
  31. The stride used in each stage
  32. hidden_size (`int`, *optional*, defaults to 256):
  33. The dimension of the descriptors.
  34. activation_function (`str`, *optional*, defaults to `"relu"`):
  35. The activation function used in the backbone
  36. q_aggregation_kernel_size (`int`, *optional*, defaults to 4):
  37. The kernel size of the aggregation of query states in the fusion network
  38. kv_aggregation_kernel_size (`int`, *optional*, defaults to 4):
  39. The kernel size of the aggregation of key and value states in the fusion network
  40. q_aggregation_stride (`int`, *optional*, defaults to 4):
  41. The stride of the aggregation of query states in the fusion network
  42. kv_aggregation_stride (`int`, *optional*, defaults to 4):
  43. The stride of the aggregation of key and value states in the fusion network
  44. num_attention_layers (`int`, *optional*, defaults to 4):
  45. Number of attention layers in the LocalFeatureTransformer
  46. num_attention_heads (`int`, *optional*, defaults to 8):
  47. The number of heads in the GNN layers.
  48. attention_dropout (`float`, *optional*, defaults to 0.0):
  49. The dropout ratio for the attention probabilities.
  50. attention_bias (`bool`, *optional*, defaults to `False`):
  51. Whether to use a bias in the query, key, value and output projection layers during attention.
  52. mlp_activation_function (`str`, *optional*, defaults to `"leaky_relu"`):
  53. Activation function used in the attention mlp layer.
  54. coarse_matching_skip_softmax (`bool`, *optional*, defaults to `False`):
  55. Whether to skip softmax or not at the coarse matching step.
  56. coarse_matching_threshold (`float`, *optional*, defaults to 0.2):
  57. The threshold for the minimum score required for a match.
  58. coarse_matching_temperature (`float`, *optional*, defaults to 0.1):
  59. The temperature to apply to the coarse similarity matrix
  60. coarse_matching_border_removal (`int`, *optional*, defaults to 2):
  61. The size of the border to remove during coarse matching
  62. fine_kernel_size (`int`, *optional*, defaults to 8):
  63. Kernel size used for the fine feature matching
  64. batch_norm_eps (`float`, *optional*, defaults to 1e-05):
  65. The epsilon used by the batch normalization layers.
  66. rope_theta (`float`, *optional*, defaults to 10000.0):
  67. The base period of the RoPE embeddings.
  68. partial_rotary_factor (`float`, *optional*, defaults to 4.0):
  69. Dim factor for the RoPE embeddings, in EfficientLoFTR, frequencies should be generated for
  70. the whole hidden_size, so this factor is used to compensate.
  71. rope_scaling (`Dict`, *optional*):
  72. Dictionary containing the scaling configuration for the RoPE embeddings. NOTE: if you apply new rope type
  73. and you expect the model to work on longer `max_position_embeddings`, we recommend you to update this value
  74. accordingly.
  75. Expected contents:
  76. `rope_type` (`str`):
  77. The sub-variant of RoPE to use. Can be one of ['default', 'linear', 'dynamic', 'yarn', 'longrope',
  78. 'llama3', '2d'], with 'default' being the original RoPE implementation.
  79. `dim` (`int`): The dimension of the RoPE embeddings.
  80. fine_matching_slice_dim (`int`, *optional*, defaults to 8):
  81. The size of the slice used to divide the fine features for the first and second fine matching stages.
  82. fine_matching_regress_temperature (`float`, *optional*, defaults to 10.0):
  83. The temperature to apply to the fine similarity matrix
  84. initializer_range (`float`, *optional*, defaults to 0.02):
  85. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  86. Examples:
  87. ```python
  88. >>> from transformers import EfficientLoFTRConfig, EfficientLoFTRForKeypointMatching
  89. >>> # Initializing a EfficientLoFTR configuration
  90. >>> configuration = EfficientLoFTRConfig()
  91. >>> # Initializing a model from the EfficientLoFTR configuration
  92. >>> model = EfficientLoFTRForKeypointMatching(configuration)
  93. >>> # Accessing the model configuration
  94. >>> configuration = model.config
  95. ```
  96. """
  97. model_type = "efficientloftr"
  98. def __init__(
  99. self,
  100. stage_num_blocks: Optional[list[int]] = None,
  101. out_features: Optional[list[int]] = None,
  102. stage_stride: Optional[list[int]] = None,
  103. hidden_size: int = 256,
  104. activation_function: str = "relu",
  105. q_aggregation_kernel_size: int = 4,
  106. kv_aggregation_kernel_size: int = 4,
  107. q_aggregation_stride: int = 4,
  108. kv_aggregation_stride: int = 4,
  109. num_attention_layers: int = 4,
  110. num_attention_heads: int = 8,
  111. attention_dropout: float = 0.0,
  112. attention_bias: bool = False,
  113. mlp_activation_function: str = "leaky_relu",
  114. coarse_matching_skip_softmax: bool = False,
  115. coarse_matching_threshold: float = 0.2,
  116. coarse_matching_temperature: float = 0.1,
  117. coarse_matching_border_removal: int = 2,
  118. fine_kernel_size: int = 8,
  119. batch_norm_eps: float = 1e-5,
  120. rope_theta: float = 10000.0,
  121. partial_rotary_factor: float = 4.0,
  122. rope_scaling: Optional[dict] = None,
  123. fine_matching_slice_dim: int = 8,
  124. fine_matching_regress_temperature: float = 10.0,
  125. initializer_range: float = 0.02,
  126. **kwargs,
  127. ):
  128. # Stage level of RepVGG
  129. self.stage_num_blocks = stage_num_blocks if stage_num_blocks is not None else [1, 2, 4, 14]
  130. self.stage_stride = stage_stride if stage_stride is not None else [2, 1, 2, 2]
  131. self.out_features = out_features if out_features is not None else [64, 64, 128, 256]
  132. self.stage_in_channels = [1] + self.out_features[:-1]
  133. # Block level of RepVGG
  134. self.stage_block_stride = [
  135. [stride] + [1] * (num_blocks - 1) for stride, num_blocks in zip(self.stage_stride, self.stage_num_blocks)
  136. ]
  137. self.stage_block_out_channels = [
  138. [self.out_features[stage_idx]] * num_blocks for stage_idx, num_blocks in enumerate(self.stage_num_blocks)
  139. ]
  140. self.stage_block_in_channels = [
  141. [self.stage_in_channels[stage_idx]] + self.stage_block_out_channels[stage_idx][:-1]
  142. for stage_idx in range(len(self.stage_num_blocks))
  143. ]
  144. # Fine matching level of EfficientLoFTR
  145. self.fine_fusion_dims = list(reversed(self.out_features))[:-1]
  146. self.hidden_size = hidden_size
  147. if self.hidden_size != self.out_features[-1]:
  148. raise ValueError(
  149. f"hidden_size should be equal to the last value in out_features. hidden_size = {self.hidden_size}, out_features = {self.out_features[-1]}"
  150. )
  151. self.activation_function = activation_function
  152. self.q_aggregation_kernel_size = q_aggregation_kernel_size
  153. self.kv_aggregation_kernel_size = kv_aggregation_kernel_size
  154. self.q_aggregation_stride = q_aggregation_stride
  155. self.kv_aggregation_stride = kv_aggregation_stride
  156. self.num_attention_layers = num_attention_layers
  157. self.num_attention_heads = num_attention_heads
  158. self.attention_dropout = attention_dropout
  159. self.attention_bias = attention_bias
  160. self.intermediate_size = self.hidden_size * 2
  161. self.mlp_activation_function = mlp_activation_function
  162. self.coarse_matching_skip_softmax = coarse_matching_skip_softmax
  163. self.coarse_matching_threshold = coarse_matching_threshold
  164. self.coarse_matching_temperature = coarse_matching_temperature
  165. self.coarse_matching_border_removal = coarse_matching_border_removal
  166. self.fine_kernel_size = fine_kernel_size
  167. self.batch_norm_eps = batch_norm_eps
  168. self.fine_matching_slice_dim = fine_matching_slice_dim
  169. self.fine_matching_regress_temperature = fine_matching_regress_temperature
  170. self.num_key_value_heads = num_attention_heads
  171. self.rope_theta = rope_theta
  172. self.rope_scaling = rope_scaling if rope_scaling is not None else {"rope_type": "default"}
  173. # for compatibility with "default" rope type
  174. self.partial_rotary_factor = partial_rotary_factor
  175. rope_config_validation(self)
  176. self.initializer_range = initializer_range
  177. super().__init__(**kwargs)
  178. __all__ = ["EfficientLoFTRConfig"]