_tensorboard_logger.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Contains a logger to push training logs to the Hub, using Tensorboard."""
  15. from pathlib import Path
  16. from typing import List, Optional, Union
  17. from ._commit_scheduler import CommitScheduler
  18. from .errors import EntryNotFoundError
  19. from .repocard import ModelCard
  20. from .utils import experimental
  21. # Depending on user's setup, SummaryWriter can come either from 'tensorboardX'
  22. # or from 'torch.utils.tensorboard'. Both are compatible so let's try to load
  23. # from either of them.
  24. try:
  25. from tensorboardX import SummaryWriter as _RuntimeSummaryWriter
  26. is_summary_writer_available = True
  27. except ImportError:
  28. try:
  29. from torch.utils.tensorboard import SummaryWriter as _RuntimeSummaryWriter
  30. is_summary_writer_available = True
  31. except ImportError:
  32. # Dummy class to avoid failing at import. Will raise on instance creation.
  33. class _DummySummaryWriter:
  34. pass
  35. _RuntimeSummaryWriter = _DummySummaryWriter # type: ignore[assignment]
  36. is_summary_writer_available = False
  37. class HFSummaryWriter(_RuntimeSummaryWriter):
  38. """
  39. Wrapper around the tensorboard's `SummaryWriter` to push training logs to the Hub.
  40. Data is logged locally and then pushed to the Hub asynchronously. Pushing data to the Hub is done in a separate
  41. thread to avoid blocking the training script. In particular, if the upload fails for any reason (e.g. a connection
  42. issue), the main script will not be interrupted. Data is automatically pushed to the Hub every `commit_every`
  43. minutes (default to every 5 minutes).
  44. > [!WARNING]
  45. > `HFSummaryWriter` is experimental. Its API is subject to change in the future without prior notice.
  46. Args:
  47. repo_id (`str`):
  48. The id of the repo to which the logs will be pushed.
  49. logdir (`str`, *optional*):
  50. The directory where the logs will be written. If not specified, a local directory will be created by the
  51. underlying `SummaryWriter` object.
  52. commit_every (`int` or `float`, *optional*):
  53. The frequency (in minutes) at which the logs will be pushed to the Hub. Defaults to 5 minutes.
  54. squash_history (`bool`, *optional*):
  55. Whether to squash the history of the repo after each commit. Defaults to `False`. Squashing commits is
  56. useful to avoid degraded performances on the repo when it grows too large.
  57. repo_type (`str`, *optional*):
  58. The type of the repo to which the logs will be pushed. Defaults to "model".
  59. repo_revision (`str`, *optional*):
  60. The revision of the repo to which the logs will be pushed. Defaults to "main".
  61. repo_private (`bool`, *optional*):
  62. Whether to make the repo private. If `None` (default), the repo will be public unless the organization's default is private. This value is ignored if the repo already exists.
  63. path_in_repo (`str`, *optional*):
  64. The path to the folder in the repo where the logs will be pushed. Defaults to "tensorboard/".
  65. repo_allow_patterns (`List[str]` or `str`, *optional*):
  66. A list of patterns to include in the upload. Defaults to `"*.tfevents.*"`. Check out the
  67. [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
  68. repo_ignore_patterns (`List[str]` or `str`, *optional*):
  69. A list of patterns to exclude in the upload. Check out the
  70. [upload guide](https://huggingface.co/docs/huggingface_hub/guides/upload#upload-a-folder) for more details.
  71. token (`str`, *optional*):
  72. Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
  73. details
  74. kwargs:
  75. Additional keyword arguments passed to `SummaryWriter`.
  76. Examples:
  77. ```diff
  78. # Taken from https://pytorch.org/docs/stable/tensorboard.html
  79. - from torch.utils.tensorboard import SummaryWriter
  80. + from huggingface_hub import HFSummaryWriter
  81. import numpy as np
  82. - writer = SummaryWriter()
  83. + writer = HFSummaryWriter(repo_id="username/my-trained-model")
  84. for n_iter in range(100):
  85. writer.add_scalar('Loss/train', np.random.random(), n_iter)
  86. writer.add_scalar('Loss/test', np.random.random(), n_iter)
  87. writer.add_scalar('Accuracy/train', np.random.random(), n_iter)
  88. writer.add_scalar('Accuracy/test', np.random.random(), n_iter)
  89. ```
  90. ```py
  91. >>> from huggingface_hub import HFSummaryWriter
  92. # Logs are automatically pushed every 15 minutes (5 by default) + when exiting the context manager
  93. >>> with HFSummaryWriter(repo_id="test_hf_logger", commit_every=15) as logger:
  94. ... logger.add_scalar("a", 1)
  95. ... logger.add_scalar("b", 2)
  96. ```
  97. """
  98. @experimental
  99. def __new__(cls, *args, **kwargs) -> "HFSummaryWriter":
  100. if not is_summary_writer_available:
  101. raise ImportError(
  102. "You must have `tensorboard` installed to use `HFSummaryWriter`. Please run `pip install --upgrade"
  103. " tensorboardX` first."
  104. )
  105. return super().__new__(cls)
  106. def __init__(
  107. self,
  108. repo_id: str,
  109. *,
  110. logdir: Optional[str] = None,
  111. commit_every: Union[int, float] = 5,
  112. squash_history: bool = False,
  113. repo_type: Optional[str] = None,
  114. repo_revision: Optional[str] = None,
  115. repo_private: Optional[bool] = None,
  116. path_in_repo: Optional[str] = "tensorboard",
  117. repo_allow_patterns: Optional[Union[List[str], str]] = "*.tfevents.*",
  118. repo_ignore_patterns: Optional[Union[List[str], str]] = None,
  119. token: Optional[str] = None,
  120. **kwargs,
  121. ):
  122. # Initialize SummaryWriter
  123. super().__init__(logdir=logdir, **kwargs)
  124. # Check logdir has been correctly initialized and fail early otherwise. In practice, SummaryWriter takes care of it.
  125. if not isinstance(self.logdir, str):
  126. raise ValueError(f"`self.logdir` must be a string. Got '{self.logdir}' of type {type(self.logdir)}.")
  127. # Append logdir name to `path_in_repo`
  128. if path_in_repo is None or path_in_repo == "":
  129. path_in_repo = Path(self.logdir).name
  130. else:
  131. path_in_repo = path_in_repo.strip("/") + "/" + Path(self.logdir).name
  132. # Initialize scheduler
  133. self.scheduler = CommitScheduler(
  134. folder_path=self.logdir,
  135. path_in_repo=path_in_repo,
  136. repo_id=repo_id,
  137. repo_type=repo_type,
  138. revision=repo_revision,
  139. private=repo_private,
  140. token=token,
  141. allow_patterns=repo_allow_patterns,
  142. ignore_patterns=repo_ignore_patterns,
  143. every=commit_every,
  144. squash_history=squash_history,
  145. )
  146. # Exposing some high-level info at root level
  147. self.repo_id = self.scheduler.repo_id
  148. self.repo_type = self.scheduler.repo_type
  149. self.repo_revision = self.scheduler.revision
  150. # Add `hf-summary-writer` tag to the model card metadata
  151. try:
  152. card = ModelCard.load(repo_id_or_path=self.repo_id, repo_type=self.repo_type)
  153. except EntryNotFoundError:
  154. card = ModelCard("")
  155. tags = card.data.get("tags", [])
  156. if "hf-summary-writer" not in tags:
  157. tags.append("hf-summary-writer")
  158. card.data["tags"] = tags
  159. card.push_to_hub(repo_id=self.repo_id, repo_type=self.repo_type)
  160. def __exit__(self, exc_type, exc_val, exc_tb):
  161. """Push to hub in a non-blocking way when exiting the logger's context manager."""
  162. super().__exit__(exc_type, exc_val, exc_tb)
  163. future = self.scheduler.trigger()
  164. future.result()