# The implementation is adopted from mmcv, # made publicly available under the Apache 2.0 License at # https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils import os import os.path as osp import sys from collections.abc import Iterable from shutil import get_terminal_size from .timer import Timer def check_file_exist(filename, msg_tmpl='file "{}" does not exist'): if not osp.isfile(filename): raise FileNotFoundError(msg_tmpl.format(filename)) def mkdir_or_exist(dir_name, mode=0o777): if dir_name == '': return dir_name = osp.expanduser(dir_name) os.makedirs(dir_name, mode=mode, exist_ok=True) class ProgressBar: """A progress bar which can print the progress.""" def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout): self.task_num = task_num self.bar_width = bar_width self.completed = 0 self.file = file if start: self.start() @property def terminal_width(self): width, _ = get_terminal_size() return width def start(self): if self.task_num > 0: self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, ' 'elapsed: 0s, ETA:') else: self.file.write('completed: 0, elapsed: 0s') self.file.flush() self.timer = Timer() def update(self, num_tasks=1): assert num_tasks > 0 self.completed += num_tasks elapsed = self.timer.since_start() if elapsed > 0: fps = self.completed / elapsed else: fps = float('inf') if self.task_num > 0: percentage = self.completed / float(self.task_num) eta = int(elapsed * (1 - percentage) / percentage + 0.5) msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \ f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \ f'ETA: {eta:5}s' bar_width = min(self.bar_width, int(self.terminal_width - len(msg)) + 2, int(self.terminal_width * 0.6)) bar_width = max(2, bar_width) mark_width = int(bar_width * percentage) bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width) self.file.write(msg.format(bar_chars)) else: self.file.write( f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,' f' {fps:.1f} tasks/s') self.file.flush() def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs): """Track the progress of tasks execution with a progress bar. Tasks are done with a simple for-loop. Args: func (callable): The function to be applied to each task. tasks (list or tuple[Iterable, int]): A list of tasks or (tasks, total num). bar_width (int): Width of progress bar. Returns: list: The task results. """ if isinstance(tasks, tuple): assert len(tasks) == 2 assert isinstance(tasks[0], Iterable) assert isinstance(tasks[1], int) task_num = tasks[1] tasks = tasks[0] elif isinstance(tasks, Iterable): task_num = len(tasks) else: raise TypeError( '"tasks" must be an iterable object or a (iterator, int) tuple') prog_bar = ProgressBar(task_num, bar_width, file=file) results = [] for task in tasks: results.append(func(task, **kwargs)) prog_bar.update() prog_bar.file.write('\n') return results