video_summarization_metric.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. # Part of the implementation is borrowed and modified from PGL-SUM,
  2. # publicly available at https://github.com/e-apostolidis/PGL-SUM
  3. from typing import Dict
  4. import numpy as np
  5. from modelscope.metainfo import Metrics
  6. from modelscope.models.cv.video_summarization.summarizer import \
  7. generate_summary
  8. from modelscope.utils.registry import default_group
  9. from .base import Metric
  10. from .builder import METRICS, MetricKeys
  11. def evaluate_summary(predicted_summary, user_summary, eval_method):
  12. """ Compare the predicted summary with the user defined one(s).
  13. :param ndarray predicted_summary: The generated summary from our model.
  14. :param ndarray user_summary: The user defined ground truth summaries (or summary).
  15. :param str eval_method: The proposed evaluation method; either 'max' (SumMe) or 'avg' (TVSum).
  16. :return: The reduced fscore based on the eval_method
  17. """
  18. max_len = max(len(predicted_summary), user_summary.shape[1])
  19. S = np.zeros(max_len, dtype=int)
  20. G = np.zeros(max_len, dtype=int)
  21. S[:len(predicted_summary)] = predicted_summary
  22. f_scores = []
  23. for user in range(user_summary.shape[0]):
  24. G[:user_summary.shape[1]] = user_summary[user]
  25. overlapped = S & G
  26. # Compute precision, recall, f-score
  27. precision = sum(overlapped) / sum(S)
  28. recall = sum(overlapped) / sum(G)
  29. if precision + recall == 0:
  30. f_scores.append(0)
  31. else:
  32. f_score = 2 * precision * recall * 100 / (precision + recall)
  33. f_scores.append(f_score)
  34. if eval_method == 'max':
  35. return max(f_scores)
  36. else:
  37. return sum(f_scores) / len(f_scores)
  38. def calculate_f_score(outputs: Dict, inputs: Dict):
  39. scores = outputs['scores']
  40. scores = scores.squeeze(0).cpu().numpy().tolist()
  41. user_summary = inputs['user_summary'].cpu().numpy()[0]
  42. sb = inputs['change_points'].cpu().numpy()[0]
  43. n_frames = inputs['n_frames'].cpu().numpy()[0]
  44. positions = inputs['positions'].cpu().numpy()[0]
  45. summary = generate_summary([sb], [scores], [n_frames], [positions])[0]
  46. f_score = evaluate_summary(summary, user_summary, 'avg')
  47. return f_score
  48. @METRICS.register_module(
  49. group_key=default_group, module_name=Metrics.video_summarization_metric)
  50. class VideoSummarizationMetric(Metric):
  51. """The metric for video summarization task.
  52. """
  53. def __init__(self):
  54. self.inputs = []
  55. self.outputs = []
  56. def add(self, outputs: Dict, inputs: Dict):
  57. self.outputs.append(outputs)
  58. self.inputs.append(inputs)
  59. def evaluate(self):
  60. f_scores = [
  61. calculate_f_score(output, input)
  62. for output, input in zip(self.outputs, self.inputs)
  63. ]
  64. return {MetricKeys.FScore: sum(f_scores) / len(f_scores)}
  65. def merge(self, other: 'VideoSummarizationMetric'):
  66. self.inputs.extend(other.inputs)
  67. self.outputs.extend(other.outputs)
  68. def __getstate__(self):
  69. return self.inputs, self.outputs
  70. def __setstate__(self, state):
  71. self.inputs, self.outputs = state