to_fsdp2.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172
  1. #!/usr/bin/env python
  2. # Copyright 2025 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import enum
  16. import logging
  17. from pathlib import Path
  18. import yaml
  19. from accelerate.commands.utils import CustomArgumentParser
  20. class ConversionStatus(enum.Enum):
  21. NOT_YET_IMPLEMENTED = 0
  22. REMOVED = -1
  23. ARGUMENT_KEY_MAPPING = {
  24. # New keys in FSDP2
  25. "fsdp_version": "fsdp_version",
  26. "fsdp_reshard_after_forward": "fsdp_reshard_after_forward",
  27. # https://github.com/pytorch/torchtitan/blob/main/docs/fsdp.md
  28. # https://huggingface.co/docs/accelerate/en/usage_guides/fsdp
  29. "fsdp_auto_wrap_policy": "fsdp_auto_wrap_policy",
  30. "fsdp_backward_prefetch": ConversionStatus.REMOVED,
  31. "fsdp_forward_prefetch": ConversionStatus.NOT_YET_IMPLEMENTED,
  32. "fsdp_cpu_ram_efficient_loading": "fsdp_cpu_ram_efficient_loading",
  33. "fsdp_offload_params": "fsdp_offload_params",
  34. "fsdp_sharding_strategy": "fsdp_reshard_after_forward",
  35. "fsdp_state_dict_type": "fsdp_state_dict_type",
  36. "fsdp_sync_module_states": ConversionStatus.REMOVED,
  37. "fsdp_transformer_layer_cls_to_wrap": "fsdp_transformer_layer_cls_to_wrap",
  38. "fsdp_min_num_params": "fsdp_min_num_params",
  39. "fsdp_use_orig_params": ConversionStatus.REMOVED,
  40. "fsdp_activation_checkpointing": "fsdp_activation_checkpointing",
  41. }
  42. ARGUMENT_VALUE_MAPPING = {
  43. "fsdp_sharding_strategy": {
  44. "FULL_SHARD": True,
  45. "SHARD_GRAD_OP": False,
  46. "HYBRID_SHARD": True,
  47. "HYBRID_SHARD_ZERO2": False,
  48. "NO_SHARD": False,
  49. },
  50. "fsdp_reshard_after_forward": { # Needed to convert newly created configs using FSDP1 to FSDP2
  51. "FULL_SHARD": True,
  52. "SHARD_GRAD_OP": False,
  53. "HYBRID_SHARD": True,
  54. "HYBRID_SHARD_ZERO2": False,
  55. "NO_SHARD": False,
  56. },
  57. }
  58. logger = logging.getLogger(__name__)
  59. def _validate_to_fsdp2_args(args):
  60. if not Path(args.config_file).exists():
  61. raise FileNotFoundError(f"Config file {args.config_file} not found")
  62. if not args.overwrite and args.output_file is None:
  63. raise ValueError("If --overwrite is not set, --output_file must be provided")
  64. if not args.overwrite and Path(args.output_file).exists():
  65. raise FileExistsError(f"Output file {args.output_file} already exists and --overwrite is not set")
  66. def convert_config_to_fsdp2(config: dict) -> dict:
  67. fsdp_config = config.get("fsdp_config", {})
  68. if not fsdp_config:
  69. logger.info("No FSDP config found in the config file, skipping conversion...")
  70. return config
  71. new_fsdp_config = {}
  72. if fsdp_config.get("fsdp_version", 1) == 2:
  73. logger.warning("Config already specifies FSDP2, skipping conversion...")
  74. logger.warning(
  75. "If the config doesn't use new argument names, change `fsdp_version` to `1` and rerun the command."
  76. )
  77. return config
  78. for key, value in fsdp_config.items():
  79. conversion_status = ARGUMENT_KEY_MAPPING.get(key, None)
  80. if isinstance(conversion_status, ConversionStatus) or conversion_status is None:
  81. conversion_status = key
  82. new_fsdp_config[conversion_status] = value
  83. continue
  84. if conversion_status == ConversionStatus.REMOVED:
  85. logger.warning(f"Argument {key} has been removed in FSDP2, skipping this key...")
  86. continue
  87. if conversion_status == ConversionStatus.NOT_YET_IMPLEMENTED:
  88. logger.warning(f"Argument {key} is not yet implemented in FSDP2, skipping this key...")
  89. continue
  90. if conversion_status is None:
  91. logger.warning(f"Argument {key} is not being converted, skipping this key...")
  92. new_fsdp_config[key] = value
  93. else:
  94. if key in ARGUMENT_VALUE_MAPPING:
  95. value = ARGUMENT_VALUE_MAPPING[key].get(value, value)
  96. new_fsdp_config[ARGUMENT_KEY_MAPPING[key]] = value
  97. new_fsdp_config["fsdp_version"] = 2
  98. config["fsdp_config"] = new_fsdp_config
  99. return config
  100. def to_fsdp2_command_parser(subparsers=None):
  101. description = "Convert an Accelerate config from FSDP1 to FSDP2"
  102. if subparsers is not None:
  103. parser = subparsers.add_parser("to-fsdp2", description=description)
  104. else:
  105. parser = CustomArgumentParser(description=description)
  106. parser.add_argument("--config_file", type=str, help="The config file to convert to FSDP2", required=True)
  107. parser.add_argument(
  108. "--overwrite",
  109. action="store_true",
  110. help="Overwrite the config file if it exists",
  111. default=False,
  112. )
  113. parser.add_argument(
  114. "--output_file",
  115. type=str,
  116. help="The path to the output file to write the converted config to. If not provided, the input file will be overwritten (if --overwrite is set)",
  117. default=None,
  118. )
  119. if subparsers is not None:
  120. parser.set_defaults(func=to_fsdp2_command)
  121. return parser
  122. def load_config(config_file: str) -> dict:
  123. with open(config_file) as f:
  124. config = yaml.safe_load(f)
  125. if not config:
  126. raise ValueError("Config file is empty")
  127. return config
  128. def to_fsdp2_command(args):
  129. _validate_to_fsdp2_args(args)
  130. config = load_config(args.config_file)
  131. if args.overwrite and args.output_file is None:
  132. args.output_file = args.config_file
  133. new_config = convert_config_to_fsdp2(config)
  134. with open(args.output_file, "w") as f:
  135. yaml.dump(new_config, f)