context.py 2.5 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889
  1. """Context Keeper."""
  2. import logging
  3. import threading
  4. from typing import Dict, Optional
  5. from wandb.proto.wandb_internal_pb2 import Record, Result
  6. logger = logging.getLogger(__name__)
  7. class Context:
  8. _cancel_event: threading.Event
  9. # TODO(debug_context) add debug setting to enable this
  10. # _debug_record: Optional[Record]
  11. def __init__(self) -> None:
  12. self._cancel_event = threading.Event()
  13. # TODO(debug_context) see above
  14. # self._debug_record = None
  15. def cancel(self) -> None:
  16. self._cancel_event.set()
  17. @property
  18. def cancel_event(self) -> threading.Event:
  19. return self._cancel_event
  20. def context_id_from_record(record: Record) -> str:
  21. context_id = record.control.mailbox_slot
  22. return context_id
  23. def context_id_from_result(result: Result) -> str:
  24. context_id = result.control.mailbox_slot
  25. return context_id
  26. class ContextKeeper:
  27. _active_items: Dict[str, Context]
  28. def __init__(self) -> None:
  29. self._active_items = {}
  30. def add_from_record(self, record: Record) -> Optional[Context]:
  31. context_id = context_id_from_record(record)
  32. if not context_id:
  33. return None
  34. context_obj = self.add(context_id)
  35. # TODO(debug_context) see above
  36. # context_obj._debug_record = record
  37. return context_obj
  38. def add(self, context_id: str) -> Context:
  39. assert context_id
  40. context_obj = Context()
  41. self._active_items[context_id] = context_obj
  42. return context_obj
  43. def get(self, context_id: str) -> Optional[Context]:
  44. item = self._active_items.get(context_id)
  45. return item
  46. def release(self, context_id: str) -> None:
  47. if not context_id:
  48. return
  49. _ = self._active_items.pop(context_id, None)
  50. def cancel(self, context_id: str) -> bool:
  51. item = self.get(context_id)
  52. if item:
  53. item.cancel()
  54. return True
  55. return False
  56. # TODO(debug_context) see above
  57. # def _debug_print_orphans(self, print_to_stdout: bool) -> None:
  58. # for context_id, context in self._active_items.items():
  59. # record = context._debug_record
  60. # record_type = record.WhichOneof("record_type") if record else "unknown"
  61. # message = (
  62. # f"Context: {context_id} {context.cancel_event.is_set()} {record_type}"
  63. # )
  64. # logger.warning(message)
  65. # if print_to_stdout:
  66. # print(message)