skychange.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import numbers
  3. import os
  4. import pdb
  5. from collections import deque
  6. import cv2
  7. import json
  8. import numpy as np
  9. import torch
  10. import torch.nn.functional as F
  11. from PIL import Image
  12. from torchvision import transforms
  13. torch.backends.cudnn.enabled = True
  14. IMAGE_MAX_DIM = 3000
  15. IMAGE_MIN_DIM = 50
  16. IMAGE_MAX_RATIO = 10.0
  17. IMAGE_BLENDER_MASK_RESIZE_SCALE = 10.0
  18. IMAGE_BLENDER_INNER_RECT_MAX_DIM = 256
  19. IMAGE_BLENDER_DILATE_KERNEL_SIZE = 7
  20. IMAGE_BLENDER_VALID_MASK_THRESHOLD = 100
  21. IMAGE_BLENDER_MIN_VALID_SKY_AREA = 100
  22. IMAGE_BLENDER_MIN_RESIZE_DIM = 10
  23. IMAGE_BLENDER_BLUR_KERNEL_SIZE = 5
  24. def extract_sky_image(in_sky_image, in_sky_mask):
  25. scale = 1.0
  26. resize_mask = in_sky_mask.copy()
  27. rows, cols = resize_mask.shape[0:2]
  28. # src size: (512, 640), target size: (256,256), then scale to size (256, 320)
  29. if (rows > IMAGE_BLENDER_INNER_RECT_MAX_DIM
  30. or cols > IMAGE_BLENDER_INNER_RECT_MAX_DIM):
  31. height_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(rows)
  32. width_scale = IMAGE_BLENDER_INNER_RECT_MAX_DIM / float(cols)
  33. scale = height_scale if height_scale > width_scale else width_scale
  34. new_size = (max(int(cols * scale), 1), max(int(rows * scale),
  35. 1)) # w, h
  36. resize_mask = cv2.resize(resize_mask, new_size, cv2.INTER_LINEAR)
  37. kernelSize = max(3, int(scale * IMAGE_BLENDER_DILATE_KERNEL_SIZE + 0.5))
  38. element = cv2.getStructuringElement(cv2.MORPH_RECT,
  39. (kernelSize, kernelSize))
  40. resize_mask = cv2.morphologyEx(resize_mask, cv2.MORPH_CLOSE, element)
  41. max_inner_rect, area = get_max_inner_rect(
  42. resize_mask, IMAGE_BLENDER_VALID_MASK_THRESHOLD, True)
  43. if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
  44. raise Exception(
  45. '[extractSkyImage]failed!! Valid sky region is too small')
  46. scale = 1.0 / scale
  47. # max_inner_rect: left top(x,y), right bottome(x,y); raw_inner_rect:left top x,y,w(of bbox),h(of bbox)
  48. raw_inner_rect = scale_rect(max_inner_rect, in_sky_mask, scale)
  49. out_sky_image = in_sky_image[raw_inner_rect[1]:raw_inner_rect[1]
  50. + raw_inner_rect[3] + 1,
  51. raw_inner_rect[0]:raw_inner_rect[0]
  52. + raw_inner_rect[2] + 1, ].copy()
  53. return out_sky_image
  54. def blend(scene_image, scene_mask, sky_image, sky_mask, inBlendLevelNum=10):
  55. if torch.cuda.is_available():
  56. scene_image = scene_image.cpu().numpy()
  57. sky_image = sky_image.cpu().numpy()
  58. else:
  59. scene_image = scene_image.numpy()
  60. sky_image = sky_image.numpy()
  61. sky_image_h, sky_image_w = sky_image.shape[0:2]
  62. sky_mask_h, sky_mask_w = sky_mask.shape[0:2]
  63. scene_image_h, scene_image_w = scene_image.shape[0:2]
  64. scene_mask_h, scene_mask_w = scene_mask.shape[0:2]
  65. if sky_image_h != sky_mask_h or sky_image_w != sky_mask_w:
  66. raise Exception(
  67. '[blend]failed!! sky_image shape not equal with sky_image_mask shape'
  68. )
  69. if scene_image_h != scene_mask_h or scene_image_w != scene_mask_w:
  70. raise Exception(
  71. '[blend]failed!! scene_image shape not equal with scene_image_mask shape'
  72. )
  73. valid_sky_image = extract_sky_image(sky_image, sky_mask)
  74. out_blend_image = blend_merge(scene_image, scene_mask, valid_sky_image,
  75. inBlendLevelNum)
  76. return out_blend_image
  77. def get_max_inner_rect(in_image_mask, in_alpha_threshold, is_bigger_valid):
  78. res = 0
  79. row, col = in_image_mask.shape[0:2]
  80. i0, j0, i1, j1 = 0, 0, 0, 0
  81. height = [0] * (col + 1)
  82. for i in range(0, row):
  83. s = deque()
  84. for j in range(0, col + 1):
  85. if j < col:
  86. if is_bigger_valid:
  87. height[j] = (
  88. height[j]
  89. + 1 if in_image_mask[i, j] > in_alpha_threshold else 0)
  90. else:
  91. height[j] = (
  92. height[j] + 1
  93. if in_image_mask[i, j] <= in_alpha_threshold else 0)
  94. while len(s) != 0 and height[s[-1]] >= height[j]:
  95. cur = s[-1]
  96. s.pop()
  97. _h = height[cur]
  98. _w = j if len(s) == 0 else j - s[-1] - 1
  99. curArea = _h * _w
  100. if curArea > res:
  101. res = curArea
  102. i1 = i
  103. i0 = i1 - _h + 1
  104. j1 = j - 1
  105. j0 = j1 - _w + 1
  106. s.append(j)
  107. out_rect = (
  108. j0,
  109. i0,
  110. j1 - j0 + 1,
  111. i1 - i0 + 1,
  112. )
  113. return out_rect, res
  114. def scale_rect(in_rect, in_image_size, in_scale):
  115. tlX = int(in_rect[0] * in_scale + 0.5)
  116. tlY = int(in_rect[1] * in_scale + 0.5)
  117. in_image_size_h, in_image_size_w = in_image_size.shape[0:2]
  118. brX = min(int(in_rect[2] * in_scale + 0.5), in_image_size_w)
  119. brY = min(int(in_rect[3] * in_scale + 0.5), in_image_size_h)
  120. out_rect = (tlX, tlY, brX - tlX, brY - tlY)
  121. return out_rect
  122. def get_fast_valid_rect(in_mask, in_threshold=0):
  123. # mask: np.array [0~1]
  124. in_mask = in_mask > in_threshold
  125. locations = cv2.findNonZero(in_mask.astype(np.uint8))
  126. output_rect = cv2.boundingRect(locations) # x,y,w,h
  127. return output_rect
  128. def min_size_match(in_image, in_min_size, type=cv2.INTER_LINEAR):
  129. resize_image = in_image.copy()
  130. width, height = in_min_size
  131. resize_img_height, resize_img_width = in_image.shape[0:2]
  132. height_scale = height / resize_img_height
  133. widht_scale = width / resize_img_width
  134. scale = height_scale if height_scale > widht_scale else widht_scale
  135. new_size = (
  136. max(int(resize_img_width * scale + 0.5), 1),
  137. max(int(resize_img_height * scale + 0.5), 1),
  138. )
  139. resize_image = cv2.resize(resize_image, new_size, 0, 0, type)
  140. return resize_image
  141. def center_crop(in_image, in_size):
  142. in_size_w, in_size_h = in_size
  143. in_image_h, in_image_w = in_image.shape[0:2]
  144. half_height = (in_image_h - in_size_h) // 2
  145. half_width = (in_image_w - in_size_w) // 2
  146. out_crop_image = in_image.copy()
  147. out_crop_image = out_crop_image[half_height:half_height + in_size_h,
  148. half_width:half_width + in_size_w]
  149. return out_crop_image
  150. def safe_roi_pad(in_pad_image, in_rect, out_base_image):
  151. in_rect_x, in_rect_y, in_rect_w, in_rect_h = in_rect
  152. if in_rect_x < 0 or in_rect_y < 0 or in_rect_w <= 0 or in_rect_h <= 0:
  153. raise Exception('[safe_roi_pad] Failed!! x,y,w,h of rect are illegal')
  154. if in_rect_w != in_pad_image.shape[1] or in_rect_h != in_pad_image.shape[0]:
  155. raise Exception('[safe_roi_pad] Failed!!')
  156. if (in_rect_x + in_rect_w > out_base_image.shape[1]
  157. or in_rect_y + in_rect_h > out_base_image.shape[0]):
  158. raise Exception('[safe_roi_pad] Failed!!')
  159. out_base_image[in_rect_y:in_rect_y + in_rect_h,
  160. in_rect_x:in_rect_x + in_rect_w] = in_pad_image
  161. def merge_image(in_base_image, in_merge_image, in_merge_mask, in_point):
  162. if in_merge_image.shape[0:2] != in_merge_mask.shape[0:2]:
  163. raise Exception(
  164. '[merge_image] Failed!! in_merge_image.shape != in_merge_mask.shape!!'
  165. )
  166. in_point_x, in_point_y = in_point
  167. in_merge_image_rows, in_merge_image_cols = in_merge_image.shape[0:2]
  168. in_base_image_rows, in_base_image_cols = in_base_image.shape[0:2]
  169. if (in_point_x + in_merge_image_cols > in_base_image_cols
  170. or in_point_y + in_merge_image_rows > in_base_image_rows):
  171. raise Exception(
  172. '[merge_image] Failed!! merge_image:image rect not in image')
  173. base_roi_image = in_base_image[in_point_y:in_point_y + in_merge_image_rows,
  174. in_point_x:in_point_x
  175. + in_merge_image_cols, ]
  176. merge_image = in_merge_image.copy()
  177. merge_alpha = in_merge_mask.copy()
  178. base_roi_image = np.float32(base_roi_image)
  179. merge_alpha = np.repeat(merge_alpha[:, :, np.newaxis], 3, axis=2)
  180. merge_alpha = merge_alpha / 255.0
  181. base_roi_image = (
  182. 1 - merge_alpha) * base_roi_image + merge_alpha * merge_image
  183. base_roi_image = np.clip(base_roi_image, 0, 255)
  184. base_roi_image = base_roi_image.astype('uint8')
  185. roi_rect = (in_point_x, in_point_y, in_merge_image_cols,
  186. in_merge_image_rows)
  187. safe_roi_pad(base_roi_image, roi_rect, in_base_image)
  188. return in_base_image
  189. def blend_merge(in_scene_image,
  190. in_scene_mask,
  191. in_valid_sky_image,
  192. inBlendLevelNum=5):
  193. scene_sky_rect = get_fast_valid_rect(in_scene_mask, 1)
  194. area = scene_sky_rect[2] * scene_sky_rect[3]
  195. if area < IMAGE_BLENDER_MIN_VALID_SKY_AREA:
  196. raise Exception(
  197. '[blend_merge] Failed!! Scene Image Valid sky region is too small')
  198. valid_sky_image = min_size_match(in_valid_sky_image, scene_sky_rect[2:])
  199. valid_sky_image = center_crop(valid_sky_image, scene_sky_rect[2:])
  200. # resizeSceneMask
  201. sky_size = (
  202. max(
  203. int(in_scene_mask.shape[1] * IMAGE_BLENDER_MASK_RESIZE_SCALE
  204. + 0.5),
  205. IMAGE_BLENDER_MIN_RESIZE_DIM,
  206. ),
  207. max(
  208. int(in_scene_mask.shape[0] * IMAGE_BLENDER_MASK_RESIZE_SCALE
  209. + 0.5),
  210. IMAGE_BLENDER_MIN_RESIZE_DIM,
  211. ),
  212. )
  213. resize_scene_mask = cv2.resize(in_scene_mask, sky_size, cv2.INTER_LINEAR)
  214. resize_scene_mask = cv2.blur(
  215. resize_scene_mask,
  216. (IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE),
  217. )
  218. element = cv2.getStructuringElement(
  219. cv2.MORPH_RECT,
  220. (IMAGE_BLENDER_BLUR_KERNEL_SIZE, IMAGE_BLENDER_BLUR_KERNEL_SIZE))
  221. sky_mask = cv2.dilate(resize_scene_mask, element) # enlarge sky region
  222. scene_mask = cv2.erode(resize_scene_mask, element) # enlarge scene region
  223. scene_mask = 255 - scene_mask
  224. sky_mask = cv2.resize(sky_mask, in_scene_mask.shape[0:2][::-1])
  225. scene_mask = cv2.resize(scene_mask, in_scene_mask.shape[0:2][::-1])
  226. x, y, w, h = scene_sky_rect
  227. valid_sky_mask = sky_mask[y:y + h, x:x + w]
  228. pano_sky_image = in_scene_image.copy()
  229. pano_sky_image = merge_image(pano_sky_image, valid_sky_image,
  230. valid_sky_mask, scene_sky_rect[0:2])
  231. blend_images = []
  232. blend_images.append(in_scene_image)
  233. blend_images.append(pano_sky_image)
  234. blend_masks = []
  235. blend_masks.append(scene_mask.astype(np.uint8))
  236. blend_masks.append(sky_mask.astype(np.uint8))
  237. panorama_rect = (0, 0, in_scene_image.shape[1], in_scene_image.shape[0])
  238. blender = cv2.detail_MultiBandBlender(1, inBlendLevelNum)
  239. blender.prepare(panorama_rect)
  240. for i in range(0, len(blend_images)):
  241. blender.feed(blend_images[i], blend_masks[i], (0, 0))
  242. pano_mask = (
  243. np.ones(
  244. (in_scene_image.shape[1], in_scene_image.shape[0]), dtype='uint8')
  245. * 255)
  246. out_blend_image = np.zeros_like(in_scene_image)
  247. result = blender.blend(out_blend_image, pano_mask)
  248. return result[0]