# Copyright (c) Alibaba, Inc. and its affiliates. import os import cv2 import matplotlib import matplotlib.cm as cm import matplotlib.pyplot as plt import numpy as np import torch.nn.functional as F from PIL import Image from modelscope.outputs import OutputKeys from modelscope.preprocessors.image import load_image from modelscope.utils import logger as logging logger = logging.get_logger() class InputPadder: """ Pads images such that dimensions are divisible by 8 """ def __init__(self, dims, mode='sintel'): self.ht, self.wd = dims[-2:] pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8 pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8 if mode == 'sintel': self._pad = [ pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2, pad_ht - pad_ht // 2 ] else: self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht] def pad(self, *inputs): return [F.pad(x, self._pad, mode='replicate') for x in inputs] def unpad(self, x): ht, wd = x.shape[-2:] c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]] return x[..., c[0]:c[1], c[2]:c[3]] def numpy_to_cv2img(img_array): """to convert a np.array with shape(h, w) to cv2 img Args: img_array (np.array): input data Returns: cv2 img """ img_array = (img_array - img_array.min()) / ( img_array.max() - img_array.min() + 1e-5) img_array = (img_array * 255).astype(np.uint8) img_array = cv2.applyColorMap(img_array, cv2.COLORMAP_JET) return img_array def draw_joints(image, np_kps, score, threshold=0.2): lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14] lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15] lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16] lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1] lst_left_ids_15 = [2, 3, 4, 8, 9, 10] lst_right_ids_15 = [5, 6, 7, 11, 12, 13] if np_kps.shape[0] == 17: lst_parent_ids = lst_parent_ids_17 lst_left_ids = lst_left_ids_17 lst_right_ids = lst_right_ids_17 elif np_kps.shape[0] == 15: lst_parent_ids = lst_parent_ids_15 lst_left_ids = lst_left_ids_15 lst_right_ids = lst_right_ids_15 for i in range(len(lst_parent_ids)): pid = lst_parent_ids[i] if i == pid: continue if (score[i] < threshold or score[1] < threshold): continue if i in lst_left_ids and pid in lst_left_ids: color = (0, 255, 0) elif i in lst_right_ids and pid in lst_right_ids: color = (255, 0, 0) else: color = (0, 255, 255) cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), (int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3) for i in range(np_kps.shape[0]): if score[i] < threshold: continue cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5, (0, 0, 255), -1) def draw_box(image, box): cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) def realtime_object_detection_bbox_vis(image, bboxes): for bbox in bboxes: cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]), (255, 0, 0), 2) return image def draw_attribute(image, box, labels): cv2.rectangle(image, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) title = [ 'gender : ', 'age : ', 'orient : ', 'hat : ', 'glass : ', 'hand_bag : ', 'shoulder_bag: ', 'back_pack : ', 'upper_wear : ', 'lower_wear : ', 'upper_color : ', 'lower_color : ' ] clr = (np.random.randint(0, 255), np.random.randint(0, 255), np.random.randint(0, 255)) point = (int(box[0] + 5), int(box[1] + 20)) for idx, lb in enumerate(labels): sz = title[idx] + lb cv2.putText(image, f'{sz}', (point[0], point[1] + idx * 20), cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 1) def draw_keypoints(output, original_image): poses = np.array(output[OutputKeys.KEYPOINTS]) scores = np.array(output[OutputKeys.SCORES]) boxes = np.array(output[OutputKeys.BOXES]) assert len(poses) == len(scores) and len(poses) == len(boxes) image = cv2.imread(original_image, -1) for i in range(len(poses)): draw_box(image, np.array(boxes[i])) draw_joints(image, np.array(poses[i]), np.array(scores[i])) return image def draw_pedestrian_attribute(output, original_image): labels = np.array(output[OutputKeys.LABELS]) boxes = np.array(output[OutputKeys.BOXES]) assert len(labels) == len(boxes) image = cv2.imread(original_image, -1) for i in range(len(boxes)): draw_attribute(image, np.array(boxes[i]), labels[i]) return image def draw_106face_keypoints(in_path, keypoints, boxes, scale=4.0, save_path=None): face_contour_point_index = [ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32 ] left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33] right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42] left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66] right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75] nose_bridge_point_index = [51, 52, 53, 54] nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65] mouth_outer_point_index = [ 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84 ] mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96] img = cv2.imread(in_path) for i in range(len(boxes)): draw_box(img, np.array(boxes[i])) image = cv2.resize(img, dsize=None, fx=scale, fy=scale) def draw_line(point_index, image, point): for i in range(len(point_index) - 1): cur_index = point_index[i] next_index = point_index[i + 1] cur_pt = (int(point[cur_index][0] * scale), int(point[cur_index][1] * scale)) next_pt = (int(point[next_index][0] * scale), int(point[next_index][1] * scale)) cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2) for i in range(len(keypoints)): points = keypoints[i] draw_line(face_contour_point_index, image, points) draw_line(left_eye_brow_point_index, image, points) draw_line(right_eye_brow_point_index, image, points) draw_line(left_eye_point_index, image, points) draw_line(right_eye_point_index, image, points) draw_line(nose_bridge_point_index, image, points) draw_line(nose_contour_point_index, image, points) draw_line(mouth_outer_point_index, image, points) draw_line(mouth_inter_point_index, image, points) size = len(points) for i in range(size): x = int(points[i][0]) y = int(points[i][1]) cv2.putText(image, str(i), (int(x * scale), int(y * scale)), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1) cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0), cv2.FILLED) if save_path is not None: cv2.imwrite(save_path, image) return image def draw_face_detection_no_lm_result(img_path, detection_result): bboxes = np.array(detection_result[OutputKeys.BOXES]) scores = np.array(detection_result[OutputKeys.SCORES]) img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" for i in range(len(scores)): bbox = bboxes[i].astype(np.int32) x1, y1, x2, y2 = bbox score = scores[i] cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) cv2.putText( img, f'{score:.2f}', (x1, y2), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) print(f'Found {len(scores)} faces') return img def draw_facial_expression_result(img_path, facial_expression_result): scores = facial_expression_result[OutputKeys.SCORES] labels = facial_expression_result[OutputKeys.LABELS] label = labels[np.argmax(scores)] img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" cv2.putText( img, 'facial expression: {}'.format(label), (10, 10), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) print('facial expression: {}'.format(label)) return img def draw_face_attribute_result(img_path, face_attribute_result): scores = face_attribute_result[OutputKeys.SCORES] labels = face_attribute_result[OutputKeys.LABELS] label_gender = labels[0][np.argmax(scores[0])] label_age = labels[1][np.argmax(scores[1])] img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" cv2.putText( img, 'face gender: {}'.format(label_gender), (10, 10), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) cv2.putText( img, 'face age interval: {}'.format(label_age), (10, 40), 1, 1.0, (255, 0, 0), thickness=1, lineType=8) logger.info('face gender: {}'.format(label_gender)) logger.info('face age interval: {}'.format(label_age)) return img def draw_face_detection_result(img_path, detection_result): bboxes = np.array(detection_result[OutputKeys.BOXES]) kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) scores = np.array(detection_result[OutputKeys.SCORES]) img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" for i in range(len(scores)): bbox = bboxes[i].astype(np.int32) kps = kpss[i].reshape(-1, 2).astype(np.int32) score = scores[i] x1, y1, x2, y2 = bbox cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2) for kp in kps: cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1) cv2.putText( img, f'{score:.2f}', (x1, y2), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) print(f'Found {len(scores)} faces') return img def draw_card_detection_result(img_path, detection_result): def warp_img(src_img, kps, ratio): short_size = 500 if ratio > 1: obj_h = short_size obj_w = int(obj_h * ratio) else: obj_w = short_size obj_h = int(obj_w / ratio) input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]]) output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0], [obj_w - 1, obj_h - 1]]) M = cv2.getPerspectiveTransform(input_pts, output_pts) obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h)) return obj_img bboxes = np.array(detection_result[OutputKeys.BOXES]) kpss = np.array(detection_result[OutputKeys.KEYPOINTS]) scores = np.array(detection_result[OutputKeys.SCORES]) img_list = [] ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)] img = cv2.imread(img_path) img_list += [img] assert img is not None, f"Can't read img: {img_path}" for i in range(len(scores)): bbox = bboxes[i].astype(np.int32) kps = kpss[i].reshape(-1, 2).astype(np.int32) _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2 _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2 ratio = 1.59 if _w >= _h else 1 / 1.59 card_img = warp_img(img, kps, ratio) img_list += [card_img] score = scores[i] x1, y1, x2, y2 = bbox cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4) for k, kp in enumerate(kps): cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10) cv2.putText( img, f'{score:.2f}', (x1, y2), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) return img_list def created_boxed_image(image_in, box): image = load_image(image_in) img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR) cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 255, 0), 3) return img def show_video_tracking_result(video_in_path, bboxes, video_save_path): cap = cv2.VideoCapture(video_in_path) for i in range(len(bboxes)): box = bboxes[i] success, frame = cap.read() if success is False: raise Exception(video_in_path, ' can not be correctly decoded by OpenCV.') if i == 0: size = (frame.shape[1], frame.shape[0]) fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') video_writer = cv2.VideoWriter(video_save_path, fourcc, cap.get(cv2.CAP_PROP_FPS), size, True) cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0), 5) video_writer.write(frame) video_writer.release cap.release() def show_video_object_detection_result(video_in_path, bboxes_list, labels_list, video_save_path): PALETTE = { 'person': [128, 0, 0], 'bicycle': [128, 128, 0], 'car': [64, 0, 0], 'motorcycle': [0, 128, 128], 'bus': [64, 128, 0], 'truck': [192, 128, 0], 'traffic light': [64, 0, 128], 'stop sign': [192, 0, 128], } from tqdm import tqdm import math cap = cv2.VideoCapture(video_in_path) with tqdm(total=len(bboxes_list)) as pbar: pbar.set_description( 'Writing results to video: {}'.format(video_save_path)) for i in range(len(bboxes_list)): bboxes = bboxes_list[i].astype(int) labels = labels_list[i] success, frame = cap.read() if success is False: raise Exception(video_in_path, ' can not be correctly decoded by OpenCV.') if i == 0: size = (frame.shape[1], frame.shape[0]) fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') video_writer = cv2.VideoWriter(video_save_path, fourcc, cap.get(cv2.CAP_PROP_FPS), size, True) FONT_SCALE = 1e-3 # Adjust for larger font size in all images THICKNESS_SCALE = 1e-3 # Adjust for larger thickness in all images TEXT_Y_OFFSET_SCALE = 1e-2 # Adjust for larger Y-offset of text and bounding box H, W, _ = frame.shape zeros_mask = np.zeros((frame.shape)).astype(np.uint8) for bbox, l in zip(bboxes, labels): cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]), PALETTE[l], 1) cv2.putText( frame, l, (bbox[0], bbox[1] - int(TEXT_Y_OFFSET_SCALE * H)), fontFace=cv2.FONT_HERSHEY_TRIPLEX, fontScale=min(H, W) * FONT_SCALE, thickness=math.ceil(min(H, W) * THICKNESS_SCALE), color=PALETTE[l]) zeros_mask = cv2.rectangle( zeros_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]), color=PALETTE[l], thickness=-1) frame = cv2.addWeighted(frame, 1., zeros_mask, .65, 0) video_writer.write(frame) pbar.update(1) video_writer.release cap.release() def panoptic_seg_masks_to_image(masks): draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) from mmdet.core.visualization.palette import get_palette mask_palette = get_palette('coco', 133) from mmdet.core.visualization.image import _get_bias_color taken_colors = set([0, 0, 0]) for i, mask in enumerate(masks): color_mask = mask_palette[i] while tuple(color_mask) in taken_colors: color_mask = _get_bias_color(color_mask) taken_colors.add(tuple(color_mask)) mask = mask.astype(bool) draw_img[mask] = color_mask return draw_img def semantic_seg_masks_to_image(masks): from mmdet.core.visualization.palette import get_palette mask_palette = get_palette('coco', 133) draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3]) for i, mask in enumerate(masks): color_mask = mask_palette[i] mask = mask.astype(bool) draw_img[mask] = color_mask return draw_img def show_video_summarization_result(video_in_path, result, video_save_path): frame_indexes = result[OutputKeys.OUTPUT] cap = cv2.VideoCapture(video_in_path) for i in range(len(frame_indexes)): idx = frame_indexes[i] success, frame = cap.read() if success is False: raise Exception(video_in_path, ' can not be correctly decoded by OpenCV.') if i == 0: size = (frame.shape[1], frame.shape[0]) fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G') video_writer = cv2.VideoWriter(video_save_path, fourcc, cap.get(cv2.CAP_PROP_FPS), size, True) if idx == 1: video_writer.write(frame) video_writer.release() cap.release() def show_image_object_detection_auto_result(img_path, detection_result, save_path=None): scores = detection_result[OutputKeys.SCORES] labels = detection_result[OutputKeys.LABELS] bboxes = detection_result[OutputKeys.BOXES] img = cv2.imread(img_path) assert img is not None, f"Can't read img: {img_path}" for (score, label, box) in zip(scores, labels, bboxes): cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), (0, 0, 255), 2) cv2.putText( img, f'{score:.2f}', (int(box[0]), int(box[1])), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) cv2.putText( img, label, (int(box[0]), int(box[3])), 1, 1.0, (0, 255, 0), thickness=1, lineType=8) if save_path is not None: cv2.imwrite(save_path, img) return img def depth_to_color(depth): colormap = plt.get_cmap('plasma') depth_color = (colormap( (depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3] depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR) return depth_color def make_colorwheel(): """ Generates a color wheel for optical flow visualization as presented in: Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007) URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf Code follows the original C++ source code of Daniel Scharstein. Code follows the the Matlab source code of Deqing Sun. Returns: np.ndarray: Color wheel """ RY = 15 YG = 6 GC = 4 CB = 11 BM = 13 MR = 6 ncols = RY + YG + GC + CB + BM + MR colorwheel = np.zeros((ncols, 3)) col = 0 # RY colorwheel[0:RY, 0] = 255 colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY) col = col + RY # YG colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG) colorwheel[col:col + YG, 1] = 255 col = col + YG # GC colorwheel[col:col + GC, 1] = 255 colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC) col = col + GC # CB colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB) colorwheel[col:col + CB, 2] = 255 col = col + CB # BM colorwheel[col:col + BM, 2] = 255 colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM) col = col + BM # MR colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR) colorwheel[col:col + MR, 0] = 255 return colorwheel def flow_uv_to_colors(u, v, convert_to_bgr=False): """ Applies the flow color wheel to (possibly clipped) flow components u and v. According to the C++ source code of Daniel Scharstein According to the Matlab source code of Deqing Sun Args: u (np.ndarray): Input horizontal flow of shape [H,W] v (np.ndarray): Input vertical flow of shape [H,W] convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8) colorwheel = make_colorwheel() # shape [55x3] ncols = colorwheel.shape[0] rad = np.sqrt(np.square(u) + np.square(v)) a = np.arctan2(-v, -u) / np.pi fk = (a + 1) / 2 * (ncols - 1) k0 = np.floor(fk).astype(np.int32) k1 = k0 + 1 k1[k1 == ncols] = 0 f = fk - k0 for i in range(colorwheel.shape[1]): tmp = colorwheel[:, i] col0 = tmp[k0] / 255.0 col1 = tmp[k1] / 255.0 col = (1 - f) * col0 + f * col1 idx = (rad <= 1) col[idx] = 1 - rad[idx] * (1 - col[idx]) col[~idx] = col[~idx] * 0.75 # out of range # Note the 2-i => BGR instead of RGB ch_idx = 2 - i if convert_to_bgr else i flow_image[:, :, ch_idx] = np.floor(255 * col) return flow_image def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False): """ Expects a two dimensional flow image of shape. Args: flow_uv (np.ndarray): Flow UV image of shape [H,W,2] clip_flow (float, optional): Clip maximum of flow values. Defaults to None. convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False. Returns: np.ndarray: Flow visualization image of shape [H,W,3] """ assert flow_uv.ndim == 3, 'input flow must have three dimensions' assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]' if clip_flow is not None: flow_uv = np.clip(flow_uv, 0, clip_flow) u = flow_uv[:, :, 0] v = flow_uv[:, :, 1] rad = np.sqrt(np.square(u) + np.square(v)) rad_max = np.max(rad) epsilon = 1e-5 u = u / (rad_max + epsilon) v = v / (rad_max + epsilon) return flow_uv_to_colors(u, v, convert_to_bgr) def flow_to_color(flow): flow = flow[0].permute(1, 2, 0).cpu().numpy() flow_color = flow_to_image(flow) return flow_color def show_video_depth_estimation_result(depths, video_save_path): height, width, layers = depths[0].shape out = cv2.VideoWriter(video_save_path, cv2.VideoWriter_fourcc(*'MP4V'), 25, (width, height)) for (i, img) in enumerate(depths): out.write(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR)) out.release() def show_image_driving_perception_result(img, results, out_file='result.jpg', if_draw=[1, 1, 1]): bboxes = results.get(OutputKeys.BOXES) if if_draw[0]: for x in bboxes: c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3])) cv2.rectangle( img, c1, c2, [255, 255, 0], thickness=2, lineType=cv2.LINE_AA) result = results.get(OutputKeys.MASKS) color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3), dtype=np.uint8) if if_draw[1]: color_area[result[0] == 1] = [0, 255, 0] if if_draw[2]: color_area[result[1] == 1] = [255, 0, 0] color_seg = color_area color_mask = np.mean(color_seg, 2) msk_idx = color_mask != 0 img[msk_idx] = img[msk_idx] * 0.5 + color_seg[msk_idx] * 0.5 if out_file is not None: cv2.imwrite(out_file, img[:, :, ::-1]) return img def masks_visualization(masks, palette): vis_masks = [] for f in range(masks.shape[0]): img_E = Image.fromarray(masks[f]) img_E.putpalette(palette) vis_masks.append(img_E) return vis_masks # This implementation is adopted from LoFTR, # made public available under the Apache License, Version 2.0, # at https://github.com/zju3dv/LoFTR def make_matching_figure(img0, img1, mkpts0, mkpts1, color, kpts0=None, kpts1=None, text=[], dpi=75, path=None): # draw image pair assert mkpts0.shape[0] == mkpts1.shape[ 0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}' fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi) axes[0].imshow(img0, cmap='gray') axes[1].imshow(img1, cmap='gray') for i in range(2): # clear all frames axes[i].get_yaxis().set_ticks([]) axes[i].get_xaxis().set_ticks([]) for spine in axes[i].spines.values(): spine.set_visible(False) plt.tight_layout(pad=1) if kpts0 is not None: assert kpts1 is not None axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2) axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2) # draw matches if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0: fig.canvas.draw() transFigure = fig.transFigure.inverted() fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0)) fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1)) fig.lines = [ matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]), (fkpts0[i, 1], fkpts1[i, 1]), transform=fig.transFigure, c=color[i], linewidth=1) for i in range(len(mkpts0)) ] axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4) axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4) # put txts txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w' fig.text( 0.01, 0.99, '\n'.join(text), transform=fig.axes[0].transAxes, fontsize=15, va='top', ha='left', color=txt_color) # save or return figure if path: plt.savefig(str(path), bbox_inches='tight', pad_inches=0) plt.close() else: return fig def match_pair_visualization(img_name0, img_name1, kpts0, kpts1, conf, output_filename='quadtree_match.png', method='QuadTreeAttention'): print(f'Found {len(kpts0)} matches') # visualize the matches img0 = cv2.imread(str(img_name0)) img1 = cv2.imread(str(img_name1)) # Draw color = cm.jet(conf) text = [ method, 'Matches: {}'.format(len(kpts0)), ] fig = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text) # save the figure fig.savefig(str(output_filename), dpi=300, bbox_inches='tight')