# ------------------------------------------------------------------------- # Copyright (R) Microsoft Corporation. All rights reserved. # Licensed under the MIT License. # -------------------------------------------------------------------------- import os import matplotlib.image as mpimg import matplotlib.pyplot as plt import numpy as np import torch from matplotlib.patches import Rectangle from PIL import Image from sam2.sam2_image_predictor import SAM2ImagePredictor from sam2_image_onnx_predictor import SAM2ImageOnnxPredictor from sam2_utils import load_sam2_model import onnxruntime def show_mask(mask, ax, random_color=False, borders=True): if random_color: color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) else: color = np.array([30 / 255, 144 / 255, 255 / 255, 0.6]) h, w = mask.shape[-2:] mask = mask.astype(np.uint8) mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) if borders: import cv2 # noqa: PLC0415 contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) # Try to smooth contours contours = [cv2.approxPolyDP(contour, epsilon=0.01, closed=True) for contour in contours] mask_image = cv2.drawContours(mask_image, contours, -1, (1, 1, 1, 0.5), thickness=2) ax.imshow(mask_image) def show_points(coords, labels, ax, marker_size=375): pos_points = coords[labels == 1] neg_points = coords[labels == 0] ax.scatter( pos_points[:, 0], pos_points[:, 1], color="green", marker="*", s=marker_size, edgecolor="white", linewidth=1.25 ) ax.scatter( neg_points[:, 0], neg_points[:, 1], color="red", marker="*", s=marker_size, edgecolor="white", linewidth=1.25 ) def show_box(box, ax): x0, y0 = box[0], box[1] w, h = box[2] - box[0], box[3] - box[1] ax.add_patch(Rectangle((x0, y0), w, h, edgecolor="green", facecolor=(0, 0, 0, 0), lw=2)) def show_masks( image, masks, scores, point_coords=None, box_coords=None, input_labels=None, borders=True, output_image_file_prefix=None, image_files=None, ): for i, (mask, score) in enumerate(zip(masks, scores, strict=False)): plt.figure(figsize=(10, 10)) plt.imshow(image) show_mask(mask, plt.gca(), borders=borders) if point_coords is not None: assert input_labels is not None show_points(point_coords, input_labels, plt.gca()) if box_coords is not None: show_box(box_coords, plt.gca()) if len(scores) > 1: plt.title(f"Mask {i + 1}, Score: {score:.3f}", fontsize=18) plt.axis("off") if output_image_file_prefix: filename = f"{output_image_file_prefix}_{i}.png" if os.path.exists(filename): os.remove(filename) plt.savefig(filename, format="png", bbox_inches="tight", pad_inches=0) if isinstance(image_files, list): image_files.append(filename) plt.show(block=False) plt.close() def get_predictor( sam2_dir: str, device: str | torch.device, dtype: torch.dtype, model_type="sam2_hiera_large", engine="torch", image_encoder_onnx_path: str = "", image_decoder_onnx_path: str = "", image_decoder_multi_onnx_path: str = "", provider: str = "CUDAExecutionProvider", ): sam2_model = load_sam2_model(sam2_dir, model_type, device=device) if engine == "torch": predictor = SAM2ImagePredictor(sam2_model) else: predictor = SAM2ImageOnnxPredictor( sam2_model, image_encoder_onnx_path=image_encoder_onnx_path, image_decoder_onnx_path=image_decoder_onnx_path, image_decoder_multi_onnx_path=image_decoder_multi_onnx_path, provider=provider, device=device, onnx_dtype=dtype, ) return predictor def run_demo( sam2_dir: str, model_type: str = "sam2_hiera_large", engine: str = "torch", dtype: torch.dtype = torch.float32, image_encoder_onnx_path: str = "", image_decoder_onnx_path: str = "", image_decoder_multi_onnx_path: str = "", use_gpu: bool = True, enable_batch: bool = False, ): if use_gpu: assert torch.cuda.is_available() assert "CUDAExecutionProvider" in onnxruntime.get_available_providers() provider = "CUDAExecutionProvider" else: provider = "CPUExecutionProvider" device = torch.device("cuda" if use_gpu else "cpu") if use_gpu and engine == "torch" and torch.cuda.get_device_properties(0).major >= 8: # Turn on tfloat32 for Ampere GPUs. torch.backends.cuda.matmul.allow_tf32 = True torch.backends.cudnn.allow_tf32 = True np.random.seed(3) image = Image.open("truck.jpg") image = np.array(image.convert("RGB")) predictor = get_predictor( sam2_dir, device, dtype, model_type, engine, image_encoder_onnx_path, image_decoder_onnx_path, image_decoder_multi_onnx_path, provider=provider, ) predictor.set_image(image) prefix = f"sam2_demo_{engine}_" # The model returns masks, quality predictions for those masks, # and low resolution mask logits that can be passed to the next iteration of prediction. # With multimask_output=True (the default setting), SAM 2 outputs 3 masks, where # scores gives the model's own estimation of the quality of these masks. # For ambiguous prompts such as a single point, it is recommended to use multimask_output=True # even if only a single mask is desired; input_point = np.array([[500, 375]]) input_label = np.array([1]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, multimask_output=True, ) sorted_ind = np.argsort(scores)[::-1] masks = masks[sorted_ind] scores = scores[sorted_ind] logits = logits[sorted_ind] image_files = [] show_masks( image, masks, scores, point_coords=input_point, input_labels=input_label, borders=True, output_image_file_prefix=prefix + "multimask", image_files=image_files, ) # Multiple points. input_point = np.array([[500, 375], [1125, 625]]) input_label = np.array([1, 1]) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) show_masks( image, masks, scores, point_coords=input_point, input_labels=input_label, output_image_file_prefix=prefix + "multi_points", image_files=image_files, ) # Specify a window and a background point. input_point = np.array([[500, 375], [1125, 625]]) input_label = np.array([1, 0]) mask_input = logits[np.argmax(scores), :, :] # Choose the model's best mask masks, scores, _ = predictor.predict( point_coords=input_point, point_labels=input_label, mask_input=mask_input[None, :, :], multimask_output=False, ) show_masks( image, masks, scores, point_coords=input_point, input_labels=input_label, output_image_file_prefix=prefix + "background_point", image_files=image_files, ) # Take a box as input input_box = np.array([425, 600, 700, 875]) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=input_box[None, :], multimask_output=False, ) show_masks( image, masks, scores, box_coords=input_box, output_image_file_prefix=prefix + "box", image_files=image_files, ) # Combining points and boxes input_box = np.array([425, 600, 700, 875]) input_point = np.array([[575, 750]]) input_label = np.array([0]) masks, scores, logits = predictor.predict( point_coords=input_point, point_labels=input_label, box=input_box, multimask_output=False, ) show_masks( image, masks, scores, box_coords=input_box, point_coords=input_point, input_labels=input_label, output_image_file_prefix=prefix + "box_and_point", image_files=image_files, ) # TODO: support batched prompt inputs if enable_batch: input_boxes = np.array( [ [75, 275, 1725, 850], [425, 600, 700, 875], [1375, 550, 1650, 800], [1240, 675, 1400, 750], ] ) masks, scores, _ = predictor.predict( point_coords=None, point_labels=None, box=input_boxes, multimask_output=False, ) plt.figure(figsize=(10, 10)) plt.imshow(image) for mask in masks: show_mask(mask.squeeze(0), plt.gca(), random_color=True) for box in input_boxes: show_box(box, plt.gca()) plt.axis("off") plt.show() plt.savefig(prefix + "batch_prompt.png") image_files.append(prefix + "batch_prompt.png") return image_files def show_all_images(left_images, right_images, suffix=""): # Show images in two rows since display screen is horizontal in most cases. fig, axes = plt.subplots(nrows=2, ncols=len(left_images), figsize=(19.20, 10.80)) for i, (left_img_path, right_img_path) in enumerate(zip(left_images, right_images, strict=False)): left_img = mpimg.imread(left_img_path) right_img = mpimg.imread(right_img_path) axes[0, i].imshow(left_img) axes[0, i].set_title(left_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10) axes[0, i].axis("off") axes[0, i].set_aspect(left_img.shape[1] / left_img.shape[0]) axes[1, i].imshow(right_img) axes[1, i].set_title(right_img_path.replace("sam2_demo_", "").replace(".png", ""), fontsize=10) axes[1, i].axis("off") axes[1, i].set_aspect(right_img.shape[1] / right_img.shape[0]) plt.tight_layout() plt.savefig(f"sam2_demo{suffix}.png", format="png", bbox_inches="tight", dpi=1000) plt.show()