beta_sync.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. """Implements `wandb sync` using wandb-core."""
  2. from __future__ import annotations
  3. import asyncio
  4. import contextlib
  5. import pathlib
  6. import time
  7. from itertools import filterfalse
  8. from typing import Iterable, Iterator
  9. import wandb
  10. from wandb.errors import term
  11. from wandb.proto.wandb_sync_pb2 import ServerSyncResponse
  12. from wandb.sdk import wandb_setup
  13. from wandb.sdk.lib import asyncio_compat
  14. from wandb.sdk.lib.printer import Printer, new_printer
  15. from wandb.sdk.lib.progress import progress_printer
  16. from wandb.sdk.lib.service.service_connection import ServiceConnection
  17. from wandb.sdk.mailbox.mailbox_handle import MailboxHandle
  18. _MAX_LIST_LINES = 20
  19. _POLL_WAIT_SECONDS = 0.1
  20. _SLEEP = asyncio.sleep # patched in tests
  21. def sync(
  22. paths: list[pathlib.Path],
  23. *,
  24. live: bool,
  25. entity: str,
  26. project: str,
  27. run_id: str,
  28. dry_run: bool,
  29. skip_synced: bool,
  30. verbose: bool,
  31. parallelism: int,
  32. ) -> None:
  33. """Replay one or more .wandb files.
  34. Args:
  35. live: Whether to enable 'live' mode, which indefinitely retries reading
  36. incomplete transaction logs.
  37. entity: The entity override for all paths, or an empty string.
  38. project: The project override for all paths, or an empty string.
  39. run_id: The run ID override for all paths, or an empty string.
  40. paths: One or more .wandb files, run directories containing
  41. .wandb files, and wandb directories containing run directories.
  42. dry_run: If true, just prints what it would do and exits.
  43. skip_synced: If true, skips files that have already been synced
  44. as indicated by a .wandb.synced marker file in the same directory.
  45. verbose: Verbose mode for printing more info.
  46. parallelism: Max number of runs to sync at a time.
  47. """
  48. singleton = wandb_setup.singleton()
  49. try:
  50. cwd = pathlib.Path.cwd()
  51. except OSError:
  52. cwd = None
  53. ask_for_confirmation = False
  54. if not paths:
  55. paths = [pathlib.Path(singleton.settings.wandb_dir)]
  56. ask_for_confirmation = True
  57. wandb_files = _to_unique_files(
  58. (
  59. wandb_file
  60. for path in paths
  61. for wandb_file in _find_wandb_files(path, skip_synced=skip_synced)
  62. ),
  63. verbose=verbose,
  64. )
  65. if not wandb_files:
  66. term.termlog("No runs to sync.")
  67. return
  68. if dry_run:
  69. term.termlog(f"Would sync {len(wandb_files)} run(s):")
  70. _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
  71. return
  72. term.termlog(f"Syncing {len(wandb_files)} run(s):")
  73. _print_sorted_paths(wandb_files, verbose=verbose, root=cwd)
  74. if ask_for_confirmation and not term.confirm("Sync the listed runs?"):
  75. return
  76. service = singleton.ensure_service()
  77. printer = new_printer()
  78. singleton.asyncer.run(
  79. lambda: _do_sync(
  80. wandb_files,
  81. cwd=cwd,
  82. live=live,
  83. service=service,
  84. entity=entity,
  85. project=project,
  86. run_id=run_id,
  87. settings=singleton.settings,
  88. printer=printer,
  89. parallelism=parallelism,
  90. )
  91. )
  92. def _to_unique_files(
  93. paths: Iterator[pathlib.Path],
  94. *,
  95. verbose: bool,
  96. ) -> set[pathlib.Path]:
  97. """Returns paths with duplicates removed.
  98. Determines file equality the same way as os.path.samefile().
  99. """
  100. id_to_path: dict[tuple[int, int], pathlib.Path] = dict()
  101. # Sort in reverse so that the last path written to the map is
  102. # alphabetically earliest.
  103. for path in sorted(paths, reverse=True):
  104. try:
  105. stat = path.stat()
  106. except OSError as e:
  107. term.termerror(f"Failed to stat {path}: {e}")
  108. continue
  109. id = (stat.st_ino, stat.st_dev)
  110. if verbose and (other_path := id_to_path.get(id)):
  111. term.termlog(f"{path} is the same as {other_path}")
  112. id_to_path[id] = path
  113. return set(id_to_path.values())
  114. async def _do_sync(
  115. wandb_files: set[pathlib.Path],
  116. *,
  117. cwd: pathlib.Path | None,
  118. live: bool,
  119. service: ServiceConnection,
  120. entity: str,
  121. project: str,
  122. run_id: str,
  123. settings: wandb.Settings,
  124. printer: Printer,
  125. parallelism: int,
  126. ) -> None:
  127. """Sync the specified files.
  128. This is factored out to make the progress animation testable.
  129. """
  130. init_handle = await service.init_sync(
  131. wandb_files,
  132. settings,
  133. cwd=cwd,
  134. live=live,
  135. entity=entity,
  136. project=project,
  137. run_id=run_id,
  138. )
  139. init_result = await init_handle.wait_async(timeout=5)
  140. sync_handle = await service.sync(init_result.id, parallelism=parallelism)
  141. await _SyncStatusLoop(
  142. init_result.id,
  143. service,
  144. printer,
  145. ).wait_with_progress(sync_handle)
  146. class _SyncStatusLoop:
  147. """Displays a sync operation's status until it completes."""
  148. def __init__(
  149. self,
  150. id: str,
  151. service: ServiceConnection,
  152. printer: Printer,
  153. ) -> None:
  154. self._id = id
  155. self._service = service
  156. self._printer = printer
  157. self._rate_limit_last_time: float | None = None
  158. self._done = asyncio.Event()
  159. async def wait_with_progress(
  160. self,
  161. handle: MailboxHandle[ServerSyncResponse],
  162. ) -> None:
  163. """Display status updates until the handle completes."""
  164. async with asyncio_compat.open_task_group() as group:
  165. group.start_soon(self._wait_then_mark_done(handle))
  166. group.start_soon(self._show_progress_until_done())
  167. async def _wait_then_mark_done(
  168. self,
  169. handle: MailboxHandle[ServerSyncResponse],
  170. ) -> None:
  171. response = await handle.wait_async(timeout=None)
  172. for msg in response.messages:
  173. self._printer.display(msg.content, level=msg.severity)
  174. self._done.set()
  175. async def _show_progress_until_done(self) -> None:
  176. """Show rate-limited status updates until _done is set."""
  177. with progress_printer(self._printer, "Syncing...") as progress:
  178. while not await self._rate_limit_check_done():
  179. handle = await self._service.sync_status(self._id)
  180. response = await handle.wait_async(timeout=None)
  181. for msg in response.new_messages:
  182. self._printer.display(msg.content, level=msg.severity)
  183. progress.update(list(response.stats))
  184. async def _rate_limit_check_done(self) -> bool:
  185. """Wait for rate limit and return whether _done is set."""
  186. now = time.monotonic()
  187. last_time = self._rate_limit_last_time
  188. self._rate_limit_last_time = now
  189. if last_time and (time_since_last := now - last_time) < _POLL_WAIT_SECONDS:
  190. await asyncio_compat.race(
  191. _SLEEP(_POLL_WAIT_SECONDS - time_since_last),
  192. self._done.wait(),
  193. )
  194. return self._done.is_set()
  195. def _find_wandb_files(
  196. path: pathlib.Path,
  197. *,
  198. skip_synced: bool,
  199. ) -> Iterator[pathlib.Path]:
  200. """Returns paths to the .wandb files to sync."""
  201. if skip_synced:
  202. yield from filterfalse(_is_synced, _expand_wandb_files(path))
  203. else:
  204. yield from _expand_wandb_files(path)
  205. def _expand_wandb_files(
  206. path: pathlib.Path,
  207. ) -> Iterator[pathlib.Path]:
  208. """Iterate over .wandb files selected by the path."""
  209. if path.suffix == ".wandb":
  210. yield path
  211. return
  212. files_in_run_directory = path.glob("*.wandb")
  213. try:
  214. first_file = next(files_in_run_directory)
  215. except StopIteration:
  216. pass
  217. else:
  218. yield first_file
  219. yield from files_in_run_directory
  220. return
  221. yield from path.glob("*/*.wandb")
  222. def _is_synced(path: pathlib.Path) -> bool:
  223. """Returns whether the .wandb file is synced."""
  224. return path.with_suffix(".wandb.synced").exists()
  225. def _print_sorted_paths(
  226. paths: Iterable[pathlib.Path],
  227. verbose: bool,
  228. *,
  229. root: pathlib.Path | None,
  230. ) -> None:
  231. """Print file paths, sorting them and truncating the list if needed.
  232. Args:
  233. paths: Paths to print. Must be absolute with symlinks resolved.
  234. verbose: If true, doesn't truncate paths.
  235. root: A root directory for making paths relative.
  236. """
  237. # Prefer to print paths relative to the current working directory.
  238. formatted_paths: list[str] = []
  239. for path in paths:
  240. formatted_path = str(path)
  241. if root:
  242. with contextlib.suppress(ValueError):
  243. formatted_path = str(path.relative_to(root))
  244. formatted_paths.append(formatted_path)
  245. sorted_paths = sorted(formatted_paths)
  246. max_lines = len(sorted_paths) if verbose else _MAX_LIST_LINES
  247. for i in range(min(len(sorted_paths), max_lines)):
  248. term.termlog(f" {sorted_paths[i]}")
  249. if len(sorted_paths) > max_lines:
  250. remaining = len(sorted_paths) - max_lines
  251. term.termlog(f" +{remaining:,d} more (pass --verbose to see all)")