__init__.py 9.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310
  1. # Copyright 2023 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from typing import TYPE_CHECKING
  15. from ..utils import OptionalDependencyNotAvailable, _LazyModule, is_torch_available, is_torch_greater_or_equal
  16. _import_structure = {
  17. "aqlm": ["replace_with_aqlm_linear"],
  18. "awq": [
  19. "fuse_awq_modules",
  20. "post_init_awq_exllama_modules",
  21. "post_init_awq_ipex_modules",
  22. "replace_quantization_scales",
  23. "replace_with_awq_linear",
  24. ],
  25. "bitnet": [
  26. "BitLinear",
  27. "pack_weights",
  28. "replace_with_bitnet_linear",
  29. "unpack_weights",
  30. ],
  31. "bitsandbytes": [
  32. "dequantize_and_replace",
  33. "get_keys_to_not_convert",
  34. "replace_8bit_linear",
  35. "replace_with_bnb_linear",
  36. "set_module_8bit_tensor_to_device",
  37. "set_module_quantized_tensor_to_device",
  38. "validate_bnb_backend_availability",
  39. ],
  40. "deepspeed": [
  41. "HfDeepSpeedConfig",
  42. "HfTrainerDeepSpeedConfig",
  43. "deepspeed_config",
  44. "deepspeed_init",
  45. "deepspeed_load_checkpoint",
  46. "deepspeed_optim_sched",
  47. "is_deepspeed_available",
  48. "is_deepspeed_zero3_enabled",
  49. "set_hf_deepspeed_config",
  50. "unset_hf_deepspeed_config",
  51. ],
  52. "eetq": ["replace_with_eetq_linear"],
  53. "fbgemm_fp8": ["FbgemmFp8Linear", "FbgemmFp8Llama4TextExperts", "replace_with_fbgemm_fp8_linear"],
  54. "finegrained_fp8": ["FP8Linear", "replace_with_fp8_linear"],
  55. "fsdp": ["is_fsdp_enabled", "is_fsdp_managed_module"],
  56. "ggml": [
  57. "GGUF_CONFIG_MAPPING",
  58. "GGUF_TOKENIZER_MAPPING",
  59. "_gguf_parse_value",
  60. "load_dequant_gguf_tensor",
  61. "load_gguf",
  62. ],
  63. "higgs": [
  64. "HiggsLinear",
  65. "dequantize_higgs",
  66. "quantize_with_higgs",
  67. "replace_with_higgs_linear",
  68. ],
  69. "hqq": ["prepare_for_hqq_linear"],
  70. "hub_kernels": [
  71. "LayerRepository",
  72. "register_kernel_mapping",
  73. "replace_kernel_forward_from_hub",
  74. "use_kernel_forward_from_hub",
  75. ],
  76. "integration_utils": [
  77. "INTEGRATION_TO_CALLBACK",
  78. "AzureMLCallback",
  79. "ClearMLCallback",
  80. "CodeCarbonCallback",
  81. "CometCallback",
  82. "DagsHubCallback",
  83. "DVCLiveCallback",
  84. "FlyteCallback",
  85. "MLflowCallback",
  86. "NeptuneCallback",
  87. "NeptuneMissingConfiguration",
  88. "SwanLabCallback",
  89. "TensorBoardCallback",
  90. "TrackioCallback",
  91. "WandbCallback",
  92. "get_available_reporting_integrations",
  93. "get_reporting_integration_callbacks",
  94. "hp_params",
  95. "is_azureml_available",
  96. "is_clearml_available",
  97. "is_codecarbon_available",
  98. "is_comet_available",
  99. "is_dagshub_available",
  100. "is_dvclive_available",
  101. "is_flyte_deck_standard_available",
  102. "is_flytekit_available",
  103. "is_mlflow_available",
  104. "is_neptune_available",
  105. "is_optuna_available",
  106. "is_ray_available",
  107. "is_ray_tune_available",
  108. "is_sigopt_available",
  109. "is_swanlab_available",
  110. "is_tensorboard_available",
  111. "is_trackio_available",
  112. "is_wandb_available",
  113. "rewrite_logs",
  114. "run_hp_search_optuna",
  115. "run_hp_search_ray",
  116. "run_hp_search_sigopt",
  117. "run_hp_search_wandb",
  118. ],
  119. "mxfp4": [
  120. "Mxfp4GptOssExperts",
  121. "convert_moe_packed_tensors",
  122. "dequantize",
  123. "load_and_swizzle_mxfp4",
  124. "quantize_to_mxfp4",
  125. "replace_with_mxfp4_linear",
  126. "swizzle_mxfp4",
  127. ],
  128. "peft": ["PeftAdapterMixin"],
  129. "quanto": ["replace_with_quanto_layers"],
  130. "spqr": ["replace_with_spqr_linear"],
  131. "vptq": ["replace_with_vptq_linear"],
  132. }
  133. try:
  134. if not is_torch_available():
  135. raise OptionalDependencyNotAvailable()
  136. except OptionalDependencyNotAvailable:
  137. pass
  138. else:
  139. _import_structure["executorch"] = [
  140. "TorchExportableModuleWithStaticCache",
  141. "convert_and_export_with_cache",
  142. ]
  143. try:
  144. if not is_torch_greater_or_equal("2.3"):
  145. raise OptionalDependencyNotAvailable()
  146. except OptionalDependencyNotAvailable:
  147. pass
  148. else:
  149. _import_structure["tensor_parallel"] = [
  150. "shard_and_distribute_module",
  151. "ALL_PARALLEL_STYLES",
  152. "translate_to_torch_parallel_style",
  153. ]
  154. try:
  155. if not is_torch_greater_or_equal("2.5"):
  156. raise OptionalDependencyNotAvailable()
  157. except OptionalDependencyNotAvailable:
  158. pass
  159. else:
  160. _import_structure["flex_attention"] = [
  161. "make_flex_block_causal_mask",
  162. ]
  163. if TYPE_CHECKING:
  164. from .aqlm import replace_with_aqlm_linear
  165. from .awq import (
  166. fuse_awq_modules,
  167. post_init_awq_exllama_modules,
  168. post_init_awq_ipex_modules,
  169. replace_quantization_scales,
  170. replace_with_awq_linear,
  171. )
  172. from .bitnet import (
  173. BitLinear,
  174. pack_weights,
  175. replace_with_bitnet_linear,
  176. unpack_weights,
  177. )
  178. from .bitsandbytes import (
  179. dequantize_and_replace,
  180. get_keys_to_not_convert,
  181. replace_8bit_linear,
  182. replace_with_bnb_linear,
  183. set_module_8bit_tensor_to_device,
  184. set_module_quantized_tensor_to_device,
  185. validate_bnb_backend_availability,
  186. )
  187. from .deepspeed import (
  188. HfDeepSpeedConfig,
  189. HfTrainerDeepSpeedConfig,
  190. deepspeed_config,
  191. deepspeed_init,
  192. deepspeed_load_checkpoint,
  193. deepspeed_optim_sched,
  194. is_deepspeed_available,
  195. is_deepspeed_zero3_enabled,
  196. set_hf_deepspeed_config,
  197. unset_hf_deepspeed_config,
  198. )
  199. from .eetq import replace_with_eetq_linear
  200. from .fbgemm_fp8 import FbgemmFp8Linear, FbgemmFp8Llama4TextExperts, replace_with_fbgemm_fp8_linear
  201. from .finegrained_fp8 import FP8Linear, replace_with_fp8_linear
  202. from .fsdp import is_fsdp_enabled, is_fsdp_managed_module
  203. from .ggml import (
  204. GGUF_CONFIG_MAPPING,
  205. GGUF_TOKENIZER_MAPPING,
  206. _gguf_parse_value,
  207. load_dequant_gguf_tensor,
  208. load_gguf,
  209. )
  210. from .higgs import HiggsLinear, dequantize_higgs, quantize_with_higgs, replace_with_higgs_linear
  211. from .hqq import prepare_for_hqq_linear
  212. from .hub_kernels import (
  213. LayerRepository,
  214. register_kernel_mapping,
  215. replace_kernel_forward_from_hub,
  216. use_kernel_forward_from_hub,
  217. )
  218. from .integration_utils import (
  219. INTEGRATION_TO_CALLBACK,
  220. AzureMLCallback,
  221. ClearMLCallback,
  222. CodeCarbonCallback,
  223. CometCallback,
  224. DagsHubCallback,
  225. DVCLiveCallback,
  226. FlyteCallback,
  227. MLflowCallback,
  228. NeptuneCallback,
  229. NeptuneMissingConfiguration,
  230. SwanLabCallback,
  231. TensorBoardCallback,
  232. TrackioCallback,
  233. WandbCallback,
  234. get_available_reporting_integrations,
  235. get_reporting_integration_callbacks,
  236. hp_params,
  237. is_azureml_available,
  238. is_clearml_available,
  239. is_codecarbon_available,
  240. is_comet_available,
  241. is_dagshub_available,
  242. is_dvclive_available,
  243. is_flyte_deck_standard_available,
  244. is_flytekit_available,
  245. is_mlflow_available,
  246. is_neptune_available,
  247. is_optuna_available,
  248. is_ray_available,
  249. is_ray_tune_available,
  250. is_sigopt_available,
  251. is_swanlab_available,
  252. is_tensorboard_available,
  253. is_trackio_available,
  254. is_wandb_available,
  255. rewrite_logs,
  256. run_hp_search_optuna,
  257. run_hp_search_ray,
  258. run_hp_search_sigopt,
  259. run_hp_search_wandb,
  260. )
  261. from .mxfp4 import (
  262. Mxfp4GptOssExperts,
  263. dequantize,
  264. load_and_swizzle_mxfp4,
  265. quantize_to_mxfp4,
  266. replace_with_mxfp4_linear,
  267. swizzle_mxfp4,
  268. )
  269. from .peft import PeftAdapterMixin
  270. from .quanto import replace_with_quanto_layers
  271. from .spqr import replace_with_spqr_linear
  272. from .vptq import replace_with_vptq_linear
  273. try:
  274. if not is_torch_available():
  275. raise OptionalDependencyNotAvailable()
  276. except OptionalDependencyNotAvailable:
  277. pass
  278. else:
  279. from .executorch import TorchExportableModuleWithStaticCache, convert_and_export_with_cache
  280. try:
  281. if not is_torch_greater_or_equal("2.3"):
  282. raise OptionalDependencyNotAvailable()
  283. except OptionalDependencyNotAvailable:
  284. pass
  285. else:
  286. from .tensor_parallel import (
  287. ALL_PARALLEL_STYLES,
  288. shard_and_distribute_module,
  289. translate_to_torch_parallel_style,
  290. )
  291. try:
  292. if not is_torch_greater_or_equal("2.5"):
  293. raise OptionalDependencyNotAvailable()
  294. except OptionalDependencyNotAvailable:
  295. pass
  296. else:
  297. from .flex_attention import make_flex_block_causal_mask
  298. else:
  299. import sys
  300. sys.modules[__name__] = _LazyModule(__name__, globals()["__file__"], _import_structure, module_spec=__spec__)