metric.py 1.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. import paddle
  2. import paddle.nn.functional as F
  3. from collections import OrderedDict
  4. def create_metric(
  5. out,
  6. label,
  7. architecture=None,
  8. topk=5,
  9. classes_num=1000,
  10. use_distillation=False,
  11. mode="train",
  12. ):
  13. """
  14. Create measures of model accuracy, such as top1 and top5
  15. Args:
  16. out(variable): model output variable
  17. feeds(dict): dict of model input variables(included label)
  18. topk(int): usually top5
  19. classes_num(int): num of classes
  20. use_distillation(bool): whether to use distillation training
  21. mode(str): mode, train/valid
  22. Returns:
  23. fetches(dict): dict of measures
  24. """
  25. # if architecture["name"] == "GoogLeNet":
  26. # assert len(out) == 3, "GoogLeNet should have 3 outputs"
  27. # out = out[0]
  28. # else:
  29. # # just need student label to get metrics
  30. # if use_distillation:
  31. # out = out[1]
  32. softmax_out = F.softmax(out)
  33. fetches = OrderedDict()
  34. # set top1 to fetches
  35. top1 = paddle.metric.accuracy(softmax_out, label=label, k=1)
  36. # set topk to fetches
  37. k = min(topk, classes_num)
  38. topk = paddle.metric.accuracy(softmax_out, label=label, k=k)
  39. # multi cards' eval
  40. if mode != "train" and paddle.distributed.get_world_size() > 1:
  41. top1 = (
  42. paddle.distributed.all_reduce(top1, op=paddle.distributed.ReduceOp.SUM)
  43. / paddle.distributed.get_world_size()
  44. )
  45. topk = (
  46. paddle.distributed.all_reduce(topk, op=paddle.distributed.ReduceOp.SUM)
  47. / paddle.distributed.get_world_size()
  48. )
  49. fetches["top1"] = top1
  50. topk_name = "top{}".format(k)
  51. fetches[topk_name] = topk
  52. return fetches