fastai_utils.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415
  1. import json
  2. import os
  3. from pathlib import Path
  4. from pickle import DEFAULT_PROTOCOL, PicklingError
  5. from typing import Any, Dict, List, Optional, Union
  6. from packaging import version
  7. from huggingface_hub import constants, snapshot_download
  8. from huggingface_hub.hf_api import HfApi
  9. from huggingface_hub.utils import (
  10. SoftTemporaryDirectory,
  11. get_fastai_version,
  12. get_fastcore_version,
  13. get_python_version,
  14. )
  15. from .utils import logging, validate_hf_hub_args
  16. from .utils._runtime import _PY_VERSION # noqa: F401 # for backward compatibility...
  17. logger = logging.get_logger(__name__)
  18. def _check_fastai_fastcore_versions(
  19. fastai_min_version: str = "2.4",
  20. fastcore_min_version: str = "1.3.27",
  21. ):
  22. """
  23. Checks that the installed fastai and fastcore versions are compatible for pickle serialization.
  24. Args:
  25. fastai_min_version (`str`, *optional*):
  26. The minimum fastai version supported.
  27. fastcore_min_version (`str`, *optional*):
  28. The minimum fastcore version supported.
  29. > [!TIP]
  30. > Raises the following error:
  31. >
  32. > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
  33. > if the fastai or fastcore libraries are not available or are of an invalid version.
  34. """
  35. if (get_fastcore_version() or get_fastai_version()) == "N/A":
  36. raise ImportError(
  37. f"fastai>={fastai_min_version} and fastcore>={fastcore_min_version} are"
  38. f" required. Currently using fastai=={get_fastai_version()} and"
  39. f" fastcore=={get_fastcore_version()}."
  40. )
  41. current_fastai_version = version.Version(get_fastai_version())
  42. current_fastcore_version = version.Version(get_fastcore_version())
  43. if current_fastai_version < version.Version(fastai_min_version):
  44. raise ImportError(
  45. "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
  46. f" fastai>={fastai_min_version} version, but you are using fastai version"
  47. f" {get_fastai_version()} which is incompatible. Upgrade with `pip install"
  48. " fastai==2.5.6`."
  49. )
  50. if current_fastcore_version < version.Version(fastcore_min_version):
  51. raise ImportError(
  52. "`push_to_hub_fastai` and `from_pretrained_fastai` require a"
  53. f" fastcore>={fastcore_min_version} version, but you are using fastcore"
  54. f" version {get_fastcore_version()} which is incompatible. Upgrade with"
  55. " `pip install fastcore==1.3.27`."
  56. )
  57. def _check_fastai_fastcore_pyproject_versions(
  58. storage_folder: str,
  59. fastai_min_version: str = "2.4",
  60. fastcore_min_version: str = "1.3.27",
  61. ):
  62. """
  63. Checks that the `pyproject.toml` file in the directory `storage_folder` has fastai and fastcore versions
  64. that are compatible with `from_pretrained_fastai` and `push_to_hub_fastai`. If `pyproject.toml` does not exist
  65. or does not contain versions for fastai and fastcore, then it logs a warning.
  66. Args:
  67. storage_folder (`str`):
  68. Folder to look for the `pyproject.toml` file.
  69. fastai_min_version (`str`, *optional*):
  70. The minimum fastai version supported.
  71. fastcore_min_version (`str`, *optional*):
  72. The minimum fastcore version supported.
  73. > [!TIP]
  74. > Raises the following errors:
  75. >
  76. > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
  77. > if the `toml` module is not installed.
  78. > - [`ImportError`](https://docs.python.org/3/library/exceptions.html#ImportError)
  79. > if the `pyproject.toml` indicates a lower than minimum supported version of fastai or fastcore.
  80. """
  81. try:
  82. import toml
  83. except ModuleNotFoundError:
  84. raise ImportError(
  85. "`push_to_hub_fastai` and `from_pretrained_fastai` require the toml module."
  86. " Install it with `pip install toml`."
  87. )
  88. # Checks that a `pyproject.toml`, with `build-system` and `requires` sections, exists in the repository. If so, get a list of required packages.
  89. if not os.path.isfile(f"{storage_folder}/pyproject.toml"):
  90. logger.warning(
  91. "There is no `pyproject.toml` in the repository that contains the fastai"
  92. " `Learner`. The `pyproject.toml` would allow us to verify that your fastai"
  93. " and fastcore versions are compatible with those of the model you want to"
  94. " load."
  95. )
  96. return
  97. pyproject_toml = toml.load(f"{storage_folder}/pyproject.toml")
  98. if "build-system" not in pyproject_toml.keys():
  99. logger.warning(
  100. "There is no `build-system` section in the pyproject.toml of the repository"
  101. " that contains the fastai `Learner`. The `build-system` would allow us to"
  102. " verify that your fastai and fastcore versions are compatible with those"
  103. " of the model you want to load."
  104. )
  105. return
  106. build_system_toml = pyproject_toml["build-system"]
  107. if "requires" not in build_system_toml.keys():
  108. logger.warning(
  109. "There is no `requires` section in the pyproject.toml of the repository"
  110. " that contains the fastai `Learner`. The `requires` would allow us to"
  111. " verify that your fastai and fastcore versions are compatible with those"
  112. " of the model you want to load."
  113. )
  114. return
  115. package_versions = build_system_toml["requires"]
  116. # Extracts contains fastai and fastcore versions from `pyproject.toml` if available.
  117. # If the package is specified but not the version (e.g. "fastai" instead of "fastai=2.4"), the default versions are the highest.
  118. fastai_packages = [pck for pck in package_versions if pck.startswith("fastai")]
  119. if len(fastai_packages) == 0:
  120. logger.warning("The repository does not have a fastai version specified in the `pyproject.toml`.")
  121. # fastai_version is an empty string if not specified
  122. else:
  123. fastai_version = str(fastai_packages[0]).partition("=")[2]
  124. if fastai_version != "" and version.Version(fastai_version) < version.Version(fastai_min_version):
  125. raise ImportError(
  126. "`from_pretrained_fastai` requires"
  127. f" fastai>={fastai_min_version} version but the model to load uses"
  128. f" {fastai_version} which is incompatible."
  129. )
  130. fastcore_packages = [pck for pck in package_versions if pck.startswith("fastcore")]
  131. if len(fastcore_packages) == 0:
  132. logger.warning("The repository does not have a fastcore version specified in the `pyproject.toml`.")
  133. # fastcore_version is an empty string if not specified
  134. else:
  135. fastcore_version = str(fastcore_packages[0]).partition("=")[2]
  136. if fastcore_version != "" and version.Version(fastcore_version) < version.Version(fastcore_min_version):
  137. raise ImportError(
  138. "`from_pretrained_fastai` requires"
  139. f" fastcore>={fastcore_min_version} version, but you are using fastcore"
  140. f" version {fastcore_version} which is incompatible."
  141. )
  142. README_TEMPLATE = """---
  143. tags:
  144. - fastai
  145. ---
  146. # Amazing!
  147. 🥳 Congratulations on hosting your fastai model on the Hugging Face Hub!
  148. # Some next steps
  149. 1. Fill out this model card with more information (see the template below and the [documentation here](https://huggingface.co/docs/hub/model-repos))!
  150. 2. Create a demo in Gradio or Streamlit using 🤗 Spaces ([documentation here](https://huggingface.co/docs/hub/spaces)).
  151. 3. Join the fastai community on the [Fastai Discord](https://discord.com/invite/YKrxeNn)!
  152. Greetings fellow fastlearner 🤝! Don't forget to delete this content from your model card.
  153. ---
  154. # Model card
  155. ## Model description
  156. More information needed
  157. ## Intended uses & limitations
  158. More information needed
  159. ## Training and evaluation data
  160. More information needed
  161. """
  162. PYPROJECT_TEMPLATE = f"""[build-system]
  163. requires = ["setuptools>=40.8.0", "wheel", "python={get_python_version()}", "fastai={get_fastai_version()}", "fastcore={get_fastcore_version()}"]
  164. build-backend = "setuptools.build_meta:__legacy__"
  165. """
  166. def _create_model_card(repo_dir: Path):
  167. """
  168. Creates a model card for the repository.
  169. Args:
  170. repo_dir (`Path`):
  171. Directory where model card is created.
  172. """
  173. readme_path = repo_dir / "README.md"
  174. if not readme_path.exists():
  175. with readme_path.open("w", encoding="utf-8") as f:
  176. f.write(README_TEMPLATE)
  177. def _create_model_pyproject(repo_dir: Path):
  178. """
  179. Creates a `pyproject.toml` for the repository.
  180. Args:
  181. repo_dir (`Path`):
  182. Directory where `pyproject.toml` is created.
  183. """
  184. pyproject_path = repo_dir / "pyproject.toml"
  185. if not pyproject_path.exists():
  186. with pyproject_path.open("w", encoding="utf-8") as f:
  187. f.write(PYPROJECT_TEMPLATE)
  188. def _save_pretrained_fastai(
  189. learner,
  190. save_directory: Union[str, Path],
  191. config: Optional[Dict[str, Any]] = None,
  192. ):
  193. """
  194. Saves a fastai learner to `save_directory` in pickle format using the default pickle protocol for the version of python used.
  195. Args:
  196. learner (`Learner`):
  197. The `fastai.Learner` you'd like to save.
  198. save_directory (`str` or `Path`):
  199. Specific directory in which you want to save the fastai learner.
  200. config (`dict`, *optional*):
  201. Configuration object. Will be uploaded as a .json file. Example: 'https://huggingface.co/espejelomar/fastai-pet-breeds-classification/blob/main/config.json'.
  202. > [!TIP]
  203. > Raises the following error:
  204. >
  205. > - [`RuntimeError`](https://docs.python.org/3/library/exceptions.html#RuntimeError)
  206. > if the config file provided is not a dictionary.
  207. """
  208. _check_fastai_fastcore_versions()
  209. os.makedirs(save_directory, exist_ok=True)
  210. # if the user provides config then we update it with the fastai and fastcore versions in CONFIG_TEMPLATE.
  211. if config is not None:
  212. if not isinstance(config, dict):
  213. raise RuntimeError(f"Provided config should be a dict. Got: '{type(config)}'")
  214. path = os.path.join(save_directory, constants.CONFIG_NAME)
  215. with open(path, "w") as f:
  216. json.dump(config, f)
  217. _create_model_card(Path(save_directory))
  218. _create_model_pyproject(Path(save_directory))
  219. # learner.export saves the model in `self.path`.
  220. learner.path = Path(save_directory)
  221. os.makedirs(save_directory, exist_ok=True)
  222. try:
  223. learner.export(
  224. fname="model.pkl",
  225. pickle_protocol=DEFAULT_PROTOCOL,
  226. )
  227. except PicklingError:
  228. raise PicklingError(
  229. "You are using a lambda function, i.e., an anonymous function. `pickle`"
  230. " cannot pickle function objects and requires that all functions have"
  231. " names. One possible solution is to name the function."
  232. )
  233. @validate_hf_hub_args
  234. def from_pretrained_fastai(
  235. repo_id: str,
  236. revision: Optional[str] = None,
  237. ):
  238. """
  239. Load pretrained fastai model from the Hub or from a local directory.
  240. Args:
  241. repo_id (`str`):
  242. The location where the pickled fastai.Learner is. It can be either of the two:
  243. - Hosted on the Hugging Face Hub. E.g.: 'espejelomar/fatai-pet-breeds-classification' or 'distilgpt2'.
  244. You can add a `revision` by appending `@` at the end of `repo_id`. E.g.: `dbmdz/bert-base-german-cased@main`.
  245. Revision is the specific model version to use. Since we use a git-based system for storing models and other
  246. artifacts on the Hugging Face Hub, it can be a branch name, a tag name, or a commit id.
  247. - Hosted locally. `repo_id` would be a directory containing the pickle and a pyproject.toml
  248. indicating the fastai and fastcore versions used to build the `fastai.Learner`. E.g.: `./my_model_directory/`.
  249. revision (`str`, *optional*):
  250. Revision at which the repo's files are downloaded. See documentation of `snapshot_download`.
  251. Returns:
  252. The `fastai.Learner` model in the `repo_id` repo.
  253. """
  254. _check_fastai_fastcore_versions()
  255. # Load the `repo_id` repo.
  256. # `snapshot_download` returns the folder where the model was stored.
  257. # `cache_dir` will be the default '/root/.cache/huggingface/hub'
  258. if not os.path.isdir(repo_id):
  259. storage_folder = snapshot_download(
  260. repo_id=repo_id,
  261. revision=revision,
  262. library_name="fastai",
  263. library_version=get_fastai_version(),
  264. )
  265. else:
  266. storage_folder = repo_id
  267. _check_fastai_fastcore_pyproject_versions(storage_folder)
  268. from fastai.learner import load_learner # type: ignore
  269. return load_learner(os.path.join(storage_folder, "model.pkl"))
  270. @validate_hf_hub_args
  271. def push_to_hub_fastai(
  272. learner,
  273. *,
  274. repo_id: str,
  275. commit_message: str = "Push FastAI model using huggingface_hub.",
  276. private: Optional[bool] = None,
  277. token: Optional[str] = None,
  278. config: Optional[dict] = None,
  279. branch: Optional[str] = None,
  280. create_pr: Optional[bool] = None,
  281. allow_patterns: Optional[Union[List[str], str]] = None,
  282. ignore_patterns: Optional[Union[List[str], str]] = None,
  283. delete_patterns: Optional[Union[List[str], str]] = None,
  284. api_endpoint: Optional[str] = None,
  285. ):
  286. """
  287. Upload learner checkpoint files to the Hub.
  288. Use `allow_patterns` and `ignore_patterns` to precisely filter which files should be pushed to the hub. Use
  289. `delete_patterns` to delete existing remote files in the same commit. See [`upload_folder`] reference for more
  290. details.
  291. Args:
  292. learner (`Learner`):
  293. The `fastai.Learner' you'd like to push to the Hub.
  294. repo_id (`str`):
  295. The repository id for your model in Hub in the format of "namespace/repo_name". The namespace can be your individual account or an organization to which you have write access (for example, 'stanfordnlp/stanza-de').
  296. commit_message (`str`, *optional*):
  297. Message to commit while pushing. Will default to :obj:`"add model"`.
  298. private (`bool`, *optional*):
  299. Whether or not the repository created should be private.
  300. If `None` (default), will default to been public except if the organization's default is private.
  301. token (`str`, *optional*):
  302. The Hugging Face account token to use as HTTP bearer authorization for remote files. If :obj:`None`, the token will be asked by a prompt.
  303. config (`dict`, *optional*):
  304. Configuration object to be saved alongside the model weights.
  305. branch (`str`, *optional*):
  306. The git branch on which to push the model. This defaults to
  307. the default branch as specified in your repository, which
  308. defaults to `"main"`.
  309. create_pr (`boolean`, *optional*):
  310. Whether or not to create a Pull Request from `branch` with that commit.
  311. Defaults to `False`.
  312. api_endpoint (`str`, *optional*):
  313. The API endpoint to use when pushing the model to the hub.
  314. allow_patterns (`List[str]` or `str`, *optional*):
  315. If provided, only files matching at least one pattern are pushed.
  316. ignore_patterns (`List[str]` or `str`, *optional*):
  317. If provided, files matching any of the patterns are not pushed.
  318. delete_patterns (`List[str]` or `str`, *optional*):
  319. If provided, remote files matching any of the patterns will be deleted from the repo.
  320. Returns:
  321. The url of the commit of your model in the given repository.
  322. > [!TIP]
  323. > Raises the following error:
  324. >
  325. > - [`ValueError`](https://docs.python.org/3/library/exceptions.html#ValueError)
  326. > if the user is not log on to the Hugging Face Hub.
  327. """
  328. _check_fastai_fastcore_versions()
  329. api = HfApi(endpoint=api_endpoint)
  330. repo_id = api.create_repo(repo_id=repo_id, token=token, private=private, exist_ok=True).repo_id
  331. # Push the files to the repo in a single commit
  332. with SoftTemporaryDirectory() as tmp:
  333. saved_path = Path(tmp) / repo_id
  334. _save_pretrained_fastai(learner, saved_path, config=config)
  335. return api.upload_folder(
  336. repo_id=repo_id,
  337. token=token,
  338. folder_path=saved_path,
  339. commit_message=commit_message,
  340. revision=branch,
  341. create_pr=create_pr,
  342. allow_patterns=allow_patterns,
  343. ignore_patterns=ignore_patterns,
  344. delete_patterns=delete_patterns,
  345. )