__init__.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Inc. team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from functools import lru_cache
  16. from huggingface_hub import get_full_repo_name # for backward compatibility
  17. from huggingface_hub.constants import HF_HUB_DISABLE_TELEMETRY as DISABLE_TELEMETRY # for backward compatibility
  18. from packaging import version
  19. from .. import __version__
  20. from .auto_docstring import (
  21. ClassAttrs,
  22. ClassDocstring,
  23. ImageProcessorArgs,
  24. ModelArgs,
  25. ModelOutputArgs,
  26. auto_class_docstring,
  27. auto_docstring,
  28. get_args_doc_from_source,
  29. parse_docstring,
  30. set_min_indent,
  31. )
  32. from .backbone_utils import BackboneConfigMixin, BackboneMixin
  33. from .chat_template_utils import DocstringParsingException, TypeHintParsingException, get_json_schema
  34. from .constants import IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD, IMAGENET_STANDARD_MEAN, IMAGENET_STANDARD_STD
  35. from .doc import (
  36. add_code_sample_docstrings,
  37. add_end_docstrings,
  38. add_start_docstrings,
  39. add_start_docstrings_to_model_forward,
  40. copy_func,
  41. replace_return_docstrings,
  42. )
  43. from .generic import (
  44. ContextManagers,
  45. ExplicitEnum,
  46. ModelOutput,
  47. PaddingStrategy,
  48. TensorType,
  49. TransformersKwargs,
  50. can_return_loss,
  51. can_return_tuple,
  52. expand_dims,
  53. filter_out_non_signature_kwargs,
  54. find_labels,
  55. flatten_dict,
  56. infer_framework,
  57. is_jax_tensor,
  58. is_numpy_array,
  59. is_tensor,
  60. is_tf_symbolic_tensor,
  61. is_tf_tensor,
  62. is_timm_config_dict,
  63. is_timm_local_checkpoint,
  64. is_torch_device,
  65. is_torch_dtype,
  66. is_torch_tensor,
  67. reshape,
  68. squeeze,
  69. strtobool,
  70. tensor_size,
  71. to_numpy,
  72. to_py_obj,
  73. torch_float,
  74. torch_int,
  75. transpose,
  76. working_or_temp_dir,
  77. )
  78. from .hub import (
  79. CHAT_TEMPLATE_DIR,
  80. CHAT_TEMPLATE_FILE,
  81. CLOUDFRONT_DISTRIB_PREFIX,
  82. HF_MODULES_CACHE,
  83. HUGGINGFACE_CO_PREFIX,
  84. HUGGINGFACE_CO_RESOLVE_ENDPOINT,
  85. LEGACY_PROCESSOR_CHAT_TEMPLATE_FILE,
  86. PYTORCH_PRETRAINED_BERT_CACHE,
  87. PYTORCH_TRANSFORMERS_CACHE,
  88. S3_BUCKET_PREFIX,
  89. TRANSFORMERS_CACHE,
  90. TRANSFORMERS_DYNAMIC_MODULE_NAME,
  91. EntryNotFoundError,
  92. PushInProgress,
  93. PushToHubMixin,
  94. RepositoryNotFoundError,
  95. RevisionNotFoundError,
  96. cached_file,
  97. default_cache_path,
  98. define_sagemaker_information,
  99. download_url,
  100. extract_commit_hash,
  101. has_file,
  102. http_user_agent,
  103. is_offline_mode,
  104. is_remote_url,
  105. list_repo_templates,
  106. try_to_load_from_cache,
  107. )
  108. from .import_utils import (
  109. ACCELERATE_MIN_VERSION,
  110. ENV_VARS_TRUE_AND_AUTO_VALUES,
  111. ENV_VARS_TRUE_VALUES,
  112. GGUF_MIN_VERSION,
  113. TORCH_FX_REQUIRED_VERSION,
  114. TRITON_MIN_VERSION,
  115. USE_JAX,
  116. USE_TF,
  117. USE_TORCH,
  118. XLA_FSDPV2_MIN_VERSION,
  119. DummyObject,
  120. OptionalDependencyNotAvailable,
  121. _LazyModule,
  122. ccl_version,
  123. check_torch_load_is_safe,
  124. direct_transformers_import,
  125. get_torch_version,
  126. is_accelerate_available,
  127. is_apex_available,
  128. is_apollo_torch_available,
  129. is_aqlm_available,
  130. is_auto_awq_available,
  131. is_auto_gptq_available,
  132. is_auto_round_available,
  133. is_av_available,
  134. is_bitsandbytes_available,
  135. is_bitsandbytes_multi_backend_available,
  136. is_bs4_available,
  137. is_ccl_available,
  138. is_coloredlogs_available,
  139. is_compressed_tensors_available,
  140. is_cuda_platform,
  141. is_cv2_available,
  142. is_cython_available,
  143. is_datasets_available,
  144. is_decord_available,
  145. is_detectron2_available,
  146. is_eetq_available,
  147. is_essentia_available,
  148. is_faiss_available,
  149. is_fbgemm_gpu_available,
  150. is_flash_attn_2_available,
  151. is_flash_attn_3_available,
  152. is_flash_attn_greater_or_equal,
  153. is_flash_attn_greater_or_equal_2_10,
  154. is_flax_available,
  155. is_flute_available,
  156. is_fp_quant_available,
  157. is_fsdp_available,
  158. is_ftfy_available,
  159. is_g2p_en_available,
  160. is_galore_torch_available,
  161. is_gguf_available,
  162. is_gptqmodel_available,
  163. is_grokadamw_available,
  164. is_habana_gaudi1,
  165. is_hadamard_available,
  166. is_hqq_available,
  167. is_huggingface_hub_greater_or_equal,
  168. is_in_notebook,
  169. is_ipex_available,
  170. is_jinja_available,
  171. is_jumanpp_available,
  172. is_kenlm_available,
  173. is_keras_nlp_available,
  174. is_kernels_available,
  175. is_levenshtein_available,
  176. is_libcst_available,
  177. is_librosa_available,
  178. is_liger_kernel_available,
  179. is_lomo_available,
  180. is_matplotlib_available,
  181. is_mistral_common_available,
  182. is_mlx_available,
  183. is_natten_available,
  184. is_ninja_available,
  185. is_nltk_available,
  186. is_num2words_available,
  187. is_onnx_available,
  188. is_openai_available,
  189. is_optimum_available,
  190. is_optimum_quanto_available,
  191. is_pandas_available,
  192. is_peft_available,
  193. is_phonemizer_available,
  194. is_pretty_midi_available,
  195. is_protobuf_available,
  196. is_psutil_available,
  197. is_py3nvml_available,
  198. is_pyctcdecode_available,
  199. is_pytesseract_available,
  200. is_pytest_available,
  201. is_pytorch_quantization_available,
  202. is_quanto_greater,
  203. is_quark_available,
  204. is_qutlass_available,
  205. is_rich_available,
  206. is_rjieba_available,
  207. is_rocm_platform,
  208. is_sacremoses_available,
  209. is_safetensors_available,
  210. is_sagemaker_dp_enabled,
  211. is_sagemaker_mp_enabled,
  212. is_schedulefree_available,
  213. is_scipy_available,
  214. is_sentencepiece_available,
  215. is_seqio_available,
  216. is_sklearn_available,
  217. is_soundfile_available,
  218. is_spacy_available,
  219. is_speech_available,
  220. is_spqr_available,
  221. is_sudachi_available,
  222. is_sudachi_projection_available,
  223. is_tensorflow_probability_available,
  224. is_tensorflow_text_available,
  225. is_tf2onnx_available,
  226. is_tf_available,
  227. is_tiktoken_available,
  228. is_timm_available,
  229. is_tokenizers_available,
  230. is_torch_accelerator_available,
  231. is_torch_available,
  232. is_torch_bf16_available,
  233. is_torch_bf16_available_on_device,
  234. is_torch_bf16_cpu_available,
  235. is_torch_bf16_gpu_available,
  236. is_torch_compile_available,
  237. is_torch_cuda_available,
  238. is_torch_deterministic,
  239. is_torch_flex_attn_available,
  240. is_torch_fp16_available_on_device,
  241. is_torch_fx_available,
  242. is_torch_fx_proxy,
  243. is_torch_greater_or_equal,
  244. is_torch_hpu_available,
  245. is_torch_mlu_available,
  246. is_torch_mps_available,
  247. is_torch_musa_available,
  248. is_torch_neuroncore_available,
  249. is_torch_npu_available,
  250. is_torch_optimi_available,
  251. is_torch_sdpa_available,
  252. is_torch_tensorrt_fx_available,
  253. is_torch_tf32_available,
  254. is_torch_xla_available,
  255. is_torch_xpu_available,
  256. is_torchao_available,
  257. is_torchaudio_available,
  258. is_torchcodec_available,
  259. is_torchdistx_available,
  260. is_torchdynamo_available,
  261. is_torchdynamo_compiling,
  262. is_torchdynamo_exporting,
  263. is_torchvision_available,
  264. is_torchvision_v2_available,
  265. is_training_run_on_sagemaker,
  266. is_triton_available,
  267. is_uroman_available,
  268. is_vision_available,
  269. is_vptq_available,
  270. is_xlstm_available,
  271. is_yt_dlp_available,
  272. requires_backends,
  273. torch_only_method,
  274. )
  275. from .peft_utils import (
  276. ADAPTER_CONFIG_NAME,
  277. ADAPTER_SAFE_WEIGHTS_NAME,
  278. ADAPTER_WEIGHTS_NAME,
  279. check_peft_version,
  280. find_adapter_config_file,
  281. )
  282. WEIGHTS_NAME = "pytorch_model.bin"
  283. WEIGHTS_INDEX_NAME = "pytorch_model.bin.index.json"
  284. TF2_WEIGHTS_NAME = "tf_model.h5"
  285. TF2_WEIGHTS_INDEX_NAME = "tf_model.h5.index.json"
  286. TF_WEIGHTS_NAME = "model.ckpt"
  287. FLAX_WEIGHTS_NAME = "flax_model.msgpack"
  288. FLAX_WEIGHTS_INDEX_NAME = "flax_model.msgpack.index.json"
  289. SAFE_WEIGHTS_NAME = "model.safetensors"
  290. SAFE_WEIGHTS_INDEX_NAME = "model.safetensors.index.json"
  291. CONFIG_NAME = "config.json"
  292. FEATURE_EXTRACTOR_NAME = "preprocessor_config.json"
  293. IMAGE_PROCESSOR_NAME = "preprocessor_config.json"
  294. VIDEO_PROCESSOR_NAME = "video_preprocessor_config.json"
  295. AUDIO_TOKENIZER_NAME = "audio_tokenizer_config.json"
  296. PROCESSOR_NAME = "processor_config.json"
  297. GENERATION_CONFIG_NAME = "generation_config.json"
  298. MODEL_CARD_NAME = "modelcard.json"
  299. SENTENCEPIECE_UNDERLINE = "▁"
  300. SPIECE_UNDERLINE = SENTENCEPIECE_UNDERLINE # Kept for backward compatibility
  301. MULTIPLE_CHOICE_DUMMY_INPUTS = [
  302. [[0, 1, 0, 1], [1, 0, 0, 1]]
  303. ] * 2 # Needs to have 0s and 1s only since XLM uses it for langs too.
  304. DUMMY_INPUTS = [[7, 6, 0, 0, 1], [1, 2, 3, 0, 0], [0, 0, 0, 4, 5]]
  305. DUMMY_MASK = [[1, 1, 1, 1, 1], [1, 1, 1, 0, 0], [0, 0, 0, 1, 1]]
  306. def check_min_version(min_version):
  307. if version.parse(__version__) < version.parse(min_version):
  308. if "dev" in min_version:
  309. error_message = (
  310. "This example requires a source install from HuggingFace Transformers (see "
  311. "`https://huggingface.co/docs/transformers/installation#install-from-source`),"
  312. )
  313. else:
  314. error_message = f"This example requires a minimum version of {min_version},"
  315. error_message += f" but the version found is {__version__}.\n"
  316. raise ImportError(
  317. error_message
  318. + "Check out https://github.com/huggingface/transformers/tree/main/examples#important-note for the examples corresponding to other "
  319. "versions of HuggingFace Transformers."
  320. )
  321. @lru_cache
  322. def get_available_devices() -> frozenset[str]:
  323. """
  324. Returns a frozenset of devices available for the current PyTorch installation.
  325. """
  326. devices = {"cpu"} # `cpu` is always supported as a device in PyTorch
  327. if is_torch_cuda_available():
  328. devices.add("cuda")
  329. if is_torch_mps_available():
  330. devices.add("mps")
  331. if is_torch_xpu_available():
  332. devices.add("xpu")
  333. if is_torch_npu_available():
  334. devices.add("npu")
  335. if is_torch_hpu_available():
  336. devices.add("hpu")
  337. if is_torch_mlu_available():
  338. devices.add("mlu")
  339. if is_torch_musa_available():
  340. devices.add("musa")
  341. return frozenset(devices)