ipython.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140
  1. import logging
  2. import sys
  3. import warnings
  4. from typing import Literal, Optional
  5. import wandb
  6. PythonType = Literal["python", "ipython", "jupyter"]
  7. logger = logging.getLogger(__name__)
  8. def toggle_button(what="run"):
  9. """Returns the HTML for a button used to reveal the element following it.
  10. The element immediately after the button must have `display: none`.
  11. """
  12. return (
  13. "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">"
  14. f"Display W&B {what}"
  15. "</button>"
  16. )
  17. def _get_python_type() -> PythonType:
  18. if "IPython" not in sys.modules:
  19. return "python"
  20. try:
  21. from IPython import get_ipython # type: ignore
  22. # Calling get_ipython can cause an ImportError
  23. if get_ipython() is None:
  24. return "python"
  25. except ImportError:
  26. return "python"
  27. # jupyter-based environments (e.g. jupyter itself, colab, kaggle, etc) have a connection file
  28. ip_kernel_app_connection_file = (
  29. (get_ipython().config.get("IPKernelApp", {}) or {})
  30. .get("connection_file", "")
  31. .lower()
  32. ) or (
  33. (get_ipython().config.get("ColabKernelApp", {}) or {})
  34. .get("connection_file", "")
  35. .lower()
  36. )
  37. if (
  38. ("terminal" in get_ipython().__module__)
  39. or ("jupyter" not in ip_kernel_app_connection_file)
  40. or ("spyder" in sys.modules)
  41. ):
  42. return "ipython"
  43. else:
  44. return "jupyter"
  45. def in_jupyter() -> bool:
  46. """Returns True if we're in a Jupyter notebook."""
  47. return _get_python_type() == "jupyter"
  48. def in_ipython() -> bool:
  49. """Returns True if we're running in IPython in the terminal."""
  50. return _get_python_type() == "ipython"
  51. def in_notebook() -> bool:
  52. """Returns True if we're running in Jupyter or IPython."""
  53. return _get_python_type() != "python"
  54. def in_vscode_notebook() -> bool:
  55. """Returns True if we're in a VSCode notebook."""
  56. try:
  57. from IPython import get_ipython
  58. except ModuleNotFoundError:
  59. return False
  60. ipython = get_ipython()
  61. if not ipython:
  62. return False
  63. return ipython.kernel.shell.user_ns.get("__vsc_ipynb_file__") is not None
  64. class ProgressWidget:
  65. """A simple wrapper to render a nice progress bar with a label."""
  66. def __init__(self, widgets, min, max):
  67. from IPython import display
  68. self._ipython_display = display
  69. self.widgets = widgets
  70. self._progress = widgets.FloatProgress(min=min, max=max)
  71. self._label = widgets.Label()
  72. self._widget = self.widgets.VBox([self._label, self._progress])
  73. self._displayed = False
  74. self._disabled = False
  75. def update(self, value: float, label: str) -> None:
  76. if self._disabled:
  77. return
  78. try:
  79. self._progress.value = value
  80. self._label.value = label
  81. if not self._displayed:
  82. self._displayed = True
  83. self._ipython_display.display(self._widget)
  84. except Exception:
  85. logger.exception("Error in ProgressWidget.update()")
  86. self._disabled = True
  87. wandb.termwarn(
  88. "Unable to render progress bar, see the user log for details"
  89. )
  90. def close(self) -> None:
  91. if self._disabled or not self._displayed:
  92. return
  93. self._widget.close()
  94. def jupyter_progress_bar(min: float = 0, max: float = 1.0) -> Optional[ProgressWidget]:
  95. """Return an ipywidget progress bar or None if we can't import it."""
  96. widgets = wandb.util.get_module("ipywidgets")
  97. try:
  98. if widgets is None:
  99. # TODO: this currently works in iPython but it's deprecated since 4.0
  100. with warnings.catch_warnings():
  101. warnings.simplefilter("ignore")
  102. from IPython.html import widgets # type: ignore
  103. assert hasattr(widgets, "VBox")
  104. assert hasattr(widgets, "Label")
  105. assert hasattr(widgets, "FloatProgress")
  106. return ProgressWidget(widgets, min=min, max=max)
  107. except (ImportError, AssertionError):
  108. return None