ppl_metric.py 2.0 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. # Copyright (c) Alibaba, Inc. and its affiliates.
  2. import math
  3. from typing import Dict, Union
  4. import numpy as np
  5. import torch
  6. import torch.nn.functional as F
  7. from modelscope.metainfo import Metrics
  8. from modelscope.outputs import OutputKeys
  9. from modelscope.utils.registry import default_group
  10. from .base import Metric
  11. from .builder import METRICS, MetricKeys
  12. @METRICS.register_module(group_key=default_group, module_name=Metrics.PPL)
  13. class PplMetric(Metric):
  14. """The metric computation class for any classes.
  15. This metric class calculates perplexity for the whole input batches.
  16. """
  17. def __init__(self, *args, **kwargs):
  18. super().__init__(*args, **kwargs)
  19. self.avg_loss: float = 0.
  20. self.batch_num: int = 0
  21. def add(self, outputs: Dict, inputs: Dict):
  22. logits = outputs[OutputKeys.LOGITS]
  23. labels = inputs[OutputKeys.LABELS]
  24. in_loss = self._get_loss(logits, labels)
  25. in_batch_num = self._get_batch_num(inputs[OutputKeys.LABELS])
  26. self.avg_loss = self._average_loss(in_loss, in_batch_num)
  27. self.batch_num += in_batch_num
  28. @staticmethod
  29. def _get_loss(logits: torch.Tensor, labels: torch.Tensor) -> float:
  30. labels = labels.view(-1)
  31. logits = logits.view(labels.shape[0], -1)
  32. return F.cross_entropy(logits, labels).item()
  33. @staticmethod
  34. def _get_batch_num(matrix: Union[np.ndarray, torch.Tensor]) -> int:
  35. return matrix.shape[0]
  36. def _average_loss(self, in_loss: float, in_batch_num):
  37. return (self.avg_loss * self.batch_num + in_loss * in_batch_num) \
  38. / (self.batch_num + in_batch_num)
  39. def evaluate(self) -> Dict[str, float]:
  40. return {MetricKeys.PPL: math.exp(self.avg_loss)}
  41. def merge(self, other: 'PplMetric'):
  42. self.avg_loss = self._average_loss(other.avg_loss, other.batch_num)
  43. self.batch_num += other.batch_num
  44. def __getstate__(self):
  45. return self.avg_loss, self.batch_num
  46. def __setstate__(self, state):
  47. self.__init__()
  48. self.avg_loss, self.batch_num = state