mask_decoder.py 8.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import warnings
  7. import torch
  8. from image_encoder import SAM2ImageEncoder, random_sam2_input_image
  9. from prompt_encoder import SAM2PromptEncoder
  10. from sam2.modeling.sam2_base import SAM2Base
  11. from torch import nn
  12. logger = logging.getLogger(__name__)
  13. class SAM2MaskDecoder(nn.Module):
  14. def __init__(
  15. self,
  16. sam_model: SAM2Base,
  17. multimask_output: bool,
  18. dynamic_multimask_via_stability: bool = True,
  19. ) -> None:
  20. super().__init__()
  21. self.mask_decoder = sam_model.sam_mask_decoder
  22. self.prompt_encoder = sam_model.sam_prompt_encoder
  23. self.model = sam_model
  24. self.multimask_output = multimask_output
  25. self.dynamic_multimask_via_stability = dynamic_multimask_via_stability
  26. @torch.no_grad()
  27. def forward(
  28. self,
  29. image_features_0: torch.Tensor,
  30. image_features_1: torch.Tensor,
  31. image_embeddings: torch.Tensor,
  32. image_pe: torch.Tensor,
  33. sparse_embeddings: torch.Tensor,
  34. dense_embeddings: torch.Tensor,
  35. ):
  36. """
  37. Decode masks from image and prompt embeddings. Only support H=W=1024.
  38. Args:
  39. image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
  40. image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
  41. image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
  42. image_pe (torch.Tensor): [1, 256, H/16, W/16]. image positional encoding.
  43. sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
  44. dense_embeddings (torch.Tensor): [L, 256, H/16, W/16]. embedding for input masks.
  45. Returns:
  46. low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
  47. iou_predictions (torch.Tensor): [1, M]. scores for M masks.
  48. """
  49. low_res_masks, iou_predictions, _, _ = self.mask_decoder.predict_masks(
  50. image_embeddings=image_embeddings,
  51. image_pe=image_pe,
  52. sparse_prompt_embeddings=sparse_embeddings,
  53. dense_prompt_embeddings=dense_embeddings,
  54. repeat_image=sparse_embeddings.shape[0] > 1, # batch mode
  55. high_res_features=[image_features_0, image_features_1],
  56. )
  57. if self.multimask_output:
  58. low_res_masks = low_res_masks[:, 1:, :, :]
  59. iou_predictions = iou_predictions[:, 1:]
  60. elif self.dynamic_multimask_via_stability:
  61. # When outputting a single mask, if the stability score from the current single-mask
  62. # output (based on output token 0) falls below a threshold, we instead select from
  63. # multi-mask outputs (based on output token 1~3) the mask with the highest predicted IoU score.
  64. low_res_masks, iou_predictions = self.mask_decoder._dynamic_multimask_via_stability(
  65. low_res_masks, iou_predictions
  66. )
  67. else:
  68. low_res_masks = low_res_masks[:, 0:1, :, :]
  69. iou_predictions = iou_predictions[:, 0:1]
  70. return low_res_masks, iou_predictions
  71. def export_mask_decoder_onnx(
  72. sam2_model: SAM2Base,
  73. onnx_model_path: str,
  74. multimask_output: bool,
  75. dynamic_multimask_via_stability: bool = True,
  76. verbose=False,
  77. ):
  78. sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
  79. image = random_sam2_input_image()
  80. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  81. image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
  82. logger.info("image_features_0.shape: %s", image_features_0.shape)
  83. logger.info("image_features_1.shape: %s", image_features_1.shape)
  84. logger.info("image_embeddings.shape: %s", image_embeddings.shape)
  85. # encode an random prompt
  86. num_labels = 2
  87. num_points = 3
  88. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  89. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
  90. input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
  91. has_input_masks = torch.ones(1, dtype=torch.float)
  92. sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
  93. point_coords, point_labels, input_masks, has_input_masks
  94. )
  95. logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
  96. logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
  97. logger.info("image_pe.shape: %s", image_pe.shape)
  98. sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
  99. inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
  100. low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
  101. logger.info("low_res_masks.shape: %s", low_res_masks.shape)
  102. logger.info("iou_predictions.shape: %s", iou_predictions.shape)
  103. with warnings.catch_warnings():
  104. if not verbose:
  105. warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
  106. warnings.filterwarnings("ignore", category=UserWarning)
  107. torch.onnx.export(
  108. sam2_mask_decoder,
  109. inputs,
  110. onnx_model_path,
  111. export_params=True,
  112. opset_version=18,
  113. do_constant_folding=True,
  114. input_names=[
  115. "image_features_0",
  116. "image_features_1",
  117. "image_embeddings",
  118. "image_pe",
  119. "sparse_embeddings",
  120. "dense_embeddings",
  121. ],
  122. output_names=["low_res_masks", "iou_predictions"],
  123. dynamic_axes={
  124. "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
  125. "dense_embeddings": {0: "num_labels"},
  126. "low_res_masks": {0: "num_labels"},
  127. "iou_predictions": {0: "num_labels"},
  128. },
  129. )
  130. print("mask decoder onnx model saved to", onnx_model_path)
  131. def test_mask_decoder_onnx(
  132. sam2_model: SAM2Base,
  133. onnx_model_path: str,
  134. multimask_output: bool,
  135. dynamic_multimask_via_stability: bool,
  136. ):
  137. sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
  138. image = random_sam2_input_image()
  139. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  140. image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
  141. num_labels = 1
  142. num_points = 5
  143. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  144. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.float)
  145. input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
  146. has_input_masks = torch.ones(1, dtype=torch.float)
  147. sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
  148. point_coords, point_labels, input_masks, has_input_masks
  149. )
  150. sam2_mask_decoder = SAM2MaskDecoder(sam2_model, multimask_output, dynamic_multimask_via_stability)
  151. inputs = (image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings)
  152. low_res_masks, iou_predictions = sam2_mask_decoder(*inputs)
  153. import onnxruntime # noqa: PLC0415
  154. ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
  155. model_inputs = ort_session.get_inputs()
  156. input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  157. logger.info("input_names: %s", input_names)
  158. model_outputs = ort_session.get_outputs()
  159. output_names = [model_outputs[i].name for i in range(len(model_outputs))]
  160. logger.info("output_names: %s", output_names)
  161. outputs = ort_session.run(
  162. output_names,
  163. {
  164. "image_features_0": image_features_0.numpy(),
  165. "image_features_1": image_features_1.numpy(),
  166. "image_embeddings": image_embeddings.numpy(),
  167. "image_pe": image_pe.numpy(),
  168. "sparse_embeddings": sparse_embeddings.numpy(),
  169. "dense_embeddings": dense_embeddings.numpy(),
  170. },
  171. )
  172. for i, output_name in enumerate(output_names):
  173. logger.info("output %s shape: %s", output_name, outputs[i].shape)
  174. ort_low_res_masks, ort_iou_predictions = outputs
  175. torch.testing.assert_close(low_res_masks, torch.tensor(ort_low_res_masks), atol=5e-3, rtol=1e-4)
  176. torch.testing.assert_close(iou_predictions, torch.tensor(ort_iou_predictions), atol=5e-3, rtol=1e-4)
  177. print(f"onnx model has been verified: {onnx_model_path}")