| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798 |
- # Copyright (c) Megvii Inc. All rights reserved.
- # Copyright © Alibaba, Inc. and its affiliates.
- import functools
- import os
- from collections import defaultdict, deque
- import numpy as np
- import torch
- __all__ = [
- 'AverageMeter',
- 'MeterBuffer',
- 'gpu_mem_usage',
- ]
- def gpu_mem_usage():
- """
- Compute the GPU memory usage for the current device (MB).
- """
- mem_usage_bytes = torch.cuda.max_memory_allocated()
- return mem_usage_bytes / (1024 * 1024)
- class AverageMeter:
- """Track a series of values and provide access to smoothed values over a
- window or the global series average.
- """
- def __init__(self, window_size=50):
- self._deque = deque(maxlen=window_size)
- self._total = 0.0
- self._count = 0
- def update(self, value):
- self._deque.append(value)
- self._count += 1
- self._total += value
- @property
- def median(self):
- d = np.array(list(self._deque))
- return np.median(d)
- @property
- def avg(self):
- # if deque is empty, nan will be returned.
- d = np.array(list(self._deque))
- return d.mean()
- @property
- def global_avg(self):
- return self._total / max(self._count, 1e-5)
- @property
- def latest(self):
- return self._deque[-1] if len(self._deque) > 0 else None
- @property
- def total(self):
- return self._total
- def reset(self):
- self._deque.clear()
- self._total = 0.0
- self._count = 0
- def clear(self):
- self._deque.clear()
- class MeterBuffer(defaultdict):
- """Computes and stores the average and current value"""
- def __init__(self, window_size=20):
- factory = functools.partial(AverageMeter, window_size=window_size)
- super().__init__(factory)
- def reset(self):
- for v in self.values():
- v.reset()
- def get_filtered_meter(self, filter_key='time'):
- return {k: v for k, v in self.items() if filter_key in k}
- def update(self, values=None, **kwargs):
- if values is None:
- values = {}
- values.update(kwargs)
- for k, v in values.items():
- if isinstance(v, torch.Tensor):
- v = v.detach()
- self[k].update(v)
- def clear_meters(self):
- for v in self.values():
- v.clear()
|