base.py 1.4 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from abc import ABC, abstractmethod
  3. from typing import Dict
  4. class Metric(ABC):
  5. """The metric base class for computing metrics.
  6. The subclasses can either compute a single metric like 'accuracy', or compute the
  7. complex metrics for a specific task with or without other Metric subclasses.
  8. """
  9. def __init__(self, *args, **kwargs):
  10. pass
  11. @abstractmethod
  12. def add(self, outputs: Dict, inputs: Dict):
  13. """ Append logits and labels within an eval loop.
  14. Will be called after every batch finished to gather the model predictions and the labels.
  15. Args:
  16. outputs: The model prediction outputs.
  17. inputs: The mini batch inputs from the dataloader.
  18. Returns: None
  19. """
  20. pass
  21. @abstractmethod
  22. def evaluate(self):
  23. """Evaluate the metrics after the eval finished.
  24. Will be called after the whole validation finished.
  25. Returns: The actual metric dict with standard names.
  26. """
  27. pass
  28. @abstractmethod
  29. def merge(self, other: 'Metric'):
  30. """ When using data parallel, the data required for different metric calculations
  31. are stored in their respective Metric classes,
  32. and we need to merge these data to uniformly calculate metric.
  33. Args:
  34. other: Another Metric instance.
  35. Returns: None
  36. """
  37. pass