test.py 2.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465
  1. #!/usr/bin/env python
  2. # Copyright 2021 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. from accelerate.test_utils import execute_subprocess_async, path_in_accelerate_package
  17. def test_command_parser(subparsers=None):
  18. if subparsers is not None:
  19. parser = subparsers.add_parser("test")
  20. else:
  21. parser = argparse.ArgumentParser("Accelerate test command")
  22. parser.add_argument(
  23. "--config_file",
  24. default=None,
  25. help=(
  26. "The path to use to store the config file. Will default to a file named default_config.yaml in the cache "
  27. "location, which is the content of the environment `HF_HOME` suffixed with 'accelerate', or if you don't have "
  28. "such an environment variable, your cache directory ('~/.cache' or the content of `XDG_CACHE_HOME`) suffixed "
  29. "with 'huggingface'."
  30. ),
  31. )
  32. if subparsers is not None:
  33. parser.set_defaults(func=test_command)
  34. return parser
  35. def test_command(args):
  36. script_name = path_in_accelerate_package("test_utils", "scripts", "test_script.py")
  37. if args.config_file is None:
  38. test_args = [script_name]
  39. else:
  40. test_args = f"--config_file={args.config_file} {script_name}".split()
  41. cmd = ["accelerate-launch"] + test_args
  42. result = execute_subprocess_async(cmd)
  43. if result.returncode == 0:
  44. print("Test is a success! You are ready for your distributed training!")
  45. def main():
  46. parser = test_command_parser()
  47. args = parser.parse_args()
  48. test_command(args)
  49. if __name__ == "__main__":
  50. main()