image_decoder.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import warnings
  7. import torch
  8. import torch.nn.functional as F
  9. from image_encoder import SAM2ImageEncoder, random_sam2_input_image
  10. from mask_decoder import SAM2MaskDecoder
  11. from prompt_encoder import SAM2PromptEncoder
  12. from sam2.modeling.sam2_base import SAM2Base
  13. from sam2_utils import compare_tensors_with_tolerance
  14. from torch import nn
  15. logger = logging.getLogger(__name__)
  16. class SAM2ImageDecoder(nn.Module):
  17. def __init__(
  18. self,
  19. sam_model: SAM2Base,
  20. multimask_output: bool,
  21. dynamic_multimask_via_stability: bool = True,
  22. return_logits: bool = False,
  23. mask_threshold: float = 0.0,
  24. ) -> None:
  25. super().__init__()
  26. self.prompt_encoder = SAM2PromptEncoder(sam_model)
  27. self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
  28. self.return_logits = return_logits
  29. self.mask_threshold = mask_threshold
  30. @torch.no_grad()
  31. def forward(
  32. self,
  33. image_features_0: torch.Tensor,
  34. image_features_1: torch.Tensor,
  35. image_embeddings: torch.Tensor,
  36. point_coords: torch.Tensor,
  37. point_labels: torch.Tensor,
  38. input_masks: torch.Tensor,
  39. has_input_masks: torch.Tensor,
  40. original_image_size: torch.Tensor,
  41. enable_nvtx_profile: bool = False,
  42. ):
  43. """
  44. Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
  45. Args:
  46. image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
  47. image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
  48. image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
  49. point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
  50. coordinate in (x, y) format of the P input points in image of size 1024x1024.
  51. point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
  52. positive (foreground), 0 means negative (background), -1 means padding,
  53. 2 (box left upper corner), 3 (box right bottom corner).
  54. input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
  55. Typically coming from a previous iteration.
  56. has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
  57. original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
  58. enable_nvtx_profile (bool): enable NVTX profiling.
  59. Returns:
  60. masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
  61. iou_predictions (torch.Tensor): [1, M]. scores for M masks.
  62. low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
  63. """
  64. nvtx_helper = None
  65. if enable_nvtx_profile:
  66. from nvtx_helper import NvtxHelper # noqa: PLC0415
  67. nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"])
  68. if nvtx_helper is not None:
  69. nvtx_helper.start_profile("prompt_encoder", color="blue")
  70. sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
  71. point_coords, point_labels, input_masks, has_input_masks
  72. )
  73. if nvtx_helper is not None:
  74. nvtx_helper.stop_profile("prompt_encoder")
  75. nvtx_helper.start_profile("mask_decoder", color="red")
  76. low_res_masks, iou_predictions = self.mask_decoder(
  77. image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
  78. )
  79. if nvtx_helper is not None:
  80. nvtx_helper.stop_profile("mask_decoder")
  81. nvtx_helper.start_profile("post_process", color="green")
  82. # Interpolate the low resolution masks back to the original image size.
  83. masks = F.interpolate(
  84. low_res_masks,
  85. (original_image_size[0], original_image_size[1]),
  86. mode="bilinear",
  87. align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
  88. )
  89. low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
  90. if not self.return_logits:
  91. masks = masks > self.mask_threshold
  92. if nvtx_helper is not None:
  93. nvtx_helper.stop_profile("post_process")
  94. nvtx_helper.print_latency()
  95. return masks, iou_predictions, low_res_masks
  96. def export_decoder_onnx(
  97. sam2_model: SAM2Base,
  98. onnx_model_path: str,
  99. multimask_output: bool = False,
  100. verbose: bool = False,
  101. ):
  102. batch_size = 1
  103. image = random_sam2_input_image(batch_size)
  104. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  105. image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
  106. logger.info("image_features_0.shape: %s", image_features_0.shape)
  107. logger.info("image_features_1.shape: %s", image_features_1.shape)
  108. logger.info("image_embeddings.shape: %s", image_embeddings.shape)
  109. sam2_decoder = SAM2ImageDecoder(
  110. sam2_model,
  111. multimask_output=multimask_output,
  112. dynamic_multimask_via_stability=True,
  113. ).cpu()
  114. num_labels = 2
  115. num_points = 3
  116. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  117. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
  118. input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
  119. has_input_masks = torch.ones(1, dtype=torch.float)
  120. original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
  121. example_inputs = (
  122. image_features_0,
  123. image_features_1,
  124. image_embeddings,
  125. point_coords,
  126. point_labels,
  127. input_masks,
  128. has_input_masks,
  129. original_image_size,
  130. )
  131. logger.info("point_coords.shape: %s", point_coords.shape)
  132. logger.info("point_labels.shape: %s", point_labels.shape)
  133. logger.info("input_masks.shape: %s", input_masks.shape)
  134. logger.info("has_input_masks.shape: %s", has_input_masks.shape)
  135. logger.info("original_image_size.shape: %s", original_image_size.shape)
  136. if verbose:
  137. masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
  138. logger.info("masks.shape: %s", masks.shape)
  139. logger.info("iou_predictions.shape: %s", iou_predictions.shape)
  140. logger.info("low_res_masks.shape: %s", low_res_masks.shape)
  141. input_names = [
  142. "image_features_0",
  143. "image_features_1",
  144. "image_embeddings",
  145. "point_coords",
  146. "point_labels",
  147. "input_masks",
  148. "has_input_masks",
  149. "original_image_size",
  150. ]
  151. output_names = ["masks", "iou_predictions", "low_res_masks"]
  152. dynamic_axes = {
  153. "point_coords": {0: "num_labels", 1: "num_points"},
  154. "point_labels": {0: "num_labels", 1: "num_points"},
  155. "input_masks": {0: "num_labels"},
  156. "has_input_masks": {0: "num_labels"},
  157. "masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
  158. "low_res_masks": {0: "num_labels"},
  159. "iou_predictions": {0: "num_labels"},
  160. }
  161. with warnings.catch_warnings():
  162. if not verbose:
  163. warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
  164. warnings.filterwarnings("ignore", category=UserWarning)
  165. torch.onnx.export(
  166. sam2_decoder,
  167. example_inputs,
  168. onnx_model_path,
  169. export_params=True,
  170. opset_version=16,
  171. do_constant_folding=True,
  172. input_names=input_names,
  173. output_names=output_names,
  174. dynamic_axes=dynamic_axes,
  175. )
  176. logger.info("decoder onnx model saved to %s", onnx_model_path)
  177. def test_decoder_onnx(
  178. sam2_model: SAM2Base,
  179. onnx_model_path: str,
  180. multimask_output=False,
  181. ):
  182. batch_size = 1
  183. image = random_sam2_input_image(batch_size)
  184. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  185. image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
  186. sam2_image_decoder = SAM2ImageDecoder(
  187. sam2_model,
  188. multimask_output=multimask_output,
  189. dynamic_multimask_via_stability=True,
  190. ).cpu()
  191. num_labels = 1
  192. num_points = 5
  193. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  194. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
  195. input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
  196. has_input_masks = torch.zeros(1, dtype=torch.float)
  197. original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
  198. example_inputs = (
  199. image_features_0,
  200. image_features_1,
  201. image_embeddings,
  202. point_coords,
  203. point_labels,
  204. input_masks,
  205. has_input_masks,
  206. original_image_size,
  207. )
  208. masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
  209. import onnxruntime # noqa: PLC0415
  210. ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
  211. model_inputs = ort_session.get_inputs()
  212. input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  213. logger.info("input_names: %s", input_names)
  214. model_outputs = ort_session.get_outputs()
  215. output_names = [model_outputs[i].name for i in range(len(model_outputs))]
  216. logger.info("output_names: %s", output_names)
  217. inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
  218. outputs = ort_session.run(output_names, inputs)
  219. for i, output_name in enumerate(output_names):
  220. logger.info(f"{output_name}.shape: %s", outputs[i].shape)
  221. ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
  222. if (
  223. compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
  224. and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
  225. and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
  226. ):
  227. print("onnx model has been verified:", onnx_model_path)
  228. else:
  229. print("onnx model verification failed:", onnx_model_path)