metrics.py 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # Adapted from score written by wkentaro
  2. # https://github.com/wkentaro/pytorch-fcn/blob/master/torchfcn/utils.py
  3. import numpy as np
  4. class runningScore(object):
  5. def __init__(self, n_classes):
  6. self.n_classes = n_classes
  7. self.confusion_matrix = np.zeros((n_classes, n_classes))
  8. def _fast_hist(self, label_true, label_pred, n_class):
  9. mask = (label_true >= 0) & (label_true < n_class)
  10. if np.sum((label_pred[mask] < 0)) > 0:
  11. print(label_pred[label_pred < 0])
  12. hist = np.bincount(
  13. n_class * label_true[mask].astype(int) + label_pred[mask],
  14. minlength=n_class**2,
  15. ).reshape(n_class, n_class)
  16. return hist
  17. def update(self, label_trues, label_preds):
  18. # print label_trues.dtype, label_preds.dtype
  19. for lt, lp in zip(label_trues, label_preds):
  20. try:
  21. self.confusion_matrix += self._fast_hist(
  22. lt.flatten(), lp.flatten(), self.n_classes
  23. )
  24. except:
  25. pass
  26. def get_scores(self):
  27. """Returns accuracy score evaluation result.
  28. - overall accuracy
  29. - mean accuracy
  30. - mean IU
  31. - fwavacc
  32. """
  33. hist = self.confusion_matrix
  34. acc = np.diag(hist).sum() / (hist.sum() + 0.0001)
  35. acc_cls = np.diag(hist) / (hist.sum(axis=1) + 0.0001)
  36. acc_cls = np.nanmean(acc_cls)
  37. iu = np.diag(hist) / (
  38. hist.sum(axis=1) + hist.sum(axis=0) - np.diag(hist) + 0.0001
  39. )
  40. mean_iu = np.nanmean(iu)
  41. freq = hist.sum(axis=1) / (hist.sum() + 0.0001)
  42. fwavacc = (freq[freq > 0] * iu[freq > 0]).sum()
  43. cls_iu = dict(zip(range(self.n_classes), iu))
  44. return {
  45. "Overall Acc": acc,
  46. "Mean Acc": acc_cls,
  47. "FreqW Acc": fwavacc,
  48. "Mean IoU": mean_iu,
  49. }, cls_iu
  50. def reset(self):
  51. self.confusion_matrix = np.zeros((self.n_classes, self.n_classes))