configuration_superpoint.py 4.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ...configuration_utils import PretrainedConfig
  15. from ...utils import logging
  16. logger = logging.get_logger(__name__)
  17. class SuperPointConfig(PretrainedConfig):
  18. r"""
  19. This is the configuration class to store the configuration of a [`SuperPointForKeypointDetection`]. It is used to instantiate a
  20. SuperPoint model according to the specified arguments, defining the model architecture. Instantiating a
  21. configuration with the defaults will yield a similar configuration to that of the SuperPoint
  22. [magic-leap-community/superpoint](https://huggingface.co/magic-leap-community/superpoint) architecture.
  23. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  24. documentation from [`PretrainedConfig`] for more information.
  25. Args:
  26. encoder_hidden_sizes (`List`, *optional*, defaults to `[64, 64, 128, 128]`):
  27. The number of channels in each convolutional layer in the encoder.
  28. decoder_hidden_size (`int`, *optional*, defaults to 256): The hidden size of the decoder.
  29. keypoint_decoder_dim (`int`, *optional*, defaults to 65): The output dimension of the keypoint decoder.
  30. descriptor_decoder_dim (`int`, *optional*, defaults to 256): The output dimension of the descriptor decoder.
  31. keypoint_threshold (`float`, *optional*, defaults to 0.005):
  32. The threshold to use for extracting keypoints.
  33. max_keypoints (`int`, *optional*, defaults to -1):
  34. The maximum number of keypoints to extract. If `-1`, will extract all keypoints.
  35. nms_radius (`int`, *optional*, defaults to 4):
  36. The radius for non-maximum suppression.
  37. border_removal_distance (`int`, *optional*, defaults to 4):
  38. The distance from the border to remove keypoints.
  39. initializer_range (`float`, *optional*, defaults to 0.02):
  40. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  41. Example:
  42. ```python
  43. >>> from transformers import SuperPointConfig, SuperPointForKeypointDetection
  44. >>> # Initializing a SuperPoint superpoint style configuration
  45. >>> configuration = SuperPointConfig()
  46. >>> # Initializing a model from the superpoint style configuration
  47. >>> model = SuperPointForKeypointDetection(configuration)
  48. >>> # Accessing the model configuration
  49. >>> configuration = model.config
  50. ```"""
  51. model_type = "superpoint"
  52. def __init__(
  53. self,
  54. encoder_hidden_sizes: list[int] = [64, 64, 128, 128],
  55. decoder_hidden_size: int = 256,
  56. keypoint_decoder_dim: int = 65,
  57. descriptor_decoder_dim: int = 256,
  58. keypoint_threshold: float = 0.005,
  59. max_keypoints: int = -1,
  60. nms_radius: int = 4,
  61. border_removal_distance: int = 4,
  62. initializer_range=0.02,
  63. **kwargs,
  64. ):
  65. self.encoder_hidden_sizes = encoder_hidden_sizes
  66. self.decoder_hidden_size = decoder_hidden_size
  67. self.keypoint_decoder_dim = keypoint_decoder_dim
  68. self.descriptor_decoder_dim = descriptor_decoder_dim
  69. self.keypoint_threshold = keypoint_threshold
  70. self.max_keypoints = max_keypoints
  71. self.nms_radius = nms_radius
  72. self.border_removal_distance = border_removal_distance
  73. self.initializer_range = initializer_range
  74. super().__init__(**kwargs)
  75. __all__ = ["SuperPointConfig"]