step_upload.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287
  1. """Batching file prepare requests to our API."""
  2. import concurrent.futures
  3. import logging
  4. import queue
  5. import sys
  6. import threading
  7. from typing import (
  8. TYPE_CHECKING,
  9. Callable,
  10. MutableMapping,
  11. MutableSequence,
  12. MutableSet,
  13. NamedTuple,
  14. Optional,
  15. Union,
  16. )
  17. from wandb.errors.term import termerror
  18. from wandb.filesync import upload_job
  19. from wandb.sdk.lib.paths import LogicalPath
  20. if TYPE_CHECKING:
  21. from typing import TypedDict
  22. from wandb.filesync import stats
  23. from wandb.sdk.internal import file_stream, internal_api, progress
  24. from wandb.sdk.internal.settings_static import SettingsStatic
  25. class ArtifactStatus(TypedDict):
  26. finalize: bool
  27. pending_count: int
  28. commit_requested: bool
  29. pre_commit_callbacks: MutableSet["PreCommitFn"]
  30. result_futures: MutableSet["concurrent.futures.Future[None]"]
  31. PreCommitFn = Callable[[], None]
  32. OnRequestFinishFn = Callable[[], None]
  33. SaveFn = Callable[["progress.ProgressFn"], bool]
  34. logger = logging.getLogger(__name__)
  35. class RequestUpload(NamedTuple):
  36. path: str
  37. save_name: LogicalPath
  38. artifact_id: Optional[str]
  39. md5: Optional[str]
  40. copied: bool
  41. save_fn: Optional[SaveFn]
  42. digest: Optional[str]
  43. class RequestCommitArtifact(NamedTuple):
  44. artifact_id: str
  45. finalize: bool
  46. before_commit: PreCommitFn
  47. result_future: "concurrent.futures.Future[None]"
  48. class RequestFinish(NamedTuple):
  49. callback: Optional[OnRequestFinishFn]
  50. class EventJobDone(NamedTuple):
  51. job: RequestUpload
  52. exc: Optional[BaseException]
  53. Event = Union[RequestUpload, RequestCommitArtifact, RequestFinish, EventJobDone]
  54. class StepUpload:
  55. def __init__(
  56. self,
  57. api: "internal_api.Api",
  58. stats: "stats.Stats",
  59. event_queue: "queue.Queue[Event]",
  60. max_threads: int,
  61. file_stream: "file_stream.FileStreamApi",
  62. settings: Optional["SettingsStatic"] = None,
  63. ) -> None:
  64. self._api = api
  65. self._stats = stats
  66. self._event_queue = event_queue
  67. self._file_stream = file_stream
  68. self._thread = threading.Thread(target=self._thread_body)
  69. self._thread.daemon = True
  70. self._pool = concurrent.futures.ThreadPoolExecutor(
  71. thread_name_prefix="wandb-upload",
  72. max_workers=max_threads,
  73. )
  74. # Indexed by files' `save_name`'s, which are their ID's in the Run.
  75. self._running_jobs: MutableMapping[LogicalPath, RequestUpload] = {}
  76. self._pending_jobs: MutableSequence[RequestUpload] = []
  77. self._artifacts: MutableMapping[str, ArtifactStatus] = {}
  78. self.silent = bool(settings.silent) if settings else False
  79. def _thread_body(self) -> None:
  80. event: Optional[Event]
  81. # Wait for event in the queue, and process one by one until a
  82. # finish event is received
  83. finish_callback = None
  84. while True:
  85. event = self._event_queue.get()
  86. if isinstance(event, RequestFinish):
  87. finish_callback = event.callback
  88. break
  89. self._handle_event(event)
  90. # We've received a finish event. At this point, further Upload requests
  91. # are invalid.
  92. # After a finish event is received, iterate through the event queue
  93. # one by one and process all remaining events.
  94. while True:
  95. try:
  96. event = self._event_queue.get(True, 0.2)
  97. except queue.Empty:
  98. event = None
  99. if event:
  100. self._handle_event(event)
  101. elif not self._running_jobs:
  102. # Queue was empty and no jobs left.
  103. self._pool.shutdown(wait=False)
  104. if finish_callback:
  105. finish_callback()
  106. break
  107. def _handle_event(self, event: Event) -> None:
  108. if isinstance(event, EventJobDone):
  109. job = event.job
  110. if event.exc is not None:
  111. logger.exception(
  112. "Failed to upload file: %s", job.path, exc_info=event.exc
  113. )
  114. if job.artifact_id:
  115. if event.exc is None:
  116. self._artifacts[job.artifact_id]["pending_count"] -= 1
  117. self._maybe_commit_artifact(job.artifact_id)
  118. else:
  119. if not self.silent:
  120. termerror(
  121. "Uploading artifact file failed. Artifact won't be committed."
  122. )
  123. self._fail_artifact_futures(job.artifact_id, event.exc)
  124. self._running_jobs.pop(job.save_name)
  125. # If we have any pending jobs, start one now
  126. if self._pending_jobs:
  127. event = self._pending_jobs.pop(0)
  128. self._start_upload_job(event)
  129. elif isinstance(event, RequestCommitArtifact):
  130. if event.artifact_id not in self._artifacts:
  131. self._init_artifact(event.artifact_id)
  132. self._artifacts[event.artifact_id]["commit_requested"] = True
  133. self._artifacts[event.artifact_id]["finalize"] = event.finalize
  134. self._artifacts[event.artifact_id]["pre_commit_callbacks"].add(
  135. event.before_commit
  136. )
  137. self._artifacts[event.artifact_id]["result_futures"].add(
  138. event.result_future
  139. )
  140. self._maybe_commit_artifact(event.artifact_id)
  141. elif isinstance(event, RequestUpload):
  142. if event.artifact_id is not None:
  143. if event.artifact_id not in self._artifacts:
  144. self._init_artifact(event.artifact_id)
  145. self._artifacts[event.artifact_id]["pending_count"] += 1
  146. self._start_upload_job(event)
  147. else:
  148. raise TypeError(f"Event has unexpected type: {event!s}")
  149. def _start_upload_job(self, event: RequestUpload) -> None:
  150. # Operations on a single backend file must be serialized. if
  151. # we're already uploading this file, put the event on the
  152. # end of the queue
  153. if event.save_name in self._running_jobs:
  154. self._pending_jobs.append(event)
  155. return
  156. self._spawn_upload(event)
  157. def _spawn_upload(self, event: RequestUpload) -> None:
  158. """Spawn an upload job, and handles the bookkeeping of `self._running_jobs`.
  159. Context: it's important that, whenever we add an entry to `self._running_jobs`,
  160. we ensure that a corresponding `EventJobDone` message will eventually get handled;
  161. otherwise, the `_running_jobs` entry will never get removed, and the StepUpload
  162. will never shut down.
  163. The sole purpose of this function is to make sure that the code that adds an entry
  164. to `self._running_jobs` is textually right next to the code that eventually enqueues
  165. the `EventJobDone` message. This should help keep them in sync.
  166. """
  167. # Adding the entry to `self._running_jobs` MUST happen in the main thread,
  168. # NOT in the job that gets submitted to the thread-pool, to guard against
  169. # this sequence of events:
  170. # - StepUpload receives a RequestUpload
  171. # ...and therefore spawns a thread to do the upload
  172. # - StepUpload receives a RequestFinish
  173. # ...and checks `self._running_jobs` to see if there are any tasks to wait for...
  174. # ...and there are none, because the addition to `self._running_jobs` happens in
  175. # the background thread, which the scheduler hasn't yet run...
  176. # ...so the StepUpload shuts down. Even though we haven't uploaded the file!
  177. #
  178. # This would be very bad!
  179. # So, this line has to happen _outside_ the `pool.submit()`.
  180. self._running_jobs[event.save_name] = event
  181. def run_and_notify() -> None:
  182. try:
  183. self._do_upload(event)
  184. finally:
  185. self._event_queue.put(EventJobDone(event, exc=sys.exc_info()[1]))
  186. self._pool.submit(run_and_notify)
  187. def _do_upload(self, event: RequestUpload) -> None:
  188. job = upload_job.UploadJob(
  189. self._stats,
  190. self._api,
  191. self._file_stream,
  192. self.silent,
  193. event.save_name,
  194. event.path,
  195. event.artifact_id,
  196. event.md5,
  197. event.copied,
  198. event.save_fn,
  199. event.digest,
  200. )
  201. job.run()
  202. def _init_artifact(self, artifact_id: str) -> None:
  203. self._artifacts[artifact_id] = {
  204. "finalize": False,
  205. "pending_count": 0,
  206. "commit_requested": False,
  207. "pre_commit_callbacks": set(),
  208. "result_futures": set(),
  209. }
  210. def _maybe_commit_artifact(self, artifact_id: str) -> None:
  211. artifact_status = self._artifacts[artifact_id]
  212. if (
  213. artifact_status["pending_count"] == 0
  214. and artifact_status["commit_requested"]
  215. ):
  216. try:
  217. for pre_callback in artifact_status["pre_commit_callbacks"]:
  218. pre_callback()
  219. if artifact_status["finalize"]:
  220. self._api.commit_artifact(artifact_id)
  221. except Exception as exc:
  222. termerror(
  223. f"Committing artifact failed. Artifact {artifact_id} won't be finalized."
  224. )
  225. termerror(str(exc))
  226. self._fail_artifact_futures(artifact_id, exc)
  227. else:
  228. self._resolve_artifact_futures(artifact_id)
  229. def _fail_artifact_futures(self, artifact_id: str, exc: BaseException) -> None:
  230. futures = self._artifacts[artifact_id]["result_futures"]
  231. for result_future in futures:
  232. result_future.set_exception(exc)
  233. futures.clear()
  234. def _resolve_artifact_futures(self, artifact_id: str) -> None:
  235. futures = self._artifacts[artifact_id]["result_futures"]
  236. for result_future in futures:
  237. result_future.set_result(None)
  238. futures.clear()
  239. def start(self) -> None:
  240. self._thread.start()
  241. def is_alive(self) -> bool:
  242. return self._thread.is_alive()