modeling_superglue.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811
  1. # Copyright 2024 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """PyTorch SuperGlue model."""
  15. import math
  16. from dataclasses import dataclass
  17. from typing import Optional, Union
  18. import torch
  19. from torch import nn
  20. from transformers import PreTrainedModel
  21. from transformers.models.superglue.configuration_superglue import SuperGlueConfig
  22. from ...pytorch_utils import find_pruneable_heads_and_indices, prune_linear_layer
  23. from ...utils import ModelOutput, auto_docstring, logging
  24. from ..auto import AutoModelForKeypointDetection
  25. logger = logging.get_logger(__name__)
  26. def concat_pairs(tensor_tuple0: tuple[torch.Tensor], tensor_tuple1: tuple[torch.Tensor]) -> tuple[torch.Tensor]:
  27. """
  28. Concatenate two tuples of tensors pairwise
  29. Args:
  30. tensor_tuple0 (`tuple[torch.Tensor]`):
  31. Tuple of tensors.
  32. tensor_tuple1 (`tuple[torch.Tensor]`):
  33. Tuple of tensors.
  34. Returns:
  35. (`tuple[torch.Tensor]`): Tuple of concatenated tensors.
  36. """
  37. return tuple(torch.cat([tensor0, tensor1]) for tensor0, tensor1 in zip(tensor_tuple0, tensor_tuple1))
  38. def normalize_keypoints(keypoints: torch.Tensor, height: int, width: int) -> torch.Tensor:
  39. """
  40. Normalize keypoints locations based on image image_shape
  41. Args:
  42. keypoints (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`):
  43. Keypoints locations in (x, y) format.
  44. height (`int`):
  45. Image height.
  46. width (`int`):
  47. Image width.
  48. Returns:
  49. Normalized keypoints locations of shape (`torch.Tensor` of shape `(batch_size, num_keypoints, 2)`).
  50. """
  51. size = torch.tensor([width, height], device=keypoints.device, dtype=keypoints.dtype)[None]
  52. center = size / 2
  53. scaling = size.max(1, keepdim=True).values * 0.7
  54. return (keypoints - center[:, None, :]) / scaling[:, None, :]
  55. def log_sinkhorn_iterations(
  56. log_cost_matrix: torch.Tensor,
  57. log_source_distribution: torch.Tensor,
  58. log_target_distribution: torch.Tensor,
  59. num_iterations: int,
  60. ) -> torch.Tensor:
  61. """
  62. Perform Sinkhorn Normalization in Log-space for stability
  63. Args:
  64. log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
  65. Logarithm of the cost matrix.
  66. log_source_distribution (`torch.Tensor` of shape `(batch_size, num_rows)`):
  67. Logarithm of the source distribution.
  68. log_target_distribution (`torch.Tensor` of shape `(batch_size, num_columns)`):
  69. Logarithm of the target distribution.
  70. Returns:
  71. log_cost_matrix (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the optimal
  72. transport matrix.
  73. """
  74. log_u_scaling = torch.zeros_like(log_source_distribution)
  75. log_v_scaling = torch.zeros_like(log_target_distribution)
  76. for _ in range(num_iterations):
  77. log_u_scaling = log_source_distribution - torch.logsumexp(log_cost_matrix + log_v_scaling.unsqueeze(1), dim=2)
  78. log_v_scaling = log_target_distribution - torch.logsumexp(log_cost_matrix + log_u_scaling.unsqueeze(2), dim=1)
  79. return log_cost_matrix + log_u_scaling.unsqueeze(2) + log_v_scaling.unsqueeze(1)
  80. def log_optimal_transport(scores: torch.Tensor, reg_param: torch.Tensor, iterations: int) -> torch.Tensor:
  81. """
  82. Perform Differentiable Optimal Transport in Log-space for stability
  83. Args:
  84. scores: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`):
  85. Cost matrix.
  86. reg_param: (`torch.Tensor` of shape `(batch_size, 1, 1)`):
  87. Regularization parameter.
  88. iterations: (`int`):
  89. Number of Sinkhorn iterations.
  90. Returns:
  91. log_optimal_transport_matrix: (`torch.Tensor` of shape `(batch_size, num_rows, num_columns)`): Logarithm of the
  92. optimal transport matrix.
  93. """
  94. batch_size, num_rows, num_columns = scores.shape
  95. one_tensor = scores.new_tensor(1)
  96. num_rows_tensor, num_columns_tensor = (num_rows * one_tensor).to(scores), (num_columns * one_tensor).to(scores)
  97. source_reg_param = reg_param.expand(batch_size, num_rows, 1)
  98. target_reg_param = reg_param.expand(batch_size, 1, num_columns)
  99. reg_param = reg_param.expand(batch_size, 1, 1)
  100. couplings = torch.cat([torch.cat([scores, source_reg_param], -1), torch.cat([target_reg_param, reg_param], -1)], 1)
  101. log_normalization = -(num_rows_tensor + num_columns_tensor).log()
  102. log_source_distribution = torch.cat(
  103. [log_normalization.expand(num_rows), num_columns_tensor.log()[None] + log_normalization]
  104. )
  105. log_target_distribution = torch.cat(
  106. [log_normalization.expand(num_columns), num_rows_tensor.log()[None] + log_normalization]
  107. )
  108. log_source_distribution, log_target_distribution = (
  109. log_source_distribution[None].expand(batch_size, -1),
  110. log_target_distribution[None].expand(batch_size, -1),
  111. )
  112. log_optimal_transport_matrix = log_sinkhorn_iterations(
  113. couplings, log_source_distribution, log_target_distribution, num_iterations=iterations
  114. )
  115. log_optimal_transport_matrix = log_optimal_transport_matrix - log_normalization # multiply probabilities by M+N
  116. return log_optimal_transport_matrix
  117. def arange_like(x, dim: int) -> torch.Tensor:
  118. return x.new_ones(x.shape[dim]).cumsum(0) - 1
  119. @dataclass
  120. @auto_docstring(
  121. custom_intro="""
  122. Base class for outputs of keypoint matching models. Due to the nature of keypoint detection and matching, the number
  123. of keypoints is not fixed and can vary from image to image, which makes batching non-trivial. In the batch of
  124. images, the maximum number of matches is set as the dimension of the matches and matching scores. The mask tensor is
  125. used to indicate which values in the keypoints, matches and matching_scores tensors are keypoint matching
  126. information.
  127. """
  128. )
  129. class KeypointMatchingOutput(ModelOutput):
  130. r"""
  131. loss (`torch.FloatTensor` of shape `(1,)`, *optional*):
  132. Loss computed during training.
  133. matches (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  134. Index of keypoint matched in the other image.
  135. matching_scores (`torch.FloatTensor` of shape `(batch_size, 2, num_matches)`):
  136. Scores of predicted matches.
  137. keypoints (`torch.FloatTensor` of shape `(batch_size, num_keypoints, 2)`):
  138. Absolute (x, y) coordinates of predicted keypoints in a given image.
  139. mask (`torch.IntTensor` of shape `(batch_size, num_keypoints)`):
  140. Mask indicating which values in matches and matching_scores are keypoint matching information.
  141. hidden_states (`tuple[torch.FloatTensor, ...]`, *optional*):
  142. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(batch_size, 2, num_channels,
  143. num_keypoints)`, returned when `output_hidden_states=True` is passed or when
  144. `config.output_hidden_states=True`)
  145. attentions (`tuple[torch.FloatTensor, ...]`, *optional*):
  146. Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, 2, num_heads, num_keypoints,
  147. num_keypoints)`, returned when `output_attentions=True` is passed or when `config.output_attentions=True`)
  148. """
  149. loss: Optional[torch.FloatTensor] = None
  150. matches: Optional[torch.FloatTensor] = None
  151. matching_scores: Optional[torch.FloatTensor] = None
  152. keypoints: Optional[torch.FloatTensor] = None
  153. mask: Optional[torch.IntTensor] = None
  154. hidden_states: Optional[tuple[torch.FloatTensor]] = None
  155. attentions: Optional[tuple[torch.FloatTensor]] = None
  156. class SuperGlueMultiLayerPerceptron(nn.Module):
  157. def __init__(self, config: SuperGlueConfig, in_channels: int, out_channels: int) -> None:
  158. super().__init__()
  159. self.linear = nn.Linear(in_channels, out_channels)
  160. self.batch_norm = nn.BatchNorm1d(out_channels)
  161. self.activation = nn.ReLU()
  162. def forward(self, hidden_state: torch.Tensor) -> torch.Tensor:
  163. hidden_state = self.linear(hidden_state)
  164. hidden_state = hidden_state.transpose(-1, -2)
  165. hidden_state = self.batch_norm(hidden_state)
  166. hidden_state = hidden_state.transpose(-1, -2)
  167. hidden_state = self.activation(hidden_state)
  168. return hidden_state
  169. class SuperGlueKeypointEncoder(nn.Module):
  170. def __init__(self, config: SuperGlueConfig) -> None:
  171. super().__init__()
  172. layer_sizes = config.keypoint_encoder_sizes
  173. hidden_size = config.hidden_size
  174. # 3 here consists of 2 for the (x, y) coordinates and 1 for the score of the keypoint
  175. encoder_channels = [3] + layer_sizes + [hidden_size]
  176. layers = [
  177. SuperGlueMultiLayerPerceptron(config, encoder_channels[i - 1], encoder_channels[i])
  178. for i in range(1, len(encoder_channels) - 1)
  179. ]
  180. layers.append(nn.Linear(encoder_channels[-2], encoder_channels[-1]))
  181. self.encoder = nn.ModuleList(layers)
  182. def forward(
  183. self,
  184. keypoints: torch.Tensor,
  185. scores: torch.Tensor,
  186. output_hidden_states: Optional[bool] = False,
  187. ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]]]:
  188. scores = scores.unsqueeze(2)
  189. hidden_state = torch.cat([keypoints, scores], dim=2)
  190. all_hidden_states = () if output_hidden_states else None
  191. for layer in self.encoder:
  192. hidden_state = layer(hidden_state)
  193. if output_hidden_states:
  194. all_hidden_states = all_hidden_states + (hidden_state,)
  195. return hidden_state, all_hidden_states
  196. class SuperGlueSelfAttention(nn.Module):
  197. def __init__(self, config, position_embedding_type=None):
  198. super().__init__()
  199. if config.hidden_size % config.num_attention_heads != 0 and not hasattr(config, "embedding_size"):
  200. raise ValueError(
  201. f"The hidden size ({config.hidden_size}) is not a multiple of the number of attention "
  202. f"heads ({config.num_attention_heads})"
  203. )
  204. self.num_attention_heads = config.num_attention_heads
  205. self.attention_head_size = int(config.hidden_size / config.num_attention_heads)
  206. self.all_head_size = self.num_attention_heads * self.attention_head_size
  207. self.query = nn.Linear(config.hidden_size, self.all_head_size)
  208. self.key = nn.Linear(config.hidden_size, self.all_head_size)
  209. self.value = nn.Linear(config.hidden_size, self.all_head_size)
  210. self.dropout = nn.Dropout(config.attention_probs_dropout_prob)
  211. self.position_embedding_type = position_embedding_type or getattr(
  212. config, "position_embedding_type", "absolute"
  213. )
  214. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  215. self.max_position_embeddings = config.max_position_embeddings
  216. self.distance_embedding = nn.Embedding(2 * config.max_position_embeddings - 1, self.attention_head_size)
  217. self.is_decoder = config.is_decoder
  218. def forward(
  219. self,
  220. hidden_states: torch.Tensor,
  221. attention_mask: Optional[torch.FloatTensor] = None,
  222. head_mask: Optional[torch.FloatTensor] = None,
  223. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  224. encoder_attention_mask: Optional[torch.FloatTensor] = None,
  225. output_attentions: Optional[bool] = False,
  226. ) -> tuple[torch.Tensor]:
  227. # If this is instantiated as a cross-attention module, the keys
  228. # and values come from an encoder; the attention mask needs to be
  229. # such that the encoder's padding tokens are not attended to.
  230. is_cross_attention = encoder_hidden_states is not None
  231. current_states = encoder_hidden_states if is_cross_attention else hidden_states
  232. attention_mask = encoder_attention_mask if is_cross_attention else attention_mask
  233. batch_size = hidden_states.shape[0]
  234. key_layer = (
  235. self.key(current_states)
  236. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  237. .transpose(1, 2)
  238. )
  239. value_layer = (
  240. self.value(current_states)
  241. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  242. .transpose(1, 2)
  243. )
  244. query_layer = (
  245. self.query(hidden_states)
  246. .view(batch_size, -1, self.num_attention_heads, self.attention_head_size)
  247. .transpose(1, 2)
  248. )
  249. # Take the dot product between "query" and "key" to get the raw attention scores.
  250. attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))
  251. if self.position_embedding_type == "relative_key" or self.position_embedding_type == "relative_key_query":
  252. query_length, key_length = query_layer.shape[2], key_layer.shape[2]
  253. position_ids_l = torch.arange(query_length, dtype=torch.long, device=hidden_states.device).view(-1, 1)
  254. position_ids_r = torch.arange(key_length, dtype=torch.long, device=hidden_states.device).view(1, -1)
  255. distance = position_ids_l - position_ids_r
  256. positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
  257. positional_embedding = positional_embedding.to(dtype=query_layer.dtype) # fp16 compatibility
  258. if self.position_embedding_type == "relative_key":
  259. relative_position_scores = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  260. attention_scores = attention_scores + relative_position_scores
  261. elif self.position_embedding_type == "relative_key_query":
  262. relative_position_scores_query = torch.einsum("bhld,lrd->bhlr", query_layer, positional_embedding)
  263. relative_position_scores_key = torch.einsum("bhrd,lrd->bhlr", key_layer, positional_embedding)
  264. attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
  265. attention_scores = attention_scores / math.sqrt(self.attention_head_size)
  266. if attention_mask is not None:
  267. # Apply the attention mask is (precomputed for all layers in SuperGlueModel forward() function)
  268. attention_scores = attention_scores + attention_mask
  269. # Normalize the attention scores to probabilities.
  270. attention_probs = nn.functional.softmax(attention_scores, dim=-1)
  271. # This is actually dropping out entire tokens to attend to, which might
  272. # seem a bit unusual, but is taken from the original Transformer paper.
  273. attention_probs = self.dropout(attention_probs)
  274. # Mask heads if we want to
  275. if head_mask is not None:
  276. attention_probs = attention_probs * head_mask
  277. context_layer = torch.matmul(attention_probs, value_layer)
  278. context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
  279. new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
  280. context_layer = context_layer.view(new_context_layer_shape)
  281. outputs = (context_layer, attention_probs) if output_attentions else (context_layer,)
  282. if self.is_decoder:
  283. outputs = outputs + (None,)
  284. return outputs
  285. class SuperGlueSelfOutput(nn.Module):
  286. def __init__(self, config: SuperGlueConfig):
  287. super().__init__()
  288. self.dense = nn.Linear(config.hidden_size, config.hidden_size)
  289. def forward(self, hidden_states: torch.Tensor, *args) -> torch.Tensor:
  290. hidden_states = self.dense(hidden_states)
  291. return hidden_states
  292. SUPERGLUE_SELF_ATTENTION_CLASSES = {
  293. "eager": SuperGlueSelfAttention,
  294. }
  295. class SuperGlueAttention(nn.Module):
  296. def __init__(self, config, position_embedding_type=None):
  297. super().__init__()
  298. self.self = SUPERGLUE_SELF_ATTENTION_CLASSES[config._attn_implementation](
  299. config,
  300. position_embedding_type=position_embedding_type,
  301. )
  302. self.output = SuperGlueSelfOutput(config)
  303. self.pruned_heads = set()
  304. def prune_heads(self, heads):
  305. if len(heads) == 0:
  306. return
  307. heads, index = find_pruneable_heads_and_indices(
  308. heads, self.self.num_attention_heads, self.self.attention_head_size, self.pruned_heads
  309. )
  310. # Prune linear layers
  311. self.self.query = prune_linear_layer(self.self.query, index)
  312. self.self.key = prune_linear_layer(self.self.key, index)
  313. self.self.value = prune_linear_layer(self.self.value, index)
  314. self.output.dense = prune_linear_layer(self.output.dense, index, dim=1)
  315. # Update hyper params and store pruned heads
  316. self.self.num_attention_heads = self.self.num_attention_heads - len(heads)
  317. self.self.all_head_size = self.self.attention_head_size * self.self.num_attention_heads
  318. self.pruned_heads = self.pruned_heads.union(heads)
  319. def forward(
  320. self,
  321. hidden_states: torch.Tensor,
  322. attention_mask: Optional[torch.FloatTensor] = None,
  323. head_mask: Optional[torch.FloatTensor] = None,
  324. encoder_hidden_states: Optional[torch.FloatTensor] = None,
  325. encoder_attention_mask: Optional[torch.Tensor] = None,
  326. output_attentions: Optional[bool] = False,
  327. ) -> tuple[torch.Tensor]:
  328. self_outputs = self.self(
  329. hidden_states,
  330. attention_mask=attention_mask,
  331. head_mask=head_mask,
  332. encoder_hidden_states=encoder_hidden_states,
  333. encoder_attention_mask=encoder_attention_mask,
  334. output_attentions=output_attentions,
  335. )
  336. attention_output = self.output(self_outputs[0], hidden_states)
  337. outputs = (attention_output,) + self_outputs[1:] # add attentions if we output them
  338. return outputs
  339. class SuperGlueAttentionalPropagation(nn.Module):
  340. def __init__(self, config: SuperGlueConfig) -> None:
  341. super().__init__()
  342. hidden_size = config.hidden_size
  343. self.attention = SuperGlueAttention(config)
  344. mlp_channels = [hidden_size * 2, hidden_size * 2, hidden_size]
  345. layers = [
  346. SuperGlueMultiLayerPerceptron(config, mlp_channels[i - 1], mlp_channels[i])
  347. for i in range(1, len(mlp_channels) - 1)
  348. ]
  349. layers.append(nn.Linear(mlp_channels[-2], mlp_channels[-1]))
  350. self.mlp = nn.ModuleList(layers)
  351. def forward(
  352. self,
  353. descriptors: torch.Tensor,
  354. attention_mask: Optional[torch.Tensor] = None,
  355. encoder_hidden_states: Optional[torch.Tensor] = None,
  356. encoder_attention_mask: Optional[torch.Tensor] = None,
  357. output_attentions: bool = False,
  358. output_hidden_states: bool = False,
  359. ) -> tuple[torch.Tensor, Optional[tuple[torch.Tensor]], Optional[tuple[torch.Tensor]]]:
  360. attention_outputs = self.attention(
  361. descriptors,
  362. attention_mask=attention_mask,
  363. encoder_hidden_states=encoder_hidden_states,
  364. encoder_attention_mask=encoder_attention_mask,
  365. output_attentions=output_attentions,
  366. )
  367. output = attention_outputs[0]
  368. attention = attention_outputs[1:]
  369. hidden_state = torch.cat([descriptors, output], dim=2)
  370. all_hidden_states = () if output_hidden_states else None
  371. for layer in self.mlp:
  372. hidden_state = layer(hidden_state)
  373. if output_hidden_states:
  374. all_hidden_states = all_hidden_states + (hidden_state,)
  375. return hidden_state, all_hidden_states, attention
  376. class SuperGlueAttentionalGNN(nn.Module):
  377. def __init__(self, config: SuperGlueConfig) -> None:
  378. super().__init__()
  379. self.hidden_size = config.hidden_size
  380. self.layers_types = config.gnn_layers_types
  381. self.layers = nn.ModuleList([SuperGlueAttentionalPropagation(config) for _ in range(len(self.layers_types))])
  382. def forward(
  383. self,
  384. descriptors: torch.Tensor,
  385. mask: Optional[torch.Tensor] = None,
  386. output_attentions: bool = False,
  387. output_hidden_states: Optional[bool] = False,
  388. ) -> tuple[torch.Tensor, Optional[tuple], Optional[tuple]]:
  389. all_hidden_states = () if output_hidden_states else None
  390. all_attentions = () if output_attentions else None
  391. batch_size, num_keypoints, _ = descriptors.shape
  392. if output_hidden_states:
  393. all_hidden_states = all_hidden_states + (descriptors,)
  394. for gnn_layer, layer_type in zip(self.layers, self.layers_types):
  395. encoder_hidden_states = None
  396. encoder_attention_mask = None
  397. if layer_type == "cross":
  398. encoder_hidden_states = (
  399. descriptors.reshape(-1, 2, num_keypoints, self.hidden_size)
  400. .flip(1)
  401. .reshape(batch_size, num_keypoints, self.hidden_size)
  402. )
  403. encoder_attention_mask = (
  404. mask.reshape(-1, 2, 1, 1, num_keypoints).flip(1).reshape(batch_size, 1, 1, num_keypoints)
  405. if mask is not None
  406. else None
  407. )
  408. gnn_outputs = gnn_layer(
  409. descriptors,
  410. attention_mask=mask,
  411. encoder_hidden_states=encoder_hidden_states,
  412. encoder_attention_mask=encoder_attention_mask,
  413. output_hidden_states=output_hidden_states,
  414. output_attentions=output_attentions,
  415. )
  416. delta = gnn_outputs[0]
  417. if output_hidden_states:
  418. all_hidden_states = all_hidden_states + gnn_outputs[1]
  419. if output_attentions:
  420. all_attentions = all_attentions + gnn_outputs[2]
  421. descriptors = descriptors + delta
  422. return descriptors, all_hidden_states, all_attentions
  423. class SuperGlueFinalProjection(nn.Module):
  424. def __init__(self, config: SuperGlueConfig) -> None:
  425. super().__init__()
  426. hidden_size = config.hidden_size
  427. self.final_proj = nn.Linear(hidden_size, hidden_size, bias=True)
  428. def forward(self, descriptors: torch.Tensor) -> torch.Tensor:
  429. return self.final_proj(descriptors)
  430. @auto_docstring
  431. class SuperGluePreTrainedModel(PreTrainedModel):
  432. config: SuperGlueConfig
  433. base_model_prefix = "superglue"
  434. main_input_name = "pixel_values"
  435. def _init_weights(self, module: nn.Module) -> None:
  436. """Initialize the weights"""
  437. if isinstance(module, (nn.Linear, nn.Conv2d)):
  438. # Slightly different from the TF version which uses truncated_normal for initialization
  439. # cf https://github.com/pytorch/pytorch/pull/5617
  440. module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
  441. if module.bias is not None:
  442. module.bias.data.zero_()
  443. elif isinstance(module, nn.BatchNorm1d):
  444. module.bias.data.zero_()
  445. module.weight.data.fill_(1.0)
  446. if hasattr(module, "bin_score"):
  447. module.bin_score.data.fill_(1.0)
  448. @auto_docstring(
  449. custom_intro="""
  450. SuperGlue model taking images as inputs and outputting the matching of them.
  451. """
  452. )
  453. class SuperGlueForKeypointMatching(SuperGluePreTrainedModel):
  454. """SuperGlue feature matching middle-end
  455. Given two sets of keypoints and locations, we determine the
  456. correspondences by:
  457. 1. Keypoint Encoding (normalization + visual feature and location fusion)
  458. 2. Graph Neural Network with multiple self and cross-attention layers
  459. 3. Final projection layer
  460. 4. Optimal Transport Layer (a differentiable Hungarian matching algorithm)
  461. 5. Thresholding matrix based on mutual exclusivity and a match_threshold
  462. The correspondence ids use -1 to indicate non-matching points.
  463. Paul-Edouard Sarlin, Daniel DeTone, Tomasz Malisiewicz, and Andrew
  464. Rabinovich. SuperGlue: Learning Feature Matching with Graph Neural
  465. Networks. In CVPR, 2020. https://huggingface.co/papers/1911.11763
  466. """
  467. def __init__(self, config: SuperGlueConfig) -> None:
  468. super().__init__(config)
  469. self.keypoint_detector = AutoModelForKeypointDetection.from_config(config.keypoint_detector_config)
  470. self.keypoint_encoder = SuperGlueKeypointEncoder(config)
  471. self.gnn = SuperGlueAttentionalGNN(config)
  472. self.final_projection = SuperGlueFinalProjection(config)
  473. bin_score = torch.nn.Parameter(torch.tensor(1.0))
  474. self.register_parameter("bin_score", bin_score)
  475. self.post_init()
  476. def _match_image_pair(
  477. self,
  478. keypoints: torch.Tensor,
  479. descriptors: torch.Tensor,
  480. scores: torch.Tensor,
  481. height: int,
  482. width: int,
  483. mask: Optional[torch.Tensor] = None,
  484. output_attentions: Optional[bool] = None,
  485. output_hidden_states: Optional[bool] = None,
  486. ) -> tuple[torch.Tensor, torch.Tensor, tuple, tuple]:
  487. """
  488. Perform keypoint matching between two images.
  489. Args:
  490. keypoints (`torch.Tensor` of shape `(batch_size, 2, num_keypoints, 2)`):
  491. Keypoints detected in the pair of image.
  492. descriptors (`torch.Tensor` of shape `(batch_size, 2, descriptor_dim, num_keypoints)`):
  493. Descriptors of the keypoints detected in the image pair.
  494. scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  495. Confidence scores of the keypoints detected in the image pair.
  496. height (`int`): Image height.
  497. width (`int`): Image width.
  498. mask (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`, *optional*):
  499. Mask indicating which values in the keypoints, matches and matching_scores tensors are keypoint matching
  500. information.
  501. output_attentions (`bool`, *optional*):
  502. Whether or not to return the attentions tensors. Default to `config.output_attentions`.
  503. output_hidden_states (`bool`, *optional*):
  504. Whether or not to return the hidden states of all layers. Default to `config.output_hidden_states`.
  505. Returns:
  506. matches (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  507. For each image pair, for each keypoint in image0, the index of the keypoint in image1 that was matched
  508. with. And for each keypoint in image1, the index of the keypoint in image0 that was matched with.
  509. matching_scores (`torch.Tensor` of shape `(batch_size, 2, num_keypoints)`):
  510. Scores of predicted matches for each image pair
  511. all_hidden_states (`tuple(torch.FloatTensor)`, *optional*):
  512. Tuple of `torch.FloatTensor` (one for the output of each stage) of shape `(1, 2, num_keypoints,
  513. num_channels)`.
  514. all_attentions (`tuple(torch.FloatTensor)`, *optional*):
  515. Tuple of `torch.FloatTensor` (one for each layer) of shape `(1, 2, num_heads, num_keypoints,
  516. num_keypoints)`.
  517. """
  518. all_hidden_states = () if output_hidden_states else None
  519. all_attentions = () if output_attentions else None
  520. if keypoints.shape[2] == 0: # no keypoints
  521. shape = keypoints.shape[:-1]
  522. return (
  523. keypoints.new_full(shape, -1, dtype=torch.int),
  524. keypoints.new_zeros(shape),
  525. all_hidden_states,
  526. all_attentions,
  527. )
  528. batch_size, _, num_keypoints, _ = keypoints.shape
  529. # (batch_size, 2, num_keypoints, 2) -> (batch_size * 2, num_keypoints, 2)
  530. keypoints = keypoints.reshape(batch_size * 2, num_keypoints, 2)
  531. descriptors = descriptors.reshape(batch_size * 2, num_keypoints, self.config.hidden_size)
  532. scores = scores.reshape(batch_size * 2, num_keypoints)
  533. mask = mask.reshape(batch_size * 2, num_keypoints) if mask is not None else None
  534. # Keypoint normalization
  535. keypoints = normalize_keypoints(keypoints, height, width)
  536. encoded_keypoints = self.keypoint_encoder(keypoints, scores, output_hidden_states=output_hidden_states)
  537. last_hidden_state = encoded_keypoints[0]
  538. # Keypoint MLP encoder.
  539. descriptors = descriptors + last_hidden_state
  540. if mask is not None:
  541. input_shape = descriptors.size()
  542. extended_attention_mask = self.get_extended_attention_mask(mask, input_shape)
  543. else:
  544. extended_attention_mask = torch.ones((batch_size, num_keypoints), device=keypoints.device)
  545. # Multi-layer Transformer network.
  546. gnn_outputs = self.gnn(
  547. descriptors,
  548. mask=extended_attention_mask,
  549. output_hidden_states=output_hidden_states,
  550. output_attentions=output_attentions,
  551. )
  552. descriptors = gnn_outputs[0]
  553. # Final MLP projection.
  554. projected_descriptors = self.final_projection(descriptors)
  555. # (batch_size * 2, num_keypoints, descriptor_dim) -> (batch_size, 2, num_keypoints, descriptor_dim)
  556. final_descriptors = projected_descriptors.reshape(batch_size, 2, num_keypoints, self.config.hidden_size)
  557. final_descriptors0 = final_descriptors[:, 0]
  558. final_descriptors1 = final_descriptors[:, 1]
  559. # Compute matching descriptor distance.
  560. scores = final_descriptors0 @ final_descriptors1.transpose(1, 2)
  561. scores = scores / self.config.hidden_size**0.5
  562. if mask is not None:
  563. mask = mask.reshape(batch_size, 2, num_keypoints)
  564. mask0 = mask[:, 0].unsqueeze(2)
  565. mask1 = mask[:, 1].unsqueeze(1)
  566. mask = torch.logical_and(mask0, mask1)
  567. scores = scores.masked_fill(mask == 0, torch.finfo(scores.dtype).min)
  568. # Run the optimal transport.
  569. scores = log_optimal_transport(scores, self.bin_score, iterations=self.config.sinkhorn_iterations)
  570. # Get the matches with score above "match_threshold".
  571. max0 = scores[:, :-1, :-1].max(2)
  572. max1 = scores[:, :-1, :-1].max(1)
  573. indices0 = max0.indices
  574. indices1 = max1.indices
  575. mutual0 = arange_like(indices0, 1)[None] == indices1.gather(1, indices0)
  576. mutual1 = arange_like(indices1, 1)[None] == indices0.gather(1, indices1)
  577. zero = scores.new_tensor(0)
  578. matching_scores0 = torch.where(mutual0, max0.values.exp(), zero)
  579. matching_scores0 = torch.where(matching_scores0 > self.config.matching_threshold, matching_scores0, zero)
  580. matching_scores1 = torch.where(mutual1, matching_scores0.gather(1, indices1), zero)
  581. valid0 = mutual0 & (matching_scores0 > zero)
  582. valid1 = mutual1 & valid0.gather(1, indices1)
  583. matches0 = torch.where(valid0, indices0, indices0.new_tensor(-1))
  584. matches1 = torch.where(valid1, indices1, indices1.new_tensor(-1))
  585. matches = torch.cat([matches0, matches1], dim=1).reshape(batch_size, 2, -1)
  586. matching_scores = torch.cat([matching_scores0, matching_scores1], dim=1).reshape(batch_size, 2, -1)
  587. if output_hidden_states:
  588. all_hidden_states = all_hidden_states + encoded_keypoints[1]
  589. all_hidden_states = all_hidden_states + gnn_outputs[1]
  590. all_hidden_states = all_hidden_states + (projected_descriptors,)
  591. all_hidden_states = tuple(
  592. x.reshape(batch_size, 2, num_keypoints, -1).transpose(-1, -2) for x in all_hidden_states
  593. )
  594. if output_attentions:
  595. all_attentions = all_attentions + gnn_outputs[2]
  596. all_attentions = tuple(x.reshape(batch_size, 2, -1, num_keypoints, num_keypoints) for x in all_attentions)
  597. return (
  598. matches,
  599. matching_scores,
  600. all_hidden_states,
  601. all_attentions,
  602. )
  603. @auto_docstring
  604. def forward(
  605. self,
  606. pixel_values: torch.FloatTensor,
  607. labels: Optional[torch.LongTensor] = None,
  608. output_attentions: Optional[bool] = None,
  609. output_hidden_states: Optional[bool] = None,
  610. return_dict: Optional[bool] = None,
  611. ) -> Union[tuple, KeypointMatchingOutput]:
  612. r"""
  613. Examples:
  614. ```python
  615. >>> from transformers import AutoImageProcessor, AutoModel
  616. >>> import torch
  617. >>> from PIL import Image
  618. >>> import requests
  619. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_78916675_4568141288.jpg?raw=true"
  620. >>> image1 = Image.open(requests.get(url, stream=True).raw)
  621. >>> url = "https://github.com/magicleap/SuperGluePretrainedNetwork/blob/master/assets/phototourism_sample_images/london_bridge_19481797_2295892421.jpg?raw=true"
  622. >>> image2 = Image.open(requests.get(url, stream=True).raw)
  623. >>> images = [image1, image2]
  624. >>> processor = AutoImageProcessor.from_pretrained("magic-leap-community/superglue_outdoor")
  625. >>> model = AutoModel.from_pretrained("magic-leap-community/superglue_outdoor")
  626. >>> with torch.no_grad():
  627. >>> inputs = processor(images, return_tensors="pt")
  628. >>> outputs = model(**inputs)
  629. ```"""
  630. loss = None
  631. if labels is not None:
  632. raise ValueError("SuperGlue is not trainable, no labels should be provided.")
  633. output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
  634. output_hidden_states = (
  635. output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
  636. )
  637. return_dict = return_dict if return_dict is not None else self.config.use_return_dict
  638. if pixel_values.ndim != 5 or pixel_values.size(1) != 2:
  639. raise ValueError("Input must be a 5D tensor of shape (batch_size, 2, num_channels, height, width)")
  640. batch_size, _, channels, height, width = pixel_values.shape
  641. pixel_values = pixel_values.reshape(batch_size * 2, channels, height, width)
  642. keypoint_detections = self.keypoint_detector(pixel_values)
  643. keypoints, scores, descriptors, mask = keypoint_detections[:4]
  644. keypoints = keypoints.reshape(batch_size, 2, -1, 2).to(pixel_values)
  645. scores = scores.reshape(batch_size, 2, -1).to(pixel_values)
  646. descriptors = descriptors.reshape(batch_size, 2, -1, self.config.hidden_size).to(pixel_values)
  647. mask = mask.reshape(batch_size, 2, -1)
  648. absolute_keypoints = keypoints.clone()
  649. absolute_keypoints[:, :, :, 0] = absolute_keypoints[:, :, :, 0] * width
  650. absolute_keypoints[:, :, :, 1] = absolute_keypoints[:, :, :, 1] * height
  651. matches, matching_scores, hidden_states, attentions = self._match_image_pair(
  652. absolute_keypoints,
  653. descriptors,
  654. scores,
  655. height,
  656. width,
  657. mask=mask,
  658. output_attentions=output_attentions,
  659. output_hidden_states=output_hidden_states,
  660. )
  661. if not return_dict:
  662. return tuple(
  663. v
  664. for v in [loss, matches, matching_scores, keypoints, mask, hidden_states, attentions]
  665. if v is not None
  666. )
  667. return KeypointMatchingOutput(
  668. loss=loss,
  669. matches=matches,
  670. matching_scores=matching_scores,
  671. keypoints=keypoints,
  672. mask=mask,
  673. hidden_states=hidden_states,
  674. attentions=attentions,
  675. )
  676. __all__ = ["SuperGluePreTrainedModel", "SuperGlueForKeypointMatching"]