utils.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import time
  3. from typing import Dict, List, Optional, Tuple, Union
  4. import cv2
  5. import numpy as np
  6. import torch
  7. import torch.nn.functional as F
  8. from einops import rearrange
  9. __all__ = [
  10. 'gen_diffuse_mask', 'get_crop_bbox', 'get_roi_without_padding',
  11. 'patch_aggregation_overlap', 'patch_partition_overlap', 'preprocess_roi',
  12. 'resize_on_long_side', 'roi_to_tensor', 'smooth_border_mg', 'whiten_img'
  13. ]
  14. def resize_on_long_side(img, long_side=800):
  15. src_height = img.shape[0]
  16. src_width = img.shape[1]
  17. if src_height > src_width:
  18. scale = long_side * 1.0 / src_height
  19. _img = cv2.resize(
  20. img, (int(src_width * scale), long_side),
  21. interpolation=cv2.INTER_LINEAR)
  22. else:
  23. scale = long_side * 1.0 / src_width
  24. _img = cv2.resize(
  25. img, (long_side, int(src_height * scale)),
  26. interpolation=cv2.INTER_LINEAR)
  27. return _img, scale
  28. def get_crop_bbox(detecting_results):
  29. boxes = []
  30. for anno in detecting_results:
  31. if anno['score'] == -1:
  32. break
  33. boxes.append({
  34. 'x1': anno['bbox'][0],
  35. 'y1': anno['bbox'][1],
  36. 'x2': anno['bbox'][2],
  37. 'y2': anno['bbox'][3]
  38. })
  39. face_count = len(boxes)
  40. suitable_bboxes = []
  41. for i in range(face_count):
  42. face_bbox = boxes[i]
  43. face_bbox_width = abs(face_bbox['x2'] - face_bbox['x1'])
  44. face_bbox_height = abs(face_bbox['y2'] - face_bbox['y1'])
  45. face_bbox_center = ((face_bbox['x1'] + face_bbox['x2']) / 2,
  46. (face_bbox['y1'] + face_bbox['y2']) / 2)
  47. square_bbox_length = face_bbox_height if face_bbox_height > face_bbox_width else face_bbox_width
  48. enlarge_ratio = 1.5
  49. square_bbox_length = int(enlarge_ratio * square_bbox_length)
  50. sideScale = 1
  51. square_bbox = {
  52. 'x1':
  53. int(face_bbox_center[0] - sideScale * square_bbox_length / 2),
  54. 'x2':
  55. int(face_bbox_center[0] + sideScale * square_bbox_length / 2),
  56. 'y1':
  57. int(face_bbox_center[1] - sideScale * square_bbox_length / 2),
  58. 'y2': int(face_bbox_center[1] + sideScale * square_bbox_length / 2)
  59. }
  60. suitable_bboxes.append(square_bbox)
  61. return suitable_bboxes
  62. def get_roi_without_padding(img, bbox):
  63. crop_t = max(bbox['y1'], 0)
  64. crop_b = min(bbox['y2'], img.shape[0])
  65. crop_l = max(bbox['x1'], 0)
  66. crop_r = min(bbox['x2'], img.shape[1])
  67. roi = img[crop_t:crop_b, crop_l:crop_r]
  68. return roi, 0, [crop_t, crop_b, crop_l, crop_r]
  69. def roi_to_tensor(img):
  70. img = torch.from_numpy(img.transpose((2, 0, 1)))[None, ...]
  71. return img
  72. def preprocess_roi(img):
  73. img = img.float() / 255.0
  74. img = (img - 0.5) * 2
  75. return img
  76. def patch_partition_overlap(image, p1, p2, padding=32):
  77. B, C, H, W = image.size()
  78. h, w = H // p1, W // p2
  79. image = F.pad(
  80. image,
  81. pad=(padding, padding, padding, padding, 0, 0),
  82. mode='constant',
  83. value=0)
  84. patch_list = []
  85. for i in range(h):
  86. for j in range(w):
  87. patch = image[:, :, p1 * i:p1 * (i + 1) + padding * 2,
  88. p2 * j:p2 * (j + 1) + padding * 2]
  89. patch_list.append(patch)
  90. output = torch.cat(
  91. patch_list, dim=0) # (b h w) c (p1 + 2 * padding) (p2 + 2 * padding)
  92. return output
  93. def patch_aggregation_overlap(image, h, w, padding=32):
  94. image = image[:, :, padding:-padding, padding:-padding]
  95. output = rearrange(image, '(b h w) c p1 p2 -> b c (h p1) (w p2)', h=h, w=w)
  96. return output
  97. def smooth_border_mg(diffuse_mask, mg):
  98. mg = mg - 0.5
  99. diffuse_mask = F.interpolate(
  100. diffuse_mask, mg.shape[:2], mode='bilinear')[0].permute(1, 2, 0)
  101. mg = mg * diffuse_mask
  102. mg = mg + 0.5
  103. return mg
  104. def whiten_img(image, skin_mask, whitening_degree, flag_bigKernal=False):
  105. """
  106. image: rgb
  107. """
  108. dilate_kernalsize = 30
  109. if flag_bigKernal:
  110. dilate_kernalsize = 80
  111. new_kernel1 = cv2.getStructuringElement(
  112. cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize))
  113. new_kernel2 = cv2.getStructuringElement(
  114. cv2.MORPH_ELLIPSE, (dilate_kernalsize, dilate_kernalsize))
  115. if len(skin_mask.shape) == 3:
  116. skin_mask = skin_mask[:, :, -1]
  117. skin_mask = cv2.dilate(skin_mask, new_kernel1, 1)
  118. skin_mask = cv2.erode(skin_mask, new_kernel2, 1)
  119. skin_mask = cv2.blur(skin_mask, (20, 20)) / 255.0
  120. skin_mask = skin_mask.squeeze()
  121. skin_mask = torch.from_numpy(skin_mask).to(image.device)
  122. skin_mask = torch.stack([skin_mask, skin_mask, skin_mask], dim=0)[None,
  123. ...]
  124. skin_mask[:, 1:, :, :] *= 0.75
  125. whiten_mg = skin_mask * 0.2 * whitening_degree + 0.5
  126. assert len(whiten_mg.shape) == 4
  127. whiten_mg = F.interpolate(
  128. whiten_mg, image.shape[:2], mode='bilinear')[0].permute(1, 2,
  129. 0).half()
  130. output_pred = image.half()
  131. output_pred = output_pred / 255.0
  132. output_pred = (
  133. -2 * whiten_mg + 1
  134. ) * output_pred * output_pred + 2 * whiten_mg * output_pred # value: 0~1
  135. output_pred = output_pred * 255.0
  136. output_pred = output_pred.byte()
  137. output_pred = output_pred.cpu().numpy()
  138. return output_pred
  139. def gen_diffuse_mask(out_channels=3):
  140. mask_size = 500
  141. diffuse_with = 20
  142. a = np.ones(shape=(mask_size, mask_size), dtype=np.float32)
  143. for i in range(mask_size):
  144. for j in range(mask_size):
  145. if i >= diffuse_with and i <= (
  146. mask_size - diffuse_with) and j >= diffuse_with and j <= (
  147. mask_size - diffuse_with):
  148. a[i, j] = 1.0
  149. elif i <= diffuse_with:
  150. a[i, j] = i * 1.0 / diffuse_with
  151. elif i > (mask_size - diffuse_with):
  152. a[i, j] = (mask_size - i) * 1.0 / diffuse_with
  153. for i in range(mask_size):
  154. for j in range(mask_size):
  155. if j <= diffuse_with:
  156. a[i, j] = min(a[i, j], j * 1.0 / diffuse_with)
  157. elif j > (mask_size - diffuse_with):
  158. a[i, j] = min(a[i, j], (mask_size - j) * 1.0 / diffuse_with)
  159. a = np.dstack([a] * out_channels)
  160. return a
  161. def pad_to_size(
  162. target_size: Tuple[int, int],
  163. image: np.array,
  164. bboxes: Optional[np.ndarray] = None,
  165. keypoints: Optional[np.ndarray] = None,
  166. ) -> Dict[str, Union[np.ndarray, Tuple[int, int, int, int]]]:
  167. """Pads the image on the sides to the target_size
  168. Args:
  169. target_size: (target_height, target_width)
  170. image:
  171. bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max]
  172. keypoints: np.array with shape (num_keypoints, 2), each row: [x, y]
  173. Returns:
  174. {
  175. "image": padded_image,
  176. "pads": (x_min_pad, y_min_pad, x_max_pad, y_max_pad),
  177. "bboxes": shifted_boxes,
  178. "keypoints": shifted_keypoints
  179. }
  180. """
  181. target_height, target_width = target_size
  182. image_height, image_width = image.shape[:2]
  183. if target_width < image_width:
  184. raise ValueError(f'Target width should bigger than image_width'
  185. f'We got {target_width} {image_width}')
  186. if target_height < image_height:
  187. raise ValueError(f'Target height should bigger than image_height'
  188. f'We got {target_height} {image_height}')
  189. if image_height == target_height:
  190. y_min_pad = 0
  191. y_max_pad = 0
  192. else:
  193. y_pad = target_height - image_height
  194. y_min_pad = y_pad // 2
  195. y_max_pad = y_pad - y_min_pad
  196. if image_width == target_width:
  197. x_min_pad = 0
  198. x_max_pad = 0
  199. else:
  200. x_pad = target_width - image_width
  201. x_min_pad = x_pad // 2
  202. x_max_pad = x_pad - x_min_pad
  203. result = {
  204. 'pads': (x_min_pad, y_min_pad, x_max_pad, y_max_pad),
  205. 'image':
  206. cv2.copyMakeBorder(image, y_min_pad, y_max_pad, x_min_pad, x_max_pad,
  207. cv2.BORDER_CONSTANT),
  208. }
  209. if bboxes is not None:
  210. bboxes[:, 0] += x_min_pad
  211. bboxes[:, 1] += y_min_pad
  212. bboxes[:, 2] += x_min_pad
  213. bboxes[:, 3] += y_min_pad
  214. result['bboxes'] = bboxes
  215. if keypoints is not None:
  216. keypoints[:, 0] += x_min_pad
  217. keypoints[:, 1] += y_min_pad
  218. result['keypoints'] = keypoints
  219. return result
  220. def unpad_from_size(
  221. pads: Tuple[int, int, int, int],
  222. image: Optional[np.array] = None,
  223. bboxes: Optional[np.ndarray] = None,
  224. keypoints: Optional[np.ndarray] = None,
  225. ) -> Dict[str, np.ndarray]:
  226. """Crops patch from the center so that sides are equal to pads.
  227. Args:
  228. image:
  229. pads: (x_min_pad, y_min_pad, x_max_pad, y_max_pad)
  230. bboxes: np.array with shape (num_boxes, 4). Each row: [x_min, y_min, x_max, y_max]
  231. keypoints: np.array with shape (num_keypoints, 2), each row: [x, y]
  232. Returns: cropped image
  233. {
  234. "image": cropped_image,
  235. "bboxes": shifted_boxes,
  236. "keypoints": shifted_keypoints
  237. }
  238. """
  239. x_min_pad, y_min_pad, x_max_pad, y_max_pad = pads
  240. result = {}
  241. if image is not None:
  242. height, width = image.shape[:2]
  243. result['image'] = image[y_min_pad:height - y_max_pad,
  244. x_min_pad:width - x_max_pad]
  245. if bboxes is not None:
  246. bboxes[:, 0] -= x_min_pad
  247. bboxes[:, 1] -= y_min_pad
  248. bboxes[:, 2] -= x_min_pad
  249. bboxes[:, 3] -= y_min_pad
  250. result['bboxes'] = bboxes
  251. if keypoints is not None:
  252. keypoints[:, 0] -= x_min_pad
  253. keypoints[:, 1] -= y_min_pad
  254. result['keypoints'] = keypoints
  255. return result