progressbar.py 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213
  1. # Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import struct
  16. import sys
  17. import time
  18. import numpy as np
  19. __all__ = []
  20. class ProgressBar:
  21. """progress bar"""
  22. def __init__(
  23. self,
  24. num=None,
  25. width=30,
  26. verbose=1,
  27. start=True,
  28. file=sys.stdout,
  29. name='step',
  30. ):
  31. self._num = num
  32. if isinstance(num, int) and num <= 0:
  33. raise TypeError('num should be None or integer (> 0)')
  34. max_width = self._get_max_width()
  35. self._width = width if width <= max_width else max_width
  36. self._total_width = 0
  37. self._verbose = verbose
  38. self.file = file
  39. self._values = {}
  40. self._values_order = []
  41. if start:
  42. self._start = time.time()
  43. self._last_update = 0
  44. self.name = name
  45. self._dynamic_display = (
  46. (hasattr(self.file, 'isatty') and self.file.isatty())
  47. or 'ipykernel' in sys.modules
  48. or 'posix' in sys.modules
  49. or 'PYCHARM_HOSTED' in os.environ
  50. )
  51. def _get_max_width(self):
  52. from shutil import get_terminal_size
  53. terminal_width, _ = get_terminal_size()
  54. terminal_width = terminal_width if terminal_width > 0 else 80
  55. max_width = min(int(terminal_width * 0.6), terminal_width - 50)
  56. return max_width
  57. def start(self):
  58. self.file.flush()
  59. self._start = time.time()
  60. def update(self, current_num, values={}):
  61. now = time.time()
  62. def convert_uint16_to_float(in_list):
  63. in_list = np.asarray(in_list)
  64. out = np.vectorize(
  65. lambda x: struct.unpack('<f', struct.pack('<I', x << 16))[0],
  66. otypes=[np.float32],
  67. )(in_list.flat)
  68. return np.reshape(out, in_list.shape)
  69. for i, (k, val) in enumerate(values):
  70. if k == "loss":
  71. if isinstance(val, list):
  72. scalar_val = val[0]
  73. else:
  74. scalar_val = val
  75. if isinstance(scalar_val, np.uint16):
  76. values[i] = ("loss", list(convert_uint16_to_float(val)))
  77. if current_num:
  78. time_per_unit = (now - self._start) / current_num
  79. else:
  80. time_per_unit = 0
  81. if time_per_unit >= 1 or time_per_unit == 0:
  82. fps = f' - {time_per_unit:.0f}s/{self.name}'
  83. elif time_per_unit >= 1e-3:
  84. fps = f' - {time_per_unit * 1e3:.0f}ms/{self.name}'
  85. else:
  86. fps = f' - {time_per_unit * 1e6:.0f}us/{self.name}'
  87. info = ''
  88. if self._verbose == 1:
  89. prev_total_width = self._total_width
  90. if self._dynamic_display:
  91. sys.stdout.write('\b' * prev_total_width)
  92. sys.stdout.write('\r')
  93. else:
  94. sys.stdout.write('\n')
  95. if self._num is not None:
  96. numdigits = int(np.log10(self._num)) + 1
  97. bar_chars = (self.name + ' %' + str(numdigits) + 'd/%d [') % (
  98. current_num,
  99. self._num,
  100. )
  101. prog = float(current_num) / self._num
  102. prog_width = int(self._width * prog)
  103. if prog_width > 0:
  104. bar_chars += '=' * (prog_width - 1)
  105. if current_num < self._num:
  106. bar_chars += '>'
  107. else:
  108. bar_chars += '='
  109. bar_chars += '.' * (self._width - prog_width)
  110. bar_chars += ']'
  111. else:
  112. bar_chars = self.name + ' %3d' % current_num
  113. self._total_width = len(bar_chars)
  114. sys.stdout.write(bar_chars)
  115. for k, val in values:
  116. info += ' - %s:' % k
  117. val = val if isinstance(val, list) else [val]
  118. for i, v in enumerate(val):
  119. if isinstance(v, (float, np.float32, np.float64)):
  120. if abs(v) > 1e-3:
  121. info += ' %.4f' % v
  122. else:
  123. info += ' %.4e' % v
  124. else:
  125. info += ' %s' % v
  126. if self._num is not None and current_num < self._num:
  127. eta = time_per_unit * (self._num - current_num)
  128. if eta > 3600:
  129. eta_format = '%d:%02d:%02d' % (
  130. eta // 3600,
  131. (eta % 3600) // 60,
  132. eta % 60,
  133. )
  134. elif eta > 60:
  135. eta_format = '%d:%02d' % (eta // 60, eta % 60)
  136. else:
  137. eta_format = '%ds' % eta
  138. info += ' - ETA: %s' % eta_format
  139. info += fps
  140. self._total_width += len(info)
  141. if prev_total_width > self._total_width:
  142. info += ' ' * (prev_total_width - self._total_width)
  143. # newline for another epoch
  144. if self._num is not None and current_num >= self._num:
  145. info += '\n'
  146. if self._num is None:
  147. info += '\n'
  148. sys.stdout.write(info)
  149. sys.stdout.flush()
  150. self._last_update = now
  151. elif self._verbose == 2 or self._verbose == 3:
  152. if self._num:
  153. numdigits = int(np.log10(self._num)) + 1
  154. count = (self.name + ' %' + str(numdigits) + 'd/%d') % (
  155. current_num,
  156. self._num,
  157. )
  158. else:
  159. count = self.name + ' %3d' % current_num
  160. info = count + info
  161. for k, val in values:
  162. info += ' - %s:' % k
  163. val = val if isinstance(val, list) else [val]
  164. for v in val:
  165. if isinstance(v, (float, np.float32, np.float64)):
  166. if abs(v) > 1e-3:
  167. info += ' %.4f' % v
  168. else:
  169. info += ' %.4e' % v
  170. elif (
  171. isinstance(v, np.ndarray)
  172. and v.size == 1
  173. and v.dtype in [np.float32, np.float64]
  174. ):
  175. if abs(v.item()) > 1e-3:
  176. info += ' %.4f' % v.item()
  177. else:
  178. info += ' %.4e' % v.item()
  179. else:
  180. info += ' %s' % v
  181. info += fps
  182. info += '\n'
  183. sys.stdout.write(info)
  184. sys.stdout.flush()