progress.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358
  1. """Defines an object for printing run progress at the end of a script."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import time
  6. from typing import Iterator, NoReturn
  7. from wandb.proto import wandb_internal_pb2 as pb
  8. from wandb.sdk.interface import interface
  9. from wandb.sdk.lib import asyncio_compat
  10. from . import printer as p
  11. _INDENT = " "
  12. _MAX_LINES_TO_PRINT = 6
  13. _MAX_OPS_TO_PRINT = 5
  14. async def loop_printing_operation_stats(
  15. progress: ProgressPrinter,
  16. interface: interface.InterfaceBase,
  17. ) -> None:
  18. """Poll and display ongoing tasks in the internal service process.
  19. This never returns and must be cancelled. This is meant to be used with
  20. `mailbox.wait_with_progress()`.
  21. Args:
  22. progress: The printer to update with operation stats.
  23. interface: The interface to use to poll for updates.
  24. Raises:
  25. HandleAbandonedError: If the mailbox associated with the interface
  26. becomes closed.
  27. Exception: Any other problem communicating with the service process.
  28. """
  29. stats: pb.OperationStats | None = None
  30. async def loop_update_screen() -> NoReturn:
  31. while True:
  32. if stats:
  33. progress.update(stats)
  34. await asyncio.sleep(0.1)
  35. async def loop_poll_stats() -> NoReturn:
  36. nonlocal stats
  37. while True:
  38. start_time = time.monotonic()
  39. handle = await interface.deliver_async(
  40. pb.Record(
  41. request=pb.Request(operations=pb.OperationStatsRequest()),
  42. )
  43. )
  44. result = await handle.wait_async(timeout=None)
  45. stats = result.response.operations_response.operation_stats
  46. elapsed_time = time.monotonic() - start_time
  47. if elapsed_time < 0.5:
  48. await asyncio.sleep(0.5 - elapsed_time)
  49. async with asyncio_compat.open_task_group() as task_group:
  50. task_group.start_soon(loop_update_screen())
  51. task_group.start_soon(loop_poll_stats())
  52. @contextlib.contextmanager
  53. def progress_printer(
  54. printer: p.Printer,
  55. default_text: str,
  56. ) -> Iterator[ProgressPrinter]:
  57. """Context manager providing an object for printing run progress.
  58. Args:
  59. printer: The printer to use.
  60. default_text: The text to show if no information is available.
  61. """
  62. with printer.dynamic_text() as text_area:
  63. try:
  64. yield ProgressPrinter(
  65. printer,
  66. text_area,
  67. default_text=default_text,
  68. )
  69. finally:
  70. printer.progress_close()
  71. class ProgressPrinter:
  72. """Displays PollExitResponse results to the user."""
  73. def __init__(
  74. self,
  75. printer: p.Printer,
  76. progress_text_area: p.DynamicText | None,
  77. default_text: str,
  78. ) -> None:
  79. self._printer = printer
  80. self._progress_text_area = progress_text_area
  81. self._default_text = default_text
  82. self._tick = -1
  83. self._last_printed_line = ""
  84. def update(
  85. self,
  86. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  87. ) -> None:
  88. """Update the displayed information.
  89. Args:
  90. stats_or_list: A single group of operations, or zero or more
  91. labeled operation groups.
  92. """
  93. self._tick += 1
  94. if not self._progress_text_area:
  95. line = self._to_static_text(stats_or_list)
  96. if line and line != self._last_printed_line:
  97. self._printer.display(line)
  98. self._last_printed_line = line
  99. return
  100. lines = self._to_dynamic_text(stats_or_list)
  101. if not lines:
  102. loading_symbol = self._printer.loading_symbol(self._tick)
  103. if loading_symbol:
  104. lines = [f"{loading_symbol} {self._default_text}"]
  105. else:
  106. lines = [self._default_text]
  107. self._progress_text_area.set_text("\n".join(lines))
  108. def _to_dynamic_text(
  109. self,
  110. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  111. ) -> list[str]:
  112. """Returns text to show in a dynamic text area."""
  113. loading_symbol = self._printer.loading_symbol(self._tick)
  114. if isinstance(stats_or_list, list):
  115. return _GroupedOperationStatsPrinter(
  116. self._printer,
  117. _MAX_LINES_TO_PRINT,
  118. loading_symbol,
  119. ).render(stats_or_list)
  120. else:
  121. return _OperationStatsPrinter(
  122. self._printer,
  123. _MAX_LINES_TO_PRINT,
  124. loading_symbol,
  125. ).render(stats_or_list)
  126. def _to_static_text(
  127. self,
  128. stats_or_list: pb.OperationStats | list[pb.OperationStats],
  129. ) -> str:
  130. """Returns a single line of text to print out."""
  131. if isinstance(stats_or_list, list):
  132. sorted_prefixed_stats = list(
  133. (f"[{stats.label}] " if stats.label else "", stats)
  134. for stats in sorted(stats_or_list, key=lambda s: s.label)
  135. )
  136. else:
  137. sorted_prefixed_stats = [("", stats_or_list)]
  138. group_strs: list[str] = []
  139. total_operations = 0
  140. total_printed = 0
  141. for prefix, stats in sorted_prefixed_stats:
  142. total_operations += stats.total_operations
  143. if not stats.operations:
  144. continue
  145. group_ops: list[str] = []
  146. i = 0
  147. while total_printed < _MAX_OPS_TO_PRINT and i < len(stats.operations):
  148. group_ops.append(stats.operations[i].desc)
  149. total_printed += 1
  150. i += 1
  151. if group_ops:
  152. group_strs.append(prefix + "; ".join(group_ops))
  153. line = "; ".join(group_strs)
  154. remaining = total_operations - total_printed
  155. if total_printed > 0 and remaining > 0:
  156. line += f" (+ {remaining} more)"
  157. return line
  158. class _GroupedOperationStatsPrinter:
  159. """Renders a list of labeled operation stats groups into lines of text."""
  160. def __init__(
  161. self,
  162. printer: p.Printer,
  163. max_lines: int,
  164. loading_symbol: str,
  165. ) -> None:
  166. self._printer = printer
  167. self._max_lines = max_lines
  168. self._loading_symbol = loading_symbol
  169. def render(self, stats_list: list[pb.OperationStats]) -> list[str]:
  170. """Convert labeled operation stats groups into text to display.
  171. Args:
  172. stats_list: A list of labeled operation stats.
  173. Returns:
  174. The lines of text to print. The lines do not end with the newline
  175. character. Returns an empty list if there are no operations.
  176. """
  177. lines: list[str] = []
  178. for stats in sorted(stats_list, key=lambda s: s.label):
  179. # Don't display empty groups.
  180. if not stats.operations:
  181. continue
  182. if stats.label:
  183. remaining_non_header_lines = self._max_lines - len(lines) - 1
  184. indent = _INDENT
  185. else:
  186. remaining_non_header_lines = self._max_lines - len(lines)
  187. indent = ""
  188. # Ensure enough space left for at least one line of content
  189. # after the header.
  190. if remaining_non_header_lines < 1:
  191. break
  192. # Group header (if not empty).
  193. if stats.label:
  194. lines.append(stats.label)
  195. # Group content.
  196. stats_lines = _OperationStatsPrinter(
  197. printer=self._printer,
  198. max_lines=remaining_non_header_lines,
  199. loading_symbol=self._loading_symbol,
  200. ).render(stats)
  201. for line in stats_lines:
  202. lines.append(f"{indent}{line}")
  203. return lines
  204. class _OperationStatsPrinter:
  205. """Renders operation stats into lines of text."""
  206. def __init__(
  207. self,
  208. printer: p.Printer,
  209. max_lines: int,
  210. loading_symbol: str,
  211. ) -> None:
  212. self._printer = printer
  213. self._max_lines = max_lines
  214. self._loading_symbol = loading_symbol
  215. self._lines: list[str] = []
  216. self._ops_shown = 0
  217. def render(self, stats: pb.OperationStats) -> list[str]:
  218. """Convert the stats into a list of lines to display.
  219. Args:
  220. stats: Collection of operations to display.
  221. Returns:
  222. The lines of text to print. The lines do not end with the newline
  223. character. Returns an empty list if there are no operations.
  224. """
  225. for op in stats.operations:
  226. self._add_operation(op, is_subtask=False, indent="")
  227. if self._ops_shown < stats.total_operations:
  228. if 1 <= self._max_lines <= len(self._lines):
  229. self._ops_shown -= 1
  230. self._lines.pop()
  231. remaining = stats.total_operations - self._ops_shown
  232. self._lines.append(f"+ {remaining} more task(s)")
  233. return self._lines
  234. def _add_operation(self, op: pb.Operation, is_subtask: bool, indent: str) -> None:
  235. """Add the operation to `self._lines`."""
  236. if len(self._lines) >= self._max_lines:
  237. return
  238. if not is_subtask:
  239. self._ops_shown += 1
  240. status_indent_level = 0 # alignment for the status message, if any
  241. parts: list[str] = []
  242. # Subtask indicator.
  243. if is_subtask and self._printer.supports_unicode:
  244. status_indent_level += 2 # +1 for space
  245. parts.append("↳")
  246. # Loading symbol.
  247. if self._loading_symbol:
  248. status_indent_level += 2 # +1 for space
  249. parts.append(self._loading_symbol)
  250. # Task name.
  251. parts.append(op.desc)
  252. # Progress information.
  253. if op.progress:
  254. parts.append(f"{op.progress}")
  255. # Task duration.
  256. parts.append(f"({_time_to_string(seconds=op.runtime_seconds)})")
  257. # Error status.
  258. self._lines.append(indent + " ".join(parts))
  259. if op.error_status:
  260. error_word = self._printer.error("ERROR")
  261. error_desc = self._printer.secondary_text(op.error_status)
  262. status_indent = " " * status_indent_level
  263. self._lines.append(
  264. f"{indent}{status_indent}{error_word} {error_desc}",
  265. )
  266. # Subtasks.
  267. if op.subtasks:
  268. subtask_indent = indent + _INDENT
  269. for task in op.subtasks:
  270. self._add_operation(
  271. task,
  272. is_subtask=True,
  273. indent=subtask_indent,
  274. )
  275. def _time_to_string(seconds: float) -> str:
  276. """Returns a short string representing the duration."""
  277. if seconds < 10:
  278. return f"{seconds:.1f}s"
  279. if seconds < 60:
  280. return f"{seconds:.0f}s"
  281. if seconds < 60 * 60:
  282. minutes = seconds / 60
  283. return f"{minutes:.1f}m"
  284. hours = int(seconds / (60 * 60))
  285. minutes = int((seconds / 60) % 60)
  286. return f"{hours}h{minutes}m"