step_checksum.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. """Batching file prepare requests to our API."""
  2. import concurrent.futures
  3. import functools
  4. import os
  5. import queue
  6. import shutil
  7. import threading
  8. from typing import TYPE_CHECKING, NamedTuple, Optional, Union, cast
  9. from wandb.filesync import step_upload
  10. from wandb.sdk.lib import filesystem, runid
  11. from wandb.sdk.lib.paths import LogicalPath
  12. if TYPE_CHECKING:
  13. import tempfile
  14. from wandb.filesync import stats
  15. from wandb.sdk.artifacts.artifact_manifest import ArtifactManifest
  16. from wandb.sdk.artifacts.artifact_saver import SaveFn
  17. from wandb.sdk.internal import internal_api
  18. class RequestUpload(NamedTuple):
  19. path: str
  20. save_name: LogicalPath
  21. copy: bool
  22. class RequestStoreManifestFiles(NamedTuple):
  23. manifest: "ArtifactManifest"
  24. artifact_id: str
  25. save_fn: "SaveFn"
  26. class RequestCommitArtifact(NamedTuple):
  27. artifact_id: str
  28. finalize: bool
  29. before_commit: step_upload.PreCommitFn
  30. result_future: "concurrent.futures.Future[None]"
  31. class RequestFinish(NamedTuple):
  32. callback: Optional[step_upload.OnRequestFinishFn]
  33. Event = Union[
  34. RequestUpload, RequestStoreManifestFiles, RequestCommitArtifact, RequestFinish
  35. ]
  36. class StepChecksum:
  37. def __init__(
  38. self,
  39. api: "internal_api.Api",
  40. tempdir: "tempfile.TemporaryDirectory",
  41. request_queue: "queue.Queue[Event]",
  42. output_queue: "queue.Queue[step_upload.Event]",
  43. stats: "stats.Stats",
  44. ) -> None:
  45. self._api = api
  46. self._tempdir = tempdir
  47. self._request_queue = request_queue
  48. self._output_queue = output_queue
  49. self._stats = stats
  50. self._thread = threading.Thread(target=self._thread_body)
  51. self._thread.daemon = True
  52. def _thread_body(self) -> None:
  53. while True:
  54. req = self._request_queue.get()
  55. if isinstance(req, RequestUpload):
  56. path = req.path
  57. if req.copy:
  58. path = os.path.join(
  59. self._tempdir.name,
  60. f"{runid.generate_id()}-{req.save_name}",
  61. )
  62. filesystem.mkdir_exists_ok(os.path.dirname(path))
  63. try:
  64. # certain linux distros throw an exception when copying
  65. # large files: https://bugs.python.org/issue43743
  66. shutil.copy2(req.path, path)
  67. except OSError:
  68. shutil._USE_CP_SENDFILE = False # type: ignore[attr-defined]
  69. shutil.copy2(req.path, path)
  70. self._stats.init_file(req.save_name, os.path.getsize(path))
  71. self._output_queue.put(
  72. step_upload.RequestUpload(
  73. path,
  74. req.save_name,
  75. None,
  76. None,
  77. req.copy,
  78. None,
  79. None,
  80. )
  81. )
  82. elif isinstance(req, RequestStoreManifestFiles):
  83. for entry in req.manifest.entries.values():
  84. if entry.local_path:
  85. self._stats.init_file(
  86. entry.local_path,
  87. cast(int, entry.size),
  88. is_artifact_file=True,
  89. )
  90. self._output_queue.put(
  91. step_upload.RequestUpload(
  92. entry.local_path,
  93. entry.path,
  94. req.artifact_id,
  95. entry.digest,
  96. False,
  97. functools.partial(req.save_fn, entry),
  98. entry.digest,
  99. )
  100. )
  101. elif isinstance(req, RequestCommitArtifact):
  102. self._output_queue.put(
  103. step_upload.RequestCommitArtifact(
  104. req.artifact_id,
  105. req.finalize,
  106. req.before_commit,
  107. req.result_future,
  108. )
  109. )
  110. elif isinstance(req, RequestFinish):
  111. break
  112. else:
  113. raise TypeError
  114. self._output_queue.put(step_upload.RequestFinish(req.callback))
  115. def start(self) -> None:
  116. self._thread.start()
  117. def is_alive(self) -> bool:
  118. return self._thread.is_alive()
  119. def finish(self) -> None:
  120. self._request_queue.put(RequestFinish(None))