tpu.py 5.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157
  1. #!/usr/bin/env python
  2. # Copyright 2022 The HuggingFace Team. All rights reserved.
  3. #
  4. # Licensed under the Apache License, Version 2.0 (the "License");
  5. # you may not use this file except in compliance with the License.
  6. # You may obtain a copy of the License at
  7. #
  8. # http://www.apache.org/licenses/LICENSE-2.0
  9. #
  10. # Unless required by applicable law or agreed to in writing, software
  11. # distributed under the License is distributed on an "AS IS" BASIS,
  12. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  13. # See the License for the specific language governing permissions and
  14. # limitations under the License.
  15. import argparse
  16. import os
  17. import subprocess
  18. from packaging.version import Version, parse
  19. from accelerate.commands.config.config_args import default_config_file, load_config_from_file
  20. _description = "Run commands across TPU VMs for initial setup before running `accelerate launch`."
  21. def tpu_command_parser(subparsers=None):
  22. if subparsers is not None:
  23. parser = subparsers.add_parser("tpu-config", description=_description)
  24. else:
  25. parser = argparse.ArgumentParser("Accelerate tpu-config command", description=_description)
  26. # Core arguments
  27. config_args = parser.add_argument_group(
  28. "Config Arguments", "Arguments that can be configured through `accelerate config`."
  29. )
  30. config_args.add_argument(
  31. "--config_file",
  32. type=str,
  33. default=None,
  34. help="Path to the config file to use for accelerate.",
  35. )
  36. config_args.add_argument(
  37. "--tpu_name",
  38. default=None,
  39. help="The name of the TPU to use. If not specified, will use the TPU specified in the config file.",
  40. )
  41. config_args.add_argument(
  42. "--tpu_zone",
  43. default=None,
  44. help="The zone of the TPU to use. If not specified, will use the zone specified in the config file.",
  45. )
  46. pod_args = parser.add_argument_group("TPU Arguments", "Arguments for options ran inside the TPU.")
  47. pod_args.add_argument(
  48. "--use_alpha",
  49. action="store_true",
  50. help="Whether to use `gcloud alpha` when running the TPU training script instead of `gcloud`.",
  51. )
  52. pod_args.add_argument(
  53. "--command_file",
  54. default=None,
  55. help="The path to the file containing the commands to run on the pod on startup.",
  56. )
  57. pod_args.add_argument(
  58. "--command",
  59. action="append",
  60. nargs="+",
  61. help="A command to run on the pod. Can be passed multiple times.",
  62. )
  63. pod_args.add_argument(
  64. "--install_accelerate",
  65. action="store_true",
  66. help="Whether to install accelerate on the pod. Defaults to False.",
  67. )
  68. pod_args.add_argument(
  69. "--accelerate_version",
  70. default="latest",
  71. help="The version of accelerate to install on the pod. If not specified, will use the latest pypi version. Specify 'dev' to install from GitHub.",
  72. )
  73. pod_args.add_argument(
  74. "--debug", action="store_true", help="If set, will print the command that would be run instead of running it."
  75. )
  76. if subparsers is not None:
  77. parser.set_defaults(func=tpu_command_launcher)
  78. return parser
  79. def tpu_command_launcher(args):
  80. defaults = None
  81. # Get the default from the config file if it exists.
  82. if args.config_file is not None or os.path.isfile(default_config_file):
  83. defaults = load_config_from_file(args.config_file)
  84. if not args.command_file and defaults.command_file is not None and not args.command:
  85. args.command_file = defaults.command_file
  86. if not args.command and defaults.commands is not None:
  87. args.command = defaults.commands
  88. if not args.tpu_name:
  89. args.tpu_name = defaults.tpu_name
  90. if not args.tpu_zone:
  91. args.tpu_zone = defaults.tpu_zone
  92. if args.accelerate_version == "dev":
  93. args.accelerate_version = "git+https://github.com/huggingface/accelerate.git"
  94. elif args.accelerate_version == "latest":
  95. args.accelerate_version = "accelerate -U"
  96. elif isinstance(parse(args.accelerate_version), Version):
  97. args.accelerate_version = f"accelerate=={args.accelerate_version}"
  98. if not args.command_file and not args.command:
  99. raise ValueError("You must specify either a command file or a command to run on the pod.")
  100. if args.command_file:
  101. with open(args.command_file) as f:
  102. args.command = [f.read().splitlines()]
  103. # To turn list of lists into list of strings
  104. if isinstance(args.command[0], list):
  105. args.command = [line for cmd in args.command for line in cmd]
  106. # Default to the shared folder and install accelerate
  107. new_cmd = ["cd /usr/share"]
  108. if args.install_accelerate:
  109. new_cmd += [f"pip install {args.accelerate_version}"]
  110. new_cmd += args.command
  111. args.command = "; ".join(new_cmd)
  112. # Then send it to gcloud
  113. # Eventually try to use google-api-core to do this instead of subprocess
  114. cmd = ["gcloud"]
  115. if args.use_alpha:
  116. cmd += ["alpha"]
  117. cmd += [
  118. "compute",
  119. "tpus",
  120. "tpu-vm",
  121. "ssh",
  122. args.tpu_name,
  123. "--zone",
  124. args.tpu_zone,
  125. "--command",
  126. args.command,
  127. "--worker",
  128. "all",
  129. ]
  130. if args.debug:
  131. print(f"Running {' '.join(cmd)}")
  132. return
  133. subprocess.run(cmd)
  134. print("Successfully setup pod.")
  135. def main():
  136. parser = tpu_command_parser()
  137. args = parser.parse_args()
  138. tpu_command_launcher(args)