| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189 |
- # -------------------------------------------------------------------------
- # Copyright (R) Microsoft Corporation. All rights reserved.
- # Licensed under the MIT License.
- # --------------------------------------------------------------------------
- import logging
- import torch
- from sam2.modeling.sam2_base import SAM2Base
- from sam2_utils import compare_tensors_with_tolerance
- from torch import nn
- logger = logging.getLogger(__name__)
- class SAM2PromptEncoder(nn.Module):
- def __init__(self, sam_model: SAM2Base):
- super().__init__()
- self.prompt_encoder = sam_model.sam_prompt_encoder
- self.model = sam_model
- @torch.no_grad()
- def forward(
- self,
- point_coords: torch.Tensor,
- point_labels: torch.Tensor,
- input_masks: torch.Tensor,
- has_input_masks: torch.Tensor,
- ):
- """Encode prompts.
- Args:
- point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
- coordinate in (x, y) format of the P input points in image of size 1024x1024.
- point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
- positive (foreground), 0 means negative (background), -1 means padding,
- 2 (box left upper corner), 3 (box right bottom corner).
- input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
- Typically coming from a previous iteration.
- has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
- Returns:
- sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
- dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
- image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
- """
- sparse_embeddings = self._embed_points(point_coords, point_labels)
- dense_embeddings = self._embed_masks(input_masks, has_input_masks)
- image_pe = self.prompt_encoder.get_dense_pe()
- return sparse_embeddings, dense_embeddings, image_pe
- def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
- point_coords = point_coords + 0.5
- padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
- padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
- point_coords = torch.cat([point_coords, padding_point], dim=1)
- point_labels = torch.cat([point_labels, padding_label], dim=1)
- # Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
- point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
- point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
- point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
- point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
- point_embedding = point_embedding * (point_labels != -1)
- point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
- for i in range(self.prompt_encoder.num_point_embeddings):
- point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
- return point_embedding
- def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
- mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
- no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
- logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
- mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
- logger.info("mask_embedding.shape: %s", mask_embedding.shape)
- return mask_embedding
- def export_prompt_encoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- ):
- sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
- num_labels = 2
- num_points = 3
- point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
- point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
- input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
- has_input_masks = torch.ones(1, dtype=torch.float)
- sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
- point_coords, point_labels, input_masks, has_input_masks
- )
- logger.info("point_coords.shape: %s", point_coords.shape)
- logger.info("point_labels.shape: %s", point_labels.shape)
- logger.info("input_masks.shape: %s", input_masks.shape)
- logger.info("has_input_masks.shape: %s", has_input_masks.shape)
- logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
- logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
- logger.info("image_pe.shape: %s", image_pe.shape)
- torch.onnx.export(
- sam2_prompt_encoder,
- (point_coords, point_labels, input_masks, has_input_masks),
- onnx_model_path,
- export_params=True,
- opset_version=18,
- do_constant_folding=True,
- input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
- output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
- dynamic_axes={
- "point_coords": {0: "num_labels", 1: "num_points"},
- "point_labels": {0: "num_labels", 1: "num_points"},
- "input_masks": {0: "num_labels"},
- "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
- "dense_embeddings": {0: "num_labels"},
- },
- )
- print("prompt encoder onnx model saved to ", onnx_model_path)
- def test_prompt_encoder_onnx(
- sam2_model: SAM2Base,
- onnx_model_path: str,
- ):
- sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
- num_labels = 1
- num_points = 5
- point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
- point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
- input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
- has_input_masks = torch.ones(1, dtype=torch.float)
- sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
- point_coords, point_labels, input_masks, has_input_masks
- )
- import onnxruntime # noqa: PLC0415
- ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
- model_inputs = ort_session.get_inputs()
- input_names = [model_inputs[i].name for i in range(len(model_inputs))]
- logger.info("input_names: %s", input_names)
- model_outputs = ort_session.get_outputs()
- output_names = [model_outputs[i].name for i in range(len(model_outputs))]
- logger.info("output_names: %s", output_names)
- outputs = ort_session.run(
- output_names,
- {
- "point_coords": point_coords.numpy(),
- "point_labels": point_labels.numpy(),
- "input_masks": input_masks.numpy(),
- "has_input_masks": has_input_masks.numpy(),
- },
- )
- for i, output_name in enumerate(output_names):
- logger.info("output %s shape: %s", output_name, outputs[i].shape)
- ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
- if (
- compare_tensors_with_tolerance(
- "sparse_embeddings",
- sparse_embeddings,
- torch.tensor(ort_sparse_embeddings),
- mismatch_percentage_tolerance=0.2,
- )
- and compare_tensors_with_tolerance(
- "dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
- )
- and compare_tensors_with_tolerance(
- "image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
- )
- ):
- print(f"onnx model has been verified: {onnx_model_path}")
- else:
- print(f"onnx model verification failed: {onnx_model_path}")
|