video_frame_interpolation_metric.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183
  1. # ------------------------------------------------------------------------
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. # ------------------------------------------------------------------------
  4. import math
  5. from math import exp
  6. from typing import Dict
  7. import lpips
  8. import numpy as np
  9. import torch
  10. import torch.nn.functional as F
  11. from modelscope.metainfo import Metrics
  12. from modelscope.metrics.base import Metric
  13. from modelscope.metrics.builder import METRICS, MetricKeys
  14. from modelscope.utils.registry import default_group
  15. @METRICS.register_module(
  16. group_key=default_group,
  17. module_name=Metrics.video_frame_interpolation_metric)
  18. class VideoFrameInterpolationMetric(Metric):
  19. """The metric computation class for video frame interpolation,
  20. which will return PSNR, SSIM and LPIPS.
  21. """
  22. pred_name = 'pred'
  23. label_name = 'target'
  24. def __init__(self):
  25. super(VideoFrameInterpolationMetric, self).__init__()
  26. self.preds = []
  27. self.labels = []
  28. self.loss_fn_alex = lpips.LPIPS(net='alex').cuda()
  29. def add(self, outputs: Dict, inputs: Dict):
  30. ground_truths = outputs[VideoFrameInterpolationMetric.label_name]
  31. eval_results = outputs[VideoFrameInterpolationMetric.pred_name]
  32. self.preds.append(eval_results)
  33. self.labels.append(ground_truths)
  34. def evaluate(self):
  35. psnr_list, ssim_list, lpips_list = [], [], []
  36. with torch.no_grad():
  37. for (pred, label) in zip(self.preds, self.labels):
  38. # norm to 0-1
  39. height, width = label.size(2), label.size(3)
  40. pred = pred[:, :, 0:height, 0:width]
  41. psnr_list.append(calculate_psnr(label, pred))
  42. ssim_list.append(calculate_ssim(label, pred))
  43. lpips_list.append(
  44. calculate_lpips(label, pred, self.loss_fn_alex))
  45. return {
  46. MetricKeys.PSNR: np.mean(psnr_list),
  47. MetricKeys.SSIM: np.mean(ssim_list),
  48. MetricKeys.LPIPS: np.mean(lpips_list)
  49. }
  50. def merge(self, other: 'VideoFrameInterpolationMetric'):
  51. self.preds.extend(other.preds)
  52. self.labels.extend(other.labels)
  53. def __getstate__(self):
  54. return self.preds, self.labels
  55. def __setstate__(self, state):
  56. self.__init__()
  57. self.preds, self.labels = state
  58. def gaussian(window_size, sigma):
  59. gauss = torch.Tensor([
  60. exp(-(x - window_size // 2)**2 / float(2 * sigma**2))
  61. for x in range(window_size)
  62. ])
  63. return gauss / gauss.sum()
  64. def create_window_3d(window_size, channel=1, device=None):
  65. _1D_window = gaussian(window_size, 1.5).unsqueeze(1)
  66. _2D_window = _1D_window.mm(_1D_window.t())
  67. _3D_window = _2D_window.unsqueeze(2) @ (_1D_window.t())
  68. window = _3D_window.expand(1, channel, window_size, window_size,
  69. window_size).contiguous().to(device)
  70. return window
  71. def calculate_psnr(img1, img2):
  72. psnr = -10 * math.log10(
  73. torch.mean((img1[0] - img2[0]) * (img1[0] - img2[0])).cpu().data)
  74. return psnr
  75. def calculate_ssim(img1,
  76. img2,
  77. window_size=11,
  78. window=None,
  79. size_average=True,
  80. full=False,
  81. val_range=None):
  82. # Value range can be different from 255. Other common ranges are 1 (sigmoid) and 2 (tanh).
  83. if val_range is None:
  84. if torch.max(img1) > 128:
  85. max_val = 255
  86. else:
  87. max_val = 1
  88. if torch.min(img1) < -0.5:
  89. min_val = -1
  90. else:
  91. min_val = 0
  92. L = max_val - min_val
  93. else:
  94. L = val_range
  95. padd = 0
  96. (_, _, height, width) = img1.size()
  97. if window is None:
  98. real_size = min(window_size, height, width)
  99. window = create_window_3d(
  100. real_size, channel=1, device=img1.device).to(img1.device)
  101. # Channel is set to 1 since we consider color images as volumetric images
  102. img1 = img1.unsqueeze(1)
  103. img2 = img2.unsqueeze(1)
  104. mu1 = F.conv3d(
  105. F.pad(img1, (5, 5, 5, 5, 5, 5), mode='replicate'),
  106. window,
  107. padding=padd,
  108. groups=1)
  109. mu2 = F.conv3d(
  110. F.pad(img2, (5, 5, 5, 5, 5, 5), mode='replicate'),
  111. window,
  112. padding=padd,
  113. groups=1)
  114. mu1_sq = mu1.pow(2)
  115. mu2_sq = mu2.pow(2)
  116. mu1_mu2 = mu1 * mu2
  117. sigma1_sq = F.conv3d(
  118. F.pad(img1 * img1, (5, 5, 5, 5, 5, 5), 'replicate'),
  119. window,
  120. padding=padd,
  121. groups=1) - mu1_sq
  122. sigma2_sq = F.conv3d(
  123. F.pad(img2 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
  124. window,
  125. padding=padd,
  126. groups=1) - mu2_sq
  127. sigma12 = F.conv3d(
  128. F.pad(img1 * img2, (5, 5, 5, 5, 5, 5), 'replicate'),
  129. window,
  130. padding=padd,
  131. groups=1) - mu1_mu2
  132. C1 = (0.01 * L)**2
  133. C2 = (0.03 * L)**2
  134. v1 = 2.0 * sigma12 + C2
  135. v2 = sigma1_sq + sigma2_sq + C2
  136. cs = torch.mean(v1 / v2) # contrast sensitivity
  137. ssim_map = ((2 * mu1_mu2 + C1) * v1) / ((mu1_sq + mu2_sq + C1) * v2)
  138. if size_average:
  139. ret = ssim_map.mean()
  140. else:
  141. ret = ssim_map.mean(1).mean(1).mean(1)
  142. if full:
  143. return ret, cs
  144. return ret.cpu()
  145. def calculate_lpips(img1, img2, loss_fn_alex):
  146. img1 = img1 * 2 - 1
  147. img2 = img2 * 2 - 1
  148. d = loss_fn_alex(img1, img2)
  149. return d.cpu().item()