step_prepare.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179
  1. """Batching file prepare requests to our API."""
  2. import queue
  3. import threading
  4. import time
  5. from typing import (
  6. TYPE_CHECKING,
  7. Callable,
  8. Dict,
  9. List,
  10. Mapping,
  11. NamedTuple,
  12. Optional,
  13. Sequence,
  14. Tuple,
  15. Union,
  16. )
  17. if TYPE_CHECKING:
  18. from wandb.sdk.internal.internal_api import (
  19. Api,
  20. CreateArtifactFileSpecInput,
  21. CreateArtifactFilesResponseFile,
  22. )
  23. # Request for a file to be prepared.
  24. class RequestPrepare(NamedTuple):
  25. file_spec: "CreateArtifactFileSpecInput"
  26. response_channel: "queue.Queue[ResponsePrepare]"
  27. class RequestFinish(NamedTuple):
  28. pass
  29. class ResponsePrepare(NamedTuple):
  30. birth_artifact_id: str
  31. upload_url: Optional[str]
  32. upload_headers: Sequence[str]
  33. upload_id: Optional[str]
  34. storage_path: Optional[str]
  35. multipart_upload_urls: Optional[Dict[int, str]]
  36. Request = Union[RequestPrepare, RequestFinish]
  37. def _clamp(x: float, low: float, high: float) -> float:
  38. return max(low, min(x, high))
  39. def gather_batch(
  40. request_queue: "queue.Queue[Request]",
  41. batch_time: float,
  42. inter_event_time: float,
  43. max_batch_size: int,
  44. clock: Callable[[], float] = time.monotonic,
  45. ) -> Tuple[bool, Sequence[RequestPrepare]]:
  46. batch_start_time = clock()
  47. remaining_time = batch_time
  48. first_request = request_queue.get()
  49. if isinstance(first_request, RequestFinish):
  50. return True, []
  51. batch: List[RequestPrepare] = [first_request]
  52. while remaining_time > 0 and len(batch) < max_batch_size:
  53. try:
  54. request = request_queue.get(
  55. timeout=_clamp(
  56. x=inter_event_time,
  57. low=1e-12, # 0 = "block forever", so just use something tiny
  58. high=remaining_time,
  59. ),
  60. )
  61. if isinstance(request, RequestFinish):
  62. return True, batch
  63. batch.append(request)
  64. remaining_time = batch_time - (clock() - batch_start_time)
  65. except queue.Empty:
  66. break
  67. return False, batch
  68. def prepare_response(response: "CreateArtifactFilesResponseFile") -> ResponsePrepare:
  69. multipart_resp = response.get("uploadMultipartUrls")
  70. part_list = multipart_resp["uploadUrlParts"] if multipart_resp else []
  71. multipart_parts = {u["partNumber"]: u["uploadUrl"] for u in part_list} or None
  72. return ResponsePrepare(
  73. birth_artifact_id=response["artifact"]["id"],
  74. upload_url=response["uploadUrl"],
  75. upload_headers=response["uploadHeaders"],
  76. upload_id=multipart_resp and multipart_resp.get("uploadID"),
  77. storage_path=response.get("storagePath"),
  78. multipart_upload_urls=multipart_parts,
  79. )
  80. class StepPrepare:
  81. """A thread that batches requests to our file prepare API.
  82. Any number of threads may call prepare() in parallel. The PrepareBatcher thread
  83. will batch requests up and send them all to the backend at once.
  84. """
  85. def __init__(
  86. self,
  87. api: "Api",
  88. batch_time: float,
  89. inter_event_time: float,
  90. max_batch_size: int,
  91. request_queue: Optional["queue.Queue[Request]"] = None,
  92. ) -> None:
  93. self._api = api
  94. self._inter_event_time = inter_event_time
  95. self._batch_time = batch_time
  96. self._max_batch_size = max_batch_size
  97. self._request_queue: queue.Queue[Request] = request_queue or queue.Queue()
  98. self._thread = threading.Thread(target=self._thread_body)
  99. self._thread.daemon = True
  100. def _thread_body(self) -> None:
  101. while True:
  102. finish, batch = gather_batch(
  103. request_queue=self._request_queue,
  104. batch_time=self._batch_time,
  105. inter_event_time=self._inter_event_time,
  106. max_batch_size=self._max_batch_size,
  107. )
  108. if batch:
  109. batch_response = self._prepare_batch(batch)
  110. # send responses
  111. for prepare_request in batch:
  112. name = prepare_request.file_spec["name"]
  113. response_file = batch_response[name]
  114. response = prepare_response(response_file)
  115. prepare_request.response_channel.put(response)
  116. if finish:
  117. break
  118. def _prepare_batch(
  119. self, batch: Sequence[RequestPrepare]
  120. ) -> Mapping[str, "CreateArtifactFilesResponseFile"]:
  121. """Execute the prepareFiles API call.
  122. Args:
  123. batch: List of RequestPrepare objects
  124. Returns:
  125. dict of (save_name: ResponseFile) pairs where ResponseFile is a dict with
  126. an uploadUrl key. The value of the uploadUrl key is None if the file
  127. already exists, or a url string if the file should be uploaded.
  128. """
  129. return self._api.create_artifact_files([req.file_spec for req in batch])
  130. def prepare(
  131. self, file_spec: "CreateArtifactFileSpecInput"
  132. ) -> "queue.Queue[ResponsePrepare]":
  133. response_queue: queue.Queue[ResponsePrepare] = queue.Queue()
  134. self._request_queue.put(RequestPrepare(file_spec, response_queue))
  135. return response_queue
  136. def start(self) -> None:
  137. self._thread.start()
  138. def finish(self) -> None:
  139. self._request_queue.put(RequestFinish())
  140. def is_alive(self) -> bool:
  141. return self._thread.is_alive()
  142. def shutdown(self) -> None:
  143. self.finish()
  144. self._thread.join()