| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465 |
- # Copyright (c) Alibaba, Inc. and its affiliates.
- import math
- from typing import Dict, Union
- import numpy as np
- import torch
- import torch.nn.functional as F
- from modelscope.metainfo import Metrics
- from modelscope.outputs import OutputKeys
- from modelscope.utils.registry import default_group
- from .base import Metric
- from .builder import METRICS, MetricKeys
- @METRICS.register_module(group_key=default_group, module_name=Metrics.PPL)
- class PplMetric(Metric):
- """The metric computation class for any classes.
- This metric class calculates perplexity for the whole input batches.
- """
- def __init__(self, *args, **kwargs):
- super().__init__(*args, **kwargs)
- self.avg_loss: float = 0.
- self.batch_num: int = 0
- def add(self, outputs: Dict, inputs: Dict):
- logits = outputs[OutputKeys.LOGITS]
- labels = inputs[OutputKeys.LABELS]
- in_loss = self._get_loss(logits, labels)
- in_batch_num = self._get_batch_num(inputs[OutputKeys.LABELS])
- self.avg_loss = self._average_loss(in_loss, in_batch_num)
- self.batch_num += in_batch_num
- @staticmethod
- def _get_loss(logits: torch.Tensor, labels: torch.Tensor) -> float:
- labels = labels.view(-1)
- logits = logits.view(labels.shape[0], -1)
- return F.cross_entropy(logits, labels).item()
- @staticmethod
- def _get_batch_num(matrix: Union[np.ndarray, torch.Tensor]) -> int:
- return matrix.shape[0]
- def _average_loss(self, in_loss: float, in_batch_num):
- return (self.avg_loss * self.batch_num + in_loss * in_batch_num) \
- / (self.batch_num + in_batch_num)
- def evaluate(self) -> Dict[str, float]:
- return {MetricKeys.PPL: math.exp(self.avg_loss)}
- def merge(self, other: 'PplMetric'):
- self.avg_loss = self._average_loss(other.avg_loss, other.batch_num)
- self.batch_num += other.batch_num
- def __getstate__(self):
- return self.avg_loss, self.batch_num
- def __setstate__(self, state):
- self.__init__()
- self.avg_loss, self.batch_num = state
|