retry.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396
  1. import abc
  2. import asyncio
  3. import datetime
  4. import functools
  5. import logging
  6. import os
  7. import random
  8. import threading
  9. import time
  10. from typing import Any, Awaitable, Callable, Generic, Optional, Tuple, Type, TypeVar
  11. import wandb
  12. import wandb.errors
  13. from wandb.util import CheckRetryFnType
  14. logger = logging.getLogger(__name__)
  15. # To let tests mock out the retry logic's now()/sleep() funcs, this file
  16. # should only use these variables, not call the stdlib funcs directly.
  17. NOW_FN = datetime.datetime.now
  18. SLEEP_FN = time.sleep
  19. SLEEP_ASYNC_FN = asyncio.sleep
  20. class RetryCancelledError(wandb.errors.Error):
  21. """A retry did not occur because it was cancelled."""
  22. class TransientError(Exception):
  23. """Exception type designated for errors that may only be temporary.
  24. Can have its own message and/or wrap another exception.
  25. """
  26. def __init__(
  27. self, msg: Optional[str] = None, exc: Optional[BaseException] = None
  28. ) -> None:
  29. super().__init__(msg)
  30. self.message = msg
  31. self.exception = exc
  32. _R = TypeVar("_R")
  33. class Retry(Generic[_R]):
  34. """Create a retryable version of a function.
  35. Calling this will call the passed function, retrying if any exceptions in
  36. retryable_exceptions are caught, with exponential backoff.
  37. """
  38. MAX_SLEEP_SECONDS = 5 * 60
  39. def __init__(
  40. self,
  41. call_fn: Callable[..., _R],
  42. retry_timedelta: Optional[datetime.timedelta] = None,
  43. retry_cancel_event: Optional[threading.Event] = None,
  44. num_retries: Optional[int] = None,
  45. check_retry_fn: CheckRetryFnType = lambda e: True,
  46. retryable_exceptions: Optional[Tuple[Type[Exception], ...]] = None,
  47. error_prefix: str = "Network error",
  48. retry_callback: Optional[Callable[[int, str], Any]] = None,
  49. ) -> None:
  50. self._call_fn = call_fn
  51. self._check_retry_fn = check_retry_fn
  52. self._error_prefix = error_prefix
  53. self._last_print = datetime.datetime.now() - datetime.timedelta(minutes=1)
  54. self._retry_timedelta = retry_timedelta
  55. self._retry_cancel_event = retry_cancel_event
  56. self._num_retries = num_retries
  57. if retryable_exceptions is not None:
  58. self._retryable_exceptions = retryable_exceptions
  59. else:
  60. self._retryable_exceptions = (TransientError,)
  61. self.retry_callback = retry_callback
  62. self._num_iter = 0
  63. def _sleep_check_cancelled(
  64. self, wait_seconds: float, cancel_event: Optional[threading.Event]
  65. ) -> bool:
  66. if not cancel_event:
  67. SLEEP_FN(wait_seconds)
  68. return False
  69. cancelled = cancel_event.wait(wait_seconds)
  70. return cancelled
  71. @property
  72. def num_iters(self) -> int:
  73. """The number of iterations the previous __call__ retried."""
  74. return self._num_iter
  75. def __call__(
  76. self,
  77. *args: Any,
  78. num_retries: Optional[int] = None,
  79. retry_timedelta: Optional[datetime.timedelta] = None,
  80. retry_sleep_base: Optional[float] = None,
  81. retry_cancel_event: Optional[threading.Event] = None,
  82. check_retry_fn: Optional[CheckRetryFnType] = None,
  83. **kwargs: Any,
  84. ) -> _R:
  85. """Call the wrapped function, with retries.
  86. Args:
  87. num_retries: The number of retries after which to give up.
  88. retry_timedelta: An amount of time after which to give up.
  89. retry_sleep_base: Number of seconds to sleep for the first retry.
  90. This is used as the base for exponential backoff.
  91. retry_cancel_event: An event that causes this to raise
  92. a RetryCancelledException on the next attempted retry.
  93. check_retry_fn: A custom check for deciding whether an exception
  94. should be retried. Retrying is prevented if this returns a falsy
  95. value, even if more retries are left. This may also return a
  96. timedelta that represents a shorter timeout: retrying is
  97. prevented if the value is less than the amount of time that has
  98. passed since the last timedelta was returned.
  99. """
  100. if os.environ.get("WANDB_TEST"):
  101. max_retries = 0
  102. elif num_retries is not None:
  103. max_retries = num_retries
  104. elif self._num_retries is not None:
  105. max_retries = self._num_retries
  106. else:
  107. max_retries = 1000000
  108. if retry_timedelta is not None:
  109. timeout = retry_timedelta
  110. elif self._retry_timedelta is not None:
  111. timeout = self._retry_timedelta
  112. else:
  113. timeout = datetime.timedelta(days=365)
  114. if retry_sleep_base is not None:
  115. initial_sleep = retry_sleep_base
  116. else:
  117. initial_sleep = 1
  118. retry_loop = _RetryLoop(
  119. max_retries=max_retries,
  120. timeout=timeout,
  121. initial_sleep=initial_sleep,
  122. max_sleep=self.MAX_SLEEP_SECONDS,
  123. cancel_event=retry_cancel_event or self._retry_cancel_event,
  124. retry_check=check_retry_fn or self._check_retry_fn,
  125. )
  126. start_time = NOW_FN()
  127. self._num_iter = 0
  128. while True:
  129. try:
  130. result = self._call_fn(*args, **kwargs)
  131. except self._retryable_exceptions as e:
  132. if not retry_loop.should_retry(e):
  133. raise
  134. if self._num_iter == 2:
  135. logger.info("Retry attempt failed:", exc_info=e)
  136. self._print_entered_retry_loop(e)
  137. retry_loop.wait_before_retry()
  138. self._num_iter += 1
  139. else:
  140. if self._num_iter > 2:
  141. self._print_recovered(start_time)
  142. return result
  143. def _print_entered_retry_loop(self, exception: Exception) -> None:
  144. """Emit a message saying we've begun retrying.
  145. Either calls the retry callback or prints a warning to console.
  146. Args:
  147. exception: The most recent exception we will retry.
  148. """
  149. from requests import HTTPError
  150. if (
  151. isinstance(exception, HTTPError)
  152. and exception.response is not None
  153. and self.retry_callback is not None
  154. ):
  155. self.retry_callback(
  156. exception.response.status_code,
  157. exception.response.text,
  158. )
  159. else:
  160. wandb.termlog(
  161. f"{self._error_prefix}"
  162. + f" ({exception.__class__.__name__}), entering retry loop."
  163. )
  164. def _print_recovered(self, start_time: datetime.datetime) -> None:
  165. """Emit a message saying we've recovered after retrying.
  166. Args:
  167. start_time: When we started retrying.
  168. """
  169. if not self.retry_callback:
  170. return
  171. now = NOW_FN()
  172. if now - self._last_print < datetime.timedelta(minutes=1):
  173. return
  174. self._last_print = now
  175. time_to_recover = now - start_time
  176. self.retry_callback(
  177. 200,
  178. (
  179. f"{self._error_prefix} resolved after"
  180. f" {time_to_recover}, resuming normal operation."
  181. ),
  182. )
  183. class _RetryLoop:
  184. """An invocation of a Retry instance."""
  185. def __init__(
  186. self,
  187. *,
  188. max_retries: int,
  189. timeout: datetime.timedelta,
  190. initial_sleep: float,
  191. max_sleep: float,
  192. cancel_event: Optional[threading.Event],
  193. retry_check: CheckRetryFnType,
  194. ) -> None:
  195. """Start a new call of a Retry instance.
  196. Args:
  197. max_retries: The number of retries after which to give up.
  198. timeout: An amount of time after which to give up.
  199. initial_sleep: Number of seconds to sleep for the first retry.
  200. This is used as the base for exponential backoff.
  201. max_sleep: Maximum number of seconds to sleep between retries.
  202. cancel_event: An event that's set when the function is cancelled.
  203. retry_check: A custom check for deciding whether an exception should
  204. be retried. Retrying is prevented if this returns a falsy value,
  205. even if more retries are left. This may also return a timedelta
  206. that represents a shorter timeout: retrying is prevented if the
  207. value is less than the amount of time that has passed since the
  208. last timedelta was returned.
  209. """
  210. self._max_retries = max_retries
  211. self._total_retries = 0
  212. self._timeout = timeout
  213. self._start_time = NOW_FN()
  214. self._next_sleep_time = initial_sleep
  215. self._max_sleep = max_sleep
  216. self._cancel_event = cancel_event
  217. self._retry_check = retry_check
  218. self._last_custom_timeout: Optional[datetime.datetime] = None
  219. def should_retry(self, exception: Exception) -> bool:
  220. """Returns whether an exception should be retried."""
  221. if self._total_retries >= self._max_retries:
  222. return False
  223. self._total_retries += 1
  224. now = NOW_FN()
  225. if now - self._start_time >= self._timeout:
  226. return False
  227. retry_check_result = self._retry_check(exception)
  228. if not retry_check_result:
  229. return False
  230. if isinstance(retry_check_result, datetime.timedelta):
  231. if not self._last_custom_timeout:
  232. self._last_custom_timeout = now
  233. if now - self._last_custom_timeout >= retry_check_result:
  234. return False
  235. return True
  236. def wait_before_retry(self) -> None:
  237. """Block until the next retry should happen.
  238. Raises:
  239. RetryCancelledError: If the operation is cancelled.
  240. """
  241. sleep_amount = self._next_sleep_time * (1 + random.random() * 0.25)
  242. if self._cancel_event:
  243. cancelled = self._cancel_event.wait(sleep_amount)
  244. if cancelled:
  245. raise RetryCancelledError("Cancelled while retrying.")
  246. else:
  247. SLEEP_FN(sleep_amount)
  248. self._next_sleep_time *= 2
  249. if self._next_sleep_time > self._max_sleep:
  250. self._next_sleep_time = self._max_sleep
  251. _F = TypeVar("_F", bound=Callable)
  252. def retriable(*args: Any, **kargs: Any) -> Callable[[_F], _F]:
  253. def decorator(fn: _F) -> _F:
  254. retrier: Retry[Any] = Retry(fn, *args, **kargs)
  255. @functools.wraps(fn)
  256. def wrapped_fn(*args: Any, **kargs: Any) -> Any:
  257. return retrier(*args, **kargs)
  258. return wrapped_fn # type: ignore
  259. return decorator
  260. class Backoff(abc.ABC):
  261. """A backoff strategy: decides whether to sleep or give up when an exception is raised."""
  262. @abc.abstractmethod
  263. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  264. raise NotImplementedError # pragma: no cover
  265. class ExponentialBackoff(Backoff):
  266. """Jittered exponential backoff: sleep times increase ~exponentially up to some limit."""
  267. def __init__(
  268. self,
  269. initial_sleep: datetime.timedelta,
  270. max_sleep: datetime.timedelta,
  271. max_retries: Optional[int] = None,
  272. timeout_at: Optional[datetime.datetime] = None,
  273. ) -> None:
  274. self._next_sleep = min(max_sleep, initial_sleep)
  275. self._max_sleep = max_sleep
  276. self._remaining_retries = max_retries
  277. self._timeout_at = timeout_at
  278. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  279. if self._remaining_retries is not None:
  280. if self._remaining_retries <= 0:
  281. raise exc
  282. self._remaining_retries -= 1
  283. if self._timeout_at is not None and NOW_FN() > self._timeout_at:
  284. raise exc
  285. result, self._next_sleep = (
  286. self._next_sleep,
  287. min(self._max_sleep, self._next_sleep * (1 + random.random())),
  288. )
  289. return result
  290. class FilteredBackoff(Backoff):
  291. """Re-raise any exceptions that fail a predicate; delegate others to another Backoff."""
  292. def __init__(self, filter: Callable[[Exception], bool], wrapped: Backoff) -> None:
  293. self._filter = filter
  294. self._wrapped = wrapped
  295. def next_sleep_or_reraise(self, exc: Exception) -> datetime.timedelta:
  296. if not self._filter(exc):
  297. raise exc
  298. return self._wrapped.next_sleep_or_reraise(exc)
  299. async def retry_async(
  300. backoff: Backoff,
  301. fn: Callable[..., Awaitable[_R]],
  302. *args: Any,
  303. on_exc: Optional[Callable[[Exception], None]] = None,
  304. **kwargs: Any,
  305. ) -> _R:
  306. """Call `fn` repeatedly until either it succeeds, or `backoff` decides we should give up.
  307. Each time `fn` fails, `on_exc` is called with the exception.
  308. """
  309. while True:
  310. try:
  311. return await fn(*args, **kwargs)
  312. except Exception as e:
  313. if on_exc is not None:
  314. on_exc(e)
  315. await SLEEP_ASYNC_FN(backoff.next_sleep_or_reraise(e).total_seconds())