inbatch_recall_metric.py 2.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. import torch
  5. from modelscope.metainfo import Metrics
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.utils.registry import default_group
  8. from .base import Metric
  9. from .builder import METRICS, MetricKeys
  10. @METRICS.register_module(
  11. group_key=default_group, module_name=Metrics.inbatch_recall)
  12. class InbatchRecallMetric(Metric):
  13. """The metric computation class for in-batch retrieval classes.
  14. This metric class calculates in-batch image recall@1 for each input batch.
  15. """
  16. def __init__(self, *args, **kwargs):
  17. super().__init__(*args, **kwargs)
  18. self.inbatch_t2i_hitcnts = []
  19. self.batch_sizes = []
  20. def add(self, outputs: Dict, inputs: Dict):
  21. image_features = outputs[OutputKeys.IMG_EMBEDDING]
  22. text_features = outputs[OutputKeys.TEXT_EMBEDDING]
  23. assert type(image_features) == torch.Tensor and type(
  24. text_features) == torch.Tensor
  25. with torch.no_grad():
  26. logits_per_image = image_features @ text_features.t()
  27. logits_per_text = logits_per_image.t()
  28. batch_size = logits_per_image.shape[0]
  29. ground_truth = torch.arange(batch_size).long()
  30. ground_truth = ground_truth.to(image_features.device)
  31. inbatch_t2i_hitcnt = (logits_per_text.argmax(-1) == ground_truth
  32. ).sum().float().item()
  33. self.inbatch_t2i_hitcnts.append(inbatch_t2i_hitcnt)
  34. self.batch_sizes.append(batch_size)
  35. def evaluate(self):
  36. assert len(self.inbatch_t2i_hitcnts) == len(
  37. self.batch_sizes) and len(self.batch_sizes) > 0
  38. return {
  39. MetricKeys.BatchAcc:
  40. sum(self.inbatch_t2i_hitcnts) / sum(self.batch_sizes)
  41. }
  42. def merge(self, other: 'InbatchRecallMetric'):
  43. self.inbatch_t2i_hitcnts.extend(other.inbatch_t2i_hitcnts)
  44. self.batch_sizes.extend(other.batch_sizes)
  45. def __getstate__(self):
  46. return self.inbatch_t2i_hitcnts, self.batch_sizes
  47. def __setstate__(self, state):
  48. self.__init__()
  49. self.inbatch_t2i_hitcnts, self.batch_sizes = state