| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657 |
- # Filename: ciderD.py
- #
- # Description: Describes the class to compute the CIDEr-D (Consensus-Based Image Description Evaluation) Metric
- # by Vedantam, Zitnick, and Parikh (http://arxiv.org/abs/1411.5726)
- #
- # Creation Date: Sun Feb 8 14:16:54 2015
- #
- # Authors: Ramakrishna Vedantam <vrama91@vt.edu> and Tsung-Yi Lin <tl483@cornell.edu>
- from __future__ import absolute_import, division, print_function
- from .ciderD_scorer import CiderScorer
- class CiderD:
- """
- Main Class to compute the CIDEr metric
- """
- def __init__(self, n=4, sigma=6.0, df='corpus'):
- # set cider to sum over 1 to 4-grams
- self._n = n
- # set the standard deviation parameter for gaussian penalty
- self._sigma = sigma
- # set which where to compute document frequencies from
- self._df = df
- self.cider_scorer = CiderScorer(n=self._n, df_mode=self._df)
- def compute_score(self, gts, res):
- """
- Main function to compute CIDEr score
- :param hypo_for_image (dict) : dictionary with key <image> and value <tokenized hypothesis / candidate sentence>
- ref_for_image (dict) : dictionary with key <image> and value <tokenized reference sentence>
- :return: cider (float) : computed CIDEr score for the corpus
- """ # noqa
- # clear all the previous hypos and refs
- tmp_cider_scorer = self.cider_scorer.copy_empty()
- tmp_cider_scorer.clear()
- for res_id in res:
- hypo = res_id['caption']
- ref = gts[res_id['image_id']]
- # Sanity check.
- assert (type(hypo) is list)
- assert (len(hypo) == 1)
- assert (type(ref) is list)
- assert (len(ref) > 0)
- tmp_cider_scorer += (hypo[0], ref)
- (score, scores) = tmp_cider_scorer.compute_score()
- return score, scores
- def method(self):
- return 'CIDEr-D'
|