ciderD.py 1.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657
  1. # Filename: ciderD.py
  2. #
  3. # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
  4. # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
  5. #
  6. # Creation Date: Sun Feb 8 14:16:54 2015
  7. #
  8. # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
  9. from __future__ import absolute_import, division, print_function
  10. from .ciderD_scorer import CiderScorer
  11. class CiderD:
  12. """
  13. Main Class to compute the CIDEr metric
  14. """
  15. def __init__(self, n=4, sigma=6.0, df='corpus'):
  16. # set cider to sum over 1 to 4-grams
  17. self._n = n
  18. # set the standard deviation parameter for gaussian penalty
  19. self._sigma = sigma
  20. # set which where to compute document frequencies from
  21. self._df = df
  22. self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
  23. def compute_score(self, gts, res):
  24. """
  25. Main function to compute CIDEr score
  26. :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
  27. ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
  28. :return: cider (float) : computed CIDEr score for the corpus
  29. """ # noqa
  30. # clear all the previous hypos and refs
  31. tmp_cider_scorer = self.cider_scorer.copy_empty()
  32. tmp_cider_scorer.clear()
  33. for res_id in res:
  34. hypo = res_id['caption']
  35. ref = gts[res_id['image_id']]
  36. # Sanity check.
  37. assert (type(hypo) is list)
  38. assert (len(hypo) == 1)
  39. assert (type(ref) is list)
  40. assert (len(ref) > 0)
  41. tmp_cider_scorer += (hypo[0], ref)
  42. (score, scores) = tmp_cider_scorer.compute_score()
  43. return score, scores
  44. def method(self):
  45. return 'CIDEr-D'