hub.py 8.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303
  1. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import os
  15. import shutil
  16. import sys
  17. import zipfile
  18. from paddle.utils.download import get_path_from_url
  19. __all__ = []
  20. DEFAULT_CACHE_DIR = '~/.cache'
  21. VAR_DEPENDENCY = 'dependencies'
  22. MODULE_HUBCONF = 'hubconf.py'
  23. HUB_DIR = os.path.expanduser(os.path.join('~', '.cache', 'paddle', 'hub'))
  24. def _remove_if_exists(path):
  25. if os.path.exists(path):
  26. if os.path.isfile(path):
  27. os.remove(path)
  28. else:
  29. shutil.rmtree(path)
  30. def _import_module(name, repo_dir):
  31. sys.path.insert(0, repo_dir)
  32. try:
  33. hub_module = __import__(name)
  34. sys.modules.pop(name)
  35. except ImportError:
  36. sys.path.remove(repo_dir)
  37. raise RuntimeError(
  38. 'Please make sure config exists or repo error messages above fixed when importing'
  39. )
  40. sys.path.remove(repo_dir)
  41. return hub_module
  42. def _git_archive_link(repo_owner, repo_name, branch, source):
  43. if source == 'github':
  44. return (
  45. f'https://github.com/{repo_owner}/{repo_name}/archive/{branch}.zip'
  46. )
  47. elif source == 'gitee':
  48. return f'https://gitee.com/{repo_owner}/{repo_name}/repository/archive/{branch}.zip'
  49. def _parse_repo_info(repo, source):
  50. branch = 'main' if source == 'github' else 'master'
  51. if ':' in repo:
  52. repo_info, branch = repo.split(':')
  53. else:
  54. repo_info = repo
  55. repo_owner, repo_name = repo_info.split('/')
  56. return repo_owner, repo_name, branch
  57. def _make_dirs(dirname):
  58. try:
  59. from pathlib import Path
  60. except ImportError:
  61. from pathlib2 import Path
  62. Path(dirname).mkdir(exist_ok=True)
  63. def _get_cache_or_reload(repo, force_reload, verbose=True, source='github'):
  64. # Setup hub_dir to save downloaded files
  65. hub_dir = HUB_DIR
  66. _make_dirs(hub_dir)
  67. # Parse github/gitee repo information
  68. repo_owner, repo_name, branch = _parse_repo_info(repo, source)
  69. # Github allows branch name with slash '/',
  70. # this causes confusion with path on both Linux and Windows.
  71. # Backslash is not allowed in Github branch name so no need to
  72. # to worry about it.
  73. normalized_br = branch.replace('/', '_')
  74. # Github renames folder repo/v1.x.x to repo-1.x.x
  75. # We don't know the repo name before downloading the zip file
  76. # and inspect name from it.
  77. # To check if cached repo exists, we need to normalize folder names.
  78. repo_dir = os.path.join(
  79. hub_dir, '_'.join([repo_owner, repo_name, normalized_br])
  80. )
  81. use_cache = (not force_reload) and os.path.exists(repo_dir)
  82. if use_cache:
  83. if verbose:
  84. sys.stderr.write(f'Using cache found in {repo_dir}\n')
  85. else:
  86. cached_file = os.path.join(hub_dir, normalized_br + '.zip')
  87. _remove_if_exists(cached_file)
  88. url = _git_archive_link(repo_owner, repo_name, branch, source=source)
  89. fpath = get_path_from_url(
  90. url,
  91. hub_dir,
  92. check_exist=not force_reload,
  93. decompress=False,
  94. )
  95. shutil.move(fpath, cached_file)
  96. with zipfile.ZipFile(cached_file) as cached_zipfile:
  97. extracted_repo_name = cached_zipfile.infolist()[0].filename
  98. extracted_repo = os.path.join(hub_dir, extracted_repo_name)
  99. _remove_if_exists(extracted_repo)
  100. # Unzip the code and rename the base folder
  101. cached_zipfile.extractall(hub_dir)
  102. _remove_if_exists(cached_file)
  103. _remove_if_exists(repo_dir)
  104. # Rename the repo
  105. shutil.move(extracted_repo, repo_dir)
  106. return repo_dir
  107. def _load_entry_from_hubconf(m, name):
  108. '''load entry from hubconf'''
  109. if not isinstance(name, str):
  110. raise ValueError(
  111. 'Invalid input: model should be a str of function name'
  112. )
  113. func = getattr(m, name, None)
  114. if func is None or not callable(func):
  115. raise RuntimeError(f'Cannot find callable {name} in hubconf')
  116. return func
  117. def _check_module_exists(name):
  118. try:
  119. __import__(name)
  120. return True
  121. except ImportError:
  122. return False
  123. def _check_dependencies(m):
  124. dependencies = getattr(m, VAR_DEPENDENCY, None)
  125. if dependencies is not None:
  126. missing_deps = [
  127. pkg for pkg in dependencies if not _check_module_exists(pkg)
  128. ]
  129. if len(missing_deps):
  130. raise RuntimeError(
  131. 'Missing dependencies: {}'.format(', '.join(missing_deps))
  132. )
  133. def list(repo_dir, source='github', force_reload=False):
  134. r"""
  135. List all entrypoints available in `github` hubconf.
  136. Args:
  137. repo_dir(str): Github or local path.
  138. - github path (str): A string with format "repo_owner/repo_name[:tag_name]" with an optional
  139. tag/branch. The default branch is `main` if not specified.
  140. - local path (str): Local repo path.
  141. source (str): `github` | `gitee` | `local`. Default is `github`.
  142. force_reload (bool, optional): Whether to discard the existing cache and force a fresh download. Default is `False`.
  143. Returns:
  144. entrypoints: A list of available entrypoint names.
  145. Examples:
  146. .. code-block:: python
  147. >>> import paddle
  148. >>> paddle.hub.list('lyuwenyu/paddlehub_demo:main', source='github', force_reload=False)
  149. """
  150. if source not in ('github', 'gitee', 'local'):
  151. raise ValueError(
  152. f'Unknown source: "{source}". Allowed values: "github" | "gitee" | "local".'
  153. )
  154. if source in ('github', 'gitee'):
  155. repo_dir = _get_cache_or_reload(
  156. repo_dir, force_reload, True, source=source
  157. )
  158. hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
  159. entrypoints = [
  160. f
  161. for f in dir(hub_module)
  162. if callable(getattr(hub_module, f)) and not f.startswith('_')
  163. ]
  164. return entrypoints
  165. def help(repo_dir, model, source='github', force_reload=False):
  166. """
  167. Show help information of model
  168. Args:
  169. repo_dir(str): Github or local path.
  170. - github path (str): A string with format "repo_owner/repo_name[:tag_name]" with an optional
  171. tag/branch. The default branch is `main` if not specified.
  172. - local path (str): Local repo path.
  173. model (str): Model name.
  174. source (str): `github` | `gitee` | `local`. Default is `github`.
  175. force_reload (bool, optional): Default is `False`.
  176. Returns:
  177. docs
  178. Examples:
  179. .. code-block:: python
  180. >>> import paddle
  181. >>> paddle.hub.help('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
  182. """
  183. if source not in ('github', 'gitee', 'local'):
  184. raise ValueError(
  185. f'Unknown source: "{source}". Allowed values: "github" | "gitee" | "local".'
  186. )
  187. if source in ('github', 'gitee'):
  188. repo_dir = _get_cache_or_reload(
  189. repo_dir, force_reload, True, source=source
  190. )
  191. hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
  192. entry = _load_entry_from_hubconf(hub_module, model)
  193. return entry.__doc__
  194. def load(repo_dir, model, source='github', force_reload=False, **kwargs):
  195. """
  196. Load model
  197. Args:
  198. repo_dir(str): Github or local path.
  199. - github path (str): A string with format "repo_owner/repo_name[:tag_name]" with an optional
  200. tag/branch. The default branch is `main` if not specified.
  201. - local path (str): Local repo path.
  202. model (str): Model name.
  203. source (str): `github` | `gitee` | `local`. Default is `github`.
  204. force_reload (bool, optional): Default is `False`.
  205. **kwargs: Parameters using for model.
  206. Returns:
  207. paddle model.
  208. Examples:
  209. .. code-block:: python
  210. >>> import paddle
  211. >>> paddle.hub.load('lyuwenyu/paddlehub_demo:main', model='MM', source='github')
  212. """
  213. if source not in ('github', 'gitee', 'local'):
  214. raise ValueError(
  215. f'Unknown source: "{source}". Allowed values: "github" | "gitee" | "local".'
  216. )
  217. if source in ('github', 'gitee'):
  218. repo_dir = _get_cache_or_reload(
  219. repo_dir, force_reload, True, source=source
  220. )
  221. hub_module = _import_module(MODULE_HUBCONF.split('.')[0], repo_dir)
  222. _check_dependencies(hub_module)
  223. entry = _load_entry_from_hubconf(hub_module, model)
  224. return entry(**kwargs)