handler.py 31 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854
  1. """Handle Manager."""
  2. import json
  3. import logging
  4. import math
  5. import numbers
  6. import time
  7. from collections import defaultdict
  8. from queue import Queue
  9. from threading import Event
  10. from typing import (
  11. TYPE_CHECKING,
  12. Any,
  13. Callable,
  14. Dict,
  15. Iterable,
  16. List,
  17. Optional,
  18. Sequence,
  19. Tuple,
  20. cast,
  21. )
  22. from wandb.errors.links import url_registry
  23. from wandb.proto.wandb_internal_pb2 import (
  24. HistoryRecord,
  25. InternalMessages,
  26. MetricRecord,
  27. Record,
  28. Result,
  29. RunRecord,
  30. SampledHistoryItem,
  31. SummaryItem,
  32. SummaryRecord,
  33. SummaryRecordRequest,
  34. )
  35. from ..interface.interface_queue import InterfaceQueue
  36. from ..lib import handler_util, proto_util
  37. from . import context, sample, tb_watcher
  38. from .settings_static import SettingsStatic
  39. if TYPE_CHECKING:
  40. from wandb.proto.wandb_internal_pb2 import MetricSummary
  41. SummaryDict = Dict[str, Any]
  42. logger = logging.getLogger(__name__)
  43. # Update (March 5, 2024): Since ~2020/2021, when constructing the summary
  44. # object, we had replaced the artifact path for media types with the latest
  45. # artifact path. The primary purpose of this was to support live updating of
  46. # media objects in the UI (since the default artifact path was fully qualified
  47. # and would not update). However, in March of 2024, a bug was discovered with
  48. # this approach which causes this path to be incorrect in cases where the media
  49. # object is logged to another artifact before being logged to the run. Setting
  50. # this to `False` disables this copy behavior. The impact is that users will
  51. # need to refresh to see updates. Ironically, this updating behavior is not
  52. # currently supported in the UI, so the impact of this change is minimal.
  53. REPLACE_SUMMARY_ART_PATH_WITH_LATEST = False
  54. def _dict_nested_set(target: Dict[str, Any], key_list: Sequence[str], v: Any) -> None:
  55. # recurse down the dictionary structure:
  56. for k in key_list[:-1]:
  57. target.setdefault(k, {})
  58. new_target = target.get(k)
  59. if TYPE_CHECKING:
  60. new_target = cast(Dict[str, Any], new_target)
  61. target = new_target
  62. # use the last element of the key to write the leaf:
  63. target[key_list[-1]] = v
  64. class HandleManager:
  65. _consolidated_summary: SummaryDict
  66. _sampled_history: Dict[str, sample.UniformSampleAccumulator]
  67. _partial_history: Dict[str, Any]
  68. _run_proto: Optional[RunRecord]
  69. _settings: SettingsStatic
  70. _record_q: "Queue[Record]"
  71. _result_q: "Queue[Result]"
  72. _stopped: Event
  73. _writer_q: "Queue[Record]"
  74. _interface: InterfaceQueue
  75. _tb_watcher: Optional[tb_watcher.TBWatcher]
  76. _metric_defines: Dict[str, MetricRecord]
  77. _metric_globs: Dict[str, MetricRecord]
  78. _metric_track: Dict[Tuple[str, ...], float]
  79. _metric_copy: Dict[Tuple[str, ...], Any]
  80. _track_time: Optional[float]
  81. _accumulate_time: float
  82. _run_start_time: Optional[float]
  83. _context_keeper: context.ContextKeeper
  84. def __init__(
  85. self,
  86. settings: SettingsStatic,
  87. record_q: "Queue[Record]",
  88. result_q: "Queue[Result]",
  89. stopped: Event,
  90. writer_q: "Queue[Record]",
  91. interface: InterfaceQueue,
  92. context_keeper: context.ContextKeeper,
  93. ) -> None:
  94. self._settings = settings
  95. self._record_q = record_q
  96. self._result_q = result_q
  97. self._stopped = stopped
  98. self._writer_q = writer_q
  99. self._interface = interface
  100. self._context_keeper = context_keeper
  101. self._tb_watcher = None
  102. self._step = 0
  103. self._track_time = None
  104. self._accumulate_time = 0
  105. self._run_start_time = None
  106. # keep track of summary from key/val updates
  107. self._consolidated_summary = dict()
  108. self._sampled_history = defaultdict(sample.UniformSampleAccumulator)
  109. self._run_proto = None
  110. self._partial_history = dict()
  111. self._metric_defines = defaultdict(MetricRecord)
  112. self._metric_globs = defaultdict(MetricRecord)
  113. self._metric_track = dict()
  114. self._metric_copy = dict()
  115. self._internal_messages = InternalMessages()
  116. self._dropped_history = False
  117. def __len__(self) -> int:
  118. return self._record_q.qsize()
  119. def handle(self, record: Record) -> None:
  120. self._context_keeper.add_from_record(record)
  121. record_type = record.WhichOneof("record_type")
  122. assert record_type
  123. handler_str = "handle_" + record_type
  124. handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
  125. assert handler, f"unknown handle: {handler_str}" # type: ignore
  126. handler(record)
  127. def handle_request(self, record: Record) -> None:
  128. request_type = record.request.WhichOneof("request_type")
  129. assert request_type
  130. handler_str = "handle_request_" + request_type
  131. handler: Callable[[Record], None] = getattr(self, handler_str, None) # type: ignore
  132. if request_type != "network_status":
  133. logger.debug(f"handle_request: {request_type}")
  134. assert handler, f"unknown handle: {handler_str}" # type: ignore
  135. handler(record)
  136. def _dispatch_record(self, record: Record, always_send: bool = False) -> None:
  137. if always_send:
  138. record.control.always_send = True
  139. self._writer_q.put(record)
  140. def _respond_result(self, result: Result) -> None:
  141. context_id = context.context_id_from_result(result)
  142. self._context_keeper.release(context_id)
  143. self._result_q.put(result)
  144. def debounce(self) -> None:
  145. pass
  146. def handle_request_cancel(self, record: Record) -> None:
  147. self._dispatch_record(record)
  148. def handle_request_defer(self, record: Record) -> None:
  149. defer = record.request.defer
  150. state = defer.state
  151. logger.info(f"handle defer: {state}")
  152. if state == defer.FLUSH_TB:
  153. if self._tb_watcher:
  154. # shutdown tensorboard workers so we get all metrics flushed
  155. self._tb_watcher.finish()
  156. self._tb_watcher = None
  157. elif state == defer.FLUSH_PARTIAL_HISTORY:
  158. self._flush_partial_history()
  159. elif state == defer.FLUSH_SUM:
  160. self._save_summary(self._consolidated_summary, flush=True)
  161. # defer is used to drive the sender finish state machine
  162. self._dispatch_record(record, always_send=True)
  163. def handle_request_python_packages(self, record: Record) -> None:
  164. self._dispatch_record(record)
  165. def handle_run(self, record: Record) -> None:
  166. if self._settings._offline:
  167. self._run_proto = record.run
  168. result = proto_util._result_from_record(record)
  169. result.run_result.run.CopyFrom(record.run)
  170. self._respond_result(result)
  171. self._dispatch_record(record)
  172. def handle_stats(self, record: Record) -> None:
  173. self._dispatch_record(record)
  174. def handle_config(self, record: Record) -> None:
  175. self._dispatch_record(record)
  176. def handle_output(self, record: Record) -> None:
  177. self._dispatch_record(record)
  178. def handle_output_raw(self, record: Record) -> None:
  179. self._dispatch_record(record)
  180. def handle_files(self, record: Record) -> None:
  181. self._dispatch_record(record)
  182. def handle_request_link_artifact(self, record: Record) -> None:
  183. self._dispatch_record(record)
  184. def handle_use_artifact(self, record: Record) -> None:
  185. self._dispatch_record(record)
  186. def handle_artifact(self, record: Record) -> None:
  187. self._dispatch_record(record)
  188. def handle_alert(self, record: Record) -> None:
  189. self._dispatch_record(record)
  190. def _save_summary(self, summary_dict: SummaryDict, flush: bool = False) -> None:
  191. summary = SummaryRecord()
  192. for k, v in summary_dict.items():
  193. update = summary.update.add()
  194. update.key = k
  195. update.value_json = json.dumps(v)
  196. if flush:
  197. record = Record(summary=summary)
  198. self._dispatch_record(record)
  199. elif not self._settings._offline:
  200. # Send this summary update as a request since we aren't persisting every update
  201. summary_record = SummaryRecordRequest(summary=summary)
  202. request_record = self._interface._make_request(
  203. summary_record=summary_record
  204. )
  205. self._dispatch_record(request_record)
  206. def _save_history(
  207. self,
  208. history: HistoryRecord,
  209. ) -> None:
  210. for item in history.item:
  211. # TODO(jhr) save nested keys?
  212. k = item.key
  213. v = json.loads(item.value_json)
  214. if isinstance(v, numbers.Real):
  215. self._sampled_history[k].add(v)
  216. def _update_summary_metrics(
  217. self,
  218. s: "MetricSummary",
  219. kl: List[str],
  220. v: "numbers.Real",
  221. float_v: float,
  222. goal_max: Optional[bool],
  223. ) -> bool:
  224. updated = False
  225. best_key: Optional[Tuple[str, ...]] = None
  226. if s.none:
  227. return False
  228. if s.copy:
  229. # non-key list copy already done in _update_summary
  230. if len(kl) > 1:
  231. _dict_nested_set(self._consolidated_summary, kl, v)
  232. return True
  233. if s.last:
  234. last_key = tuple(kl + ["last"])
  235. old_last = self._metric_track.get(last_key)
  236. if old_last is None or float_v != old_last:
  237. self._metric_track[last_key] = float_v
  238. _dict_nested_set(self._consolidated_summary, last_key, v)
  239. updated = True
  240. if s.best:
  241. best_key = tuple(kl + ["best"])
  242. if s.max or best_key and goal_max:
  243. max_key = tuple(kl + ["max"])
  244. old_max = self._metric_track.get(max_key)
  245. if old_max is None or float_v > old_max:
  246. self._metric_track[max_key] = float_v
  247. if s.max:
  248. _dict_nested_set(self._consolidated_summary, max_key, v)
  249. updated = True
  250. if best_key:
  251. _dict_nested_set(self._consolidated_summary, best_key, v)
  252. updated = True
  253. # defaulting to minimize if goal is not specified
  254. if s.min or best_key and not goal_max:
  255. min_key = tuple(kl + ["min"])
  256. old_min = self._metric_track.get(min_key)
  257. if old_min is None or float_v < old_min:
  258. self._metric_track[min_key] = float_v
  259. if s.min:
  260. _dict_nested_set(self._consolidated_summary, min_key, v)
  261. updated = True
  262. if best_key:
  263. _dict_nested_set(self._consolidated_summary, best_key, v)
  264. updated = True
  265. if s.mean:
  266. tot_key = tuple(kl + ["tot"])
  267. num_key = tuple(kl + ["num"])
  268. avg_key = tuple(kl + ["mean"])
  269. tot = self._metric_track.get(tot_key, 0.0)
  270. num = self._metric_track.get(num_key, 0)
  271. tot += float_v
  272. num += 1
  273. self._metric_track[tot_key] = tot
  274. self._metric_track[num_key] = num
  275. _dict_nested_set(self._consolidated_summary, avg_key, tot / num)
  276. updated = True
  277. return updated
  278. def _update_summary_leaf(
  279. self,
  280. kl: List[str],
  281. v: Any,
  282. d: Optional[MetricRecord] = None,
  283. ) -> bool:
  284. has_summary = d and d.HasField("summary")
  285. if len(kl) == 1:
  286. copy_key = tuple(kl)
  287. old_copy = self._metric_copy.get(copy_key)
  288. if old_copy is None or v != old_copy:
  289. self._metric_copy[copy_key] = v
  290. # Store copy metric if not specified, or copy behavior
  291. if not has_summary or (d and d.summary.copy):
  292. self._consolidated_summary[kl[0]] = v
  293. return True
  294. if not d:
  295. return False
  296. if not has_summary:
  297. return False
  298. if not isinstance(v, numbers.Real):
  299. return False
  300. if math.isnan(v):
  301. return False
  302. float_v = float(v)
  303. goal_max = None
  304. if d.goal:
  305. goal_max = d.goal == d.GOAL_MAXIMIZE
  306. if self._update_summary_metrics(
  307. d.summary, kl=kl, v=v, float_v=float_v, goal_max=goal_max
  308. ):
  309. return True
  310. return False
  311. def _update_summary_list(
  312. self,
  313. kl: List[str],
  314. v: Any,
  315. d: Optional[MetricRecord] = None,
  316. ) -> bool:
  317. metric_key = ".".join([k.replace(".", "\\.") for k in kl])
  318. d = self._metric_defines.get(metric_key, d)
  319. # if the dict has _type key, it's a wandb table object
  320. if isinstance(v, dict) and not handler_util.metric_is_wandb_dict(v):
  321. updated = False
  322. for nk, nv in v.items():
  323. if self._update_summary_list(kl=kl[:] + [nk], v=nv, d=d):
  324. updated = True
  325. return updated
  326. # If the dict is a media object, update the pointer to the latest alias
  327. elif (
  328. REPLACE_SUMMARY_ART_PATH_WITH_LATEST
  329. and isinstance(v, dict)
  330. and handler_util.metric_is_wandb_dict(v)
  331. ):
  332. if "_latest_artifact_path" in v and "artifact_path" in v:
  333. # TODO: Make non-destructive?
  334. v["artifact_path"] = v["_latest_artifact_path"]
  335. updated = self._update_summary_leaf(kl=kl, v=v, d=d)
  336. return updated
  337. def _update_summary_media_objects(self, v: Dict[str, Any]) -> Dict[str, Any]:
  338. # For now, non-recursive - just top level
  339. for nk, nv in v.items():
  340. if REPLACE_SUMMARY_ART_PATH_WITH_LATEST and (
  341. isinstance(nv, dict)
  342. and handler_util.metric_is_wandb_dict(nv)
  343. and "_latest_artifact_path" in nv
  344. and "artifact_path" in nv
  345. ):
  346. # TODO: Make non-destructive?
  347. nv["artifact_path"] = nv["_latest_artifact_path"]
  348. v[nk] = nv
  349. return v
  350. def _update_summary(self, history_dict: Dict[str, Any]) -> List[str]:
  351. # keep old behavior fast path if no define metrics have been used
  352. if not self._metric_defines:
  353. history_dict = self._update_summary_media_objects(history_dict)
  354. self._consolidated_summary.update(history_dict)
  355. return list(history_dict.keys())
  356. updated_keys = []
  357. for k, v in history_dict.items():
  358. if self._update_summary_list(kl=[k], v=v):
  359. updated_keys.append(k)
  360. return updated_keys
  361. def _history_assign_step(
  362. self,
  363. history: HistoryRecord,
  364. history_dict: Dict[str, Any],
  365. ) -> None:
  366. has_step = history.HasField("step")
  367. item = history.item.add()
  368. item.key = "_step"
  369. if has_step:
  370. step = history.step.num
  371. history_dict["_step"] = step
  372. item.value_json = json.dumps(step)
  373. self._step = step + 1
  374. else:
  375. history_dict["_step"] = self._step
  376. item.value_json = json.dumps(self._step)
  377. self._step += 1
  378. def _history_define_metric(self, hkey: str) -> Optional[MetricRecord]:
  379. """Check for hkey match in glob metrics and return the defined metric."""
  380. # Dont define metric for internal metrics
  381. if hkey.startswith("_"):
  382. return None
  383. for k, mglob in self._metric_globs.items():
  384. if k.endswith("*"):
  385. if hkey.startswith(k[:-1]):
  386. m = MetricRecord()
  387. m.CopyFrom(mglob)
  388. m.ClearField("glob_name")
  389. m.options.defined = False
  390. m.name = hkey
  391. return m
  392. return None
  393. def _history_update_leaf(
  394. self,
  395. kl: List[str],
  396. v: Any,
  397. history_dict: Dict[str, Any],
  398. update_history: Dict[str, Any],
  399. ) -> None:
  400. hkey = ".".join([k.replace(".", "\\.") for k in kl])
  401. m = self._metric_defines.get(hkey)
  402. if not m:
  403. m = self._history_define_metric(hkey)
  404. if not m:
  405. return
  406. mr = Record()
  407. mr.metric.CopyFrom(m)
  408. mr.control.local = True # Dont store this, just send it
  409. self._handle_defined_metric(mr)
  410. if m.options.step_sync and m.step_metric:
  411. if m.step_metric not in history_dict:
  412. copy_key = tuple([m.step_metric])
  413. step = self._metric_copy.get(copy_key)
  414. if step is not None:
  415. update_history[m.step_metric] = step
  416. def _history_update_list(
  417. self,
  418. kl: List[str],
  419. v: Any,
  420. history_dict: Dict[str, Any],
  421. update_history: Dict[str, Any],
  422. ) -> None:
  423. if isinstance(v, dict):
  424. for nk, nv in v.items():
  425. self._history_update_list(
  426. kl=kl[:] + [nk],
  427. v=nv,
  428. history_dict=history_dict,
  429. update_history=update_history,
  430. )
  431. return
  432. self._history_update_leaf(
  433. kl=kl, v=v, history_dict=history_dict, update_history=update_history
  434. )
  435. def _history_update(
  436. self,
  437. history: HistoryRecord,
  438. history_dict: Dict[str, Any],
  439. ) -> None:
  440. # if syncing an old run, we can skip this logic
  441. if history_dict.get("_step") is None:
  442. self._history_assign_step(history, history_dict)
  443. update_history: Dict[str, Any] = {}
  444. # Look for metric matches
  445. if self._metric_defines or self._metric_globs:
  446. for hkey, hval in history_dict.items():
  447. self._history_update_list([hkey], hval, history_dict, update_history)
  448. if update_history:
  449. history_dict.update(update_history)
  450. for k, v in update_history.items():
  451. item = history.item.add()
  452. item.key = k
  453. item.value_json = json.dumps(v)
  454. def handle_history(self, record: Record) -> None:
  455. history_dict = proto_util.dict_from_proto_list(record.history.item)
  456. # Inject _runtime if it is not present
  457. if history_dict is not None:
  458. if "_runtime" not in history_dict:
  459. self._history_assign_runtime(record.history, history_dict)
  460. self._history_update(record.history, history_dict)
  461. self._dispatch_record(record)
  462. self._save_history(record.history)
  463. # update summary from history
  464. updated_keys = self._update_summary(history_dict)
  465. if updated_keys:
  466. updated_items = {k: self._consolidated_summary[k] for k in updated_keys}
  467. self._save_summary(updated_items)
  468. def _flush_partial_history(
  469. self,
  470. step: Optional[int] = None,
  471. ) -> None:
  472. if not self._partial_history:
  473. return
  474. history = HistoryRecord()
  475. for k, v in self._partial_history.items():
  476. item = history.item.add()
  477. item.key = k
  478. item.value_json = json.dumps(v)
  479. if step is not None:
  480. history.step.num = step
  481. self.handle_history(Record(history=history))
  482. self._partial_history = {}
  483. def handle_request_sender_mark_report(self, record: Record) -> None:
  484. self._dispatch_record(record, always_send=True)
  485. def handle_request_status_report(self, record: Record) -> None:
  486. self._dispatch_record(record, always_send=True)
  487. def handle_request_partial_history(self, record: Record) -> None:
  488. partial_history = record.request.partial_history
  489. flush = None
  490. if partial_history.HasField("action"):
  491. flush = partial_history.action.flush
  492. step = None
  493. if partial_history.HasField("step"):
  494. step = partial_history.step.num
  495. history_dict = proto_util.dict_from_proto_list(partial_history.item)
  496. if step is not None:
  497. if step < self._step:
  498. if not self._dropped_history:
  499. message = (
  500. "Step only supports monotonically increasing values, use define_metric to set a custom x "
  501. f"axis. For details see: {url_registry.url('define-metric')}"
  502. )
  503. self._internal_messages.warning.append(message)
  504. self._dropped_history = True
  505. message = (
  506. f"(User provided step: {step} is less than current step: {self._step}. "
  507. f"Dropping entry: {history_dict})."
  508. )
  509. self._internal_messages.warning.append(message)
  510. return
  511. elif step > self._step:
  512. self._flush_partial_history()
  513. self._step = step
  514. elif flush is None:
  515. flush = True
  516. self._partial_history.update(history_dict)
  517. if flush:
  518. self._flush_partial_history(self._step)
  519. def handle_summary(self, record: Record) -> None:
  520. summary = record.summary
  521. for item in summary.update:
  522. if len(item.nested_key) > 0:
  523. # we use either key or nested_key -- not both
  524. assert item.key == ""
  525. key = tuple(item.nested_key)
  526. else:
  527. # no counter-assertion here, because technically
  528. # summary[""] is valid
  529. key = (item.key,)
  530. target = self._consolidated_summary
  531. # recurse down the dictionary structure:
  532. for prop in key[:-1]:
  533. target = target[prop]
  534. # use the last element of the key to write the leaf:
  535. target[key[-1]] = json.loads(item.value_json)
  536. for item in summary.remove:
  537. if len(item.nested_key) > 0:
  538. # we use either key or nested_key -- not both
  539. assert item.key == ""
  540. key = tuple(item.nested_key)
  541. else:
  542. # no counter-assertion here, because technically
  543. # summary[""] is valid
  544. key = (item.key,)
  545. target = self._consolidated_summary
  546. # recurse down the dictionary structure:
  547. for prop in key[:-1]:
  548. target = target[prop]
  549. # use the last element of the key to erase the leaf:
  550. del target[key[-1]]
  551. self._save_summary(self._consolidated_summary)
  552. def handle_exit(self, record: Record) -> None:
  553. if self._track_time is not None:
  554. self._accumulate_time += time.time() - self._track_time
  555. record.exit.runtime = int(self._accumulate_time)
  556. self._dispatch_record(record, always_send=True)
  557. def handle_final(self, record: Record) -> None:
  558. self._dispatch_record(record, always_send=True)
  559. def handle_preempting(self, record: Record) -> None:
  560. self._dispatch_record(record)
  561. def handle_header(self, record: Record) -> None:
  562. self._dispatch_record(record)
  563. def handle_footer(self, record: Record) -> None:
  564. self._dispatch_record(record)
  565. def handle_metadata(self, record: Record) -> None:
  566. self._dispatch_record(record)
  567. def handle_request_attach(self, record: Record) -> None:
  568. result = proto_util._result_from_record(record)
  569. attach_id = record.request.attach.attach_id
  570. assert attach_id
  571. assert self._run_proto
  572. result.response.attach_response.run.CopyFrom(self._run_proto)
  573. self._respond_result(result)
  574. def handle_request_log_artifact(self, record: Record) -> None:
  575. self._dispatch_record(record)
  576. def handle_telemetry(self, record: Record) -> None:
  577. self._dispatch_record(record)
  578. def handle_request_run_start(self, record: Record) -> None:
  579. run_start = record.request.run_start
  580. assert run_start
  581. assert run_start.run
  582. self._run_proto = run_start.run
  583. self._run_start_time = run_start.run.start_time.ToMicroseconds() / 1e6
  584. self._track_time = time.time()
  585. if run_start.run.resumed and run_start.run.runtime:
  586. self._accumulate_time = run_start.run.runtime
  587. else:
  588. self._accumulate_time = 0
  589. self._tb_watcher = tb_watcher.TBWatcher(
  590. self._settings, interface=self._interface, run_proto=run_start.run
  591. )
  592. if run_start.run.resumed or run_start.run.forked:
  593. self._step = run_start.run.starting_step
  594. result = proto_util._result_from_record(record)
  595. self._respond_result(result)
  596. def handle_request_resume(self, record: Record) -> None:
  597. if self._track_time is not None:
  598. self._accumulate_time += time.time() - self._track_time
  599. self._track_time = time.time()
  600. def handle_request_pause(self, record: Record) -> None:
  601. if self._track_time is not None:
  602. self._accumulate_time += time.time() - self._track_time
  603. self._track_time = None
  604. def handle_request_poll_exit(self, record: Record) -> None:
  605. self._dispatch_record(record, always_send=True)
  606. def handle_request_stop_status(self, record: Record) -> None:
  607. self._dispatch_record(record)
  608. def handle_request_network_status(self, record: Record) -> None:
  609. self._dispatch_record(record)
  610. def handle_request_internal_messages(self, record: Record) -> None:
  611. result = proto_util._result_from_record(record)
  612. result.response.internal_messages_response.messages.CopyFrom(
  613. self._internal_messages
  614. )
  615. self._internal_messages.Clear()
  616. self._respond_result(result)
  617. def handle_request_status(self, record: Record) -> None:
  618. result = proto_util._result_from_record(record)
  619. self._respond_result(result)
  620. def handle_request_get_summary(self, record: Record) -> None:
  621. result = proto_util._result_from_record(record)
  622. for key, value in self._consolidated_summary.items():
  623. item = SummaryItem()
  624. item.key = key
  625. item.value_json = json.dumps(value)
  626. result.response.get_summary_response.item.append(item)
  627. self._respond_result(result)
  628. def handle_tbrecord(self, record: Record) -> None:
  629. logger.info("handling tbrecord: %s", record)
  630. if self._tb_watcher:
  631. tbrecord = record.tbrecord
  632. self._tb_watcher.add(tbrecord.log_dir, tbrecord.save, tbrecord.root_dir)
  633. self._dispatch_record(record)
  634. def _handle_defined_metric(self, record: Record) -> None:
  635. metric = record.metric
  636. if metric._control.overwrite:
  637. self._metric_defines[metric.name].CopyFrom(metric)
  638. else:
  639. self._metric_defines[metric.name].MergeFrom(metric)
  640. # before dispatching, make sure step_metric is defined, if not define it and
  641. # dispatch it locally first
  642. metric = self._metric_defines[metric.name]
  643. if metric.step_metric and metric.step_metric not in self._metric_defines:
  644. m = MetricRecord(name=metric.step_metric)
  645. self._metric_defines[metric.step_metric] = m
  646. mr = Record()
  647. mr.metric.CopyFrom(m)
  648. mr.control.local = True # Don't store this, just send it
  649. self._dispatch_record(mr)
  650. self._dispatch_record(record)
  651. def _handle_glob_metric(self, record: Record) -> None:
  652. metric = record.metric
  653. if metric._control.overwrite:
  654. self._metric_globs[metric.glob_name].CopyFrom(metric)
  655. else:
  656. self._metric_globs[metric.glob_name].MergeFrom(metric)
  657. self._dispatch_record(record)
  658. def handle_metric(self, record: Record) -> None:
  659. """Handle MetricRecord.
  660. Walkthrough of the life of a MetricRecord:
  661. Metric defined:
  662. - run.define_metric() parses arguments create wandb_metric.Metric
  663. - build MetricRecord publish to interface
  664. - handler (this function) keeps list of metrics published:
  665. - self._metric_defines: Fully defined metrics
  666. - self._metric_globs: metrics that have a wildcard
  667. - dispatch writer and sender thread
  668. - writer: records are saved to persistent store
  669. - sender: fully defined metrics get mapped into metadata for UI
  670. History logged:
  671. - handle_history
  672. - check if metric matches _metric_defines
  673. - if not, check if metric matches _metric_globs
  674. - if _metric globs match, generate defined metric and call _handle_metric
  675. Args:
  676. record (Record): Metric record to process
  677. """
  678. if record.metric.name:
  679. self._handle_defined_metric(record)
  680. elif record.metric.glob_name:
  681. self._handle_glob_metric(record)
  682. def handle_request_sampled_history(self, record: Record) -> None:
  683. result = proto_util._result_from_record(record)
  684. for key, sampled in self._sampled_history.items():
  685. item = SampledHistoryItem()
  686. item.key = key
  687. values: Iterable[Any] = sampled.get()
  688. if all(isinstance(i, numbers.Integral) for i in values):
  689. try:
  690. item.values_int.extend(values)
  691. except ValueError:
  692. # it is safe to ignore these as this is for display information
  693. pass
  694. elif all(isinstance(i, numbers.Real) for i in values):
  695. item.values_float.extend(values)
  696. result.response.sampled_history_response.item.append(item)
  697. self._respond_result(result)
  698. def handle_request_keepalive(self, record: Record) -> None:
  699. """Handle a keepalive request.
  700. Keepalive is a noop, we just want to verify transport is alive.
  701. """
  702. def handle_request_run_status(self, record: Record) -> None:
  703. self._dispatch_record(record, always_send=True)
  704. def handle_request_shutdown(self, record: Record) -> None:
  705. # TODO(jhr): should we drain things and stop new requests from coming in?
  706. result = proto_util._result_from_record(record)
  707. self._respond_result(result)
  708. self._stopped.set()
  709. def handle_request_operations(self, record: Record) -> None:
  710. """No-op. Not implemented for the legacy-service."""
  711. self._respond_result(proto_util._result_from_record(record))
  712. def finish(self) -> None:
  713. logger.info("shutting down handler")
  714. if self._tb_watcher:
  715. self._tb_watcher.finish()
  716. # self._context_keeper._debug_print_orphans()
  717. def __next__(self) -> Record:
  718. return self._record_q.get(block=True)
  719. next = __next__
  720. def _history_assign_runtime(
  721. self,
  722. history: HistoryRecord,
  723. history_dict: Dict[str, Any],
  724. ) -> None:
  725. # _runtime calculation is meaningless if there is no _timestamp
  726. if "_timestamp" not in history_dict:
  727. return
  728. # if it is offline sync, self._run_start_time is None
  729. # in that case set it to the first tfevent timestamp
  730. if self._run_start_time is None:
  731. self._run_start_time = history_dict["_timestamp"]
  732. history_dict["_runtime"] = history_dict["_timestamp"] - self._run_start_time
  733. item = history.item.add()
  734. item.key = "_runtime"
  735. item.value_json = json.dumps(history_dict[item.key])