gitlib.py 7.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. import configparser
  2. import logging
  3. import os
  4. from typing import TYPE_CHECKING, Any, Optional
  5. from urllib.parse import urlparse, urlunparse
  6. import wandb
  7. try:
  8. from git import ( # type: ignore
  9. GitCommandError,
  10. InvalidGitRepositoryError,
  11. NoSuchPathError,
  12. Repo,
  13. )
  14. except ImportError:
  15. Repo = None # type: ignore
  16. if TYPE_CHECKING:
  17. from git import Repo
  18. logger = logging.getLogger(__name__)
  19. class GitRepo:
  20. def __init__(
  21. self,
  22. root: Optional[str] = None,
  23. remote: str = "origin",
  24. lazy: bool = True,
  25. remote_url: Optional[str] = None,
  26. commit: Optional[str] = None,
  27. ) -> None:
  28. self.remote_name = remote if remote_url is None else None
  29. self._root = root
  30. self._remote_url = remote_url
  31. self._commit = commit
  32. self._repo = None
  33. self._repo_initialized = False
  34. if not lazy:
  35. self._repo = self._init_repo()
  36. def _init_repo(self) -> Optional[Repo]:
  37. self._repo_initialized = True
  38. if Repo is None:
  39. return None
  40. if self.remote_name is None:
  41. return None
  42. try:
  43. return Repo(self._root or os.getcwd(), search_parent_directories=True)
  44. except FileNotFoundError:
  45. wandb.termwarn("current working directory has been invalidated")
  46. logger.warning("current working directory has been invalidated")
  47. except InvalidGitRepositoryError:
  48. logger.debug("git repository is invalid")
  49. except NoSuchPathError:
  50. wandb.termwarn(f"git root {self._root} does not exist")
  51. logger.warning(f"git root {self._root} does not exist")
  52. return None
  53. @property
  54. def repo(self) -> Optional[Repo]:
  55. if not self._repo_initialized:
  56. self._repo = self._init_repo()
  57. return self._repo
  58. @property
  59. def auto(self) -> bool:
  60. return self._remote_url is None
  61. def is_untracked(self, file_name: str) -> Optional[bool]:
  62. if not self.repo:
  63. return True
  64. try:
  65. return file_name in self.repo.untracked_files
  66. except GitCommandError:
  67. return None
  68. @property
  69. def enabled(self) -> bool:
  70. return bool(self.repo)
  71. @property
  72. def root(self) -> Any:
  73. if not self.repo:
  74. return None
  75. try:
  76. return self.repo.git.rev_parse("--show-toplevel")
  77. except GitCommandError:
  78. # todo: collect telemetry on this
  79. logger.exception("git root error")
  80. return None
  81. @property
  82. def dirty(self) -> Any:
  83. if not self.repo:
  84. return False
  85. try:
  86. return self.repo.is_dirty()
  87. except GitCommandError:
  88. return False
  89. @property
  90. def email(self) -> Optional[str]:
  91. if not self.repo:
  92. return None
  93. try:
  94. return self.repo.config_reader().get_value("user", "email") # type: ignore
  95. except configparser.Error:
  96. return None
  97. @property
  98. def last_commit(self) -> Any:
  99. if self._commit:
  100. return self._commit
  101. if not self.repo:
  102. return None
  103. if not self.repo.head or not self.repo.head.is_valid():
  104. return None
  105. # TODO: Saw a user getting a Unicode decode error when parsing refs,
  106. # more details on implementing a real fix in [WB-4064]
  107. try:
  108. if len(self.repo.refs) > 0: # type: ignore[arg-type]
  109. return self.repo.head.commit.hexsha
  110. else:
  111. return self.repo.git.show_ref("--head").split(" ")[0]
  112. except Exception:
  113. logger.exception("Unable to find most recent commit in git")
  114. return None
  115. @property
  116. def branch(self) -> Any:
  117. if not self.repo:
  118. return None
  119. return self.repo.head.ref.name
  120. @property
  121. def remote(self) -> Any:
  122. if not self.repo:
  123. return None
  124. try:
  125. return self.repo.remotes[self.remote_name] # type: ignore[index]
  126. except IndexError:
  127. return None
  128. # the --submodule=diff option doesn't exist in pre-2.11 versions of git (november 2016)
  129. # https://stackoverflow.com/questions/10757091/git-list-of-all-changed-files-including-those-in-submodules
  130. @property
  131. def has_submodule_diff(self) -> bool:
  132. if not self.repo:
  133. return False
  134. return bool(self.repo.git.version_info >= (2, 11, 0))
  135. @property
  136. def remote_url(self) -> Any:
  137. if self._remote_url:
  138. return self._remote_url
  139. if not self.remote:
  140. return None
  141. parsed = urlparse(self.remote.url)
  142. hostname = parsed.hostname
  143. if parsed.port is not None:
  144. hostname = f"{hostname}:{parsed.port}"
  145. if parsed.password is not None:
  146. return urlunparse(parsed._replace(netloc=f"{parsed.username}:@{hostname}"))
  147. return urlunparse(parsed._replace(netloc=hostname))
  148. @property
  149. def root_dir(self) -> Any:
  150. if not self.repo:
  151. return None
  152. try:
  153. return self.repo.git.rev_parse("--show-toplevel")
  154. except GitCommandError:
  155. return None
  156. def get_upstream_fork_point(self) -> Any:
  157. """Get the most recent ancestor of HEAD that occurs on an upstream branch.
  158. First looks at the current branch's tracking branch, if applicable. If
  159. that doesn't work, looks at every other branch to find the most recent
  160. ancestor of HEAD that occurs on a tracking branch.
  161. Returns:
  162. git.Commit object or None
  163. """
  164. possible_relatives = []
  165. try:
  166. if not self.repo:
  167. return None
  168. try:
  169. active_branch = self.repo.active_branch
  170. except (TypeError, ValueError):
  171. logger.debug("git is in a detached head state")
  172. return None # detached head
  173. else:
  174. tracking_branch = active_branch.tracking_branch()
  175. if tracking_branch:
  176. possible_relatives.append(tracking_branch.commit)
  177. if not possible_relatives:
  178. for branch in self.repo.branches: # type: ignore[attr-defined]
  179. tracking_branch = branch.tracking_branch()
  180. if tracking_branch is not None:
  181. possible_relatives.append(tracking_branch.commit)
  182. head = self.repo.head
  183. most_recent_ancestor = None
  184. for possible_relative in possible_relatives:
  185. # at most one:
  186. for ancestor in self.repo.merge_base(head, possible_relative):
  187. if most_recent_ancestor is None:
  188. most_recent_ancestor = ancestor
  189. elif self.repo.is_ancestor(most_recent_ancestor, ancestor): # type: ignore
  190. most_recent_ancestor = ancestor
  191. except GitCommandError as e:
  192. logger.debug("git remote upstream fork point could not be found")
  193. logger.debug(str(e))
  194. return None
  195. return most_recent_ancestor
  196. def tag(self, name: str, message: Optional[str]) -> Any:
  197. if not self.repo:
  198. return None
  199. try:
  200. return self.repo.create_tag(f"wandb/{name}", message=message, force=True)
  201. except GitCommandError:
  202. logger.debug("Failed to tag repository.")
  203. return None
  204. def push(self, name: str) -> Any:
  205. if not self.remote:
  206. return None
  207. try:
  208. return self.remote.push(f"wandb/{name}", force=True)
  209. except GitCommandError:
  210. logger.debug("failed to push git")
  211. return None