image_denoise_metric.py 9.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. # ------------------------------------------------------------------------
  4. # modified from https://github.com/megvii-research/NAFNet/blob/main/basicsr/metrics/psnr_ssim.py
  5. # ------------------------------------------------------------------------
  6. from typing import Dict
  7. import cv2
  8. import numpy as np
  9. import torch
  10. from modelscope.metainfo import Metrics
  11. from modelscope.utils.registry import default_group
  12. from .base import Metric
  13. from .builder import METRICS, MetricKeys
  14. @METRICS.register_module(
  15. group_key=default_group, module_name=Metrics.image_denoise_metric)
  16. class ImageDenoiseMetric(Metric):
  17. """The metric computation class for image denoise classes.
  18. """
  19. pred_name = 'pred'
  20. label_name = 'target'
  21. def __init__(self):
  22. super(ImageDenoiseMetric, self).__init__()
  23. self.preds = []
  24. self.labels = []
  25. def add(self, outputs: Dict, inputs: Dict):
  26. ground_truths = outputs[ImageDenoiseMetric.label_name]
  27. eval_results = outputs[ImageDenoiseMetric.pred_name]
  28. self.preds.append(eval_results)
  29. self.labels.append(ground_truths)
  30. def evaluate(self):
  31. psnr_list, ssim_list = [], []
  32. for (pred, label) in zip(self.preds, self.labels):
  33. psnr_list.append(calculate_psnr(label[0], pred[0], crop_border=0))
  34. ssim_list.append(calculate_ssim(label[0], pred[0], crop_border=0))
  35. return {
  36. MetricKeys.PSNR: np.mean(psnr_list),
  37. MetricKeys.SSIM: np.mean(ssim_list)
  38. }
  39. def merge(self, other: 'ImageDenoiseMetric'):
  40. self.preds.extend(other.preds)
  41. self.labels.extend(other.labels)
  42. def __getstate__(self):
  43. return self.preds, self.labels
  44. def __setstate__(self, state):
  45. self.__init__()
  46. self.preds, self.labels = state
  47. def reorder_image(img, input_order='HWC'):
  48. """Reorder images to 'HWC' order.
  49. If the input_order is (h, w), return (h, w, 1);
  50. If the input_order is (c, h, w), return (h, w, c);
  51. If the input_order is (h, w, c), return as it is.
  52. Args:
  53. img (ndarray): Input image.
  54. input_order (str): Whether the input order is 'HWC' or 'CHW'.
  55. If the input image shape is (h, w), input_order will not have
  56. effects. Default: 'HWC'.
  57. Returns:
  58. ndarray: reordered image.
  59. """
  60. if input_order not in ['HWC', 'CHW']:
  61. raise ValueError(
  62. f"Wrong input_order {input_order}. Supported input_orders are 'HWC' and 'CHW'"
  63. )
  64. if len(img.shape) == 2:
  65. img = img[..., None]
  66. if input_order == 'CHW':
  67. img = img.transpose(1, 2, 0)
  68. return img
  69. def calculate_psnr(img1, img2, crop_border, input_order='HWC'):
  70. """Calculate PSNR (Peak Signal-to-Noise Ratio).
  71. Ref: https://en.wikipedia.org/wiki/Peak_signal-to-noise_ratio
  72. Args:
  73. img1 (ndarray/tensor): Images with range [0, 255]/[0, 1].
  74. img2 (ndarray/tensor): Images with range [0, 255]/[0, 1].
  75. crop_border (int): Cropped pixels in each edge of an image. These
  76. pixels are not involved in the PSNR calculation.
  77. input_order (str): Whether the input order is 'HWC' or 'CHW'.
  78. Default: 'HWC'.
  79. test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
  80. Returns:
  81. float: psnr result.
  82. """
  83. assert img1.shape == img2.shape, (
  84. f'Image shapes are different: {img1.shape}, {img2.shape}.')
  85. if input_order not in ['HWC', 'CHW']:
  86. raise ValueError(
  87. f'Wrong input_order {input_order}. Supported input_orders are '
  88. '"HWC" and "CHW"')
  89. if type(img1) == torch.Tensor:
  90. if len(img1.shape) == 4:
  91. img1 = img1.squeeze(0)
  92. img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
  93. if type(img2) == torch.Tensor:
  94. if len(img2.shape) == 4:
  95. img2 = img2.squeeze(0)
  96. img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)
  97. img1 = reorder_image(img1, input_order=input_order)
  98. img2 = reorder_image(img2, input_order=input_order)
  99. img1 = img1.astype(np.float64)
  100. img2 = img2.astype(np.float64)
  101. if crop_border != 0:
  102. img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
  103. img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
  104. def _psnr(img1, img2):
  105. mse = np.mean((img1 - img2)**2)
  106. if mse == 0:
  107. return float('inf')
  108. max_value = 1. if img1.max() <= 1 else 255.
  109. return 20. * np.log10(max_value / np.sqrt(mse))
  110. return _psnr(img1, img2)
  111. def calculate_ssim(img1, img2, crop_border, input_order='HWC', ssim3d=True):
  112. """Calculate SSIM (structural similarity).
  113. Ref:
  114. Image quality assessment: From error visibility to structural similarity
  115. The results are the same as that of the official released MATLAB code in
  116. https://ece.uwaterloo.ca/~z70wang/research/ssim/.
  117. For three-channel images, SSIM is calculated for each channel and then
  118. averaged.
  119. Args:
  120. img1 (ndarray): Images with range [0, 255].
  121. img2 (ndarray): Images with range [0, 255].
  122. crop_border (int): Cropped pixels in each edge of an image. These
  123. pixels are not involved in the SSIM calculation.
  124. input_order (str): Whether the input order is 'HWC' or 'CHW'.
  125. Default: 'HWC'.
  126. test_y_channel (bool): Test on Y channel of YCbCr. Default: False.
  127. Returns:
  128. float: ssim result.
  129. """
  130. assert img1.shape == img2.shape, (
  131. f'Image shapes are different: {img1.shape}, {img2.shape}.')
  132. if input_order not in ['HWC', 'CHW']:
  133. raise ValueError(
  134. f'Wrong input_order {input_order}. Supported input_orders are '
  135. '"HWC" and "CHW"')
  136. if type(img1) == torch.Tensor:
  137. if len(img1.shape) == 4:
  138. img1 = img1.squeeze(0)
  139. img1 = img1.detach().cpu().numpy().transpose(1, 2, 0)
  140. if type(img2) == torch.Tensor:
  141. if len(img2.shape) == 4:
  142. img2 = img2.squeeze(0)
  143. img2 = img2.detach().cpu().numpy().transpose(1, 2, 0)
  144. img1 = reorder_image(img1, input_order=input_order)
  145. img2 = reorder_image(img2, input_order=input_order)
  146. img1 = img1.astype(np.float64)
  147. img2 = img2.astype(np.float64)
  148. if crop_border != 0:
  149. img1 = img1[crop_border:-crop_border, crop_border:-crop_border, ...]
  150. img2 = img2[crop_border:-crop_border, crop_border:-crop_border, ...]
  151. def _cal_ssim(img1, img2):
  152. ssims = []
  153. max_value = 1 if img1.max() <= 1 else 255
  154. with torch.no_grad():
  155. final_ssim = _ssim_3d(img1, img2, max_value) if ssim3d else _ssim(
  156. img1, img2, max_value)
  157. ssims.append(final_ssim)
  158. return np.array(ssims).mean()
  159. return _cal_ssim(img1, img2)
  160. def _ssim(img, img2, max_value):
  161. """Calculate SSIM (structural similarity) for one channel images.
  162. It is called by func:`calculate_ssim`.
  163. Args:
  164. img (ndarray): Images with range [0, 255] with order 'HWC'.
  165. img2 (ndarray): Images with range [0, 255] with order 'HWC'.
  166. Returns:
  167. float: SSIM result.
  168. """
  169. c1 = (0.01 * max_value)**2
  170. c2 = (0.03 * max_value)**2
  171. img = img.astype(np.float64)
  172. img2 = img2.astype(np.float64)
  173. kernel = cv2.getGaussianKernel(11, 1.5)
  174. window = np.outer(kernel, kernel.transpose())
  175. mu1 = cv2.filter2D(img, -1, window)[5:-5,
  176. 5:-5] # valid mode for window size 11
  177. mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
  178. mu1_sq = mu1**2
  179. mu2_sq = mu2**2
  180. mu1_mu2 = mu1 * mu2
  181. sigma1_sq = cv2.filter2D(img**2, -1, window)[5:-5, 5:-5] - mu1_sq
  182. sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
  183. sigma12 = cv2.filter2D(img * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
  184. tmp1 = (2 * mu1_mu2 + c1) * (2 * sigma12 + c2)
  185. tmp2 = (mu1_sq + mu2_sq + c1) * (sigma1_sq + sigma2_sq + c2)
  186. ssim_map = tmp1 / tmp2
  187. return ssim_map.mean()
  188. def _3d_gaussian_calculator(img, conv3d):
  189. out = conv3d(img.unsqueeze(0).unsqueeze(0)).squeeze(0).squeeze(0)
  190. return out
  191. def _generate_3d_gaussian_kernel():
  192. kernel = cv2.getGaussianKernel(11, 1.5)
  193. window = np.outer(kernel, kernel.transpose())
  194. kernel_3 = cv2.getGaussianKernel(11, 1.5)
  195. kernel = torch.tensor(np.stack([window * k for k in kernel_3], axis=0))
  196. conv3d = torch.nn.Conv3d(
  197. 1,
  198. 1, (11, 11, 11),
  199. stride=1,
  200. padding=(5, 5, 5),
  201. bias=False,
  202. padding_mode='replicate')
  203. conv3d.weight.requires_grad = False
  204. conv3d.weight[0, 0, :, :, :] = kernel
  205. return conv3d
  206. def _ssim_3d(img1, img2, max_value):
  207. assert len(img1.shape) == 3 and len(img2.shape) == 3
  208. """Calculate SSIM (structural similarity) for one channel images.
  209. It is called by func:`calculate_ssim`.
  210. Args:
  211. img1 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
  212. img2 (ndarray): Images with range [0, 255]/[0, 1] with order 'HWC'.
  213. Returns:
  214. float: ssim result.
  215. """
  216. C1 = (0.01 * max_value)**2
  217. C2 = (0.03 * max_value)**2
  218. img1 = img1.astype(np.float64)
  219. img2 = img2.astype(np.float64)
  220. kernel = _generate_3d_gaussian_kernel().cuda()
  221. img1 = torch.tensor(img1).float().cuda()
  222. img2 = torch.tensor(img2).float().cuda()
  223. mu1 = _3d_gaussian_calculator(img1, kernel)
  224. mu2 = _3d_gaussian_calculator(img2, kernel)
  225. mu1_sq = mu1**2
  226. mu2_sq = mu2**2
  227. mu1_mu2 = mu1 * mu2
  228. sigma1_sq = _3d_gaussian_calculator(img1**2, kernel) - mu1_sq
  229. sigma2_sq = _3d_gaussian_calculator(img2**2, kernel) - mu2_sq
  230. sigma12 = _3d_gaussian_calculator(img1 * img2, kernel) - mu1_mu2
  231. tmp1 = (2 * mu1_mu2 + C1) * (2 * sigma12 + C2)
  232. tmp2 = (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
  233. ssim_map = tmp1 / tmp2
  234. return float(ssim_map.mean())