__init__.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. from ..parallelism_config import ParallelismConfig
  15. from .ao import convert_model_to_fp8_ao, filter_first_and_last_linear_layers, has_ao_layers
  16. from .constants import (
  17. MITA_PROFILING_AVAILABLE_PYTORCH_VERSION,
  18. MODEL_NAME,
  19. OPTIMIZER_NAME,
  20. PROFILE_PATTERN_NAME,
  21. RNG_STATE_NAME,
  22. SAFE_MODEL_NAME,
  23. SAFE_WEIGHTS_INDEX_NAME,
  24. SAFE_WEIGHTS_NAME,
  25. SAFE_WEIGHTS_PATTERN_NAME,
  26. SAMPLER_NAME,
  27. SCALER_NAME,
  28. SCHEDULER_NAME,
  29. TORCH_DISTRIBUTED_OPERATION_TYPES,
  30. TORCH_LAUNCH_PARAMS,
  31. WEIGHTS_INDEX_NAME,
  32. WEIGHTS_NAME,
  33. WEIGHTS_PATTERN_NAME,
  34. XPU_PROFILING_AVAILABLE_PYTORCH_VERSION,
  35. )
  36. from .dataclasses import (
  37. AORecipeKwargs,
  38. AutocastKwargs,
  39. BnbQuantizationConfig,
  40. ComputeEnvironment,
  41. CustomDtype,
  42. DataLoaderConfiguration,
  43. DDPCommunicationHookType,
  44. DeepSpeedPlugin,
  45. DeepSpeedSequenceParallelConfig,
  46. DistributedDataParallelKwargs,
  47. DistributedType,
  48. DynamoBackend,
  49. FP8RecipeKwargs,
  50. FullyShardedDataParallelPlugin,
  51. GradientAccumulationPlugin,
  52. GradScalerKwargs,
  53. InitProcessGroupKwargs,
  54. KwargsHandler,
  55. LoggerType,
  56. MegatronLMPlugin,
  57. MSAMPRecipeKwargs,
  58. PrecisionType,
  59. ProfileKwargs,
  60. ProjectConfiguration,
  61. RNGType,
  62. SageMakerDistributedType,
  63. TensorInformation,
  64. TERecipeKwargs,
  65. TorchContextParallelConfig,
  66. TorchDynamoPlugin,
  67. TorchTensorParallelConfig,
  68. TorchTensorParallelPlugin,
  69. add_model_config_to_megatron_parser,
  70. )
  71. from .environment import (
  72. are_libraries_initialized,
  73. check_cuda_fp8_capability,
  74. check_cuda_p2p_ib_support,
  75. clear_environment,
  76. convert_dict_to_env_variables,
  77. get_cpu_distributed_information,
  78. get_current_device_type,
  79. get_gpu_info,
  80. get_int_from_env,
  81. parse_choice_from_env,
  82. parse_flag_from_env,
  83. patch_environment,
  84. purge_accelerate_environment,
  85. set_numa_affinity,
  86. str_to_bool,
  87. )
  88. from .imports import (
  89. deepspeed_required,
  90. get_ccl_version,
  91. is_4bit_bnb_available,
  92. is_8bit_bnb_available,
  93. is_aim_available,
  94. is_bf16_available,
  95. is_bitsandbytes_multi_backend_available,
  96. is_bnb_available,
  97. is_boto3_available,
  98. is_ccl_available,
  99. is_clearml_available,
  100. is_comet_ml_available,
  101. is_cuda_available,
  102. is_datasets_available,
  103. is_deepspeed_available,
  104. is_dvclive_available,
  105. is_fp8_available,
  106. is_fp16_available,
  107. is_habana_gaudi1,
  108. is_hpu_available,
  109. is_import_timer_available,
  110. is_ipex_available,
  111. is_lomo_available,
  112. is_matplotlib_available,
  113. is_megatron_lm_available,
  114. is_mlflow_available,
  115. is_mlu_available,
  116. is_mps_available,
  117. is_msamp_available,
  118. is_musa_available,
  119. is_npu_available,
  120. is_pandas_available,
  121. is_peft_available,
  122. is_pippy_available,
  123. is_pynvml_available,
  124. is_pytest_available,
  125. is_rich_available,
  126. is_sagemaker_available,
  127. is_schedulefree_available,
  128. is_sdaa_available,
  129. is_swanlab_available,
  130. is_tensorboard_available,
  131. is_timm_available,
  132. is_torch_xla_available,
  133. is_torchao_available,
  134. is_torchdata_available,
  135. is_torchdata_stateful_dataloader_available,
  136. is_torchvision_available,
  137. is_trackio_available,
  138. is_transformer_engine_available,
  139. is_transformer_engine_mxfp8_available,
  140. is_transformers_available,
  141. is_triton_available,
  142. is_wandb_available,
  143. is_weights_only_available,
  144. is_xccl_available,
  145. is_xpu_available,
  146. torchao_required,
  147. )
  148. from .modeling import (
  149. align_module_device,
  150. calculate_maximum_sizes,
  151. check_device_map,
  152. check_tied_parameters_in_config,
  153. check_tied_parameters_on_same_device,
  154. compute_module_sizes,
  155. convert_file_size_to_int,
  156. dtype_byte_size,
  157. find_tied_parameters,
  158. get_balanced_memory,
  159. get_grad_scaler,
  160. get_max_layer_size,
  161. get_max_memory,
  162. get_mixed_precision_context_manager,
  163. has_offloaded_params,
  164. id_tensor_storage,
  165. infer_auto_device_map,
  166. is_peft_model,
  167. load_checkpoint_in_model,
  168. load_offloaded_weights,
  169. load_state_dict,
  170. named_module_tensors,
  171. retie_parameters,
  172. set_module_tensor_to_device,
  173. )
  174. from .offload import (
  175. OffloadedWeightsLoader,
  176. PrefixedDataset,
  177. extract_submodules_state_dict,
  178. load_offloaded_weight,
  179. offload_state_dict,
  180. offload_weight,
  181. save_offload_index,
  182. )
  183. from .operations import (
  184. CannotPadNestedTensorWarning,
  185. GatheredParameters,
  186. broadcast,
  187. broadcast_object_list,
  188. concatenate,
  189. convert_outputs_to_fp32,
  190. convert_to_fp32,
  191. copy_tensor_to_devices,
  192. find_batch_size,
  193. find_device,
  194. gather,
  195. gather_object,
  196. get_data_structure,
  197. honor_type,
  198. ignorant_find_batch_size,
  199. initialize_tensors,
  200. is_namedtuple,
  201. is_tensor_information,
  202. is_torch_tensor,
  203. listify,
  204. pad_across_processes,
  205. pad_input_tensors,
  206. recursively_apply,
  207. reduce,
  208. send_to_device,
  209. slice_tensors,
  210. )
  211. from .versions import compare_versions, is_torch_version
  212. if is_deepspeed_available():
  213. from .deepspeed import (
  214. DeepSpeedEngineWrapper,
  215. DeepSpeedOptimizerWrapper,
  216. DeepSpeedSchedulerWrapper,
  217. DummyOptim,
  218. DummyScheduler,
  219. HfDeepSpeedConfig,
  220. get_active_deepspeed_plugin,
  221. map_pytorch_optim_to_deepspeed,
  222. )
  223. from .bnb import has_4bit_bnb_layers, load_and_quantize_model
  224. from .fsdp_utils import (
  225. disable_fsdp_ram_efficient_loading,
  226. enable_fsdp_ram_efficient_loading,
  227. ensure_weights_retied,
  228. fsdp2_apply_ac,
  229. fsdp2_canonicalize_names,
  230. fsdp2_load_full_state_dict,
  231. fsdp2_prepare_model,
  232. fsdp2_switch_optimizer_parameters,
  233. get_fsdp2_grad_scaler,
  234. load_fsdp_model,
  235. load_fsdp_optimizer,
  236. merge_fsdp_weights,
  237. save_fsdp_model,
  238. save_fsdp_optimizer,
  239. )
  240. from .launch import (
  241. PrepareForLaunch,
  242. _filter_args,
  243. prepare_deepspeed_cmd_env,
  244. prepare_multi_gpu_env,
  245. prepare_sagemager_args_inputs,
  246. prepare_simple_launcher_cmd_env,
  247. prepare_tpu,
  248. )
  249. # For docs
  250. from .megatron_lm import (
  251. AbstractTrainStep,
  252. BertTrainStep,
  253. GPTTrainStep,
  254. MegatronLMDummyDataLoader,
  255. MegatronLMDummyScheduler,
  256. T5TrainStep,
  257. avg_losses_across_data_parallel_group,
  258. )
  259. if is_megatron_lm_available():
  260. from .megatron_lm import (
  261. MegatronEngine,
  262. MegatronLMOptimizerWrapper,
  263. MegatronLMSchedulerWrapper,
  264. gather_across_data_parallel_groups,
  265. )
  266. from .megatron_lm import initialize as megatron_lm_initialize
  267. from .megatron_lm import prepare_data_loader as megatron_lm_prepare_data_loader
  268. from .megatron_lm import prepare_model_optimizer_scheduler as megatron_lm_prepare_model_optimizer_scheduler
  269. from .megatron_lm import prepare_optimizer as megatron_lm_prepare_optimizer
  270. from .megatron_lm import prepare_scheduler as megatron_lm_prepare_scheduler
  271. from .memory import find_executable_batch_size, release_memory
  272. from .other import (
  273. check_os_kernel,
  274. clean_state_dict_for_safetensors,
  275. compile_regions,
  276. compile_regions_deepspeed,
  277. convert_bytes,
  278. extract_model_from_parallel,
  279. get_module_children_bottom_up,
  280. get_pretty_name,
  281. has_compiled_regions,
  282. is_compiled_module,
  283. is_port_in_use,
  284. load,
  285. merge_dicts,
  286. model_has_dtensor,
  287. recursive_getattr,
  288. save,
  289. wait_for_everyone,
  290. write_basic_config,
  291. )
  292. from .random import set_seed, synchronize_rng_state, synchronize_rng_states
  293. from .torch_xla import install_xla
  294. from .tqdm import tqdm
  295. from .transformer_engine import (
  296. apply_fp8_autowrap,
  297. contextual_fp8_autocast,
  298. convert_model,
  299. has_transformer_engine_layers,
  300. )