util.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107
  1. # The implementation is adopted from mmcv,
  2. # made publicly available under the Apache 2.0 License at
  3. # https://github.com/open-mmlab/mmcv/blob/master/mmcv/utils
  4. import os
  5. import os.path as osp
  6. import sys
  7. from collections.abc import Iterable
  8. from shutil import get_terminal_size
  9. from .timer import Timer
  10. def check_file_exist(filename, msg_tmpl='file "{}" does not exist'):
  11. if not osp.isfile(filename):
  12. raise FileNotFoundError(msg_tmpl.format(filename))
  13. def mkdir_or_exist(dir_name, mode=0o777):
  14. if dir_name == '':
  15. return
  16. dir_name = osp.expanduser(dir_name)
  17. os.makedirs(dir_name, mode=mode, exist_ok=True)
  18. class ProgressBar:
  19. """A progress bar which can print the progress."""
  20. def __init__(self, task_num=0, bar_width=50, start=True, file=sys.stdout):
  21. self.task_num = task_num
  22. self.bar_width = bar_width
  23. self.completed = 0
  24. self.file = file
  25. if start:
  26. self.start()
  27. @property
  28. def terminal_width(self):
  29. width, _ = get_terminal_size()
  30. return width
  31. def start(self):
  32. if self.task_num > 0:
  33. self.file.write(f'[{" " * self.bar_width}] 0/{self.task_num}, '
  34. 'elapsed: 0s, ETA:')
  35. else:
  36. self.file.write('completed: 0, elapsed: 0s')
  37. self.file.flush()
  38. self.timer = Timer()
  39. def update(self, num_tasks=1):
  40. assert num_tasks > 0
  41. self.completed += num_tasks
  42. elapsed = self.timer.since_start()
  43. if elapsed > 0:
  44. fps = self.completed / elapsed
  45. else:
  46. fps = float('inf')
  47. if self.task_num > 0:
  48. percentage = self.completed / float(self.task_num)
  49. eta = int(elapsed * (1 - percentage) / percentage + 0.5)
  50. msg = f'\r[{{}}] {self.completed}/{self.task_num}, ' \
  51. f'{fps:.1f} task/s, elapsed: {int(elapsed + 0.5)}s, ' \
  52. f'ETA: {eta:5}s'
  53. bar_width = min(self.bar_width,
  54. int(self.terminal_width - len(msg)) + 2,
  55. int(self.terminal_width * 0.6))
  56. bar_width = max(2, bar_width)
  57. mark_width = int(bar_width * percentage)
  58. bar_chars = '>' * mark_width + ' ' * (bar_width - mark_width)
  59. self.file.write(msg.format(bar_chars))
  60. else:
  61. self.file.write(
  62. f'completed: {self.completed}, elapsed: {int(elapsed + 0.5)}s,'
  63. f' {fps:.1f} tasks/s')
  64. self.file.flush()
  65. def track_progress(func, tasks, bar_width=50, file=sys.stdout, **kwargs):
  66. """Track the progress of tasks execution with a progress bar.
  67. Tasks are done with a simple for-loop.
  68. Args:
  69. func (callable): The function to be applied to each task.
  70. tasks (list or tuple[Iterable, int]): A list of tasks or
  71. (tasks, total num).
  72. bar_width (int): Width of progress bar.
  73. Returns:
  74. list: The task results.
  75. """
  76. if isinstance(tasks, tuple):
  77. assert len(tasks) == 2
  78. assert isinstance(tasks[0], Iterable)
  79. assert isinstance(tasks[1], int)
  80. task_num = tasks[1]
  81. tasks = tasks[0]
  82. elif isinstance(tasks, Iterable):
  83. task_num = len(tasks)
  84. else:
  85. raise TypeError(
  86. '"tasks" must be an iterable object or a (iterator, int) tuple')
  87. prog_bar = ProgressBar(task_num, bar_width, file=file)
  88. results = []
  89. for task in tasks:
  90. results.append(func(task, **kwargs))
  91. prog_bar.update()
  92. prog_bar.file.write('\n')
  93. return results