| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272 |
- # -------------------------------------------------------------------------
- # Copyright (R) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- import warnings
- import torch
- import torch.nn.functional as F
- from image_encoder import SAM2ImageEncoder, random_sam2_input_image
- from mask_decoder import SAM2MaskDecoder
- from prompt_encoder import SAM2PromptEncoder
- from sam2.modeling.sam2_base import SAM2Base
- from sam2_utils import compare_tensors_with_tolerance
- from torch import nn
- logger = logging.getLogger(__name__)
- class SAM2ImageDecoder(nn.Module):
- def __init__(
- self,
- sam_model: SAM2Base,
- multimask_output: bool,
- dynamic_multimask_via_stability: bool = True,
- return_logits: bool = False,
- mask_threshold: float = 0.0,
- ) -> None:
- super().__init__()
- self.prompt_encoder = SAM2PromptEncoder(sam_model)
- self.mask_decoder = SAM2MaskDecoder(sam_model, multimask_output, dynamic_multimask_via_stability)
- self.return_logits = return_logits
- self.mask_threshold = mask_threshold
- @torch.no_grad()
- def forward(
- self,
- image_features_0: torch.Tensor,
- image_features_1: torch.Tensor,
- image_embeddings: torch.Tensor,
- point_coords: torch.Tensor,
- point_labels: torch.Tensor,
- input_masks: torch.Tensor,
- has_input_masks: torch.Tensor,
- original_image_size: torch.Tensor,
- enable_nvtx_profile: bool = False,
- ):
- """
- Decode masks from image features and prompts. Batched images are not supported. H=W=1024.
- Args:
- image_features_0 (torch.Tensor): [1, 32, H/4, W/4]. high resolution features of level 0 from image encoder.
- image_features_1 (torch.Tensor): [1, 64, H/8, W/8]. high resolution features of level 1 from image encoder.
- image_embeddings (torch.Tensor): [1, 256, H/16, W/16]. image embedding from image encoder.
- point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
- coordinate in (x, y) format of the P input points in image of size 1024x1024.
- point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
- positive (foreground), 0 means negative (background), -1 means padding,
- 2 (box left upper corner), 3 (box right bottom corner).
- input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
- Typically coming from a previous iteration.
- has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
- original_image_size(torch.Tensor): [2]. original image size H_o, W_o.
- enable_nvtx_profile (bool): enable NVTX profiling.
- Returns:
- masks (torch.Tensor): [1, M, H_o, W_o] where M=3 or 1. Masks of original image size.
- iou_predictions (torch.Tensor): [1, M]. scores for M masks.
- low_res_masks (torch.Tensor, optional): [1, M, H/4, W/4]. low resolution masks.
- """
- nvtx_helper = None
- if enable_nvtx_profile:
- from nvtx_helper import NvtxHelper # noqa: PLC0415
- nvtx_helper = NvtxHelper(["prompt_encoder", "mask_decoder", "post_process"])
- if nvtx_helper is not None:
- nvtx_helper.start_profile("prompt_encoder", color="blue")
- sparse_embeddings, dense_embeddings, image_pe = self.prompt_encoder(
- point_coords, point_labels, input_masks, has_input_masks
- )
- if nvtx_helper is not None:
- nvtx_helper.stop_profile("prompt_encoder")
- nvtx_helper.start_profile("mask_decoder", color="red")
- low_res_masks, iou_predictions = self.mask_decoder(
- image_features_0, image_features_1, image_embeddings, image_pe, sparse_embeddings, dense_embeddings
- )
- if nvtx_helper is not None:
- nvtx_helper.stop_profile("mask_decoder")
- nvtx_helper.start_profile("post_process", color="green")
- # Interpolate the low resolution masks back to the original image size.
- masks = F.interpolate(
- low_res_masks,
- (original_image_size[0], original_image_size[1]),
- mode="bilinear",
- align_corners=False, # Note that align_corners=True has less mismatches during comparing ORT and PyTorch.
- )
- low_res_masks = torch.clamp(low_res_masks, -32.0, 32.0)
- if not self.return_logits:
- masks = masks > self.mask_threshold
- if nvtx_helper is not None:
- nvtx_helper.stop_profile("post_process")
- nvtx_helper.print_latency()
- return masks, iou_predictions, low_res_masks
- def export_decoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- multimask_output: bool = False,
- verbose: bool = False,
- ):
- batch_size = 1
- image = random_sam2_input_image(batch_size)
- sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
- image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
- logger.info("image_features_0.shape: %s", image_features_0.shape)
- logger.info("image_features_1.shape: %s", image_features_1.shape)
- logger.info("image_embeddings.shape: %s", image_embeddings.shape)
- sam2_decoder = SAM2ImageDecoder(
- sam2_model,
- multimask_output=multimask_output,
- dynamic_multimask_via_stability=True,
- ).cpu()
- num_labels = 2
- num_points = 3
- point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
- point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
- input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
- has_input_masks = torch.ones(1, dtype=torch.float)
- original_image_size = torch.tensor([1200, 1800], dtype=torch.int32)
- example_inputs = (
- image_features_0,
- image_features_1,
- image_embeddings,
- point_coords,
- point_labels,
- input_masks,
- has_input_masks,
- original_image_size,
- )
- logger.info("point_coords.shape: %s", point_coords.shape)
- logger.info("point_labels.shape: %s", point_labels.shape)
- logger.info("input_masks.shape: %s", input_masks.shape)
- logger.info("has_input_masks.shape: %s", has_input_masks.shape)
- logger.info("original_image_size.shape: %s", original_image_size.shape)
- if verbose:
- masks, iou_predictions, low_res_masks = sam2_decoder(*example_inputs)
- logger.info("masks.shape: %s", masks.shape)
- logger.info("iou_predictions.shape: %s", iou_predictions.shape)
- logger.info("low_res_masks.shape: %s", low_res_masks.shape)
- input_names = [
- "image_features_0",
- "image_features_1",
- "image_embeddings",
- "point_coords",
- "point_labels",
- "input_masks",
- "has_input_masks",
- "original_image_size",
- ]
- output_names = ["masks", "iou_predictions", "low_res_masks"]
- dynamic_axes = {
- "point_coords": {0: "num_labels", 1: "num_points"},
- "point_labels": {0: "num_labels", 1: "num_points"},
- "input_masks": {0: "num_labels"},
- "has_input_masks": {0: "num_labels"},
- "masks": {0: "num_labels", 2: "original_image_height", 3: "original_image_width"},
- "low_res_masks": {0: "num_labels"},
- "iou_predictions": {0: "num_labels"},
- }
- with warnings.catch_warnings():
- if not verbose:
- warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
- warnings.filterwarnings("ignore", category=UserWarning)
- torch.onnx.export(
- sam2_decoder,
- example_inputs,
- onnx_model_path,
- export_params=True,
- opset_version=16,
- do_constant_folding=True,
- input_names=input_names,
- output_names=output_names,
- dynamic_axes=dynamic_axes,
- )
- logger.info("decoder onnx model saved to %s", onnx_model_path)
- def test_decoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- multimask_output=False,
- ):
- batch_size = 1
- image = random_sam2_input_image(batch_size)
- sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
- image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
- sam2_image_decoder = SAM2ImageDecoder(
- sam2_model,
- multimask_output=multimask_output,
- dynamic_multimask_via_stability=True,
- ).cpu()
- num_labels = 1
- num_points = 5
- point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
- point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
- input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
- has_input_masks = torch.zeros(1, dtype=torch.float)
- original_image_size = torch.tensor([1500, 1500], dtype=torch.int32)
- example_inputs = (
- image_features_0,
- image_features_1,
- image_embeddings,
- point_coords,
- point_labels,
- input_masks,
- has_input_masks,
- original_image_size,
- )
- masks, iou_predictions, low_res_masks = sam2_image_decoder(*example_inputs)
- import onnxruntime # noqa: PLC0415
- ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
- model_inputs = ort_session.get_inputs()
- input_names = [model_inputs[i].name for i in range(len(model_inputs))]
- logger.info("input_names: %s", input_names)
- model_outputs = ort_session.get_outputs()
- output_names = [model_outputs[i].name for i in range(len(model_outputs))]
- logger.info("output_names: %s", output_names)
- inputs = {model_inputs[i].name: example_inputs[i].numpy() for i in range(len(model_inputs))}
- outputs = ort_session.run(output_names, inputs)
- for i, output_name in enumerate(output_names):
- logger.info(f"{output_name}.shape: %s", outputs[i].shape)
- ort_masks, ort_iou_predictions, ort_low_res_masks = outputs
- if (
- compare_tensors_with_tolerance("masks", masks.float(), torch.tensor(ort_masks).float())
- and compare_tensors_with_tolerance("iou_predictions", iou_predictions, torch.tensor(ort_iou_predictions))
- and compare_tensors_with_tolerance("low_res_masks", low_res_masks, torch.tensor(ort_low_res_masks))
- ):
- print("onnx model has been verified:", onnx_model_path)
- else:
- print("onnx model verification failed:", onnx_model_path)
|