modeling_lightglue.py 43 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920
  1. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  2. # This file was automatically generated from src/transformers/models/lightglue/modular_lightglue.py.
  3. # Do NOT edit this file manually as any edits will be overwritten by the generation of
  4. # the file from the modular. If any change should be done, please apply the change to the
  5. # modular_lightglue.py file directly. One of our CI enforces this.
  6. # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
  7. # Copyright 2025 The HuggingFace Team. All rights reserved.
  8. #
  9. # Licensed under the Apache License, Version 2.0 (the "License");
  10. # you may not use this file except in compliance with the License.
  11. # You may obtain a copy of the License at
  12. #
  13. # http://www.apache.org/licenses/LICENSE-2.0
  14. #
  15. # Unless required by applicable law or agreed to in writing, software
  16. # distributed under the License is distributed on an "AS IS" BASIS,
  17. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  18. # See the License for the specific language governing permissions and
  19. # limitations under the License.
  20. from dataclasses import dataclass
  21. from typing import Callable, Optional, Union
  22. import numpy as np
  23. import torch
  24. from torch import nn
  25. from torch.nn.utils.rnn import pad_sequence
  26. from ...activations import ACT2FN
  27. from ...modeling_flash_attention_utils import FlashAttentionKwargs
  28. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  29. from ...processing_utils import Unpack
  30. from ...utils import ModelOutput, TransformersKwargs, auto_docstring
  31. from ...utils.deprecation import deprecate_kwarg
  32. from ...utils.generic import can_return_tuple
  33. from ..auto.modeling_auto import AutoModelForKeypointDetection
  34. from .configuration_lightglue import LightGlueConfig
  35. @dataclass
  36. @auto_docstring(
  37. custom_intro="""
  38. Base class for outputs of LightGlue keypoint matching models. Due to the nature of keypoint detection and matching,
  39. the number of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the
  40. batch of images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask
  41. tensor is used to indicate which values in the keypoints, matches, matching_scores and prune tensors are keypoint
  42. matching information.
  43. """
  44. )
  45. class LightGlueKeypointMatchingOutput(ModelOutput):
  46. r"""
  47. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  48. Loss computed during training.
  49. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  50. Index of keypoint matched in the other image.
  51. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  52. Scores of predicted matches.
  53. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  54. Absolute (x, y) coordinates of predicted keypoints in a given image.
  55. prune (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
  56. Pruning mask indicating which keypoints are removed and at which layer.
  57. mask (`torch.BoolTensor` of shape `(batch_size, num_keypoints)`):
  58. Mask indicating which values in matches, matching_scores, keypoints and prune are keypoint matching
  59. information.
  60. hidden_states (`Tuple[torch.FloatTensor, ...]`, *optional*):
  61. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  62. num_keypoints)` returned when `output_hidden_states=True` is passed or when
  63. `config.output_hidden_states=True`
  64. attentions (`Tuple[torch.FloatTensor, ...]`, *optional*):
  65. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  66. num_keypoints)` returned when `output_attentions=True` is passed or when
  67. `config.output_attentions=True`
  68. """
  69. loss: Optional[torch.FloatTensor] = None
  70. matches: Optional[torch.FloatTensor] = None
  71. matching_scores: Optional[torch.FloatTensor] = None
  72. keypoints: Optional[torch.FloatTensor] = None
  73. prune: Optional[torch.IntTensor] = None
  74. mask: Optional[torch.FloatTensor] = None
  75. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  76. attentions: Optional[tuple[torch.FloatTensor]] = None
  77. class LightGluePositionalEncoder(nn.Module):
  78. def __init__(self, config: LightGlueConfig):
  79. super().__init__()
  80. self.projector = nn.Linear(2, config.descriptor_dim // config.num_attention_heads // 2, bias=False)
  81. def forward(
  82. self, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
  83. ) -> Union[tuple[torch.Tensor], tuple[torch.Tensor, torch.Tensor]]:
  84. projected_keypoints = self.projector(keypoints)
  85. embeddings = projected_keypoints.repeat_interleave(2, dim=-1)
  86. cosines = torch.cos(embeddings)
  87. sines = torch.sin(embeddings)
  88. embeddings = (cosines, sines)
  89. output = (embeddings, projected_keypoints) if output_hidden_states else (embeddings,)
  90. return output
  91. def rotate_half(x):
  92. # Split and rotate. Note that this function is different from e.g. Llama.
  93. x1 = x[..., ::2]
  94. x2 = x[..., 1::2]
  95. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  96. return rot_x
  97. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  98. """Applies Rotary Position Embedding to the query and key tensors.
  99. Args:
  100. q (`torch.Tensor`): The query tensor.
  101. k (`torch.Tensor`): The key tensor.
  102. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  103. sin (`torch.Tensor`): The sine part of the rotary embedding.
  104. position_ids (`torch.Tensor`, *optional*):
  105. Deprecated and unused.
  106. unsqueeze_dim (`int`, *optional*, defaults to 1):
  107. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  108. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  109. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  110. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  111. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  112. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  113. Returns:
  114. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  115. """
  116. dtype = q.dtype
  117. q = q.float()
  118. k = k.float()
  119. cos = cos.unsqueeze(unsqueeze_dim)
  120. sin = sin.unsqueeze(unsqueeze_dim)
  121. q_embed = (q * cos) + (rotate_half(q) * sin)
  122. k_embed = (k * cos) + (rotate_half(k) * sin)
  123. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  124. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  125. """
  126. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  127. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  128. """
  129. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  130. if n_rep == 1:
  131. return hidden_states
  132. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  133. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  134. def eager_attention_forward(
  135. module: nn.Module,
  136. query: torch.Tensor,
  137. key: torch.Tensor,
  138. value: torch.Tensor,
  139. attention_mask: Optional[torch.Tensor],
  140. scaling: float,
  141. dropout: float = 0.0,
  142. **kwargs: Unpack[TransformersKwargs],
  143. ):
  144. key_states = repeat_kv(key, module.num_key_value_groups)
  145. value_states = repeat_kv(value, module.num_key_value_groups)
  146. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  147. if attention_mask is not None:
  148. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  149. attn_weights = attn_weights + causal_mask
  150. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  151. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  152. attn_output = torch.matmul(attn_weights, value_states)
  153. attn_output = attn_output.transpose(1, 2).contiguous()
  154. return attn_output, attn_weights
  155. class LightGlueAttention(nn.Module):
  156. """Multi-headed attention from 'Attention Is All You Need' paper"""
  157. def __init__(self, config: LightGlueConfig, layer_idx: int):
  158. super().__init__()
  159. self.config = config
  160. self.layer_idx = layer_idx
  161. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  162. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  163. self.scaling = self.head_dim**-0.5
  164. self.attention_dropout = config.attention_dropout
  165. self.is_causal = True
  166. self.q_proj = nn.Linear(
  167. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  168. )
  169. self.k_proj = nn.Linear(
  170. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  171. )
  172. self.v_proj = nn.Linear(
  173. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  174. )
  175. self.o_proj = nn.Linear(
  176. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  177. )
  178. @deprecate_kwarg("past_key_value", new_name="past_key_values", version="4.58")
  179. def forward(
  180. self,
  181. hidden_states: torch.Tensor,
  182. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  183. attention_mask: Optional[torch.Tensor] = None,
  184. encoder_hidden_states: Optional[torch.Tensor] = None,
  185. encoder_attention_mask: Optional[torch.Tensor] = None,
  186. **kwargs: Unpack[FlashAttentionKwargs],
  187. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  188. input_shape = hidden_states.shape[:-1]
  189. hidden_shape = (*input_shape, -1, self.head_dim)
  190. query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
  191. is_cross_attention = encoder_hidden_states is not None
  192. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  193. current_attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  194. key_states = self.k_proj(current_states).view(hidden_shape).transpose(1, 2)
  195. value_states = self.v_proj(current_states).view(hidden_shape).transpose(1, 2)
  196. if position_embeddings is not None:
  197. cos, sin = position_embeddings
  198. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
  199. attention_interface: Callable = eager_attention_forward
  200. if self.config._attn_implementation != "eager":
  201. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  202. attn_output, attn_weights = attention_interface(
  203. self,
  204. query_states,
  205. key_states,
  206. value_states,
  207. current_attention_mask,
  208. dropout=0.0 if not self.training else self.attention_dropout,
  209. scaling=self.scaling,
  210. **kwargs,
  211. )
  212. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  213. attn_output = self.o_proj(attn_output)
  214. return attn_output, attn_weights
  215. class LightGlueMLP(nn.Module):
  216. def __init__(self, config: LightGlueConfig):
  217. super().__init__()
  218. self.config = config
  219. self.activation_fn = ACT2FN[config.hidden_act]
  220. self.fc1 = nn.Linear(config.intermediate_size, config.intermediate_size)
  221. self.fc2 = nn.Linear(config.intermediate_size, config.hidden_size)
  222. self.layer_norm = nn.LayerNorm(config.intermediate_size, elementwise_affine=True)
  223. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  224. hidden_states = self.fc1(hidden_states)
  225. hidden_states = self.layer_norm(hidden_states)
  226. hidden_states = self.activation_fn(hidden_states)
  227. hidden_states = self.fc2(hidden_states)
  228. return hidden_states
  229. class LightGlueTransformerLayer(nn.Module):
  230. def __init__(self, config: LightGlueConfig, layer_idx: int):
  231. super().__init__()
  232. self.self_attention = LightGlueAttention(config, layer_idx)
  233. self.self_mlp = LightGlueMLP(config)
  234. self.cross_attention = LightGlueAttention(config, layer_idx)
  235. self.cross_mlp = LightGlueMLP(config)
  236. def forward(
  237. self,
  238. descriptors: torch.Tensor,
  239. keypoints: torch.Tensor,
  240. attention_mask: torch.Tensor,
  241. output_hidden_states: Optional[bool] = False,
  242. output_attentions: Optional[bool] = False,
  243. ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
  244. all_hidden_states = () if output_hidden_states else None
  245. all_attentions = () if output_attentions else None
  246. if output_hidden_states:
  247. all_hidden_states = all_hidden_states + (descriptors,)
  248. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  249. # Self attention block
  250. attention_output, self_attentions = self.self_attention(
  251. descriptors,
  252. position_embeddings=keypoints,
  253. attention_mask=attention_mask,
  254. output_attentions=output_attentions,
  255. )
  256. intermediate_states = torch.cat([descriptors, attention_output], dim=-1)
  257. output_states = self.self_mlp(intermediate_states)
  258. self_attention_descriptors = descriptors + output_states
  259. if output_hidden_states:
  260. self_attention_hidden_states = (intermediate_states, output_states)
  261. # Reshape hidden_states to group by image_pairs :
  262. # (batch_size, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
  263. # Flip dimension 1 to perform cross attention :
  264. # (image0, image1) -> (image1, image0)
  265. # Reshape back to original shape :
  266. # (batch_size, 2, num_keypoints, descriptor_dim) -> (batch_size, num_keypoints, descriptor_dim)
  267. encoder_hidden_states = (
  268. self_attention_descriptors.reshape(-1, 2, num_keypoints, descriptor_dim)
  269. .flip(1)
  270. .reshape(batch_size, num_keypoints, descriptor_dim)
  271. )
  272. # Same for mask
  273. encoder_attention_mask = (
  274. attention_mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
  275. if attention_mask is not None
  276. else None
  277. )
  278. # Cross attention block
  279. cross_attention_output, cross_attentions = self.cross_attention(
  280. self_attention_descriptors,
  281. encoder_hidden_states=encoder_hidden_states,
  282. encoder_attention_mask=encoder_attention_mask,
  283. output_attentions=output_attentions,
  284. )
  285. cross_intermediate_states = torch.cat([self_attention_descriptors, cross_attention_output], dim=-1)
  286. cross_output_states = self.cross_mlp(cross_intermediate_states)
  287. descriptors = self_attention_descriptors + cross_output_states
  288. if output_hidden_states:
  289. cross_attention_hidden_states = (cross_intermediate_states, cross_output_states)
  290. all_hidden_states = (
  291. all_hidden_states
  292. + (self_attention_descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  293. + self_attention_hidden_states
  294. + (descriptors.reshape(batch_size, num_keypoints, descriptor_dim),)
  295. + cross_attention_hidden_states
  296. )
  297. if output_attentions:
  298. all_attentions = all_attentions + (self_attentions,) + (cross_attentions,)
  299. return descriptors, all_hidden_states, all_attentions
  300. def sigmoid_log_double_softmax(
  301. similarity: torch.Tensor, matchability0: torch.Tensor, matchability1: torch.Tensor
  302. ) -> torch.Tensor:
  303. """create the log assignment matrix from logits and similarity"""
  304. batch_size, num_keypoints_0, num_keypoints_1 = similarity.shape
  305. certainties = nn.functional.logsigmoid(matchability0) + nn.functional.logsigmoid(matchability1).transpose(1, 2)
  306. scores0 = nn.functional.log_softmax(similarity, 2)
  307. scores1 = nn.functional.log_softmax(similarity.transpose(-1, -2).contiguous(), 2).transpose(-1, -2)
  308. scores = similarity.new_full((batch_size, num_keypoints_0 + 1, num_keypoints_1 + 1), 0)
  309. scores[:, :num_keypoints_0, :num_keypoints_1] = scores0 + scores1 + certainties
  310. scores[:, :-1, -1] = nn.functional.logsigmoid(-matchability0.squeeze(-1))
  311. scores[:, -1, :-1] = nn.functional.logsigmoid(-matchability1.squeeze(-1))
  312. return scores
  313. class LightGlueMatchAssignmentLayer(nn.Module):
  314. def __init__(self, config: LightGlueConfig):
  315. super().__init__()
  316. self.descriptor_dim = config.descriptor_dim
  317. self.final_projection = nn.Linear(self.descriptor_dim, self.descriptor_dim, bias=True)
  318. self.matchability = nn.Linear(self.descriptor_dim, 1, bias=True)
  319. def forward(self, descriptors: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
  320. batch_size, num_keypoints, descriptor_dim = descriptors.shape
  321. # Final projection and similarity computation
  322. m_descriptors = self.final_projection(descriptors)
  323. m_descriptors = m_descriptors / torch.tensor(self.descriptor_dim, device=m_descriptors.device) ** 0.25
  324. m_descriptors = m_descriptors.reshape(batch_size // 2, 2, num_keypoints, descriptor_dim)
  325. m_descriptors0 = m_descriptors[:, 0]
  326. m_descriptors1 = m_descriptors[:, 1]
  327. similarity = m_descriptors0 @ m_descriptors1.transpose(-1, -2)
  328. if mask is not None:
  329. mask = mask.reshape(batch_size // 2, 2, num_keypoints)
  330. mask0 = mask[:, 0].unsqueeze(-1)
  331. mask1 = mask[:, 1].unsqueeze(-1).transpose(-1, -2)
  332. mask = mask0 * mask1
  333. similarity = similarity.masked_fill(mask == 0, torch.finfo(similarity.dtype).min)
  334. # Compute matchability of descriptors
  335. matchability = self.matchability(descriptors)
  336. matchability = matchability.reshape(batch_size // 2, 2, num_keypoints, 1)
  337. matchability_0 = matchability[:, 0]
  338. matchability_1 = matchability[:, 1]
  339. # Compute scores from similarity and matchability
  340. scores = sigmoid_log_double_softmax(similarity, matchability_0, matchability_1)
  341. return scores
  342. def get_matchability(self, descriptors: torch.Tensor) -> torch.Tensor:
  343. """Get matchability of descriptors as a probability"""
  344. matchability = self.matchability(descriptors)
  345. matchability = nn.functional.sigmoid(matchability).squeeze(-1)
  346. return matchability
  347. class LightGlueTokenConfidenceLayer(nn.Module):
  348. def __init__(self, config: LightGlueConfig):
  349. super().__init__()
  350. self.token = nn.Linear(config.descriptor_dim, 1)
  351. def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
  352. token = self.token(descriptors.detach())
  353. token = nn.functional.sigmoid(token).squeeze(-1)
  354. return token
  355. @auto_docstring
  356. class LightGluePreTrainedModel(PreTrainedModel):
  357. """
  358. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  359. models.
  360. """
  361. config: LightGlueConfig
  362. base_model_prefix = "lightglue"
  363. main_input_name = "pixel_values"
  364. supports_gradient_checkpointing = False
  365. _supports_flash_attn = True
  366. _supports_sdpa = True
  367. def get_matches_from_scores(scores: torch.Tensor, threshold: float) -> tuple[torch.Tensor, torch.Tensor]:
  368. """obtain matches from a score matrix [Bx M+1 x N+1]"""
  369. batch_size, _, _ = scores.shape
  370. # For each keypoint, get the best match
  371. max0 = scores[:, :-1, :-1].max(2)
  372. max1 = scores[:, :-1, :-1].max(1)
  373. matches0 = max0.indices
  374. matches1 = max1.indices
  375. # Mutual check for matches
  376. indices0 = torch.arange(matches0.shape[1], device=matches0.device)[None]
  377. indices1 = torch.arange(matches1.shape[1], device=matches1.device)[None]
  378. mutual0 = indices0 == matches1.gather(1, matches0)
  379. mutual1 = indices1 == matches0.gather(1, matches1)
  380. # Get matching scores and filter based on mutual check and thresholding
  381. max0 = max0.values.exp()
  382. zero = max0.new_tensor(0)
  383. matching_scores0 = torch.where(mutual0, max0, zero)
  384. matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, matches1), zero)
  385. valid0 = mutual0 & (matching_scores0 > threshold)
  386. valid1 = mutual1 & valid0.gather(1, matches1)
  387. # Filter matches based on mutual check and thresholding of scores
  388. matches0 = torch.where(valid0, matches0, -1)
  389. matches1 = torch.where(valid1, matches1, -1)
  390. matches = torch.stack([matches0, matches1]).transpose(0, 1).reshape(batch_size * 2, -1)
  391. matching_scores = torch.stack([matching_scores0, matching_scores1]).transpose(0, 1).reshape(batch_size * 2, -1)
  392. return matches, matching_scores
  393. def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
  394. """
  395. Normalize keypoints locations based on image image_shape
  396. Args:
  397. keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
  398. Keypoints locations in (x, y) format.
  399. height (`int`):
  400. Image height.
  401. width (`int`):
  402. Image width.
  403. Returns:
  404. Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
  405. """
  406. size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
  407. shift = size / 2
  408. scale = size.max(-1).values / 2
  409. keypoints = (keypoints - shift[..., None, :]) / scale[..., None, None]
  410. return keypoints
  411. @auto_docstring(
  412. custom_intro="""
  413. LightGlue model taking images as inputs and outputting the matching of them.
  414. """
  415. )
  416. class LightGlueForKeypointMatching(LightGluePreTrainedModel):
  417. """
  418. LightGlue is a model matching keypoints in images by leveraging detections from a keypoint detector such as
  419. SuperPoint. It is based on the SuperGlue architecture and is designed to be lightweight and efficient.
  420. It consists of :
  421. 1. Keypoint Encoder
  422. 2. A Graph Neural Network with self and cross attention layers
  423. 3. Matching Assignment layers
  424. The correspondence ids use -1 to indicate non-matching points.
  425. Philipp Lindenberger, Paul-Edouard Sarlin and Marc Pollefeys. LightGlue: Local Feature Matching at Light Speed.
  426. In ICCV 2023. https://huggingface.co/papers/2306.13643
  427. """
  428. def __init__(self, config: LightGlueConfig):
  429. super().__init__(config)
  430. self.keypoint_detector = AutoModelForKeypointDetection.from_config(
  431. config.keypoint_detector_config, trust_remote_code=config.trust_remote_code
  432. )
  433. self.keypoint_detector_descriptor_dim = config.keypoint_detector_config.descriptor_decoder_dim
  434. self.descriptor_dim = config.descriptor_dim
  435. self.num_layers = config.num_hidden_layers
  436. self.filter_threshold = config.filter_threshold
  437. self.depth_confidence = config.depth_confidence
  438. self.width_confidence = config.width_confidence
  439. if self.descriptor_dim != self.keypoint_detector_descriptor_dim:
  440. self.input_projection = nn.Linear(self.keypoint_detector_descriptor_dim, self.descriptor_dim, bias=True)
  441. else:
  442. self.input_projection = nn.Identity()
  443. self.positional_encoder = LightGluePositionalEncoder(config)
  444. self.transformer_layers = nn.ModuleList(
  445. [LightGlueTransformerLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]
  446. )
  447. self.match_assignment_layers = nn.ModuleList(
  448. [LightGlueMatchAssignmentLayer(config) for _ in range(config.num_hidden_layers)]
  449. )
  450. self.token_confidence = nn.ModuleList(
  451. [LightGlueTokenConfidenceLayer(config) for _ in range(config.num_hidden_layers - 1)]
  452. )
  453. self.post_init()
  454. def _get_confidence_threshold(self, layer_index: int) -> float:
  455. """scaled confidence threshold for a given layer"""
  456. threshold = 0.8 + 0.1 * np.exp(-4.0 * layer_index / self.num_layers)
  457. return np.clip(threshold, 0, 1)
  458. def _keypoint_processing(
  459. self, descriptors: torch.Tensor, keypoints: torch.Tensor, output_hidden_states: Optional[bool] = False
  460. ) -> tuple[torch.Tensor, tuple[torch.Tensor, torch.Tensor]]:
  461. descriptors = descriptors.detach().contiguous()
  462. projected_descriptors = self.input_projection(descriptors)
  463. keypoint_encoding_output = self.positional_encoder(keypoints, output_hidden_states=output_hidden_states)
  464. return projected_descriptors, keypoint_encoding_output
  465. def _get_early_stopped_image_pairs(
  466. self, keypoint_confidences: torch.Tensor, layer_index: int, mask: torch.Tensor, num_points: torch.Tensor
  467. ) -> torch.Tensor:
  468. """evaluate whether we should stop inference based on the confidence of the keypoints"""
  469. batch_size, _ = mask.shape
  470. if layer_index < self.num_layers - 1:
  471. # If the current layer is not the last layer, we compute the confidence of the keypoints and check
  472. # if we should stop the forward pass through the transformer layers for each pair of images.
  473. keypoint_confidences = keypoint_confidences.masked_fill(mask == 0, 1)
  474. keypoint_confidences = keypoint_confidences.reshape(batch_size // 2, -1)
  475. threshold = self._get_confidence_threshold(layer_index)
  476. ratio_confident = 1.0 - (keypoint_confidences < threshold).float().sum(dim=1) / num_points
  477. early_stopped_pairs = ratio_confident > self.depth_confidence
  478. else:
  479. # If the current layer is the last layer, we stop the forward pass through the transformer layers for
  480. # all pairs of images.
  481. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  482. return early_stopped_pairs
  483. def _get_keypoint_matching(self, descriptors, mask, layer_index, early_stops=None):
  484. if early_stops is not None:
  485. descriptors = descriptors[early_stops]
  486. mask = mask[early_stops]
  487. scores = self.match_assignment_layers[layer_index](descriptors, mask)
  488. matches, matching_scores = get_matches_from_scores(scores, self.filter_threshold)
  489. return matches, matching_scores
  490. def _get_pruning_mask(self, confidences: torch.Tensor, scores: torch.Tensor, layer_index: int) -> torch.Tensor:
  491. """mask points which should be removed"""
  492. keep = scores > (1 - self.width_confidence)
  493. if confidences is not None: # Low-confidence points are never pruned.
  494. keep |= confidences <= self._get_confidence_threshold(layer_index)
  495. return keep
  496. def _do_layer_keypoint_pruning(
  497. self,
  498. descriptors: torch.Tensor,
  499. keypoints: torch.Tensor,
  500. mask: torch.Tensor,
  501. indices: torch.Tensor,
  502. prune_output: torch.Tensor,
  503. keypoint_confidences: torch.Tensor,
  504. layer_index: int,
  505. ):
  506. """
  507. For a given layer, prune keypoints based on the confidence of the keypoints and the matchability of the
  508. descriptors.
  509. """
  510. batch_size, _, _ = descriptors.shape
  511. descriptors_matchability = self.match_assignment_layers[layer_index].get_matchability(descriptors)
  512. pruned_keypoints_mask = self._get_pruning_mask(keypoint_confidences, descriptors_matchability, layer_index)
  513. pruned_keypoints_mask = pruned_keypoints_mask.masked_fill(mask == 0, torch.tensor(False))
  514. # For each image, we extract the pruned indices and the corresponding descriptors and keypoints.
  515. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask, pruned_indices = (
  516. [t[mask] for t, mask in zip(tensor, pruned_keypoints_mask)]
  517. for tensor in [descriptors, keypoints[0], keypoints[1], pruned_keypoints_mask, indices]
  518. )
  519. for i in range(batch_size):
  520. prune_output[i, pruned_indices[i]] += 1
  521. # Pad the pruned descriptors, keypoints, indices and mask to have the same shape across the batch.
  522. pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask = (
  523. pad_sequence(pruned_tensor, batch_first=True)
  524. for pruned_tensor in [pruned_descriptors, pruned_keypoints_0, pruned_keypoints_1, pruned_mask]
  525. )
  526. pruned_keypoints = (pruned_keypoints_0, pruned_keypoints_1)
  527. pruned_indices = pad_sequence(pruned_indices, batch_first=True, padding_value=-1)
  528. return pruned_descriptors, pruned_keypoints, pruned_indices, pruned_mask, prune_output
  529. def _concat_early_stopped_outputs(
  530. self,
  531. early_stops_indices,
  532. final_pruned_keypoints_indices,
  533. final_pruned_keypoints_iterations,
  534. matches,
  535. matching_scores,
  536. ):
  537. early_stops_indices = torch.stack(early_stops_indices)
  538. # Rearrange tensors to have the same order as the input batch
  539. ids = torch.arange(early_stops_indices.shape[0])
  540. order_indices = early_stops_indices[ids]
  541. early_stops_indices = early_stops_indices[order_indices]
  542. matches, final_pruned_keypoints_indices = (
  543. pad_sequence(tensor, batch_first=True, padding_value=-1)
  544. for tensor in [matches, final_pruned_keypoints_indices]
  545. )
  546. matching_scores, final_pruned_keypoints_iterations = (
  547. pad_sequence(tensor, batch_first=True, padding_value=0)
  548. for tensor in [matching_scores, final_pruned_keypoints_iterations]
  549. )
  550. matches, matching_scores, final_pruned_keypoints_indices, final_pruned_keypoints_iterations = (
  551. tensor[early_stops_indices]
  552. for tensor in [
  553. matches,
  554. matching_scores,
  555. final_pruned_keypoints_indices,
  556. final_pruned_keypoints_iterations,
  557. ]
  558. )
  559. return final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores
  560. def _do_final_keypoint_pruning(
  561. self,
  562. indices: torch.Tensor,
  563. matches: torch.Tensor,
  564. matching_scores: torch.Tensor,
  565. num_keypoints: torch.Tensor,
  566. ) -> tuple[torch.Tensor, torch.Tensor]:
  567. # (batch_size, num_keypoints) -> (batch_size // 2, 2, num_keypoints) -> 2 * (batch_size // 2, num_keypoints) to
  568. # have tensors from
  569. batch_size, _ = indices.shape
  570. indices, matches, matching_scores = (
  571. tensor.reshape(batch_size // 2, 2, -1) for tensor in [indices, matches, matching_scores]
  572. )
  573. indices0 = indices[:, 0]
  574. indices1 = indices[:, 1]
  575. matches0 = matches[:, 0]
  576. matches1 = matches[:, 1]
  577. matching_scores0 = matching_scores[:, 0]
  578. matching_scores1 = matching_scores[:, 1]
  579. # Prepare final matches and matching scores
  580. _matches = torch.full((batch_size // 2, 2, num_keypoints), -1, device=indices.device, dtype=matches.dtype)
  581. _matching_scores = torch.zeros(
  582. (batch_size // 2, 2, num_keypoints), device=indices.device, dtype=matching_scores.dtype
  583. )
  584. # Fill the matches and matching scores for each image pair
  585. for i in range(batch_size // 2):
  586. _matches[i, 0, indices0[i]] = torch.where(
  587. matches0[i] == -1, -1, indices1[i].gather(0, matches0[i].clamp(min=0))
  588. )
  589. _matches[i, 1, indices1[i]] = torch.where(
  590. matches1[i] == -1, -1, indices0[i].gather(0, matches1[i].clamp(min=0))
  591. )
  592. _matching_scores[i, 0, indices0[i]] = matching_scores0[i]
  593. _matching_scores[i, 1, indices1[i]] = matching_scores1[i]
  594. return _matches, _matching_scores
  595. def _match_image_pair(
  596. self,
  597. keypoints: torch.Tensor,
  598. descriptors: torch.Tensor,
  599. height: int,
  600. width: int,
  601. mask: Optional[torch.Tensor] = None,
  602. output_attentions: Optional[bool] = None,
  603. output_hidden_states: Optional[bool] = None,
  604. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, tuple, tuple]:
  605. all_hidden_states = () if output_hidden_states else None
  606. all_attentions = () if output_attentions else None
  607. if keypoints.shape[2] == 0: # no keypoints
  608. shape = keypoints.shape[:-1]
  609. return (
  610. keypoints.new_full(shape, -1, dtype=torch.int),
  611. keypoints.new_zeros(shape),
  612. keypoints.new_zeros(shape),
  613. all_hidden_states,
  614. all_attentions,
  615. )
  616. device = keypoints.device
  617. batch_size, _, initial_num_keypoints, _ = keypoints.shape
  618. num_points_per_pair = torch.sum(mask.reshape(batch_size, -1), dim=1)
  619. # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
  620. keypoints = keypoints.reshape(batch_size * 2, initial_num_keypoints, 2)
  621. mask = mask.reshape(batch_size * 2, initial_num_keypoints) if mask is not None else None
  622. descriptors = descriptors.reshape(batch_size * 2, initial_num_keypoints, self.keypoint_detector_descriptor_dim)
  623. image_indices = torch.arange(batch_size * 2, device=device)
  624. # Keypoint normalization
  625. keypoints = normalize_keypoints(keypoints, height, width)
  626. descriptors, keypoint_encoding_output = self._keypoint_processing(
  627. descriptors, keypoints, output_hidden_states=output_hidden_states
  628. )
  629. keypoints = keypoint_encoding_output[0]
  630. # Early stop consists of stopping the forward pass through the transformer layers when the confidence of the
  631. # keypoints is above a certain threshold.
  632. do_early_stop = self.depth_confidence > 0
  633. # Keypoint pruning consists of removing keypoints from the input of the transformer layers when the confidence of
  634. # the keypoints is below a certain threshold.
  635. do_keypoint_pruning = self.width_confidence > 0
  636. early_stops_indices = []
  637. matches = []
  638. matching_scores = []
  639. final_pruned_keypoints_indices = []
  640. final_pruned_keypoints_iterations = []
  641. pruned_keypoints_indices = torch.arange(0, initial_num_keypoints, device=device).expand(batch_size * 2, -1)
  642. pruned_keypoints_iterations = torch.ones_like(pruned_keypoints_indices)
  643. for layer_index in range(self.num_layers):
  644. input_shape = descriptors.size()
  645. if mask is not None:
  646. extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
  647. else:
  648. extended_attention_mask = torch.ones((batch_size, input_shape[-2]), device=keypoints.device)
  649. layer_output = self.transformer_layers[layer_index](
  650. descriptors,
  651. keypoints,
  652. attention_mask=extended_attention_mask,
  653. output_hidden_states=output_hidden_states,
  654. output_attentions=output_attentions,
  655. )
  656. descriptors, hidden_states, attention = layer_output
  657. if output_hidden_states:
  658. all_hidden_states = all_hidden_states + hidden_states
  659. if output_attentions:
  660. all_attentions = all_attentions + attention
  661. if do_early_stop:
  662. if layer_index < self.num_layers - 1:
  663. # Get the confidence of the keypoints for the current layer
  664. keypoint_confidences = self.token_confidence[layer_index](descriptors)
  665. # Determine which pairs of images should be early stopped based on the confidence of the keypoints for
  666. # the current layer.
  667. early_stopped_pairs = self._get_early_stopped_image_pairs(
  668. keypoint_confidences, layer_index, mask, num_points=num_points_per_pair
  669. )
  670. else:
  671. # Early stopping always occurs at the last layer
  672. early_stopped_pairs = torch.ones(batch_size, dtype=torch.bool)
  673. if torch.any(early_stopped_pairs):
  674. # If a pair of images is considered early stopped, we compute the matches for the remaining
  675. # keypoints and stop the forward pass through the transformer layers for this pair of images.
  676. early_stops = early_stopped_pairs.repeat_interleave(2)
  677. early_stopped_image_indices = image_indices[early_stops]
  678. early_stopped_matches, early_stopped_matching_scores = self._get_keypoint_matching(
  679. descriptors, mask, layer_index, early_stops=early_stops
  680. )
  681. early_stops_indices.extend(list(early_stopped_image_indices))
  682. matches.extend(list(early_stopped_matches))
  683. matching_scores.extend(list(early_stopped_matching_scores))
  684. if do_keypoint_pruning:
  685. final_pruned_keypoints_indices.extend(list(pruned_keypoints_indices[early_stops]))
  686. final_pruned_keypoints_iterations.extend(list(pruned_keypoints_iterations[early_stops]))
  687. # Remove image pairs that have been early stopped from the forward pass
  688. num_points_per_pair = num_points_per_pair[~early_stopped_pairs]
  689. descriptors, keypoints_0, keypoint_1, mask, image_indices = tuple(
  690. tensor[~early_stops]
  691. for tensor in [descriptors, keypoints[0], keypoints[1], mask, image_indices]
  692. )
  693. keypoints = (keypoints_0, keypoint_1)
  694. if do_keypoint_pruning:
  695. pruned_keypoints_indices, pruned_keypoints_iterations, keypoint_confidences = tuple(
  696. tensor[~early_stops]
  697. for tensor in [
  698. pruned_keypoints_indices,
  699. pruned_keypoints_iterations,
  700. keypoint_confidences,
  701. ]
  702. )
  703. # If all pairs of images are early stopped, we stop the forward pass through the transformer
  704. # layers for all pairs of images.
  705. if torch.all(early_stopped_pairs):
  706. break
  707. if do_keypoint_pruning:
  708. # Prune keypoints from the input of the transformer layers for the next iterations if the confidence of
  709. # the keypoints is below a certain threshold.
  710. descriptors, keypoints, pruned_keypoints_indices, mask, pruned_keypoints_iterations = (
  711. self._do_layer_keypoint_pruning(
  712. descriptors,
  713. keypoints,
  714. mask,
  715. pruned_keypoints_indices,
  716. pruned_keypoints_iterations,
  717. keypoint_confidences,
  718. layer_index,
  719. )
  720. )
  721. if do_early_stop and do_keypoint_pruning:
  722. # Concatenate early stopped outputs together and perform final keypoint pruning
  723. final_pruned_keypoints_indices, final_pruned_keypoints_iterations, matches, matching_scores = (
  724. self._concat_early_stopped_outputs(
  725. early_stops_indices,
  726. final_pruned_keypoints_indices,
  727. final_pruned_keypoints_iterations,
  728. matches,
  729. matching_scores,
  730. )
  731. )
  732. matches, matching_scores = self._do_final_keypoint_pruning(
  733. final_pruned_keypoints_indices,
  734. matches,
  735. matching_scores,
  736. initial_num_keypoints,
  737. )
  738. else:
  739. matches, matching_scores = self._get_keypoint_matching(descriptors, mask, self.num_layers - 1)
  740. final_pruned_keypoints_iterations = torch.ones_like(matching_scores) * self.num_layers
  741. final_pruned_keypoints_iterations = final_pruned_keypoints_iterations.reshape(
  742. batch_size, 2, initial_num_keypoints
  743. )
  744. return (
  745. matches,
  746. matching_scores,
  747. final_pruned_keypoints_iterations,
  748. all_hidden_states,
  749. all_attentions,
  750. )
  751. @can_return_tuple
  752. @auto_docstring
  753. def forward(
  754. self,
  755. pixel_values: torch.FloatTensor,
  756. labels: Optional[torch.LongTensor] = None,
  757. output_attentions: Optional[bool] = None,
  758. output_hidden_states: Optional[bool] = None,
  759. ) -> Union[tuple, LightGlueKeypointMatchingOutput]:
  760. loss = None
  761. if labels is not None:
  762. raise ValueError("LightGlue is not trainable, no labels should be provided.")
  763. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  764. output_hidden_states = (
  765. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  766. )
  767. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  768. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  769. batch_size, _, channels, height, width = pixel_values.shape
  770. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  771. keypoint_detections = self.keypoint_detector(pixel_values)
  772. keypoints, _, descriptors, mask = keypoint_detections[:4]
  773. keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
  774. descriptors = descriptors.reshape(batch_size, 2, -1, self.keypoint_detector_descriptor_dim).to(pixel_values)
  775. mask = mask.reshape(batch_size, 2, -1)
  776. absolute_keypoints = keypoints.clone()
  777. absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
  778. absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
  779. matches, matching_scores, prune, hidden_states, attentions = self._match_image_pair(
  780. absolute_keypoints,
  781. descriptors,
  782. height,
  783. width,
  784. mask=mask,
  785. output_attentions=output_attentions,
  786. output_hidden_states=output_hidden_states,
  787. )
  788. return LightGlueKeypointMatchingOutput(
  789. loss=loss,
  790. matches=matches,
  791. matching_scores=matching_scores,
  792. keypoints=keypoints,
  793. prune=prune,
  794. mask=mask,
  795. hidden_states=hidden_states,
  796. attentions=attentions,
  797. )
  798. __all__ = ["LightGluePreTrainedModel", "LightGlueForKeypointMatching"]