sam2_demo.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321
  1. # -------------------------------------------------------------------------
  2. # Copyright (R) Microsoft Corporation. All rights reserved.
  3. # Licensed under the MIT License.
  4. # --------------------------------------------------------------------------
  5. import os
  6. import matplotlib.image as mpimg
  7. import matplotlib.pyplot as plt
  8. import numpy as np
  9. import torch
  10. from matplotlib.patches import Rectangle
  11. from PIL import Image
  12. from sam2.sam2_image_predictor import SAM2ImagePredictor
  13. from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor
  14. from sam2_utils import load_sam2_model
  15. import onnxruntime
  16. def show_mask(mask, ax, random_color=False, borders=True):
  17. if random_color:
  18. color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
  19. else:
  20. color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6])
  21. h, w = mask.shape[-2:]
  22. mask = mask.astype(np.uint8)
  23. mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
  24. if borders:
  25. import cv2 # noqa: PLC0415
  26. contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
  27. # Try to smooth contours
  28. contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours]
  29. mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2)
  30. ax.imshow(mask_image)
  31. def show_points(coords, labels, ax, marker_size=375):
  32. pos_points = coords[labels == 1]
  33. neg_points = coords[labels == 0]
  34. ax.scatter(
  35. pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
  36. )
  37. ax.scatter(
  38. neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25
  39. )
  40. def show_box(box, ax):
  41. x0, y0 = box[0], box[1]
  42. w, h = box[2] - box[0], box[3] - box[1]
  43. ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2))
  44. def show_masks(
  45. image,
  46. masks,
  47. scores,
  48. point_coords=None,
  49. box_coords=None,
  50. input_labels=None,
  51. borders=True,
  52. output_image_file_prefix=None,
  53. image_files=None,
  54. ):
  55. for i, (mask, score) in enumerate(zip(masks, scores, strict=False)):
  56. plt.figure(figsize=(10, 10))
  57. plt.imshow(image)
  58. show_mask(mask, plt.gca(), borders=borders)
  59. if point_coords is not None:
  60. assert input_labels is not None
  61. show_points(point_coords, input_labels, plt.gca())
  62. if box_coords is not None:
  63. show_box(box_coords, plt.gca())
  64. if len(scores) > 1:
  65. plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18)
  66. plt.axis("off")
  67. if output_image_file_prefix:
  68. filename = f"{output_image_file_prefix}_{i}.png"
  69. if os.path.exists(filename):
  70. os.remove(filename)
  71. plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0)
  72. if isinstance(image_files, list):
  73. image_files.append(filename)
  74. plt.show(block=False)
  75. plt.close()
  76. def get_predictor(
  77. sam2_dir: str,
  78. device: str | torch.device,
  79. dtype: torch.dtype,
  80. model_type="sam2_hiera_large",
  81. engine="torch",
  82. image_encoder_onnx_path: str = "",
  83. image_decoder_onnx_path: str = "",
  84. image_decoder_multi_onnx_path: str = "",
  85. provider: str = "CUDAExecutionProvider",
  86. ):
  87. sam2_model = load_sam2_model(sam2_dir, model_type, device=device)
  88. if engine == "torch":
  89. predictor = SAM2ImagePredictor(sam2_model)
  90. else:
  91. predictor = SAM2ImageOnnxPredictor(
  92. sam2_model,
  93. image_encoder_onnx_path=image_encoder_onnx_path,
  94. image_decoder_onnx_path=image_decoder_onnx_path,
  95. image_decoder_multi_onnx_path=image_decoder_multi_onnx_path,
  96. provider=provider,
  97. device=device,
  98. onnx_dtype=dtype,
  99. )
  100. return predictor
  101. def run_demo(
  102. sam2_dir: str,
  103. model_type: str = "sam2_hiera_large",
  104. engine: str = "torch",
  105. dtype: torch.dtype = torch.float32,
  106. image_encoder_onnx_path: str = "",
  107. image_decoder_onnx_path: str = "",
  108. image_decoder_multi_onnx_path: str = "",
  109. use_gpu: bool = True,
  110. enable_batch: bool = False,
  111. ):
  112. if use_gpu:
  113. assert torch.cuda.is_available()
  114. assert "CUDAExecutionProvider" in onnxruntime.get_available_providers()
  115. provider = "CUDAExecutionProvider"
  116. else:
  117. provider = "CPUExecutionProvider"
  118. device = torch.device("cuda" if use_gpu else "cpu")
  119. if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8:
  120. # Turn on tfloat32 for Ampere GPUs.
  121. torch.backends.cuda.matmul.allow_tf32 = True
  122. torch.backends.cudnn.allow_tf32 = True
  123. np.random.seed(3)
  124. image = Image.open("truck.jpg")
  125. image = np.array(image.convert("RGB"))
  126. predictor = get_predictor(
  127. sam2_dir,
  128. device,
  129. dtype,
  130. model_type,
  131. engine,
  132. image_encoder_onnx_path,
  133. image_decoder_onnx_path,
  134. image_decoder_multi_onnx_path,
  135. provider=provider,
  136. )
  137. predictor.set_image(image)
  138. prefix = f"sam2_demo_{engine}_"
  139. # The model returns masks, quality predictions for those masks,
  140. # and low resolution mask logits that can be passed to the next iteration of prediction.
  141. # With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where
  142. # scores gives the model's own estimation of the quality of these masks.
  143. # For ambiguous prompts such as a single point, it is recommended to use multimask_output=True
  144. # even if only a single mask is desired;
  145. input_point = np.array([[500, 375]])
  146. input_label = np.array([1])
  147. masks, scores, logits = predictor.predict(
  148. point_coords=input_point,
  149. point_labels=input_label,
  150. multimask_output=True,
  151. )
  152. sorted_ind = np.argsort(scores)[::-1]
  153. masks = masks[sorted_ind]
  154. scores = scores[sorted_ind]
  155. logits = logits[sorted_ind]
  156. image_files = []
  157. show_masks(
  158. image,
  159. masks,
  160. scores,
  161. point_coords=input_point,
  162. input_labels=input_label,
  163. borders=True,
  164. output_image_file_prefix=prefix + "multimask",
  165. image_files=image_files,
  166. )
  167. # Multiple points.
  168. input_point = np.array([[500, 375], [1125, 625]])
  169. input_label = np.array([1, 1])
  170. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  171. masks, scores, _ = predictor.predict(
  172. point_coords=input_point,
  173. point_labels=input_label,
  174. mask_input=mask_input[None, :, :],
  175. multimask_output=False,
  176. )
  177. show_masks(
  178. image,
  179. masks,
  180. scores,
  181. point_coords=input_point,
  182. input_labels=input_label,
  183. output_image_file_prefix=prefix + "multi_points",
  184. image_files=image_files,
  185. )
  186. # Specify a window and a background point.
  187. input_point = np.array([[500, 375], [1125, 625]])
  188. input_label = np.array([1, 0])
  189. mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask
  190. masks, scores, _ = predictor.predict(
  191. point_coords=input_point,
  192. point_labels=input_label,
  193. mask_input=mask_input[None, :, :],
  194. multimask_output=False,
  195. )
  196. show_masks(
  197. image,
  198. masks,
  199. scores,
  200. point_coords=input_point,
  201. input_labels=input_label,
  202. output_image_file_prefix=prefix + "background_point",
  203. image_files=image_files,
  204. )
  205. # Take a box as input
  206. input_box = np.array([425, 600, 700, 875])
  207. masks, scores, _ = predictor.predict(
  208. point_coords=None,
  209. point_labels=None,
  210. box=input_box[None, :],
  211. multimask_output=False,
  212. )
  213. show_masks(
  214. image,
  215. masks,
  216. scores,
  217. box_coords=input_box,
  218. output_image_file_prefix=prefix + "box",
  219. image_files=image_files,
  220. )
  221. # Combining points and boxes
  222. input_box = np.array([425, 600, 700, 875])
  223. input_point = np.array([[575, 750]])
  224. input_label = np.array([0])
  225. masks, scores, logits = predictor.predict(
  226. point_coords=input_point,
  227. point_labels=input_label,
  228. box=input_box,
  229. multimask_output=False,
  230. )
  231. show_masks(
  232. image,
  233. masks,
  234. scores,
  235. box_coords=input_box,
  236. point_coords=input_point,
  237. input_labels=input_label,
  238. output_image_file_prefix=prefix + "box_and_point",
  239. image_files=image_files,
  240. )
  241. # TODO: support batched prompt inputs
  242. if enable_batch:
  243. input_boxes = np.array(
  244. [
  245. [75, 275, 1725, 850],
  246. [425, 600, 700, 875],
  247. [1375, 550, 1650, 800],
  248. [1240, 675, 1400, 750],
  249. ]
  250. )
  251. masks, scores, _ = predictor.predict(
  252. point_coords=None,
  253. point_labels=None,
  254. box=input_boxes,
  255. multimask_output=False,
  256. )
  257. plt.figure(figsize=(10, 10))
  258. plt.imshow(image)
  259. for mask in masks:
  260. show_mask(mask.squeeze(0), plt.gca(), random_color=True)
  261. for box in input_boxes:
  262. show_box(box, plt.gca())
  263. plt.axis("off")
  264. plt.show()
  265. plt.savefig(prefix + "batch_prompt.png")
  266. image_files.append(prefix + "batch_prompt.png")
  267. return image_files
  268. def show_all_images(left_images, right_images, suffix=""):
  269. # Show images in two rows since display screen is horizontal in most cases.
  270. fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80))
  271. for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)):
  272. left_img = mpimg.imread(left_img_path)
  273. right_img = mpimg.imread(right_img_path)
  274. axes[0, i].imshow(left_img)
  275. axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
  276. axes[0, i].axis("off")
  277. axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0])
  278. axes[1, i].imshow(right_img)
  279. axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10)
  280. axes[1, i].axis("off")
  281. axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0])
  282. plt.tight_layout()
  283. plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000)
  284. plt.show()