| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236 |
- # -------------------------------------------------------------------------
- # Copyright (R) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- import warnings
- import torch
- from sam2.modeling.sam2_base import SAM2Base
- from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
- from torch import nn
- import onnxruntime
- logger = logging.getLogger(__name__)
- class SAM2ImageEncoder(nn.Module):
- def __init__(self, sam_model: SAM2Base) -> None:
- super().__init__()
- self.model = sam_model
- self.image_encoder = sam_model.image_encoder
- self.no_mem_embed = sam_model.no_mem_embed
- def forward(
- self,
- image: torch.Tensor,
- enable_nvtx_profile: bool = False,
- ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
- """
- Encodes images into features.
- Only supports H=W=1024. If you want to use different image sizes like 512x512,
- see https://github.com/facebookresearch/segment-anything-2/issues/138.
- Args:
- image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
- enable_nvtx_profile (bool): enable NVTX profiling.
- Returns:
- image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
- image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
- image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
- """
- nvtx_helper = None
- if enable_nvtx_profile:
- from nvtx_helper import NvtxHelper # noqa: PLC0415
- nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
- if nvtx_helper is not None:
- nvtx_helper.start_profile("image_encoder")
- backbone_out = self.image_encoder(image)
- if nvtx_helper is not None:
- nvtx_helper.stop_profile("image_encoder")
- nvtx_helper.start_profile("post_process")
- # precompute projected level 0 and level 1 features in SAM decoder
- # to avoid running it again on every SAM click
- backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
- backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
- # Prepare and flatten visual features.
- feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
- vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
- feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
- # flatten NxCxHxW to HWxNxC
- # TODO: we should avoid this transpose since it will be transposed back to NCHW later.
- vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
- vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
- feats = [
- feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
- for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False)
- ][::-1]
- if nvtx_helper is not None:
- nvtx_helper.stop_profile("post_process")
- nvtx_helper.print_latency()
- return feats[0], feats[1], feats[2]
- def export_image_encoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- dynamic_batch_axes: bool = False,
- verbose: bool = False,
- dynamo: bool = False,
- clear_dynamo_metadata: bool = False,
- ):
- image = random_sam2_input_image()
- sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
- image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
- logger.info("image.shape: %s", image.shape)
- logger.info("image_features_0.shape: %s", image_features_0.shape)
- logger.info("image_features_1.shape: %s", image_features_1.shape)
- logger.info("image_embeddings.shape: %s", image_embeddings.shape)
- dynamic_axes = None
- if dynamic_batch_axes:
- dynamic_axes = {
- "image": {0: "batch_size"},
- "image_features_0": {0: "batch_size"},
- "image_features_1": {0: "batch_size"},
- "image_embeddings": {0: "batch_size"},
- }
- with warnings.catch_warnings():
- if not verbose:
- warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
- warnings.filterwarnings("ignore", category=UserWarning)
- if not dynamo:
- torch.onnx.export(
- sam2_encoder,
- image,
- onnx_model_path,
- export_params=True,
- opset_version=17,
- do_constant_folding=True,
- input_names=["image"],
- output_names=["image_features_0", "image_features_1", "image_embeddings"],
- dynamic_axes=dynamic_axes,
- )
- else:
- torch._dynamo.config.capture_scalar_outputs = True
- ep = torch.export.export(
- sam2_encoder,
- args=(image,),
- strict=False,
- dynamic_shapes=[
- {0: torch.export.Dim.AUTO},
- ],
- )
- onnx_program = torch.onnx.export(
- ep,
- (),
- opset_version=17,
- input_names=["image"],
- output_names=["image_features_0", "image_features_1", "image_embeddings"],
- dynamo=True,
- )
- onnx_program.optimize()
- onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False)
- import onnx # noqa: PLC0415
- from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper # noqa: PLC0415
- onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True)
- if dynamic_batch_axes:
- # Fix labels of dynamic axes since they can't be specified during Dynamo export currently
- onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size"
- for i in range(3):
- onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size"
- onnx_model_helper = DynamoOnnxHelper(onnx_model)
- onnx_model_helper.convert_constants_to_initializers()
- if clear_dynamo_metadata:
- onnx_model_helper.clear_metadata()
- import os # noqa: PLC0415
- if os.path.exists(onnx_model_path):
- os.remove(onnx_model_path)
- if os.path.exists(onnx_model_path + ".data"):
- os.remove(onnx_model_path + ".data")
- onnx_model_helper.model.save_model_to_file(
- onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True
- )
- print("encoder onnx model saved to", onnx_model_path)
- def test_image_encoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- dynamic_batch_axes=False,
- ):
- ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
- model_inputs = ort_session.get_inputs()
- input_names = [model_inputs[i].name for i in range(len(model_inputs))]
- logger.info("input_names: %s", input_names)
- model_outputs = ort_session.get_outputs()
- output_names = [model_outputs[i].name for i in range(len(model_outputs))]
- logger.info("output_names: %s", output_names)
- batch_sizes = [1, 2] if dynamic_batch_axes else [1]
- for batch_size in batch_sizes:
- image = random_sam2_input_image(batch_size)
- sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
- image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
- logger.info("image.shape: %s", image.shape)
- logger.info("image_features_0.shape: %s", image_features_0.shape)
- logger.info("image_features_1.shape: %s", image_features_1.shape)
- logger.info("image_embeddings.shape: %s", image_embeddings.shape)
- outputs = ort_session.run(output_names, {"image": image.numpy()})
- for i, output_name in enumerate(output_names):
- logger.info("output %s shape %s", output_name, outputs[i].shape)
- ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
- # ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
- if (
- compare_tensors_with_tolerance(
- "image_features_0",
- image_features_0,
- torch.tensor(ort_image_features_0),
- mismatch_percentage_tolerance=1,
- )
- and compare_tensors_with_tolerance(
- "image_features_1",
- image_features_1,
- torch.tensor(ort_image_features_1),
- mismatch_percentage_tolerance=1,
- )
- and compare_tensors_with_tolerance(
- "image_embeddings",
- image_embeddings,
- torch.tensor(ort_image_embeddings),
- mismatch_percentage_tolerance=1,
- )
- ):
- print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
- else:
- print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")
|