metric.py 2.3 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798
  1. # Copyright (c) Megvii Inc. All rights reserved.
  2. # Copyright © Alibaba, Inc. and its affiliates.
  3. import functools
  4. import os
  5. from collections import defaultdict, deque
  6. import numpy as np
  7. import torch
  8. __all__ = [
  9. 'AverageMeter',
  10. 'MeterBuffer',
  11. 'gpu_mem_usage',
  12. ]
  13. def gpu_mem_usage():
  14. """
  15. Compute the GPU memory usage for the current device (MB).
  16. """
  17. mem_usage_bytes = torch.cuda.max_memory_allocated()
  18. return mem_usage_bytes / (1024 * 1024)
  19. class AverageMeter:
  20. """Track a series of values and provide access to smoothed values over a
  21. window or the global series average.
  22. """
  23. def __init__(self, window_size=50):
  24. self._deque = deque(maxlen=window_size)
  25. self._total = 0.0
  26. self._count = 0
  27. def update(self, value):
  28. self._deque.append(value)
  29. self._count += 1
  30. self._total += value
  31. @property
  32. def median(self):
  33. d = np.array(list(self._deque))
  34. return np.median(d)
  35. @property
  36. def avg(self):
  37. # if deque is empty, nan will be returned.
  38. d = np.array(list(self._deque))
  39. return d.mean()
  40. @property
  41. def global_avg(self):
  42. return self._total / max(self._count, 1e-5)
  43. @property
  44. def latest(self):
  45. return self._deque[-1] if len(self._deque) > 0 else None
  46. @property
  47. def total(self):
  48. return self._total
  49. def reset(self):
  50. self._deque.clear()
  51. self._total = 0.0
  52. self._count = 0
  53. def clear(self):
  54. self._deque.clear()
  55. class MeterBuffer(defaultdict):
  56. """Computes and stores the average and current value"""
  57. def __init__(self, window_size=20):
  58. factory = functools.partial(AverageMeter, window_size=window_size)
  59. super().__init__(factory)
  60. def reset(self):
  61. for v in self.values():
  62. v.reset()
  63. def get_filtered_meter(self, filter_key='time'):
  64. return {k: v for k, v in self.items() if filter_key in k}
  65. def update(self, values=None, **kwargs):
  66. if values is None:
  67. values = {}
  68. values.update(kwargs)
  69. for k, v in values.items():
  70. if isinstance(v, torch.Tensor):
  71. v = v.detach()
  72. self[k].update(v)
  73. def clear_meters(self):
  74. for v in self.values():
  75. v.clear()