sam2_utils.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import os
  7. import sys
  8. from collections.abc import Mapping
  9. import torch
  10. from sam2.build_sam import build_sam2
  11. from sam2.modeling.sam2_base import SAM2Base
  12. logger = logging.getLogger(__name__)
  13. def _get_model_cfg(model_type) -> str:
  14. assert model_type in ["sam2_hiera_tiny", "sam2_hiera_small", "sam2_hiera_large", "sam2_hiera_base_plus"]
  15. if model_type == "sam2_hiera_tiny":
  16. model_cfg = "sam2_hiera_t.yaml"
  17. elif model_type == "sam2_hiera_small":
  18. model_cfg = "sam2_hiera_s.yaml"
  19. elif model_type == "sam2_hiera_base_plus":
  20. model_cfg = "sam2_hiera_b+.yaml"
  21. else:
  22. model_cfg = "sam2_hiera_l.yaml"
  23. return model_cfg
  24. def load_sam2_model(sam2_dir, model_type, device: str | torch.device = "cpu") -> SAM2Base:
  25. checkpoints_dir = os.path.join(sam2_dir, "checkpoints")
  26. sam2_config_dir = os.path.join(sam2_dir, "sam2_configs")
  27. if not os.path.exists(sam2_dir):
  28. raise FileNotFoundError(f"{sam2_dir} does not exist. Please specify --sam2_dir correctly.")
  29. if not os.path.exists(checkpoints_dir):
  30. raise FileNotFoundError(f"{checkpoints_dir} does not exist. Please specify --sam2_dir correctly.")
  31. if not os.path.exists(sam2_config_dir):
  32. raise FileNotFoundError(f"{sam2_config_dir} does not exist. Please specify --sam2_dir correctly.")
  33. checkpoint_path = os.path.join(checkpoints_dir, f"{model_type}.pt")
  34. if not os.path.exists(checkpoint_path):
  35. raise FileNotFoundError(f"{checkpoint_path} does not exist. Please download checkpoints under the directory.")
  36. if sam2_dir not in sys.path:
  37. sys.path.append(sam2_dir)
  38. model_cfg = _get_model_cfg(model_type)
  39. sam2_model = build_sam2(model_cfg, checkpoint_path, device=device)
  40. return sam2_model
  41. def sam2_onnx_path(output_dir, model_type, component, multimask_output=False, suffix=""):
  42. if component == "image_encoder":
  43. return os.path.join(output_dir, f"{model_type}_image_encoder{suffix}.onnx")
  44. elif component == "mask_decoder":
  45. return os.path.join(output_dir, f"{model_type}_mask_decoder{suffix}.onnx")
  46. elif component == "prompt_encoder":
  47. return os.path.join(output_dir, f"{model_type}_prompt_encoder{suffix}.onnx")
  48. else:
  49. assert component == "image_decoder"
  50. return os.path.join(
  51. output_dir, f"{model_type}_image_decoder" + ("_multi" if multimask_output else "") + f"{suffix}.onnx"
  52. )
  53. def encoder_shape_dict(batch_size: int, height: int, width: int) -> Mapping[str, list[int]]:
  54. assert height == 1024 and width == 1024, "Only 1024x1024 images are supported."
  55. return {
  56. "image": [batch_size, 3, height, width],
  57. "image_features_0": [batch_size, 32, height // 4, width // 4],
  58. "image_features_1": [batch_size, 64, height // 8, width // 8],
  59. "image_embeddings": [batch_size, 256, height // 16, width // 16],
  60. }
  61. def decoder_shape_dict(
  62. original_image_height: int,
  63. original_image_width: int,
  64. num_labels: int = 1,
  65. max_points: int = 16,
  66. num_masks: int = 1,
  67. ) -> dict:
  68. height: int = 1024
  69. width: int = 1024
  70. return {
  71. "image_features_0": [1, 32, height // 4, width // 4],
  72. "image_features_1": [1, 64, height // 8, width // 8],
  73. "image_embeddings": [1, 256, height // 16, width // 16],
  74. "point_coords": [num_labels, max_points, 2],
  75. "point_labels": [num_labels, max_points],
  76. "input_masks": [num_labels, 1, height // 4, width // 4],
  77. "has_input_masks": [num_labels],
  78. "original_image_size": [2],
  79. "masks": [num_labels, num_masks, original_image_height, original_image_width],
  80. "iou_predictions": [num_labels, num_masks],
  81. "low_res_masks": [num_labels, num_masks, height // 4, width // 4],
  82. }
  83. def compare_tensors_with_tolerance(
  84. name: str,
  85. tensor1: torch.Tensor,
  86. tensor2: torch.Tensor,
  87. atol=5e-3,
  88. rtol=1e-4,
  89. mismatch_percentage_tolerance=0.1,
  90. ) -> bool:
  91. assert tensor1.shape == tensor2.shape
  92. a = tensor1.clone().float()
  93. b = tensor2.clone().float()
  94. differences = torch.abs(a - b)
  95. mismatch_count = (differences > (rtol * torch.max(torch.abs(a), torch.abs(b)) + atol)).sum().item()
  96. total_elements = a.numel()
  97. mismatch_percentage = (mismatch_count / total_elements) * 100
  98. passed = mismatch_percentage < mismatch_percentage_tolerance
  99. log_func = logger.error if not passed else logger.info
  100. log_func(
  101. "%s: mismatched elements percentage %.2f (%d/%d). Verification %s (threshold=%.2f).",
  102. name,
  103. mismatch_percentage,
  104. mismatch_count,
  105. total_elements,
  106. "passed" if passed else "failed",
  107. mismatch_percentage_tolerance,
  108. )
  109. return passed
  110. def random_sam2_input_image(batch_size=1, image_height=1024, image_width=1024) -> torch.Tensor:
  111. image = torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32).cpu()
  112. return image
  113. def setup_logger(verbose=True):
  114. if verbose:
  115. logging.basicConfig(format="[%(filename)s:%(lineno)s - %(funcName)20s()] %(message)s")
  116. logging.getLogger().setLevel(logging.INFO)
  117. else:
  118. logging.basicConfig(format="[%(message)s")
  119. logging.getLogger().setLevel(logging.WARNING)