loss_metric.py 1.3 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. from typing import Dict
  3. import numpy as np
  4. from sklearn.metrics import accuracy_score, f1_score
  5. from modelscope.metainfo import Metrics
  6. from modelscope.outputs import OutputKeys
  7. from modelscope.utils.registry import default_group
  8. from modelscope.utils.tensor_utils import (torch_nested_detach,
  9. torch_nested_numpify)
  10. from .base import Metric
  11. from .builder import METRICS, MetricKeys
  12. @METRICS.register_module(
  13. group_key=default_group, module_name=Metrics.loss_metric)
  14. class LossMetric(Metric):
  15. """The metric class to calculate average loss of batches.
  16. Args:
  17. loss_key: The key of loss
  18. """
  19. def __init__(self, loss_key=OutputKeys.LOSS, *args, **kwargs):
  20. super().__init__(*args, **kwargs)
  21. self.loss_key = loss_key
  22. self.losses = []
  23. def add(self, outputs: Dict, inputs: Dict):
  24. loss = outputs[self.loss_key]
  25. self.losses.append(torch_nested_numpify(torch_nested_detach(loss)))
  26. def evaluate(self):
  27. return {OutputKeys.LOSS: float(np.average(self.losses))}
  28. def merge(self, other: 'LossMetric'):
  29. self.losses.extend(other.losses)
  30. def __getstate__(self):
  31. return self.losses
  32. def __setstate__(self, state):
  33. self.__init__()
  34. self.losses = state