log_buffer.py 1.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142
  1. # Copyright (c) OpenMMLab. All rights reserved.
  2. # Copyright (c) Alibaba, Inc. and its affiliates.
  3. from collections import OrderedDict
  4. import numpy as np
  5. class LogBuffer:
  6. def __init__(self):
  7. self.val_history = OrderedDict()
  8. self.n_history = OrderedDict()
  9. self.output = OrderedDict()
  10. self.ready = False
  11. def clear(self) -> None:
  12. self.val_history.clear()
  13. self.n_history.clear()
  14. self.clear_output()
  15. def clear_output(self) -> None:
  16. self.output.clear()
  17. self.ready = False
  18. def update(self, vars: dict, count: int = 1) -> None:
  19. assert isinstance(vars, dict)
  20. for key, var in vars.items():
  21. if key not in self.val_history:
  22. self.val_history[key] = []
  23. self.n_history[key] = []
  24. self.val_history[key].append(var)
  25. self.n_history[key].append(count)
  26. def average(self, n: int = 0) -> None:
  27. """Average latest n values or all values."""
  28. assert n >= 0
  29. for key in self.val_history:
  30. values = np.array(self.val_history[key][-n:])
  31. nums = np.array(self.n_history[key][-n:])
  32. avg = np.sum(values * nums) / np.sum(nums)
  33. self.output[key] = avg
  34. self.ready = True