modular_lightglue.py 50 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import warnings
  15. from dataclasses import dataclass
  16. from typing import Callable, Optional, Union
  17. import numpy as np
  18. import torch
  19. from torch import nn
  20. from torch.nn.utils.rnn import pad_sequence
  21. from ...configuration_utils import PretrainedConfig
  22. from ...image_utils import ImageInput, is_vision_available, to_numpy_array
  23. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  24. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  25. from ...processing_utils import Unpack
  26. from ...utils import ModelOutput, TensorType, auto_docstring, is_matplotlib_available, logging
  27. from ...utils.generic import can_return_tuple
  28. from ..auto import CONFIG_MAPPING, AutoConfig
  29. from ..auto.modeling_auto import AutoModelForKeypointDetection
  30. from ..clip.modeling_clip import CLIPMLP
  31. from ..cohere.modeling_cohere import apply_rotary_pos_emb
  32. from ..llama.modeling_llama import LlamaAttention, eager_attention_forward
  33. from ..superglue.image_processing_superglue import SuperGlueImageProcessor, validate_and_format_image_pairs
  34. from ..superpoint import SuperPointConfig
  35. if is_vision_available():
  36. from PIL import Image, ImageDraw
  37. logger = logging.get_logger(__name__)
  38. class LightGlueConfig(PretrainedConfig):
  39. r"""
  40. This is the configuration class to store the configuration of a [`LightGlueForKeypointMatching`]. It is used to
  41. instantiate a LightGlue model according to the specified arguments, defining the model architecture. Instantiating a
  42. configuration with the defaults will yield a similar configuration to that of the LightGlue
  43. [ETH-CVG/lightglue_superpoint](https://huggingface.co/ETH-CVG/lightglue_superpoint) architecture.
  44. Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
  45. documentation from [`PretrainedConfig`] for more information.
  46. Args:
  47. keypoint_detector_config (`Union[AutoConfig, dict]`, *optional*, defaults to `SuperPointConfig`):
  48. The config object or dictionary of the keypoint detector.
  49. descriptor_dim (`int`, *optional*, defaults to 256):
  50. The dimension of the descriptors.
  51. num_hidden_layers (`int`, *optional*, defaults to 9):
  52. The number of self and cross attention layers.
  53. num_attention_heads (`int`, *optional*, defaults to 4):
  54. The number of heads in the multi-head attention.
  55. num_key_value_heads (`int`, *optional*):
  56. This is the number of key_value heads that should be used to implement Grouped Query Attention. If
  57. `num_key_value_heads=num_attention_heads`, the model will use Multi Head Attention (MHA), if
  58. `num_key_value_heads=1` the model will use Multi Query Attention (MQA) otherwise GQA is used. When
  59. converting a multi-head checkpoint to a GQA checkpoint, each group key and value head should be constructed
  60. by meanpooling all the original heads within that group. For more details checkout [this
  61. paper](https://huggingface.co/papers/2305.13245). If it is not specified, will default to
  62. `num_attention_heads`.
  63. depth_confidence (`float`, *optional*, defaults to 0.95):
  64. The confidence threshold used to perform early stopping
  65. width_confidence (`float`, *optional*, defaults to 0.99):
  66. The confidence threshold used to prune points
  67. filter_threshold (`float`, *optional*, defaults to 0.1):
  68. The confidence threshold used to filter matches
  69. initializer_range (`float`, *optional*, defaults to 0.02):
  70. The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
  71. hidden_act (`str`, *optional*, defaults to `"gelu"`):
  72. The activation function to be used in the hidden layers.
  73. attention_dropout (`float`, *optional*, defaults to 0.0):
  74. The dropout ratio for the attention probabilities.
  75. attention_bias (`bool`, *optional*, defaults to `True`):
  76. Whether to use a bias in the query, key, value and output projection layers during self-attention.
  77. trust_remote_code (`bool`, *optional*, defaults to `False`):
  78. Whether to trust remote code when using other models than SuperPoint as keypoint detector.
  79. Examples:
  80. ```python
  81. >>> from transformers import LightGlueConfig, LightGlueForKeypointMatching
  82. >>> # Initializing a LightGlue style configuration
  83. >>> configuration = LightGlueConfig()
  84. >>> # Initializing a model from the LightGlue style configuration
  85. >>> model = LightGlueForKeypointMatching(configuration)
  86. >>> # Accessing the model configuration
  87. >>> configuration = model.config
  88. ```
  89. """
  90. model_type = "lightglue"
  91. sub_configs = {"keypoint_detector_config": AutoConfig}
  92. def __init__(
  93. self,
  94. keypoint_detector_config: SuperPointConfig = None,
  95. descriptor_dim: int = 256,
  96. num_hidden_layers: int = 9,
  97. num_attention_heads: int = 4,
  98. num_key_value_heads=None,
  99. depth_confidence: float = 0.95,
  100. width_confidence: float = 0.99,
  101. filter_threshold: float = 0.1,
  102. initializer_range: float = 0.02,
  103. hidden_act: str = "gelu",
  104. attention_dropout=0.0,
  105. attention_bias=True,
  106. trust_remote_code: bool = False,
  107. **kwargs,
  108. ):
  109. # LightGlue can be used with other models than SuperPoint as keypoint detector
  110. # We provide the trust_remote_code argument to allow the use of other models
  111. # that are not registered in the CONFIG_MAPPING dictionary (for example DISK)
  112. self.trust_remote_code = trust_remote_code
  113. if descriptor_dim % num_attention_heads != 0:
  114. raise ValueError("descriptor_dim % num_heads is different from zero")
  115. self.descriptor_dim = descriptor_dim
  116. self.num_hidden_layers = num_hidden_layers
  117. self.num_attention_heads = num_attention_heads
  118. # for backward compatibility
  119. if num_key_value_heads is None:
  120. num_key_value_heads = num_attention_heads
  121. self.num_key_value_heads = num_key_value_heads
  122. self.depth_confidence = depth_confidence
  123. self.width_confidence = width_confidence
  124. self.filter_threshold = filter_threshold
  125. self.initializer_range = initializer_range
  126. # Keypoint Detector is forced into eager attention mode because SuperPoint does not have Attention
  127. # See https://github.com/huggingface/transformers/pull/31718#discussion_r2109733153
  128. if isinstance(keypoint_detector_config, dict):
  129. keypoint_detector_config["model_type"] = keypoint_detector_config.get("model_type", "superpoint")
  130. if keypoint_detector_config["model_type"] not in CONFIG_MAPPING:
  131. keypoint_detector_config = AutoConfig.from_pretrained(
  132. keypoint_detector_config["_name_or_path"], trust_remote_code=self.trust_remote_code
  133. )
  134. else:
  135. keypoint_detector_config = CONFIG_MAPPING[keypoint_detector_config["model_type"]](
  136. **keypoint_detector_config, attn_implementation="eager"
  137. )
  138. if keypoint_detector_config is None:
  139. keypoint_detector_config = CONFIG_MAPPING["superpoint"](attn_implementation="eager")
  140. self.keypoint_detector_config = keypoint_detector_config
  141. self.hidden_size = descriptor_dim
  142. self.intermediate_size = descriptor_dim * 2
  143. self.hidden_act = hidden_act
  144. self.attention_dropout = attention_dropout
  145. self.attention_bias = attention_bias
  146. super().__init__(**kwargs)
  147. @dataclass
  148. @auto_docstring(
  149. custom_intro="""
  150. Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
  151. the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
  152. batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
  153. tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
  154. matching information.
  155. """
  156. )
  157. class LightGlueKeypointMatchingOutput(ModelOutput):
  158. r"""
  159. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  160. Loss computed during training.
  161. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  162. Index of keypoint matched in the other image.
  163. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  164. Scores of predicted matches.
  165. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  166. Absolute (x, y) coordinates of predicted keypoints in a given image.
  167. prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
  168. Pruning mask indicating which keypoints are removed and at which layer.
  169. mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
  170. Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
  171. information.
  172. hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
  173. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  174. num_keypoints)` returned when `output_hidden_states=True` is passed or when
  175. `config.output_hidden_states=True`
  176. attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
  177. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  178. num_keypoints)` returned when `output_attentions=True` is passed or when
  179. `config.output_attentions=True`
  180. """
  181. loss: Optional[torch.FloatTensor] = None
  182. matches: Optional[torch.FloatTensor] = None
  183. matching_scores: Optional[torch.FloatTensor] = None
  184. keypoints: Optional[torch.FloatTensor] = None
  185. prune: Optional[torch.IntTensor] = None
  186. mask: Optional[torch.FloatTensor] = None
  187. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  188. attentions: Optional[tuple[torch.FloatTensor]] = None
  189. class LightGlueImageProcessor(SuperGlueImageProcessor):
  190. def post_process_keypoint_matching(
  191. self,
  192. outputs: LightGlueKeypointMatchingOutput,
  193. target_sizes: Union[TensorType, list[tuple]],
  194. threshold: float = 0.0,
  195. ) -> list[dict[str, torch.Tensor]]:
  196. return super().post_process_keypoint_matching(outputs, target_sizes, threshold)
  197. # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor.visualize_keypoint_matching with EfficientLoFTR->LightGlue
  198. def visualize_keypoint_matching(
  199. self,
  200. images: ImageInput,
  201. keypoint_matching_output: list[dict[str, torch.Tensor]],
  202. ) -> list["Image.Image"]:
  203. """
  204. Plots the image pairs side by side with the detected keypoints as well as the matching between them.
  205. Args:
  206. images (`ImageInput`):
  207. Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2
  208. images or a list of list of 2 images list with pixel values ranging from 0 to 255.
  209. keypoint_matching_output (List[Dict[str, torch.Tensor]]]):
  210. A post processed keypoint matching output
  211. Returns:
  212. `List[PIL.Image.Image]`: A list of PIL images, each containing the image pairs side by side with the detected
  213. keypoints as well as the matching between them.
  214. """
  215. images = validate_and_format_image_pairs(images)
  216. images = [to_numpy_array(image) for image in images]
  217. image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
  218. results = []
  219. for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
  220. height0, width0 = image_pair[0].shape[:2]
  221. height1, width1 = image_pair[1].shape[:2]
  222. plot_image = np.zeros((max(height0, height1), width0 + width1, 3), dtype=np.uint8)
  223. plot_image[:height0, :width0] = image_pair[0]
  224. plot_image[:height1, width0:] = image_pair[1]
  225. plot_image_pil = Image.fromarray(plot_image)
  226. draw = ImageDraw.Draw(plot_image_pil)
  227. keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
  228. keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
  229. for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
  230. keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
  231. ):
  232. color = self._get_color(matching_score)
  233. draw.line(
  234. (keypoint0_x, keypoint0_y, keypoint1_x + width0, keypoint1_y),
  235. fill=color,
  236. width=3,
  237. )
  238. draw.ellipse((keypoint0_x - 2, keypoint0_y - 2, keypoint0_x + 2, keypoint0_y + 2), fill="black")
  239. draw.ellipse(
  240. (keypoint1_x + width0 - 2, keypoint1_y - 2, keypoint1_x + width0 + 2, keypoint1_y + 2),
  241. fill="black",
  242. )
  243. results.append(plot_image_pil)
  244. return results
  245. # Copied from transformers.models.efficientloftr.image_processing_efficientloftr.EfficientLoFTRImageProcessor._get_color
  246. def _get_color(self, score):
  247. """Maps a score to a color."""
  248. r = int(255 * (1 - score))
  249. g = int(255 * score)
  250. b = 0
  251. return (r, g, b)
  252. def plot_keypoint_matching(self, images: ImageInput, keypoint_matching_output: LightGlueKeypointMatchingOutput):
  253. """
  254. Plots the image pairs side by side with the detected keypoints as well as the matching between them. Requires
  255. matplotlib to be installed.
  256. .. deprecated::
  257. `plot_keypoint_matching` is deprecated and will be removed in a future version. Use `visualize_keypoint_matching` instead.
  258. Args:
  259. images (`ImageInput`):
  260. Image pairs to plot. Same as `LightGlueImageProcessor.preprocess`. Expects either a list of 2 images or
  261. a list of list of 2 images list with pixel values ranging from 0 to 255.
  262. keypoint_matching_output ([`LightGlueKeypointMatchingOutput`]):
  263. Raw outputs of the model.
  264. """
  265. warnings.warn(
  266. "`plot_keypoint_matching` is deprecated and will be removed in transformers v. "
  267. "Use `visualize_keypoint_matching` instead.",
  268. FutureWarning,
  269. )
  270. if is_matplotlib_available():
  271. import matplotlib.pyplot as plt
  272. else:
  273. raise ImportError("Please install matplotlib to use `plot_keypoint_matching` method")
  274. images = validate_and_format_image_pairs(images)
  275. images = [to_numpy_array(image) for image in images]
  276. image_pairs = [images[i : i + 2] for i in range(0, len(images), 2)]
  277. for image_pair, pair_output in zip(image_pairs, keypoint_matching_output):
  278. height0, width0 = image_pair[0].shape[:2]
  279. height1, width1 = image_pair[1].shape[:2]
  280. plot_image = np.zeros((max(height0, height1), width0 + width1, 3))
  281. plot_image[:height0, :width0] = image_pair[0] / 255.0
  282. plot_image[:height1, width0:] = image_pair[1] / 255.0
  283. plt.imshow(plot_image)
  284. plt.axis("off")
  285. keypoints0_x, keypoints0_y = pair_output["keypoints0"].unbind(1)
  286. keypoints1_x, keypoints1_y = pair_output["keypoints1"].unbind(1)
  287. for keypoint0_x, keypoint0_y, keypoint1_x, keypoint1_y, matching_score in zip(
  288. keypoints0_x, keypoints0_y, keypoints1_x, keypoints1_y, pair_output["matching_scores"]
  289. ):
  290. plt.plot(
  291. [keypoint0_x, keypoint1_x + width0],
  292. [keypoint0_y, keypoint1_y],
  293. color=plt.get_cmap("RdYlGn")(matching_score.item()),
  294. alpha=0.9,
  295. linewidth=0.5,
  296. )
  297. plt.scatter(keypoint0_x, keypoint0_y, c="black", s=2)
  298. plt.scatter(keypoint1_x + width0, keypoint1_y, c="black", s=2)
  299. plt.show()
  300. class LightGluePositionalEncoder(nn.Module):
  301. def __init__(self, config: LightGlueConfig):
  302. super().__init__()
  303. self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
  304. def forward(
  305. self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
  306. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  307. projected_keypoints = self.projector(keypoints)
  308. embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
  309. cosines = torch.cos(embeddings)
  310. sines = torch.sin(embeddings)
  311. embeddings = (cosines, sines)
  312. output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
  313. return output
  314. class LightGlueAttention(LlamaAttention):
  315. def forward(
  316. self,
  317. hidden_states: torch.Tensor,
  318. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  319. attention_mask: Optional[torch.Tensor] = None,
  320. encoder_hidden_states: Optional[torch.Tensor] = None,
  321. encoder_attention_mask: Optional[torch.Tensor] = None,
  322. **kwargs: Unpack[FlashAttentionKwargs],
  323. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  324. input_shape = hidden_states.shape[:-1]
  325. hidden_shape = (*input_shape, -1, self.head_dim)
  326. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  327. is_cross_attention = encoder_hidden_states is not None
  328. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  329. current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  330. key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
  331. value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
  332. if position_embeddings is not None:
  333. cos, sin = position_embeddings
  334. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  335. attention_interface: Callable = eager_attention_forward
  336. if self.config._attn_implementation != "eager":
  337. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  338. attn_output, attn_weights = attention_interface(
  339. self,
  340. query_states,
  341. key_states,
  342. value_states,
  343. current_attention_mask,
  344. dropout=0.0 if not self.training else self.attention_dropout,
  345. scaling=self.scaling,
  346. **kwargs,
  347. )
  348. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  349. attn_output = self.o_proj(attn_output)
  350. return attn_output, attn_weights
  351. class LightGlueMLP(CLIPMLP):
  352. def __init__(self, config: LightGlueConfig):
  353. super().__init__(config)
  354. self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
  355. self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
  356. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  357. hidden_states = self.fc1(hidden_states)
  358. hidden_states = self.layer_norm(hidden_states)
  359. hidden_states = self.activation_fn(hidden_states)
  360. hidden_states = self.fc2(hidden_states)
  361. return hidden_states
  362. class LightGlueTransformerLayer(nn.Module):
  363. def __init__(self, config: LightGlueConfig, layer_idx: int):
  364. super().__init__()
  365. self.self_attention = LightGlueAttention(config, layer_idx)
  366. self.self_mlp = LightGlueMLP(config)
  367. self.cross_attention = LightGlueAttention(config, layer_idx)
  368. self.cross_mlp = LightGlueMLP(config)
  369. def forward(
  370. self,
  371. descriptors: torch.Tensor,
  372. keypoints: torch.Tensor,
  373. attention_mask: torch.Tensor,
  374. output_hidden_states: Optional[bool] = False,
  375. output_attentions: Optional[bool] = False,
  376. ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
  377. all_hidden_states = () if output_hidden_states else None
  378. all_attentions = () if output_attentions else None
  379. if output_hidden_states:
  380. all_hidden_states = all_hidden_states + (descriptors,)
  381. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  382. # Self attention block
  383. attention_output, self_attentions = self.self_attention(
  384. descriptors,
  385. position_embeddings=keypoints,
  386. attention_mask=attention_mask,
  387. output_attentions=output_attentions,
  388. )
  389. intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
  390. output_states = self.self_mlp(intermediate_states)
  391. self_attention_descriptors = descriptors + output_states
  392. if output_hidden_states:
  393. self_attention_hidden_states = (intermediate_states, output_states)
  394. # Reshape hidden_states to group by image_pairs :
  395. # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
  396. # Flip dimension 1 to perform cross attention :
  397. # (image0, image1) -> (image1, image0)
  398. # Reshape back to original shape :
  399. # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
  400. encoder_hidden_states = (
  401. self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
  402. .flip(1)
  403. .reshape(batch_size, num_keypoints, descriptor_dim)
  404. )
  405. # Same for mask
  406. encoder_attention_mask = (
  407. attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
  408. if attention_mask is not None
  409. else None
  410. )
  411. # Cross attention block
  412. cross_attention_output, cross_attentions = self.cross_attention(
  413. self_attention_descriptors,
  414. encoder_hidden_states=encoder_hidden_states,
  415. encoder_attention_mask=encoder_attention_mask,
  416. output_attentions=output_attentions,
  417. )
  418. cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
  419. cross_output_states = self.cross_mlp(cross_intermediate_states)
  420. descriptors = self_attention_descriptors + cross_output_states
  421. if output_hidden_states:
  422. cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
  423. all_hidden_states = (
  424. all_hidden_states
  425. + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  426. + self_attention_hidden_states
  427. + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  428. + cross_attention_hidden_states
  429. )
  430. if output_attentions:
  431. all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
  432. return descriptors, all_hidden_states, all_attentions
  433. def sigmoid_log_double_softmax(
  434. similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
  435. ) -> torch.Tensor:
  436. """create the log assignment matrix from logits and similarity"""
  437. batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
  438. certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
  439. scores0 = nn.functional.log_softmax(similarity, 2)
  440. scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  441. scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
  442. scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
  443. scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
  444. scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
  445. return scores
  446. class LightGlueMatchAssignmentLayer(nn.Module):
  447. def __init__(self, config: LightGlueConfig):
  448. super().__init__()
  449. self.descriptor_dim = config.descriptor_dim
  450. self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
  451. self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
  452. def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
  453. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  454. # Final projection and similarity computation
  455. m_descriptors = self.final_projection(descriptors)
  456. m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
  457. m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
  458. m_descriptors0 = m_descriptors[:, 0]
  459. m_descriptors1 = m_descriptors[:, 1]
  460. similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
  461. if mask is not None:
  462. mask = mask.reshape(batch_size // 2, 2, num_keypoints)
  463. mask0 = mask[:, 0].unsqueeze(-1)
  464. mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
  465. mask = mask0 * mask1
  466. similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
  467. # Compute matchability of descriptors
  468. matchability = self.matchability(descriptors)
  469. matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
  470. matchability_0 = matchability[:, 0]
  471. matchability_1 = matchability[:, 1]
  472. # Compute scores from similarity and matchability
  473. scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
  474. return scores
  475. def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
  476. """Get matchability of descriptors as a probability"""
  477. matchability = self.matchability(descriptors)
  478. matchability = nn.functional.sigmoid(matchability).squeeze(-1)
  479. return matchability
  480. class LightGlueTokenConfidenceLayer(nn.Module):
  481. def __init__(self, config: LightGlueConfig):
  482. super().__init__()
  483. self.token = nn.Linear(config.descriptor_dim, 1)
  484. def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
  485. token = self.token(descriptors.detach())
  486. token = nn.functional.sigmoid(token).squeeze(-1)
  487. return token
  488. @auto_docstring
  489. class LightGluePreTrainedModel(PreTrainedModel):
  490. """
  491. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  492. models.
  493. """
  494. config: LightGlueConfig
  495. base_model_prefix = "lightglue"
  496. main_input_name = "pixel_values"
  497. supports_gradient_checkpointing = False
  498. _supports_flash_attn = True
  499. _supports_sdpa = True
  500. def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
  501. """obtain matches from a score matrix [Bx M+1 x N+1]"""
  502. batch_size, _, _ = scores.shape
  503. # For each keypoint, get the best match
  504. max0 = scores[:, :-1, :-1].max(2)
  505. max1 = scores[:, :-1, :-1].max(1)
  506. matches0 = max0.indices
  507. matches1 = max1.indices
  508. # Mutual check for matches
  509. indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
  510. indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
  511. mutual0 = indices0 == matches1.gather(1, matches0)
  512. mutual1 = indices1 == matches0.gather(1, matches1)
  513. # Get matching scores and filter based on mutual check and thresholding
  514. max0 = max0.values.exp()
  515. zero = max0.new_tensor(0)
  516. matching_scores0 = torch.where(mutual0, max0, zero)
  517. matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
  518. valid0 = mutual0 & (matching_scores0 > threshold)
  519. valid1 = mutual1 & valid0.gather(1, matches1)
  520. # Filter matches based on mutual check and thresholding of scores
  521. matches0 = torch.where(valid0, matches0, -1)
  522. matches1 = torch.where(valid1, matches1, -1)
  523. matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
  524. matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
  525. return matches, matching_scores
  526. def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
  527. """
  528. Normalize keypoints locations based on image image_shape
  529. Args:
  530. keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
  531. Keypoints locations in (x, y) format.
  532. height (`int`):
  533. Image height.
  534. width (`int`):
  535. Image width.
  536. Returns:
  537. Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
  538. """
  539. size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
  540. shift = size / 2
  541. scale = size.max(-1).values / 2
  542. keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
  543. return keypoints
  544. @auto_docstring(
  545. custom_intro="""
  546. LightGlue model taking images as inputs and outputting the matching of them.
  547. """
  548. )
  549. class LightGlueForKeypointMatching(LightGluePreTrainedModel):
  550. """
  551. LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
  552. SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
  553. It consists of :
  554. 1. Keypoint Encoder
  555. 2. A Graph Neural Network with self and cross attention layers
  556. 3. Matching Assignment layers
  557. The correspondence ids use -1 to indicate non-matching points.
  558. Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
  559. In ICCV 2023. https://huggingface.co/papers/2306.13643
  560. """
  561. def __init__(self, config: LightGlueConfig):
  562. super().__init__(config)
  563. self.keypoint_detector = AutoModelForKeypointDetection.from_config(
  564. config.keypoint_detector_config, trust_remote_code=config.trust_remote_code
  565. )
  566. self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
  567. self.descriptor_dim = config.descriptor_dim
  568. self.num_layers = config.num_hidden_layers
  569. self.filter_threshold = config.filter_threshold
  570. self.depth_confidence = config.depth_confidence
  571. self.width_confidence = config.width_confidence
  572. if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
  573. self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
  574. else:
  575. self.input_projection = nn.Identity()
  576. self.positional_encoder = LightGluePositionalEncoder(config)
  577. self.transformer_layers = nn.ModuleList(
  578. [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  579. )
  580. self.match_assignment_layers = nn.ModuleList(
  581. [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
  582. )
  583. self.token_confidence = nn.ModuleList(
  584. [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
  585. )
  586. self.post_init()
  587. def _get_confidence_threshold(self, layer_index: int) -> float:
  588. """scaled confidence threshold for a given layer"""
  589. threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
  590. return np.clip(threshold, 0, 1)
  591. def _keypoint_processing(
  592. self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
  593. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  594. descriptors = descriptors.detach().contiguous()
  595. projected_descriptors = self.input_projection(descriptors)
  596. keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
  597. return projected_descriptors, keypoint_encoding_output
  598. def _get_early_stopped_image_pairs(
  599. self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
  600. ) -> torch.Tensor:
  601. """evaluate whether we should stop inference based on the confidence of the keypoints"""
  602. batch_size, _ = mask.shape
  603. if layer_index < self.num_layers - 1:
  604. # If the current layer is not the last layer, we compute the confidence of the keypoints and check
  605. # if we should stop the forward pass through the transformer layers for each pair of images.
  606. keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
  607. keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
  608. threshold = self._get_confidence_threshold(layer_index)
  609. ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
  610. early_stopped_pairs = ratio_confident > self.depth_confidence
  611. else:
  612. # If the current layer is the last layer, we stop the forward pass through the transformer layers for
  613. # all pairs of images.
  614. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  615. return early_stopped_pairs
  616. def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
  617. if early_stops is not None:
  618. descriptors = descriptors[early_stops]
  619. mask = mask[early_stops]
  620. scores = self.match_assignment_layers[layer_index](descriptors, mask)
  621. matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
  622. return matches, matching_scores
  623. def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
  624. """mask points which should be removed"""
  625. keep = scores > (1 - self.width_confidence)
  626. if confidences is not None: # Low-confidence points are never pruned.
  627. keep |= confidences <= self._get_confidence_threshold(layer_index)
  628. return keep
  629. def _do_layer_keypoint_pruning(
  630. self,
  631. descriptors: torch.Tensor,
  632. keypoints: torch.Tensor,
  633. mask: torch.Tensor,
  634. indices: torch.Tensor,
  635. prune_output: torch.Tensor,
  636. keypoint_confidences: torch.Tensor,
  637. layer_index: int,
  638. ):
  639. """
  640. For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
  641. descriptors.
  642. """
  643. batch_size, _, _ = descriptors.shape
  644. descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
  645. pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
  646. pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
  647. # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
  648. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
  649. [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
  650. for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
  651. )
  652. for i in range(batch_size):
  653. prune_output[i, pruned_indices[i]] += 1
  654. # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
  655. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
  656. pad_sequence(pruned_tensor, batch_first=True)
  657. for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
  658. )
  659. pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
  660. pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
  661. return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
  662. def _concat_early_stopped_outputs(
  663. self,
  664. early_stops_indices,
  665. final_pruned_keypoints_indices,
  666. final_pruned_keypoints_iterations,
  667. matches,
  668. matching_scores,
  669. ):
  670. early_stops_indices = torch.stack(early_stops_indices)
  671. # Rearrange tensors to have the same order as the input batch
  672. ids = torch.arange(early_stops_indices.shape[0])
  673. order_indices = early_stops_indices[ids]
  674. early_stops_indices = early_stops_indices[order_indices]
  675. matches, final_pruned_keypoints_indices = (
  676. pad_sequence(tensor, batch_first=True, padding_value=-1)
  677. for tensor in [matches, final_pruned_keypoints_indices]
  678. )
  679. matching_scores, final_pruned_keypoints_iterations = (
  680. pad_sequence(tensor, batch_first=True, padding_value=0)
  681. for tensor in [matching_scores, final_pruned_keypoints_iterations]
  682. )
  683. matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
  684. tensor[early_stops_indices]
  685. for tensor in [
  686. matches,
  687. matching_scores,
  688. final_pruned_keypoints_indices,
  689. final_pruned_keypoints_iterations,
  690. ]
  691. )
  692. return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
  693. def _do_final_keypoint_pruning(
  694. self,
  695. indices: torch.Tensor,
  696. matches: torch.Tensor,
  697. matching_scores: torch.Tensor,
  698. num_keypoints: torch.Tensor,
  699. ) -> tuple[torch.Tensor, torch.Tensor]:
  700. # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
  701. # have tensors from
  702. batch_size, _ = indices.shape
  703. indices, matches, matching_scores = (
  704. tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
  705. )
  706. indices0 = indices[:, 0]
  707. indices1 = indices[:, 1]
  708. matches0 = matches[:, 0]
  709. matches1 = matches[:, 1]
  710. matching_scores0 = matching_scores[:, 0]
  711. matching_scores1 = matching_scores[:, 1]
  712. # Prepare final matches and matching scores
  713. _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
  714. _matching_scores = torch.zeros(
  715. (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
  716. )
  717. # Fill the matches and matching scores for each image pair
  718. for i in range(batch_size // 2):
  719. _matches[i, 0, indices0[i]] = torch.where(
  720. matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
  721. )
  722. _matches[i, 1, indices1[i]] = torch.where(
  723. matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
  724. )
  725. _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
  726. _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
  727. return _matches, _matching_scores
  728. def _match_image_pair(
  729. self,
  730. keypoints: torch.Tensor,
  731. descriptors: torch.Tensor,
  732. height: int,
  733. width: int,
  734. mask: Optional[torch.Tensor] = None,
  735. output_attentions: Optional[bool] = None,
  736. output_hidden_states: Optional[bool] = None,
  737. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
  738. all_hidden_states = () if output_hidden_states else None
  739. all_attentions = () if output_attentions else None
  740. if keypoints.shape[2] == 0: # no keypoints
  741. shape = keypoints.shape[:-1]
  742. return (
  743. keypoints.new_full(shape, -1, dtype=torch.int),
  744. keypoints.new_zeros(shape),
  745. keypoints.new_zeros(shape),
  746. all_hidden_states,
  747. all_attentions,
  748. )
  749. device = keypoints.device
  750. batch_size, _, initial_num_keypoints, _ = keypoints.shape
  751. num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
  752. # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
  753. keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
  754. mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
  755. descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
  756. image_indices = torch.arange(batch_size * 2, device=device)
  757. # Keypoint normalization
  758. keypoints = normalize_keypoints(keypoints, height, width)
  759. descriptors, keypoint_encoding_output = self._keypoint_processing(
  760. descriptors, keypoints, output_hidden_states=output_hidden_states
  761. )
  762. keypoints = keypoint_encoding_output[0]
  763. # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
  764. # keypoints is above a certain threshold.
  765. do_early_stop = self.depth_confidence > 0
  766. # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
  767. # the keypoints is below a certain threshold.
  768. do_keypoint_pruning = self.width_confidence > 0
  769. early_stops_indices = []
  770. matches = []
  771. matching_scores = []
  772. final_pruned_keypoints_indices = []
  773. final_pruned_keypoints_iterations = []
  774. pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
  775. pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
  776. for layer_index in range(self.num_layers):
  777. input_shape = descriptors.size()
  778. if mask is not None:
  779. extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
  780. else:
  781. extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
  782. layer_output = self.transformer_layers[layer_index](
  783. descriptors,
  784. keypoints,
  785. attention_mask=extended_attention_mask,
  786. output_hidden_states=output_hidden_states,
  787. output_attentions=output_attentions,
  788. )
  789. descriptors, hidden_states, attention = layer_output
  790. if output_hidden_states:
  791. all_hidden_states = all_hidden_states + hidden_states
  792. if output_attentions:
  793. all_attentions = all_attentions + attention
  794. if do_early_stop:
  795. if layer_index < self.num_layers - 1:
  796. # Get the confidence of the keypoints for the current layer
  797. keypoint_confidences = self.token_confidence[layer_index](descriptors)
  798. # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
  799. # the current layer.
  800. early_stopped_pairs = self._get_early_stopped_image_pairs(
  801. keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
  802. )
  803. else:
  804. # Early stopping always occurs at the last layer
  805. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  806. if torch.any(early_stopped_pairs):
  807. # If a pair of images is considered early stopped, we compute the matches for the remaining
  808. # keypoints and stop the forward pass through the transformer layers for this pair of images.
  809. early_stops = early_stopped_pairs.repeat_interleave(2)
  810. early_stopped_image_indices = image_indices[early_stops]
  811. early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
  812. descriptors, mask, layer_index, early_stops=early_stops
  813. )
  814. early_stops_indices.extend(list(early_stopped_image_indices))
  815. matches.extend(list(early_stopped_matches))
  816. matching_scores.extend(list(early_stopped_matching_scores))
  817. if do_keypoint_pruning:
  818. final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
  819. final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
  820. # Remove image pairs that have been early stopped from the forward pass
  821. num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
  822. descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
  823. tensor[~early_stops]
  824. for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
  825. )
  826. keypoints = (keypoints_0, keypoint_1)
  827. if do_keypoint_pruning:
  828. pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
  829. tensor[~early_stops]
  830. for tensor in [
  831. pruned_keypoints_indices,
  832. pruned_keypoints_iterations,
  833. keypoint_confidences,
  834. ]
  835. )
  836. # If all pairs of images are early stopped, we stop the forward pass through the transformer
  837. # layers for all pairs of images.
  838. if torch.all(early_stopped_pairs):
  839. break
  840. if do_keypoint_pruning:
  841. # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
  842. # the keypoints is below a certain threshold.
  843. descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
  844. self._do_layer_keypoint_pruning(
  845. descriptors,
  846. keypoints,
  847. mask,
  848. pruned_keypoints_indices,
  849. pruned_keypoints_iterations,
  850. keypoint_confidences,
  851. layer_index,
  852. )
  853. )
  854. if do_early_stop and do_keypoint_pruning:
  855. # Concatenate early stopped outputs together and perform final keypoint pruning
  856. final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
  857. self._concat_early_stopped_outputs(
  858. early_stops_indices,
  859. final_pruned_keypoints_indices,
  860. final_pruned_keypoints_iterations,
  861. matches,
  862. matching_scores,
  863. )
  864. )
  865. matches, matching_scores = self._do_final_keypoint_pruning(
  866. final_pruned_keypoints_indices,
  867. matches,
  868. matching_scores,
  869. initial_num_keypoints,
  870. )
  871. else:
  872. matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
  873. final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
  874. final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
  875. batch_size, 2, initial_num_keypoints
  876. )
  877. return (
  878. matches,
  879. matching_scores,
  880. final_pruned_keypoints_iterations,
  881. all_hidden_states,
  882. all_attentions,
  883. )
  884. @can_return_tuple
  885. @auto_docstring
  886. def forward(
  887. self,
  888. pixel_values: torch.FloatTensor,
  889. labels: Optional[torch.LongTensor] = None,
  890. output_attentions: Optional[bool] = None,
  891. output_hidden_states: Optional[bool] = None,
  892. ) -> Union[tuple, LightGlueKeypointMatchingOutput]:
  893. loss = None
  894. if labels is not None:
  895. raise ValueError("LightGlue is not trainable, no labels should be provided.")
  896. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  897. output_hidden_states = (
  898. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  899. )
  900. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  901. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  902. batch_size, _, channels, height, width = pixel_values.shape
  903. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  904. keypoint_detections = self.keypoint_detector(pixel_values)
  905. keypoints, _, descriptors, mask = keypoint_detections[:4]
  906. keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
  907. descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
  908. mask = mask.reshape(batch_size, 2, -1)
  909. absolute_keypoints = keypoints.clone()
  910. absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
  911. absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
  912. matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
  913. absolute_keypoints,
  914. descriptors,
  915. height,
  916. width,
  917. mask=mask,
  918. output_attentions=output_attentions,
  919. output_hidden_states=output_hidden_states,
  920. )
  921. return LightGlueKeypointMatchingOutput(
  922. loss=loss,
  923. matches=matches,
  924. matching_scores=matching_scores,
  925. keypoints=keypoints,
  926. prune=prune,
  927. mask=mask,
  928. hidden_states=hidden_states,
  929. attentions=attentions,
  930. )
  931. __all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching", "LightGlueConfig", "LightGlueImageProcessor"]