constants.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106
  1. # Copyright 2022 The HuggingFace Team. All rights reserved.
  2. #
  3. # Licensed under the Apache License, Version 2.0 (the "License");
  4. # you may not use this file except in compliance with the License.
  5. # You may obtain a copy of the License at
  6. #
  7. # http://www.apache.org/licenses/LICENSE-2.0
  8. #
  9. # Unless required by applicable law or agreed to in writing, software
  10. # distributed under the License is distributed on an "AS IS" BASIS,
  11. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. # See the License for the specific language governing permissions and
  13. # limitations under the License.
  14. import operator as op
  15. import torch
  16. SCALER_NAME = "scaler.pt"
  17. MODEL_NAME = "pytorch_model"
  18. SAFE_MODEL_NAME = "model"
  19. RNG_STATE_NAME = "random_states"
  20. OPTIMIZER_NAME = "optimizer"
  21. SCHEDULER_NAME = "scheduler"
  22. SAMPLER_NAME = "sampler"
  23. PROFILE_PATTERN_NAME = "profile_{suffix}.json"
  24. WEIGHTS_NAME = f"{MODEL_NAME}.bin"
  25. WEIGHTS_PATTERN_NAME = "pytorch_model{suffix}.bin"
  26. WEIGHTS_INDEX_NAME = f"{WEIGHTS_NAME}.index.json"
  27. SAFE_WEIGHTS_NAME = f"{SAFE_MODEL_NAME}.safetensors"
  28. SAFE_WEIGHTS_PATTERN_NAME = "model{suffix}.safetensors"
  29. SAFE_WEIGHTS_INDEX_NAME = f"{SAFE_WEIGHTS_NAME}.index.json"
  30. SAGEMAKER_PYTORCH_VERSION = "1.10.2"
  31. SAGEMAKER_PYTHON_VERSION = "py38"
  32. SAGEMAKER_TRANSFORMERS_VERSION = "4.17.0"
  33. SAGEMAKER_PARALLEL_EC2_INSTANCES = ["ml.p3.16xlarge", "ml.p3dn.24xlarge", "ml.p4dn.24xlarge"]
  34. FSDP_SHARDING_STRATEGY = ["FULL_SHARD", "SHARD_GRAD_OP", "NO_SHARD", "HYBRID_SHARD", "HYBRID_SHARD_ZERO2"]
  35. FSDP_AUTO_WRAP_POLICY = ["TRANSFORMER_BASED_WRAP", "SIZE_BASED_WRAP", "NO_WRAP"]
  36. FSDP_BACKWARD_PREFETCH = ["BACKWARD_PRE", "BACKWARD_POST", "NO_PREFETCH"]
  37. FSDP_STATE_DICT_TYPE = ["FULL_STATE_DICT", "LOCAL_STATE_DICT", "SHARDED_STATE_DICT"]
  38. FSDP2_STATE_DICT_TYPE = ["SHARDED_STATE_DICT", "FULL_STATE_DICT"]
  39. FSDP_PYTORCH_VERSION = (
  40. "2.1.0.a0+32f93b1" # Technically should be 2.1.0, but MS-AMP uses this specific prerelease in their Docker image.
  41. )
  42. FSDP2_PYTORCH_VERSION = "2.6.0"
  43. FSDP_MODEL_NAME = "pytorch_model_fsdp"
  44. DEEPSPEED_MULTINODE_LAUNCHERS = ["pdsh", "standard", "openmpi", "mvapich", "mpich", "nossh", "slurm"]
  45. TORCH_DYNAMO_MODES = ["default", "reduce-overhead", "max-autotune"]
  46. ELASTIC_LOG_LINE_PREFIX_TEMPLATE_PYTORCH_VERSION = "2.2.0"
  47. XPU_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.4.0"
  48. MITA_PROFILING_AVAILABLE_PYTORCH_VERSION = "2.1.0"
  49. BETA_TP_AVAILABLE_PYTORCH_VERSION = "2.3.0"
  50. BETA_TP_AVAILABLE_TRANSFORMERS_VERSION = "4.52.0"
  51. BETA_CP_AVAILABLE_PYTORCH_VERSION = "2.6.0"
  52. BETA_SP_AVAILABLE_DEEPSPEED_VERSION = "0.18.2"
  53. STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<=": op.le, "<": op.lt}
  54. # These are the args for `torch.distributed.launch` for pytorch < 1.9
  55. TORCH_LAUNCH_PARAMS = [
  56. "nnodes",
  57. "nproc_per_node",
  58. "rdzv_backend",
  59. "rdzv_endpoint",
  60. "rdzv_id",
  61. "rdzv_conf",
  62. "standalone",
  63. "max_restarts",
  64. "monitor_interval",
  65. "start_method",
  66. "role",
  67. "module",
  68. "m",
  69. "no_python",
  70. "run_path",
  71. "log_dir",
  72. "r",
  73. "redirects",
  74. "t",
  75. "tee",
  76. "node_rank",
  77. "master_addr",
  78. "master_port",
  79. ]
  80. CUDA_DISTRIBUTED_TYPES = ["DEEPSPEED", "MULTI_GPU", "FSDP", "MEGATRON_LM", "TP"]
  81. TORCH_DISTRIBUTED_OPERATION_TYPES = CUDA_DISTRIBUTED_TYPES + [
  82. "MULTI_NPU",
  83. "MULTI_MLU",
  84. "MULTI_SDAA",
  85. "MULTI_MUSA",
  86. "MULTI_XPU",
  87. "MULTI_CPU",
  88. "MULTI_HPU",
  89. ]
  90. SUPPORTED_PYTORCH_LAYERS_FOR_UPCASTING = (
  91. torch.nn.Conv1d,
  92. torch.nn.Conv2d,
  93. torch.nn.Conv3d,
  94. torch.nn.ConvTranspose1d,
  95. torch.nn.ConvTranspose2d,
  96. torch.nn.ConvTranspose3d,
  97. torch.nn.Linear,
  98. )