pyagent.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385
  1. """Agent - Agent object.
  2. Manage wandb agent.
  3. """
  4. import ctypes
  5. import logging
  6. import os
  7. import queue
  8. import socket
  9. import sys
  10. import threading
  11. import time
  12. import traceback
  13. import wandb
  14. from wandb.apis import InternalApi
  15. from wandb.sdk.launch.sweeps import utils as sweep_utils
  16. from wandb.sdk.lib import config_util
  17. logger = logging.getLogger(__name__)
  18. def _terminate_thread(thread):
  19. if not thread.is_alive():
  20. return
  21. if hasattr(thread, "_terminated"):
  22. return
  23. thread._terminated = True
  24. tid = getattr(thread, "_thread_id", None)
  25. if tid is None:
  26. for k, v in threading._active.items():
  27. if v is thread:
  28. tid = k
  29. if tid is None:
  30. # This should never happen
  31. return
  32. logger.debug(f"Terminating thread: {tid}")
  33. res = ctypes.pythonapi.PyThreadState_SetAsyncExc(
  34. ctypes.c_long(tid), ctypes.py_object(Exception)
  35. )
  36. if res == 0:
  37. # This should never happen
  38. return
  39. elif res != 1:
  40. # Revert
  41. logger.debug(f"Termination failed for thread {tid}")
  42. ctypes.pythonapi.PyThreadState_SetAsyncExc(ctypes.c_long(tid), None)
  43. class Job:
  44. def __init__(self, command):
  45. self.command = command
  46. job_type = command.get("type")
  47. self.type = job_type
  48. self.run_id = command.get("run_id")
  49. self.config = command.get("args")
  50. def __repr__(self):
  51. if self.type == "run":
  52. return f"Job({self.run_id},{self.config})"
  53. elif self.type == "stop":
  54. return f"stop({self.run_id})"
  55. else:
  56. return "exit"
  57. class RunStatus:
  58. QUEUED = "QUEUED"
  59. RUNNING = "RUNNING"
  60. STOPPED = "STOPPED"
  61. ERRORED = "ERRORED"
  62. DONE = "DONE"
  63. class Agent:
  64. FLAPPING_MAX_SECONDS = 60
  65. FLAPPING_MAX_FAILURES = 3
  66. MAX_INITIAL_FAILURES = 5
  67. def __init__(
  68. self, sweep_id=None, project=None, entity=None, function=None, count=None
  69. ):
  70. self._sweep_path = sweep_id
  71. self._sweep_id = None
  72. self._project = project
  73. self._entity = entity
  74. self._function = function
  75. self._count = count
  76. # glob_config = os.path.expanduser('~/.config/wandb/settings')
  77. # loc_config = 'wandb/settings'
  78. # files = (glob_config, loc_config)
  79. self._api = InternalApi()
  80. self._agent_id = None
  81. self._max_initial_failures = wandb.env.get_agent_max_initial_failures(
  82. self.MAX_INITIAL_FAILURES
  83. )
  84. # if the directory to log to is not set, set it
  85. if os.environ.get(wandb.env.DIR) is None:
  86. os.environ[wandb.env.DIR] = os.path.abspath(os.getcwd())
  87. def _init(self):
  88. # These are not in constructor so that Agent instance can be rerun
  89. self._run_threads = {}
  90. self._run_status = {}
  91. self._queue = queue.Queue()
  92. self._exit_flag = False
  93. self._exceptions = {}
  94. self._start_time = time.time()
  95. def _register(self):
  96. logger.debug("Agent._register()")
  97. agent = self._api.register_agent(socket.gethostname(), sweep_id=self._sweep_id)
  98. self._agent_id = agent["id"]
  99. logger.debug(f"agent_id = {self._agent_id}")
  100. def _setup(self):
  101. logger.debug("Agent._setup()")
  102. self._init()
  103. parts = dict(entity=self._entity, project=self._project, name=self._sweep_path)
  104. err = sweep_utils.parse_sweep_id(parts)
  105. if err:
  106. wandb.termerror(err)
  107. return
  108. entity = parts.get("entity") or self._entity
  109. project = parts.get("project") or self._project
  110. sweep_id = parts.get("name") or self._sweep_id
  111. if sweep_id:
  112. os.environ[wandb.env.SWEEP_ID] = sweep_id
  113. if entity:
  114. wandb.env.set_entity(entity)
  115. if project:
  116. wandb.env.set_project(project)
  117. if sweep_id:
  118. self._sweep_id = sweep_id
  119. self._register()
  120. def _stop_run(self, run_id):
  121. logger.debug(f"Stopping run {run_id}.")
  122. self._run_status[run_id] = RunStatus.STOPPED
  123. thread = self._run_threads.get(run_id)
  124. if thread:
  125. _terminate_thread(thread)
  126. def _stop_all_runs(self):
  127. logger.debug("Stopping all runs.")
  128. for run in list(self._run_threads.keys()):
  129. self._stop_run(run)
  130. def _exit(self):
  131. self._stop_all_runs()
  132. self._exit_flag = True
  133. # _terminate_thread(self._main_thread)
  134. def _heartbeat(self):
  135. while True:
  136. if self._exit_flag:
  137. return
  138. # if not self._main_thread.is_alive():
  139. # return
  140. run_status = {
  141. run: True
  142. for run, status in self._run_status.items()
  143. if status in (RunStatus.QUEUED, RunStatus.RUNNING)
  144. }
  145. commands = self._api.agent_heartbeat(self._agent_id, {}, run_status)
  146. if commands:
  147. job = Job(commands[0])
  148. logger.debug(f"Job received: {job}")
  149. if job.type in ["run", "resume"]:
  150. self._queue.put(job)
  151. self._run_status[job.run_id] = RunStatus.QUEUED
  152. elif job.type == "stop":
  153. self._stop_run(job.run_id)
  154. elif job.type == "exit":
  155. self._exit()
  156. return
  157. time.sleep(5)
  158. def _run_jobs_from_queue(self):
  159. global _INSTANCES
  160. _INSTANCES += 1
  161. try:
  162. waiting = False
  163. count = 0
  164. while True:
  165. if self._exit_flag:
  166. return
  167. try:
  168. try:
  169. job = self._queue.get(timeout=5)
  170. if self._exit_flag:
  171. logger.debug("Exiting main loop due to exit flag.")
  172. wandb.termlog("Sweep Agent: Exiting.")
  173. return
  174. except queue.Empty:
  175. if not waiting:
  176. logger.debug("Paused.")
  177. wandb.termlog("Sweep Agent: Waiting for job.")
  178. waiting = True
  179. time.sleep(5)
  180. if self._exit_flag:
  181. logger.debug("Exiting main loop due to exit flag.")
  182. wandb.termlog("Sweep Agent: Exiting.")
  183. return
  184. continue
  185. if waiting:
  186. logger.debug("Resumed.")
  187. wandb.termlog("Job received.")
  188. waiting = False
  189. count += 1
  190. run_id = job.run_id
  191. if self._run_status[run_id] == RunStatus.STOPPED:
  192. continue
  193. logger.debug(f"Spawning new thread for run {run_id}.")
  194. thread = threading.Thread(target=self._run_job, args=(job,))
  195. self._run_threads[run_id] = thread
  196. thread.start()
  197. self._run_status[run_id] = RunStatus.RUNNING
  198. thread.join()
  199. logger.debug(f"Thread joined for run {run_id}.")
  200. if self._run_status[run_id] == RunStatus.RUNNING:
  201. self._run_status[run_id] = RunStatus.DONE
  202. elif self._run_status[run_id] == RunStatus.ERRORED:
  203. exc = self._exceptions[run_id]
  204. # Extract to reduce a decision point to avoid ruff c901
  205. log_str, term_str = _get_exception_logger_and_term_strs(exc)
  206. logger.error(f"Run {run_id} errored:\n{log_str}")
  207. wandb.termerror(f"Run {run_id} errored:{term_str}")
  208. if os.getenv(wandb.env.AGENT_DISABLE_FLAPPING) == "true":
  209. self._exit_flag = True
  210. return
  211. elif (
  212. time.time() - self._start_time < self.FLAPPING_MAX_SECONDS
  213. ) and (len(self._exceptions) >= self.FLAPPING_MAX_FAILURES):
  214. msg = f"Detected {self.FLAPPING_MAX_FAILURES} failed runs in the first {self.FLAPPING_MAX_SECONDS} seconds, killing sweep."
  215. logger.error(msg)
  216. wandb.termerror(msg)
  217. wandb.termlog(
  218. "To disable this check set WANDB_AGENT_DISABLE_FLAPPING=true"
  219. )
  220. self._exit_flag = True
  221. return
  222. if (
  223. self._max_initial_failures < len(self._exceptions)
  224. and len(self._exceptions) >= count
  225. ):
  226. msg = f"Detected {self._max_initial_failures} failed runs in a row at start, killing sweep."
  227. logger.error(msg)
  228. wandb.termerror(msg)
  229. wandb.termlog(
  230. "To change this value set WANDB_AGENT_MAX_INITIAL_FAILURES=val"
  231. )
  232. self._exit_flag = True
  233. return
  234. if self._count and self._count == count:
  235. logger.debug("Exiting main loop because max count reached.")
  236. self._exit_flag = True
  237. return
  238. except KeyboardInterrupt:
  239. logger.debug("Ctrl + C detected. Stopping sweep.")
  240. wandb.termlog("Ctrl + C detected. Stopping sweep.")
  241. self._exit()
  242. return
  243. except Exception:
  244. if self._exit_flag:
  245. logger.debug("Exiting main loop due to exit flag.")
  246. wandb.termlog("Sweep Agent: Killed.")
  247. return
  248. else:
  249. raise
  250. finally:
  251. _INSTANCES -= 1
  252. def _run_job(self, job):
  253. try:
  254. run_id = job.run_id
  255. config_file = os.path.join(
  256. "wandb", "sweep-" + self._sweep_id, "config-" + run_id + ".yaml"
  257. )
  258. os.environ[wandb.env.RUN_ID] = run_id
  259. base_dir = os.environ.get(wandb.env.DIR, "")
  260. sweep_param_path = os.path.join(base_dir, config_file)
  261. os.environ[wandb.env.SWEEP_PARAM_PATH] = sweep_param_path
  262. config_util.save_config_file_from_dict(sweep_param_path, job.config)
  263. os.environ[wandb.env.SWEEP_ID] = self._sweep_id
  264. wandb.teardown()
  265. wandb.termlog(f"Agent Starting Run: {run_id} with config:")
  266. for k, v in job.config.items():
  267. wandb.termlog("\t{}: {}".format(k, v["value"]))
  268. try:
  269. self._function()
  270. except KeyboardInterrupt:
  271. raise
  272. except Exception as e:
  273. # Log the run's exceptions directly to stderr to match CLI case, and wrap so we
  274. # can identify it as coming from the job later later. This will get automatically
  275. # logged by console_capture.py. Exception handler below will also handle exceptions
  276. # in setup code.
  277. exc_repr = _format_exception_traceback(e)
  278. print(exc_repr, file=sys.stderr) # noqa: T201
  279. raise _JobError(f"Run threw exception: {str(e)}") from e
  280. wandb.finish()
  281. except KeyboardInterrupt:
  282. raise
  283. except Exception as e:
  284. wandb.finish(exit_code=1)
  285. if self._run_status[run_id] == RunStatus.RUNNING:
  286. self._run_status[run_id] = RunStatus.ERRORED
  287. self._exceptions[run_id] = e
  288. finally:
  289. # clean up the environment changes made
  290. os.environ.pop(wandb.env.RUN_ID, None)
  291. os.environ.pop(wandb.env.SWEEP_ID, None)
  292. os.environ.pop(wandb.env.SWEEP_PARAM_PATH, None)
  293. def run(self):
  294. logger.info(
  295. f"Starting sweep agent: entity={self._entity}, project={self._project}, count={self._count}"
  296. )
  297. self._setup()
  298. # self._main_thread = threading.Thread(target=self._run_jobs_from_queue)
  299. self._heartbeat_thread = threading.Thread(target=self._heartbeat)
  300. self._heartbeat_thread.daemon = True
  301. # self._main_thread.start()
  302. self._heartbeat_thread.start()
  303. # self._main_thread.join()
  304. self._run_jobs_from_queue()
  305. def pyagent(sweep_id, function, entity=None, project=None, count=None):
  306. """Generic agent entrypoint, used for CLI or jupyter.
  307. Args:
  308. sweep_id (dict): Sweep ID generated by CLI or sweep API
  309. function (func, optional): A function to call instead of the "program"
  310. entity (str, optional): W&B Entity
  311. project (str, optional): W&B Project
  312. count (int, optional): the number of trials to run.
  313. """
  314. if not callable(function):
  315. raise TypeError("function parameter must be callable!")
  316. agent = Agent(
  317. sweep_id,
  318. function=function,
  319. entity=entity,
  320. project=project,
  321. count=count,
  322. )
  323. agent.run()
  324. def _format_exception_traceback(exc):
  325. return "".join(traceback.format_exception(type(exc), exc, exc.__traceback__))
  326. class _JobError(Exception):
  327. """Exception raised when a job fails during execution."""
  328. pass
  329. def _get_exception_logger_and_term_strs(exc):
  330. if isinstance(exc, _JobError) and exc.__cause__:
  331. # If it's a JobException, get the original exception for display
  332. job_exc = exc.__cause__
  333. log_str = _format_exception_traceback(job_exc)
  334. # Don't long full stacktrace to terminal again because we already
  335. # printed it to stderr.
  336. term_str = " " + str(job_exc)
  337. else:
  338. log_str = _format_exception_traceback(exc)
  339. term_str = "\n" + log_str
  340. return log_str, term_str
  341. _INSTANCES = 0
  342. def is_running():
  343. return bool(_INSTANCES)