dynamic_module_utils.py 36 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843
  1. # Copyright 2021 The HuggingFace Inc. team.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. """Utilities to dynamically load objects from the Hub."""
  15. import ast
  16. import filecmp
  17. import hashlib
  18. import importlib
  19. import importlib.metadata
  20. import importlib.util
  21. import keyword
  22. import os
  23. import re
  24. import shutil
  25. import signal
  26. import sys
  27. import threading
  28. import warnings
  29. from pathlib import Path
  30. from types import ModuleType
  31. from typing import Any, Optional, Union
  32. from huggingface_hub import try_to_load_from_cache
  33. from packaging import version
  34. from .utils import (
  35. HF_MODULES_CACHE,
  36. TRANSFORMERS_DYNAMIC_MODULE_NAME,
  37. cached_file,
  38. extract_commit_hash,
  39. is_offline_mode,
  40. logging,
  41. )
  42. from .utils.import_utils import VersionComparison, split_package_version
  43. logger = logging.get_logger(__name__) # pylint: disable=invalid-name
  44. def _sanitize_module_name(name: str) -> str:
  45. r"""
  46. Tries to sanitize a module name so that it can be used as a Python module.
  47. The following transformations are applied:
  48. 1. Replace `.` in module names with `_dot_`.
  49. 2. Replace `-` in module names with `_hyphen_`.
  50. 3. If the module name starts with a digit, prepend it with `_`.
  51. 4. Warn if the sanitized name is a Python reserved keyword or not a valid identifier.
  52. If the input name is already a valid identifier, it is returned unchanged.
  53. """
  54. # We not replacing `\W` characters with `_` to avoid collisions. Because `_` is a very common
  55. # separator used in module names, replacing `\W` with `_` would create too many collisions.
  56. # Once a module is imported, it is cached in `sys.modules` and the second import would return
  57. # the first module, which might not be the expected behavior if name collisions happen.
  58. new_name = name.replace(".", "_dot_").replace("-", "_hyphen_")
  59. if new_name and new_name[0].isdigit():
  60. new_name = f"_{new_name}"
  61. if keyword.iskeyword(new_name):
  62. logger.warning(
  63. f"The module name {new_name} (originally {name}) is a reserved keyword in Python. "
  64. "Please rename the original module to avoid import issues."
  65. )
  66. elif not new_name.isidentifier():
  67. logger.warning(
  68. f"The module name {new_name} (originally {name}) is not a valid Python identifier. "
  69. "Please rename the original module to avoid import issues."
  70. )
  71. return new_name
  72. _HF_REMOTE_CODE_LOCK = threading.Lock()
  73. def init_hf_modules():
  74. """
  75. Creates the cache directory for modules with an init, and adds it to the Python path.
  76. """
  77. # This function has already been executed if HF_MODULES_CACHE already is in the Python path.
  78. if HF_MODULES_CACHE in sys.path:
  79. return
  80. sys.path.append(HF_MODULES_CACHE)
  81. os.makedirs(HF_MODULES_CACHE, exist_ok=True)
  82. init_path = Path(HF_MODULES_CACHE) / "__init__.py"
  83. if not init_path.exists():
  84. init_path.touch()
  85. importlib.invalidate_caches()
  86. def create_dynamic_module(name: Union[str, os.PathLike]) -> None:
  87. """
  88. Creates a dynamic module in the cache directory for modules.
  89. Args:
  90. name (`str` or `os.PathLike`):
  91. The name of the dynamic module to create.
  92. """
  93. init_hf_modules()
  94. dynamic_module_path = (Path(HF_MODULES_CACHE) / name).resolve()
  95. # If the parent module does not exist yet, recursively create it.
  96. if not dynamic_module_path.parent.exists():
  97. create_dynamic_module(dynamic_module_path.parent)
  98. os.makedirs(dynamic_module_path, exist_ok=True)
  99. init_path = dynamic_module_path / "__init__.py"
  100. if not init_path.exists():
  101. init_path.touch()
  102. # It is extremely important to invalidate the cache when we change stuff in those modules, or users end up
  103. # with errors about module that do not exist. Same for all other `invalidate_caches` in this file.
  104. importlib.invalidate_caches()
  105. def get_relative_imports(module_file: Union[str, os.PathLike]) -> list[str]:
  106. """
  107. Get the list of modules that are relatively imported in a module file.
  108. Args:
  109. module_file (`str` or `os.PathLike`): The module file to inspect.
  110. Returns:
  111. `list[str]`: The list of relative imports in the module.
  112. """
  113. with open(module_file, encoding="utf-8") as f:
  114. content = f.read()
  115. # Imports of the form `import .xxx`
  116. relative_imports = re.findall(r"^\s*import\s+\.(\S+)\s*$", content, flags=re.MULTILINE)
  117. # Imports of the form `from .xxx import yyy`
  118. relative_imports += re.findall(r"^\s*from\s+\.(\S+)\s+import", content, flags=re.MULTILINE)
  119. # Unique-ify
  120. return list(set(relative_imports))
  121. def get_relative_import_files(module_file: Union[str, os.PathLike]) -> list[str]:
  122. """
  123. Get the list of all files that are needed for a given module. Note that this function recurses through the relative
  124. imports (if a imports b and b imports c, it will return module files for b and c).
  125. Args:
  126. module_file (`str` or `os.PathLike`): The module file to inspect.
  127. Returns:
  128. `list[str]`: The list of all relative imports a given module needs (recursively), which will give us the list
  129. of module files a given module needs.
  130. """
  131. no_change = False
  132. files_to_check = [module_file]
  133. all_relative_imports = []
  134. # Let's recurse through all relative imports
  135. while not no_change:
  136. new_imports = []
  137. for f in files_to_check:
  138. new_imports.extend(get_relative_imports(f))
  139. module_path = Path(module_file).parent
  140. new_import_files = [f"{str(module_path / m)}.py" for m in new_imports]
  141. files_to_check = [f for f in new_import_files if f not in all_relative_imports]
  142. no_change = len(files_to_check) == 0
  143. all_relative_imports.extend(files_to_check)
  144. return all_relative_imports
  145. def get_imports(filename: Union[str, os.PathLike]) -> list[str]:
  146. """
  147. Extracts all the libraries (not relative imports this time) that are imported in a file.
  148. Args:
  149. filename (`str` or `os.PathLike`): The module file to inspect.
  150. Returns:
  151. `list[str]`: The list of all packages required to use the input module.
  152. """
  153. with open(filename, encoding="utf-8") as f:
  154. content = f.read()
  155. imported_modules = set()
  156. import transformers.utils
  157. def recursive_look_for_imports(node):
  158. if isinstance(node, ast.Try):
  159. return # Don't recurse into Try blocks and ignore imports in them
  160. elif isinstance(node, ast.If):
  161. test = node.test
  162. for condition_node in ast.walk(test):
  163. if isinstance(condition_node, ast.Call):
  164. check_function = getattr(condition_node.func, "id", "")
  165. if (
  166. check_function.endswith("available")
  167. and check_function.startswith("is_flash_attn")
  168. or hasattr(transformers.utils.import_utils, check_function)
  169. ):
  170. # Don't recurse into "if flash_attn_available()" or any "if library_available" blocks
  171. # that appears in `transformers.utils.import_utils` and ignore imports in them
  172. return
  173. elif isinstance(node, ast.Import):
  174. # Handle 'import x' statements
  175. for alias in node.names:
  176. top_module = alias.name.split(".")[0]
  177. if top_module:
  178. imported_modules.add(top_module)
  179. elif isinstance(node, ast.ImportFrom):
  180. # Handle 'from x import y' statements, ignoring relative imports
  181. if node.level == 0 and node.module:
  182. top_module = node.module.split(".")[0]
  183. if top_module:
  184. imported_modules.add(top_module)
  185. # Recursively visit all children
  186. for child in ast.iter_child_nodes(node):
  187. recursive_look_for_imports(child)
  188. tree = ast.parse(content)
  189. recursive_look_for_imports(tree)
  190. return sorted(imported_modules)
  191. def check_imports(filename: Union[str, os.PathLike]) -> list[str]:
  192. """
  193. Check if the current Python environment contains all the libraries that are imported in a file. Will raise if a
  194. library is missing.
  195. Args:
  196. filename (`str` or `os.PathLike`): The module file to check.
  197. Returns:
  198. `list[str]`: The list of relative imports in the file.
  199. """
  200. imports = get_imports(filename)
  201. missing_packages = []
  202. for imp in imports:
  203. try:
  204. importlib.import_module(imp)
  205. except ImportError as exception:
  206. logger.warning(f"Encountered exception while importing {imp}: {exception}")
  207. # Some packages can fail with an ImportError because of a dependency issue.
  208. # This check avoids hiding such errors.
  209. # See https://github.com/huggingface/transformers/issues/33604
  210. if "No module named" in str(exception):
  211. missing_packages.append(imp)
  212. else:
  213. raise
  214. if len(missing_packages) > 0:
  215. raise ImportError(
  216. "This modeling file requires the following packages that were not found in your environment: "
  217. f"{', '.join(missing_packages)}. Run `pip install {' '.join(missing_packages)}`"
  218. )
  219. return get_relative_imports(filename)
  220. def get_class_in_module(
  221. class_name: str,
  222. module_path: Union[str, os.PathLike],
  223. *,
  224. force_reload: bool = False,
  225. ) -> type:
  226. """
  227. Import a module on the cache directory for modules and extract a class from it.
  228. Args:
  229. class_name (`str`): The name of the class to import.
  230. module_path (`str` or `os.PathLike`): The path to the module to import.
  231. force_reload (`bool`, *optional*, defaults to `False`):
  232. Whether to reload the dynamic module from file if it already exists in `sys.modules`.
  233. Otherwise, the module is only reloaded if the file has changed.
  234. Returns:
  235. `typing.Type`: The class looked for.
  236. """
  237. name = os.path.normpath(module_path)
  238. name = name.removesuffix(".py")
  239. name = name.replace(os.path.sep, ".")
  240. module_file: Path = Path(HF_MODULES_CACHE) / module_path
  241. with _HF_REMOTE_CODE_LOCK:
  242. if force_reload:
  243. sys.modules.pop(name, None)
  244. importlib.invalidate_caches()
  245. cached_module: Optional[ModuleType] = sys.modules.get(name)
  246. module_spec = importlib.util.spec_from_file_location(name, location=module_file)
  247. # Hash the module file and all its relative imports to check if we need to reload it
  248. module_files: list[Path] = [module_file] + sorted(map(Path, get_relative_import_files(module_file)))
  249. module_hash: str = hashlib.sha256(b"".join(bytes(f) + f.read_bytes() for f in module_files)).hexdigest()
  250. module: ModuleType
  251. if cached_module is None:
  252. module = importlib.util.module_from_spec(module_spec)
  253. # insert it into sys.modules before any loading begins
  254. sys.modules[name] = module
  255. else:
  256. module = cached_module
  257. # reload in both cases, unless the module is already imported and the hash hits
  258. if getattr(module, "__transformers_module_hash__", "") != module_hash:
  259. module_spec.loader.exec_module(module)
  260. module.__transformers_module_hash__ = module_hash
  261. return getattr(module, class_name)
  262. def get_cached_module_file(
  263. pretrained_model_name_or_path: Union[str, os.PathLike],
  264. module_file: str,
  265. cache_dir: Optional[Union[str, os.PathLike]] = None,
  266. force_download: bool = False,
  267. resume_download: Optional[bool] = None,
  268. proxies: Optional[dict[str, str]] = None,
  269. token: Optional[Union[bool, str]] = None,
  270. revision: Optional[str] = None,
  271. local_files_only: bool = False,
  272. repo_type: Optional[str] = None,
  273. _commit_hash: Optional[str] = None,
  274. **deprecated_kwargs,
  275. ) -> str:
  276. """
  277. Prepares Downloads a module from a local folder or a distant repo and returns its path inside the cached
  278. Transformers module.
  279. Args:
  280. pretrained_model_name_or_path (`str` or `os.PathLike`):
  281. This can be either:
  282. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  283. huggingface.co.
  284. - a path to a *directory* containing a configuration file saved using the
  285. [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
  286. module_file (`str`):
  287. The name of the module file containing the class to look for.
  288. cache_dir (`str` or `os.PathLike`, *optional*):
  289. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  290. cache should not be used.
  291. force_download (`bool`, *optional*, defaults to `False`):
  292. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  293. exist.
  294. resume_download:
  295. Deprecated and ignored. All downloads are now resumed by default when possible.
  296. Will be removed in v5 of Transformers.
  297. proxies (`dict[str, str]`, *optional*):
  298. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  299. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  300. token (`str` or *bool*, *optional*):
  301. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  302. when running `hf auth login` (stored in `~/.huggingface`).
  303. revision (`str`, *optional*, defaults to `"main"`):
  304. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  305. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  306. identifier allowed by git.
  307. local_files_only (`bool`, *optional*, defaults to `False`):
  308. If `True`, will only try to load the tokenizer configuration from local files.
  309. repo_type (`str`, *optional*):
  310. Specify the repo type (useful when downloading from a space for instance).
  311. <Tip>
  312. Passing `token=True` is required when you want to use a private model.
  313. </Tip>
  314. Returns:
  315. `str`: The path to the module inside the cache.
  316. """
  317. use_auth_token = deprecated_kwargs.pop("use_auth_token", None)
  318. if use_auth_token is not None:
  319. warnings.warn(
  320. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  321. FutureWarning,
  322. )
  323. if token is not None:
  324. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  325. token = use_auth_token
  326. if is_offline_mode() and not local_files_only:
  327. logger.info("Offline mode: forcing local_files_only=True")
  328. local_files_only = True
  329. # Download and cache module_file from the repo `pretrained_model_name_or_path` of grab it if it's a local file.
  330. pretrained_model_name_or_path = str(pretrained_model_name_or_path)
  331. is_local = os.path.isdir(pretrained_model_name_or_path)
  332. if is_local:
  333. submodule = _sanitize_module_name(os.path.basename(pretrained_model_name_or_path))
  334. else:
  335. submodule = os.path.sep.join(map(_sanitize_module_name, pretrained_model_name_or_path.split("/")))
  336. cached_module = try_to_load_from_cache(
  337. pretrained_model_name_or_path, module_file, cache_dir=cache_dir, revision=_commit_hash, repo_type=repo_type
  338. )
  339. new_files = []
  340. try:
  341. # Load from URL or cache if already cached
  342. resolved_module_file = cached_file(
  343. pretrained_model_name_or_path,
  344. module_file,
  345. cache_dir=cache_dir,
  346. force_download=force_download,
  347. proxies=proxies,
  348. resume_download=resume_download,
  349. local_files_only=local_files_only,
  350. token=token,
  351. revision=revision,
  352. repo_type=repo_type,
  353. _commit_hash=_commit_hash,
  354. )
  355. if not is_local and cached_module != resolved_module_file:
  356. new_files.append(module_file)
  357. except OSError:
  358. logger.info(f"Could not locate the {module_file} inside {pretrained_model_name_or_path}.")
  359. raise
  360. # Check we have all the requirements in our environment
  361. modules_needed = check_imports(resolved_module_file)
  362. # Now we move the module inside our cached dynamic modules.
  363. full_submodule = TRANSFORMERS_DYNAMIC_MODULE_NAME + os.path.sep + submodule
  364. create_dynamic_module(full_submodule)
  365. submodule_path = Path(HF_MODULES_CACHE) / full_submodule
  366. if submodule == _sanitize_module_name(os.path.basename(pretrained_model_name_or_path)):
  367. # We copy local files to avoid putting too many folders in sys.path. This copy is done when the file is new or
  368. # has changed since last copy.
  369. if not (submodule_path / module_file).exists() or not filecmp.cmp(
  370. resolved_module_file, str(submodule_path / module_file)
  371. ):
  372. (submodule_path / module_file).parent.mkdir(parents=True, exist_ok=True)
  373. shutil.copy(resolved_module_file, submodule_path / module_file)
  374. importlib.invalidate_caches()
  375. for module_needed in modules_needed:
  376. module_needed = Path(module_file).parent / f"{module_needed}.py"
  377. module_needed_file = os.path.join(pretrained_model_name_or_path, module_needed)
  378. if not (submodule_path / module_needed).exists() or not filecmp.cmp(
  379. module_needed_file, str(submodule_path / module_needed)
  380. ):
  381. shutil.copy(module_needed_file, submodule_path / module_needed)
  382. importlib.invalidate_caches()
  383. else:
  384. # Get the commit hash
  385. commit_hash = extract_commit_hash(resolved_module_file, _commit_hash)
  386. # The module file will end up being placed in a subfolder with the git hash of the repo. This way we get the
  387. # benefit of versioning.
  388. submodule_path = submodule_path / commit_hash
  389. full_submodule = full_submodule + os.path.sep + commit_hash
  390. full_submodule_module_file_path = os.path.join(full_submodule, module_file)
  391. create_dynamic_module(Path(full_submodule_module_file_path).parent)
  392. if not (submodule_path / module_file).exists():
  393. shutil.copy(resolved_module_file, submodule_path / module_file)
  394. importlib.invalidate_caches()
  395. # Make sure we also have every file with relative
  396. for module_needed in modules_needed:
  397. if not ((submodule_path / module_file).parent / f"{module_needed}.py").exists():
  398. get_cached_module_file(
  399. pretrained_model_name_or_path,
  400. f"{Path(module_file).parent / module_needed}.py",
  401. cache_dir=cache_dir,
  402. force_download=force_download,
  403. resume_download=resume_download,
  404. proxies=proxies,
  405. token=token,
  406. revision=revision,
  407. local_files_only=local_files_only,
  408. _commit_hash=commit_hash,
  409. )
  410. new_files.append(f"{module_needed}.py")
  411. if len(new_files) > 0 and revision is None:
  412. new_files = "\n".join([f"- {f}" for f in new_files])
  413. repo_type_str = "" if repo_type is None else f"{repo_type}s/"
  414. url = f"https://huggingface.co/{repo_type_str}{pretrained_model_name_or_path}"
  415. logger.warning(
  416. f"A new version of the following files was downloaded from {url}:\n{new_files}"
  417. "\n. Make sure to double-check they do not contain any added malicious code. To avoid downloading new "
  418. "versions of the code file, you can pin a revision."
  419. )
  420. return os.path.join(full_submodule, module_file)
  421. def get_class_from_dynamic_module(
  422. class_reference: str,
  423. pretrained_model_name_or_path: Union[str, os.PathLike],
  424. cache_dir: Optional[Union[str, os.PathLike]] = None,
  425. force_download: bool = False,
  426. resume_download: Optional[bool] = None,
  427. proxies: Optional[dict[str, str]] = None,
  428. token: Optional[Union[bool, str]] = None,
  429. revision: Optional[str] = None,
  430. local_files_only: bool = False,
  431. repo_type: Optional[str] = None,
  432. code_revision: Optional[str] = None,
  433. **kwargs,
  434. ) -> type:
  435. """
  436. Extracts a class from a module file, present in the local folder or repository of a model.
  437. <Tip warning={true}>
  438. Calling this function will execute the code in the module file found locally or downloaded from the Hub. It should
  439. therefore only be called on trusted repos.
  440. </Tip>
  441. Args:
  442. class_reference (`str`):
  443. The full name of the class to load, including its module and optionally its repo.
  444. pretrained_model_name_or_path (`str` or `os.PathLike`):
  445. This can be either:
  446. - a string, the *model id* of a pretrained model configuration hosted inside a model repo on
  447. huggingface.co.
  448. - a path to a *directory* containing a configuration file saved using the
  449. [`~PreTrainedTokenizer.save_pretrained`] method, e.g., `./my_model_directory/`.
  450. This is used when `class_reference` does not specify another repo.
  451. module_file (`str`):
  452. The name of the module file containing the class to look for.
  453. class_name (`str`):
  454. The name of the class to import in the module.
  455. cache_dir (`str` or `os.PathLike`, *optional*):
  456. Path to a directory in which a downloaded pretrained model configuration should be cached if the standard
  457. cache should not be used.
  458. force_download (`bool`, *optional*, defaults to `False`):
  459. Whether or not to force to (re-)download the configuration files and override the cached versions if they
  460. exist.
  461. resume_download:
  462. Deprecated and ignored. All downloads are now resumed by default when possible.
  463. Will be removed in v5 of Transformers.
  464. proxies (`dict[str, str]`, *optional*):
  465. A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
  466. 'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
  467. token (`str` or `bool`, *optional*):
  468. The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
  469. when running `hf auth login` (stored in `~/.huggingface`).
  470. revision (`str`, *optional*, defaults to `"main"`):
  471. The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
  472. git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
  473. identifier allowed by git.
  474. local_files_only (`bool`, *optional*, defaults to `False`):
  475. If `True`, will only try to load the tokenizer configuration from local files.
  476. repo_type (`str`, *optional*):
  477. Specify the repo type (useful when downloading from a space for instance).
  478. code_revision (`str`, *optional*, defaults to `"main"`):
  479. The specific revision to use for the code on the Hub, if the code leaves in a different repository than the
  480. rest of the model. It can be a branch name, a tag name, or a commit id, since we use a git-based system for
  481. storing models and other artifacts on huggingface.co, so `revision` can be any identifier allowed by git.
  482. <Tip>
  483. Passing `token=True` is required when you want to use a private model.
  484. </Tip>
  485. Returns:
  486. `typing.Type`: The class, dynamically imported from the module.
  487. Examples:
  488. ```python
  489. # Download module `modeling.py` from huggingface.co and cache then extract the class `MyBertModel` from this
  490. # module.
  491. cls = get_class_from_dynamic_module("modeling.MyBertModel", "sgugger/my-bert-model")
  492. # Download module `modeling.py` from a given repo and cache then extract the class `MyBertModel` from this
  493. # module.
  494. cls = get_class_from_dynamic_module("sgugger/my-bert-model--modeling.MyBertModel", "sgugger/another-bert-model")
  495. ```"""
  496. use_auth_token = kwargs.pop("use_auth_token", None)
  497. if use_auth_token is not None:
  498. warnings.warn(
  499. "The `use_auth_token` argument is deprecated and will be removed in v5 of Transformers. Please use `token` instead.",
  500. FutureWarning,
  501. )
  502. if token is not None:
  503. raise ValueError("`token` and `use_auth_token` are both specified. Please set only the argument `token`.")
  504. token = use_auth_token
  505. # Catch the name of the repo if it's specified in `class_reference`
  506. if "--" in class_reference:
  507. repo_id, class_reference = class_reference.split("--")
  508. else:
  509. repo_id = pretrained_model_name_or_path
  510. module_file, class_name = class_reference.split(".")
  511. if code_revision is None and pretrained_model_name_or_path == repo_id:
  512. code_revision = revision
  513. # And lastly we get the class inside our newly created module
  514. final_module = get_cached_module_file(
  515. repo_id,
  516. module_file + ".py",
  517. cache_dir=cache_dir,
  518. force_download=force_download,
  519. resume_download=resume_download,
  520. proxies=proxies,
  521. token=token,
  522. revision=code_revision,
  523. local_files_only=local_files_only,
  524. repo_type=repo_type,
  525. )
  526. return get_class_in_module(class_name, final_module, force_reload=force_download)
  527. def custom_object_save(obj: Any, folder: Union[str, os.PathLike], config: Optional[dict] = None) -> list[str]:
  528. """
  529. Save the modeling files corresponding to a custom model/configuration/tokenizer etc. in a given folder. Optionally
  530. adds the proper fields in a config.
  531. Args:
  532. obj (`Any`): The object for which to save the module files.
  533. folder (`str` or `os.PathLike`): The folder where to save.
  534. config (`PretrainedConfig` or dictionary, `optional`):
  535. A config in which to register the auto_map corresponding to this custom object.
  536. Returns:
  537. `list[str]`: The list of files saved.
  538. """
  539. if obj.__module__ == "__main__":
  540. logger.warning(
  541. f"We can't save the code defining {obj} in {folder} as it's been defined in __main__. You should put "
  542. "this code in a separate module so we can include it in the saved folder and make it easier to share via "
  543. "the Hub."
  544. )
  545. return
  546. def _set_auto_map_in_config(_config):
  547. module_name = obj.__class__.__module__
  548. last_module = module_name.split(".")[-1]
  549. full_name = f"{last_module}.{obj.__class__.__name__}"
  550. # Special handling for tokenizers
  551. if "Tokenizer" in full_name:
  552. slow_tokenizer_class = None
  553. fast_tokenizer_class = None
  554. if obj.__class__.__name__.endswith("Fast"):
  555. # Fast tokenizer: we have the fast tokenizer class and we may have the slow one has an attribute.
  556. fast_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
  557. if getattr(obj, "slow_tokenizer_class", None) is not None:
  558. slow_tokenizer = getattr(obj, "slow_tokenizer_class")
  559. slow_tok_module_name = slow_tokenizer.__module__
  560. last_slow_tok_module = slow_tok_module_name.split(".")[-1]
  561. slow_tokenizer_class = f"{last_slow_tok_module}.{slow_tokenizer.__name__}"
  562. else:
  563. # Slow tokenizer: no way to have the fast class
  564. slow_tokenizer_class = f"{last_module}.{obj.__class__.__name__}"
  565. full_name = (slow_tokenizer_class, fast_tokenizer_class)
  566. if isinstance(_config, dict):
  567. auto_map = _config.get("auto_map", {})
  568. auto_map[obj._auto_class] = full_name
  569. _config["auto_map"] = auto_map
  570. elif getattr(_config, "auto_map", None) is not None:
  571. _config.auto_map[obj._auto_class] = full_name
  572. else:
  573. _config.auto_map = {obj._auto_class: full_name}
  574. # Add object class to the config auto_map
  575. if isinstance(config, (list, tuple)):
  576. for cfg in config:
  577. _set_auto_map_in_config(cfg)
  578. elif config is not None:
  579. _set_auto_map_in_config(config)
  580. result = []
  581. # Copy module file to the output folder.
  582. object_file = sys.modules[obj.__module__].__file__
  583. dest_file = Path(folder) / (Path(object_file).name)
  584. shutil.copy(object_file, dest_file)
  585. result.append(dest_file)
  586. # Gather all relative imports recursively and make sure they are copied as well.
  587. for needed_file in get_relative_import_files(object_file):
  588. dest_file = Path(folder) / (Path(needed_file).name)
  589. shutil.copy(needed_file, dest_file)
  590. result.append(dest_file)
  591. return result
  592. def _raise_timeout_error(signum, frame):
  593. raise ValueError(
  594. "Loading this model requires you to execute custom code contained in the model repository on your local "
  595. "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
  596. )
  597. TIME_OUT_REMOTE_CODE = 15
  598. def resolve_trust_remote_code(
  599. trust_remote_code, model_name, has_local_code, has_remote_code, error_message=None, upstream_repo=None
  600. ):
  601. """
  602. Resolves the `trust_remote_code` argument. If there is remote code to be loaded, the user must opt-in to loading
  603. it.
  604. Args:
  605. trust_remote_code (`bool` or `None`):
  606. User-defined `trust_remote_code` value.
  607. model_name (`str`):
  608. The name of the model repository in huggingface.co.
  609. has_local_code (`bool`):
  610. Whether the model has local code.
  611. has_remote_code (`bool`):
  612. Whether the model has remote code.
  613. error_message (`str`, *optional*):
  614. Custom error message to display if there is remote code to load and the user didn't opt-in. If unset, the error
  615. message will be regarding loading a model with custom code.
  616. Returns:
  617. The resolved `trust_remote_code` value.
  618. """
  619. if error_message is None:
  620. if upstream_repo is not None:
  621. error_message = (
  622. f"The repository {model_name} references custom code contained in {upstream_repo} which "
  623. f"must be executed to correctly load the model. You can inspect the repository "
  624. f"content at https://hf.co/{upstream_repo} .\n"
  625. )
  626. elif os.path.isdir(model_name):
  627. error_message = (
  628. f"The repository {model_name} contains custom code which must be executed "
  629. f"to correctly load the model. You can inspect the repository "
  630. f"content at {os.path.abspath(model_name)} .\n"
  631. )
  632. else:
  633. error_message = (
  634. f"The repository {model_name} contains custom code which must be executed "
  635. f"to correctly load the model. You can inspect the repository "
  636. f"content at https://hf.co/{model_name} .\n"
  637. )
  638. if trust_remote_code is None:
  639. if has_local_code:
  640. trust_remote_code = False
  641. elif has_remote_code and TIME_OUT_REMOTE_CODE > 0:
  642. prev_sig_handler = None
  643. try:
  644. prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
  645. signal.alarm(TIME_OUT_REMOTE_CODE)
  646. while trust_remote_code is None:
  647. answer = input(
  648. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  649. f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
  650. f"Do you wish to run the custom code? [y/N] "
  651. )
  652. if answer.lower() in ["yes", "y", "1"]:
  653. trust_remote_code = True
  654. elif answer.lower() in ["no", "n", "0", ""]:
  655. trust_remote_code = False
  656. signal.alarm(0)
  657. except Exception:
  658. # OS which does not support signal.SIGALRM
  659. raise ValueError(
  660. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  661. f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
  662. )
  663. finally:
  664. if prev_sig_handler is not None:
  665. signal.signal(signal.SIGALRM, prev_sig_handler)
  666. signal.alarm(0)
  667. elif has_remote_code:
  668. # For the CI which puts the timeout at 0
  669. _raise_timeout_error(None, None)
  670. if has_remote_code and not has_local_code and not trust_remote_code:
  671. raise ValueError(
  672. f"{error_message} You can inspect the repository content at https://hf.co/{model_name}.\n"
  673. f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
  674. )
  675. return trust_remote_code
  676. def check_python_requirements(path_or_repo_id, requirements_file="requirements.txt", **kwargs):
  677. """
  678. Tries to locate `requirements_file` in a local folder or repo, and confirms that the environment has all the
  679. python dependencies installed.
  680. Args:
  681. path_or_repo_id (`str` or `os.PathLike`):
  682. This can be either:
  683. - a string, the *model id* of a model repo on huggingface.co.
  684. - a path to a *directory* potentially containing the file.
  685. kwargs (`dict[str, Any]`, *optional*):
  686. Additional arguments to pass to `cached_file`.
  687. """
  688. failed = [] # error messages regarding requirements
  689. try:
  690. requirements = cached_file(path_or_repo_id=path_or_repo_id, filename=requirements_file, **kwargs)
  691. with open(requirements, "r") as f:
  692. requirements = f.readlines()
  693. for requirement in requirements:
  694. requirement = requirement.strip()
  695. if not requirement or requirement.startswith("#"): # skip empty lines and comments
  696. continue
  697. try:
  698. # e.g. "torch>2.6.0" -> "torch", ">", "2.6.0"
  699. package_name, delimiter, version_number = split_package_version(requirement)
  700. except ValueError: # e.g. "torch", as opposed to "torch>2.6.0"
  701. package_name = requirement
  702. delimiter, version_number = None, None
  703. try:
  704. local_package_version = importlib.metadata.version(package_name)
  705. except importlib.metadata.PackageNotFoundError:
  706. failed.append(f"{requirement} (installed: None)")
  707. continue
  708. if delimiter is not None and version_number is not None:
  709. is_satisfied = VersionComparison.from_string(delimiter)(
  710. version.parse(local_package_version), version.parse(version_number)
  711. )
  712. else:
  713. is_satisfied = True
  714. if not is_satisfied:
  715. failed.append(f"{requirement} (installed: {local_package_version})")
  716. except OSError: # no requirements.txt
  717. pass
  718. if failed:
  719. raise ImportError(
  720. f"Missing requirements in your local environment for `{path_or_repo_id}`:\n" + "\n".join(failed)
  721. )