merge.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869
  1. #!/usr/bin/env python
  2. # Copyright 2024 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. from accelerate.commands.utils import CustomArgumentParser
  16. from accelerate.utils import merge_fsdp_weights
  17. description = """Utility to merge the weights from multiple FSDP checkpoints into a single combined checkpoint. Should be used if
  18. `SHARDED_STATE_DICT` was used for the model. Weights will be saved to `{output_path}`.
  19. This is a CPU-bound process and requires enough RAM to load the entire model state dict."""
  20. def merge_command(args):
  21. merge_fsdp_weights(
  22. args.checkpoint_directory, args.output_path, not args.unsafe_serialization, args.remove_checkpoint_dir
  23. )
  24. def merge_command_parser(subparsers=None):
  25. if subparsers is not None:
  26. parser = subparsers.add_parser("merge-weights", description=description)
  27. else:
  28. parser = CustomArgumentParser(description=description)
  29. parser.add_argument("checkpoint_directory", type=str, help="A directory containing sharded weights saved by FSDP.")
  30. parser.add_argument(
  31. "output_path",
  32. type=str,
  33. help="The path to save the merged weights. Defaults to the current directory. ",
  34. )
  35. parser.add_argument(
  36. "--unsafe_serialization",
  37. action="store_true",
  38. default=False,
  39. help="Whether to save the merged weights as `.bin` rather than `.safetensors` (not recommended).",
  40. )
  41. parser.add_argument(
  42. "--remove_checkpoint_dir",
  43. action="store_true",
  44. help="Whether to remove the checkpoint directory after merging.",
  45. default=False,
  46. )
  47. if subparsers is not None:
  48. parser.set_defaults(func=merge_command)
  49. return parser
  50. def main():
  51. parser = merge_command_parser()
  52. args = parser.parse_args()
  53. merge_command(args)
  54. if __name__ == "__main__":
  55. main()