modeling_efficientloftr.py 57 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from dataclasses import dataclass
  15. from typing import Callable, Optional, Union
  16. import torch
  17. from torch import nn
  18. from ...activations import ACT2CLS, ACT2FN
  19. from ...modeling_layers import GradientCheckpointingLayer
  20. from ...modeling_outputs import BackboneOutput
  21. from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
  22. from ...modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
  23. from ...processing_utils import Unpack
  24. from ...pytorch_utils import compile_compatible_method_lru_cache
  25. from ...utils import (
  26. ModelOutput,
  27. TransformersKwargs,
  28. auto_docstring,
  29. can_return_tuple,
  30. torch_int,
  31. )
  32. from ...utils.generic import check_model_inputs
  33. from .configuration_efficientloftr import EfficientLoFTRConfig
  34. @dataclass
  35. @auto_docstring(
  36. custom_intro="""
  37. Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number
  38. of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
  39. images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
  40. used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
  41. information.
  42. """
  43. )
  44. class KeypointMatchingOutput(ModelOutput):
  45. r"""
  46. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  47. Index of keypoint matched in the other image.
  48. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  49. Scores of predicted matches.
  50. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  51. Absolute (x, y) coordinates of predicted keypoints in a given image.
  52. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
  53. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  54. num_keypoints)`, returned when `output_hidden_states=True` is passed or when
  55. `config.output_hidden_states=True`)
  56. attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
  57. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  58. num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
  59. """
  60. matches: Optional[torch.FloatTensor] = None
  61. matching_scores: Optional[torch.FloatTensor] = None
  62. keypoints: Optional[torch.FloatTensor] = None
  63. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  64. attentions: Optional[tuple[torch.FloatTensor]] = None
  65. @compile_compatible_method_lru_cache(maxsize=32)
  66. def compute_embeddings(inv_freq: torch.Tensor, embed_height: int, embed_width: int, hidden_size: int) -> torch.Tensor:
  67. i_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
  68. j_indices = torch.ones(embed_height, embed_width, dtype=inv_freq.dtype, device=inv_freq.device)
  69. i_indices = i_indices.cumsum(0).unsqueeze(-1)
  70. j_indices = j_indices.cumsum(1).unsqueeze(-1)
  71. emb = torch.zeros(1, embed_height, embed_width, hidden_size // 2, dtype=inv_freq.dtype, device=inv_freq.device)
  72. emb[:, :, :, 0::2] = i_indices * inv_freq
  73. emb[:, :, :, 1::2] = j_indices * inv_freq
  74. return emb
  75. class EfficientLoFTRRotaryEmbedding(nn.Module):
  76. inv_freq: torch.Tensor # fix linting for `register_buffer`
  77. def __init__(self, config: EfficientLoFTRConfig, device=None):
  78. super().__init__()
  79. self.config = config
  80. self.rope_type = config.rope_scaling["rope_type"]
  81. self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
  82. inv_freq, _ = self.rope_init_fn(self.config, device)
  83. inv_freq_expanded = inv_freq[None, None, None, :].float().expand(1, 1, 1, -1)
  84. self.register_buffer("inv_freq", inv_freq_expanded, persistent=False)
  85. @torch.no_grad()
  86. def forward(
  87. self, x: torch.Tensor, position_ids: Optional[tuple[torch.LongTensor, torch.LongTensor]] = None
  88. ) -> tuple[torch.Tensor, torch.Tensor]:
  89. feats_height, feats_width = x.shape[-2:]
  90. embed_height = (feats_height - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
  91. embed_width = (feats_width - self.config.q_aggregation_kernel_size) // self.config.q_aggregation_stride + 1
  92. device_type = x.device.type if isinstance(x.device.type, str) and x.device.type != "mps" else "cpu"
  93. with torch.autocast(device_type=device_type, enabled=False): # Force float32
  94. emb = compute_embeddings(self.inv_freq, embed_height, embed_width, self.config.hidden_size)
  95. sin = emb.sin()
  96. cos = emb.cos()
  97. sin = sin.repeat_interleave(2, dim=-1)
  98. cos = cos.repeat_interleave(2, dim=-1)
  99. sin = sin.to(device=x.device, dtype=x.dtype)
  100. cos = cos.to(device=x.device, dtype=x.dtype)
  101. return cos, sin
  102. # Copied from transformers.models.rt_detr_v2.modeling_rt_detr_v2.RTDetrV2ConvNormLayer with RTDetrV2->EfficientLoFTR
  103. class EfficientLoFTRConvNormLayer(nn.Module):
  104. def __init__(self, config, in_channels, out_channels, kernel_size, stride, padding=None, activation=None):
  105. super().__init__()
  106. self.conv = nn.Conv2d(
  107. in_channels,
  108. out_channels,
  109. kernel_size,
  110. stride,
  111. padding=(kernel_size - 1) // 2 if padding is None else padding,
  112. bias=False,
  113. )
  114. self.norm = nn.BatchNorm2d(out_channels, config.batch_norm_eps)
  115. self.activation = nn.Identity() if activation is None else ACT2CLS[activation]()
  116. def forward(self, hidden_state):
  117. hidden_state = self.conv(hidden_state)
  118. hidden_state = self.norm(hidden_state)
  119. hidden_state = self.activation(hidden_state)
  120. return hidden_state
  121. class EfficientLoFTRRepVGGBlock(GradientCheckpointingLayer):
  122. """
  123. RepVGG architecture block introduced by the work "RepVGG: Making VGG-style ConvNets Great Again".
  124. """
  125. def __init__(self, config: EfficientLoFTRConfig, stage_idx: int, block_idx: int):
  126. super().__init__()
  127. in_channels = config.stage_block_in_channels[stage_idx][block_idx]
  128. out_channels = config.stage_block_out_channels[stage_idx][block_idx]
  129. stride = config.stage_block_stride[stage_idx][block_idx]
  130. activation = config.activation_function
  131. self.conv1 = EfficientLoFTRConvNormLayer(
  132. config, in_channels, out_channels, kernel_size=3, stride=stride, padding=1
  133. )
  134. self.conv2 = EfficientLoFTRConvNormLayer(
  135. config, in_channels, out_channels, kernel_size=1, stride=stride, padding=0
  136. )
  137. self.identity = nn.BatchNorm2d(in_channels) if in_channels == out_channels and stride == 1 else None
  138. self.activation = nn.Identity() if activation is None else ACT2FN[activation]
  139. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  140. if self.identity is not None:
  141. identity_out = self.identity(hidden_states)
  142. else:
  143. identity_out = 0
  144. hidden_states = self.conv1(hidden_states) + self.conv2(hidden_states) + identity_out
  145. hidden_states = self.activation(hidden_states)
  146. return hidden_states
  147. class EfficientLoFTRRepVGGStage(nn.Module):
  148. def __init__(self, config: EfficientLoFTRConfig, stage_idx: int):
  149. super().__init__()
  150. self.blocks = nn.ModuleList([])
  151. for block_idx in range(config.stage_num_blocks[stage_idx]):
  152. self.blocks.append(
  153. EfficientLoFTRRepVGGBlock(
  154. config,
  155. stage_idx,
  156. block_idx,
  157. )
  158. )
  159. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  160. for block in self.blocks:
  161. hidden_states = block(hidden_states)
  162. return hidden_states
  163. class EfficientLoFTRepVGG(nn.Module):
  164. def __init__(self, config: EfficientLoFTRConfig):
  165. super().__init__()
  166. self.stages = nn.ModuleList([])
  167. for stage_idx in range(len(config.stage_stride)):
  168. stage = EfficientLoFTRRepVGGStage(config, stage_idx)
  169. self.stages.append(stage)
  170. def forward(self, hidden_states: torch.Tensor) -> list[torch.Tensor]:
  171. outputs = []
  172. for stage in self.stages:
  173. hidden_states = stage(hidden_states)
  174. outputs.append(hidden_states)
  175. # Exclude first stage in outputs
  176. outputs = outputs[1:]
  177. return outputs
  178. class EfficientLoFTRAggregationLayer(nn.Module):
  179. def __init__(self, config: EfficientLoFTRConfig):
  180. super().__init__()
  181. hidden_size = config.hidden_size
  182. self.q_aggregation = nn.Conv2d(
  183. hidden_size,
  184. hidden_size,
  185. kernel_size=config.q_aggregation_kernel_size,
  186. padding=0,
  187. stride=config.q_aggregation_stride,
  188. bias=False,
  189. groups=hidden_size,
  190. )
  191. self.kv_aggregation = torch.nn.MaxPool2d(
  192. kernel_size=config.kv_aggregation_kernel_size, stride=config.kv_aggregation_stride
  193. )
  194. self.norm = nn.LayerNorm(hidden_size)
  195. def forward(
  196. self,
  197. hidden_states: torch.Tensor,
  198. encoder_hidden_states: Optional[torch.Tensor] = None,
  199. ) -> tuple[torch.Tensor, torch.Tensor]:
  200. query_states = hidden_states
  201. is_cross_attention = encoder_hidden_states is not None
  202. kv_states = encoder_hidden_states if is_cross_attention else hidden_states
  203. query_states = self.q_aggregation(query_states)
  204. kv_states = self.kv_aggregation(kv_states)
  205. query_states = query_states.permute(0, 2, 3, 1)
  206. kv_states = kv_states.permute(0, 2, 3, 1)
  207. hidden_states = self.norm(query_states)
  208. encoder_hidden_states = self.norm(kv_states)
  209. return hidden_states, encoder_hidden_states
  210. # Copied from transformers.models.cohere.modeling_cohere.rotate_half
  211. def rotate_half(x):
  212. # Split and rotate. Note that this function is different from e.g. Llama.
  213. x1 = x[..., ::2]
  214. x2 = x[..., 1::2]
  215. rot_x = torch.stack([-x2, x1], dim=-1).flatten(-2)
  216. return rot_x
  217. # Copied from transformers.models.cohere.modeling_cohere.apply_rotary_pos_emb
  218. def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
  219. """Applies Rotary Position Embedding to the query and key tensors.
  220. Args:
  221. q (`torch.Tensor`): The query tensor.
  222. k (`torch.Tensor`): The key tensor.
  223. cos (`torch.Tensor`): The cosine part of the rotary embedding.
  224. sin (`torch.Tensor`): The sine part of the rotary embedding.
  225. position_ids (`torch.Tensor`, *optional*):
  226. Deprecated and unused.
  227. unsqueeze_dim (`int`, *optional*, defaults to 1):
  228. The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
  229. sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
  230. that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
  231. k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
  232. cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
  233. the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
  234. Returns:
  235. `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
  236. """
  237. dtype = q.dtype
  238. q = q.float()
  239. k = k.float()
  240. cos = cos.unsqueeze(unsqueeze_dim)
  241. sin = sin.unsqueeze(unsqueeze_dim)
  242. q_embed = (q * cos) + (rotate_half(q) * sin)
  243. k_embed = (k * cos) + (rotate_half(k) * sin)
  244. return q_embed.to(dtype=dtype), k_embed.to(dtype=dtype)
  245. # Copied from transformers.models.cohere.modeling_cohere.repeat_kv
  246. def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
  247. """
  248. This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
  249. num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
  250. """
  251. batch, num_key_value_heads, slen, head_dim = hidden_states.shape
  252. if n_rep == 1:
  253. return hidden_states
  254. hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
  255. return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
  256. # Copied from transformers.models.llama.modeling_llama.eager_attention_forward
  257. def eager_attention_forward(
  258. module: nn.Module,
  259. query: torch.Tensor,
  260. key: torch.Tensor,
  261. value: torch.Tensor,
  262. attention_mask: Optional[torch.Tensor],
  263. scaling: float,
  264. dropout: float = 0.0,
  265. **kwargs: Unpack[TransformersKwargs],
  266. ):
  267. key_states = repeat_kv(key, module.num_key_value_groups)
  268. value_states = repeat_kv(value, module.num_key_value_groups)
  269. attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
  270. if attention_mask is not None:
  271. causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
  272. attn_weights = attn_weights + causal_mask
  273. attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
  274. attn_weights = nn.functional.dropout(attn_weights, p=dropout, training=module.training)
  275. attn_output = torch.matmul(attn_weights, value_states)
  276. attn_output = attn_output.transpose(1, 2).contiguous()
  277. return attn_output, attn_weights
  278. class EfficientLoFTRAttention(nn.Module):
  279. """Multi-headed attention from 'Attention Is All You Need' paper"""
  280. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  281. super().__init__()
  282. self.config = config
  283. self.layer_idx = layer_idx
  284. self.head_dim = getattr(config, "head_dim", config.hidden_size // config.num_attention_heads)
  285. self.num_key_value_groups = config.num_attention_heads // config.num_key_value_heads
  286. self.scaling = self.head_dim**-0.5
  287. self.attention_dropout = config.attention_dropout
  288. self.is_causal = False
  289. self.q_proj = nn.Linear(
  290. config.hidden_size, config.num_attention_heads * self.head_dim, bias=config.attention_bias
  291. )
  292. self.k_proj = nn.Linear(
  293. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  294. )
  295. self.v_proj = nn.Linear(
  296. config.hidden_size, config.num_key_value_heads * self.head_dim, bias=config.attention_bias
  297. )
  298. self.o_proj = nn.Linear(
  299. config.num_attention_heads * self.head_dim, config.hidden_size, bias=config.attention_bias
  300. )
  301. def forward(
  302. self,
  303. hidden_states: torch.Tensor,
  304. encoder_hidden_states: Optional[torch.Tensor] = None,
  305. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  306. **kwargs: Unpack[TransformersKwargs],
  307. ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
  308. batch_size, seq_len, dim = hidden_states.shape
  309. input_shape = hidden_states.shape[:-1]
  310. query_states = self.q_proj(hidden_states).view(batch_size, seq_len, -1, dim)
  311. is_cross_attention = encoder_hidden_states is not None
  312. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  313. key_states = self.k_proj(current_states).view(batch_size, seq_len, -1, dim)
  314. value_states = self.v_proj(current_states).view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  315. if position_embeddings is not None:
  316. cos, sin = position_embeddings
  317. query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, unsqueeze_dim=2)
  318. query_states = query_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  319. key_states = key_states.view(batch_size, seq_len, -1, self.head_dim).transpose(1, 2)
  320. attention_interface: Callable = eager_attention_forward
  321. if self.config._attn_implementation != "eager":
  322. attention_interface = ALL_ATTENTION_FUNCTIONS[self.config._attn_implementation]
  323. attn_output, attn_weights = attention_interface(
  324. self,
  325. query_states,
  326. key_states,
  327. value_states,
  328. attention_mask=None,
  329. dropout=0.0 if not self.training else self.attention_dropout,
  330. scaling=self.scaling,
  331. **kwargs,
  332. )
  333. attn_output = attn_output.reshape(*input_shape, -1).contiguous()
  334. attn_output = self.o_proj(attn_output)
  335. return attn_output, attn_weights
  336. class EfficientLoFTRMLP(nn.Module):
  337. def __init__(self, config: EfficientLoFTRConfig):
  338. super().__init__()
  339. hidden_size = config.hidden_size
  340. intermediate_size = config.intermediate_size
  341. self.fc1 = nn.Linear(hidden_size * 2, intermediate_size, bias=False)
  342. self.activation = ACT2FN[config.mlp_activation_function]
  343. self.fc2 = nn.Linear(intermediate_size, hidden_size, bias=False)
  344. self.layer_norm = nn.LayerNorm(hidden_size)
  345. def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
  346. hidden_states = self.fc1(hidden_states)
  347. hidden_states = self.activation(hidden_states)
  348. hidden_states = self.fc2(hidden_states)
  349. hidden_states = self.layer_norm(hidden_states)
  350. return hidden_states
  351. class EfficientLoFTRAggregatedAttention(nn.Module):
  352. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  353. super().__init__()
  354. self.q_aggregation_kernel_size = config.q_aggregation_kernel_size
  355. self.aggregation = EfficientLoFTRAggregationLayer(config)
  356. self.attention = EfficientLoFTRAttention(config, layer_idx)
  357. self.mlp = EfficientLoFTRMLP(config)
  358. def forward(
  359. self,
  360. hidden_states: torch.Tensor,
  361. encoder_hidden_states: Optional[torch.Tensor] = None,
  362. position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None,
  363. **kwargs: Unpack[TransformersKwargs],
  364. ) -> torch.Tensor:
  365. batch_size, embed_dim, _, _ = hidden_states.shape
  366. # Aggregate features
  367. aggregated_hidden_states, aggregated_encoder_hidden_states = self.aggregation(
  368. hidden_states, encoder_hidden_states
  369. )
  370. _, aggregated_h, aggregated_w, _ = aggregated_hidden_states.shape
  371. # Multi-head attention
  372. aggregated_hidden_states = aggregated_hidden_states.reshape(batch_size, -1, embed_dim)
  373. aggregated_encoder_hidden_states = aggregated_encoder_hidden_states.reshape(batch_size, -1, embed_dim)
  374. attn_output, _ = self.attention(
  375. aggregated_hidden_states,
  376. aggregated_encoder_hidden_states,
  377. position_embeddings=position_embeddings,
  378. **kwargs,
  379. )
  380. # Upsample features
  381. # (batch_size, seq_len, embed_dim) -> (batch_size, embed_dim, h, w) with seq_len = h * w
  382. attn_output = attn_output.permute(0, 2, 1)
  383. attn_output = attn_output.reshape(batch_size, embed_dim, aggregated_h, aggregated_w)
  384. attn_output = torch.nn.functional.interpolate(
  385. attn_output, scale_factor=self.q_aggregation_kernel_size, mode="bilinear", align_corners=False
  386. )
  387. intermediate_states = torch.cat([hidden_states, attn_output], dim=1)
  388. intermediate_states = intermediate_states.permute(0, 2, 3, 1)
  389. output_states = self.mlp(intermediate_states)
  390. output_states = output_states.permute(0, 3, 1, 2)
  391. hidden_states = hidden_states + output_states
  392. return hidden_states
  393. class EfficientLoFTRLocalFeatureTransformerLayer(GradientCheckpointingLayer):
  394. def __init__(self, config: EfficientLoFTRConfig, layer_idx: int):
  395. super().__init__()
  396. self.self_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
  397. self.cross_attention = EfficientLoFTRAggregatedAttention(config, layer_idx)
  398. def forward(
  399. self,
  400. hidden_states: torch.Tensor,
  401. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  402. **kwargs: Unpack[TransformersKwargs],
  403. ) -> torch.Tensor:
  404. batch_size, _, embed_dim, height, width = hidden_states.shape
  405. hidden_states = hidden_states.reshape(-1, embed_dim, height, width)
  406. hidden_states = self.self_attention(hidden_states, position_embeddings=position_embeddings, **kwargs)
  407. ###
  408. # Implementation of a bug in the original implementation regarding the cross-attention
  409. # See : https://github.com/zju3dv/MatchAnything/issues/26
  410. hidden_states = hidden_states.reshape(-1, 2, embed_dim, height, width)
  411. features_0 = hidden_states[:, 0]
  412. features_1 = hidden_states[:, 1]
  413. features_0 = self.cross_attention(features_0, features_1, **kwargs)
  414. features_1 = self.cross_attention(features_1, features_0, **kwargs)
  415. hidden_states = torch.stack((features_0, features_1), dim=1)
  416. ###
  417. return hidden_states
  418. class EfficientLoFTRLocalFeatureTransformer(nn.Module):
  419. def __init__(self, config: EfficientLoFTRConfig):
  420. super().__init__()
  421. self.layers = nn.ModuleList(
  422. [
  423. EfficientLoFTRLocalFeatureTransformerLayer(config, layer_idx=i)
  424. for i in range(config.num_attention_layers)
  425. ]
  426. )
  427. def forward(
  428. self,
  429. hidden_states: torch.Tensor,
  430. position_embeddings: tuple[torch.Tensor, torch.Tensor],
  431. **kwargs: Unpack[TransformersKwargs],
  432. ) -> torch.Tensor:
  433. for layer in self.layers:
  434. hidden_states = layer(hidden_states, position_embeddings=position_embeddings, **kwargs)
  435. return hidden_states
  436. class EfficientLoFTROutConvBlock(nn.Module):
  437. def __init__(self, config: EfficientLoFTRConfig, hidden_size: int, intermediate_size: int):
  438. super().__init__()
  439. self.out_conv1 = nn.Conv2d(hidden_size, intermediate_size, kernel_size=1, stride=1, padding=0, bias=False)
  440. self.out_conv2 = nn.Conv2d(
  441. intermediate_size, intermediate_size, kernel_size=3, stride=1, padding=1, bias=False
  442. )
  443. self.batch_norm = nn.BatchNorm2d(intermediate_size)
  444. self.activation = ACT2CLS[config.mlp_activation_function]()
  445. self.out_conv3 = nn.Conv2d(intermediate_size, hidden_size, kernel_size=3, stride=1, padding=1, bias=False)
  446. def forward(self, hidden_states: torch.Tensor, residual_states: torch.Tensor) -> torch.Tensor:
  447. residual_states = self.out_conv1(residual_states)
  448. residual_states = residual_states + hidden_states
  449. residual_states = self.out_conv2(residual_states)
  450. residual_states = self.batch_norm(residual_states)
  451. residual_states = self.activation(residual_states)
  452. residual_states = self.out_conv3(residual_states)
  453. residual_states = nn.functional.interpolate(
  454. residual_states, scale_factor=2.0, mode="bilinear", align_corners=False
  455. )
  456. return residual_states
  457. class EfficientLoFTRFineFusionLayer(nn.Module):
  458. def __init__(self, config: EfficientLoFTRConfig):
  459. super().__init__()
  460. self.fine_kernel_size = config.fine_kernel_size
  461. fine_fusion_dims = config.fine_fusion_dims
  462. self.out_conv = nn.Conv2d(
  463. fine_fusion_dims[0], fine_fusion_dims[0], kernel_size=1, stride=1, padding=0, bias=False
  464. )
  465. self.out_conv_layers = nn.ModuleList()
  466. for i in range(1, len(fine_fusion_dims)):
  467. out_conv = EfficientLoFTROutConvBlock(config, fine_fusion_dims[i], fine_fusion_dims[i - 1])
  468. self.out_conv_layers.append(out_conv)
  469. def forward_pyramid(
  470. self,
  471. hidden_states: torch.Tensor,
  472. residual_states: list[torch.Tensor],
  473. ) -> torch.Tensor:
  474. hidden_states = self.out_conv(hidden_states)
  475. hidden_states = nn.functional.interpolate(
  476. hidden_states, scale_factor=2.0, mode="bilinear", align_corners=False
  477. )
  478. for i, layer in enumerate(self.out_conv_layers):
  479. hidden_states = layer(hidden_states, residual_states[i])
  480. return hidden_states
  481. def forward(
  482. self,
  483. coarse_features: torch.Tensor,
  484. residual_features: list[torch.Tensor],
  485. ) -> tuple[torch.Tensor, torch.Tensor]:
  486. """
  487. For each image pair, compute the fine features of pixels.
  488. In both images, compute a patch of fine features center cropped around each coarse pixel.
  489. In the first image, the feature patch is kernel_size large and long.
  490. In the second image, it is (kernel_size + 2) large and long.
  491. """
  492. batch_size, _, embed_dim, coarse_height, coarse_width = coarse_features.shape
  493. coarse_features = coarse_features.reshape(-1, embed_dim, coarse_height, coarse_width)
  494. residual_features = list(reversed(residual_features))
  495. # 1. Fine feature extraction
  496. fine_features = self.forward_pyramid(coarse_features, residual_features)
  497. _, fine_embed_dim, fine_height, fine_width = fine_features.shape
  498. fine_features = fine_features.reshape(batch_size, 2, fine_embed_dim, fine_height, fine_width)
  499. fine_features_0 = fine_features[:, 0]
  500. fine_features_1 = fine_features[:, 1]
  501. # 2. Unfold all local windows in crops
  502. stride = int(fine_height // coarse_height)
  503. fine_features_0 = nn.functional.unfold(
  504. fine_features_0, kernel_size=self.fine_kernel_size, stride=stride, padding=0
  505. )
  506. _, _, seq_len = fine_features_0.shape
  507. fine_features_0 = fine_features_0.reshape(batch_size, -1, self.fine_kernel_size**2, seq_len)
  508. fine_features_0 = fine_features_0.permute(0, 3, 2, 1)
  509. fine_features_1 = nn.functional.unfold(
  510. fine_features_1, kernel_size=self.fine_kernel_size + 2, stride=stride, padding=1
  511. )
  512. fine_features_1 = fine_features_1.reshape(batch_size, -1, (self.fine_kernel_size + 2) ** 2, seq_len)
  513. fine_features_1 = fine_features_1.permute(0, 3, 2, 1)
  514. return fine_features_0, fine_features_1
  515. @auto_docstring
  516. class EfficientLoFTRPreTrainedModel(PreTrainedModel):
  517. """
  518. An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
  519. models.
  520. """
  521. config_class = EfficientLoFTRConfig
  522. base_model_prefix = "efficientloftr"
  523. main_input_name = "pixel_values"
  524. supports_gradient_checkpointing = True
  525. _supports_flash_attn = True
  526. _supports_sdpa = True
  527. _can_record_outputs = {
  528. "hidden_states": EfficientLoFTRRepVGGBlock,
  529. "attentions": EfficientLoFTRAttention,
  530. }
  531. def _init_weights(self, module: nn.Module) -> None:
  532. """Initialize the weights"""
  533. if isinstance(module, (nn.Linear, nn.Conv2d, nn.Conv1d, nn.BatchNorm2d)):
  534. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  535. if module.bias is not None:
  536. module.bias.data.zero_()
  537. elif isinstance(module, nn.LayerNorm):
  538. module.bias.data.zero_()
  539. module.weight.data.fill_(1.0)
  540. # Copied from transformers.models.superpoint.modeling_superpoint.SuperPointPreTrainedModel.extract_one_channel_pixel_values with SuperPoint->EfficientLoFTR
  541. def extract_one_channel_pixel_values(self, pixel_values: torch.FloatTensor) -> torch.FloatTensor:
  542. """
  543. Assuming pixel_values has shape (batch_size, 3, height, width), and that all channels values are the same,
  544. extract the first channel value to get a tensor of shape (batch_size, 1, height, width) for EfficientLoFTR. This is
  545. a workaround for the issue discussed in :
  546. https://github.com/huggingface/transformers/pull/25786#issuecomment-1730176446
  547. Args:
  548. pixel_values: torch.FloatTensor of shape (batch_size, 3, height, width)
  549. Returns:
  550. pixel_values: torch.FloatTensor of shape (batch_size, 1, height, width)
  551. """
  552. return pixel_values[:, 0, :, :][:, None, :, :]
  553. @auto_docstring(
  554. custom_intro="""
  555. EfficientLoFTR model taking images as inputs and outputting the features of the images.
  556. """
  557. )
  558. class EfficientLoFTRModel(EfficientLoFTRPreTrainedModel):
  559. def __init__(self, config: EfficientLoFTRConfig):
  560. super().__init__(config)
  561. self.config = config
  562. self.backbone = EfficientLoFTRepVGG(config)
  563. self.local_feature_transformer = EfficientLoFTRLocalFeatureTransformer(config)
  564. self.rotary_emb = EfficientLoFTRRotaryEmbedding(config=config)
  565. self.post_init()
  566. @check_model_inputs()
  567. @auto_docstring
  568. def forward(
  569. self,
  570. pixel_values: torch.FloatTensor,
  571. labels: Optional[torch.LongTensor] = None,
  572. **kwargs: Unpack[TransformersKwargs],
  573. ) -> BackboneOutput:
  574. r"""
  575. Examples:
  576. ```python
  577. >>> from transformers import AutoImageProcessor, AutoModel
  578. >>> import torch
  579. >>> from PIL import Image
  580. >>> import requests
  581. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  582. >>> image1 = Image.open(requests.get(url, stream=True).raw)
  583. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  584. >>> image2 = Image.open(requests.get(url, stream=True).raw)
  585. >>> images = [image1, image2]
  586. >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
  587. >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
  588. >>> with torch.no_grad():
  589. >>> inputs = processor(images, return_tensors="pt")
  590. >>> outputs = model(**inputs)
  591. ```"""
  592. if labels is not None:
  593. raise ValueError("EfficientLoFTR is not trainable, no labels should be provided.")
  594. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  595. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  596. batch_size, _, channels, height, width = pixel_values.shape
  597. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  598. pixel_values = self.extract_one_channel_pixel_values(pixel_values)
  599. # 1. Local Feature CNN
  600. features = self.backbone(pixel_values)
  601. # Last stage outputs are coarse outputs
  602. coarse_features = features[-1]
  603. # Rest is residual features used in EfficientLoFTRFineFusionLayer
  604. residual_features = features[:-1]
  605. coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
  606. # 2. Coarse-level LoFTR module
  607. cos, sin = self.rotary_emb(coarse_features)
  608. cos = cos.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
  609. sin = sin.expand(batch_size * 2, -1, -1, -1).reshape(batch_size * 2, -1, coarse_embed_dim)
  610. position_embeddings = (cos, sin)
  611. coarse_features = coarse_features.reshape(batch_size, 2, coarse_embed_dim, coarse_height, coarse_width)
  612. coarse_features = self.local_feature_transformer(
  613. coarse_features, position_embeddings=position_embeddings, **kwargs
  614. )
  615. features = (coarse_features,) + tuple(residual_features)
  616. return BackboneOutput(feature_maps=features)
  617. def mask_border(tensor: torch.Tensor, border_margin: int, value: Union[bool, float, int]) -> torch.Tensor:
  618. """
  619. Mask a tensor border with a given value
  620. Args:
  621. tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  622. The tensor to mask
  623. border_margin (`int`) :
  624. The size of the border
  625. value (`Union[bool, int, float]`):
  626. The value to place in the tensor's borders
  627. Returns:
  628. tensor (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  629. The masked tensor
  630. """
  631. if border_margin <= 0:
  632. return tensor
  633. tensor[:, :border_margin] = value
  634. tensor[:, :, :border_margin] = value
  635. tensor[:, :, :, :border_margin] = value
  636. tensor[:, :, :, :, :border_margin] = value
  637. tensor[:, -border_margin:] = value
  638. tensor[:, :, -border_margin:] = value
  639. tensor[:, :, :, -border_margin:] = value
  640. tensor[:, :, :, :, -border_margin:] = value
  641. return tensor
  642. def create_meshgrid(
  643. height: Union[int, torch.Tensor],
  644. width: Union[int, torch.Tensor],
  645. normalized_coordinates: bool = False,
  646. device: Optional[torch.device] = None,
  647. dtype: Optional[torch.dtype] = None,
  648. ) -> torch.Tensor:
  649. """
  650. Copied from kornia library : kornia/kornia/utils/grid.py:26
  651. Generate a coordinate grid for an image.
  652. When the flag ``normalized_coordinates`` is set to True, the grid is
  653. normalized to be in the range :math:`[-1,1]` to be consistent with the pytorch
  654. function :py:func:`torch.nn.functional.grid_sample`.
  655. Args:
  656. height (`int`):
  657. The image height (rows).
  658. width (`int`):
  659. The image width (cols).
  660. normalized_coordinates (`bool`):
  661. Whether to normalize coordinates in the range :math:`[-1,1]` in order to be consistent with the
  662. PyTorch function :py:func:`torch.nn.functional.grid_sample`.
  663. device (`torch.device`):
  664. The device on which the grid will be generated.
  665. dtype (`torch.dtype`):
  666. The data type of the generated grid.
  667. Return:
  668. grid (`torch.Tensor` of shape `(1, height, width, 2)`):
  669. The grid tensor.
  670. Example:
  671. >>> create_meshgrid(2, 2)
  672. tensor([[[[-1., -1.],
  673. [ 1., -1.]],
  674. <BLANKLINE>
  675. [[-1., 1.],
  676. [ 1., 1.]]]])
  677. >>> create_meshgrid(2, 2, normalized_coordinates=False)
  678. tensor([[[[0., 0.],
  679. [1., 0.]],
  680. <BLANKLINE>
  681. [[0., 1.],
  682. [1., 1.]]]])
  683. """
  684. xs = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
  685. ys = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
  686. if normalized_coordinates:
  687. xs = (xs / (width - 1) - 0.5) * 2
  688. ys = (ys / (height - 1) - 0.5) * 2
  689. grid = torch.stack(torch.meshgrid(ys, xs, indexing="ij"), dim=-1)
  690. grid = grid.permute(1, 0, 2).unsqueeze(0)
  691. return grid
  692. def spatial_expectation2d(input: torch.Tensor, normalized_coordinates: bool = True) -> torch.Tensor:
  693. r"""
  694. Copied from kornia library : kornia/geometry/subpix/dsnt.py:76
  695. Compute the expectation of coordinate values using spatial probabilities.
  696. The input heatmap is assumed to represent a valid spatial probability distribution,
  697. which can be achieved using :func:`~kornia.geometry.subpixel.spatial_softmax2d`.
  698. Args:
  699. input (`torch.Tensor` of shape `(batch_size, embed_dim, height, width)`):
  700. The input tensor representing dense spatial probabilities.
  701. normalized_coordinates (`bool`):
  702. Whether to return the coordinates normalized in the range of :math:`[-1, 1]`. Otherwise, it will return
  703. the coordinates in the range of the input shape.
  704. Returns:
  705. output (`torch.Tensor` of shape `(batch_size, embed_dim, 2)`)
  706. Expected value of the 2D coordinates. Output order of the coordinates is (x, y).
  707. Examples:
  708. >>> heatmaps = torch.tensor([[[
  709. ... [0., 0., 0.],
  710. ... [0., 0., 0.],
  711. ... [0., 1., 0.]]]])
  712. >>> spatial_expectation2d(heatmaps, False)
  713. tensor([[[1., 2.]]])
  714. """
  715. batch_size, embed_dim, height, width = input.shape
  716. # Create coordinates grid.
  717. grid = create_meshgrid(height, width, normalized_coordinates, input.device)
  718. grid = grid.to(input.dtype)
  719. pos_x = grid[..., 0].reshape(-1)
  720. pos_y = grid[..., 1].reshape(-1)
  721. input_flat = input.view(batch_size, embed_dim, -1)
  722. # Compute the expectation of the coordinates.
  723. expected_y = torch.sum(pos_y * input_flat, -1, keepdim=True)
  724. expected_x = torch.sum(pos_x * input_flat, -1, keepdim=True)
  725. output = torch.cat([expected_x, expected_y], -1)
  726. return output.view(batch_size, embed_dim, 2)
  727. @auto_docstring(
  728. custom_intro="""
  729. EfficientLoFTR model taking images as inputs and outputting the matching of them.
  730. """
  731. )
  732. class EfficientLoFTRForKeypointMatching(EfficientLoFTRPreTrainedModel):
  733. """EfficientLoFTR dense image matcher
  734. Given two images, we determine the correspondences by:
  735. 1. Extracting coarse and fine features through a backbone
  736. 2. Transforming coarse features through self and cross attention
  737. 3. Matching coarse features to obtain coarse coordinates of matches
  738. 4. Obtaining full resolution fine features by fusing transformed and backbone coarse features
  739. 5. Refining the coarse matches using fine feature patches centered at each coarse match in a two-stage refinement
  740. Yifan Wang, Xingyi He, Sida Peng, Dongli Tan and Xiaowei Zhou.
  741. Efficient LoFTR: Semi-Dense Local Feature Matching with Sparse-Like Speed
  742. In CVPR, 2024. https://huggingface.co/papers/2403.04765
  743. """
  744. def __init__(self, config: EfficientLoFTRConfig):
  745. super().__init__(config)
  746. self.config = config
  747. self.efficientloftr = EfficientLoFTRModel(config)
  748. self.refinement_layer = EfficientLoFTRFineFusionLayer(config)
  749. self.post_init()
  750. def _get_matches_from_scores(self, scores: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
  751. """
  752. Based on a keypoint score matrix, compute the best keypoint matches between the first and second image.
  753. Since each image pair can have different number of matches, the matches are concatenated together for all pair
  754. in the batch and a batch_indices tensor is returned to specify which match belong to which element in the batch.
  755. Note:
  756. This step can be done as a postprocessing step, because does not involve any model weights/params.
  757. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  758. easier torch.compile/torch.export (all ops are in torch).
  759. Args:
  760. scores (`torch.Tensor` of shape `(batch_size, height_0, width_0, height_1, width_1)`):
  761. Scores of keypoints
  762. Returns:
  763. matched_indices (`torch.Tensor` of shape `(2, num_matches)`):
  764. Indices representing which pixel in the first image matches which pixel in the second image
  765. matching_scores (`torch.Tensor` of shape `(num_matches,)`):
  766. Scores of each match
  767. """
  768. batch_size, height0, width0, height1, width1 = scores.shape
  769. scores = scores.view(batch_size, height0 * width0, height1 * width1)
  770. # For each keypoint, get the best match
  771. max_0 = scores.max(2, keepdim=True).values
  772. max_1 = scores.max(1, keepdim=True).values
  773. # 1. Thresholding
  774. mask = scores > self.config.coarse_matching_threshold
  775. # 2. Border removal
  776. mask = mask.reshape(batch_size, height0, width0, height1, width1)
  777. mask = mask_border(mask, self.config.coarse_matching_border_removal, False)
  778. mask = mask.reshape(batch_size, height0 * width0, height1 * width1)
  779. # 3. Mutual nearest neighbors
  780. mask = mask * (scores == max_0) * (scores == max_1)
  781. # 4. Fine coarse matches
  782. masked_scores = scores * mask
  783. matching_scores_0, max_indices_0 = masked_scores.max(1)
  784. matching_scores_1, max_indices_1 = masked_scores.max(2)
  785. matching_indices = torch.cat([max_indices_0, max_indices_1]).reshape(batch_size, 2, -1)
  786. matching_scores = torch.stack([matching_scores_0, matching_scores_1], dim=1)
  787. # For the keypoints not meeting the threshold score, set the indices to -1 which corresponds to no matches found
  788. matching_indices = torch.where(matching_scores > 0, matching_indices, -1)
  789. return matching_indices, matching_scores
  790. def _coarse_matching(
  791. self, coarse_features: torch.Tensor, coarse_scale: float
  792. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  793. """
  794. For each image pair, compute the matching confidence between each coarse element (by default (image_height / 8)
  795. * (image_width / 8 elements)) from the first image to the second image.
  796. Note:
  797. This step can be done as a postprocessing step, because does not involve any model weights/params.
  798. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  799. easier torch.compile/torch.export (all ops are in torch).
  800. Args:
  801. coarse_features (`torch.Tensor` of shape `(batch_size, 2, hidden_size, coarse_height, coarse_width)`):
  802. Coarse features
  803. coarse_scale (`float`): Scale between the image size and the coarse size
  804. Returns:
  805. keypoints (`torch.Tensor` of shape `(batch_size, 2, num_matches, 2)`):
  806. Keypoints coordinates.
  807. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
  808. The confidence matching score of each keypoint.
  809. matched_indices (`torch.Tensor` of shape `(batch_size, 2, num_matches)`):
  810. Indices which indicates which keypoint in an image matched with which keypoint in the other image. For
  811. both image in the pair.
  812. """
  813. batch_size, _, embed_dim, height, width = coarse_features.shape
  814. # (batch_size, 2, embed_dim, height, width) -> (batch_size, 2, height * width, embed_dim)
  815. coarse_features = coarse_features.permute(0, 1, 3, 4, 2)
  816. coarse_features = coarse_features.reshape(batch_size, 2, -1, embed_dim)
  817. coarse_features = coarse_features / coarse_features.shape[-1] ** 0.5
  818. coarse_features_0 = coarse_features[:, 0]
  819. coarse_features_1 = coarse_features[:, 1]
  820. similarity = coarse_features_0 @ coarse_features_1.transpose(-1, -2)
  821. similarity = similarity / self.config.coarse_matching_temperature
  822. if self.config.coarse_matching_skip_softmax:
  823. confidence = similarity
  824. else:
  825. confidence = nn.functional.softmax(similarity, 1) * nn.functional.softmax(similarity, 2)
  826. confidence = confidence.view(batch_size, height, width, height, width)
  827. matched_indices, matching_scores = self._get_matches_from_scores(confidence)
  828. keypoints = torch.stack([matched_indices % width, matched_indices // width], dim=-1) * coarse_scale
  829. return keypoints, matching_scores, matched_indices
  830. def _get_first_stage_fine_matching(
  831. self,
  832. fine_confidence: torch.Tensor,
  833. coarse_matched_keypoints: torch.Tensor,
  834. fine_window_size: int,
  835. fine_scale: float,
  836. ) -> tuple[torch.Tensor, torch.Tensor]:
  837. """
  838. For each coarse pixel, retrieve the highest fine confidence score and index.
  839. The index represents the matching between a pixel position in the fine window in the first image and a pixel
  840. position in the fine window of the second image.
  841. For example, for a fine_window_size of 64 (8 * 8), the index 2474 represents the matching between the index 38
  842. (2474 // 64) in the fine window of the first image, and the index 42 in the second image. This means that 38
  843. which corresponds to the position (4, 6) (4 // 8 and 4 % 8) is matched with the position (5, 2). In this example
  844. the coarse matched coordinate will be shifted to the matched fine coordinates in the first and second image.
  845. Note:
  846. This step can be done as a postprocessing step, because does not involve any model weights/params.
  847. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  848. easier torch.compile/torch.export (all ops are in torch).
  849. Args:
  850. fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
  851. First stage confidence of matching fine features between the first and the second image
  852. coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
  853. Coarse matched keypoint between the first and the second image.
  854. fine_window_size (`int`):
  855. Size of the window used to refine matches
  856. fine_scale (`float`):
  857. Scale between the size of fine features and coarse features
  858. Returns:
  859. indices (`torch.Tensor` of shape `(2, num_matches, 1)`):
  860. Indices of the fine coordinate matched in the fine window
  861. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  862. Coordinates of matched keypoints after the first fine stage
  863. """
  864. batch_size, num_keypoints, _, _ = fine_confidence.shape
  865. fine_kernel_size = torch_int(fine_window_size**0.5)
  866. fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, -1)
  867. values, indices = torch.max(fine_confidence, dim=-1)
  868. indices = indices[..., None]
  869. indices_0 = indices // fine_window_size
  870. indices_1 = indices % fine_window_size
  871. grid = create_meshgrid(
  872. fine_kernel_size,
  873. fine_kernel_size,
  874. normalized_coordinates=False,
  875. device=fine_confidence.device,
  876. dtype=fine_confidence.dtype,
  877. )
  878. grid = grid - (fine_kernel_size // 2) + 0.5
  879. grid = grid.reshape(1, 1, -1, 2).expand(batch_size, num_keypoints, -1, -1)
  880. delta_0 = torch.gather(grid, 1, indices_0.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
  881. delta_1 = torch.gather(grid, 1, indices_1.unsqueeze(-1).expand(-1, -1, -1, 2)).squeeze(2)
  882. fine_matches_0 = coarse_matched_keypoints[:, 0] + delta_0 * fine_scale
  883. fine_matches_1 = coarse_matched_keypoints[:, 1] + delta_1 * fine_scale
  884. indices = torch.stack([indices_0, indices_1], dim=1)
  885. fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
  886. return indices, fine_matches
  887. def _get_second_stage_fine_matching(
  888. self,
  889. indices: torch.Tensor,
  890. fine_matches: torch.Tensor,
  891. fine_confidence: torch.Tensor,
  892. fine_window_size: int,
  893. fine_scale: float,
  894. ) -> torch.Tensor:
  895. """
  896. For the given position in their respective fine windows, retrieve the 3x3 fine confidences around this position.
  897. After applying softmax to these confidences, compute the 2D spatial expected coordinates.
  898. Shift the first stage fine matching with these expected coordinates.
  899. Note:
  900. This step can be done as a postprocessing step, because does not involve any model weights/params.
  901. However, we keep it in the modeling code for consistency with other keypoint matching models AND for
  902. easier torch.compile/torch.export (all ops are in torch).
  903. Args:
  904. indices (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  905. Indices representing the position of each keypoint in the fine window
  906. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  907. Coordinates of matched keypoints after the first fine stage
  908. fine_confidence (`torch.Tensor` of shape `(num_matches, fine_window_size, fine_window_size)`):
  909. Second stage confidence of matching fine features between the first and the second image
  910. fine_window_size (`int`):
  911. Size of the window used to refine matches
  912. fine_scale (`float`):
  913. Scale between the size of fine features and coarse features
  914. Returns:
  915. fine_matches (`torch.Tensor` of shape `(2, num_matches, 2)`):
  916. Coordinates of matched keypoints after the second fine stage
  917. """
  918. batch_size, num_keypoints, _, _ = fine_confidence.shape
  919. fine_kernel_size = torch_int(fine_window_size**0.5)
  920. indices_0 = indices[:, 0]
  921. indices_1 = indices[:, 1]
  922. indices_1_i = indices_1 // fine_kernel_size
  923. indices_1_j = indices_1 % fine_kernel_size
  924. # matches_indices, indices_0, indices_1_i, indices_1_j of shape (num_matches, 3, 3)
  925. batch_indices = torch.arange(batch_size, device=indices_0.device).reshape(batch_size, 1, 1, 1)
  926. matches_indices = torch.arange(num_keypoints, device=indices_0.device).reshape(1, num_keypoints, 1, 1)
  927. indices_0 = indices_0[..., None]
  928. indices_1_i = indices_1_i[..., None]
  929. indices_1_j = indices_1_j[..., None]
  930. delta = create_meshgrid(3, 3, normalized_coordinates=True, device=indices_0.device).to(torch.long)
  931. delta = delta[None, ...]
  932. indices_1_i = indices_1_i + delta[..., 1]
  933. indices_1_j = indices_1_j + delta[..., 0]
  934. fine_confidence = fine_confidence.reshape(
  935. batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
  936. )
  937. # (batch_size, seq_len, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2) -> (batch_size, seq_len, 3, 3)
  938. fine_confidence = fine_confidence[batch_indices, matches_indices, indices_0, indices_1_i, indices_1_j]
  939. fine_confidence = fine_confidence.reshape(batch_size, num_keypoints, 9)
  940. fine_confidence = nn.functional.softmax(
  941. fine_confidence / self.config.fine_matching_regress_temperature, dim=-1
  942. )
  943. heatmap = fine_confidence.reshape(batch_size, num_keypoints, 3, 3)
  944. fine_coordinates_normalized = spatial_expectation2d(heatmap, True)[0]
  945. fine_matches_0 = fine_matches[:, 0]
  946. fine_matches_1 = fine_matches[:, 1] + (fine_coordinates_normalized * (3 // 2) * fine_scale)
  947. fine_matches = torch.stack([fine_matches_0, fine_matches_1], dim=1)
  948. return fine_matches
  949. def _fine_matching(
  950. self,
  951. fine_features_0: torch.Tensor,
  952. fine_features_1: torch.Tensor,
  953. coarse_matched_keypoints: torch.Tensor,
  954. fine_scale: float,
  955. ) -> torch.Tensor:
  956. """
  957. For each coarse pixel with a corresponding window of fine features, compute the matching confidence between fine
  958. features in the first image and the second image.
  959. Fine features are sliced in two part :
  960. - The first part used for the first stage are the first fine_hidden_size - config.fine_matching_slicedim (64 - 8
  961. = 56 by default) features.
  962. - The second part used for the second stage are the last config.fine_matching_slicedim (8 by default) features.
  963. Each part is used to compute a fine confidence tensor of the following shape :
  964. (batch_size, (coarse_height * coarse_width), fine_window_size, fine_window_size)
  965. They correspond to the score between each fine pixel in the first image and each fine pixel in the second image.
  966. Args:
  967. fine_features_0 (`torch.Tensor` of shape `(num_matches, fine_kernel_size ** 2, fine_kernel_size ** 2)`):
  968. Fine features from the first image
  969. fine_features_1 (`torch.Tensor` of shape `(num_matches, (fine_kernel_size + 2) ** 2, (fine_kernel_size + 2)
  970. ** 2)`):
  971. Fine features from the second image
  972. coarse_matched_keypoints (`torch.Tensor` of shape `(2, num_matches, 2)`):
  973. Keypoint coordinates found in coarse matching for the first and second image
  974. fine_scale (`int`):
  975. Scale between the size of fine features and coarse features
  976. Returns:
  977. fine_coordinates (`torch.Tensor` of shape `(2, num_matches, 2)`):
  978. Matched keypoint between the first and the second image. All matched keypoints are concatenated in the
  979. second dimension.
  980. """
  981. batch_size, num_keypoints, fine_window_size, fine_embed_dim = fine_features_0.shape
  982. fine_matching_slice_dim = self.config.fine_matching_slice_dim
  983. fine_kernel_size = torch_int(fine_window_size**0.5)
  984. # Split fine features into first and second stage features
  985. split_fine_features_0 = torch.split(fine_features_0, fine_embed_dim - fine_matching_slice_dim, -1)
  986. split_fine_features_1 = torch.split(fine_features_1, fine_embed_dim - fine_matching_slice_dim, -1)
  987. # Retrieve first stage fine features
  988. fine_features_0 = split_fine_features_0[0]
  989. fine_features_1 = split_fine_features_1[0]
  990. # Normalize first stage fine features
  991. fine_features_0 = fine_features_0 / fine_features_0.shape[-1] ** 0.5
  992. fine_features_1 = fine_features_1 / fine_features_1.shape[-1] ** 0.5
  993. # Compute first stage confidence
  994. fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
  995. fine_confidence = nn.functional.softmax(fine_confidence, 1) * nn.functional.softmax(fine_confidence, 2)
  996. fine_confidence = fine_confidence.reshape(
  997. batch_size, num_keypoints, fine_window_size, fine_kernel_size + 2, fine_kernel_size + 2
  998. )
  999. fine_confidence = fine_confidence[..., 1:-1, 1:-1]
  1000. first_stage_fine_confidence = fine_confidence.reshape(
  1001. batch_size, num_keypoints, fine_window_size, fine_window_size
  1002. )
  1003. fine_indices, fine_matches = self._get_first_stage_fine_matching(
  1004. first_stage_fine_confidence,
  1005. coarse_matched_keypoints,
  1006. fine_window_size,
  1007. fine_scale,
  1008. )
  1009. # Retrieve second stage fine features
  1010. fine_features_0 = split_fine_features_0[1]
  1011. fine_features_1 = split_fine_features_1[1]
  1012. # Normalize second stage fine features
  1013. fine_features_1 = fine_features_1 / fine_matching_slice_dim**0.5
  1014. # Compute second stage fine confidence
  1015. second_stage_fine_confidence = fine_features_0 @ fine_features_1.transpose(-1, -2)
  1016. fine_coordinates = self._get_second_stage_fine_matching(
  1017. fine_indices,
  1018. fine_matches,
  1019. second_stage_fine_confidence,
  1020. fine_window_size,
  1021. fine_scale,
  1022. )
  1023. return fine_coordinates
  1024. @auto_docstring
  1025. @can_return_tuple
  1026. def forward(
  1027. self,
  1028. pixel_values: torch.FloatTensor,
  1029. labels: Optional[torch.LongTensor] = None,
  1030. **kwargs: Unpack[TransformersKwargs],
  1031. ) -> KeypointMatchingOutput:
  1032. r"""
  1033. Examples:
  1034. ```python
  1035. >>> from transformers import AutoImageProcessor, AutoModel
  1036. >>> import torch
  1037. >>> from PIL import Image
  1038. >>> import requests
  1039. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  1040. >>> image1 = Image.open(requests.get(url, stream=True).raw)
  1041. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  1042. >>> image2 = Image.open(requests.get(url, stream=True).raw)
  1043. >>> images = [image1, image2]
  1044. >>> processor = AutoImageProcessor.from_pretrained("zju-community/efficient_loftr")
  1045. >>> model = AutoModel.from_pretrained("zju-community/efficient_loftr")
  1046. >>> with torch.no_grad():
  1047. >>> inputs = processor(images, return_tensors="pt")
  1048. >>> outputs = model(**inputs)
  1049. ```"""
  1050. if labels is not None:
  1051. raise ValueError("SuperGlue is not trainable, no labels should be provided.")
  1052. # 1. Extract coarse and residual features
  1053. model_outputs: BackboneOutput = self.efficientloftr(pixel_values, **kwargs)
  1054. features = model_outputs.feature_maps
  1055. # 2. Compute coarse-level matching
  1056. coarse_features = features[0]
  1057. coarse_embed_dim, coarse_height, coarse_width = coarse_features.shape[-3:]
  1058. batch_size, _, channels, height, width = pixel_values.shape
  1059. coarse_scale = height / coarse_height
  1060. coarse_keypoints, coarse_matching_scores, coarse_matched_indices = self._coarse_matching(
  1061. coarse_features, coarse_scale
  1062. )
  1063. # 3. Fine-level refinement
  1064. residual_features = features[1:]
  1065. coarse_features = coarse_features / self.config.hidden_size**0.5
  1066. fine_features_0, fine_features_1 = self.refinement_layer(coarse_features, residual_features)
  1067. # Filter fine features with coarse matches indices
  1068. _, _, num_keypoints = coarse_matching_scores.shape
  1069. batch_indices = torch.arange(batch_size)[..., None]
  1070. fine_features_0 = fine_features_0[batch_indices, coarse_matched_indices[:, 0]]
  1071. fine_features_1 = fine_features_1[batch_indices, coarse_matched_indices[:, 1]]
  1072. # 4. Computer fine-level matching
  1073. fine_height = torch_int(coarse_height * coarse_scale)
  1074. fine_scale = height / fine_height
  1075. matching_keypoints = self._fine_matching(fine_features_0, fine_features_1, coarse_keypoints, fine_scale)
  1076. matching_keypoints[:, :, :, 0] = matching_keypoints[:, :, :, 0] / width
  1077. matching_keypoints[:, :, :, 1] = matching_keypoints[:, :, :, 1] / height
  1078. return KeypointMatchingOutput(
  1079. matches=coarse_matched_indices,
  1080. matching_scores=coarse_matching_scores,
  1081. keypoints=matching_keypoints,
  1082. hidden_states=model_outputs.hidden_states,
  1083. attentions=model_outputs.attentions,
  1084. )
  1085. __all__ = ["EfficientLoFTRPreTrainedModel", "EfficientLoFTRModel", "EfficientLoFTRForKeypointMatching"]