keypoint_matching.py 7.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178
  1. # Copyright 2025 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from collections.abc import Sequence
  15. from typing import Any, TypedDict, Union
  16. from typing_extensions import TypeAlias, overload
  17. from ..image_utils import is_pil_image
  18. from ..utils import is_vision_available, requires_backends
  19. from .base import Pipeline
  20. if is_vision_available():
  21. from PIL import Image
  22. from ..image_utils import load_image
  23. ImagePair: TypeAlias = Sequence[Union["Image.Image", str]]
  24. class Keypoint(TypedDict):
  25. x: float
  26. y: float
  27. class Match(TypedDict):
  28. keypoint_image_0: Keypoint
  29. keypoint_image_1: Keypoint
  30. score: float
  31. def validate_image_pairs(images: Any) -> Sequence[Sequence[ImagePair]]:
  32. error_message = (
  33. "Input images must be a one of the following :",
  34. " - A pair of images.",
  35. " - A list of pairs of images.",
  36. )
  37. def _is_valid_image(image):
  38. """images is a PIL Image or a string."""
  39. return is_pil_image(image) or isinstance(image, str)
  40. if isinstance(images, Sequence):
  41. if len(images) == 2 and all((_is_valid_image(image)) for image in images):
  42. return [images]
  43. if all(
  44. isinstance(image_pair, Sequence)
  45. and len(image_pair) == 2
  46. and all(_is_valid_image(image) for image in image_pair)
  47. for image_pair in images
  48. ):
  49. return images
  50. raise ValueError(error_message)
  51. class KeypointMatchingPipeline(Pipeline):
  52. """
  53. Keypoint matching pipeline using any `AutoModelForKeypointMatching`. This pipeline matches keypoints between two images.
  54. """
  55. _load_processor = False
  56. _load_image_processor = True
  57. _load_feature_extractor = False
  58. _load_tokenizer = False
  59. def __init__(self, *args, **kwargs):
  60. super().__init__(*args, **kwargs)
  61. requires_backends(self, "vision")
  62. if self.framework != "pt":
  63. raise ValueError("Keypoint matching pipeline only supports PyTorch (framework='pt').")
  64. def _sanitize_parameters(self, threshold=None, timeout=None):
  65. preprocess_params = {}
  66. if timeout is not None:
  67. preprocess_params["timeout"] = timeout
  68. postprocess_params = {}
  69. if threshold is not None:
  70. postprocess_params["threshold"] = threshold
  71. return preprocess_params, {}, postprocess_params
  72. @overload
  73. def __call__(self, inputs: ImagePair, threshold: float = 0.0, **kwargs: Any) -> list[Match]: ...
  74. @overload
  75. def __call__(self, inputs: list[ImagePair], threshold: float = 0.0, **kwargs: Any) -> list[list[Match]]: ...
  76. def __call__(
  77. self,
  78. inputs: Union[list[ImagePair], ImagePair],
  79. threshold: float = 0.0,
  80. **kwargs: Any,
  81. ) -> Union[list[Match], list[list[Match]]]:
  82. """
  83. Find matches between keypoints in two images.
  84. Args:
  85. inputs (`str`, `list[str]`, `PIL.Image` or `list[PIL.Image]`):
  86. The pipeline handles three types of images:
  87. - A string containing a http link pointing to an image
  88. - A string containing a local path to an image
  89. - An image loaded in PIL directly
  90. The pipeline accepts either a single pair of images or a batch of image pairs, which must then be passed as a string.
  91. Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
  92. images.
  93. threshold (`float`, *optional*, defaults to 0.0):
  94. The threshold to use for keypoint matching. Keypoints matched with a lower matching score will be filtered out.
  95. A value of 0 means that all matched keypoints will be returned.
  96. kwargs:
  97. `timeout (`float`, *optional*, defaults to None)`
  98. The maximum time in seconds to wait for fetching images from the web. If None, no timeout is set and
  99. the call may block forever.
  100. Return:
  101. Union[list[Match], list[list[Match]]]:
  102. A list of matches or a list if a single image pair is provided, or of lists of matches if a batch
  103. of image pairs is provided. Each match is a dictionary containing the following keys:
  104. - **keypoint_image_0** (`Keypoint`): The keypoint in the first image (x, y coordinates).
  105. - **keypoint_image_1** (`Keypoint`): The keypoint in the second image (x, y coordinates).
  106. - **score** (`float`): The matching score between the two keypoints.
  107. """
  108. if inputs is None:
  109. raise ValueError("Cannot call the keypoint-matching pipeline without an inputs argument!")
  110. formatted_inputs = validate_image_pairs(inputs)
  111. outputs = super().__call__(formatted_inputs, threshold=threshold, **kwargs)
  112. if len(formatted_inputs) == 1:
  113. return outputs[0]
  114. return outputs
  115. def preprocess(self, images, timeout=None):
  116. images = [load_image(image, timeout=timeout) for image in images]
  117. model_inputs = self.image_processor(images=images, return_tensors=self.framework)
  118. model_inputs = model_inputs.to(self.dtype)
  119. target_sizes = [image.size for image in images]
  120. preprocess_outputs = {"model_inputs": model_inputs, "target_sizes": target_sizes}
  121. return preprocess_outputs
  122. def _forward(self, preprocess_outputs):
  123. model_inputs = preprocess_outputs["model_inputs"]
  124. model_outputs = self.model(**model_inputs)
  125. forward_outputs = {"model_outputs": model_outputs, "target_sizes": [preprocess_outputs["target_sizes"]]}
  126. return forward_outputs
  127. def postprocess(self, forward_outputs, threshold=0.0) -> list[Match]:
  128. model_outputs = forward_outputs["model_outputs"]
  129. target_sizes = forward_outputs["target_sizes"]
  130. postprocess_outputs = self.image_processor.post_process_keypoint_matching(
  131. model_outputs, target_sizes=target_sizes, threshold=threshold
  132. )
  133. postprocess_outputs = postprocess_outputs[0]
  134. pair_result = []
  135. for kp_0, kp_1, score in zip(
  136. postprocess_outputs["keypoints0"],
  137. postprocess_outputs["keypoints1"],
  138. postprocess_outputs["matching_scores"],
  139. ):
  140. kp_0 = Keypoint(x=kp_0[0].item(), y=kp_0[1].item())
  141. kp_1 = Keypoint(x=kp_1[0].item(), y=kp_1[1].item())
  142. pair_result.append(Match(keypoint_image_0=kp_0, keypoint_image_1=kp_1, score=score.item()))
  143. pair_result = sorted(pair_result, key=lambda x: x["score"], reverse=True)
  144. return pair_result