video_stabilization_metric.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269
  1. # Part of the implementation is borrowed and modified from DIFRINT,
  2. # publicly available at https://github.com/jinsc37/DIFRINT/blob/master/metrics.py
  3. import os
  4. import sys
  5. import tempfile
  6. from typing import Dict
  7. import cv2
  8. import numpy as np
  9. from tqdm import tqdm
  10. from modelscope.metainfo import Metrics
  11. from modelscope.models.cv.video_stabilization.utils.WarpUtils import \
  12. warpListImage
  13. from modelscope.utils.registry import default_group
  14. from .base import Metric
  15. from .builder import METRICS, MetricKeys
  16. @METRICS.register_module(
  17. group_key=default_group, module_name=Metrics.video_stabilization_metric)
  18. class VideoStabilizationMetric(Metric):
  19. """The metric for video summarization task.
  20. """
  21. def __init__(self):
  22. self.inputs = []
  23. self.outputs = []
  24. def add(self, outputs: Dict, inputs: Dict):
  25. out = video_merger(warpprocess(outputs))
  26. self.outputs.append(out['video'])
  27. self.inputs.append(inputs['input'][0])
  28. def evaluate(self):
  29. CR = []
  30. DV = []
  31. SS = []
  32. for output, input in zip(self.outputs, self.inputs):
  33. cropping_ratio, distortion_value, stability_score = \
  34. metrics(input, output)
  35. if cropping_ratio <= 1 and distortion_value <= 1 and stability_score <= 1:
  36. CR.append(cropping_ratio)
  37. DV.append(distortion_value)
  38. SS.append(stability_score)
  39. else:
  40. print('Removed one error item when computing metrics.')
  41. return {
  42. MetricKeys.CROPPING_RATIO: sum(CR) / len(CR),
  43. MetricKeys.DISTORTION_VALUE: sum(DV) / len(DV),
  44. MetricKeys.STABILITY_SCORE: sum(SS) / len(SS),
  45. }
  46. def merge(self, other: 'VideoStabilizationMetric'):
  47. self.inputs.extend(other.inputs)
  48. self.outputs.extend(other.outputs)
  49. def __getstate__(self):
  50. return self.inputs, self.outputs
  51. def __setstate__(self, state):
  52. self.inputs, self.outputs = state
  53. def warpprocess(inputs):
  54. """ video stabilization postprocess
  55. Args:
  56. inputs: input data
  57. Return:
  58. dict of results: a dict containing outputs of model.
  59. """
  60. x_paths = inputs['origin_motion'][:, :, :, 0]
  61. y_paths = inputs['origin_motion'][:, :, :, 1]
  62. sx_paths = inputs['smooth_path'][:, :, :, 0]
  63. sy_paths = inputs['smooth_path'][:, :, :, 1]
  64. new_x_motion_meshes = sx_paths - x_paths
  65. new_y_motion_meshes = sy_paths - y_paths
  66. out_images = warpListImage(inputs['ori_images'], new_x_motion_meshes,
  67. new_y_motion_meshes, inputs['width'],
  68. inputs['height'])
  69. return {
  70. 'output': out_images,
  71. 'fps': inputs['fps'],
  72. 'width': inputs['width'],
  73. 'height': inputs['height'],
  74. 'base_crop_width': inputs['base_crop_width']
  75. }
  76. def video_merger(inputs):
  77. out_images = inputs['output'].numpy().astype(np.uint8)
  78. out_images = [
  79. np.transpose(out_images[idx], (1, 2, 0))
  80. for idx in range(out_images.shape[0])
  81. ]
  82. output_video_path = tempfile.NamedTemporaryFile(suffix='.mp4').name
  83. fourcc = cv2.VideoWriter_fourcc(*'mp4v')
  84. w = inputs['width']
  85. h = inputs['height']
  86. base_crop_width = inputs['base_crop_width']
  87. video_writer = cv2.VideoWriter(output_video_path, fourcc, inputs['fps'],
  88. (w, h))
  89. for idx, frame in enumerate(out_images):
  90. horizontal_border = int(base_crop_width * w / 1280)
  91. vertical_border = int(horizontal_border * h / w)
  92. new_frame = frame[vertical_border:-vertical_border,
  93. horizontal_border:-horizontal_border]
  94. new_frame = cv2.resize(new_frame, (w, h))
  95. video_writer.write(new_frame)
  96. video_writer.release()
  97. return {'video': output_video_path}
  98. def metrics(original_v, pred_v):
  99. # Create brute-force matcher object
  100. bf = cv2.BFMatcher()
  101. sift = cv2.SIFT_create()
  102. # Apply the homography transformation if we have enough good matches
  103. MIN_MATCH_COUNT = 10
  104. ratio = 0.7
  105. thresh = 5.0
  106. CR_seq = []
  107. DV_seq = []
  108. Pt = np.eye(3)
  109. P_seq = []
  110. vc_o = cv2.VideoCapture(original_v)
  111. vc_p = cv2.VideoCapture(pred_v)
  112. rval_o = vc_o.isOpened()
  113. rval_p = vc_p.isOpened()
  114. imgs1 = []
  115. imgs1o = []
  116. while (rval_o and rval_p):
  117. rval_o, img1 = vc_o.read()
  118. rval_p, img1o = vc_p.read()
  119. if rval_o and rval_p:
  120. imgs1.append(img1)
  121. imgs1o.append(img1o)
  122. is_got_bad_item = False
  123. print('processing ' + original_v.split('/')[-1] + ':')
  124. for i in tqdm(range(len(imgs1))):
  125. # Load the images in gray scale
  126. img1 = imgs1[i]
  127. img1o = imgs1o[i]
  128. # Detect the SIFT key points and compute the descriptors for the two images
  129. keyPoints1, descriptors1 = sift.detectAndCompute(img1, None)
  130. keyPoints1o, descriptors1o = sift.detectAndCompute(img1o, None)
  131. # Match the descriptors
  132. matches = bf.knnMatch(descriptors1, descriptors1o, k=2)
  133. # Select the good matches using the ratio test
  134. goodMatches = []
  135. for m, n in matches:
  136. if m.distance < ratio * n.distance:
  137. goodMatches.append(m)
  138. if len(goodMatches) > MIN_MATCH_COUNT:
  139. # Get the good key points positions
  140. sourcePoints = np.float32([
  141. keyPoints1[m.queryIdx].pt for m in goodMatches
  142. ]).reshape(-1, 1, 2)
  143. destinationPoints = np.float32([
  144. keyPoints1o[m.trainIdx].pt for m in goodMatches
  145. ]).reshape(-1, 1, 2)
  146. # Obtain the homography matrix
  147. M, _ = cv2.findHomography(
  148. sourcePoints,
  149. destinationPoints,
  150. method=cv2.RANSAC,
  151. ransacReprojThreshold=thresh)
  152. else:
  153. is_got_bad_item = True
  154. # end
  155. if not is_got_bad_item:
  156. # Obtain Scale, Translation, Rotation, Distortion value
  157. # Based on https://math.stackexchange.com/questions/78137/decomposition-of-a-nonsquare-affine-matrix
  158. scaleRecovered = np.sqrt(M[0, 1]**2 + M[0, 0]**2)
  159. w, _ = np.linalg.eig(M[0:2, 0:2])
  160. w = np.sort(w)[::-1]
  161. DV = w[1] / w[0]
  162. CR_seq.append(1 / scaleRecovered)
  163. DV_seq.append(DV)
  164. # For Stability score calculation
  165. if i + 1 < len(imgs1):
  166. img2o = imgs1o[i + 1]
  167. keyPoints2o, descriptors2o = sift.detectAndCompute(img2o, None)
  168. matches = bf.knnMatch(descriptors1o, descriptors2o, k=2)
  169. goodMatches = []
  170. for m, n in matches:
  171. if m.distance < ratio * n.distance:
  172. goodMatches.append(m)
  173. if len(goodMatches) > MIN_MATCH_COUNT:
  174. # Get the good key points positions
  175. sourcePoints = np.float32([
  176. keyPoints1o[m.queryIdx].pt for m in goodMatches
  177. ]).reshape(-1, 1, 2)
  178. destinationPoints = np.float32([
  179. keyPoints2o[m.trainIdx].pt for m in goodMatches
  180. ]).reshape(-1, 1, 2)
  181. # Obtain the homography matrix
  182. M, _ = cv2.findHomography(
  183. sourcePoints,
  184. destinationPoints,
  185. method=cv2.RANSAC,
  186. ransacReprojThreshold=thresh)
  187. # end
  188. P_seq.append(np.matmul(Pt, M))
  189. Pt = np.matmul(Pt, M)
  190. # end
  191. # end
  192. if is_got_bad_item:
  193. return -1, -1, -1
  194. # Make 1D temporal signals
  195. P_seq_t = []
  196. P_seq_r = []
  197. for Mp in P_seq:
  198. transRecovered = np.sqrt(Mp[0, 2]**2 + Mp[1, 2]**2)
  199. # Based on https://math.stackexchange.com/questions/78137/decomposition-of-a-nonsquare-affine-matrix
  200. thetaRecovered = np.arctan2(Mp[1, 0], Mp[0, 0]) * 180 / np.pi
  201. P_seq_t.append(transRecovered)
  202. P_seq_r.append(thetaRecovered)
  203. # FFT
  204. fft_t = np.fft.fft(P_seq_t)
  205. fft_r = np.fft.fft(P_seq_r)
  206. fft_t = np.abs(fft_t)**2
  207. fft_r = np.abs(fft_r)**2
  208. fft_t = np.delete(fft_t, 0)
  209. fft_r = np.delete(fft_r, 0)
  210. fft_t = fft_t[:len(fft_t) // 2]
  211. fft_r = fft_r[:len(fft_r) // 2]
  212. SS_t = np.sum(fft_t[:5]) / np.sum(fft_t)
  213. SS_r = np.sum(fft_r[:5]) / np.sum(fft_r)
  214. return np.min([np.mean(CR_seq),
  215. 1]), np.absolute(np.min(DV_seq)), (SS_t + SS_r) / 2