image_encoder.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import warnings
  7. import torch
  8. from sam2.modeling.sam2_base import SAM2Base
  9. from sam2_utils import compare_tensors_with_tolerance, random_sam2_input_image
  10. from torch import nn
  11. import onnxruntime
  12. logger = logging.getLogger(__name__)
  13. class SAM2ImageEncoder(nn.Module):
  14. def __init__(self, sam_model: SAM2Base) -> None:
  15. super().__init__()
  16. self.model = sam_model
  17. self.image_encoder = sam_model.image_encoder
  18. self.no_mem_embed = sam_model.no_mem_embed
  19. def forward(
  20. self,
  21. image: torch.Tensor,
  22. enable_nvtx_profile: bool = False,
  23. ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
  24. """
  25. Encodes images into features.
  26. Only supports H=W=1024. If you want to use different image sizes like 512x512,
  27. see https://github.com/facebookresearch/segment-anything-2/issues/138.
  28. Args:
  29. image (torch.Tensor): images of shape [B, 3, H, W], B is batch size, H and W are height and width.
  30. enable_nvtx_profile (bool): enable NVTX profiling.
  31. Returns:
  32. image_features_0: image features of shape [B, 32, H/4, W/4] - high resolution features of level 0
  33. image_features_1: image features of shape [B, 64, H/8, W/8] - high resolution features of level 1
  34. image_embeddings: image features of shape [B, 256, H/16, W/16] - 16 is the backbone_stride
  35. """
  36. nvtx_helper = None
  37. if enable_nvtx_profile:
  38. from nvtx_helper import NvtxHelper # noqa: PLC0415
  39. nvtx_helper = NvtxHelper(["image_encoder", "post_process"])
  40. if nvtx_helper is not None:
  41. nvtx_helper.start_profile("image_encoder")
  42. backbone_out = self.image_encoder(image)
  43. if nvtx_helper is not None:
  44. nvtx_helper.stop_profile("image_encoder")
  45. nvtx_helper.start_profile("post_process")
  46. # precompute projected level 0 and level 1 features in SAM decoder
  47. # to avoid running it again on every SAM click
  48. backbone_out["backbone_fpn"][0] = self.model.sam_mask_decoder.conv_s0(backbone_out["backbone_fpn"][0])
  49. backbone_out["backbone_fpn"][1] = self.model.sam_mask_decoder.conv_s1(backbone_out["backbone_fpn"][1])
  50. # Prepare and flatten visual features.
  51. feature_maps = backbone_out["backbone_fpn"][-self.model.num_feature_levels :]
  52. vision_pos_embeds = backbone_out["vision_pos_enc"][-self.model.num_feature_levels :]
  53. feat_sizes = [(x.shape[-2], x.shape[-1]) for x in vision_pos_embeds]
  54. # flatten NxCxHxW to HWxNxC
  55. # TODO: we should avoid this transpose since it will be transposed back to NCHW later.
  56. vision_feats = [x.flatten(2).permute(2, 0, 1) for x in feature_maps]
  57. vision_feats[-1] = vision_feats[-1] + self.no_mem_embed
  58. feats = [
  59. feat.permute(1, 2, 0).reshape(1, -1, *feat_size)
  60. for feat, feat_size in zip(vision_feats[::-1], feat_sizes[::-1], strict=False)
  61. ][::-1]
  62. if nvtx_helper is not None:
  63. nvtx_helper.stop_profile("post_process")
  64. nvtx_helper.print_latency()
  65. return feats[0], feats[1], feats[2]
  66. def export_image_encoder_onnx(
  67. sam2_model: SAM2Base,
  68. onnx_model_path: str,
  69. dynamic_batch_axes: bool = False,
  70. verbose: bool = False,
  71. dynamo: bool = False,
  72. clear_dynamo_metadata: bool = False,
  73. ):
  74. image = random_sam2_input_image()
  75. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  76. image_features_0, image_features_1, image_embeddings = sam2_encoder(image)
  77. logger.info("image.shape: %s", image.shape)
  78. logger.info("image_features_0.shape: %s", image_features_0.shape)
  79. logger.info("image_features_1.shape: %s", image_features_1.shape)
  80. logger.info("image_embeddings.shape: %s", image_embeddings.shape)
  81. dynamic_axes = None
  82. if dynamic_batch_axes:
  83. dynamic_axes = {
  84. "image": {0: "batch_size"},
  85. "image_features_0": {0: "batch_size"},
  86. "image_features_1": {0: "batch_size"},
  87. "image_embeddings": {0: "batch_size"},
  88. }
  89. with warnings.catch_warnings():
  90. if not verbose:
  91. warnings.filterwarnings("ignore", category=torch.jit.TracerWarning)
  92. warnings.filterwarnings("ignore", category=UserWarning)
  93. if not dynamo:
  94. torch.onnx.export(
  95. sam2_encoder,
  96. image,
  97. onnx_model_path,
  98. export_params=True,
  99. opset_version=17,
  100. do_constant_folding=True,
  101. input_names=["image"],
  102. output_names=["image_features_0", "image_features_1", "image_embeddings"],
  103. dynamic_axes=dynamic_axes,
  104. )
  105. else:
  106. torch._dynamo.config.capture_scalar_outputs = True
  107. ep = torch.export.export(
  108. sam2_encoder,
  109. args=(image,),
  110. strict=False,
  111. dynamic_shapes=[
  112. {0: torch.export.Dim.AUTO},
  113. ],
  114. )
  115. onnx_program = torch.onnx.export(
  116. ep,
  117. (),
  118. opset_version=17,
  119. input_names=["image"],
  120. output_names=["image_features_0", "image_features_1", "image_embeddings"],
  121. dynamo=True,
  122. )
  123. onnx_program.optimize()
  124. onnx_program.save(onnx_model_path + ".dynamo.onnx", external_data=False)
  125. import onnx # noqa: PLC0415
  126. from onnxruntime.transformers.dynamo_onnx_helper import DynamoOnnxHelper # noqa: PLC0415
  127. onnx_model = onnx.load_model(onnx_model_path + ".dynamo.onnx", load_external_data=True)
  128. if dynamic_batch_axes:
  129. # Fix labels of dynamic axes since they can't be specified during Dynamo export currently
  130. onnx_model.graph.input[0].type.tensor_type.shape.dim[0].dim_param = "batch_size"
  131. for i in range(3):
  132. onnx_model.graph.output[i].type.tensor_type.shape.dim[0].dim_param = "batch_size"
  133. onnx_model_helper = DynamoOnnxHelper(onnx_model)
  134. onnx_model_helper.convert_constants_to_initializers()
  135. if clear_dynamo_metadata:
  136. onnx_model_helper.clear_metadata()
  137. import os # noqa: PLC0415
  138. if os.path.exists(onnx_model_path):
  139. os.remove(onnx_model_path)
  140. if os.path.exists(onnx_model_path + ".data"):
  141. os.remove(onnx_model_path + ".data")
  142. onnx_model_helper.model.save_model_to_file(
  143. onnx_model_path, use_external_data_format=True, all_tensors_to_one_file=True, convert_attribute=True
  144. )
  145. print("encoder onnx model saved to", onnx_model_path)
  146. def test_image_encoder_onnx(
  147. sam2_model: SAM2Base,
  148. onnx_model_path: str,
  149. dynamic_batch_axes=False,
  150. ):
  151. ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
  152. model_inputs = ort_session.get_inputs()
  153. input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  154. logger.info("input_names: %s", input_names)
  155. model_outputs = ort_session.get_outputs()
  156. output_names = [model_outputs[i].name for i in range(len(model_outputs))]
  157. logger.info("output_names: %s", output_names)
  158. batch_sizes = [1, 2] if dynamic_batch_axes else [1]
  159. for batch_size in batch_sizes:
  160. image = random_sam2_input_image(batch_size)
  161. sam2_encoder = SAM2ImageEncoder(sam2_model).cpu()
  162. image_features_0, image_features_1, image_embeddings = sam2_encoder(image.clone())
  163. logger.info("image.shape: %s", image.shape)
  164. logger.info("image_features_0.shape: %s", image_features_0.shape)
  165. logger.info("image_features_1.shape: %s", image_features_1.shape)
  166. logger.info("image_embeddings.shape: %s", image_embeddings.shape)
  167. outputs = ort_session.run(output_names, {"image": image.numpy()})
  168. for i, output_name in enumerate(output_names):
  169. logger.info("output %s shape %s", output_name, outputs[i].shape)
  170. ort_image_features_0, ort_image_features_1, ort_image_embeddings = outputs
  171. # ONNXRuntime and PyTorch has about 0.75% mismatched elements, but seems not impacting segmentation results.
  172. if (
  173. compare_tensors_with_tolerance(
  174. "image_features_0",
  175. image_features_0,
  176. torch.tensor(ort_image_features_0),
  177. mismatch_percentage_tolerance=1,
  178. )
  179. and compare_tensors_with_tolerance(
  180. "image_features_1",
  181. image_features_1,
  182. torch.tensor(ort_image_features_1),
  183. mismatch_percentage_tolerance=1,
  184. )
  185. and compare_tensors_with_tolerance(
  186. "image_embeddings",
  187. image_embeddings,
  188. torch.tensor(ort_image_embeddings),
  189. mismatch_percentage_tolerance=1,
  190. )
  191. ):
  192. print(f"onnx model has been verified for batch_size={batch_size}: {onnx_model_path}")
  193. else:
  194. print(f"onnx model verification failed for batch_size={batch_size}: {onnx_model_path}")