sr_metric.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161
  1. # copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """
  15. https://github.com/FudanVI/FudanOCR/blob/main/text-gestalt/utils/ssim_psnr.py
  16. """
  17. from math import exp
  18. import paddle
  19. import paddle.nn.functional as F
  20. import paddle.nn as nn
  21. import string
  22. class SSIM(nn.Layer):
  23. def __init__(self, window_size=11, size_average=True):
  24. super(SSIM, self).__init__()
  25. self.window_size = window_size
  26. self.size_average = size_average
  27. self.channel = 1
  28. self.window = self.create_window(window_size, self.channel)
  29. def gaussian(self, window_size, sigma):
  30. gauss = paddle.to_tensor(
  31. [
  32. exp(-((x - window_size // 2) ** 2) / float(2 * sigma**2))
  33. for x in range(window_size)
  34. ]
  35. )
  36. return gauss / gauss.sum()
  37. def create_window(self, window_size, channel):
  38. _1D_window = self.gaussian(window_size, 1.5).unsqueeze(1)
  39. _2D_window = _1D_window.mm(_1D_window.t()).unsqueeze(0).unsqueeze(0)
  40. window = _2D_window.expand([channel, 1, window_size, window_size])
  41. return window
  42. def _ssim(self, img1, img2, window, window_size, channel, size_average=True):
  43. mu1 = F.conv2d(img1, window, padding=window_size // 2, groups=channel)
  44. mu2 = F.conv2d(img2, window, padding=window_size // 2, groups=channel)
  45. mu1_sq = mu1.pow(2)
  46. mu2_sq = mu2.pow(2)
  47. mu1_mu2 = mu1 * mu2
  48. sigma1_sq = (
  49. F.conv2d(img1 * img1, window, padding=window_size // 2, groups=channel)
  50. - mu1_sq
  51. )
  52. sigma2_sq = (
  53. F.conv2d(img2 * img2, window, padding=window_size // 2, groups=channel)
  54. - mu2_sq
  55. )
  56. sigma12 = (
  57. F.conv2d(img1 * img2, window, padding=window_size // 2, groups=channel)
  58. - mu1_mu2
  59. )
  60. C1 = 0.01**2
  61. C2 = 0.03**2
  62. ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / (
  63. (mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2)
  64. )
  65. if size_average:
  66. return ssim_map.mean()
  67. else:
  68. return ssim_map.mean([1, 2, 3])
  69. def ssim(self, img1, img2, window_size=11, size_average=True):
  70. (_, channel, _, _) = img1.shape
  71. window = self.create_window(window_size, channel)
  72. return self._ssim(img1, img2, window, window_size, channel, size_average)
  73. def forward(self, img1, img2):
  74. (_, channel, _, _) = img1.shape
  75. if channel == self.channel and self.window.dtype == img1.dtype:
  76. window = self.window
  77. else:
  78. window = self.create_window(self.window_size, channel)
  79. self.window = window
  80. self.channel = channel
  81. return self._ssim(
  82. img1, img2, window, self.window_size, channel, self.size_average
  83. )
  84. class SRMetric(object):
  85. def __init__(self, main_indicator="all", **kwargs):
  86. self.main_indicator = main_indicator
  87. self.eps = 1e-5
  88. self.psnr_result = []
  89. self.ssim_result = []
  90. self.calculate_ssim = SSIM()
  91. self.reset()
  92. def reset(self):
  93. self.correct_num = 0
  94. self.all_num = 0
  95. self.norm_edit_dis = 0
  96. self.psnr_result = []
  97. self.ssim_result = []
  98. def calculate_psnr(self, img1, img2):
  99. # img1 and img2 have range [0, 1]
  100. mse = ((img1 * 255 - img2 * 255) ** 2).mean()
  101. if mse == 0:
  102. return float("inf")
  103. return 20 * paddle.log10(255.0 / paddle.sqrt(mse))
  104. def _normalize_text(self, text):
  105. text = "".join(
  106. filter(lambda x: x in (string.digits + string.ascii_letters), text)
  107. )
  108. return text.lower()
  109. def __call__(self, pred_label, *args, **kwargs):
  110. metric = {}
  111. images_sr = pred_label["sr_img"]
  112. images_hr = pred_label["hr_img"]
  113. psnr = self.calculate_psnr(images_sr, images_hr)
  114. ssim = self.calculate_ssim(images_sr, images_hr)
  115. self.psnr_result.append(psnr)
  116. self.ssim_result.append(ssim)
  117. def get_metric(self):
  118. """
  119. return metrics {
  120. 'acc': 0,
  121. 'norm_edit_dis': 0,
  122. }
  123. """
  124. self.psnr_avg = sum(self.psnr_result) / len(self.psnr_result)
  125. self.psnr_avg = round(self.psnr_avg.item(), 6)
  126. self.ssim_avg = sum(self.ssim_result) / len(self.ssim_result)
  127. self.ssim_avg = round(self.ssim_avg.item(), 6)
  128. self.all_avg = self.psnr_avg + self.ssim_avg
  129. self.reset()
  130. return {
  131. "psnr_avg": self.psnr_avg,
  132. "ssim_avg": self.ssim_avg,
  133. "all": self.all_avg,
  134. }