image_utils.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import os
  3. import cv2
  4. import matplotlib
  5. import matplotlib.cm as cm
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. import torch.nn.functional as F
  9. from PIL import Image
  10. from modelscope.outputs import OutputKeys
  11. from modelscope.preprocessors.image import load_image
  12. from modelscope.utils import logger as logging
  13. logger = logging.get_logger()
  14. class InputPadder:
  15. """ Pads images such that dimensions are divisible by 8 """
  16. def __init__(self, dims, mode='sintel'):
  17. self.ht, self.wd = dims[-2:]
  18. pad_ht = (((self.ht // 8) + 1) * 8 - self.ht) % 8
  19. pad_wd = (((self.wd // 8) + 1) * 8 - self.wd) % 8
  20. if mode == 'sintel':
  21. self._pad = [
  22. pad_wd // 2, pad_wd - pad_wd // 2, pad_ht // 2,
  23. pad_ht - pad_ht // 2
  24. ]
  25. else:
  26. self._pad = [pad_wd // 2, pad_wd - pad_wd // 2, 0, pad_ht]
  27. def pad(self, *inputs):
  28. return [F.pad(x, self._pad, mode='replicate') for x in inputs]
  29. def unpad(self, x):
  30. ht, wd = x.shape[-2:]
  31. c = [self._pad[2], ht - self._pad[3], self._pad[0], wd - self._pad[1]]
  32. return x[..., c[0]:c[1], c[2]:c[3]]
  33. def numpy_to_cv2img(img_array):
  34. """to convert a np.array with shape(h, w) to cv2 img
  35. Args:
  36. img_array (np.array): input data
  37. Returns:
  38. cv2 img
  39. """
  40. img_array = (img_array - img_array.min()) / (
  41. img_array.max() - img_array.min() + 1e-5)
  42. img_array = (img_array * 255).astype(np.uint8)
  43. img_array = cv2.applyColorMap(img_array, cv2.COLORMAP_JET)
  44. return img_array
  45. def draw_joints(image, np_kps, score, threshold=0.2):
  46. lst_parent_ids_17 = [0, 0, 0, 1, 2, 0, 0, 5, 6, 7, 8, 5, 6, 11, 12, 13, 14]
  47. lst_left_ids_17 = [1, 3, 5, 7, 9, 11, 13, 15]
  48. lst_right_ids_17 = [2, 4, 6, 8, 10, 12, 14, 16]
  49. lst_parent_ids_15 = [0, 0, 1, 2, 3, 1, 5, 6, 14, 8, 9, 14, 11, 12, 1]
  50. lst_left_ids_15 = [2, 3, 4, 8, 9, 10]
  51. lst_right_ids_15 = [5, 6, 7, 11, 12, 13]
  52. if np_kps.shape[0] == 17:
  53. lst_parent_ids = lst_parent_ids_17
  54. lst_left_ids = lst_left_ids_17
  55. lst_right_ids = lst_right_ids_17
  56. elif np_kps.shape[0] == 15:
  57. lst_parent_ids = lst_parent_ids_15
  58. lst_left_ids = lst_left_ids_15
  59. lst_right_ids = lst_right_ids_15
  60. for i in range(len(lst_parent_ids)):
  61. pid = lst_parent_ids[i]
  62. if i == pid:
  63. continue
  64. if (score[i] < threshold or score[1] < threshold):
  65. continue
  66. if i in lst_left_ids and pid in lst_left_ids:
  67. color = (0, 255, 0)
  68. elif i in lst_right_ids and pid in lst_right_ids:
  69. color = (255, 0, 0)
  70. else:
  71. color = (0, 255, 255)
  72. cv2.line(image, (int(np_kps[i, 0]), int(np_kps[i, 1])),
  73. (int(np_kps[pid][0]), int(np_kps[pid, 1])), color, 3)
  74. for i in range(np_kps.shape[0]):
  75. if score[i] < threshold:
  76. continue
  77. cv2.circle(image, (int(np_kps[i, 0]), int(np_kps[i, 1])), 5,
  78. (0, 0, 255), -1)
  79. def draw_box(image, box):
  80. cv2.rectangle(image, (int(box[0]), int(box[1])),
  81. (int(box[2]), int(box[3])), (0, 0, 255), 2)
  82. def realtime_object_detection_bbox_vis(image, bboxes):
  83. for bbox in bboxes:
  84. cv2.rectangle(image, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
  85. (255, 0, 0), 2)
  86. return image
  87. def draw_attribute(image, box, labels):
  88. cv2.rectangle(image, (int(box[0]), int(box[1])),
  89. (int(box[2]), int(box[3])), (0, 0, 255), 2)
  90. title = [
  91. 'gender : ', 'age : ', 'orient : ', 'hat : ',
  92. 'glass : ', 'hand_bag : ', 'shoulder_bag: ', 'back_pack : ',
  93. 'upper_wear : ', 'lower_wear : ', 'upper_color : ', 'lower_color : '
  94. ]
  95. clr = (np.random.randint(0, 255), np.random.randint(0, 255),
  96. np.random.randint(0, 255))
  97. point = (int(box[0] + 5), int(box[1] + 20))
  98. for idx, lb in enumerate(labels):
  99. sz = title[idx] + lb
  100. cv2.putText(image, f'{sz}', (point[0], point[1] + idx * 20),
  101. cv2.FONT_HERSHEY_SIMPLEX, 0.5, clr, 1)
  102. def draw_keypoints(output, original_image):
  103. poses = np.array(output[OutputKeys.KEYPOINTS])
  104. scores = np.array(output[OutputKeys.SCORES])
  105. boxes = np.array(output[OutputKeys.BOXES])
  106. assert len(poses) == len(scores) and len(poses) == len(boxes)
  107. image = cv2.imread(original_image, -1)
  108. for i in range(len(poses)):
  109. draw_box(image, np.array(boxes[i]))
  110. draw_joints(image, np.array(poses[i]), np.array(scores[i]))
  111. return image
  112. def draw_pedestrian_attribute(output, original_image):
  113. labels = np.array(output[OutputKeys.LABELS])
  114. boxes = np.array(output[OutputKeys.BOXES])
  115. assert len(labels) == len(boxes)
  116. image = cv2.imread(original_image, -1)
  117. for i in range(len(boxes)):
  118. draw_attribute(image, np.array(boxes[i]), labels[i])
  119. return image
  120. def draw_106face_keypoints(in_path,
  121. keypoints,
  122. boxes,
  123. scale=4.0,
  124. save_path=None):
  125. face_contour_point_index = [
  126. 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19,
  127. 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32
  128. ]
  129. left_eye_brow_point_index = [33, 34, 35, 36, 37, 38, 39, 40, 41, 33]
  130. right_eye_brow_point_index = [42, 43, 44, 45, 46, 47, 48, 49, 50, 42]
  131. left_eye_point_index = [66, 67, 68, 69, 70, 71, 72, 73, 66]
  132. right_eye_point_index = [75, 76, 77, 78, 79, 80, 81, 82, 75]
  133. nose_bridge_point_index = [51, 52, 53, 54]
  134. nose_contour_point_index = [55, 56, 57, 58, 59, 60, 61, 62, 63, 64, 65]
  135. mouth_outer_point_index = [
  136. 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, 84
  137. ]
  138. mouth_inter_point_index = [96, 97, 98, 99, 100, 101, 102, 103, 96]
  139. img = cv2.imread(in_path)
  140. for i in range(len(boxes)):
  141. draw_box(img, np.array(boxes[i]))
  142. image = cv2.resize(img, dsize=None, fx=scale, fy=scale)
  143. def draw_line(point_index, image, point):
  144. for i in range(len(point_index) - 1):
  145. cur_index = point_index[i]
  146. next_index = point_index[i + 1]
  147. cur_pt = (int(point[cur_index][0] * scale),
  148. int(point[cur_index][1] * scale))
  149. next_pt = (int(point[next_index][0] * scale),
  150. int(point[next_index][1] * scale))
  151. cv2.line(image, cur_pt, next_pt, (0, 0, 255), thickness=2)
  152. for i in range(len(keypoints)):
  153. points = keypoints[i]
  154. draw_line(face_contour_point_index, image, points)
  155. draw_line(left_eye_brow_point_index, image, points)
  156. draw_line(right_eye_brow_point_index, image, points)
  157. draw_line(left_eye_point_index, image, points)
  158. draw_line(right_eye_point_index, image, points)
  159. draw_line(nose_bridge_point_index, image, points)
  160. draw_line(nose_contour_point_index, image, points)
  161. draw_line(mouth_outer_point_index, image, points)
  162. draw_line(mouth_inter_point_index, image, points)
  163. size = len(points)
  164. for i in range(size):
  165. x = int(points[i][0])
  166. y = int(points[i][1])
  167. cv2.putText(image, str(i), (int(x * scale), int(y * scale)),
  168. cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 1)
  169. cv2.circle(image, (int(x * scale), int(y * scale)), 2, (0, 255, 0),
  170. cv2.FILLED)
  171. if save_path is not None:
  172. cv2.imwrite(save_path, image)
  173. return image
  174. def draw_face_detection_no_lm_result(img_path, detection_result):
  175. bboxes = np.array(detection_result[OutputKeys.BOXES])
  176. scores = np.array(detection_result[OutputKeys.SCORES])
  177. img = cv2.imread(img_path)
  178. assert img is not None, f"Can't read img: {img_path}"
  179. for i in range(len(scores)):
  180. bbox = bboxes[i].astype(np.int32)
  181. x1, y1, x2, y2 = bbox
  182. score = scores[i]
  183. cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
  184. cv2.putText(
  185. img,
  186. f'{score:.2f}', (x1, y2),
  187. 1,
  188. 1.0, (0, 255, 0),
  189. thickness=1,
  190. lineType=8)
  191. print(f'Found {len(scores)} faces')
  192. return img
  193. def draw_facial_expression_result(img_path, facial_expression_result):
  194. scores = facial_expression_result[OutputKeys.SCORES]
  195. labels = facial_expression_result[OutputKeys.LABELS]
  196. label = labels[np.argmax(scores)]
  197. img = cv2.imread(img_path)
  198. assert img is not None, f"Can't read img: {img_path}"
  199. cv2.putText(
  200. img,
  201. 'facial expression: {}'.format(label), (10, 10),
  202. 1,
  203. 1.0, (0, 255, 0),
  204. thickness=1,
  205. lineType=8)
  206. print('facial expression: {}'.format(label))
  207. return img
  208. def draw_face_attribute_result(img_path, face_attribute_result):
  209. scores = face_attribute_result[OutputKeys.SCORES]
  210. labels = face_attribute_result[OutputKeys.LABELS]
  211. label_gender = labels[0][np.argmax(scores[0])]
  212. label_age = labels[1][np.argmax(scores[1])]
  213. img = cv2.imread(img_path)
  214. assert img is not None, f"Can't read img: {img_path}"
  215. cv2.putText(
  216. img,
  217. 'face gender: {}'.format(label_gender), (10, 10),
  218. 1,
  219. 1.0, (0, 255, 0),
  220. thickness=1,
  221. lineType=8)
  222. cv2.putText(
  223. img,
  224. 'face age interval: {}'.format(label_age), (10, 40),
  225. 1,
  226. 1.0, (255, 0, 0),
  227. thickness=1,
  228. lineType=8)
  229. logger.info('face gender: {}'.format(label_gender))
  230. logger.info('face age interval: {}'.format(label_age))
  231. return img
  232. def draw_face_detection_result(img_path, detection_result):
  233. bboxes = np.array(detection_result[OutputKeys.BOXES])
  234. kpss = np.array(detection_result[OutputKeys.KEYPOINTS])
  235. scores = np.array(detection_result[OutputKeys.SCORES])
  236. img = cv2.imread(img_path)
  237. assert img is not None, f"Can't read img: {img_path}"
  238. for i in range(len(scores)):
  239. bbox = bboxes[i].astype(np.int32)
  240. kps = kpss[i].reshape(-1, 2).astype(np.int32)
  241. score = scores[i]
  242. x1, y1, x2, y2 = bbox
  243. cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 2)
  244. for kp in kps:
  245. cv2.circle(img, tuple(kp), 1, (0, 0, 255), 1)
  246. cv2.putText(
  247. img,
  248. f'{score:.2f}', (x1, y2),
  249. 1,
  250. 1.0, (0, 255, 0),
  251. thickness=1,
  252. lineType=8)
  253. print(f'Found {len(scores)} faces')
  254. return img
  255. def draw_card_detection_result(img_path, detection_result):
  256. def warp_img(src_img, kps, ratio):
  257. short_size = 500
  258. if ratio > 1:
  259. obj_h = short_size
  260. obj_w = int(obj_h * ratio)
  261. else:
  262. obj_w = short_size
  263. obj_h = int(obj_w / ratio)
  264. input_pts = np.float32([kps[0], kps[1], kps[2], kps[3]])
  265. output_pts = np.float32([[0, obj_h - 1], [0, 0], [obj_w - 1, 0],
  266. [obj_w - 1, obj_h - 1]])
  267. M = cv2.getPerspectiveTransform(input_pts, output_pts)
  268. obj_img = cv2.warpPerspective(src_img, M, (obj_w, obj_h))
  269. return obj_img
  270. bboxes = np.array(detection_result[OutputKeys.BOXES])
  271. kpss = np.array(detection_result[OutputKeys.KEYPOINTS])
  272. scores = np.array(detection_result[OutputKeys.SCORES])
  273. img_list = []
  274. ver_col = [(255, 0, 0), (0, 255, 0), (0, 0, 255), (0, 255, 255)]
  275. img = cv2.imread(img_path)
  276. img_list += [img]
  277. assert img is not None, f"Can't read img: {img_path}"
  278. for i in range(len(scores)):
  279. bbox = bboxes[i].astype(np.int32)
  280. kps = kpss[i].reshape(-1, 2).astype(np.int32)
  281. _w = (kps[0][0] - kps[3][0])**2 + (kps[0][1] - kps[3][1])**2
  282. _h = (kps[0][0] - kps[1][0])**2 + (kps[0][1] - kps[1][1])**2
  283. ratio = 1.59 if _w >= _h else 1 / 1.59
  284. card_img = warp_img(img, kps, ratio)
  285. img_list += [card_img]
  286. score = scores[i]
  287. x1, y1, x2, y2 = bbox
  288. cv2.rectangle(img, (x1, y1), (x2, y2), (255, 0, 0), 4)
  289. for k, kp in enumerate(kps):
  290. cv2.circle(img, tuple(kp), 1, color=ver_col[k], thickness=10)
  291. cv2.putText(
  292. img,
  293. f'{score:.2f}', (x1, y2),
  294. 1,
  295. 1.0, (0, 255, 0),
  296. thickness=1,
  297. lineType=8)
  298. return img_list
  299. def created_boxed_image(image_in, box):
  300. image = load_image(image_in)
  301. img = cv2.cvtColor(np.asarray(image), cv2.COLOR_RGB2BGR)
  302. cv2.rectangle(img, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])),
  303. (0, 255, 0), 3)
  304. return img
  305. def show_video_tracking_result(video_in_path, bboxes, video_save_path):
  306. cap = cv2.VideoCapture(video_in_path)
  307. for i in range(len(bboxes)):
  308. box = bboxes[i]
  309. success, frame = cap.read()
  310. if success is False:
  311. raise Exception(video_in_path,
  312. ' can not be correctly decoded by OpenCV.')
  313. if i == 0:
  314. size = (frame.shape[1], frame.shape[0])
  315. fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
  316. video_writer = cv2.VideoWriter(video_save_path, fourcc,
  317. cap.get(cv2.CAP_PROP_FPS), size,
  318. True)
  319. cv2.rectangle(frame, (box[0], box[1]), (box[2], box[3]), (0, 255, 0),
  320. 5)
  321. video_writer.write(frame)
  322. video_writer.release
  323. cap.release()
  324. def show_video_object_detection_result(video_in_path, bboxes_list, labels_list,
  325. video_save_path):
  326. PALETTE = {
  327. 'person': [128, 0, 0],
  328. 'bicycle': [128, 128, 0],
  329. 'car': [64, 0, 0],
  330. 'motorcycle': [0, 128, 128],
  331. 'bus': [64, 128, 0],
  332. 'truck': [192, 128, 0],
  333. 'traffic light': [64, 0, 128],
  334. 'stop sign': [192, 0, 128],
  335. }
  336. from tqdm import tqdm
  337. import math
  338. cap = cv2.VideoCapture(video_in_path)
  339. with tqdm(total=len(bboxes_list)) as pbar:
  340. pbar.set_description(
  341. 'Writing results to video: {}'.format(video_save_path))
  342. for i in range(len(bboxes_list)):
  343. bboxes = bboxes_list[i].astype(int)
  344. labels = labels_list[i]
  345. success, frame = cap.read()
  346. if success is False:
  347. raise Exception(video_in_path,
  348. ' can not be correctly decoded by OpenCV.')
  349. if i == 0:
  350. size = (frame.shape[1], frame.shape[0])
  351. fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
  352. video_writer = cv2.VideoWriter(video_save_path, fourcc,
  353. cap.get(cv2.CAP_PROP_FPS), size,
  354. True)
  355. FONT_SCALE = 1e-3 # Adjust for larger font size in all images
  356. THICKNESS_SCALE = 1e-3 # Adjust for larger thickness in all images
  357. TEXT_Y_OFFSET_SCALE = 1e-2 # Adjust for larger Y-offset of text and bounding box
  358. H, W, _ = frame.shape
  359. zeros_mask = np.zeros((frame.shape)).astype(np.uint8)
  360. for bbox, l in zip(bboxes, labels):
  361. cv2.rectangle(frame, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
  362. PALETTE[l], 1)
  363. cv2.putText(
  364. frame,
  365. l, (bbox[0], bbox[1] - int(TEXT_Y_OFFSET_SCALE * H)),
  366. fontFace=cv2.FONT_HERSHEY_TRIPLEX,
  367. fontScale=min(H, W) * FONT_SCALE,
  368. thickness=math.ceil(min(H, W) * THICKNESS_SCALE),
  369. color=PALETTE[l])
  370. zeros_mask = cv2.rectangle(
  371. zeros_mask, (bbox[0], bbox[1]), (bbox[2], bbox[3]),
  372. color=PALETTE[l],
  373. thickness=-1)
  374. frame = cv2.addWeighted(frame, 1., zeros_mask, .65, 0)
  375. video_writer.write(frame)
  376. pbar.update(1)
  377. video_writer.release
  378. cap.release()
  379. def panoptic_seg_masks_to_image(masks):
  380. draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3])
  381. from mmdet.core.visualization.palette import get_palette
  382. mask_palette = get_palette('coco', 133)
  383. from mmdet.core.visualization.image import _get_bias_color
  384. taken_colors = set([0, 0, 0])
  385. for i, mask in enumerate(masks):
  386. color_mask = mask_palette[i]
  387. while tuple(color_mask) in taken_colors:
  388. color_mask = _get_bias_color(color_mask)
  389. taken_colors.add(tuple(color_mask))
  390. mask = mask.astype(bool)
  391. draw_img[mask] = color_mask
  392. return draw_img
  393. def semantic_seg_masks_to_image(masks):
  394. from mmdet.core.visualization.palette import get_palette
  395. mask_palette = get_palette('coco', 133)
  396. draw_img = np.zeros([masks[0].shape[0], masks[0].shape[1], 3])
  397. for i, mask in enumerate(masks):
  398. color_mask = mask_palette[i]
  399. mask = mask.astype(bool)
  400. draw_img[mask] = color_mask
  401. return draw_img
  402. def show_video_summarization_result(video_in_path, result, video_save_path):
  403. frame_indexes = result[OutputKeys.OUTPUT]
  404. cap = cv2.VideoCapture(video_in_path)
  405. for i in range(len(frame_indexes)):
  406. idx = frame_indexes[i]
  407. success, frame = cap.read()
  408. if success is False:
  409. raise Exception(video_in_path,
  410. ' can not be correctly decoded by OpenCV.')
  411. if i == 0:
  412. size = (frame.shape[1], frame.shape[0])
  413. fourcc = cv2.VideoWriter_fourcc('M', 'J', 'P', 'G')
  414. video_writer = cv2.VideoWriter(video_save_path, fourcc,
  415. cap.get(cv2.CAP_PROP_FPS), size,
  416. True)
  417. if idx == 1:
  418. video_writer.write(frame)
  419. video_writer.release()
  420. cap.release()
  421. def show_image_object_detection_auto_result(img_path,
  422. detection_result,
  423. save_path=None):
  424. scores = detection_result[OutputKeys.SCORES]
  425. labels = detection_result[OutputKeys.LABELS]
  426. bboxes = detection_result[OutputKeys.BOXES]
  427. img = cv2.imread(img_path)
  428. assert img is not None, f"Can't read img: {img_path}"
  429. for (score, label, box) in zip(scores, labels, bboxes):
  430. cv2.rectangle(img, (int(box[0]), int(box[1])),
  431. (int(box[2]), int(box[3])), (0, 0, 255), 2)
  432. cv2.putText(
  433. img,
  434. f'{score:.2f}', (int(box[0]), int(box[1])),
  435. 1,
  436. 1.0, (0, 255, 0),
  437. thickness=1,
  438. lineType=8)
  439. cv2.putText(
  440. img,
  441. label, (int(box[0]), int(box[3])),
  442. 1,
  443. 1.0, (0, 255, 0),
  444. thickness=1,
  445. lineType=8)
  446. if save_path is not None:
  447. cv2.imwrite(save_path, img)
  448. return img
  449. def depth_to_color(depth):
  450. colormap = plt.get_cmap('plasma')
  451. depth_color = (colormap(
  452. (depth.max() - depth) / depth.max()) * 2**8).astype(np.uint8)[:, :, :3]
  453. depth_color = cv2.cvtColor(depth_color, cv2.COLOR_RGB2BGR)
  454. return depth_color
  455. def make_colorwheel():
  456. """
  457. Generates a color wheel for optical flow visualization as presented in:
  458. Baker et al. "A Database and Evaluation Methodology for Optical Flow" (ICCV, 2007)
  459. URL: http://vision.middlebury.edu/flow/flowEval-iccv07.pdf
  460. Code follows the original C++ source code of Daniel Scharstein.
  461. Code follows the the Matlab source code of Deqing Sun.
  462. Returns:
  463. np.ndarray: Color wheel
  464. """
  465. RY = 15
  466. YG = 6
  467. GC = 4
  468. CB = 11
  469. BM = 13
  470. MR = 6
  471. ncols = RY + YG + GC + CB + BM + MR
  472. colorwheel = np.zeros((ncols, 3))
  473. col = 0
  474. # RY
  475. colorwheel[0:RY, 0] = 255
  476. colorwheel[0:RY, 1] = np.floor(255 * np.arange(0, RY) / RY)
  477. col = col + RY
  478. # YG
  479. colorwheel[col:col + YG, 0] = 255 - np.floor(255 * np.arange(0, YG) / YG)
  480. colorwheel[col:col + YG, 1] = 255
  481. col = col + YG
  482. # GC
  483. colorwheel[col:col + GC, 1] = 255
  484. colorwheel[col:col + GC, 2] = np.floor(255 * np.arange(0, GC) / GC)
  485. col = col + GC
  486. # CB
  487. colorwheel[col:col + CB, 1] = 255 - np.floor(255 * np.arange(CB) / CB)
  488. colorwheel[col:col + CB, 2] = 255
  489. col = col + CB
  490. # BM
  491. colorwheel[col:col + BM, 2] = 255
  492. colorwheel[col:col + BM, 0] = np.floor(255 * np.arange(0, BM) / BM)
  493. col = col + BM
  494. # MR
  495. colorwheel[col:col + MR, 2] = 255 - np.floor(255 * np.arange(MR) / MR)
  496. colorwheel[col:col + MR, 0] = 255
  497. return colorwheel
  498. def flow_uv_to_colors(u, v, convert_to_bgr=False):
  499. """
  500. Applies the flow color wheel to (possibly clipped) flow components u and v.
  501. According to the C++ source code of Daniel Scharstein
  502. According to the Matlab source code of Deqing Sun
  503. Args:
  504. u (np.ndarray): Input horizontal flow of shape [H,W]
  505. v (np.ndarray): Input vertical flow of shape [H,W]
  506. convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
  507. Returns:
  508. np.ndarray: Flow visualization image of shape [H,W,3]
  509. """
  510. flow_image = np.zeros((u.shape[0], u.shape[1], 3), np.uint8)
  511. colorwheel = make_colorwheel() # shape [55x3]
  512. ncols = colorwheel.shape[0]
  513. rad = np.sqrt(np.square(u) + np.square(v))
  514. a = np.arctan2(-v, -u) / np.pi
  515. fk = (a + 1) / 2 * (ncols - 1)
  516. k0 = np.floor(fk).astype(np.int32)
  517. k1 = k0 + 1
  518. k1[k1 == ncols] = 0
  519. f = fk - k0
  520. for i in range(colorwheel.shape[1]):
  521. tmp = colorwheel[:, i]
  522. col0 = tmp[k0] / 255.0
  523. col1 = tmp[k1] / 255.0
  524. col = (1 - f) * col0 + f * col1
  525. idx = (rad <= 1)
  526. col[idx] = 1 - rad[idx] * (1 - col[idx])
  527. col[~idx] = col[~idx] * 0.75 # out of range
  528. # Note the 2-i => BGR instead of RGB
  529. ch_idx = 2 - i if convert_to_bgr else i
  530. flow_image[:, :, ch_idx] = np.floor(255 * col)
  531. return flow_image
  532. def flow_to_image(flow_uv, clip_flow=None, convert_to_bgr=False):
  533. """
  534. Expects a two dimensional flow image of shape.
  535. Args:
  536. flow_uv (np.ndarray): Flow UV image of shape [H,W,2]
  537. clip_flow (float, optional): Clip maximum of flow values. Defaults to None.
  538. convert_to_bgr (bool, optional): Convert output image to BGR. Defaults to False.
  539. Returns:
  540. np.ndarray: Flow visualization image of shape [H,W,3]
  541. """
  542. assert flow_uv.ndim == 3, 'input flow must have three dimensions'
  543. assert flow_uv.shape[2] == 2, 'input flow must have shape [H,W,2]'
  544. if clip_flow is not None:
  545. flow_uv = np.clip(flow_uv, 0, clip_flow)
  546. u = flow_uv[:, :, 0]
  547. v = flow_uv[:, :, 1]
  548. rad = np.sqrt(np.square(u) + np.square(v))
  549. rad_max = np.max(rad)
  550. epsilon = 1e-5
  551. u = u / (rad_max + epsilon)
  552. v = v / (rad_max + epsilon)
  553. return flow_uv_to_colors(u, v, convert_to_bgr)
  554. def flow_to_color(flow):
  555. flow = flow[0].permute(1, 2, 0).cpu().numpy()
  556. flow_color = flow_to_image(flow)
  557. return flow_color
  558. def show_video_depth_estimation_result(depths, video_save_path):
  559. height, width, layers = depths[0].shape
  560. out = cv2.VideoWriter(video_save_path, cv2.VideoWriter_fourcc(*'MP4V'), 25,
  561. (width, height))
  562. for (i, img) in enumerate(depths):
  563. out.write(cv2.cvtColor(img.astype(np.uint8), cv2.COLOR_RGB2BGR))
  564. out.release()
  565. def show_image_driving_perception_result(img,
  566. results,
  567. out_file='result.jpg',
  568. if_draw=[1, 1, 1]):
  569. bboxes = results.get(OutputKeys.BOXES)
  570. if if_draw[0]:
  571. for x in bboxes:
  572. c1, c2 = (int(x[0]), int(x[1])), (int(x[2]), int(x[3]))
  573. cv2.rectangle(
  574. img, c1, c2, [255, 255, 0], thickness=2, lineType=cv2.LINE_AA)
  575. result = results.get(OutputKeys.MASKS)
  576. color_area = np.zeros((result[0].shape[0], result[0].shape[1], 3),
  577. dtype=np.uint8)
  578. if if_draw[1]:
  579. color_area[result[0] == 1] = [0, 255, 0]
  580. if if_draw[2]:
  581. color_area[result[1] == 1] = [255, 0, 0]
  582. color_seg = color_area
  583. color_mask = np.mean(color_seg, 2)
  584. msk_idx = color_mask != 0
  585. img[msk_idx] = img[msk_idx] * 0.5 + color_seg[msk_idx] * 0.5
  586. if out_file is not None:
  587. cv2.imwrite(out_file, img[:, :, ::-1])
  588. return img
  589. def masks_visualization(masks, palette):
  590. vis_masks = []
  591. for f in range(masks.shape[0]):
  592. img_E = Image.fromarray(masks[f])
  593. img_E.putpalette(palette)
  594. vis_masks.append(img_E)
  595. return vis_masks
  596. # This implementation is adopted from LoFTR,
  597. # made public available under the Apache License, Version 2.0,
  598. # at https://github.com/zju3dv/LoFTR
  599. def make_matching_figure(img0,
  600. img1,
  601. mkpts0,
  602. mkpts1,
  603. color,
  604. kpts0=None,
  605. kpts1=None,
  606. text=[],
  607. dpi=75,
  608. path=None):
  609. # draw image pair
  610. assert mkpts0.shape[0] == mkpts1.shape[
  611. 0], f'mkpts0: {mkpts0.shape[0]} v.s. mkpts1: {mkpts1.shape[0]}'
  612. fig, axes = plt.subplots(1, 2, figsize=(10, 6), dpi=dpi)
  613. axes[0].imshow(img0, cmap='gray')
  614. axes[1].imshow(img1, cmap='gray')
  615. for i in range(2): # clear all frames
  616. axes[i].get_yaxis().set_ticks([])
  617. axes[i].get_xaxis().set_ticks([])
  618. for spine in axes[i].spines.values():
  619. spine.set_visible(False)
  620. plt.tight_layout(pad=1)
  621. if kpts0 is not None:
  622. assert kpts1 is not None
  623. axes[0].scatter(kpts0[:, 0], kpts0[:, 1], c='w', s=2)
  624. axes[1].scatter(kpts1[:, 0], kpts1[:, 1], c='w', s=2)
  625. # draw matches
  626. if mkpts0.shape[0] != 0 and mkpts1.shape[0] != 0:
  627. fig.canvas.draw()
  628. transFigure = fig.transFigure.inverted()
  629. fkpts0 = transFigure.transform(axes[0].transData.transform(mkpts0))
  630. fkpts1 = transFigure.transform(axes[1].transData.transform(mkpts1))
  631. fig.lines = [
  632. matplotlib.lines.Line2D((fkpts0[i, 0], fkpts1[i, 0]),
  633. (fkpts0[i, 1], fkpts1[i, 1]),
  634. transform=fig.transFigure,
  635. c=color[i],
  636. linewidth=1) for i in range(len(mkpts0))
  637. ]
  638. axes[0].scatter(mkpts0[:, 0], mkpts0[:, 1], c=color, s=4)
  639. axes[1].scatter(mkpts1[:, 0], mkpts1[:, 1], c=color, s=4)
  640. # put txts
  641. txt_color = 'k' if img0[:100, :200].mean() > 200 else 'w'
  642. fig.text(
  643. 0.01,
  644. 0.99,
  645. '\n'.join(text),
  646. transform=fig.axes[0].transAxes,
  647. fontsize=15,
  648. va='top',
  649. ha='left',
  650. color=txt_color)
  651. # save or return figure
  652. if path:
  653. plt.savefig(str(path), bbox_inches='tight', pad_inches=0)
  654. plt.close()
  655. else:
  656. return fig
  657. def match_pair_visualization(img_name0,
  658. img_name1,
  659. kpts0,
  660. kpts1,
  661. conf,
  662. output_filename='quadtree_match.png',
  663. method='QuadTreeAttention'):
  664. print(f'Found {len(kpts0)} matches')
  665. # visualize the matches
  666. img0 = cv2.imread(str(img_name0))
  667. img1 = cv2.imread(str(img_name1))
  668. # Draw
  669. color = cm.jet(conf)
  670. text = [
  671. method,
  672. 'Matches: {}'.format(len(kpts0)),
  673. ]
  674. fig = make_matching_figure(img0, img1, kpts0, kpts1, color, text=text)
  675. # save the figure
  676. fig.savefig(str(output_filename), dpi=300, bbox_inches='tight')