prompt_encoder.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import logging
  6. import torch
  7. from sam2.modeling.sam2_base import SAM2Base
  8. from sam2_utils import compare_tensors_with_tolerance
  9. from torch import nn
  10. logger = logging.getLogger(__name__)
  11. class SAM2PromptEncoder(nn.Module):
  12. def __init__(self, sam_model: SAM2Base):
  13. super().__init__()
  14. self.prompt_encoder = sam_model.sam_prompt_encoder
  15. self.model = sam_model
  16. @torch.no_grad()
  17. def forward(
  18. self,
  19. point_coords: torch.Tensor,
  20. point_labels: torch.Tensor,
  21. input_masks: torch.Tensor,
  22. has_input_masks: torch.Tensor,
  23. ):
  24. """Encode prompts.
  25. Args:
  26. point_coords (torch.Tensor): [L, P, 2] shape and float32 dtype and contains the absolute pixel
  27. coordinate in (x, y) format of the P input points in image of size 1024x1024.
  28. point_labels (torch.Tensor): shape [L, P] and int32 dtype, where 1 means
  29. positive (foreground), 0 means negative (background), -1 means padding,
  30. 2 (box left upper corner), 3 (box right bottom corner).
  31. input_masks (torch.Tensor): [L, 1, H/4, W/4]. Low resolution mask input to the model.
  32. Typically coming from a previous iteration.
  33. has_input_masks (torch.Tensor): [L]. 1.0 if input_masks is used, 0.0 otherwise.
  34. Returns:
  35. sparse_embeddings (torch.Tensor): [L, P+1, 256], embedding for points and boxes.
  36. dense_embeddings (torch.Tensor): [L, 256, 64, 64]. embedding for input masks.
  37. image_pe (torch.Tensor, optional): [1, 256, 64, 64]. image positional encoding.
  38. """
  39. sparse_embeddings = self._embed_points(point_coords, point_labels)
  40. dense_embeddings = self._embed_masks(input_masks, has_input_masks)
  41. image_pe = self.prompt_encoder.get_dense_pe()
  42. return sparse_embeddings, dense_embeddings, image_pe
  43. def _embed_points(self, point_coords: torch.Tensor, point_labels: torch.Tensor) -> torch.Tensor:
  44. point_coords = point_coords + 0.5
  45. padding_point = torch.zeros((point_coords.shape[0], 1, 2), device=point_coords.device)
  46. padding_label = -torch.ones((point_labels.shape[0], 1), device=point_labels.device)
  47. point_coords = torch.cat([point_coords, padding_point], dim=1)
  48. point_labels = torch.cat([point_labels, padding_label], dim=1)
  49. # Note that the input coordinates are based on image size 1024x1024. Here we normalize it to [0.0, 1.0).
  50. point_coords[:, :, 0] = point_coords[:, :, 0] / self.model.image_size
  51. point_coords[:, :, 1] = point_coords[:, :, 1] / self.model.image_size
  52. point_embedding = self.prompt_encoder.pe_layer._pe_encoding(point_coords)
  53. point_labels = point_labels.unsqueeze(-1).expand_as(point_embedding)
  54. point_embedding = point_embedding * (point_labels != -1)
  55. point_embedding = point_embedding + self.prompt_encoder.not_a_point_embed.weight * (point_labels == -1)
  56. for i in range(self.prompt_encoder.num_point_embeddings):
  57. point_embedding = point_embedding + self.prompt_encoder.point_embeddings[i].weight * (point_labels == i)
  58. return point_embedding
  59. def _embed_masks(self, input_masks: torch.Tensor, has_input_masks: torch.Tensor) -> torch.Tensor:
  60. mask_embedding = self.prompt_encoder.mask_downscaling(input_masks)
  61. no_mask_embedding = self.prompt_encoder.no_mask_embed.weight.reshape(1, -1, 1, 1)
  62. logger.info("no_mask_embedding.shape: %s", no_mask_embedding.shape)
  63. mask_embedding = has_input_masks * mask_embedding + (1.0 - has_input_masks) * no_mask_embedding
  64. logger.info("mask_embedding.shape: %s", mask_embedding.shape)
  65. return mask_embedding
  66. def export_prompt_encoder_onnx(
  67. sam2_model: SAM2Base,
  68. onnx_model_path: str,
  69. ):
  70. sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
  71. num_labels = 2
  72. num_points = 3
  73. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  74. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
  75. input_masks = torch.zeros(num_labels, 1, 256, 256, dtype=torch.float)
  76. has_input_masks = torch.ones(1, dtype=torch.float)
  77. sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
  78. point_coords, point_labels, input_masks, has_input_masks
  79. )
  80. logger.info("point_coords.shape: %s", point_coords.shape)
  81. logger.info("point_labels.shape: %s", point_labels.shape)
  82. logger.info("input_masks.shape: %s", input_masks.shape)
  83. logger.info("has_input_masks.shape: %s", has_input_masks.shape)
  84. logger.info("sparse_embeddings.shape: %s", sparse_embeddings.shape)
  85. logger.info("dense_embeddings.shape: %s", dense_embeddings.shape)
  86. logger.info("image_pe.shape: %s", image_pe.shape)
  87. torch.onnx.export(
  88. sam2_prompt_encoder,
  89. (point_coords, point_labels, input_masks, has_input_masks),
  90. onnx_model_path,
  91. export_params=True,
  92. opset_version=18,
  93. do_constant_folding=True,
  94. input_names=["point_coords", "point_labels", "input_masks", "has_input_masks"],
  95. output_names=["sparse_embeddings", "dense_embeddings", "image_pe"],
  96. dynamic_axes={
  97. "point_coords": {0: "num_labels", 1: "num_points"},
  98. "point_labels": {0: "num_labels", 1: "num_points"},
  99. "input_masks": {0: "num_labels"},
  100. "sparse_embeddings": {0: "num_labels", 1: "num_points+1"},
  101. "dense_embeddings": {0: "num_labels"},
  102. },
  103. )
  104. print("prompt encoder onnx model saved to ", onnx_model_path)
  105. def test_prompt_encoder_onnx(
  106. sam2_model: SAM2Base,
  107. onnx_model_path: str,
  108. ):
  109. sam2_prompt_encoder = SAM2PromptEncoder(sam2_model).cpu()
  110. num_labels = 1
  111. num_points = 5
  112. point_coords = torch.randint(low=0, high=1024, size=(num_labels, num_points, 2), dtype=torch.float)
  113. point_labels = torch.randint(low=0, high=1, size=(num_labels, num_points), dtype=torch.int32)
  114. input_masks = torch.rand(num_labels, 1, 256, 256, dtype=torch.float)
  115. has_input_masks = torch.ones(1, dtype=torch.float)
  116. sparse_embeddings, dense_embeddings, image_pe = sam2_prompt_encoder(
  117. point_coords, point_labels, input_masks, has_input_masks
  118. )
  119. import onnxruntime # noqa: PLC0415
  120. ort_session = onnxruntime.InferenceSession(onnx_model_path, providers=["CPUExecutionProvider"])
  121. model_inputs = ort_session.get_inputs()
  122. input_names = [model_inputs[i].name for i in range(len(model_inputs))]
  123. logger.info("input_names: %s", input_names)
  124. model_outputs = ort_session.get_outputs()
  125. output_names = [model_outputs[i].name for i in range(len(model_outputs))]
  126. logger.info("output_names: %s", output_names)
  127. outputs = ort_session.run(
  128. output_names,
  129. {
  130. "point_coords": point_coords.numpy(),
  131. "point_labels": point_labels.numpy(),
  132. "input_masks": input_masks.numpy(),
  133. "has_input_masks": has_input_masks.numpy(),
  134. },
  135. )
  136. for i, output_name in enumerate(output_names):
  137. logger.info("output %s shape: %s", output_name, outputs[i].shape)
  138. ort_sparse_embeddings, ort_dense_embeddings, ort_image_pe = outputs
  139. if (
  140. compare_tensors_with_tolerance(
  141. "sparse_embeddings",
  142. sparse_embeddings,
  143. torch.tensor(ort_sparse_embeddings),
  144. mismatch_percentage_tolerance=0.2,
  145. )
  146. and compare_tensors_with_tolerance(
  147. "dense_embeddings", dense_embeddings, torch.tensor(ort_dense_embeddings), mismatch_percentage_tolerance=0.2
  148. )
  149. and compare_tensors_with_tolerance(
  150. "image_pe", image_pe, torch.tensor(ort_image_pe), mismatch_percentage_tolerance=0.2
  151. )
  152. ):
  153. print(f"onnx model has been verified: {onnx_model_path}")
  154. else:
  155. print(f"onnx model verification failed: {onnx_model_path}")