| 123456789101112131415161718192021222324252627282930313233343536373839404142 |
- # Copyright (c) OpenMMLab. All rights reserved.
- # Copyright (c) Alibaba, Inc. and its affiliates.
- from collections import OrderedDict
- import numpy as np
- class LogBuffer:
- def __init__(self):
- self.val_history = OrderedDict()
- self.n_history = OrderedDict()
- self.output = OrderedDict()
- self.ready = False
- def clear(self) -> None:
- self.val_history.clear()
- self.n_history.clear()
- self.clear_output()
- def clear_output(self) -> None:
- self.output.clear()
- self.ready = False
- def update(self, vars: dict, count: int = 1) -> None:
- assert isinstance(vars, dict)
- for key, var in vars.items():
- if key not in self.val_history:
- self.val_history[key] = []
- self.n_history[key] = []
- self.val_history[key].append(var)
- self.n_history[key].append(count)
- def average(self, n: int = 0) -> None:
- """Average latest n values or all values."""
- assert n >= 0
- for key in self.val_history:
- values = np.array(self.val_history[key][-n:])
- nums = np.array(self.n_history[key][-n:])
- avg = np.sum(values * nums) / np.sum(nums)
- self.output[key] = avg
- self.ready = True
|